import numpy as np
from opt_cost import OptimalWeights, d
from algo_base import Algorithms, Arm


class RegretArm(Arm):
    def __init__(self, name, rewards):
        super().__init__(name, rewards)
        self.pull()


class TrackAndStopRegret(Algorithms):

    def __init__(self, arms, delta, costs):
        super().__init__(arms, delta)
        self.costs = costs
        self.cost = np.sum(costs)

    def stop(self):
        stop = np.log((np.log(self.time) + 1) / self.delta)
        best = np.argmax([arm.empirical_mean for arm in self.arms])
        K = len(self.arms)

        muB = self.arms[best].empirical_mean
        NB = self.arms[best].num_pulls
        N = np.array([arm.num_pulls for arm in self.arms])
        Mu = np.array([arm.empirical_mean for arm in self.arms])

        S = np.diag(N) @ Mu
        MuMid = np.divide((S[best] + S), NB + N)
        index = np.delete(np.arange(0, K, 1), best, axis=0)
        Score = min([NB*d(muB, MuMid[i])+N[i]*d(Mu[i], MuMid[i]) for i in index])

        if Score > stop:
            return True
        return False

    def next_arm(self):
        mu = [arm.empirical_mean for arm in self.arms]
        num_pulls = np.array([arm.num_pulls for arm in self.arms])
        index = np.argwhere(mu == np.amax(mu))
        L = len(index)
        if L > 1:
            best = index[0]
            for i in index:
                if self.arms[best].cost > self.arms[i].cost:
                    best = i
            return self.arms[best]
        if min(num_pulls) <= 3*np.sqrt(np.log(self.time)):
            index = np.argmin(num_pulls)
        else:
            best = self.best_idx
            diff = mu - best.empirical_mean
            T, w_vec = OptimalWeights(mu, self.costs)
            vec = (np.linalg.inv(np.diag(self.costs)) @ w_vec) - (1/self.cost)*num_pulls
            index = np.argmax(vec)
            self.cost += self.costs[index]

        # if (self.time > 10000):
        #     breakpoint()

        return self.arms[index]


class AlphaUCB(Algorithms):

    def __init__(self, arms, delta, costs, alpha=1):
        self.alpha = alpha
        super().__init__(arms, delta)

    def stop(self):
        if self.time > 10000:
            return True
        best = self.best_idx
        lower = best.empirical_mean - self.bonus(best)
        for arm in self.arms:
            if arm == best:
                continue
            upper = arm.empirical_mean + self.bonus(arm)
            if upper >= lower:
                return False
        return True
        # if self.time > 10**5:
        #     return True
        # else:
        #     return False

    def bonus(self, arm):
        N_i = arm.num_pulls
        return np.sqrt(np.log(3*(np.log(N_i))**2/self.delta)/N_i)

    def next_arm(self):
        stat = [arm.empirical_mean + self.alpha*self.bonus(arm) for arm in self.arms]
        index = np.argmax(stat)
        return self.arms[index]
