import numpy as np 
import scipy.stats as st
import warnings


'''
ucb/lcb calculations
'''

def clt_get_lcb(dataset, c): 
    k = len(dataset)
    rhat = estimate_rewards(dataset)
    n_a = [len(dataset[a]) for a in range(k)]
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        sems = [st.sem(dataset[a]) for a in range(k)]
        lcbs = [rhat[a] if ((sems[a] == 0) & (n_a[a] >= 10)) else np.array(st.norm.interval(c, loc=rhat[a], scale=sems[a]))[0] for a in range(k)]
    lcb_list = np.array([0 if np.isnan(lcb) else lcb for lcb in lcbs])
    lcb_list = np.clip(np.array(lcb_list), 0, 1)
    return lcb_list

def clt_get_ucb(dataset, c): 
    k = len(dataset)
    rhat = estimate_rewards(dataset)
    n_a = [len(dataset[a]) for a in range(k)]
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        sems = [st.sem(dataset[a]) for a in range(k)]
        # ucbs = [np.array(st.norm.interval(c, loc=rhat[a], scale=sems[a]))[1] for a in range(k)]
        ucbs = [rhat[a] if ((sems[a] == 0) & (n_a[a] >= 10)) else np.array(st.norm.interval(c, loc=rhat[a], scale=sems[a]))[1] for a in range(k)]
    ucb_list = np.array([1 if np.isnan(ucb) else ucb for ucb in ucbs])
    ucb_list = np.clip(ucb_list, 0, 1)
    return ucb_list

############################################################################################################
'''
offline RL algos
'''

def clt_lcb(dataset, c): 
    lcb_list = clt_get_lcb(dataset, c)
    action = np.argmax(lcb_list)
    return action, lcb_list[action]

def clt_armor(a_ref, dataset, c):     
    lcb_list = clt_get_lcb(dataset, c)
    ucb_list = clt_get_ucb(dataset, c)
    ucb_a_ref = ucb_list[a_ref]  
    
    obj = lcb_list - ucb_a_ref
    obj[a_ref] = 0
    action = np.argmax(obj)
    return action, obj[action] 



############################################################################################################
'''
computing version space
'''

def clt_compute_mcal(dataset, c, a_ref, input_diff, step=0.05):
    k = len(dataset)
    lcb_list = clt_get_lcb(dataset, c) 
    ucb_list = clt_get_ucb(dataset, c)
    rhat = np.clip(estimate_rewards(dataset), 0, 1)

    mask = np.ones(k, dtype=bool)
    mask[a_ref] = False 
    gaps = ucb_list - lcb_list[a_ref]
    max_diff = np.max(gaps[mask])
    diff = min(max_diff, input_diff)

    num = np.ceil((ucb_list[a_ref]-lcb_list[a_ref]) / step).astype(int) + 1
    ref_list = np.hstack([rhat[a_ref], np.linspace(lcb_list[a_ref], ucb_list[a_ref],num=num, endpoint=True)])
    mcal_list = [] 
    for a in range(k):
        if a != a_ref:  
            ref_list_a = ref_list[np.where((ref_list + diff <= ucb_list[a]))[0]]
            aprime = list(set(range(k)).difference(set({a, a_ref})))[0]
            
            if len(ref_list_a) > 0: 
                models = np.zeros([len(ref_list_a), k])
                models[:, a_ref] = ref_list_a
                models[:, a] = np.maximum(ref_list_a + diff, lcb_list[a])
                models[:, aprime] = lcb_list[aprime]
                mcal_list.append(models)
    return np.vstack(mcal_list)


############################################################################################################
'''
other supporting functions
'''

def get_mu_on(k, step=0.1): 
    temp = np.round(np.linspace(0, 1, num=int(1/step)+1), 1)
    if k == 2: 
        mu_on_list = [[x, 1-x] for x in temp]
    else: 
        mu_on_list = [[x, y, 1-x-y] for x in temp for y in np.linspace(0, (1-x), num=int(np.round((1-x)/step))+1)]
    return np.round(np.array(mu_on_list), 3)


def combine_datasets(dataset_off, dataset_on): 
    dataset = {}
    for a in range(len(dataset_off)): 
        dataset[a] = np.array(list(dataset_off[a]) + list(dataset_on[a]))
    return dataset

def draw_dataset(p, n, mu): 
    dataset = {}
    pulls = np.random.multinomial(n, mu)
    for (a, pulls) in enumerate(pulls): 
        rewards = np.random.binomial(1, p[a], size=pulls)
        dataset[a] = rewards
    return dataset

def estimate_rewards(dataset): 
    rhat = [np.mean(dataset[a]) if len(dataset[a]) > 0 else -1 for a in range(len(dataset))]
    return np.array(rhat)

def stochastic_equal_pulls(dataset_off, a_ref, n_on, c): 
    lcb_list = clt_get_lcb(dataset_off, c)
    ucb_list = clt_get_ucb(dataset_off, c)
    arms = np.where(ucb_list >= lcb_list[a_ref])[0]

    n_a = [len(dataset_off[a]) for a in arms]
    count = n_on
    pulls = np.zeros_like(arms)
    while count > 0: 
        amin = np.argmin(n_a + pulls)
        count -= 1 
        pulls[amin] += 1
    mu = np.zeros(len(dataset_off))
    mu[arms] = pulls / n_on
    return mu 

############################################################################################################

'''
functions for parallelization 
'''

def inner(model, mu_on, dataset_off, a_ref, c, n_on, draws): 
    model_values = np.zeros([2])
    for _ in range(draws): 
        dataset_on = draw_dataset(model, n_on, mu_on)
        dataset = combine_datasets(dataset_off, dataset_on)
        ahat, armor_value = clt_armor(a_ref, dataset, c)
        model_values[0] += armor_value / draws
        model_values[1] += (model[ahat] - model[a_ref]) / draws
    return model_values

def inner_baseline(model, mu_on, dataset_off, a_ref, c, n_on, draws): 
    model_values = 0
    for _ in range(draws): 
        dataset_on = draw_dataset(model, n_on, mu_on)
        dataset = combine_datasets(dataset_off, dataset_on)
        ahat, _ = clt_armor(a_ref, dataset, c)
        model_values += (model[ahat] - model[a_ref]) / draws
    return model_values

def evaluate_mu(mu, dataset_off, a_ref, n_on, c, p, draws): 
    result = 0 
    for _ in range(draws): 
        dataset_on = draw_dataset(p, n_on, mu)
        dataset = combine_datasets(dataset_off, dataset_on)
        ahat, _ = clt_armor(a_ref, dataset, c)
        result += (p[ahat] - p[a_ref]) / draws
    return result
    
