from collections import defaultdict
from typing import Tuple, Optional

from centralized_verification.MultiAgentAPEnv import MultiAgentSafetyEnv
from centralized_verification.shields.partial_obs.asym_obs_shields import PartialObsCentralizedShieldSpec
from centralized_verification.shields.partial_obs.pobs_label_extractor import PobsLabelExtractor
from centralized_verification.shields.shield import ShieldResult, AbstractShield
from centralized_verification.shields.utils import create_random_priority

pobs_centralized_shield_state_type = Optional[Tuple[int, Tuple[int, ...]]]


def find_lowest_edit_distance(proposed_action, allowed_actions):
    min_dist = len(proposed_action) + 1
    min_action = None
    for action in allowed_actions:
        dist = sum(1 for a, b in zip(proposed_action, action) if a != b)
        if dist < min_dist:
            min_dist = dist
            min_action = action
    return min_action


class PobsLabelShieldCentralized(AbstractShield[MultiAgentSafetyEnv, pobs_centralized_shield_state_type]):
    def __init__(self, shield_spec: PartialObsCentralizedShieldSpec, label_extractor: PobsLabelExtractor, **kwargs):
        super().__init__(**kwargs)
        self.label_extractor = label_extractor
        self.shield_spec = shield_spec

        mut_label_state_map = defaultdict(lambda: set())
        for state_num, state in self.shield_spec.states.items():
            mut_label_state_map[state.agent_observations].add(state_num)

        self.label_state_map = {obs: frozenset(states) for obs, states in mut_label_state_map.items()}
        self.initial_states = frozenset(
            state_num for state_num, state in self.shield_spec.states.items() if state.initial_state)

    def get_initial_shield_state(self, state, initial_joint_obs) -> pobs_centralized_shield_state_type:
        return None

    def evaluate_joint_action(self, state, joint_obs, proposed_action,
                              shield_state: pobs_centralized_shield_state_type) -> Tuple[
        ShieldResult, pobs_centralized_shield_state_type]:
        if shield_state is None:
            possible_states_given_prev = self.initial_states
        else:
            prev_state_num, prev_action = shield_state
            possible_states_given_prev = self.shield_spec.states[prev_state_num].actions[prev_action]

        label = self.label_extractor(state, joint_obs)
        ordered_label = tuple(
            tuple(indiv_label[s] for s in self.shield_spec.labels[agent_num]) for agent_num, indiv_label in
            enumerate(label))

        possible_states_given_label = self.label_state_map[ordered_label]

        possible_states = list(possible_states_given_prev.intersection(possible_states_given_label))
        if len(possible_states) == 0:
            raise ValueError("Shield is not descriptive of the environment.")
        elif len(possible_states) > 1:
            raise ValueError("Shield is not deterministic.")
        else:
            state_num = possible_states[0]

        action_set = frozenset(self.shield_spec.states[state_num].actions.keys())

        # default_action = next(iter(action_set))
        default_action = find_lowest_edit_distance(proposed_action, action_set)
        shield_result, _ = self.neuter_agents_in_priority(proposed_action, action_set,
                                                          create_random_priority(len(proposed_action)),
                                                          default_action, default_action)

        action_taken = tuple(agent_shield_result.real_action.action for agent_shield_result in shield_result)

        return shield_result, (state_num, action_taken)
