import argparse
import random

import numpy as np
from logzero import logger as log

import conf_tmaze
import env_tmaze
import mcts_tmaze
import logging_tmaze
from conf_tmaze import MCTSMode, MCTSPolicy
from env_tmaze import Action, Observation, Reward, TMazeEnv, TMazeState
from mcts_tmaze import MCTS, MCTSNode, NodeStats

logger = None


class LazyMCTS:
    def __init__(self, init_position: int, init_obs: Observation):
        self.mcts = MCTS(init_position, init_obs)
        self.last_state = None
        self.expanded = None

    @property
    def last_node(self) -> MCTSNode:
        try:
            return self.mcts.path[-1]
        except IndexError:
            return self.mcts.root

    def init_histories(self):
        self.mcts.init_histories()
        self.last_state = self.mcts.root.state
        self.expanded = False

    @property
    def first_return(self) -> float:
        return self.mcts.first_return

    def add_histories(self, node: MCTSNode):
        self.mcts.add_histories(node)

    def policy(self, prior: np.ndarray = None) -> np.ndarray:
        if not self.expanded:
            # selection or expansion prob
            return self.virtual_select(prior)
        else:
            # rollout prob is uniform of all actions
            all_acts = list(Action)
            return np.array([1 / len(all_acts)] * len(all_acts))

    def virtual_select(self, prior: np.ndarray = None) -> np.ndarray:
        # start from previous last node
        target: MCTSNode = self.last_node
        if not target.state.episode_end:
            if not target.fully_expanded():
                # uniform prob of all untried actions
                return target.untried_prob
            else:
                if conf.MCTS_POLICY == MCTSPolicy.Greedy:
                    # one-hot vector of best child
                    return target.greedy_children(prior)
                elif conf.MCTS_POLICY == MCTSPolicy.Softmax:
                    # prob vector by softmax( conf.BETA * indices)
                    return target.softmax_children(prior)
        return None

    def step_action(self, env: TMazeEnv, action: Action):
        if not self.expanded:
            self.select_action(env, action)
        else:
            self.rollout_action(env, action)

    def select_action(self, env: TMazeEnv, action: Action):
        # start from previous last node
        target: MCTSNode = self.last_node
        if not target.state.episode_end:
            if target.untried(action):
                if conf.MCTS_MODE != MCTSMode.Lipschitz or target.parent is None or np.random.uniform() <= np.minimum(1, target.n - NodeStats.n):
                    # if conf.MCTS_MODE != MCTSMode.Lipschitz \
                    #  or target.parent is None \
                    #  or np.random.uniform() <= np.minimum(1, np.sum([c.n for c in target.parent.children.values()]) - len(Action) * NodeStats.n):
                    #  # or np.random.uniform() <= np.minimum(1, target.parent.n - 1):
                    # print(target.state.timestep, target.state.position, target.n, NodeStats.n, np.minimum(1, target.n - NodeStats.n))
                    if conf.EXPANSION_ALL_AT_ONCE:
                        self.expand_all_at_once(env, target)
                    else:
                        self.expand_action(env, target, action)
                else:
                    self.rollout_action(env, action)
                    # no more expansion in this episode
                    self.expanded = True
                    return
            if not conf.STATE_ACTION:
                child = target.find_child(action)
                assert child.state.timestep == target.state.timestep + 1
            else:
                pass  # TODO：状態遷移が確率的で状態行動ノードを使う場合
            self.add_histories(child)
            self.last_state = child.state

    def expand_action(self, env: TMazeEnv, target: MCTSNode, action: Action):
        state: TMazeState = target.state
        action = target.pop_action(action)
        new_state = env.step(state, action)
        # create new child under target
        new_child = MCTSNode(state=new_state, parent=target, caused_by=action)
        target.children[action] = new_child
        self.expanded = True

    def expand_all_at_once(self, env: TMazeEnv, target: MCTSNode):
        state: TMazeState = target.state
        while not target.fully_expanded():
            action = target.pop_untried_action()
            new_state = env.step(state, action)
            # create new child under target
            new_child = MCTSNode(state=new_state, parent=target, caused_by=action)
            target.children[action] = new_child
        self.expanded = True

    def rollout_action(self, env: TMazeEnv, action: Action):
        # start from previous last state
        state: TMazeState = self.last_state
        if not state.episode_end:
            # run simulation with specified action
            state = env.step(state, action)
            # append reward only
            self.mcts.rewards.append(state.reward.value)
            # below are for LazyMCTS
            self.last_state = state

    def backpropagate(self, num_episode: int):
        if conf.MCTS_MODE == MCTSMode.Standard:
            self.mcts.backpropagate()
        elif conf.MCTS_MODE == MCTSMode.Lipschitz:
            self.mcts.comp_return()
            # update n & q
            for node, q_target in zip(self.mcts.path, self.mcts.returns):
                y = 1 / node.n
                kappa = np.minimum(conf.M / (num_episode + 1), y)
                node.q += (q_target - node.q) * kappa
                y -= y / (y + 1) * kappa
                node.n = 1 / y


def main():
    # create environment
    log.info("Start T-Maze LazyMCTS simulation")
    log.info(f"Len: {conf.CORRIDOR_LENGTH}, Ini_posi: {conf.INITIAL_POSITION}, Iter: {conf.NUM_EPISODE}")
    log.info(f"Config: {conf}")
    # re-initialize config
    env_tmaze.conf = conf
    mcts_tmaze.conf = conf
    logging_tmaze.conf = conf
    # set logger
    global logger
    logger = logging_tmaze.Logger(alg_name='LazyMCTS')
    # simulate with lazy mcts
    mcts_roots = {
        Observation.START_UP: LazyMCTS(init_position=conf.INITIAL_POSITION, init_obs=Observation.START_UP),
        Observation.START_DOWN: LazyMCTS(init_position=conf.INITIAL_POSITION, init_obs=Observation.START_DOWN),
    }
    goal_count, first_return = 0, 0
    for ep in range(conf.NUM_EPISODE):
        init_obs: Observation = random.choice([Observation.START_UP, Observation.START_DOWN])
        env = TMazeEnv(init_obs=init_obs)
        mcts = mcts_roots[init_obs]
        mcts.init_histories()
        while not mcts.last_state.episode_end:
            policy = mcts.policy()
            action = np.random.choice(list(Action), p=policy)
            mcts.step_action(env, action)
            # logging (only works if conf.LOG_LEVEL <= logging.DEBUG)
            logger.during_episode(ep, env, mcts.last_state, action, policy=None,
                                  etc={},
                                  freq=0.0003, freq_at_episode_end=0.03)

        # update tree
        mcts.backpropagate(ep)

        # count goaled state for last conf.LOG_INTERVAL training iteration
        logger.after_episode(ep, mcts.last_state.reward == Reward.GOAL, mcts.first_return)

        # show last 5 simulated paths
        # if ep > args.num_episode - 6:
        #    log.info(f'{"=" * 5} {ep}: {init_obs.get_action()} {"=" * 5}')
        #    for p in mcts.mcts.path:
        #        log.debug(f"\tN: {p.n}, Q: {p.q}, Pos: {p.state.position}")

    logger.after_training()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-c", "--config_name", type=str, default="lazymcts_conf_demo")
    args = parser.parse_args()
    conf = getattr(conf_tmaze, args.config_name)
    main()
