import argparse
import random
from typing import Tuple

import torch
from logzero import logger as log
from torch.distributions import Categorical

import conf_tmaze
import env_tmaze
import lazymcts_tmaze
import rpg_tmaze
import logging_tmaze
from env_tmaze import Action, Observation, Reward, TMazeEnv, TMazeState
from lazymcts_tmaze import LazyMCTS
from rpg_tmaze import RPG

logger = None


class LazyAlphaZero:
    def __init__(self, dev: torch.device):
        # rpg settings
        self.rpg = RPG(dev=dev)
        # mcts settings
        self.mcts_roots = {
            Observation.START_UP: LazyMCTS(init_position=conf.INITIAL_POSITION, init_obs=Observation.START_UP),
            Observation.START_DOWN: LazyMCTS(init_position=conf.INITIAL_POSITION, init_obs=Observation.START_DOWN),
        }
        # etc
        self.dev = dev

    def train(self):
        for ep in range(conf.NUM_EPISODE):
            # get initial observation
            init_obs: Observation = random.choice([Observation.START_UP, Observation.START_DOWN])
            # create current environment
            env = TMazeEnv(init_obs=init_obs)

            # initialize rpg
            self.rpg.init_histories()
            # select mcts root and initialize it
            mcts = self.mcts_roots[init_obs]
            mcts.init_histories()

            # retain tmaze progression state and hc_state for rpg
            state: TMazeState = TMazeState(position=conf.INITIAL_POSITION, obs=init_obs)
            hc_state: Tuple[torch.Tensor, torch.Tensor] = None
            while not state.episode_end:
                # set policy (which will be one-hot vector)
                obs = state.obs.to_tensor().to(self.dev)
                rpg_policy, hc_state, baseline, _ = self.rpg.policy(obs, hc_state)
                rpg_policy = torch.nan_to_num(rpg_policy)
                if mcts.last_node.fully_expanded():
                    policy = torch.tensor(mcts.policy(prior=rpg_policy[0].detach().cpu().numpy())).to(self.dev)
                    rpg_update = True
                else:
                    mcts_policy = torch.tensor(mcts.policy(prior=None)).to(self.dev)
                    policy = (1 - conf.LAMBDA) * rpg_policy + conf.LAMBDA * mcts_policy
                    rpg_update = False

                # choose and execute action by stepping mcts with sampled action
                action_idx = Categorical(policy).sample()
                action = Action(action_idx.item())
                mcts.step_action(env, action)

                # observe a reward and new history by getting last state from mcts
                state = mcts.last_state

                # add RPG history
                action_prob = rpg_policy[:, action.value]
                if mcts.expanded:
                    action_prob = action_prob.detach()
                self.rpg.add_history(action_prob, state.reward, baseline=baseline, is_update=rpg_update, action=action)

                # logging (only works if conf.LOG_LEVEL <= logging.DEBUG)
                logger.during_episode(ep, env, state, action, rpg_policy[0],
                                      etc={'rpg':self.rpg},
                                      freq=0.0003, freq_at_episode_end=0.03)

            # update RPG and MCTS params
            self.rpg.policy_gradient(use_rho=False, ml_update=True)
            mcts.backpropagate(ep)

            # count goaled state for last conf.LOG_INTERVAL training iteration
            logger.after_episode(ep, state.reward == Reward.GOAL, self.rpg.first_return)

        logger.after_training()


def main():
    # re-initialize config
    env_tmaze.conf = conf
    rpg_tmaze.conf = conf
    lazymcts_tmaze.conf = conf
    lazymcts_tmaze.mcts_tmaze.conf = conf
    logging_tmaze.conf = conf
    # set logger
    global logger
    logger = logging_tmaze.Logger(alg_name='alphaZero')
    # run
    agent = LazyAlphaZero(dev=torch.device("cuda:0"))
    agent.train()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-c", "--config_name", type=str, default="alphaZero_conf_demo")
    args = parser.parse_args()
    conf = getattr(conf_tmaze, args.config_name)
    main()
