import random
from typing import Callable

import torch
import torch.nn.functional as F
from torch.optim import Adam

from centralized_verification.agents.single_agent import SingleAgentLearner
from centralized_verification.agents.utils import transform_potentially_multidiscrete_obs, get_space_params, \
    default_epsilon_schedule
from centralized_verification.models.rollout_buffer import CircularRolloutBuffer, TraceStep
from centralized_verification.models.stateful_module import StatefulModule
from centralized_verification.shields.shield import AgentUpdate, AgentResult
from centralized_verification.utils import Updater, TrainingProgress, convert_gym_space_to_q_shape
from experiments.utils.parallel_run import DEVICE

BATCH_SIZE = 8
SEQUENCE_LEN = 4
NUM_SAMPLES = BATCH_SIZE * SEQUENCE_LEN


class DeepRecurrentQLearner(SingleAgentLearner):
    def __init__(self, obs_space, action_space, model_class: StatefulModule, discount=0.9, alpha_index=1,
                 epsilon_scheduler: Callable[[TrainingProgress], float] = default_epsilon_schedule,
                 evaluation_epsilon: float = 0.0,
                 make_multidiscrete_one_hot: bool = False,
                 train_every_steps=4,
                 buffer_size=int(1e5), clip_gradients=False, **kwargs):

        self.obs_space = obs_space
        self.num_actions = convert_gym_space_to_q_shape(action_space)[0]
        self.discount = discount
        self.alpha_index = alpha_index
        self.epsilon_scheduler = epsilon_scheduler
        self.evaluation_epsilon = evaluation_epsilon
        self.log_last_eps = epsilon_scheduler(TrainingProgress(0, 0))
        self.obs_shape, self.num_actions, obs_dtype, low, high = get_space_params(self.obs_space, action_space,
                                                                                  make_multidiscrete_one_hot)
        self.make_multidiscrete_one_hot = make_multidiscrete_one_hot

        self.model: StatefulModule = model_class(self.obs_shape, self.num_actions, low, high).to(DEVICE)
        self.target_model: StatefulModule = model_class(self.obs_shape, self.num_actions, low, high).to(DEVICE)

        self.replay_buffer = CircularRolloutBuffer(capacity=buffer_size, input_shape=self.obs_shape,
                                                   state_dtype=torch.float32, device=DEVICE)
        self.trainer = Updater(lambda: self.perform_training())
        self.target_agent_updater = Updater(lambda: self.target_model.load_state_dict(self.model.state_dict()))
        self.optimizer = Adam(self.model.parameters())
        self.last_training_loss = None
        self.train_every_steps = train_every_steps
        self.clip_gradients = clip_gradients

        self.model_state = None

    def new_episode(self):
        # TODO Potentially log in rollout buffer
        self.model_state = None

    def transform_obs(self, obs):
        return transform_potentially_multidiscrete_obs(obs, self.obs_space, self.make_multidiscrete_one_hot)

    def get_action(self, observation, training_progress: TrainingProgress):
        """
        Slightly different from get_action in QLearner,
        because we need to update the model state even if we take a random action
        """
        epsilon = self.epsilon_scheduler(
            training_progress) if training_progress is not None else self.evaluation_epsilon
        self.log_last_eps = epsilon
        best_action = self.get_greedy_action(self.transform_obs(observation))
        if random.random() < epsilon:
            return random.randint(0, self.num_actions - 1)
        else:
            return best_action

    def get_greedy_action(self, observation):
        if self.model_state is None:
            self.model_state = self.model.get_initial_state(1)

        model_input_batched = torch.as_tensor(observation, device=DEVICE).view((1, 1, *self.obs_shape)).float()
        model_output, self.model_state = self.model(model_input_batched, self.model_state)
        return int(model_output.view((self.num_actions,)).argmax())

    def update_q(self, obs, action, next_obs, rew, done, step_num, training_progress):
        self.optimizer.zero_grad()
        self.replay_buffer.add_step(TraceStep(state=torch.as_tensor(obs, device=DEVICE), action=action, reward=rew,
                                              next_state=torch.as_tensor(next_obs, device=DEVICE), done=done))

        if self.replay_buffer.num_filled_approx() >= BATCH_SIZE * 10:
            self.trainer.update_every(self.train_every_steps)
            self.target_agent_updater.update_every(200)

    def update_q_with_agent_update(self, obs, agent_update: AgentUpdate, next_obs, rew, done, step_num,
                                   training_progress):
        self.update_q(obs, agent_update.action, next_obs, agent_update.get_modified_reward(rew), done, step_num,
                      training_progress)

    def observe_transition(self, obs, shield_result: AgentResult, next_obs, rew, done, step_num, training_progress):
        obs = self.transform_obs(obs)
        next_obs = self.transform_obs(next_obs)
        action_to_update = shield_result.real_action
        if shield_result.augmented_action is not None:
            if random.random() < 0.5:  # TODO Experiment with this
                action_to_update = shield_result.augmented_action

        self.update_q_with_agent_update(obs, action_to_update, next_obs, rew, done, step_num, training_progress)

    def perform_training(self):
        self.optimizer.zero_grad()
        rollout_sample, valid_indices = self.replay_buffer.sample_sequences(batch_size=BATCH_SIZE,
                                                                            sequence_len=SEQUENCE_LEN)

        # Estimate best action in new states using main Q network
        q_max, _ = self.model(rollout_sample.next_states, self.model.get_initial_state(BATCH_SIZE))
        arg_q_max = torch.argmax(q_max, dim=2)

        # Target DQN estimates q-values
        future_q_values, _ = self.target_model(rollout_sample.next_states,
                                               self.target_model.get_initial_state(BATCH_SIZE))
        double_q = future_q_values.view((NUM_SAMPLES, self.num_actions))[
            range(NUM_SAMPLES), arg_q_max.view(NUM_SAMPLES)].view((BATCH_SIZE, SEQUENCE_LEN))

        # Calculate targets (bellman equation)
        target_q = rollout_sample.rewards + (self.discount * double_q * (~rollout_sample.dones).float())
        target_q = target_q.detach()

        # What are the q-values that the current agent predicts for the actions it took
        q_values, _ = self.model(rollout_sample.states, self.model.get_initial_state(BATCH_SIZE))
        action_q_values = q_values.view((NUM_SAMPLES, self.num_actions))[
            range(NUM_SAMPLES), rollout_sample.actions.view(NUM_SAMPLES)].view((BATCH_SIZE, SEQUENCE_LEN))

        # Actually train the neural network
        loss = F.mse_loss(input=action_q_values, target=target_q, reduction='none')
        loss = (loss * valid_indices).mean()
        loss.backward()

        if self.clip_gradients:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1)

        self.last_training_loss = float(loss)
        self.optimizer.step()

    def state_dict(self):
        return {
            "optimizer": self.optimizer.state_dict(),
            "model": self.model.state_dict(),
            "target_model": self.target_model.state_dict(),
            "last_training_loss": self.last_training_loss,
            "replay_buffer": self.replay_buffer.state_dict()
        }

    def load_state_dict(self, state_dict):
        self.optimizer.load_state_dict(state_dict["optimizer"])
        self.model.load_state_dict(state_dict["model"])
        self.target_model.load_state_dict(state_dict["target_model"])
        self.last_training_loss = state_dict.get("last_training_loss")
        self.replay_buffer.load_state_dict(state_dict["replay_buffer"])

    def get_log_dict(self):
        ld = {
            "epsilon": self.log_last_eps,
            "training_loss": self.last_training_loss
        }
        return ld
