from typing import Dict, Optional, SupportsFloat, Tuple

import gymnasium as gym
import numpy as np
from gymnasium.envs.mujoco.mujoco_env import MujocoEnv


class PDController(gym.Wrapper):
    """
    PD controller for MuJoCo environments, the action space is in torque space,
    it allows to control the robot in joint space.

    See https://gymnasium.farama.org/environments/mujoco
    for env descriptions.

    :param env: Gym environment
    :param n_joints: Number of joints
    :param kp: Proportional gain
    :param kd: Derivative gain (velocity gain)
    """

    mujoco_env: MujocoEnv

    def __init__(self, env: gym.Env, n_joints: int, kp: float = 1, kd: float = 0.1):
        super().__init__(env)
        self.kp = kp
        self.kd = kd
        self.n_joints = n_joints
        self.last_qpos = np.zeros(self.n_joints)
        self.last_qvel = np.zeros(self.n_joints)
        try:
            mujoco_env = env.unwrapped
            assert isinstance(mujoco_env, MujocoEnv)
            self.mujoco_env = mujoco_env
        except AssertionError:
            pass

    @property
    def current_joint_pos(self) -> np.ndarray:
        return self.mujoco_env.data.qpos[-self.n_joints :]

    @property
    def current_joint_vel(self) -> np.ndarray:
        return self.mujoco_env.data.qvel[-self.n_joints :]

    def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[np.ndarray, dict]:
        obs, info = self.env.reset(seed=seed, options=options)
        self.last_qpos = self.current_joint_pos.copy()
        self.last_qvel = self.current_joint_vel.copy()
        return obs, info

    def step(self, action: np.ndarray) -> Tuple[np.ndarray, SupportsFloat, bool, bool, dict]:  # type: ignore[override]
        desired_qpos = action
        qpos_err = desired_qpos - self.last_qpos
        # desired qvel is zero
        qvel_err = -self.last_qvel
        action = self.kp * qpos_err + self.kd * qvel_err
        # Clip to correct action range
        action = np.clip(action, -1.0, 1.0)
        obs, reward, terminated, truncated, info = self.env.step(action)
        self.last_qpos = self.current_joint_pos.copy()
        self.last_qvel = self.current_joint_vel.copy()

        return obs, reward, terminated, truncated, info


class SkipPDController(PDController):
    def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[np.ndarray, dict]:
        return self.env.reset(seed=seed, options=options)

    def step(self, action: np.ndarray) -> Tuple[np.ndarray, SupportsFloat, bool, bool, dict]:  # type: ignore[override]
        return self.env.step(action)


class BipedalPDController(PDController):
    @property
    def current_joint_pos(self) -> np.ndarray:
        return np.array([self.env.unwrapped.joints[i].angle for i in range(self.n_joints)])  # type: ignore[attr-defined]

    @property
    def current_joint_vel(self) -> np.ndarray:
        return np.array([self.env.unwrapped.joints[i].speed for i in range(self.n_joints)])  # type: ignore[attr-defined]
