from __future__ import annotations

import argparse
import dataclasses
import random
from collections import OrderedDict
from typing import List

import numpy as np
from logzero import logger as log

import conf_tmaze
import env_tmaze
import logging_tmaze
from conf_tmaze import MCTSIndex, MCTSMode, MCTSPolicy
from env_tmaze import Action, Observation, TMazeEnv, TMazeState

logger = None


@dataclasses.dataclass
class NodeStats:
    n: float = 1.0
    q: float = 0.0


def softmax(x: np.ndarray) -> np.ndarray:
    y = np.exp(conf.BETA * (x - np.max(x)))
    return y / np.sum(y)


def ucb(c: MCTSNode, log_sum_n: float) -> float:
    return c.q + conf.C * np.sqrt(log_sum_n / c.n)  # Eq. 16


def ucb_alphazero(c: MCTSMode, p: float, sqrt_sum_n: float) -> float:
    return c.q + conf.C * p * sqrt_sum_n / (c.n + 1.0001 - NodeStats.n)


class MCTSNode(NodeStats):
    def __init__(self, state: TMazeState = TMazeState(), parent: MCTSNode = None, caused_by: Action = None):
        self.state = state
        self.parent = parent
        self.caused_by: Action = caused_by
        self.children: OrderedDict[Action, NodeStats] = OrderedDict({a: NodeStats() for a in list(Action)})
        self.untried_actions: List[Action] = list(Action)
        self.untried_prob: np.ndarray = np.array([1 / len(Action)] * len(Action))

    def pop_untried_action(self) -> Action:
        aidx = np.random.randint(len(self.untried_actions))
        return self.untried_actions.pop(aidx)

    def fully_expanded(self) -> bool:
        return len(self.untried_actions) == 0

    def comp_children_indices(self, prior: np.ndarray = None) -> list:
        sum_n = sum([c.n for c in self.children.values()])
        if conf.MCTS_INDEX == MCTSIndex.UCB:
            log_sum_n = np.log(sum_n)
            return [ucb(c, log_sum_n) for c in self.children.values()]
        elif conf.MCTS_INDEX == MCTSIndex.AlphaZero:
            sqrt_sum_n = np.sqrt(sum_n - len(Action) * NodeStats.n + 0.0001)
            return [ucb_alphazero(c, prior[i], sqrt_sum_n) for i,c in enumerate(self.children.values())]

    def best_child(self, prior: np.ndarray = None) -> MCTSNode:
        indices = self.comp_children_indices(prior)
        return list(self.children.values())[np.argmax(indices)]

    def pop_action(self, action: Action) -> Action:
        # fop LazyMCTS
        action = self.untried_actions.pop(self.untried_actions.index(action))
        if not self.fully_expanded():
            prob = 1 / len(self.untried_actions)
            self.untried_prob = np.array([prob if a in self.untried_actions else 0 for a in list(Action)])
        else:
            self.untried_prob = None
        return action

    def greedy_children(self, prior: np.ndarray = None) -> np.ndarray:
        # for LazyMCTS
        indices = self.comp_children_indices(prior)
        prob = np.zeros(len(Action))
        prob[np.argmax(indices)] = 1.0
        return prob  # one-hot vector

    def softmax_children(self, prior: np.ndarray = None) -> np.ndarray:
        # for LazyMCTS
        indices = self.comp_children_indices(prior)
        return softmax(indices)

    def untried(self, action: Action) -> bool:
        # for LazyMCTS
        return action in self.untried_actions

    def find_child(self, action: Action) -> MCTSNode:
        # for LazyMCTS
        return self.children.get(action)


class MCTS:
    def __init__(self, init_position: int, init_obs: Observation):
        state: TMazeState = TMazeState(position=init_position, obs=init_obs)
        self.root = MCTSNode(state=state)
        self.rewards = []
        self.path = []
        self.returns = []

    def init_histories(self):
        self.rewards = []
        self.path = []

    @property
    def first_return(self) -> float:
        return self.returns[0]

    def add_histories(self, node: MCTSNode):
        self.rewards.append(node.state.reward.value)
        self.path.append(node)

    def select(self, env: TMazeEnv):
        target: MCTSNode = self.root
        while not target.state.episode_end:
            if not target.fully_expanded():
                return self.expand(env, target)
            else:
                target = target.best_child()
                self.add_histories(target)
        return target

    def expand(self, env: TMazeEnv, target: MCTSNode) -> MCTSNode:
        state: TMazeState = target.state
        action: Action = target.pop_untried_action()
        state = env.step(state, action)
        # create new child under target
        new_child = MCTSNode(state=state, parent=target, caused_by=action)
        target.children[action] = new_child
        self.add_histories(new_child)
        return new_child

    def rollout(self, env: TMazeEnv, target: MCTSNode):
        state: TMazeState = target.state
        while not state.episode_end:
            # run virtual simulation
            possible_acts = list(Action)
            action = possible_acts[np.random.randint(len(possible_acts))]
            state = env.step(state, action)
            # append reward only
            self.rewards.append(state.reward.value)

    def comp_return(self):
        self.returns = []
        new_r = 0
        for r in self.rewards[::-1]:
            new_r = r + conf.GAMMA * new_r
            self.returns.insert(0, new_r)

    def backpropagate(self):
        self.comp_return()
        # update n & q
        for node, q_target in zip(self.path, self.returns):
            node.q += (q_target - node.q) / node.n
            node.n += 1

#
# def main():
#     # re-initialize config
#     env_tmaze.conf = conf
#     logging_tmaze.conf = conf
#     # check config parameters
#     assert conf.MCTS_INDEX != MCTSIndex.UCB
#     assert conf.MCTS_MODE != MCTSMode.Standard
#     assert conf.MCTS_POLICY != MCTSPolicy.Greedy
#     # set logger
#     global logger
#     logger = logging_tmaze.Logger(alg_name='MCTS')
#     # simulate with mcts
#     init_obs: Observation = random.choice([Observation.START_UP, Observation.START_DOWN])
#     mcts = MCTS(init_position=conf.INITIAL_POSITION, init_obs=init_obs)
#     env = TMazeEnv(position=conf.INITIAL_POSITION, init_obs=init_obs)
#     for ep in range(args.num_episode):
#         mcts.init_histories()
#         target = mcts.select(env)
#         mcts.rollout(env, target)
#         mcts.backpropagate()
#         # show last 3 simulated paths
#         if ep > args.num_episode - 4:
#             log.info(f'{"=" * 5} {ep} {"=" * 5}')
#             for p in mcts.path:
#                 log.debug(f"\tN: {p.n}, Q: {p.q}, Pos: {p.state.position}")
#
#
# if __name__ == "__main__":
#     parser = argparse.ArgumentParser()
#     parser.add_argument("-c", "--config_name", type=str, default="mcts_conf_demo")
#     args = parser.parse_args()
#     conf = getattr(conf_tmaze, args.config_name)
#     main()
