import itertools
from typing import List, Tuple

import numpy as np
from gym import spaces

from centralized_verification.shields.shield import ShieldResult, AgentResult, AgentUpdate, T, AbstractShield
from centralized_verification.shields.utils import decentralize_actions
from centralized_verification.violation_specification.specified_env import SpecifiedEnvironment
from centralized_verification.violation_specification.violation_specification import ViolationSpecification


def spec_to_action_set(env: SpecifiedEnvironment, spec: ViolationSpecification, current_spec_state) -> np.ndarray:
    all_agent_actions_spaces = env.agent_actions_spaces()
    relevant_action_spaces = [all_agent_actions_spaces[agent_num] for agent_num in spec.agent_nums()]

    # Creates the joint safe action space for all discrete-actioned agents in the environment
    relevant_action_space_dims = tuple(
        space.n for space in relevant_action_spaces if isinstance(space, spaces.Discrete))

    safe_action_space = np.zeros(relevant_action_space_dims, dtype=bool)
    for joint_action in itertools.product(*[range(i) for i in relevant_action_space_dims]):
        safe_action_space[joint_action] = not spec.violates_spec(current_spec_state, joint_action)

    return safe_action_space


def translate_priority(global_priority: List[int], relevant_agents: List[int]):
    """
    Project a global priority onto just a few agents
    Examples:
    Input: global_priority [2, 0, 1, 3, 4]
    Relevant: [3, 1, 4]
    Expected output: [1, 0, 2]
    """
    # Filter global_priority to only include agents we are interested in
    # Currently [1, 3, 4]
    interesting_agents_global_priority = [agent for agent in global_priority if agent in relevant_agents]

    # Renumber agents
    renumbered_agents = [relevant_agents.index(agent) for agent in interesting_agents_global_priority]
    return renumbered_agents


class ViolationSpecificationShield(AbstractShield[SpecifiedEnvironment, None]):
    def get_initial_shield_state(self, state, initial_joint_obs) -> T:
        return None

    def __init__(self, env: SpecifiedEnvironment, punish_unsafe_orig_action=False, unsafe_action_punishment=0):
        super().__init__(env, punish_unsafe_orig_action, unsafe_action_punishment)
        self.specs = env.get_all_specifications()

    def shield_single_agent(self, agent_num, obs, proposed_action, priority) -> AgentResult:
        relevant_specs = [spec for spec in self.specs if agent_num in spec.agent_nums()]
        for spec in relevant_specs:
            if agent_num not in spec.agent_nums():  # Concerns different agents anyways
                continue

            spec_state = spec.extract_state_from_obs(obs, agent_num)
            if spec_state is None:  # Visibility constraint is not met
                continue

            spec_priority = translate_priority(priority, spec.agent_nums())
            local_safe_action_set = spec_to_action_set(self.env, spec, spec_state)
            guaranteed_safe_action = [0] * len(spec_priority)
            decentralized_safe_actions = decentralize_actions(local_safe_action_set, guaranteed_safe_action,
                                                              spec_priority)

            this_agent_spec_idx = spec.agent_nums().index(agent_num)
            safe_factorized_spec_actions = decentralized_safe_actions[this_agent_spec_idx]

            if not safe_factorized_spec_actions[proposed_action]:  # It is unsafe
                return self.replace_action_agent_result(proposed_action, 0)

        # We've gone through all the specs and none were violated
        return AgentResult(AgentUpdate(action=proposed_action))

    def evaluate_joint_action(self, _, joint_obs, proposed_action, __) -> Tuple[ShieldResult, None]:
        # Assume the priority is generated on all agents independently based on shared random seed
        # noinspection PyTypeChecker
        priority: List[int] = np.random.permutation(np.arange(len(proposed_action))).tolist()  # Agents in random order

        res = []
        for agent_num, single_obs in enumerate(joint_obs):
            res.append(self.shield_single_agent(agent_num, single_obs, proposed_action[agent_num], priority))

        return res, None
