import numpy as np
import os
from datetime import datetime



# folder_path = 'human_eps_pinpad_feb20th/human_eps_3k_1'
folder_path = 'human_eps_02-23_14-53-24_mm'

data_in = []
data_out = []

for file in os.listdir(folder_path):
    file_path = folder_path + '/' + file
    a = np.load(file_path)
    # print(a)
    data_in.append(a['image'])
    data_out.append(a['action'])
    
  
data_in = np.vstack(data_in)
data_out = np.vstack(data_out)

# folder_path = 'pinpad_human_demonstrations/human_eps_1'

# data_in = []
# data_out = []
# # result_array = np.concatenate((arr1, arr2), axis=0)


# for file in os.listdir(folder_path):
#     file_path = folder_path + '/' + file
#     a = np.load(file_path)
#     # print(a)
#     data_in.append(a['image'])
#     data_out.append(a['action'])
    
  
# data_in_array2 = np.concatenate((data_in[0], data_in[1], data_in[2]), axis = 0)  
# data_out_array2 = np.concatenate((data_out[0], data_out[1], data_out[2]), axis = 0)
    
# data_in = np.concatenate((data_in_array1, data_in_array2), axis = 0)    
# data_out = np.concatenate((data_out_array1, data_out_array2), axis = 0)    

# print('here')
# '''
# Imports
# '''
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

# import gym
# import pickle
# import TurtleBot_v0
from torch.distributions import Categorical
import time
from torch.utils.data import DataLoader, TensorDataset, random_split
from torch.utils.tensorboard import SummaryWriter

# Create a TensorBoard writer
# Create a unique directory name using the current timestamp
current_time = datetime.now().strftime('%b%d_%H-%M-%S')
log_dir = f'runs/behavior_cloning{current_time}'
writer = SummaryWriter(log_dir)

# writer = SummaryWriter('runs/behavior_cloning')



# # infile = open('/media/{redacted}/SSHD/Robotics/Current_Work/Curriculum_Behavior_Cloning/PPO-GAIL-cartpole/PPO/trajectories/turtlebot_GAN.pickle','rb')
# infile = open('trajectories/turtlebot_GAN.pickle','rb')

# demos = pickle.load(infile)

# # demos = np.load('expert_cartpole.npz', mmap_mode='r')
# data_in = demos['states']
# data_out = demos['actions']

# device = torch.device('cuda:0')
# device = torch.device('cpu')
device = torch.device('mps')
# if(torch.cuda.is_available()):
#     # torch.cuda.set_device(0)     
#     device = torch.device('cuda') 
#     torch.cuda.empty_cache()
#     print("Device set to : " + str(torch.cuda.get_device_name(device)))
# else:
#     print("Device set to : cpu")
# print("============================================================================================")


'''
Define BC Model as NN

Specs:
NN: 3 layers (4 each cells with ReLu) and Sigmoid on Output
Loss: BCE (Binary Cross Entropy)
'''

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # self.fc1 = nn.Linear(4, 4)
        # self.fc2 = nn.Linear(4, 4)
        # self.fc3 = nn.Linear(4, 1)

        self.actor = nn.Sequential(
                        nn.Conv2d(3, 32, kernel_size = 3, stride = 1),
                        # nn.BatchNorm2d(64),
                        nn.ReLU(),
                        nn.MaxPool2d(2,2),
                        nn.Dropout(p=0.2),
                        nn.Conv2d(32,32, kernel_size = 3, stride = 1),
                        # nn.BatchNorm2d(64),
                        nn.ReLU(),
                        nn.MaxPool2d(2,2),
                        nn.Dropout(p=0.2),
                        nn.Flatten(),   
                        nn.Linear(12544,512),
                        nn.ReLU(),
                        nn.Linear(512, 64),
                        nn.ReLU(),
                        nn.Linear(64,6)
                        # nn.Softmax(dim=-1)
                        )       
   


    def forward(self, state):
        action_probs = self.actor(state)
        # dist = Categorical(action_probs)

        # action = dist.sample()
        # action_logprob = dist.log_prob(action)
        return action_probs
        # return action.detach()
        # return action_logprob.detach()

env_name  = "memory_maze"
net = Net().to(device)
model = net
loss_arr = []
'''
Train BC Model
'''
# implement early stopping
patience = 20
best_val_loss = float('inf')
counter = 0

learnr = 0.0001
criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.SGD(model.parameters(), lr=learnr, weight_decay=0.01)
# optimizer = torch.optim.Adam(model.parameters(), lr=learnr, weight_decay=0.01)
optimizer = torch.optim.Adam(model.parameters(), lr=learnr)

# learning_rate = [0.01]

# data_in_squeezed = np.squeeze(np.array(data_in),1)
data_in = np.divide(data_in, 255)
data_in = np.swapaxes(data_in,1,3)

data_in_torch = torch.from_numpy(data_in).to(torch.float32).to(device)

data_out_torch = torch.from_numpy(np.array(data_out)).to(device)

# data_out_oh = F.one_hot(torch.from_numpy(np.array(data_out)), num_classes=4)

dataset = TensorDataset(data_in_torch, data_out_torch)
# dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

