import abc

import numpy as np
from gym import spaces

from centralized_verification.MultiAgentAPEnv import MultiAgentSafetyEnv
from centralized_verification.shields.shield import AbstractShield
from centralized_verification.shields.utils import env_to_action_set
from centralized_verification.utils import convert_gym_space_to_q_shape, is_gym_space_suitable_for_table_idx


class AheadOfTimeShield(AbstractShield[MultiAgentSafetyEnv, None], abc.ABC):

    def __init__(self, env: MultiAgentSafetyEnv, punish_unsafe_orig_action=False, unsafe_action_punishment=0,
                 cache_shield_result: bool = True, calc_ahead_of_time: bool = False):
        super().__init__(env, punish_unsafe_orig_action, unsafe_action_punishment)

        cache_shield_result = cache_shield_result and is_gym_space_suitable_for_table_idx(env.state_space())
        self.cache_shield_result = cache_shield_result

        if cache_shield_result:
            state_space = convert_gym_space_to_q_shape(env.state_space())
            joint_action_space = tuple(
                space.n for space in env.agent_actions_spaces() if isinstance(space, spaces.Discrete))

            self.shield_result_cache = np.zeros((*state_space, *joint_action_space), dtype=bool)  # The actual result
            self.shield_result_cache_flag = np.zeros(state_space,
                                                     dtype=bool)  # True if a given cache slot contains the result

            if calc_ahead_of_time:
                for index, _ in np.ndenumerate(self.shield_result_cache_flag):
                    self.get_action_set(index)

    def get_action_set(self, state):
        if not self.cache_shield_result:
            return env_to_action_set(environment=self.env, current_env_state=state)
        else:
            if self.shield_result_cache_flag[state]:
                return self.shield_result_cache[state]
            else:
                action_set = env_to_action_set(environment=self.env, current_env_state=state)
                self.shield_result_cache[state] = action_set
                self.shield_result_cache_flag[state] = True
                return action_set
