#!/usr/bin/env python
# coding: utf-8

# In[2]:


import argparse, functools, datetime, os, shutil, pathlib, sys
from collections.abc import Mapping

os.environ["MUJOCO_GL"] = "osmesa"

import numpy as np, ruamel.yaml as yaml
from functools import partial
import time, random
import envs.wrappers as wrappers
from parallel import Parallel, Damy

import torch
from torch import nn
from torch.nn import functional as F
from torch import distributions as torchd
from torch.utils.tensorboard import SummaryWriter

import torch_utils as tu
import director_models
import tools
import models
import networks
import mem_tools

import matplotlib.pyplot as plt
import fastcore.all as fc
from fastai_utils import to_cpu, show_image, scale_img, count_steps, make_dataset, make_env, make_animation
from fastrl_utils import *
from AnotherDirector import ImaginationActorCritic, ImaginationAgent, ImaginationAE

to_np = lambda x: x.detach().cpu().numpy()


# In[3]:


seed_value = 29
random.seed(seed_value)
np.random.seed(seed_value)
torch.manual_seed(seed_value)
torch.cuda.manual_seed(seed_value)


# In[4]:


USE_HOOKS = False
# config = get_config(env_name="pinpad")
config = get_config(env_name="MemoryMaze")

# Get this current filename
import inspect
filename = inspect.getframeinfo(inspect.currentframe()).filename
print(filename)
if 'ipykernel' in filename: # in jupyter (cant use argparse)
    debug = False
    config.human = False
    notebook = True
else: 
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--human', action='store_true')
    debug = parser.parse_args().debug
    config.human = parser.parse_args().human
    notebook = False
# debug = False
config.prehuman = not config.human # Don't run the human mode, but save model periodically

# config = get_config(env_name="scalars/(.*)lossMemoryMaze")
# config.task = "MemoryMaze_9x9"
config.train_ratio = 128 # 32
config.manager_lr = 3e-4; config.worker_lr = 3e-4
config.spoofed_human_path = None# "12-26_22-31-52_2023_pinpad_five_4x4_128_emamix" if config.human else None
config.load_pretrained_main_model_path = None #'01-03_00-54-33_2024_pinpad_five_spoofed_human4x4_128_emamix'
config.human_reward = 0 #0.001 if config.spoofed_human_path is None else 0.0
config.human_i_horizon = 12
config.n_human_clusters = 8 # don't cluster just show raw trajectories
config.n_skill_samples = 16 if config.n_human_clusters == 0 else 128
config.n_skill_samples = 512 if config.spoofed_human_path is not None else config.n_skill_samples

config.skill_alpha = 0.5 # blending between manager skill distribution and uniform distribution
train_steps = 100000
config.eval_steps = 250
  
if config.human:
    config.pretrain = 100 if config.load_pretrained_main_model_path is None else 1
    config.train_period = 3000 if config.load_pretrained_main_model_path is None else 1
else:
    config.pretrain = 100; config.train_period = 10000

if debug:
    config.pretrain = 1; config.train_period = 1; config.eval_steps = 1; train_steps = 10

PARTIAL_TRAIN_FIRST = False # partial training to step in and mess with skill generation
DELETE_PREV_DATASET = True
LOAD_PRETRAINED_WM = False
config.WORKER_PRIMARY = WORKER_PRIMARY = False
config.skip_wm_training = DONT_TRAIN_WM = True if LOAD_PRETRAINED_WM else False
# config.manager_lr = 3e-3; config.worker_lr = 3e-4

config.skill_shape = (8, 8)
config.reward_EMA = True
# config.pretrain_wm_ae_only = True

config.unfold_manager_traj = False
config.use_ccw_random_agent = False

# note base
notes = f'highlr_' if (config.manager_lr > 3e-4 or config.worker_lr > 3e-4) else 'joint'
notes = "human" if config.human else notes; notes = "prehuman" if config.prehuman else notes
notes = "spoofed_human" if config.spoofed_human_path else notes

# note additions
notes += f'{config.skill_shape[0]}x{config.skill_shape[1]}'; notes += f'_{config.train_ratio}'; notes += '_hook' if USE_HOOKS else ''; notes += 'ufldmgr' if config.unfold_manager_traj else ''
notes += '_wm' if LOAD_PRETRAINED_WM else ''; notes += '_ema' if config.reward_EMA else ''; notes += 'mix' if config.slow_target_fraction < 1.0 else ''


# In[ ]:


