import numpy as np
from opt_cost import OptimalWeights
from algo_base import Algorithms
from random import randrange


class TrackAndStop(Algorithms):

    def __init__(self, mudist, arms, delta):
        self.det = False
        if mudist == "Deterministic":
            self.det = True
        super().__init__(mudist, arms, delta)

    def stop(self):
        stop = np.log((np.log(self.time) + 1) / self.delta)
        best = np.argmax([arm.empirical_mean for arm in self.arms])
        K = len(self.arms)

        muB = self.arms[best].empirical_mean
        NB = self.arms[best].num_pulls
        N = np.array([arm.num_pulls for arm in self.arms])
        Mu = np.array([arm.empirical_mean for arm in self.arms])

        S = np.diag(N) @ Mu
        MuMid = np.divide((S[best] + S), NB + N)
        index = np.delete(np.arange(0, K, 1), best, axis=0)
        Score = min([NB*self.d(muB, MuMid[i])+N[i]*self.d(Mu[i], MuMid[i])
                     for i in index])

        if Score > stop:
            return True
        return False

    def next_arm(self):
        mu = [arm.empirical_mean for arm in self.arms]
        costs = np.ones(len(self.arms))
        T, w_vec = OptimalWeights(self.d, mu, costs)
        num_pulls = np.array([arm.num_pulls for arm in self.arms])

        index = np.argwhere(mu == np.amax(mu))
        L = len(index)
        if L > 1:
            idx = randrange(0, len(index))
            return [self.arms[index[idx, 0]]]
        if min(num_pulls) <= 2*np.sqrt(self.time) and not self.det:
            index = np.argmin(num_pulls)
        else:
            vec = self.time*w_vec - (num_pulls)
            index = np.argmax(vec)

        # if (self.time > 10000):
        #     breakpoint()

        return [self.arms[index]]


class CTrackAndStop(Algorithms):

    def __init__(self, mudist, arms, delta):
        self.det = False
        if mudist == "Deterministic":
            self.det = True
        super().__init__(mudist, arms, delta)

    def stop(self):
        stop = np.log((np.log(self.time) + 1) / self.delta)
        best = np.argmax([arm.empirical_mean for arm in self.arms])
        K = len(self.arms)

        muB = self.arms[best].empirical_mean
        NB = self.arms[best].num_pulls
        N = np.array([arm.num_pulls for arm in self.arms])
        Mu = np.array([arm.empirical_mean for arm in self.arms])

        S = np.diag(N) @ Mu
        MuMid = np.divide((S[best] + S), NB + N)
        index = np.delete(np.arange(0, K, 1), best, axis=0)
        Score = min([NB*self.d(muB, MuMid[i])+N[i]*self.d(Mu[i], MuMid[i])
                     for i in index])

        if Score > stop:
            return True
        return False

    def next_arm(self):
        mu = [arm.empirical_mean for arm in self.arms]
        costs = [arm.empirical_cost for arm in self.arms]
        T, w_vec = OptimalWeights(self.d, mu, costs)
        num_pulls = np.array([arm.num_pulls for arm in self.arms])

        index = np.argwhere(mu == np.amax(mu))
        L = len(index)
        if L > 1:
            best = index[0, 0]
            for i in index:
                arm = self.arms[i[0]]
                best_arm = self.arms[best]
                if best_arm.empirical_cost > arm.empirical_cost:
                    best = i[0]
            return [self.arms[best]]
        if min(num_pulls) <= 4*np.sqrt(self.time) and not self.det:
            index = np.argmin(num_pulls)
        else:
            vec = self.cost*w_vec - (np.diag(costs) @ num_pulls)
            index = np.argmax(vec)

        # if (self.time > 10000):
        #     breakpoint()

        return [self.arms[index]]


