import copy
import torch
from torch import nn
import numpy as np
from PIL import ImageColor, Image, ImageDraw, ImageFont
import random

import cv2
import networks
import tools
import models as dv3_models
import torch_utils as tu
import fastrl_utils as frlu

import fastcore.all as fc
import models

from torchviz import make_dot
from IPython import embed as ipshell

from uiux import UIUX

to_np = lambda x: x.detach().cpu().numpy()
reshape = lambda tnsr: tnsr.reshape(list(tnsr.shape[:-2]) + [tnsr.shape[-1] ** 2])

class ImaginationActorCritic(nn.Module):
    def __init__(self, config, world_model, input_size, num_actions, stop_grad_actor=True, prefix=''):
        super().__init__()
        self._config = config
        self._wm = world_model
        self._stop_grad_actor = stop_grad_actor
        self._prefix = prefix
        self._config = config


        if self._prefix == "manager":
            actor_act = config.manager_actor['act']
            actor_norm = config.manager_actor['norm']
            actor_dist = config.manager_actor['dist']
            actor_outscale = config.manager_actor['outscale']
            actor_lr = config.manager_lr
            ac_opt_eps = config.manager_opt_eps
        elif self._prefix == "worker":
            actor_act = config.worker_actor['act']
            actor_norm = 'LayerNorm' # config.worker_actor['norm']
            actor_dist = "onehot" #config.worker_actor['dist']
            actor_outscale = 1.0 #config.worker_actor['outscale']
            actor_lr = config.worker_lr
            ac_opt_eps = 1e-5 #config.ac_opt_eps
        else:
            raise NotImplementedError(self._prefix)

        self.actor = networks.ActionHead(
            input_size,
            num_actions,
            config.actor_layers,
            config.units,
            actor_act,
            actor_norm,
            actor_dist,
            config.actor_init_std,
            config.actor_min_std,
            config.actor_max_std,
            config.actor_temp,
            outscale=actor_outscale,
            unimix_ratio=config.action_unimix_ratio,
        )
        # self.actor = networks.MLP(
        #     input_size,
        #     num_actions,
        #     config.actor_layers,
        #     config.units,
        #     actor_act,
        #     actor_norm,
        #     actor_dist,
        #     config.actor_init_std,
        #     outscale=actor_outscale,
        #     unimix_ratio=config.action_unimix_ratio,
        #     device=config.device,
        # )
        
        # if value_head == "symlog_disc":
        self.value = networks.MLP(
            input_size,
            (255,),
            config.value_layers,
            config.units,
            config.act,
            config.norm,
            "symlog_disc",
            outscale=0.1,
            device=config.device,
        )


        if config.slow_value_target:
            self._slow_value = copy.deepcopy(self.value)
            self._updates = 0
            
        self._simple_actor_opt = torch.optim.Adam(
            self.actor.parameters(),
            lr=actor_lr,
            weight_decay=config.weight_decay,
        )

        self._simple_value_opt = torch.optim.Adam(
            self.value.parameters(),
            lr=config.value_lr,
            weight_decay=config.weight_decay,
        )

        if use_opt := False:
            pass
            # kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp)
            # self._actor_opt = tools.Optimizer(
            #     f"{self._prefix}_actor",
            #     self.actor.parameters(),
            #     actor_lr,
            #     ac_opt_eps,
            #     config.actor_grad_clip,
            #     **kw,
            # )
            # self._value_opt = tools.Optimizer(
            #     f"{self._prefix}_value",
            #     self.value.parameters(),
            #     config.value_lr,
            #     ac_opt_eps,
            #     config.value_grad_clip,
            #     **kw,
            # )

        if self._config.reward_EMA:
            self.reward_ema = dv3_models.RewardEMA(device=self._config.device)
    
    def _train(
        self,
        start=None,
        objective=None,
        repeats=None,
        itraj=None,
    ):
        if itraj == None: raise ValueError("Must provide imag_traj")
        
        self._update_slow_target()
        metrics = {}

        ifeat = itraj["feat"]
        istate = itraj["state"]
        iaction = itraj["action"]
        ireward = itraj["reward"]
        icont = itraj["cont"]
        igoal = itraj["goal"] if "goal" in itraj else None # NOTE: goals become actions for the manager
    
        with tools.RequiresGrad(self.actor):
            state_ent = self._wm.dynamics.get_dist(istate).entropy()
            actor_ent = self.actor(ifeat).entropy() # NOTE: there's not detach on ifeat here
            target, weights, base = self._compute_target(
                icont, ifeat, istate, iaction, igoal, ireward, self.calc_value
                )
            
            if 'weights' in itraj:
                iweights = itraj['weights']
                assert iweights.shape == weights.shape, f"itraj['weights'].shape {itraj['weights'].shape} != weights.shape {weights.shape}"

            # if self._config.make_dots: make_dot(base, params=dict(self.actor.named_parameters())).render(f"{self._prefix}_base", format="png")

            # NOTE: there's a lot of target regularization in DHafner's version
            actor_loss, mets = self._compute_actor_loss(
                ifeat,
                istate,
                iaction,
                igoal,
                target,
                state_ent,
                weights,
                base,
            )
            
            metrics.update(mets)
            if self._prefix == "worker" or self._prefix == "manager":
                # if self._config.debug_worker_extrinsic_reward:
                #     target_reward = itraj["reward_extr"]
                # else:
                #     target_reward = itraj["reward_goal"]

                # value_input = torch.cat([ifeat], dim=-1)
                value_input = ifeat
            elif self._prefix == "manager":
                value_input = ifeat
            else:
                raise NotImplementedError

        detached_istate = {k: v.detach() for k, v in istate.items()}
        ### Train critics
        '''
        # NOTE: Targets are calculated with the slower target networks, then loss is calculated with the active value networks agaunst the slower targets. The alternate way in models.py (dreamerv3?) is to use the active netowrk for the targets, and then add the negative log-likelihood of the slow network targets. i.e. -value.log_prob(value_target) - value.log_prob(slow_value_target)
        '''
        with tools.RequiresGrad(self.value):
            slow_value_fn = lambda x: self._slow_value(x.detach()).mode()
            extr_dist = self.value(value_input[:-1].detach())
            # (time, batch, 1), (time, batch, 1) -> (time, batch)
            target = torch.stack(target, dim=1)
            value_loss = -extr_dist.log_prob(target.detach())
            slow_target = slow_value_fn(value_input[:-1])
            if self._config.slow_value_target:
                value_loss = value_loss - extr_dist.log_prob(
                    slow_target.detach()
                )

            # (time, batch, 1), (time, batch, 1) -> (1,)
            value_loss = torch.mean(weights[:-1] * value_loss[:, :, None])

        metrics["state_entropy"] = to_np(torch.mean(state_ent))
        with tools.RequiresGrad(self):
            self._simple_actor_opt.zero_grad()
            actor_loss.backward()
            self._simple_actor_opt.step()

            self._simple_value_opt.zero_grad()
            value_loss.backward()
            self._simple_value_opt.step()

            # metrics.update(self._actor_opt(actor_loss, self.actor.parameters(), named_parameters=actor_named_params))
            # metrics.update(self._value_opt(value_loss, self.value.parameters(), named_parameters=value_named_params))

        metrics.update(tools.tensorstats(extr_dist.mode(), "value"))
        metrics.update(tools.tensorstats(value_loss.detach(), "value_loss"))
        metrics.update(tools.tensorstats(target, "target"))
        # metrics.update(tools.tensorstats(extr_target, "extr_target"))
        metrics.update(tools.tensorstats(ireward, "ireward"))
        
        if self._config.actor_dist in ["onehot"]:
            metrics.update(
                tools.tensorstats(
                    torch.argmax(iaction, dim=-1).float(), "iaction"
                )
            )
        else:
            metrics.update(tools.tensorstats(iaction, "iaction"))
        metrics.update(tools.tensorstats(actor_loss.detach(), "actor_loss"))

        # Name the metrics with the prefix
        prefixxed_metrics = {}
        for k,v in metrics.items():
            if self._prefix not in k:
                prefixxed_metrics[self._prefix+"_"+k] = v
            else:
                prefixxed_metrics[k] = v


        return ifeat, istate, iaction, weights, prefixxed_metrics

    def _compute_target(
        self, icont, ifeat, istate, iaction, igoal, reward, value_network_fn
    ):
        discount = self._config.discount * icont
        if self._config.future_entropy: raise NotImplementedError

        if self._prefix == "worker" or self._prefix == "manager":
            inp = torch.cat([ifeat], dim=-1)
        else:
            raise NotImplementedError(self._prefix)
        
        value = value_network_fn(inp) # pass in a value function so we can compute the target for separate value networks # NOTE: DHafner def does this better

        target = tools.lambda_return( 
            reward, 
            value[:-1],
            discount[1:],
            bootstrap=value[-1],
            lambda_=self._config.discount_lambda,
            axis=0,
        )
        weights = torch.cumprod(
            torch.cat([torch.ones_like(discount[:1]), discount[:-1]], 0), 0
        ).detach()

        return target, weights, value[:-1] # see note above about lining up value and reward

    def calc_value(self, inp):
        extr_value = self.value(inp).mode() # extr_reward_weight == goal_reward_weight for the worker TODO: separate variables for clarity

        return self._config.extr_reward_weight*extr_value

    def _compute_actor_loss(
        self,
        ifeat,
        istate,
        iaction,
        igoal,
        target,
        state_ent,
        weights,
        base,
    ):
        metrics = {}
        inp = ifeat
        # recalculate the entropy to detach it from the graph
        inp = inp.detach() if self._stop_grad_actor else inp
        policy = self.actor(inp)
        actor_ent = policy.entropy()

        # Q-val for actor is not transformed using symlog
        target = torch.stack(target, dim=1)
        if self._config.reward_EMA:
            offset, scale = self.reward_ema(target)
            normed_target = (target - offset) / scale
            normed_base = (base - offset) / scale

            adv = normed_target - normed_base
            metrics.update(tools.tensorstats(normed_target, "normed_target"))
            values = self.reward_ema.values
            metrics["EMA_005"] = to_np(values[0])
            metrics["EMA_095"] = to_np(values[1])

        if self._config.imag_gradient == "reinforce":
            actor_target = (
                policy.log_prob(iaction)[:-1][:, :, None]
                * (target - self.calc_value(inp[:-1])).detach()
            )
        else:
            raise NotImplementedError(self._config.igradient)
        
        if self._config.actor_entropy > 0:
            actor_entropy = self._config.actor_entropy * actor_ent[:-1][:, :, None]
            metrics["actor_entropy"] = to_np(torch.mean(actor_entropy))
            metrics["entropy_scale"] = self._config.actor_entropy

        metrics["actor_target_unnormalized"] = to_np(actor_target.mean())
        metrics.update({f'actor_entropy_unweighted': to_np(actor_ent.mean())})
        
        actor_target = actor_target + actor_entropy
        actor_loss = -torch.mean(weights[:-1] * actor_target)
        return actor_loss, metrics

    def _update_slow_target(self):
        if self._config.slow_value_target:
            if self._updates % self._config.slow_target_update == 0:
                mix = 1.0 # self._config.slow_target_fraction # NOTE: value set from dhafner impl
                for s, d in zip(self.value.parameters(), self._slow_value.parameters()):
                    d.data = mix * s.data + (1 - mix) * d.data
            self._updates += 1