logdir = pathlib.Path(config.logdir).expanduser()
logger, full_logdir = setup_paths_and_logging(config, logdir=logdir, delete_train_eps=DELETE_PREV_DATASET, notes=notes)
train_dataset, eval_dataset, human_dataset, train_envs, eval_envs, human_envs, train_eps, eval_eps, human_eps = get_datasets(config, logger, logdir=logdir)


# In[6]:


def get_model():
    wm =  models.WorldModel(train_envs[0].observation_space, train_envs[0].action_space, step=None, config=config).to(config.device)
    if config.compile: wm = torch.compile(wm)

    z_size = config.dyn_stoch * config.dyn_discrete # one-hot vector grid size
    feat_size = z_size + config.dyn_deter

    goal_ae = ImaginationAE(config).to(config.device)

    worker_critic_scales = {"reward_goal": 1.0} #{"reward": 0.5, "reward_goal": 0.5} 
    # worker_critic_scales = {"reward_goal": 1.0} if not WORKER_PRIMARY else {"reward": 1.0}
    worker_feat_size = feat_size + config.dyn_deter if not WORKER_PRIMARY else feat_size
    iac = worker = ImaginationActorCritic(config, wm, input_size=worker_feat_size, num_actions=config.num_actions, value_heads=worker_critic_scales, stop_grad_actor=True, prefix="worker").to(config.device)
    print(f"Worker is using {worker.critics.keys()} critic(s)")

    manager_critic_scales = {"reward": 0.9, "reward_expl": 0.1} # {"reward": 1.0} #
    manager = ImaginationActorCritic(config, wm, input_size=feat_size, num_actions=config.skill_shape, value_heads=manager_critic_scales, stop_grad_actor=True, prefix="manager").to(config.device)
    print(f"manager is using {manager.critics.keys()} critic(s)")
    iagent = ImaginationAgent(config, logger, worker=iac, manager=manager, goalae=goal_ae, _wm = wm, dataset=train_dataset)

    if LOAD_PRETRAINED_WM:
        print("Loading pretrained wm")
        wm.load_state_dict(torch.load(logdir / "bk_wm_15000.pth")); wm.eval(); wm.requires_grad_(False)
        goal_ae.load_state_dict(torch.load(logdir / "bk_goalae_15000.pth")); goal_ae.eval(); goal_ae.requires_grad_(False)

    return iagent, worker, manager, goal_ae, wm


# In[7]:


iagent, *_ = get_model()
if config.spoofed_human_path is not None:
    # Read in the pretrained model as a human proxy, and attach it to the uiux. This is a questionable design choice, but it's intended to keep a clear line between "inside the model" and "external feedback"
    spoofed_human_path = logdir / config.spoofed_human_path / "iagent.pth"
    print(f"Loading spoofed model from {spoofed_human_path}")
    spoofed_human_model, *_ = get_model()
    spoofed_human_model.load_state_dict(torch.load(spoofed_human_path)); 
    iagent.uiux.expert_model = spoofed_human_model

if config.load_pretrained_main_model_path is not None:
    print(f"Loading pretrained model from {config.load_pretrained_main_model_path}")
    minstep, maxstep = 2000, 4000
    pretrained_model_path = logdir / config.load_pretrained_main_model_path 
    # find the iagent_{num}.pth file that's closest to minstep
    iagent_path = [f for f in pathlib.Path(pretrained_model_path).expanduser().iterdir() if f.name.startswith("iagent_") and f.name.endswith(".pth") and '-' not in f.stem]
    iagent_path = sorted(iagent_path, key=lambda x: abs(int(x.stem.split("_")[1]) - minstep))[0]
    print(f"\tLoading {iagent_path}")
    iagent.load_state_dict(torch.load(iagent_path))

if USE_HOOKS: 
    from miniai.activations import Hooks, append_stats
    # model = fc.filter_ex(iagent.worker.actor.modules(), fc.noop)
    model_names = []
    model = fc.filter_ex(iagent.worker.actor.modules(), fc.risinstance(getattr(nn, config.worker_actor['act'])))
    model_names += [f"actor{i}{m.__class__.__name__}" for i, m in enumerate(model)]


    critic_models = []
    for v in iagent.worker.critics.values(): 
        critic_models += fc.filter_ex(v.modules(), fc.risinstance(getattr(nn, config.act)))
    model += critic_models
    model_names += [f"value{i}{m.__class__.__name__}" for i, m in enumerate(critic_models)]
    # model = fc.filter_ex(iagent.worker.actor.modules(), fc.risinstance(nn.Linear))
    [print(m, n) for m,n in zip(model, model_names)]
    if 'hooks' in globals(): hooks.remove()
    
    hooks = Hooks(model, append_stats)


