import abc
from typing import Optional, Any, List, NamedTuple, Generic, TypeVar, Tuple, AbstractSet


# Shield result:
# real_action corresponds to an action that will actually take place in the environment
# It should be a safe joint action, because this action will actually occur
# However, sometimes, the shield may
#
# can allow action (None) or replace it (return tuple of joint action),
# can punish (negative) or leave unchanged (zero). NOTE this applies to the transition which is actually taken
# If replaced, can generate "fake" extra transition for training (true) or only the transition actually taken (false)
# If the extra transition is created, how much punishment (negative) to assign to it, or leave unchanged (zero)


class AgentUpdate(NamedTuple):
    action: Any
    reward_modifier: Optional[int] = None
    absolute_reward: Optional[int] = None

    def get_modified_reward(self, rew_from_env):
        if self.absolute_reward is not None:
            return self.absolute_reward
        elif self.reward_modifier is not None:
            return self.reward_modifier + rew_from_env
        else:
            return rew_from_env


class AgentResult(NamedTuple):
    real_action: AgentUpdate
    augmented_action: Optional[AgentUpdate] = None


ShieldResult = List[AgentResult]

# If there is some centralized training component, all augmented_actions should be treated as a single joint action

T = TypeVar("T")
S = TypeVar("S")


class Shield(Generic[T, S], abc.ABC):
    @abc.abstractmethod
    def get_initial_shield_state(self, state, initial_joint_obs) -> S:
        pass

    @abc.abstractmethod
    def evaluate_joint_action(self, state, joint_obs, proposed_action, shield_state: S) -> Tuple[ShieldResult, S]:
        pass


class AbstractShield(Shield[T, S], Generic[T, S], abc.ABC):
    def __init__(self, env: T, punish_unsafe_orig_action=False, unsafe_action_punishment=0):
        self.punish_unsafe_orig_action = punish_unsafe_orig_action
        self.unsafe_action_punishment = unsafe_action_punishment
        self.env = env

    def replace_action_agent_result(self, unsafe_action, safe_action) -> AgentResult:
        if self.punish_unsafe_orig_action:
            return AgentResult(real_action=AgentUpdate(action=safe_action),
                               augmented_action=AgentUpdate(action=unsafe_action,
                                                            reward_modifier=self.unsafe_action_punishment))
        else:
            return AgentResult(real_action=AgentUpdate(action=safe_action))

    def neuter_agents_in_priority(self, proposed_action, allowed_actions: AbstractSet[Tuple[int, ...]],
                                  priority: List[int], action_to_modify_to, default_action):
        shield_result = [AgentResult(AgentUpdate(action=action)) for action in proposed_action]

        if proposed_action in allowed_actions:
            return shield_result, proposed_action

        proposed_action_list = list(proposed_action)
        for agent_to_neuter in priority:  # Set individual agents to the default action until the joint action is safe

            proposed_action_list[agent_to_neuter] = action_to_modify_to[agent_to_neuter]
            shield_result[agent_to_neuter] = self.replace_action_agent_result(proposed_action[agent_to_neuter],
                                                                              action_to_modify_to[agent_to_neuter])

            if tuple(proposed_action_list) in allowed_actions:
                return shield_result, proposed_action_list

        shield_result = [AgentResult(AgentUpdate(
            action=taken_indiv_action)) if taken_indiv_action == proposed_indiv_action else self.replace_action_agent_result(
            proposed_indiv_action, taken_indiv_action) for proposed_indiv_action, taken_indiv_action in
                         zip(proposed_action, default_action)]
        return shield_result, default_action
