from collections.abc import Mapping
import argparse
import ruamel.yaml as yaml
import os
import time
import pathlib
import tools
import datetime
import collections
from parallel import Parallel, Damy

from IPython import embed as ipshell
import numpy as np
import torch
import torch.distributions as torchd

import envs.wrappers as wrappers

from fastai_utils import to_cpu, show_image, scale_img, count_steps, make_dataset, make_env, make_animation

def recursive_update(base, update):
    for key, value in update.items():
        if isinstance(value, dict) and key in base:
            recursive_update(base[key], value)
        else:
            base[key] = value


def get_config(path="~/workspace/fastrl", env_name="pinpad"):
    parser = argparse.ArgumentParser()
    parser.add_argument("--configs", nargs="+")
    args, remaining = parser.parse_known_args()
    args.configs = [env_name]
    y = yaml.YAML(typ="safe", pure=True)
    configs = y.load(
        # (pathlib.Path(os.path.expanduser("~/workspace/dreamerv3-torch/")) / "configs.yaml").read_text()
        (pathlib.Path(os.path.expanduser(path)) / "configs.yaml").read_text()
    )


    name_list = ["defaults", *args.configs] if args.configs else ["defaults"]
    defaults = {}
    for name in name_list:
        recursive_update(defaults, configs[name])
    parser = argparse.ArgumentParser()
    for key, value in sorted(defaults.items(), key=lambda x: x[0]):
        arg_type = tools.args_type(value)
        parser.add_argument(f"--{key}", type=arg_type, default=arg_type(value))
    args = parser.parse_args('')
    print(env_name, args.task)
    config = args
    return config

def setup_paths_and_logging(config, logdir, delete_train_eps=True, notes=''):
    tools.set_seed_everywhere(config.seed)
    if config.deterministic_run:
        tools.enable_deterministic_run()
        
    config.traindir = config.traindir or logdir / "train_eps"
    config.evaldir = config.evaldir or logdir / "eval_eps"
    config.humandir = config.humandir or logdir / "human_eps"
    config.steps //= config.action_repeat
    config.eval_every //= config.action_repeat
    config.log_every //= config.action_repeat
    config.time_limit //= config.action_repeat

    if delete_train_eps:
        # delete everything in traindir
        import shutil
        if config.traindir.exists(): shutil.rmtree(config.traindir)
        if config.humandir.exists(): shutil.rmtree(config.humandir)
        if config.evaldir.exists(): shutil.rmtree(config.evaldir)

    print("Logdir", logdir)
    logdir.mkdir(parents=True, exist_ok=True)
    config.traindir.mkdir(parents=True, exist_ok=True)
    config.evaldir.mkdir(parents=True, exist_ok=True)
    config.humandir.mkdir(parents=True, exist_ok=True)

    datetime_string = f"{datetime.datetime.now().strftime('%m-%d_%H-%M-%S_%Y')}_{config.task}_{notes}"
    full_logdir = logdir / datetime_string
    logger = tools.Logger(full_logdir, 0)
    return logger, full_logdir

class RandomAgent():
    def __init__(self, acts, config) -> None:
        if hasattr(acts, "discrete"):
            self.random_actor = tools.OneHotDist(
                torch.zeros(config.num_actions).repeat(config.envs, 1)
            )
        else:
            self.random_actor = torchd.independent.Independent(
                torchd.uniform.Uniform(
                    torch.Tensor(acts.low).repeat(config.envs, 1),
                    torch.Tensor(acts.high).repeat(config.envs, 1),
                ),
                1,
            )

    def __call__(self, o, d, s):
        action = self.random_actor.sample()
        logprob = self.random_actor.log_prob(action)
        return {"action": action, "logprob": logprob}, None

class CCW_PP_Agent(): # go counter clockwise around the pinpad
    pp_actions = [(0, 0), (0, 1), (0, -1), (1, 0), (-1, 0)] 
    def __init__(self, acts, config) -> None:
        self.action_idx_order = [1, 3, 2, 4]
        self.current_action_idx = 0
        self.action_count = 0
        self.config = config
        self.ra = RandomAgent(acts, config)


    def __call__(self, o, d, s):
        self.action_count += 1
        if self.action_count % 10 == 0: self.current_action_idx = (self.current_action_idx + 1) % len(self.action_idx_order)
        action = torch.zeros(self.config.num_actions)
        action[self.action_idx_order[self.current_action_idx]] = 1
        return {"action": action[None, :], "logprob": self.ra.random_actor.log_prob(action)}, None


