import pandas as pd
import wandb

from experiments.analysis.utils.format import format_to_latex
from experiments.analysis.utils.gridworld import run_table_experiments_2, \
    run_table_gridworld_1a
from experiments.analysis.utils.particle import run_table_flashlight, run_table_particle_smv


def run():
    api = wandb.Api()

    generate_1a(api)
    generate_1b(api)
    generate_2a(api)
    generate_2b(api)
    generate_3a(api)
    generate_3b(api)


def generate_1a(api):
    no_rand_full, no_rand_diag = run_table_gridworld_1a(api.runs("[username]/Centralized-Verification-SMV-Replicate",
                                                                 filters={"config.randomize_starts": "False"}))

    format_to_latex(no_rand_full, "smv_replicate_no_random_start")
    format_to_latex(no_rand_diag, "smv_replicate_no_random_start_diag")

    rand_full, rand_diag = run_table_gridworld_1a(api.runs("[username]/Centralized-Verification-SMV-Replicate",
                                                           filters={"config.randomize_starts": "True"}))

    format_to_latex(rand_full, "smv_replicate_random_start")
    format_to_latex(rand_diag, "smv_replicate_random_start_diag")

    format_to_latex(pd.concat((no_rand_diag, rand_diag), keys=("Fixed", "Random"), names=["Start Type"]),
                    "smv_replicate_diag")


def generate_1b(api):
    particle_diag = run_table_particle_smv(api.runs("[username]/Centralized-Verification-Particle-Replicate"))
    format_to_latex(particle_diag, "particle_momentum_replicate")


def generate_2a(api):
    no_rand_nearby_2_diag = run_table_experiments_2(
        api.runs("[username]/Centralized-Verification-Nearby-Obs-Recurrent-2",
                 filters={"config.randomize_starts": "False"}))

    format_to_latex(no_rand_nearby_2_diag, "nearby_obs_recurrent_2_no_random_start_diag")

    rand_nearby_2_diag = run_table_experiments_2(api.runs("[username]/Centralized-Verification-Nearby-Obs-Recurrent-2",
                                                          filters={"config.randomize_starts": "True"}))

    format_to_latex(rand_nearby_2_diag, "nearby_obs_recurrent_2_random_start_diag")

    format_to_latex(
        pd.concat((no_rand_nearby_2_diag, rand_nearby_2_diag), keys=("Fixed", "Random"), names=["Start Type"]),
        "nearby_obs_2_recurrent_diag")


def generate_2b(api):
    no_rand_nearby_diag = run_table_experiments_2(api.runs("[username]/Centralized-Verification-Nearby-Obs-Recurrent",
                                                           filters={"config.randomize_starts": "False"}))

    format_to_latex(no_rand_nearby_diag, "nearby_obs_recurrent_1_no_random_start_diag")

    rand_nearby_diag = run_table_experiments_2(api.runs("[username]/Centralized-Verification-Nearby-Obs-Recurrent",
                                                        filters={"config.randomize_starts": "True"}))

    format_to_latex(rand_nearby_diag, "nearby_obs_recurrent_1_random_start_diag")

    format_to_latex(pd.concat((no_rand_nearby_diag, rand_nearby_diag), keys=("Fixed", "Random"), names=["Start Type"]),
                    "nearby_obs_1_recurrent_diag")


def generate_3a(api):
    diag_table = run_table_flashlight(api.runs("[username]/Centralized-Verification-Flashlight-Small",
                                               filters={"config.skip_training": None,
                                                        "State": "finished"}))
    format_to_latex(diag_table, "flashlight_small")


def generate_3b(api):
    diag_table = run_table_flashlight(api.runs("[username]/Centralized-Verification-Flashlight",
                                               filters={"config.skip_training": None}))
    format_to_latex(diag_table, "flashlight_large")


if __name__ == '__main__':
    run()