class ImaginationAE(nn.Module):
    def __init__(self, config):
        super().__init__()
        fc.store_attr()

        self.goal_enc = networks.MLP(
            config.dyn_deter,
            config.skill_shape, # size of output [8x8]
            **config.goal_encoder
        )
        self.goal_dec = networks.MLP(
            np.prod(config.skill_shape),
            config.dyn_deter,
            **config.goal_decoder
        )
        kw = dict(wd=config.goal_ae_wd, opt=config.opt, use_amp=False)
        self.goal_ae_opt = tools.Optimizer(
            "goal_ae",
            list(self.goal_enc.parameters()) + list(self.goal_dec.parameters()),
            config.goal_ae_lr,
            config.goal_ae_opt_eps,
            config.goal_ae_grad_clip,
            **kw,
        )
        self.skill_prior = torch.distributions.independent.Independent(
            tools.OneHotDist(logits=torch.zeros(*config.skill_shape, device=config.device), unimix_ratio=self.config.action_unimix_ratio), 1
        )

        self.kl_autoadapt = tu.AutoAdapt((), **config.kl_autoadapt)


    def _train(self, batch):
        # feat = batch["deter"].detach()
        feat = batch["deter"]
        metrics = {}
        with tools.RequiresGrad(self.goal_enc):
            with tools.RequiresGrad(self.goal_dec):
                enc = self.goal_enc(feat)
                # pre_loss_probs = enc.base_dist.probs.detach().cpu() ## DEBUG
                x = enc.sample() # discrete one-hot grid
                x = x.reshape([x.shape[0], x.shape[1], -1]) # (time, batch, feat0*feat1)
                dec = self.goal_dec(x)
                rec = -dec.log_prob(feat.detach()) # DHafner's original
                if self.config.goal_vae_kl:
                    # NOTE: Should kl and recreation loss be scaled by scheduled constants as in the world model?
                    kl = kl_nonorm = torch.distributions.kl.kl_divergence(enc, self.skill_prior)
                    if self.config.goal_vae_kl_autoadapt:
                        kl, mets = self.kl_autoadapt(kl_nonorm)
                        metrics.update({f'goalkl_{k}':to_np(v) for k, v in mets.items()})
                    else:
                        # simple kl adaptations
                        kl *= self.config.goal_vae_kl_beta
                    metrics.update({'goalkl_nonorm': to_np(kl_nonorm.mean())})
                    assert kl.shape == rec.shape, (kl.shape, rec.shape)
                    loss = torch.mean(rec + kl)
                else:
                    loss = torch.mean(rec)
                
                metrics.update(self.goal_ae_opt(loss, list(self.goal_enc.parameters()) + list(self.goal_dec.parameters())))
        return metrics

