import random
from copy import copy
from typing import Sequence, Any, Tuple

from gym import spaces as spaces

from centralized_verification.envs.fast_grid_world import FastGridWorld
from centralized_verification.envs.fast_grid_world_nearby_obs_simple import posn_dist_manhattan


class FlashlightGridWorld(FastGridWorld):
    """
    Agents have ten actions:
    0 = Do nothing
    1 = Move up
    2 = Right
    3 = Down
    4 = Left
    5 = Turn on flashlight and do nothing
    6 = Turn on flashlight and move up
    7 = Turn on flashlight and move right
    8 = Turn on flashlight and move down
    9 = Turn on flashlight and move left

    The agents can only turn on the flashlight once every flashlight_recharge_time steps (flashlight_recharge_time=1 means agents can use it on every step).
    It symmetrically increases the observation radius of agents from obs_radius to flashlight_obs_radius.
    """

    def __init__(self, *args, obs_radius=1, flashlight_obs_radius=5, flashlight_recharge_time=4, **kwargs):
        self.obs_radius = obs_radius
        self.flashlight_obs_radius = flashlight_obs_radius
        assert flashlight_recharge_time > 0, "Flashlight recharge time must be positive"
        self.flashlight_recharge_time = flashlight_recharge_time
        super().__init__(*args, **kwargs)

    def agent_obs_spaces(self) -> Sequence[spaces.Space]:
        single_obs = [len(self.grid_posns)] + [2] + [len(self.grid_posns) + 1] * (self.num_agents - 1)
        return [spaces.MultiDiscrete(single_obs)] * self.num_agents

    def agent_actions_spaces(self) -> Sequence[spaces.Space]:
        return [spaces.Discrete(10)] * self.num_agents

    def state_space(self) -> spaces.Space:
        return spaces.MultiDiscrete(
            [len(self.grid_posns)] * self.num_agents + [self.flashlight_recharge_time] * self.num_agents)

    def initial_state(self):
        if self.randomize_starts:
            starting_locs = random.sample(range(len(self.grid_posns)), self.num_agents)
            starting_flashlight_recharge = [random.randint(0, self.flashlight_recharge_time) for _ in
                                            range(self.num_agents)]
        else:
            starting_locs = copy(self.start_idx)
            starting_flashlight_recharge = [0] * self.num_agents

        starting_flashlight_succeeded = [False] * self.num_agents

        init_state = starting_locs + starting_flashlight_recharge
        obs = [self.project_obs_agent(starting_locs, i, starting_flashlight_succeeded) for i in range(self.num_agents)]

        return init_state, obs

    def step(self, environment_state, joint_action: Sequence[Any]) -> Tuple[
        Any, Sequence[Any], Sequence[float], bool, bool]:

        env_state_movement = environment_state[:self.num_agents]
        joint_action_movement = tuple(ja % 5 for ja in joint_action)
        new_env_state_movement, _, rewards, done, safe = super().step(env_state_movement, joint_action_movement)

        env_state_flashlight = environment_state[self.num_agents:]
        attempted_flashlight_actions = tuple(ja // 5 for ja in joint_action)

        flashlight_succeeded = [False] * self.num_agents
        new_env_state_flashlight = list(env_state_flashlight)

        for i, (attempted_flashlight_action, flashlight_recharge_remaining) in enumerate(
                zip(attempted_flashlight_actions, env_state_flashlight)):
            if attempted_flashlight_action == 1 and flashlight_recharge_remaining == 0:
                flashlight_succeeded[i] = True
                new_env_state_flashlight[i] = self.flashlight_recharge_time - 1
            else:
                new_env_state_flashlight[i] = max(0, flashlight_recharge_remaining - 1)

        obs = [self.project_obs_agent(new_env_state_movement, i, flashlight_succeeded) for i in range(self.num_agents)]
        return new_env_state_movement + tuple(new_env_state_flashlight), obs, rewards, done, safe

    def project_obs_agent(self, state, agent_num, flashlights: Sequence[bool]):
        agent_place_num = state[agent_num]
        agent_flashlight_on = flashlights[agent_num]
        agent_posn = self.grid_posns[agent_place_num]
        obs = [agent_place_num, int(agent_flashlight_on)]

        for i, (other_agent_place, other_agent_flashlight) in enumerate(zip(state, flashlights)):
            if i == agent_num:
                continue
            else:
                other_agent_posn = self.grid_posns[other_agent_place]
                can_see_other_agent = posn_dist_manhattan(agent_posn, other_agent_posn) <= self.obs_radius or \
                                      ((agent_flashlight_on or other_agent_flashlight) and posn_dist_manhattan(
                                          agent_posn, other_agent_posn) <= self.flashlight_obs_radius)
                if can_see_other_agent:
                    obs.append(other_agent_place)
                else:
                    obs.append(len(self.grid_posns))

        return tuple(obs)
