#!/usr/bin/env python
# coding: utf-8

# In[ ]:


import argparse, functools, datetime, os, shutil, pathlib, sys
from collections.abc import Mapping

os.environ["MUJOCO_GL"] = "glfw" #"osmesa"

import numpy as np, ruamel.yaml as yaml
from functools import partial
import time, random
import envs.wrappers as wrappers
from parallel import Parallel, Damy

import torch
from torch import nn
from torch.nn import functional as F
from torch import distributions as torchd
from torch.utils.tensorboard import SummaryWriter

import torch_utils as tu
import director_models
import tools
import models
import networks
import mem_tools

import matplotlib.pyplot as plt
import fastcore.all as fc
from fastai_utils import to_cpu, show_image, scale_img, count_steps, make_dataset, make_env, make_animation
from fastrl_utils import *
from AnotherDirector import ImaginationActorCritic, ImaginationAgent, ImaginationAE, ImaginationAgentNonHier


to_np = lambda x: x.detach().cpu().numpy()
seed_value = 29
random.seed(seed_value)
np.random.seed(seed_value)
torch.manual_seed(seed_value)
torch.cuda.manual_seed(seed_value)


# In[ ]:


# env_name = "pinpad"
env_name = "MemoryMaze"
config = get_config(env_name=env_name)

# Get this current filename
import inspect
filename = inspect.getframeinfo(inspect.currentframe()).filename
print(filename)
if 'ipykernel' in filename: # in jupyter (cant use argparse)
    debug = False
    config.human = False
    notebook = True
else: 
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--human', action='store_true')
    debug = parser.parse_args().debug
    config.human = parser.parse_args().human
    notebook = False
# debug = False
config.prehuman = not config.human # Don't run the human mode, but save model periodically

# config = get_config(env_name="scalars/(.*)lossMemoryMaze")
# config.task = "MemoryMaze_9x9"
config.train_ratio = 128 # 32
# config.manager_lr = 3e-4; config.worker_lr = 3e-4
config.manager_lr = 3e-5; config.worker_lr = 3e-5
config.spoofed_human_path = None# "12-26_22-31-52_2023_pinpad_five_4x4_128_emamix" if config.human else None
config.load_pretrained_main_model_path = None #'01-03_00-54-33_2024_pinpad_five_spoofed_human4x4_128_emamix'
config.human_reward = 0 #0.001 if config.spoofed_human_path is None else 0.0
config.human_i_horizon = 12
config.n_human_clusters = 0 # don't cluster just show raw trajectories
config.n_skill_samples = 16 if config.n_human_clusters == 0 else 128
config.n_skill_samples = 512 if config.spoofed_human_path is not None else config.n_skill_samples

config.skill_alpha = 0.5 # blending between manager skill distribution and uniform distribution
train_steps = 100000
config.eval_steps = 250
config.prefill = 1000


if config.human:
    config.pretrain = 100 if config.load_pretrained_main_model_path is None else 1
    config.train_period = 0 if config.load_pretrained_main_model_path is None else 1
else:
    config.pretrain = 100; config.train_period = 5000

if debug:
    print("Configuring for DEBUG mode")
    config.pretrain = 1; config.train_period = 1; config.eval_steps = 1; train_steps = 10; config.prefill = 10; config.time_limit = 100

PARTIAL_TRAIN_FIRST = False # partial training to step in and mess with skill generation
DELETE_PREV_DATASET = True
config.skip_wm_training = DONT_TRAIN_WM = False

config.skill_shape = (8, 8)
config.reward_EMA = True
config.unfold_manager_traj = False
config.use_ccw_random_agent = False

logdir = pathlib.Path(config.logdir).expanduser()
if USE_OFFLINE_DATASET := False: # use a dataset of human demonstrations
    print("Using offline dataset.")
    config.offline_traindir = [logdir / "pinpad_human_demonstrations" / "npzs"]
    config.pretrain = 1

'''
# make non-hierarchical option that clips out the manager and the goal autoencoder. Must change state-input to worker to exclude goals.
 Possibly simpler to make an alternate agent depending on how ImaginationAgent is structured.
'''
config.single_level_agent = config.WORKER_PRIMARY = WORKER_PRIMARY = False # NOTE: yikes

