from collections import defaultdict
from typing import NamedTuple, Tuple, Dict, List, FrozenSet, Callable, Set, Any

import tqdm.auto as tqdm

VISIBLE_LABELS = ["rel_pos_x", "rel_pos_y"]
ACTION_LABELS = ["agent_0_action", "agent_1_action"]
HIDDEN_LABELS = ["agent_0_xmove", "agent_0_ymove", "agent_1_xmove", "agent_1_ymove"]


class HiddenValueShieldState(NamedTuple):
    label: FrozenSet[Tuple[Any, ...]]
    prev_actions: Tuple[int, ...]
    next_states: FrozenSet[int]
    hidden_values: Tuple[Any, ...]
    violation: bool
    state_num: int


class OutgoingActionShieldState(NamedTuple):
    label: FrozenSet[Tuple[int, ...]]  # Set of possible observations
    initial_state: bool
    actions: Dict[Tuple[int, ...], FrozenSet[int]]  # Action -> List of states
    hidden_values: Tuple[Set[Any], ...]
    state_num: int

    def __hash__(self):
        return hash((self.label, frozenset(self.actions.items()), self.initial_state))

    def __eq__(self, other):
        return isinstance(other,
                          OutgoingActionShieldState) and self.label == other.label and self.initial_state == other.initial_state and self.actions == other.actions


ShieldSpec = Dict[int, OutgoingActionShieldState]


class ActionPermutation(NamedTuple):
    """
    As long as all agents select the same action permutation at the same time (e.g. using a PRNG with the same seed),
    the resulting joint action will be safe for everyone.
    """
    agent_order: Tuple[int, ...]  # Doesn't actually matter at runtime, this is just for info
    actions: List[Tuple[int, ...]]
    next_states: FrozenSet[int]


class DecentralizedShieldState(NamedTuple):
    label: Tuple[int, ...]
    initial_state: bool
    action_permutations: List[ActionPermutation]
    hidden_values: Tuple[Set[int], ...]


DecentralizedShieldSpec = Dict[int, DecentralizedShieldState]


def move_actions_to_outgoing_transitions(input_dict: Dict[int, HiddenValueShieldState],
                                         initial_cond: Callable[[HiddenValueShieldState], bool]) -> Dict[
    int, OutgoingActionShieldState]:
    output_dict: Dict[int, OutgoingActionShieldState] = {}

    for state_num, shield_state in tqdm.tqdm(input_dict.items(), desc="Processing next-state actions"):
        outgoing_actions = defaultdict(set)
        for next_state in shield_state.next_states:
            outgoing_actions[input_dict[next_state].prev_actions].add(next_state)

        frozen_outgoing_actions = dict(
            (action, frozenset(next_states)) for action, next_states in outgoing_actions.items())
        output_dict[state_num] = OutgoingActionShieldState(shield_state.label, initial_cond(shield_state),
                                                           frozen_outgoing_actions,
                                                           tuple({i} for i in shield_state.hidden_values), state_num)

    return output_dict


def combine_identical_states(input_dict: ShieldSpec) -> ShieldSpec:
    intermediate_shield_states: List[OutgoingActionShieldState] = []
    shield_state_reverse_mapping: Dict[OutgoingActionShieldState, Tuple[int, Tuple[Set[int], ...]]] = {}
    input_shield_num_mapping: Dict[int, int] = {}

    # Find all states which have the same
    for input_shield_num, input_shield_state in tqdm.tqdm(input_dict.items(), desc="Finding identical states"):
        if input_shield_state in shield_state_reverse_mapping:
            corresponding_new_shield_state, collected_hidden_values = shield_state_reverse_mapping[input_shield_state]
            for hid_val_set, this_hidden_value in zip(collected_hidden_values, input_shield_state.hidden_values):
                hid_val_set.update(this_hidden_value)
        else:
            corresponding_new_shield_state = len(intermediate_shield_states)
            intermediate_shield_states.append(input_shield_state)
            shield_state_reverse_mapping[input_shield_state] = (
                corresponding_new_shield_state, input_shield_state.hidden_values)

        input_shield_num_mapping[input_shield_num] = corresponding_new_shield_state

    def remap_single_shield_state(input: Tuple[int, OutgoingActionShieldState]) -> OutgoingActionShieldState:
        state_num, shield_state = input
        return OutgoingActionShieldState(label=shield_state.label,
                                         actions=dict(
                                             (action, frozenset(
                                                 input_shield_num_mapping[next_state] for next_state in next_states))
                                             for action, next_states in shield_state.actions.items()),
                                         initial_state=shield_state.initial_state,
                                         hidden_values=shield_state.hidden_values,
                                         state_num=state_num)

    intermediate_shield_states_remapped = list(map(remap_single_shield_state, enumerate(intermediate_shield_states)))

    return dict(enumerate(intermediate_shield_states_remapped))


def eliminate_bad_states(input_dict: ShieldSpec) -> ShieldSpec:
    bad_states = set()
    t = tqdm.tqdm(total=len(input_dict) * 2, desc="Removing deadlock states")
    for state_num, state in input_dict.items():
        t.update(1)
        if len(state.actions) == 0:
            bad_states.add(state_num)

    output_dict = {}
    for state_num, state in input_dict.items():
        t.update(1)
        if state_num in bad_states:
            continue

        outgoing_actions = dict((action, successors) for action, successors in state.actions.items() if
                                len(successors.intersection(bad_states)) == 0)
        output_dict[state_num] = OutgoingActionShieldState(state.label, state.initial_state, outgoing_actions,
                                                           state.hidden_values, state_num)

    return output_dict


def iterate_shield_cleanup(shield: ShieldSpec):
    while True:
        combined_states = combine_identical_states(shield)
        bad_states_removed = eliminate_bad_states(combined_states)

        if len(bad_states_removed) == len(shield):
            return shield

        shield = bad_states_removed

