from centralized_verification.configuration import TestingLimits, Configuration, TrainingLimits, TestConfiguration
from experiments.utils.configuration.env_config import construct_env
from experiments.utils.configuration.learner_config import construct_agent
from experiments.utils.configuration.utils import max_steps_and_episodes, set_all_seeds


def get_config_from_params(shield_constructor, params):
    seed = int(params["seed"])
    set_all_seeds(seed)

    run_name = params["run_name"]
    max_total_steps, max_num_episodes = max_steps_and_episodes(params)

    env = construct_env(params)
    multi_agent = construct_agent(params, env)
    shield, evaluation_shield = shield_constructor(params, env)

    discount = float(params.get("learner_discount", 0.9))

    config = Configuration(
        shield=shield,
        env=env,
        learner=multi_agent,
        run_name=run_name,
        limits=TrainingLimits(max_episode_len=500, max_total_steps=max_total_steps, max_num_episodes=max_num_episodes),
        num_log_entries=200,
        num_checkpoints=10,
        discount=discount
    )

    evaluation_config = TestConfiguration(
        shield=evaluation_shield,
        env=env,
        agent=multi_agent,
        run_name=params.get("evaluation_run_name", run_name),
        limits=TestingLimits(max_episode_len=500, num_episodes=100),
        discount=discount
    )

    return config, evaluation_config
