
from pyvirtualdisplay import Display
import numpy as np
import os
import pandas as pd
import argparse
#np.random.seed(111111)

#3141592, 314159



os.environ["CUDA_VISIBLE_DEVICES"]="1"

import tensorflow as tf

gpus = tf.config.experimental.list_physical_devices('GPU')
print(f"Available GPUs: {gpus}")

os.environ["KERAS_BACKEND"] = "tensorflow"

from optimization_problem import Program
from fittedq import *
from exponentiated_gradient import ExponentiatedGradient
from fitted_off_policy_evaluation import *
from exact_policy_evaluation import ExactPolicyEvaluator
from stochastic_policy import StochasticPolicy
from DQN import DeepQLearning
from print_policy import PrintPolicy
from keras.models import load_model
from keras import backend as K
from env_dqns import *
import h5py
import time
import os
from config_car import *
np.set_printoptions(suppress=True)

labels = []
for i in range(len(constraints)-1): 
    labels.append(['g%s_min' % i, 
                    'g%s_max' % i, 
                    'g%s_1quantile' % i, 
                    'g%s_mean' % i, 
                    'g%s_3quantile' % i])

labels = np.array(labels).T.tolist()

model_dir = os.path.join(os.getcwd(), f'models/rnd_{np_seed}')
result_dir = os.path.join(os.getcwd(), f'results_robustness/rnd{np_seed}_eta{eta}_g1{constraints[0]}_g2{constraints[1]}')
if not os.path.exists(result_dir):
    os.makedirs(result_dir)

print('Loading Prebuilt Data')
tic = time.time()

with h5py.File('./seed_2_data/car_data_actions_seed_2.h5', 'r') as hf:
    action_data = hf['data'][:]
with h5py.File('./seed_2_data/car_data_frames_seed_2.h5', 'r') as hf:
    frame_data = hf['data'][:]
with h5py.File('./seed_2_data/car_data_is_done_seed_2.h5', 'r') as hf:
    terminated_data = hf['data'][:]
with h5py.File('./seed_2_data/car_data_is_done_seed_2.h5', 'r') as hf:
    truncated_data = hf['data'][:]
with h5py.File('./seed_2_data/car_data_next_states_seed_2.h5', 'r') as hf:
    next_state_data = hf['data'][:]
with h5py.File('./seed_2_data/car_data_prev_states_seed_2.h5', 'r') as hf:
    current_state_data = hf['data'][:]
with h5py.File('./seed_2_data/car_data_rewards_seed_2.h5', 'r') as hf:
    cost_data = hf['data'][:]

cost_data[:, 3] = transform_distance(cost_data[:, 3])
cost_data[:, -1] = transform_brake(cost_data[:, -1]) # For the fail state

frame_gray_scale = np.zeros((len(frame_data),96,96)).astype('float32')
for i in range(len(frame_data)):
    frame_gray_scale[i,:,:] = np.dot(frame_data[i,:,:,:]/255. , [0.299, 0.587, 0.114])

prev_stat = []


def calculate_stats(data) :
    min_val = np.min(data)
    max_val = np.max(data)
    quartiles = np.percentile(data, [25, 75])
    mean = np.mean(data)
    return min_val, max_val, quartiles[0], mean, quartiles[1]

