from typing import Dict, Optional, SupportsFloat, Tuple

import gymnasium as gym
import numpy as np
from gymnasium.envs.mujoco.mujoco_env import MujocoEnv


class ExternalForce(gym.Wrapper):
    """
    Apply external force impulse to a MuJoCo model.

    :param env: Gym environment
    :param force: Force to apply in N (x, y, z)
    :param proba: Probability to apply the force
    """

    mujoco_env: MujocoEnv

    def __init__(self, env: gym.Env, force: np.ndarray, proba: float = 1.0):
        super().__init__(env)

        mujoco_env = env.unwrapped
        assert isinstance(mujoco_env, MujocoEnv)
        self.mujoco_env = mujoco_env
        self.proba = proba
        self.force = force

    def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[np.ndarray, dict]:
        if seed is not None:
            np.random.seed(seed)
        return self.env.reset(seed=seed, options=options)

    def step(self, action: np.ndarray) -> Tuple[np.ndarray, SupportsFloat, bool, bool, dict]:  # type: ignore[override]
        if np.random.rand() < self.proba:
            # Randomize sign
            sign = (np.random.uniform(size=3) > 0.5) * 2 - 1.0
            force = self.force * sign
            # force + torque (6D)
            self.mujoco_env.data.body("torso").xfrc_applied[:3] = force

        return self.env.step(action)
