from centralized_verification.MultiAgentAPEnv import MultiAgentSafetyEnv
from centralized_verification.agents.decentralized_training.independent_agents.deep_q_learner import DeepQLearner
from centralized_verification.agents.decentralized_training.independent_agents.deep_recurrent_q_learner import \
    DeepRecurrentQLearner
from centralized_verification.agents.decentralized_training.independent_agents.ppo_learner import PPOAgent
from centralized_verification.agents.decentralized_training.independent_agents.sac_learner import SACAgent
from centralized_verification.agents.decentralized_training.independent_agents.tabular_q_learner import TabularQLearner
from centralized_verification.agents.decentralized_training.multi_agent_wrapper import MultiAgentLearnerWrapper
from centralized_verification.agents.utils import linear_epsilon_anneal_episodes, linear_epsilon_anneal_steps
from centralized_verification.models.larger_mlp import LargerMLP
from centralized_verification.models.norm_simple_mlp import NormSimpleMLP
from centralized_verification.models.simple_cnn import SimpleCNN
from centralized_verification.models.simple_gru import SimpleGRU
from centralized_verification.models.simple_mlp import SimpleMLP
from centralized_verification.models.tabular_gradient import TabularGradient
from experiments.utils.configuration.utils import max_steps_and_episodes, is_true, update_all_func

NN_MODEL_CLASSES = {
    "simple_mlp": SimpleMLP,
    "norm_simple_mlp": NormSimpleMLP,
    "simple_cnn": SimpleCNN,
    "larger_mlp": LargerMLP,
    "tabular_gradient": TabularGradient
}

RECURRENT_NN_MODEL_CLASSES = {
    "simple_gru": SimpleGRU
}


def get_epsilon_scheduler_learner_params(params):
    max_total_steps, max_num_episodes = max_steps_and_episodes(params)
    if max_total_steps is None:
        eps_scheduler = linear_epsilon_anneal_episodes(float(params["learner_anneal_eps_start"]),
                                                       float(params["learner_anneal_eps_finish"]),
                                                       max_num_episodes)
    else:
        eps_scheduler = linear_epsilon_anneal_steps(float(params["learner_anneal_eps_start"]),
                                                    float(params["learner_anneal_eps_finish"]),
                                                    int(params.get("learner_anneal_eps_finish_steps",
                                                                   max_total_steps)))

    return {
        "epsilon_scheduler": eps_scheduler,
        "evaluation_epsilon": float(params.get("learner_evaluation_epsilon", 0))
    }


def one_hot_from_params(params):
    return {
        "make_multidiscrete_one_hot": is_true(params.get("learner_transform_one_hot"))
    }


def buffer_size_from_params(params):
    return {
        "buffer_size": int(params.get("learner_replay_buffer_size", 1e5))
    }


def individual_deep_q_specific_params(params):
    nn_model_class = NN_MODEL_CLASSES[params["learner_deep_network_model"]]
    return {"model_class": nn_model_class, "clip_gradients": is_true(params.get("learner_clip_gradients"))}


def recurrent_deep_q_specific_params(params):
    nn_model_class = RECURRENT_NN_MODEL_CLASSES[params["learner_deep_network_model"]]
    return {"model_class": nn_model_class, "clip_gradients": is_true(params.get("learner_clip_gradients"))}


def individual_ppo_specific_params(params):
    return {"policy_model_class": NN_MODEL_CLASSES[params["learner_ppo_actor_model_class"]],
            "critic_model_class": NN_MODEL_CLASSES[params["learner_ppo_critic_model_class"]]}


INDIVIDUAL_LEARNER_TYPES = {
    "Individual_Q": (TabularQLearner, [get_epsilon_scheduler_learner_params]),
    "Individual_Deep_Q": (DeepQLearner,
                          [get_epsilon_scheduler_learner_params, one_hot_from_params, buffer_size_from_params,
                           individual_deep_q_specific_params]),
    "Individual_SAC": (SACAgent, [one_hot_from_params, buffer_size_from_params]),
    "Individual_PPO": (PPOAgent, [one_hot_from_params, buffer_size_from_params, individual_ppo_specific_params]),
    "Individual_Recurrent_Q": (DeepRecurrentQLearner, [get_epsilon_scheduler_learner_params, one_hot_from_params,
                                                       buffer_size_from_params, recurrent_deep_q_specific_params])
}


def construct_agent(params, env: MultiAgentSafetyEnv):
    learner_params = {
        "discount": float(params.get("learner_discount", 0.9))
    }

    agent_cls, param_processors = INDIVIDUAL_LEARNER_TYPES[params["learner_type"]]
    learner_params.update(update_all_func(param_processors, [params]))

    agents = list(agent_cls(obs_space, action_space, **learner_params) for obs_space, action_space in
                  zip(env.agent_obs_spaces(), env.agent_actions_spaces()))

    return MultiAgentLearnerWrapper(agents)
