import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import matplotlib.ticker as mtick

# Load data

def read_brake() :
    all = []
    for i in range (3) :
        result = f'statistics_eta0.05_g2_4_brake_0{i}.csv'
        all.append(pd.read_csv(os.path.join(os.getcwd(), result)))
    data = pd.concat(all)
    return data

def read_steer() :
    all = []
    for i in range (0, 5) :
        result = f'statistics_eta0.05_g2_4_steer_plus{i}.csv'
        all.append(pd.read_csv(os.path.join(os.getcwd(), result)))
    data = pd.concat(all)
    return data

def read_inertia() :
    all = []
    for i in range(0, 5) :
        df = pd.read_csv(os.path.join(os.getcwd(), f'statistics_eta0.05_g2_4_intertia_{i}.csv'))
        all.append(df)
    data = pd.concat(all)
    return data

def read_power() :
    all = []
    for i in range (3) :
        result = f'statistics_eta0.05_g2_4_power_0{i}.csv'
        all.append(pd.read_csv(os.path.join(os.getcwd(), result)))
    data = pd.concat(all)
    return data

#data = read_brake()
#data = read_power()
#data = read_inertia()
data = read_steer()

#index = 'Power_shift'
#index = 'Brake_shift'
index = 'Steer_shift'
#index = 'Inertia_shift'

op = {'Brake_shift' : 1, 'Steer_shift' : -1, 'Inertia_shift' : 1, 'Power_shift' : 1}

dec = {'Brake_shift' : 0, 'Steer_shift' : 21, 'Inertia_shift' : 40, 'Power_shift' : 0} # align data to ~30% shifts

def split_by_gamma(data, gamma):
    df = data[data['Gamma'] == gamma]

    out = []
    for i in range(1, df[index].size - dec[index], 1) :
        cnt = 0
        for j in range(3) :
            if j == 0 : 
                cur = np.hstack([op[index] * df.iloc[i + j][index], df.iloc[i + j][['c_mean', 'g0_mean', 'g1_mean']].values, int(df.iloc[i + j]['perf_mean'] * 335)])
            elif i + j >= 0 and i + j < df['Brake_shift'].size: 
                cur = cur + np.hstack([op[index] * df.iloc[i + j][index], df.iloc[i + j][['c_mean', 'g0_mean', 'g1_mean']].values, int(df.iloc[i + j]['perf_mean'] * 335)])
            else : break
            cnt += 1
        out.append(cur / cnt)

    return pd.DataFrame(out, columns=[index, 'c_mean', 'g0_mean', 'g1_mean', 'perf_mean'])

df_95 = split_by_gamma(data, 0.95)
df_90 = split_by_gamma(data, 0.90)
df_925 = split_by_gamma(data, 0.925)

# Plot

lims = {
    'Brake_shift' : {
        'g0_mean': [1.5, 3],
        'g1_mean': [1.5, 5],
        'perf_mean': [60, 130],
        'c_mean': [16, 46]
    }, 
    'Power_shift' : {
        'g0_mean': [1.5, 3],
        'g1_mean': [2, 5],
        'perf_mean': [60, 130],
        'c_mean': [25, 40]
    },
    'Inertia_shift' : {
        'g0_mean': [1.5, 3.5],
        'g1_mean': [1.5, 6],
        'perf_mean': [60, 130],
        'c_mean': [28, 38]
    },
    'Steer_shift' : {
        'g0_mean': [1.5, 4],
        'g1_mean': [1, 8],
        'perf_mean': [60, 130],
        'c_mean': [15, 40]
    }
}
x_label = {'Power_shift' : 'Shift of Power', 'Brake_shift' : 'Shift of Brake', 'Steer_shift' : 'Shift of Steering', 'Inertia_shift' : 'Shift of Inertia'}

caption = {
    'c_mean': 'Value Function of $\hat{\pi}$',
    'g0_mean': f'Constraint (Slow Driving)',
    'g1_mean': f'Constraint (Edge Driving)',
    'perf_mean': 'Objeticve'
}
label = ['c_mean', 'g0_mean', 'g1_mean']
constraints = [2, 4]

fontsize=30
fig = plt.figure(figsize = (57,9))
nax = {'Brake_shift' : 0.2, 'Steer_shift' : 1, 'Inertia_shift' : 1, 'Power_shift' : 1}
lw = 3

i = 0
ax = fig.add_subplot(1, 3, 3)
ax.plot(df_90[index], df_90[label[i]], label=f'$\gamma$=0.90(ours)', linewidth=lw, marker = 'o', markersize = 7)
ax.plot(df_925[index], df_925[label[i]], label=f'$\gamma$=0.925(ours)', linewidth=lw, marker = '^', markersize = 7)
ax.plot(df_95[index], df_95[label[i]], label=f'$\gamma$=0.95(baseline)', linewidth=lw, marker = 's', markersize = 7)


ax.set_xlabel(x_label[index], fontsize=fontsize)
ax.set_ylabel(caption[label[i]], fontsize=fontsize * 1.5)
ax.set_title('Objective', fontsize = fontsize * 1.5)
#if i == 0:
ax.set_ylim(lims[index][label[i]])
ax.tick_params(axis='both', which='major', labelsize=fontsize)
xticks = mtick.PercentFormatter(nax[index]) 
ax.xaxis.set_major_formatter(xticks)    
ax.legend(fontsize = fontsize)

for i in range(1, 3):
    ax = fig.add_subplot(1, 3, i)
    ax.plot(df_90[index], df_90[label[i]], label=f'$\gamma$=0.90(ours)', linewidth=lw)
    ax.plot(df_925[index], df_925[label[i]], label=f'$\gamma$=0.925(ours)', linewidth=lw)
    ax.plot(df_95[index], df_95[label[i]], label=f'$\gamma$=0.95(baseline)', linewidth=lw)
    
    
    ax.set_xlabel(index, fontsize=fontsize)
    ax.set_ylabel(caption[label[i]], fontsize=fontsize, fontweight = 'bold')
    #if i == 0:
    #    ax.set_ylim([50, 110])
    ax.set_ylim(lims[index][label[i]])
    ax.tick_params(axis='both', which='major', labelsize=fontsize)
    if i > 0 : 
        ax.axhline(y=constraints[i - 1], color='r', linestyle='--', label=f'threshold')
    xticks = mtick.PercentFormatter(nax[index]) 
    ax.xaxis.set_major_formatter(xticks)    
    if i == 2 :
        ax.legend(fontsize = fontsize)
    
fig.savefig(f'figs/{index}_final_2.png')    