import numpy as np
import time
import six
from abc import ABCMeta, abstractmethod
from helpers import dist_funcs


class Arm:

    def __init__(self, name, rewards, costs):
        self.name = name
        self.rewards = rewards
        self.costs = costs
        self.empirical_mean = 0
        self.empirical_cost = 0
        self.num_pulls = 0

    def update_mean(self):
        new_sample = self.rewards[self.num_pulls + 1]
        new_sample = new_sample / (self.num_pulls + 1)
        self.empirical_mean = self.empirical_mean * (
            self.num_pulls / (self.num_pulls + 1)) + new_sample

    def update_cost(self):
        new_sample = self.costs[self.num_pulls + 1]
        new_sample = new_sample / (self.num_pulls + 1)
        self.empirical_cost = self.empirical_cost * (
            self.num_pulls / (self.num_pulls + 1)) + new_sample

    def pull(self):
        self.update_mean()
        self.update_cost()
        self.num_pulls = self.num_pulls + 1
        cost = self.costs[self.num_pulls]
        reward = self.rewards[self.num_pulls]
        return cost, reward

    def copy(self):
        copy = Arm(self.name, self.cost, self.rewards)
        copy.empirical_mean = self.empirical_mean
        copy.num_pulls = self.num_pulls
        copy.rewards = self.rewards
        return copy

    def __eq__(self, other):
        return self.name == other.name


@six.add_metaclass(ABCMeta)
class Algorithms:

    def __init__(self, rewards_distr, arms, delta):
        self.delta = delta
        self.arms = arms
        self.best_idx = 0
        self.time = 1
        self.cost = 0
        self.xaxis = []
        self.cost_time = []
        self.max = 10000
        self.bonuses = [[] for _ in range(0, len(arms))]
        self.pulls = [[] for _ in range(0, len(arms))]
        self.means = [[] for _ in range(0, len(arms))]
        self.d, self.dup, self.dlow = dist_funcs(rewards_distr)

    @abstractmethod
    def stop(self):
        pass

    def pull(self, arm):
        cost, reward = arm.pull()
        self.cost += cost
        self.best_idx = np.argmax([arm.empirical_mean for arm in self.arms])
        self.time += 1

        self.cost_time.append(self.cost)
        self.xaxis.append(self.time)
        for idx in range(0, len(self.arms)):
            # self.bonuses[idx].append(self.bonus(self.arms[idx]))
            self.pulls[idx].append(self.arms[idx].num_pulls)
            self.means[idx].append(self.arms[idx].empirical_mean)

    def run(self):
        self.run_time = time.process_time()
        for arm in self.arms:
            while arm.empirical_mean <= 0 or arm.empirical_cost <= 0:
                self.pull(arm)
        while(not self.stop() and self.time < self.max):
            next_arms = self.next_arm()
            for next in next_arms:
                self.pull(next)

        self.run_time = time.process_time()
        return self.arms[self.best_idx].name

    @abstractmethod
    def next_arm(self):
        pass
