import abc
from typing import Tuple

from centralized_verification.shields.shield import Shield, S, ShieldResult, AgentResult, AgentUpdate


class ActionTransformer(abc.ABC):
    @abc.abstractmethod
    def transform_env_to_shield(self, env_action):
        pass

    @abc.abstractmethod
    def transform_shield_to_env(self, shield_action, aux_info):
        pass


class ActionTransformerShieldWrapper(Shield):
    def __init__(self, shield: Shield, action_transformer, **kwargs):
        super().__init__(**kwargs)
        self.shield = shield
        self.action_transformer = action_transformer

    def get_initial_shield_state(self, state, initial_joint_obs) -> S:
        return self.shield.get_initial_shield_state(state, initial_joint_obs)

    def transform_agent_update(self, agent_update: AgentUpdate, aux_info) -> AgentUpdate:
        return AgentUpdate(
            action=self.action_transformer.transform_shield_to_env(agent_update.action, aux_info),
            reward_modifier=agent_update.reward_modifier,
            absolute_reward=agent_update.absolute_reward
        )

    def transform_agent_result(self, agent_result: AgentResult, aux_info) -> AgentResult:
        return AgentResult(
            real_action=self.transform_agent_update(agent_result.real_action, aux_info),
            augmented_action=self.transform_agent_update(agent_result.augmented_action,
                                                         aux_info) if agent_result.augmented_action is not None else None
        )

    def evaluate_joint_action(self, state, joint_obs, proposed_action, shield_state: S) -> Tuple[ShieldResult, S]:
        shield_joint_action = []
        aux_info = []
        for env_action in proposed_action:
            shield_action, aux = self.action_transformer.transform_env_to_shield(env_action)
            shield_joint_action.append(shield_action)
            aux_info.append(aux)

        shield_result, new_shield_state = self.shield.evaluate_joint_action(state, joint_obs,
                                                                            tuple(shield_joint_action), shield_state)
        env_result = [self.transform_agent_result(agent_result, aux) for agent_result, aux in
                      zip(shield_result, aux_info)]
        return env_result, new_shield_state


class FlashlightActionTransformer(ActionTransformer):
    def transform_env_to_shield(self, env_action):
        return env_action % 5, env_action // 5

    def transform_shield_to_env(self, shield_action, aux_info):
        return shield_action + 5 * aux_info
