import argparse

import optuna

from open_loop.ant import Ant
from open_loop.base_model import BaseModel, StoreDict
from open_loop.cheetah import HalfCheetah
from open_loop.hopper import Hopper
from open_loop.swimmer import Swimmer
from open_loop.walker import Walker

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--env", type=str, help="Env id", default="Ant-v4")
    parser.add_argument("--render", action="store_true", default=False, help="Render")
    parser.add_argument(
        "-disturbances",
        "--apply-external-force",
        action="store_true",
        default=False,
        help="Apply external force impulse on the model.",
    )
    parser.add_argument("-n", "--n-eval-episodes", type=int, default=20, help="Number of evaluation episodes.")
    parser.add_argument("--seed", type=int, help="Seed for the pseudo random generator")
    parser.add_argument("-o", "--output", type=str, help="Log path to save return at test time")
    parser.add_argument("-name", "--study-name", help="Study name when loading Optuna results", type=str)
    parser.add_argument("-id", "--trial-id", help="Trial id to load, otherwise loading best trial", type=int)
    parser.add_argument("--storage", help="Database storage path if distributed optimization should be used", type=str)
    parser.add_argument(
        "--env-kwargs", type=str, nargs="+", action=StoreDict, help="Optional keyword argument to pass to the env constructor"
    )
    args = parser.parse_args()

    model = {
        "Ant-v4": Ant,
        "Hopper-v4": Hopper,
        "HalfCheetah-v4": HalfCheetah,
        "Walker2d-v4": Walker,
        "Swimmer-v4": Swimmer,
    }[args.env]()
    assert isinstance(model, BaseModel)

    storage = args.storage
    if storage is not None and storage.endswith(".log"):
        storage = optuna.storages.JournalStorage(
            optuna.storages.JournalFileStorage(args.storage),
        )

    env = model.make_env(
        args.render,
        additional_env_kwargs=args.env_kwargs,
        apply_external_force=args.apply_external_force,
    )
    if args.seed is not None:
        env.reset(seed=args.seed)

    if storage is not None and args.study_name is not None:
        study = optuna.load_study(storage=storage, study_name=args.study_name)
        if args.trial_id is not None:
            params = study.trials[args.trial_id].params
        else:
            params = study.best_trial.params
    else:
        params = model.read_params()

    for param, value in params.items():
        print(f"{param}:{value}")

    # Phase shifts are relative
    if "phase_shift_1" in params and "phase_shift_0" not in params:
        params["phase_shift_0"] = 0.0
    else:
        # Default to all zeros, all joints synchronized
        pass

    if "omega_stance" not in params:
        params["omega_stance"] = params["omega_swing"]

    if "desired_step_len" not in params:
        params["desired_step_len"] = 0.0

    try:
        model.evaluate(params, env, args.n_eval_episodes, render=args.render, verbose=2, log_path=args.output)
    except KeyboardInterrupt:
        pass
    # Disconnect from pybullet client
    del env
