import argparse
import random
from typing import Tuple

import torch
from torch.distributions import Categorical

import conf_tmaze
import env_tmaze
import lazymcts_tmaze
import rpg_tmaze
import logging_tmaze
from env_tmaze import Action, Observation, Reward, TMazeEnv, TMazeState
from lazymcts_tmaze import LazyMCTS
from rpg_tmaze import RPG

logger = None


class PGMCTS:
    def __init__(self, dev: torch.device):
        # rpg settings
        self.rpg = RPG(dev=dev)
        # mcts settings
        self.mcts_roots = {
            Observation.START_UP: LazyMCTS(init_position=conf.INITIAL_POSITION, init_obs=Observation.START_UP),
            Observation.START_DOWN: LazyMCTS(init_position=conf.INITIAL_POSITION, init_obs=Observation.START_DOWN),
        }
        # pg-mcts setting
        self.min_rho = torch.tensor([conf.MIN_RHO]).to(dev)
        # etc
        self.dev = dev

    def train(self):
        for ep in range(conf.NUM_EPISODE):
            # get initial observation
            init_obs: Observation = random.choice([Observation.START_UP, Observation.START_DOWN])
            # create environment
            env = TMazeEnv(init_obs=init_obs)

            # initialize rpg
            self.rpg.init_histories()
            # select mcts root and initialize it
            mcts = self.mcts_roots[init_obs]
            mcts.init_histories()

            # retain tmaze progression state and hc_state for rpg
            state: TMazeState = TMazeState(position=conf.INITIAL_POSITION, obs=init_obs)
            hc_state: Tuple[torch.Tensor, torch.Tensor] = None
            while not state.episode_end:
                # choose policy by mixing policies (Eq. 6)
                obs = state.obs.to_tensor().to(self.dev)
                rpg_policy, hc_state, baseline, mix_coef = self.rpg.policy(obs, hc_state)
                mcts_policy = torch.tensor(mcts.policy()).to(self.dev)
                with torch.no_grad():
                    if conf.MIXING_COEFFICIENT_ADAPTATION:
                        # if mcts.expanded:
                        #    mix_coef = torch.tensor([0.5]).to(self.dev)
                        lambda_ajst = mix_coef
                    else:
                        lambda_ajst = conf.LAMBDA
                        if mcts.expanded:
                            lambda_ajst = torch.min(torch.tensor([conf.LAMBDA, 0.1]))
                    mixed_policy = (1 - lambda_ajst) * rpg_policy + lambda_ajst * mcts_policy
                # choose action
                action_idx = Categorical(mixed_policy).sample()
                action = Action(action_idx.item())
                # execute action by stepping mcts with the selected action
                mcts.step_action(env, action)

                # observe a reward and new history by getting last state from mcts
                state = mcts.last_state

                # add RPG history (Eqs. 6, 8)
                rpg_action_prob = rpg_policy[:, action.value]
                with torch.no_grad():
                    mix_action_prob = mixed_policy[:, action.value]
                    rho = (1 - lambda_ajst) * rpg_action_prob / mix_action_prob
                    if conf.MIXING_COEFFICIENT_ADAPTATION:
                        varrho = ( mcts_policy[action.value] - rpg_action_prob ) / mix_action_prob
                    rho = torch.maximum(torch.minimum(self.min_rho, rho * 1.414), rho)
                if conf.MIXING_COEFFICIENT_ADAPTATION:
                    self.rpg.add_history(
                        rpg_action_prob, state.reward, baseline=baseline, mix_coef=mix_coef, rho=rho, varrho=varrho, action=action
                    )
                else:
                    self.rpg.add_history(rpg_action_prob, state.reward, baseline=baseline, rho=rho, action=action)

                # logging (only works if conf.LOG_LEVEL <= logging.DEBUG)
                logger.during_episode(ep, env, state, action, mixed_policy[0],
                                      etc={'rpg_policy': rpg_policy[0], 'mcts_policy':mcts_policy, 'rpg':self.rpg},
                                      freq=0.0003, freq_at_episode_end=0.03)

            # update RPG
            self.rpg.policy_gradient(use_rho=True)
            # update MCTS
            mcts.backpropagate(ep)
            # update min_rho (monotonically decreasing)
            self.min_rho = torch.sqrt(torch.tensor([(ep+100)/(ep+100+1)])).to(self.dev) * self.min_rho

            # count goaled state for last conf.LOG_INTERVAL training iteration
            logger.after_episode(ep, state.reward == Reward.GOAL, self.rpg.first_return)

        logger.after_training()


def main():
    # re-initialize config
    env_tmaze.conf = conf
    rpg_tmaze.conf = conf
    lazymcts_tmaze.conf = conf
    lazymcts_tmaze.mcts_tmaze.conf = conf
    logging_tmaze.conf = conf
    # set logger
    global logger
    logger = logging_tmaze.Logger(alg_name='PG-MCTS')
    # run
    agent = PGMCTS(dev=torch.device("cuda:0"))
    agent.train()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-c", "--config_name", type=str, default="pgmcts_conf_demo")
    args = parser.parse_args()
    conf = getattr(conf_tmaze, args.config_name)
    main()
