import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
from algo_base import Arm
import BAIalgos as algos


def extend_list(data, desired_length):
    curr_length = len(data)
    if curr_length < desired_length:
        added_length = desired_length - curr_length
        data.extend([data[-1]] * added_length)
    return data


def get_name(Algorithm):
    if Algorithm == algos.TrackAndStop:
        return "TAS"
    elif Algorithm == algos.CTrackAndStop:
        return "CTAS"
    elif Algorithm == algos.ChernoffOverlap:
        return "CO"
    elif Algorithm == algos.KLOverlap:
        return "KLOverlap"
    elif Algorithm == algos.KLULCB:
        return "KL-ULCB"
    elif Algorithm == algos.KellTrackAndStop:
        return "Kellen"
    elif Algorithm == algos.ChernoffRacing:
        return "ChernoffRacing"
    elif Algorithm == algos.ChernoffModified:
        return "ChernoffModified"
    else:
        return "Unknown"


class Data:

    def __init__(self, Algorithm, costs, rewards, dist, delta=1e-4):

        self.costs = []

        self.num_arms = len(rewards)
        self.num_samples = len(rewards[0])

        self.num_pulls = np.zeros((self.num_samples, self.num_arms))
        self.predicted_arms = np.zeros(self.num_arms)
        self.run_time = []

        time_cost = []
        time_pulls = []
        time_means = []
        bonuses = []

        self.name = get_name(Algorithm)
        self.cut = 1

        for i in range(self.num_samples):

            arms = []

            for j in range(0, self.num_arms):
                arms.append(Arm(j+1, rewards[j][i], costs[j][i]))

            algorithm = Algorithm(dist, arms, delta)
            predicted_arm = algorithm.run()
            self.predicted_arms[predicted_arm - 1] += 1
            self.run_time.append(algorithm.run_time)

            self.costs.append(algorithm.cost)

            for j in range(0, self.num_arms):
                self.num_pulls[i][j] += algorithm.arms[j].num_pulls

            # self.cut = (algorithm.time - 1) if i == 1 else min(self.cut, algorithm.time - 1)
            self.cut = max(self.cut, algorithm.time - 1)

            time_means.append(algorithm.means)
            time_pulls.append(algorithm.pulls)
            time_cost.append(algorithm.cost_time)

            # bonuses.append([algorithm.bonuses[j] for j in range(0, len(algorithm.bonuses))])

        for i in range(self.num_samples):
            time_cost[i] = extend_list(time_cost[i], self.cut)
            for j in range(0, self.num_arms):
                # bonuses[i][j] = extend_list(bonuses[i][j], self.cut)
                time_pulls[i][j] = extend_list(time_pulls[i][j], self.cut)
                time_means[i][j] = extend_list(time_means[i][j], self.cut)

        self.avg_cost = np.average(self.costs)

        self.average_pulls = np.average(self.num_pulls, axis=0)

        # self.bonuses = np.average(bonuses, axis=0)

        self.means = np.average(time_means, axis=0)
        self.means_std = np.std(time_means, axis=0)

        self.pulls = np.average(time_pulls, axis=0)
        self.pulls_std = np.std(time_pulls, axis=0)

        self.cost = np.average(time_cost, axis=0)
        self.cost_std = np.std(time_cost, axis=0)
