from typing import Tuple

from centralized_verification.shields.ahead_of_time_shield import AheadOfTimeShield
from centralized_verification.shields.shield import ShieldResult, T
from centralized_verification.shields.utils import create_random_priority


class CentralizedShieldOracle(AheadOfTimeShield):
    def get_initial_shield_state(self, state, initial_joint_obs) -> T:
        return None

    def evaluate_joint_action(self, state, _, proposed_action, __) -> Tuple[ShieldResult, None]:
        action_set = self.get_action_set(state)
        default_action = tuple(0 for _ in range(len(proposed_action)))
        priority = create_random_priority(len(proposed_action))

        shield_result, _ = self.neuter_agents_in_priority(proposed_action, action_set, priority, default_action,
                                                          default_action)
        return shield_result, None
