import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import matplotlib.ticker as mtick

Shift_Type = ['Brake_shift', 'Steer_shift', 'Inertia_shift', 'Power_shift']
op = {'Brake_shift' : 1, 'Steer_shift' : -1, 'Inertia_shift' : -1, 'Power_shift' : 1}
seed = ['31415', '314159', '3141592']
gamma = [0.95, 0.925, 0.90]
nax = {'Brake_shift' : 0.2, 'Steer_shift' : 1, 'Inertia_shift' : 1, 'Power_shift' : 1}

# Plot
lims = {
    'Brake_shift' : {
        'g0_mean': [1.5, 3],
        'g1_mean': [1.5, 5],
        'perf_mean': [60, 130],
        'c_mean': [16, 46]
    }, 
    'Power_shift' : {
        'g0_mean': [1.5, 3],
        'g1_mean': [2, 5],
        'perf_mean': [60, 130],
        'c_mean': [25, 40]
    },
    'Inertia_shift' : {
        'g0_mean': [1.5, 3.5],
        'g1_mean': [1.5, 6],
        'perf_mean': [60, 130],
        'c_mean': [28, 38]
    },
    'Steer_shift' : {
        'g0_mean': [1.5, 4],
        'g1_mean': [1, 8],
        'perf_mean': [60, 130],
        'c_mean': [15, 40]
    }
}

x_label = {'Power_shift' : 'Shift of Power', 'Brake_shift' : 'Shift of Brake', 'Steer_shift' : 'Shift of Steering', 'Inertia_shift' : 'Shift of Inertia'}

caption = {
    'c_mean': 'Value Function of $\hat{\pi}$',
    'g0_mean': f'Constraint (Slow Driving)',
    'g1_mean': f'Constraint (Edge Driving)',
    'perf_mean': 'Objeticve'
}
label = ['c_mean', 'g0_mean', 'g1_mean']
constraints = [2, 4]
fontsize=30

result_dir = os.path.join(os.getcwd(), f'results_robustness')
# Load data

def read_gamma(index, gamma):
    out = []
    shift_mag = []
    for i in range(len(seed)):
        result = f'rnd{seed[i]}_eta0.05_g12_g24/statistics_{index[0]}_30%.csv'
        data = pd.read_csv(os.path.join(result_dir, result))
        data = data[data['Gamma'] == gamma]
        df = data[['c_mean', 'g0_mean', 'g1_mean']]

        df = df.rolling(3, min_periods = 1).mean()

        if i == 0: 
            out = df
            shift_mag = data[index] * op[index]
        else :
            out = out + df
    
    out = out / len(seed)
    return pd.concat([shift_mag, out], axis = 1)

def plot_shift(index) :
    df_95 = read_gamma(index, 0.95)
    df_925 = read_gamma(index, 0.925)
    df_90 = read_gamma(index, 0.90)
        
    fig = plt.figure(figsize = (57, 9))

    lw = 3

    i = 0
    ax = fig.add_subplot(1, 3, 3)
    ax.plot(df_90[index], df_90[label[i]], label=f'$\gamma$=0.90(ours)', linewidth=lw, marker = 'o', markersize = 7)
    ax.plot(df_925[index], df_925[label[i]], label=f'$\gamma$=0.925(ours)', linewidth=lw, marker = '^', markersize = 7)
    ax.plot(df_95[index], df_95[label[i]], label=f'$\gamma$=0.95(baseline)', linewidth=lw, marker = 's', markersize = 7)

    ax.set_xlabel(x_label[index], fontsize=fontsize)
    ax.set_ylabel(caption[label[i]], fontsize=fontsize * 1.5)
    ax.set_title('Objective', fontsize = fontsize * 1.5)
    #if i == 0:
    ax.set_ylim(lims[index][label[i]])
    ax.tick_params(axis='both', which='major', labelsize=fontsize)
    xticks = mtick.PercentFormatter(nax[index]) 
    ax.xaxis.set_major_formatter(xticks)    
    ax.legend(fontsize = fontsize)

    for i in range(1, 3):
        ax = fig.add_subplot(1, 3, i)
        ax.plot(df_90[index], df_90[label[i]], label=f'$\gamma$=0.90(ours)', linewidth=lw)
        ax.plot(df_925[index], df_925[label[i]], label=f'$\gamma$=0.925(ours)', linewidth=lw)
        ax.plot(df_95[index], df_95[label[i]], label=f'$\gamma$=0.95(baseline)', linewidth=lw)
        
        
        ax.set_xlabel(index, fontsize=fontsize)
        ax.set_ylabel(caption[label[i]], fontsize=fontsize, fontweight = 'bold')
        #if i == 0:
        #    ax.set_ylim([50, 110])
        ax.set_ylim(lims[index][label[i]])
        ax.tick_params(axis='both', which='major', labelsize=fontsize)
        if i > 0 : 
            ax.axhline(y=constraints[i - 1], color='r', linestyle='--', label=f'threshold')
        xticks = mtick.PercentFormatter(nax[index]) 
        ax.xaxis.set_major_formatter(xticks)    
        if i == 2 :
            ax.legend(fontsize = fontsize)

    fig.savefig(f'figs/3ave/{index}_3ave.png')

if __name__ == '__main__':
    for index in Shift_Type :
        plot_shift(index)
