import numpy as np
import pickle
from matplotlib import rc
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import matplotlib.colors as mcolors
from modular import extend_list
from opt_cost import OptimalWeights
import KLFunctions as kl

font = {'family': 'sans-serif',
        'weight': 'normal',
        'size': 28}

rc('font', **font)


def max_time(data):
    cut_sizes = [data[algo].cut for algo in data]
    index = cut_sizes.index(max(cut_sizes))
    return cut_sizes[index]


class Visualization:

    def __init__(self, problems, algos):

        self.algos = algos
        self.problem_map = {}
        for filename in problems:
            with open(filename, 'rb') as data_file:
                self.problem_map[filename] = pickle.load(data_file)

        self.lower_bounds = {}
        for problem in self.problem_map:
            self.lower_bounds[problem] = self.problem_map[problem]['lower_bound']
            del self.problem_map[problem]['lower_bound']

        self.cut = max(max_time(self.problem_map[problem]) for problem in problems)

        self.K = 3
        self.num_samples = 100

    def print_avg_runtime(self):
        for key in self.problem_map:
            for algo in self.algos:
                data = self.problem_map[key][algo]
                avg_time = np.average(data.run_time)
                print("{}: {}s".format(data.name, avg_time))

    def get_ratios(self):
        for key in self.problem_map:
            for algo in self.algos:
                data = self.problem_map[key][algo]
                print(data.name)
                for i, arm in enumerate(data.pulls):
                    ratio = arm[-1] / len(arm)
                    print("Arm {}: {}".format(i, ratio))

    def plot_cost_time(self, save=False):
        for k, key in enumerate(self.problem_map):
            fig = plt.figure()
            ax = fig.add_subplot()
            fig.subplots_adjust(left=.15, right=.95, top=.9)
            ax.set_xlabel('Time')
            ax.set_ylabel('Cost')
            ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
            x = [i for i in range(0, self.cut)]
            colors = ['red', 'blue', 'green', 'purple', 'orange']
            lower_bound = self.lower_bounds[key]
            for i, algo in enumerate(self.algos):
                data = self.problem_map[key][algo]
                ax.plot(x, extend_list(data.cost.tolist(), self.cut), color=colors[i], label="{}".format(data.name))
            lower_vec = [lower_bound for i in range(0, self.cut)]
            ax.plot(x, lower_vec, '--', label="Lower bound")
            plt.legend()
            plt.show()
            if save:
                fig.savefig('figures/{}-costs-time.png'.format(k))

    def plot_pulls_time(self, save=False):
        for k, key in enumerate(self.problem_map):
            fig = plt.figure()
            ax = fig.add_subplot()
            fig.subplots_adjust(left=.15, right=.95, top=.9)
            ax.set_ylabel('Number of Pulls')
            ax.set_xlabel('Time')
            # ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
            for i, algo in enumerate(self.algos):
                data = self.problem_map[key][algo]
                if data.name == "ChernoffOverlap":
                    data.name = "CO"
                x = [i for i in range(0, self.cut)]
                colors = ['darkblue', 'darkred']
                linestyles = ['-', '--', ':', '-.', '.']
                for j in range(0, self.K):
                    ax.plot(x, extend_list(data.pulls[j].tolist(), self.cut),
                            color=colors[i], ls=linestyles[j], lw=3,
                            label="{}, Arm {}".format(data.name, j + 1))
            ax.legend(fontsize=18, ncols=1)
            plt.show()
            if save:
                fig.savefig('figures/{}-pulls-time.png'.format(k), bbox_inches="tight")

    def plot_means_time(self):
        y_lower = []
        y_upper = []
        for i in range(len(self.data)):
            y_lower.append([])
            y_upper.append([])
            for j in range(0, self.K):
                diff = self.data[i].means[j] - self.data[i].means_std[j]
                diff = extend_list(diff.tolist(), self.cut)
                sum = self.data[i].means[j] + self.data[i].means_std[j]
                sum = extend_list(sum.tolist(), self.cut)
                y_lower[i].append(diff)
                y_upper[i].append(sum)
        fig = plt.figure()
        ax = fig.add_subplot()
        fig.subplots_adjust(top=0.85)
        ax.set_xlabel('Time')
        ax.set_ylabel('Mean')
        ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
        x = [i for i in range(0, self.cut)]
        colors = ['red', 'blue', 'green', 'purple']
        linestyles = ['-', '--', ':', '-.']
        for k, key in enumerate(self.problem_map):
            for i, algo in enumerate(self.algos):
                data = self.problem_map[key][algo]
                for j in range(0, self.K):
                    ax.plot(x, extend_list(data.means[j].tolist(), self.cut),
                            color=colors[i], ls=linestyles[j],
                            label="{}, {}".format(data.name, j + 1))
                    ax.fill_between(x, y_lower[i][j], y_upper[i][j],
                                    color=colors[i], alpha=0.2)
        plt.show()
        plt.legend()
        fig.savefig('figures/{}-mean-time.png'.format(self.name))

    def plot_arm_pulls(self):
        x = [i for i in range(0, self.num_samples)]
        fig = plt.figure()
        ax = fig.add_subplot()
        fig.subplots_adjust(top=0.85)
        ax.set_xlabel('Iteration')
        ax.set_ylabel('Number of pulls')
        ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
        colors = ['red', 'blue', 'green', 'purple']
        linestyles = ['-', '--', ':', '-.']
        for k, key in enumerate(self.problem_map):
            for algo in self.algos:
                data = self.problem_map[key][algo]
                for j in self.K:
                    num_pulls = [data.num_pulls[j][k]
                                 for k in range(0, self.num_samples)]
                    ax.plot(x, num_pulls,
                            ls=linestyles[j], color=colors[j],
                            label="{} {}".format(data.name, j))
        plt.legend()
        fig.savefig('figures/{}-arm_pulls.png'.format(self.name))

    def plot_costs(self, save=False):
        for k, key in enumerate(self.problem_map):
            fig = plt.figure()
            ax = fig.add_subplot()
            fig.subplots_adjust(left=.15, right=.95, top=.9)
            ax.set_ylabel('Cost')
            ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
            cost_dict = {}
            for i, algo in enumerate(self.algos):
                data = self.problem_map[key][algo]
                if data.name == "ChernoffOverlap":
                    data.name = "CO"
                elif data.name == "KL-ULCB":
                    data.name = "KL-LUCB"
                cost_dict[data.name] = data.costs
            bplot = ax.boxplot(cost_dict.values(), patch_artist=True, showmeans=True, labels=cost_dict.keys())
            # fill with colors
            colors = ['rebeccapurple', 'darkseagreen', 'darkblue', 'darkred']
            for patch, color in zip(bplot['boxes'], colors):
                patch.set_facecolor(color)
            plt.legend()
            plt.show()
            if save:
                fig.savefig("figures/{}-costs.png".format(k),
                            bbox_inches='tight')

    def print_correct(self):
        for k, key in enumerate(self.problem_map):
            for i, algo in enumerate(self.algos):
                data = self.problem_map[key][algo]
                print("{} correct: {}".format(data.name,
                                              data.predicted_arms[0]))

    def plot_all(self):
        self.plot_arm_pulls()
        self.plot_cost_time()
        self.plot_means_time()
        self.plot_costs()
        self.plot_pulls_time()


