import math
import random
from builtins import Exception
from itertools import chain
from typing import Sequence, Any, Tuple, List

from gym import spaces as spaces

from centralized_verification.MultiAgentAPEnv import MultiAgentSafetyEnv
from centralized_verification.envs.utils import add_posn


def flatten(iterator_to_flatten):
    return list(chain.from_iterable(iterator_to_flatten))


def posn_floor(p):
    px, py = p
    return math.floor(px), math.floor(py)


def posn_dist(p1, p2):
    return math.sqrt(((p1[0] - p2[0]) ** 2) + ((p1[1] - p2[1]) ** 2))


class ContinuousGridWorld(MultiAgentSafetyEnv):
    """
    A grid world with many pre-computed properties for extremely fast steps
    (well, as fast as you can get with python)

    Each agent has a position and a direction. It has the following cone of visibility
    (assuming the agent is facing up):

    The observations are as such:
    0 = Empty
    1 = Filled with another agent
    2 = Wall

    This environment generates a single AP, representing if any two agents have collided with each other.

    Agents have five actions:
    0 = Do nothing
    1 = Move up
    2 = Right
    3 = Down
    4 = Left

    (All decreased by one if no_idle is true, with do nothing not allowed)
    """

    def __init__(self, grid_posns, num_agents, start_idx, ending_idx, randomize_starts: bool = False,
                 collision_reward=-30, agents_bounce: bool = False, terminate_on_collision: bool = False,
                 no_idle: bool = False, reward_shaping: bool = False, reward_shaping_factor: float = 1.0):

        # Want to have same constructor interface as regular gridworld, but can't support all of the params
        if agents_bounce:
            raise Exception("Agents bounce not supported on continuous gridworld")

        self.grid_posns = grid_posns
        self.num_agents = num_agents
        self.width = max(pos[0] for pos in self.grid_posns)
        self.height = max(pos[1] for pos in self.grid_posns)

        # posn -> idx
        self.grid_posn_inv = {pos: idx for idx, pos in enumerate(self.grid_posns)}

        assert len(start_idx) == num_agents
        assert len(ending_idx) == num_agents

        self.start_idx = start_idx
        self.goal_idx = ending_idx

        self.randomize_starts = randomize_starts
        self.collision_cost = collision_reward
        self.terminate_on_collision = terminate_on_collision
        self.no_idle = no_idle
        self.reward_shaping = reward_shaping
        self.reward_shaping_factor = reward_shaping_factor

    def ap_names(self) -> List[str]:
        return ["collision"]

    def agent_obs_spaces(self) -> Sequence[spaces.Space]:
        return [spaces.Box(-1, 1, shape=(2 * self.num_agents,))] * self.num_agents

    def agent_actions_spaces(self) -> Sequence[spaces.Space]:
        num_actions = 4 if self.no_idle else 5
        return [spaces.Discrete(num_actions)] * self.num_agents

    def state_space(self) -> spaces.Space:
        return spaces.Box(0, max(self.width, self.height), shape=(2 * self.num_agents,))

    def initial_state(self):
        if self.randomize_starts:
            starting_locs = []
            for _ in range(self.num_agents):
                good_locs = [p for p in self.grid_posns if all(posn_dist(sl, p) >= 2 for sl in starting_locs)]
                starting_locs.append(random.choice(good_locs))
        else:
            starting_locs = [self.grid_posns[idx] for idx in self.start_idx]

        starting_space_in_loc = .5, .5

        new_locs = [add_posn(loc, starting_space_in_loc) for loc in starting_locs]
        starting_state = tuple(flatten(new_locs))

        return starting_state, self.project_obs(starting_state)

    def get_next_loc(self, loc, act):
        dist_to_move = random.random()
        actions = [(0, 0),
                   (0, -dist_to_move),
                   (dist_to_move, 0),
                   (0, dist_to_move),
                   (-dist_to_move, 0)]

        next_loc = add_posn(loc, actions[act])
        grid_loc = posn_floor(next_loc)
        if grid_loc in self.grid_posn_inv:
            return next_loc
        else:
            return loc

    def step(self, environment_state, joint_action: Sequence[Any]) -> Tuple[
        Any, Sequence[Any], Sequence[float], bool, bool]:

        if self.no_idle:
            joint_action = list(a + 1 for a in joint_action)

        locs = [environment_state[2 * i:2 * i + 2] for i in range(self.num_agents)]
        new_locs = [self.get_next_loc(loc, act) for loc, act in zip(locs, joint_action)]

        collision = [False] * self.num_agents
        for i in range(self.num_agents):
            for j in range(i + 1, self.num_agents):
                if i != j and posn_dist(new_locs[i], new_locs[j]) < 1:
                    collision[i] = True
                    collision[j] = True

        reached_goal = all(posn_floor(loc) == self.grid_posns[goal] for loc, goal in
                           zip(new_locs, self.goal_idx))
        done = reached_goal

        def base_rew_for_agent(loc, old_loc, action):
            if reached_goal:
                return 100
            elif loc == old_loc and action != 0:  # Hit a wall
                return -10
            else:
                return -1

        def rew_for_agent(loc, old_loc, action, goal_loc):
            base_rew = base_rew_for_agent(loc, old_loc, action)
            if self.reward_shaping:
                old_dist_to_goal = posn_dist(old_loc, goal_loc)
                new_dist_to_goal = posn_dist(loc, goal_loc)
                return base_rew + ((old_dist_to_goal - new_dist_to_goal) * self.reward_shaping_factor)
            else:
                return base_rew

        rewards = [rew_for_agent(loc, new_loc, joint_action[i], self.grid_posns[self.goal_idx[i]]) for i, (loc, new_loc)
                   in enumerate(zip(locs, new_locs))]

        if any(collision):
            rewards = [rew + self.collision_cost if coll else rew for rew, coll in zip(rewards, collision)]

            if self.terminate_on_collision:
                done = True

        new_env_state = tuple(flatten(new_locs))
        return new_env_state, self.project_obs(new_env_state), rewards, done, (not any(collision))

    def project_obs(self, state) -> Sequence[Any]:
        obs = tuple(
            ((num * 2) / div) - 1 for num, div in zip(state, [self.width + 1, self.height + 1] * self.num_agents))
        return tuple([obs] * self.num_agents)