class KellTrackAndStop(Algorithms):

    def __init__(self, mudist, arms, delta):
        super().__init__(mudist, arms, delta)
        self.max = 100000

    def beta(self, K=1):
        return np.log(K*np.log(self.time) + 1 / self.delta)

    def stop(self):
        stop = np.log((np.log(self.time) + 1) / self.delta)
        best = np.argmax([arm.empirical_mean for arm in self.arms])
        K = len(self.arms)

        muB = self.arms[best].empirical_mean
        NB = self.arms[best].num_pulls
        N = np.array([arm.num_pulls for arm in self.arms])
        Mu = np.array([arm.empirical_mean for arm in self.arms])

        S = np.diag(N) @ Mu
        MuMid = np.divide((S[best] + S), NB + N)
        index = np.delete(np.arange(0, K, 1), best, axis=0)
        Score = min([NB*self.d(muB, MuMid[i])+N[i]*self.d(Mu[i], MuMid[i]) for i in index])

        if Score > stop:
            return True
        return False

    def next_arm(self):
        mu = [self.dup(arm.empirical_mean, self.beta()/arm.num_pulls) for arm in self.arms]
        costs = [arm.empirical_cost for arm in self.arms]
        T, w_vec = OptimalWeights(self.d, mu, costs)
        num_pulls = np.array([arm.num_pulls for arm in self.arms])

        index = np.argwhere(mu == np.amax(mu))
        L = len(index)
        if L > 1:
            best = index[0, 0]
            for i in index:
                arm = self.arms[i[0]]
                best_arm = self.arms[best]
                if best_arm.empirical_cost > arm.empirical_cost:
                    best = i[0]
            return self.arms[best]
        vec = self.cost*w_vec - (np.diag(costs) @ num_pulls)
        index = np.argmax(vec)

        # if (self.time > 10000):
        #     breakpoint()

        return [self.arms[index]]


class ChernoffOverlap(Algorithms):
    def __init__(self, mudist, arms, delta):
        super().__init__(mudist, arms, delta)
        self.remaining_arms = [i for i in range(len(arms))]

    def beta(self, K=1):
        # return np.log((2*self.time*(K - 1) + 1) / self.delta)
        return np.log(K*np.log(self.time) + 1 / self.delta)

    def eliminate(self, arm, K=1):
        B = self.beta(K)
        arms = [self.arms[i] for i in self.remaining_arms]
        mu_bar = [arm.empirical_mean for arm in arms]

        def statistic(left, right):
            MuB = right.empirical_mean
            Mu = left.empirical_mean
            NB = right.num_pulls
            N = left.num_pulls

            MuMid = (NB*MuB + N*Mu)/(NB + N)
            Score = NB*self.d(MuB, MuMid)+N*self.d(Mu, MuMid)
            return Score

        potential = [i for i, v in enumerate(mu_bar) if v > arm.empirical_mean]

        if len(potential) == 0:
            return False

        Score = max([statistic(arm, self.arms[self.remaining_arms[i]])
                     for i in potential])

        if Score > B:
            return True
        return False

    def stop(self):
        if len(self.remaining_arms) > 1:
            return False
        return True

    def debug(self):
        for arm in self.arms:
            print("Arm {}:".format(arm.name))
            print("Num pulls: {}".format(arm.num_pulls))
            print("Mean: {}".format(arm.empirical_mean))
            print("Confidence interval: ({}, {})".format(self.lower(arm), self.upper(arm)))

    def ratio(self, arm):
        diff = 1/arm.num_pulls
        return diff/np.sqrt(arm.empirical_cost)

    def next_arm(self):
        K = len(self.remaining_arms)
        for i in self.remaining_arms:
            arm = self.arms[i]
            if self.eliminate(arm, K=K):
                self.remaining_arms.remove(i)
                # print("Arm {} at time {}".format(i, self.time))
        ratios = [self.ratio(self.arms[i]) for i in self.remaining_arms]
        index = np.argmax(ratios)
        arm = self.arms[self.remaining_arms[index]]
        # self.debug()
        return [arm]