# In[8]:


save_config(logger, config)

shutil.copyfile('AnotherDirector.py', full_logdir / 'AnotherDirector.py')

steps = train_steps / (config.train_ratio / (config.batch_size * config.batch_length))

train_ratio = (config.train_ratio / (config.batch_size * config.batch_length)); run_steps = 0
train_for = int(config.train_period / train_ratio)

print(f"Training for {steps:0.1e} steps to train for {train_steps} steps. Training for {train_for} steps at a time.")

tool_state, eval_state, human_state = None, None, None
while run_steps < steps:
    print(f"Training for {train_for} env steps. Total env steps so far {run_steps}. Total train steps {iagent._train_steps}.")
    tool_state = tools.simulate(
        iagent,
        train_envs,
        train_eps,
        config.traindir,
        logger,
        is_eval=False,
        limit=config.dataset_size,
        state=tool_state,
        # episodes=1,
        steps=train_for,
    ); run_steps += train_for
    # ); run_steps += config.time_limit

    if config.human:
        print("Human mode")
        # save the current model NOTE: this uses 178 mb so use sparingly
        # torch.save(iagent.state_dict(), full_logdir / f"iagent_{iagent._train_steps}.pth")


        # save the optimizer states
        # torch.save(iagent.worker.optimizer.state_dict(), full_logdir / 'opts' / f"worker_opt_{iagent._train_steps}.pth")
        # torch.save(iagent.goalae.opt.state_dict(), full_logdir / 'opts' / f"goalae_opt_{iagent._train_steps}.pth")
        # train_for = int(250 / train_ratio)
        train_for = 500

        human_agent = functools.partial(iagent, training=False, human=True)
        # human_agent = functools.partial(iagent, training=False, human=True)
        human_state = tools.simulate(
            human_agent,
            human_envs,
            human_eps,
            config.humandir,
            logger,
            is_eval=False,
            limit=config.dataset_size,
            # limit=human_dataset_size,
            state = human_state,
            episodes=1,
            additional_cache=train_eps, # add the human dataset to the training dataset
            # steps = int(config.human_i_horizon * 25),
            uiux=iagent.uiux,
        )
        iagent.uiux._reset_session_state()

        # train over the human dataset for a bit
        # batch = next(human_dataset)
        # print("Training over human dataset. Batch shape ", batch[list(batch.keys())[0]].shape)
        
        # since we're not training during human mode, let's try to keep the number of updates consistent
        # balance_updates = int(iagent.train_every._ratio * config.time_limit)
        # for _ in range(balance_updates):
            # iagent._train(next(human_dataset))
            # iagent._train(next(iagent.dataset)); iagent._train_steps += 1

    # elif config.prehuman:
        # save the current model
        # torch.save(iagent.state_dict(), full_logdir / f"iagent_{iagent._train_steps}.pth")
        # NOTE: Make sure the optimizer state is being saved.
        # check if the saved model includes optimizer state
        # optimizer_state = torch.load(full_logdir / f"iagent_{iagent._train_steps}.pth", map_location=config.device)
        # for key in optimizer_state.keys():
        #     if "opt" in key:
        #         print("Optimizer state is being saved", key, optimizer_state[key].keys())

    print("Evaluating"); t0 = time.time()
    eval_agent = functools.partial(iagent, training=False)
    eval_state = tools.simulate(
        eval_agent,
        eval_envs,
        eval_eps,
        config.traindir,
        logger,
        is_eval=True,
        limit=config.dataset_size,
        state=eval_state,
        steps=config.eval_steps
        # episodes=2,
    )
    # print(f"Done evaluating took {time.time() - t0:1.2f}. Saving model at run step {run_steps} and train step {iagent._train_steps}.")
    # torch.save(iagent.state_dict(), full_logdir / f"iagent_{iagent._train_steps}.pth")
    print(f"Done evaluating took {time.time() - t0:1.2f}. Skipping model save.")


print(f"Trained for {iagent._train_steps} steps over a dataset of {count_steps(config.traindir):0.1e} steps")
play_complete_sound(n=2)


# In[9]:


# format the current time as a string
now = datetime.datetime.now().strftime("%m-%d_%H-%M-%S")
torch.save(iagent.state_dict(), full_logdir / f"iagent_{now}.pth")
torch.save(iagent._wm.state_dict(), full_logdir / f"wm_{train_steps}_{now}.pth")
torch.save(iagent.goalae.state_dict(), full_logdir / f"goalae_{train_steps}_{now}.pth")


