import csv
import itertools

if __name__ == '__main__':
    run_name = "2023_04_25_NearbyObsRecurrent2MoreSeeds"

    map_names = ["Pentagon", "ISR", "MIT", "SUNY"]
    randomize_starts = [True, False]
    shields = [
        ("pobs_label_cent", "shields/smv/gridworld_shield_nearby_obs"),
        ("pobs_label", "shields/smv/gridworld_shield_nearby_obs_sat"),
        ("pobs_label", "shields/smv/gridworld_shield_nearby_obs_naive"),
        ("none", None)]

    learner_anneal_eps = [(1.0, 0.0)]
    punish_unsafe_orig_actions = [(True, -10)]
    num_runs = 50

    with open(f"../../parallel_configs/{run_name}.csv", "w") as train_file, open(
            f"../../parallel_configs/{run_name}Eval.csv", "w") as eval_file:
        base_params = ["run_name", "shield", "shield_specification", "punish_unsafe_orig_action",
                       "punish_unsafe_orig_action_modifier", "randomize_starts", "map_type",
                       "grid_world_map_name", "grid_world_obs_type", "grid_world_nearby_obs_radius",
                       "learner_type", "learner_anneal_eps_start",
                       "learner_anneal_eps_finish", "max_total_steps",
                       "seed", "learner_evaluation_epsilon",
                       "learner_deep_network_model", "learner_transform_one_hot", "learner_clip_gradients",
                       "learner_discount"]
        train_writer = csv.DictWriter(train_file, base_params)
        train_writer.writeheader()

        eval_writer = csv.DictWriter(eval_file,
                                     base_params + ["skip_training", "evaluation_run_name", "evaluation_shield",
                                                    "evaluation_shield_specification"])
        eval_writer.writeheader()

        for run_type_idx, (map_name,
                           random_start, (shield, shield_specification),
                           (eps_anneal_start, eps_anneal_finish),
                           (punish_unsafe_action, unsafe_action_rew_modifier)) in enumerate(
            itertools.product(map_names, randomize_starts, shields, learner_anneal_eps,
                              punish_unsafe_orig_actions)):

            for run_num_of_same_type in range(num_runs):
                if run_num_of_same_type < 10:
                    continue

                global_run_idx = run_type_idx * num_runs + run_num_of_same_type

                concat_run_name = run_name + "/" + str(global_run_idx) + "_" + str(run_type_idx) + "_" + str(
                    run_num_of_same_type)

                base_param_values = {
                    "run_name": concat_run_name,
                    "shield": shield,
                    "shield_specification": shield_specification,
                    "punish_unsafe_orig_action": punish_unsafe_action,
                    "punish_unsafe_orig_action_modifier": unsafe_action_rew_modifier,
                    "randomize_starts": random_start,
                    "map_type": "GridWorld",
                    "grid_world_map_name": map_name,
                    "grid_world_obs_type": "NearbyObsSimpleDiscrete",
                    "grid_world_nearby_obs_radius": 2,
                    "learner_type": "Individual_Recurrent_Q",
                    "learner_deep_network_model": "simple_gru",
                    "learner_transform_one_hot": True,
                    "learner_clip_gradients": True,
                    "learner_discount": 0.98,
                    "learner_anneal_eps_start": eps_anneal_start,
                    "learner_anneal_eps_finish": eps_anneal_finish,
                    "max_total_steps": int(2.5e6),
                    "seed": run_num_of_same_type,
                    "learner_evaluation_epsilon": eps_anneal_finish
                }

                train_writer.writerow(base_param_values)

                if shield == "none":
                    for i, (eval_shield, eval_shield_specification) in enumerate(
                            shields):
                        eval_param_values = {
                            "skip_training": True,
                            "evaluation_run_name": concat_run_name + f"_eval_{i}",
                            "evaluation_shield": eval_shield,
                            "evaluation_shield_specification": eval_shield_specification,
                            **base_param_values
                        }
                        eval_writer.writerow(eval_param_values)
