from typing import Optional

import lib.pytorch_sac.agent.actor as psac_actor
import lib.pytorch_sac.agent.critic as psac_critic
import lib.pytorch_sac.agent.sac as sac
import lib.pytorch_sac.replay_buffer as replay_buffer
from centralized_verification.agents.single_agent import SingleAgentLearner
from centralized_verification.agents.utils import transform_potentially_multidiscrete_obs, get_space_params
from centralized_verification.shields.shield import AgentResult
from centralized_verification.utils import convert_gym_space_to_q_shape, TrainingProgress
from experiments.utils.parallel_run import DEVICE


class FakeLogger:
    def __init__(self):
        self.log_dict = {}
        self.log_dict_step = {}

    def log(self, key, value, step=None):
        self.log_dict[key] = value
        if step:
            self.log_dict_step[key] = step

    def log_histogram(self, key, value, step):
        pass

    def log_param(self, key, value, step):
        pass


class SACAgent(SingleAgentLearner):
    def __init__(self, obs_space, action_space, buffer_size, make_multidiscrete_one_hot: bool = False,
                 discount=0.99):
        self.raw_obs_space = obs_space
        self.num_actions = convert_gym_space_to_q_shape(action_space)[0]
        self.discount = discount
        self.logger = FakeLogger()

        self.obs_shape, self.num_actions, obs_dtype, low, high = get_space_params(self.raw_obs_space, action_space,
                                                                                  make_multidiscrete_one_hot)

        if len(self.obs_shape) != 1:
            raise Exception("SAC learner must have a 1-D observation space")

        self.make_multidiscrete_one_hot = make_multidiscrete_one_hot

        critic_parameters = {
            "obs_dim": self.obs_shape[0],
            "action_dim": self.num_actions,
            "hidden_dim": 1024,
            "hidden_depth": 2,
        }
        critic = psac_critic.DoubleQCritic(**critic_parameters)
        critic_target = psac_critic.DoubleQCritic(**critic_parameters)

        actor_parameters = {
            "obs_dim": self.obs_shape[0],
            "action_dim": self.num_actions,
            "hidden_dim": 1024,
            "hidden_depth": 2,
        }
        actor = psac_actor.CategoricalActor(**actor_parameters)

        sac_parameters = {
            "obs_dim": self.obs_shape[0],
            "action_dim": self.num_actions,
            "device": DEVICE,
            "critic": critic,
            "critic_target": critic_target,
            "actor": actor,
            "discount": discount,
            "init_temperature": 0.1,
            "alpha_lr": 1e-4,
            "alpha_betas": (0.9, 0.999),
            "actor_lr": 1e-4,
            "actor_betas": (0.9, 0.999),
            "actor_update_frequency": 1,
            "critic_lr": 1e-4,
            "critic_betas": (0.9, 0.999),
            "critic_tau": 0.005,
            "critic_target_update_frequency": 2,
            "batch_size": 1024,
            "learnable_temperature": True
        }
        self.sac = sac.SACAgent(**sac_parameters)

        self.rollout_buffer = replay_buffer.ReplayBuffer(obs_shape=self.obs_shape, action_shape=(),
                                                         capacity=buffer_size, device=DEVICE)

    def observe_transition(self, obs, shield_result: AgentResult, next_obs, rew, done, step_num,
                           training_progress: TrainingProgress):
        obs = transform_potentially_multidiscrete_obs(obs, self.raw_obs_space, self.make_multidiscrete_one_hot)
        next_obs = transform_potentially_multidiscrete_obs(next_obs, self.raw_obs_space,
                                                           self.make_multidiscrete_one_hot)
        self.rollout_buffer.add(obs, shield_result.real_action.action,
                                shield_result.real_action.get_modified_reward(rew), next_obs, float(done), float(done))
        if shield_result.augmented_action:
            self.rollout_buffer.add(obs, shield_result.augmented_action.action,
                                    shield_result.augmented_action.get_modified_reward(rew), next_obs, float(done),
                                    float(done))
        if len(self.rollout_buffer) > 5000:
            self.sac.update(replay_buffer=self.rollout_buffer, logger=self.logger,
                            step=training_progress.global_step_count)

    def state_dict(self):
        return {
            "rollout_buffer": self.rollout_buffer.state_dict(),
            "sac": self.sac.state_dict()
        }

    def load_state_dict(self, state_dict):
        self.rollout_buffer.load_state_dict(state_dict["rollout_buffer"])
        self.sac.load_state_dict(state_dict["sac"])

    def get_action(self, observation, step_num: Optional[int]):
        obs = transform_potentially_multidiscrete_obs(observation, self.raw_obs_space, self.make_multidiscrete_one_hot)
        return self.sac.act(obs, True)

    def get_log_dict(self):
        return dict(self.logger.log_dict)
