# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/04_multiagent.ipynb.

# %% auto 0
__all__ = ['config', 'to_np', 'n_agents', 'n_timesteps', 'n_features', 'X', 'Y', 'agent', 'perturber', 'env', 'meta_agent',
           'compiled_gmm', 'reshape', 'gmm_loss', 'fit_gmm', 'round_robin', 'MetaAgent', 'Perturber', 'get_patch',
           'DroneImgAgent', 'DroneImgEnv', 'single_step', 'single_batch']

# %% ../nbs/04_multiagent.ipynb 1
import argparse
import datetime
import functools
import os
import pathlib
import sys
from collections.abc import Mapping
import matplotlib.pyplot as plt

os.environ["MUJOCO_GL"] = "osmesa"

import numpy as np
import ruamel.yaml as yaml
import pathlib
from functools import partial
import time
import random
import skimage.feature as skf


import fastcore.all as fc
import torch
from torch import nn
from torch.nn import functional as F
from torch import distributions as torchd
from torch.utils.tensorboard import SummaryWriter

# add the grandparent directory to sys.path so we can import tools.
sys.path.append(str(pathlib.Path(os.getcwd()).parent.parent.absolute()))
import envs.wrappers as wrappers
import torch_utils as tu
from parallel import Parallel, Damy
import director_models
import tools
import models
import networks
import mem_tools
from fastrl_utils import *
import sklearn.mixture as mixture
config = get_config()
to_np = lambda x: x.detach().cpu().numpy()


# %% ../nbs/04_multiagent.ipynb 2
@torch.compile
def compiled_gmm(X, n):
    return mixture.GaussianMixture(n_components=n).fit(X)    

# %% ../nbs/04_multiagent.ipynb 3
n_agents = 4
n_timesteps = 10
n_features = 144

def reshape(x, n_timesteps=10, n_features=144):
    return x.reshape(x.shape[0] * n_timesteps, n_features)

def gmm_loss(X, gmm):
    return round(-gmm.score(X), 4)

def fit_gmm(X, n=5):
    gmm = mixture.GaussianMixture(n_components=n).fit(X)
    loss = round(-gmm.score(X), 4)
    return loss

X = np.random.randn(n_agents, n_timesteps, n_features)
Y = np.random.randn(n_agents, n_timesteps, n_features)
Y = reshape(Y)
X = reshape(X)

fit_gmm(Y), fit_gmm(X)


# %% ../nbs/04_multiagent.ipynb 4
def round_robin(X, gmm_n=10):
    '''
    Take a (n x timestep x features) matrix and return the scores of each agent.
        Scores are the negative log likelihood of the group without the agent under a GMM made from the group.
    '''
    n_agents, n_timesteps, n_features = X.shape
    losses = []
    for i in range(n_agents):
        X_ = X[[j for j in range(n_agents) if j != i]]; #print(X_.shape)
        X_ = reshape(X_, n_timesteps=n_timesteps, n_features=n_features);
        loss = fit_gmm(X_, n=gmm_n)
        # normalize loss by the number of points
        loss /= (n_agents - 1) * n_timesteps
        losses.append(loss)

    baseline_loss = fit_gmm(reshape(X, n_timesteps=n_timesteps, n_features=n_features), n=3)
    baseline_loss /= n_agents * n_timesteps

    return losses, baseline_loss

X = np.random.randn(n_agents, n_timesteps, n_features)
round_robin(X)

# %% ../nbs/04_multiagent.ipynb 5
if USE_MVGAUSSIAN := False:
    def fit_mv_gaussian(X):
        mean = np.mean(X, axis=0)
        cov = np.cov(X.T)
        mvn = torchd.multivariate_normal.MultivariateNormal(torch.tensor(mean), torch.tensor(cov))
        loss = mvn.log_prob(torch.tensor(X))
        loss = -loss.mean().item()
        return round(loss, 4)

    X = np.random.randn(n_agents, n_timesteps, n_features); X = reshape(X, n_timesteps=n_timesteps, n_features=n_features)
    fit_mv_gaussian(X), fit_gmm(X)