def shift_calc(Type, dir, power, inertia, brake = 0., steer = 0.) : 
    env = ExtendedCarRacing(init_seed, stochastic_env, max_pos_costs, power, inertia, brake, steer)

    policy_old = CarDQN(env, 
                    0.95, 
                    action_space_map = action_space_map, 
                    action_space_dim=action_space_dim, 
                    model_type=model_type,
                    max_time_spent_in_episode=max_time_spent_in_episode,
                    num_iterations = num_iterations,
                    sample_every_N_transitions = sample_every_N_transitions,
                    batchsize = batchsize,
                    copy_over_target_every_M_training_iterations = copy_over_target_every_M_training_iterations,
                    buffer_size = buffer_size,
                    min_epsilon = min_epsilon,
                    initial_epsilon = initial_epsilon,
                    epsilon_decay_steps = epsilon_decay_steps,
                    num_frame_stack=num_frame_stack,
                    min_buffer_size_to_train=min_buffer_size_to_train,
                    frame_skip = frame_skip,
                    pic_size = pic_size,
                    models_path = os.path.join(model_dir,'weights.{epoch:02d}-{loss:.2f}.keras'),
                    )

    exact_policy_algorithm = ExactPolicyEvaluator(action_space_map, 0.95, env=env, frame_skip=frame_skip, num_frame_stack=num_frame_stack, pic_size = pic_size, constraint_thresholds=constraint_thresholds, constraints_cared_about=constraints_cared_about)

    id_num = 25

    for ga in {0.95, 0.925, 0.9}:
        current_path = os.path.join(model_dir, f'cold_gamma_{ga}_eta{eta}_g1{constraints[0]}_g2{constraints[1]}')

        res_c, res_g1, res_g2, perf = [], [], [], []
        for id in range(1, id_num + 1) :
            policy_path = os.path.join(current_path, f'pi_{id}.keras')

            K.clear_session()
            policy_old.Q.model = load_model(policy_path)

            policy_old.Q.all_actions_func = KerasModel(inputs = policy_old.Q.model.get_layer('inp').output, outputs = policy_old.Q.model.get_layer('all_actions').output)
            '''
            dataset_length = len(problem.dataset)
            batch_size = 512
            num_batches = int(np.ceil(dataset_length/float(batch_size)))
            
            actions = []
            all_idxs = range(dataset_length)
            print(f'Creating gamma_{ga}_pi_{id}')
            for i in tqdm(range(num_batches)):
                idxs = all_idxs[(batch_size*i):(batch_size*(i+1))]
                states = np.rollaxis(problem.dataset['frames'][problem.dataset['next_states'][idxs]],1,4)
                actions.append(policy_old.Q([states], x_preprocessed=True))

            problem.dataset.data['pi_of_x_prime'] = np.hstack(actions)
            '''
            exact_c, exact_g, performance = exact_policy_algorithm.run(policy_old.Q, to_monitor=False)
            exact_g = np.array(exact_g)[[-1,2]] 
            print(f'(Power{power}, Inertia{inertia}, Brake{brake}, Steer{steer})Exact R(gamma_{ga}_pi_{id}): %s. G(gamma_{ga}_pi_{id}): %s' % (exact_c, exact_g))
            print()
            res_c.append(exact_c)
            res_g1.append(exact_g[0])
            res_g2.append(exact_g[1])
            perf.append(performance)
        # Calculate statistics for c, g_1, and g_2
        stats = []
        stats.append(calculate_stats(res_c))
        stats.append(calculate_stats(res_g1))
        stats.append(calculate_stats(res_g2))
        stats.append(calculate_stats(perf))
        stats = np.array(stats).T.tolist()

        # Creating a DataFrame
        prev_stat.append(np.hstack([ga, brake, 1 - power, 1 - inertia, steer, np.hstack(stats)]))

        

        df = pd.DataFrame(prev_stat, columns=np.hstack(['Gamma', 'Brake_shift', 'Power_shift', 'Inertia_shift', 'Steer_shift', 'c_min', labels[0], 'perf_min', 'c_max', labels[1], 'perf_max','c_1quantile', labels[2], 'perf_1quantile', 'c_mean', labels[3], 'perf_mean', 'c_3quantile', labels[4], 'perf_3quantile']))

        
        df.to_csv(os.path.join(result_dir, f'statistics_{dir}_{Type}_30%.csv'), index=False)# nax = 0.2 only if Type='B' otherwise 1

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Process shift argument')
    parser.add_argument('-shift', choices=['B', 'I', 'S', 'P'], required=True, help='Shift value must be one of {B, I, S, P}')
    parser.add_argument('-dir', type=int, choices=[1, -1], required=True, help='Direction value must be either 1 or -1')

    args = parser.parse_args()

    if args.shift == 'B':
        for brake in np.linspace(0, 0.06, 61) :
            shift_calc(args.shift, args.dir, power = 1.0, inertia = 1.0 , brake = brake * args.dir, steer = 0)
    elif args.shift == 'I':
        for inertia in np.linspace(0, 0.3, 61) :
            shift_calc(args.shift, args.dir, power = 1.0, inertia = 1 - inertia * args.dir , brake = 0, steer = 0)
    elif args.shift == 'S':
        for steer in np.linspace(0, 0.3, 61) :
            shift_calc(args.shift, args.dir, power = 1.0, inertia = 1.0 , brake = 0, steer = -steer * args.dir)
    elif args.shift == 'P':
        for power in np.linspace(0, 0.3, 61) :
            shift_calc(args.shift, args.dir, power = 1 - power * args.dir, inertia = 1.0 , brake = 0, steer = 0)
    else:
        print("Invalid shift value")
    
    
    
    