class ChernoffModified(ChernoffOverlap):
    def __init__(self, mudist, arms, delta):
        super().__init__(mudist, arms, delta)

    def ratio(self, arm):
        best_arm = self.arms[self.best_idx]
        if arm == best_arm:
            arms_remaining = len(self.remaining_arms)
            diff = arms_remaining/arm.num_pulls
        else:
            diff = 1/arm.num_pulls
        return diff/np.sqrt(arm.empirical_cost)


class ChernoffRacing(ChernoffOverlap):
    def __init__(self, mudist, arms, delta):
        super().__init__(mudist, arms, delta)

    def next_arm(self):
        for i in self.remaining_arms:
            arm = self.arms[i]
            if self.eliminate(arm):
                self.remaining_arms.remove(i)
                # print("Arm {} at time {}".format(i, self.time))
        # self.debug()
        return [self.arms[i] for i in self.remaining_arms]


class KLOverlap(Algorithms):

    def __init__(self, mudist, arms, delta):
        super().__init__(mudist, arms, delta)
        self.remaining_arms = [i for i in range(len(arms))]

    def stop(self):
        B = self.beta()
        arms = [self.arms[i] for i in self.remaining_arms]
        best = np.argmax([arm.empirical_mean for arm in arms])
        K = len(arms)

        muB = arms[best].empirical_mean
        NB = arms[best].num_pulls
        N = np.array([arm.num_pulls for arm in arms])
        Mu = np.array([arm.empirical_mean for arm in arms])

        S = np.diag(N) @ Mu
        MuMid = np.divide((S[best] + S), NB + N)
        index = np.delete(np.arange(0, K, 1), best, axis=0)
        Score = min([NB*self.d(muB, MuMid[i])+N[i]*self.d(Mu[i], MuMid[i]) for i in index])

        if Score > B:
            return True
        return False

    def debug(self):
        for arm in self.arms:
            print("Arm {}:".format(arm.name))
            print("Num pulls: {}".format(arm.num_pulls))
            print("Mean: {}".format(arm.empirical_mean))
            print("Confidence interval: ({}, {})".format(self.lower(arm), self.upper(arm)))

    def beta(self, K=1):
        return np.log(K*np.log(self.time) + 1 / self.delta)

    def upper(self, arm):
        return self.dup(arm.empirical_mean, self.beta()/arm.num_pulls)

    def lower(self, arm):
        return self.dlow(arm.empirical_mean, self.beta()/arm.num_pulls)

    def ratio(self, arm):
        diff = 1/arm.num_pulls
        return diff/np.sqrt(arm.empirical_cost)

    def next_arm(self):
        lower_bounds = [self.lower(self.arms[i]) for i in self.remaining_arms]
        max_lower = max(lower_bounds)
        for i in self.remaining_arms:
            arm = self.arms[i]
            if self.upper(arm) < max_lower:
                self.remaining_arms.remove(i)
                # print("Arm {} at time {}".format(i, self.time))
        ratios = [self.ratio(self.arms[i]) for i in self.remaining_arms]
        index = np.argmax(ratios)
        arm = self.arms[self.remaining_arms[index]]
        # self.debug()
        return [arm]


