# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/00_core.ipynb.

# %% auto 0
__all__ = ['act_gr', 'ImageDirectoryDataset', 'create_transform', 'create_dataset', 'conv', 'deconv', 'TrainAELearner', 'summary']

# %% ../nbs/00_core.ipynb 3
import pickle,gzip,math,os,time,shutil,torch,matplotlib as mpl,numpy as np,matplotlib.pyplot as plt
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torcheval.metrics import MeanSquaredError

from PIL import Image
import random 
from functools import partial,reduce

# import torcheval.metrics as tmetrics
# print out the classes in the tmetrics
# print(dir(tmetrics))

import fastcore.all as fc

from fastprogress import progress_bar,master_bar
from miniai.datasets import *
from miniai.conv import *
from miniai.learner import *
from miniai.activations import *
from miniai.init import *
from miniai.sgd import *

# %% ../nbs/00_core.ipynb 6
class ImageDirectoryDataset(Dataset):
    def __init__(self, image_paths, transform=None, debug=False):
        self.transform = transform
        self.image_paths = []
        self.image_paths = image_paths

        # Shuffle the image paths to create a random split between training and validation sets
        random.shuffle(self.image_paths)
        if debug:
            samples = 3200
            print(f"Debug mode: using only {samples} images")
            self.image_paths = self.image_paths[:samples]

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]

        with Image.open(img_path) as img:
            if self.transform:
                img = self.transform(img)
        return (img, 0)

    def __len__(self):
        return len(self.image_paths)

# %% ../nbs/00_core.ipynb 7
def create_transform(color=False, image_side_length=128):
    image_size = (image_side_length, image_side_length)
    if color:
        transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.ToTensor(), # normalizes to [0, 1]
            transforms.Normalize(mean=[0.47469366222070025, 0.47469366222070025, 0.47469366222070025], std=[0.18099716921052658, 0.18099716921052658, 0.18099716921052658]) # from AirSim dataset
            # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # from AirSim dataset
        ])
    else:
        transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.ToTensor(), # normalizes to [0, 1]
            transforms.Grayscale(num_output_channels=1),
            # transforms.Normalize(mean=[0.47469366222070025], std=[0.18099716921052658]) # from AirSim dataset
            transforms.Normalize(mean=[0.5], std=[0.288]) # stats for uniform grayscale distribution
        ])
    return transform


def create_dataset(split_ratio=0.1, batch_size=32, color=False, debug=False, image_side_length=128):
    transform = create_transform(color=color, image_side_length=image_side_length)

    root_dir = os.path.expanduser("~/Documents/AirSim/aero_conf_training/")
    image_paths = []
    for label, class_dir in enumerate(os.listdir(root_dir)):
        class_path = os.path.join(root_dir, class_dir, "images")
        if not os.path.isdir(class_path): continue
        for image_name in os.listdir(class_path):
            if not image_name.endswith(".ppm"): continue
            image_path = os.path.join(class_path, image_name)
            image_paths.append(image_path)
            # labels.append(label)

    # shuffle the image paths to create a random split between training and validation sets
    random.shuffle(image_paths)
    train_image_paths = image_paths[int(len(image_paths) * split_ratio):]
    valid_image_paths = image_paths[:int(len(image_paths) * split_ratio)]
    train_image_paths = [entry for entry in train_image_paths if entry.endswith(".ppm")]
    valid_image_paths = [entry for entry in valid_image_paths if entry.endswith(".ppm")]

    # Create the custom dataset and DataLoader
    train_image_dataset = ImageDirectoryDataset(train_image_paths, transform=transform, debug=debug)
    train_loader = DataLoader(train_image_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    valid_image_dataset = ImageDirectoryDataset(valid_image_paths, transform=transform, debug=debug)
    validat_loader = DataLoader(valid_image_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    # Novelty test set
    root_dir = os.path.expanduser("~/workspace/nasa/april/OnAIR/experiment/test_anomaly_detection")
    # root_dir = os.path.expanduser("~/workspace/nasa/ml/novelty_images")
    test_image_paths = []
    for label, image_name in enumerate(os.listdir(root_dir)):
        image_path = os.path.join(root_dir, image_name)
        test_image_paths.append(image_path)

    test_image_dataset = ImageDirectoryDataset(test_image_paths, transform=transform, debug=debug)
    test_loader = DataLoader(test_image_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    print(f"Train, Valid, Test  {len(train_image_dataset)}, {len(valid_image_dataset)}, {len(test_image_dataset)}")

    return train_loader, validat_loader, test_loader

# %% ../nbs/00_core.ipynb 10
act_gr = partial(GeneralRelu, leak=0.1, sub=0.4)

def conv(ni, nf, ks=3, stride=2, act=nn.ReLU, norm=None, bias=None):
    if bias is None: bias = not isinstance(norm, (nn.BatchNorm1d,nn.BatchNorm2d,nn.BatchNorm3d))
    layers = [nn.Conv2d(ni, nf, stride=stride, kernel_size=ks, padding=ks//2, bias=bias)]
    if norm: layers.append(norm(nf))
    if act: layers.append(act())
    return nn.Sequential(*layers)

def deconv(ni, nf, ks=3, stride=2, act=nn.ReLU, norm=None, bias=None):
    if bias is None: bias = not isinstance(norm, (nn.BatchNorm1d,nn.BatchNorm2d,nn.BatchNorm3d))
    layers = [nn.ConvTranspose2d(ni, nf, stride=stride, kernel_size=ks, padding=ks//2, bias=bias)]
    if norm: layers.append(norm(nf))
    if act: layers.append(act())
    return nn.Sequential(*layers)

class TrainAELearner(Learner):
    def predict(self): self.preds = self.model(self.batch[0])
    def get_loss(self): self.loss = self.loss_func(self.preds, self.batch[0])
    def backward(self): self.loss.backward()
    def step(self): self.opt.step()
    def zero_grad(self): self.opt.zero_grad()

# %% ../nbs/00_core.ipynb 13
@fc.patch
def summary(self:Learner):
    res = '|Module|Input|Output|Num params|\n|--|--|--|--|\n'
    tot = 0
    def _f(hook, mod, inp, outp):
        nonlocal res,tot
        nparms = sum(o.numel() for o in mod.parameters())
        tot += nparms
        res += f'|{type(mod).__name__}|{tuple(inp[0].shape)}|{tuple(outp.shape)}|{nparms}|\n'
    with Hooks(self.model, _f) as hooks: self.fit(1, lr=1, train=False, cbs=SingleBatchCB())
    print("Tot params: ", tot)
    if fc.IN_NOTEBOOK:
        from IPython.display import Markdown
        return Markdown(res)
    else: print(res)
