from typing import Dict, Tuple, FrozenSet, NamedTuple, List, Any

import regex

from centralized_verification.shields.combine_identical_states import ShieldSpec


class PartialObsCentralizedShieldState(NamedTuple):
    initial_state: bool
    agent_observations: Tuple[FrozenSet[Tuple[Any, ...]], ...]  # Agent -> Set of possible observations
    actions: Dict[Tuple[int, ...], FrozenSet[int]]
    hidden_values: Tuple[FrozenSet[Any], ...]
    state_num: int


class PartialObsCentralizedShieldSpec(NamedTuple):
    states: Dict[int, PartialObsCentralizedShieldState]
    labels: Tuple[Tuple[str, ...], ...]
    hidden_labels: Tuple[str, ...]


def to_partial_observations_shield(sspec: ShieldSpec, num_agents: int, obs_info: List[Tuple[str, List[int], str]],
                                   hidden_labels: List[str]) -> PartialObsCentralizedShieldSpec:
    # obs_info: full name of obs, list of agents with this obs, "nice name" of obs
    dest_shield_spec = dict()

    label_order = [[] for _ in range(num_agents)]
    label_names = [[] for _ in range(num_agents)]

    for label_position, (_, agents_with_observation, obs_nice_name) in enumerate(obs_info):
        for agent in agents_with_observation:
            label_order[agent].append(label_position)
            label_names[agent].append(obs_nice_name)

    label_names = tuple(tuple(n) for n in label_names)

    for snum, sval in sspec.items():
        agent_obs = tuple(
            frozenset(tuple(
                possible_label[label_pos]
                for label_pos in agent_label_order
            ) for possible_label in sval.label)
            for agent_label_order in label_order
        )

        dest_shield_spec[snum] = PartialObsCentralizedShieldState(
            initial_state=sval.initial_state,
            agent_observations=agent_obs,
            actions=sval.actions,
            hidden_values=tuple(frozenset(s) for s in sval.hidden_values),
            state_num=sval.state_num
        )

    return PartialObsCentralizedShieldSpec(dest_shield_spec, label_names, tuple(hidden_labels))


def parse_shield_labels(label_names):
    action_labels_re = regex.compile(r"agent_(\d+)_action")
    obs_labels_re = regex.compile(r"agent_(?:(?:(\d+)_)*(\d+))_obs_(.+)")

    action_labels = []
    obs_labels = []
    seen_violation = False
    hidden_labels = []

    for label in label_names:
        action_match = action_labels_re.match(label)
        if action_match:
            action_labels.append((label, int(action_match.group(1))))
            continue

        obs_match = obs_labels_re.match(label)
        if obs_match:
            agents_with_this_obs = [int(i) for i in obs_match.captures(1) + obs_match.captures(2)]
            obs_name = obs_match.group(3)
            obs_labels.append((label, agents_with_this_obs, obs_name))
            continue

        if label == "violation":
            seen_violation = True
            continue

        else:
            hidden_labels.append(label)

    if not seen_violation:
        raise Exception("No violation label found")

    action_labels.sort(key=lambda a: a[1])

    raw_obs_labels = [label for label, _, _ in obs_labels]
    raw_action_labels = [label for label, _ in action_labels]

    return raw_obs_labels, obs_labels, raw_action_labels, hidden_labels



def partial_obs_centralized_shield_to_json(sspec: PartialObsCentralizedShieldSpec):
    state_map = {}

    max_actions = [0] * len(sspec.labels)

    for state_num, state in sspec.states.items():
        actions = [{
            "action": list(action),
            "successors": list(following_states)
        } for action, following_states in state.actions.items()]

        for new_joint_action in state.actions.keys():
            max_actions = [max(prev_max_action, new_action + 1) for prev_max_action, new_action in
                           zip(max_actions, new_joint_action)]

        state_map[str(state_num)] = {
            "observations": list(map(list, state.agent_observations)),
            "actions": actions,
            "initial": state.initial_state,
            "hidden": list(map(list, state.hidden_values))
        }

    return {
        "obs_names": sspec.labels,
        "hidden_obs_names": sspec.hidden_labels,
        "action_space": max_actions,
        "shield_states": state_map
    }


def partial_obs_centralized_shield_from_json(json_dict) -> PartialObsCentralizedShieldSpec:
    states = {}

    for state_num, state in json_dict["shield_states"].items():
        actions = {}
        for action_info in state["actions"]:
            actions[tuple(action_info["action"])] = frozenset(action_info["successors"])

        states[int(state_num)] = PartialObsCentralizedShieldState(
            state_num=int(state_num),
            initial_state=state["initial"],
            hidden_values=tuple(map(frozenset, state["hidden"])),
            agent_observations=tuple(
                frozenset(tuple(possible_label) for possible_label in agent_label) for agent_label in
                state["observations"]),
            actions=actions
        )

    return PartialObsCentralizedShieldSpec(
        states=states,
        labels=tuple(map(tuple, json_dict["obs_names"])),
        hidden_labels=tuple(json_dict["hidden_obs_names"])
    )


