# Taken and modified from https://github.com/DLR-RM/rl-baselines3-zoo
# MIT License

import argparse
import itertools
import pickle

import matplotlib
import numpy as np
import pytablewriter
import seaborn
from matplotlib import pyplot as plt
from rliable import library as rly
from rliable import metrics, plot_utils
from score_normalization import normalize_score


def compute_parameter_numbers(algo: str, env_id: str) -> int:
    if algo == "Open Loop":
        return {
            "Ant-v4": 19,
            "HalfCheetah-v4": 19,
            "Hopper-v4": 7,
            "Walker2d-v4": 19,
            "Swimmer-v4": 3,
        }[env_id]
    else:
        from sb3_contrib import ARS
        from stable_baselines3 import DDPG, PPO, SAC

        policy = "MlpPolicy" if algo != "ARS" else "LinearPolicy"
        model_class = {
            "ARS": ARS,
            "PPO": PPO,
            "DDPG": DDPG,
            "SAC": SAC,
        }[algo]

        return sum(p.numel() for p in model_class(policy, env_id).policy.parameters() if p.requires_grad)


def plot_from_file():  # noqa: C901
    parser = argparse.ArgumentParser("Gather results, plot them and create table")
    parser.add_argument("-i", "--input", help="Input filename (numpy archive)", type=str)
    parser.add_argument("-skip", "--skip-envs", help="Environments to skip", nargs="+", default=[], type=str)
    parser.add_argument("--keep-envs", help="Envs to keep", nargs="+", default=[], type=str)
    parser.add_argument("--skip-keys", help="Keys to skip", nargs="+", default=[], type=str)
    parser.add_argument("--keep-keys", help="Keys to keep", nargs="+", default=[], type=str)
    parser.add_argument("--no-million", action="store_true", default=False, help="Do not convert x-axis to million")
    parser.add_argument("--skip-timesteps", action="store_true", default=False, help="Do not display learning curves")
    parser.add_argument("-o", "--output", help="Output filename (image)", type=str)
    parser.add_argument("--format", help="Output format", type=str, default="svg")
    parser.add_argument("-loc", "--legend-loc", help="The location of the legend.", type=str, default="best")
    parser.add_argument("--figsize", help="Figure size, width, height in inches.", nargs=2, type=int, default=[6.4, 4.8])
    parser.add_argument("--fontsize", help="Font size", type=int, default=14)
    parser.add_argument("-l", "--labels", help="Custom labels", type=str, nargs="+")
    parser.add_argument("-b", "--boxplot", help="Enable boxplot", action="store_true", default=False)
    parser.add_argument("-r", "--rliable", help="Enable rliable plots", action="store_true", default=False)
    parser.add_argument("-vs", "--versus", help="Enable probability of improvement plot", action="store_true", default=False)
    parser.add_argument("-iqm", "--iqm", help="Enable IQM sample efficiency plot", action="store_true", default=False)
    parser.add_argument("-ci", "--ci-size", help="Confidence interval size (for rliable)", type=float, default=0.95)
    parser.add_argument("-latex", "--latex", help="Enable latex support", action="store_true", default=False)
    parser.add_argument("--merge", help="Merge with other results files", nargs="+", default=[], type=str)
    parser.add_argument(
        "-count-param", "--count-parameters", help="Count the number of parameters", action="store_true", default=False
    )

    args = parser.parse_args()

    # Seaborn style
    # seaborn.set(style="whitegrid",font_scale=1.8)
    # Number of repetition for CI intervals
    repetitions = 200  # default: 2000

    # Enable LaTeX support
    if args.latex:
        plt.rc("text", usetex=True)
        plt.rc("text.latex", preamble=r"\usepackage{amsmath}")  # for \text command

    filename = args.input

    if not filename.endswith(".pkl"):
        filename += ".pkl"

    with open(filename, "rb") as file_handler:
        results = pickle.load(file_handler)

    # Plot table
    writer = pytablewriter.MarkdownTableWriter(max_precision=3)
    writer.table_name = "results_table"
    writer.headers = results["results_table"]["headers"]
    writer.value_matrix = results["results_table"]["value_matrix"]
    writer.write_table()

    del results["results_table"]

    for filename in args.merge:
        # Merge other files
        with open(filename, "rb") as file_handler:
            results_2 = pickle.load(file_handler)
            del results_2["results_table"]
            for key in results.keys():
                if key in results_2:
                    for new_key in results_2[key].keys():
                        results[key][new_key] = results_2[key][new_key]

    keys = [key for key in results[next(iter(results.keys()))].keys() if key not in args.skip_keys]
    print(f"keys: {keys}")
    if len(args.keep_keys) > 0:
        keys = [key for key in keys if key in args.keep_keys]
    envs = [env for env in results.keys() if env not in args.skip_envs]

    if len(args.keep_envs) > 0:
        envs = [env for env in envs if env in args.keep_envs]

    labels = {key: key for key in keys}
    # args.labels = [
    #     r"Open-Loop Full",
    #     r"No $\omega_\text{swing}$",
    #     r"No $\varphi_i$",
    #     r"No $\omega_\text{swing}$ No $\varphi_i$",
    # ]
    if args.labels is not None:
        for key, label in zip(keys, args.labels):
            labels[key] = label

    # Convert to pandas dataframe, in order to use seaborn
    labels_df, envs_df, scores, param_num = [], [], [], []
    param_num_per_env, aggregated_count = {}, {}
    # Post-process to use it with rliable
    # algo: (n_runs, n_envs)
    normalized_score_dict = {}
    # algo: (n_runs, n_envs, n_eval)
    all_eval_normalized_scores_dict = {}
    # Convert env key to env id for normalization
    env_key_to_env_id = {
        "HalfCheetah": "HalfCheetah-v4",
        "Ant": "Ant-v4",
        "Hopper": "Hopper-v4",
        "Walker": "Walker2d-v4",
        "Swimmer": "Swimmer-v4",
    }
    # Backward compat
    skip_all_algos_dict = False

    for key in keys:
        algo_scores, all_algo_scores = [], []
        param_num_per_env[labels[key]] = []
        for env in envs:
            if isinstance(results[env][key]["last_evals"], (np.float32, np.float64)):
                # No enough timesteps
                print(f"Skipping {env}-{key}")
                continue

            for score in results[env][key]["last_evals"]:
                labels_df.append(labels[key])
                # convert to int if needed
                # labels_df.append(int(labels[key]))
                envs_df.append(env)
                scores.append(score)
                if args.count_parameters:
                    param_num.append(compute_parameter_numbers(labels[key], env_key_to_env_id.get(env, env)))

            algo_scores.append(results[env][key]["last_evals"])
            if args.count_parameters:
                param_num_per_env[labels[key]].append(param_num[-1])

            # Backward compat: mean_per_eval key may not be present
            if "mean_per_eval" in results[env][key]:
                all_algo_scores.append(results[env][key]["mean_per_eval"])
            else:
                skip_all_algos_dict = True

            # Normalize score, env key must match env_id
            # if env in env_key_to_env_id:
            algo_scores[-1] = normalize_score(algo_scores[-1], env_key_to_env_id.get(env, env))
            if not skip_all_algos_dict:
                all_algo_scores[-1] = normalize_score(all_algo_scores[-1], env_key_to_env_id.get(env, env))
            # elif env not in env_key_to_env_id and args.rliable:
            #     import warnings
            #     warnings.warn(f"{env} not found for normalizing scores, you should update `env_key_to_env_id`")

        # Truncate to convert to matrix
        min_runs = min(len(algo_score) for algo_score in algo_scores)
        if min_runs > 0:
            algo_scores = [algo_score[:min_runs] for algo_score in algo_scores]
            # shape: (n_envs, n_runs) -> (n_runs, n_envs)
            normalized_score_dict[labels[key]] = np.array(algo_scores).T
            if not skip_all_algos_dict:
                all_algo_scores = [all_algo_score[:, :min_runs] for all_algo_score in all_algo_scores]
                # (n_envs, n_eval, n_runs) -> (n_runs, n_envs, n_eval)
                all_eval_normalized_scores_dict[labels[key]] = np.array(all_algo_scores).transpose((2, 0, 1))

    if args.count_parameters:
        # Aggregate counts, only take one run per env
        # as the number of parameters doesn't change between runs
        reference_num = 13  # Open Loop
        for label in labels.values():
            mean_number = np.mean(param_num_per_env[label])
            factor = mean_number / reference_num
            std_error = np.std(param_num_per_env[label]) / np.sqrt(len(param_num_per_env[label]))
            print(f"{label:9}: n_params = {mean_number:7.0f} +/- {std_error:6.0f} factor: {factor:5.0f}x")
            aggregated_count[label] = np.array([mean_number, std_error])

    # Rliable plots, see https://github.com/google-research/rliable
    if args.rliable:
        print("Computing bootstrap CI ...")
        algorithms = list(labels.values())
        # Scores as a dictionary mapping algorithms to their normalized
        # score matrices, each of which is of size `(num_runs x num_envs)`.

        aggregate_func = lambda x: np.array(  # noqa: E731
            [
                metrics.aggregate_median(x),
                metrics.aggregate_iqm(x),
                # metrics.aggregate_mean(x),
                # metrics.aggregate_optimality_gap(x),
            ]
        )
        aggregate_scores, aggregate_interval_estimates = rly.get_interval_estimates(
            normalized_score_dict,
            aggregate_func,
            # Default was 50000
            reps=repetitions,  # Number of bootstrap replications.
            confidence_interval_size=args.ci_size,  # Coverage of confidence interval. Defaults to 95%.
        )

        if args.count_parameters:
            # set seaborn style
            seaborn.set(style="whitegrid", font_scale=1.5)

            markers = {
                "ARS": "o",
                "Open Loop": "h",
                "PPO": "X",
                "DDPG": "s",
                "SAC": "p",
            }

            plt.figure("Parameter Numbers", figsize=(9, 6))
            for label in labels.values():
                error = aggregate_interval_estimates[label][1, 1] - aggregate_interval_estimates[label][0, 1]
                plt.scatter(
                    aggregated_count[label][0],
                    aggregate_scores[label][1],
                    label=label,
                    s=250,
                    marker=markers[label],
                )
                plt.errorbar(
                    aggregated_count[label][0],
                    aggregate_scores[label][1],
                    # xerr=aggregated_count[label][1],
                    yerr=error,
                    markersize=100,
                    capsize=10,
                )

            # plt.yticks(fontsize="medium")
            # plt.xticks(fontsize="medium")
            plt.xlabel("Number of Parameters (log)", fontsize="large")
            plt.ylabel("Normalized Score", fontsize="large")

            plt.xscale("log")
            plt.legend()
            plt.tight_layout()
            plt.show()
            # reset to matplotlib defaults for the other plots
            matplotlib.rc_file_defaults()

        fig, axes = plot_utils.plot_interval_estimates(
            aggregate_scores,
            aggregate_interval_estimates,
            # metric_names=["Median", "IQM", "Mean", "Optimality Gap"],
            metric_names=["Median", "IQM"],
            algorithms=algorithms,
            xlabel="",
            # xlabel_y_coordinate=0.02,
            subfigure_width=5,
            row_height=0.5,
            # max_ticks=4,
            # interval_height=0.6,
        )
        fig.canvas.manager.set_window_title("Rliable metrics")
        # Adjust margin to see the x label
        fig.text(0.5, 0.02, "Normalized Score", ha="center", fontsize="xx-large")
        plt.tight_layout()
        plt.subplots_adjust(bottom=0.3)

        # Performance profiles
        # Normalized score thresholds
        normalized_score_thresholds = np.linspace(0.0, 1.5, 50)
        score_distributions, score_distributions_cis = rly.create_performance_profile(
            normalized_score_dict,
            normalized_score_thresholds,
            reps=repetitions,
            confidence_interval_size=args.ci_size,
        )
        # Plot score distributions
        fig, ax = plt.subplots(ncols=1, figsize=(8, 6))
        plot_utils.plot_performance_profiles(
            score_distributions,
            normalized_score_thresholds,
            performance_profile_cis=score_distributions_cis,
            colors=dict(zip(algorithms, seaborn.color_palette("colorblind"))),
            xlabel=r"Normalized Score $(\tau)$",
            ax=ax,
        )
        fig.canvas.manager.set_window_title("Performance profiles")
        plt.legend(fontsize=18)
        ax.set_xlabel(r"Normalized Score $(\tau)$", fontsize=18)
        ax.set_ylabel(r"Fraction of runs with score $> \tau$", fontsize=18)

        # Probability of improvement
        # Scores as a dictionary containing pairs of normalized score
        # matrices for pairs of algorithms we want to compare
        algorithm_pairs_keys = itertools.combinations(algorithms, 2)
        # algorithm_pairs = {.. , 'x,y': (score_x, score_y), ..}
        algorithm_pairs = {}
        for algo1, algo2 in algorithm_pairs_keys:
            # Only compare open loop
            if "open" not in algo1.lower():
                continue
            algorithm_pairs[f"{algo1}, {algo2}"] = (normalized_score_dict[algo1], normalized_score_dict[algo2])

        if args.versus:
            average_probabilities, average_prob_cis = rly.get_interval_estimates(
                algorithm_pairs,
                metrics.probability_of_improvement,
                reps=1000,  # Default was 50000
                confidence_interval_size=args.ci_size,
            )
            plot_utils.plot_probability_of_improvement(
                average_probabilities,
                average_prob_cis,
                # figsize=(10, 8),
                # interval_height=0.6,
            )
            plt.gcf().canvas.manager.set_window_title("Probability of Improvement")
            plt.tight_layout()

        plt.show()

    plt.show()


if __name__ == "__main__":
    plot_from_file()
