from __future__ import annotations
from typing import TYPE_CHECKING

import os
import inspect
import logging
import logzero
import datetime
import json
import numpy as np

from conf_tmaze import Config as conf
from env_tmaze import Action
if TYPE_CHECKING:
    import torch
    from env_tmaze import TMazeEnv, TMazeState


def stack_info(stack_level: int = 3) -> dict:
    stack = inspect.stack()[stack_level]
    return {
        'stack_filename': stack.filename,
        'stack_module': os.path.splitext(os.path.basename(stack.filename))[0],
        'stack_lineno': stack.lineno,
        'stack_function': stack.function,
    }


def to_json(item):
    if not isinstance(item, type) and hasattr(item, '__dict__'):
        return item.__dict__
    else:
        return str(item)


class Logger:
    def __init__(self, alg_name: str):
        self.alg_name: str = alg_name
        self.level = conf.LOG_LEVEL
        self.default_logger: logging.Logger = logzero.setup_logger(
            name='default',
            level=conf.LOG_LEVEL,
            formatter=logging.Formatter(conf.LOG_DEFAULT_FORMAT,datefmt='%Y/%m/%d %H:%M:%S'),
        )
        self.simple_logger: logging.Logger = logzero.setup_logger(
            name='simple',
            level=conf.LOG_LEVEL,
            formatter=logging.Formatter(conf.LOG_SIMPLE_FORMAT)
        )
        self.num_total: int = None
        self.num_success: float = None
        self.sum_return: float = None
        self.episodes: list = []
        self.success_rates: list = []
        self.average_returns: list = []
        self._at_first()

    def _reset(self):
        self.num_total = 0
        self.num_success = 0.0
        self.sum_return = 0.0

    def _at_first(self):
        self.default_logger.info(f"Start T-Maze {self.alg_name} simulation", extra=stack_info(3))
        self.simple_logger.info(f"Len: {conf.CORRIDOR_LENGTH}, Ini_posi: {conf.INITIAL_POSITION}, Iter: {conf.NUM_EPISODE}")
        self.simple_logger.info(json.dumps({"config":conf}, default=to_json, indent=6))
        self._reset()

    def during_episode(
        self, ep: int, env: TMazeEnv, state: TMazeState, action: Action, policy: torch.tensor, etc: dict = {},
        freq: float = .001, freq_at_episode_end: float = .01
    ):
        def basic_log():
            self.default_logger.debug(f"During episode {ep}", extra=stack_info(3))
            self.simple_logger.debug(f"episode end: {state.episode_end}")
            self.simple_logger.debug(env.init_obs)
            self.simple_logger.debug(f"time: {state.timestep}, position: {state.position}")
            self.simple_logger.debug(f"selected action: {action}")
            if policy is not None:
                self.simple_logger.debug("policy: up=%.03f, right=%.03f, down=%.03f, left=%0.3f" % tuple(policy.tolist()))

        if conf.LOG_LEVEL <= logging.DEBUG:
            if not state.episode_end:
                if np.random.uniform(0,1) < freq:
                    basic_log()
            else:
                if np.random.uniform(0,1) < freq_at_episode_end:
                    basic_log()
                    if 'rpg_policy' in etc.keys():
                        rpg_policy = etc['rpg_policy'].tolist()
                        self.simple_logger.debug("rpg_policy: up=%.03f, right=%.03f, down=%.03f, left=%0.3f" % tuple(rpg_policy))
                    if 'mcts_policy' in etc.keys():
                        mcts_policy = etc['mcts_policy'].tolist()
                        self.simple_logger.debug("mcts_policy: up=%.03f, right=%.03f, down=%.03f, left=%0.3f" % tuple(mcts_policy))
                    if 'rpg' in etc.keys():
                        self.simple_logger.debug(state.reward)
                        self.simple_logger.debug(
                            f"T: {len(etc['rpg'].rewards)}, "
                            f"left: {etc['rpg'].actions.count(Action.LEFT)}, "
                            f"right: {etc['rpg'].actions.count(Action.RIGHT)}, "
                            f"up: {etc['rpg'].actions.count(Action.UP)}, "
                            f"down: {etc['rpg'].actions.count(Action.DOWN)}"
                        )

    def after_episode(self, ep: int, is_success: bool, initial_state_return: float):
        self.num_total += 1
        self.num_success += int(is_success)
        self.sum_return += initial_state_return
        if self.num_total == conf.LOG_INTERVAL:
            success_rate = self.num_success / conf.LOG_INTERVAL
            average_return = self.sum_return / conf.LOG_INTERVAL
            self.episodes.append(ep+1)
            self.success_rates.append(success_rate)
            self.average_returns.append(average_return)
            self.default_logger.info(
                f"Episode {ep} ~ success rate: {success_rate:.02f}, average return: {average_return:.03f}",
                extra=stack_info(2)
            )
            self._reset()

    def after_training(self):
        # prep
        result = np.transpose(np.array([self.episodes, self.success_rates, self.average_returns]))
        # logout the result
        self.default_logger.info(f"Finished {conf.NUM_EPISODE} episodes of training !!!", extra=stack_info(3))
        self.simple_logger.info(" episode \t success rate \t average return")
        for i in range(result.shape[0]):
            self.simple_logger.info(' %07d \t     %.02f     \t     %.03f' % tuple(result[i]))
        # output csv file
        if conf.LOG_OUTPUT_FILES:
            now = datetime.datetime.now()
            result_filename = conf.LOG_RESULT_FILENAME.format(now, conf.CORRIDOR_LENGTH, conf.INITIAL_POSITION)
            config_filename = conf.LOG_CONFIG_FILENAME.format(now, conf.CORRIDOR_LENGTH, conf.INITIAL_POSITION)
            np.savetxt(
                result_filename, result, delimiter=",",
                fmt=["%.0f", "%.2f", "%.3f"],
                header="episode,success rate,average return", comments=""
            )
            with open(config_filename, mode='w', encoding='utf-8') as f:
                json.dump([{'algo':self.alg_name}, {'config':conf}], f, default=to_json, indent=2)

            self.simple_logger.info(
                f'saved the result and config as \"{result_filename}\" and \"{config_filename}\"')
        # normal termination
        self.simple_logger.info('end of training ...',  extra=stack_info(3))
