import numpy as np
import random
import KLFunctions as kl


def dicoSolve(f, xMin, xMax, pre=1e-11):
    # find m such that f(m)=0 using binary search
    low = xMin
    up = xMax
    sgn = f(xMin)
    while up - low > pre:
        m = (up+low)/2
        if (f(m)*sgn > 0):
            low = m
        else:
            up = m
    return (up+low)/2


def dist_funcs(rewards_distr):
    # if rewards_distr == "Bernoulli":
    #     d = kl.dBernoulli
    #     dup = kl.dupBernoulli
    #     dlow = kl.dlowBernoulli
    # elif rewards_distr == "Poisson":
    #     d = kl.dPoisson
    #     dup = kl.dupPoisson
    #     dlow = kl.dlowPoisson
    if rewards_distr == "Exponential":
        d = kl.dExpo
        dup = kl.dupExpo
        dlow = kl.dlowExpo
    else:
        d = kl.dBernoulli
        dup = kl.dupBernoulli
        dlow = kl.dlowBernoulli
    # else:
    #     # sigma (std) must be defined !
    #     d = kl.dGaussian
    #     dup = kl.dupGaussian
    #     dlow = kl.dlowGaussian
    return d, dup, dlow


def sample(rewards_distr, mu, N, sigma=1):
    if rewards_distr == "Bernoulli":
        return np.random.binomial(size=N, n=1, p=mu)
    elif rewards_distr == "Poisson":
        return np.random.poisson(lam=mu, size=N)
    elif rewards_distr == "Exponential":
        return -mu*np.log(random.randint())
    elif rewards_distr == "Gaussian":
        return np.random.normal(loc=mu, scale=sigma, size=N)
    else:
        return mu*np.ones(N)


def paths(distr, means, num_paths, num_samples=500000, std=1):
    paths = []
    num_arms = len(means)
    for i in range(0, num_arms):
        paths.append([])
    for i in range(0, num_arms):
        for j in range(0, num_paths):
            paths[i].append(sample(distr, means[i], num_samples))
    return paths