# %% ../nbs/04_multiagent.ipynb 9
class MetaAgent():
    def __init__(self, n_agents=3) -> None:
        self.n_agents = n_agents
        self.history = []; self.baseline_history = []
    def update(self, X, **kwargs):
        # Keep an estimate of the group distribution and the individual distribution
        losses, baseline_loss = round_robin(X, **kwargs)
        assert len(losses) == self.n_agents
        self.history.append(losses); self.baseline_history.append(baseline_loss)
        return losses, baseline_loss

    def plot(self, with_baseline=True, legend=None, special_idx=None):
        for i in range(self.n_agents):
            if special_idx is not None and i == special_idx:
                l = f'spcl {i}'
            else:
                l = f'agent {i}'
                
            plt.plot([loss[i] for loss in self.history], label=l)
        if with_baseline:
            plt.plot(self.baseline_history, color='black', linestyle='dashed')
        if legend is not None: plt.legend(legend)
        else: plt.legend()

class Perturber():
    # It's perturbin'!
    def __init__(self, scale=1) -> None:
        self.scale = scale
    def perturb(self, X, scale=None):
        scale = self.scale if scale is None else scale
        return X + np.random.randn(*X.shape) * scale

n_agents, n_timesteps, n_features = 3, 10, 20
agent = MetaAgent(n_agents=3)
perturber = Perturber()

for _ in range(10):
    X = np.random.randn(n_agents, n_timesteps, n_features)
    X[0] = perturber.perturb(X[0], scale=10)
    agent.update(X)

agent.plot(legend=['agent 0 (perturbed)', 'agent 1', 'agent 2', 'baseline'])


# %% ../nbs/04_multiagent.ipynb 10
# print the current directory
# make a path to the data directory

import random
import gymnasium as gym
import cv2
def get_patch(img, x, y, theta=0, patch_size=64):
    ps = patch_size // 2
    '''
    Get a patch centered at x,y from an image.
    '''
    # rotate the image
    # img = np.rot90(img.copy(), theta)
    patch_bounds = (x-ps, x+ps, y-ps, y+ps)
    patch = img[patch_bounds[0]:patch_bounds[1], patch_bounds[2]:patch_bounds[3]]
    return patch, patch_bounds

class DroneImgAgent():
    def __init__(self, xlim, ylim, x=0, y=0, theta=0, id=0) -> None: 
        fc.store_attr(); self.reset()

    def reset(self):
        self.x = random.randint(*self.xlim)
        self.y = random.randint(*self.ylim)
        self.theta = random.randint(0, 360)
        self.history = [(self.x, self.y, self.theta)]
        self.goal_history = []; self._set_goal()

        return self.x, self.y, self.theta
    
    def _set_goal(self):
        self.goal = random.randint(*self.xlim), random.randint(*self.ylim)
        self.goal_history.append(self.goal)

    def step(self):
        vel = 1 # pixels
        angvel = 0.1 # radians
        # move towards the goal
        dx, dy = self.goal[0] - self.x, self.goal[1] - self.y
        dtheta = np.arctan2(dy, dx)

        if np.abs(dtheta - self.theta) > 0.1: # NOTE: Theta is not currently used
            self.theta += angvel if self.theta < dtheta else -angvel

        self.x += min(dx, vel) if self.x < self.goal[0] else max(dx, -vel)
        self.y += min(dy, vel) if self.y < self.goal[1] else max(dy, -vel)

        self.history.append((self.x, self.y, self.theta))

        if np.linalg.norm([dx, dy]) < 10: self._set_goal()

