import argparse
from functools import partial
from pathlib import Path
from typing import Any, Dict, Optional

import gymnasium as gym
import numpy as np
import optuna
from optuna.samplers import CmaEsSampler, RandomSampler, TPESampler

try:
    from stable_baselines3.common.monitor import Monitor

    MONITOR_AVAILABLE = True
except ImportError:
    MONITOR_AVAILABLE = False

from open_loop.ant import Ant
from open_loop.base_model import BaseModel, StoreDict
from open_loop.cheetah import HalfCheetah
from open_loop.hopper import Hopper
from open_loop.swimmer import Swimmer
from open_loop.walker import Walker


class StopOnMaxTimesteps:
    def __init__(self, max_timesteps: Optional[int] = None):
        super().__init__()
        self.max_timesteps = max_timesteps
        self._should_stop = False
        self.total_timesteps = 0

    def check(self) -> None:
        if self._should_stop:
            print(f"Max number of timesteps reached: {self.total_timesteps} >= {self.max_timesteps}")
            raise KeyboardInterrupt("Max number of timesteps reached")

    def update(self, n_timesteps: int) -> None:
        self.total_timesteps += n_timesteps
        if self.max_timesteps is not None:
            self._should_stop = self.total_timesteps >= self.max_timesteps


def objective(
    trial: optuna.Trial,
    env_id: str,
    n_eval_episodes: int,
    callback: StopOnMaxTimesteps,
    additional_env_kwargs: Dict[str, Any],
    render_mode: Optional[str] = None,
    sample_coupling: bool = False,
    log_path: Optional[str] = None,
) -> float:
    """
    Objective function using by Optuna to evaluate
    one configuration (i.e., one set of hyperparameters).

    Given a trial object, it will sample hyperparameters,
    evaluate it and report the result (mean episodic reward after training)

    :param trial: Optuna trial object
    :param env_id: Environment unique identifier
    :param n_eval_episodes: How many episodes to evaluate each trial
    :param callback: Callback to stop early
    :param render_mode:
    :param sample_coupling: Whether to sample phase shifts or not
    :param log_path: Where to save training rewards
    :return: Mean episodic reward after training
    """

    callback.check()

    model = {
        "Ant-v4": Ant,
        "Hopper-v4": Hopper,
        "HalfCheetah-v4": HalfCheetah,
        "Walker2d-v4": Walker,
        "Swimmer-v4": Swimmer,
    }[env_id]()

    env_kwargs = model.env_kwargs
    env_kwargs.update(additional_env_kwargs)
    env = gym.make(
        env_id,
        render_mode=render_mode,
        **env_kwargs,
    )  # type: ignore[arg-type]

    if MONITOR_AVAILABLE and log_path is not None:
        env = Monitor(env, log_path, override_existing=trial.number == 0)

    assert isinstance(model, BaseModel)
    env = model.init_model(env)

    # Create the model
    params = model.sample_params(trial, sample_coupling)

    nan_encountered = False
    # run the environment
    try:
        mean_episode_reward, n_timesteps = model.evaluate(params, env, n_eval_episodes)
        # Send report to Optuna
        trial.report(mean_episode_reward, trial.number)
        callback.update(n_timesteps)
    # except KeyboardInterrupt:
    #     pass
    except (AssertionError, SystemError) as e:
        # Sometimes, random hyperparams can generate NaN
        print(e)
        nan_encountered = True

    env.close()

    # Tell the optimizer that the trial failed
    if nan_encountered:
        return float("nan")

    return mean_episode_reward


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--env", type=str, help="Env id", default="Ant-v4")
    parser.add_argument(
        "--sampler",
        help="Sampler to use when optimizing hyperparameters",
        type=str,
        default="cmaes",
        choices=["random", "tpe", "cmaes"],
    )
    parser.add_argument(
        "--storage", help="Database storage path if distributed optimization should be used", type=str, default=None
    )
    parser.add_argument("-name", "--study-name", help="Study name for distributed optimization", type=str, default=None)
    parser.add_argument("--n-eval-episodes", help="Number of episode to evaluate a set of parameters", type=int, default=1)
    parser.add_argument("--render", action="store_true", default=False, help="Render")
    parser.add_argument("--sample-coupling", action="store_true", default=False, help="Also sample the coupling params")
    parser.add_argument("--seed", help="Random generator seed", type=int, default=-1)
    parser.add_argument("--pop-size", help="Initial population size for CMAES", type=int, default=30)
    parser.add_argument("--n-trials", help="Max number of trials for this process", type=int, default=1000)
    parser.add_argument("-f", "--log-filename", help="Log filename, careful when using multiple processes", type=str)
    parser.add_argument("--timeout", help="Timeout (in seconds)", type=int)
    parser.add_argument("-n", "--n-timesteps", help="Max number of timesteps", type=int)
    parser.add_argument(
        "--env-kwargs", type=str, nargs="+", action=StoreDict, help="Optional keyword argument to pass to the env constructor"
    )
    args = parser.parse_args()

    N_TRIALS = args.n_trials  # Maximum number of trials for this process
    N_STARTUP_TRIALS = args.pop_size  # Stop random sampling after N_STARTUP_TRIALS
    N_JOBS = 1  # Number of jobs to run in parallel
    # N_EVALUATIONS = 2  # Number of evaluations during the training
    # N_EVAL_EPISODES = 2  # number of episodes to run during evaluation

    render_mode = "human" if args.render else None
    callback = StopOnMaxTimesteps(args.n_timesteps)

    # TODO: seed also the env
    if args.seed < 0:
        # Seed but with a random one
        args.seed = np.random.randint(2**32 - 1, dtype="int64").item()  # type: ignore[attr-defined]

    print(f"Seed: {args.seed}")

    # Select the sampler, can be random, TPESampler, CMAES, ...
    sampler = {
        "tpe": TPESampler(n_startup_trials=N_STARTUP_TRIALS, multivariate=True, seed=args.seed),
        "cmaes": CmaEsSampler(
            seed=args.seed,
            restart_strategy="bipop",
            popsize=args.pop_size,
            n_startup_trials=N_STARTUP_TRIALS,
        ),
        "random": RandomSampler(seed=args.seed),
    }[args.sampler]

    if args.log_filename is not None:
        # Create folder if it doesn't exist
        Path(args.log_filename).parent.mkdir(parents=True, exist_ok=True)

    storage = args.storage
    if storage is not None and storage.endswith(".log"):
        # Create folder if it doesn't exist
        Path(storage).parent.mkdir(parents=True, exist_ok=True)
        storage = optuna.storages.JournalStorage(
            optuna.storages.JournalFileStorage(args.storage),
        )

    # Create the study and start the hyperparameter optimization
    study = optuna.create_study(
        sampler=sampler,
        storage=storage,
        study_name=args.study_name,
        load_if_exists=True,
        direction="maximize",
    )

    objective_fn = partial(
        objective,
        env_id=args.env,
        render_mode=render_mode,
        n_eval_episodes=args.n_eval_episodes,
        callback=callback,
        sample_coupling=args.sample_coupling,
        log_path=args.log_filename,
        additional_env_kwargs=args.env_kwargs or {},
    )

    try:
        study.optimize(objective_fn, n_trials=N_TRIALS, timeout=args.timeout)
    except KeyboardInterrupt:
        pass

    print(f"Number of finished trials: {len(study.trials)}")

    print("Best trial:")
    trial = study.best_trial

    print(f"  Value: {trial.value}")

    print("  Params: ")
    for key, value in trial.params.items():
        print(f"    {key}: {value}")

    if len(trial.user_attrs) > 0:
        print("  User attrs:")
        for key, value in trial.user_attrs.items():
            print(f"    {key}: {value}")
