from IPython.display import display, clear_output, Image, HTML, update_display
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import fastrl_utils as frlu
import cv2
import time
import torch
from uiux import UIUX
from IPython import embed as ipshell
import tools

to_np = lambda x: x.detach().cpu().numpy()
reshape = lambda tnsr: tnsr.reshape(list(tnsr.shape[:-2]) + [tnsr.shape[-1] ** 2])

def cluster_itrajs(start, agent, num_skill_samples=8, sample=True):
    pass # TODO: take clustering out of uiux.py and put it here

def goal_img_value(agent, goal_img, latent, action, is_first):
    img = goal_img.copy()
    # normalize goal_img to lie between 0 and 1
    img = (img - img.min()) / (img.max() - img.min()) - 0.5
    img = torch.Tensor(goal_img).cuda()

    embed = agent._wm.encoder({"image": img})

    print(action.shape, embed.shape, is_first.shape)
    p, _ = agent._wm.dynamics.observe(embed[None, None], action[None], is_first[None], state=None)
    inp = agent._wm.dynamics.get_feat(p)
    v = agent.manager.critics["reward"](inp).mode().detach().cpu().numpy()
    return v


def start_to_itrajs_with_skills(start, agent, num_skill_samples=8, sample=True):
    '''
    Given an embedded starting state and a director agent, generate a bunch of imagined trajectories
    based off of selected skills. Return skills,actions,goals of the trajectories and the decoded images.
    '''
    metrics = {}; t0 = time.time()

    alpha = agent.uiux.refresh_params()

    with torch.no_grad():
        feat = agent._wm.dynamics.get_feat(start)
        mngr_dist = agent.manager.actor(feat).base_dist; uniform_dist = agent.goalae.skill_prior.base_dist;

        combined_logits = alpha * uniform_dist.logits[None] + (1 - alpha) * mngr_dist.logits
        if use_mngr := False:
            skill_dist = mngr_dist
            skills = skill_dist.sample((num_skill_samples,)).transpose(0, 1)
        elif use_unif := False:
            skill_dist = uniform_dist
            skills = skill_dist.sample((num_skill_samples,))[None]
        else: 
            # TODO: normalize probabilities
            skill_dist = tools.get_skill_dist(combined_logits, agent.config.unimix_ratio)
            skills = skill_dist.sample((num_skill_samples,)).transpose(0, 1)

        metrics["hri_skill_ent"] = skill_dist.entropy().detach().mean().item()

        skills = reshape(skills)
        starting_goals = goal_deter = agent.goalae.goal_dec(skills).mode()
        starting_skill = skills.detach()

        feats, states, iactions, skills, goals = frlu.imagine_with_skills(
            agent.config, agent._wm.dynamics, start, agent.worker.actor, agent.manager.actor, agent.goalae.goal_dec, \
            agent.config.human_i_horizon + 1,  # agent.config.imag_horizon+1, \
            starting_skill, starting_goals, sample=sample)
        
        feats = feats.permute(1, 2, 0, 3).squeeze() 
        iactions = iactions.permute(1, 2, 0, 3).squeeze()
        skills = skills.permute(1, 2, 0, 3).squeeze()
        goals = goals.permute(1, 2, 0, 3).squeeze()
        states["deter"] = states["deter"].permute(1, 2, 0, 3).squeeze()
        states["stoch"] = states["stoch"].permute(1, 2, 0, 3, 4).squeeze()
        states["logit"] = states["logit"].permute(1, 2, 0, 3, 4).squeeze()

    # print(f"Gen took {time.time() - t0:0.2f}s"); t0 = time.time()

    # decode the
    ifeats = agent._wm.dynamics.get_feat(states)
    dec_images = agent._wm.heads["decoder"](ifeats)["image"].mode().squeeze()
    assert dec_images.shape[:-3] == goals.shape[:-1], "images and goals should have the same batch size"
    # print(f"Dec took {time.time() - t0:0.2f}s. {dec_images.shape}")

    return dec_images, skills, starting_goals, iactions, metrics

def start_to_itrajs(start, agent, num_skill_samples=8, sample=True):
    '''
    Given an embedded starting state and a director agent, generate a bunch of imagined trajectories
    based off of selected skills. Return skills,actions,goals of the trajectories and the decoded images.
    '''
    metrics = {}; t0 = time.time()

    with torch.no_grad():
        feat = agent._wm.dynamics.get_feat(start)

        feats, states, iactions = frlu.imagine(agent.config, agent._wm.dynamics, start, agent.worker.actor, agent.config.human_i_horizon + 1)

        feats.unsqueeze_(0); iactions.unsqueeze_(0); states["deter"].unsqueeze_(0); states["stoch"].unsqueeze_(0); states["logit"].unsqueeze_(0)
        # ipshell()
        # feats = feats.permute(1, 2, 0, 3).squeeze() 
        # iactions = iactions.permute(1, 2, 0, 3).squeeze()
        # states["deter"] = states["deter"].permute(1, 2, 0, 3).squeeze()
        # states["stoch"] = states["stoch"].permute(1, 2, 0, 3, 4).squeeze()
        # states["logit"] = states["logit"].permute(1, 2, 0, 3, 4).squeeze()

    # print(f"Gen took {time.time() - t0:0.2f}s"); t0 = time.time()

    # decode the
    ifeats = agent._wm.dynamics.get_feat(states)
    dec_images = agent._wm.heads["decoder"](ifeats)["image"].mode().squeeze()

    return dec_images, iactions, metrics

