# TODO: Implement more efficient monotonic heuristic
#
# Every function receive coordinates of two grid points returns estimated distance between them.
# Each argument is a tuple of two or three integer coordinates.
# See file task.md for description of all grids.

import math
import numpy as np
from graphs import Grid2D, GridDiagonal2D, GridQueen2D, GridGreatKing2D, GridRook2D, GridJumper2D, Grid3D, GridFaceDiagonal3D, GridAllDiagonal3D

# For two points a and b in the n-dimensional space, return the d-dimensional point r such that r_i = | a_i - b_i | for i = 1...d
def distance_in_each_coordinate(x, y):
    return [ abs(a-b) for (a,b) in zip(x, y) ]

def calculate_distance(graph, size):
    """
        For a given graph (derived from Grid) and size, calculate lengths of shortest paths from (0,0) to all vertices (a,b) where 0 <= a < size and 0 <= b < size.
        Returns a matrix dists (a list of lists) where dists[a][b] is the length of the shortest path between (0,0) to (a,b) in the graph.
        Warning: Vertices of the graph must be two dimensional coordinates.
    """
    dists = [ [-1 for _ in range(size)] for _ in range(size) ]
    dists[0][0] = 0
    queue = [ (0,0) ]
    head = 0
    while head < len(queue):
        u = queue[head]
        head += 1
        d = dists[u[0]][u[1]] + 1
        for v in graph.neighbours(u):
            a = abs(v[0])
            b = abs(v[1])
            if a < size and b < size and dists[a][b] == -1:
                dists[a][b] = d
                queue.append((a,b))
    return dists

## when not with return 0, then it passed the tests

def grid_2D_heuristic(current, destination):
    ## taking the length by each coordinate
    x, y = distance_in_each_coordinate(current, destination)
    return x+y

def grid_diagonal_2D_heuristic(current, destination):
    ## going by diagonal as long as possible, then straight
    x, y = distance_in_each_coordinate(current, destination)
    if (x <= y):
        return x+(y-x)
    else:
        return y+(x-y)

def grid_3D_heuristic(current, destination):
    ## taking the length by each coordinate
    x, y, z = distance_in_each_coordinate(current, destination)
    return x+y+z

def grid_face_diagonal_3D_heuristic(current, destination):
    n = distance_in_each_coordinate(current, destination)
    t = max(math.ceil(sum(n)/2), max(n))
    return t

def grid_all_diagonal_3D_heuristic(current, destination):
    total_dist = 0
    n = np.array(distance_in_each_coordinate(current, destination)) 
    while (n.shape[0] > 0):
        m = min(n)
        n -= m
        total_dist += m
        n = n[n != 0]
    return total_dist.item()

def grid_queen_2D_heuristic(current, destination): ## NOT MONOTONIC
    total_dist = 0
    n = np.array(distance_in_each_coordinate(current, destination))
    while (n.shape[0] > 0):
        m = min(n)
        if (m%8 != 0): ## when not exactly 8*x steps
            total_dist += 1
        total_dist += m//8
        n -= m
        n = n[n != 0]
    tmp_arr = distance_in_each_coordinate(current, destination)
    tmp_sum = tmp_arr[0] + tmp_arr[1]
    if tmp_sum % 8 == 0:
        return min(total_dist.item(), (tmp_sum//8))
    else:
        return min(total_dist.item(), math.ceil(tmp_sum/8))



def grid_great_king_2D_heuristic(current, destination):
    total_dist = 0
    n = distance_in_each_coordinate(current, destination)
    m = max(n)
    total_dist += m//8
    if (m%8 != 0):
        total_dist += 1
    return total_dist

def grid_rook_2D_heuristic(current, destination):
    total_dist = 0
    x, y = distance_in_each_coordinate(current, destination)
    if (x >= 8):
        old_x = x
        temp = x // 8
        total_dist += temp
        if (old_x != temp*8):
            total_dist += 1
    elif (x > 0):
        total_dist += 1

    if (y >= 8):
        old_y = y
        temp = y // 8
        total_dist += temp
        if (old_y != temp*8):
            total_dist += 1
    elif( y > 0):
        total_dist += 1

    return total_dist

def grid_jumper_2D_heuristic(current, destination):
    n = distance_in_each_coordinate(current, destination)
    t = math.ceil(max(n[0]/3, n[1]/3, (n[0]+n[1])/5))
    return t + (t + n[0] + n[1])%2