import sys,gc,traceback
import torch
import numpy as np
from collections.abc import Mapping
import fastcore.all as fc
import matplotlib.pyplot as plt
import tools
import envs.wrappers as wrappers

def count_steps(folder):
    return sum(int(str(n).split("-")[-1][:-4]) - 1 for n in folder.glob("*.npz"))

def make_dataset(episodes, config, batch_length=None, batch_size=None):
    generator = tools.sample_episodes(episodes, config.batch_length if batch_length is None else batch_length)
    dataset = tools.from_generator(generator, config.batch_size if batch_size is None else batch_size)
    return dataset

def make_env(config, mode):
    suite, task = config.task.split("_", 1)
    if suite == 'pinpad':
        import envs.pinpad as pinpad
        assert config.size == (64, 64), "PinPad only supports 64x64 images " + str(config.size)
        env = pinpad.PinPad(task)
        env = wrappers.OneHotAction(env)
    elif suite == "MemoryMaze":
        from envs.memorymaze import MemoryMaze
        env = MemoryMaze(task, seed=config.seed, action_repeat=config.action_repeat)
        env = wrappers.OneHotAction(env)
    elif suite == "calvin":
        import hydra
        with hydra.initialize(config_path="../calvin/calvin_env/conf/"):
            cfg = hydra.compose(config_name="config_data_collection.yaml", overrides=["cameras=static_and_gripper"])
            cfg.env["use_egl"] = False
            cfg.env["show_gui"] = False
            cfg.env["use_vr"] = False
            cfg.env["use_scene_info"] = True
            print(cfg.env)
        env = hydra.utils.instantiate(cfg.env)
    else:
        raise NotImplementedError(suite)
    env = wrappers.TimeLimit(env, config.time_limit)
    env = wrappers.SelectAction(env, key="action")
    env = wrappers.UUID(env)
    if suite == "minecraft":
        env = wrappers.RewardObs(env)
    return env
from matplotlib.animation import FuncAnimation
import matplotlib.pyplot as plt
from IPython.display import HTML

def make_animation(batch, batch_idx=0, dt=100, vscode=False):
    def do_one(d):
        ax.clear()
        if hasattr(batch, "keys") and "image" in batch:
            ax.imshow(batch['image'][batch_idx, d, ...])
        else:
            ax.imshow(batch[batch_idx, d, ...])
    
    fig,ax = plt.subplots()
    ani = FuncAnimation(fig, do_one, frames=64, interval=dt, repeat=False)
    if vscode:
        plt.show()
    else:
        plt.close()
        return HTML(ani.to_jshtml())

def to_cpu(x):
    if isinstance(x, Mapping): return {k:to_cpu(v) for k,v in x.items()}
    if isinstance(x, list): return [to_cpu(o) for o in x]
    if isinstance(x, tuple): return tuple(to_cpu(list(x)))
    res = x.detach().cpu()
    return res.float() if res.dtype==torch.float16 else res

def show_image(im, ax=None, figsize=None, title=None, noframe=True, **kwargs):
    "Show a PIL or PyTorch image on `ax`."
    if fc.hasattrs(im, ('cpu','permute','detach')):
        im = im.detach().cpu()
        if len(im.shape)==3 and im.shape[0]<5: im=im.permute(1,2,0)
    elif not isinstance(im,np.ndarray): im=np.array(im)
    if im.shape[-1]==1: im=im[...,0]
    if ax is None: _,ax = plt.subplots(figsize=figsize)
    ax.imshow(im, **kwargs)
    if title is not None: ax.set_title(title)
    ax.set_xticks([]) 
    ax.set_yticks([]) 
    if noframe: ax.axis('off')
    return ax

def scale_img(decoded_img):
    min_val = torch.min(decoded_img)
    max_val = torch.max(decoded_img)

    if max_val == min_val:
        return torch.zeros_like(decoded_img)

    scaled_tensor = (decoded_img - min_val) / (max_val - min_val)
    return scaled_tensor * 255

def clean_ipython_hist():
    # Code in this function mainly copied from IPython source
    if not 'get_ipython' in globals(): return
    ip = get_ipython()
    user_ns = ip.user_ns
    ip.displayhook.flush()
    pc = ip.displayhook.prompt_count + 1 
    for n in range(1, pc): user_ns.pop('_i'+repr(n),None)
    user_ns.update(dict(_i='',_ii='',_iii=''))
    hm = ip.history_manager
    hm.input_hist_parsed[:] = [''] * pc
    hm.input_hist_raw[:] = [''] * pc
    hm._i = hm._ii = hm._iii = hm._i00 =  ''

# %% ../nbs/11_initializing.ipynb 12
def clean_tb():
    # h/t Piotr Czapla
    if hasattr(sys, 'last_traceback'):
        traceback.clear_frames(sys.last_traceback)
        delattr(sys, 'last_traceback')
    if hasattr(sys, 'last_type'): delattr(sys, 'last_type')
    if hasattr(sys, 'last_value'): delattr(sys, 'last_value')

# %% ../nbs/11_initializing.ipynb 13
def clean_mem():
    clean_tb()
    clean_ipython_hist()
    gc.collect()
    torch.cuda.empty_cache()