def human_interaction(start, agent, live_obs=None, num_skill_samples=64, sample=True):
    uiux: UIUX = agent.uiux
    metrics = {}
    t0 = time.time()

    spoofed_iimage_size = torch.Size([13, 64, 64, 3])
    spoofed_iaction_size = torch.Size([1, 13, 5])

    def gen_samples(n_samples=num_skill_samples):
        # if agent.config.single_level_agent:
            # dec_images, iactions, mets = start_to_itrajs(start, agent, n_samples, sample=sample)
            # skills, starting_goals = None, None
        # else:
        #     dec_images, skills, starting_goals, iactions, mets = start_to_itrajs_with_skills(start, agent, n_samples, sample=sample)
        # normalize the images
        # iimages = dec_images.detach().cpu().numpy()
        # iimages = (dec_images - dec_images.min()) / (dec_images.max() - dec_images.min())
        # return iimages, skills, starting_goals, iactions, mets

        # NOTE: Temporary code for RLC to speed up atomic action entry (we don't need to generate irollouts)
        print(f"WARN: skipping human irollouts for RLC.")
        skills, starting_goals, mets = None, None, {}
        spoofed_iimages = torch.zeros(spoofed_iimage_size)
        spoofed_iactions = torch.zeros(spoofed_iaction_size)
        return spoofed_iimages, skills, starting_goals, spoofed_iactions, mets

    exit_code = 0
    if uiux.expert_model is not None: # Deprecated, but possibly useful in the future
        iimages, skills, starting_goals, iactions, mets = gen_samples()
        assert live_obs is not None, "Need to pass in live observations to use expert model"

        expert_model = agent.uiux.expert_model
        ewm = expert_model._wm
        with torch.no_grad():
            if EXPERT_MODEL_USES_SKILL_VALUE := True:
                iimages = iimages - 0.5 # normalize the images to be between -0.5 and 0.5
                # decode the starting_goals with the existing agent to get the images
                goal_deter = starting_goals
                goal_stoch = reshape(agent._wm.dynamics.get_stoch(goal_deter))
                inp = torch.cat([goal_stoch, goal_deter], axis=-1)
                goal_imgs = agent._wm.heads["decoder"](inp)["image"].mode()

                # renormalize the images to be between -0.5 and 0.5 (expected input range)
                goal_imgs = (goal_imgs - goal_imgs.min()) / (goal_imgs.max() - goal_imgs.min()) - 0.5

                # encode the images and observe them with the first actions to get posteriors
                embed = ewm.encoder({"image": goal_imgs})
                p, _ = ewm.dynamics.observe(embed, iactions[:, 0][None], torch.zeros_like(iactions[:, 0][None])) 

                # pass the posteriors into the manager's critic to get values.
                inp = ewm.dynamics.get_feat(p)
                skill_values = agent.uiux.expert_model.manager.critics["reward"](inp).mode()


                chosen_index = skill_values.argmax().item()
                metrics.update({"expert_reward": skill_values[0, chosen_index].item(), "expert_reward_std": skill_values.std().item()})

            elif EXPERT_MODEL_USES_WORLD_VALUE := False:  
                iimages = iimages - 0.5
                data = {"image": iimages.permute(1, 0, 2, 3, 4)}
                embed = ewm.encoder(data)
                actions = iactions.permute(1, 0, 2)
                is_first = torch.zeros(embed.shape[:2], dtype=torch.bool, device=embed.device)

                swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape))))
                embed, actions, is_first = swap(embed), swap(actions), swap(is_first)

                # assert hasattr(expert_model, "latent"), "Need to attach latent to use expert model"
                post, _ = ewm.dynamics.observe(embed, actions, is_first)

                feats = ewm.dynamics.get_feat(post)
                rewards = ewm.heads["reward"](feats).mode()

                # pick the skill that has the greatest increase in value between the first and last frame
                # skill_values = rewards[:, -1] - rewards[:, 0]

                # pick the skill that has the greatest summed value above the average
                skill_values = rewards.sum(axis=1) - rewards.mean(axis=1)

                chosen_index = skill_values.argmax().item()
                metrics.update({"expert_reward": skill_values[chosen_index].item(), "expert_reward_std": skill_values.std().item()})
                
            def play_trajectory(i, title):
                if i >= iimages.shape[0]: return
                toshow = iimages[i].detach().cpu().numpy()
                for i in range(toshow.shape[0]):
                    cv2.imshow(title, toshow[i] + 0.5)
                    cv2.waitKey(20)

            # ipshell()
            play_trajectory(chosen_index, 'best trajectory')
            play_trajectory(skill_values.argmin().item(), "worst trajectory")
            # print(f"\t {skill_values.mean():.2e} std {skill_values.std():.2e}. Best is {skill_values.max()}")

    else:
        exit_code = -3; refreshes = 0
        while exit_code == -3: # -3 : "refresh"
            n_clusters = int(uiux.ncluster_input.get()); n_samples = int(uiux.nsample_input.get())
            iimages, skills, starting_goals, iactions, mets = gen_samples(n_samples); iimages = iimages * 255 # convert to 0-255
            chosen_index, exit_code = uiux.interface(iimages, n_clusters, iactions)
        mets.update({"human_time_cum": uiux.total_time_in_loop, "human_action_count": uiux.total_loop_actions})

    metrics.update(mets)
    if exit_code == -4: # atomic action exit code
        skill, goal_deter = None, None
    elif chosen_index is None:
        print("No image chosen")
        skill, goal_deter = None, None
    else:
        print(f"Chosen index: {chosen_index} of {starting_goals.shape}")
        skill, goal_deter = skills[chosen_index, 0, :], starting_goals[:, chosen_index, :]
    
    # skill_value = agent.manager.critics["reward"](torch.cat([goal_stoch, goal_deter], axis=-1)).mode().detach().cpu().numpy()

    return skill, goal_deter, exit_code, metrics