def get_datasets(config, logger, logdir):
    if config.offline_evaldir: raise NotImplementedError

    directory = config.evaldir
    eval_eps = tools.load_episodes(directory, limit=1)
    make = lambda mode: make_env(config, mode)
    train_envs = [make("train") for _ in range(config.envs)]
    eval_envs = [make("eval") for _ in range(config.envs)]
    human_envs = [make("human") for _ in range(config.envs)]

    if config.parallel: raise NotImplementedError

    train_envs = [Damy(env) for env in train_envs]
    eval_envs = [Damy(env) for env in eval_envs]
    # human_envs = [wrappers.HumanReward(env, config.human_reward) for env in human_envs]
    human_envs = [Damy(env) for env in human_envs]

    acts = train_envs[0].action_space
    config.num_actions = acts.n if hasattr(acts, "n") else acts.shape[0]

    train_eps = tools.load_episodes(config.traindir, limit=config.dataset_size, reverse=False)
    human_eps = tools.load_episodes(config.humandir, limit=config.dataset_size, reverse=False)

    state = None
    ra = RandomAgent(acts, config)
    if hasattr(config, "use_ccw_random_agent") and config.use_ccw_random_agent: ra = CCW_PP_Agent(acts, config)
    
    if not config.offline_traindir:
        print(config.prefill)
        prefill = max(0, config.prefill - count_steps(config.traindir))
        print(f"Prefill dataset ({prefill} steps).")

        state = tools.simulate(
            ra,
            train_envs,
            train_eps,
            config.traindir,
            logger,
            limit=config.dataset_size,
            steps=prefill,
        )
    else:
        print("Prefilling from offline dataset ", config.offline_traindir)
        # load the npz files into train_eps
        train_eps = collections.OrderedDict()
        for traindir in config.offline_traindir:
            tmp_eps = tools.load_episodes(traindir, limit=config.dataset_size, reverse=False)
            train_eps.update(tmp_eps)
        
        print(f"Loaded {len(train_eps)} episodes from offline dataset")

    tools.simulate(
        ra,
        human_envs,
        human_eps,
        config.humandir,
        logger,
        limit=config.dataset_size,
        steps=1, # Just set things up
    )


    train_dataset = make_dataset(train_eps, config)
    eval_dataset = make_dataset(eval_eps, config)

    hubatch_size, hubatch_length = 2, 64 
    human_dataset = make_dataset(human_eps, config, batch_size=hubatch_size, batch_length=hubatch_length)

    total_trajectories = 0
    for k,v in train_eps.items():
        # print(f"{k} {v['image'].shape if hasattr(v['image'], 'shape') else v['image'].__len__()}")
        total_trajectories += v["image"].shape[0] if hasattr(v['image'], 'shape') else len(v['image'])

    print(f"found {total_trajectories} trajectories. Assumed length is {config.batch_length}. Total steps {total_trajectories * config.batch_length:0.1e}")
    return train_dataset, eval_dataset, human_dataset, train_envs, eval_envs, human_envs, train_eps, eval_eps, human_eps

def imagine_carry(config, dynamics, start_wm, policy, horizon):
    flatten = lambda x: x.reshape([-1] + list(x.shape[2:]))
    start = {k: flatten(v) for k, v in start_wm.items()}
    if stop_grad_actor := True:
        start = {k: v.detach() for k, v in start.items()}

    def step(prev, step):
        state, _, _, skill, goal = prev
        
        feat = dynamics.get_feat(state) # z_state + h

        # NOTE: step is ignored in original code. does it mess with the flatten?
        # if goal is None or step % self.config.train_skill_duration == 0:
        #     skill, goal = self.sample_goal(state)

        # inp = torch.cat([feat.detach(), goal], -1)
        inp = torch.cat([feat.detach()], -1)

        action = policy.forward(inp).sample().to(config.device)
        succ = dynamics.img_step(state, action, sample=config.imag_sample)
        return succ, feat, action, skill, goal

    succ, sc_feats, sc_actions, sc_skills, sc_goals = tools.static_scan(step, [torch.arange(horizon)], (start, None, None, torch.ones([1,1]), torch.ones([1,1]))) # skil and goal to -1
    sc_states = {k: torch.cat([start[k].unsqueeze(0), v[:-1]], 0) for k, v in succ.items()}
    return sc_feats, sc_states, sc_actions, sc_skills, sc_goals