# note base
notes = f'lowlr_' if (config.manager_lr < 3e-4 or config.worker_lr < 3e-4) else ''
notes = "human" if config.human else notes; notes = "prehuman" if config.prehuman else notes
notes = "spoofed_human" if config.spoofed_human_path else notes
notes += f'rpt{config.action_repeat}_'
notes += 'single_level' if config.single_level_agent else ''
notes += 'OFF' if USE_OFFLINE_DATASET else ''

# note additions
notes += f'{config.skill_shape[0]}x{config.skill_shape[1]}'; notes += f'_{config.train_ratio}'
# notes += '_ema' if config.reward_EMA else ''; notes += 'mix' if config.slow_target_fraction < 1.0 else ''


# In[ ]:


# setup weights and biases
import wandb

# wandb.init("white-computer-pinpad5-hier", 
        #    config=config)

project_name = "pops-mm-lowLR"
project_name += "-hier" if not config.single_level_agent else "-flat"
project_name += "-human" if config.human else ""
project_name += "-off" if USE_OFFLINE_DATASET else ""
wandb.init(project=project_name, sync_tensorboard=True)


# In[ ]:


logger, full_logdir = setup_paths_and_logging(config, logdir=logdir, delete_train_eps=DELETE_PREV_DATASET, notes=notes)
train_dataset, eval_dataset, human_dataset, train_envs, eval_envs, human_envs, train_eps, eval_eps, human_eps = get_datasets(config, logger, logdir=logdir)


# In[ ]:


def get_model():
    wm =  models.WorldModel(train_envs[0].observation_space, train_envs[0].action_space, step=None, config=config).to(config.device)
    if config.compile: wm = torch.compile(wm)

    z_size = config.dyn_stoch * config.dyn_discrete # one-hot vector grid size
    feat_size = z_size + config.dyn_deter


    if config.single_level_agent:
        environment_critic_scales = {"reward": 1.0} # how do scale extrinsic and exploration rewards
        print("Configuring single-level agent")
        # worker_critic_scales = {"reward_goal": 1.0} if not WORKER_PRIMARY else {"reward": 1.0}
        iac = worker = ImaginationActorCritic(config, wm, input_size=feat_size, num_actions=config.num_actions, value_heads=environment_critic_scales, stop_grad_actor=True, prefix="worker").to(config.device)
        iagent = ImaginationAgentNonHier(config, logger, worker=iac, _wm = wm, dataset=train_dataset)
        return iagent, None, None, None, wm
    else:
        environment_critic_scales = {"reward": 0.9, "reward_expl": 0.1} # how do scale extrinsic and exploration rewards
        print("Configuring hierarchical agent")
        goal_ae = ImaginationAE(config).to(config.device)

        worker_critic_scales = {"reward_goal": 1.0} #{"reward": 0.5, "reward_goal": 0.5} 
        # worker_critic_scales = {"reward_goal": 1.0} if not WORKER_PRIMARY else {"reward": 1.0}
        worker_feat_size = feat_size + config.dyn_deter if not WORKER_PRIMARY else feat_size
        iac = worker = ImaginationActorCritic(config, wm, input_size=worker_feat_size, num_actions=config.num_actions, value_heads=worker_critic_scales, stop_grad_actor=True, prefix="worker").to(config.device)
        print(f"Worker is using {worker.critics.keys()} critic(s)")

        manager = ImaginationActorCritic(config, wm, input_size=feat_size, num_actions=config.skill_shape, value_heads=environment_critic_scales, stop_grad_actor=True, prefix="manager").to(config.device)
        print(f"manager is using {manager.critics.keys()} critic(s)")
        iagent = ImaginationAgent(config, logger, worker=iac, manager=manager, goalae=goal_ae, _wm = wm, dataset=train_dataset)

        return iagent, worker, manager, goal_ae, wm


# In[ ]:


iagent, *_ = get_model()

if config.load_pretrained_main_model_path is not None:
    print(f"Loading pretrained model from {config.load_pretrained_main_model_path}")
    minstep, maxstep = 2000, 4000
    pretrained_model_path = logdir / config.load_pretrained_main_model_path 
    # find the iagent_{num}.pth file that's closest to minstep
    iagent_path = [f for f in pathlib.Path(pretrained_model_path).expanduser().iterdir() if f.name.startswith("iagent_") and f.name.endswith(".pth") and '-' not in f.stem]
    iagent_path = sorted(iagent_path, key=lambda x: abs(int(x.stem.split("_")[1]) - minstep))[0]
    print(f"\tLoading {iagent_path}")
    iagent.load_state_dict(torch.load(iagent_path))


