from typing import Optional

import torch
import torch.nn.functional as F
from torch import nn
from torch.distributions import Categorical

from centralized_verification.agents.decentralized_training.independent_agents.sac_learner import FakeLogger
from centralized_verification.agents.single_agent import SingleAgentLearner
from centralized_verification.agents.utils import transform_potentially_multidiscrete_obs, get_space_params
from centralized_verification.models.rollout_buffer import CircularRolloutBuffer, TraceStep
from centralized_verification.shields.shield import AgentResult
from centralized_verification.utils import Updater
from experiments.utils.parallel_run import DEVICE

BATCH_SIZE = 32


class PPOAgent(SingleAgentLearner):
    def __init__(self, obs_space, action_space, policy_model_class, critic_model_class,
                 make_multidiscrete_one_hot: bool = False, discount=0.9, buffer_size=int(4e4)):
        self.obs_space = obs_space
        self.discount = discount

        self.obs_shape, self.num_actions, obs_dtype, low, high = get_space_params(self.obs_space, action_space,
                                                                                  make_multidiscrete_one_hot)

        nn_params = {
            "obs_space": self.obs_shape,
            "num_outputs": self.num_actions,
            "low": low,
            "high": high
        }

        self.make_multidiscrete_one_hot = make_multidiscrete_one_hot
        self.policy: nn.Module = policy_model_class(**nn_params).to(DEVICE)
        self.critic: nn.Module = critic_model_class(**nn_params).to(DEVICE)
        self.critic_target: nn.Module = critic_model_class(**nn_params).to(DEVICE)

        self.optimizer = torch.optim.Adam(list(self.policy.parameters()) + list(self.critic.parameters()))

        self.trainer = Updater(lambda: self.train())
        self.target_critic_updater = Updater(lambda: self.critic_target.load_state_dict(self.critic.state_dict()))

        self.buffer = CircularRolloutBuffer(buffer_size, self.obs_shape, obs_dtype, DEVICE)
        self.clip_epsilon = 0.2
        self.logger = FakeLogger()

    def train(self):
        rollout_sample, _, _ = self.buffer.sample(BATCH_SIZE)

        # Train critic
        # Estimate action in new states using main Q network
        next_policy = Categorical(logits=self.policy(rollout_sample.next_states))

        # Target Critic estimates q-values
        future_q_values = (self.critic(rollout_sample.next_states) * next_policy.probs).mean(dim=1)

        # Calculate targets (bellman equation)
        target_q = rollout_sample.rewards + (self.discount * future_q_values * (~rollout_sample.dones).float())
        target_q = target_q.detach()

        # What are the q-values that the current agent predicts for the actions it took
        q_values = self.critic(rollout_sample.states)
        action_q_values = q_values[range(BATCH_SIZE), rollout_sample.actions]

        # Actually train the neural network
        critic_loss = F.mse_loss(input=action_q_values, target=target_q, reduction='none').mean()

        # Train actor
        # Calculate policy loss
        policy_current_state = Categorical(logits=self.policy(rollout_sample.states))
        value_current_state = (q_values * policy_current_state.probs).mean(dim=1).detach()
        advantage_current_state = target_q - value_current_state
        probability_ratio_current_state = (policy_current_state.probs / (policy_current_state.probs.detach() + 1e-8))[
            range(BATCH_SIZE), rollout_sample.actions]
        clipped_probability_ratio = torch.clamp(probability_ratio_current_state, 1 - self.clip_epsilon,
                                                1 + self.clip_epsilon)
        actor_loss = -(torch.min(probability_ratio_current_state * advantage_current_state,
                                 clipped_probability_ratio * advantage_current_state)).mean()

        entropy_loss = -policy_current_state.entropy().mean()

        loss = actor_loss + 0.1 * critic_loss + 0.01 * entropy_loss

        self.logger.log("entropy_loss", entropy_loss.item())
        self.logger.log("actor_loss", actor_loss.item())
        self.logger.log("critic_loss", critic_loss.item())

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Update the target critic network
        self.target_critic_updater.update_every(200)

    def transform_obs(self, obs):
        return torch.tensor(
            transform_potentially_multidiscrete_obs(obs, self.obs_space, self.make_multidiscrete_one_hot),
            device=DEVICE, dtype=torch.float32)

    def observe_transition(self, obs, shield_result: AgentResult, next_obs, rew, done, step_num, training_progress):

        obs = self.transform_obs(obs)
        next_obs = self.transform_obs(next_obs)

        self.buffer.add_step(TraceStep(state=obs, action=shield_result.real_action.action,
                                       reward=shield_result.real_action.get_modified_reward(rew), next_state=next_obs,
                                       done=done))
        if shield_result.augmented_action:
            self.buffer.add_step(TraceStep(state=obs, action=shield_result.augmented_action.action,
                                           reward=shield_result.augmented_action.get_modified_reward(rew),
                                           next_state=next_obs, done=done))

        if self.buffer.num_filled_approx() >= self.buffer.capacity / 2:
            self.trainer.update_every(4)

    def state_dict(self):
        return {
            "policy": self.policy.state_dict(),
            "critic": self.critic.state_dict(),
            "buffer": self.buffer.state_dict()
        }

    def load_state_dict(self, state_dict):
        self.policy.load_state_dict(state_dict["policy"])
        self.critic.load_state_dict(state_dict["critic"])
        self.buffer.load_state_dict(state_dict["buffer"])

    def get_action(self, observation, step_num: Optional[int]):
        obs_tens = self.transform_obs(observation).view(1, -1)
        action_logits = self.policy(obs_tens)
        action = Categorical(logits=action_logits).sample()
        return action.item()

    def get_log_dict(self):
        return self.logger.log_dict
