import numpy as np 
import pickle 
from datetime import datetime
from helper_functions import *
import argparse 
import multiprocessing


examples = {
    3 : {
    'p' : [0.6, 0.3, 0.1], 
    'n_off' : 100, 
    'mu_off' : [0.1, 0.9, 0.0],
    'name' : 'example3', 
    }, 
    ## 
    4 : {
    'p': [0.6, 0.5, 0.2], 
    'n_off' : 200,
    'mu_off' : [0.01, 0.01, 0.98],
    'name' : 'example4', 
    }, 
    ##
    5 : {
    'p': [0.6, 0.3, 0.2], 
    'n_off' : 200,
    'mu_off' : [0.01, 0.495, 0.495],
    'name' : 'example5', 
    }, 
}


def parse_arguments():
    parser = argparse.ArgumentParser(description='Parser')
    parser.add_argument('--example', dest='example', type=int, required=True)
    parser.add_argument('--diff', dest='diff', type=float, required=False)
    parser.add_argument('--repeats', dest='repeats', type=int, default=100)
    parser.add_argument('--draws', dest='draws', type=int, default=200)
    parser.add_argument('--mcal-step', dest='mcal_step', type=float, default=0.1)
    parser.add_argument('--muon-step', dest='mu_on_step', type=float, default=0.1)
    parser.add_argument('--obj', dest='obj', default=False, action='store_true')
    parser.add_argument('--start', dest='start', type=int, default=0)

    return parser.parse_args()

def main(): 
    now = datetime.now().strftime("%m%d_%H%M")
    args = parse_arguments()

    problem = examples[args.example]
    p = problem['p']
    k = len(p)
    n_off = problem['n_off']

    repeats = args.repeats 
    draws = args.draws 
    mcal_step = args.mcal_step
    mu_on_step = args.mu_on_step 
    mu_on_list = get_mu_on(k, step=mu_on_step)

    dataset_filename = './data/example_' + str(args.example) + '.pickle'
    with open(dataset_filename, 'rb') as handle:
        all_data_off = pickle.load(handle) 

    root = './results/' + now + '_example_' + str(args.example)
    filename = root + '.npy'
    mustar_filename = root + '_MUSTAR.npy'
    obj_filename = root + '_OBJ.npy'

    if args.obj: # for figure 1 (top)
        param_list = {
            'n_on': [0.5 * n_off], 
            'cprime' : [0.7], 
            'c' : [0.7], 
            'Delta': np.hstack([np.linspace(0.05, 0.5, 10), np.linspace(0.6, 0.9, 4)]), 
        }
    else: # for figures 2 + 3 + 4
        param_list = {
            'n_on': np.ceil(np.power(0.5, np.flip(np.linspace(0,2,3))) * n_off), 
            'cprime' : 1. - np.power(0.5, (np.linspace(0,7,8))), 
            'c' : 1. - np.power(0.5, (np.linspace(0,7,8))), 
            'Delta': [0.1, 0.3, 0.5], 
        }


    num_keys = 2 
    results = np.zeros([repeats, 
                        len(param_list['n_on']), 
                        len(param_list['cprime']), 
                        len(param_list['c']), 
                        len(param_list['Delta']), 
                        num_keys])
    
    mu_star_results = np.zeros([repeats, 
                        len(param_list['n_on']), 
                        len(param_list['cprime']), 
                        len(param_list['c']), 
                        len(param_list['Delta']), 
                        num_keys,
                        k, 
                        ])
    if args.obj: 
        obj_results = np.zeros([repeats, 
                        len(param_list['n_on']), 
                        len(param_list['cprime']), 
                        len(param_list['c']), 
                        len(param_list['Delta']), 
                        num_keys,
                        ])
    

    threads = multiprocessing.cpu_count() - 1

    for r in range(repeats): 
        dataset_off = all_data_off[args.start+r]['data']
        a_ref = all_data_off[args.start+r]['aref']

        for cprime_idx, cprime in enumerate(param_list['cprime']): 
            for diff_idx, diff in enumerate(param_list['Delta']): 
 
                mcal = clt_compute_mcal(dataset_off, cprime, a_ref, diff, step=mcal_step)
                
                for n_on_idx, n_on in enumerate(param_list['n_on']): 
                    for c_idx, c in enumerate(param_list['c']):                     
                        print('r: {repeat}\taref: {a_ref}\tn_on: {n_on}\tcprime: {cprime}\tc: {c}\tdiff: {diff}'.format(
                            repeat=r, a_ref=a_ref, n_on=n_on, cprime=cprime, c=c, diff=diff,))

                        inputs = [(model, mu_on, dataset_off, a_ref, c, n_on, draws) for model in mcal for mu_on in mu_on_list]
                        with multiprocessing.Pool(threads) as pool: 
                            out = pool.starmap(inner, inputs)
                        out = np.array(out)
                        for i in range(2): 
                            ave_result = np.array(out[:, i]).reshape((len(mcal), len(mu_on_list)))
                            mu_idx = np.argmax(np.min(ave_result, axis=0))
                            mu_star = mu_on_list[mu_idx]
                            mu_star_results[r, n_on_idx, cprime_idx, c_idx, diff_idx, i] = mu_star
                            if args.obj: 
                                obj_results[r, n_on_idx, cprime_idx, c_idx, diff_idx, i] = np.max(np.min(ave_result, axis=0))
                        
                        
        '''
        evaluate results
        '''
        inputs = [(mu_star_results[r, n_on_idx, cprime_idx, c_idx, diff_idx,i], \
                dataset_off, a_ref, param_list['n_on'][n_on_idx], param_list['c'][c_idx], p, draws) \
                    for n_on_idx in range(len(param_list['n_on'])) \
                    for cprime_idx in range(len(param_list['cprime'])) \
                    for c_idx in range(len(param_list['c'])) \
                    for diff_idx in range(len(param_list['Delta'])) \
                    for i in range(num_keys)
                        ]
        with multiprocessing.Pool(threads) as pool: 
            out = pool.starmap(evaluate_mu, inputs)
        out_reshaped = np.array(out).reshape((len(param_list['n_on']),len(param_list['cprime']),len(param_list['c']),len(param_list['Delta']), num_keys))
        results[r] = out_reshaped

        if (r+1) % 10 == 0: 
            mu_star_results.dump(mustar_filename)
            results.dump(filename)
            if args.obj: 
                obj_results.dump(obj_filename)
        
    mu_star_results.dump(mustar_filename)
    results.dump(filename)
    if args.obj: 
        obj_results.dump(obj_filename)
    
                    
if __name__ == '__main__': 
    main()