# In[ ]:


save_config(logger, config)

shutil.copyfile('AnotherDirector.py', full_logdir / 'AnotherDirector.py')

steps = train_steps / (config.train_ratio / (config.batch_size * config.batch_length))

train_ratio = (config.train_ratio / (config.batch_size * config.batch_length)); run_steps = 0
train_for = int(config.train_period / train_ratio)

print(f"Training for {steps:0.1e} steps to train for {train_steps} steps. Training for {train_for} steps at a time.")

tool_state, eval_state, human_state = None, None, None
eval_run_steps = 0
while run_steps < steps:
    print(f"Training for {train_for} env steps. Total env steps so far {run_steps}. Total train steps {iagent._train_steps}.")
    tool_state = tools.simulate(
        iagent,
        train_envs,
        train_eps,
        config.traindir,
        logger,
        is_eval=False,
        limit=config.dataset_size,
        state=tool_state,
        # episodes=1,
        steps=train_for,
    ); run_steps += train_for
    # ); run_steps += config.time_limit

    if config.human:
        print("Human mode")
        # save the current model NOTE: this uses 178 mb so use sparingly
        # torch.save(iagent.state_dict(), full_logdir / f"iagent_{iagent._train_steps}.pth")


        # save the optimizer states
        # torch.save(iagent.worker.optimizer.state_dict(), full_logdir / 'opts' / f"worker_opt_{iagent._train_steps}.pth")
        # torch.save(iagent.goalae.opt.state_dict(), full_logdir / 'opts' / f"goalae_opt_{iagent._train_steps}.pth")
        # train_for = int(250 / train_ratio)
        train_for = 1000

        human_agent = functools.partial(iagent, training=False, human=True)
        # human_agent = functools.partial(iagent, training=False, human=True)
        human_state = tools.simulate(
            human_agent,
            human_envs,
            human_eps,
            config.humandir,
            logger,
            is_eval=False,
            limit=config.dataset_size,
            # limit=human_dataset_size,
            state = human_state,
            episodes=1,
            additional_cache=train_eps, # add the human dataset to the training dataset
            # steps = int(config.human_i_horizon * 25),
            uiux=iagent.uiux,
        )
        iagent.uiux._reset_session_state()

        # train over the human dataset for a bit
        # batch = next(human_dataset)
        # print("Training over human dataset. Batch shape ", batch[list(batch.keys())[0]].shape)
        
        # since we're not training during human mode, let's try to keep the number of updates consistent
        # balance_updates = int(iagent.train_every._ratio * config.time_limit)
        # for _ in range(balance_updates):
            # iagent._train(next(human_dataset))
            # iagent._train(next(iagent.dataset)); iagent._train_steps += 1

    # elif config.prehuman:
        # save the current model
        # torch.save(iagent.state_dict(), full_logdir / f"iagent_{iagent._train_steps}.pth")
        # NOTE: Make sure the optimizer state is being saved.
        # check if the saved model includes optimizer state
        # optimizer_state = torch.load(full_logdir / f"iagent_{iagent._train_steps}.pth", map_location=config.device)
        # for key in optimizer_state.keys():
        #     if "opt" in key:
        #         print("Optimizer state is being saved", key, optimizer_state[key].keys())

    eval_run_steps += train_for # need this to be consistent between human and non-human mode
    if eval_run_steps >= config.eval_every:
        eval_run_steps = 0
        print("Evaluating"); t0 = time.time()
        eval_agent = functools.partial(iagent, training=False)
        eval_state = tools.simulate(
            eval_agent,
            eval_envs,
            eval_eps,
            config.traindir,
            logger,
            is_eval=True,
            limit=config.dataset_size,
            state=eval_state,
            steps=config.eval_steps
            # episodes=2,
        )
        # print(f"Done evaluating took {time.time() - t0:1.2f}. Saving model at run step {run_steps} and train step {iagent._train_steps}.")
        # torch.save(iagent.state_dict(), full_logdir / f"iagent_{iagent._train_steps}.pth")
        print(f"Done evaluating took {time.time() - t0:1.2f}. Skipping model save.")


print(f"Trained for {iagent._train_steps} steps over a dataset of {count_steps(config.traindir):0.1e} steps")
play_complete_sound(n=2)


# In[ ]:


# sync up the logs
wandb.finish()

