import numpy
from minesweeper_common import UNKNOWN, MINE, get_neighbors
import constraint as ct

RUN_TESTS = False

"""
    TODO: Improve the strategy for playing Minesweeper provided in this file.
    The provided strategy simply counts the number of unexplored mines in the neighborhood of each cells.
    If this counting concludes that a cell has to contain a mine, then it is marked.
    If it concludes that a cell cannot contain a mine, then it is explored; i.e. function Player.preprocessing returns its coordinates.
    If the simple counting algorithm does not find a unexplored cell provably without a mine,
     function Player.probability_player is called to find an unexplored cell with the minimal probability having a mine.
    A recommended approach is implementing the function Player.get_each_mine_probability.
    You can adopt this file as you like but you have to keep the interface so that your player properly works on recodex; i.e.
        * Player.__init__ is called in the beginning of every game.
        * Player.turn is called to explore one cell.
"""

class Player:
    def __init__(self, rows, columns, game, mine_prb):
        # Initialize a player for a game on a board of given size with the probability of a mine on each cell.
        self.rows = rows
        self.columns = columns
        self.mine_prb = mine_prb

        # Matrix of the game. Every cell contains a number with the following meaning:
        # - A non-negative integer for explored cells means the number of mines in neighbors
        # - MINE: A cell is marked that it contains a mine
        # - UNKNOWN: We do not know whether it contains a mine or not.
        self.game = game

        # Matrix which for every cell contains the list of all neighbors.
        self.neighbors = get_neighbors(rows, columns)

        # Matrix of numbers of missing mines in the neighborhood of every cell.
        # -1 if a cell is unexplored.
        self.mines = numpy.full((rows, columns), -1)

        # Matrix of the numbers of unexplored neighborhood cells, excluding known mines.
        self.unknown = numpy.full((rows, columns), 0)
        for i in range(self.rows):
            for j in range(self.columns):
                self.unknown[i,j] = len(self.neighbors[i,j])

        # A set of cells for which the precomputed values self.mines and self.unknown need to be updated.
        self.invalid = set()

    def turn(self):
        # Returns the position of one cell to be explored.
        pos = self.preprocessing()
        if not pos:
            pos = self.probability_player()
        self.invalidate_with_neighbors(pos)
        return pos        

    def probability_player(self):
        # Return an unexplored cell with the minimal probability of mine
        prb = self.get_each_mine_probability()
        min_prb = 1
        for i in range(self.rows):
            for j in range(self.columns):
                if self.game[i,j] == UNKNOWN:
                    if prb[i,j] > 0.9999: # Float-point arithmetics may not be exact.
                        self.game[i,j] = MINE
                        self.invalidate_with_neighbors((i,j))
                    if min_prb > prb[i,j]:
                        min_prb = prb[i,j]
                        best_position = (i,j)
        return best_position

    def invalidate_with_neighbors(self, pos):
        # Insert a given position and its neighborhood to the list of cell requiring update of precomputed information.
        self.invalid.add(pos)
        for neigh in self.neighbors[pos]:
            self.invalid.add(neigh)

    def preprocess_all(self):
        # Preprocess all cells
        self.invalid = set((i,j) for i in range(self.rows) for j in range(self.columns))
        pos = self.preprocessing()
        assert(pos == None) # Preprocessing is incomplete

    def preprocessing(self):
        """
            Update precomputed information of cells in the set self.invalid.
            Using a simple counting, check cells which have to contain a mine.
            If this simple counting finds a cell which cannot contain a mine, then returns its position.
            Otherwise, returns None.
        """
        while self.invalid:
            pos = self.invalid.pop()

            # Count the numbers of unexplored neighborhood cells, excluding known mines.
            self.unknown[pos] = sum(1 if self.game[neigh] == UNKNOWN else 0 for neigh in self.neighbors[pos])

            if self.game[pos] >= 0:
                # If the cell pos is explored, count the number of missing mines in its neighborhood.
                self.mines[pos] = self.game[pos] - sum(1 if self.game[neigh] == MINE else 0 for neigh in self.neighbors[pos])
                assert(0 <= self.mines[pos] and self.mines[pos] <= self.unknown[pos])

                if self.unknown[pos] > 0:
                    if self.mines[pos] == self.unknown[pos]:
                        # All unexplored neighbors have to contain a mine, so mark them.
                        for neigh in self.neighbors[pos]:
                            if self.game[neigh] == UNKNOWN:
                                self.game[neigh] = MINE
                                self.invalidate_with_neighbors(neigh)

                    elif self.mines[pos] == 0:
                        # All mines in the neighborhood was found, so explore the rest.
                        self.invalid.add(pos) # There may be other unexplored neighbors.
                        for neigh in self.neighbors[pos]:
                            if self.game[neigh] == UNKNOWN:
                                return neigh
                        assert(False) # There has to be at least one unexplored neighbor.

        if not RUN_TESTS:
            return None

        # If the invalid list is empty, so self.unknown and self.mines should be correct.
        # Verify it to be sure.
        for i in range(self.rows):
            for j in range(self.columns):
                assert(self.unknown[i,j] == sum(1 if self.game[neigh] == UNKNOWN else 0 for neigh in self.neighbors[i,j]))
                if self.game[i,j] >= 0:
                    assert(self.mines[i,j] == self.game[i,j] - sum(1 if self.game[neigh] == MINE else 0 for neigh in self.neighbors[i,j]))

    def get_each_mine_probability(self):
        # Returns a matrix of probabilities of a mine of each cell
        probability = numpy.zeros((self.rows,self.columns))
        probability.fill(self.mine_prb)
        all_unknown_positions = set()
        
        mines_position_problem = ct.Problem()

        ## iterating over all positions
        for i in range(self.rows):
            for j in range(self.columns):
                ## when the position has a given number, create new constraint with given number of mines around
                if (self.game[i, j] == MINE):
                    probability[i, j] = 1
                elif (self.game[i, j] >= 0):
                    probability[i, j] = 0 ## place currently standing on obviously isn't mine, therefore 0 prob.
                    unknown_cells_around = []
                    num_of_revealed_mines_around = 0
                    for neighbour in self.neighbors[i, j]:
                        x = neighbour[0]
                        y = neighbour[1]
                        variable = x*self.columns+y ## coding coordinates into unique int
                        ## considering only unknown (potentially mine) positions into constraint
                        if (self.game[x, y] == UNKNOWN):
                            if variable not in mines_position_problem._variables:
                                mines_position_problem.addVariable(variable, [0, 1])
                            unknown_cells_around.append(variable) ## getting all unknown neighbours into list for new constraint
                        elif (self.game[x, y] == MINE):
                            num_of_revealed_mines_around += 1 ## mine is excluded from the constraint, only number of remaining mines will be reduced

                    required_sum = self.game[i, j].item() - num_of_revealed_mines_around ## calling .item() to convert the number into python int
                    if (len(unknown_cells_around) > 0):
                        ## creating constraint requiring the sum of all variables to be equal to the number of mines around
                        ## since all unknown places have value [0,1], it only represents IS / IS NOT MINE and therefore
                        ## the sum is equal to the required mine_around number when that same number of mines is placed around
                        mines_position_problem.addConstraint(ct.ExactSumConstraint(required_sum), unknown_cells_around)
                        for unknown in unknown_cells_around:
                            if unknown not in all_unknown_positions:
                                all_unknown_positions.add(unknown)

        solutions = mines_position_problem.getSolutions()
        ## getting the probability of the individual phenomen
        ## saving it as a new key/value pair in the solution dict   
        for solution in solutions:
            solution["prb"] = 1
            for variable in solution:
                if solution[variable] == 1:
                    solution["prb"] *= self.mine_prb
                if solution[variable] == 0:
                    solution["prb"] *= (1-self.mine_prb)

    
        ## obtaining the probability for each unknown position
        for unknown_position in all_unknown_positions:
            ## obtaining back the coordinates from the int representation
            x = unknown_position // self.columns
            y = unknown_position % self.columns
            true_prbs = 0
            false_prbs = 0
            for solution in solutions:
                if solution[unknown_position] == 1:
                    true_prbs += solution["prb"]
                else:
                    false_prbs += solution["prb"]

            true_prb = true_prbs
            false_prb = false_prbs
            if (true_prb+false_prb == 0):
                probability[x, y] = 0
            else:
                alpha = 1/(true_prb+false_prb) ## normalization
                probability[x, y] = alpha*true_prb ## final prob

        return probability