def imagine(config, dynamics, start, policy, horizon, repeats=None):
    if repeats: raise NotImplemented("repeats is not implemented in this version")
    flatten = lambda x: x.reshape([-1] + list(x.shape[2:]))
    start = {k: flatten(v) for k, v in start.items()}

    def step(prev, _):
        state, _, _ = prev
        feat = dynamics.get_feat(state)
        inp = feat.detach() if config.behavior_stop_grad else feat
        action = policy(inp).sample()
        succ = dynamics.img_step(state, action, sample=config.imag_sample)
        return succ, feat, action

    succ, feats, actions = tools.static_scan(
        step, [torch.arange(horizon)], (start, None, None)
    )
    states = {k: torch.cat([start[k][None], v[:-1]], 0) for k, v in succ.items()}

    return feats, states, actions

def imagine_with_skills(config, dynamics, start, policy, skill_policy, skill_decoder, horizon, starting_skill=None, starting_goal=None, sample=True):
    flatten = lambda x: x.reshape([-1] + list(x.shape[2:]))
    start = {k: flatten(v) for k, v in start.items()}

    if starting_skill is not None: # repeat the starting state to match the batch size. starting_skill is only done for human interaction
        start["deter"] = start["deter"].repeat(starting_goal.shape[1], 1)
        start["stoch"] = start["stoch"].repeat(starting_goal.shape[1], 1, 1)
        start["logit"] = start["logit"].repeat(starting_goal.shape[1], 1, 1)
        start = {k: v[None, ...] for k, v in start.items()}
        assert start["deter"].shape[:2] == starting_goal.shape[:2], f"{start['deter'].shape} {starting_goal.shape}"

    def step(prev, step):
        state, _, _, skill, goal = prev
        feat = dynamics.get_feat(state)

        if skill is None or (step % config.train_skill_duration == 0 and step > 0):
            skill = skill_policy(feat).sample()
            skill = skill.reshape([*skill.shape[:-2], -1])
            goal = skill_decoder(skill).mode()

        inp = feat.detach() if config.behavior_stop_grad else feat
        if not config.WORKER_PRIMARY: inp = torch.cat([inp, goal], -1)

        action = policy(inp).sample() if sample else policy(inp).mode()
        # succ = dynamics.img_step(state, action, sample=config.imag_sample)
        succ = dynamics.img_step(state, action, sample=sample)
        return succ, feat, action, skill, goal

    succ, feats, actions, skills, goals = tools.static_scan(
        step, [torch.arange(horizon)], (start, None, None, starting_skill, starting_goal)
    )
    states = {k: torch.cat([start[k][None], v[:-1]], 0) for k, v in succ.items()}

    return feats, states, actions, skills, goals

def play_sound(n=1, soundfn='bell.oga'):
    for _ in range(n):
        os.system(f'paplay /usr/share/sounds/freedesktop/stereo/{soundfn}')
        time.sleep(0.1)
def play_complete_sound(n=2): play_sound(n, 'bell.oga')
def play_intermediate_sound(n=1): play_sound(n, 'message.oga')

def save_config(logger, config):
    # Put a copy of current config into the logger's directory
    with open(logger._logdir / "config.yaml", "w") as f:
        y = yaml.YAML(typ="unsafe", pure=True)
        y.dump(vars(config), f)

    
def deep_copy_dict(d, detach=False):
    # account for nested dicts
    if detach:
        return {k:deep_copy_dict(v, detach) if isinstance(v, Mapping) else v.detach().clone() for k,v in d.items()}
    else:
        return {k:deep_copy_dict(v) if isinstance(v, Mapping) else v.clone() for k,v in d.items()}
    
def norm_dec_img(img):
    '''
    Normalize the image to 0-255 integer
    '''
    img = (img - img.min()) / (img.max() - img.min()) * 255
    img = np.clip(img, 0, 255).astype(np.uint8)
    return img