# In[10]:


if USE_HOOKS:
    def get_hist(h): return torch.stack(h.stats[2]).t().float().log1p()
    def get_min(h):
        h1 = torch.stack(h.stats[2]).t().float()
        return h1[0]/h1.sum(0)
    fig,axs = plt.subplots(1,2, figsize=(11,5))
    for idx, h in enumerate(iter(hooks)):
        if not hasattr(h, "stats"): continue
        for i in 0,1: axs[i].plot(h.stats[i], label=model_names[idx])

    axs[0].set_title('Means'); axs[1].set_title('Stdevs'); plt.legend(model_names)
    # plt.legend()

    # JHoward explanation of histogram: https://youtu.be/9YZaYjRKuEc?si=-PcZkEEBMph-gGpS&t=4662

    from miniai.datasets import get_grid
    hc = hooks
    fig,axes = get_grid(len(hc), 3, 1, figsize=(11,5))
    for ax,h in zip(axes.flat, hc):
        show_image(get_hist(h), ax, origin='lower')
    # plt.imshow(get_hist(hc[0]), origin='lower')


# In[11]:


if USE_HOOKS:
    def dead_units():
        fig,axes = get_grid(len(hc), figsize=(11,5))
        # This shows the ratio of activations that are near 0 (1.0 on the graph). Being near the top of this graph is bad. 
        for ax,h in zip(axes.flatten(), hc):
            ax.plot(get_min(h))
            ax.set_ylim(0,1)
    dead_units()


# In[12]:


# eval_policy = functools.partial(iac, training=False)
# for step in range(train_steps):
#     batch = next(train_dataset)
#     post, context, metrics = wm._train(batch)
#     with torch.no_grad():
#         # feats, states, actions, skills, goals = imagine_carry(config, wm.dynamics, post, iac.actor, config.imag_horizon+1)
#         feats, states, actions = imagine(config, wm.dynamics, post, iac.actor, config.imag_horizon+1)
#     cont = wm.heads["cont"](feats).mean
#     first_cont = torch.Tensor((1 - batch['is_terminal']).reshape(1, -1, 1)).to(config.device)
#     imag_cont = torch.cat([first_cont, cont[1:]], dim=0)
#     imag_reward = wm.heads["reward"](feats).mode()[1:]

#     # print(f"imag_reward {imag_reward.shape}, {imag_reward.mean():1.2f}, {imag_reward.std():1.2f}")
#     # print(f"imag_actions {actions.shape}, {actions[0, 0, :]}, {actions.std():1.2f}")
#     imag_traj = {
#         # "image": images,
#         "stoch": states["stoch"],
#         "deter": states["deter"],
#         "logit": states["logit"],
#         "feat": feats,
#         "action": actions,

#     imag_traj["reward_goal"] = imag_traj["reward"]
#     *_, metrics = iac._train(imag_traj=imag_traj)
#     [logger.scalar(f"train/{k}", v) for k, v in metrics.items()]
#     logger.step += 1

#     if step % 10 == 0:
#         # logger.write()
#         tools.simulate(
#             iac,
#             train_envs,
#             train_eps,
#             config.traindir,
#             logger,
#             is_eval=False,
#             limit=config.dataset_size,
#             episodes=1,
#         )
        


# In[29]:


obs.keys()


# In[33]:


# loaded_iagent, *_ = get_model()
# loaded_iagent.load_state_dict(torch.load(logdir / "iagent.pth"))
# loaded_agent, *_ = get_model()
# loaded_agent.load_state_dict(torch.load(logdir / "12-19_03-47-23_2023_pinpad_five_prehuman4x4_256_emamix" / "iagent_10101.pth"))
# loaded_agent_state = None

agent = iagent
loaded_agent = None

# agent = spoofed_human_model
# agent.uiux.expert_model = loaded_agent
# agent = loaded_agent
# agent.uiux.expert_model = spoofed_human_model
agent.uiux.reset()

state = None
action = None
is_first = None
goal_img = None

total_r = 0
imgs = []
skill = None

def agent_update(agent, state, obs, action):
    assert "image" in obs.keys(), "Observation must contain an image for agent update."
    if state is None:
        action = torch.zeros((1, pp.action_space.n)).to(config.device)
        state = agent._wm.dynamics.initial(action.shape[0])
        is_first = torch.Tensor([1.0]).to(config.device)
    else:
        is_first = torch.Tensor([0.0]).to(config.device)

    obs.update({"is_first": is_first, "is_terminal": torch.Tensor([0.])})
    obs = agent._wm.preprocess(obs)
    embed = agent._wm.encoder(obs)[None, :]
    post, prior = agent._wm.dynamics.obs_step(state, action, embed, is_first)

    return post, prior, state, action, is_first

