import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import matplotlib.ticker as mtick
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.cm import ScalarMappable
from matplotlib.colors import LinearSegmentedColormap, Normalize
from matplotlib import rcParams

rcParams['xtick.labelsize'] = 30
Shift_Type = ['Brake_shift', 'Inertia_shift', 'Power_shift']
op = {'Brake_shift' : 1, 'Steer_shift' : -1, 'Inertia_shift' : -1, 'Power_shift' : 1}
seed = ['31415', '314159', '3141592']

# Load data

x_label = {'Power_shift' : 'Shift of Power', 'Brake_shift' : 'Shift of Braking', 'Steer_shift' : 'Shift of Steering', 'Inertia_shift' : 'Shift of Inertia'}

neg_global_min = 0.2
ab = [212 / 255, 205 / 255, 71 / 255] 
be = [125 / 255, 120 / 255, 140 / 255] 

result_dir = os.path.join(os.getcwd(), f'results_robustness')
# Load data

def read_gamma(index, gamma):
    out = []
    shift_mag = []
    for i in range(len(seed)):
        result = f'rnd{seed[i]}_eta0.05_g12_g24/statistics_{index[0]}_30%.csv'
        data = pd.read_csv(os.path.join(result_dir, result))
        data = data[data['Gamma'] == gamma]
        df = data[['c_mean', 'g0_mean', 'g1_mean']]

        df = df.rolling(3, min_periods = 1).mean()

        if i == 0: 
            out = df
            shift_mag = data[index] * op[index]
        else :
            out = out + df
    
    out = out / len(seed)
    return pd.concat([shift_mag, out], axis = 1)

def set_c(val) :
    if val >= 0:
        return ab
    elif val <= -neg_global_min:
        return be
    else : 
        return (val + neg_global_min) / (neg_global_min) * np.array(ab) + (-val) / (neg_global_min) * np.array(be)

fontsize=25

caption = {
    'g0_mean': f'Constraint I (Slow Driving)',
    'g1_mean': f'Constraint II (Edge Driving)',
}
label = ['g0_mean', 'g1_mean']
constraints = [2, 4]
nax = {'Brake_shift' : 0.2, 'Steer_shift' : 1, 'Inertia_shift' : 1, 'Power_shift' : 1}
y_position = [0.3, 0.15, 0]

def plot_con(index, fig, axx):

    df_95 = read_gamma(index, 0.95)
    df_90 = read_gamma(index, 0.90)
    df_925 = read_gamma(index, 0.925)
    #print(df_95)

    for j in range(2):
        ax = axx[Shift_Type.index(index), j]
       # print(Shift_Type.index(index))
        if(Shift_Type.index(index) == 0) :
            ax.set_title(caption[label[j]], fontsize=fontsize * 2, pad = 30)
        ax.set_ylim(-0.1, 0.4)
    
        c_data = np.vstack([df_95[label[j]], df_925[label[j]], df_90[label[j]]])

        for idx, array in enumerate(c_data):
            for pos, val in enumerate(array):
                ax.barh(y_position[idx], width = df_95.iloc[pos][index], left = df_95.iloc[pos][index], height = 0.1, color = set_c((val - constraints[j]) / constraints[j]))
    
        ax.set_xlim(0, nax[index] * 0.6)
        ax.set_xticks([0, nax[index] * 0.2, nax[index] * 0.4, nax[index] * 0.6])
        xticks = mtick.PercentFormatter(nax[index] * 2) 
        ax.xaxis.set_major_formatter(xticks)    
        ax.set_yticks(y_position)
        if j == 0 :
            ax.set_yticklabels(['Baseline: $\gamma$=0.95', 'Ours: $\gamma$=0.925', 'Ours: $\gamma$=0.90'], fontsize=fontsize * 1.5)
            #ax.set_yticklabels([])
        else : 
            ax.set_yticklabels([])
        #ax.set_xticks(range(len(df_90[label[j]])))
        #ax.set_xticklabels(, fontsize=fontsize)
        ax.set_xlabel(x_label[index], fontsize=fontsize * 1.5)   



if __name__ == '__main__':
    fig, axx = plt.subplots(len(Shift_Type), 2, figsize = (35, 18))
    #fig.tight_layout()
    fig.subplots_adjust(left = 0.2, right = 0.9, hspace=0.4, wspace=0.2)
    #axx.xticks(fontsize = 20)
    for index in Shift_Type :
        plot_con(index, fig, axx)

    vals = np.linspace(-0.4, 0.1, 500)
    colors = np.array([set_c(val) for val in vals])
    #colors = np.vstack((colors, np.array([set_c(val) for val in np.linspace(-0.8, 0.3, 500)])))
    custom_cmap = LinearSegmentedColormap.from_list("custom_cmap", colors)

    cbar = ScalarMappable(cmap=custom_cmap)
    cbar.set_array([])
    cbar.set_clim(-0.3, 0.1)

    cb = plt.colorbar(cbar, ax=axx, orientation='vertical')
    cb.set_ticks([-0.3, -0.2, -0.1, 0, 0.1])
    cb.set_ticklabels(['-30%', '-20%', '-10%', 'Threshold', '10%'])

    cb.ax.tick_params(labelsize=40)
    fig.savefig('figs/3ave/constraints_no_steer.png')   