from typing import Any, Dict

import numpy as np
import optuna

from open_loop.base_model import BaseModel


class Hopper(BaseModel):
    env_id: str = "Hopper-v4"
    kp: float = 10.0
    kd: float = 0.5
    n_joints: int = 3
    n_dim: int = 3

    @property
    def env_kwargs(self) -> Dict[str, Any]:
        # Original: healthy_angle_range=(-0.2, 0.2)
        # Note: healthy angle in hopper env seems to be too conservative
        # This reaches a score ~1960 but terminates early half the time
        # with original limits, we added two degrees in the limits (0.04 rad)
        return dict(healthy_angle_range=(-0.24, 0.24))

    def sample_params(self, trial: optuna.Trial, sample_coupling: bool = False) -> Dict[str, Any]:
        omega_swing = trial.suggest_float("omega_swing", 0.4, 5)
        omega_stance = trial.suggest_float("omega_stance", 0.4, 5)

        params = {}
        if sample_coupling:
            # Phase shifts are relative
            phase_shifts = np.zeros(self.n_dim)
            phase_shifts[0] = 0.0
            for idx in range(1, self.n_dim):
                phase_shifts[idx] = trial.suggest_float(f"phase_shift_{idx}", 0.0, 1.0)

            params = {f"phase_shift_{idx}": phase_shifts[idx] for idx in range(self.n_dim)}

        for idx in range(self.n_dim):
            params[f"amplitude_{idx}"] = trial.suggest_float(f"amplitude_{idx}", -1.0, 1.0)
            # params[f"offset_{idx}"] = trial.suggest_float(f"offset_{idx}", -1.0, 1.0)

        params.update(
            {
                "omega_swing": omega_swing,
                "omega_stance": omega_stance,
            }
        )
        return params
