import torch
import torch.nn.functional as F
from tkinter import Tk, Canvas
import tkinter as tk
import numpy as np
import random, math, time
import cv2
from PIL import Image, ImageTk
from collections import defaultdict
import fastrl_utils as frlu
import pygame

from sklearn.cluster import KMeans
from IPython import embed as ipshell

import pygame
import os

from nov20.second_wind.peripheral import get_memorymaze_action_from_joystick, get_calvin_action_from_joystick, get_pinpad_action_from_joystick

to_np = lambda x: x.detach().cpu().numpy()
unnormalize = lambda x: (x + 0.5) * 255
# unnormalize = lambda x: ((x + 1) / 2) * 255 # [-1, 1] to [0, 255]

EXIT_CODES = {0: "human_selected_cluster",-1: "defer",-2: "exit", -3: "refresh", -4: "atomic_action"}

class UIUX:
    def __init__(self, config):
        self.expert_model = None # may be None
        self.exited = 0 # Permanently exit human interaction
        self.deferred = False # Temporarily defer human interaction for this episode
        self.live_img = None
        self.total_loop_actions = 0
        self.total_time_in_loop = 0
        self.config = config

        pygame.joystick.init()
        try:
            self.joysticks = [pygame.joystick.Joystick(x) for x in range(pygame.joystick.get_count())]
            os.environ['SDL_JOYSTICK_ALLOW_BACKGROUND_EVENTS'] = "1"
            print(f"Found {len(self.joysticks)} joysticks.")
        except:
            self.joysticks = []
            print("No joysticks found.")
        # use display 1

        # self.pygame_display = pygame.display.set_mode((200, 200))
        
        self.reset()

    # make expert model a setter
    @property
    def expert_model(self):
        return self._expert_model
    
    @expert_model.setter
    def expert_model(self, expert_model):
        if expert_model is not None: print(f"WARN: Setting expert model. All interfaces will be spoofed using this model.")
        self._expert_model = expert_model

    def should_skip(self): 
        exited = (self.exited > 0)
        self.exited -= 1
        ret = exited or self.deferred
        if exited and (self.exited == 0): 
            frlu.play_intermediate_sound(n=2)
        return exited or self.deferred

    def _reset_session_state(self):
        ''' Reset the variables used on every human interaction session'''
        self.selected_video_cluster = None
        self.update_clusters = True

        if self.deferred: frlu.play_intermediate_sound(n=1)
        
        self.deferred = False

        self.atomic_action_idx = -1; self.atomic_action_mode = False

        # gray out the buttons
        self.canvas.itemconfig(self.trust_robot_button, state=tk.NORMAL)
        self.canvas.itemconfig(self.trust_robot_text, state=tk.NORMAL)
        self.canvas.itemconfig(self.end_session_button, state=tk.NORMAL)

    def refresh_params(self):
        '''
        Interface function to update parameters from slider values
        '''
        mix_ratio = self.mix_slider.get() / 100
        self.cluster_fps = self.fps_slider.get()

        return mix_ratio

    def reset(self):
        if hasattr(self, "canvas"): 
            try:
                self.canvas.destroy()
            except:
                pass
        self.root = Tk()
        self.root.title("HRI UIUX")
        self.screenw, self.screenh = self.root.winfo_screenwidth() // 4, self.root.winfo_screenheight() // 2

        self.canvas = Canvas(self.root, width=self.screenw, height=self.screenh)
        # self.canvas = tk.Frame(self.root, width=self.screenw, height=self.screenh)
        self.canvas.pack()
        self.root.update()

        # add a button
        posx, posy = 0, 0
        self.trust_robot_button = self.canvas.create_rectangle(posx, posy  + 50, posx + 200, posy  + 100, fill="green")
        self.trust_robot_text = self.canvas.create_text(posx + 100, posy  + 75, text="Trust Robot", anchor="center", fill="black", font=("Purisa", 24))
        self.canvas.tag_bind(self.trust_robot_button, "<Button-1>", self.trust_robot)
        self.canvas.tag_bind(self.trust_robot_text, "<Button-1>", self.trust_robot)
        posx, posy = 0, 75
        self.end_session_button = self.canvas.create_rectangle(posx, posy  + 50, posx + 200, posy  + 100, fill="red")
        self.end_session_text = self.canvas.create_text(posx + 100, posy  + 75, text="End Session", anchor="center", fill="black", font=("Purisa", 24)) 
        self.canvas.tag_bind(self.end_session_button, "<Button-1>", self.end_session)
        self.canvas.tag_bind(self.end_session_text, "<Button-1>", self.end_session)
        posx, posy = 0, 150
        self.refresh_button = self.canvas.create_rectangle(posx, posy  + 50, posx + 200, posy  + 100, fill="blue")
        self.refresh_text = self.canvas.create_text(posx + 100, posy  + 75, text="Refresh", anchor="center", fill="white", font=("Purisa", 24))
        self.canvas.tag_bind(self.refresh_button, "<Button-1>", self.refresh)
        self.canvas.tag_bind(self.refresh_text, "<Button-1>", self.refresh)

        # create text input boxes that let the user input integers

        self.canvas.create_text(0, 600, text="n_clusters", anchor="nw", fill="black", font=("Purisa", 24))
        self.ncluster_input = tk.Entry(self.root)
        self.ncluster_input.insert(0, "-1")
        self.ncluster_input.pack(side=tk.LEFT, fill=tk.BOTH, expand=1)

        self.canvas.create_text(300, 600, text="n_samples", anchor="nw", fill="black", font=("Purisa", 24))
        self.nsample_input = tk.Entry(self.root)
        self.nsample_input.insert(0, "-1")
        self.nsample_input.pack(side=tk.LEFT, fill=tk.BOTH, expand=1)

        # create an fps slider
        self.fps = self.cluster_fps = 60 # if not hasattr(self, 'config') else self.config.human_fps
        self.fps_slider = tk.Scale(self.root, from_=0, to=100, orient=tk.HORIZONTAL, troughcolor="green")
        self.fps_slider.pack(side=tk.LEFT, fill=tk.BOTH, expand=1)
        self.fps_slider.set(self.cluster_fps)

        # create a mix slider
        self.mix_slider = tk.Scale(self.root, from_=0, to=100, orient=tk.HORIZONTAL, troughcolor="blue")
        self.mix_slider.pack(side=tk.LEFT, fill=tk.BOTH, expand=1)

        # how many loops should we do before exiting when a sessions is ended (0 defers forever).
        self.end_session_duration_slider = tk.Scale(self.root, from_=0, to=50, orient=tk.HORIZONTAL, troughcolor="red")
        self.end_session_duration_slider.pack(side=tk.LEFT, fill=tk.BOTH, expand=1)
        self.end_session_duration_slider.set(3)

        self.static_images = defaultdict(list)
        self.rectangle_label = None
        self.selected_video_cluster = None
        self.update_clusters = True
        self.string_history_length = 50
        self._reset_session_state()

        # gray out the buttons
        self.canvas.itemconfig(self.trust_robot_button, state=tk.DISABLED)
        self.canvas.itemconfig(self.trust_robot_text, state=tk.DISABLED)
        self.canvas.itemconfig(self.end_session_button, state=tk.DISABLED)

        ### low level action canvas ###
        self.ll_canvas = Canvas(self.root, width=self.screenw, height=self.screenh)
        self.ll_canvas.pack()

        # add a button for each atomic action (discrete for now)
        self.atomic_action_mode = False
        self.atomic_action_idx = -1

        n_buttons = 6
        button_width = self.screenw // n_buttons
        button_height = self.screenh // 2
        self.button_bounds = []
        for i in range(n_buttons):
            posx, posy = i * button_width, 0
            button = self.ll_canvas.create_rectangle(posx, posy, posx + button_width, posy + button_height, fill="green")
            text = self.ll_canvas.create_text(posx + button_width // 2, posy + button_height // 2, text=f"{i}", anchor="center", fill="black", font=("Purisa", 24))
            self.button_bounds.append((posx, posy, posx + button_width, posy + button_height))
            self.ll_canvas.tag_bind(button, "<Button-1>", self.ll_action)


                # set the window id for pygame
        # if hasattr(self, 'pygame_window'): pygame.display.quit()
        # if not hasattr(self, 'pygame_window'): 
        #     os.environ['SDL_WINDOWID'] = str(self.canvas.winfo_id())
        #     # os.environ['SDL_VIDEODRIVER'] = 'dummy'
        #     self.pygame_window = pygame.display.set_mode(size=(0, 0), flags=0, depth=0, display=0, vsync=0)
        #     # self.pygame_window = pygame.display.set_mode(pygame.display)
        #     # self.pygame_window = pygame.display.set_mode((200, 200))
        #     # print("Initializing pygame window", self.canvas.winfo_id())
        #     # os.environ['SDL_WINDOWID'] = str(self.canvas.winfo_id())
        #     # pygame.display.init()
        #     # self.pygame_window = pygame.display.set_mode((self.screenw, self.screenh))
        #     # print(f"pygame window id: {self.pygame_window.get_wm_info()['window']}")
        #     pygame.display.init()
        #     pygame.display.update()
        

    def ll_action(self, event):
        print(f"ll_action: {event}")
            # Get the item ID of the canvas item that was clicked
        item_id = self.ll_canvas.find_withtag("current")

        if item_id:
            for i, pos in enumerate(self.button_bounds):
                if pos[0] <= event.x <= pos[2] and pos[1] <= event.y <= pos[3]:
                    print(f"Button {i} clicked. Position {pos}")
                    self.ll_canvas.itemconfig(item_id, fill="red")
                    self.ll_canvas.after(100, lambda: self.ll_canvas.itemconfig(item_id, fill="green"))
                    
                    self.atomic_action_idx = i
                    self.atomic_action_mode = True
                    self.exit_loop(-4)
                    return

            
    ### Button Functions ###
    def exit_loop(self, ret_code): 
        if ret_code == -1: self.deferred = True
        elif ret_code == -2: 
            skip_sessions = self.end_session_duration_slider.get()
            self.exited = skip_sessions if skip_sessions > 0 else 1e6 # skip all future sessions if slider is set to 0
        self.update_clusters = False
        self.selected_video_cluster = ret_code
        self.total_time_in_loop += time.time() - self.loop_start_time

    def trust_robot(self, event): self.exit_loop(-1)
    def end_session(self, event): self.exit_loop(-2)
    def refresh(self, event): self.exit_loop(-3)
    ### end Button Functions ###

    ### text functions
    def update_text(self,n_clusters, n_samples):
        if n_clusters < 0: n_clusters = 0
        if n_samples < 0: n_samples = 0
        if n_clusters > n_samples:
            print(f"n_clusters {n_clusters} > n_samples {n_samples}, setting n_clusters to n_samples")
            n_clusters = n_samples

        self.ncluster_input.delete(0, tk.END); self.ncluster_input.insert(0, str(n_clusters))
        self.nsample_input.delete(0, tk.END); self.nsample_input.insert(0, str(n_samples))
    ### end text functions

    def update_obs(self, img, obs_string):
        scale_size = 256
        image_pos = 10, 200
        string_pos = image_pos[0], image_pos[1] + scale_size
        self.update_image(img, obs_string, "obs", image_pos, string_pos, size=scale_size)

        cv2.imshow("obs", img)
        cv2.waitKey(1)

        # self.root.update()
        # self.root.after(1000//self.fps)

    def update_feed(self, live_img, scale=2):
        if type(live_img) == torch.Tensor: 
            live_img = live_img.detach().cpu().numpy()
            # normalize to [0, 1]
            live_img = unnormalize(live_img).clip(0,255).astype(np.uint8)

        live_img = live_img.squeeze()
        live_img = cv2.resize(live_img, None, fx=scale, fy=scale, interpolation=cv2.INTER_NEAREST)
        live_img = ImageTk.PhotoImage(Image.fromarray(live_img.astype(np.uint8)))
        if not hasattr(self, 'live_img_label'): self.live_img_label = self.canvas.create_image(self.screenw // 2, 10, anchor="nw")
        
        self.live_img = live_img

        self.canvas.itemconfig(self.live_img_label, image=self.live_img)
        self.root.update()

    def update_goal(self, goal, stoch, goal_string):
        scale_size = 256
        image_pos = self.screenw - scale_size, 0
        string_pos = image_pos[0] - int(1.5*scale_size), image_pos[1]

        goal_stoch = self.world_model.dynamics.get_stoch(goal)
        goal_stoch = goal_stoch.reshape([*goal_stoch.shape[:-2], -1])
        inp = torch.cat([goal_stoch, goal], dim=-1)
        pure_goal_img = self.world_model.heads["decoder"](inp.unsqueeze(0))["image"].mode().detach().cpu().numpy()
        self.update_image(pure_goal_img.squeeze().squeeze(), goal_string, "goal_pure", image_pos, string_pos, size=scale_size)

        stoch = stoch.reshape([*stoch.shape[:-2], -1])
        inp = torch.cat([stoch, goal], dim=-1)
        goal_stoch_img = self.world_model.heads["decoder"](inp.unsqueeze(0))["image"].mode().detach().cpu().numpy()
        self.update_image(goal_stoch_img.squeeze().squeeze(), '', "goal_stoch", (image_pos[0], image_pos[1]+scale_size), string_pos, size=scale_size)

        # self.root.update()

    def update_image(self, image, image_string, image_key, image_pos, string_pos, size=256):
        '''
        Update or place an image at the position with the image string as text above it.
        Stored with internal key
        '''
        unnormalized_img = unnormalize(image).clip(0,255).astype(np.uint8)
        # swap the red and blue channels
        unnormalized_img = unnormalized_img[...,::-1]
        img = Image.fromarray(unnormalized_img)  # Convert tensor to PIL Image
        img = img.resize((size, size), Image.LANCZOS)  # Resize image
        img = ImageTk.PhotoImage(img)

        if image_key not in self.static_images:
            posx, posy = image_pos
            display_label = tk.Label(self.root, image=img)
            display_label.place(x=posx, y=posy)
            # display_label.pack()
            text_label = tk.Label(self.root, text=image_string)
            text_label.config(font=("Purisa", 12))
            string_posx, string_posy = string_pos
            text_label.place(x=string_posx, y=string_posy)
            # text_label.pack()
            # display_label = self.canvas.create_image(posx, posy, anchor="nw")
            # text_label = self.canvas.create_text(posx - 100, posy - 50, text=image_string, anchor="nw", fill="black", font=("Purisa", 16))
            # display_label.pack()
            # text_label.pack()
            self.static_images[image_key] = [display_label, text_label, img, [image_string]]
        else:
            display_label, text_label, _, text_history = self.static_images[image_key]
            display_label.configure(image=img)

            text_history.append(image_string)
            if len(text_history) > self.string_history_length:
                text_history.pop(0)

            text_history_string = "\n".join(text_history)
            text_label.configure(text=text_history_string)
            # text_label.configure(text=image_string)
            self.static_images[image_key][2] = img
            self.static_images[image_key][3] = text_history

    def display_images_in_grid(self, imgs, n_rows=8, n_cols=8, title="images"):
        # show all the images on a grid in cv2
        '''
        Expected input shape: (n_samples, n_frames, height, width, channels)
        '''
        # Create a blank canvas for the grid
        canvas = np.zeros((n_rows * imgs.shape[2], n_cols * imgs.shape[3], 3), dtype=np.uint8)
        sampled_idx = np.random.choice(imgs.shape[0], size=min(n_rows * n_cols, imgs.shape[0]), replace=False)

        # Loop over each sample and add it to the canvas
        sample_idx = 0
        for i in range(n_rows):
            for j in range(n_cols):
                # Calculate the position of the sample in the grid
                x = j * imgs.shape[3]
                y = i * imgs.shape[2]
                # take the mean across all frames
                if sample_idx >= len(sampled_idx):
                    break

                unnormalized_img = unnormalize(np.mean(imgs[sampled_idx[sample_idx], :], axis=0))
                canvas[y:y+imgs.shape[2], x:x+imgs.shape[3], :] = unnormalized_img
                sample_idx += 1

        # Display the grid
        cv2.imshow(title, canvas)
        cv2.waitKey(100)
        # cv2.destroyAllWindows()

    def update_image_grid(self, images, image_texts, image_key, image_pos, string_pos, n_rows=8, n_cols=8, title="images"):
        pass

    def interface(self, images, n_clusters=8, actions=None):
        """ autogenerated
        This function creates an interface for interaction using human-in-the-loop feedback. The method uses the 
        given latent (stoch, deter) and actor_policy to perform the following steps:

        1. Reset the interactive session state.
        2. Sample skills given the current latent state.
        3. Decode goals from the generated skills.
        4. For each goal, generate imagined worker trajectories.
        5. Generate and concatenate images from the imagined trajectories.

        Once you have imagined image trajectories
        6. Cluster the trajectories (configurable by image/action).
        7. Visualize the clusters and ask the user to select one.
        8. Return the skill corresponding to the selected cluster.

        Parameters:
        latent (dict): Dictionary containing the stochastic and deterministic latent variables for the world model.
        actor_policy (Policy): The actor's policy used to generate the imagined worker trajectories.

        Returns:
        ret_skill (Tensor): The selected skill after user interaction. None, if the user chose to exit or defer.
        exit_code (str): A string indicating the reason for return. Can be 'human_selected_cluster' if a cluster 
                         was selected, 'defer' if the decision was deferred, or 'exit' if the user chose to exit. 
                         None, if a cluster was selected.
        """
        t0 = time.time()
        self._reset_session_state()

        # If we have an expert model we're faking the human interaction from this side. 
        if self.expert_model is not None: return self.spoof_interface(images, actions)

        n_frames = images.shape[1]
        video_list = []
        video_texts = []
        if n_clusters > 0:
            kmeans = KMeans(n_clusters=n_clusters, random_state=0, n_init="auto")
            if True:
                '''
                Split the images into 4 patches, average the color channels in each patch, and cluster the patches. 12 features per image.
                '''
                mean_images = images.mean(axis=1)
                patches = []
                # show_image(mean_images[0, 0:32, 0:32, :])
                split = 32
                for i in range(2):
                    # patch = mean_images[:, i*split:(i+1)*split, i*split:(i+1)*split, :]
                    patch = mean_images[:, i*split:(i+1)*split, i*split:(i+1)*split, :] # upper left, lower right
                    patches.append(patch)
                    patch = mean_images[:, i*split:(i+1)*split, (split*2)-(i+1)*split:(split*2)-i*split, :] # upper right, lower left
                    patches.append(patch)
                mean_patches = [to_np(patch.mean(axis=(1, 2))) for patch in patches] # mean r,g,b (3 features)
                mean_patches = np.stack(mean_patches, axis=1)  # stack up all for patches (4 patches)
                mean_patches = mean_patches.reshape(-1, 12) # 12 features per sample
                kmeans = kmeans.fit(mean_patches)
                labels = kmeans.labels_

                # print a measure of how good the clusters are
                print(f"cluster score: {kmeans.score(mean_patches)}")
            elif actions is not None:
                # action clustering needs to take time into account. Trajectory similarity by action.
                action_idx = actions.argmax(axis=-1)
                # make a histogram for each trajectory and cluster by wasserstein distance
                histograms = [torch.histc(action_idx[batch], bins=actions.shape[-1]) for batch in action_idx]
                histograms = [to_np(h) for h in histograms]
                labels = kmeans.fit(histograms).labels_
            else:
                # mean over frames, width, height -> (n_trajectories, channels)
                mean_rgb = images.mean(axis=(1, 2, 3))
                kmeans = kmeans.fit(mean_rgb)
                labels = kmeans.labels_
                
            clusters = [[] for _ in range(n_clusters)]
            traj_cluster_map = defaultdict(list) # maps from cluster_idx to traj_idx
            for traj_idx, label in enumerate(labels):
                traj_cluster_map[label].append(traj_idx)
                clusters[label].append(images[traj_idx])
                
            for cluster_idx in range(n_clusters):
                print(f"cluster_idx: {cluster_idx} n_trajs: {len(traj_cluster_map[cluster_idx])}, {traj_cluster_map[cluster_idx]}")
            print(f"interface cluster time {time.time() - t0:.2f}s"); t0 = time.time()
                
            for cluster_idx in range(n_clusters):
                frame_list = []
                image_texts = []
                random_cluster_example = 0 # NOTE: how should we select from the cluster?
                if len(clusters[cluster_idx]) == 0:
                    print(f"WARN: cluster_idx: {cluster_idx} is empty")
                    continue

                ## Open cv windows of each cluster with a grid of the cluster examples
                # self.display_images_in_grid(np.array(clusters[cluster_idx]), n_rows=8, n_cols=8, title=f"cluster_idx: {cluster_idx}")

                if take_mean_of_cluster:= True:
                    cluster_example = torch.mean(torch.stack(clusters[cluster_idx]), dim=0)
                    # print(f"cluster_example shape: {cluster_example.shape}")
                    # cluster_actions_example = np.mean(np.array(clustered_actions[cluster_idx]), axis=0)
                else:
                    cluster_example = clusters[cluster_idx][random_cluster_example]
                    cluster_actions_example = clustered_actions[cluster_idx][random_cluster_example]
                # print(f"cluster example shape: {cluster_example.shape}, cluster_actions_example shape: {cluster_actions_example.shape}")
                for frame_idx in range(n_frames):
                    frame_list.append(cluster_example[frame_idx])
                    image_texts.append(f"{''}")

                video_texts.append(image_texts)
                video_list.append(frame_list)
        else:
            # NOTE: this can be done way faster
            # just display the images in a grid
            for traj_idx in range(images.shape[0]):
                frame_list = []
                image_texts = []
                for frame_idx in range(n_frames):
                    frame_list.append(images[traj_idx, frame_idx])
                    image_texts.append(f"{''}")
                video_texts.append(image_texts)
                video_list.append(frame_list)
    

        # print(len(video_list), len(video_list[0]), len(video_list[0][0]))
        selected_index = self.display_videos(video_list, video_texts)
        print(f"interface runtime {time.time() - t0:.2f}s")
        # selected_index = random.randint(0, len(video_list)-1)

        if selected_index < 0: return None, selected_index
        else:
            idx = traj_cluster_map[selected_index][0] if n_clusters > 0 else selected_index
            print(f"selected_index: {selected_index}, goal idx {idx}")
            return idx, 0

    def display_videos(self, video_lists, video_texts):
        self.update_clusters = True
        # radius = 200 # (0, 1)
        # circle_pattern = CirclePattern(root, radius, clusters)

        # Display the videos in a grid on the canvas
        # calculate the size of the grid
        n_videos = len(video_lists)
        n_rows = int(np.sqrt(n_videos))
        n_cols = int(np.ceil(n_videos / n_rows))

        # calculate the size of each video
        video_w, video_h = video_lists[0][0].shape[0:2]

        # calculate the size of the grid in pixels
        grid_w = n_cols * video_w
        grid_h = n_rows * video_h

        # calculate the position of the grid in pixels
        # grid_x = (screenw) // 2
        # grid_y = (screenh) // 2
        grid_x = (self.screenw - grid_w) // 2
        grid_y = (self.screenh - grid_h) // 3

        grid_spacing = 50

        video_positions = []
        video_bounds = []
        for i in range(len(video_lists)):
            x = grid_x + (i % n_cols) * video_w + (i % n_cols) * grid_spacing
            y = grid_y + (i // n_cols) * video_h + (i // n_cols) * grid_spacing
            video_positions.append((x, y))
            video_bounds.append((x, y, x + video_w, y + video_h))

        def on_canvas_click(event):
            print(f"Clicked at {event.x}, {event.y}")
            for i, pos in enumerate(video_bounds):
                if pos[0] <= event.x <= pos[2] and pos[1] <= event.y <= pos[3]:
                    print(f"Image {i} clicked. Position {video_positions[i]}")
                    self.selected_video_cluster = i
                    # highlight the selected image
                    self.rectangle_label = self.canvas.create_rectangle(pos[0], pos[1], pos[2], pos[3], outline="red", width=5)
                    # self.root.after(500, self.root.destroy)
                    self.update_clusters = False

        def display_videos_on_canvas_in_loop(root, canvas, videos, video_positions, video_texts):
            """
            Displays a list of numpy images as a looping video on a tkinter canvas.

            Args:
                root (Tk): The tkinter root object.
                canvas (Canvas): The tkinter canvas object.
                videos (List[List[np.ndarray]]): A list of lists of numpy arrays representing the frames of each video.
                fps (int): The frames per second to display the video at. Defaults to 30.
            """
            assert len(videos) == len(video_positions), "The number of videos must match the number of video positions."
            # create a list to hold the PhotoImage objects for each video
            photo_images_list = []
            photo_texts_list = []

            # create a tkinter label and list of PhotoImage objects for each video
            video_labels = []
            text_labels = []
            for idx, video in enumerate(videos):
                photo_images = []
                photo_texts = []
                # video_label = canvas.create_image(0, idx*image_h, anchor="nw")
                x,y = video_positions[idx]
                video_label = canvas.create_image(x, y, anchor="nw")
                video_labels.append(video_label)
                text_label = canvas.create_text(x - grid_spacing // 4, y - grid_spacing // 2, anchor="nw", text=video_texts[idx][0], font=("Arial", 16, "bold"), fill="black")
                text_labels.append(text_label)
                for img_idx, img in enumerate(video):
                    # scale up the image size
                    # img = F.interpolate(img, size=(video_h*2, video_w*2), mode="nearest")
                    # img = cv2.resize(img, (video_w*2, video_h*2), interpolation=cv2.INTER_NEAREST)

                    
                    pil_img = Image.fromarray(to_np(img).astype(np.uint8))
                    photo_image = ImageTk.PhotoImage(pil_img)
                    photo_images.append(photo_image)
                    photo_texts.append(video_texts[idx][img_idx])
                photo_images_list.append(photo_images)
                photo_texts_list.append(photo_texts)

            assert len(photo_images_list) == len(video_labels), "The number of PhotoImage objects must match the number of video labels."

            # display the videos on the canvas
            # if self.live_img_label is not None:
                # live_img_label = canvas.create_image(self.screenw // 2, 10, anchor="nw")
                # live_img = ImageTk.PhotoImage(Image.fromarray(self.live_img.astype(np.uint8)))
                # canvas.itemconfig(live_img_label, image=live_img)

            loop_timer_started = False
            self.loop_start_time = time.time()
            while self.update_clusters:
                if SHOW_IMAGES := False:
                    for i, photo_images in enumerate(photo_images_list):
                        # print(i, video_labels[i], photo_images[0])
                        canvas.itemconfig(video_labels[i], image=photo_images[0])
                        canvas.itemconfig(text_labels[i], text=photo_texts_list[i][0])
                        # canvas.create_text(video_positions[i][0], video_positions[i][1] + video_h, text=photo_texts_list[i][0], anchor="nw", fill="black", font=("Purisa", 24))

                        photo_images.append(photo_images.pop(0))
                        photo_texts_list[i].append(photo_texts_list[i].pop(0))
                else:
                    pass
                    # print(f"SKIPPING CLUSTER UPDATE FOR RLC.")

                # Get pygame events if we have a joystick
                if len(self.joysticks) > 0:
                    pygame.event.pump()

                    joystick_action = None
                    if self.config.task == "pinpad_five":
                        joystick_action = get_pinpad_action_from_joystick(self.joysticks)
                    elif self.config.task == "MemoryMaze_9x9":
                        joystick_action = get_memorymaze_action_from_joystick(self.joysticks)
                    # elif self.config.task == "calvin":
                    # joystick_action = get_calvin_action_from_joystick(self.joysticks)
                    else:
                        print(f"ERROR: Joystick not implemented for this task {self.config.task}")
                    # if config.task == "memorymaze":
                    #     joystick_action = get_memorymaze_action_from_joystick(self.joysticks)
                    # elif config.task == "calvin":
                    #     joystick_action = get_calvin_action_from_joystick(self.joysticks)
                    # else:
                    #     raise NotImplementedError, "Joystick not implemented for this task", config.task
                    
                    if joystick_action is not None:
                        if not loop_timer_started:
                            # Gate on first joystick action
                            print("Starting loop timer")
                            self.loop_start_time = time.time()
                            loop_timer_started = True
                        self.total_loop_actions += 1
                        self.atomic_action_idx = joystick_action
                        self.atomic_action_mode = True
                        self.exit_loop(-4)
                    
                canvas.itemconfig(self.live_img_label, image=self.live_img)
                root.update() 
                if hasattr(self, 'pygame_display'): pygame.display.update()
                root.after(int(1000/self.cluster_fps))

            # delete the text labels
            for text_label in text_labels:
                canvas.delete(text_label)

        # display left and right images
        # left_image = Image.fromarray((left_side_image).clip(0,255).astype(np.uint8))  # Convert tensor to PIL Image
        # left_image = left_image.resize((256, 256), Image.LANCZOS)  # Resize image
        # left_img = ImageTk.PhotoImage(left_image)
        # left_img_obj = self.canvas.create_image(0, 300, image=left_img, anchor="nw")

        # right_image = Image.fromarray((right_side_image).clip(0,255).astype(np.uint8))  # Convert tensor to PIL Image
        # right_image = right_image.resize((256, 256), Image.LANCZOS)  # Resize image
        # right_img = ImageTk.PhotoImage(right_image)
        # right_img_obj = self.canvas.create_image(1400, 300, image=right_img)

        # display videos
        self.canvas.bind("<Button-1>", on_canvas_click)
        self.root.bind("q", lambda e: self.root.destroy())

        display_videos_on_canvas_in_loop(self.root, self.canvas, video_lists, video_positions, video_texts)
        self.canvas.delete(self.rectangle_label)

        print(f"User selected {self.selected_video_cluster if self.selected_video_cluster >= 0 else EXIT_CODES[self.selected_video_cluster]}")
        return self.selected_video_cluster

    def show_data(self, data: dict):
        assert "image" in data

        images = data["image"]

        for batch_idx, batch in enumerate(images):
            for step in batch:
                cv2.imshow(f"image batch {batch_idx}", step)
                cv2.waitKey(100)

    def test_goal_decoder(self, n=1000):
        # sample a ton of goals from the uniform prior
        skills = self.skill_prior.sample((n,))
        skills = skills.reshape(n, -1)
        goals = self.goal_dec(skills).mode()
        stoch = self.world_model.dynamics.get_stoch(goals)
        stoch = stoch.reshape(n, -1)
        inp = torch.concat([stoch, goals], dim=-1)
        imgs = self.world_model.heads["decoder"](inp.unsqueeze(0))["image"].mode().detach().cpu().permute(1, 0, 2, 3, 4).numpy()
        self.display_images_in_grid(imgs, 10, 20)
        

    def spoof_decision(self, start):
        ''' Instead of spoofing the interface (selection from a set of videos), spoof the decision (selection of a skill)'''
        # what would the expert model choose?
        pass

    def spoof_interface(self, imagined_trajectory_images, imagined_actions):
            '''
            Evaluate the imagined trajectories based on the saved model, return the corresponding index
            NOTE: we're using a model in place of a human to select the skill/goal
            '''
            wm = self.expert_model
            traj_values = []
            imagined_actions = torch.stack(imagined_actions, dim=0)

            # reshape both so they're (timestep, sample, ...)
            imagined_actions = imagined_actions.permute(1, 0, 2)
            imagined_trajectory_images = imagined_trajectory_images.permute(1, 0, 2, 3, 4)

            for obs, action in zip(imagined_trajectory_images, imagined_actions):
                # Using the world model, encode the images into the world models latent space
                # obs = wm.preprocess(imagined_traj)
                embed = wm.encoder({"image":obs.to(self.config.device)})

                is_first = torch.zeros((action.shape[0], 1))
                is_first = is_first.unsqueeze(0)
                action = action.unsqueeze(0)
                embed = embed[:-1].unsqueeze(0)
                # action = action.permute(1, 0)
                # embed = embed.permute(1, 0)
                post, _ = wm.dynamics.observe(embed, action, is_first)

                # option1: Use the world models reward decoder to evaluate the trajectory
                feat = wm.dynamics.get_feat(post)
                traj_reward = wm.heads["reward"](feat).mean().sum().detach().cpu().numpy()
                traj_values.append(traj_reward)

                # option2: Just use the loaded model to pick a skill directly, this skips the sampling stage though so doesn't really simulate a human picking from sampled skills.

            return np.argmax(traj_values), "spoofed_ selected_cluster"

class CirclePattern:
    def __init__(self, root, radius, clusters):
        self.root = root
        self.radius = radius
        self.clusters = clusters
        self.canvas = Canvas(self.root, width=self.root.winfo_screenwidth() // 2, height=self.root.winfo_screenheight() // 2, bg='white')
        self.canvas.pack()
        self.image_positions = []
        self.cluster_bounds = []
        self.images = []

        self.selected_image_cluster = None

        # self.root.after_idle(self.load_images)
        self.root.after(10, self.load_images)

    def quit(self):
        if hasattr(self, 'pygame_display'):
            pygame.display.quit()

        self.root.destroy()

    def spatial_calculations(self):
        self.root.update()
        center = (self.canvas.winfo_width() / 2, self.canvas.winfo_height() / 2)  # Center of the canvas
        print(f"Center of canvas: {center}")
        clusters = []
        for cluster_idx, cluster in enumerate(self.clusters):
            image = None
            for _, tensor in enumerate(cluster):
                angle = 2 * math.pi * cluster_idx / len(self.clusters)
                x = center[0] + self.radius * math.cos(angle)
                y = center[1] + self.radius * math.sin(angle)

            clusters.append(cluster_idx)
            self.cluster_bounds.append((x - self.img_width // 2, y - self.img_height // 2, x + self.img_width // 2, y + self.img_height // 2))
        
        # Bind mouse click event to canvas
        # quit on q
        self.canvas.bind("<Button-1>", self.on_canvas_click)
        self.root.bind("q", self.quit())
        self.root.bind("<Escape>", self.quit())

        # self.selected_image_cluster = valid_clusters[0] # NOTE: temporary random selection
        # self.root.after(500, self.root.destroy())

    def load_images(self):
        self.root.update()
        center = (self.canvas.winfo_width() / 2, self.canvas.winfo_height() / 2)  # Center of the canvas
        print(f"Center of canvas: {center}")
        valid_clusters = []
        for cluster_idx, cluster in enumerate(self.clusters):
            image = None
            for _, tensor in enumerate(cluster):
                angle = 2 * math.pi * cluster_idx / len(self.clusters)
                x = center[0] + self.radius * math.cos(angle)
                y = center[1] + self.radius * math.sin(angle)
                img_array = tensor.numpy().astype(np.uint8)
                image = Image.fromarray(img_array)  # Convert tensor to PIL Image
                image = image.resize((128, 128), Image.ANTIALIAS)  # Resize image
                img = ImageTk.PhotoImage(image)
                self.images.append(img)
                img_obj = self.canvas.create_image(x, y, image=img)

                # Add text about cluster and position
                self.canvas.create_text(x, y-100, text=f"Cluster {cluster_idx}", font=("Arial", 16  ), fill="black")
            if image is None:
                print(f"Cluster {cluster_idx} is empty")
            else:
                valid_clusters.append(cluster_idx)
                self.cluster_bounds.append((x - img.width() // 2, y - img.height() // 2, x + img.width() // 2, y + img.height() // 2))
        # Bind mouse click event to canvas
        self.canvas.bind("<Button-1>", self.on_canvas_click)
        # quit on q
        self.root.bind("q", lambda e: self.root.destroy())

        # self.selected_image_cluster = valid_clusters[0] # NOTE: temporary random selection
        # self.root.after(500, self.root.destroy())

    def on_canvas_click(self, event):
        for i, pos in enumerate(self.cluster_bounds):
            if pos[0] <= event.x <= pos[2] and pos[1] <= event.y <= pos[3]:
                print(f"Image {i} clicked. Position in circle: {i + 1}/{len(self.clusters)}")
                self.selected_image_cluster = i
                # highlight the selected image
                self.canvas.create_rectangle(pos[0], pos[1], pos[2], pos[3], outline="red", width=5)
                # self.root.after(500, self.root.destroy)
                self.update_clusters = False

if __name__ == "__main__":
    task = "memorymaze"
    class SpoofConfig:
        def __init__(self):
            self.device = "cuda"
            self.task = "n/a"

    config = SpoofConfig()
    if task == "memorymaze":
        from envs.memorymaze import MemoryMaze
        env = MemoryMaze("9x9")
        config.task = "MemoryMaze_9x9"
    elif task == "pinpad":
        from envs.pinpad import PinPad
        env = PinPad("five")
        config.task = "pinpad_five"
    
    obs = env.reset()
    get_img = lambda obs: obs["image"]

    ## Calvin robot arm environment
    # import hydra
    # with hydra.initialize(config_path="../calvin/calvin_env/conf/"):
    #     cfg = hydra.compose(config_name="config_data_collection.yaml", overrides=["cameras=static_and_gripper"])
    #     cfg.env["use_egl"] = False
    #     cfg.env["show_gui"] = False
    #     cfg.env["use_vr"] = False
    #     cfg.env["use_scene_info"] = True
    #     print(cfg.env)

    # env = hydra.utils.instantiate(cfg.env)
    # obs = env.reset()
    # def get_img(obs):
    #     if "rgb_obs" not in obs: return 
    #     if "rgb_static" not in obs["rgb_obs"]: return
    #     return obs["rgb_obs"]["rgb_static"]
    
    # from nov20.second_wind.peripheral import get_calvin_action_from_joystick

    # env.step(np.zeros((7,)))
    # get_img = lambda obs: obs["rgb_obs"]["rgb_static"]

    # for k,v in obs.items():
    #     print(f"{k}")
        # print(f"{k}: {v.shape if hasattr(v, 'shape') else v}")
    
    uiux = UIUX(config)
    uiux.mix_slider.set(50)
    uiux.update_text(n_clusters=0, n_samples=16)
    s = 64
    initial = np.random.rand(1, s, s, 3) * 255
    alpha = uiux.refresh_params()

    n = 1000
    total_R = 0
    for _ in range(n):
        uiux.update_feed(get_img(obs))
        # uiux.update_feed(obs["image"])
        images = torch.rand(17, 32, s, s, 3) * 255
        image_idx, exit_code = uiux.interface(images, n_clusters=0)

        if uiux.atomic_action_mode:
            action = uiux.atomic_action_idx
            if type(uiux.atomic_action_idx) == list:
                print(f"atomic_action_idx: {[f'{entry:1.2f}' for entry in uiux.atomic_action_idx]}")
            else:
                print(f"atomic_action_idx: {uiux.atomic_action_idx}")
            uiux.atomic_action_mode = False
            uiux.exit_loop(-4)
            

            # calvin_action = [0, 0, 0.1, 0, 0, 0, 1] # dx, dy, dz, droll, dpitch, dyaw, gripper
            # obs, reward, done, *info = env.step(calvin_action)
            obs, reward, *_ = env.step(action)
            total_R += reward
            print(f"reward: {reward}")
        print(f"image_idx: {image_idx}, exit_code: {exit_code}")
        print(f"total_R: {total_R}")
        uiux._reset_session_state()
    print(f"total_R: {total_R}")