def goal_img_value(agent, goal_img, latent, action, is_first):
    img = goal_img.copy()
    # normalize goal_img to lie between 0 and 1
    img = (img - img.min()) / (img.max() - img.min()) - 0.5
    img = torch.Tensor(goal_img).cuda()

    embed = agent._wm.encoder({"image": img})

    p, _ = agent._wm.dynamics.observe(embed, action[None], is_first[None], state=None)
    inp = agent._wm.dynamics.get_feat(p)
    v = agent.manager.critics["reward"](inp).mode().detach().cpu().numpy()
    return v

# for step in range(config.time_limit):
obs = pp.reset()
for step in range(2000):
    if not notebook: continue
    # img = pp.render()
    img = obs["image"]
    obs = {"image": img}

    post, prior, state, action, is_first = agent_update(agent, state, obs, action)

    # Loaded agent (keep the rolling posterior up to date)
    if loaded_agent is not None: 
        lpost, lprior, loaded_agent_state, laction, lis_first = agent_update(loaded_agent, loaded_agent_state, obs, action)

    feat = agent._wm.dynamics.get_feat(post)

    if skill is None or step % (config.imag_horizon) == 0:
        goal_img = None
        manager_chooses_skill = not HUMAN
        if HUMAN:
            skill, goal_deter, exit_code, metrics = hriu.human_interaction(post, agent, obs, num_skill_samples=256, sample=False)
            agent._metrics.update(metrics)
            goal_stoch = agent._wm.dynamics.get_stoch(goal_deter)
            inp = torch.cat([reshape(goal_stoch), goal_deter], axis=-1)

            goal_img = agent._wm.heads["decoder"](inp[None])["image"].mode().detach().cpu().numpy()
            if exit_code < 0:
                print(f"Human selected goal with exit code {exit_code}.")
                manager_chooses_skill = True

        # if manager_chooses_skill: skill, goal_deter, goal_stoch, goal_img = loaded_agent.get_goal(loaded_agent._wm.dynamics.get_feat(lpost), sample=True, gen_image=True)
        if manager_chooses_skill: skill, goal_deter, goal_stoch, goal_img = agent.get_goal(feat, sample=True, gen_image=True)
        skill_value = agent.manager.critics["reward"](torch.cat([reshape(goal_stoch), goal_deter], axis=-1)).mode().detach().cpu().numpy()
        if goal_img is not None:
            # skill_value = goal_img_value(agent, goal_img, post, action, is_first)
            if loaded_agent is not None: lskill_value = goal_img_value(loaded_agent, goal_img, lpost, laction, lis_first)
            goal_img = fru.norm_dec_img(goal_img)


    # print(agent._wm.heads["reward"](feat).mode().detach().cpu().numpy())

    inp = torch.cat([feat, goal_deter], -1)
    onehot_dist = agent.worker.actor.forward(inp)
    # action = onehot_dist.mode()
    action = onehot_dist.sample()
    action_idx = torch.argmax(action)


    # TODO: standardize environment return structure (or alt?)
    # obs, *_ = pp.step(action_idx) # Pinpad
    obs, reward, done, info = pp.step(action_idx) # memorymaxe
    obs["reward"] = reward

    if reward > 1e-3: print(f"Reward {reward:1.2f}")

    total_r += obs["reward"]

    to_show = obs["image"]
    # add the goal to the image
    if goal_img is not None: 
        goal_img = goal_img.squeeze()
        to_show = np.concatenate([obs["image"], goal_img], axis=1)

    # double the size of to_show
    to_show = cv2.resize(to_show, (0,0), fx=3, fy=3)

    # write the skill value onto the right side of the frame
    font = cv2.FONT_HERSHEY_SIMPLEX
    cv2.putText(to_show, f"skill: {skill_value.mean():1.2f}", (10, 30), font, 1, (0, 0, 0), 2, cv2.LINE_AA)
    if loaded_agent is not None: cv2.putText(to_show, f"lskill: {lskill_value.mean():1.2f}", (10, 60), font, 1, (0, 0, 0), 2, cv2.LINE_AA)

    cv2.imshow('pp', to_show)
    cv2.waitKey(30)
    if obs["is_terminal"]:
        state = None
        pp.reset()
    state = post
    
cv2.destroyAllWindows()
print("R ", total_r)

