import abc
from typing import Optional

from centralized_verification.shields.shield import AgentResult


class SingleAgent(abc.ABC):
    @abc.abstractmethod
    def get_action(self, observation, step_num: Optional[int]):
        """
        :param step_num: Will be None if in testing
        """
        pass

    def new_episode(self):
        pass

    @abc.abstractmethod
    def get_log_dict(self):
        pass


class SingleAgentLearner(SingleAgent, abc.ABC):
    @abc.abstractmethod
    def observe_transition(self, obs, shield_result: AgentResult, next_obs, rew, done, step_num, training_progress):
        pass

    @abc.abstractmethod
    def state_dict(self):
        pass

    @abc.abstractmethod
    def load_state_dict(self, state_dict):
        pass
