from helpers import dicoSolve, dist_funcs
import numpy as np


def Ifonc(d, alpha, mu1, mu2):
    if (alpha == 0) or (alpha == 1):
        return 0
    else:
        mid = alpha*mu1 + (1-alpha)*mu2
        return alpha*d(mu1, mid)+(1-alpha)*d(mu2, mid)


def cost(d, mu1, mu2, nu1, nu2):
    if (nu1 == 0) and (nu2 == 0):
        return 0
    else:
        alpha = nu1/(nu1 + nu2)
        return((nu1 + nu2)*Ifonc(d, alpha, mu1, mu2))


def xkofy(d, y, k, mu, delta=1e-11):
    # return x_k(y), i.e. finds x such that g_k(x)=y
    g = lambda x: (1+x)*cost(d, mu[0], mu[k], 1/(1+x), x/(1+x)) - y
    xMax = 1
    while g(xMax) < 0:
        xMax = 2*xMax
    return dicoSolve(g, 0, xMax, 1e-11)


def muddle(mu1, mu2, nu1, nu2):
    return (nu1*mu1 + nu2*mu2)/(nu1+nu2)


def aux(d, y, mu, c):
    # returns F_mu(y) - 1
    K = len(mu)
    x = [xkofy(d, y, k, mu) for k in range(1, K)]
    m = [muddle(mu[0], mu[k], 1, x[k-1]) for k in range(1, K)]
    return sum([c[k]*d(mu[0], m[k-1])/(c[0]*d(mu[k], m[k-1])) for k in range(1, K)]) - 1


def oneStepOpt(d, mu, c, delta=1e-11):
    yMax = 0.5
    if d(mu[0], mu[1]) == float('inf'):
        # find yMax such that aux(yMax,mu)>0
        while aux(d, yMax, mu, c) < 0:
            yMax = yMax*2
    else:
        yMax = d(mu[0], mu[1])

    f = lambda y: aux(d, y, mu, c)
    y = dicoSolve(f, 0, yMax, delta)
    x = [xkofy(d, y, k, mu, delta) for k in range(1, len(mu))]
    x.insert(0, 1)
    x = np.diag(c) @ x
    nuOpt = x/sum(x)
    return y/sum(x), nuOpt


def OptimalWeights(d, mu, c, delta=1e-11):
    # returns T*(mu) and w*(mu)
    K = len(mu)
    IndMax = np.argwhere(mu == np.amax(mu))
    L = len(IndMax)
    if L > 1:
        # multiple optimal arms
        vOpt = np.zeros(K)
        for idx in IndMax:
            vOpt[idx] = 1/L
        return float('inf'), vOpt
    else:
        mu = np.array(mu)
        c = np.array(c)
        sort_idx = np.argsort(mu)[::-1]
        mu = np.take_along_axis(mu, sort_idx, axis=0)
        c = np.take_along_axis(c, sort_idx, axis=0)
        # one-step optim
        vOpt, NuOpt = oneStepOpt(d, mu, c, delta)
        # back to good ordering
        nuOpt = np.zeros(K)
        for curr, old in enumerate(sort_idx):
            nuOpt[curr] = NuOpt[old]
        return 1/vOpt, nuOpt
