

import numpy as np 
import pickle 
from helper_functions import *
import argparse 
import multiprocessing

examples = {
    3 : {
    'p' : [0.6, 0.3, 0.1], 
    'n_off' : 100, 
    'mu_off' : [0.1, 0.9, 0.0],
    'name' : 'example3', 
    }, 
    ## 
    4 : {
    'p': [0.6, 0.5, 0.2], 
    'n_off' : 200,
    'mu_off' : [0.01, 0.01, 0.98],
    'name' : 'example4', 
    }, 
    ##
    5 : {
    'p': [0.6, 0.3, 0.2], 
    'n_off' : 200,
    'mu_off' : [0.01, 0.495, 0.495],
    'name' : 'example5', 
    }, 
}

def parse_arguments():
    parser = argparse.ArgumentParser(description='Parser')
    
    parser.add_argument('--example', dest='example', type=int, required=True)
    parser.add_argument('--draws', dest='draws', type=int, required=False, default=200)
    parser.add_argument('--repeats', dest='repeats', type=int, required=False, default=100)
    parser.add_argument('--muon-step', dest='mu_on_step', type=float, default=0.1)
    parser.add_argument('--cprime', dest='cprime', type=float, default=0.95)
   
    return parser.parse_args()

def main(): 
    args = parse_arguments() 

    problem = examples[args.example]
    p = problem['p']
    k = len(p)
    n_off = problem['n_off']
    mu_off = problem['mu_off']

    draws = args.draws 
    mu_on_step = args.mu_on_step 
    mu_on_list = get_mu_on(k, step=mu_on_step)
    mu_uniform = np.ones(k) / k 

    root = './data/example_' + str(args.example) 
    filename = root + '_BASELINE.npy' 
    param_list = {
        'n_on': np.ceil(np.power(0.5, np.flip(np.linspace(0,2,3))) * n_off), 
        'c' : 1. - np.power(0.5, (np.linspace(0,np.emath.logn(0.5, .05),4))), 
    }

    dataset_filename = './data/example_' + str(args.example) + '.pickle'
    with open(dataset_filename, 'rb') as handle:
        all_data_off = pickle.load(handle) 
    repeats = min(args.repeats, len(all_data_off))
    
    num_keys = 6
    threads = 4
    
    results = np.zeros([repeats, 
                    len(param_list['n_on']), 
                    len(param_list['c']), 
                    num_keys])


    for r in range(repeats): 
        
        dataset_off = all_data_off[r]['data']
        a_ref = all_data_off[r]['aref']

        mhat = np.clip(estimate_rewards(dataset_off), 0, 1)

        ucb = clt_get_ucb(dataset_off, args.cprime)
        lcb = clt_get_lcb(dataset_off, args.cprime)
        a_ucb = np.argmax(ucb) 

        m_ucb = np.copy(lcb)
        m_ucb[a_ucb] = ucb[a_ucb]

        mu_a_ucb = np.zeros_like(mu_off)
        mu_a_ucb[a_ucb] = 1


        for n_on_idx, n_on in enumerate(param_list['n_on']): 
            for c_idx, c in enumerate(param_list['c']): 
                print('example: {num}\tr: {repeat}\t\tn_on: {n_on}\tc: {c}'.format(
                            num=args.example, repeat=r, n_on=n_on, c=c,))
                
                inputs = [(mhat, mu_on, dataset_off, a_ref, c, n_on, draws) for mu_on in mu_on_list]
                with multiprocessing.Pool(threads) as pool: 
                    out = pool.starmap(inner_baseline, inputs)
                mu_mhat = mu_on_list[np.argmax(np.array(out))]

                inputs = [(p, mu_on, dataset_off, a_ref, c, n_on, draws) for mu_on in mu_on_list]
                with multiprocessing.Pool(threads) as pool: 
                    out = pool.starmap(inner_baseline, inputs)
                mu_mstar = mu_on_list[np.argmax(np.array(out))]
                
                inputs = [(m_ucb, mu_on, dataset_off, a_ref, c, n_on, draws) for mu_on in mu_on_list]
                with multiprocessing.Pool(threads) as pool: 
                    out = pool.starmap(inner_baseline, inputs)
                mu_ucb = mu_on_list[np.argmax(np.array(out))]

                for key_idx, mu_on in enumerate([mu_mhat, mu_mstar, mu_ucb, mu_a_ucb, mu_uniform, mu_off]): 
                    out = evaluate_mu(mu_on, dataset_off, a_ref, n_on, c, p, draws)
                    results[r, n_on_idx, c_idx, key_idx] = out
    results.dump(filename)

if __name__ == '__main__': 
    main()