import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

from centralized_verification.agents.decentralized_training.independent_agents.q_learner import QLearner
from centralized_verification.agents.utils import transform_potentially_multidiscrete_obs, get_space_params
from centralized_verification.models.rollout_buffer import CircularRolloutBuffer, TraceStep
from centralized_verification.utils import Updater
from experiments.utils.parallel_run import DEVICE

BATCH_SIZE = 32


class DeepQLearner(QLearner):
    def __init__(self, obs_space, action_space, model_class, make_multidiscrete_one_hot: bool = False,
                 train_every_steps=4,
                 buffer_size=int(1e5), clip_gradients=False, **kwargs):
        super().__init__(obs_space, action_space, **kwargs)
        self.obs_shape, self.num_actions, obs_dtype, low, high = get_space_params(self.obs_space, action_space,
                                                                                  make_multidiscrete_one_hot)

        self.make_multidiscrete_one_hot = make_multidiscrete_one_hot

        self.model: nn.Module = model_class(self.obs_shape, self.num_actions, low, high).to(DEVICE)
        self.target_model: nn.Module = model_class(self.obs_shape, self.num_actions, low, high).to(DEVICE)

        self.replay_buffer = CircularRolloutBuffer(capacity=buffer_size, input_shape=self.obs_shape,
                                                   state_dtype=torch.float32, device=DEVICE)
        self.trainer = Updater(lambda: self.perform_training())
        self.target_agent_updater = Updater(lambda: self.target_model.load_state_dict(self.model.state_dict()))
        self.optimizer = Adam(self.model.parameters())
        self.last_training_loss = None
        self.train_every_steps = train_every_steps
        self.clip_gradients = clip_gradients

    def transform_obs(self, obs):
        return transform_potentially_multidiscrete_obs(obs, self.obs_space, self.make_multidiscrete_one_hot)

    def get_greedy_action(self, observation):
        model_input_batched = torch.as_tensor(observation, device=DEVICE).view((1, *self.obs_shape)).float()
        model_output = self.model(model_input_batched).view((self.num_actions,))
        return int(model_output.argmax())

    def update_q(self, obs, action, next_obs, rew, done, step_num, training_progress):
        self.optimizer.zero_grad()
        self.replay_buffer.add_step(TraceStep(state=torch.as_tensor(obs, device=DEVICE), action=action, reward=rew,
                                              next_state=torch.as_tensor(next_obs, device=DEVICE), done=done))

        if self.replay_buffer.num_filled_approx() >= BATCH_SIZE * 10:
            self.trainer.update_every(self.train_every_steps)
            self.target_agent_updater.update_every(200)

    def perform_training(self):
        self.optimizer.zero_grad()
        rollout_sample, indices, importance = self.replay_buffer.sample(BATCH_SIZE, priority_scale=0.7)

        importance = torch.pow(importance, 0.9)  # So that high-priority states aren't _too_ overrepresented

        # Estimate best action in new states using main Q network
        q_max = self.model(rollout_sample.next_states)
        arg_q_max = torch.argmax(q_max, dim=1)

        # Target DQN estimates q-values
        future_q_values = self.target_model(rollout_sample.next_states)
        double_q = future_q_values[range(BATCH_SIZE), arg_q_max]

        # Calculate targets (bellman equation)
        target_q = rollout_sample.rewards + (self.discount * double_q * (~rollout_sample.dones).float())
        target_q = target_q.detach()

        # What are the q-values that the current agent predicts for the actions it took
        q_values = self.model(rollout_sample.states)
        action_q_values = q_values[range(BATCH_SIZE), rollout_sample.actions]

        # Sample q values that we get wrong more often
        error = action_q_values - target_q
        self.replay_buffer.set_priorities(indices=indices, errors=error.detach())

        # Actually train the neural network
        loss = F.mse_loss(input=action_q_values, target=target_q, reduction='none')
        loss = (loss * importance).mean()
        loss.backward()

        if self.clip_gradients:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)

        self.last_training_loss = float(loss)
        self.optimizer.step()

    def state_dict(self):
        return {
            "optimizer": self.optimizer.state_dict(),
            "model": self.model.state_dict(),
            "target_model": self.target_model.state_dict(),
            "last_training_loss": self.last_training_loss,
            "replay_buffer": self.replay_buffer.state_dict()
        }

    def load_state_dict(self, state_dict):
        self.optimizer.load_state_dict(state_dict["optimizer"])
        self.model.load_state_dict(state_dict["model"])
        self.target_model.load_state_dict(state_dict["target_model"])
        self.last_training_loss = state_dict.get("last_training_loss")
        self.replay_buffer.load_state_dict(state_dict["replay_buffer"])

    def get_log_dict(self):
        ld = super().get_log_dict()
        ld.update({
            "training_loss": self.last_training_loss
        })
        return ld
