import copy
import torch
from torch import nn
import numpy as np
from PIL import ImageColor, Image, ImageDraw, ImageFont
import random

import cv2
import networks
import tools
import models as dv3_models
import torch_utils as tu

from torchviz import make_dot
from IPython import embed as ipshell

from uiux import UIUX

to_np = lambda x: x.detach().cpu().numpy()

class HierarchyBehavior(nn.Module):
    def __init__(self, config, world_model, stop_grad_actor=True, spoofed_world_model=None):
        super(HierarchyBehavior, self).__init__()
        self._use_amp = True if config.precision == 16 else False
        if self._use_amp:
            raise NotImplementedError("AMP is not implemented in this version")
        self._config = config
        self._world_model = world_model
        self._stop_grad_actor = stop_grad_actor

        # NOTE: The RSSM has MLPs that support a nxn goal space. Use that to figure out why this isn't working. 

        #### GOAL AE START
        # the goal autencoder converts the feature representation (h? 1024?) to a one-hot of the goal
        self.goal_enc = networks.MLP(
            config.dyn_deter,
            config.skill_shape, # size of output [8x8]
            **config.goal_encoder
        )
        self.goal_dec = networks.MLP(
            np.prod(config.skill_shape),
            config.dyn_deter,
            **config.goal_decoder
        )
        kw = dict(wd=config.goal_ae_wd, opt=config.opt, use_amp=self._use_amp)
        self.goal_ae_opt = tools.Optimizer(
            "goal_ae",
            list(self.goal_enc.parameters()) + list(self.goal_dec.parameters()),
            config.goal_ae_lr,
            config.goal_ae_opt_eps,
            config.goal_ae_grad_clip,
            **kw,
        )
        #### GOAL AE END

        self.kl_autoadapt = tu.AutoAdapt((), **config.kl_autoadapt)

        z_shape = config.dyn_stoch * config.dyn_discrete if config.dyn_discrete else config.dyn_stoch
        feat_size = z_shape + config.dyn_deter # z_stoch + h_deter
        goal_size = config.dyn_deter # h_goal
        '''
        Manager takes the world state {h,z} and outputs a skill (z_goal)
        '''
        self.manager = ImagActorCritic(config, world_model, input_shape=feat_size, num_actions=config.skill_shape, stop_grad_actor=stop_grad_actor, prefix="manager")
        
        '''
        Worker takes the world state and skill {h and h_goal} and outputs an action. It's reward does not consider Z because it cannot control the Z: https://github.com/danijar/director/issues/7
        '''
        self.worker = ImagActorCritic(config, world_model, input_shape=feat_size+goal_size, num_actions=config.num_actions, stop_grad_actor=stop_grad_actor, prefix="worker")

        self.skill_prior = torch.distributions.independent.Independent(
                tools.OneHotDist(logits=torch.zeros(*config.skill_shape, device=config.device), unimix_ratio=self._config.action_unimix_ratio), 1
            )
        print(f"skill_prior_entropy:{self.skill_prior.entropy():1.2f} {self.skill_prior.base_dist.entropy()}")
    
        if self._config.human_policy or self._config.debug_display_all:
            self.uiux = UIUX(config, world_model, self.manager, self.goal_dec, self.skill_prior, None)
            # self.uiux = UIUX(config, world_model, self.manager, self.goal_dec, self.skill_prior, spoofed_world_model)

    def goal_video_pred(self, data):
        n_frames = 64
        data = self._world_model.preprocess(data)
        embed = self._world_model.encoder(data)

        post, _ = self._world_model.dynamics.observe(
            embed[:6, :n_frames], data["action"][:6, :n_frames], data["is_first"][:6, :n_frames]
        )
        recon = self._world_model.heads["decoder"](self._world_model.dynamics.get_feat(post))["image"].mode()[:6]

        deter = post["deter"]
        goal = self.goal_enc(deter)
        # print("skill logits", goal.base_dist.logits)
        goal_sample = goal.sample()
        goal_sample = goal_sample.reshape([*goal_sample.shape[:-2], -1]) # (batch, time, feat*feat)
        dec_deter = self.goal_dec(goal_sample).mode()

        deter_stoch = self._world_model.dynamics.get_stoch(dec_deter)
        # deter_stoch = self._world_model.dynamics.get_stoch(deter)
        deter_stoch = deter_stoch.reshape([*deter_stoch.shape[:-2], -1])
        inp = torch.cat([deter_stoch, dec_deter], dim=-1)
        goal_reconstruction = self._world_model.heads["decoder"](inp)["image"].mode()

        truth = data["image"][:6, :n_frames] + 0.5
        model_reconstruction = recon + 0.5
        goal_reconstruction += + 0.5

        return torch.cat([truth, model_reconstruction, goal_reconstruction], 2)
    
    def sample_many_goals(self, data, step, logdir):
        n_frames = 8
        n_width = 10
        data = self._world_model.preprocess(data)
        embed = self._world_model.encoder(data)

        post, _ = self._world_model.dynamics.observe(
            embed[:6, :n_frames], data["action"][:6, :n_frames], data["is_first"][:6, :n_frames]
        )
        feat = self._world_model.dynamics.get_feat(post)
        self.manager.eval()
        skill_samples = self.manager.actor(feat).sample((n_width**2,))
        skill_samples = skill_samples.reshape([*skill_samples.shape[:-2], -1]) # (batch, time, feat*feat)
        goals = self.goal_dec(skill_samples).mode()
        stoch_goals = self._world_model.dynamics.get_stoch(goals)
        stoch_goals = stoch_goals.reshape([*stoch_goals.shape[:-2], -1])
        inp = torch.cat([stoch_goals, goals], dim=-1)
        inp = inp.permute(1, 0, 2, 3).squeeze(0)[0]
        goal_imgs = self._world_model.heads["decoder"](inp)["image"].mode()
        goal_imgs = to_np(goal_imgs).reshape(n_width, n_width, n_frames, 64, 64, 3)
        # save out as a png
        goal_imgs = np.concatenate(np.split(goal_imgs, n_width, axis=0), axis=4)
        goal_imgs = np.concatenate(np.split(goal_imgs, n_width, axis=1), axis=3)
        goal_imgs = goal_imgs.squeeze(0).squeeze(0)
        a = (goal_imgs+0.5)*255
        a = a.clip(0, 255)
        a = a.astype(np.uint8)
        # save out to disk
        for i in range(n_frames):
            Image.fromarray(a[i]).save(logdir / f"goal_manager_samples_{step}_{i}.png")

    def goal_pred(self, data):
        data = self._world_model.preprocess(data)
        
        reshape = lambda x: x.reshape([*x.shape[:-2], -1])

        # Look at a random subset of batches 
        random_batches = random.choices(range(data["image"].shape[0]), k=4)
        random_samples = random.choices(range(data["image"].shape[1]), k=1)

        actions = data["action"][random_batches, :]
        is_first = data["is_first"][random_batches, :]
        embed = self._world_model.encoder(data)[random_batches, :]

        states, _ = self._world_model.dynamics.observe(embed, actions, is_first)

        # feat = self._world_model.dynamics.get_feat(states)
        feat = states["deter"]
        enc = self.goal_enc(feat).sample()
        enc = reshape(enc) # (time, batch, feat0*feat1)
        dec = self.goal_dec(enc).mode()
        deter = dec
        
        # now add the stochastic state to the deterministic state, which is what the world model decoder requires
        stoch = self._world_model.dynamics.get_stoch(deter)
        # stoch = states["stoch"]
        stoch = reshape(stoch)
        inp = torch.cat([stoch, deter], dim=-1)

        model = self._world_model.heads["decoder"](inp)["image"].mode() + 0.5
        truth = data["image"][random_batches, :] + 0.5
        error = (model - truth + 1.0) / 2.0

        model = model[:, random_samples]
        truth = truth[:, random_samples]
        error = error[:, random_samples]
        
        goal_pred = torch.cat([truth, model, error], dim=2)

        # return model
        # print(f"director_models::goal_pred"); ipshell()
        return goal_pred
    

    def viz_goal(self, goal, stoch=None):
        if stoch == None:
            stoch = self._world_model.dynamics.get_stoch(goal)        
        stoch = stoch.reshape([*stoch.shape[:-2], -1])
        inp = torch.cat([stoch, goal], dim=-1)
        goal_img = self._world_model.heads["decoder"](inp)["image"].mode() + 0.5
        return goal_img

    def extr_reward(self, state):
        # extrR = self._world_model.heads["reward"](self._world_model.dynamics.get_feat(state)).mode() # NOTE: Original uses mean()[1:] 
        extrR = self._world_model.heads["reward"](self._world_model.dynamics.get_feat(state)).mean()
        return extrR

    def expl_reward(self, state):
        # elbo (Evidence Lower BOund) reward
        feat = state["deter"]
        enc = self.goal_enc(feat) 
        x = enc.sample() # produces a skill_dim x skill_dim
        x = x.reshape([x.shape[0], x.shape[1], -1]) # (time, batch, feat0*feat1)
        dec = self.goal_dec(x) # decode back the feature dimensions

        return ((dec.mode() - feat) ** 2).mean(-1) # NOTE: This can probably be a log likelihood under the MSE distribution

    def goal_reward(self, state, goal):
        # cosine_max reward
        h = state["deter"]
        goal = goal.detach()
        gnorm = torch.linalg.norm(goal, dim=-1, keepdim=True) + 1e-12
        hnorm = torch.linalg.norm(h, dim=-1, keepdim=True) + 1e-12
        norm = torch.max(gnorm, hnorm)
        return torch.einsum("...i,...i->...", goal / norm, h / norm) # NOTE

    def split_traj(self, traj):
        traj = traj.copy()
        k = self._config.train_skill_duration
        # print(f"Trajectory length must be divisible by k+1 {len(traj['action'])} % {k} != 1")
        assert len(traj['action']) % k == 1; (len(traj['action']) % k), "Trajectory length must be divisible by k+1"
        reshape = lambda x: x.reshape([x.shape[0] // k, k] + list(x.shape[1:]))
        for key, val in list(traj.items()):
            # print(f"director_models::split_traj::key: {key}, val.shape: {val.shape if hasattr(val, 'shape') else None}")
            val = torch.concat([0 * val[:1], val], 0) if 'reward' in key else val
            # (1 2 3 4 5 6 7 8 9 10) -> ((1 2 3 4) (4 5 6 7) (7 8 9 10))
            val = torch.concat([reshape(val[:-1]), val[k::k][:, None]], 1)
            # N val K val B val F... -> K val (N B) val F...
            val = val.permute([1, 0] + list(range(2, len(val.shape))))
            val = val.reshape(
                [val.shape[0], np.prod(val.shape[1:3])] + list(val.shape[3:]))
            val = val[1:] if 'reward' in key else val
            traj[key] = val
        # Bootstrap sub trajectory against current not next goal.
        traj['goal'] = torch.concat([traj['goal'][:-1], traj['goal'][:1]], 0)
        traj['weight'] = torch.cumprod(
            self._config.imag_discount * traj['cont'], axis=0) / self._config.imag_discount
        return traj

    def abstract_traj(self, orig_traj):
        traj = orig_traj.copy()
        traj["action"] = traj.pop("skill")
        k = self._config.train_skill_duration
        reshape = lambda x: x.reshape([x.shape[0] // k, k] + list(x.shape[1:]))
        weights = torch.cumprod(reshape(traj["cont"][:-1]), 1).to(self._config.device)
        for key, value in list(traj.items()):
            if 'reward' in key:
                # if (traj[key] > 1.0).any():
                #     a = 1
                # NOTE: Paper says to use the sum, DHafner code uses the mean
                # traj[key] = (reshape(value) * weights).mean(1) # weighted mean of each subtrajectory
                traj[key] = (reshape(value) * weights).sum(1) # weighted sum of each subtrajectory
            elif key == 'cont':
                traj[key] = torch.concat([value[:1], reshape(value[1:]).prod(1)], 0)
            else:
                traj[key] = torch.concat([reshape(value[:-1])[:, 0], value[-1:]], 0)
        traj['weight'] = torch.cumprod(
            self._config.imag_discount * traj['cont'], 0) / self._config.imag_discount
    

        # NOTE: This block, and other blocks around like it, are for investigating the possibility of an off-by-one error in the 
        #  the manager/worker trajectory split
        # if (traj["reward_extr"] > 1.0).any():
        #     r = traj["reward_extr"].detach().cpu().numpy()
        #     reward_index = np.unravel_index(np.argmax(r), r.shape)
        #     imagined_ts_index, time_index, _ = reward_index
        #     history = max(time_index - 3, 0)
        #     mr = self._world_model.heads["reward"](traj["feat"].detach()).mean()
            # if mr[reward_index] != r[reward_index]:
            #     print(f"director_models::abstract_traj::reward_extr: {r[reward_index]}, reward_model: {mr[reward_index]}")
            #     ipshell()

        return traj

    def train_goal_vae(self, data, metrics):
        feat = data["deter"].detach()
        with tools.RequiresGrad(self.goal_enc):
            with tools.RequiresGrad(self.goal_dec):
                enc = self.goal_enc(feat)
                # pre_loss_probs = enc.base_dist.probs.detach().cpu() ## DEBUG
                x = enc.sample() # discrete one-hot grid
                x = x.reshape([x.shape[0], x.shape[1], -1]) # (time, batch, feat0*feat1)
                dec = self.goal_dec(x)
                rec = -dec.log_prob(feat.detach()) # DHafner's original
                if self._config.goal_vae_kl:
                    # NOTE: Should kl and recreation loss be scaled by scheduled constants as in the world model?
                    kl = kl_nonorm = torch.distributions.kl.kl_divergence(enc, self.skill_prior)
                    if self._config.goal_vae_kl_autoadapt:
                        kl, mets = self.kl_autoadapt(kl_nonorm)
                        metrics.update({f'goalkl_{k}':to_np(v) for k, v in mets.items()})
                    else:
                        # simple kl adaptations
                        kl *= self._config.goal_vae_kl_beta
                    metrics.update({'goalkl_nonorm': to_np(kl_nonorm.mean())})
                    assert kl.shape == rec.shape, (kl.shape, rec.shape)
                    loss = torch.mean(rec + kl)
                else:
                    loss = torch.mean(rec)
        
                if self._config.make_dots:
                    # create a dictionary with the named parameters of the goal encoder and decoder
                    enc_dict = dict(self.goal_enc.named_parameters())
                    # prefix every key in the dictionary with 'goal_enc.' to avoid name clashes
                    enc_dict = {'goal_enc.' + k : v for k,v in enc_dict.items()}
                    # create a dictionary with the named parameters of the goal encoder and decoder
                    dec_dict = dict(self.goal_dec.named_parameters())
                    # prefix every key in the dictionary with 'goal_dec.' to avoid name clashes
                    dec_dict = {'goal_dec.' + k : v for k,v in dec_dict.items()}
                    # concatenate the two dictionaries
                    named_parameters = {**enc_dict, **dec_dict}
                else:
                    named_parameters = None

                
                metrics.update(self.goal_ae_opt(loss, list(self.goal_enc.parameters()) + list(self.goal_dec.parameters()), named_parameters=named_parameters))
        
        # with torch.no_grad(): ## DEBUG
        #     enc = self.goal_enc(feat)
        #     post_loss_probs = enc.base_dist.probs.detach().cpu()
        #     print(f"\tvae probs: {pre_loss_probs[0, 0, 0, 0:2]} {post_loss_probs[0, 0, 0, 0:2]}")

        encoder_ent = enc.entropy().mean()
        metrics.update(tools.tensorstats(encoder_ent.detach().cpu(), "goal_ae_ent"))
        metrics.update(tools.tensorstats(rec.detach().cpu(), "goal_ae_rec"))
        metrics.update(tools.tensorstats(loss.detach().cpu(), "goal_ae_loss"))
        if self._config.goal_vae_kl:
            metrics.update(tools.tensorstats(kl.detach().cpu(), "goal_ae_kl"))

        return metrics
    
    def _train(
        self,
        start,
        context,
        action=None,
        extr_reward=None,
        imagine=None,
        tape=None,
        repeats=None,
    ):
        metrics = {}

        # Train the goal autoencoder on world model representations
        # NOTE: These are posterior representations of the world model (s_t|s_t-1, a_t-1, x_t)
                
        metrics = self.train_goal_vae(context, metrics)

        if self._config.debug_only_goal_ae:
            return [metrics]

        with tools.RequiresGrad(self.manager.actor):
            with tools.RequiresGrad(self.worker.actor):
            # with torch.no_grad():
                # Given the output world model starting state, do imagined rollouts at each step with the worker's action policy 
                imag_feat, imag_state, imag_action, imag_skills, imag_goals = self._imagine_carry(
                    start, self.worker.actor, self._config.imag_horizon+1, repeats
                ) # plus one to horizon to match dhafner implementation for worker/manager trajectory splits.

                # Compute the rewards. All rewards start from ts=1 
                reward_extr = self.extr_reward(imag_state)[1:]
                reward_expl = self.expl_reward(imag_state)[1:]
                reward_goal = self.goal_reward(imag_state, imag_goals)[1:]

                reward_expl = reward_expl.unsqueeze(-1)
                reward_goal = reward_goal.unsqueeze(-1)


                # The rollout needs to be split two ways. 
                # The manager takes every kth step and sums the weighted rewards for 0 to k-1
                # The worker is trained on k step snippets that have the same goal
                '''
                debug tensor for checking trajetory splits
                '''
                shape = list(imag_feat.shape)
                dims = [np.arange(s) for s in shape]
                meshgrid = np.meshgrid(*dims, indexing='ij')
                debug = torch.Tensor(np.stack(meshgrid, axis=-1)).to(self._config.device)
                debug = debug[:, :, :, 0]
            
                traj = {
                    "stoch": imag_state["stoch"],
                    "deter": imag_state["deter"],
                    "logit": imag_state["logit"],
                    "feat": imag_feat,
                    "action": imag_action,
                    "reward_extr": reward_extr,
                    "reward_expl": reward_expl,
                    "reward_goal": reward_goal,
                    "goal": imag_goals,
                    "skill": imag_skills,
                    "debug": debug
                }

                cont = self._world_model.heads["cont"](imag_feat).mean
                first_cont = torch.Tensor((1 - context['is_terminal']).reshape(1, -1, 1)).to(self._config.device)
                traj["cont"] = torch.cat([first_cont, cont[1:]], dim=0)

                if (self._config.debug):
                    print(f"Director")
                    for k,v in traj.items():
                        print(f"\t{k}: {v.shape if hasattr(v, 'shape') else 'None'}")


                # manager trajectory must be split to only include the goal selection steps and sum reward between them
                mtraj = self.abstract_traj(traj)
                # mtraj["reward"] = self._config.extr_reward_weight * mtraj["reward_extr"]
                mtraj["reward"] = self._config.expl_reward_weight * mtraj["reward_expl"] + self._config.extr_reward_weight * mtraj["reward_extr"]
                mtraj["state"] = {"stoch": mtraj["stoch"], "deter": mtraj["deter"], "logit": mtraj["logit"]}

                # worker trajecory must be split into goal horizon length chunks
                wtraj = self.split_traj(traj)
                if self._config.debug_worker_extrinsic_reward:
                    wtraj["reward"] = wtraj["reward_extr"] + wtraj["reward_goal"]
                else:
                    wtraj["reward"] = wtraj["reward_goal"]

                wtraj["state"] = {"stoch": wtraj["stoch"], "deter": wtraj["deter"], "logit": wtraj["logit"]}

                # if True:
                if (self._config.debug):
                    print(f"original action dim: {imag_action.shape} --> {wtraj['action'].shape} worker")
                    print(f"                     {imag_action.shape} --> {mtraj['action'].shape} manager")

                    print(f"Manager Traj")
                    for key, value in list(mtraj.items()):
                        if key == "action":
                            old = traj["skill"]
                        elif key not in traj:
                            print(f"\t{key} shape:  None --> {value.shape if hasattr(value, 'shape') else 'None'}")
                            continue
                        else:
                            old = traj[key]
                        print(f"\t{key} shape:  {old.shape} --> {value.shape}")
                        
                    print(f"Worker Traj")
                    for key, value in list(wtraj.items()):
                        if key not in traj:
                            print(f"\t{key} shape:  None --> {value.shape if hasattr(value, 'shape') else 'None'}")
                            continue
                        else:
                            old = traj[key]
                        print(f"\t{key} shape:  {old.shape} --> {value.shape}")

        # JS GARBAGE
        # traj['reward'] = traj['reward_extr'] + 0.1 * traj['reward_expl']
        # traj['action'] = traj['skill']
        # traj['state'] = {'stoch': traj['stoch'], 'deter': traj['deter'], 'logit': traj['logit']}
        # metrics.update(self.manager._train(imag_traj=traj)[-1])

        metrics.update(self.manager._train(imag_traj=mtraj)[-1])
        metrics.update(self.worker._train(imag_traj=wtraj)[-1])

        metrics.update(tools.tensorstats(reward_extr.detach().cpu(), "reward_extr_unweighted"))
        metrics.update(tools.tensorstats(reward_expl.detach().cpu(), "reward_expl_unweighted"))
        metrics.update(tools.tensorstats(reward_goal.detach().cpu(), "reward_goal_unweighted"))

        return None, metrics
    

    def _imagine_carry(self, start_wm, policy, horizon, repeats=None):
        dynamics = self._world_model.dynamics
        if repeats:
            raise NotImplemented("repeats is not implemented in this version")
        flatten = lambda x: x.reshape([-1] + list(x.shape[2:]))
        start = {k: flatten(v) for k, v in start_wm.items()}
        if self._stop_grad_actor:
            start = {k: v.detach() for k, v in start.items()}

        def step(prev, step):
            state, _, _, skill, goal = prev
            
            feat = dynamics.get_feat(state) # z_state + h

            # NOTE: step is ignored in original code. does it mess with the flatten?
            if goal is None or step % self._config.train_skill_duration == 0:
                skill, goal, *_ = self.sample_goal(state)

            inp = torch.cat([feat.detach(), goal], -1)
            action = policy(inp).sample()

            # print(f"director_models::_imagine::step"); ipshell()
            succ = dynamics.img_step(state, action, sample=self._config.imag_sample)
            return succ, feat, action, skill, goal

        succ, sc_feats, sc_actions, sc_skills, sc_goals = tools.static_scan(
            step, [torch.arange(horizon)], (start, None, None, None, None)
        )
        sc_states = {k: torch.cat([start[k].unsqueeze(0), v[:-1]], 0) for k, v in succ.items()}
        if repeats:
            raise NotImplemented("repeats is not implemented in this version")

        feats = sc_feats
        actions = sc_actions
        goals = sc_goals
        skills = sc_skills
        states = sc_states
        return feats, states, actions, skills, goals
    
    def sample_goal(self, latent, human=False):
        skill_probs = np.zeros_like(self._config.skill_shape)
        skill_dist = None
        uiux_code = None
        if human or self._config.debug_uiux:
            # sample, cluster, accept ui
            skill, uiux_code = self.uiux.interface(latent, self.worker.actor) # NOTE: Blocking call
            if uiux_code == "defer" or uiux_code == "exit": # NOTE: -1 indicates the interface has elected to defer to the system i.e. "Trust Robot"
                feat = self._world_model.dynamics.get_feat(latent).detach()
                skill = self.manager.actor(feat).sample()
            skill = skill.detach()
        else:
            feat = self._world_model.dynamics.get_feat(latent).detach()
            skill_dist = self.manager.actor(feat)
            skill = skill_dist.sample().detach()
            skill_probs = skill_dist.base_dist.probs.detach().cpu().numpy()
            # format skill probs so they only show 2 decimal places with a leading space
            skill_probs = np.round(skill_probs, 2)

        skill_entropy = to_np(skill_dist.entropy().mean()) if skill_dist is not None else -1337

        goal = self.goal_dec(skill.reshape(-1, np.product(self._config.skill_shape))).mode()
        goal = goal.detach()
        return skill, goal, skill_probs, skill_entropy, uiux_code

class ImagActorCritic(nn.Module):
    def __init__(self, config, world_model, input_shape, num_actions, stop_grad_actor=True, prefix=''):
        super(ImagActorCritic, self).__init__()
        self._use_amp = True if config.precision == 16 else False
        self._config = config
        self._world_model = world_model
        self._stop_grad_actor = stop_grad_actor
        self._prefix = prefix
        self._config = config

        if self._prefix == "manager":
            actor_act = config.manager_actor['act']
            actor_norm = config.manager_actor['norm']
            actor_dist = config.manager_actor['dist']
            actor_outscale = config.manager_actor['outscale']
            actor_lr = config.manager_lr
            ac_opt_eps = config.manager_opt_eps
            # value_outscale = config.critic["outscale"] # NOTE: formerly 0.0 in this repos dreamerv3 implementation
            value_head = config.value_head # "symlog_mse" # 
        elif self._prefix == "worker" or self._prefix == "jsdebug":
            actor_act = config.worker_actor['act']
            actor_norm = config.worker_actor['norm']
            actor_dist = config.worker_actor['dist']
            actor_outscale = config.worker_actor['outscale']
            # actor_outscale = 1.0
            actor_lr = config.worker_lr
            ac_opt_eps = config.ac_opt_eps
            # value_outscale = config.critic["outscale"] # NOTE: formerly 0.0 in this repos dreamerv3 implementation
            value_head = config.value_head # "symlog_mse" # 
        else:
            raise NotImplementedError(self._prefix)

        self.actor = networks.ActionHead(
            input_shape,
            num_actions,
            config.actor_layers,
            config.units,
            actor_act,
            actor_norm,
            actor_dist,
            config.actor_init_std,
            config.actor_min_std,
            config.actor_max_std,
            config.actor_temp,
            outscale=actor_outscale,
            unimix_ratio=config.action_unimix_ratio,
        )
        # self.actor = networks.MLP(
        #     input_shape,
        #     num_actions,
        #     config.actor_layers,
        #     config.units,
        #     actor_act,
        #     actor_norm,
        #     actor_dist,
        #     config.actor_init_std,
        #     outscale=actor_outscale,
        #     unimix_ratio=config.action_unimix_ratio,
        #     device=config.device,
        # )
        
        if value_head == "symlog_disc":
            self.value = networks.MLP(
                input_shape,
                (255,),
                config.value_layers,
                config.units,
                config.act,
                config.norm,
                # config.value_head,
                value_head,
                outscale=0.1,
                device=config.device,
            )
            self.expl_value = networks.MLP(
                input_shape,
                (255,),
                config.value_layers,
                config.units,
                config.act,
                config.norm,
                # config.value_head,
                value_head,
                outscale=0.0,
                device=config.device,
            # ) if self._prefix == "n/a" else None
            ) if self._prefix == "manager" else None
        else:
            print(f"director_models::ImagActorCritic::value_head: {value_head}")
            self.value = networks.MLP(
                input_shape,
                [],
                config.value_layers,
                config.units,
                config.act,
                config.norm,
                value_head,
                outscale=0.0,
                device=config.device,
            )
            self.expl_value = networks.MLP(
                input_shape,
                [],
                config.value_layers,
                config.units,
                config.act,
                config.norm,
                value_head,
                outscale=0.0,
                device=config.device,
            # ) if self._prefix == "n/a" else None
            ) if self._prefix == "manager" else None


        
        if config.slow_value_target:
            self._slow_value = copy.deepcopy(self.value)
            self._slow_value_expl = copy.deepcopy(self.expl_value) if self.expl_value is not None else None
            self._updates = 0
            
        kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp)
        self._actor_opt = tools.Optimizer(
            f"{self._prefix}_actor",
            self.actor.parameters(),
            actor_lr,
            ac_opt_eps,
            config.actor_grad_clip,
            **kw,
        )

        # self._simple_actor_opt = torch.optim.Adam(
        #     self.actor.parameters(),
        #     lr=actor_lr,
        #     weight_decay=config.weight_decay,
        # )

        self._value_opt = tools.Optimizer(
            f"{self._prefix}_value",
            self.value.parameters(),
            config.value_lr,
            ac_opt_eps,
            config.value_grad_clip,
            **kw,
        )
        self._expl_value_opt = tools.Optimizer(
            f"{self._prefix}_expl_value",
            self.expl_value.parameters(),
            config.value_lr,
            ac_opt_eps,
            config.value_grad_clip,
            **kw,
        ) if self.expl_value is not None else None
        if self._config.reward_EMA:
            self.reward_ema = dv3_models.RewardEMA(device=self._config.device)

        if self._config.actent_perdim:
            shape = () if type(num_actions) == int else num_actions[:-1]
            self.actent = tu.AutoAdapt(shape, **self._config.actent, inverse=True).to(self._config.device)

    def _train(
        self,
        start=None,
        objective=None,
        repeats=None,
        imag_traj=None,
    ):
        # if start == None and imag_traj == None:
        if imag_traj == None:
            raise ValueError("Must provide either start or imag_traj")
        
        self._update_slow_target()
        metrics = {}

        # if imag_traj is None:
        #     imag_feat, imag_state, imag_action = self._imagine(
        #         start, self.actor, self._config.imag_horizon, repeats
        #     )
        #     rewards = objective(imag_feat, imag_state, imag_action)
        #     # reduce the rewards to a single entry for each timestep
        #     imag_reward = sum(rewards)
        # else:
        imag_feat = imag_traj["feat"]
        imag_state = imag_traj["state"]
        imag_action = imag_traj["action"]
        imag_reward = imag_traj["reward"]
        imag_cont = imag_traj["cont"]
        imag_goal = imag_traj["goal"] if "goal" in imag_traj else None # NOTE: goals become actions for the manager
        
        
        with tools.RequiresGrad(self.actor):
            with torch.cuda.amp.autocast(self._use_amp):
                state_ent = self._world_model.dynamics.get_dist(imag_state).entropy()
                # this target is not scaled
                # slow is flag to indicate whether slow_target is used for lambda-return
                target, weights, base = self._compute_target(imag_cont,imag_feat, imag_state, imag_action, imag_goal, imag_reward, self.calc_value)
                # if self._config.make_dots: make_dot(base, params=dict(self.actor.named_parameters())).render(f"{self._prefix}_base", format="png")

                # NOTE: there's a lot of target regularization in DHafner's version
                actor_loss, mets = self._compute_actor_loss(
                    imag_feat,
                    imag_state,
                    imag_action,
                    imag_goal,
                    target,
                    state_ent,
                    weights,
                    base,
                )
                
                metrics.update(mets)
                if self._prefix == "manager" or self._prefix == "jsdebug":
                    value_input = imag_feat 
                    target_reward = imag_traj["reward_extr"]
                elif self._prefix == "worker":
                    if self._config.debug_worker_extrinsic_reward:
                        target_reward = imag_traj["reward_extr"]
                    else:
                        target_reward = imag_traj["reward_goal"]

                    value_input = torch.cat([imag_feat], dim=-1)
                    # value_input = torch.cat([imag_feat, imag_goal], dim=-1)
                else:
                    raise NotImplementedError

        if self._config.make_dots: 
            actor_named_params = {k: v for k, v in self.actor.named_parameters()} if self._config.make_dots else None
            make_dot(actor_loss.mean(), params=actor_named_params).render(f"{self._prefix}_actor_loss", format="png")

        detached_imag_state = {k: v.detach() for k, v in imag_state.items()}
        ### Train critics
        '''
        # NOTE: Targets are calculated with the slower target networks, then loss is calculated with the active value networks agaunst the slower targets. The alternate way in models.py (dreamerv3?) is to use the active netowrk for the targets, and then add the negative log-likelihood of the slow network targets. i.e. -value.log_prob(value_target) - value.log_prob(slow_value_target)
        '''
        with tools.RequiresGrad(self.value):
            with torch.cuda.amp.autocast(self._use_amp):
                # value_fn = lambda x: self.value(x.detach()).mode()
                slow_value_fn = lambda x: self._slow_value(x.detach()).mode()
                extr_target, extr_weights, _ = self._compute_target(imag_cont,
                    imag_feat.detach(), detached_imag_state, imag_action.detach(), imag_goal.detach(), target_reward.detach(), slow_value_fn
                )
                extr_dist = self.value(value_input[:-1].detach())
                # (time, batch, 1), (time, batch, 1) -> (time, batch)
                extr_target = torch.stack(extr_target, dim=1)
                extr_target = extr_target.detach()
                value_loss = extr_dist.log_prob(extr_target)

                # (time, batch, 1), (time, batch, 1) -> (1,)
                value_loss = -torch.mean(extr_weights[:-1] * value_loss[:, :, None])
                # value_loss = torch.mean(extr_weights[:-1] * value_loss[:, :, None])
                if self._config.make_dots: 
                    make_dot(value_loss, params=dict(self.value.named_parameters())).render(f"{self._prefix}_value_loss", format="png")


        if self.expl_value is not None:
            with tools.RequiresGrad(self.expl_value):
                with torch.cuda.amp.autocast(self._use_amp):
                    # value_fn = lambda x: self.expl_value(x.detach()).mode()
                    slow_value_fn = lambda x: self._slow_value_expl(x.detach()).mean()
                    expl_target, expl_weights, _ = self._compute_target(imag_cont,
                        imag_feat.detach(), detached_imag_state, imag_action.detach(), imag_goal.detach(), imag_traj["reward_expl"].detach(), slow_value_fn
                    )
                    expl_target = torch.stack(expl_target, dim=1)
                    expl_target = expl_target.detach()
                    expl_dist = self.expl_value(value_input[:-1].detach())
                    # (time, batch, 1), (time, batch, 1) -> (time, batch)
                    expl_value_loss = expl_dist.log_prob(expl_target)

                    # (time, batch, 1), (time, batch, 1) -> (1,)
                    expl_value_loss = -torch.mean(expl_weights[:-1] * expl_value_loss[:, :, None])
                    if self._config.make_dots: 
                        make_dot(expl_value_loss, params=dict(self.expl_value.named_parameters())).render(f"{self._prefix}_expl_value_loss", format="png")
                    # expl_value_loss = torch.mean(expl_weights[:-1] * expl_value_loss[:, :, None])


        metrics["state_entropy"] = to_np(torch.mean(state_ent))
        target = torch.stack(target, dim=1)
        with tools.RequiresGrad(self):
            actor_named_params = {k: v for k, v in self.actor.named_parameters()} if self._config.make_dots else None
            value_named_params = {k: v for k, v in self.value.named_parameters()} if self._config.make_dots else None


            # if self._prefix == "manager" or self._prefix == "jsdebug":
            #     inp = imag_feat
            #     inp = inp.detach() if self._stop_grad_actor else inp
            #     policy = self.actor(inp)
            #     pre_act_log_probs = policy.log_prob(imag_action).detach().mean().cpu().numpy()
            #     pre_train_entropy = policy.entropy().detach().mean().cpu().numpy()
            #     pre_train_entropy_str = [f'{ent:1.2f}' for ent in  policy.base_dist.entropy().detach().mean(axis=(0, 1)).cpu().numpy()]
            #     # print(policy.base_dist.probs.detach().mean(axis=(0, 1)))
            
            # if self._prefix == "manager":
            # if False: # ise the simple optimizer for debugging purposes
                # self._simple_actor_opt.zero_grad()
                # actor_loss.backward()
                # self._simple_actor_opt.step()
            # else:
            metrics.update(self._actor_opt(actor_loss, self.actor.parameters(), named_parameters=actor_named_params))
            # if self._prefix == "manager" or self._prefix == "jsdebug":
            #     inp = imag_feat
            #     inp = inp.detach() if self._stop_grad_actor else inp
            #     policy = self.actor(inp)
            #     post_act_log_probs = policy.log_prob(imag_action).detach().mean().cpu().numpy()
            #     post_train_entropy = policy.entropy().detach().mean().cpu().numpy()
            #     post_train_entropy_str = [f'{ent:1.2f}' for ent in  policy.base_dist.entropy().detach().mean(axis=(0, 1)).cpu().numpy()]
            #     print(f"\tactor_loss: {actor_loss:1.2f} reward: {imag_reward.detach().cpu().numpy().mean():1.2f} value_loss: {value_loss:1.2f} target: {to_np(target).mean():1.2f}")
            #     print(f"\t{pre_train_entropy:1.2f} -> {post_train_entropy:1.2f} || {[f'{pre} -> {post}' for pre, post in zip(pre_train_entropy_str, post_train_entropy_str)]}")
            #     print(f"\t{self.actent.scale():} actent scale")
            #     print(f"\t{pre_act_log_probs:1.2f} -> {post_act_log_probs:1.2f} logprobs")
            
            metrics.update(self._value_opt(value_loss, self.value.parameters(), named_parameters=value_named_params))
            if self.expl_value is not None:
                expl_value_named_params = {k: v for k, v in self.expl_value.named_parameters()} if self._config.make_dots else None
                metrics.update(self._expl_value_opt(expl_value_loss, self.expl_value.parameters(), named_parameters=expl_value_named_params))

        metrics.update(tools.tensorstats(extr_dist.mode(), "extr"))
        metrics.update(tools.tensorstats(value_loss, "extr_loss"))
        metrics.update(tools.tensorstats(target, "target"))
        metrics.update(tools.tensorstats(extr_target, "extr_target"))
        if self.expl_value is not None:
            metrics.update(tools.tensorstats(expl_dist.mode(), "expl"))
            metrics.update(tools.tensorstats(expl_value_loss, "expl_loss"))
            metrics.update(tools.tensorstats(expl_target, "expl_target"))
        metrics.update(tools.tensorstats(imag_reward, "imag_reward"))
        
        if self._config.actor_dist in ["onehot"]:
            metrics.update(
                tools.tensorstats(
                    torch.argmax(imag_action, dim=-1).float(), "imag_action"
                )
            )
        else:
            metrics.update(tools.tensorstats(imag_action, "imag_action"))
        metrics.update(tools.tensorstats(actor_loss, "actor_loss"))

        # Name the metrics with the prefix
        prefixxed_metrics = {}
        for k,v in metrics.items():
            if self._prefix not in k:
                prefixxed_metrics[self._prefix+"_"+k] = v
            else:
                prefixxed_metrics[k] = v


        return imag_feat, imag_state, imag_action, weights, prefixxed_metrics

    def _imagine(self, start, policy, horizon, repeats=None):
        dynamics = self._world_model.dynamics
        if repeats:
            raise NotImplemented("repeats is not implemented in this version")
        flatten = lambda x: x.reshape([-1] + list(x.shape[2:]))
        start = {k: flatten(v) for k, v in start.items()}

        def step(prev, _):
            state, _, _ = prev
            feat = dynamics.get_feat(state)
            inp = feat.detach() if self._stop_grad_actor else feat
            action = policy(inp).sample()
            succ = dynamics.img_step(state, action, sample=self._config.imag_sample)
            return succ, feat, action

        succ, feats, actions = tools.static_scan(
            step, [torch.arange(horizon)], (start, None, None)
        )
        states = {k: torch.cat([start[k][None], v[:-1]], 0) for k, v in succ.items()}
        if repeats:
            raise NotImplemented("repeats is not implemented in this version")

        return feats, states, actions

    def _compute_target(
        self, imag_cont, imag_feat, imag_state, imag_action, imag_goal, reward, value_network_fn
    ):
        discount = self._config.discount * imag_cont
        if self._config.future_entropy: raise NotImplementedError

        if self._prefix == "manager" or self._prefix == "jsdebug":
            inp = imag_feat
        elif self._prefix == "worker":
            inp = torch.cat([imag_feat], dim=-1)
            # inp = torch.cat([imag_feat, imag_goal], dim=-1)
        else:
            raise NotImplementedError(self._prefix)
        
        value = value_network_fn(inp) # pass in a value function so we can compute the target for separate value networks # NOTE: DHafner def does this better
        # value = self.value(inp).mode() 
        # if self.expl_value is not None:
        #     value = value + self.expl_value(inp).mode()
        # value(15, 960, ch)
        # action(15, 960, ch)
        # discount(15, 960, ch)
        # NOTE: value and discount start from one to match DHafner. Likely to align with norm to have a reward be the consequence of an early state-action rather than the current step's state-action.
        # NOTE: This is a departure from this repo's implementation of the lambda return which ignore the last index. Keep a lookout for other places that make ths same assumption
        
        if self._config.make_dots:
            make_dot(value, params=dict(self.value.named_parameters())).render(f"{self._prefix}_value_pre_lambda_return", format="png")

        target = tools.lambda_return( 
            reward, 
            value[:-1],
            discount[1:],
            bootstrap=value[-1],
            lambda_=self._config.discount_lambda,
            axis=0,
        )
        weights = torch.cumprod(
            torch.cat([torch.ones_like(discount[:1]), discount[:-1]], 0), 0
        ).detach()

        return target, weights, value[:-1] # see note above about lining up value and reward

    def calc_value(self, inp):
        if self._config.slow_value_target:
            extr_value = self._slow_value(inp).mode()
            expl_value = self._slow_value_expl(inp).mode() if self.expl_value is not None else 0
        else:
            extr_value = self.value(inp).mode() # extr_reward_weight == goal_reward_weight for the worker TODO: separate variables for clarity
            expl_value = self.expl_value(inp).mode() if self.expl_value is not None else 0

        return self._config.extr_reward_weight*extr_value + self._config.expl_reward_weight*expl_value

    def _compute_actor_loss(
        self,
        imag_feat,
        imag_state,
        imag_action,
        imag_goal,
        target,
        state_ent,
        weights,
        base,
    ):
        metrics = {}
        if self._prefix == "manager" or self._prefix == "jsdebug":
            inp = imag_feat
        elif self._prefix == "worker":
            inp = torch.cat([imag_feat], dim=-1)
            # inp = torch.cat([imag_feat, imag_goal], dim=-1)
        else:
            raise NotImplementedError(self._prefix)
        inp = inp.detach() if self._stop_grad_actor else inp
        policy = self.actor.forward(inp)

        # if self._prefix == "worker" or self._prefix == "jsdebug":
        #     print(f"actor probs {policy.probs.detach().mean(axis=(0,1)).cpu().numpy()} entropy {policy.entropy().mean():1.2f} {policy.entropy().std():1.2f} {self._prefix}")
            # import pdb; pdb.set_trace()

        # recalculate the entropy to detach it from the graph
        actor_ent = policy.entropy()

        # Q-val for actor is not transformed using symlog
        target = torch.stack(target, dim=1).detach()
        if self._config.reward_EMA:
            offset, scale = self.reward_ema(target)
            normed_target = (target - offset) / scale
            normed_base = (base - offset) / scale

            adv = normed_target - normed_base
            metrics.update(tools.tensorstats(normed_target, "normed_target"))
            values = self.reward_ema.values
            metrics["EMA_005"] = to_np(values[0])
            metrics["EMA_095"] = to_np(values[1])

        if self._config.imag_gradient == "dynamics": # NOTE: "backprop" in DHafner
            actor_target = adv
        elif self._config.imag_gradient == "reinforce":
            actor_target = ( # NOTE: these leave out the last action instead of the first. Conflicts with the changes to _calculate_target?
                -policy.log_prob(imag_action)[:-1][:, :, None].detach()
                * (target - self.calc_value(inp[:-1])).detach()
            )
            # actor_target = (policy.log_prob(imag_action)[:-1][:, :, None]).detach() * (adv).detach()
            if self._config.make_dots: make_dot(actor_target.mean(), params=dict(self.actor.named_parameters())).render("actor_target", format="png")
        else:
            raise NotImplementedError(self._config.imag_gradient)
        
        
        actor_entropy = 0
        if (self._config.auto_adapt_entropy_manager and self._prefix == "manager") or (self._config.auto_adapt_entropy_worker and self._prefix == "worker"):
             # NOTE: The worker seems happy with the standard entropy regularizer below, but the manager is not exploring enough so here's the Dhafner entropy regularizer
            if len(self.actent._shape) > 0:
                assert isinstance(policy, torch.distributions.Independent), type(policy)
                assert isinstance(policy.base_dist, tools.OneHotDist), type(policy.base_dist)
                ent = policy.base_dist.entropy()[:-1]
                # based on dist
                lo = 0.0 
                hi = np.prod(self.actor._size[:-1]) * np.log(self.actor._size[-1]) # based on dist
                lo /= ent.shape[-1]
                hi /= ent.shape[-1]
                ent = (ent - lo) / (hi - lo)
                ent_loss, ent_mets = self.actent(ent)
                ent_loss = ent_loss.sum(axis=-1) # NOTE: original DHafner impl
            else:
                lo = 0.0
                hi = np.log(self.actor._size) # NOTE: Only works for onehot
                ent = policy.entropy()[:-1]
                ent = (ent - lo) / (hi - lo)
                ent_loss, ent_mets = self.actent(ent)
            actor_entropy = ent_loss.unsqueeze(-1)
            # if self._prefix == "worker":
            #     print(f"ent: {ent.shape} mean {[f'{p:1.2f}' for p in to_np(policy.probs.mean(axis=(0,1)))]} ent {to_np(ent.mean()):1.2f} ent_loss {to_np(ent_loss.mean()):+1.2f} {self._prefix}")
            
            metrics.update({f'actor_entropy_loss': to_np(ent_loss.mean())})
            metrics.update({f'actent_{k}': to_np(v) for k, v in ent_mets.items()})

        elif self._config.actor_entropy > 0:
            actor_entropy = -self._config.actor_entropy * actor_ent[:-1][:, :, None]
            # if self._config.make_dots: make_dot(actor_entropy.mean(), params=dict(self.actor.named_parameters())).render("actor_entropy_II", format="png")
            # if self._config.make_dots: make_dot(actor_target.mean(), params=dict(self.actor.named_parameters())).render("actor_target_II", format="png")
            metrics["actor_entropy"] = to_np(torch.mean(actor_entropy))
            # metrics["actor_entropy"] = to_np(torch.mean(policy.entropy()))
            metrics["entropy_scale"] = self._config.actor_entropy

        metrics["actor_target_unnormalized"] = to_np(actor_target.mean())
        metrics.update({f'actor_entropy_unweighted': to_np(actor_ent.mean())})
        
        actor_target = actor_target + actor_entropy
        actor_loss = torch.mean(weights[:-1].detach() * actor_target)
        return actor_loss, metrics

    def _update_slow_target(self):
        if self._config.slow_value_target:
            if self._updates % self._config.slow_target_update == 0:
                mix = 1.0 # self._config.slow_target_fraction # NOTE: value set from dhafner impl
                for s, d in zip(self.value.parameters(), self._slow_value.parameters()):
                    d.data = mix * s.data + (1 - mix) * d.data
                if self.expl_value is not None:
                    for s, d in zip(self.expl_value.parameters(), self._slow_value_expl.parameters()):
                        d.data = mix * s.data + (1 - mix) * d.data
            self._updates += 1