class DroneImgEnv():
    def __init__(self, n, patch_size=64) -> None:
        fc.store_attr()
        self.ps = self.patch_size // 2
        self.data_dir = pathlib.Path(os.getcwd()) / '../drone_pics'
        self.img_fns = os.listdir(self.data_dir)
        self.img = None
        self.replacement_agent_id = 2

    def reset(self):
        self.img_fn = random.choice(self.img_fns); self.img_fns.remove(self.img_fn)
        self.img = plt.imread(self.data_dir / self.img_fn)
        self.padded_img = np.pad(self.img, ((self.ps, self.ps), (self.ps, self.ps), (0, 0)), mode='constant')

        self.replacement_img_fn = random.choice(self.img_fns); 
        self.padded_replacement_img = np.pad(plt.imread(self.data_dir / self.replacement_img_fn), ((self.ps, self.ps), (self.ps, self.ps), (0, 0)), mode='constant')

        # cv2.imshow('img', self.img)
        # cv2.imshow('padded_img', self.padded_replacement_img)
        # cv2.waitKey(1000)

        print("img ", self.img_fn, np.mean(self.padded_img, axis=(0, 1)), np.std(self.padded_img, axis=(0, 1)))
        print("replimage", self.replacement_img_fn, np.mean(self.padded_replacement_img, axis=(0, 1)), np.std(self.padded_replacement_img, axis=(0, 1)))

        minx, maxx = (0, self.img.shape[0]); miny, maxy = (0, self.img.shape[1])
        self.agents = [DroneImgAgent(xlim=(minx, maxx), ylim=(miny, maxy), id=id) for id in range(self.n)]
        return self.img
    
    def step(self): [agent.step() for agent in self.agents]

    def get_stream_patch(self, agent, img):
        # replace an agents patch with a different image
        if agent.id == self.replacement_agent_id: # NOTE: hardcoded for now
            # map from the agents position in the old image to the new image
            x, y, theta = agent.history[-1]
            x = int(x / self.img.shape[0] * self.padded_replacement_img.shape[0]); y = int(y / self.img.shape[1] * self.padded_replacement_img.shape[1])
            return get_patch(self.padded_replacement_img, x, y, patch_size=self.patch_size)[0]
        else:
            return get_patch(img, *agent.history[-1], patch_size=self.patch_size)[0].copy()

    def render(self, show=False):
        # draw the agents on the image
        img = self.padded_img.copy()
        patches = {}
        for agent in self.agents:
            patches[agent.id] = self.get_stream_patch(agent, img)
        if show:
            for idx, agent in enumerate(self.agents):
                x, y, theta = agent.history[-1]
                cv2.rectangle(img, (y+32, x-32), (y-32, x+32), (125, 255, 0), 10)
                cv2.circle(img, (y, x), 1, (0, 0, 255), 10)
                cv2.imshow(f'patch{agent.id}', patches[idx])
            # scale down to half size 
            wimage = cv2.resize(img, (0, 0), fx=0.25, fy=0.25)
            cv2.imshow('wimage', wimage)
            cv2.waitKey(10)
        return patches


# cv2.destroyAllWindows()
# env = DroneImgEnv(n=4); env.reset()
# for _ in range(100):
#     env.step()
#     env.render(show=True)

# time.sleep(500)
# cv2.destroyAllWindows()

# %% ../nbs/04_multiagent.ipynb 16
env = DroneImgEnv(n=5, patch_size=32); env.reset()
meta_agent = MetaAgent(n_agents=env.n)

# patch_feature = lambda x: skf.hog(x, pixels_per_cell=(16, 16), cells_per_block=(1,1), visualize=False, channel_axis=-1)
# get the color histogram of the patch
# patch_feature = lambda x: cv2.calcHist([x], [0, 1, 2], None, [8, 8, 8], [0, 256, 0, 256, 0, 256]).reshape(-1)

if USE_VGG:
    patch_feature = lambda x: vgg.features(torch.tensor(x).permute(2, 0, 1).float().unsqueeze(0).cuda()).squeeze().detach().cpu().numpy()
else:
    patch_feature = lambda x: np.mean(x, axis=(0, 1))


def single_step():
    env.step()
    patches = env.render(show=True)
    assert len(patches) == env.n, f'Number of patches ({len(patches)}) does not match number of agents ({env.n})'
    assert all([p.shape == (env.patch_size, env.patch_size, 3) for p in patches.values()]), f'Patch shapes are {[p.shape for p in patches.values()]}'
    features = {a_id: patch_feature(patch) for a_id, patch in patches.items()}
    features = np.array(list(features.values()))
    features = features[:, None, :]
    # print(features.shape)
    return features

def single_batch(ts=10):
    batch = []
    for _ in range(ts):
        batch.append(single_step())
    batch = np.concatenate(batch, axis=1)
    return batch

env.img.min(), env.img.max(), single_step().shape, single_batch().shape

