from experiments.analysis.utils.format import create_table_from_run_set, create_full_table, create_diag_table


def run_table_particle(run_set):
    df = create_table_from_run_set(run_set, {
        "observe_momentum": "particle_agents_observe_momentum",
        "random_start": "randomize_starts"
    }, True, ["observe_momentum", "random_start"])

    df.index.set_levels(df.index.levels[0].map({"False": "Partial", "True": "Full"}).rename("Observability"), 0,
                        inplace=True)
    df.index.set_levels(df.index.levels[1].map({"False": "Fixed", "True": "Random"}).rename("Start Type"), 1,
                        inplace=True)
    df.index.set_levels(df.index.levels[2].rename("Shield"), 2, inplace=True)

    df_full = create_full_table(df, {"observe_momentum": "Observe Momentum", "random_start": "Random Start"})
    df_diag = create_diag_table(df, {"observe_momentum": "Observe Momentum", "random_start": "Random Start"}, True)

    return df_full, df_diag


def change_shield_name(shield_field_name, spec_field_name, df):
    new_df = df.copy()
    for i in df.index:
        if df.at[i, shield_field_name] == "Decentralized Partial Obs":
            if df.at[i, spec_field_name].endswith("naive"):
                new_df.at[i, shield_field_name] = "Decentralized Partial Obs (Naive)"
            elif df.at[i, spec_field_name].endswith("sat"):
                new_df.at[i, shield_field_name] = "Decentralized Partial Obs (SAT)"
            else:
                raise ValueError("Unknown shield specification: {}".format(df.at[i, "shield_specification"]))
    return new_df


def change_shield_name_flashlight(shield_field_name, spec_field_name, df):
    new_df = df.copy()
    for i in df.index:
        if df.at[i, shield_field_name] == "Decentralized Partial Obs":
            if df.at[i, spec_field_name].endswith("0"):
                new_df.at[i, shield_field_name] = "Decentralized Partial Obs (0)"
            elif df.at[i, spec_field_name].endswith("1"):
                new_df.at[i, shield_field_name] = "Decentralized Partial Obs (1)"
            elif df.at[i, spec_field_name].endswith("2"):
                new_df.at[i, shield_field_name] = "Decentralized Partial Obs (2)"
            else:
                raise ValueError("Unknown shield specification: {}".format(df.at[i, "shield_specification"]))
    return new_df

def run_table_particle_smv(run_set):
    df = create_table_from_run_set(run_set, {
        "observe_momentum": "particle_agents_observe_momentum",
        "random_start": "randomize_starts",
        "shield_specification": "shield_specification"
    }, False, ["observe_momentum", "random_start"],
                                   pregroup_hook=lambda d: change_shield_name("shield", "shield_specification", d))

    df.index.set_levels(df.index.levels[0].map({"False": "Partial", "True": "Full"}).rename("Observability"), 0,
                        inplace=True)
    df.index.set_levels(df.index.levels[1].map({"False": "Fixed", "True": "Random"}).rename("Start Type"), 1,
                        inplace=True)
    df.index.set_levels(df.index.levels[2].rename("Shield"), 2, inplace=True)

    df_diag = create_diag_table(df, {"observe_momentum": "Observe Momentum", "random_start": "Random Start"}, False)

    return df_diag


def run_table_flashlight(run_set):
    df = create_table_from_run_set(run_set, {
        "flashlight_recharge_time": "grid_world_flashlight_obs_recharge_time",
        "random_start": "randomize_starts",
        "shield_specification": "shield_specification"
    }, False, ["random_start", "flashlight_recharge_time"],
                                   pregroup_hook=lambda d: change_shield_name_flashlight("shield",
                                                                                         "shield_specification", d),
                                   primary_metric_name="test/avg_discounted_rew_sum_0")

    df.index.set_levels(df.index.levels[0].map({"False": "Fixed", "True": "Random"}).rename("Start Type"), 0,
                        inplace=True)
    df.index.set_levels(df.index.levels[1].rename("Recharge Time"), 1,
                        inplace=True)
    df.index.set_levels(df.index.levels[2].rename("Shield"), 2, inplace=True)

    df_diag = create_diag_table(df, {"flashlight_recharge_time": "Flashlight Recharge Time",
                                     "random_start": "Random Start"}, False)
    return df_diag


def run_table_no_shield_all_evals(run_set):
    df = create_table_from_run_set(run_set, {
        "observe_momentum": "particle_agents_observe_momentum",
        "random_start": "randomize_starts",
        "shield_specification": "shield_specification",
        "eval_shield_specification": "evaluation_shield_specification"
    }, True, ["observe_momentum", "random_start"], pregroup_hook=lambda d: change_shield_name("eval_shield",
                                                                                              "eval_shield_specification",
                                                                                              d))

    df.index.set_levels(df.index.levels[0].map({"False": "Partial", "True": "Full"}).rename("Observability"), 0,
                        inplace=True)
    df.index.set_levels(df.index.levels[1].map({"False": "Fixed", "True": "Random"}).rename("Start Type"), 1,
                        inplace=True)
    df.reset_index(level=2, drop=True, inplace=True)
    df.index = df.index.rename({"eval_shield": "shield"})

    df_diag = create_diag_table(df, {"observe_momentum": "Observe Momentum", "random_start": "Random Start"}, False)
    return df_diag
