import argparse
import random
from typing import Tuple, Optional

import torch
import torch.nn as nn
import torch.nn.utils as utils
from torch.distributions import Categorical
from torch.optim import SGD

import conf_tmaze
import env_tmaze
import logging_tmaze
from conf_tmaze import RNNType
from env_tmaze import Action, Observation, Reward, TMazeEnv, TMazeState

logger = None


class RNN(nn.Module):
    def __init__(self):
        super().__init__()
        if conf.RNN_TYPE == RNNType.GRU:
            self.rnn = nn.GRU(Observation.num_bit(), conf.RNN_SIZE)
        elif conf.RNN_TYPE == RNNType.LTSM:
            self.rnn = nn.LSTM(Observation.num_bit(), conf.RNN_SIZE)
        else:
            raise ValueError(f"Unknown RNN specified: {conf.RNN_TYPE}")
        if conf.DROPOUT > 0:
            self.dropout = nn.Dropout(conf.DROPOUT)
        hidden_size = Observation.num_bit() + conf.RNN_SIZE
        #self.fc = nn.Linear(hidden_size, hidden_size)
        self.qvalue = nn.Linear(hidden_size, len(Action))
        self.baseline = nn.Linear(hidden_size, 1)
        #self.policy_mixing = nn.Linear(hidden_size, 1, bias=False)
        self.mix_coef = nn.Linear(Observation.num_bit(), 1, bias=False)

    def forward(
        self, x: torch.Tensor, hc_state: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        out, hc_state = self.rnn(x, hc_state)
        if conf.DROPOUT > 0:
            out = self.dropout(out)
        out = torch.cat([x, out], dim=-1)
        #out = self.fc(out)
        qvalue = self.qvalue(out)
        policy_prob = torch.softmax(qvalue, dim=-1)[-1, :, :]
        baseline = self.baseline(out)[0, 0]
        #mix_coef = torch.sigmoid(conf.LAMBDA * self.policy_mixing(out)[0, 0])
        mix_coef = torch.sigmoid(conf.LAMBDA * self.mix_coef(x)[0, 0])
        return policy_prob, hc_state, baseline, mix_coef


class PPO:
    def __init__(self, dev: torch.device):
        self.policy = RNN().to(dev)
        self.dev = dev
        self.optim = SGD(self.policy.parameters(), lr=conf.ALPHA)

    def load_state_dicts(self, model_to_resume: str) -> int:
        model_pth = torch.load(model_to_resume)
        self.policy.load_state_dict(model_pth["model_state_dict"])
        self.optim.load_state_dict(model_pth["optim_state_dict"])
        return model_pth["next_starting_epoch"]

    def init_histories(self):
        # rnn outputs
        self.action_probs = torch.empty(0).to(self.dev)  # torch.tensor([])
        self.baselines = torch.empty(0).to(self.dev)  # torch.tensor([])
        self.mix_coefs = torch.empty(0).to(self.dev)  # torch.tensor([])
        # pgmcts rhos & mcts_action_probs
        self.rhos = torch.empty(0).to(self.dev)  # torch.tensor([])
        self.varrhos = torch.empty(0).to(self.dev)  # torch.tensor([])
        # other histories
        self.obss = []
        self.rewards = []
        self.returns = []
        # histories for debug
        self.probs = []
        self.actions = []
        self.positions = []
        self.is_updates = []

    @property
    def first_return(self) -> float:
        return self.returns[0]

    def add_history(
        self,
        obs: Observation,
        action_prob: torch.Tensor,
        reward: Reward,
        *,
        baseline: torch.Tensor = None,
        mix_coef: torch.Tensor = None,
        rho: torch.Tensor = None,
        varrho: torch.Tensor = None,
        prob: torch.Tensor = None,
        action: Action = None,
        position: int = None,
        is_update: bool = None,
    ):
        # store to reward history
        self.rewards.append(reward.value)
        # store to policy history
        self.action_probs = torch.cat((self.action_probs, action_prob))
        if baseline is not None:
            self.baselines = torch.cat((self.baselines, baseline))
        if mix_coef is not None:
            self.mix_coefs = torch.cat((self.mix_coefs, mix_coef))
        # store to pgmcts rho history
        if rho is not None:
            self.rhos = torch.cat((self.rhos, rho))
        if varrho is not None:
            self.varrhos = torch.cat((self.varrhos, varrho))
        # store others
        self.obss.append(obs)
        self.probs.append(prob)
        self.actions.append(action)
        self.positions.append(position)
        self.is_updates.append(is_update)

    def recompute_policy_output(self):
        """
        Recompute policy outputs from scratch given the stored history.
        """
        action_probs_list = []
        baselines_list = []
        hc_state = None

        for t in range(len(self.obss)):
            # get observation and action directly from GPU
            obs = self.obss[t]
            action = self.actions[t]
            # get policy probability
            policy, hc_state, baseline, _ = self.policy(obs, hc_state)
            # store to action_prob and baseline
            action_probs_list.append(policy[:, action.value])
            baselines_list.append(baseline)

        # stack tensors (already on GPU)
        self.action_probs = torch.stack(action_probs_list).squeeze(-1)
        self.baselines = torch.stack(baselines_list).squeeze(-1)

    def policy_gradient(self, use_rho: bool = False, ml_update: bool = False):
        # get the return (discounted cumulative rewards) and discount (discount rates) at each time-step
        returns, discounts = self.comp_returns()

        # prep
        clip = conf.CLIP_PPO
        action_probs_old = self.action_probs.detach().clone()
        advantages = (returns - self.baselines).detach().clone()
        if conf.IS_LOSS_DISCOUNTED:
            advantages = advantages * discounts

        # adjust clip for avoiding peaky policy
        #clip = clip * torch.min(torch.tensor(1.0), (1.0 - action_probs_old) / 0.5)  # clip is proportional to action_probs (not used)

        # iterate over epochs
        for e in range(conf.NUM_EPOCHS):

            if e > 0:  # recompute policy output according to updated policy the stored history
                self.recompute_policy_output()

            # compute ratios
            ratios = self.action_probs / action_probs_old

            # unclipped REINFOCE of each step
            reinforces = ratios * advantages

            # clip REINFORCE if necessary
            if clip is not None:
                # clipped REINFORCEs
                clipped_reinforces = torch.clamp(ratios, 1.0 - clip, 1.0 + clip) * advantages

                # final REINFORCEs
                reinforces = torch.min(reinforces, clipped_reinforces)

            # compute pg loss
            loss_pg = - reinforces
            # compute baseline loss
            loss_baseline = (returns - self.baselines) ** 2
            if conf.IS_LOSS_DISCOUNTED:
                loss_baseline = discounts * loss_baseline

            # # if called as pgmcts, use rhos (Eq. 7)
            # if use_rho:
            #     loss_pg = self.rhos.detach() * loss_pg  # / (1-conf.LAMBDA)
            # if conf.MIXING_COEFFICIENT_ADAPTATION:
            #     loss_pg -= self.mix_coefs * self.varrhos.detach() * advantages.detach() * discounts

            # total loss
            loss = (
                torch.sum( loss_pg ) + .5 * torch.sum(loss_baseline) / len(self.rewards)
            )  # gradient descent

            # reset gradient storage for each params
            self.optim.zero_grad()
            # perform BPTT
            loss.backward()
            # gradient clipping
            utils.clip_grad_norm_(self.policy.parameters(), conf.CLIP_GRAD_NORM)
            # update parameters
            self.optim.step()

    def comp_returns(self) -> (torch.Tensor, torch.Tensor):
        # discount and tensornize
        new_r = 0
        for r in self.rewards[::-1]:
            new_r = r + conf.GAMMA * new_r
            self.returns.insert(0, new_r)
        return torch.tensor(self.returns).to(self.dev),\
               torch.tensor([conf.GAMMA**t for t in range(len(self.rewards))]).to(self.dev)

    def train(self):
        for ep in range(conf.NUM_EPISODE):
            # get initial observation
            init_obs: Observation = random.choice([Observation.START_UP, Observation.START_DOWN])
            # create environment
            env = TMazeEnv(init_obs=init_obs)

            # initialize agent
            self.init_histories()

            # retain tmaze progression state and rnn hc_state
            state: TMazeState = TMazeState(position=conf.INITIAL_POSITION, obs=init_obs)
            hc_state: Tuple[torch.Tensor, torch.Tensor] = None
            while not state.episode_end:
                # get policy probability
                obs = state.obs.to_tensor().to(self.dev)
                policy, hc_state, baseline, _ = self.policy(obs, hc_state)

                # draw next action from policy prob
                action_idx = Categorical(policy).sample()
                action = Action(action_idx.item())

                # move according to action, then get observation and reward
                state = env.step(state, action)

                # add history
                action_prob = policy[:, action.value]  # torch.Size([1])
                self.add_history(obs, action_prob, state.reward, baseline=baseline, action=action)

                # logging (only works if conf.LOG_LEVEL <= logging.DEBUG)
                logger.during_episode(ep, env, state, action, policy[0], etc={'rpg':self},
                                      freq=0.0003, freq_at_episode_end=0.03)

            # after an episode, update parameters with PPO
            self.policy_gradient()

            # count goaled state for last conf.LOG_INTERVAL training iteration
            logger.after_episode(ep, state.reward==Reward.GOAL, self.first_return)

        logger.after_training()


def main():
    # re-initialize config
    env_tmaze.conf = conf
    logging_tmaze.conf = conf
    # set logger
    global logger
    logger = logging_tmaze.Logger(alg_name='PPO')
    # run
    agent = PPO(dev=torch.device("cuda:0"))
    agent.train()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-c", "--config_name", type=str, default="ppo_conf_demo")
    args = parser.parse_args()
    conf = getattr(conf_tmaze, args.config_name)
    main()
