from typing import Tuple

from centralized_verification.MultiAgentAPEnv import MultiAgentSafetyEnv
from centralized_verification.shields.partial_obs.pobs_label_extractor import PobsLabelExtractor
from centralized_verification.shields.partial_obs.pobs_label_shield_spec import PobsLabelShieldSpec, label_step_type, \
    agent_shield_state_type, shield_state_type
from centralized_verification.shields.shield import ShieldResult, AgentResult, AgentUpdate, AbstractShield


class PobsLabelShield(AbstractShield[MultiAgentSafetyEnv, shield_state_type]):
    def __init__(self, shield_spec: PobsLabelShieldSpec, label_extractor: PobsLabelExtractor, **kwargs):
        super().__init__(**kwargs)
        self.label_extractor = label_extractor
        self.shield_spec = shield_spec

    def get_initial_shield_state(self, state, initial_joint_obs) -> shield_state_type:
        return tuple(() for _ in range(len(self.shield_spec.label_names)))

    def evaluate_individual_action(self, agent_num, shield_state: agent_shield_state_type,
                                   proposed_action) -> AgentResult:
        allowed_actions = self.shield_spec.allowed_actions[agent_num][shield_state]
        if proposed_action in allowed_actions:
            return AgentResult(real_action=AgentUpdate(action=proposed_action))
        else:
            chosen_safe_action = next(iter(allowed_actions))
            return self.replace_action_agent_result(proposed_action, chosen_safe_action)

    def add_to_shield_state(self, prev: shield_state_type, new: label_step_type) -> shield_state_type:
        untrunc_new = tuple(p + (n,) for p, n in zip(prev, new))
        return tuple(n[-(self.shield_spec.history_len + 1):] for n in untrunc_new)

    def evaluate_joint_action(self, state, joint_obs, proposed_action, shield_state: shield_state_type) -> Tuple[
        ShieldResult, shield_state_type]:
        joint_label = self.label_extractor(state, joint_obs)
        joint_int_labels = tuple(tuple(indiv_label[s] for s in indiv_label_names) for indiv_label, indiv_label_names in
                                 zip(joint_label, self.shield_spec.label_names))
        shield_state = self.add_to_shield_state(shield_state, joint_int_labels)
        shield_result = [self.evaluate_individual_action(i, indiv_shield_state, agent_proposed_action) for
                         i, (indiv_shield_state, agent_proposed_action) in
                         enumerate(zip(shield_state, proposed_action))]
        return shield_result, shield_state