class ImaginationAgent(nn.Module):
    def __init__(self, config, logger, worker: ImaginationActorCritic, manager: ImaginationActorCritic, goalae: ImaginationAE, _wm: models.WorldModel, dataset, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        fc.store_attr()
        self._steps = frlu.count_steps(self.config.traindir); self._train_steps = 0; 
        self.train_period = 2; self.log_period = 1e3; self.log_vid_period = 5e3
        self._pretrained = False; self._metrics = {}; 
        self.logger.step = self._steps

    def __call__(self, obs, done, state, training=True):
        if state is not None and done.any(): # zero out the state if we're done
            mask = 1 - done
            for key in state[0].keys():
                for i in range(state[0][key].shape[0]):
                    state[0][key][i] *= mask[i]
            for i in range(len(state[1])):
                state[1][i] *= mask[i]

        if training:
            if self.config.pretrain and not self._pretrained:
                self._pretrained = True
                for i in range(self.config.pretrain):
                    self._train(next(self.dataset)); self._train_steps += 1
                self._log()

            elif self._steps % self.train_period == 0:
                self._train(next(self.dataset)); self._train_steps += 1

            if self._steps % self.log_period == 0: self._log()
            if self._steps % self.log_vid_period == 0: self._log_video()

        self._steps += 1
        self.logger.step = self._steps
        
        return self._policy(obs, state, training)

    def _policy(self, obs, state, training=False):
        if state is None:
            batch_size = len(obs["image"])
            latent = self._wm.dynamics.initial(len(obs["image"]))
            action = torch.zeros((batch_size, self.config.num_actions)).to(
                self.config.device
            )
        else:
            latent, action = state

        obs = self._wm.preprocess(obs)
        embed = self._wm.encoder(obs)
        latent, _ = self._wm.dynamics.obs_step(
            latent, action, embed, obs["is_first"], self.config.collect_dyn_sample
        )

        feat = self._wm.dynamics.get_feat(latent)
        actor_dist = self.worker.actor(feat)
        action = actor_dist.sample() if training else actor_dist.mode()

        # detach the latent state so we don't backpropagate through it
        latent = {k: v.detach() for k, v in latent.items()}
        action = action.detach()        

        log_prob = actor_dist.log_prob(action)        

        state = (latent, action)
        policy_output = {"action": action, "logprob": log_prob}
        return policy_output, state
    
    def _train(self, batch):
        post, prior, metrics = self._wm._train(batch); self._metrics.update(metrics)

        self._metrics.update(self.goalae._train(post))

        with tools.RequiresGrad(self.worker.actor):
            # ifeats, istates, iactions = frlu.imagine(self.config, self._wm.dynamics, post, self.worker.actor, self.config.imag_horizon+1)
            ifeats, istates, iactions, iskills, igoals = frlu.imagine_with_skills(self.config, self._wm.dynamics, post, self.worker.actor, self.manager.actor, self.goalae.goal_dec, self.config.imag_horizon+1)
        feats = self._wm.dynamics.get_feat(istates) # use start plus successor states
        icont = self._wm.heads["cont"](feats).mean
        ireward = self._wm.heads["reward"](feats).mode()[1:]

        itraj = {
            # "image": images,
            "stoch": istates["stoch"],
            "deter": istates["deter"],
            "logit": istates["logit"],
            "feat": ifeats,
            "action": iactions,
            "reward": ireward,
            "skill": iskills,
            "cont": icont,
        }

        # split itraj for manager and worker. Avoid the douple deep copy by giving the _second_ one the original
        wtraj = self.split_traj(frlu.deep_copy_dict(itraj, detach=True), deep_copy_input=False)
        mtraj = self.abstract_traj(itraj, deep_copy_input=False)

        # mtraj["reward"] = mtraj["reward"] # + 0.1 * self.expl_reward(mtraj["state"])[:1]
        mtraj["state"] = {"stoch": mtraj["stoch"], "deter": mtraj["deter"], "logit": mtraj["logit"]}
        *_, metrics = self.manager._train(itraj=mtraj); self._metrics.update(metrics)
        # itraj["state"] = {"stoch": itraj["stoch"], "deter": itraj["deter"], "logit": itraj["logit"]}
        # *_, metrics = self.manager._train(itraj=itraj); self._metrics.update(metrics)

        # wtraj["reward_goal"] = self.goal_reward(wtraj["state"], wtraj["goal"])[1:]
        wtraj["reward_goal"] = wtraj["reward"]
        # itraj["state"] = {"stoch": itraj["stoch"], "deter": itraj["deter"], "logit": itraj["logit"]}
        wtraj["state"] = {"stoch": wtraj["stoch"], "deter": wtraj["deter"], "logit": wtraj["logit"]}
        *_, metrics = self.worker._train(itraj=wtraj); self._metrics.update(metrics)
        # *_, metrics = self.worker._train(itraj=itraj); self._metrics.update(metrics)

    def _log(self):
        self._metrics["update_count"] = self._train_steps
        for k, v in self._metrics.items():
            self.logger.scalar(k, np.mean(v))
        self._metrics = {}
        self.logger.write()

    def _log_video(self):
        openl = self._wm.video_pred(next(self.dataset))
        self.logger.video("train_openl", to_np(openl))

    def get_goal(self, feat, sample=True, gen_image=False):
        skill = self.manager.actor(feat).sample() if sample else self.manager.actor(feat).mode()
        goal_deter = self.goalae.goal_dec(reshape(skill)).mode()
        goal_stoch = self._wm.dynamics.get_stoch(goal_deter)
        if gen_image:
            inp = torch.cat([reshape(goal_stoch), goal_deter], -1)
            # pad the tensor shape
            img = self._wm.heads["decoder"](inp[None, ...])["image"].mode().squeeze().detach().cpu().numpy()
            img = (img - img.min()) / (img.max() - img.min()) * 255
            print(img.min(), img.max())
            img = np.clip(img, 0, 255).astype(np.uint8)

        return skill, goal_deter, goal_stoch, img

    def split_traj(self, orig_traj, deep_copy_input=True):
        traj = frlu.deep_copy_dict(orig_traj) if deep_copy_input else orig_traj
        k = self.config.train_skill_duration
        assert len(traj['action']) % k == 1; (len(traj['action']) % k), "Trajectory length must be divisible by k+1"
        reshape = lambda x: x.reshape([x.shape[0] // k, k] + list(x.shape[1:]))
        for key, val in list(traj.items()):
            # print(f"director_models::split_traj::key: {key}, val.shape: {val.shape if hasattr(val, 'shape') else None}")
            val = torch.concat([0 * val[:1], val], 0) if 'reward' in key else val
            # (1 2 3 4 5 6 7 8 9 10) -> ((1 2 3 4) (4 5 6 7) (7 8 9 10))
            val = torch.concat([reshape(val[:-1]), val[k::k][:, None]], 1)
            # N val K val B val F... -> K val (N B) val F...
            val = val.permute([1, 0] + list(range(2, len(val.shape))))
            val = val.reshape(
                [val.shape[0], np.prod(val.shape[1:3])] + list(val.shape[3:]))
            val = val[1:] if 'reward' in key else val
            traj[key] = val
        # Bootstrap sub trajectory against current not next goal.

        if 'goal' in traj: traj['goal'] = torch.concat([traj['goal'][:-1], traj['goal'][:1]], 0)

        traj['weight'] = torch.cumprod(
            self.config.imag_discount * traj['cont'], axis=0) / self.config.imag_discount
        return traj
    
    ''' NOTE:
    These two deep copys use a huge amount of memory. Is there a way to combine them?
    '''

    def abstract_traj(self, orig_traj, deep_copy_input=True):
        traj = frlu.deep_copy_dict(orig_traj) if deep_copy_input else orig_traj
        traj["action"] = traj.pop("skill")
        k = self.config.train_skill_duration
        def reshape(x):
            # unfold into every length k subtrajectory
            x = x.unfold(0, k, 1) # (n, ..., k)
            return x.reshape([x.shape[0], k, *x.shape[1:-1]]) # reshape to match (n, k, ...)
        # reshape = lambda x: x.reshape([x.shape[0] // k, k] + list(x.shape[1:]))
        weights = torch.cumprod(reshape(traj["cont"][:-1]), 1).to(self.config.device)
        for key, value in list(traj.items()):
            if 'reward' in key:
                # if (traj[key] > 1.0).any():
                #     a = 1
                # NOTE: Paper says to use the sum, DHafner code uses the mean
                # traj[key] = (reshape(value) * weights).mean(1) # weighted mean of each subtrajectory
                traj[key] = (reshape(value) * weights).sum(1) # weighted sum of each subtrajectory
            elif key == 'cont':
                traj[key] = torch.concat([value[:1], reshape(value[1:]).prod(1)], 0)
            else:
                traj[key] = torch.concat([reshape(value[:-1])[:, 0], value[-1:]], 0)
        traj['weight'] = torch.cumprod(
            self.config.imag_discount * traj['cont'], 0) / self.config.imag_discount
        
        ashape = traj["action"].shape
        traj["action"] = traj["action"].reshape(*ashape[:len(ashape)-1], *self.config.skill_shape)
        return traj
        
    def goal_reward(self, state, goal):
        # cosine_max reward
        h = state["deter"]
        goal = goal.detach()
        gnorm = torch.linalg.norm(goal, dim=-1, keepdim=True) + 1e-12
        hnorm = torch.linalg.norm(h, dim=-1, keepdim=True) + 1e-12
        norm = torch.max(gnorm, hnorm)
        return torch.einsum("...i,...i->...", goal / norm, h / norm) # NOTE

    def expl_reward(self, state):
        # elbo (Evidence Lower BOund) reward
        feat = state["deter"]
        enc = self.goal_enc(feat) 
        x = enc.sample() # produces a skill_dim x skill_dim
        x = x.reshape([x.shape[0], x.shape[1], -1]) # (time, batch, feat0*feat1)
        dec = self.goal_dec(x) # decode back the feature dimensions

        return ((dec.mode() - feat) ** 2).mean(-1) # NOTE: This can probably be a log likelihood under the MSE distribution

    def save(self, logdir, note=''):
        torch.save(self.state_dict(), logdir / f"model{note}.pt")

    def load(self, logdir, note=''):
        self.load_state_dict(torch.load(logdir / f"model{note}.pt"))