Compare commits

...

6 Commits

Author SHA1 Message Date
Lukas Schneider 96393c41c5
Update runner.py to load models to correct device 2024-04-03 13:25:58 -07:00
Lukas Schneider c7e950439d
Update setup.py dependency list 2024-03-24 15:56:09 -07:00
Lukas Schneider 9c0fcdc677
Update setup.py dependency versions 2024-03-24 12:25:56 -07:00
Lukas Schneider 7ce11711d4
Update setup.py package versions 2024-03-24 12:12:03 -07:00
Lukas Schneider dc9f33a3c3
Added citation to README 2024-01-30 11:15:48 -08:00
Lukas Schneider 96dd4929c5 added rewrite of rsl_rl for supporting additional algorithms 2023-12-12 18:32:21 +01:00
79 changed files with 7856 additions and 1172 deletions

11
.gitignore vendored
View File

@ -7,6 +7,17 @@
# cache
__pycache__
.pytest_cache
wandb/
# vs code
.vscode
# data
videos/
# secrets
examples/wandb_config.py
# docs
docs/_build
docs/source

View File

@ -25,6 +25,7 @@ Please keep the lists sorted alphabetically.
* Eric Vollenweider
* Fabian Jenelten
* Lorenzo Terenzi
* Lukas Schneider
* Marko Bjelonic
* Matthijs van der Boon
* Mayank Mittal

View File

@ -1,57 +1,78 @@
# RSL RL
Fast and simple implementation of RL algorithms, designed to run fully on GPU.
This code is an evolution of `rl-pytorch` provided with NVIDIA's Isaac GYM.
Only PPO is implemented for now. More algorithms will be added later.
Contributions are welcome.
Currently, the following algorithms are implemented:
- Distributed Distributional DDPG (D4PG)
- Deep Deterministic Policy Gradient (DDPG)
- Distributional PPO (DPPO)
- Distributional Soft Actor Critic (DSAC)
- Proximal Policy Optimization (PPO)
- Soft Actor Critic (SAC)
- Twin Delayed DDPG (TD3)
**Maintainer**: David Hoeller and Nikita Rudin <br/>
**Maintainer**: David Hoeller, Nikita Rudin <br/>
**Affiliation**: Robotic Systems Lab, ETH Zurich & NVIDIA <br/>
**Contact**: rudinn@ethz.ch
**Contact**: Nikita Rudin (rudinn@ethz.ch), Lukas Schneider (lukas@luschneider.com)
## Setup
## Citation
Following are the instructions to setup the repository for your workspace:
```bash
git clone https://github.com/leggedrobotics/rsl_rl
cd rsl_rl
pip install -e .
If you use our code in your research, please cite us:
```
@misc{schneider2023learning,
archivePrefix={arXiv},
author={Lukas Schneider and Jonas Frey and Takahiro Miki and Marco Hutter},
eprint={2309.14246},
primaryClass={cs.RO}
title={Learning Risk-Aware Quadrupedal Locomotion using Distributional Reinforcement Learning},
year={2023},
}
```
The framework supports the following logging frameworks which can be configured through `logger`:
## Installation
* Tensorboard: https://www.tensorflow.org/tensorboard/
* Weights & Biases: https://wandb.ai/site
* Neptune: https://docs.neptune.ai/
To install the package, run the following command in the root directory of the repository:
For a demo configuration of the PPO, please check: [dummy_config.yaml](config/dummy_config.yaml) file.
```bash
$ pip3 install -e .
```
Examples can be run from the `examples/` directory.
The example directory also include hyperparameters tuned for some gym environments.
These are automatically loaded when running the example.
Videos of the trained policies are periodically saved to the `videos/` directory.
```bash
$ python3 examples/example.py
```
To run gym mujoco environments, you need a working installation of the mujoco simulator and [mujoco_py](https://github.com/openai/mujoco-py).
## Tests
The repository contains a set of tests to ensure that the algorithms are working as expected.
To run the tests, simply execute:
```bash
$ cd tests/ && python -m unittest
```
## Documentation
To generate documentation, run the following command in the root directory of the repository:
```bash
$ pip3 install sphinx sphinx-rtd-theme
$ sphinx-apidoc -o docs/source . ./examples
$ cd docs/ && make html
```
## Contribution Guidelines
For documentation, we adopt the [Google Style Guide](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) for docstrings. We use [Sphinx](https://www.sphinx-doc.org/en/master/) for generating the documentation. Please make sure that your code is well-documented and follows the guidelines.
We use the following tools for maintaining code quality:
- [pre-commit](https://pre-commit.com/): Runs a list of formatters and linters over the codebase.
- [black](https://black.readthedocs.io/en/stable/): The uncompromising code formatter.
- [flake8](https://flake8.pycqa.org/en/latest/): A wrapper around PyFlakes, pycodestyle, and McCabe complexity checker.
Please check [here](https://pre-commit.com/#install) for instructions to set these up. To run over the entire repository, please execute the following command in the terminal:
We use [`black`](https://github.com/psf/black) formatter for formatting the python code.
You should [configure `black` with VSCode](https://dev.to/adamlombard/how-to-use-the-black-python-code-formatter-in-vscode-3lo0) or you can manually format files with:
```bash
# for installation (only once)
pre-commit install
# for running
pre-commit run --all-files
$ pip install black
$ black --line-length 120 .
```
### Useful Links
Environment repositories using the framework:
* `Legged-Gym` (built on top of NVIDIA Isaac Gym): https://leggedrobotics.github.io/legged_gym/
* `Orbit` (built on top of NVIDIA Isaac Sim): https://isaac-orbit.github.io/

View File

@ -1,48 +0,0 @@
algorithm:
class_name: PPO
# training parameters
# -- value function
value_loss_coef: 1.0
clip_param: 0.2
use_clipped_value_loss: true
# -- surrogate loss
desired_kl: 0.01
entropy_coef: 0.01
gamma: 0.99
lam: 0.95
max_grad_norm: 1.0
# -- training
learning_rate: 0.001
num_learning_epochs: 5
num_mini_batches: 4 # mini batch size = num_envs * num_steps / num_mini_batches
schedule: adaptive # adaptive, fixed
policy:
class_name: ActorCritic
# for MLP i.e. `ActorCritic`
activation: elu
actor_hidden_dims: [128, 128, 128]
critic_hidden_dims: [128, 128, 128]
init_noise_std: 1.0
# only needed for `ActorCriticRecurrent`
# rnn_type: 'lstm'
# rnn_hidden_size: 512
# rnn_num_layers: 1
runner:
num_steps_per_env: 24 # number of steps per environment per iteration
max_iterations: 1500 # number of policy updates
empirical_normalization: false
# -- logging parameters
save_interval: 50 # check for potential saves every `save_interval` iterations
experiment_name: walking_experiment
run_name: ""
# -- logging writer
logger: tensorboard # tensorboard, neptune, wandb
neptune_project: legged_gym
wandb_project: legged_gym
# -- load and resuming
resume: false
load_run: -1 # -1 means load latest run
resume_path: null # updated from load_run and checkpoint
checkpoint: -1 # -1 means load latest checkpoint
runner_class_name: OnPolicyRunner
seed: 1

20
docs/Makefile Normal file
View File

@ -0,0 +1,20 @@
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

32
docs/conf.py Normal file
View File

@ -0,0 +1,32 @@
# Configuration file for the Sphinx documentation builder.
#
# For the full list of built-in configuration values, see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
import os
import sys
sys.path.insert(0, os.path.abspath(".."))
# -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
project = "rsl_rl"
copyright = "2023, Lukas Schneider"
author = "Lukas Schneider"
release = "1.0.0"
# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
extensions = ["sphinx.ext.autodoc", "sphinx.ext.napoleon"]
templates_path = ["_templates"]
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
html_theme = "sphinx_rtd_theme"
html_static_path = ["_static"]

20
docs/index.rst Normal file
View File

@ -0,0 +1,20 @@
.. rsl_rl documentation master file, created by
sphinx-quickstart on Tue Jul 4 16:39:24 2023.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
Welcome to rsl_rl's documentation!
==================================
.. toctree::
:maxdepth: 2
:caption: Contents:
Indices and tables
==================
* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`

35
docs/make.bat Normal file
View File

@ -0,0 +1,35 @@
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=.
set BUILDDIR=_build
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.https://www.sphinx-doc.org/
exit /b 1
)
if "%1" == "" goto help
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd

0
examples/__init__.py Normal file
View File

92
examples/benchmark.py Normal file
View File

@ -0,0 +1,92 @@
import numpy as np
import os
import torch
import wandb
from rsl_rl.algorithms import *
from rsl_rl.env.gym_env import GymEnv
from rsl_rl.runners.runner import Runner
from rsl_rl.runners.callbacks import make_wandb_cb
from hyperparams import hyperparams
from wandb_config import WANDB_API_KEY, WANDB_ENTITY
ALGORITHMS = [PPO, DPPO]
ENVIRONMENTS = ["BipedalWalker-v3"]
ENVIRONMENT_KWARGS = [{}]
EXPERIMENT_DIR = os.environ.get("EXPERIMENT_DIRECTORY", "./")
EXPERIMENT_NAME = os.environ.get("EXPERIMENT_NAME", "benchmark")
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
RENDER_VIDEO = False
RETURN_EPOCHS = 100 # Number of epochs to average return over
LOG_WANDB = True
RUNS = 3
TRAIN_TIMEOUT = 60 * 10 # Training time (in seconds)
TRAIN_ENV_STEPS = None # Number of training environment steps
os.environ["WANDB_API_KEY"] = WANDB_API_KEY
def run(alg_class, env_name, env_kwargs={}):
try:
hp = hyperparams[alg_class.__name__][env_name]
except KeyError:
print("No hyperparameters found. Using default values.")
hp = dict(agent_kwargs={}, env_kwargs={"environment_count": 1}, runner_kwargs={"num_steps_per_env": 1})
agent_kwargs = dict(device=DEVICE, **hp["agent_kwargs"])
env_kwargs = dict(name=env_name, gym_kwargs=env_kwargs, **hp["env_kwargs"])
runner_kwargs = dict(device=DEVICE, **hp["runner_kwargs"])
learn_steps = (
None
if TRAIN_ENV_STEPS is None
else int(np.ceil(TRAIN_ENV_STEPS / (env_kwargs["environment_count"] * runner_kwargs["num_steps_per_env"])))
)
learn_timeout = None if TRAIN_TIMEOUT is None else TRAIN_TIMEOUT
video_directory = f"{EXPERIMENT_DIR}/{EXPERIMENT_NAME}/videos/{env_name}/{alg_class.__name__}"
save_video_cb = (
lambda ep, file: wandb.log({f"video-{ep}": wandb.Video(file, fps=4, format="mp4")}) if LOG_WANDB else None
)
env = GymEnv(**env_kwargs, draw=RENDER_VIDEO, draw_cb=save_video_cb, draw_directory=video_directory)
agent = alg_class(env, **agent_kwargs)
config = dict(
agent_kwargs=agent_kwargs,
env_kwargs=env_kwargs,
learn_steps=learn_steps,
learn_timeout=learn_timeout,
runner_kwargs=runner_kwargs,
)
wandb_learn_config = dict(
config=config,
entity=WANDB_ENTITY,
group=f"{alg_class.__name__}_{env_name}",
project="rsl_rl-benchmark",
tags=[alg_class.__name__, env_name, "train"],
)
runner = Runner(env, agent, **runner_kwargs)
runner._learn_cb = [lambda *args, **kwargs: Runner._log(*args, prefix=f"{alg_class.__name__}_{env_name}", **kwargs)]
if LOG_WANDB:
runner._learn_cb.append(make_wandb_cb(wandb_learn_config))
runner.learn(iterations=learn_steps, timeout=learn_timeout, return_epochs=RETURN_EPOCHS)
env.close()
def main():
for algorithm in ALGORITHMS:
for i, env_name in enumerate(ENVIRONMENTS):
env_kwargs = ENVIRONMENT_KWARGS[i]
for _ in range(RUNS):
run(algorithm, env_name, env_kwargs=env_kwargs)
if __name__ == "__main__":
main()

26
examples/example.py Normal file
View File

@ -0,0 +1,26 @@
import torch
from rsl_rl.algorithms import *
from rsl_rl.env.gym_env import GymEnv
from rsl_rl.runners.runner import Runner
from hyperparams import hyperparams
ALGORITHM = DPPO
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
TASK = "BipedalWalker-v3"
def main():
hp = hyperparams[ALGORITHM.__name__][TASK]
env = GymEnv(name=TASK, device=DEVICE, draw=True, **hp["env_kwargs"])
agent = ALGORITHM(env, benchmark=True, device=DEVICE, **hp["agent_kwargs"])
runner = Runner(env, agent, device=DEVICE, **hp["runner_kwargs"])
runner._learn_cb = [Runner._log]
runner.learn(5000)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,7 @@
from rsl_rl.algorithms import PPO, DPPO
from .dppo import dppo_hyperparams
from .ppo import ppo_hyperparams
hyperparams = {DPPO.__name__: dppo_hyperparams, PPO.__name__: ppo_hyperparams}
__all__ = ["hyperparams"]

View File

@ -0,0 +1,202 @@
import copy
import numpy as np
from rsl_rl.algorithms import DPPO
from rsl_rl.modules import QuantileNetwork
default = dict()
default["env_kwargs"] = dict(environment_count=1)
default["runner_kwargs"] = dict(num_steps_per_env=2048)
default["agent_kwargs"] = dict(
actor_activations=["tanh", "tanh", "linear"],
actor_hidden_dims=[64, 64],
actor_input_normalization=False,
actor_noise_std=np.exp(0.0),
batch_count=(default["env_kwargs"]["environment_count"] * default["runner_kwargs"]["num_steps_per_env"] // 64),
clip_ratio=0.2,
critic_activations=["tanh", "tanh"],
critic_hidden_dims=[64, 64],
critic_input_normalization=False,
entropy_coeff=0.0,
gae_lambda=0.95,
gamma=0.99,
gradient_clip=0.5,
learning_rate=0.0003,
qrdqn_quantile_count=50,
schedule="adaptive",
target_kl=0.01,
value_coeff=0.5,
value_measure=QuantileNetwork.measure_neutral,
value_measure_kwargs={},
)
# Parameters optimized for PPO
ant_v4 = copy.deepcopy(default)
ant_v4["env_kwargs"]["environment_count"] = 128
ant_v4["runner_kwargs"]["num_steps_per_env"] = 64
ant_v4["agent_kwargs"]["actor_activations"] = ["tanh", "tanh", "linear"]
ant_v4["agent_kwargs"]["actor_hidden_dims"] = [64, 64]
ant_v4["agent_kwargs"]["actor_noise_std"] = 0.2611
ant_v4["agent_kwargs"]["batch_count"] = 12
ant_v4["agent_kwargs"]["clip_ratio"] = 0.4
ant_v4["agent_kwargs"]["critic_activations"] = ["tanh", "tanh"]
ant_v4["agent_kwargs"]["critic_hidden_dims"] = [64, 64]
ant_v4["agent_kwargs"]["entropy_coeff"] = 0.0102
ant_v4["agent_kwargs"]["gae_lambda"] = 0.92
ant_v4["agent_kwargs"]["gamma"] = 0.9731
ant_v4["agent_kwargs"]["gradient_clip"] = 5.0
ant_v4["agent_kwargs"]["learning_rate"] = 0.8755
ant_v4["agent_kwargs"]["target_kl"] = 0.1711
ant_v4["agent_kwargs"]["value_coeff"] = 0.6840
"""
Tuned for environment interactions:
[I 2023-01-03 03:11:29,212] Trial 19 finished with value: 0.5272218152693111 and parameters: {
'env_count': 16,
'actor_noise_std': 0.7304437880901905,
'batch_count': 10,
'clip_ratio': 0.3,
'entropy_coeff': 0.004236574285220795,
'gae_lambda': 0.95,
'gamma': 0.9890074826092162,
'gradient_clip': 0.9,
'learning_rate': 0.18594043324129061,
'steps_per_env': 256,
'target_kl': 0.05838576142010138,
'value_coeff': 0.14402022632575992,
'net_arch': 'small',
'net_activation': 'relu'
}. Best is trial 19 with value: 0.5272218152693111.
Tuned for training time:
[I 2023-01-08 21:09:06,069] Trial 407 finished with value: 7.497591958940029 and parameters: {
'actor_noise_std': 0.1907398121300662,
'batch_count': 3,
'clip_ratio': 0.1,
'entropy_coeff': 0.0053458057035692735,
'env_count': 16,
'gae_lambda': 0.8,
'gamma': 0.985000267068182,
'gradient_clip': 2.0,
'learning_rate': 0.605956844400053,
'steps_per_env': 512,
'target_kl': 0.17611450607281642,
'value_coeff': 0.46015664905111847,
'actor_net_arch': 'small',
'critic_net_arch': 'medium',
'actor_net_activation': 'relu',
'critic_net_activation': 'relu',
'qrdqn_quantile_count': 200,
'value_measure': 'neutral'
}. Best is trial 407 with value: 7.497591958940029.
"""
bipedal_walker_v3 = copy.deepcopy(default)
bipedal_walker_v3["env_kwargs"]["environment_count"] = 256
bipedal_walker_v3["runner_kwargs"]["num_steps_per_env"] = 16
bipedal_walker_v3["agent_kwargs"]["actor_activations"] = ["relu", "relu", "relu", "linear"]
bipedal_walker_v3["agent_kwargs"]["actor_hidden_dims"] = [512, 256, 128]
bipedal_walker_v3["agent_kwargs"]["actor_noise_std"] = 0.8505
bipedal_walker_v3["agent_kwargs"]["batch_count"] = 10
bipedal_walker_v3["agent_kwargs"]["clip_ratio"] = 0.1
bipedal_walker_v3["agent_kwargs"]["critic_activations"] = ["relu", "relu"]
bipedal_walker_v3["agent_kwargs"]["critic_hidden_dims"] = [256, 256]
bipedal_walker_v3["agent_kwargs"]["critic_network"] = DPPO.network_qrdqn
bipedal_walker_v3["agent_kwargs"]["entropy_coeff"] = 0.0917
bipedal_walker_v3["agent_kwargs"]["gae_lambda"] = 0.95
bipedal_walker_v3["agent_kwargs"]["gamma"] = 0.9553
bipedal_walker_v3["agent_kwargs"]["gradient_clip"] = 2.0
bipedal_walker_v3["agent_kwargs"]["iqn_action_samples"] = 32
bipedal_walker_v3["agent_kwargs"]["iqn_embedding_size"] = 64
bipedal_walker_v3["agent_kwargs"]["iqn_feature_layers"] = 1
bipedal_walker_v3["agent_kwargs"]["iqn_value_samples"] = 8
bipedal_walker_v3["agent_kwargs"]["learning_rate"] = 0.4762
bipedal_walker_v3["agent_kwargs"]["qrdqn_quantile_count"] = 200
bipedal_walker_v3["agent_kwargs"]["recurrent"] = False
bipedal_walker_v3["agent_kwargs"]["target_kl"] = 0.1999
bipedal_walker_v3["agent_kwargs"]["value_coeff"] = 0.4435
"""
[I 2023-01-12 08:01:35,514] Trial 476 finished with value: 5202.960759290059 and parameters: {
'actor_noise_std': 0.15412869066185989,
'batch_count': 11,
'clip_ratio': 0.3,
'entropy_coeff': 0.036031209302206955,
'env_count': 128,
'gae_lambda': 0.92,
'gamma': 0.973937576989299,
'gradient_clip': 5.0,
'learning_rate': 0.1621249118505433,
'steps_per_env': 128,
'target_kl': 0.05054738172852222,
'value_coeff': 0.1647632125820593,
'actor_net_arch': 'small',
'critic_net_arch': 'medium',
'actor_net_activation': 'tanh',
'critic_net_activation': 'relu',
'qrdqn_quantile_count': 50,
'value_measure': 'var-risk-averse'
}. Best is trial 476 with value: 5202.960759290059.
"""
half_cheetah_v4 = copy.deepcopy(default)
half_cheetah_v4["env_kwargs"]["environment_count"] = 128
half_cheetah_v4["runner_kwargs"]["num_steps_per_env"] = 128
half_cheetah_v4["agent_kwargs"]["actor_activations"] = ["tanh", "tanh", "linear"]
half_cheetah_v4["agent_kwargs"]["actor_hidden_dims"] = [64, 64]
half_cheetah_v4["agent_kwargs"]["actor_noise_std"] = 0.1541
half_cheetah_v4["agent_kwargs"]["batch_count"] = 11
half_cheetah_v4["agent_kwargs"]["clip_ratio"] = 0.3
half_cheetah_v4["agent_kwargs"]["critic_activations"] = ["relu", "relu"]
half_cheetah_v4["agent_kwargs"]["critic_hidden_dims"] = [256, 256]
half_cheetah_v4["agent_kwargs"]["entropy_coeff"] = 0.03603
half_cheetah_v4["agent_kwargs"]["gae_lambda"] = 0.92
half_cheetah_v4["agent_kwargs"]["gamma"] = 0.9739
half_cheetah_v4["agent_kwargs"]["gradient_clip"] = 5.0
half_cheetah_v4["agent_kwargs"]["learning_rate"] = 0.1621
half_cheetah_v4["agent_kwargs"]["qrdqn_quantile_count"] = 50
half_cheetah_v4["agent_kwargs"]["target_kl"] = 0.0505
half_cheetah_v4["agent_kwargs"]["value_coeff"] = 0.1648
half_cheetah_v4["agent_kwargs"]["value_measure"] = QuantileNetwork.measure_percentile
half_cheetah_v4["agent_kwargs"]["value_measure_kwargs"] = dict(confidence_level=0.25)
# Parameters optimized for PPO
hopper_v4 = copy.deepcopy(default)
hopper_v4["runner_kwargs"]["num_steps_per_env"] = 128
hopper_v4["agent_kwargs"]["actor_activations"] = ["relu", "relu", "linear"]
hopper_v4["agent_kwargs"]["actor_hidden_dims"] = [256, 256]
hopper_v4["agent_kwargs"]["actor_noise_std"] = 0.5590
hopper_v4["agent_kwargs"]["batch_count"] = 15
hopper_v4["agent_kwargs"]["clip_ratio"] = 0.2
hopper_v4["agent_kwargs"]["critic_activations"] = ["relu", "relu", "linear"]
hopper_v4["agent_kwargs"]["critic_hidden_dims"] = [32, 32]
hopper_v4["agent_kwargs"]["entropy_coeff"] = 0.03874
hopper_v4["agent_kwargs"]["gae_lambda"] = 0.98
hopper_v4["agent_kwargs"]["gamma"] = 0.9890
hopper_v4["agent_kwargs"]["gradient_clip"] = 1.0
hopper_v4["agent_kwargs"]["learning_rate"] = 0.3732
hopper_v4["agent_kwargs"]["value_coeff"] = 0.8163
swimmer_v4 = copy.deepcopy(default)
swimmer_v4["agent_kwargs"]["gamma"] = 0.9999
walker2d_v4 = copy.deepcopy(default)
walker2d_v4["runner_kwargs"]["num_steps_per_env"] = 512
walker2d_v4["agent_kwargs"]["batch_count"] = (
walker2d_v4["env_kwargs"]["environment_count"] * walker2d_v4["runner_kwargs"]["num_steps_per_env"] // 32
)
walker2d_v4["agent_kwargs"]["clip_ratio"] = 0.1
walker2d_v4["agent_kwargs"]["entropy_coeff"] = 0.000585045
walker2d_v4["agent_kwargs"]["gae_lambda"] = 0.95
walker2d_v4["agent_kwargs"]["gamma"] = 0.99
walker2d_v4["agent_kwargs"]["gradient_clip"] = 1.0
walker2d_v4["agent_kwargs"]["learning_rate"] = 5.05041e-05
walker2d_v4["agent_kwargs"]["value_coeff"] = 0.871923
dppo_hyperparams = {
"default": default,
"Ant-v4": ant_v4,
"BipedalWalker-v3": bipedal_walker_v3,
"HalfCheetah-v4": half_cheetah_v4,
"Hopper-v4": hopper_v4,
"Swimmer-v4": swimmer_v4,
"Walker2d-v4": walker2d_v4,
}

226
examples/hyperparams/ppo.py Normal file
View File

@ -0,0 +1,226 @@
import copy
import numpy as np
default = dict()
default["env_kwargs"] = dict(environment_count=1)
default["runner_kwargs"] = dict(num_steps_per_env=2048)
default["agent_kwargs"] = dict(
actor_activations=["tanh", "tanh", "linear"],
actor_hidden_dims=[64, 64],
actor_input_normalization=False,
actor_noise_std=np.exp(0.0),
batch_count=(default["env_kwargs"]["environment_count"] * default["runner_kwargs"]["num_steps_per_env"] // 64),
clip_ratio=0.2,
critic_activations=["tanh", "tanh", "linear"],
critic_hidden_dims=[64, 64],
critic_input_normalization=False,
entropy_coeff=0.0,
gae_lambda=0.95,
gamma=0.99,
gradient_clip=0.5,
learning_rate=0.0003,
schedule="adaptive",
target_kl=0.01,
value_coeff=0.5,
)
"""
[I 2023-01-09 00:33:02,217] Trial 85 finished with value: 2191.0249068421276 and parameters: {
'actor_noise_std': 0.2611334861249876,
'batch_count': 12,
'clip_ratio': 0.4,
'entropy_coeff': 0.010204149626344796,
'env_count': 128,
'gae_lambda': 0.92,
'gamma': 0.9730549104215155,
'gradient_clip': 5.0,
'learning_rate': 0.8754540531090014,
'steps_per_env': 64,
'target_kl': 0.17110535070344035,
'value_coeff': 0.6840401569818934,
'actor_net_arch': 'small',
'critic_net_arch': 'small',
'actor_net_activation': 'tanh',
'critic_net_activation': 'tanh'
}. Best is trial 85 with value: 2191.0249068421276.
"""
ant_v3 = copy.deepcopy(default)
ant_v3["env_kwargs"]["environment_count"] = 128
ant_v3["runner_kwargs"]["num_steps_per_env"] = 64
ant_v3["agent_kwargs"]["actor_activations"] = ["tanh", "tanh", "linear"]
ant_v3["agent_kwargs"]["actor_hidden_dims"] = [64, 64]
ant_v3["agent_kwargs"]["actor_noise_std"] = 0.2611
ant_v3["agent_kwargs"]["batch_count"] = 12
ant_v3["agent_kwargs"]["clip_ratio"] = 0.4
ant_v3["agent_kwargs"]["critic_activations"] = ["tanh", "tanh", "linear"]
ant_v3["agent_kwargs"]["critic_hidden_dims"] = [64, 64]
ant_v3["agent_kwargs"]["entropy_coeff"] = 0.0102
ant_v3["agent_kwargs"]["gae_lambda"] = 0.92
ant_v3["agent_kwargs"]["gamma"] = 0.9731
ant_v3["agent_kwargs"]["gradient_clip"] = 5.0
ant_v3["agent_kwargs"]["learning_rate"] = 0.8755
ant_v3["agent_kwargs"]["target_kl"] = 0.1711
ant_v3["agent_kwargs"]["value_coeff"] = 0.6840
"""
Standard:
[I 2023-01-17 07:43:46,884] Trial 125 finished with value: 150.23491836690064 and parameters: {
'actor_net_activation': 'relu',
'actor_net_arch': 'large',
'actor_noise_std': 0.8504545432069994,
'batch_count': 10,
'clip_ratio': 0.1,
'critic_net_activation': 'relu',
'critic_net_arch': 'medium',
'entropy_coeff': 0.0916881539697197,
'env_count': 256,
'gae_lambda': 0.95,
'gamma': 0.955285858564339,
'gradient_clip': 2.0,
'learning_rate': 0.4762365866431558,
'steps_per_env': 16,
'recurrent': False,
'target_kl': 0.19991906392721126,
'value_coeff': 0.4434793554275927
}. Best is trial 125 with value: 150.23491836690064.
Hardcore:
[I 2023-01-09 06:25:44,000] Trial 262 finished with value: 2.290071208278338 and parameters: {
'actor_noise_std': 0.2710521003644249,
'batch_count': 6,
'clip_ratio': 0.1,
'entropy_coeff': 0.005105282891378981,
'env_count': 16,
'gae_lambda': 1.0,
'gamma': 0.9718119008688937,
'gradient_clip': 0.1,
'learning_rate': 0.4569184610431825,
'steps_per_env': 256,
'target_kl': 0.11068348002480229,
'value_coeff': 0.19453900570701116,
'actor_net_arch': 'small',
'critic_net_arch': 'medium',
'actor_net_activation': 'relu',
'critic_net_activation': 'relu'
}. Best is trial 262 with value: 2.290071208278338.
"""
bipedal_walker_v3 = copy.deepcopy(default)
bipedal_walker_v3["env_kwargs"]["environment_count"] = 256
bipedal_walker_v3["runner_kwargs"]["num_steps_per_env"] = 16
bipedal_walker_v3["agent_kwargs"]["actor_activations"] = ["relu", "relu", "relu", "linear"]
bipedal_walker_v3["agent_kwargs"]["actor_hidden_dims"] = [512, 256, 128]
bipedal_walker_v3["agent_kwargs"]["actor_noise_std"] = 0.8505
bipedal_walker_v3["agent_kwargs"]["batch_count"] = 10
bipedal_walker_v3["agent_kwargs"]["clip_ratio"] = 0.1
bipedal_walker_v3["agent_kwargs"]["critic_activations"] = ["relu", "relu", "linear"]
bipedal_walker_v3["agent_kwargs"]["critic_hidden_dims"] = [256, 256]
bipedal_walker_v3["agent_kwargs"]["entropy_coeff"] = 0.0917
bipedal_walker_v3["agent_kwargs"]["gae_lambda"] = 0.95
bipedal_walker_v3["agent_kwargs"]["gamma"] = 0.9553
bipedal_walker_v3["agent_kwargs"]["gradient_clip"] = 2.0
bipedal_walker_v3["agent_kwargs"]["learning_rate"] = 0.4762
bipedal_walker_v3["agent_kwargs"]["target_kl"] = 0.1999
bipedal_walker_v3["agent_kwargs"]["value_coeff"] = 0.4435
"""
[I 2023-01-04 05:57:20,749] Trial 1451 finished with value: 5260.338678148058 and parameters: {
'env_count': 32,
'actor_noise_std': 0.3397405098274084,
'batch_count': 6,
'clip_ratio': 0.3,
'entropy_coeff': 0.009392937880259133,
'gae_lambda': 0.8,
'gamma': 0.9683403243382301,
'gradient_clip': 5.0,
'learning_rate': 0.5985206877398142,
'steps_per_env': 16,
'target_kl': 0.027651917189297347,
'value_coeff': 0.26705235341068373,
'net_arch': 'medium',
'net_activation': 'tanh'
}. Best is trial 1451 with value: 5260.338678148058.
"""
half_cheetah_v3 = copy.deepcopy(default)
half_cheetah_v3["env_kwargs"]["environment_count"] = 32
half_cheetah_v3["runner_kwargs"]["num_steps_per_env"] = 16
half_cheetah_v3["agent_kwargs"]["actor_activations"] = ["tanh", "tanh", "linear"]
half_cheetah_v3["agent_kwargs"]["actor_hidden_dims"] = [256, 256]
half_cheetah_v3["agent_kwargs"]["actor_noise_std"] = 0.3397
half_cheetah_v3["agent_kwargs"]["batch_count"] = 6
half_cheetah_v3["agent_kwargs"]["clip_ratio"] = 0.3
half_cheetah_v3["agent_kwargs"]["critic_activations"] = ["tanh", "tanh", "linear"]
half_cheetah_v3["agent_kwargs"]["critic_hidden_dims"] = [256, 256]
half_cheetah_v3["agent_kwargs"]["entropy_coeff"] = 0.009393
half_cheetah_v3["agent_kwargs"]["gae_lambda"] = 0.8
half_cheetah_v3["agent_kwargs"]["gamma"] = 0.9683
half_cheetah_v3["agent_kwargs"]["gradient_clip"] = 5.0
half_cheetah_v3["agent_kwargs"]["learning_rate"] = 0.5985
half_cheetah_v3["agent_kwargs"]["target_kl"] = 0.02765
half_cheetah_v3["agent_kwargs"]["value_coeff"] = 0.2671
"""
[I 2023-01-08 18:38:51,481] Trial 25 finished with value: 2225.9547948810073 and parameters: {
'actor_noise_std': 0.5589708917145111,
'batch_count': 15,
'clip_ratio': 0.2,
'entropy_coeff': 0.03874027035272886,
'env_count': 128,
'gae_lambda': 0.98,
'gamma': 0.9879577396280973,
'gradient_clip': 1.0,
'learning_rate': 0.3732431793266761,
'steps_per_env': 128,
'target_kl': 0.12851506672519566,
'value_coeff': 0.8162548885723906,
'actor_net_arch': 'medium',
'critic_net_arch': 'small',
'actor_net_activation': 'relu',
'critic_net_activation': 'relu'
}. Best is trial 25 with value: 2225.9547948810073.
"""
hopper_v3 = copy.deepcopy(default)
half_cheetah_v3["env_kwargs"]["environment_count"] = 128
hopper_v3["runner_kwargs"]["num_steps_per_env"] = 128
hopper_v3["agent_kwargs"]["actor_activations"] = ["relu", "relu", "linear"]
hopper_v3["agent_kwargs"]["actor_hidden_dims"] = [256, 256]
hopper_v3["agent_kwargs"]["actor_noise_std"] = 0.5590
hopper_v3["agent_kwargs"]["batch_count"] = 15
hopper_v3["agent_kwargs"]["clip_ratio"] = 0.2
hopper_v3["agent_kwargs"]["critic_activations"] = ["relu", "relu", "linear"]
hopper_v3["agent_kwargs"]["critic_hidden_dims"] = [32, 32]
hopper_v3["agent_kwargs"]["entropy_coeff"] = 0.03874
hopper_v3["agent_kwargs"]["gae_lambda"] = 0.98
hopper_v3["agent_kwargs"]["gamma"] = 0.9890
hopper_v3["agent_kwargs"]["gradient_clip"] = 1.0
hopper_v3["agent_kwargs"]["learning_rate"] = 0.3732
hopper_v3["agent_kwargs"]["value_coeff"] = 0.8163
swimmer_v3 = copy.deepcopy(default)
swimmer_v3["agent_kwargs"]["gamma"] = 0.9999
walker2d_v3 = copy.deepcopy(default)
walker2d_v3["runner_kwargs"]["num_steps_per_env"] = 512
walker2d_v3["agent_kwargs"]["batch_count"] = (
walker2d_v3["env_kwargs"]["environment_count"] * walker2d_v3["runner_kwargs"]["num_steps_per_env"] // 32
)
walker2d_v3["agent_kwargs"]["clip_ratio"] = 0.1
walker2d_v3["agent_kwargs"]["entropy_coeff"] = 0.000585045
walker2d_v3["agent_kwargs"]["gae_lambda"] = 0.95
walker2d_v3["agent_kwargs"]["gamma"] = 0.99
walker2d_v3["agent_kwargs"]["gradient_clip"] = 1.0
walker2d_v3["agent_kwargs"]["learning_rate"] = 5.05041e-05
walker2d_v3["agent_kwargs"]["value_coeff"] = 0.871923
ppo_hyperparams = {
"default": default,
"Ant-v3": ant_v3,
"Ant-v4": ant_v3,
"BipedalWalker-v3": bipedal_walker_v3,
"HalfCheetah-v3": half_cheetah_v3,
"HalfCheetah-v4": half_cheetah_v3,
"Hopper-v3": hopper_v3,
"Hopper-v4": hopper_v3,
"Swimmer-v3": swimmer_v3,
"Swimmer-v4": swimmer_v3,
"Walker2d-v3": walker2d_v3,
"Walker2d-v4": walker2d_v3,
}

103
examples/tune.py Normal file
View File

@ -0,0 +1,103 @@
from rsl_rl.algorithms import *
from rsl_rl.env.gym_env import GymEnv
from rsl_rl.runners.runner import Runner
import copy
from datetime import datetime
import numpy as np
import optuna
import os
import random
import torch
from tune_cfg import samplers
ALGORITHM = PPO
ENVIRONMENT = "BipedalWalker-v3"
ENVIRONMENT_KWARGS = {}
EVAL_AGENTS = 64
EVAL_RUNS = 10
EVAL_STEPS = 1000
EXPERIMENT_DIR = os.environ.get("EXPERIMENT_DIRECTORY", "./")
EXPERIMENT_NAME = os.environ.get("EXPERIMENT_NAME", f"tune-{ALGORITHM.__name__}-{ENVIRONMENT}")
TRAIN_ITERATIONS = None
TRAIN_TIMEOUT = 60 * 15 # 10 minutes
TRAIN_RUNS = 3
TRAIN_SEED = None
def tune():
assert TRAIN_RUNS == 1 or TRAIN_SEED is None, "If multiple runs are used, the seed must be None."
storage = optuna.storages.RDBStorage(url=f"sqlite:///{EXPERIMENT_DIR}/{EXPERIMENT_NAME}.db")
pruner = optuna.pruners.MedianPruner(n_startup_trials=10)
try:
study = optuna.create_study(direction="maximize", pruner=pruner, storage=storage, study_name=EXPERIMENT_NAME)
except Exception:
study = optuna.load_study(pruner=pruner, storage=storage, study_name=EXPERIMENT_NAME)
study.optimize(objective, n_trials=100)
def seed(s=None):
seed = int(datetime.now().timestamp() * 1e6) % 2**32 if s is None else s
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def objective(trial):
seed()
agent_kwargs, env_kwargs, runner_kwargs = samplers[ALGORITHM.__name__](trial)
evaluations = []
for instantiation in range(TRAIN_RUNS):
seed(TRAIN_SEED)
env = GymEnv(ENVIRONMENT, gym_kwargs=ENVIRONMENT_KWARGS, **env_kwargs)
agent = ALGORITHM(env, **agent_kwargs)
runner = Runner(env, agent, **runner_kwargs)
runner._learn_cb = [lambda _, stat: runner._log_progress(stat, prefix=f"learn {instantiation+1}/{TRAIN_RUNS}")]
eval_env_kwargs = copy.deepcopy(env_kwargs)
eval_env_kwargs["environment_count"] = EVAL_AGENTS
eval_runner = Runner(
GymEnv(ENVIRONMENT, gym_kwargs=ENVIRONMENT_KWARGS, **env_kwargs),
agent,
**runner_kwargs,
)
eval_runner._eval_cb = [
lambda _, stat: runner._log_progress(stat, prefix=f"eval {instantiation+1}/{TRAIN_RUNS}")
]
try:
runner.learn(TRAIN_ITERATIONS, timeout=TRAIN_TIMEOUT)
except Exception:
raise optuna.TrialPruned()
intermediate_evaluations = []
for eval_run in range(EVAL_RUNS):
eval_runner._eval_cb = [lambda _, stat: runner._log_progress(stat, prefix=f"eval {eval_run+1}/{EVAL_RUNS}")]
seed()
eval_runner.env.reset()
intermediate_evaluations.append(eval_runner.evaluate(steps=EVAL_STEPS))
eval = np.mean(intermediate_evaluations)
trial.report(eval, instantiation)
if trial.should_prune():
raise optuna.TrialPruned()
evaluations.append(eval)
evaluation = np.mean(evaluations)
return evaluation
if __name__ == "__main__":
tune()

134
examples/tune_cfg.py Normal file
View File

@ -0,0 +1,134 @@
import torch
from rsl_rl.algorithms import DPPO, PPO
from rsl_rl.modules import QuantileNetwork
NETWORKS = {"small": [64, 64], "medium": [256, 256], "large": [512, 256, 128]}
def sample_dppo_hyperparams(trial):
actor_net_activation = trial.suggest_categorical("actor_net_activation", ["relu", "tanh"])
actor_net_arch = trial.suggest_categorical("actor_net_arch", list(NETWORKS.keys()))
actor_noise_std = trial.suggest_float("actor_noise_std", 0.05, 1.0)
batch_count = trial.suggest_int("batch_count", 1, 20)
clip_ratio = trial.suggest_categorical("clip_ratio", [0.1, 0.2, 0.3, 0.4])
critic_net_activation = trial.suggest_categorical("critic_net_activation", ["relu", "tanh"])
critic_net_arch = trial.suggest_categorical("critic_net_arch", list(NETWORKS.keys()))
entropy_coeff = trial.suggest_float("entropy_coeff", 0.00000001, 0.1)
env_count = trial.suggest_categorical("env_count", [1, 8, 16, 32, 64, 128, 256, 512])
gae_lambda = trial.suggest_categorical("gae_lambda", [0.8, 0.9, 0.92, 0.95, 0.98, 0.99, 1.0])
gamma = trial.suggest_float("gamma", 0.95, 0.999)
gradient_clip = trial.suggest_categorical("gradient_clip", [0.1, 0.3, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 2.0, 5.0])
learning_rate = trial.suggest_float("learning_rate", 1e-5, 1)
num_steps_per_env = trial.suggest_categorical("steps_per_env", [8, 16, 32, 64, 128, 256, 512, 1024, 2048])
quantile_count = trial.suggest_categorical("quantile_count", [20, 50, 100, 200])
recurrent = trial.suggest_categorical("recurrent", [True, False])
target_kl = trial.suggest_float("target_kl", 0.01, 0.3)
value_coeff = trial.suggest_float("value_coeff", 0.0, 1.0)
value_measure = trial.suggest_categorical(
"value_measure",
["neutral", "var-risk-averse", "var-risk-seeking", "var-super-risk-averse", "var-super-risk-seeking"],
)
actor_net_arch = NETWORKS[actor_net_arch]
critic_net_arch = NETWORKS[critic_net_arch]
value_measure_kwargs = {
"neutral": dict(),
"var-risk-averse": dict(confidence_level=0.25),
"var-risk-seeking": dict(confidence_level=0.75),
"var-super-risk-averse": dict(confidence_level=0.1),
"var-super-risk-seeking": dict(confidence_level=0.9),
}[value_measure]
value_measure = {
"neutral": QuantileNetwork.measure_neutral,
"var-risk-averse": QuantileNetwork.measure_percentile,
"var-risk-seeking": QuantileNetwork.measure_percentile,
"var-super-risk-averse": QuantileNetwork.measure_percentile,
"var-super-risk-seeking": QuantileNetwork.measure_percentile,
}[value_measure]
device = "cuda:0" if env_count * num_steps_per_env > 2048 and torch.cuda.is_available() else "cpu"
agent_kwargs = dict(
actor_activations=([actor_net_activation] * len(actor_net_arch)) + ["linear"],
actor_hidden_dims=actor_net_arch,
actor_input_normalization=False,
actor_noise_std=actor_noise_std,
batch_count=batch_count,
clip_ratio=clip_ratio,
critic_activations=([critic_net_activation] * len(critic_net_arch)),
critic_hidden_dims=critic_net_arch,
critic_input_normalization=False,
device=device,
entropy_coeff=entropy_coeff,
gae_lambda=gae_lambda,
gamma=gamma,
gradient_clip=gradient_clip,
learning_rate=learning_rate,
quantile_count=quantile_count,
recurrent=recurrent,
schedule="adaptive",
target_kl=target_kl,
value_coeff=value_coeff,
value_measure=value_measure,
value_measure_kwargs=value_measure_kwargs,
)
env_kwargs = dict(device=device, environment_count=env_count)
runner_kwargs = dict(device=device, num_steps_per_env=num_steps_per_env)
return agent_kwargs, env_kwargs, runner_kwargs
def sample_ppo_hyperparams(trial):
actor_net_activation = trial.suggest_categorical("actor_net_activation", ["relu", "tanh"])
actor_net_arch = trial.suggest_categorical("actor_net_arch", list(NETWORKS.keys()))
actor_noise_std = trial.suggest_float("actor_noise_std", 0.05, 1.0)
batch_count = trial.suggest_int("batch_count", 1, 20)
clip_ratio = trial.suggest_categorical("clip_ratio", [0.1, 0.2, 0.3, 0.4])
critic_net_activation = trial.suggest_categorical("critic_net_activation", ["relu", "tanh"])
critic_net_arch = trial.suggest_categorical("critic_net_arch", list(NETWORKS.keys()))
entropy_coeff = trial.suggest_float("entropy_coeff", 0.00000001, 0.1)
env_count = trial.suggest_categorical("env_count", [1, 8, 16, 32, 64, 128, 256, 512])
gae_lambda = trial.suggest_categorical("gae_lambda", [0.8, 0.9, 0.92, 0.95, 0.98, 0.99, 1.0])
gamma = trial.suggest_float("gamma", 0.95, 0.999)
gradient_clip = trial.suggest_categorical("gradient_clip", [0.1, 0.3, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 2.0, 5.0])
learning_rate = trial.suggest_float("learning_rate", 1e-5, 1)
num_steps_per_env = trial.suggest_categorical("steps_per_env", [8, 16, 32, 64, 128, 256, 512, 1024, 2048])
recurrent = trial.suggest_categorical("recurrent", [True, False])
target_kl = trial.suggest_float("target_kl", 0.01, 0.3)
value_coeff = trial.suggest_float("value_coeff", 0.0, 1.0)
actor_net_arch = NETWORKS[actor_net_arch]
critic_net_arch = NETWORKS[critic_net_arch]
device = "cuda:0" if env_count * num_steps_per_env > 2048 and torch.cuda.is_available() else "cpu"
agent_kwargs = dict(
actor_activations=([actor_net_activation] * len(actor_net_arch)) + ["linear"],
actor_hidden_dims=actor_net_arch,
actor_input_normalization=False,
actor_noise_std=actor_noise_std,
batch_count=batch_count,
clip_ratio=clip_ratio,
critic_activations=([critic_net_activation] * len(critic_net_arch)) + ["linear"],
critic_hidden_dims=critic_net_arch,
critic_input_normalization=False,
device=device,
entropy_coeff=entropy_coeff,
gae_lambda=gae_lambda,
gamma=gamma,
gradient_clip=gradient_clip,
learning_rate=learning_rate,
recurrent=recurrent,
schedule="adaptive",
target_kl=target_kl,
value_coeff=value_coeff,
)
env_kwargs = dict(device=device, environment_count=env_count)
runner_kwargs = dict(device=device, num_steps_per_env=num_steps_per_env)
return agent_kwargs, env_kwargs, runner_kwargs
samplers = {
DPPO.__name__: sample_dppo_hyperparams,
PPO.__name__: sample_ppo_hyperparams,
}

View File

@ -0,0 +1,4 @@
# To use this file, copy it to wandb_config.py and fill in the missing values.
WANDB_API_KEY = ""
WANDB_ENTITY = ""

View File

@ -1,7 +1,2 @@
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
# SPDX-License-Identifier: BSD-3-Clause
"""Main module for the rsl_rl package."""
__version__ = "2.0.1"
__license__ = "BSD-3"

View File

@ -1,8 +1,10 @@
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
# SPDX-License-Identifier: BSD-3-Clause
"""Implementation of different RL agents."""
from .agent import Agent
from .d4pg import D4PG
from .ddpg import DDPG
from .dppo import DPPO
from .dsac import DSAC
from .ppo import PPO
from .sac import SAC
from .td3 import TD3
__all__ = ["PPO"]
__all__ = ["Agent", "DDPG", "D4PG", "DPPO", "DSAC", "PPO", "SAC", "TD3"]

View File

@ -0,0 +1,371 @@
from abc import abstractmethod
import torch
from typing import Any, Callable, Dict, List, Tuple, Union
from rsl_rl.algorithms.agent import Agent
from rsl_rl.env.vec_env import VecEnv
from rsl_rl.modules.network import Network
from rsl_rl.storage.storage import Dataset
from rsl_rl.utils.utils import environment_dimensions
from rsl_rl.utils.utils import squeeze_preserve_batch
class AbstractActorCritic(Agent):
_alg_features = dict(recurrent=False)
def __init__(
self,
env: VecEnv,
actor_activations: List[str] = ["relu", "relu", "relu", "linear"],
actor_hidden_dims: List[int] = [256, 256, 256],
actor_init_gain: float = 0.5,
actor_input_normalization: bool = False,
actor_recurrent_layers: int = 1,
actor_recurrent_module: str = Network.recurrent_module_lstm,
actor_recurrent_tf_context_length: int = 64,
actor_recurrent_tf_head_count: int = 8,
actor_shared_dims: int = None,
batch_count: int = 1,
batch_size: int = 1,
critic_activations: List[str] = ["relu", "relu", "relu", "linear"],
critic_hidden_dims: List[int] = [256, 256, 256],
critic_init_gain: float = 0.5,
critic_input_normalization: bool = False,
critic_recurrent_layers: int = 1,
critic_recurrent_module: str = Network.recurrent_module_lstm,
critic_recurrent_tf_context_length: int = 64,
critic_recurrent_tf_head_count: int = 8,
critic_shared_dims: int = None,
polyak: float = 0.995,
recurrent: bool = False,
return_steps: int = 1,
_actor_input_size_delta: int = 0,
_critic_input_size_delta: int = 0,
**kwargs,
):
"""Creates an actor critic agent.
Args:
env (VecEnv): A vectorized environment.
actor_activations (List[str]): A list of activation functions for the actor network.
actor_hidden_dims (List[str]): A list of layer sizes for the actor network.
actor_init_gain (float): Network initialization gain for actor.
actor_input_normalization (bool): Whether to empirically normalize inputs to the actor network.
actor_recurrent_layers (int): The number of recurrent layers to use for the actor network.
actor_recurrent_module (str): The recurrent module to use. Must be one of Network.recurrent_modules.
actor_shared_dims (int): The number of dimensions to share for an actor with multiple heads.
batch_count (int): The number of batches to process per update step.
batch_size (int): The size of each batch to process during the update step.
critic_activations (List[str]): A list of activation functions for the critic network.
critic_hidden_dims: (List[str]): A list of layer sizes for the critic network.
critic_init_gain (float): Network initialization gain for critic.
critic_input_normalization (bool): Whether to empirically normalize inputs to the critic network.
critic_recurrent_layers (int): The number of recurrent layers to use for the critic network.
critic_recurrent_module (str): The recurrent module to use. Must be one of Network.recurrent_modules.
critic_shared_dims (int): The number of dimensions to share for a critic with multiple heads.
polyak (float): The actor-critic target network polyak factor.
recurrent (bool): Whether to use recurrent actor and critic networks.
recurrent_module (str): The recurrent module to use. Must be one of Network.recurrent_modules.
recurrent_tf_context_length (int): The context length of the Transformer.
recurrent_tf_head_count (int): The head count of the Transformer.
return_steps (float): The number of steps over which to compute the returns (n-step return).
_actor_input_size_delta (int): The number of additional dimensions to add to the actor input.
_critic_input_size_delta (int): The number of additional dimensions to add to the critic input.
"""
assert (
self._alg_features["recurrent"] == True or not recurrent
), f"{self.__class__.__name__} does not support recurrent networks."
super().__init__(env, **kwargs)
self.actor: torch.nn.Module = None
self.actor_optimizer: torch.nn.Module = None
self.critic_optimizer: torch.nn.Module = None
self.critic: torch.nn.Module = None
self._batch_size = batch_size
self._batch_count = batch_count
self._polyak_factor = polyak
self._return_steps = return_steps
self._recurrent = recurrent
self._register_serializable(
"_batch_size", "_batch_count", "_discount_factor", "_polyak_factor", "_return_steps"
)
dimensions = environment_dimensions(self.env)
try:
actor_input_size = dimensions["actor_observations"]
critic_input_size = dimensions["critic_observations"]
except KeyError:
actor_input_size = dimensions["observations"]
critic_input_size = dimensions["observations"]
self._actor_input_size = actor_input_size + _actor_input_size_delta
self._critic_input_size = critic_input_size + self._action_size + _critic_input_size_delta
self._register_actor_network_kwargs(
activations=actor_activations,
hidden_dims=actor_hidden_dims,
init_gain=actor_init_gain,
input_normalization=actor_input_normalization,
recurrent=recurrent,
recurrent_layers=actor_recurrent_layers,
recurrent_module=actor_recurrent_module,
recurrent_tf_context_length=actor_recurrent_tf_context_length,
recurrent_tf_head_count=actor_recurrent_tf_head_count,
)
if actor_shared_dims is not None:
self._register_actor_network_kwargs(shared_dims=actor_shared_dims)
self._register_critic_network_kwargs(
activations=critic_activations,
hidden_dims=critic_hidden_dims,
init_gain=critic_init_gain,
input_normalization=critic_input_normalization,
recurrent=recurrent,
recurrent_layers=critic_recurrent_layers,
recurrent_module=critic_recurrent_module,
recurrent_tf_context_length=critic_recurrent_tf_context_length,
recurrent_tf_head_count=critic_recurrent_tf_head_count,
)
if critic_shared_dims is not None:
self._register_critic_network_kwargs(shared_dims=critic_shared_dims)
self._register_serializable(
"_actor_input_size", "_actor_network_kwargs", "_critic_input_size", "_critic_network_kwargs"
)
# For computing n-step returns using prior transitions.
self._stored_dataset = []
def export_onnx(self) -> Tuple[torch.nn.Module, torch.Tensor, Dict]:
self.eval_mode()
class ONNXActor(torch.nn.Module):
def __init__(self, model: torch.nn.Module):
super().__init__()
self.model = model
def forward(self, x: torch.Tensor, hidden_state: Tuple[torch.Tensor, torch.Tensor] = None):
if hidden_state is None:
return self.model(x)
data = self.model(x, hidden_state=hidden_state)
hidden_state = self.model.last_hidden_state
return data, hidden_state
model = ONNXActor(self.actor)
kwargs = dict(
export_params=True,
opset_version=11,
verbose=True,
dynamic_axes={},
)
kwargs["input_names"] = ["observations"]
kwargs["output_names"] = ["actions"]
args = torch.zeros(1, self._actor_input_size)
if self.actor.recurrent:
hidden_state = (
torch.zeros(self.actor._features[0].num_layers, 1, self.actor._features[0].hidden_size),
torch.zeros(self.actor._features[0].num_layers, 1, self.actor._features[0].hidden_size),
)
args = (args, {"hidden_state": hidden_state})
return model, args, kwargs
def draw_random_actions(self, obs: torch.Tensor, env_info: Dict[str, Any]) -> Tuple[torch.Tensor, Dict]:
actions, data = super().draw_random_actions(obs, env_info)
actor_obs, critic_obs = self._process_observations(obs, env_info)
data.update({"actor_observations": actor_obs.clone(), "critic_observations": critic_obs.clone()})
return actions, data
def get_inference_policy(self, device=None) -> Callable:
self.to(device)
self.eval_mode()
if self.actor.recurrent:
self.actor.reset_full_hidden_state(batch_size=self.env.num_envs)
if self.critic.recurrent:
self.critic.reset_full_hidden_state(batch_size=self.env.num_envs)
def policy(obs, env_info=None):
with torch.inference_mode():
obs, _ = self._process_observations(obs, env_info)
actions = self._process_actions(self.actor.forward(obs))
return actions
return policy
def process_transition(
self,
observations: torch.Tensor,
environment_info: Dict[str, Any],
actions: torch.Tensor,
rewards: torch.Tensor,
next_observations: torch.Tensor,
next_environment_info: torch.Tensor,
dones: torch.Tensor,
data: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
if "actor_observations" in data and "critic_observations" in data:
actor_obs, critic_obs = data["actor_observations"], data["critic_observations"]
else:
actor_obs, critic_obs = self._process_observations(observations, environment_info)
if "next_actor_observations" in data and "next_critic_observations" in data:
next_actor_obs, next_critic_obs = data["next_actor_observations"], data["next_critic_observations"]
else:
next_actor_obs, next_critic_obs = self._process_observations(next_observations, next_environment_info)
transition = {
"actions": actions,
"actor_observations": actor_obs,
"critic_observations": critic_obs,
"dones": dones,
"next_actor_observations": next_actor_obs,
"next_critic_observations": next_critic_obs,
"rewards": squeeze_preserve_batch(rewards),
"timeouts": self._extract_timeouts(next_environment_info),
}
transition.update(data)
for key, value in transition.items():
transition[key] = value.detach().clone()
return transition
@property
def recurrent(self) -> bool:
return self._recurrent
def register_terminations(self, terminations: torch.Tensor) -> None:
pass
@abstractmethod
def update(self, dataset: Dataset) -> Dict[str, Union[float, torch.Tensor]]:
with torch.inference_mode():
self.storage.append(self._process_dataset(dataset))
def _critic_input(self, observations, actions) -> torch.Tensor:
"""Combines observations and actions into a tensor that can be fed into the critic network.
Args:
observations (torch.Tensor): The critic observations.
actions (torch.Tensor): The actions computed by the actor.
Returns:
A torch.Tensor of inputs for the critic network.
"""
return torch.cat((observations, actions), dim=-1)
def _extract_timeouts(self, next_environment_info):
"""Extracts timeout information from the transition next state information dictionary.
Args:
next_environment_info (Dict[str, Any]): The transition next state information dictionary.
Returns:
A torch.Tensor vector of actor timeouts.
"""
if "time_outs" not in next_environment_info:
return torch.zeros(self.env.num_envs, device=self.device)
timeouts = squeeze_preserve_batch(next_environment_info["time_outs"].to(self.device))
return timeouts
def _process_dataset(self, dataset: Dataset) -> Dataset:
"""Processes a dataset before it is added to the replay memory.
Handles n-step returns and timeouts.
TODO: This function seems to be a bottleneck in the training pipeline - speed it up!
Args:
dataset (Dataset): The dataset to process.
Returns:
A Dataset object containing the processed data.
"""
assert len(dataset) >= self._return_steps
dataset = self._stored_dataset + dataset
length = len(dataset) - self._return_steps + 1
self._stored_dataset = dataset[length:]
for idx in range(len(dataset) - self._return_steps + 1):
dead_idx = torch.zeros_like(dataset[idx]["dones"])
rewards = torch.zeros_like(dataset[idx]["rewards"])
for k in range(self._return_steps):
data = dataset[idx + k]
alive_idx = (dead_idx == 0).nonzero()
critic_predictions = self.critic.forward(
self._critic_input(
data["critic_observations"].clone().to(self.device),
data["actions"].clone().to(self.device),
)
)
rewards[alive_idx] += self._discount_factor**k * data["rewards"][alive_idx]
rewards[alive_idx] += (
self._discount_factor ** (k + 1) * data["timeouts"][alive_idx] * critic_predictions[alive_idx]
)
dead_idx += data["dones"]
dead_idx += data["timeouts"]
dataset[idx]["rewards"] = rewards
return dataset[:length]
def _process_observations(
self, observations: torch.Tensor, environment_info: Dict[str, Any] = None
) -> Tuple[torch.Tensor, ...]:
"""Processes observations returned by the environment to extract actor and critic observations.
Args:
observations (torch.Tensor): normal environment observations.
environment_info (Dict[str, Any]): A dictionary of additional environment information.
Returns:
A tuple containing two torch.Tensors with actor and critic observations, respectively.
"""
try:
critic_obs = environment_info["observations"]["critic"]
except (KeyError, TypeError):
critic_obs = observations
actor_obs, critic_obs = observations.to(self.device), critic_obs.to(self.device)
return actor_obs, critic_obs
def _register_actor_network_kwargs(self, **kwargs) -> None:
"""Function to configure actor network in child classes before calling super().__init__()."""
if not hasattr(self, "_actor_network_kwargs"):
self._actor_network_kwargs = dict()
self._actor_network_kwargs.update(**kwargs)
def _register_critic_network_kwargs(self, **kwargs) -> None:
"""Function to configure critic network in child classes before calling super().__init__()."""
if not hasattr(self, "_critic_network_kwargs"):
self._critic_network_kwargs = dict()
self._critic_network_kwargs.update(**kwargs)
def _update_target(self, online: torch.nn.Module, target: torch.nn.Module) -> None:
"""Updates the target network using the polyak factor.
Args:
online (torch.nn.Module): The online network.
target (torch.nn.Module): The target network.
"""
for op, tp in zip(online.parameters(), target.parameters()):
tp.data.copy_((1.0 - self._polyak_factor) * op.data + self._polyak_factor * tp.data)

197
rsl_rl/algorithms/agent.py Normal file
View File

@ -0,0 +1,197 @@
from __future__ import annotations
from abc import ABC, abstractmethod
import numpy as np
import torch
from typing import Any, Callable, Dict, Tuple, Union
from rsl_rl.env import VecEnv
from rsl_rl.storage.storage import Dataset
from rsl_rl.utils.benchmarkable import Benchmarkable
from rsl_rl.utils.serializable import Serializable
from rsl_rl.utils.utils import environment_dimensions
class Agent(ABC, Benchmarkable, Serializable):
def __init__(
self,
env: VecEnv,
action_max: float = np.inf,
action_min: float = -np.inf,
benchmark: bool = False,
device: str = "cpu",
gamma: float = 0.99,
):
"""Creates an agent.
Args:
env (VecEnv): The envrionment of the agent.
action_max (float): The maximum action value.
action_min (float): The minimum action value.
bechmark (bool): Whether to benchmark runtime.
device (str): The device to use for computation.
gamma (float): The environment discount factor.
"""
super().__init__()
self.env = env
self.device = device
self.storage = None
self._action_max = action_max
self._action_min = action_min
self._discount_factor = gamma
self._register_serializable("_action_max", "_action_min", "_discount_factor")
dimensions = environment_dimensions(self.env)
self._action_size = dimensions["actions"]
self._register_serializable("_action_size")
if self._action_min > -np.inf and self._action_max < np.inf:
self._rand_scale = self._action_max - self._action_min
self._rand_offset = self._action_min
else:
self._rand_scale = 2.0
self._rand_offset = -1.0
self._bm_toggle(benchmark)
@abstractmethod
def draw_actions(
self, obs: torch.Tensor, env_info: Dict[str, Any]
) -> Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]:
"""Draws actions from the action space.
Args:
obs (torch.Tensor): The observations for which to draw actions.
env_info (Dict[str, Any]): The environment information for the observations.
Returns:
A tuple containing the actions and the data dictionary.
"""
pass
def draw_random_actions(
self, obs: torch.Tensor, env_info: Dict[str, Any]
) -> Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]:
"""Draws random actions from the action space.
Args:
obs (torch.Tensor): The observations to include in the data dictionary.
env_info (Dict[str, Any]): The environment information to include in the data dictionary.
Returns:
A tuple containing the random actions and the data dictionary.
"""
actions = self._process_actions(
self._rand_offset + self._rand_scale * torch.rand(self.env.num_envs, self._action_size)
)
return actions, {}
@abstractmethod
def eval_mode(self) -> Agent:
"""Sets the agent to evaluation mode."""
return self
@abstractmethod
def export_onnx(self) -> Tuple[torch.nn.Module, torch.Tensor, Dict]:
"""Exports the agent's policy network to ONNX format.
Returns:
A tuple containing the ONNX model, the input arguments, and the keyword arguments.
"""
pass
@property
def gamma(self) -> float:
return self._discount_factor
@abstractmethod
def get_inference_policy(self, device: str = None) -> Callable:
"""Returns a function that computes actions from observations without storing gradients.
Args:
device (torch.device): The device to use for inference.
Returns:
A function that computes actions from observations.
"""
pass
@property
def initialized(self) -> bool:
"""Whether the agent has been initialized."""
return self.storage.initialized
@abstractmethod
def process_transition(
self,
observations: torch.Tensor,
environement_info: Dict[str, Any],
actions: torch.Tensor,
rewards: torch.Tensor,
next_observations: torch.Tensor,
next_environment_info: torch.Tensor,
dones: torch.Tensor,
data: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
"""Processes a transition before it is added to the replay memory.
Args:
observations (torch.Tensor): The observations from the environment.
environment_info (Dict[str, Any]): The environment information.
actions (torch.Tensor): The actions computed by the actor.
rewards (torch.Tensor): The rewards from the environment.
next_observations (torch.Tensor): The next observations from the environment.
next_environment_info (Dict[str, Any]): The next environment information.
dones (torch.Tensor): The done flags from the environment.
data (Dict[str, torch.Tensor]): Additional data to include in the transition.
Returns:
A dictionary containing the processed transition.
"""
pass
@abstractmethod
def register_terminations(self, terminations: torch.Tensor) -> None:
"""Registers terminations with the actor critic agent.
Args:
terminations (torch.Tensor): A tensor of indicator values for each environment.
"""
pass
@abstractmethod
def to(self, device: str) -> Agent:
"""Transfers agent parameters to device."""
self.device = device
return self
@abstractmethod
def train_mode(self) -> Agent:
"""Sets the agent to training mode."""
return self
@abstractmethod
def update(self, dataset: Dataset) -> Dict[str, Union[float, torch.Tensor]]:
"""Updates the agent's parameters.
Args:
dataset (Dataset): The dataset from which to update the agent.
Returns:
A dictionary containing the loss values.
"""
pass
def _process_actions(self, actions: torch.Tensor) -> torch.Tensor:
"""Processes actions produced by the agent.
Args:
actions (torch.Tensor): The raw actions.
Returns:
A torch.Tensor containing the processed actions.
"""
actions = actions.reshape(-1, self._action_size)
actions = actions.clamp(self._action_min, self._action_max)
actions = actions.to(self.device)
return actions

168
rsl_rl/algorithms/d4pg.py Normal file
View File

@ -0,0 +1,168 @@
from __future__ import annotations
import torch
from typing import Dict, Union
from rsl_rl.algorithms.dpg import AbstractDPG
from rsl_rl.env import VecEnv
from rsl_rl.storage.storage import Dataset
from rsl_rl.modules import CategoricalNetwork, Network
class D4PG(AbstractDPG):
"""Distributed Distributional Deep Deterministic Policy Gradients algorithm.
This is an implementation of the D4PG algorithm by Barth-Maron et. al. for vectorized environments.
Paper: https://arxiv.org/pdf/1804.08617.pdf
"""
def __init__(
self,
env: VecEnv,
actor_lr: float = 1e-4,
atom_count: int = 51,
critic_activations: list = ["relu", "relu", "relu"],
critic_lr: float = 1e-3,
target_update_delay: int = 2,
value_max: float = 10.0,
value_min: float = -10.0,
**kwargs,
) -> None:
"""
Args:
env (VecEnv): A vectorized environment.
actor_lr (float): The learning rate for the actor network.
atom_count (int): The number of atoms to use for the categorical distribution.
critic_activations (list): A list of activation functions to use for the critic network.
critic_lr (float): The learning rate for the critic network.
target_update_delay (int): The number of steps to wait before updating the target networks.
value_max (float): The maximum value for the categorical distribution.
value_min (float): The minimum value for the categorical distribution.
"""
kwargs["critic_activations"] = critic_activations
super().__init__(env, **kwargs)
self._atom_count = atom_count
self._target_update_delay = target_update_delay
self._value_max = value_max
self._value_min = value_min
self._register_serializable("_atom_count", "_target_update_delay", "_value_max", "_value_min")
self.actor = Network(self._actor_input_size, self._action_size, **self._actor_network_kwargs)
self.critic = CategoricalNetwork(
self._critic_input_size,
1,
atom_count=atom_count,
value_max=value_max,
value_min=value_min,
**self._critic_network_kwargs,
)
self.target_actor = Network(self._actor_input_size, self._action_size, **self._actor_network_kwargs)
self.target_critic = CategoricalNetwork(
self._critic_input_size,
1,
atom_count=atom_count,
value_max=value_max,
value_min=value_min,
**self._critic_network_kwargs,
)
self.target_actor.load_state_dict(self.actor.state_dict())
self.target_critic.load_state_dict(self.critic.state_dict())
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)
self._register_serializable(
"actor", "critic", "target_actor", "target_critic", "actor_optimizer", "critic_optimizer"
)
self._update_step = 0
self._register_serializable("_update_step")
self.to(self.device)
def eval_mode(self) -> D4PG:
super().eval_mode()
self.actor.eval()
self.critic.eval()
self.target_actor.eval()
self.target_critic.eval()
return self
def to(self, device: str) -> D4PG:
super().to(device)
self.actor.to(device)
self.critic.to(device)
self.target_actor.to(device)
self.target_critic.to(device)
return self
def train_mode(self) -> D4PG:
super().train_mode()
self.actor.train()
self.critic.train()
self.target_actor.train()
self.target_critic.train()
return self
def update(self, dataset: Dataset) -> Dict[str, Union[float, torch.Tensor]]:
super().update(dataset)
if not self.initialized:
return {}
total_actor_loss = torch.zeros(self._batch_count)
total_critic_loss = torch.zeros(self._batch_count)
for idx, batch in enumerate(self.storage.batch_generator(self._batch_size, self._batch_count)):
actor_obs = batch["actor_observations"]
critic_obs = batch["critic_observations"]
actions = batch["actions"].reshape(self._batch_size, -1)
rewards = batch["rewards"]
actor_next_obs = batch["next_actor_observations"]
critic_next_obs = batch["next_critic_observations"]
dones = batch["dones"]
predictions = self.critic.forward(self._critic_input(critic_obs, actions), distribution=True).squeeze()
target_actor_prediction = self._process_actions(self.target_actor.forward(actor_next_obs))
target_probabilities = self.target_critic.forward(
self._critic_input(critic_next_obs, target_actor_prediction), distribution=True
).squeeze()
targets = self.target_critic.compute_targets(rewards, dones, self._discount_factor)
critic_loss = self.target_critic.categorical_loss(predictions, target_probabilities, targets)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
evaluation = self.critic.forward(
self._critic_input(critic_obs, self._process_actions(self.actor.forward(actor_obs)))
)
actor_loss = -evaluation.mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
if self._update_step % self._target_update_delay == 0:
self._update_target(self.actor, self.target_actor)
self._update_target(self.critic, self.target_critic)
self._update_step += 1
total_actor_loss[idx] = actor_loss.item()
total_critic_loss[idx] = critic_loss.item()
stats = {"actor": total_actor_loss.mean().item(), "critic": total_critic_loss.mean().item()}
return stats

125
rsl_rl/algorithms/ddpg.py Normal file
View File

@ -0,0 +1,125 @@
from __future__ import annotations
import torch
from torch import optim
from typing import Dict, Union
from rsl_rl.algorithms.dpg import AbstractDPG
from rsl_rl.env import VecEnv
from rsl_rl.modules.network import Network
from rsl_rl.storage.storage import Dataset
class DDPG(AbstractDPG):
"""Deep Deterministic Policy Gradients algorithm.
This is an implementation of the DDPG algorithm by Lillicrap et. al. for vectorized environments.
Paper: https://arxiv.org/pdf/1509.02971.pdf
"""
def __init__(
self,
env: VecEnv,
actor_lr: float = 1e-4,
critic_lr: float = 1e-3,
**kwargs,
) -> None:
super().__init__(env, **kwargs)
self.actor = Network(self._actor_input_size, self._action_size, **self._actor_network_kwargs)
self.critic = Network(self._critic_input_size, 1, **self._critic_network_kwargs)
self.target_actor = Network(self._actor_input_size, self._action_size, **self._actor_network_kwargs)
self.target_critic = Network(self._critic_input_size, 1, **self._critic_network_kwargs)
self.target_actor.load_state_dict(self.actor.state_dict())
self.target_critic.load_state_dict(self.critic.state_dict())
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=critic_lr)
self._register_serializable(
"actor", "critic", "target_actor", "target_critic", "actor_optimizer", "critic_optimizer"
)
self.to(self.device)
def eval_mode(self) -> DDPG:
super().eval_mode()
self.actor.eval()
self.critic.eval()
self.target_actor.eval()
self.target_critic.eval()
return self
def to(self, device: str) -> DDPG:
"""Transfers agent parameters to device."""
super().to(device)
self.actor.to(device)
self.critic.to(device)
self.target_actor.to(device)
self.target_critic.to(device)
return self
def train_mode(self) -> DDPG:
super().train_mode()
self.actor.train()
self.critic.train()
self.target_actor.train()
self.target_critic.train()
return self
def update(self, dataset: Dataset) -> Dict[str, Union[float, torch.Tensor]]:
super().update(dataset)
if not self.initialized:
return {}
total_actor_loss = torch.zeros(self._batch_count)
total_critic_loss = torch.zeros(self._batch_count)
for idx, batch in enumerate(self.storage.batch_generator(self._batch_size, self._batch_count)):
actor_obs = batch["actor_observations"]
critic_obs = batch["critic_observations"]
actions = batch["actions"]
rewards = batch["rewards"]
actor_next_obs = batch["next_actor_observations"]
critic_next_obs = batch["next_critic_observations"]
dones = batch["dones"]
target_actor_prediction = self._process_actions(self.target_actor.forward(actor_next_obs))
target_critic_prediction = self.target_critic.forward(
self._critic_input(critic_next_obs, target_actor_prediction)
)
target = rewards + self._discount_factor * (1 - dones) * target_critic_prediction
prediction = self.critic.forward(self._critic_input(critic_obs, actions))
critic_loss = (prediction - target).pow(2).mean()
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
evaluation = self.critic.forward(
self._critic_input(critic_obs, self._process_actions(self.actor.forward(actor_obs)))
)
actor_loss = -evaluation.mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
self._update_target(self.actor, self.target_actor)
self._update_target(self.critic, self.target_critic)
total_actor_loss[idx] = actor_loss.item()
total_critic_loss[idx] = critic_loss.item()
stats = {"actor": total_actor_loss.mean().item(), "critic": total_critic_loss.mean().item()}
return stats

49
rsl_rl/algorithms/dpg.py Normal file
View File

@ -0,0 +1,49 @@
import torch
from typing import Any, Dict, Tuple, Union
from rsl_rl.algorithms.actor_critic import AbstractActorCritic
from rsl_rl.env import VecEnv
from rsl_rl.storage.replay_storage import ReplayStorage
from rsl_rl.storage.storage import Dataset
class AbstractDPG(AbstractActorCritic):
def __init__(
self, env: VecEnv, action_noise_scale: float = 0.1, storage_initial_size=0, storage_size=100000, **kwargs
):
"""
Args:
env (VecEnv): A vectorized environment.
action_noise_scale (float): The scale of the gaussian action noise.
storage_initial_size (int): Initial size of the replay storage.
storage_size (int): Maximum size of the replay storage.
"""
assert action_noise_scale > 0
super().__init__(env, **kwargs)
self.storage = ReplayStorage(
self.env.num_envs, storage_size, device=self.device, initial_size=storage_initial_size
)
self._register_serializable("storage")
self._action_noise_scale = action_noise_scale
self._register_serializable("_action_noise_scale")
def draw_actions(
self, obs: torch.Tensor, env_info: Dict[str, Any]
) -> Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]:
actor_obs, critic_obs = self._process_observations(obs, env_info)
actions = self.actor.forward(actor_obs)
noise = torch.normal(torch.zeros_like(actions), torch.ones_like(actions) * self._action_noise_scale)
noisy_actions = self._process_actions(actions + noise)
data = {"actor_observations": actor_obs.clone(), "critic_observations": critic_obs.clone()}
return noisy_actions, data
def register_terminations(self, terminations: torch.Tensor) -> None:
pass

327
rsl_rl/algorithms/dppo.py Normal file
View File

@ -0,0 +1,327 @@
import torch
from torch import nn
from typing import Dict, List, Tuple, Type, Union
from rsl_rl.algorithms.ppo import PPO
from rsl_rl.distributions import QuantileDistribution
from rsl_rl.env import VecEnv
from rsl_rl.utils.benchmarkable import Benchmarkable
from rsl_rl.utils.recurrency import trajectories_to_transitions, transitions_to_trajectories
from rsl_rl.modules import ImplicitQuantileNetwork, QuantileNetwork
from rsl_rl.storage.storage import Dataset
class DPPO(PPO):
"""Distributional Proximal Policy Optimization algorithm.
This algorithm is an extension of PPO that uses a distributional method (either QR-DQN or IQN) to estimate the
value function.
QR-DQN Paper: https://arxiv.org/pdf/1710.10044.pdf
IQN Paper: https://arxiv.org/pdf/1806.06923.pdf
The implementation works with recurrent neural networks. We further implement Sample-Replacement SR(lambda) for the
value target computation, as described by Nam et. al. in https://arxiv.org/pdf/2105.11366.pdf.
"""
critic_network: Type[nn.Module] = QuantileNetwork
_alg_features = dict(recurrent=True)
value_loss_energy = "sample_energy"
value_loss_l1 = "quantile_l1"
value_loss_huber = "quantile_huber"
network_qrdqn = "qrdqn"
network_iqn = "iqn"
networks = {network_qrdqn: QuantileNetwork, network_iqn: ImplicitQuantileNetwork}
values_losses = {
network_qrdqn: {
value_loss_energy: QuantileNetwork.sample_energy_loss,
value_loss_l1: QuantileNetwork.quantile_l1_loss,
value_loss_huber: QuantileNetwork.quantile_huber_loss,
},
network_iqn: {
value_loss_energy: ImplicitQuantileNetwork.sample_energy_loss,
},
}
def __init__(
self,
env: VecEnv,
critic_activations: List[str] = ["relu", "relu", "relu"],
critic_network: str = network_qrdqn,
iqn_action_samples: int = 32,
iqn_embedding_size: int = 64,
iqn_feature_layers: int = 1,
iqn_value_samples: int = 8,
qrdqn_quantile_count: int = 200,
value_lambda: float = 0.95,
value_loss: str = value_loss_l1,
value_loss_kwargs: Dict = {},
value_measure: str = None,
value_measure_adaptation: Union[Tuple, None] = None,
value_measure_kwargs: Dict = {},
**kwargs,
):
"""
Args:
env (VecEnv): A vectorized environment.
critic_activations (List[str]): A list of activations to use for the critic network.
critic_network (str): The critic network to use.
iqn_action_samples (int): The number of samples to use for the critic IQN network when acting.
iqn_embedding_size (int): The embedding size to use for the critic IQN network.
iqn_feature_layers (int): The number of feature layers to use for the critic IQN network.
iqn_value_samples (int): The number of samples to use for the critic IQN network when computing the value.
qrdqn_quantile_count (int): The number of quantiles to use for the critic QR network.
value_lambda (float): The lambda parameter for the SR(lambda) value target computation.
value_loss (str): The loss function to use for the critic network.
value_loss_kwargs (Dict): Keyword arguments for computing the value loss.
value_measure (str): The probability measure to apply to the critic network output distribution when
updating the policy.
value_measure_adaptation (Union[Tuple, None]): Controls adaptation of the value measure. If None, no
adaptation is performed. If a tuple, the tuple specifies the observations that are passed to the value
measure.
value_measure_kwargs (Dict): The keyword arguments to pass to the value measure.
"""
self._register_critic_network_kwargs(measure=value_measure, measure_kwargs=value_measure_kwargs)
self._critic_network_name = critic_network
self.critic_network = self.networks[self._critic_network_name]
if self._critic_network_name == self.network_qrdqn:
self._register_critic_network_kwargs(quantile_count=qrdqn_quantile_count)
elif self._critic_network_name == self.network_iqn:
self._register_critic_network_kwargs(feature_layers=iqn_feature_layers, embedding_size=iqn_embedding_size)
kwargs["critic_activations"] = critic_activations
if value_measure_adaptation is not None:
# Value measure adaptation observations are not passed to the critic network.
kwargs["_critic_input_size_delta"] = (
kwargs["_critic_input_size_delta"] if "_critic_input_size_delta" in kwargs else 0
) - len(value_measure_adaptation)
super().__init__(env, **kwargs)
self._value_lambda = value_lambda
self._value_loss_name = value_loss
self._register_serializable("_value_lambda", "_value_loss_name")
assert (
self._value_loss_name in self.values_losses[self._critic_network_name]
), f"Value loss '{self._value_loss_name}' is not supported for network '{self._critic_network_name}'."
value_loss_func = self.values_losses[critic_network][self._value_loss_name]
self._value_loss = lambda *args, **kwargs: value_loss_func(self.critic, *args, **kwargs)
if value_loss == self.value_loss_energy:
value_loss_kwargs["sample_count"] = (
value_loss_kwargs["sample_count"] if "sample_count" in value_loss_kwargs else 100
)
self._value_loss_kwargs = value_loss_kwargs
self._register_serializable("_value_loss_kwargs")
self._value_measure_adaptation = value_measure_adaptation
self._register_serializable("_value_measure_adaptation")
if self._critic_network_name == self.network_iqn:
self._iqn_action_samples = iqn_action_samples
self._iqn_value_samples = iqn_value_samples
self._register_serializable("_iqn_action_samples", "_iqn_value_samples")
def _critic_input(self, observations, actions=None) -> torch.Tensor:
mask, shape = self._get_critic_obs_mask(observations)
processed_observations = observations[mask].reshape(*shape)
return processed_observations
def _get_critic_obs_mask(self, observations):
mask = torch.ones_like(observations).bool()
if self._value_measure_adaptation is not None:
mask[:, self._value_measure_adaptation] = False
shape = (observations.shape[0], self._critic_input_size)
return mask, shape
def _process_quants(self, x):
if self._value_loss_name == self.value_loss_energy:
quants, idx = QuantileDistribution(x).sample(self._value_loss_kwargs["sample_count"])
else:
quants, idx = x, None
return quants, idx
@Benchmarkable.register
def process_transition(self, *args) -> Dict[str, torch.Tensor]:
transition = super(PPO, self).process_transition(*args)
if self.recurrent:
transition["critic_state_h"] = self.critic.hidden_state[0].detach()
transition["critic_state_c"] = self.critic.hidden_state[1].detach()
transition["full_critic_observations"] = transition["critic_observations"].detach()
transition["full_next_critic_observations"] = transition["next_critic_observations"].detach()
mask, shape = self._get_critic_obs_mask(transition["critic_observations"])
transition["critic_observations"] = transition["critic_observations"][mask].reshape(*shape)
transition["next_critic_observations"] = transition["next_critic_observations"][mask].reshape(*shape)
critic_kwargs = (
{"sample_count": self._iqn_action_samples} if self._critic_network_name == self.network_iqn else {}
)
transition["values"] = self.critic.forward(
transition["critic_observations"],
measure_args=self._extract_value_measure_adaptation(transition["full_critic_observations"]),
**critic_kwargs,
).detach()
if self._critic_network_name == self.network_iqn:
# For IQN, we sample new (undistorted) quantiles for computing the value update
critic_kwargs = (
{"hidden_state": (transition["critic_state_h"], transition["critic_state_c"])} if self.recurrent else {}
)
self.critic.forward(
transition["critic_observations"],
sample_count=self._iqn_value_samples,
use_measure=False,
**critic_kwargs,
).detach()
transition["value_taus"] = self.critic.last_taus.detach().reshape(transition["values"].shape[0], -1)
transition["value_quants"] = self.critic.last_quantiles.detach().reshape(transition["values"].shape[0], -1)
if self.recurrent:
transition["critic_next_state_h"] = self.critic.hidden_state[0].detach()
transition["critic_next_state_c"] = self.critic.hidden_state[1].detach()
return transition
@Benchmarkable.register
def _compute_value_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
critic_kwargs = (
{"sample_count": self._iqn_value_samples, "taus": batch["value_target_taus"], "use_measure": False}
if self._critic_network_name == self.network_iqn
else {}
)
if self.recurrent:
observations, data = transitions_to_trajectories(batch["critic_observations"], batch["dones"])
hidden_state_h, _ = transitions_to_trajectories(batch["critic_state_h"], batch["dones"])
hidden_state_c, _ = transitions_to_trajectories(batch["critic_state_c"], batch["dones"])
hidden_states = (hidden_state_h[0].transpose(0, 1), hidden_state_c[0].transpose(0, 1))
if self._critic_network_name == self.network_iqn:
critic_kwargs["taus"], _ = transitions_to_trajectories(critic_kwargs["taus"], batch["dones"])
trajectory_evaluations = self.critic.forward(
observations, distribution=True, hidden_state=hidden_states, **critic_kwargs
)
trajectory_evaluations = trajectory_evaluations.reshape(*observations.shape[:-1], -1)
predictions = trajectories_to_transitions(trajectory_evaluations, data)
else:
predictions = self.critic.forward(batch["critic_observations"], distribution=True, **critic_kwargs)
value_loss = self._value_loss(self._process_quants(predictions)[0], batch["value_target_quants"])
return value_loss
def _extract_value_measure_adaptation(self, observations: torch.Tensor) -> Tuple[torch.Tensor]:
if self._value_measure_adaptation is None:
return tuple()
relevant_observations = observations[:, self._value_measure_adaptation]
measure_adaptations = torch.tensor_split(relevant_observations, relevant_observations.shape[1], dim=1)
return measure_adaptations
@Benchmarkable.register
def _process_dataset(self, dataset: Dataset) -> Dataset:
rewards = torch.stack([entry["rewards"] for entry in dataset])
dones = torch.stack([entry["dones"] for entry in dataset]).float()
timeouts = torch.stack([entry["timeouts"] for entry in dataset])
values = torch.stack([entry["values"] for entry in dataset])
value_quants_idx = [self._process_quants(entry["value_quants"]) for entry in dataset]
value_quants = torch.stack([entry[0] for entry in value_quants_idx])
critic_kwargs = (
{"hidden_state": (dataset[-1]["critic_state_h"], dataset[-1]["critic_state_c"])} if self.recurrent else {}
)
if self._critic_network_name == self.network_iqn:
critic_kwargs["sample_count"] = self._iqn_action_samples
measure_args = self._extract_value_measure_adaptation(dataset[-1]["full_next_critic_observations"])
next_values = self.critic.forward(
dataset[-1]["next_critic_observations"], measure_args=measure_args, **critic_kwargs
)
if self._critic_network_name == self.network_iqn:
# For IQN, we sample new (undistorted) quantiles for computing the value update
critic_kwargs["sample_count"] = self._iqn_value_samples
self.critic.forward(
dataset[-1]["next_critic_observations"],
use_measure=False,
**critic_kwargs,
)
final_value_taus = self.critic.last_taus
value_taus = torch.stack(
[
torch.take_along_dim(dataset[i]["value_taus"], value_quants_idx[i][1], -1)
for i in range(len(dataset))
]
)
final_value_quants = self.critic.last_quantiles
# Timeout bootstrapping for rewards.
rewards += self.gamma * timeouts * values
# Compute advantages and value target quantiles
next_values = torch.cat((values[1:], next_values.unsqueeze(0)), dim=0)
deltas = (rewards + (1 - dones) * self.gamma * next_values - values).reshape(-1, self.env.num_envs)
advantages = torch.zeros((len(dataset) + 1, self.env.num_envs), device=self.device)
next_value_quants, idx = self._process_quants(final_value_quants)
value_target_quants = torch.zeros(len(dataset), *next_value_quants.shape, device=self.device)
if self._critic_network_name == self.network_iqn:
value_target_taus = torch.zeros(len(dataset) + 1, *next_value_quants.shape, device=self.device)
value_target_taus[-1] = torch.take_along_dim(final_value_taus, idx, -1)
for step in reversed(range(len(dataset))):
not_terminal = 1.0 - dones[step]
not_terminal_ = not_terminal.unsqueeze(-1)
advantages[step] = deltas[step] + (1.0 - dones[step]) * self.gamma * self._gae_lambda * advantages[step + 1]
value_target_quants[step] = rewards[step].unsqueeze(-1) + not_terminal_ * self.gamma * next_value_quants
preserved_value_quants = not_terminal_.bool() * (
torch.rand(next_value_quants.shape, device=self.device) < self._value_lambda
)
next_value_quants = torch.where(preserved_value_quants, value_target_quants[step], value_quants[step])
if self._critic_network_name == self.network_iqn:
value_target_taus[step] = torch.where(
preserved_value_quants, value_target_taus[step + 1], value_taus[step]
)
advantages = advantages[:-1]
if self._critic_network_name == self.network_iqn:
value_target_taus = value_target_taus[:-1]
# Normalize advantages and pack into dataset structure.
amean, astd = advantages.mean(), torch.nan_to_num(advantages.std())
for step in range(len(dataset)):
dataset[step]["advantages"] = advantages[step]
dataset[step]["normalized_advantages"] = (advantages[step] - amean) / (astd + 1e-8)
dataset[step]["value_target_quants"] = value_target_quants[step]
if self._critic_network_name == self.network_iqn:
dataset[step]["value_target_taus"] = value_target_taus[step]
return dataset

75
rsl_rl/algorithms/dsac.py Normal file
View File

@ -0,0 +1,75 @@
import torch
from torch import nn
from typing import Tuple, Type
from rsl_rl.algorithms.sac import SAC
from rsl_rl.env import VecEnv
from rsl_rl.modules.quantile_network import QuantileNetwork
class DSAC(SAC):
"""Deep Soft Actor Critic (DSAC) algorithm.
This is an implementation of the DSAC algorithm by Ma et. al. for vectorized environments.
Paper: https://arxiv.org/pdf/2004.14547.pdf
The implementation inherits automatic tuning of the temperature parameter (alpha) and tanh action scaling from
the SAC implementation.
"""
critic_network: Type[nn.Module] = QuantileNetwork
def __init__(self, env: VecEnv, critic_activations=["relu", "relu", "relu"], quantile_count=200, **kwargs):
"""
Args:
env (VecEnv): A vectorized environment.
critic_activations (list): A list of activation functions to use for the critic network.
quantile_count (int): The number of quantiles to use for the critic QR network.
"""
self._quantile_count = quantile_count
self._register_critic_network_kwargs(quantile_count=self._quantile_count)
kwargs["critic_activations"] = critic_activations
super().__init__(env, **kwargs)
def _update_critic(
self,
critic_obs: torch.Tensor,
actions: torch.Tensor,
rewards: torch.Tensor,
dones: torch.Tensor,
actor_next_obs: torch.Tensor,
critic_next_obs: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
target_action, target_action_logp = self._sample_action(actor_next_obs)
target_critic_input = self._critic_input(critic_next_obs, target_action)
target_critic_prediction_1 = self.target_critic_1.forward(target_critic_input, distribution=True)
target_critic_prediction_2 = self.target_critic_2.forward(target_critic_input, distribution=True)
target_critic_prediction = torch.minimum(target_critic_prediction_1, target_critic_prediction_2)
next_soft_q = target_critic_prediction - self.alpha * target_action_logp.unsqueeze(-1).repeat(
1, self._quantile_count
)
target = (rewards.reshape(-1, 1) + self._discount_factor * (1 - dones).reshape(-1, 1) * next_soft_q).detach()
critic_input = self._critic_input(critic_obs, actions).detach()
critic_1_prediction = self.critic_1.forward(critic_input, distribution=True)
critic_1_loss = self.critic_1.quantile_huber_loss(critic_1_prediction, target)
self.critic_1_optimizer.zero_grad()
critic_1_loss.backward()
nn.utils.clip_grad_norm_(self.critic_1.parameters(), self._gradient_clip)
self.critic_1_optimizer.step()
critic_2_prediction = self.critic_2.forward(critic_input, distribution=True)
critic_2_loss = self.critic_2.quantile_huber_loss(critic_2_prediction, target)
self.critic_2_optimizer.zero_grad()
critic_2_loss.backward()
nn.utils.clip_grad_norm_(self.critic_2.parameters(), self._gradient_clip)
self.critic_2_optimizer.step()
return critic_1_loss, critic_2_loss

57
rsl_rl/algorithms/dtd3.py Normal file
View File

@ -0,0 +1,57 @@
from __future__ import annotations
import torch
from torch import nn
from typing import Type
from rsl_rl.algorithms.td3 import TD3
from rsl_rl.env import VecEnv
from rsl_rl.modules import QuantileNetwork
class DTD3(TD3):
"""Distributional Twin-Delayed Deep Deterministic Policy Gradients algorithm.
This is an implementation of the TD3 algorithm by Fujimoto et. al. for vectorized environments using a QR-DQN
critic.
"""
critic_network: Type[nn.Module] = QuantileNetwork
def __init__(
self,
env: VecEnv,
quantile_count: int = 200,
**kwargs,
) -> None:
self._quantile_count = quantile_count
self._register_critic_network_kwargs(quantile_count=self._quantile_count)
super().__init__(env, **kwargs)
def _update_critic(self, critic_obs, actions, rewards, dones, actor_next_obs, critic_next_obs):
target_action = self._apply_action_noise(self.target_actor.forward(actor_next_obs), clip=True)
target_critic_input = self._critic_input(critic_next_obs, target_action)
target_critic_prediction_1 = self.target_critic_1.forward(target_critic_input, distribution=True)
target_critic_prediction_2 = self.target_critic_2.forward(target_critic_input, distribution=True)
target_critic_prediction = torch.minimum(target_critic_prediction_1, target_critic_prediction_2)
target = (
rewards.reshape(-1, 1) + self._discount_factor * (1 - dones).reshape(-1, 1) * target_critic_prediction
).detach()
critic_input = self._critic_input(critic_obs, actions).detach()
critic_1_prediction = self.critic_1.forward(critic_input, distribution=True)
critic_1_loss = self.critic_1.quantile_huber_loss(critic_1_prediction, target)
self.critic_1_optimizer.zero_grad()
critic_1_loss.backward()
self.critic_1_optimizer.step()
critic_2_prediction = self.critic_2.forward(critic_input, distribution=True)
critic_2_loss = self.critic_2.quantile_huber_loss(critic_2_prediction, target)
self.critic_2_optimizer.zero_grad()
critic_2_loss.backward()
self.critic_2_optimizer.step()
return critic_1_loss, critic_2_loss

193
rsl_rl/algorithms/hybrid.py Normal file
View File

@ -0,0 +1,193 @@
from abc import ABC, abstractmethod
import torch
from typing import Callable, Dict, Tuple, Type, Union
from rsl_rl.algorithms import D4PG, DSAC
from rsl_rl.algorithms import TD3
from rsl_rl.algorithms.actor_critic import AbstractActorCritic
from rsl_rl.algorithms.agent import Agent
from rsl_rl.env import VecEnv
from rsl_rl.storage.storage import Dataset, Storage
class AbstractHybridAgent(Agent, ABC):
def __init__(
self,
env: VecEnv,
agent_class: Type[Agent],
agent_kwargs: dict,
pretrain_agent_class: Type[Agent],
pretrain_agent_kwargs: dict,
pretrain_steps: int,
freeze_steps: int = 0,
**general_kwargs,
):
"""
Args:
env (VecEnv): A vectorized environment.
"""
agent_kwargs["env"] = env
pretrain_agent_kwargs["env"] = env
self.agent = agent_class(**agent_kwargs, **general_kwargs)
self.pretrain_agent = pretrain_agent_class(**pretrain_agent_kwargs, **general_kwargs)
self._freeze_steps = freeze_steps
self._pretrain_steps = pretrain_steps
self._steps = 0
self._register_serializable("agent", "pretrain_agent", "_freeze_steps", "_pretrain_steps", "_steps")
@property
def active_agent(self):
agent = self.pretrain_agent if self.pretraining else self.agent
return agent
def draw_actions(self, *args, **kwargs) -> Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]:
return self.active_agent.draw_actions(*args, **kwargs)
def draw_random_actions(self, *args, **kwargs) -> Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]:
return self.active_agent.draw_random_actions(*args, **kwargs)
def eval_mode(self, *args, **kwargs) -> Agent:
self.agent.eval_mode(*args, **kwargs)
def get_inference_policy(self, *args, **kwargs) -> Callable:
return self.active_agent.get_inference_policy(*args, **kwargs)
@property
def initialized(self) -> bool:
return self.active_agent.initialized
@property
def pretraining(self):
return self._steps < self._pretrain_steps
def process_dataset(self, *args, **kwargs) -> Dataset:
return self.active_agent.process_dataset(*args, **kwargs)
def process_transition(self, *args, **kwargs) -> Dict[str, torch.Tensor]:
return self.active_agent.process_transition(*args, **kwargs)
def register_terminations(self, *args, **kwargs) -> None:
return self.active_agent.register_terminations(*args, **kwargs)
@property
def storage(self) -> Storage:
return self.active_agent.storage
def to(self, *args, **kwargs) -> Agent:
self.agent.to(*args, **kwargs)
self.pretrain_agent.to(*args, **kwargs)
def train_mode(self, *args, **kwargs) -> Agent:
self.agent.train_mode(*args, **kwargs)
self.pretrain_agent.train_mode(*args, **kwargs)
def update(self, *args, **kwargs) -> Dict[str, Union[float, torch.Tensor]]:
result = self.active_agent.update(*args, **kwargs)
if not self.active_agent.initialized:
return
self._steps += 1
if self._steps == self._pretrain_steps:
self._transfer_weights()
self._freeze_weights(freeze=True)
if self._steps == self._pretrain_steps + self._freeze_steps:
self._transfer_weights()
self._freeze_weights(freeze=False)
return result
@abstractmethod
def _freeze_weights(self, freeze=True):
pass
@abstractmethod
def _transfer_weights(self):
pass
class HybridD4PG(AbstractHybridAgent):
def __init__(
self,
env: VecEnv,
d4pg_kwargs: dict,
pretrain_kwargs: dict,
pretrain_agent: Type[AbstractActorCritic] = TD3,
**kwargs,
):
assert d4pg_kwargs["action_max"] == pretrain_kwargs["action_max"]
assert d4pg_kwargs["action_min"] == pretrain_kwargs["action_min"]
assert d4pg_kwargs["actor_activations"] == pretrain_kwargs["actor_activations"]
assert d4pg_kwargs["actor_hidden_dims"] == pretrain_kwargs["actor_hidden_dims"]
assert d4pg_kwargs["actor_input_normalization"] == pretrain_kwargs["actor_input_normalization"]
super().__init__(
env,
agent_class=D4PG,
agent_kwargs=d4pg_kwargs,
pretrain_agent_class=pretrain_agent,
pretrain_agent_kwargs=pretrain_kwargs,
**kwargs,
)
def _freeze_weights(self, freeze=True):
for param in self.agent.actor.parameters():
param.requires_grad = not freeze
def _transfer_weights(self):
self.agent.actor.load_state_dict(self.pretrain_agent.actor.state_dict())
self.agent.actor_optimizer.load_state_dict(self.pretrain_agent.actor_optimizer.state_dict())
class HybridDSAC(AbstractHybridAgent):
def __init__(
self,
env: VecEnv,
dsac_kwargs: dict,
pretrain_kwargs: dict,
pretrain_agent: Type[AbstractActorCritic] = TD3,
**kwargs,
):
assert dsac_kwargs["action_max"] == pretrain_kwargs["action_max"]
assert dsac_kwargs["action_min"] == pretrain_kwargs["action_min"]
assert dsac_kwargs["actor_activations"] == pretrain_kwargs["actor_activations"]
assert dsac_kwargs["actor_hidden_dims"] == pretrain_kwargs["actor_hidden_dims"]
assert dsac_kwargs["actor_input_normalization"] == pretrain_kwargs["actor_input_normalization"]
super().__init__(
env,
agent_class=DSAC,
agent_kwargs=dsac_kwargs,
pretrain_agent_class=pretrain_agent,
pretrain_agent_kwargs=pretrain_kwargs,
**kwargs,
)
def _freeze_weights(self, freeze=True):
"""Freezes actor layers.
Freezes feature encoding and mean computation layers for gaussian network. Leaves log standard deviation layer
unfreezed.
"""
for param in self.agent.actor._layers.parameters():
param.requires_grad = not freeze
for param in self.agent.actor._mean_layer.parameters():
param.requires_grad = not freeze
def _transfer_weights(self):
"""Transfers actor layers.
Transfers only feature encoding and mean computation layers for gaussian network.
"""
for i, layer in enumerate(self.agent.actor._layers):
layer.load_state_dict(self.pretrain_agent.actor._layers[i].state_dict())
for j, layer in enumerate(self.agent.actor._mean_layer):
layer.load_state_dict(self.pretrain_agent.actor._layers[i + j + 1].state_dict())

View File

@ -1,185 +1,384 @@
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
import torch
import torch.nn as nn
import torch.optim as optim
from torch import nn, optim
from typing import Any, Dict, Tuple, Type, Union
from rsl_rl.modules import ActorCritic
from rsl_rl.storage import RolloutStorage
from rsl_rl.algorithms.actor_critic import AbstractActorCritic
from rsl_rl.env import VecEnv
from rsl_rl.utils.benchmarkable import Benchmarkable
from rsl_rl.utils.recurrency import trajectories_to_transitions, transitions_to_trajectories
from rsl_rl.modules import GaussianNetwork, Network
from rsl_rl.storage.rollout_storage import RolloutStorage
from rsl_rl.storage.storage import Dataset
class PPO:
actor_critic: ActorCritic
class PPO(AbstractActorCritic):
"""Proximal Policy Optimization algorithm.
This is an implementation of the PPO algorithm by Schulman et. al. for vectorized environments.
Paper: https://arxiv.org/pdf/1707.06347.pdf
The implementation works with recurrent neural networks. We implement adaptive learning rate based on the
KL-divergence between the old and new policy, as described by Schulman et. al. in
https://arxiv.org/pdf/1707.06347.pdf.
"""
critic_network: Type[nn.Module] = Network
_alg_features = dict(recurrent=True)
schedule_adaptive = "adaptive"
schedule_fixed = "fixed"
def __init__(
self,
actor_critic,
num_learning_epochs=1,
num_mini_batches=1,
clip_param=0.2,
gamma=0.998,
lam=0.95,
value_loss_coef=1.0,
entropy_coef=0.0,
learning_rate=1e-3,
max_grad_norm=1.0,
use_clipped_value_loss=True,
schedule="fixed",
desired_kl=0.01,
device="cpu",
env: VecEnv,
actor_noise_std: float = 1.0,
clip_ratio: float = 0.2,
entropy_coeff: float = 0.0,
gae_lambda: float = 0.97,
gradient_clip: float = 1.0,
learning_rate: float = 1e-3,
schedule: str = "fixed",
target_kl: float = 0.01,
value_coeff: float = 1.0,
**kwargs,
):
self.device = device
"""
Args:
env (VecEnv): A vectorized environment.
actor_noise_std (float): The standard deviation of the Gaussian noise to add to the actor network output.
clip_ratio (float): The clipping ratio for the PPO objective.
entropy_coeff (float): The coefficient for the entropy term in the PPO objective.
gae_lambda (float): The lambda parameter for the GAE computation.
gradient_clip (float): The gradient clipping value.
learning_rate (float): The learning rate for the actor and critic networks.
schedule (str): The learning rate schedule. Can be "fixed" or "adaptive". Defaults to "fixed".
target_kl (float): The target KL-divergence for the adaptive learning rate schedule.
value_coeff (float): The coefficient for the value function loss in the PPO objective.
"""
kwargs["batch_size"] = env.num_envs
kwargs["return_steps"] = 1
self.desired_kl = desired_kl
self.schedule = schedule
self.learning_rate = learning_rate
super().__init__(env, **kwargs)
self._critic_input_size = self._critic_input_size - self._action_size # We use a state-value function (not Q)
# PPO components
self.actor_critic = actor_critic
self.actor_critic.to(self.device)
self.storage = None # initialized later
self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=learning_rate)
self.transition = RolloutStorage.Transition()
self.storage = RolloutStorage(self.env.num_envs, device=self.device)
# PPO parameters
self.clip_param = clip_param
self.num_learning_epochs = num_learning_epochs
self.num_mini_batches = num_mini_batches
self.value_loss_coef = value_loss_coef
self.entropy_coef = entropy_coef
self.gamma = gamma
self.lam = lam
self.max_grad_norm = max_grad_norm
self.use_clipped_value_loss = use_clipped_value_loss
self._register_serializable("storage")
def init_storage(self, num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape):
self.storage = RolloutStorage(
num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape, self.device
self._clip_ratio = clip_ratio
self._entropy_coeff = entropy_coeff
self._gae_lambda = gae_lambda
self._gradient_clip = gradient_clip
self._schedule = schedule
self._target_kl = target_kl
self._value_coeff = value_coeff
self._register_serializable(
"_clip_ratio",
"_entropy_coeff",
"_gae_lambda",
"_gradient_clip",
"_schedule",
"_target_kl",
"_value_coeff",
)
def test_mode(self):
self.actor_critic.test()
self.actor = GaussianNetwork(
self._actor_input_size, self._action_size, std_init=actor_noise_std, **self._actor_network_kwargs
)
self.critic = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs)
def train_mode(self):
self.actor_critic.train()
if self.recurrent:
self.actor.reset_full_hidden_state(batch_size=self.env.num_envs)
self.critic.reset_full_hidden_state(batch_size=self.env.num_envs)
def act(self, obs, critic_obs):
if self.actor_critic.is_recurrent:
self.transition.hidden_states = self.actor_critic.get_hidden_states()
# Compute the actions and values
self.transition.actions = self.actor_critic.act(obs).detach()
self.transition.values = self.actor_critic.evaluate(critic_obs).detach()
self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach()
self.transition.action_mean = self.actor_critic.action_mean.detach()
self.transition.action_sigma = self.actor_critic.action_std.detach()
# need to record obs and critic_obs before env.step()
self.transition.observations = obs
self.transition.critic_observations = critic_obs
return self.transition.actions
tp = lambda v: v.transpose(0, 1) # for storing, transpose (num_layers, batch, F) to (batch, num_layers, F)
self.storage.register_processor("actor_state_h", tp)
self.storage.register_processor("actor_state_c", tp)
self.storage.register_processor("critic_state_h", tp)
self.storage.register_processor("critic_state_c", tp)
self.storage.register_processor("critic_next_state_h", tp)
self.storage.register_processor("critic_next_state_c", tp)
def process_env_step(self, rewards, dones, infos):
self.transition.rewards = rewards.clone()
self.transition.dones = dones
# Bootstrapping on time outs
if "time_outs" in infos:
self.transition.rewards += self.gamma * torch.squeeze(
self.transition.values * infos["time_outs"].unsqueeze(1).to(self.device), 1
)
self._bm_fuse(self.actor, prefix="actor.")
self._bm_fuse(self.critic, prefix="critic.")
# Record the transition
self.storage.add_transitions(self.transition)
self.transition.clear()
self.actor_critic.reset(dones)
self._register_serializable("actor", "critic")
def compute_returns(self, last_critic_obs):
last_values = self.actor_critic.evaluate(last_critic_obs).detach()
self.storage.compute_returns(last_values, self.gamma, self.lam)
self.learning_rate = learning_rate
self.optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
def update(self):
mean_value_loss = 0
mean_surrogate_loss = 0
if self.actor_critic.is_recurrent:
generator = self.storage.reccurent_mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
else:
generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
for (
obs_batch,
critic_obs_batch,
actions_batch,
target_values_batch,
advantages_batch,
returns_batch,
old_actions_log_prob_batch,
old_mu_batch,
old_sigma_batch,
hid_states_batch,
masks_batch,
) in generator:
self.actor_critic.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0])
actions_log_prob_batch = self.actor_critic.get_actions_log_prob(actions_batch)
value_batch = self.actor_critic.evaluate(
critic_obs_batch, masks=masks_batch, hidden_states=hid_states_batch[1]
)
mu_batch = self.actor_critic.action_mean
sigma_batch = self.actor_critic.action_std
entropy_batch = self.actor_critic.entropy
self._register_serializable("learning_rate", "optimizer")
# KL
if self.desired_kl is not None and self.schedule == "adaptive":
with torch.inference_mode():
kl = torch.sum(
torch.log(sigma_batch / old_sigma_batch + 1.0e-5)
+ (torch.square(old_sigma_batch) + torch.square(old_mu_batch - mu_batch))
/ (2.0 * torch.square(sigma_batch))
- 0.5,
axis=-1,
)
kl_mean = torch.mean(kl)
def draw_random_actions(self, obs: torch.Tensor, env_info: Dict[str, Any]) -> torch.Tensor:
raise NotImplementedError("PPO does not support drawing random actions.")
if kl_mean > self.desired_kl * 2.0:
self.learning_rate = max(1e-5, self.learning_rate / 1.5)
elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0:
self.learning_rate = min(1e-2, self.learning_rate * 1.5)
@Benchmarkable.register
def draw_actions(
self, obs: torch.Tensor, env_info: Dict[str, Any]
) -> Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]:
actor_obs, critic_obs = self._process_observations(obs, env_info)
for param_group in self.optimizer.param_groups:
param_group["lr"] = self.learning_rate
data = {}
# Surrogate loss
ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch))
surrogate = -torch.squeeze(advantages_batch) * ratio
surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp(
ratio, 1.0 - self.clip_param, 1.0 + self.clip_param
)
surrogate_loss = torch.max(surrogate, surrogate_clipped).mean()
if self.recurrent:
data["actor_state_h"] = self.actor.hidden_state[0].detach()
data["actor_state_c"] = self.actor.hidden_state[1].detach()
# Value function loss
if self.use_clipped_value_loss:
value_clipped = target_values_batch + (value_batch - target_values_batch).clamp(
-self.clip_param, self.clip_param
)
value_losses = (value_batch - returns_batch).pow(2)
value_losses_clipped = (value_clipped - returns_batch).pow(2)
value_loss = torch.max(value_losses, value_losses_clipped).mean()
mean, std = self.actor.forward(actor_obs, compute_std=True)
action_distribution = torch.distributions.Normal(mean, std)
actions = self._process_actions(action_distribution.rsample()).detach()
action_prediction_logp = action_distribution.log_prob(actions).sum(-1)
data["actor_observations"] = actor_obs
data["critic_observations"] = critic_obs
data["actions_logp"] = action_prediction_logp.detach()
data["actions_mean"] = action_distribution.mean.detach()
data["actions_std"] = action_distribution.stddev.detach()
return actions, data
def eval_mode(self) -> AbstractActorCritic:
super().eval_mode()
self.actor.eval()
self.critic.eval()
return self
@property
def initialized(self) -> bool:
return True
@Benchmarkable.register
def process_transition(self, *args) -> Dict[str, torch.Tensor]:
transition = super().process_transition(*args)
if self.recurrent:
transition["critic_state_h"] = self.critic.hidden_state[0].detach()
transition["critic_state_c"] = self.critic.hidden_state[1].detach()
transition["values"] = self.critic.forward(transition["critic_observations"]).detach()
if self.recurrent:
transition["critic_next_state_h"] = self.critic.hidden_state[0].detach()
transition["critic_next_state_c"] = self.critic.hidden_state[1].detach()
return transition
def parameters(self):
params = list(self.actor.parameters()) + list(self.critic.parameters())
return params
def register_terminations(self, terminations: torch.Tensor) -> None:
"""Registers terminations with the agent.
Args:
terminations (torch.Tensor): A 1-dimensional int tensor containing the indices of the terminated
environments.
"""
if terminations.shape[0] == 0:
return
if self.recurrent:
self.actor.reset_hidden_state(terminations)
self.critic.reset_hidden_state(terminations)
def to(self, device: str) -> AbstractActorCritic:
super().to(device)
self.actor.to(device)
self.critic.to(device)
return self
def train_mode(self) -> AbstractActorCritic:
super().train_mode()
self.actor.train()
self.critic.train()
return self
@Benchmarkable.register
def update(self, dataset: Dataset) -> Dict[str, Union[float, torch.Tensor]]:
super().update(dataset)
assert self.storage.initialized
total_loss = torch.zeros(self._batch_count)
total_surrogate_loss = torch.zeros(self._batch_count)
total_value_loss = torch.zeros(self._batch_count)
for idx, batch in enumerate(self.storage.batch_generator(self._batch_count, trajectories=self.recurrent)):
if self.recurrent:
transition_obs = batch["actor_observations"].reshape(*batch["actor_observations"].shape[:2], -1)
observations, data = transitions_to_trajectories(transition_obs, batch["dones"])
hidden_state_h, _ = transitions_to_trajectories(batch["actor_state_h"], batch["dones"])
hidden_state_c, _ = transitions_to_trajectories(batch["actor_state_c"], batch["dones"])
# Init. sequence with each trajectory's first hidden state. Subsequent hidden states are produced by the
# network, depending on the previous hidden state and the current observation.
hidden_state = (hidden_state_h[0].transpose(0, 1), hidden_state_c[0].transpose(0, 1))
action_mean, action_std = self.actor.forward(observations, hidden_state=hidden_state, compute_std=True)
action_mean = action_mean.reshape(*observations.shape[:-1], self._action_size)
action_std = action_std.reshape(*observations.shape[:-1], self._action_size)
action_mean = trajectories_to_transitions(action_mean, data)
action_std = trajectories_to_transitions(action_std, data)
else:
value_loss = (returns_batch - value_batch).pow(2).mean()
action_mean, action_std = self.actor.forward(batch["actor_observations"], compute_std=True)
loss = surrogate_loss + self.value_loss_coef * value_loss - self.entropy_coef * entropy_batch.mean()
actions_dist = torch.distributions.Normal(action_mean, action_std)
if self._schedule == self.schedule_adaptive:
self._update_learning_rate(batch, actions_dist)
surrogate_loss = self._compute_actor_loss(batch, actions_dist)
value_loss = self._compute_value_loss(batch)
actions_entropy = actions_dist.entropy().sum(-1)
loss = surrogate_loss + self._value_coeff * value_loss - self._entropy_coeff * actions_entropy.mean()
# Gradient step
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.max_grad_norm)
nn.utils.clip_grad_norm_(self.parameters(), self._gradient_clip)
self.optimizer.step()
mean_value_loss += value_loss.item()
mean_surrogate_loss += surrogate_loss.item()
total_loss[idx] = loss.detach()
total_surrogate_loss[idx] = surrogate_loss.detach()
total_value_loss[idx] = value_loss.detach()
num_updates = self.num_learning_epochs * self.num_mini_batches
mean_value_loss /= num_updates
mean_surrogate_loss /= num_updates
self.storage.clear()
stats = {
"total": total_loss.mean().item(),
"surrogate": total_surrogate_loss.mean().item(),
"value": total_value_loss.mean().item(),
}
return mean_value_loss, mean_surrogate_loss
return stats
@Benchmarkable.register
def _compute_actor_loss(
self, batch: Dict[str, torch.Tensor], actions_dist: torch.distributions.Normal
) -> torch.Tensor:
batch_actions_logp = actions_dist.log_prob(batch["actions"]).sum(-1)
ratio = (batch_actions_logp - batch["actions_logp"]).exp()
surrogate = batch["normalized_advantages"] * ratio
surrogate_clipped = batch["normalized_advantages"] * ratio.clamp(1.0 - self._clip_ratio, 1.0 + self._clip_ratio)
surrogate_loss = -torch.min(surrogate, surrogate_clipped).mean()
return surrogate_loss
@Benchmarkable.register
def _compute_value_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
if self.recurrent:
observations, data = transitions_to_trajectories(batch["critic_observations"], batch["dones"])
hidden_state_h, _ = transitions_to_trajectories(batch["critic_state_h"], batch["dones"])
hidden_state_c, _ = transitions_to_trajectories(batch["critic_state_c"], batch["dones"])
hidden_states = (hidden_state_h[0].transpose(0, 1), hidden_state_c[0].transpose(0, 1))
trajectory_evaluations = self.critic.forward(observations, hidden_state=hidden_states)
trajectory_evaluations = trajectory_evaluations.reshape(*observations.shape[:-1])
evaluation = trajectories_to_transitions(trajectory_evaluations, data)
else:
evaluation = self.critic.forward(batch["critic_observations"])
value_clipped = batch["values"] + (evaluation - batch["values"]).clamp(-self._clip_ratio, self._clip_ratio)
returns = batch["advantages"] + batch["values"]
value_losses = (evaluation - returns).pow(2)
value_losses_clipped = (value_clipped - returns).pow(2)
value_loss = torch.max(value_losses, value_losses_clipped).mean()
return value_loss
def _critic_input(self, observations, actions=None) -> torch.Tensor:
return observations
def _entry_to_hs(
self, entry: Dict[str, torch.Tensor], critic: bool = False, next: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Helper function to turn a dataset entry into a hidden state tuple.
Args:
entry (Dict[str, torch.Tensor]): The dataset entry.
critic (bool): Whether to extract the hidden state for the critic instead of the actor. Defaults to False.
next (bool): Whether the hidden state is for the next step or the current. Defaults to False.
Returns:
A tuple of hidden state tensors.
"""
key = ("critic" if critic else "actor") + "_" + ("next_state" if next else "state")
hidden_state = entry[f"{key}_h"], entry[f"{key}_c"]
return hidden_state
@Benchmarkable.register
def _process_dataset(self, dataset: Dataset) -> Dataset:
"""Processes a dataset before it is added to the replay memory.
Computes advantages and returns.
Args:
dataset (Dataset): The dataset to process.
Returns:
A Dataset object containing the processed data.
"""
rewards = torch.stack([entry["rewards"] for entry in dataset])
dones = torch.stack([entry["dones"] for entry in dataset])
timeouts = torch.stack([entry["timeouts"] for entry in dataset])
values = torch.stack([entry["values"] for entry in dataset])
# We could alternatively compute the next hidden state from the current state and hidden state. But this
# (storing the hidden state when evaluating the action in process_transition) is computationally more efficient
# and doesn't change the result as the network is not updated between storing the data and computing advantages.
critic_kwargs = (
{"hidden_state": (dataset[-1]["critic_state_h"], dataset[-1]["critic_state_c"])} if self.recurrent else {}
)
final_values = self.critic.forward(dataset[-1]["next_critic_observations"], **critic_kwargs)
next_values = torch.cat((values[1:], final_values.unsqueeze(0)), dim=0)
rewards += self.gamma * timeouts * values
deltas = (rewards + (1 - dones).float() * self.gamma * next_values - values).reshape(-1, self.env.num_envs)
advantages = torch.zeros((len(dataset) + 1, self.env.num_envs), device=self.device)
for step in reversed(range(len(dataset))):
advantages[step] = (
deltas[step] + (1 - dones[step]).float() * self.gamma * self._gae_lambda * advantages[step + 1]
)
advantages = advantages[:-1]
amean, astd = advantages.mean(), torch.nan_to_num(advantages.std())
for step in range(len(dataset)):
dataset[step]["advantages"] = advantages[step]
dataset[step]["normalized_advantages"] = (advantages[step] - amean) / (astd + 1e-8)
return dataset
@Benchmarkable.register
def _update_learning_rate(self, batch: Dict[str, torch.Tensor], actions_dist: torch.distributions.Normal) -> None:
with torch.inference_mode():
actions_mean = actions_dist.mean
actions_std = actions_dist.stddev
kl = torch.sum(
torch.log(actions_std / batch["actions_std"] + 1.0e-5)
+ (torch.square(batch["actions_std"]) + torch.square(batch["actions_mean"] - actions_mean))
/ (2.0 * torch.square(actions_std))
- 0.5,
axis=-1,
)
kl_mean = torch.mean(kl)
if kl_mean > self._target_kl * 2.0:
self.learning_rate = max(1e-5, self.learning_rate / 1.5)
elif kl_mean < self._target_kl / 2.0 and kl_mean > 0.0:
self.learning_rate = min(1e-2, self.learning_rate * 1.5)
for param_group in self.optimizer.param_groups:
param_group["lr"] = self.learning_rate

319
rsl_rl/algorithms/sac.py Normal file
View File

@ -0,0 +1,319 @@
import numpy as np
import torch
from torch import nn, optim
from typing import Any, Callable, Dict, Tuple, Type, Union
from rsl_rl.algorithms.actor_critic import AbstractActorCritic
from rsl_rl.env import VecEnv
from rsl_rl.modules import Network, GaussianChimeraNetwork, GaussianNetwork
from rsl_rl.storage.replay_storage import ReplayStorage
from rsl_rl.storage.storage import Dataset
class SAC(AbstractActorCritic):
"""Soft Actor Critic algorithm.
This is an implementation of the SAC algorithm by Haarnoja et. al. for vectorized environments.
Paper: https://arxiv.org/pdf/1801.01290.pdf
We additionally implement automatic tuning of the temperature parameter (alpha) and tanh action scaling, as
introduced by Haarnoja et. al. in https://arxiv.org/pdf/1812.05905.pdf.
"""
critic_network: Type[nn.Module] = Network
def __init__(
self,
env: VecEnv,
action_max: float = 100.0,
action_min: float = -100.0,
actor_lr: float = 1e-4,
actor_noise_std: float = 1.0,
alpha: float = 0.2,
alpha_lr: float = 1e-3,
chimera: bool = True,
critic_lr: float = 1e-3,
gradient_clip: float = 1.0,
log_std_max: float = 4.0,
log_std_min: float = -20.0,
storage_initial_size: int = 0,
storage_size: int = 100000,
target_entropy: float = None,
**kwargs
):
"""
Args:
env (VecEnv): A vectorized environment.
actor_lr (float): Learning rate for the actor.
alpha (float): Initial entropy regularization coefficient.
alpha_lr (float): Learning rate for entropy regularization coefficient.
chimera (bool): Whether to use separate heads for computing action mean and std (True) or treat the std as a
tunable parameter (True).
critic_lr (float): Learning rate for the critic.
gradient_clip (float): Gradient clip value.
log_std_max (float): Maximum log standard deviation.
log_std_min (float): Minimum log standard deviation.
storage_initial_size (int): Initial size of the replay storage.
storage_size (int): Maximum size of the replay storage.
target_entropy (float): Target entropy for the actor policy. Defaults to action space dimensionality.
"""
super().__init__(env, action_max=action_max, action_min=action_min, **kwargs)
self.storage = ReplayStorage(
self.env.num_envs, storage_size, device=self.device, initial_size=storage_initial_size
)
self._register_serializable("storage")
assert self._action_max < np.inf, 'Parameter "action_max" needs to be set for SAC.'
assert self._action_min > -np.inf, 'Parameter "action_min" needs to be set for SAC.'
self._action_delta = 0.5 * (self._action_max - self._action_min)
self._action_offset = 0.5 * (self._action_max + self._action_min)
self.log_alpha = torch.tensor(np.log(alpha), dtype=torch.float32).requires_grad_()
self._gradient_clip = gradient_clip
self._target_entropy = target_entropy if target_entropy else -self._action_size
self._register_serializable("log_alpha", "_gradient_clip")
network_class = GaussianChimeraNetwork if chimera else GaussianNetwork
self.actor = network_class(
self._actor_input_size,
self._action_size,
log_std_max=log_std_max,
log_std_min=log_std_min,
std_init=actor_noise_std,
**self._actor_network_kwargs
)
self.critic_1 = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs)
self.critic_2 = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs)
self.target_critic_1 = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs)
self.target_critic_2 = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs)
self.target_critic_1.load_state_dict(self.critic_1.state_dict())
self.target_critic_2.load_state_dict(self.critic_2.state_dict())
self._register_serializable("actor", "critic_1", "critic_2", "target_critic_1", "target_critic_2")
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)
self.log_alpha_optimizer = optim.Adam([self.log_alpha], lr=alpha_lr)
self.critic_1_optimizer = optim.Adam(self.critic_1.parameters(), lr=critic_lr)
self.critic_2_optimizer = optim.Adam(self.critic_2.parameters(), lr=critic_lr)
self._register_serializable(
"actor_optimizer", "log_alpha_optimizer", "critic_1_optimizer", "critic_2_optimizer"
)
self.critic = self.critic_1
@property
def alpha(self):
return self.log_alpha.exp()
def draw_actions(
self, obs: torch.Tensor, env_info: Dict[str, Any]
) -> Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]:
actor_obs, critic_obs = self._process_observations(obs, env_info)
action = self._sample_action(actor_obs, compute_logp=False)
data = {"actor_observations": actor_obs.clone(), "critic_observations": critic_obs.clone()}
return action, data
def eval_mode(self) -> AbstractActorCritic:
super().eval_mode()
self.actor.eval()
self.critic_1.eval()
self.critic_2.eval()
self.target_critic_1.eval()
self.target_critic_2.eval()
return self
def get_inference_policy(self, device=None) -> Callable:
self.to(device)
self.eval_mode()
def policy(obs, env_info=None):
obs, _ = self._process_observations(obs, env_info)
actions = self._scale_actions(self.actor.forward(obs))
# actions, _ = self.draw_actions(obs, env_info)
return actions
return policy
def register_terminations(self, terminations: torch.Tensor) -> None:
pass
def to(self, device: str) -> AbstractActorCritic:
"""Transfers agent parameters to device."""
super().to(device)
self.actor.to(device)
self.critic_1.to(device)
self.critic_2.to(device)
self.target_critic_1.to(device)
self.target_critic_2.to(device)
self.log_alpha.to(device)
return self
def train_mode(self) -> AbstractActorCritic:
super().train_mode()
self.actor.train()
self.critic_1.train()
self.critic_2.train()
self.target_critic_1.train()
self.target_critic_2.train()
return self
def update(self, dataset: Dataset) -> Dict[str, Union[float, torch.Tensor]]:
super().update(dataset)
if not self.initialized:
return {}
total_actor_loss = torch.zeros(self._batch_count)
total_alpha_loss = torch.zeros(self._batch_count)
total_critic_1_loss = torch.zeros(self._batch_count)
total_critic_2_loss = torch.zeros(self._batch_count)
for idx, batch in enumerate(self.storage.batch_generator(self._batch_size, self._batch_count)):
actor_obs = batch["actor_observations"]
critic_obs = batch["critic_observations"]
actions = batch["actions"].reshape(-1, self._action_size)
rewards = batch["rewards"]
actor_next_obs = batch["next_actor_observations"]
critic_next_obs = batch["next_critic_observations"]
dones = batch["dones"]
critic_1_loss, critic_2_loss = self._update_critic(
critic_obs, actions, rewards, dones, actor_next_obs, critic_next_obs
)
actor_loss, alpha_loss = self._update_actor_and_alpha(actor_obs, critic_obs)
# Update Target Networks
self._update_target(self.critic_1, self.target_critic_1)
self._update_target(self.critic_2, self.target_critic_2)
total_actor_loss[idx] = actor_loss.item()
total_alpha_loss[idx] = alpha_loss.item()
total_critic_1_loss[idx] = critic_1_loss.item()
total_critic_2_loss[idx] = critic_2_loss.item()
stats = {
"actor": total_actor_loss.mean().item(),
"alpha": total_alpha_loss.mean().item(),
"critic1": total_critic_1_loss.mean().item(),
"critic2": total_critic_2_loss.mean().item(),
}
return stats
def _sample_action(
self, observation: torch.Tensor, compute_logp=True
) -> Union[torch.Tensor, Tuple[torch.Tensor, float]]:
"""Samples and action from the policy.
Args:
observation (torch.Tensor): The observation to sample an action for.
compute_logp (bool): Whether to compute and return the action log probability. Default to True.
Returns:
Either the action as a torch.Tensor or, if compute_logp is set to true, a tuple containing the actions as a
torch.Tensor and the action log probability.
"""
mean, std = self.actor.forward(observation, compute_std=True)
dist = torch.distributions.Normal(mean, std)
actions = dist.rsample()
actions_normalized, actions_scaled = self._scale_actions(actions, intermediate=True)
if not compute_logp:
return actions_scaled
action_logp = dist.log_prob(actions).sum(-1) - torch.log(1.0 - actions_normalized.pow(2) + 1e-6).sum(-1)
return actions_scaled, action_logp
def _scale_actions(self, actions: torch.Tensor, intermediate=False) -> torch.Tensor:
actions = actions.reshape(-1, self._action_size)
action_normalized = torch.tanh(actions)
action_scaled = super()._process_actions(action_normalized * self._action_delta + self._action_offset)
if intermediate:
return action_normalized, action_scaled
return action_scaled
def _update_actor_and_alpha(
self, actor_obs: torch.Tensor, critic_obs: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
actor_prediction, actor_prediction_logp = self._sample_action(actor_obs)
# Update alpha (also called temperature / entropy coefficient)
alpha_loss = -(self.log_alpha * (actor_prediction_logp + self._target_entropy).detach()).mean()
self.log_alpha_optimizer.zero_grad()
alpha_loss.backward()
self.log_alpha_optimizer.step()
# Update actor
evaluation_input = self._critic_input(critic_obs, actor_prediction)
evaluation_1 = self.critic_1.forward(evaluation_input)
evaluation_2 = self.critic_2.forward(evaluation_input)
actor_loss = (self.alpha.detach() * actor_prediction_logp - torch.min(evaluation_1, evaluation_2)).mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
nn.utils.clip_grad_norm_(self.actor.parameters(), self._gradient_clip)
self.actor_optimizer.step()
return actor_loss, alpha_loss
def _update_critic(
self,
critic_obs: torch.Tensor,
actions: torch.Tensor,
rewards: torch.Tensor,
dones: torch.Tensor,
actor_next_obs: torch.Tensor,
critic_next_obs: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
with torch.no_grad():
target_action, target_action_logp = self._sample_action(actor_next_obs)
target_critic_input = self._critic_input(critic_next_obs, target_action)
target_critic_prediction_1 = self.target_critic_1.forward(target_critic_input)
target_critic_prediction_2 = self.target_critic_2.forward(target_critic_input)
target_next = (
torch.min(target_critic_prediction_1, target_critic_prediction_2) - self.alpha * target_action_logp
)
target = rewards + self._discount_factor * (1 - dones) * target_next
critic_input = self._critic_input(critic_obs, actions)
critic_prediction_1 = self.critic_1.forward(critic_input)
critic_1_loss = nn.functional.mse_loss(critic_prediction_1, target)
self.critic_1_optimizer.zero_grad()
critic_1_loss.backward()
nn.utils.clip_grad_norm_(self.critic_1.parameters(), self._gradient_clip)
self.critic_1_optimizer.step()
critic_prediction_2 = self.critic_2.forward(critic_input)
critic_2_loss = nn.functional.mse_loss(critic_prediction_2, target)
self.critic_2_optimizer.zero_grad()
critic_2_loss.backward()
nn.utils.clip_grad_norm_(self.critic_2.parameters(), self._gradient_clip)
self.critic_2_optimizer.step()
return critic_1_loss, critic_2_loss

198
rsl_rl/algorithms/td3.py Normal file
View File

@ -0,0 +1,198 @@
from __future__ import annotations
import torch
from torch import nn, optim
from typing import Dict, Type, Union
from rsl_rl.algorithms.dpg import AbstractDPG
from rsl_rl.env import VecEnv
from rsl_rl.modules.network import Network
from rsl_rl.storage.storage import Dataset
class TD3(AbstractDPG):
"""Twin-Delayed Deep Deterministic Policy Gradients algorithm.
This is an implementation of the TD3 algorithm by Fujimoto et. al. for vectorized environments.
Paper: https://arxiv.org/pdf/1802.09477.pdf
"""
critic_network: Type[nn.Module] = Network
def __init__(
self,
env: VecEnv,
actor_lr: float = 1e-4,
critic_lr: float = 1e-3,
noise_clip: float = 0.5,
policy_delay: int = 2,
target_noise_scale: float = 0.2,
**kwargs,
) -> None:
super().__init__(env, **kwargs)
self._noise_clip = noise_clip
self._policy_delay = policy_delay
self._target_noise_scale = target_noise_scale
self._register_serializable("_noise_clip", "_policy_delay", "_target_noise_scale")
self.actor = Network(self._actor_input_size, self._action_size, **self._actor_network_kwargs)
self.critic_1 = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs)
self.critic_2 = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs)
self.target_actor = Network(self._actor_input_size, self._action_size, **self._actor_network_kwargs)
self.target_critic_1 = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs)
self.target_critic_2 = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs)
self.target_actor.load_state_dict(self.actor.state_dict())
self.target_critic_1.load_state_dict(self.critic_1.state_dict())
self.target_critic_2.load_state_dict(self.critic_2.state_dict())
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)
self.critic_1_optimizer = optim.Adam(self.critic_1.parameters(), lr=critic_lr)
self.critic_2_optimizer = optim.Adam(self.critic_2.parameters(), lr=critic_lr)
self._update_step = 0
self._register_serializable(
"actor",
"critic_1",
"critic_2",
"target_actor",
"target_critic_1",
"target_critic_2",
"actor_optimizer",
"critic_1_optimizer",
"critic_2_optimizer",
"_update_step",
)
self.critic = self.critic_1
self.to(self.device)
def eval_mode(self) -> TD3:
super().eval_mode()
self.actor.eval()
self.critic_1.eval()
self.critic_2.eval()
self.target_actor.eval()
self.target_critic_1.eval()
self.target_critic_2.eval()
return self
def to(self, device: str) -> TD3:
"""Transfers agent parameters to device."""
super().to(device)
self.actor.to(device)
self.critic_1.to(device)
self.critic_2.to(device)
self.target_actor.to(device)
self.target_critic_1.to(device)
self.target_critic_2.to(device)
return self
def train_mode(self) -> TD3:
super().train_mode()
self.actor.train()
self.critic_1.train()
self.critic_2.train()
self.target_actor.train()
self.target_critic_1.train()
self.target_critic_2.train()
return self
def _apply_action_noise(self, actions: torch.Tensor, clip=False) -> torch.Tensor:
noise = torch.normal(torch.zeros_like(actions), torch.ones_like(actions) * self._action_noise_scale)
if clip:
noise = noise.clamp(-self._noise_clip, self._noise_clip)
noisy_actions = self._process_actions(actions + noise)
return noisy_actions
def update(self, dataset: Dataset) -> Dict[str, Union[float, torch.Tensor]]:
super().update(dataset)
if not self.initialized:
return {}
total_actor_loss = torch.zeros(self._batch_count)
total_critic_1_loss = torch.zeros(self._batch_count)
total_critic_2_loss = torch.zeros(self._batch_count)
for idx, batch in enumerate(self.storage.batch_generator(self._batch_size, self._batch_count)):
actor_obs = batch["actor_observations"]
critic_obs = batch["critic_observations"]
actions = batch["actions"].reshape(self._batch_size, -1)
rewards = batch["rewards"]
actor_next_obs = batch["next_actor_observations"]
critic_next_obs = batch["next_critic_observations"]
dones = batch["dones"]
critic_1_loss, critic_2_loss = self._update_critic(
critic_obs, actions, rewards, dones, actor_next_obs, critic_next_obs
)
if self._update_step % self._policy_delay == 0:
evaluation = self.critic_1.forward(
self._critic_input(critic_obs, self._process_actions(self.actor.forward(actor_obs)))
)
actor_loss = -evaluation.mean()
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
self._update_target(self.actor, self.target_actor)
self._update_target(self.critic_1, self.target_critic_1)
self._update_target(self.critic_2, self.target_critic_2)
total_actor_loss[idx] = actor_loss.item()
self._update_step = self._update_step + 1
total_critic_1_loss[idx] = critic_1_loss.item()
total_critic_2_loss[idx] = critic_2_loss.item()
stats = {
"actor": total_actor_loss.mean().item(),
"critic1": total_critic_1_loss.mean().item(),
"critic2": total_critic_2_loss.mean().item(),
}
return stats
def _update_critic(self, critic_obs, actions, rewards, dones, actor_next_obs, critic_next_obs):
target_actor_prediction = self._apply_action_noise(self.target_actor.forward(actor_next_obs), clip=True)
target_critic_1_prediction = self.target_critic_1.forward(
self._critic_input(critic_next_obs, target_actor_prediction)
)
target_critic_2_prediction = self.target_critic_2.forward(
self._critic_input(critic_next_obs, target_actor_prediction)
)
target_critic_prediction = torch.min(target_critic_1_prediction, target_critic_2_prediction)
target = (rewards + self._discount_factor * (1 - dones) * target_critic_prediction).detach()
prediction_1 = self.critic_1.forward(self._critic_input(critic_obs, actions))
critic_1_loss = (prediction_1 - target).pow(2).mean()
self.critic_1_optimizer.zero_grad()
critic_1_loss.backward()
self.critic_1_optimizer.step()
prediction_2 = self.critic_2.forward(self._critic_input(critic_obs, actions))
critic_2_loss = (prediction_2 - target).pow(2).mean()
self.critic_2_optimizer.zero_grad()
critic_2_loss.backward()
self.critic_2_optimizer.step()
return critic_1_loss, critic_2_loss

View File

@ -0,0 +1,2 @@
from .distribution import Distribution
from .quantile_distribution import QuantileDistribution

View File

@ -0,0 +1,18 @@
from abc import ABC, abstractmethod
import torch
class Distribution(ABC):
def __init__(self, params: torch.Tensor) -> None:
self._params = params
@abstractmethod
def sample(self, sample_count: int = 1) -> torch.Tensor:
"""Sample from the distribution.
Args:
sample_count: The number of samples to draw.
Returns:
A tensor of shape (sample_count, *parameter_shape).
"""
pass

View File

@ -0,0 +1,13 @@
import torch
from .distribution import Distribution
class QuantileDistribution(Distribution):
def sample(self, sample_count: int = 1) -> torch.Tensor:
idx = torch.randint(
self._params.shape[-1], (*self._params.shape[:-1], sample_count), device=self._params.device
)
samples = torch.take_along_dim(self._params, idx, -1)
return samples, idx

View File

@ -1,6 +1,5 @@
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
# SPDX-License-Identifier: BSD-3-Clause
"""Submodule defining the environment definitions."""
from .vec_env import VecEnv

120
rsl_rl/env/gym_env.py vendored Normal file
View File

@ -0,0 +1,120 @@
from datetime import datetime
import gym
import torch
from typing import Any, Dict, Tuple, Union
from rsl_rl.env.vec_env import VecEnv
class GymEnv(VecEnv):
"""A vectorized environment wrapper for OpenAI Gym environments.
This class wraps a single OpenAI Gym environment into a vectorized environment. It is assumed that the environment
is a single agent environment. The environment is wrapped in a `gym.vector.SyncVectorEnv` environment, which
allows for parallel execution of multiple environments.
"""
def __init__(self, name, draw=False, draw_cb=None, draw_directory="videos/", gym_kwargs={}, **kwargs):
"""
Args:
name: The name of the OpenAI Gym environment.
draw: Whether to record videos of the environment.
draw_cb: A callback function that is called after each episode. The callback function is passed the episode
number and the path to the video file. The callback function should return `True` if the video should
be recorded and `False` otherwise.
draw_directory: The directory in which to store the videos.
gym_kwargs: Keyword arguments that are passed to the OpenAI Gym environment.
**kwargs: Keyword arguments that are passed to the `VecEnv` constructor.
"""
self._gym_kwargs = gym_kwargs
env = gym.make(name, **self._gym_kwargs)
assert isinstance(env.observation_space, gym.spaces.Box)
assert len(env.observation_space.shape) == 1
assert isinstance(env.action_space, gym.spaces.Box)
assert len(env.action_space.shape) == 1
super().__init__(env.observation_space.shape[0], env.observation_space.shape[0], **kwargs)
self.name = name
self.draw_directory = draw_directory
self.num_actions = env.action_space.shape[0]
self._gym_venv = gym.vector.SyncVectorEnv(
[lambda: gym.make(self.name, **self._gym_kwargs) for _ in range(self.num_envs)]
)
self._draw = False
self._draw_cb = draw_cb if draw_cb is not None else lambda *args: True
self._draw_uuid = None
self.draw = draw
self.reset()
def close(self) -> None:
self._gym_venv.close()
def get_observations(self) -> Tuple[torch.Tensor, Dict[str, Any]]:
return self.obs_buf, self.extras
def get_privileged_observations(self) -> Union[torch.Tensor, None]:
return self.obs_buf, self.extras
@property
def draw(self) -> bool:
return self._draw
@draw.setter
def draw(self, value: bool) -> None:
if value != self._draw:
if value:
self._draw_uuid = datetime.now().strftime("%Y%m%d%H%M%S")
env = gym.make(self.name, render_mode="rgb_array", **self._gym_kwargs)
env = gym.wrappers.RecordVideo(
env,
f"{self.draw_directory}/{self._draw_uuid}/",
episode_trigger=lambda ep: (
self._draw_cb(ep - 1, f"{self.draw_directory}/{self._draw_uuid}/rl-video-episode-{ep-1}.mp4")
or True
)
if ep > 0
else False,
)
else:
env = gym.make(self.name, render_mode=None, **self._gym_kwargs)
self._gym_venv.envs[0] = env
self._draw = value
self.reset()
def reset(self) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
self.obs_buf = torch.from_numpy(self._gym_venv.reset()[0]).float().to(self.device)
self.rew_buf = torch.zeros((self.num_envs,), device=self.device).float()
self.reset_buf = torch.zeros((self.num_envs,), device=self.device).float()
self.extras = {"observations": {}, "time_outs": torch.zeros((self.num_envs,), device=self.device).float()}
return self.obs_buf, self.extras
def step(
self, actions: torch.Tensor
) -> Tuple[torch.Tensor, Union[torch.Tensor, None], torch.Tensor, torch.Tensor, Dict[str, Any]]:
obs, rew, reset, term, _ = self._gym_venv.step(actions.cpu().numpy())
self.obs_buf = torch.from_numpy(obs).float().to(self.device)
self.rew_buf = torch.from_numpy(rew).float().to(self.device)
self.reset_buf = torch.from_numpy(reset).float().to(self.device)
self.extras = {
"observations": {},
"time_outs": torch.from_numpy(term).float().to(self.device).float().to(self.device),
}
return self.obs_buf, self.rew_buf, self.reset_buf, self.extras
def to(self, device: str) -> None:
self.device = device
self.obs_buf = self.obs_buf.to(device)
self.rew_buf = self.rew_buf.to(device)
self.reset_buf = self.reset_buf.to(device)

137
rsl_rl/env/pole_balancing.py vendored Normal file
View File

@ -0,0 +1,137 @@
import math
import numpy as np
import time
import matplotlib.pyplot as plt
import torch
from typing import Any, Dict, Tuple, Union
from rsl_rl.env.vec_env import VecEnv
class PoleBalancing(VecEnv):
"""Custom pole balancing environment.
This class implements a custom pole balancing environment. It demonstrates how to implement a custom `VecEnv`
environment.
"""
def __init__(self, **kwargs):
"""
Args:
**kwargs: Keyword arguments that are passed to the `VecEnv` constructor.
"""
super().__init__(2, 2, **kwargs)
self.num_actions = 1
self.gravity = 9.8
self.length = 2.0
self.dt = 0.1
# Angle at which to fail the episode (15 deg)
self.theta_threshold_radians = 15 * 2 * math.pi / 360
# Max. angle at which to initialize the episode (5 deg)
self.initial_max_position = 2 * 2 * math.pi / 360
# Max. angular velocity at which to initialize the episode (1 deg/s)
self.initial_max_velocity = 0.3 * 2 * math.pi / 360
self.initial_position_factor = self.initial_max_position / 0.5
self.initial_velocity_factor = self.initial_max_velocity / 0.5
self.draw = False
self.pushes = False
self.reset()
def get_observations(self) -> Tuple[torch.Tensor, Dict[str, Any]]:
return self.obs_buf, self.extras
def get_privileged_observations(self) -> Union[torch.Tensor, None]:
return self.obs_buf, self.extras
def step(
self, actions: torch.Tensor
) -> Tuple[torch.Tensor, Union[torch.Tensor, None], torch.Tensor, torch.Tensor, Dict[str, Any]]:
assert actions.size() == (self.num_envs, 1)
self.to(self.device)
actions = actions.to(self.device)
noise = torch.normal(torch.zeros_like(actions), torch.ones_like(actions) * 0.005).squeeze()
if self.pushes and np.random.rand() < 0.05:
noise *= 100.0
actions = actions.clamp(min=-0.2, max=0.2).squeeze()
gravity = torch.sin(self.state[:, 0]) * self.gravity / self.length
angular_acceleration = gravity + actions + noise
self.state[:, 1] = self.state[:, 1] + self.dt * angular_acceleration
self.state[:, 0] = self.state[:, 0] + self.dt * self.state[:, 1]
self.reset_buf = torch.zeros(self.num_envs)
self.reset_buf[(torch.abs(self.state[:, 0]) > self.theta_threshold_radians).nonzero()] = 1.0
reset_idx = self.reset_buf.nonzero()
self.state[reset_idx, 0] = (
torch.rand(reset_idx.size()[0], 1, device=self.device) - 0.5
) * self.initial_position_factor
self.state[reset_idx, 1] = (
torch.rand(reset_idx.size()[0], 1, device=self.device) - 0.5
) * self.initial_velocity_factor
self.rew_buf = torch.ones(self.num_envs, device=self.device)
self.rew_buf[reset_idx] = -1.0
self.rew_buf = self.rew_buf - actions.abs()
self.rew_buf = self.rew_buf - self.state[:, 0].abs()
self._update_obs()
if self.draw:
self._debug_draw(actions)
self.to(self.device)
return self.obs_buf, self.rew_buf, self.reset_buf, self.extras
def _update_obs(self):
self.obs_buf = self.state
self.extras = {"observations": {}, "time_outs": torch.zeros_like(self.rew_buf)}
def reset(self) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
self.state = torch.zeros(self.num_envs, 2, device=self.device)
self.state[:, 0] = (torch.rand(self.num_envs) - 0.5) * self.initial_position_factor
self.state[:, 1] = (torch.rand(self.num_envs) - 0.5) * self.initial_velocity_factor
self.rew_buf = torch.zeros(self.num_envs)
self.reset_buf = torch.zeros(self.num_envs)
self.extras = {}
self._update_obs()
return self.obs_buf, self.extras
def to(self, device):
self.device = device
self.obs_buf = self.obs_buf.to(device)
self.rew_buf = self.rew_buf.to(device)
self.reset_buf = self.reset_buf.to(device)
self.state = self.state.to(device)
def _debug_draw(self, actions):
if not hasattr(self, "_visuals"):
self._visuals = {"x": [0], "pos": [], "act": [], "done": []}
plt.gca().figure.show()
else:
self._visuals["x"].append(self._visuals["x"][-1] + 1)
self._visuals["pos"].append(self.obs_buf[0, 0].cpu().item())
self._visuals["done"].append(self.reset_buf[0].cpu().item())
self._visuals["act"].append(actions.squeeze()[0].cpu().item())
plt.cla()
plt.plot(self._visuals["x"][-100:], self._visuals["act"][-100:], color="green")
plt.plot(self._visuals["x"][-100:], self._visuals["pos"][-100:], color="blue")
plt.plot(self._visuals["x"][-100:], self._visuals["done"][-100:], color="red")
plt.draw()
plt.gca().figure.canvas.flush_events()
time.sleep(0.0001)

97
rsl_rl/env/pomdp.py vendored Normal file
View File

@ -0,0 +1,97 @@
import torch
from typing import Any, Dict, Tuple, Union
from rsl_rl.env.gym_env import GymEnv
class GymPOMDP(GymEnv):
"""A vectorized POMDP environment wrapper for OpenAI Gym environments.
This environment allows for the modification of the observation space of an OpenAI Gym environment. The modified
observation space is a subset of the original observation space.
"""
_reduced_observation_count: int = None
def __init__(self, name: str, **kwargs):
assert self._reduced_observation_count is not None
super().__init__(name=name, **kwargs)
self.num_obs = self._reduced_observation_count
self.num_privileged_obs = self._reduced_observation_count
def _process_observations(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Reduces observation space from original observation space to modified observation space.
Args:
obs (torch.Tensor): Original observations.
Returns:
The modified observations as a torch.Tensor of shape (obs.shape[0], self.num_obs).
"""
raise NotImplementedError
def reset(self, *args, **kwargs):
obs, _ = super().reset(*args, **kwargs)
self.obs_buf = self._process_observations(obs)
return self.obs_buf, self.extras
def step(
self, actions: torch.Tensor
) -> Tuple[torch.Tensor, Union[torch.Tensor, None], torch.Tensor, torch.Tensor, Dict[str, Any]]:
obs, _, _, _ = super().step(actions)
self.obs_buf = self._process_observations(obs)
return self.obs_buf, self.rew_buf, self.reset_buf, self.extras
class BipedalWalkerP(GymPOMDP):
"""
Original observation space (24 values):
[
hull angle,
hull angular velocity,
horizontal velocity,
vertical velocity,
joint 1 angle,
joint 1 speed,
joint 2 angle,
joint 2 speed,
leg 1 ground contact,
joint 3 angle,
joint 3 speed,
joint 4 angle,
joint 4 speed,
leg 2 ground contact,
lidar (10 values),
]
Modified observation space (15 values):
[
hull angle,
joint 1 angle,
joint 2 angle,
joint 3 angle,
joint 4 angle,
lidar (10 values),
]
"""
_reduced_observation_count: int = 15
def __init__(self, **kwargs):
super().__init__(name="BipedalWalker-v3", **kwargs)
def _process_observations(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Reduces observation space from original observation space to modified observation space."""
reduced_obs = torch.zeros(obs.shape[0], self._reduced_observation_count, device=self.device)
reduced_obs[:, 0] = obs[:, 0]
reduced_obs[:, 1] = obs[:, 4]
reduced_obs[:, 2] = obs[:, 6]
reduced_obs[:, 3] = obs[:, 9]
reduced_obs[:, 4] = obs[:, 11]
reduced_obs[:, 5:] = obs[:, 14:]
return reduced_obs

44
rsl_rl/env/rslgym_env.py vendored Normal file
View File

@ -0,0 +1,44 @@
import torch
from typing import Any, Dict, Tuple, Union
from rsl_rl.env.vec_env import VecEnv
class RSLGymEnv(VecEnv):
"""A wrapper for using rsl_rl with the rslgym library."""
def __init__(self, rslgym_env, **kwargs):
self._rslgym_env = rslgym_env
observation_count = self._rslgym_env.observation_space.shape[0]
super().__init__(observation_count, observation_count, **kwargs)
self.num_actions = self._rslgym_env.action_space.shape[0]
self.obs_buf = None
self.rew_buf = None
self.reset_buf = None
self.extras = None
self.reset()
def get_observations(self) -> Tuple[torch.Tensor, Dict[str, Any]]:
return self.obs_buf, self.extras
def get_privileged_observations(self) -> Union[torch.Tensor, None]:
return self.obs_buf
def reset(self) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
obs = self._rslgym_env.reset()
self.obs_buf = torch.from_numpy(obs)
self.extras = {"observations": {}, "time_outs": torch.zeros((self.num_envs,), device=self.device).float()}
def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, Any]]:
obs, reward, dones, info = self._rslgym_env.step(actions, True)
self.obs_buf = torch.from_numpy(obs)
self.rew_buf = torch.from_numpy(reward)
self.reset_buf = torch.from_numpy(dones).float()
return self.obs_buf, self.rew_buf, self.reset_buf, self.extras

99
rsl_rl/env/vec_env.py vendored
View File

@ -1,85 +1,74 @@
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
import torch
from abc import ABC, abstractmethod
import torch
from typing import Any, Dict, Tuple, Union
# minimal interface of the environment
class VecEnv(ABC):
"""Abstract class for vectorized environment.
The vectorized environment is a collection of environments that are synchronized. This means that
the same action is applied to all environments and the same observation is returned from all environments.
All extra observations must be provided as a dictionary to "extras" in the step() method. Based on the
configuration, the extra observations are used for different purposes. The following keys are reserved
in the "observations" dictionary (if they are present):
- "critic": The observation is used as input to the critic network. Useful for asymmetric observation spaces.
"""
"""Abstract class for vectorized environment."""
num_envs: int
"""Number of environments."""
num_obs: int
"""Number of observations."""
num_privileged_obs: int
"""Number of privileged observations."""
num_actions: int
"""Number of actions."""
max_episode_length: int
"""Maximum episode length."""
privileged_obs_buf: torch.Tensor
"""Buffer for privileged observations."""
obs_buf: torch.Tensor
"""Buffer for observations."""
rew_buf: torch.Tensor
"""Buffer for rewards."""
reset_buf: torch.Tensor
"""Buffer for resets."""
episode_length_buf: torch.Tensor # current episode duration
"""Buffer for current episode lengths."""
extras: dict
"""Extra information (metrics).
Extra information is stored in a dictionary. This includes metrics such as the episode reward, episode length,
etc. Additional information can be stored in the dictionary such as observations for the critic network, etc.
"""
device: torch.device
"""Device to use."""
"""
Operations.
"""
@abstractmethod
def get_observations(self) -> tuple[torch.Tensor, dict]:
"""Return the current observations.
Returns:
Tuple[torch.Tensor, dict]: Tuple containing the observations and extras.
def __init__(
self, observation_count, privileged_observation_count, device="cpu", environment_count=1, max_episode_length=-1
):
"""
raise NotImplementedError
@abstractmethod
def reset(self) -> tuple[torch.Tensor, dict]:
"""Reset all environment instances.
Returns:
Tuple[torch.Tensor, dict]: Tuple containing the observations and extras.
Args:
observation_count (int): Number of observations per environment.
privileged_observation_count (int): Number of privileged observations per environment.
device (str): Device to use for the tensors.
environment_count (int): Number of environments to run in parallel.
max_episode_length (int): Maximum length of an episode. If -1, the episode length is not limited.
"""
raise NotImplementedError
self.num_obs = observation_count
self.num_privileged_obs = privileged_observation_count
self.num_envs = environment_count
self.max_episode_length = max_episode_length
self.device = device
@abstractmethod
def step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]:
def get_observations(self) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""Return observations and extra information."""
pass
@abstractmethod
def get_privileged_observations(self) -> Union[torch.Tensor, None]:
"""Return privileged observations."""
pass
@abstractmethod
def step(
self, actions: torch.Tensor
) -> Tuple[torch.Tensor, Union[torch.Tensor, None], torch.Tensor, torch.Tensor, Dict[str, Any]]:
"""Apply input action on the environment.
Args:
actions (torch.Tensor): Input actions to apply. Shape: (num_envs, num_actions)
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]:
A tuple containing the observations, rewards, dones and extra information (metrics).
Tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, dict]:
A tuple containing the observations, privileged observations, rewards, dones and
extra information (metrics).
"""
raise NotImplementedError
@abstractmethod
def reset(self) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
"""Reset all environment instances.
Returns:
Tuple[torch.Tensor, torch.Tensor | None]: Tuple containing the observations and privileged observations.
"""
raise NotImplementedError

View File

@ -1,10 +1,19 @@
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
# SPDX-License-Identifier: BSD-3-Clause
"""Definitions for neural-network components for RL-agents."""
from .actor_critic import ActorCritic
from .actor_critic_recurrent import ActorCriticRecurrent
from .categorical_network import CategoricalNetwork
from .gaussian_chimera_network import GaussianChimeraNetwork
from .gaussian_network import GaussianNetwork
from .implicit_quantile_network import ImplicitQuantileNetwork
from .network import Network
from .normalizer import EmpiricalNormalization
from .quantile_network import QuantileNetwork
from .transformer import Transformer
__all__ = ["ActorCritic", "ActorCriticRecurrent"]
__all__ = [
"CategoricalNetwork",
"EmpiricalNormalization",
"GaussianChimeraNetwork",
"GaussianNetwork",
"ImplicitQuantileNetwork",
"Network",
"QuantileNetwork",
"Transformer",
]

View File

@ -1,136 +0,0 @@
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
import torch
import torch.nn as nn
from torch.distributions import Normal
class ActorCritic(nn.Module):
is_recurrent = False
def __init__(
self,
num_actor_obs,
num_critic_obs,
num_actions,
actor_hidden_dims=[256, 256, 256],
critic_hidden_dims=[256, 256, 256],
activation="elu",
init_noise_std=1.0,
**kwargs,
):
if kwargs:
print(
"ActorCritic.__init__ got unexpected arguments, which will be ignored: "
+ str([key for key in kwargs.keys()])
)
super().__init__()
activation = get_activation(activation)
mlp_input_dim_a = num_actor_obs
mlp_input_dim_c = num_critic_obs
# Policy
actor_layers = []
actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0]))
actor_layers.append(activation)
for layer_index in range(len(actor_hidden_dims)):
if layer_index == len(actor_hidden_dims) - 1:
actor_layers.append(nn.Linear(actor_hidden_dims[layer_index], num_actions))
else:
actor_layers.append(nn.Linear(actor_hidden_dims[layer_index], actor_hidden_dims[layer_index + 1]))
actor_layers.append(activation)
self.actor = nn.Sequential(*actor_layers)
# Value function
critic_layers = []
critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0]))
critic_layers.append(activation)
for layer_index in range(len(critic_hidden_dims)):
if layer_index == len(critic_hidden_dims) - 1:
critic_layers.append(nn.Linear(critic_hidden_dims[layer_index], 1))
else:
critic_layers.append(nn.Linear(critic_hidden_dims[layer_index], critic_hidden_dims[layer_index + 1]))
critic_layers.append(activation)
self.critic = nn.Sequential(*critic_layers)
print(f"Actor MLP: {self.actor}")
print(f"Critic MLP: {self.critic}")
# Action noise
self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
self.distribution = None
# disable args validation for speedup
Normal.set_default_validate_args = False
# seems that we get better performance without init
# self.init_memory_weights(self.memory_a, 0.001, 0.)
# self.init_memory_weights(self.memory_c, 0.001, 0.)
@staticmethod
# not used at the moment
def init_weights(sequential, scales):
[
torch.nn.init.orthogonal_(module.weight, gain=scales[idx])
for idx, module in enumerate(mod for mod in sequential if isinstance(mod, nn.Linear))
]
def reset(self, dones=None):
pass
def forward(self):
raise NotImplementedError
@property
def action_mean(self):
return self.distribution.mean
@property
def action_std(self):
return self.distribution.stddev
@property
def entropy(self):
return self.distribution.entropy().sum(dim=-1)
def update_distribution(self, observations):
mean = self.actor(observations)
self.distribution = Normal(mean, mean * 0.0 + self.std)
def act(self, observations, **kwargs):
self.update_distribution(observations)
return self.distribution.sample()
def get_actions_log_prob(self, actions):
return self.distribution.log_prob(actions).sum(dim=-1)
def act_inference(self, observations):
actions_mean = self.actor(observations)
return actions_mean
def evaluate(self, critic_observations, **kwargs):
value = self.critic(critic_observations)
return value
def get_activation(act_name):
if act_name == "elu":
return nn.ELU()
elif act_name == "selu":
return nn.SELU()
elif act_name == "relu":
return nn.ReLU()
elif act_name == "crelu":
return nn.CReLU()
elif act_name == "lrelu":
return nn.LeakyReLU()
elif act_name == "tanh":
return nn.Tanh()
elif act_name == "sigmoid":
return nn.Sigmoid()
else:
print("invalid activation function!")
return None

View File

@ -1,97 +0,0 @@
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
import torch
import torch.nn as nn
from rsl_rl.modules.actor_critic import ActorCritic, get_activation
from rsl_rl.utils import unpad_trajectories
class ActorCriticRecurrent(ActorCritic):
is_recurrent = True
def __init__(
self,
num_actor_obs,
num_critic_obs,
num_actions,
actor_hidden_dims=[256, 256, 256],
critic_hidden_dims=[256, 256, 256],
activation="elu",
rnn_type="lstm",
rnn_hidden_size=256,
rnn_num_layers=1,
init_noise_std=1.0,
**kwargs,
):
if kwargs:
print(
"ActorCriticRecurrent.__init__ got unexpected arguments, which will be ignored: " + str(kwargs.keys()),
)
super().__init__(
num_actor_obs=rnn_hidden_size,
num_critic_obs=rnn_hidden_size,
num_actions=num_actions,
actor_hidden_dims=actor_hidden_dims,
critic_hidden_dims=critic_hidden_dims,
activation=activation,
init_noise_std=init_noise_std,
)
activation = get_activation(activation)
self.memory_a = Memory(num_actor_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size)
self.memory_c = Memory(num_critic_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size)
print(f"Actor RNN: {self.memory_a}")
print(f"Critic RNN: {self.memory_c}")
def reset(self, dones=None):
self.memory_a.reset(dones)
self.memory_c.reset(dones)
def act(self, observations, masks=None, hidden_states=None):
input_a = self.memory_a(observations, masks, hidden_states)
return super().act(input_a.squeeze(0))
def act_inference(self, observations):
input_a = self.memory_a(observations)
return super().act_inference(input_a.squeeze(0))
def evaluate(self, critic_observations, masks=None, hidden_states=None):
input_c = self.memory_c(critic_observations, masks, hidden_states)
return super().evaluate(input_c.squeeze(0))
def get_hidden_states(self):
return self.memory_a.hidden_states, self.memory_c.hidden_states
class Memory(torch.nn.Module):
def __init__(self, input_size, type="lstm", num_layers=1, hidden_size=256):
super().__init__()
# RNN
rnn_cls = nn.GRU if type.lower() == "gru" else nn.LSTM
self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
self.hidden_states = None
def forward(self, input, masks=None, hidden_states=None):
batch_mode = masks is not None
if batch_mode:
# batch mode (policy update): need saved hidden states
if hidden_states is None:
raise ValueError("Hidden states not passed to memory module during policy update")
out, _ = self.rnn(input, hidden_states)
out = unpad_trajectories(out, masks)
else:
# inference mode (collection): use hidden states of last step
out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states)
return out
def reset(self, dones=None):
# When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state
for hidden_state in self.hidden_states:
hidden_state[..., dones, :] = 0.0

View File

@ -0,0 +1,105 @@
import torch
import torch.nn as nn
from rsl_rl.modules.network import Network
from rsl_rl.utils.utils import squeeze_preserve_batch
eps = torch.finfo(torch.float32).eps
class CategoricalNetwork(Network):
def __init__(
self,
input_size,
output_size,
activations=["relu", "relu", "relu"],
atom_count=51,
hidden_dims=[256, 256, 256],
init_gain=1.0,
value_max=10.0,
value_min=-10.0,
**kwargs,
):
assert len(hidden_dims) == len(activations)
assert value_max > value_min
assert atom_count > 1
super().__init__(
input_size,
activations=activations,
hidden_dims=hidden_dims[:-1],
init_fade=False,
init_gain=init_gain,
output_size=hidden_dims[-1],
**kwargs,
)
self._value_max = value_max
self._value_min = value_min
self._atom_count = atom_count
self.value_delta = (self._value_max - self._value_min) / (self._atom_count - 1)
action_values = torch.arange(self._value_min, self._value_max + eps, self.value_delta)
self.register_buffer("action_values", action_values)
self._categorical_layers = nn.ModuleList([nn.Linear(hidden_dims[-1], atom_count) for _ in range(output_size)])
self._init(self._categorical_layers, fade=False, gain=init_gain)
def categorical_loss(
self, predictions: torch.Tensor, target_probabilities: torch.Tensor, targets: torch.Tensor
) -> torch.Tensor:
"""Computes the categorical loss between the prediction and target categorical distributions.
Projects the targets back onto the categorical distribution supports before computing KL divergence.
Args:
predictions (torch.Tensor): The network prediction.
target_probabilities (torch.Tensor): The next-state value probabilities.
targets (torch.Tensor): The targets to compute the loss from.
Returns:
A torch.Tensor of the cross-entropy loss between the projected targets and the prediction.
"""
b = (targets - self._value_min) / self.value_delta
l = b.floor().long().clamp(0, self._atom_count - 1)
u = b.ceil().long().clamp(0, self._atom_count - 1)
all_idx = torch.arange(b.shape[0])
projected_targets = torch.zeros((b.shape[0], self._atom_count), device=self.device)
for i in range(self._atom_count):
# Correct for when l == u
l[:, i][(l[:, i] == u[:, i]) * (l[:, i] > 0)] -= 1
u[:, i][(l[:, i] == u[:, i]) * (u[:, i] < self._atom_count - 1)] += 1
projected_targets[all_idx, l[:, i]] += (u[:, i] - b[:, i]) * target_probabilities[..., i]
projected_targets[all_idx, u[:, i]] += (b[:, i] - l[:, i]) * target_probabilities[..., i]
loss = torch.nn.functional.cross_entropy(
predictions.reshape(*projected_targets.shape), projected_targets.detach()
)
return loss
def compute_targets(self, rewards: torch.Tensor, dones: torch.Tensor, discount: float) -> torch.Tensor:
gamma = (discount * (1 - dones)).reshape(-1, 1)
gamma_z = gamma * self.action_values.repeat(dones.size()[0], 1)
targets = (rewards.reshape(-1, 1) + gamma_z).clamp(self._value_min, self._value_max)
return targets
def forward(self, x: torch.Tensor, distribution: bool = False) -> torch.Tensor:
features = super().forward(x)
probabilities = squeeze_preserve_batch(
torch.stack([layer(features).softmax(dim=-1) for layer in self._categorical_layers], dim=1)
)
if distribution:
return probabilities
values = self.probabilities_to_values(probabilities)
return values
def probabilities_to_values(self, probabilities: torch.Tensor) -> torch.Tensor:
values = probabilities @ self.action_values
return values

View File

@ -0,0 +1,86 @@
import numpy as np
import torch
import torch.nn as nn
from typing import List, Tuple, Union
from rsl_rl.modules.network import Network
from rsl_rl.modules.utils import get_activation
class GaussianChimeraNetwork(Network):
"""A network to predict mean and std of a gaussian distribution with separate heads for mean and std."""
def __init__(
self,
input_size: int,
output_size: int,
activations: List[str] = ["relu", "relu", "relu", "linear"],
hidden_dims: List[int] = [256, 256, 256],
init_fade: bool = True,
init_gain: float = 0.5,
log_std_max: float = 4.0,
log_std_min: float = -20.0,
std_init: float = 1.0,
shared_dims: int = 1,
**kwargs,
):
assert len(hidden_dims) + 1 == len(activations)
assert shared_dims > 0 and shared_dims <= len(hidden_dims)
super().__init__(
input_size,
hidden_dims[shared_dims],
activations=activations[: shared_dims + 1],
hidden_dims=hidden_dims[:shared_dims],
init_fade=False,
init_gain=init_gain,
**kwargs,
)
# Since the network predicts log_std ~= 0 after initialization, compute std = std_init * exp(log_std).
self._log_std_init = np.log(std_init)
self._log_std_max = log_std_max
self._log_std_min = log_std_min
separate_dims = len(hidden_dims) - shared_dims
mean_layers = []
for i in range(separate_dims):
isize = hidden_dims[shared_dims + i]
osize = output_size if i == separate_dims - 1 else hidden_dims[shared_dims + i + 1]
layer = nn.Linear(isize, osize)
activation = activations[shared_dims + i + 1]
mean_layers += [layer, get_activation(activation)]
self._mean_layer = nn.Sequential(*mean_layers)
self._init(self._mean_layer, fade=init_fade, gain=init_gain)
log_std_layers = []
for i in range(separate_dims):
isize = hidden_dims[shared_dims + i]
osize = output_size if i == separate_dims - 1 else hidden_dims[shared_dims + i + 1]
layer = nn.Linear(isize, osize)
activation = activations[shared_dims + i + 1]
log_std_layers += [layer, get_activation(activation)]
self._log_std_layer = nn.Sequential(*log_std_layers)
self._init(self._log_std_layer, fade=init_fade, gain=init_gain)
def forward(self, x: torch.Tensor, compute_std: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
features = super().forward(x)
mean = self._mean_layer(features)
if not compute_std:
return mean
# compute standard deviation as std = std_init * exp(log_std) = exp(log(std_init) + log(std)) since the network
# will predict log_std ~= 0 after initialization.
log_std = (self._log_std_init + self._log_std_layer(features)).clamp(self._log_std_min, self._log_std_max)
std = log_std.exp()
return mean, std

View File

@ -0,0 +1,37 @@
import numpy as np
import torch
import torch.nn as nn
from typing import Tuple, Union
from rsl_rl.modules.network import Network
class GaussianNetwork(Network):
"""A network to predict mean and std of a gaussian distribution where std is a tunable parameter."""
def __init__(
self,
input_size: int,
output_size: int,
log_std_max: float = 4.0,
log_std_min: float = -20.0,
std_init: float = 1.0,
**kwargs,
):
super().__init__(input_size, output_size, **kwargs)
self._log_std_max = log_std_max
self._log_std_min = log_std_min
self._log_std = nn.Parameter(torch.ones(output_size) * np.log(std_init))
def forward(self, x: torch.Tensor, compute_std: bool = False, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
mean = super().forward(x, **kwargs)
if not compute_std:
return mean
log_std = torch.ones_like(mean) * self._log_std.clamp(self._log_std_min, self._log_std_max)
std = log_std.exp()
return mean, std

View File

@ -0,0 +1,168 @@
import numpy as np
import torch
from torch.distributions import Normal
import torch.nn as nn
from typing import List, Union
from rsl_rl.modules.network import Network
from rsl_rl.modules.quantile_network import energy_loss
from rsl_rl.utils.benchmarkable import Benchmarkable
def reshape_measure_param(tau: torch.Tensor, param: torch.Tensor) -> torch.Tensor:
if not torch.is_tensor(param):
param = torch.tensor([param])
param = param.expand(tau.shape[0], -1).to(tau.device)
return param
def risk_measure_neutral(tau: torch.Tensor) -> torch.Tensor:
return tau
def risk_measure_wang(tau: torch.Tensor, beta: float = 0.0) -> torch.Tensor:
beta = reshape_measure_param(tau, beta)
distorted_tau = Normal(0, 1).cdf(Normal(0, 1).icdf(tau) + beta)
return distorted_tau
class ImplicitQuantileNetwork(Network):
measure_neutral = "neutral"
measure_wang = "wang"
measures = {
measure_neutral: risk_measure_neutral,
measure_wang: risk_measure_wang,
}
def __init__(
self,
input_size: int,
output_size: int,
activations: List[str] = ["relu", "relu", "relu"],
feature_layers: int = 1,
embedding_size: int = 64,
hidden_dims: List[int] = [256, 256, 256],
init_fade: bool = False,
init_gain: float = 0.5,
measure: str = None,
measure_kwargs: dict = {},
**kwargs,
):
assert len(hidden_dims) == len(activations), "hidden_dims and activations must have the same length."
assert feature_layers > 0, "feature_layers must be greater than 0."
assert feature_layers < len(hidden_dims), "feature_layers must be less than the number of hidden dimensions."
assert embedding_size > 0, "embedding_size must be greater than 0."
super().__init__(
input_size,
hidden_dims[feature_layers - 1],
activations=activations[:feature_layers],
hidden_dims=hidden_dims[: feature_layers - 1],
init_fade=init_fade,
init_gain=init_gain,
**kwargs,
)
self._last_taus = None
self._last_quantiles = None
self._embedding_size = embedding_size
self.register_buffer(
"_embedding_pis",
np.pi * (torch.arange(self._embedding_size, device=self.device).reshape(1, 1, self._embedding_size)),
)
self._embedding_layer = nn.Sequential(
nn.Linear(self._embedding_size, hidden_dims[feature_layers - 1]), nn.ReLU()
)
self._fusion_layers = Network(
hidden_dims[feature_layers - 1],
output_size,
activations=activations[feature_layers:] + ["linear"],
hidden_dims=hidden_dims[feature_layers:],
init_fade=init_fade,
init_gain=init_gain,
)
measure_func = risk_measure_neutral if measure is None else self.measures[measure]
self._measure_func = measure_func
self._measure_kwargs = measure_kwargs
@Benchmarkable.register
def _sample_taus(self, batch_size: int, sample_count: int, measure_args: list, use_measure: bool) -> torch.Tensor:
"""Sample quantiles and distort them according to the risk metric.
Args:
batch_size: The batch size.
sample_count: The number of samples to draw.
measure_args: The arguments to pass to the risk measure function.
use_measure: Whether to use the risk measure or not.
Returns:
A tensor of shape (batch_size, sample_count, 1).
"""
taus = torch.rand(batch_size, sample_count, device=self.device)
if not use_measure:
return taus
if measure_args:
taus = self._measure_func(taus, *measure_args)
else:
taus = self._measure_func(taus, **self._measure_kwargs)
return taus
@Benchmarkable.register
def forward(
self,
x: torch.Tensor,
distribution: bool = False,
measure_args: list = [],
sample_count: int = 8,
taus: Union[torch.Tensor, None] = None,
use_measure: bool = True,
**kwargs,
) -> torch.Tensor:
assert taus is None or not use_measure, "Cannot use taus and use_measure at the same time."
batch_size = x.shape[0]
features = super().forward(x, **kwargs)
taus = self._sample_taus(batch_size, sample_count, measure_args, use_measure) if taus is None else taus
# Compute quantile embeddings
singular_dims = [1] * taus.dim()
cos = torch.cos(taus.unsqueeze(-1) * self._embedding_pis.reshape(*singular_dims, self._embedding_size))
embeddings = self._embedding_layer(cos)
# Compute the fusion of the features and the embeddings
fused_features = features.unsqueeze(-2) * embeddings
quantiles = self._fusion_layers(fused_features)
self._last_quantiles = quantiles
self._last_taus = taus
if distribution:
return quantiles
values = quantiles.mean(-1)
return values
@property
def last_taus(self):
return self._last_taus
@property
def last_quantiles(self):
return self._last_quantiles
@Benchmarkable.register
def sample_energy_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
loss = energy_loss(predictions, targets)
return loss

210
rsl_rl/modules/network.py Normal file
View File

@ -0,0 +1,210 @@
import torch
import torch.nn as nn
from typing import List
from rsl_rl.modules.normalizer import EmpiricalNormalization
from rsl_rl.modules.utils import get_activation
from rsl_rl.modules.transformer import Transformer
from rsl_rl.utils.benchmarkable import Benchmarkable
from rsl_rl.utils.utils import squeeze_preserve_batch
class Network(Benchmarkable, nn.Module):
recurrent_module_lstm = "LSTM"
recurrent_module_transformer = "TF"
recurrent_modules = {recurrent_module_lstm: nn.LSTM, recurrent_module_transformer: Transformer}
def __init__(
self,
input_size: int,
output_size: int,
activations: List[str] = ["relu", "relu", "relu", "tanh"],
hidden_dims: List[int] = [256, 256, 256],
init_fade: bool = True,
init_gain: float = 1.0,
input_normalization: bool = False,
recurrent: bool = False,
recurrent_layers: int = 1,
recurrent_module: str = recurrent_module_lstm,
recurrent_tf_context_length: int = 64,
recurrent_tf_head_count: int = 8,
) -> None:
"""
Args:
input_size (int): The size of the input.
output_size (int): The size of the output.
activations (List[str]): The activation functions to use. If the network is recurrent, the first activation
function is used for the output of the recurrent layer.
hidden_dims (List[int]): The hidden dimensions. If the network is recurrent, the first hidden dimension is
used for the recurrent layer.
init_fade (bool): Whether to use the fade in initialization.
init_gain (float): The gain to use for the initialization.
input_normalization (bool): Whether to use input normalization.
recurrent (bool): Whether to use a recurrent network.
recurrent_layers (int): The number of recurrent layers (LSTM) / blocks (Transformer) to use.
recurrent_module (str): The recurrent module to use. Must be one of Network.recurrent_modules.
recurrent_tf_context_length (int): The context length of the Transformer.
recurrent_tf_head_count (int): The head count of the Transformer.
"""
assert len(hidden_dims) + 1 == len(activations)
super().__init__()
if input_normalization:
self._normalization = EmpiricalNormalization(shape=(input_size,))
else:
self._normalization = nn.Identity()
dims = [input_size] + hidden_dims + [output_size]
self._recurrent = recurrent
self._recurrent_module = recurrent_module
self.hidden_state = None
self._last_hidden_state = None
if self._recurrent:
recurrent_kwargs = dict()
if recurrent_module == self.recurrent_module_lstm:
recurrent_kwargs["hidden_size"] = dims[1]
recurrent_kwargs["input_size"] = dims[0]
recurrent_kwargs["num_layers"] = recurrent_layers
elif recurrent_module == self.recurrent_module_transformer:
recurrent_kwargs["block_count"] = recurrent_layers
recurrent_kwargs["context_length"] = recurrent_tf_context_length
recurrent_kwargs["head_count"] = recurrent_tf_head_count
recurrent_kwargs["hidden_size"] = dims[1]
recurrent_kwargs["input_size"] = dims[0]
recurrent_kwargs["output_size"] = dims[1]
rnn = self.recurrent_modules[recurrent_module](**recurrent_kwargs)
activation = get_activation(activations[0])
dims = dims[1:]
activations = activations[1:]
self._features = nn.Sequential(rnn, activation)
else:
self._features = nn.Identity()
layers = []
for i in range(len(activations)):
layer = nn.Linear(dims[i], dims[i + 1])
activation = get_activation(activations[i])
layers.append(layer)
layers.append(activation)
self._layers = nn.Sequential(*layers)
if len(layers) > 0:
self._init(self._layers, fade=init_fade, gain=init_gain)
@property
def device(self):
"""Returns the device of the network."""
return next(self.parameters()).device
def forward(self, x: torch.Tensor, hidden_state=None) -> torch.Tensor:
"""
Args:
x (torch.Tensor): The input data.
hidden_state (Tuple[torch.Tensor, torch.Tensor]): The hidden state of the network. If None, the hidden state
of the network is used. If provided, the hidden state of the neural network will not be updated. To
retrieve the new hidden state, use the last_hidden_state property. If the network is not recurrent,
this argument is ignored.
Returns:
The output of the network as a torch.Tensor.
"""
assert hidden_state is None or self._recurrent, "Cannot pass hidden state to non-recurrent network."
input = self._normalization(x.to(self.device))
if self._recurrent:
current_hidden_state = self.hidden_state if hidden_state is None else hidden_state
current_hidden_state = (current_hidden_state[0].to(self.device), current_hidden_state[1].to(self.device))
input = input.unsqueeze(0) if len(input.shape) == 2 else input
input, next_hidden_state = self._features[0](input, current_hidden_state)
input = self._features[1](input).squeeze(0)
if hidden_state is None:
self.hidden_state = next_hidden_state
self._last_hidden_state = next_hidden_state
output = squeeze_preserve_batch(self._layers(input))
return output
@property
def last_hidden_state(self):
"""Returns the hidden state of the last forward pass.
Does not differentiate whether the hidden state depends on the hidden state kept in the network or whether it
was passed into the forward pass.
Returns:
The hidden state of the last forward pass as Tuple[torch.Tensor, torch.Tensor].
"""
return self._last_hidden_state
def normalize(self, x: torch.Tensor) -> torch.Tensor:
"""Normalizes the given input.
Args:
x (torch.Tensor): The input to normalize.
Returns:
The normalized input as a torch.Tensor.
"""
output = self._normalization(x.to(self.device))
return output
@property
def recurrent(self) -> bool:
"""Returns whether the network is recurrent."""
return self._recurrent
def reset_hidden_state(self, indices: torch.Tensor) -> None:
"""Resets the hidden state of the neural network.
Throws an error if the network is not recurrent.
Args:
indices (torch.Tensor): A 1-dimensional int tensor containing the indices of the terminated
environments.
"""
assert self._recurrent
self.hidden_state[0][:, indices] = torch.zeros(len(indices), self._features[0].hidden_size, device=self.device)
self.hidden_state[1][:, indices] = torch.zeros(len(indices), self._features[0].hidden_size, device=self.device)
def reset_full_hidden_state(self, batch_size=None) -> None:
"""Resets the hidden state of the neural network.
Args:
batch_size (int): The batch size of the hidden state. If None, the hidden state is reset to None.
"""
assert self._recurrent
if batch_size is None:
self.hidden_state = None
else:
layer_count, hidden_size = self._features[0].num_layers, self._features[0].hidden_size
self.hidden_state = (
torch.zeros(layer_count, batch_size, hidden_size, device=self.device),
torch.zeros(layer_count, batch_size, hidden_size, device=self.device),
)
def _init(self, layers: List[nn.Module], fade: bool = True, gain: float = 1.0) -> List[nn.Module]:
"""Initializes neural network layers."""
last_layer_idx = len(layers) - 1 - next(i for i, l in enumerate(reversed(layers)) if isinstance(l, nn.Linear))
for idx, layer in enumerate(layers):
if not isinstance(layer, nn.Linear):
continue
current_gain = gain / 100.0 if fade and idx == last_layer_idx else gain
nn.init.xavier_normal_(layer.weight, gain=current_gain)
return layers

View File

@ -1,9 +1,3 @@
# Copyright (c) 2020 Preferred Networks, Inc.
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
import torch
from torch import nn
@ -11,48 +5,58 @@ from torch import nn
class EmpiricalNormalization(nn.Module):
"""Normalize mean and variance of values based on empirical values."""
def __init__(self, shape, eps=1e-2, until=None):
"""Initialize EmpiricalNormalization module.
def __init__(self, shape, eps=1e-6, until=None) -> None:
"""
Args:
shape (int or tuple of int): Shape of input values except batch axis.
eps (float): Small value for stability.
until (int or None): If this arg is specified, the link learns input values until the sum of batch sizes
exceeds it.
exceeds it.
"""
super().__init__()
self.eps = eps
self.until = until
self.register_buffer("_mean", torch.zeros(shape).unsqueeze(0))
self.register_buffer("_var", torch.ones(shape).unsqueeze(0))
self.register_buffer("_std", torch.ones(shape).unsqueeze(0))
self.count = 0
@property
def mean(self):
return self._mean.squeeze(0).clone()
def mean(self) -> torch.Tensor:
"""Mean of input values."""
return self._mean.squeeze(0).detach().clone()
@property
def std(self):
return self._std.squeeze(0).clone()
def forward(self, x):
"""Normalize mean and variance of values based on empirical values.
def std(self) -> torch.Tensor:
"""Standard deviation of input values."""
return self._std.squeeze(0).detach().clone()
def forward(self, x) -> torch.Tensor:
"""Normalize mean and variance of values based on emprical values.
Args:
x (ndarray or Variable): Input values
Returns:
ndarray or Variable: Normalized output values
Normalized output values
"""
if self.training:
self.update(x)
return (x - self._mean) / (self._std + self.eps)
x_normalized = (x - self._mean.detach()) / (self._std.detach() + self.eps)
return x_normalized
@torch.jit.unused
def update(self, x):
"""Learn input values without computing the output values of them"""
def update(self, x: torch.Tensor) -> None:
"""Learn input values without computing the output values of them.
Args:
x (torch.Tensor): Input values.
"""
x = x.detach()
if self.until is not None and self.count >= self.until:
return
@ -69,5 +73,14 @@ class EmpiricalNormalization(nn.Module):
self._std = torch.sqrt(self._var)
@torch.jit.unused
def inverse(self, y):
return y * (self._std + self.eps) + self._mean
def inverse(self, y: torch.Tensor) -> torch.Tensor:
"""Inverse normalized values.
Args:
y (torch.Tensor): Normalized input values.
Returns:
Inverse normalized output values.
"""
inv = y * (self._std + self.eps) + self._mean
return inv

View File

@ -0,0 +1,333 @@
import torch
import torch.nn as nn
from torch.distributions import Normal
from typing import Callable, Tuple, Union
from rsl_rl.modules.network import Network
from rsl_rl.utils.benchmarkable import Benchmarkable
from rsl_rl.utils.utils import squeeze_preserve_batch
eps = torch.finfo(torch.float32).eps
def reshape_measure_parameters(
qn: Network, *params: Union[torch.Tensor, float]
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
"""Reshapes the parameters of a measure function to match the shape of the quantile network.
Args:
qn (Network): The quantile network.
*params (Union[torch.Tensor, float]): The parameters of the measure function.
Returns:
Union[torch.Tensor, Tuple[torch.Tensor, ...]]: The reshaped parameters.
"""
if not params:
return qn._tau.to(qn.device), *params
assert len([*set([torch.is_tensor(p) for p in params])]) == 1, "All parameters must be either tensors or scalars."
if torch.is_tensor(params[0]):
assert all([p.dim() == 1 for p in params]), "All parameters must have dimensionality 1."
assert len([*set([p.shape[0] for p in params])]) == 1, "All parameters must have the same size."
reshaped_params = [p.reshape(-1, 1).to(qn.device) for p in params]
tau = qn._tau.expand(params[0].shape[0], -1).to(qn.device)
else:
reshaped_params = params
tau = qn._tau.to(qn.device)
return tau, *reshaped_params
def make_distorted_measure(distorted_tau: torch.Tensor) -> Callable:
"""Creates a measure function for the distorted expectation under some distortion function.
The distorted expectation for some distortion function g(tau) is given by the integral w.r.t. tau
"int_0^1 g'(tau) * F_Z^{-1}(tau) dtau" where g'(tau) is the derivative of g w.r.t. tau and F_Z^{-1} is the inverse
cumulative distribution function of the value distribution.
See https://arxiv.org/pdf/2004.14547.pdf and https://arxiv.org/pdf/1806.06923.pdf for details.
"""
distorted_tau = distorted_tau.reshape(-1, distorted_tau.shape[-1])
distortion = (distorted_tau[:, 1:] - distorted_tau[:, :-1]).squeeze(0)
def distorted_measure(quantiles):
sorted_quantiles, _ = quantiles.sort(-1)
sorted_quantiles = sorted_quantiles.reshape(-1, sorted_quantiles.shape[-1])
# dtau = tau[1:] - tau[:-1] cancels the denominator of g'(tau) = g(tau)[1:] - g(tau)[:-1] / dtau.
values = squeeze_preserve_batch((distortion.to(sorted_quantiles.device) * sorted_quantiles).sum(-1))
return values
return distorted_measure
def risk_measure_cvar(qn: Network, confidence_level: float = 1.0) -> Callable:
"""Conditional value at risk measure.
TODO: Handle confidence_level being a tensor.
Args:
qn (QuantileNetwork): Quantile network to compute the risk measure for.
confidence_level (float): Confidence level of the risk measure. Must be between 0 and 1.
Returns:
A risk measure function.
"""
tau, confidence_level = reshape_measure_parameters(qn, confidence_level)
distorted_tau = torch.min(tau / confidence_level, torch.ones(*tau.shape).to(tau.device))
return make_distorted_measure(distorted_tau)
def risk_measure_neutral(_: Network) -> Callable:
"""Neutral risk measure (expected value).
Args:
_ (QuantileNetwork): Quantile network to compute the risk measure for.
Returns:
A risk measure function.
"""
def measure(quantiles):
values = squeeze_preserve_batch(quantiles.mean(-1))
return values
return measure
def risk_measure_percentile(_: Network, confidence_level: float = 1.0) -> Callable:
"""Value at risk measure.
Args:
_ (QuantileNetwork): Quantile network to compute the risk measure for.
confidence_level (float): Confidence level of the risk measure. Must be between 0 and 1.
Returns:
A risk measure function.
"""
def measure(quantiles):
sorted_quantiles, _ = quantiles.sort(-1)
sorted_quantiles = sorted_quantiles.reshape(-1, sorted_quantiles.shape[-1])
idx = min(int(confidence_level * quantiles.shape[-1]), quantiles.shape[-1] - 1)
values = squeeze_preserve_batch(sorted_quantiles[:, idx])
return values
return measure
def risk_measure_wang(qn: Network, beta: Union[float, torch.Tensor] = 0.0) -> Callable:
"""Wang's risk measure.
The risk measure computes the distorted expectation under Wang's risk distortion function
g(tau) = Phi(Phi^-1(tau) + beta) where Phi and Phi^-1 are the standard normal CDF and its inverse.
See https://arxiv.org/pdf/2004.14547.pdf for details.
Args:
qn (QuantileNetwork): Quantile network to compute the risk measure for.
beta (float): Parameter of the risk distortion function.
Returns:
A risk measure function.
"""
tau, beta = reshape_measure_parameters(qn, beta)
distorted_tau = Normal(0, 1).cdf(Normal(0, 1).icdf(tau) + beta)
return make_distorted_measure(distorted_tau)
def energy_loss(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""Computes sample energy loss between predictions and targets.
The energy loss is computed as 2*E[||X - Y||_2] - E[||X - X'||_2] - E[||Y - Y'||_2], where X, X' and Y, Y' are
random variables and ||.||_2 is the L2-norm. X, X' are the predictions and Y, Y' are the targets.
Args:
predictions (torch.Tensor): Predictions to compute loss from.
targets (torch.Tensor): Targets to compare predictions against.
Returns:
A torch.Tensor of shape (1,) containing the loss.
"""
dims = [-1 for _ in range(predictions.dim())]
prediction_mat = predictions.unsqueeze(-1).expand(*dims, predictions.shape[-1])
target_mat = targets.unsqueeze(-1).expand(*dims, predictions.shape[-1])
delta_xx = (prediction_mat - prediction_mat.transpose(-1, -2)).abs().mean()
delta_yy = (target_mat - target_mat.transpose(-1, -2)).abs().mean()
delta_xy = (prediction_mat - target_mat.transpose(-1, -2)).abs().mean()
loss = 2 * delta_xy - delta_xx - delta_yy
return loss
class QuantileNetwork(Network):
measure_cvar = "cvar"
measure_neutral = "neutral"
measure_percentile = "percentile"
measure_wang = "wang"
measures = {
measure_cvar: risk_measure_cvar,
measure_neutral: risk_measure_neutral,
measure_percentile: risk_measure_percentile,
measure_wang: risk_measure_wang,
}
def __init__(
self,
input_size,
output_size,
activations=["relu", "relu", "relu"],
hidden_dims=[256, 256, 256],
init_fade=False,
init_gain=0.5,
measure=None,
measure_kwargs={},
quantile_count=200,
**kwargs,
):
assert len(hidden_dims) == len(activations)
assert quantile_count > 0
super().__init__(
input_size,
activations=activations,
hidden_dims=hidden_dims[:-1],
init_fade=False,
init_gain=init_gain,
output_size=hidden_dims[-1],
**kwargs,
)
self._quantile_count = quantile_count
self._tau = torch.arange(self._quantile_count + 1) / self._quantile_count
self._tau_hat = torch.tensor([(self._tau[i] + self._tau[i + 1]) / 2 for i in range(self._quantile_count)])
self._tau_hat_mat = torch.empty((0,))
self._quantile_layers = nn.ModuleList([nn.Linear(hidden_dims[-1], quantile_count) for _ in range(output_size)])
self._init(self._quantile_layers, fade=init_fade, gain=init_gain)
measure_func = risk_measure_neutral if measure is None else self.measures[measure]
self._measure_func = measure_func
self._measure = measure_func(self, **measure_kwargs)
self._last_quantiles = None
@property
def last_quantiles(self) -> torch.Tensor:
return self._last_quantiles
def make_diracs(self, values: torch.Tensor) -> torch.Tensor:
"""Generates value distributions that have a single spike at the given values.
Args:
values (torch.Tensor): Values to generate dirac distributions for.
Returns:
A torch.Tensor of shape (*values.shape, quantile_count) containing the dirac distributions.
"""
dirac = values.unsqueeze(-1).expand(*[-1 for _ in range(values.dim())], self._quantile_count)
return dirac
@property
def quantile_count(self) -> int:
return self._quantile_count
@Benchmarkable.register
def quantiles_to_values(self, quantiles: torch.Tensor, *measure_args) -> torch.Tensor:
"""Computes values from quantiles.
Args:
quantiles (torch.Tensor): Quantiles to compute values from.
measure_kwargs (dict): Keyword arguments to pass to the risk measure function instead of the arguments
passed when creating the network.
Returns:
A torch.Tensor of shape (1,) containing the values.
"""
if measure_args:
values = self._measure_func(self, *[squeeze_preserve_batch(m) for m in measure_args])(quantiles)
else:
values = self._measure(quantiles)
return values
@Benchmarkable.register
def quantile_l1_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""Computes quantile-wise l1 loss between predictions and targets.
TODO: This function is a bottleneck.
Args:
predictions (torch.Tensor): Predictions to compute loss from.
targets (torch.Tensor): Targets to compare predictions against.
Returns:
A torch.Tensor of shape (1,) containing the loss.
"""
assert (
predictions.dim() == 2 or predictions.dim() == 3
), f"Predictions must be 2D or 3D. Got {predictions.dim()}."
assert (
predictions.shape == targets.shape
), f"The shapes of predictions and targets must match. Got {predictions.shape} and {targets.shape}."
pre_dims = [-1] if predictions.dim() == 3 else []
prediction_mat = predictions.unsqueeze(-3).expand(*pre_dims, self._quantile_count, -1, -1)
target_mat = targets.transpose(-2, -1).unsqueeze(-1).expand(*pre_dims, -1, -1, self._quantile_count)
delta = target_mat - prediction_mat
tau_hat = self._tau_hat.expand(predictions.shape[-2], -1).to(self.device)
loss = (torch.where(delta < 0, (tau_hat - 1), tau_hat) * delta).abs().mean()
return loss
@Benchmarkable.register
def quantile_huber_loss(self, predictions: torch.Tensor, targets: torch.Tensor, kappa: float = 1.0) -> torch.Tensor:
"""Computes quantile huber loss between predictions and targets.
TODO: This function is a bottleneck.
Args:
predictions (torch.Tensor): Predictions to compute loss from.
targets (torch.Tensor): Targets to compare predictions against.
kappa (float): Defines the interval [-kappa, kappa] around zero where squared loss is used. Defaults to 1.
Returns:
A torch.tensor of shape (1,) containing the loss.
"""
pre_dims = [-1] if predictions.dim() == 3 else []
prediction_mat = predictions.unsqueeze(-3).expand(*pre_dims, self._quantile_count, -1, -1)
target_mat = targets.transpose(-2, -1).unsqueeze(-1).expand(*pre_dims, -1, -1, self._quantile_count)
delta = target_mat - prediction_mat
delta_abs = delta.abs()
huber = torch.where(delta_abs <= kappa, 0.5 * delta.pow(2), kappa * (delta_abs - 0.5 * kappa))
tau_hat = self._tau_hat.expand(predictions.shape[-2], -1).to(self.device)
loss = (torch.where(delta < 0, (tau_hat - 1), tau_hat).abs() * huber).mean()
return loss
@Benchmarkable.register
def sample_energy_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
loss = energy_loss(predictions, targets)
return loss
@Benchmarkable.register
def forward(self, x: torch.Tensor, distribution: bool = False, measure_args: list = [], **kwargs) -> torch.Tensor:
features = super().forward(x, **kwargs)
quantiles = squeeze_preserve_batch(torch.stack([layer(features) for layer in self._quantile_layers], dim=1))
self._last_quantiles = quantiles
if distribution:
return quantiles
values = self.quantiles_to_values(quantiles, *measure_args)
return values

View File

@ -0,0 +1,150 @@
import torch
from typing import Tuple
class Head(torch.nn.Module):
"""A single causal self-attention head."""
def __init__(self, hidden_size: int, head_size: int):
super().__init__()
self.query = torch.nn.Linear(hidden_size, head_size)
self.key = torch.nn.Linear(hidden_size, head_size)
self.value = torch.nn.Linear(hidden_size, head_size)
def forward(self, x: torch.Tensor):
x = x.transpose(0, 1)
_, S, F = x.shape # (Batch, Sequence, Features)
q = self.query(x)
k = self.key(x)
weight = q @ k.transpose(-1, -2) * F**-0.5 # shape: (B, S, S)
weight.masked_fill(torch.tril(torch.ones(S, S, device=x.device)) == 0, float("-inf"))
weight = torch.nn.functional.softmax(weight, dim=-1)
v = self.value(x) # shape: (B, S, F)
out = (weight @ v).transpose(0, 1) # shape: (S, B, F)
return out
class MultiHead(torch.nn.Module):
def __init__(self, hidden_size: int, head_count: int):
super().__init__()
assert hidden_size % head_count == 0, f"Multi-headed attention head size must be a multiple of the head count."
self.heads = torch.nn.ModuleList([Head(hidden_size, hidden_size // head_count) for _ in range(head_count)])
self.proj = torch.nn.Linear(hidden_size, hidden_size)
def forward(self, x: torch.Tensor):
x = torch.cat([head(x) for head in self.heads], dim=-1)
out = self.proj(x)
return out
class Block(torch.nn.Module):
def __init__(self, hidden_size: int, head_count: int):
super().__init__()
self.sa = MultiHead(hidden_size, head_count)
self.ff = torch.nn.Sequential(
torch.nn.Linear(hidden_size, 4 * hidden_size),
torch.nn.ReLU(),
torch.nn.Linear(4 * hidden_size, hidden_size),
)
self.ln1 = torch.nn.LayerNorm(hidden_size)
self.ln2 = torch.nn.LayerNorm(hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.sa(self.ln1(x)) + x
out = self.ff(self.ln2(x)) + x
return out
class Transformer(torch.nn.Module):
"""A Transformer-based recurrent module.
The Transformer module is a recurrent module that uses a Transformer architecture to process the input sequence. It
uses a hidden state to emulate RNN-like behavior.
"""
def __init__(
self, input_size, output_size, hidden_size, block_count: int = 6, context_length: int = 64, head_count: int = 8
):
"""
Args:
input_size (int): The size of the input.
output_size (int): The size of the output.
hidden_size (int): The size of the hidden layers.
block_count (int): The number of Transformer blocks.
context_length (int): The length of the context to consider when predicting the next token.
head_count (int): The number of attention heads per block.
"""
assert context_length % 2 == 0, f"Context length must be even."
super().__init__()
self.context_length = context_length
self.hidden_size = hidden_size
self.feature_proj = torch.nn.Linear(input_size, hidden_size)
self.position_embedding = torch.nn.Embedding(context_length, hidden_size)
self.blocks = torch.nn.Sequential(
*[Block(hidden_size, head_count) for _ in range(block_count)],
torch.nn.LayerNorm(hidden_size),
)
self.head = torch.nn.Linear(hidden_size, output_size)
@property
def num_layers(self):
# Set num_layers to half the context length for simple torch.nn.LSTM compatibility. TODO: This is a bit hacky.
return self.context_length // 2
def step(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
"""Computes Transformer output given the full context and the input.
Args:
x (torch.Tensor): The input tensor of shape (Sequence, Batch, Features).
context (torch.Tensor): The context tensor of shape (Context Length, Batch, Features).
Returns:
A tuple of the output tensor of shape (Sequence, Batch, Features) and the updated context with the input
features appended. The context has shape (Context Length, Batch, Features).
"""
S = x.shape[0]
# Project input to feature space and add to context.
ft_x = self.feature_proj(x)
context = torch.cat((context, ft_x), dim=0)[-self.context_length :]
# Add positional embedding to context.
ft_pos = self.position_embedding(torch.arange(self.context_length, device=x.device)).unsqueeze(1)
x = context + ft_pos
# Compute output from Transformer blocks.
x = self.blocks(x)
out = self.head(x)[-S:]
return out, context
def forward(self, x: torch.Tensor, hidden_state: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
"""Computes Transformer output given the input and the hidden state which encapsulates the context."""
if hidden_state is None:
hidden_state = self.reset_hidden_state(x.shape[1], device=x.device)
context = torch.cat(hidden_state, dim=0)
out, context = self.step(x, context)
hidden_state = context[: self.num_layers], context[self.num_layers :]
return out, hidden_state
def reset_hidden_state(self, batch_size: int, device="cpu"):
hidden_state = torch.zeros((self.context_length, batch_size, self.hidden_size), device=device)
hidden_state = hidden_state[: self.num_layers], hidden_state[self.num_layers :]
return hidden_state

32
rsl_rl/modules/utils.py Normal file
View File

@ -0,0 +1,32 @@
from torch import nn
def get_activation(act_name):
if act_name == "elu":
return nn.ELU()
elif act_name == "selu":
return nn.SELU()
elif act_name == "relu":
return nn.ReLU()
elif act_name == "crelu":
return nn.ReLU()
elif act_name == "lrelu":
return nn.LeakyReLU()
elif act_name == "tanh":
return nn.Tanh()
elif act_name == "sigmoid":
return nn.Sigmoid()
elif act_name == "linear":
return nn.Identity()
elif act_name == "softmax":
return nn.Softmax()
else:
print("invalid activation function!")
return None
def init_xavier_uniform(layer, activation):
try:
nn.init.xavier_uniform_(layer.weight, gain=nn.init.calculate_gain(activation))
except ValueError:
nn.init.xavier_uniform_(layer.weight)

View File

@ -3,6 +3,7 @@
"""Implementation of runners for environment-agent interaction."""
from .on_policy_runner import OnPolicyRunner
from .runner import Runner
from .legacy_runner import LeggedGymRunner
__all__ = ["OnPolicyRunner"]
__all__ = ["LeggedGymRunner", "Runner"]

View File

@ -0,0 +1,79 @@
import os
import random
import string
import wandb
def make_save_model_cb(directory):
def cb(runner, stat):
path = os.path.join(directory, "model_{}.pt".format(stat["current_iteration"]))
runner.save(path)
return cb
def make_save_model_onnx_cb(directory):
def cb(runner, stat):
path = os.path.join(directory, "model_{}.onnx".format(stat["current_iteration"]))
runner.export_onnx(path)
return cb
def make_interval_cb(callback, interval):
def cb(runner, stat):
if stat["current_iteration"] % interval != 0:
return
callback(runner, stat)
return cb
def make_final_cb(callback):
def cb(runner, stat):
if not runner._learning_should_terminate():
return
callback(runner, stat)
return cb
def make_first_cb(callback):
uuid = "".join(random.choices(string.ascii_letters + string.digits, k=8))
def cb(runner, stat):
if hasattr(runner, f"_first_cb_{uuid}"):
return
setattr(runner, f"_first_cb_{uuid}", True)
callback(runner, stat)
return cb
def make_wandb_cb(init_kwargs):
assert "project" in init_kwargs, "The project must be specified in the init_kwargs."
run = wandb.init(**init_kwargs)
check_complete = make_final_cb(lambda *_: run.finish())
def cb(runner, stat):
mean_reward = sum(stat["returns"]) / len(stat["returns"]) if len(stat["returns"]) > 0 else 0.0
mean_steps = sum(stat["lengths"]) / len(stat["lengths"]) if len(stat["lengths"]) > 0 else 0.0
total_steps = stat["current_iteration"] * runner.env.num_envs * runner._num_steps_per_env
training_time = stat["training_time"]
run.log(
{
"mean_rewards": mean_reward,
"mean_steps": mean_steps,
"training_steps": total_steps,
"training_time": training_time,
}
)
check_complete(runner, stat)
return cb

View File

@ -0,0 +1,136 @@
import os
from rsl_rl.algorithms import *
from rsl_rl.env import VecEnv
from rsl_rl.runners.callbacks import (
make_final_cb,
make_first_cb,
make_interval_cb,
make_save_model_onnx_cb,
)
from rsl_rl.runners.runner import Runner
from rsl_rl.storage import *
def make_legacy_save_model_cb(directory):
def cb(runner, stat):
data = {}
if hasattr(runner.env, "_persistent_data"):
data["env_data"] = runner.env._persistent_data
path = os.path.join(directory, "model_{}.pt".format(stat["current_iteration"]))
runner.save(path, data=data)
return cb
class LeggedGymRunner(Runner):
"""Runner for legged_gym environments."""
mappings = [
("init_noise_std", "actor_noise_std"),
("clip_param", "clip_ratio"),
("desired_kl", "target_kl"),
("entropy_coef", "entropy_coeff"),
("lam", "gae_lambda"),
("max_grad_norm", "gradient_clip"),
("num_learning_epochs", None),
("num_mini_batches", "batch_count"),
("use_clipped_value_loss", None),
("value_loss_coef", "value_coeff"),
]
@staticmethod
def _hook_env(env: VecEnv):
old_step = env.step
def step_hook(*args, **kwargs):
result = old_step(*args, **kwargs)
if len(result) == 4:
obs, rewards, dones, env_info = result
elif len(result) == 5:
obs, _, rewards, dones, env_info = result
else:
raise ValueError("Invalid number of return values from env.step().")
return obs, rewards, dones.float(), env_info
env.step = step_hook
return env
def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"):
env = self._hook_env(env)
self.cfg = train_cfg["runner"]
alg_class = eval(self.cfg["algorithm_class_name"])
if "policy_class_name" in self.cfg:
print("WARNING: ignoring deprecated parameter 'runner.policy_class_name'.")
alg_cfg = train_cfg["algorithm"]
alg_cfg.update(train_cfg["policy"])
if "activation" in alg_cfg:
print(
"WARNING: using deprecated parameter 'activation'. Use 'actor_activations' and 'critic_activations' instead."
)
alg_cfg["actor_activations"] = [alg_cfg["activation"] for _ in range(len(alg_cfg["actor_hidden_dims"]))]
alg_cfg["actor_activations"] += ["linear"]
alg_cfg["critic_activations"] = [alg_cfg["activation"] for _ in range(len(alg_cfg["critic_hidden_dims"]))]
alg_cfg["critic_activations"] += ["linear"]
del alg_cfg["activation"]
for old, new in self.mappings:
if old not in alg_cfg:
continue
if new is None:
print(f"WARNING: ignoring deprecated parameter '{old}'.")
del alg_cfg[old]
continue
print(f"WARNING: using deprecated parameter '{old}'. Use '{new}' instead.")
alg_cfg[new] = alg_cfg[old]
del alg_cfg[old]
agent: Agent = alg_class(env, device=device, **train_cfg["algorithm"])
callbacks = []
evaluation_callbacks = []
evaluation_callbacks.append(lambda *args: Runner._log_progress(*args, prefix="eval"))
if log_dir and "save_interval" in self.cfg:
callbacks.append(make_first_cb(make_legacy_save_model_cb(log_dir)))
callbacks.append(make_interval_cb(make_legacy_save_model_cb(log_dir), self.cfg["save_interval"]))
if log_dir:
callbacks.append(Runner._log)
callbacks.append(make_final_cb(make_legacy_save_model_cb(log_dir)))
callbacks.append(make_final_cb(make_save_model_onnx_cb(log_dir)))
# callbacks.append(make_first_cb(lambda *_: store_code_state(log_dir, self._git_status_repos)))
else:
callbacks.append(Runner._log_progress)
super().__init__(
env,
agent,
learn_cb=callbacks,
evaluation_cb=evaluation_callbacks,
device=device,
num_steps_per_env=self.cfg["num_steps_per_env"],
)
self._iteration_time = 0.0
def learn(self, *args, num_learning_iterations=None, init_at_random_ep_len=None, **kwargs):
if num_learning_iterations is not None:
print("WARNING: using deprecated parameter 'num_learning_iterations'. Use 'iterations' instead.")
kwargs["iterations"] = num_learning_iterations
if init_at_random_ep_len is not None:
print("WARNING: ignoring deprecated parameter 'init_at_random_ep_len'.")
super().learn(*args, **kwargs)

View File

@ -1,304 +0,0 @@
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
import os
import statistics
import time
import torch
from collections import deque
from torch.utils.tensorboard import SummaryWriter as TensorboardSummaryWriter
import rsl_rl
from rsl_rl.algorithms import PPO
from rsl_rl.env import VecEnv
from rsl_rl.modules import ActorCritic, ActorCriticRecurrent, EmpiricalNormalization
from rsl_rl.utils import store_code_state
class OnPolicyRunner:
"""On-policy runner for training and evaluation."""
def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"):
self.cfg = train_cfg
self.alg_cfg = train_cfg["algorithm"]
self.policy_cfg = train_cfg["policy"]
self.device = device
self.env = env
obs, extras = self.env.get_observations()
num_obs = obs.shape[1]
if "critic" in extras["observations"]:
num_critic_obs = extras["observations"]["critic"].shape[1]
else:
num_critic_obs = num_obs
actor_critic_class = eval(self.policy_cfg.pop("class_name")) # ActorCritic
actor_critic: ActorCritic | ActorCriticRecurrent = actor_critic_class(
num_obs, num_critic_obs, self.env.num_actions, **self.policy_cfg
).to(self.device)
alg_class = eval(self.alg_cfg.pop("class_name")) # PPO
self.alg: PPO = alg_class(actor_critic, device=self.device, **self.alg_cfg)
self.num_steps_per_env = self.cfg["num_steps_per_env"]
self.save_interval = self.cfg["save_interval"]
self.empirical_normalization = self.cfg["empirical_normalization"]
if self.empirical_normalization:
self.obs_normalizer = EmpiricalNormalization(shape=[num_obs], until=1.0e8).to(self.device)
self.critic_obs_normalizer = EmpiricalNormalization(shape=[num_critic_obs], until=1.0e8).to(self.device)
else:
self.obs_normalizer = torch.nn.Identity() # no normalization
self.critic_obs_normalizer = torch.nn.Identity() # no normalization
# init storage and model
self.alg.init_storage(
self.env.num_envs,
self.num_steps_per_env,
[num_obs],
[num_critic_obs],
[self.env.num_actions],
)
# Log
self.log_dir = log_dir
self.writer = None
self.tot_timesteps = 0
self.tot_time = 0
self.current_learning_iteration = 0
self.git_status_repos = [rsl_rl.__file__]
def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = False):
# initialize writer
if self.log_dir is not None and self.writer is None:
# Launch either Tensorboard or Neptune & Tensorboard summary writer(s), default: Tensorboard.
self.logger_type = self.cfg.get("logger", "tensorboard")
self.logger_type = self.logger_type.lower()
if self.logger_type == "neptune":
from rsl_rl.utils.neptune_utils import NeptuneSummaryWriter
self.writer = NeptuneSummaryWriter(log_dir=self.log_dir, flush_secs=10, cfg=self.cfg)
self.writer.log_config(self.env.cfg, self.cfg, self.alg_cfg, self.policy_cfg)
elif self.logger_type == "wandb":
from rsl_rl.utils.wandb_utils import WandbSummaryWriter
self.writer = WandbSummaryWriter(log_dir=self.log_dir, flush_secs=10, cfg=self.cfg)
self.writer.log_config(self.env.cfg, self.cfg, self.alg_cfg, self.policy_cfg)
elif self.logger_type == "tensorboard":
self.writer = TensorboardSummaryWriter(log_dir=self.log_dir, flush_secs=10)
else:
raise AssertionError("logger type not found")
if init_at_random_ep_len:
self.env.episode_length_buf = torch.randint_like(
self.env.episode_length_buf, high=int(self.env.max_episode_length)
)
obs, extras = self.env.get_observations()
critic_obs = extras["observations"].get("critic", obs)
obs, critic_obs = obs.to(self.device), critic_obs.to(self.device)
self.train_mode() # switch to train mode (for dropout for example)
ep_infos = []
rewbuffer = deque(maxlen=100)
lenbuffer = deque(maxlen=100)
cur_reward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
cur_episode_length = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
start_iter = self.current_learning_iteration
tot_iter = start_iter + num_learning_iterations
for it in range(start_iter, tot_iter):
start = time.time()
# Rollout
with torch.inference_mode():
for i in range(self.num_steps_per_env):
actions = self.alg.act(obs, critic_obs)
obs, rewards, dones, infos = self.env.step(actions)
obs = self.obs_normalizer(obs)
if "critic" in infos["observations"]:
critic_obs = self.critic_obs_normalizer(infos["observations"]["critic"])
else:
critic_obs = obs
obs, critic_obs, rewards, dones = (
obs.to(self.device),
critic_obs.to(self.device),
rewards.to(self.device),
dones.to(self.device),
)
self.alg.process_env_step(rewards, dones, infos)
if self.log_dir is not None:
# Book keeping
# note: we changed logging to use "log" instead of "episode" to avoid confusion with
# different types of logging data (rewards, curriculum, etc.)
if "episode" in infos:
ep_infos.append(infos["episode"])
elif "log" in infos:
ep_infos.append(infos["log"])
cur_reward_sum += rewards
cur_episode_length += 1
new_ids = (dones > 0).nonzero(as_tuple=False)
rewbuffer.extend(cur_reward_sum[new_ids][:, 0].cpu().numpy().tolist())
lenbuffer.extend(cur_episode_length[new_ids][:, 0].cpu().numpy().tolist())
cur_reward_sum[new_ids] = 0
cur_episode_length[new_ids] = 0
stop = time.time()
collection_time = stop - start
# Learning step
start = stop
self.alg.compute_returns(critic_obs)
mean_value_loss, mean_surrogate_loss = self.alg.update()
stop = time.time()
learn_time = stop - start
self.current_learning_iteration = it
if self.log_dir is not None:
self.log(locals())
if it % self.save_interval == 0:
self.save(os.path.join(self.log_dir, f"model_{it}.pt"))
ep_infos.clear()
if it == start_iter:
# obtain all the diff files
git_file_paths = store_code_state(self.log_dir, self.git_status_repos)
# if possible store them to wandb
if self.logger_type in ["wandb", "neptune"] and git_file_paths:
for path in git_file_paths:
self.writer.save_file(path)
self.save(os.path.join(self.log_dir, f"model_{self.current_learning_iteration}.pt"))
def log(self, locs: dict, width: int = 80, pad: int = 35):
self.tot_timesteps += self.num_steps_per_env * self.env.num_envs
self.tot_time += locs["collection_time"] + locs["learn_time"]
iteration_time = locs["collection_time"] + locs["learn_time"]
ep_string = ""
if locs["ep_infos"]:
for key in locs["ep_infos"][0]:
infotensor = torch.tensor([], device=self.device)
for ep_info in locs["ep_infos"]:
# handle scalar and zero dimensional tensor infos
if key not in ep_info:
continue
if not isinstance(ep_info[key], torch.Tensor):
ep_info[key] = torch.Tensor([ep_info[key]])
if len(ep_info[key].shape) == 0:
ep_info[key] = ep_info[key].unsqueeze(0)
infotensor = torch.cat((infotensor, ep_info[key].to(self.device)))
value = torch.mean(infotensor)
# log to logger and terminal
if "/" in key:
self.writer.add_scalar(key, value, locs["it"])
ep_string += f"""{f'{key}:':>{pad}} {value:.4f}\n"""
else:
self.writer.add_scalar("Episode/" + key, value, locs["it"])
ep_string += f"""{f'Mean episode {key}:':>{pad}} {value:.4f}\n"""
mean_std = self.alg.actor_critic.std.mean()
fps = int(self.num_steps_per_env * self.env.num_envs / (locs["collection_time"] + locs["learn_time"]))
self.writer.add_scalar("Loss/value_function", locs["mean_value_loss"], locs["it"])
self.writer.add_scalar("Loss/surrogate", locs["mean_surrogate_loss"], locs["it"])
self.writer.add_scalar("Loss/learning_rate", self.alg.learning_rate, locs["it"])
self.writer.add_scalar("Policy/mean_noise_std", mean_std.item(), locs["it"])
self.writer.add_scalar("Perf/total_fps", fps, locs["it"])
self.writer.add_scalar("Perf/collection time", locs["collection_time"], locs["it"])
self.writer.add_scalar("Perf/learning_time", locs["learn_time"], locs["it"])
if len(locs["rewbuffer"]) > 0:
self.writer.add_scalar("Train/mean_reward", statistics.mean(locs["rewbuffer"]), locs["it"])
self.writer.add_scalar("Train/mean_episode_length", statistics.mean(locs["lenbuffer"]), locs["it"])
if self.logger_type != "wandb": # wandb does not support non-integer x-axis logging
self.writer.add_scalar("Train/mean_reward/time", statistics.mean(locs["rewbuffer"]), self.tot_time)
self.writer.add_scalar(
"Train/mean_episode_length/time", statistics.mean(locs["lenbuffer"]), self.tot_time
)
str = f" \033[1m Learning iteration {locs['it']}/{locs['tot_iter']} \033[0m "
if len(locs["rewbuffer"]) > 0:
log_string = (
f"""{'#' * width}\n"""
f"""{str.center(width, ' ')}\n\n"""
f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[
'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n"""
f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n"""
f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n"""
f"""{'Mean reward:':>{pad}} {statistics.mean(locs['rewbuffer']):.2f}\n"""
f"""{'Mean episode length:':>{pad}} {statistics.mean(locs['lenbuffer']):.2f}\n"""
)
# f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n"""
# f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n""")
else:
log_string = (
f"""{'#' * width}\n"""
f"""{str.center(width, ' ')}\n\n"""
f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[
'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n"""
f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n"""
f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n"""
)
# f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n"""
# f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n""")
log_string += ep_string
log_string += (
f"""{'-' * width}\n"""
f"""{'Total timesteps:':>{pad}} {self.tot_timesteps}\n"""
f"""{'Iteration time:':>{pad}} {iteration_time:.2f}s\n"""
f"""{'Total time:':>{pad}} {self.tot_time:.2f}s\n"""
f"""{'ETA:':>{pad}} {self.tot_time / (locs['it'] + 1) * (
locs['num_learning_iterations'] - locs['it']):.1f}s\n"""
)
print(log_string)
def save(self, path, infos=None):
saved_dict = {
"model_state_dict": self.alg.actor_critic.state_dict(),
"optimizer_state_dict": self.alg.optimizer.state_dict(),
"iter": self.current_learning_iteration,
"infos": infos,
}
if self.empirical_normalization:
saved_dict["obs_norm_state_dict"] = self.obs_normalizer.state_dict()
saved_dict["critic_obs_norm_state_dict"] = self.critic_obs_normalizer.state_dict()
torch.save(saved_dict, path)
# Upload model to external logging service
if self.logger_type in ["neptune", "wandb"]:
self.writer.save_model(path, self.current_learning_iteration)
def load(self, path, load_optimizer=True):
loaded_dict = torch.load(path)
self.alg.actor_critic.load_state_dict(loaded_dict["model_state_dict"])
if self.empirical_normalization:
self.obs_normalizer.load_state_dict(loaded_dict["obs_norm_state_dict"])
self.critic_obs_normalizer.load_state_dict(loaded_dict["critic_obs_norm_state_dict"])
if load_optimizer:
self.alg.optimizer.load_state_dict(loaded_dict["optimizer_state_dict"])
self.current_learning_iteration = loaded_dict["iter"]
return loaded_dict["infos"]
def get_inference_policy(self, device=None):
self.eval_mode() # switch to evaluation mode (dropout for example)
if device is not None:
self.alg.actor_critic.to(device)
policy = self.alg.actor_critic.act_inference
if self.cfg["empirical_normalization"]:
if device is not None:
self.obs_normalizer.to(device)
policy = lambda x: self.alg.actor_critic.act_inference(self.obs_normalizer(x)) # noqa: E731
return policy
def train_mode(self):
self.alg.actor_critic.train()
if self.empirical_normalization:
self.obs_normalizer.train()
self.critic_obs_normalizer.train()
def eval_mode(self):
self.alg.actor_critic.eval()
if self.empirical_normalization:
self.obs_normalizer.eval()
self.critic_obs_normalizer.eval()
def add_git_repo_to_log(self, repo_file_path):
self.git_status_repos.append(repo_file_path)

498
rsl_rl/runners/runner.py Normal file
View File

@ -0,0 +1,498 @@
from __future__ import annotations
import copy
from datetime import timedelta
import numpy as np
import os
import time
import torch
from typing import Any, Callable, Dict, List, Tuple, TypedDict, Union
import rsl_rl
from rsl_rl.storage.storage import Dataset
from rsl_rl.algorithms import Agent
from rsl_rl.env import VecEnv
class EpisodeStatistics(TypedDict):
"""The statistics of an episode."""
# Time it took to collect samples for the current interation.
collection_time: Union[int, None]
# The counter of the current interation.
current_iteration: int
# The number of the final iteration of the current run.
final_iteration: int
# The number of the first iteration of the current run.
first_iteration: int
# Environment information about the current interation.
info: list
# The lengths of the episodes.
lengths: Union[List[int], None]
# The loss of the current interation.
loss: Union[dict, None]
# The returns of the episodes.
returns: Union[List[float], None]
# The total time it took to run the current interation.
total_time: Union[int, None]
# The time it took to update the agent.
update_time: Union[int, None]
Callback = Callable[[EpisodeStatistics], None]
class Runner:
"""The runner class for running an agent in an environment.
This class is responsible for running an agent in an environment. It is responsible for collecting data from the
environment, updating the agent, and evaluating the agent. It also provides a number of callbacks that can be used
to log and visualize the training progress.
"""
_dataset: Dataset
_episode_statistics: EpisodeStatistics
_num_steps_per_env: int
def __init__(
self,
environment: VecEnv,
agent: Agent,
device: str = "cpu",
evaluation_cb: List[Callback] = None,
learn_cb: List[Callback] = None,
observation_history_length: int = 1,
**kwargs,
) -> None:
"""
Args:
environment (rsl_rl.env.VecEnv): The environment to run the agent in.
agent (rsl_rl.algorithms.agent): The RL agent to run.
device (str): The device to run on.
evaluation_cb (List[Callable[[dict], None]], optional): A list of callbacks that are called after each round
of evaluation.
learn_cb (List[Callable[[dict], None]], optional): A list of callbacks that are called after each round of
learning.
observation_history_length: The number of observations to concatenate into a single observation.
"""
self.env = environment
self.agent = agent
self.device = device
self._obs_hist_len = observation_history_length
self._learn_cb = learn_cb if learn_cb else []
self._eval_cb = evaluation_cb if evaluation_cb else []
self._set_kwarg(kwargs, "num_steps_per_env", default=1)
self._current_learning_iteration = 0
self._git_status_repos = [rsl_rl.__file__]
self.to(self.device)
self._stored_dataset = [] # For computing observation history over multiple steps.
def add_git_repo_to_log(self, repo_file_path):
self._git_status_repos.append(repo_file_path)
def eval_mode(self):
"""Sets the agent to evaluation mode."""
self.agent.eval_mode()
def evaluate(self, steps: int, return_epochs: int = 100) -> float:
"""Evaluates the agent for a number of steps.
Args:
steps (int): The number of steps to evaluate the agent for.
return_epochs (int): The number of epochs over which to aggregate the return. Defaults to 100.
Returns:
The mean return of the agent.
"""
cumulative_rewards = []
current_cumulative_rewards = torch.zeros(self.env.num_envs, dtype=torch.float)
current_episode_lengths = torch.zeros(self.env.num_envs, dtype=torch.int)
episode_lengths = []
self.eval_mode()
policy = self.get_inference_policy()
obs, env_info = self.env.get_observations()
with torch.inference_mode():
for step in range(steps):
actions = policy(obs.clone(), copy.deepcopy(env_info))
obs, rewards, dones, env_info, episode_statistics = self.evaluate_step(obs, env_info, actions)
dones_idx = dones.nonzero().cpu()
current_cumulative_rewards += rewards.clone().cpu()
current_episode_lengths += 1
cumulative_rewards.extend(current_cumulative_rewards[dones_idx].squeeze(1).cpu().tolist())
episode_lengths.extend(current_episode_lengths[dones_idx].squeeze(1).cpu().tolist())
current_cumulative_rewards[dones_idx] = 0.0
current_episode_lengths[dones_idx] = 0
episode_statistics["current_iteration"] = step
episode_statistics["final_iteration"] = steps
episode_statistics["lengths"] = episode_lengths[-return_epochs:]
episode_statistics["returns"] = cumulative_rewards[-return_epochs:]
for cb in self._eval_cb:
cb(self, episode_statistics)
cumulative_rewards.extend(current_cumulative_rewards.cpu().tolist())
mean_return = np.mean(cumulative_rewards)
return mean_return
def evaluate_step(
self, observations=None, environment_info=None, actions=None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Dict, Dict]:
"""Evaluates the agent for a single step.
Args:
observations (torch.Tensor): The observations to evaluate the agent for.
environment_info (Dict[str, Any]): The environment information for the observations.
actions (torch.Tensor): The actions to evaluate the agent for.
Returns:
A tuple containing the observations, rewards, dones, environment information, and episode statistics after
the evaluation step.
"""
episode_statistics = {
"current_actions": None,
"current_dones": None,
"current_iteration": 0,
"current_observations": None,
"current_rewards": None,
"final_iteration": 0,
"first_iteration": 0,
"info": [],
"lengths": [],
"returns": [],
"timeout": None,
"total_time": None,
}
self.eval_mode()
with torch.inference_mode():
obs, env_info = self.env.get_observations() if observations is None else (observations, environment_info)
with torch.inference_mode():
start = time.time()
actions = self.get_inference_policy()(obs.clone(), copy.deepcopy(env_info)) if actions is None else actions
obs, rewards, dones, env_info = self.env.step(actions.clone())
self.agent.register_terminations(dones.nonzero().reshape(-1))
end = time.time()
if "episode" in env_info:
episode_statistics["info"].append(env_info["episode"])
episode_statistics["current_actions"] = actions
episode_statistics["current_dones"] = dones
episode_statistics["current_observations"] = obs
episode_statistics["current_rewards"] = rewards
episode_statistics["total_time"] = end - start
return obs, rewards, dones, env_info, episode_statistics
def get_inference_policy(self, device=None):
self.eval_mode()
return self.agent.get_inference_policy(device)
def learn(
self, iterations: Union[int, None] = None, timeout: Union[int, None] = None, return_epochs: int = 100
) -> None:
"""Runs a number of learning iterations.
Args:
iterations (int): The number of iterations to run.
timeout (int): Optional number of seconds after which to terminate training. Defaults to None.
return_epochs (int): The number of epochs over which to aggregate the return. Defaults to 100.
"""
assert iterations is not None or timeout is not None
self._episode_statistics = {
"collection_time": None,
"current_actions": None,
"current_iteration": self._current_learning_iteration,
"current_observations": None,
"final_iteration": self._current_learning_iteration + iterations if iterations is not None else None,
"first_iteration": self._current_learning_iteration,
"info": [],
"lengths": [],
"loss": {},
"returns": [],
"storage_initialized": False,
"timeout": timeout,
"total_time": None,
"training_time": 0,
"update_time": None,
}
self._current_episode_lengths = torch.zeros(self.env.num_envs, dtype=torch.float)
self._current_cumulative_rewards = torch.zeros(self.env.num_envs, dtype=torch.float)
self.train_mode()
self._obs, self._env_info = self.env.get_observations()
while True:
if self._learning_should_terminate():
break
# Collect data
start = time.time()
with torch.inference_mode():
self._dataset = []
for _ in range(self._num_steps_per_env):
self._collect()
self._episode_statistics["lengths"] = self._episode_statistics["lengths"][-return_epochs:]
self._episode_statistics["returns"] = self._episode_statistics["returns"][-return_epochs:]
self._episode_statistics["collection_time"] = time.time() - start
# Update agent
start = time.time()
self._update()
self._episode_statistics["update_time"] = time.time() - start
# Housekeeping
self._episode_statistics["total_time"] = (
self._episode_statistics["collection_time"] + self._episode_statistics["update_time"]
)
self._episode_statistics["training_time"] += self._episode_statistics["total_time"]
if self.agent.initialized:
self._episode_statistics["current_iteration"] += 1
terminate = False
for cb in self._learn_cb:
terminate = (cb(self, self._episode_statistics) == False) or terminate
if terminate:
break
self._episode_statistics["info"].clear()
self._current_learning_iteration = self._episode_statistics["current_iteration"]
def _collect(self) -> None:
"""Runs a single step in the environment to collect a transition and stores it in the dataset.
This method runs a single step in the environment to collect a transition and stores it in the dataset. If the
agent is not initialized, random actions are drawn from the action space. Furthermore, the method gathers
statistics about the episode and stores them in the episode statistics dictionary of the runner.
"""
if self.agent.initialized:
actions, data = self.agent.draw_actions(self._obs, self._env_info)
else:
actions, data = self.agent.draw_random_actions(self._obs, self._env_info)
next_obs, rewards, dones, next_env_info = self.env.step(actions)
self._dataset.append(
self.agent.process_transition(
self._obs.clone(),
copy.deepcopy(self._env_info),
actions.clone(),
rewards.clone(),
next_obs.clone(),
copy.deepcopy(next_env_info),
dones.clone(),
copy.deepcopy(data),
)
)
self.agent.register_terminations(dones.nonzero().reshape(-1))
self._obs, self._env_info = next_obs, next_env_info
# Gather statistics
if "episode" in self._env_info:
self._episode_statistics["info"].append(self._env_info["episode"])
dones_idx = (dones + next_env_info["time_outs"]).nonzero().cpu()
self._current_episode_lengths += 1
self._current_cumulative_rewards += rewards.cpu()
completed_lengths = self._current_episode_lengths[dones_idx][:, 0].cpu()
completed_returns = self._current_cumulative_rewards[dones_idx][:, 0].cpu()
self._episode_statistics["lengths"].extend(completed_lengths.tolist())
self._episode_statistics["returns"].extend(completed_returns.tolist())
self._current_episode_lengths[dones_idx] = 0.0
self._current_cumulative_rewards[dones_idx] = 0.0
self._episode_statistics["current_actions"] = actions
self._episode_statistics["current_observations"] = self._obs
self._episode_statistics["sample_count"] = self.agent.storage.sample_count
def _learning_should_terminate(self):
"""Checks whether the learning should terminate.
Termination is triggered if the number of iterations or the timeout is reached.
Returns:
Whether the learning should terminate.
"""
if (
self._episode_statistics["final_iteration"] is not None
and self._episode_statistics["current_iteration"] >= self._episode_statistics["final_iteration"]
):
return True
if (
self._episode_statistics["timeout"] is not None
and self._episode_statistics["training_time"] >= self._episode_statistics["timeout"]
):
return True
return False
def _update(self) -> None:
"""Updates the agent using the collected data."""
loss = self.agent.update(self._dataset)
self._dataset = []
if not self.agent.initialized:
return
self._episode_statistics["loss"] = loss
self._episode_statistics["storage_initialized"] = True
def load(self, path: str) -> Any:
"""Restores the agent and runner state from a file."""
content = torch.load(path, map_location=self.device)
assert "agent" in content
assert "data" in content
assert "iteration" in content
self.agent.load_state_dict(content["agent"])
self._current_learning_iteration = content["iteration"]
return content["data"]
def save(self, path: str, data: Any = None) -> None:
"""Saves the agent and runner state to a file."""
content = {
"agent": self.agent.state_dict(),
"data": data,
"iteration": self._current_learning_iteration,
}
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save(content, path)
def export_onnx(self, path: str) -> None:
"""Exports the agent's policy network to ONNX format."""
model, args, kwargs = self.agent.export_onnx()
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.onnx.export(model, args, path, **kwargs)
def to(self, device) -> Runner:
"""Sets the device of the runner and its components."""
self.device = device
self.agent.to(device)
try:
self.env.to(device)
except AttributeError:
pass
return self
def train_mode(self):
"""Sets the agent to training mode."""
self.agent.train_mode()
def _set_kwarg(self, args, key, default=None, private=True):
setattr(self, f"_{key}" if private else key, args[key] if key in args else default)
def _log_progress(self, stat, clear_line=True, prefix=""):
"""Logs the progress of the runner."""
if not hasattr(self, "_iteration_times"):
self._iteration_times = []
self._iteration_times = (self._iteration_times + [stat["total_time"]])[-100:]
average_total_time = np.mean(self._iteration_times)
if stat["final_iteration"] is not None:
first_iteration = stat["first_iteration"]
final_iteration = stat["final_iteration"]
current_iteration = stat["current_iteration"]
final_run_iteration = final_iteration - first_iteration
remaining_iterations = final_iteration - current_iteration
remaining_iteration_time = remaining_iterations * average_total_time
iteration_completion_percentage = 100 * (current_iteration - first_iteration) / final_run_iteration
else:
remaining_iteration_time = np.inf
iteration_completion_percentage = 0
if stat["timeout"] is not None:
training_time = stat["training_time"]
timeout = stat["timeout"]
remaining_timeout_time = stat["timeout"] - stat["training_time"]
timeout_completion_percentage = 100 * stat["training_time"] / stat["timeout"]
else:
remaining_timeout_time = np.inf
timeout_completion_percentage = 0
if remaining_iteration_time > remaining_timeout_time:
completion_percentage = timeout_completion_percentage
remaining_time = remaining_timeout_time
step_string = f"({int(training_time)}s / {timeout}s)"
else:
completion_percentage = iteration_completion_percentage
remaining_time = remaining_iteration_time
step_string = f"({current_iteration} / {final_iteration})"
prefix = f"[{prefix}] " if prefix else ""
progress = "".join(["#" if i <= int(completion_percentage) else "_" for i in range(10, 101, 5)])
remaining_time_string = str(timedelta(seconds=int(np.ceil(remaining_time))))
print(
f"{prefix}{progress} {step_string} [{completion_percentage:.1f}%, {1/average_total_time:.2f}it/s, {remaining_time_string} ETA]",
end="\r" if clear_line else "\n",
)
def _log(self, stat, prefix=""):
"""Logs the progress and statistics of the runner."""
current_iteration = stat["current_iteration"]
collection_time = stat["collection_time"]
update_time = stat["update_time"]
total_time = stat["total_time"]
collection_percentage = 100 * collection_time / total_time
update_percentage = 100 * update_time / total_time
if prefix == "":
prefix = "learn" if stat["storage_initialized"] else "init"
self._log_progress(stat, clear_line=False, prefix=prefix)
print(
f"iteration time:\t{total_time:.4f}s (collection: {collection_time:.2f}s [{collection_percentage:.1f}%], update: {update_time:.2f}s [{update_percentage:.1f}%])"
)
mean_reward = sum(stat["returns"]) / len(stat["returns"]) if len(stat["returns"]) > 0 else 0.0
mean_steps = sum(stat["lengths"]) / len(stat["lengths"]) if len(stat["lengths"]) > 0 else 0.0
total_steps = current_iteration * self.env.num_envs * self._num_steps_per_env
sample_count = stat["sample_count"]
print(f"avg. reward:\t{mean_reward:.4f}")
print(f"avg. steps:\t{mean_steps:.4f}")
print(f"stored samples:\t{sample_count}")
print(f"total steps:\t{total_steps}")
for key, value in stat["loss"].items():
print(f"{key} loss:\t{value:.4f}")
for key, value in self.agent._bm_report().items():
mean, count = value
print(f"BM {key}:\t{mean/1000000.0:.4f}ms ({count} calls, total {mean*count/1000000.0:.4f}ms)")
self.agent._bm_flush()

View File

@ -4,5 +4,6 @@
"""Implementation of transitions storage for RL-agent."""
from .rollout_storage import RolloutStorage
from .replay_storage import ReplayStorage
__all__ = ["RolloutStorage"]
__all__ = ["RolloutStorage", "ReplayStorage"]

View File

@ -0,0 +1,147 @@
import torch
from typing import Callable, Dict, Generator, Tuple, Optional
from rsl_rl.storage.storage import Dataset, Storage, Transition
class ReplayStorage(Storage):
def __init__(self, environment_count: int, max_size: int, device: str = "cpu", initial_size: int = 0) -> None:
self._env_count = environment_count
self.initial_size = initial_size // environment_count
self.max_size = max_size
self.device = device
self._register_serializable("max_size", "initial_size")
self._idx = 0
self._full = False
self._initialized = initial_size == 0
self._data = {}
self._processors: Dict[Tuple[Callable, Callable]] = {}
@property
def max_size(self):
return self._size * self._env_count
@max_size.setter
def max_size(self, value):
self._size = value // self._env_count
assert self.initial_size <= self._size
def _add_item(self, name: str, value: torch.Tensor) -> None:
"""Adds a transition item to the storage.
Args:
name (str): The name of the item.
value (torch.Tensor): The value of the item.
"""
value = self._process(name, value.clone().to(self.device))
if name not in self._data:
if self._full or self._idx != 0:
raise ValueError(f'Tried to store invalid transition data for "{name}".')
self._data[name] = torch.empty(
self._size * self._env_count, *value.shape[1:], device=self.device, dtype=value.dtype
)
start_idx = self._idx * self._env_count
end_idx = (self._idx + 1) * self._env_count
self._data[name][start_idx:end_idx] = value
def _process(self, name: str, value: torch.Tensor) -> torch.Tensor:
if name not in self._processors:
return value
for process, _ in self._processors[name]:
if process is None:
continue
value = process(value)
return value
def _process_undo(self, name: str, value: torch.Tensor) -> torch.Tensor:
if name not in self._processors:
return value
for _, undo in reversed(self._processors[name]):
if undo is None:
continue
value = undo(value)
return value
def append(self, dataset: Dataset) -> None:
"""Appends a dataset of transitions to the storage.
Args:
dataset (Dataset): The dataset of transitions.
"""
for transition in dataset:
for name, value in transition.items():
self._add_item(name, value)
self._idx += 1
if self._idx >= self.initial_size:
self._initialized = True
if self._idx == self._size:
self._full = True
self._idx = 0
def batch_generator(self, batch_size: int, batch_count: int) -> Generator[Transition, None, None]:
"""Returns a generator that yields batches of transitions.
Args:
batch_size (int): The size of the batches.
batch_count (int): The number of batches to yield.
Returns:
A generator that yields batches of transitions.
"""
assert self._full or self._idx > 0
if not self._initialized:
return
max_idx = self._env_count * (self._size if self._full else self._idx)
for _ in range(batch_count):
batch_idx = torch.randint(high=max_idx, size=(batch_size,))
batch = {}
for key, value in self._data.items():
batch[key] = self._process_undo(key, value[batch_idx].clone())
yield batch
def register_processor(self, key: str, process: Callable, undo: Optional[Callable] = None) -> None:
"""Registers a processor for a transition item.
The processor is called before the item is stored in the storage. The undo function is called when the item is
retrieved from the storage. The undo function is called in reverse order of the processors so that the order of
the processors does not matter.
Args:
key (str): The name of the transition item.
process (Callable): The function to process the item.
undo (Optional[Callable], optional): The function to undo the processing. Defaults to None.
"""
if key not in self._processors:
self._processors[key] = []
self._processors[key].append((process, undo))
@property
def initialized(self) -> bool:
return self._initialized
@property
def sample_count(self) -> int:
"""Returns the number of individual transitions stored in the storage."""
transition_count = self._size * self._env_count if self._full else self._idx * self._env_count
return transition_count

View File

@ -1,228 +1,71 @@
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
import torch
from typing import Generator
from rsl_rl.utils import split_and_pad_trajectories
from rsl_rl.storage.replay_storage import ReplayStorage
from rsl_rl.storage.storage import Dataset, Transition
class RolloutStorage:
class Transition:
def __init__(self):
self.observations = None
self.critic_observations = None
self.actions = None
self.rewards = None
self.dones = None
self.values = None
self.actions_log_prob = None
self.action_mean = None
self.action_sigma = None
self.hidden_states = None
class RolloutStorage(ReplayStorage):
"""Implementation of rollout storage for RL-agent."""
def clear(self):
self.__init__()
def __init__(self, environment_count: int, device: str = "cpu"):
"""
Args:
environment_count (int): Number of environments.
device (str, optional): Device to use. Defaults to "cpu".
"""
super().__init__(environment_count, environment_count, device=device, initial_size=0)
def __init__(self, num_envs, num_transitions_per_env, obs_shape, privileged_obs_shape, actions_shape, device="cpu"):
self.device = device
self._size_initialized = False
self.obs_shape = obs_shape
self.privileged_obs_shape = privileged_obs_shape
self.actions_shape = actions_shape
def append(self, dataset: Dataset) -> None:
"""Appends a dataset to the rollout storage.
# Core
self.observations = torch.zeros(num_transitions_per_env, num_envs, *obs_shape, device=self.device)
if privileged_obs_shape[0] is not None:
self.privileged_observations = torch.zeros(
num_transitions_per_env, num_envs, *privileged_obs_shape, device=self.device
)
Args:
dataset (Dataset): Dataset to append.
Raises:
AssertionError: If the dataset is not of the correct size.
"""
assert self._idx == 0
if not self._size_initialized:
self.max_size = len(dataset) * self._env_count
assert len(dataset) == self._size
super().append(dataset)
def batch_generator(self, batch_count: int, trajectories: bool = False) -> Generator[Transition, None, None]:
"""Yields batches of transitions or trajectories.
Args:
batch_count (int): Number of batches to yield.
trajectories (bool, optional): Whether to yield batches of trajectories. Defaults to False.
Raises:
AssertionError: If the rollout storage is not full.
Returns:
Generator yielding batches of transitions of shape (batch_size, *shape). If trajectories is True, yields
batches of trajectories of shape (env_count, steps_per_env, *shape).
"""
assert self._full and self._initialized, "Rollout storage must be full and initialized."
total_size = self._env_count if trajectories else self._size * self._env_count
batch_size = total_size // batch_count
indices = torch.randperm(total_size)
assert batch_size > 0, "Batch count is too large."
if trajectories:
# Reshape to (env_count, steps_per_env, *shape)
data = {k: v.reshape(-1, self._env_count, *v.shape[1:]).transpose(0, 1) for k, v in self._data.items()}
else:
self.privileged_observations = None
self.rewards = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
self.actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)
self.dones = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device).byte()
data = self._data
# For PPO
self.actions_log_prob = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
self.values = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
self.returns = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
self.advantages = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
self.mu = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)
self.sigma = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)
for i in range(batch_count):
batch_idx = indices[i * batch_size : (i + 1) * batch_size].detach().to(self.device)
self.num_transitions_per_env = num_transitions_per_env
self.num_envs = num_envs
batch = {}
for key, value in data.items():
batch[key] = self._process_undo(key, value[batch_idx].clone())
# rnn
self.saved_hidden_states_a = None
self.saved_hidden_states_c = None
self.step = 0
def add_transitions(self, transition: Transition):
if self.step >= self.num_transitions_per_env:
raise AssertionError("Rollout buffer overflow")
self.observations[self.step].copy_(transition.observations)
if self.privileged_observations is not None:
self.privileged_observations[self.step].copy_(transition.critic_observations)
self.actions[self.step].copy_(transition.actions)
self.rewards[self.step].copy_(transition.rewards.view(-1, 1))
self.dones[self.step].copy_(transition.dones.view(-1, 1))
self.values[self.step].copy_(transition.values)
self.actions_log_prob[self.step].copy_(transition.actions_log_prob.view(-1, 1))
self.mu[self.step].copy_(transition.action_mean)
self.sigma[self.step].copy_(transition.action_sigma)
self._save_hidden_states(transition.hidden_states)
self.step += 1
def _save_hidden_states(self, hidden_states):
if hidden_states is None or hidden_states == (None, None):
return
# make a tuple out of GRU hidden state sto match the LSTM format
hid_a = hidden_states[0] if isinstance(hidden_states[0], tuple) else (hidden_states[0],)
hid_c = hidden_states[1] if isinstance(hidden_states[1], tuple) else (hidden_states[1],)
# initialize if needed
if self.saved_hidden_states_a is None:
self.saved_hidden_states_a = [
torch.zeros(self.observations.shape[0], *hid_a[i].shape, device=self.device) for i in range(len(hid_a))
]
self.saved_hidden_states_c = [
torch.zeros(self.observations.shape[0], *hid_c[i].shape, device=self.device) for i in range(len(hid_c))
]
# copy the states
for i in range(len(hid_a)):
self.saved_hidden_states_a[i][self.step].copy_(hid_a[i])
self.saved_hidden_states_c[i][self.step].copy_(hid_c[i])
def clear(self):
self.step = 0
def compute_returns(self, last_values, gamma, lam):
advantage = 0
for step in reversed(range(self.num_transitions_per_env)):
if step == self.num_transitions_per_env - 1:
next_values = last_values
else:
next_values = self.values[step + 1]
next_is_not_terminal = 1.0 - self.dones[step].float()
delta = self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step]
advantage = delta + next_is_not_terminal * gamma * lam * advantage
self.returns[step] = advantage + self.values[step]
# Compute and normalize the advantages
self.advantages = self.returns - self.values
self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8)
def get_statistics(self):
done = self.dones
done[-1] = 1
flat_dones = done.permute(1, 0, 2).reshape(-1, 1)
done_indices = torch.cat(
(flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero(as_tuple=False)[:, 0])
)
trajectory_lengths = done_indices[1:] - done_indices[:-1]
return trajectory_lengths.float().mean(), self.rewards.mean()
def mini_batch_generator(self, num_mini_batches, num_epochs=8):
batch_size = self.num_envs * self.num_transitions_per_env
mini_batch_size = batch_size // num_mini_batches
indices = torch.randperm(num_mini_batches * mini_batch_size, requires_grad=False, device=self.device)
observations = self.observations.flatten(0, 1)
if self.privileged_observations is not None:
critic_observations = self.privileged_observations.flatten(0, 1)
else:
critic_observations = observations
actions = self.actions.flatten(0, 1)
values = self.values.flatten(0, 1)
returns = self.returns.flatten(0, 1)
old_actions_log_prob = self.actions_log_prob.flatten(0, 1)
advantages = self.advantages.flatten(0, 1)
old_mu = self.mu.flatten(0, 1)
old_sigma = self.sigma.flatten(0, 1)
for epoch in range(num_epochs):
for i in range(num_mini_batches):
start = i * mini_batch_size
end = (i + 1) * mini_batch_size
batch_idx = indices[start:end]
obs_batch = observations[batch_idx]
critic_observations_batch = critic_observations[batch_idx]
actions_batch = actions[batch_idx]
target_values_batch = values[batch_idx]
returns_batch = returns[batch_idx]
old_actions_log_prob_batch = old_actions_log_prob[batch_idx]
advantages_batch = advantages[batch_idx]
old_mu_batch = old_mu[batch_idx]
old_sigma_batch = old_sigma[batch_idx]
yield obs_batch, critic_observations_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, (
None,
None,
), None
# for RNNs only
def reccurent_mini_batch_generator(self, num_mini_batches, num_epochs=8):
padded_obs_trajectories, trajectory_masks = split_and_pad_trajectories(self.observations, self.dones)
if self.privileged_observations is not None:
padded_critic_obs_trajectories, _ = split_and_pad_trajectories(self.privileged_observations, self.dones)
else:
padded_critic_obs_trajectories = padded_obs_trajectories
mini_batch_size = self.num_envs // num_mini_batches
for ep in range(num_epochs):
first_traj = 0
for i in range(num_mini_batches):
start = i * mini_batch_size
stop = (i + 1) * mini_batch_size
dones = self.dones.squeeze(-1)
last_was_done = torch.zeros_like(dones, dtype=torch.bool)
last_was_done[1:] = dones[:-1]
last_was_done[0] = True
trajectories_batch_size = torch.sum(last_was_done[:, start:stop])
last_traj = first_traj + trajectories_batch_size
masks_batch = trajectory_masks[:, first_traj:last_traj]
obs_batch = padded_obs_trajectories[:, first_traj:last_traj]
critic_obs_batch = padded_critic_obs_trajectories[:, first_traj:last_traj]
actions_batch = self.actions[:, start:stop]
old_mu_batch = self.mu[:, start:stop]
old_sigma_batch = self.sigma[:, start:stop]
returns_batch = self.returns[:, start:stop]
advantages_batch = self.advantages[:, start:stop]
values_batch = self.values[:, start:stop]
old_actions_log_prob_batch = self.actions_log_prob[:, start:stop]
# reshape to [num_envs, time, num layers, hidden dim] (original shape: [time, num_layers, num_envs, hidden_dim])
# then take only time steps after dones (flattens num envs and time dimensions),
# take a batch of trajectories and finally reshape back to [num_layers, batch, hidden_dim]
last_was_done = last_was_done.permute(1, 0)
hid_a_batch = [
saved_hidden_states.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj]
.transpose(1, 0)
.contiguous()
for saved_hidden_states in self.saved_hidden_states_a
]
hid_c_batch = [
saved_hidden_states.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj]
.transpose(1, 0)
.contiguous()
for saved_hidden_states in self.saved_hidden_states_c
]
# remove the tuple for GRU
hid_a_batch = hid_a_batch[0] if len(hid_a_batch) == 1 else hid_a_batch
hid_c_batch = hid_c_batch[0] if len(hid_c_batch) == 1 else hid_c_batch
yield obs_batch, critic_obs_batch, actions_batch, values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, (
hid_a_batch,
hid_c_batch,
), masks_batch
first_traj = last_traj
yield batch

47
rsl_rl/storage/storage.py Normal file
View File

@ -0,0 +1,47 @@
from abc import abstractmethod
import torch
from typing import Dict, Generator, List
from rsl_rl.utils.serializable import Serializable
# prev_obs, prev_obs_info, actions, rewards, next_obs, next_obs_info, dones, data
Transition = Dict[str, torch.Tensor]
Dataset = List[Transition]
class Storage(Serializable):
@abstractmethod
def append(self, dataset: Dataset) -> None:
"""Adds transitions to the storage.
Args:
dataset (Dataset): The transitions to add to the storage.
"""
pass
@abstractmethod
def batch_generator(self, batch_size: int, batch_count: int) -> Generator[Dict[str, torch.Tensor], None, None]:
"""Generates a batch of transitions.
Args:
batch_size (int): The size of each batch to generate.
batch_count (int): The number of batches to generate.
Returns:
A generator that yields transitions.
"""
pass
@property
def initialized(self) -> bool:
"""Returns whether the storage is initialized."""
return True
@abstractmethod
def sample_count(self) -> int:
"""Returns how many individual samples are stored in the storage.
Returns:
The number of stored samples.
"""
pass

View File

@ -1,6 +1,3 @@
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
# SPDX-License-Identifier: BSD-3-Clause
"""Helper functions."""
from .utils import split_and_pad_trajectories, store_code_state, unpad_trajectories
from .utils import split_and_pad_trajectories, unpad_trajectories, store_code_state

View File

@ -0,0 +1,111 @@
import numpy as np
import time
from typing import Callable, Dict
class Benchmark:
def __init__(self):
self.reset()
def __call__(self):
if self.running:
self.end()
else:
self.start()
def end(self):
now = time.process_time_ns()
assert self.running
difference = now - self._timer
self._timings.append(difference)
self._timer = None
def reset(self):
self._timer = None
self._timings = []
@property
def running(self):
return self._timer is not None
def start(self):
self._timer = time.process_time_ns()
@property
def timings(self):
return self._timings
class Benchmarkable:
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._benchmark = False
self._bm_data = dict()
self._bm_fusions = []
def _bm(self, name: str) -> None:
if not self._benchmark:
return
if name not in self._bm_data:
self._bm_data[name] = Benchmark()
self._bm_data[name]()
def _bm_flush(self) -> None:
# TODO: implement
for val in self._bm_data.values():
val.reset()
for fusion in self._bm_fusions:
fusion["target"]._bm_flush()
def _bm_fuse(self, target, prefix="") -> None:
assert isinstance(target, Benchmarkable)
assert target not in self._bm_fusions
target._bm_toggle(self._benchmark)
self._bm_fusions.append(dict(target=target, prefix=prefix))
def _bm_report(self) -> Dict:
data = dict()
if not self._benchmark:
return data
for key, val in self._bm_data.items():
data[key] = (np.mean(val.timings), len(val.timings))
for fusion in self._bm_fusions:
target = fusion["target"]
prefix = fusion["prefix"]
for key, val in target._bm_report().items():
data[f"{prefix}{key}"] = val
return data
def _bm_toggle(self, value: bool) -> None:
self._benchmark = value
for fusion in self._bm_fusions:
fusion["target"]._bm_toggle(value)
@staticmethod
def register(method: Callable, name=None) -> Callable:
benchmark_name = method.__name__ if name is None else name
def wrapper(self, *args, **kwargs):
assert isinstance(self, Benchmarkable)
self._bm(benchmark_name)
result = method(self, *args, **kwargs)
self._bm(benchmark_name)
return result
return wrapper

View File

@ -1,32 +1,26 @@
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
import os
from dataclasses import asdict
from torch.utils.tensorboard import SummaryWriter
from legged_gym.utils import class_to_dict
try:
import neptune
import neptune.new as neptune
except ModuleNotFoundError:
raise ModuleNotFoundError("neptune-client is required to log to Neptune.")
class NeptuneLogger:
def __init__(self, project, token):
self.run = neptune.init_run(project=project, api_token=token)
self.run = neptune.init(project=project, api_token=token)
def store_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg):
self.run["runner_cfg"] = runner_cfg
self.run["policy_cfg"] = policy_cfg
self.run["alg_cfg"] = alg_cfg
self.run["env_cfg"] = asdict(env_cfg)
self.run["env_cfg"] = class_to_dict(env_cfg)
class NeptuneSummaryWriter(SummaryWriter):
"""Summary writer for Neptune."""
def __init__(self, log_dir: str, flush_secs: int, cfg):
super().__init__(log_dir, flush_secs)
@ -86,7 +80,3 @@ class NeptuneSummaryWriter(SummaryWriter):
def save_model(self, model_path, iter):
self.neptune_logger.run["model/saved_model_" + str(iter)].upload(model_path)
def save_file(self, path, iter=None):
name = path.rsplit("/", 1)[-1].split(".")[0]
self.neptune_logger.run["git_diff/" + name].upload(path)

View File

@ -0,0 +1,69 @@
import torch
from typing import Tuple
def trajectories_to_transitions(trajectories: torch.Tensor, data: Tuple[torch.Tensor, int, bool]) -> torch.Tensor:
"""Unpacks a tensor of trajectories into a tensor of transitions.
Args:
trajectories (torch.Tensor): A tensor of trajectories.
data (Tuple[torch.Tensor, int, bool]): A tuple containing the mask and data for the conversion.
batch_first (bool, optional): Whether the first dimension of the trajectories tensor is the batch dimension.
Defaults to False.
Returns:
A tensor of transitions of shape (batch_size, time, *).
"""
mask, batch_size, batch_first = data
if not batch_first:
trajectories, mask = trajectories.transpose(0, 1), mask.transpose(0, 1)
transitions = trajectories[mask == 1.0].reshape(batch_size, -1, *trajectories.shape[2:])
return transitions
def transitions_to_trajectories(
transitions: torch.Tensor, dones: torch.Tensor, batch_first: bool = False
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, int, bool]]:
"""Packs a tensor of transitions into a tensor of trajectories.
Example:
>>> transitions = torch.tensor([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]])
>>> dones = torch.tensor([[0, 0, 1], [0, 1, 0]])
>>> transitions_to_trajectories(None, transitions, dones, batch_first=True)
(tensor([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [0, 0]], [[11, 12], [0, 0], [0, 0]]]), tensor([[1, 1, 1], [1, 1, 0], [1, 0, 0]]))
Args:
transitions (torch.Tensor): Tensor of transitions of shape (batch_size, time, *).
dones (torch.Tensor): Tensor of transition terminations of shape (batch_size, time).
batch_first (bool): Whether the first dimension of the output tensor should be the batch dimension. Defaults to
False.
Returns:
A torch.Tensor of trajectories of shape (time, trajectory_count, *) that is padded with zeros and data for
reverting the operation. If batch_first is True, the shape of the trajectories is (trajectory_count, time, *).
"""
batch_size = transitions.shape[0]
# Count the trajectory lengths by (1) padding dones with a 1 at the end to indicate the end of the trajectory,
# (2) stacking up the padded dones in a single column, and (3) counting the number of steps between each done by
# using the row index.
padded_dones = dones.clone()
padded_dones[:, -1] = 1
stacked_dones = torch.cat((padded_dones.new([-1]), padded_dones.reshape(-1, 1).nonzero()[:, 0]))
trajectory_lengths = stacked_dones[1:] - stacked_dones[:-1]
# Compute trajectories by splitting transitions according to previously computed trajectory lengths.
trajectory_list = torch.split(transitions.flatten(0, 1), trajectory_lengths.int().tolist())
trajectories = torch.nn.utils.rnn.pad_sequence(trajectory_list, batch_first=batch_first)
# The mask is generated by computing a 2d matrix of increasing counts in the 2nd dimension and comparing it to the
# trajectory lengths.
range = torch.arange(0, trajectory_lengths.max()).repeat(len(trajectory_lengths), 1)
range = range.cuda(dones.device) if dones.is_cuda else range
mask = (trajectory_lengths.unsqueeze(1) > range).float()
if not batch_first:
mask = mask.T
return trajectories, (mask, batch_size, batch_first)

View File

@ -0,0 +1,43 @@
from typing import Any, Dict
class Serializable:
def load_state_dict(self, data: Dict[str, Any]) -> None:
"""Loads agent parameters from a dictionary."""
assert hasattr(self, "_serializable_objects")
for name in self._serializable_objects:
assert hasattr(self, name)
assert name in data, f'Object "{name}" was not found while loading "{self.__class__.__name__}".'
attr = getattr(self, name)
if hasattr(attr, "load_state_dict"):
print(f"Loading {name}")
attr.load_state_dict(data[name])
else:
print(f"Loading value {name}={data[name]}")
setattr(self, name, data[name])
def state_dict(self) -> Dict[str, Any]:
"""Returns a dictionary containing the agent parameters."""
assert hasattr(self, "_serializable_objects")
data = {}
for name in self._serializable_objects:
assert hasattr(self, name)
attr = getattr(self, name)
data[name] = attr.state_dict() if hasattr(attr, "state_dict") else attr
return data
def _register_serializable(self, *objects) -> None:
if not hasattr(self, "_serializable_objects"):
self._serializable_objects = []
for name in objects:
if name in self._serializable_objects:
continue
self._serializable_objects.append(name)

View File

@ -1,16 +1,38 @@
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
from datetime import datetime
import git
import os
import numpy as np
import pathlib
import random
import torch
def environment_dimensions(env):
obs = env.get_observations()
if isinstance(obs, tuple):
obs, env_info = obs
else:
env_info = {}
dims = {}
dims["observations"] = obs.shape[1]
if "observations" in env_info and "critic" in env_info["observations"]:
dims["actor_observations"] = dims["observations"]
dims["critic_observations"] = env_info["observations"]["critic"].shape[1]
else:
dims["actor_observations"] = dims["observations"]
dims["critic_observations"] = dims["observations"]
dims["actions"] = env.num_actions
return dims
def split_and_pad_trajectories(tensor, dones):
"""Splits trajectories at done indices. Then concatenates them and pads with zeros up to the length og the longest trajectory.
"""Splits trajectories at done indices. Then concatenates them and padds with zeros up to the length og the longest trajectory.
Returns masks corresponding to valid parts of the trajectories
Example:
Input: [ [a1, a2, a3, a4 | a5, a6],
@ -24,7 +46,7 @@ def split_and_pad_trajectories(tensor, dones):
[b6, 0, 0, 0] | [True, False, False, False],
] | ]
Assumes that the inputy has the following dimension order: [time, number of envs, additional dimensions]
Assumes that the inputy has the following dimension order: [time, number of envs, aditional dimensions]
"""
dones = dones.clone()
dones[-1] = 1
@ -37,12 +59,7 @@ def split_and_pad_trajectories(tensor, dones):
trajectory_lengths_list = trajectory_lengths.tolist()
# Extract the individual trajectories
trajectories = torch.split(tensor.transpose(1, 0).flatten(0, 1), trajectory_lengths_list)
# add at least one full length trajectory
trajectories = trajectories + (torch.zeros(tensor.shape[0], tensor.shape[-1], device=tensor.device),)
# pad the trajectories to the length of the longest trajectory
padded_trajectories = torch.nn.utils.rnn.pad_sequence(trajectories)
# remove the added tensor
padded_trajectories = padded_trajectories[:, :-1]
trajectory_masks = trajectory_lengths > torch.arange(0, tensor.shape[0], device=tensor.device).unsqueeze(1)
return padded_trajectories, trajectory_masks
@ -58,29 +75,34 @@ def unpad_trajectories(trajectories, masks):
)
def store_code_state(logdir, repositories) -> list:
git_log_dir = os.path.join(logdir, "git")
os.makedirs(git_log_dir, exist_ok=True)
file_paths = []
def store_code_state(logdir, repositories):
for repository_file_path in repositories:
try:
repo = git.Repo(repository_file_path, search_parent_directories=True)
except Exception:
print(f"Could not find git repository in {repository_file_path}. Skipping.")
# skip if not a git repository
continue
# get the name of the repository
repo = git.Repo(repository_file_path, search_parent_directories=True)
repo_name = pathlib.Path(repo.working_dir).name
t = repo.head.commit.tree
diff_file_name = os.path.join(git_log_dir, f"{repo_name}.diff")
# check if the diff file already exists
if os.path.isfile(diff_file_name):
continue
# write the diff file
print(f"Storing git diff for '{repo_name}' in: {diff_file_name}")
with open(diff_file_name, "x") as f:
content = f"--- git status ---\n{repo.git.status()} \n\n\n--- git diff ---\n{repo.git.diff(t)}"
content = f"--- git status ---\n{repo.git.status()} \n\n\n--- git diff ---\n{repo.git.diff(t)}"
with open(os.path.join(logdir, f"{repo_name}_git.diff"), "x", encoding="utf-8") as f:
f.write(content)
# add the file path to the list of files to be uploaded
file_paths.append(diff_file_name)
return file_paths
def seed(s=None):
seed = int(datetime.now().timestamp() * 1e6) % 2**32 if s is None else s
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def squeeze_preserve_batch(tensor):
"""Squeezes a tensor, but preserves the batch dimension"""
single_batch = tensor.shape[0] == 1
squeezed_tensor = tensor.squeeze()
if single_batch:
squeezed_tensor = squeezed_tensor.unsqueeze(0)
return squeezed_tensor

View File

@ -1,11 +1,7 @@
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
import os
from dataclasses import asdict
from torch.utils.tensorboard import SummaryWriter
from legged_gym.utils import class_to_dict
try:
import wandb
@ -14,8 +10,6 @@ except ModuleNotFoundError:
class WandbSummaryWriter(SummaryWriter):
"""Summary writer for Weights and Biases."""
def __init__(self, log_dir: str, flush_secs: int, cfg):
super().__init__(log_dir, flush_secs)
@ -49,7 +43,7 @@ class WandbSummaryWriter(SummaryWriter):
wandb.config.update({"runner_cfg": runner_cfg})
wandb.config.update({"policy_cfg": policy_cfg})
wandb.config.update({"alg_cfg": alg_cfg})
wandb.config.update({"env_cfg": asdict(env_cfg)})
wandb.config.update({"env_cfg": class_to_dict(env_cfg)})
def _map_path(self, path):
if path in self.name_map:
@ -74,7 +68,4 @@ class WandbSummaryWriter(SummaryWriter):
self.store_config(env_cfg, runner_cfg, alg_cfg, policy_cfg)
def save_model(self, model_path, iter):
wandb.save(model_path, base_path=os.path.dirname(model_path))
def save_file(self, path, iter=None):
wandb.save(path, base_path=os.path.dirname(path))
wandb.save(model_path)

View File

@ -1,24 +1,13 @@
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
# SPDX-License-Identifier: BSD-3-Clause
from setuptools import find_packages, setup
from setuptools import setup, find_packages
setup(
name="rsl_rl",
version="2.0.2",
version="1.0.2",
packages=find_packages(),
author="ETH Zurich, NVIDIA CORPORATION",
maintainer="Nikita Rudin, David Hoeller",
maintainer_email="rudinn@ethz.ch",
url="https://github.com/leggedrobotics/rsl_rl",
license="BSD-3",
description="Fast and simple RL algorithms implemented in pytorch",
python_requires=">=3.6",
install_requires=[
"torch>=1.10.0",
"torchvision>=0.5.0",
"numpy>=1.16.4",
"GitPython",
"onnx",
"GitPython", "gym", "numpy", "onnx", "tensorboard", "torch", "torchvision", "wandb",
],
)

169
tests/test_algorithms.py Normal file
View File

@ -0,0 +1,169 @@
import unittest
from rsl_rl.algorithms import D4PG, DDPG, DPPO, DSAC, PPO, SAC, TD3
from rsl_rl.env.gym_env import GymEnv
from rsl_rl.modules import Network
from rsl_rl.runners.runner import Runner
DEVICE = "cpu"
class AlgorithmTestCaseMixin:
algorithm_class = None
def _make_env(self, params={}):
my_params = dict(name="LunarLanderContinuous-v2", device=DEVICE, environment_count=4)
my_params.update(params)
return GymEnv(**my_params)
def _make_agent(self, env, agent_params={}):
return self.algorithm_class(env, device=DEVICE, **agent_params)
def _make_runner(self, env, agent, runner_params={}):
if not runner_params or "num_steps_per_env" not in runner_params:
runner_params["num_steps_per_env"] = 6
return Runner(env, agent, device=DEVICE, **runner_params)
def _learn(self, env, agent, runner_params={}):
runner = self._make_runner(env, agent, runner_params)
runner.learn(10)
def test_default(self):
env = self._make_env()
agent = self._make_agent(env)
self._learn(env, agent)
def test_single_env_single_step(self):
env = self._make_env(dict(environment_count=1))
agent = self._make_agent(env)
self._learn(env, agent, dict(num_steps_per_env=1))
class RecurrentAlgorithmTestCaseMixin(AlgorithmTestCaseMixin):
def test_recurrent(self):
env = self._make_env()
agent = self._make_agent(env, dict(recurrent=True))
self._learn(env, agent)
def test_single_env_single_step_recurrent(self):
env = self._make_env(dict(environment_count=1))
agent = self._make_agent(env, dict(recurrent=True))
self._learn(env, agent, dict(num_steps_per_env=1))
class D4PGTest(AlgorithmTestCaseMixin, unittest.TestCase):
algorithm_class = D4PG
class DDPGTest(AlgorithmTestCaseMixin, unittest.TestCase):
algorithm_class = DDPG
iqn_params = dict(
critic_network=DPPO.network_iqn,
iqn_action_samples=8,
iqn_embedding_size=16,
iqn_feature_layers=2,
iqn_value_samples=4,
value_loss=DPPO.value_loss_energy,
)
qrdqn_params = dict(
critic_network=DPPO.network_qrdqn,
qrdqn_quantile_count=16,
value_loss=DPPO.value_loss_l1,
)
class DPPOTest(RecurrentAlgorithmTestCaseMixin, unittest.TestCase):
algorithm_class = DPPO
def test_qrdqn(self):
env = self._make_env()
agent = self._make_agent(env, qrdqn_params)
self._learn(env, agent)
def test_qrdqn_sing_env_single_step(self):
env = self._make_env(dict(environment_count=1))
agent = self._make_agent(env, qrdqn_params)
self._learn(env, agent, dict(num_steps_per_env=1))
def test_qrdqn_energy_loss(self):
my_agent_params = qrdqn_params.copy()
my_agent_params["value_loss"] = DPPO.value_loss_energy
env = self._make_env()
agent = self._make_agent(env, my_agent_params)
self._learn(env, agent)
def test_qrdqn_huber_loss(self):
my_agent_params = qrdqn_params.copy()
my_agent_params["value_loss"] = DPPO.value_loss_huber
env = self._make_env()
agent = self._make_agent(env, my_agent_params)
self._learn(env, agent)
def test_qrdqn_transformer(self):
my_agent_params = qrdqn_params.copy()
my_agent_params["recurrent"] = True
my_agent_params["critic_recurrent_layers"] = 2
my_agent_params["critic_recurrent_module"] = Network.recurrent_module_transformer
my_agent_params["critic_recurrent_tf_context_length"] = 8
my_agent_params["critic_recurrent_tf_head_count"] = 2
env = self._make_env()
agent = self._make_agent(env, my_agent_params)
self._learn(env, agent)
def test_iqn(self):
env = self._make_env()
agent = self._make_agent(env, iqn_params)
self._learn(env, agent)
def test_iqn_single_step_single_env(self):
env = self._make_env(dict(environment_count=1))
agent = self._make_agent(env, iqn_params)
self._learn(env, agent, dict(num_steps_per_env=1))
def test_iqn_recurrent(self):
my_agent_params = iqn_params.copy()
my_agent_params["recurrent"] = True
env = self._make_env()
agent = self._make_agent(env, my_agent_params)
self._learn(env, agent)
class DSACTest(AlgorithmTestCaseMixin, unittest.TestCase):
algorithm_class = DSAC
class PPOTest(RecurrentAlgorithmTestCaseMixin, unittest.TestCase):
algorithm_class = PPO
class SACTest(AlgorithmTestCaseMixin, unittest.TestCase):
algorithm_class = SAC
class TD3Test(AlgorithmTestCaseMixin, unittest.TestCase):
algorithm_class = TD3
if __name__ == "__main__":
unittest.main()

126
tests/test_dpg.py Normal file
View File

@ -0,0 +1,126 @@
import torch
import unittest
from rsl_rl.algorithms.dpg import AbstractDPG
from rsl_rl.env.pole_balancing import PoleBalancing
class DPG(AbstractDPG):
def draw_actions(self, obs, env_info):
pass
def eval_mode(self):
pass
def get_inference_policy(self, device=None):
pass
def process_transition(
self, observations, environment_info, actions, rewards, next_observations, next_environment_info, dones, data
):
pass
def register_terminations(self, terminations):
pass
def to(self, device):
pass
def train_mode(self):
pass
def update(self, storage):
pass
class FakeCritic(torch.nn.Module):
def __init__(self, values):
self.values = values
def forward(self, _):
return self.values
class DPGTest(unittest.TestCase):
def test_timeout_bootstrapping(self):
env = PoleBalancing(environment_count=4)
dpg = DPG(env, device="cpu", return_steps=3)
rewards = torch.tensor(
[
[0.1000, 0.4000, 0.6000, 0.2000, -0.6000, -0.2000],
[0.0000, 0.9000, 0.5000, -0.9000, -0.4000, 0.8000],
[-0.5000, 0.4000, 0.0000, -0.2000, 0.3000, 0.1000],
[-0.8000, 0.9000, -0.6000, 0.7000, 0.5000, 0.1000],
]
)
dones = torch.tensor(
[
[0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0],
]
)
timeouts = torch.tensor(
[
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0],
]
)
actions = torch.zeros((4, 6, 1))
observations = torch.zeros((4, 6, 2))
values = torch.tensor([-0.1000, -0.8000, 0.4000, 0.7000])
dpg.critic = FakeCritic(values)
dataset = [
{
"actions": actions[:, i],
"critic_observations": observations[:, i],
"dones": dones[:, i],
"rewards": rewards[:, i],
"timeouts": timeouts[:, i],
}
for i in range(3)
]
processed_dataset = dpg._process_dataset(dataset)
processed_rewards = torch.stack([processed_dataset[i]["rewards"] for i in range(1)], dim=-1)
expected_rewards = torch.tensor(
[
[1.08406],
[1.38105],
[-0.5],
[0.77707],
]
)
self.assertTrue(len(processed_dataset) == 1)
self.assertTrue(torch.isclose(processed_rewards, expected_rewards).all())
dataset = [
{
"actions": actions[:, i + 3],
"critic_observations": observations[:, i + 3],
"dones": dones[:, i + 3],
"rewards": rewards[:, i + 3],
"timeouts": timeouts[:, i + 3],
}
for i in range(3)
]
processed_dataset = dpg._process_dataset(dataset)
processed_rewards = torch.stack([processed_dataset[i]["rewards"] for i in range(3)], dim=-1)
expected_rewards = torch.tensor(
[
[0.994, 0.6, -0.59002],
[0.51291, -1.5592792, -2.08008],
[0.20398, 0.09603, 0.19501],
[1.593, 0.093, 0.7],
]
)
self.assertTrue(len(processed_dataset) == 3)
self.assertTrue(torch.isclose(processed_rewards, expected_rewards).all())

278
tests/test_dppo.py Normal file
View File

@ -0,0 +1,278 @@
import torch
import unittest
from rsl_rl.algorithms import DPPO
from rsl_rl.env.pole_balancing import PoleBalancing
class FakeCritic(torch.nn.Module):
def __init__(self, values, quantile_count=1):
self.quantile_count = quantile_count
self.recurrent = False
self.values = values
self.last_quantiles = values
def forward(self, _, distribution=False, measure_args=None):
if distribution:
return self.values
return self.values.mean(-1)
def quantiles_to_values(self, quantiles):
return quantiles.mean(-1)
class DPPOTest(unittest.TestCase):
def test_gae_computation(self):
# GT taken from old PPO implementation.
env = PoleBalancing(environment_count=4)
dppo = DPPO(env, device="cpu", gae_lambda=0.97, gamma=0.99, qrdqn_quantile_count=1)
rewards = torch.tensor(
[
[-1.0000e02, -1.4055e-01, -3.0476e-02, -2.7149e-01, -1.1157e-01, -2.3366e-01, -3.3658e-01, -1.6447e-01],
[
-1.7633e-01,
-2.6533e-01,
-3.0786e-01,
-2.6038e-01,
-2.7176e-01,
-2.1655e-01,
-1.5441e-01,
-2.9580e-01,
],
[-1.5952e-01, -1.5177e-01, -1.4296e-01, -1.6131e-01, -3.1395e-02, 2.8808e-03, -3.1242e-02, 4.8696e-03],
[1.1407e-02, -1.0000e02, -6.2290e-02, -3.7030e-01, -2.7648e-01, -3.6655e-01, -2.8456e-01, -2.3165e-01],
]
)
dones = torch.tensor(
[
[1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0],
]
)
observations = torch.zeros((dones.shape[0], dones.shape[1], 24))
timeouts = torch.zeros((dones.shape[0], dones.shape[1]))
values = torch.tensor(
[
[-4.6342, -7.6510, -7.0166, -7.6137, -7.4130, -7.7071, -7.7413, -7.8301],
[-7.0442, -7.0032, -6.9321, -6.7765, -6.5433, -6.3503, -6.2529, -5.9337],
[-7.5753, -7.8146, -7.6142, -7.8443, -7.8791, -7.7973, -7.7853, -7.7724],
[-6.4326, -6.1673, -7.6511, -7.7505, -8.0004, -7.8584, -7.5949, -7.9023],
]
)
value_quants = values.unsqueeze(-1)
last_values = torch.tensor([-7.9343, -5.8734, -7.8527, -8.1257])
dppo.critic = FakeCritic(last_values.unsqueeze(-1))
dataset = [
{
"dones": dones[:, i],
"full_next_critic_observations": observations[:, i].clone(),
"next_critic_observations": observations[:, i],
"rewards": rewards[:, i],
"timeouts": timeouts[:, i],
"values": values[:, i],
"value_quants": value_quants[:, i],
}
for i in range(dones.shape[1])
]
processed_dataset = dppo._process_dataset(dataset)
processed_returns = torch.stack(
[processed_dataset[i]["advantages"] + processed_dataset[i]["values"] for i in range(dones.shape[1])],
dim=-1,
)
processed_advantages = torch.stack(
[processed_dataset[i]["normalized_advantages"] for i in range(dones.shape[1])], dim=-1
)
expected_returns = torch.tensor(
[
[-100.0000, -8.4983, -8.4863, -8.5699, -8.4122, -8.4054, -8.2702, -8.0194],
[-7.2900, -7.1912, -6.9978, -6.7569, -6.5627, -6.3547, -6.1985, -6.1104],
[-7.9179, -7.8374, -7.7679, -7.6976, -7.6041, -7.6446, -7.7229, -7.7693],
[-96.2018, -100.0000, -9.0710, -9.1415, -8.8863, -8.7228, -8.4668, -8.2761],
]
)
expected_advantages = torch.tensor(
[
[-3.1452, 0.3006, 0.2779, 0.2966, 0.2951, 0.3060, 0.3122, 0.3246],
[0.3225, 0.3246, 0.3291, 0.3322, 0.3308, 0.3313, 0.3335, 0.3250],
[0.3190, 0.3307, 0.3259, 0.3368, 0.3415, 0.3371, 0.3338, 0.3316],
[-2.9412, -3.0893, 0.2797, 0.2808, 0.2992, 0.3000, 0.2997, 0.3179],
]
)
self.assertTrue(torch.isclose(processed_returns, expected_returns, atol=1e-4).all())
self.assertTrue(torch.isclose(processed_advantages, expected_advantages, atol=1e-4).all())
def test_target_computation_nstep(self):
# GT taken from old PPO implementation.
env = PoleBalancing(environment_count=2)
dppo = DPPO(env, device="cpu", gae_lambda=0.97, gamma=0.99, qrdqn_quantile_count=3, value_lambda=1.0)
rewards = torch.tensor(
[
[0.6, 1.0, 0.0, 0.6, 0.1, -0.2],
[0.1, -0.2, -0.4, 0.1, 1.0, 1.0],
]
)
dones = torch.tensor(
[
[1.0, 0.0, 0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 1.0, 0.0, 1.0, 0.0],
]
)
observations = torch.zeros((dones.shape[0], dones.shape[1], 24))
timeouts = torch.zeros((dones.shape[0], dones.shape[1]))
value_quants = torch.tensor(
[
[
[-0.4, 0.0, 0.1],
[0.9, 0.8, 0.7],
[0.7, 1.3, 0.0],
[1.2, 0.4, 1.2],
[1.3, 1.3, 1.1],
[0.0, 0.7, 0.5],
],
[
[1.3, 1.3, 0.9],
[0.4, -0.4, -0.1],
[0.4, 0.6, 0.1],
[0.7, 0.1, 0.3],
[0.2, 0.1, 0.3],
[1.4, 1.4, -0.3],
],
]
)
values = value_quants.mean(dim=-1)
last_values = torch.rand((dones.shape[0], 3))
dppo.critic = FakeCritic(last_values, quantile_count=3)
dataset = [
{
"dones": dones[:, i],
"full_next_critic_observations": observations[:, i].clone(),
"next_critic_observations": observations[:, i],
"rewards": rewards[:, i],
"timeouts": timeouts[:, i],
"values": values[:, i],
"value_quants": value_quants[:, i],
}
for i in range(dones.shape[1])
]
processed_dataset = dppo._process_dataset(dataset)
processed_value_target_quants = torch.stack(
[processed_dataset[i]["value_target_quants"] for i in range(dones.shape[1])],
dim=-2,
)
# N-step returns
# These exclude the reward received on the final step since it should not be added to the value target.
expected_returns = torch.tensor(
[
[0.6, 1.58806, 0.594, 0.6, 0.1, -0.2],
[-0.098, -0.2, -0.4, 0.1, 1.0, 1.0],
]
)
reset = lambda x: [0.0 for _ in x]
dscnt = lambda x, s=0: [x[v] * dppo.gamma**s for v in range(len(x))]
expected_value_target_quants = expected_returns.unsqueeze(-1) + torch.tensor(
[
[
reset([-0.4, 0.0, 0.1]),
dscnt([1.3, 1.3, 1.1], 3),
dscnt([1.3, 1.3, 1.1], 2),
dscnt([1.3, 1.3, 1.1], 1),
reset([1.3, 1.3, 1.1]),
dscnt(last_values[0], 1),
],
[
dscnt([0.4, 0.6, 0.1], 2),
dscnt([0.4, 0.6, 0.1], 1),
reset([0.4, 0.6, 0.1]),
dscnt([0.2, 0.1, 0.3], 1),
reset([0.2, 0.1, 0.3]),
dscnt(last_values[1], 1),
],
]
)
self.assertTrue(torch.isclose(processed_value_target_quants, expected_value_target_quants, atol=1e-4).all())
def test_target_computation_1step(self):
# GT taken from old PPO implementation.
env = PoleBalancing(environment_count=2)
dppo = DPPO(env, device="cpu", gae_lambda=0.97, gamma=0.99, qrdqn_quantile_count=3, value_lambda=0.0)
rewards = torch.tensor(
[
[0.6, 1.0, 0.0, 0.6, 0.1, -0.2],
[0.1, -0.2, -0.4, 0.1, 1.0, 1.0],
]
)
dones = torch.tensor(
[
[1.0, 0.0, 0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 1.0, 0.0, 1.0, 0.0],
]
)
observations = torch.zeros((dones.shape[0], dones.shape[1], 24))
timeouts = torch.zeros((dones.shape[0], dones.shape[1]))
value_quants = torch.tensor(
[
[
[-0.4, 0.0, 0.1],
[0.9, 0.8, 0.7],
[0.7, 1.3, 0.0],
[1.2, 0.4, 1.2],
[1.3, 1.3, 1.1],
[0.0, 0.7, 0.5],
],
[
[1.3, 1.3, 0.9],
[0.4, -0.4, -0.1],
[0.4, 0.6, 0.1],
[0.7, 0.1, 0.3],
[0.2, 0.1, 0.3],
[1.4, 1.4, -0.3],
],
]
)
values = value_quants.mean(dim=-1)
last_values = torch.rand((dones.shape[0], 3))
dppo.critic = FakeCritic(last_values, quantile_count=3)
dataset = [
{
"dones": dones[:, i],
"full_next_critic_observations": observations[:, i].clone(),
"next_critic_observations": observations[:, i],
"rewards": rewards[:, i],
"timeouts": timeouts[:, i],
"values": values[:, i],
"value_quants": value_quants[:, i],
}
for i in range(dones.shape[1])
]
processed_dataset = dppo._process_dataset(dataset)
processed_value_target_quants = torch.stack(
[processed_dataset[i]["value_target_quants"] for i in range(dones.shape[1])],
dim=-2,
)
# 1-step returns
expected_value_target_quants = rewards.unsqueeze(-1) + (
(1.0 - dones).float().unsqueeze(-1)
* dppo.gamma
* torch.cat((value_quants[:, 1:], last_values.unsqueeze(1)), dim=1)
)
self.assertTrue(torch.isclose(processed_value_target_quants, expected_value_target_quants, atol=1e-4).all())

171
tests/test_dppo_iqn.py Normal file
View File

@ -0,0 +1,171 @@
import torch
import unittest
from rsl_rl.algorithms import DPPO
from rsl_rl.env.vec_env import VecEnv
ACTION_SIZE = 3
ENV_COUNT = 3
OBS_SIZE = 24
class FakeEnv(VecEnv):
def __init__(self, rewards, dones, environment_count=1):
super().__init__(OBS_SIZE, OBS_SIZE, environment_count=environment_count)
self.num_actions = ACTION_SIZE
self.rewards = rewards
self.dones = dones
self._step = 0
def get_observations(self):
return torch.zeros((self.num_envs, self.num_obs)), {"observations": {}}
def get_privileged_observations(self):
return torch.zeros((self.num_envs, self.num_privileged_obs)), {"observations": {}}
def step(self, actions):
obs, _ = self.get_observations()
rewards = self.rewards[self._step]
dones = self.dones[self._step]
self._step += 1
return obs, rewards, dones, {"observations": {}}
def reset(self):
pass
class FakeCritic(torch.nn.Module):
def __init__(self, action_samples, value_samples, action_values, value_values, action_taus, value_taus):
self.recurrent = False
self.action_samples = action_samples
self.value_samples = value_samples
self.action_values = action_values
self.value_values = value_values
self.action_taus = action_taus
self.value_taus = value_taus
self.last_quantiles = None
self.last_taus = None
def forward(self, _, distribution=False, measure_args=None, sample_count=8, taus=None, use_measure=True):
if taus is not None:
sample_count = taus.shape[-1]
if sample_count == self.action_samples:
self.last_taus = self.action_taus
self.last_quantiles = self.action_values
elif sample_count == self.value_samples:
self.last_taus = self.value_taus
self.last_quantiles = self.value_values
else:
raise ValueError(f"Invalid sample count: {sample_count}")
if distribution:
return self.last_quantiles
return self.last_quantiles.mean(-1)
def fake_process_quants(self, x):
idx = torch.arange(0, x.shape[-1]).expand(*x.shape[:-1])
return x, idx
class DPPOTest(unittest.TestCase):
def test_value_target_computation(self):
rewards = torch.tensor(
[
[-1.0000e02, -1.4055e-01, -3.0476e-02],
[-1.7633e-01, -2.6533e-01, -3.0786e-01],
[-1.5952e-01, -1.5177e-01, -1.4296e-01],
[1.1407e-02, -1.0000e02, -6.2290e-02],
]
)
dones = torch.tensor(
[
[1, 0, 0],
[0, 0, 0],
[0, 0, 0],
[0, 1, 0],
]
)
env = FakeEnv(rewards, dones, environment_count=ENV_COUNT)
dppo = DPPO(
env,
critic_network=DPPO.network_iqn,
device="cpu",
gae_lambda=0.97,
gamma=0.99,
iqn_action_samples=4,
iqn_value_samples=2,
value_lambda=1.0,
value_loss=DPPO.value_loss_energy,
)
# Generate fake dataset
action_taus = torch.tensor(
[
[[0.3, 0.5, 1.0, 0.2], [0.8, 0.9, 0.0, 0.9], [0.6, 0.1, 0.6, 0.5]],
[[0.7, 0.9, 0.3, 0.0], [1.0, 0.7, 0.7, 0.7], [0.3, 0.8, 0.8, 0.1]],
[[0.3, 0.8, 0.3, 0.2], [0.2, 0.9, 0.6, 0.4], [0.8, 0.4, 0.8, 1.0]],
[[0.6, 0.6, 0.8, 0.8], [0.8, 0.0, 0.9, 0.1], [0.2, 0.3, 0.6, 0.2]],
]
)
action_value_quants = torch.tensor(
[
[[0.2, 0.2, 0.6, 0.5], [0.5, 0.8, 0.1, 0.0], [1.0, 0.1, 0.8, 0.8]],
[[0.0, 0.6, 0.1, 0.9], [0.2, 1.0, 0.9, 1.0], [0.4, 0.1, 0.1, 0.8]],
[[0.7, 0.0, 0.6, 0.8], [0.7, 0.7, 0.7, 0.8], [0.0, 0.1, 0.5, 0.8]],
[[0.5, 0.8, 0.1, 0.1], [0.9, 0.4, 0.7, 0.6], [0.6, 0.3, 0.1, 0.4]],
]
)
value_taus = torch.tensor(
[
[[0.3, 0.5], [0.8, 0.9], [0.6, 0.1]],
[[0.7, 0.9], [1.0, 0.7], [0.3, 0.8]],
[[0.3, 0.8], [0.2, 0.9], [0.8, 0.4]],
[[0.6, 0.6], [0.8, 0.0], [0.2, 0.3]],
]
)
value_value_quants = torch.tensor(
[
[[0.9, 0.8], [0.1, 0.3], [0.3, 0.5]],
[[0.2, 0.1], [0.9, 0.3], [0.4, 0.2]],
[[0.7, 1.0], [0.6, 0.2], [0.2, 0.6]],
[[0.4, 1.0], [0.3, 0.6], [0.3, 0.1]],
]
)
actions = torch.zeros(ENV_COUNT, ACTION_SIZE)
env_info = {"observations": {}}
obs = torch.zeros(ENV_COUNT, OBS_SIZE)
dataset = []
for i in range(4):
dppo.critic = FakeCritic(4, 2, action_value_quants[i], value_value_quants[i], action_taus[i], value_taus[i])
dppo.critic._process_quants = fake_process_quants
_, data = dppo.draw_actions(obs, {})
_, rewards, dones, _ = env.step(actions)
dataset.append(
dppo.process_transition(
obs,
env_info,
actions,
rewards,
obs,
env_info,
dones,
data,
)
)
processed_dataset = dppo._process_dataset(dataset)
# TODO: Test that the value targets are correct.

View File

@ -0,0 +1,170 @@
import torch
import unittest
from rsl_rl.algorithms import DPPO
from rsl_rl.env.vec_env import VecEnv
from rsl_rl.runners.runner import Runner
from rsl_rl.utils.benchmarkable import Benchmarkable
class FakeNetwork(torch.nn.Module, Benchmarkable):
def __init__(self, values):
super().__init__()
self.hidden_state = None
self.quantile_count = 1
self.recurrent = True
self.values = values
self._hidden_size = 2
def forward(self, x, hidden_state=None):
if not hidden_state:
self.hidden_state = (self.hidden_state[0] + 1, self.hidden_state[1] - 1)
values = self.values.repeat((*x.shape[:-1], 1)).squeeze(-1)
values.requires_grad_(True)
return values
def reset_full_hidden_state(self, batch_size=None):
assert batch_size is None or batch_size == 4, f"batch_size={batch_size}"
self.hidden_state = (torch.zeros((1, 4, self._hidden_size)), torch.zeros((1, 4, self._hidden_size)))
def reset_hidden_state(self, indices):
self.hidden_state[0][:, indices] = torch.zeros((len(indices), self._hidden_size))
self.hidden_state[1][:, indices] = torch.zeros((len(indices), self._hidden_size))
class FakeActorNetwork(FakeNetwork):
def forward(self, x, compute_std=False, hidden_state=None):
values = super().forward(x, hidden_state=hidden_state)
if compute_std:
return values, torch.ones_like(values)
return values
class FakeCriticNetwork(FakeNetwork):
_quantile_count = 1
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x, distribution=False, hidden_state=None, measure_args=None):
values = super().forward(x, hidden_state=hidden_state)
self.last_quantiles = values.reshape(*values.shape, 1)
if distribution:
return self.last_quantiles
return values
def quantile_l1_loss(self, *args, **kwargs):
return torch.tensor(0.0)
def quantiles_to_values(self, quantiles):
return quantiles.squeeze()
class FakeEnv(VecEnv):
def __init__(self, dones=None, **kwargs):
super().__init__(3, 3, **kwargs)
self.num_actions = 3
self._extra = {"observations": {}, "time_outs": torch.zeros((self.num_envs, 1))}
self._step = 0
self._dones = dones
self.reset()
def get_observations(self):
return self._state_buf, self._extra
def get_privileged_observations(self):
return self._state_buf, self._extra
def reset(self):
self._state_buf = torch.zeros((self.num_envs, self.num_obs))
return self._state_buf, self._extra
def step(self, actions):
assert actions.shape[0] == self.num_envs
assert actions.shape[1] == self.num_actions
self._state_buf += actions
rewards = torch.zeros((self.num_envs))
dones = torch.zeros((self.num_envs)) if self._dones is None else self._dones[self._step % self._dones.shape[0]]
self._step += 1
return self._state_buf, rewards, dones, self._extra
class DPPORecurrencyTest(unittest.TestCase):
def test_draw_action_produces_hidden_state(self):
"""Test that the hidden state is correctly added to the data dictionary when drawing actions."""
env = FakeEnv(environment_count=4)
dppo = DPPO(env, device="cpu", recurrent=True)
dppo.actor = FakeActorNetwork(torch.ones(env.num_actions))
dppo.critic = FakeCriticNetwork(torch.zeros(1))
# Done during DPPO.__init__, however we need to reset the hidden state here again since we are using a fake
# network that is added after initialization.
dppo.actor.reset_full_hidden_state(batch_size=env.num_envs)
dppo.critic.reset_full_hidden_state(batch_size=env.num_envs)
ones = torch.ones((1, env.num_envs, dppo.actor._hidden_size))
state, extra = env.reset()
for ctr in range(10):
_, data = dppo.draw_actions(state, extra)
# Actor state is changed every time an action is drawn.
self.assertTrue(torch.allclose(data["actor_state_h"], ones * ctr))
self.assertTrue(torch.allclose(data["actor_state_c"], -ones * ctr))
# Critic state is only changed and saved when processing the transition (evaluating the action) so we can't
# check it here.
def test_update_produces_hidden_state(self):
"""Test that the hidden state is correctly added to the data dictionary when updating."""
dones = torch.cat((torch.tensor([[0, 0, 0, 1]]), torch.zeros((4, 4)), torch.tensor([[1, 0, 0, 0]])), dim=0)
env = FakeEnv(dones=dones, environment_count=4)
dppo = DPPO(env, device="cpu", recurrent=True)
runner = Runner(env, dppo, num_steps_per_env=6)
dppo._value_loss = lambda *args, **kwargs: torch.tensor(0.0)
dppo.actor = FakeActorNetwork(torch.ones(env.num_actions))
dppo.critic = FakeCriticNetwork(torch.zeros(1))
dppo.actor.reset_full_hidden_state(batch_size=env.num_envs)
dppo.critic.reset_full_hidden_state(batch_size=env.num_envs)
runner.learn(1)
state_h_0 = torch.tensor([[0, 0], [0, 0], [0, 0], [0, 0]])
state_h_1 = torch.tensor([[1, 1], [1, 1], [1, 1], [0, 0]])
state_h_2 = state_h_1 + 1
state_h_3 = state_h_2 + 1
state_h_4 = state_h_3 + 1
state_h_5 = state_h_4 + 1
state_h_6 = torch.tensor([[0, 0], [6, 6], [6, 6], [5, 5]])
state_h = (
torch.cat((state_h_0, state_h_1, state_h_2, state_h_3, state_h_4, state_h_5), dim=0).float().unsqueeze(1)
)
next_state_h = (
torch.cat((state_h_1, state_h_2, state_h_3, state_h_4, state_h_5, state_h_6), dim=0).float().unsqueeze(1)
)
self.assertTrue(torch.allclose(dppo.storage._data["critic_state_h"], state_h))
self.assertTrue(torch.allclose(dppo.storage._data["critic_state_c"], -state_h))
self.assertTrue(torch.allclose(dppo.storage._data["critic_next_state_h"], next_state_h))
self.assertTrue(torch.allclose(dppo.storage._data["critic_next_state_c"], -next_state_h))
self.assertTrue(torch.allclose(dppo.storage._data["actor_state_h"], state_h))
self.assertTrue(torch.allclose(dppo.storage._data["actor_state_c"], -state_h))

99
tests/test_ppo.py Normal file
View File

@ -0,0 +1,99 @@
import torch
import unittest
from rsl_rl.algorithms import PPO
from rsl_rl.env.pole_balancing import PoleBalancing
class FakeCritic(torch.nn.Module):
def __init__(self, values):
self.recurrent = False
self.values = values
def forward(self, _):
return self.values
class PPOTest(unittest.TestCase):
def test_gae_computation(self):
# GT taken from old PPO implementation.
env = PoleBalancing(environment_count=4)
ppo = PPO(env, device="cpu", gae_lambda=0.97, gamma=0.99)
rewards = torch.tensor(
[
[-1.0000e02, -1.4055e-01, -3.0476e-02, -2.7149e-01, -1.1157e-01, -2.3366e-01, -3.3658e-01, -1.6447e-01],
[
-1.7633e-01,
-2.6533e-01,
-3.0786e-01,
-2.6038e-01,
-2.7176e-01,
-2.1655e-01,
-1.5441e-01,
-2.9580e-01,
],
[-1.5952e-01, -1.5177e-01, -1.4296e-01, -1.6131e-01, -3.1395e-02, 2.8808e-03, -3.1242e-02, 4.8696e-03],
[1.1407e-02, -1.0000e02, -6.2290e-02, -3.7030e-01, -2.7648e-01, -3.6655e-01, -2.8456e-01, -2.3165e-01],
]
)
dones = torch.tensor(
[
[1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0],
]
)
observations = torch.zeros((dones.shape[0], dones.shape[1], 24))
timeouts = torch.zeros((dones.shape[0], dones.shape[1]))
values = torch.tensor(
[
[-4.6342, -7.6510, -7.0166, -7.6137, -7.4130, -7.7071, -7.7413, -7.8301],
[-7.0442, -7.0032, -6.9321, -6.7765, -6.5433, -6.3503, -6.2529, -5.9337],
[-7.5753, -7.8146, -7.6142, -7.8443, -7.8791, -7.7973, -7.7853, -7.7724],
[-6.4326, -6.1673, -7.6511, -7.7505, -8.0004, -7.8584, -7.5949, -7.9023],
]
)
last_values = torch.tensor([-7.9343, -5.8734, -7.8527, -8.1257])
ppo.critic = FakeCritic(last_values)
dataset = [
{
"dones": dones[:, i],
"next_critic_observations": observations[:, i],
"rewards": rewards[:, i],
"timeouts": timeouts[:, i],
"values": values[:, i],
}
for i in range(dones.shape[1])
]
processed_dataset = ppo._process_dataset(dataset)
processed_returns = torch.stack(
[processed_dataset[i]["advantages"] + processed_dataset[i]["values"] for i in range(dones.shape[1])],
dim=-1,
)
processed_advantages = torch.stack(
[processed_dataset[i]["normalized_advantages"] for i in range(dones.shape[1])], dim=-1
)
expected_returns = torch.tensor(
[
[-100.0000, -8.4983, -8.4863, -8.5699, -8.4122, -8.4054, -8.2702, -8.0194],
[-7.2900, -7.1912, -6.9978, -6.7569, -6.5627, -6.3547, -6.1985, -6.1104],
[-7.9179, -7.8374, -7.7679, -7.6976, -7.6041, -7.6446, -7.7229, -7.7693],
[-96.2018, -100.0000, -9.0710, -9.1415, -8.8863, -8.7228, -8.4668, -8.2761],
]
)
expected_advantages = torch.tensor(
[
[-3.1452, 0.3006, 0.2779, 0.2966, 0.2951, 0.3060, 0.3122, 0.3246],
[0.3225, 0.3246, 0.3291, 0.3322, 0.3308, 0.3313, 0.3335, 0.3250],
[0.3190, 0.3307, 0.3259, 0.3368, 0.3415, 0.3371, 0.3338, 0.3316],
[-2.9412, -3.0893, 0.2797, 0.2808, 0.2992, 0.3000, 0.2997, 0.3179],
]
)
self.assertTrue(torch.isclose(processed_returns, expected_returns, atol=1e-4).all())
self.assertTrue(torch.isclose(processed_advantages, expected_advantages, atol=1e-4).all())

View File

@ -0,0 +1,144 @@
import torch
import unittest
from rsl_rl.algorithms import PPO
from rsl_rl.env.vec_env import VecEnv
from rsl_rl.runners.runner import Runner
class FakeNetwork(torch.nn.Module):
def __init__(self, values):
super().__init__()
self.hidden_state = None
self.recurrent = True
self.values = values
self._hidden_size = 2
def forward(self, x, hidden_state=None):
if not hidden_state:
self.hidden_state = (self.hidden_state[0] + 1, self.hidden_state[1] - 1)
values = self.values.repeat((*x.shape[:-1], 1)).squeeze(-1)
values.requires_grad_(True)
return values
def reset_full_hidden_state(self, batch_size=None):
assert batch_size is None or batch_size == 4, f"batch_size={batch_size}"
self.hidden_state = (torch.zeros((1, 4, self._hidden_size)), torch.zeros((1, 4, self._hidden_size)))
def reset_hidden_state(self, indices):
self.hidden_state[0][:, indices] = torch.zeros((len(indices), self._hidden_size))
self.hidden_state[1][:, indices] = torch.zeros((len(indices), self._hidden_size))
class FakeActorNetwork(FakeNetwork):
def forward(self, x, compute_std=False, hidden_state=None):
values = super().forward(x, hidden_state=hidden_state)
if compute_std:
return values, torch.ones_like(values)
return values
class FakeEnv(VecEnv):
def __init__(self, dones=None, **kwargs):
super().__init__(3, 3, **kwargs)
self.num_actions = 3
self._extra = {"observations": {}, "time_outs": torch.zeros((self.num_envs, 1))}
self._step = 0
self._dones = dones
self.reset()
def get_observations(self):
return self._state_buf, self._extra
def get_privileged_observations(self):
return self._state_buf, self._extra
def reset(self):
self._state_buf = torch.zeros((self.num_envs, self.num_obs))
return self._state_buf, self._extra
def step(self, actions):
assert actions.shape[0] == self.num_envs
assert actions.shape[1] == self.num_actions
self._state_buf += actions
rewards = torch.zeros((self.num_envs))
dones = torch.zeros((self.num_envs)) if self._dones is None else self._dones[self._step % self._dones.shape[0]]
self._step += 1
return self._state_buf, rewards, dones, self._extra
class PPORecurrencyTest(unittest.TestCase):
def test_draw_action_produces_hidden_state(self):
"""Test that the hidden state is correctly added to the data dictionary when drawing actions."""
env = FakeEnv(environment_count=4)
ppo = PPO(env, device="cpu", recurrent=True)
ppo.actor = FakeActorNetwork(torch.ones(env.num_actions))
ppo.critic = FakeNetwork(torch.zeros(1))
# Done during PPO.__init__, however we need to reset the hidden state here again since we are using a fake
# network that is added after initialization.
ppo.actor.reset_full_hidden_state(batch_size=env.num_envs)
ppo.critic.reset_full_hidden_state(batch_size=env.num_envs)
ones = torch.ones((1, env.num_envs, ppo.actor._hidden_size))
state, extra = env.reset()
for ctr in range(10):
_, data = ppo.draw_actions(state, extra)
# Actor state is changed every time an action is drawn.
self.assertTrue(torch.allclose(data["actor_state_h"], ones * ctr))
self.assertTrue(torch.allclose(data["actor_state_c"], -ones * ctr))
# Critic state is only changed and saved when processing the transition (evaluating the action) so we can't
# check it here.
def test_update_produces_hidden_state(self):
"""Test that the hidden state is correctly added to the data dictionary when updating."""
dones = torch.cat((torch.tensor([[0, 0, 0, 1]]), torch.zeros((4, 4)), torch.tensor([[1, 0, 0, 0]])), dim=0)
env = FakeEnv(dones=dones, environment_count=4)
ppo = PPO(env, device="cpu", recurrent=True)
runner = Runner(env, ppo, num_steps_per_env=6)
ppo.actor = FakeActorNetwork(torch.ones(env.num_actions))
ppo.critic = FakeNetwork(torch.zeros(1))
ppo.actor.reset_full_hidden_state(batch_size=env.num_envs)
ppo.critic.reset_full_hidden_state(batch_size=env.num_envs)
runner.learn(1)
state_h_0 = torch.tensor([[0, 0], [0, 0], [0, 0], [0, 0]])
state_h_1 = torch.tensor([[1, 1], [1, 1], [1, 1], [0, 0]])
state_h_2 = state_h_1 + 1
state_h_3 = state_h_2 + 1
state_h_4 = state_h_3 + 1
state_h_5 = state_h_4 + 1
state_h_6 = torch.tensor([[0, 0], [6, 6], [6, 6], [5, 5]])
state_h = (
torch.cat((state_h_0, state_h_1, state_h_2, state_h_3, state_h_4, state_h_5), dim=0).float().unsqueeze(1)
)
next_state_h = (
torch.cat((state_h_1, state_h_2, state_h_3, state_h_4, state_h_5, state_h_6), dim=0).float().unsqueeze(1)
)
self.assertTrue(torch.allclose(ppo.storage._data["critic_state_h"], state_h))
self.assertTrue(torch.allclose(ppo.storage._data["critic_state_c"], -state_h))
self.assertTrue(torch.allclose(ppo.storage._data["critic_next_state_h"], next_state_h))
self.assertTrue(torch.allclose(ppo.storage._data["critic_next_state_c"], -next_state_h))
self.assertTrue(torch.allclose(ppo.storage._data["actor_state_h"], state_h))
self.assertTrue(torch.allclose(ppo.storage._data["actor_state_c"], -state_h))

View File

@ -0,0 +1,286 @@
import torch
import unittest
from rsl_rl.modules.quantile_network import QuantileNetwork
class QuantileNetworkTest(unittest.TestCase):
def test_l1_loss(self):
qn = QuantileNetwork(10, 1, quantile_count=5)
prediction = torch.tensor(
[
[0.8510, 0.2329, 0.4244, 0.5241, 0.2144],
[0.7693, 0.2522, 0.3909, 0.0858, 0.7914],
[0.8701, 0.2144, 0.9661, 0.9975, 0.5043],
[0.2653, 0.6951, 0.9787, 0.2244, 0.0430],
[0.7907, 0.5209, 0.7276, 0.1735, 0.2757],
[0.1696, 0.7167, 0.6363, 0.2188, 0.7025],
[0.0445, 0.6008, 0.5334, 0.1838, 0.7387],
[0.4934, 0.5117, 0.4488, 0.0591, 0.6442],
]
)
target = torch.tensor(
[
[0.3918, 0.8979, 0.4347, 0.1076, 0.5303],
[0.5449, 0.9974, 0.3197, 0.8686, 0.0631],
[0.7397, 0.7734, 0.6559, 0.3020, 0.7229],
[0.9519, 0.8138, 0.1502, 0.3445, 0.3356],
[0.8970, 0.0910, 0.7536, 0.6069, 0.2556],
[0.1741, 0.6863, 0.7142, 0.2911, 0.3142],
[0.8835, 0.0215, 0.4774, 0.5362, 0.4998],
[0.8037, 0.8269, 0.5518, 0.4368, 0.5323],
]
)
loss = qn.quantile_l1_loss(prediction, target)
self.assertAlmostEqual(loss.item(), 0.16419549)
def test_l1_loss_3d(self):
qn = QuantileNetwork(10, 1, quantile_count=5)
prediction = torch.tensor(
[
[
[0.8510, 0.2329, 0.4244, 0.5241, 0.2144],
[0.7693, 0.2522, 0.3909, 0.0858, 0.7914],
[0.8701, 0.2144, 0.9661, 0.9975, 0.5043],
[0.2653, 0.6951, 0.9787, 0.2244, 0.0430],
[0.7907, 0.5209, 0.7276, 0.1735, 0.2757],
[0.1696, 0.7167, 0.6363, 0.2188, 0.7025],
[0.0445, 0.6008, 0.5334, 0.1838, 0.7387],
[0.4934, 0.5117, 0.4488, 0.0591, 0.6442],
],
[
[0.6874, 0.6214, 0.7796, 0.8148, 0.2070],
[0.0276, 0.5764, 0.5516, 0.9682, 0.6901],
[0.4020, 0.7084, 0.9965, 0.4311, 0.3789],
[0.5350, 0.9431, 0.1032, 0.6959, 0.4992],
[0.5059, 0.5479, 0.2302, 0.6753, 0.1593],
[0.6753, 0.4590, 0.9956, 0.6117, 0.1410],
[0.7464, 0.7184, 0.2972, 0.7694, 0.7999],
[0.3907, 0.2112, 0.6485, 0.0139, 0.6252],
],
]
)
target = torch.tensor(
[
[
[0.3918, 0.8979, 0.4347, 0.1076, 0.5303],
[0.5449, 0.9974, 0.3197, 0.8686, 0.0631],
[0.7397, 0.7734, 0.6559, 0.3020, 0.7229],
[0.9519, 0.8138, 0.1502, 0.3445, 0.3356],
[0.8970, 0.0910, 0.7536, 0.6069, 0.2556],
[0.1741, 0.6863, 0.7142, 0.2911, 0.3142],
[0.8835, 0.0215, 0.4774, 0.5362, 0.4998],
[0.8037, 0.8269, 0.5518, 0.4368, 0.5323],
],
[
[0.5120, 0.7683, 0.3579, 0.8640, 0.4374],
[0.2533, 0.3039, 0.2214, 0.7069, 0.3093],
[0.6993, 0.4288, 0.0827, 0.9156, 0.2043],
[0.6739, 0.2303, 0.3263, 0.6884, 0.3847],
[0.3990, 0.1458, 0.8918, 0.8036, 0.5012],
[0.9061, 0.2024, 0.7276, 0.8619, 0.1198],
[0.7379, 0.2005, 0.7634, 0.5691, 0.6132],
[0.4341, 0.5711, 0.1119, 0.4286, 0.7521],
],
]
)
loss = qn.quantile_l1_loss(prediction, target)
self.assertAlmostEqual(loss.item(), 0.15836075)
def test_l1_loss_multi_output(self):
qn = QuantileNetwork(10, 3, quantile_count=10)
prediction = torch.tensor(
[
[0.3003, 0.8692, 0.4608, 0.7158, 0.2640, 0.3928, 0.4557, 0.4620, 0.1331, 0.6356],
[0.8867, 0.1521, 0.5827, 0.0501, 0.4401, 0.7216, 0.6081, 0.5758, 0.2772, 0.6048],
[0.0763, 0.1609, 0.1860, 0.9173, 0.2121, 0.1920, 0.8509, 0.8588, 0.3321, 0.7202],
[0.8375, 0.5339, 0.4287, 0.9228, 0.8519, 0.0420, 0.5736, 0.9156, 0.4444, 0.2039],
[0.0704, 0.1833, 0.0839, 0.9573, 0.9852, 0.4191, 0.3562, 0.7225, 0.8481, 0.2096],
[0.4054, 0.8172, 0.8737, 0.2138, 0.4455, 0.7538, 0.1936, 0.9346, 0.8710, 0.0178],
[0.2139, 0.6619, 0.6889, 0.5726, 0.0595, 0.3278, 0.7673, 0.0803, 0.0374, 0.9011],
[0.2757, 0.0309, 0.8913, 0.0958, 0.1828, 0.9624, 0.6529, 0.7451, 0.9996, 0.8877],
[0.0722, 0.4240, 0.0716, 0.3199, 0.5570, 0.1056, 0.5950, 0.9926, 0.2991, 0.7334],
[0.0576, 0.6353, 0.5078, 0.4456, 0.9119, 0.6897, 0.1720, 0.5172, 0.9939, 0.5044],
[0.6300, 0.2304, 0.4064, 0.9195, 0.3299, 0.8631, 0.5842, 0.6751, 0.2964, 0.1215],
[0.7418, 0.5448, 0.7615, 0.6333, 0.9255, 0.1129, 0.0552, 0.4198, 0.9953, 0.7482],
[0.9910, 0.7644, 0.7047, 0.1395, 0.3688, 0.7688, 0.8574, 0.3494, 0.6153, 0.1286],
[0.2325, 0.7908, 0.3036, 0.4504, 0.3775, 0.6004, 0.0199, 0.9581, 0.8078, 0.8337],
[0.4038, 0.8313, 0.5441, 0.4778, 0.5777, 0.0580, 0.5314, 0.5336, 0.0740, 0.0094],
[0.9025, 0.5814, 0.4711, 0.2683, 0.4443, 0.5799, 0.6703, 0.2678, 0.7538, 0.1317],
[0.6755, 0.5696, 0.3334, 0.9146, 0.6203, 0.2080, 0.0799, 0.0059, 0.8347, 0.1874],
[0.0932, 0.0264, 0.9006, 0.3124, 0.3421, 0.8271, 0.3495, 0.2814, 0.9888, 0.5042],
[0.4893, 0.3514, 0.2564, 0.8117, 0.3738, 0.9085, 0.3055, 0.1456, 0.3624, 0.4095],
[0.0726, 0.2145, 0.6295, 0.7423, 0.1292, 0.7570, 0.4645, 0.0775, 0.1280, 0.7312],
[0.8763, 0.5302, 0.8627, 0.0429, 0.2833, 0.4745, 0.6308, 0.2245, 0.2755, 0.6823],
[0.9997, 0.3519, 0.0312, 0.1468, 0.5145, 0.0286, 0.6333, 0.1323, 0.2264, 0.9109],
[0.7742, 0.4857, 0.0413, 0.4523, 0.6847, 0.5774, 0.9478, 0.5861, 0.9834, 0.9437],
[0.7590, 0.5697, 0.7509, 0.3562, 0.9926, 0.3380, 0.0337, 0.7871, 0.1351, 0.9184],
[0.5701, 0.0234, 0.8088, 0.0681, 0.7090, 0.5925, 0.5266, 0.7198, 0.4121, 0.0268],
[0.5377, 0.1420, 0.2649, 0.0885, 0.1987, 0.1475, 0.1562, 0.2283, 0.9447, 0.4679],
[0.0306, 0.9763, 0.1234, 0.5009, 0.8800, 0.9409, 0.3525, 0.7264, 0.2209, 0.1436],
[0.2492, 0.4041, 0.9044, 0.3730, 0.3152, 0.7515, 0.2614, 0.9726, 0.6402, 0.5211],
[0.8626, 0.2828, 0.6946, 0.7066, 0.4395, 0.3015, 0.2643, 0.4421, 0.6036, 0.9009],
[0.7721, 0.1706, 0.7043, 0.4097, 0.7685, 0.3818, 0.1468, 0.6452, 0.1102, 0.1826],
[0.7156, 0.1795, 0.5574, 0.9478, 0.0058, 0.8037, 0.8712, 0.7730, 0.5638, 0.5843],
[0.8775, 0.6133, 0.4118, 0.3038, 0.2612, 0.2424, 0.8960, 0.8194, 0.3588, 0.3198],
]
)
target = torch.tensor(
[
[0.0986, 0.4029, 0.3110, 0.9976, 0.5668, 0.2658, 0.0660, 0.8492, 0.7872, 0.6368],
[0.3556, 0.9007, 0.0227, 0.7684, 0.0105, 0.9890, 0.7468, 0.0642, 0.5164, 0.1976],
[0.1331, 0.0998, 0.0959, 0.5596, 0.5984, 0.3880, 0.8050, 0.8320, 0.8977, 0.3486],
[0.3297, 0.8110, 0.2844, 0.4594, 0.0739, 0.2865, 0.2957, 0.9357, 0.9898, 0.4419],
[0.0495, 0.2826, 0.8306, 0.2968, 0.5690, 0.7251, 0.5947, 0.7526, 0.5076, 0.6480],
[0.0381, 0.8645, 0.7774, 0.9158, 0.9682, 0.5851, 0.0913, 0.8948, 0.1251, 0.1205],
[0.9059, 0.2758, 0.1948, 0.2694, 0.0946, 0.4381, 0.4667, 0.2176, 0.3494, 0.6073],
[0.1778, 0.8632, 0.3015, 0.2882, 0.4214, 0.2420, 0.8394, 0.1468, 0.9679, 0.6730],
[0.2400, 0.4344, 0.9765, 0.6544, 0.6338, 0.3434, 0.4776, 0.7981, 0.2008, 0.2267],
[0.5574, 0.8110, 0.0264, 0.4199, 0.8178, 0.8421, 0.8237, 0.2623, 0.8025, 0.9030],
[0.8652, 0.2872, 0.9463, 0.5543, 0.4866, 0.2842, 0.6692, 0.2306, 0.3136, 0.4570],
[0.0651, 0.8955, 0.7531, 0.9373, 0.0265, 0.0795, 0.7755, 0.1123, 0.1920, 0.3273],
[0.9824, 0.4177, 0.2729, 0.9447, 0.3987, 0.5495, 0.3674, 0.8067, 0.8668, 0.2394],
[0.4874, 0.3616, 0.7577, 0.6439, 0.2927, 0.8110, 0.6821, 0.0702, 0.5514, 0.7358],
[0.3627, 0.6392, 0.9085, 0.3646, 0.6051, 0.0586, 0.8763, 0.3899, 0.3242, 0.4598],
[0.0167, 0.0558, 0.3862, 0.7017, 0.0403, 0.6604, 0.9992, 0.2337, 0.5128, 0.1959],
[0.7774, 0.9201, 0.0405, 0.7894, 0.1406, 0.2458, 0.2616, 0.8787, 0.8158, 0.8591],
[0.3225, 0.9827, 0.4032, 0.2621, 0.7949, 0.9796, 0.9480, 0.3353, 0.1430, 0.5747],
[0.4734, 0.8714, 0.9320, 0.4265, 0.7765, 0.6980, 0.1587, 0.8784, 0.7119, 0.5141],
[0.7263, 0.4754, 0.8234, 0.0649, 0.4343, 0.5201, 0.8274, 0.9632, 0.3525, 0.8893],
[0.3324, 0.0142, 0.7222, 0.5026, 0.6011, 0.9275, 0.9351, 0.9236, 0.2621, 0.0768],
[0.8456, 0.1005, 0.5550, 0.0586, 0.3811, 0.0168, 0.9724, 0.9225, 0.7242, 0.0678],
[0.2167, 0.5423, 0.9059, 0.3320, 0.4026, 0.2128, 0.4562, 0.3564, 0.2573, 0.1076],
[0.8385, 0.2233, 0.0736, 0.3407, 0.4702, 0.1668, 0.5174, 0.4154, 0.4407, 0.1843],
[0.1828, 0.5321, 0.6651, 0.4108, 0.5736, 0.4012, 0.0434, 0.0034, 0.9282, 0.3111],
[0.1754, 0.8750, 0.6629, 0.7052, 0.9739, 0.7441, 0.8954, 0.9273, 0.3836, 0.5735],
[0.5586, 0.0381, 0.1493, 0.8575, 0.9351, 0.5222, 0.5600, 0.2369, 0.9217, 0.2545],
[0.1054, 0.8020, 0.8463, 0.6495, 0.3011, 0.3734, 0.7263, 0.8736, 0.9258, 0.5804],
[0.7614, 0.4748, 0.6588, 0.7717, 0.9811, 0.1659, 0.7851, 0.2135, 0.1767, 0.6724],
[0.7655, 0.8571, 0.4224, 0.9397, 0.1363, 0.9431, 0.9326, 0.3762, 0.1077, 0.9514],
[0.4115, 0.2169, 0.1340, 0.6564, 0.9989, 0.8068, 0.0387, 0.5064, 0.9964, 0.9427],
[0.5760, 0.2967, 0.3891, 0.6596, 0.8037, 0.1060, 0.0102, 0.8672, 0.5922, 0.6684],
]
)
loss = qn.quantile_l1_loss(prediction, target)
self.assertAlmostEqual(loss.item(), 0.17235948)
def test_quantile_huber_loss(self):
qn = QuantileNetwork(10, 1, quantile_count=5)
prediction = torch.tensor(
[
[0.8510, 0.2329, 0.4244, 0.5241, 0.2144],
[0.7693, 0.2522, 0.3909, 0.0858, 0.7914],
[0.8701, 0.2144, 0.9661, 0.9975, 0.5043],
[0.2653, 0.6951, 0.9787, 0.2244, 0.0430],
[0.7907, 0.5209, 0.7276, 0.1735, 0.2757],
[0.1696, 0.7167, 0.6363, 0.2188, 0.7025],
[0.0445, 0.6008, 0.5334, 0.1838, 0.7387],
[0.4934, 0.5117, 0.4488, 0.0591, 0.6442],
]
)
target = torch.tensor(
[
[0.3918, 0.8979, 0.4347, 0.1076, 0.5303],
[0.5449, 0.9974, 0.3197, 0.8686, 0.0631],
[0.7397, 0.7734, 0.6559, 0.3020, 0.7229],
[0.9519, 0.8138, 0.1502, 0.3445, 0.3356],
[0.8970, 0.0910, 0.7536, 0.6069, 0.2556],
[0.1741, 0.6863, 0.7142, 0.2911, 0.3142],
[0.8835, 0.0215, 0.4774, 0.5362, 0.4998],
[0.8037, 0.8269, 0.5518, 0.4368, 0.5323],
]
)
loss = qn.quantile_huber_loss(prediction, target)
self.assertAlmostEqual(loss.item(), 0.04035041)
def test_sample_energy_loss(self):
qn = QuantileNetwork(10, 1, quantile_count=5)
prediction = torch.tensor(
[
[0.9813, 0.5331, 0.3298, 0.2428, 0.0737],
[0.5442, 0.9623, 0.6070, 0.9360, 0.1145],
[0.3642, 0.0887, 0.1696, 0.8027, 0.7121],
[0.2005, 0.9889, 0.4350, 0.0301, 0.4546],
[0.8360, 0.6766, 0.2257, 0.7589, 0.3443],
[0.0835, 0.1747, 0.1734, 0.6668, 0.4522],
[0.0851, 0.3146, 0.0316, 0.2250, 0.5729],
[0.7725, 0.4596, 0.2495, 0.3633, 0.6340],
]
)
target = torch.tensor(
[
[0.5365, 0.1495, 0.8120, 0.2595, 0.1409],
[0.7784, 0.7070, 0.9066, 0.0123, 0.5587],
[0.9097, 0.0773, 0.9430, 0.2747, 0.1912],
[0.2307, 0.5068, 0.4624, 0.6708, 0.2844],
[0.3356, 0.5885, 0.2484, 0.8468, 0.1833],
[0.3354, 0.8831, 0.3489, 0.7165, 0.7953],
[0.7577, 0.8578, 0.2735, 0.1029, 0.5621],
[0.9124, 0.3476, 0.2012, 0.5830, 0.4615],
]
)
loss = qn.sample_energy_loss(prediction, target)
self.assertAlmostEqual(loss.item(), 0.09165202)
def test_cvar(self):
qn = QuantileNetwork(10, 1, quantile_count=5)
measure = qn.measures[qn.measure_cvar](qn, 0.5)
# Quantiles for 3 agents
input = torch.tensor(
[
[0.1056, 0.0609, 0.3523, 0.3033, 0.1779],
[0.2049, 0.1425, 0.0767, 0.1868, 0.3891],
[0.1899, 0.1527, 0.2420, 0.2623, 0.1532],
]
)
correct_output = torch.tensor(
[
(0.4 * 0.0609 + 0.4 * 0.1056 + 0.2 * 0.1779),
(0.4 * 0.0767 + 0.4 * 0.1425 + 0.2 * 0.1868),
(0.4 * 0.1527 + 0.4 * 0.1532 + 0.2 * 0.1899),
]
)
computed_output = measure(input)
self.assertTrue(torch.isclose(computed_output, correct_output).all())
def test_cvar_adaptive(self):
qn = QuantileNetwork(10, 1, quantile_count=5)
input = torch.tensor(
[
[0.95, 0.21, 0.27, 0.26, 0.19],
[0.38, 0.34, 0.18, 0.32, 0.97],
[0.70, 0.24, 0.38, 0.89, 0.96],
]
)
confidence_levels = torch.tensor([0.1, 0.7, 0.9])
correct_output = torch.tensor(
[
0.19,
(0.18 / 3.5 + 0.32 / 3.5 + 0.34 / 3.5 + 0.38 / 7.0),
(0.24 / 4.5 + 0.38 / 4.5 + 0.70 / 4.5 + 0.89 / 4.5 + 0.96 / 9.0),
]
)
measure = qn.measures[qn.measure_cvar](qn, confidence_levels)
computed_output = measure(input)
self.assertTrue(torch.isclose(computed_output, correct_output).all())

View File

@ -0,0 +1,33 @@
import torch
import unittest
from rsl_rl.utils.recurrency import trajectories_to_transitions, transitions_to_trajectories
class TrajectoryConversionTest(unittest.TestCase):
def test_basic_conversion(self):
input = torch.rand(128, 24)
dones = (torch.rand(128, 24) > 0.8).float()
trajectories, data = transitions_to_trajectories(input, dones)
transitions = trajectories_to_transitions(trajectories, data)
self.assertTrue(torch.allclose(input, transitions))
def test_2d_observations(self):
input = torch.rand(128, 24, 32)
dones = (torch.rand(128, 24) > 0.8).float()
trajectories, data = transitions_to_trajectories(input, dones)
transitions = trajectories_to_transitions(trajectories, data)
self.assertTrue(torch.allclose(input, transitions))
def test_batch_first(self):
input = torch.rand(128, 24, 32)
dones = (torch.rand(128, 24) > 0.8).float()
trajectories, data = transitions_to_trajectories(input, dones, batch_first=True)
transitions = trajectories_to_transitions(trajectories, data)
self.assertTrue(torch.allclose(input, transitions))

58
tests/test_transformer.py Normal file
View File

@ -0,0 +1,58 @@
import unittest
import torch
from rsl_rl.modules import Transformer # Assuming the Transformer class is in a module named my_module
class TestTransformer(unittest.TestCase):
def setUp(self):
self.input_size = 9
self.output_size = 12
self.hidden_size = 64
self.block_count = 2
self.context_length = 32
self.head_count = 4
self.batch_size = 10
self.sequence_length = 16
self.transformer = Transformer(
self.input_size, self.output_size, self.hidden_size, self.block_count, self.context_length, self.head_count
)
def test_num_layers(self):
self.assertEqual(self.transformer.num_layers, self.context_length // 2)
def test_reset_hidden_state(self):
hidden_state = self.transformer.reset_hidden_state(self.batch_size)
self.assertIsInstance(hidden_state, tuple)
self.assertEqual(len(hidden_state), 2)
self.assertTrue(
torch.equal(hidden_state[0], torch.zeros((self.transformer.num_layers, self.batch_size, self.hidden_size)))
)
self.assertTrue(
torch.equal(hidden_state[1], torch.zeros((self.transformer.num_layers, self.batch_size, self.hidden_size)))
)
def test_step(self):
x = torch.rand(self.sequence_length, self.batch_size, self.input_size)
context = torch.rand(self.context_length, self.batch_size, self.hidden_size)
out, new_context = self.transformer.step(x, context)
self.assertEqual(out.shape, (self.sequence_length, self.batch_size, self.output_size))
self.assertEqual(new_context.shape, (self.context_length, self.batch_size, self.hidden_size))
def test_forward(self):
x = torch.rand(self.sequence_length, self.batch_size, self.input_size)
hidden_state = self.transformer.reset_hidden_state(self.batch_size)
out, new_hidden_state = self.transformer.forward(x, hidden_state)
self.assertEqual(out.shape, (self.sequence_length, self.batch_size, self.output_size))
self.assertEqual(len(new_hidden_state), 2)
self.assertEqual(new_hidden_state[0].shape, (self.transformer.num_layers, self.batch_size, self.hidden_size))
self.assertEqual(new_hidden_state[1].shape, (self.transformer.num_layers, self.batch_size, self.hidden_size))
if __name__ == "__main__":
unittest.main()