
import time
import torch
import torch.nn as nn
import numpy as np


import fastcore.all as fc
import tools

def reshape(tnsr): return tnsr.reshape(*tnsr.shape[:(len(tnsr.shape)-2)], -1)
to_np = lambda x: x.detach().cpu().numpy()

def goal_reward(state, goal):
    # cosine_max reward
    h = state["deter"]
    goal = goal.detach()
    gnorm = torch.linalg.norm(goal, dim=-1, keepdim=True) + 1e-12
    hnorm = torch.linalg.norm(h, dim=-1, keepdim=True) + 1e-12
    norm = torch.max(gnorm, hnorm)
    # import pdb; pdb.set_trace()
    return torch.einsum("...i,...i->...", goal / norm, h / norm)

def expl_reward(state, goal_enc, goal_dec):
    # elbo (Evidence Lower BOund) reward
    feat = state["deter"]
    enc = goal_enc(feat) 
    x = enc.sample() # produces a skill_dim x skill_dim
    x = x.reshape([x.shape[0], x.shape[1], -1]) # (time, batch, feat0*feat1)
    dec = goal_dec(x) # decode back the feature dimensions

    return ((dec.mode() - feat) ** 2).mean(-1) # NOTE: This can probably be a log likelihood under the MSE distribution

def abstract_traj(config, orig_traj):
    traj = orig_traj.copy()
    traj["action"] = traj.pop("skill")
    k = config.train_skill_duration
    reshape = lambda x: x.reshape([x.shape[0] // k, k] + list(x.shape[1:]))
    weights = torch.cumprod(reshape(traj["cont"][:-1]), 1).to(config.device)
    for key, value in list(traj.items()):
        if 'reward' in key:
            traj[key] = (reshape(value) * weights).sum(1) # weighted sum of each subtrajectory
        elif key == 'cont':
            traj[key] = torch.concat([value[:1], reshape(value[1:]).prod(1)], 0)
        else:
            traj[key] = torch.concat([reshape(value[:-1])[:, 0], value[-1:]], 0)
    traj['weight'] = torch.cumprod(
        config.imag_discount * traj['cont'], 0) / config.imag_discount
    return traj

def split_traj(config, orig_traj):
    traj = orig_traj.copy()
    k = config.train_skill_duration
    # print(f"Trajectory length must be divisible by k+1 {len(traj['action'])} % {k} != 1")
    assert len(traj['action']) % k == 1; (len(traj['action']) % k), "Trajectory length must be divisible by k+1"
    reshape = lambda x: x.reshape([x.shape[0] // k, k] + list(x.shape[1:]))
    for key, val in list(traj.items()):
        val = torch.concat([0 * val[:1], val], 0) if 'reward' in key else val
        # (1 2 3 4 5 6 7 8 9 10) -> ((1 2 3 4) (4 5 6 7) (7 8 9 10))
        val = torch.concat([reshape(val[:-1]), val[k::k][:, None]], 1)
        # N val K val B val F... -> K val (N B) val F...
        val = val.permute([1, 0] + list(range(2, len(val.shape))))
        val = val.reshape(
            [val.shape[0], np.prod(val.shape[1:3])] + list(val.shape[3:]))
        val = val[1:] if 'reward' in key else val
        traj[key] = val
    # Bootstrap sub trajectory against current not next goal.
    traj['goal'] = torch.concat([traj['goal'][:-1], traj['goal'][:1]], 0)
    traj['weight'] = torch.cumprod(
        config.imag_discount * traj['cont'], axis=0) / config.imag_discount
    return traj

class Director(nn.Module):
    def __init__(self, config, logger, train_dataset, wm, goal_enc, goal_dec, goal_ae_opt, manager, worker):
        super(Director, self).__init__()
        fc.store_attr()
        self.step = 0; self._step = 0; self.start_time = time.time()
        self.pretrained = False
        self.metrics = {}

        self.skill_prior = torch.distributions.independent.Independent(
                tools.OneHotDist(logits=torch.zeros(*config.skill_shape, device=config.device), unimix_ratio=config.action_unimix_ratio), 1)

    def reset_step(self): self.step = 0
    
    def modules(self):
        return self.wm, self.goal_enc, self.goal_dec, self.goal_ae_opt, self.manager, self.worker

    def expand_goals(self, goals, target_shape=None):
        if target_shape == None: target_shape = (self.config.batch_size, self.config.batch_length, 1024)
        tmp = torch.zeros(target_shape).to(self.config.device)
        num_skill_switches = self.config.batch_length // self.config.train_skill_duration
        for i in range(num_skill_switches):
            s = i*self.config.train_skill_duration
            tmp[:, s:s+self.config.train_skill_duration, ...] = goals[:, i, ...][:, None, ...]
        return tmp

    def run_batch(self, obs):
        def reshape(tnsr):
            return tnsr.reshape(*tnsr.shape[:(len(tnsr.shape)-2)], -1)
        
        obs = self.wm.preprocess(obs)
        embed = self.wm.encoder(obs)

        post, prior = self.wm.dynamics.observe(embed, obs["action"], obs["is_first"])

        stoch = post["stoch"]
        stoch = stoch.reshape(*stoch.shape[:2], -1)

        embed = torch.cat([stoch, post["deter"]], dim=-1)

        goal_embed = embed.detach()[:, ::self.config.train_skill_duration, ...] 
        mgr_onehot_dist = self.manager.actor.forward(goal_embed)
        skill = mgr_onehot_dist.sample()
        skill = skill.reshape(*skill.shape[0:2], -1)

        goal = self.goal_dec(skill).mode()
        goal = self.expand_goals(goal)
        print(goal.shape, embed.shape)
        worker_inp = torch.cat([embed, goal], dim=-1)
        wrk_onehot_dist = self.worker.actor.forward(worker_inp)
        action = wrk_onehot_dist.sample()
        

    def sample_goal(self, latent, human=False):
        skill_probs = np.zeros_like(self.config.skill_shape)
        skill_dist = None
        feat = self.wm.dynamics.get_feat(latent).detach()
        # import pdb; pdb.set_trace()
        skill_dist = self.manager.actor.forward(feat)
        skill = skill_dist.sample().detach()
        skill_probs = skill_dist.base_dist.probs.detach().cpu().numpy()
        # format skill probs so they only show 2 decimal places with a leading space
        skill_probs = np.round(skill_probs, 2)
    
        skill_entropy = to_np(skill_dist.entropy().mean()) if skill_dist is not None else -1337
    
        goal = self.goal_dec(skill.reshape(-1, np.product(self.config.skill_shape))).mode()
        goal = goal.detach()
        return skill, goal
    
    def imagine_carry(self, start_wm, policy, horizon):
        dynamics = self.wm.dynamics
        flatten = lambda x: x.reshape([-1] + list(x.shape[2:]))
        start = {k: flatten(v) for k, v in start_wm.items()}
        if stop_grad_actor := True:
            start = {k: v.detach() for k, v in start.items()}
    
        def step(prev, step):
            state, _, _, skill, goal = prev
            
            feat = dynamics.get_feat(state) # z_state + h
    
            # NOTE: step is ignored in original code. does it mess with the flatten?
            if goal is None or step % self.config.train_skill_duration == 0:
                skill, goal = self.sample_goal(state)
    
            # inp = torch.cat([feat.detach(), goal], -1)
            inp = torch.cat([feat.detach()], -1)

            action = policy.forward(inp).sample().to(self.config.device)
            succ = dynamics.img_step(state, action, sample=self.config.imag_sample)
            return succ, feat, action, skill, goal
    
        succ, sc_feats, sc_actions, sc_skills, sc_goals = tools.static_scan(
            step, [torch.arange(horizon)], (start, None, None, None, None)
        )
        sc_states = {k: torch.cat([start[k].unsqueeze(0), v[:-1]], 0) for k, v in succ.items()}
    
        feats = sc_feats
        actions = sc_actions
        goals = sc_goals
        skills = sc_skills
        states = sc_states
        return feats, states, actions, skills, goals
    
    def _policy(self, obs, state):
        if state is None:
            batch_size = len(obs["image"])
            latent = self.wm.dynamics.initial(len(obs["image"]))
            action = torch.zeros((batch_size, self.config.num_actions)).to(self.config.device)
        else: latent, action = state

        obs = self.wm.preprocess(obs)
        # print(f"obs {obs['image'].detach().cpu().numpy().mean():1.3f} {obs['image'].detach().cpu().numpy().std():1.3f}")
        embed = self.wm.encoder(obs)
        # print(f"embed {embed.detach().cpu().numpy().mean():1.3f} {embed.detach().cpu().numpy().std():1.3f}")
        latent, _ = self.wm.dynamics.obs_step(latent, action, embed, obs["is_first"], sample=True)
        feat = self.wm.dynamics.get_feat(latent)
            
        stoch = reshape(latent["stoch"])
        deter = latent["deter"]
        
        # print(stoch.shape, deter.shape)
        if self.step % self.config.train_skill_duration == 0:
            onehot_dist = self.manager.actor.forward(torch.cat([stoch, deter], dim=-1))
            skill = reshape(onehot_dist.sample())
            goal_deter = self.goal_dec(skill).mode()
            self.current_goal_deter = goal_deter
        else:
            goal_deter = self.current_goal_deter
        self.step += 1

        # print(f"goal_deter {goal_deter.detach().cpu().numpy().mean():1.3f} {goal_deter.detach().cpu().numpy().std():1.3f}")
        
        # action_inp = torch.cat([stoch, deter, goal_deter], dim=-1)
        action_inp = torch.cat([stoch, deter], dim=-1)


        onehot_dist = self.worker.actor.forward(action_inp)
        action = onehot_dist.sample()
        action_idx = torch.argmax(action)

        return {"action": action, "logprob": torch.Tensor([0])}, (latent, action)

    def update_metrics(self, metrics):
        for name, value in metrics.items():
            if not name in self.metrics.keys():
                self.metrics[name] = [value]
            else:
                self.metrics[name].append(value)

    def avg_loss(self, metrics):
        l = [v.item() for k,v in metrics.items() if k.endswith("loss")]
        return sum(l)
                
    def __call__(self, obs, done, agent_state, training=True):
        if training:
            if self.config.pretrain and not self.pretrained:
                self.pretrained = True
                metrics = {}
                for _ in range(self.config.pure_model_pretrain):
                    *e, metrics = self.train_world_and_ae(next(self.train_dataset)); self.update_metrics(metrics)
                    if _ % 9 == 0: print(f"pure model pretrain avg_loss {self.avg_loss(metrics):1.2f} step {_} t {time.time() - self.start_time:1.2f}")
                for _ in range(self.config.pretrain):
                    # just pretrain the world and autoencoder
                    # post, context, metrics = self.train_world_and_ae(next(train_dataset), metrics)
                    metrics = self.train(next(self.train_dataset)); self.update_metrics(metrics); 
                    if _ % 9 == 0: print(f"pretrain avg_loss {self.avg_loss(metrics):1.2f} step {_} t {time.time() - self.start_time:1.2f}")
                        
            elif self._step % self.config.train_every == 0:
                metrics = self.train(next(self.train_dataset)); self.update_metrics(metrics);
                if self._step % 5 == 0: print(f"train avg_loss {self.avg_loss(metrics):1.2f} step {self._step:03} t {time.time() - self.start_time:1.2f}")
        policy_out, state = self._policy(obs, agent_state)

        if training:
            self.logger.step += self.config.action_repeat * len(done)
            self._step += len(done)
        
        return policy_out, state

    def train_world_and_ae(self, batch, metrics={}):
        post, context, wm_metrics = self.wm._train(batch)
        metrics.update(wm_metrics)
    
        onehot_dist = self.goal_enc(post["deter"])
        collapsed = onehot_dist.sample()
        mse_dist = self.goal_dec(collapsed.reshape(*collapsed.shape[:2], -1))
        recreation_error = -mse_dist.log_prob(post["deter"].detach())
    
        # simple kl adaptations
        kl = torch.distributions.kl.kl_divergence(onehot_dist, self.skill_prior)
        kl *= self.config.goal_vae_kl_beta
    
        assert kl.shape == recreation_error.shape, (kl.shape, recreation_error.shape)
        loss = torch.mean(recreation_error + kl)
        # loss = torch.mean(recreation_error)
        metrics.update(self.goal_ae_opt(loss, list(self.goal_enc.parameters()) + list(self.goal_dec.parameters())))
        return post, context, metrics

    def train(self, batch, metrics={}):
        post, context, metrics = self.train_world_and_ae(batch, metrics)
        with torch.no_grad():
            feats, states, actions, skills, goals = self.imagine_carry(post, self.worker.actor, self.config.imag_horizon+1)

        reward_extr = self.wm.heads["reward"](self.wm.dynamics.get_feat(states)).mean()[1:] # reward starts from t=1
        reward_expl = expl_reward(states, self.goal_enc, self.goal_dec)[1:][..., None]
        reward_goal = goal_reward(states, goals)[1:][..., None]
        
        cont = self.wm.heads["cont"](feats).mean
        first_cont = torch.Tensor((1 - batch['is_terminal']).reshape(1, -1, 1)).to(self.config.device)
        imag_cont = torch.cat([first_cont, cont[1:]], dim=0)

        # NOTE: for debugging, expand image by one so it splits like everything else
        # last_batch = torch.Tensor(batch["image"][-1, ...][None, ...])
        # images = torch.cat([torch.Tensor(batch["image"]), last_batch], dim=0)
        imag_traj = {
            # "image": images,
            "stoch": states["stoch"],
            "deter": states["deter"],
            "logit": states["logit"],
            "feat": feats,
            "action": actions,
            "reward_extr": reward_extr,
            "reward_expl": reward_expl,
            "reward_goal": reward_goal,
            "reward": reward_extr,
            "goal": goals,
            "skill": skills,
            "cont": imag_cont,
        }

        # import pdb; pdb.set_trace()
        metrics.update({"extr_R": torch.mean(reward_extr).item(), "expl_R": torch.mean(reward_expl).item(), "worker_R": torch.mean(reward_goal).item()})
        
        # mtraj = abstract_traj(self.config, imag_traj)
        # mtraj["reward"] = mtraj["reward_extr"] + 0.1 * mtraj["reward_expl"]
        # mtraj["state"] = {"stoch": mtraj["stoch"], "deter": mtraj["deter"], "logit": mtraj["logit"]}
        # imag_feat, imag_state, imag_action, weights, prefixxed_metrics = self.manager._train(imag_traj=mtraj)
        # metrics.update(prefixxed_metrics)
        
        # wtraj = split_traj(self.config, imag_traj)
        # wtraj["reward"] = wtraj["reward_extr"]
        # # wtraj["reward"] = wtraj["reward_goal"]
        # wtraj["state"] = {"stoch": wtraj["stoch"], "deter": wtraj["deter"], "logit": wtraj["logit"]}
        imag_traj["state"] = {"stoch": imag_traj["stoch"], "deter": imag_traj["deter"], "logit": imag_traj["logit"]}
        # imag_feat, imag_state, imag_action, weights, prefixxed_metrics = self.worker._train(imag_traj=wtraj)
        imag_traj["reward_goal"] = reward_extr
        imag_feat, imag_state, imag_action, weights, prefixxed_metrics = self.worker._train(imag_traj=imag_traj)
        metrics.update(prefixxed_metrics)
        return metrics


    def save(self, logdir, note=''):
        torch.save(self.state_dict(), logdir / f"model{note}.pt")

    def load(self, logdir, note=''):
        self.load_state_dict(torch.load(logdir / f"model{note}.pt"))