# RL Baselines

We are using [Stable Baselines Jax implementations (SBX)](https://github.com/araffin/sbx) for faster runtime (and ARS from SB3-Contrib),
together with the [RL Zoo](https://github.com/DLR-RM/rl-baselines3-zoo) for managing experiments.
SBX implementations match SB3 PyTorch performance but are much faster to run thanks to Jax JIT compilation.


## Installation

1. First install Jax, see [instructions](https://github.com/google/jax) in their repo

2. Install SBX: `pip install sbx-rl==0.8.0`

3. Install RL Zoo: `pip install 'rl_zoo3[plots]'`


## Run the Experiments

SAC:
```
./scripts/run_sac.sh
```

PPO:
```
./scripts/run_ppo.sh
```

Note: we use defaults hyperparameters from the original papers except for Swimmer-4 where we use tuned hyperparameters (see shell scripts), otherwise the RL algorithms never reach the 300 score barrier.

Note: in case you want to use PyTorch backend instead of Jax, you can simply replace `python train.py` by `python -m rl_zoo3.train` in the shell scripts.

Robustness experiments:
```
./scripts/evaluate_robustness.sh
```

## Plots

All plots, save to file:

```
python -m rl_zoo3.cli all_plots -f logs/ -a sac ppo ddpg ars -e Ant-v4 HalfCheetah-v4 Hopper-v4 Swimmer-v4 Walker2d-v4 -o logs/sac_ppo_ddpg_ars
```

Plot from file:
```
python -m rl_zoo3.cli plot_from_file -i logs/sac_ppo_ddpg_ars.pkl
```

```
python plots/plot_from_file.py -i logs/sac_ppo_ddpg_ars.pkl
```

Plot number of parameters:
```
CUDA_VISIBLE_DEVICES= python plots/plot_from_file.py -i ../open-loop-mujoco/logs/results_open_loop.pkl --skip-timesteps -r --merge logs/sac_ppo_ddpg_ars.pkl -l "Open Loop" "SAC" "PPO" "DDPG" "ARS" -count-param

```

Robustness plots
```
python plots/plot_robustness.py -a sac ppo ddpg -e Ant-v4 HalfCheetah-v4 Hopper-v4 Swimmer-v4 Walker2d-v4 -f logs/robustness/no_noise/ logs/robustness/zero_value/ logs/robustness/constant_value/ -l "No Failure" "Type I \\\\ Zero Value" "Type II \\\\ Constant Value" -ci 95 --latex
```

```
python plots/plot_robustness.py -a sac open_loop ppo ars ddpg -e Ant-v4 HalfCheetah-v4 Hopper-v4 Swimmer-v4 Walker2d-v4 -f logs/robustness/no_noise/ logs/robustness/noise_std_025/ logs/robustness/noise_std_05/ logs/robustness/zero_value/ logs/robustness/constant_value/ logs/robustness/external_force/ -l "No Noise or Failure" "$\\sigma = 0.25$" "$\\sigma = 0.5\$" "Type I \\\\ Zero Value" "Type II \\\\ Constant Value" "External Perturbation" -ci 95 --latex
```

Performance profiles:
```
python plots/plot_from_file.py -i ../open-loop-mujoco/logs/results_open_loop.pkl --skip-timesteps -r --merge logs/sac_ppo_ddpg_ars.pkl -l "Open Loop" "SAC" "PPO" "DDPG" "ARS" -vs -ci 0.95
```

Ablation study:
```
python plots/plot_from_file.py -i ../open-loop-mujoco/logs/open_loop_ablation.pkl --skip-timesteps -r -l "Open Loop Full" "No Omega" "No Phi" "No Phi No Omega" -ci 0.95 -latex
```
