from typing import Any, Dict

import numpy as np
import optuna

from open_loop.base_model import BaseModel


class Ant(BaseModel):
    env_id: str = "Ant-v4"
    kp: float = 1.0
    kd: float = 0.05
    n_joints: int = 8
    n_dim: int = 8
    xml_name: str = "ant.xml"  # xml with correct motor order

    def sample_params(self, trial: optuna.Trial, sample_coupling: bool = False) -> Dict[str, Any]:
        omega_swing = trial.suggest_float("omega_swing", 0.4, 2)
        omega_stance = trial.suggest_float("omega_stance", 0.4, 2)

        params = {}
        if sample_coupling:
            # Phase shifts are relative
            phase_shifts = np.zeros(self.n_dim)
            phase_shifts[0] = 0.0
            for idx in range(1, self.n_dim):
                phase_shifts[idx] = trial.suggest_float(f"phase_shift_{idx}", 0.0, 1.0)

            params = {f"phase_shift_{idx}": phase_shifts[idx] for idx in range(self.n_dim)}

        for idx in range(self.n_dim):
            params[f"offset_{idx}"] = trial.suggest_float(f"offset_{idx}", -1.0, 1.0)
            # params[f"amplitude_{idx}"] = trial.suggest_float(f"amplitude_{idx}", -1.0, 1.0)

        # Use symmetry for amplitudes (repeat amplitude for 1st and 2nd joint)
        amplitude_first = trial.suggest_float("amplitude_first", -1.0, 1.0)
        amplitude_second = trial.suggest_float("amplitude_second", -1.0, 1.0)

        params.update(
            {
                "omega_swing": omega_swing,
                "omega_stance": omega_stance,
                "amplitude_first": amplitude_first,
                "amplitude_second": amplitude_second,
            }
        )
        return params