total_size = len(dataset)
train_size = int(0.8 * total_size)  # 80% for training
test_size = (total_size - train_size) // 2  # 10% for testing
val_size = total_size - train_size - test_size  # 10% for validation

# Use random_split to create train, test, and val datasets
train_dataset, test_dataset, val_dataset = random_split(dataset, [train_size, test_size, val_size])

# Create dataloaders for train, test, and val datasets
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=False)  # No need to shuffle for testing
val_dataloader = DataLoader(val_dataset, batch_size=128, shuffle=False)    # No need to shuffle for validation


train_features, train_labels = next(iter(train_dataloader))


num_epochs = 8000
cutoff_epoch = 8000


for epoch in range(num_epochs):
    # Training phase
    model.train()
    for xs, ys in train_dataloader:
        y_pred = model(xs)
        loss = criterion(y_pred, ys)
        print(f"Epoch {epoch}, Training Loss: {loss.item()}")

        loss_arr.append(loss.item())
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    avg_train_loss = np.mean(loss_arr)
    print (f"Epoch {epoch}, Training Loss: {avg_train_loss}")
    writer.add_scalar('training loss', avg_train_loss, epoch)
    


    # Validation phase
    model.eval()
    with torch.no_grad():
        val_loss_arr = []
        for val_xs, val_ys in val_dataloader:
            val_y_pred = model(val_xs)
            val_loss = criterion(val_y_pred, val_ys)
            val_loss_arr.append(val_loss.item())

        avg_val_loss = np.mean(val_loss_arr)
        print(f"Epoch {epoch}, Validation Loss: {avg_val_loss}")
        writer.add_scalar('validation loss', avg_val_loss, epoch)

    if epoch % 100 == 0:
        checkpoint_path = "BC_lr={}_epochs={}.pth".format(learnr, epoch)
        torch.save(model.state_dict(), checkpoint_path)

    #  Check if the current validation loss is lower than the best one seen so far
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        counter = 0
        # Save the best model
        best_model_path = "BC_best_model_lr={}_epoch={}.pth".format(learnr, epoch)
        torch.save(model.state_dict(), best_model_path)
    else:
        counter += 1
        if counter >= patience:
            print(f"Early stopping triggered at epoch {epoch} with validation loss {avg_val_loss}")
            cutoff_epoch = epoch
            break

# At the end of training, reload the best model
model.load_state_dict(torch.load(best_model_path))

# After training and validation, evaluate on the test set
model.eval()  # Set the model to evaluation mode
test_loss_arr = []
correct = 0
total = 0

with torch.no_grad():
    for test_xs, test_ys in test_dataloader:
        # Convert one-hot encoded targets to class indices if necessary
        if test_ys.ndim > 1 and test_ys.size(1) > 1:  # Check if test_ys is one-hot encoded
            test_ys = torch.argmax(test_ys, dim=1)
        
        test_y_pred = model(test_xs)
        test_loss = criterion(test_y_pred, test_ys)
        test_loss_arr.append(test_loss.item())

        _, predicted = torch.max(test_y_pred, 1)
        total += test_ys.size(0)
        correct += (predicted == test_ys).sum().item()


avg_test_loss = np.mean(test_loss_arr)
test_accuracy = 100 * correct / total
print(f"Test Loss: {avg_test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")

# log test loss and accuracy to TensorBoard
writer.add_scalar('Loss/test', avg_test_loss, epoch)
writer.add_scalar('Accuracy/test', test_accuracy, epoch)

# Save the final model
final_model_path = folder_path + "BC_final_model_{}_lr={}_epoch={}.pth".format(env_name, learnr, cutoff_epoch)
torch.save(model.state_dict(), final_model_path)


    # # Early stopping
    # if len(loss_arr) > 50 and np.mean(val_loss_arr[-50:]) < 0.05:
    #     print("Early stopping at epoch:", epoch)
    #     break


# checkpoint_path = "BC_lr={}_epochs={}_BN.pth".format(learnr, cutoff_epoch)
# torch.save(model.state_dict(), checkpoint_path)
writer.close()









# for epoch in range(num_epochs):
#     for xs, ys in iter(train_dataloader):
#         y_pred = model(xs)
#         loss = criterion(y_pred, ys)
#         print(epoch, loss.item())
#         loss_arr.append(loss.item())
#         # model.zero_grad()
#         loss.backward()
#         optimizer.step()
#         optimizer.zero_grad()
#         # with torch.no_grad:
#     if len(loss_arr)> 50 and np.mean(loss_arr[-50:]) < 0.05:
#         print("Early stopping at epoch: ", epoch)
#         break




'''
Render BC Agent and Generate Gifs
'''
# env = gym.make('TurtleBot-v2')
# obs = env.reset()

# for t in range(200):
#     obs = torch.from_numpy(obs).to(torch.float32).to(device)
#     action = model.forward(obs).detach()
#     if action[0][0].item() > 0.5:
#         action_to_take = 0    
#     elif action[0][1].item() > 0.5:
#         action_to_take = 1
#     elif action[0][2].item() > 0.5:
#         action_to_take = 2
#     elif action[0][3].item() > 0.5:
#         action_to_take = 3
#     else:
#         action_to_take = env.action_space.sample()

#     obs, reward, done, info = env.step(action_to_take)
#     time.sleep(0.5)