import os
import argparse
from typing import Sequence, Tuple, Dict
import numpy as np
import cv2 # OpenCV
import scipy.signal
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import scipy.signal
import skimage.io

parser = argparse.ArgumentParser()
# These arguments will not be used during evaluation but you can use them for experimentation
# if it's easier for you to debug your algorithms in this file.
parser.add_argument("--image", default="Lighthouse_bggr.png", type=str, help="Image to load.")
parser.add_argument("--task", default="bayer", type=str, help="Selected task.")

def bayerDemosaic(bayerImage : np.ndarray) -> Dict[str, np.ndarray]:
    ## SECONDARY FUNCTION ##
    def get_mask(shape):
        ## extending the mask for later cutting
        ## Must be extended to the size of the whole mask
        shape = (shape[0]+4, shape[1]+4)
        four_masks = {}

        ## RGGB mask
        mask = dict((color, np.zeros(shape)) for color in 'RGB')
        mask['R'][0::2, 0::2] = 1
        mask['G'][0::2, 1::2] = 1
        mask['G'][1::2, 0::2] = 1
        mask['B'][1::2, 1::2] = 1
        for color in 'RGB':
            mask[color] = mask[color].astype(bool)
        four_masks["rggb"] = mask

        ## BGGR mask
        mask = dict((color, np.zeros(shape)) for color in 'RGB')
        mask['B'][0::2, 0::2] = 1
        mask['G'][0::2, 1::2] = 1
        mask['G'][1::2, 0::2] = 1
        mask['R'][1::2, 1::2] = 1
        for color in 'RGB':
            mask[color] = mask[color].astype(bool)
        four_masks["bggr"] = mask

        ## GBRG mask
        mask = dict((color, np.zeros(shape)) for color in 'RGB')
        mask['G'][0::2, 0::2] = 1
        mask['B'][0::2, 1::2] = 1
        mask['R'][1::2, 0::2] = 1
        mask['G'][1::2, 1::2] = 1
        for color in 'RGB':
            mask[color] = mask[color].astype(bool)
        four_masks["gbrg"] = mask

        ## GRBG mask
        mask = dict((color, np.zeros(shape)) for color in 'RGB')
        mask['G'][0::2, 0::2] = 1
        mask['R'][0::2, 1::2] = 1
        mask['B'][1::2, 0::2] = 1
        mask['G'][1::2, 1::2] = 1
        for color in 'RGB':
            mask[color] = mask[color].astype(bool)
        four_masks["grbg"] = mask

        return four_masks
    
    ## MAIN FUNCTION
    kernels = np.array([
        [[1,2,1],
         [2,4,2],
         [1,2,1]],

        [[0,1,0],
         [1,4,1],
         [0,1,0]],

        [[1,2,1],
         [2,4,2],
         [1,2,1]]
        ])/4
    
    masks = get_mask(bayerImage.shape) ## getting all color masks
    bayerImage = np.pad(bayerImage,pad_width=((2,2),), mode='reflect') ## extending the image
    results = {}
    for key in masks.keys():
        R = scipy.signal.convolve(bayerImage * masks[key]['R'], kernels[0])[3:-3, 3:-3]
        G = scipy.signal.convolve(bayerImage * masks[key]['G'], kernels[1])[3:-3, 3:-3]
        B = scipy.signal.convolve(bayerImage * masks[key]['B'], kernels[2])[3:-3, 3:-3]
        results[key] = np.stack([R, G, B], axis=2).astype(np.uint8)
    return results

def medianCut(image : np.ndarray, numColors : int) -> Tuple[np.ndarray, np.ndarray]:
    ## SECONDARY FUNCTION ##
    def quantization(flat_image, color_palette, palette_indexing):
        avgs = [0, 0, 0]
        ## iterating RGB
        for i in range(3): 
            avgs[i] = np.mean(flat_image[:,i])
        ## saving new value to color palette
        color_palette.append(avgs)
        ## replacing the image pixel color to the bucket average
        for pixel in flat_image:
            ## pixel[3] := x_coor, pixel[4] := y_coor
            palette_indexing[pixel[3]][pixel[4]] = len(color_palette)-1

    ## SECONDARY FUNCTION ##
    def bucket_split(flat_image, num_of_colors, color_palette, palette_indexing):
        ## when already splitted to atoms
        if len(flat_image) == 0: 
            return
        ## when already reached the color limit, average them in the bucket
        if num_of_colors <= 1: 
            quantization(flat_image, color_palette, palette_indexing)
            return

        ## finding the "greatest" range
        ranges = [0, 0, 0]
        for i in range(3): ## RGB iteration
            ranges[i] = np.max(flat_image[:,i]) - np.min(flat_image[:,i])
        max_range_index = np.argmax(ranges)

        ## sorting the image by the greatest axis
        flat_image = flat_image[flat_image[:,max_range_index].argsort()]
        middle_index = int((len(flat_image)-1)/2)

        bucket_split(flat_image[:middle_index], num_of_colors/2, color_palette, palette_indexing)
        bucket_split(flat_image[middle_index:], num_of_colors/2, color_palette, palette_indexing)

    ## MAIN FUNCTION ##
    ## flatting the image to 1D array with 5-tuples
    flat_img = []
    for x in range(len(image)): ## iteration over rows
        for y in range(len(image[x])): ## iteration over pixels in row
            flat_img.append([image[x][y][0], image[x][y][1], image[x][y][2], x, y]) ## saving each color separately
    flat_img = np.array(flat_img)

    ## create output arrays
    palette_indexing = np.zeros(shape=(image.shape[0], image.shape[1]), dtype=np.uint64)
    color_palette = []

    ## make the median-cut
    bucket_split(flat_img, numColors, color_palette, palette_indexing)
    color_palette = np.asarray(color_palette, dtype=np.uint64)

    ## return results
    palette = color_palette
    idxImage = palette_indexing
    return palette, idxImage

def main(args : argparse.Namespace):
    current_dir = os.path.dirname(__file__)
    
    ## TASK 1
    # test_loaded_image = mpimg.imread(current_dir + "\\data\\Lighthouse_bggr.png")
    # results = bayerDemosaic(test_loaded_image)
    # for key in results:
    #     plt.imshow(results[key])
    #     plt.show()

    # TASK 2
    test_loaded_image = skimage.io.imread(current_dir + "\\data\\im36.jpg")
    palette, indexes = medianCut(test_loaded_image, 32)
    plt.imshow(indexes)
    plt.show()
    plt.imshow(palette[indexes])
    plt.show()

if __name__ == "__main__":
    args = parser.parse_args([] if "__file__" not in globals() else None)
    main(args)