class KLULCB(Algorithms):

    def __init__(self, mudist, arms, delta):
        super().__init__(mudist, arms, delta)

    def stop(self):
        B = np.log((np.log(self.time + 1) + 1) / self.delta)
        K = len(self.arms)

        muB = self.arms[self.best_idx].empirical_mean
        NB = self.arms[self.best_idx].num_pulls
        N = np.array([arm.num_pulls for arm in self.arms])
        Mu = np.array([arm.empirical_mean for arm in self.arms])

        S = np.diag(N) @ Mu
        MuMid = np.divide((S[self.best_idx] + S), NB + N)
        index = np.delete(np.arange(0, K, 1), self.best_idx, axis=0)
        Score = min([NB*self.d(muB, MuMid[i])+N[i]*self.d(Mu[i], MuMid[i]) for i in index])

        if Score > B:
            return True
        return False

    def beta(self):
        return np.log(np.log(self.time) + 1 / self.delta)

    def upper(self, arm):
        return self.dup(arm.empirical_mean, self.beta()/arm.num_pulls)

    def lower(self, arm):
        return self.dlow(arm.empirical_mean, self.beta()/arm.num_pulls)

    def next_arm(self):
        K = len(self.arms)
        costs = [arm.empirical_cost for arm in self.arms]
        best = self.arms[self.best_idx]
        diff = np.zeros(K)
        for i in range(K):
            if i != self.best_idx:
                arm = self.arms[i]
                diff[i] = self.upper(arm)
        max_idx = np.argmax(diff)
        return [best, self.arms[max_idx]]


class UGapE(Algorithms):

    def __init__(self, mudist, arms, delta):
        super().__init__(mudist, arms, delta)

    def beta(self):
        return np.log(np.log(self.time) + 1 / self.delta)

    def upper(self, arm):
        return self.dup(arm.empirical_mean, self.beta()/arm.num_pulls)

    def lower(self, arm):
        return self.dlow(arm.empirical_mean, self.beta()/arm.num_pulls)

    def debug(self):
        for arm in self.arms:
            print("Arm {}:".format(arm.name))
            print("Num pulls: {}".format(arm.num_pulls))
            print("Mean: {}".format(arm.empirical_mean))
            print("Confidence interval: ({}, {})".format(self.lower(arm), self.upper(arm)))
            print("Overlap: {}".format(self.calculate_overlap()))

    def contains(self, left, right):
        lower_left = self.lower(left)
        upper_left = self.upper(left)

        lower_right = self.lower(right)
        upper_right = self.upper(right)

        if lower_left < lower_right and upper_left > upper_right:
            return True
        else:
            return False

    def calculate_overlap(self):
        best = self.arms[self.best_idx]
        overlap = 0
        for arm in self.arms:
            if arm == best:
                continue
            right = self.upper(arm)
            if right > self.upper(best):
                right = self.upper(best)
            left = self.lower(arm)
            if left < self.lower(best):
                left = self.lower(best)
            overlap += max(right - left, 0)
        return overlap

    # def stop(self):
    #     return (not self.calculate_overlap() > 0)

    def stop(self):
        stop = np.log((np.log(self.time) + 1) / self.delta)
        best = np.argmax([arm.empirical_mean for arm in self.arms])
        K = len(self.arms)

        muB = self.arms[best].empirical_mean
        NB = self.arms[best].num_pulls
        N = np.array([arm.num_pulls for arm in self.arms])
        Mu = np.array([arm.empirical_mean for arm in self.arms])

        S = np.diag(N) @ Mu
        MuMid = np.divide((S[best] + S), NB + N)
        index = np.delete(np.arange(0, K, 1), best, axis=0)
        Score = min([NB*d(muB, MuMid[i])+N[i]*d(Mu[i], MuMid[i]) for i in index])

        if Score > stop:
            return True
        elif self.calculate_overlap() == 0:
            return True
        return False

    def overlap_virtual(self, arm):
        arm.num_pulls += 1
        self.time += 1
        overlap = self.calculate_overlap()
        self.time -= 1
        arm.num_pulls -= 1
        return overlap

    def ratio(self, arm):
        overlap = self.calculate_overlap()
        virtual = self.overlap_virtual(arm)
        return (overlap - virtual)/(arm.cost)

    def next_arm(self):
        best = self.arms[self.best_idx]
        for arm in self.arms:
            if self.contains(arm, best):
                return arm
            elif self.contains(best, arm):
                return best
        ratios = [self.ratio(arm) for arm in self.arms]
        index = np.argmax(ratios)
        # self.debug()
        return [self.arms[index]]