def plot_ratios(d, mu, n):
    fig = plt.figure()
    ax = fig.add_subplot()
    fig.subplots_adjust(top=0.85)
    ax.set_xlabel('Cost')
    ax.set_ylabel('Ratio')
    ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
    K = len(mu)
    ratios = np.zeros((K, n-1))
    for j in range(K):
        for i in range(1, n):
            costs = np.ones(K)
            costs[j] = i
            lower_bound, ratio = OptimalWeights(d, mu, costs)
            ratios[j][i-1] = ratio[j]
    x = [i for i in range(1, n)]
    for i in range(K):
        ax.plot(x, ratios[i], label="mu = {}".format(mu[i]))
    plt.legend()
    # fig.savefig('figures/cost-ratios.png')
    plt.show()


def plot_ratios_2(d, mu, n, arm):
    fig = plt.figure()
    ax = fig.add_subplot()
    fig.subplots_adjust(top=0.85)
    ax.set_xlabel('Cost')
    ax.set_ylabel('Ratio')
    ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
    K = len(mu)
    ratios = np.zeros((K, n-1))
    for i in range(1, n):
        costs = np.ones(K)
        costs[2] = i
        lower_bound, ratio = OptimalWeights(d, mu, costs)
        for j in range(K):
            ratios[j][i-1] = ratio[j]
    x = [i for i in range(1, n)]
    for i in range(K):
        ax.plot(x, ratios[i], label="mu = {}".format(mu[i]))
    plt.legend()
    plt.show()


def regret(algo, means):
    best = max(means)
    num_pulls = [arm.num_pulls for arm in algo.arms]
    regret = algo.time*best - np.dot(means, num_pulls)
    return regret
