from typing import Dict

import pandas as pd
from pandas.io.formats.style import Styler


def create_full_table(df, index_rename):
    df_full = df.unstack(level="eval_shield")
    df_full.index = df_full.index.rename({"shield": "Shield", **index_rename})
    df_full.columns = df_full.columns.get_level_values("eval_shield").values
    return df_full


def create_diag_table(df, index_rename, from_eval):
    df_diag = pd.concat((df.index.to_frame(), df), axis=1)

    if from_eval:
        df_diag = df_diag[df_diag["shield"] == df_diag["eval_shield"]]
        del df_diag["eval_shield"]

    del df_diag["shield"]
    for k in index_rename.keys():
        del df_diag[k]

    df_diag.index = df_diag.index.rename({"shield": "Shield", **index_rename})
    if from_eval:
        df_diag = df_diag.reset_index(level=-1, drop=True)

    df_diag = df_diag.unstack().rename(columns=name_map)
    df_diag.columns = df_diag.columns.get_level_values(1)

    return df_diag


def create_table_from_run_set(run_set, extra_columns: Dict[str, str], separate_eval, extra_agg, pregroup_hook=None,
                              primary_metric_name="test/avg_rew_sum_0",
                              secondary_metric_name="test/avg_unsafe_actions"):
    columns = {
        "shield": "shield",
        "safety_violations": secondary_metric_name,
        "rew": primary_metric_name,
        **extra_columns
    }

    if separate_eval:
        columns["eval_shield"] = "evaluation_shield"

    if separate_eval:
        agg = [*extra_agg, "shield", "eval_shield"]
    else:
        agg = [*extra_agg, "shield"]

    data = {dest_name: [] for dest_name in columns.keys()}
    for r in run_set:
        config = r.config
        summary = r.summary._json_dict

        config_and_summary = {**config, **summary}

        if separate_eval and "skip_training" not in config:
            continue

        for dest_name, src_name in columns.items():
            data[dest_name].append(config_and_summary[src_name])

    df = pd.DataFrame(data)

    df["shield"] = df["shield"].map(name_map)
    if separate_eval:
        df["eval_shield"] = df["eval_shield"].map(name_map)

    if pregroup_hook is not None:
        df = pregroup_hook(df)

    df.sort_values(agg)
    df = df.groupby(agg).agg({'safety_violations': ["mean"], 'rew': ["mean", "sem"]})
    df["formatted_output"] = df.apply(
        lambda x: f"{x['rew']['mean']:2.1f} $\\pm$ {x['rew']['sem']:1.1f} " +
                  (f"({(x['safety_violations']['mean'] * 100):1.01f})" if x['safety_violations'][
                                                                              'mean'] > 0 else "(0)"),
        axis=1)
    del df["rew"]
    del df["safety_violations"]

    return df


def format_to_latex(table, name, **override):
    latex_format = {
        "multirow_align": "c",
        "column_format": "lllll",
        "hrules": True,
        "clines": "skip-last;data"
    }

    latex_format.update(override)

    Styler(table).to_latex(name + ".tex", **latex_format)


name_map = {
    "centralized": "Centralized",
    "slugs_centralized": "Centralized Full Obs",
    "decentralized": "Decentralized",
    "slugs_decentralized": "Decentralized Full Obs",
    "pobs_label": "Decentralized Partial Obs",
    "pobs_label_cent": "Centralized Partial Obs",
    "none": "No Shield"
}
