diff --git a/.gitignore b/.gitignore
index 6009f71..e47297e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,6 +7,17 @@
# cache
__pycache__
.pytest_cache
+wandb/
# vs code
.vscode
+
+# data
+videos/
+
+# secrets
+examples/wandb_config.py
+
+# docs
+docs/_build
+docs/source
diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index e3153d9..54f9094 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -25,6 +25,7 @@ Please keep the lists sorted alphabetically.
* Eric Vollenweider
* Fabian Jenelten
* Lorenzo Terenzi
+* Lukas Schneider
* Marko Bjelonic
* Matthijs van der Boon
* Mayank Mittal
diff --git a/README.md b/README.md
index 6b5c2d8..734376a 100644
--- a/README.md
+++ b/README.md
@@ -1,57 +1,64 @@
# RSL RL
Fast and simple implementation of RL algorithms, designed to run fully on GPU.
-This code is an evolution of `rl-pytorch` provided with NVIDIA's Isaac GYM.
-Only PPO is implemented for now. More algorithms will be added later.
-Contributions are welcome.
+Currently, the following algorithms are implemented:
+- Distributed Distributional DDPG (D4PG)
+- Deep Deterministic Policy Gradient (DDPG)
+- Distributional PPO (DPPO)
+- Distributional Soft Actor Critic (DSAC)
+- Proximal Policy Optimization (PPO)
+- Soft Actor Critic (SAC)
+- Twin Delayed DDPG (TD3)
-**Maintainer**: David Hoeller and Nikita Rudin
+**Maintainer**: David Hoeller, Nikita Rudin
**Affiliation**: Robotic Systems Lab, ETH Zurich & NVIDIA
**Contact**: rudinn@ethz.ch
-## Setup
+## Installation
-Following are the instructions to setup the repository for your workspace:
+To install the package, run the following command in the root directory of the repository:
```bash
-git clone https://github.com/leggedrobotics/rsl_rl
-cd rsl_rl
-pip install -e .
+$ pip3 install -e .
```
-The framework supports the following logging frameworks which can be configured through `logger`:
+Examples can be run from the `examples/` directory.
+The example directory also include hyperparameters tuned for some gym environments.
+These are automatically loaded when running the example.
+Videos of the trained policies are periodically saved to the `videos/` directory.
-* Tensorboard: https://www.tensorflow.org/tensorboard/
-* Weights & Biases: https://wandb.ai/site
-* Neptune: https://docs.neptune.ai/
+```bash
+$ python3 examples/example.py
+```
-For a demo configuration of the PPO, please check: [dummy_config.yaml](config/dummy_config.yaml) file.
+To run gym mujoco environments, you need a working installation of the mujoco simulator and [mujoco_py](https://github.com/openai/mujoco-py).
+## Tests
+
+The repository contains a set of tests to ensure that the algorithms are working as expected.
+To run the tests, simply execute:
+
+```bash
+$ cd tests/ && python -m unittest
+```
+
+## Documentation
+
+To generate documentation, run the following command in the root directory of the repository:
+
+```bash
+$ pip3 install sphinx sphinx-rtd-theme
+$ sphinx-apidoc -o docs/source . ./examples
+$ cd docs/ && make html
+```
## Contribution Guidelines
-For documentation, we adopt the [Google Style Guide](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) for docstrings. We use [Sphinx](https://www.sphinx-doc.org/en/master/) for generating the documentation. Please make sure that your code is well-documented and follows the guidelines.
-
-We use the following tools for maintaining code quality:
-
-- [pre-commit](https://pre-commit.com/): Runs a list of formatters and linters over the codebase.
-- [black](https://black.readthedocs.io/en/stable/): The uncompromising code formatter.
-- [flake8](https://flake8.pycqa.org/en/latest/): A wrapper around PyFlakes, pycodestyle, and McCabe complexity checker.
-
-Please check [here](https://pre-commit.com/#install) for instructions to set these up. To run over the entire repository, please execute the following command in the terminal:
-
+We use [`black`](https://github.com/psf/black) formatter for formatting the python code.
+You should [configure `black` with VSCode](https://dev.to/adamlombard/how-to-use-the-black-python-code-formatter-in-vscode-3lo0) or you can manually format files with:
```bash
-# for installation (only once)
-pre-commit install
-# for running
-pre-commit run --all-files
+$ pip install black
+$ black --line-length 120 .
```
-
-### Useful Links
-
-Environment repositories using the framework:
-
-* `Legged-Gym` (built on top of NVIDIA Isaac Gym): https://leggedrobotics.github.io/legged_gym/
-* `Orbit` (built on top of NVIDIA Isaac Sim): https://isaac-orbit.github.io/
diff --git a/config/dummy_config.yaml b/config/dummy_config.yaml
deleted file mode 100644
index aaf5d21..0000000
--- a/config/dummy_config.yaml
+++ /dev/null
@@ -1,48 +0,0 @@
-algorithm:
- class_name: PPO
- # training parameters
- # -- value function
- value_loss_coef: 1.0
- clip_param: 0.2
- use_clipped_value_loss: true
- # -- surrogate loss
- desired_kl: 0.01
- entropy_coef: 0.01
- gamma: 0.99
- lam: 0.95
- max_grad_norm: 1.0
- # -- training
- learning_rate: 0.001
- num_learning_epochs: 5
- num_mini_batches: 4 # mini batch size = num_envs * num_steps / num_mini_batches
- schedule: adaptive # adaptive, fixed
-policy:
- class_name: ActorCritic
- # for MLP i.e. `ActorCritic`
- activation: elu
- actor_hidden_dims: [128, 128, 128]
- critic_hidden_dims: [128, 128, 128]
- init_noise_std: 1.0
- # only needed for `ActorCriticRecurrent`
- # rnn_type: 'lstm'
- # rnn_hidden_size: 512
- # rnn_num_layers: 1
-runner:
- num_steps_per_env: 24 # number of steps per environment per iteration
- max_iterations: 1500 # number of policy updates
- empirical_normalization: false
- # -- logging parameters
- save_interval: 50 # check for potential saves every `save_interval` iterations
- experiment_name: walking_experiment
- run_name: ""
- # -- logging writer
- logger: tensorboard # tensorboard, neptune, wandb
- neptune_project: legged_gym
- wandb_project: legged_gym
- # -- load and resuming
- resume: false
- load_run: -1 # -1 means load latest run
- resume_path: null # updated from load_run and checkpoint
- checkpoint: -1 # -1 means load latest checkpoint
-runner_class_name: OnPolicyRunner
-seed: 1
diff --git a/docs/Makefile b/docs/Makefile
new file mode 100644
index 0000000..d4bb2cb
--- /dev/null
+++ b/docs/Makefile
@@ -0,0 +1,20 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line, and also
+# from the environment for the first two.
+SPHINXOPTS ?=
+SPHINXBUILD ?= sphinx-build
+SOURCEDIR = .
+BUILDDIR = _build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
diff --git a/docs/conf.py b/docs/conf.py
new file mode 100644
index 0000000..4dbe906
--- /dev/null
+++ b/docs/conf.py
@@ -0,0 +1,32 @@
+# Configuration file for the Sphinx documentation builder.
+#
+# For the full list of built-in configuration values, see the documentation:
+# https://www.sphinx-doc.org/en/master/usage/configuration.html
+
+import os
+import sys
+
+sys.path.insert(0, os.path.abspath(".."))
+
+# -- Project information -----------------------------------------------------
+# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
+
+project = "rsl_rl"
+copyright = "2023, Lukas Schneider"
+author = "Lukas Schneider"
+release = "1.0.0"
+
+# -- General configuration ---------------------------------------------------
+# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
+
+extensions = ["sphinx.ext.autodoc", "sphinx.ext.napoleon"]
+
+templates_path = ["_templates"]
+exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
+
+
+# -- Options for HTML output -------------------------------------------------
+# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
+
+html_theme = "sphinx_rtd_theme"
+html_static_path = ["_static"]
diff --git a/docs/index.rst b/docs/index.rst
new file mode 100644
index 0000000..4af7326
--- /dev/null
+++ b/docs/index.rst
@@ -0,0 +1,20 @@
+.. rsl_rl documentation master file, created by
+ sphinx-quickstart on Tue Jul 4 16:39:24 2023.
+ You can adapt this file completely to your liking, but it should at least
+ contain the root `toctree` directive.
+
+Welcome to rsl_rl's documentation!
+==================================
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Contents:
+
+
+
+Indices and tables
+==================
+
+* :ref:`genindex`
+* :ref:`modindex`
+* :ref:`search`
diff --git a/docs/make.bat b/docs/make.bat
new file mode 100644
index 0000000..32bb245
--- /dev/null
+++ b/docs/make.bat
@@ -0,0 +1,35 @@
+@ECHO OFF
+
+pushd %~dp0
+
+REM Command file for Sphinx documentation
+
+if "%SPHINXBUILD%" == "" (
+ set SPHINXBUILD=sphinx-build
+)
+set SOURCEDIR=.
+set BUILDDIR=_build
+
+%SPHINXBUILD% >NUL 2>NUL
+if errorlevel 9009 (
+ echo.
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
+ echo.installed, then set the SPHINXBUILD environment variable to point
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
+ echo.may add the Sphinx directory to PATH.
+ echo.
+ echo.If you don't have Sphinx installed, grab it from
+ echo.https://www.sphinx-doc.org/
+ exit /b 1
+)
+
+if "%1" == "" goto help
+
+%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+goto end
+
+:help
+%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+
+:end
+popd
diff --git a/examples/__init__.py b/examples/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/examples/benchmark.py b/examples/benchmark.py
new file mode 100644
index 0000000..3bba830
--- /dev/null
+++ b/examples/benchmark.py
@@ -0,0 +1,92 @@
+import numpy as np
+import os
+import torch
+import wandb
+
+from rsl_rl.algorithms import *
+from rsl_rl.env.gym_env import GymEnv
+from rsl_rl.runners.runner import Runner
+from rsl_rl.runners.callbacks import make_wandb_cb
+
+from hyperparams import hyperparams
+from wandb_config import WANDB_API_KEY, WANDB_ENTITY
+
+
+ALGORITHMS = [PPO, DPPO]
+ENVIRONMENTS = ["BipedalWalker-v3"]
+ENVIRONMENT_KWARGS = [{}]
+EXPERIMENT_DIR = os.environ.get("EXPERIMENT_DIRECTORY", "./")
+EXPERIMENT_NAME = os.environ.get("EXPERIMENT_NAME", "benchmark")
+DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
+RENDER_VIDEO = False
+RETURN_EPOCHS = 100 # Number of epochs to average return over
+LOG_WANDB = True
+RUNS = 3
+TRAIN_TIMEOUT = 60 * 10 # Training time (in seconds)
+TRAIN_ENV_STEPS = None # Number of training environment steps
+
+
+os.environ["WANDB_API_KEY"] = WANDB_API_KEY
+
+
+def run(alg_class, env_name, env_kwargs={}):
+ try:
+ hp = hyperparams[alg_class.__name__][env_name]
+ except KeyError:
+ print("No hyperparameters found. Using default values.")
+ hp = dict(agent_kwargs={}, env_kwargs={"environment_count": 1}, runner_kwargs={"num_steps_per_env": 1})
+
+ agent_kwargs = dict(device=DEVICE, **hp["agent_kwargs"])
+ env_kwargs = dict(name=env_name, gym_kwargs=env_kwargs, **hp["env_kwargs"])
+ runner_kwargs = dict(device=DEVICE, **hp["runner_kwargs"])
+
+ learn_steps = (
+ None
+ if TRAIN_ENV_STEPS is None
+ else int(np.ceil(TRAIN_ENV_STEPS / (env_kwargs["environment_count"] * runner_kwargs["num_steps_per_env"])))
+ )
+ learn_timeout = None if TRAIN_TIMEOUT is None else TRAIN_TIMEOUT
+
+ video_directory = f"{EXPERIMENT_DIR}/{EXPERIMENT_NAME}/videos/{env_name}/{alg_class.__name__}"
+ save_video_cb = (
+ lambda ep, file: wandb.log({f"video-{ep}": wandb.Video(file, fps=4, format="mp4")}) if LOG_WANDB else None
+ )
+ env = GymEnv(**env_kwargs, draw=RENDER_VIDEO, draw_cb=save_video_cb, draw_directory=video_directory)
+ agent = alg_class(env, **agent_kwargs)
+
+ config = dict(
+ agent_kwargs=agent_kwargs,
+ env_kwargs=env_kwargs,
+ learn_steps=learn_steps,
+ learn_timeout=learn_timeout,
+ runner_kwargs=runner_kwargs,
+ )
+ wandb_learn_config = dict(
+ config=config,
+ entity=WANDB_ENTITY,
+ group=f"{alg_class.__name__}_{env_name}",
+ project="rsl_rl-benchmark",
+ tags=[alg_class.__name__, env_name, "train"],
+ )
+
+ runner = Runner(env, agent, **runner_kwargs)
+ runner._learn_cb = [lambda *args, **kwargs: Runner._log(*args, prefix=f"{alg_class.__name__}_{env_name}", **kwargs)]
+ if LOG_WANDB:
+ runner._learn_cb.append(make_wandb_cb(wandb_learn_config))
+
+ runner.learn(iterations=learn_steps, timeout=learn_timeout, return_epochs=RETURN_EPOCHS)
+
+ env.close()
+
+
+def main():
+ for algorithm in ALGORITHMS:
+ for i, env_name in enumerate(ENVIRONMENTS):
+ env_kwargs = ENVIRONMENT_KWARGS[i]
+
+ for _ in range(RUNS):
+ run(algorithm, env_name, env_kwargs=env_kwargs)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/example.py b/examples/example.py
new file mode 100644
index 0000000..0392d89
--- /dev/null
+++ b/examples/example.py
@@ -0,0 +1,26 @@
+import torch
+
+from rsl_rl.algorithms import *
+from rsl_rl.env.gym_env import GymEnv
+from rsl_rl.runners.runner import Runner
+from hyperparams import hyperparams
+
+
+ALGORITHM = DPPO
+DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
+TASK = "BipedalWalker-v3"
+
+
+def main():
+ hp = hyperparams[ALGORITHM.__name__][TASK]
+
+ env = GymEnv(name=TASK, device=DEVICE, draw=True, **hp["env_kwargs"])
+ agent = ALGORITHM(env, benchmark=True, device=DEVICE, **hp["agent_kwargs"])
+ runner = Runner(env, agent, device=DEVICE, **hp["runner_kwargs"])
+ runner._learn_cb = [Runner._log]
+
+ runner.learn(5000)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/hyperparams/__init__.py b/examples/hyperparams/__init__.py
new file mode 100644
index 0000000..8d9fb1e
--- /dev/null
+++ b/examples/hyperparams/__init__.py
@@ -0,0 +1,7 @@
+from rsl_rl.algorithms import PPO, DPPO
+from .dppo import dppo_hyperparams
+from .ppo import ppo_hyperparams
+
+hyperparams = {DPPO.__name__: dppo_hyperparams, PPO.__name__: ppo_hyperparams}
+
+__all__ = ["hyperparams"]
diff --git a/examples/hyperparams/dppo.py b/examples/hyperparams/dppo.py
new file mode 100644
index 0000000..5f8ecad
--- /dev/null
+++ b/examples/hyperparams/dppo.py
@@ -0,0 +1,202 @@
+import copy
+import numpy as np
+
+from rsl_rl.algorithms import DPPO
+from rsl_rl.modules import QuantileNetwork
+
+default = dict()
+default["env_kwargs"] = dict(environment_count=1)
+default["runner_kwargs"] = dict(num_steps_per_env=2048)
+default["agent_kwargs"] = dict(
+ actor_activations=["tanh", "tanh", "linear"],
+ actor_hidden_dims=[64, 64],
+ actor_input_normalization=False,
+ actor_noise_std=np.exp(0.0),
+ batch_count=(default["env_kwargs"]["environment_count"] * default["runner_kwargs"]["num_steps_per_env"] // 64),
+ clip_ratio=0.2,
+ critic_activations=["tanh", "tanh"],
+ critic_hidden_dims=[64, 64],
+ critic_input_normalization=False,
+ entropy_coeff=0.0,
+ gae_lambda=0.95,
+ gamma=0.99,
+ gradient_clip=0.5,
+ learning_rate=0.0003,
+ qrdqn_quantile_count=50,
+ schedule="adaptive",
+ target_kl=0.01,
+ value_coeff=0.5,
+ value_measure=QuantileNetwork.measure_neutral,
+ value_measure_kwargs={},
+)
+
+# Parameters optimized for PPO
+ant_v4 = copy.deepcopy(default)
+ant_v4["env_kwargs"]["environment_count"] = 128
+ant_v4["runner_kwargs"]["num_steps_per_env"] = 64
+ant_v4["agent_kwargs"]["actor_activations"] = ["tanh", "tanh", "linear"]
+ant_v4["agent_kwargs"]["actor_hidden_dims"] = [64, 64]
+ant_v4["agent_kwargs"]["actor_noise_std"] = 0.2611
+ant_v4["agent_kwargs"]["batch_count"] = 12
+ant_v4["agent_kwargs"]["clip_ratio"] = 0.4
+ant_v4["agent_kwargs"]["critic_activations"] = ["tanh", "tanh"]
+ant_v4["agent_kwargs"]["critic_hidden_dims"] = [64, 64]
+ant_v4["agent_kwargs"]["entropy_coeff"] = 0.0102
+ant_v4["agent_kwargs"]["gae_lambda"] = 0.92
+ant_v4["agent_kwargs"]["gamma"] = 0.9731
+ant_v4["agent_kwargs"]["gradient_clip"] = 5.0
+ant_v4["agent_kwargs"]["learning_rate"] = 0.8755
+ant_v4["agent_kwargs"]["target_kl"] = 0.1711
+ant_v4["agent_kwargs"]["value_coeff"] = 0.6840
+
+"""
+Tuned for environment interactions:
+[I 2023-01-03 03:11:29,212] Trial 19 finished with value: 0.5272218152693111 and parameters: {
+ 'env_count': 16,
+ 'actor_noise_std': 0.7304437880901905,
+ 'batch_count': 10,
+ 'clip_ratio': 0.3,
+ 'entropy_coeff': 0.004236574285220795,
+ 'gae_lambda': 0.95,
+ 'gamma': 0.9890074826092162,
+ 'gradient_clip': 0.9,
+ 'learning_rate': 0.18594043324129061,
+ 'steps_per_env': 256,
+ 'target_kl': 0.05838576142010138,
+ 'value_coeff': 0.14402022632575992,
+ 'net_arch': 'small',
+ 'net_activation': 'relu'
+}. Best is trial 19 with value: 0.5272218152693111.
+Tuned for training time:
+[I 2023-01-08 21:09:06,069] Trial 407 finished with value: 7.497591958940029 and parameters: {
+ 'actor_noise_std': 0.1907398121300662,
+ 'batch_count': 3,
+ 'clip_ratio': 0.1,
+ 'entropy_coeff': 0.0053458057035692735,
+ 'env_count': 16,
+ 'gae_lambda': 0.8,
+ 'gamma': 0.985000267068182,
+ 'gradient_clip': 2.0,
+ 'learning_rate': 0.605956844400053,
+ 'steps_per_env': 512,
+ 'target_kl': 0.17611450607281642,
+ 'value_coeff': 0.46015664905111847,
+ 'actor_net_arch': 'small',
+ 'critic_net_arch': 'medium',
+ 'actor_net_activation': 'relu',
+ 'critic_net_activation': 'relu',
+ 'qrdqn_quantile_count': 200,
+ 'value_measure': 'neutral'
+}. Best is trial 407 with value: 7.497591958940029.
+"""
+bipedal_walker_v3 = copy.deepcopy(default)
+bipedal_walker_v3["env_kwargs"]["environment_count"] = 256
+bipedal_walker_v3["runner_kwargs"]["num_steps_per_env"] = 16
+bipedal_walker_v3["agent_kwargs"]["actor_activations"] = ["relu", "relu", "relu", "linear"]
+bipedal_walker_v3["agent_kwargs"]["actor_hidden_dims"] = [512, 256, 128]
+bipedal_walker_v3["agent_kwargs"]["actor_noise_std"] = 0.8505
+bipedal_walker_v3["agent_kwargs"]["batch_count"] = 10
+bipedal_walker_v3["agent_kwargs"]["clip_ratio"] = 0.1
+bipedal_walker_v3["agent_kwargs"]["critic_activations"] = ["relu", "relu"]
+bipedal_walker_v3["agent_kwargs"]["critic_hidden_dims"] = [256, 256]
+bipedal_walker_v3["agent_kwargs"]["critic_network"] = DPPO.network_qrdqn
+bipedal_walker_v3["agent_kwargs"]["entropy_coeff"] = 0.0917
+bipedal_walker_v3["agent_kwargs"]["gae_lambda"] = 0.95
+bipedal_walker_v3["agent_kwargs"]["gamma"] = 0.9553
+bipedal_walker_v3["agent_kwargs"]["gradient_clip"] = 2.0
+bipedal_walker_v3["agent_kwargs"]["iqn_action_samples"] = 32
+bipedal_walker_v3["agent_kwargs"]["iqn_embedding_size"] = 64
+bipedal_walker_v3["agent_kwargs"]["iqn_feature_layers"] = 1
+bipedal_walker_v3["agent_kwargs"]["iqn_value_samples"] = 8
+bipedal_walker_v3["agent_kwargs"]["learning_rate"] = 0.4762
+bipedal_walker_v3["agent_kwargs"]["qrdqn_quantile_count"] = 200
+bipedal_walker_v3["agent_kwargs"]["recurrent"] = False
+bipedal_walker_v3["agent_kwargs"]["target_kl"] = 0.1999
+bipedal_walker_v3["agent_kwargs"]["value_coeff"] = 0.4435
+
+"""
+[I 2023-01-12 08:01:35,514] Trial 476 finished with value: 5202.960759290059 and parameters: {
+ 'actor_noise_std': 0.15412869066185989,
+ 'batch_count': 11,
+ 'clip_ratio': 0.3,
+ 'entropy_coeff': 0.036031209302206955,
+ 'env_count': 128,
+ 'gae_lambda': 0.92,
+ 'gamma': 0.973937576989299,
+ 'gradient_clip': 5.0,
+ 'learning_rate': 0.1621249118505433,
+ 'steps_per_env': 128,
+ 'target_kl': 0.05054738172852222,
+ 'value_coeff': 0.1647632125820593,
+ 'actor_net_arch': 'small',
+ 'critic_net_arch': 'medium',
+ 'actor_net_activation': 'tanh',
+ 'critic_net_activation': 'relu',
+ 'qrdqn_quantile_count': 50,
+ 'value_measure': 'var-risk-averse'
+}. Best is trial 476 with value: 5202.960759290059.
+"""
+half_cheetah_v4 = copy.deepcopy(default)
+half_cheetah_v4["env_kwargs"]["environment_count"] = 128
+half_cheetah_v4["runner_kwargs"]["num_steps_per_env"] = 128
+half_cheetah_v4["agent_kwargs"]["actor_activations"] = ["tanh", "tanh", "linear"]
+half_cheetah_v4["agent_kwargs"]["actor_hidden_dims"] = [64, 64]
+half_cheetah_v4["agent_kwargs"]["actor_noise_std"] = 0.1541
+half_cheetah_v4["agent_kwargs"]["batch_count"] = 11
+half_cheetah_v4["agent_kwargs"]["clip_ratio"] = 0.3
+half_cheetah_v4["agent_kwargs"]["critic_activations"] = ["relu", "relu"]
+half_cheetah_v4["agent_kwargs"]["critic_hidden_dims"] = [256, 256]
+half_cheetah_v4["agent_kwargs"]["entropy_coeff"] = 0.03603
+half_cheetah_v4["agent_kwargs"]["gae_lambda"] = 0.92
+half_cheetah_v4["agent_kwargs"]["gamma"] = 0.9739
+half_cheetah_v4["agent_kwargs"]["gradient_clip"] = 5.0
+half_cheetah_v4["agent_kwargs"]["learning_rate"] = 0.1621
+half_cheetah_v4["agent_kwargs"]["qrdqn_quantile_count"] = 50
+half_cheetah_v4["agent_kwargs"]["target_kl"] = 0.0505
+half_cheetah_v4["agent_kwargs"]["value_coeff"] = 0.1648
+half_cheetah_v4["agent_kwargs"]["value_measure"] = QuantileNetwork.measure_percentile
+half_cheetah_v4["agent_kwargs"]["value_measure_kwargs"] = dict(confidence_level=0.25)
+
+
+# Parameters optimized for PPO
+hopper_v4 = copy.deepcopy(default)
+hopper_v4["runner_kwargs"]["num_steps_per_env"] = 128
+hopper_v4["agent_kwargs"]["actor_activations"] = ["relu", "relu", "linear"]
+hopper_v4["agent_kwargs"]["actor_hidden_dims"] = [256, 256]
+hopper_v4["agent_kwargs"]["actor_noise_std"] = 0.5590
+hopper_v4["agent_kwargs"]["batch_count"] = 15
+hopper_v4["agent_kwargs"]["clip_ratio"] = 0.2
+hopper_v4["agent_kwargs"]["critic_activations"] = ["relu", "relu", "linear"]
+hopper_v4["agent_kwargs"]["critic_hidden_dims"] = [32, 32]
+hopper_v4["agent_kwargs"]["entropy_coeff"] = 0.03874
+hopper_v4["agent_kwargs"]["gae_lambda"] = 0.98
+hopper_v4["agent_kwargs"]["gamma"] = 0.9890
+hopper_v4["agent_kwargs"]["gradient_clip"] = 1.0
+hopper_v4["agent_kwargs"]["learning_rate"] = 0.3732
+hopper_v4["agent_kwargs"]["value_coeff"] = 0.8163
+
+swimmer_v4 = copy.deepcopy(default)
+swimmer_v4["agent_kwargs"]["gamma"] = 0.9999
+
+walker2d_v4 = copy.deepcopy(default)
+walker2d_v4["runner_kwargs"]["num_steps_per_env"] = 512
+walker2d_v4["agent_kwargs"]["batch_count"] = (
+ walker2d_v4["env_kwargs"]["environment_count"] * walker2d_v4["runner_kwargs"]["num_steps_per_env"] // 32
+)
+walker2d_v4["agent_kwargs"]["clip_ratio"] = 0.1
+walker2d_v4["agent_kwargs"]["entropy_coeff"] = 0.000585045
+walker2d_v4["agent_kwargs"]["gae_lambda"] = 0.95
+walker2d_v4["agent_kwargs"]["gamma"] = 0.99
+walker2d_v4["agent_kwargs"]["gradient_clip"] = 1.0
+walker2d_v4["agent_kwargs"]["learning_rate"] = 5.05041e-05
+walker2d_v4["agent_kwargs"]["value_coeff"] = 0.871923
+
+dppo_hyperparams = {
+ "default": default,
+ "Ant-v4": ant_v4,
+ "BipedalWalker-v3": bipedal_walker_v3,
+ "HalfCheetah-v4": half_cheetah_v4,
+ "Hopper-v4": hopper_v4,
+ "Swimmer-v4": swimmer_v4,
+ "Walker2d-v4": walker2d_v4,
+}
diff --git a/examples/hyperparams/ppo.py b/examples/hyperparams/ppo.py
new file mode 100644
index 0000000..d45dd8e
--- /dev/null
+++ b/examples/hyperparams/ppo.py
@@ -0,0 +1,226 @@
+import copy
+import numpy as np
+
+default = dict()
+default["env_kwargs"] = dict(environment_count=1)
+default["runner_kwargs"] = dict(num_steps_per_env=2048)
+default["agent_kwargs"] = dict(
+ actor_activations=["tanh", "tanh", "linear"],
+ actor_hidden_dims=[64, 64],
+ actor_input_normalization=False,
+ actor_noise_std=np.exp(0.0),
+ batch_count=(default["env_kwargs"]["environment_count"] * default["runner_kwargs"]["num_steps_per_env"] // 64),
+ clip_ratio=0.2,
+ critic_activations=["tanh", "tanh", "linear"],
+ critic_hidden_dims=[64, 64],
+ critic_input_normalization=False,
+ entropy_coeff=0.0,
+ gae_lambda=0.95,
+ gamma=0.99,
+ gradient_clip=0.5,
+ learning_rate=0.0003,
+ schedule="adaptive",
+ target_kl=0.01,
+ value_coeff=0.5,
+)
+
+"""
+[I 2023-01-09 00:33:02,217] Trial 85 finished with value: 2191.0249068421276 and parameters: {
+ 'actor_noise_std': 0.2611334861249876,
+ 'batch_count': 12,
+ 'clip_ratio': 0.4,
+ 'entropy_coeff': 0.010204149626344796,
+ 'env_count': 128,
+ 'gae_lambda': 0.92,
+ 'gamma': 0.9730549104215155,
+ 'gradient_clip': 5.0,
+ 'learning_rate': 0.8754540531090014,
+ 'steps_per_env': 64,
+ 'target_kl': 0.17110535070344035,
+ 'value_coeff': 0.6840401569818934,
+ 'actor_net_arch': 'small',
+ 'critic_net_arch': 'small',
+ 'actor_net_activation': 'tanh',
+ 'critic_net_activation': 'tanh'
+}. Best is trial 85 with value: 2191.0249068421276.
+"""
+ant_v3 = copy.deepcopy(default)
+ant_v3["env_kwargs"]["environment_count"] = 128
+ant_v3["runner_kwargs"]["num_steps_per_env"] = 64
+ant_v3["agent_kwargs"]["actor_activations"] = ["tanh", "tanh", "linear"]
+ant_v3["agent_kwargs"]["actor_hidden_dims"] = [64, 64]
+ant_v3["agent_kwargs"]["actor_noise_std"] = 0.2611
+ant_v3["agent_kwargs"]["batch_count"] = 12
+ant_v3["agent_kwargs"]["clip_ratio"] = 0.4
+ant_v3["agent_kwargs"]["critic_activations"] = ["tanh", "tanh", "linear"]
+ant_v3["agent_kwargs"]["critic_hidden_dims"] = [64, 64]
+ant_v3["agent_kwargs"]["entropy_coeff"] = 0.0102
+ant_v3["agent_kwargs"]["gae_lambda"] = 0.92
+ant_v3["agent_kwargs"]["gamma"] = 0.9731
+ant_v3["agent_kwargs"]["gradient_clip"] = 5.0
+ant_v3["agent_kwargs"]["learning_rate"] = 0.8755
+ant_v3["agent_kwargs"]["target_kl"] = 0.1711
+ant_v3["agent_kwargs"]["value_coeff"] = 0.6840
+
+"""
+Standard:
+[I 2023-01-17 07:43:46,884] Trial 125 finished with value: 150.23491836690064 and parameters: {
+ 'actor_net_activation': 'relu',
+ 'actor_net_arch': 'large',
+ 'actor_noise_std': 0.8504545432069994,
+ 'batch_count': 10,
+ 'clip_ratio': 0.1,
+ 'critic_net_activation': 'relu',
+ 'critic_net_arch': 'medium',
+ 'entropy_coeff': 0.0916881539697197,
+ 'env_count': 256,
+ 'gae_lambda': 0.95,
+ 'gamma': 0.955285858564339,
+ 'gradient_clip': 2.0,
+ 'learning_rate': 0.4762365866431558,
+ 'steps_per_env': 16,
+ 'recurrent': False,
+ 'target_kl': 0.19991906392721126,
+ 'value_coeff': 0.4434793554275927
+}. Best is trial 125 with value: 150.23491836690064.
+Hardcore:
+[I 2023-01-09 06:25:44,000] Trial 262 finished with value: 2.290071208278338 and parameters: {
+ 'actor_noise_std': 0.2710521003644249,
+ 'batch_count': 6,
+ 'clip_ratio': 0.1,
+ 'entropy_coeff': 0.005105282891378981,
+ 'env_count': 16,
+ 'gae_lambda': 1.0,
+ 'gamma': 0.9718119008688937,
+ 'gradient_clip': 0.1,
+ 'learning_rate': 0.4569184610431825,
+ 'steps_per_env': 256,
+ 'target_kl': 0.11068348002480229,
+ 'value_coeff': 0.19453900570701116,
+ 'actor_net_arch': 'small',
+ 'critic_net_arch': 'medium',
+ 'actor_net_activation': 'relu',
+ 'critic_net_activation': 'relu'
+}. Best is trial 262 with value: 2.290071208278338.
+"""
+bipedal_walker_v3 = copy.deepcopy(default)
+bipedal_walker_v3["env_kwargs"]["environment_count"] = 256
+bipedal_walker_v3["runner_kwargs"]["num_steps_per_env"] = 16
+bipedal_walker_v3["agent_kwargs"]["actor_activations"] = ["relu", "relu", "relu", "linear"]
+bipedal_walker_v3["agent_kwargs"]["actor_hidden_dims"] = [512, 256, 128]
+bipedal_walker_v3["agent_kwargs"]["actor_noise_std"] = 0.8505
+bipedal_walker_v3["agent_kwargs"]["batch_count"] = 10
+bipedal_walker_v3["agent_kwargs"]["clip_ratio"] = 0.1
+bipedal_walker_v3["agent_kwargs"]["critic_activations"] = ["relu", "relu", "linear"]
+bipedal_walker_v3["agent_kwargs"]["critic_hidden_dims"] = [256, 256]
+bipedal_walker_v3["agent_kwargs"]["entropy_coeff"] = 0.0917
+bipedal_walker_v3["agent_kwargs"]["gae_lambda"] = 0.95
+bipedal_walker_v3["agent_kwargs"]["gamma"] = 0.9553
+bipedal_walker_v3["agent_kwargs"]["gradient_clip"] = 2.0
+bipedal_walker_v3["agent_kwargs"]["learning_rate"] = 0.4762
+bipedal_walker_v3["agent_kwargs"]["target_kl"] = 0.1999
+bipedal_walker_v3["agent_kwargs"]["value_coeff"] = 0.4435
+
+"""
+[I 2023-01-04 05:57:20,749] Trial 1451 finished with value: 5260.338678148058 and parameters: {
+ 'env_count': 32,
+ 'actor_noise_std': 0.3397405098274084,
+ 'batch_count': 6,
+ 'clip_ratio': 0.3,
+ 'entropy_coeff': 0.009392937880259133,
+ 'gae_lambda': 0.8,
+ 'gamma': 0.9683403243382301,
+ 'gradient_clip': 5.0,
+ 'learning_rate': 0.5985206877398142,
+ 'steps_per_env': 16,
+ 'target_kl': 0.027651917189297347,
+ 'value_coeff': 0.26705235341068373,
+ 'net_arch': 'medium',
+ 'net_activation': 'tanh'
+}. Best is trial 1451 with value: 5260.338678148058.
+"""
+half_cheetah_v3 = copy.deepcopy(default)
+half_cheetah_v3["env_kwargs"]["environment_count"] = 32
+half_cheetah_v3["runner_kwargs"]["num_steps_per_env"] = 16
+half_cheetah_v3["agent_kwargs"]["actor_activations"] = ["tanh", "tanh", "linear"]
+half_cheetah_v3["agent_kwargs"]["actor_hidden_dims"] = [256, 256]
+half_cheetah_v3["agent_kwargs"]["actor_noise_std"] = 0.3397
+half_cheetah_v3["agent_kwargs"]["batch_count"] = 6
+half_cheetah_v3["agent_kwargs"]["clip_ratio"] = 0.3
+half_cheetah_v3["agent_kwargs"]["critic_activations"] = ["tanh", "tanh", "linear"]
+half_cheetah_v3["agent_kwargs"]["critic_hidden_dims"] = [256, 256]
+half_cheetah_v3["agent_kwargs"]["entropy_coeff"] = 0.009393
+half_cheetah_v3["agent_kwargs"]["gae_lambda"] = 0.8
+half_cheetah_v3["agent_kwargs"]["gamma"] = 0.9683
+half_cheetah_v3["agent_kwargs"]["gradient_clip"] = 5.0
+half_cheetah_v3["agent_kwargs"]["learning_rate"] = 0.5985
+half_cheetah_v3["agent_kwargs"]["target_kl"] = 0.02765
+half_cheetah_v3["agent_kwargs"]["value_coeff"] = 0.2671
+
+"""
+[I 2023-01-08 18:38:51,481] Trial 25 finished with value: 2225.9547948810073 and parameters: {
+ 'actor_noise_std': 0.5589708917145111,
+ 'batch_count': 15,
+ 'clip_ratio': 0.2,
+ 'entropy_coeff': 0.03874027035272886,
+ 'env_count': 128,
+ 'gae_lambda': 0.98,
+ 'gamma': 0.9879577396280973,
+ 'gradient_clip': 1.0,
+ 'learning_rate': 0.3732431793266761,
+ 'steps_per_env': 128,
+ 'target_kl': 0.12851506672519566,
+ 'value_coeff': 0.8162548885723906,
+ 'actor_net_arch': 'medium',
+ 'critic_net_arch': 'small',
+ 'actor_net_activation': 'relu',
+ 'critic_net_activation': 'relu'
+}. Best is trial 25 with value: 2225.9547948810073.
+"""
+hopper_v3 = copy.deepcopy(default)
+half_cheetah_v3["env_kwargs"]["environment_count"] = 128
+hopper_v3["runner_kwargs"]["num_steps_per_env"] = 128
+hopper_v3["agent_kwargs"]["actor_activations"] = ["relu", "relu", "linear"]
+hopper_v3["agent_kwargs"]["actor_hidden_dims"] = [256, 256]
+hopper_v3["agent_kwargs"]["actor_noise_std"] = 0.5590
+hopper_v3["agent_kwargs"]["batch_count"] = 15
+hopper_v3["agent_kwargs"]["clip_ratio"] = 0.2
+hopper_v3["agent_kwargs"]["critic_activations"] = ["relu", "relu", "linear"]
+hopper_v3["agent_kwargs"]["critic_hidden_dims"] = [32, 32]
+hopper_v3["agent_kwargs"]["entropy_coeff"] = 0.03874
+hopper_v3["agent_kwargs"]["gae_lambda"] = 0.98
+hopper_v3["agent_kwargs"]["gamma"] = 0.9890
+hopper_v3["agent_kwargs"]["gradient_clip"] = 1.0
+hopper_v3["agent_kwargs"]["learning_rate"] = 0.3732
+hopper_v3["agent_kwargs"]["value_coeff"] = 0.8163
+
+swimmer_v3 = copy.deepcopy(default)
+swimmer_v3["agent_kwargs"]["gamma"] = 0.9999
+
+walker2d_v3 = copy.deepcopy(default)
+walker2d_v3["runner_kwargs"]["num_steps_per_env"] = 512
+walker2d_v3["agent_kwargs"]["batch_count"] = (
+ walker2d_v3["env_kwargs"]["environment_count"] * walker2d_v3["runner_kwargs"]["num_steps_per_env"] // 32
+)
+walker2d_v3["agent_kwargs"]["clip_ratio"] = 0.1
+walker2d_v3["agent_kwargs"]["entropy_coeff"] = 0.000585045
+walker2d_v3["agent_kwargs"]["gae_lambda"] = 0.95
+walker2d_v3["agent_kwargs"]["gamma"] = 0.99
+walker2d_v3["agent_kwargs"]["gradient_clip"] = 1.0
+walker2d_v3["agent_kwargs"]["learning_rate"] = 5.05041e-05
+walker2d_v3["agent_kwargs"]["value_coeff"] = 0.871923
+
+ppo_hyperparams = {
+ "default": default,
+ "Ant-v3": ant_v3,
+ "Ant-v4": ant_v3,
+ "BipedalWalker-v3": bipedal_walker_v3,
+ "HalfCheetah-v3": half_cheetah_v3,
+ "HalfCheetah-v4": half_cheetah_v3,
+ "Hopper-v3": hopper_v3,
+ "Hopper-v4": hopper_v3,
+ "Swimmer-v3": swimmer_v3,
+ "Swimmer-v4": swimmer_v3,
+ "Walker2d-v3": walker2d_v3,
+ "Walker2d-v4": walker2d_v3,
+}
diff --git a/examples/tune.py b/examples/tune.py
new file mode 100644
index 0000000..b11c058
--- /dev/null
+++ b/examples/tune.py
@@ -0,0 +1,103 @@
+from rsl_rl.algorithms import *
+from rsl_rl.env.gym_env import GymEnv
+from rsl_rl.runners.runner import Runner
+
+import copy
+from datetime import datetime
+import numpy as np
+import optuna
+import os
+import random
+import torch
+from tune_cfg import samplers
+
+
+ALGORITHM = PPO
+ENVIRONMENT = "BipedalWalker-v3"
+ENVIRONMENT_KWARGS = {}
+EVAL_AGENTS = 64
+EVAL_RUNS = 10
+EVAL_STEPS = 1000
+EXPERIMENT_DIR = os.environ.get("EXPERIMENT_DIRECTORY", "./")
+EXPERIMENT_NAME = os.environ.get("EXPERIMENT_NAME", f"tune-{ALGORITHM.__name__}-{ENVIRONMENT}")
+TRAIN_ITERATIONS = None
+TRAIN_TIMEOUT = 60 * 15 # 10 minutes
+TRAIN_RUNS = 3
+TRAIN_SEED = None
+
+
+def tune():
+ assert TRAIN_RUNS == 1 or TRAIN_SEED is None, "If multiple runs are used, the seed must be None."
+
+ storage = optuna.storages.RDBStorage(url=f"sqlite:///{EXPERIMENT_DIR}/{EXPERIMENT_NAME}.db")
+ pruner = optuna.pruners.MedianPruner(n_startup_trials=10)
+ try:
+ study = optuna.create_study(direction="maximize", pruner=pruner, storage=storage, study_name=EXPERIMENT_NAME)
+ except Exception:
+ study = optuna.load_study(pruner=pruner, storage=storage, study_name=EXPERIMENT_NAME)
+
+ study.optimize(objective, n_trials=100)
+
+
+def seed(s=None):
+ seed = int(datetime.now().timestamp() * 1e6) % 2**32 if s is None else s
+
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ os.environ["PYTHONHASHSEED"] = str(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+
+def objective(trial):
+ seed()
+ agent_kwargs, env_kwargs, runner_kwargs = samplers[ALGORITHM.__name__](trial)
+
+ evaluations = []
+ for instantiation in range(TRAIN_RUNS):
+ seed(TRAIN_SEED)
+
+ env = GymEnv(ENVIRONMENT, gym_kwargs=ENVIRONMENT_KWARGS, **env_kwargs)
+ agent = ALGORITHM(env, **agent_kwargs)
+ runner = Runner(env, agent, **runner_kwargs)
+ runner._learn_cb = [lambda _, stat: runner._log_progress(stat, prefix=f"learn {instantiation+1}/{TRAIN_RUNS}")]
+
+ eval_env_kwargs = copy.deepcopy(env_kwargs)
+ eval_env_kwargs["environment_count"] = EVAL_AGENTS
+ eval_runner = Runner(
+ GymEnv(ENVIRONMENT, gym_kwargs=ENVIRONMENT_KWARGS, **env_kwargs),
+ agent,
+ **runner_kwargs,
+ )
+ eval_runner._eval_cb = [
+ lambda _, stat: runner._log_progress(stat, prefix=f"eval {instantiation+1}/{TRAIN_RUNS}")
+ ]
+
+ try:
+ runner.learn(TRAIN_ITERATIONS, timeout=TRAIN_TIMEOUT)
+ except Exception:
+ raise optuna.TrialPruned()
+
+ intermediate_evaluations = []
+ for eval_run in range(EVAL_RUNS):
+ eval_runner._eval_cb = [lambda _, stat: runner._log_progress(stat, prefix=f"eval {eval_run+1}/{EVAL_RUNS}")]
+
+ seed()
+ eval_runner.env.reset()
+ intermediate_evaluations.append(eval_runner.evaluate(steps=EVAL_STEPS))
+ eval = np.mean(intermediate_evaluations)
+
+ trial.report(eval, instantiation)
+ if trial.should_prune():
+ raise optuna.TrialPruned()
+
+ evaluations.append(eval)
+
+ evaluation = np.mean(evaluations)
+
+ return evaluation
+
+
+if __name__ == "__main__":
+ tune()
diff --git a/examples/tune_cfg.py b/examples/tune_cfg.py
new file mode 100644
index 0000000..1ca78c2
--- /dev/null
+++ b/examples/tune_cfg.py
@@ -0,0 +1,134 @@
+import torch
+
+from rsl_rl.algorithms import DPPO, PPO
+from rsl_rl.modules import QuantileNetwork
+
+NETWORKS = {"small": [64, 64], "medium": [256, 256], "large": [512, 256, 128]}
+
+
+def sample_dppo_hyperparams(trial):
+ actor_net_activation = trial.suggest_categorical("actor_net_activation", ["relu", "tanh"])
+ actor_net_arch = trial.suggest_categorical("actor_net_arch", list(NETWORKS.keys()))
+ actor_noise_std = trial.suggest_float("actor_noise_std", 0.05, 1.0)
+ batch_count = trial.suggest_int("batch_count", 1, 20)
+ clip_ratio = trial.suggest_categorical("clip_ratio", [0.1, 0.2, 0.3, 0.4])
+ critic_net_activation = trial.suggest_categorical("critic_net_activation", ["relu", "tanh"])
+ critic_net_arch = trial.suggest_categorical("critic_net_arch", list(NETWORKS.keys()))
+ entropy_coeff = trial.suggest_float("entropy_coeff", 0.00000001, 0.1)
+ env_count = trial.suggest_categorical("env_count", [1, 8, 16, 32, 64, 128, 256, 512])
+ gae_lambda = trial.suggest_categorical("gae_lambda", [0.8, 0.9, 0.92, 0.95, 0.98, 0.99, 1.0])
+ gamma = trial.suggest_float("gamma", 0.95, 0.999)
+ gradient_clip = trial.suggest_categorical("gradient_clip", [0.1, 0.3, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 2.0, 5.0])
+ learning_rate = trial.suggest_float("learning_rate", 1e-5, 1)
+ num_steps_per_env = trial.suggest_categorical("steps_per_env", [8, 16, 32, 64, 128, 256, 512, 1024, 2048])
+ quantile_count = trial.suggest_categorical("quantile_count", [20, 50, 100, 200])
+ recurrent = trial.suggest_categorical("recurrent", [True, False])
+ target_kl = trial.suggest_float("target_kl", 0.01, 0.3)
+ value_coeff = trial.suggest_float("value_coeff", 0.0, 1.0)
+ value_measure = trial.suggest_categorical(
+ "value_measure",
+ ["neutral", "var-risk-averse", "var-risk-seeking", "var-super-risk-averse", "var-super-risk-seeking"],
+ )
+
+ actor_net_arch = NETWORKS[actor_net_arch]
+ critic_net_arch = NETWORKS[critic_net_arch]
+ value_measure_kwargs = {
+ "neutral": dict(),
+ "var-risk-averse": dict(confidence_level=0.25),
+ "var-risk-seeking": dict(confidence_level=0.75),
+ "var-super-risk-averse": dict(confidence_level=0.1),
+ "var-super-risk-seeking": dict(confidence_level=0.9),
+ }[value_measure]
+ value_measure = {
+ "neutral": QuantileNetwork.measure_neutral,
+ "var-risk-averse": QuantileNetwork.measure_percentile,
+ "var-risk-seeking": QuantileNetwork.measure_percentile,
+ "var-super-risk-averse": QuantileNetwork.measure_percentile,
+ "var-super-risk-seeking": QuantileNetwork.measure_percentile,
+ }[value_measure]
+ device = "cuda:0" if env_count * num_steps_per_env > 2048 and torch.cuda.is_available() else "cpu"
+
+ agent_kwargs = dict(
+ actor_activations=([actor_net_activation] * len(actor_net_arch)) + ["linear"],
+ actor_hidden_dims=actor_net_arch,
+ actor_input_normalization=False,
+ actor_noise_std=actor_noise_std,
+ batch_count=batch_count,
+ clip_ratio=clip_ratio,
+ critic_activations=([critic_net_activation] * len(critic_net_arch)),
+ critic_hidden_dims=critic_net_arch,
+ critic_input_normalization=False,
+ device=device,
+ entropy_coeff=entropy_coeff,
+ gae_lambda=gae_lambda,
+ gamma=gamma,
+ gradient_clip=gradient_clip,
+ learning_rate=learning_rate,
+ quantile_count=quantile_count,
+ recurrent=recurrent,
+ schedule="adaptive",
+ target_kl=target_kl,
+ value_coeff=value_coeff,
+ value_measure=value_measure,
+ value_measure_kwargs=value_measure_kwargs,
+ )
+ env_kwargs = dict(device=device, environment_count=env_count)
+ runner_kwargs = dict(device=device, num_steps_per_env=num_steps_per_env)
+
+ return agent_kwargs, env_kwargs, runner_kwargs
+
+
+def sample_ppo_hyperparams(trial):
+ actor_net_activation = trial.suggest_categorical("actor_net_activation", ["relu", "tanh"])
+ actor_net_arch = trial.suggest_categorical("actor_net_arch", list(NETWORKS.keys()))
+ actor_noise_std = trial.suggest_float("actor_noise_std", 0.05, 1.0)
+ batch_count = trial.suggest_int("batch_count", 1, 20)
+ clip_ratio = trial.suggest_categorical("clip_ratio", [0.1, 0.2, 0.3, 0.4])
+ critic_net_activation = trial.suggest_categorical("critic_net_activation", ["relu", "tanh"])
+ critic_net_arch = trial.suggest_categorical("critic_net_arch", list(NETWORKS.keys()))
+ entropy_coeff = trial.suggest_float("entropy_coeff", 0.00000001, 0.1)
+ env_count = trial.suggest_categorical("env_count", [1, 8, 16, 32, 64, 128, 256, 512])
+ gae_lambda = trial.suggest_categorical("gae_lambda", [0.8, 0.9, 0.92, 0.95, 0.98, 0.99, 1.0])
+ gamma = trial.suggest_float("gamma", 0.95, 0.999)
+ gradient_clip = trial.suggest_categorical("gradient_clip", [0.1, 0.3, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 2.0, 5.0])
+ learning_rate = trial.suggest_float("learning_rate", 1e-5, 1)
+ num_steps_per_env = trial.suggest_categorical("steps_per_env", [8, 16, 32, 64, 128, 256, 512, 1024, 2048])
+ recurrent = trial.suggest_categorical("recurrent", [True, False])
+ target_kl = trial.suggest_float("target_kl", 0.01, 0.3)
+ value_coeff = trial.suggest_float("value_coeff", 0.0, 1.0)
+
+ actor_net_arch = NETWORKS[actor_net_arch]
+ critic_net_arch = NETWORKS[critic_net_arch]
+ device = "cuda:0" if env_count * num_steps_per_env > 2048 and torch.cuda.is_available() else "cpu"
+
+ agent_kwargs = dict(
+ actor_activations=([actor_net_activation] * len(actor_net_arch)) + ["linear"],
+ actor_hidden_dims=actor_net_arch,
+ actor_input_normalization=False,
+ actor_noise_std=actor_noise_std,
+ batch_count=batch_count,
+ clip_ratio=clip_ratio,
+ critic_activations=([critic_net_activation] * len(critic_net_arch)) + ["linear"],
+ critic_hidden_dims=critic_net_arch,
+ critic_input_normalization=False,
+ device=device,
+ entropy_coeff=entropy_coeff,
+ gae_lambda=gae_lambda,
+ gamma=gamma,
+ gradient_clip=gradient_clip,
+ learning_rate=learning_rate,
+ recurrent=recurrent,
+ schedule="adaptive",
+ target_kl=target_kl,
+ value_coeff=value_coeff,
+ )
+ env_kwargs = dict(device=device, environment_count=env_count)
+ runner_kwargs = dict(device=device, num_steps_per_env=num_steps_per_env)
+
+ return agent_kwargs, env_kwargs, runner_kwargs
+
+
+samplers = {
+ DPPO.__name__: sample_dppo_hyperparams,
+ PPO.__name__: sample_ppo_hyperparams,
+}
diff --git a/examples/wandb_config.example.py b/examples/wandb_config.example.py
new file mode 100644
index 0000000..051710b
--- /dev/null
+++ b/examples/wandb_config.example.py
@@ -0,0 +1,4 @@
+# To use this file, copy it to wandb_config.py and fill in the missing values.
+
+WANDB_API_KEY = ""
+WANDB_ENTITY = ""
diff --git a/rsl_rl/__init__.py b/rsl_rl/__init__.py
index 8b1a072..89f4d9a 100644
--- a/rsl_rl/__init__.py
+++ b/rsl_rl/__init__.py
@@ -1,7 +1,2 @@
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
# SPDX-License-Identifier: BSD-3-Clause
-
-"""Main module for the rsl_rl package."""
-
-__version__ = "2.0.1"
-__license__ = "BSD-3"
diff --git a/rsl_rl/algorithms/__init__.py b/rsl_rl/algorithms/__init__.py
index 776258c..202faf2 100644
--- a/rsl_rl/algorithms/__init__.py
+++ b/rsl_rl/algorithms/__init__.py
@@ -1,8 +1,10 @@
-# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
-# SPDX-License-Identifier: BSD-3-Clause
-
-"""Implementation of different RL agents."""
-
+from .agent import Agent
+from .d4pg import D4PG
+from .ddpg import DDPG
+from .dppo import DPPO
+from .dsac import DSAC
from .ppo import PPO
+from .sac import SAC
+from .td3 import TD3
-__all__ = ["PPO"]
+__all__ = ["Agent", "DDPG", "D4PG", "DPPO", "DSAC", "PPO", "SAC", "TD3"]
diff --git a/rsl_rl/algorithms/actor_critic.py b/rsl_rl/algorithms/actor_critic.py
new file mode 100644
index 0000000..6e3fe8b
--- /dev/null
+++ b/rsl_rl/algorithms/actor_critic.py
@@ -0,0 +1,371 @@
+from abc import abstractmethod
+import torch
+from typing import Any, Callable, Dict, List, Tuple, Union
+
+from rsl_rl.algorithms.agent import Agent
+from rsl_rl.env.vec_env import VecEnv
+from rsl_rl.modules.network import Network
+from rsl_rl.storage.storage import Dataset
+from rsl_rl.utils.utils import environment_dimensions
+from rsl_rl.utils.utils import squeeze_preserve_batch
+
+
+class AbstractActorCritic(Agent):
+ _alg_features = dict(recurrent=False)
+
+ def __init__(
+ self,
+ env: VecEnv,
+ actor_activations: List[str] = ["relu", "relu", "relu", "linear"],
+ actor_hidden_dims: List[int] = [256, 256, 256],
+ actor_init_gain: float = 0.5,
+ actor_input_normalization: bool = False,
+ actor_recurrent_layers: int = 1,
+ actor_recurrent_module: str = Network.recurrent_module_lstm,
+ actor_recurrent_tf_context_length: int = 64,
+ actor_recurrent_tf_head_count: int = 8,
+ actor_shared_dims: int = None,
+ batch_count: int = 1,
+ batch_size: int = 1,
+ critic_activations: List[str] = ["relu", "relu", "relu", "linear"],
+ critic_hidden_dims: List[int] = [256, 256, 256],
+ critic_init_gain: float = 0.5,
+ critic_input_normalization: bool = False,
+ critic_recurrent_layers: int = 1,
+ critic_recurrent_module: str = Network.recurrent_module_lstm,
+ critic_recurrent_tf_context_length: int = 64,
+ critic_recurrent_tf_head_count: int = 8,
+ critic_shared_dims: int = None,
+ polyak: float = 0.995,
+ recurrent: bool = False,
+ return_steps: int = 1,
+ _actor_input_size_delta: int = 0,
+ _critic_input_size_delta: int = 0,
+ **kwargs,
+ ):
+ """Creates an actor critic agent.
+
+ Args:
+ env (VecEnv): A vectorized environment.
+ actor_activations (List[str]): A list of activation functions for the actor network.
+ actor_hidden_dims (List[str]): A list of layer sizes for the actor network.
+ actor_init_gain (float): Network initialization gain for actor.
+ actor_input_normalization (bool): Whether to empirically normalize inputs to the actor network.
+ actor_recurrent_layers (int): The number of recurrent layers to use for the actor network.
+ actor_recurrent_module (str): The recurrent module to use. Must be one of Network.recurrent_modules.
+ actor_shared_dims (int): The number of dimensions to share for an actor with multiple heads.
+ batch_count (int): The number of batches to process per update step.
+ batch_size (int): The size of each batch to process during the update step.
+ critic_activations (List[str]): A list of activation functions for the critic network.
+ critic_hidden_dims: (List[str]): A list of layer sizes for the critic network.
+ critic_init_gain (float): Network initialization gain for critic.
+ critic_input_normalization (bool): Whether to empirically normalize inputs to the critic network.
+ critic_recurrent_layers (int): The number of recurrent layers to use for the critic network.
+ critic_recurrent_module (str): The recurrent module to use. Must be one of Network.recurrent_modules.
+ critic_shared_dims (int): The number of dimensions to share for a critic with multiple heads.
+ polyak (float): The actor-critic target network polyak factor.
+ recurrent (bool): Whether to use recurrent actor and critic networks.
+ recurrent_module (str): The recurrent module to use. Must be one of Network.recurrent_modules.
+ recurrent_tf_context_length (int): The context length of the Transformer.
+ recurrent_tf_head_count (int): The head count of the Transformer.
+ return_steps (float): The number of steps over which to compute the returns (n-step return).
+ _actor_input_size_delta (int): The number of additional dimensions to add to the actor input.
+ _critic_input_size_delta (int): The number of additional dimensions to add to the critic input.
+ """
+ assert (
+ self._alg_features["recurrent"] == True or not recurrent
+ ), f"{self.__class__.__name__} does not support recurrent networks."
+
+ super().__init__(env, **kwargs)
+
+ self.actor: torch.nn.Module = None
+ self.actor_optimizer: torch.nn.Module = None
+ self.critic_optimizer: torch.nn.Module = None
+ self.critic: torch.nn.Module = None
+
+ self._batch_size = batch_size
+ self._batch_count = batch_count
+ self._polyak_factor = polyak
+ self._return_steps = return_steps
+ self._recurrent = recurrent
+
+ self._register_serializable(
+ "_batch_size", "_batch_count", "_discount_factor", "_polyak_factor", "_return_steps"
+ )
+
+ dimensions = environment_dimensions(self.env)
+ try:
+ actor_input_size = dimensions["actor_observations"]
+ critic_input_size = dimensions["critic_observations"]
+ except KeyError:
+ actor_input_size = dimensions["observations"]
+ critic_input_size = dimensions["observations"]
+ self._actor_input_size = actor_input_size + _actor_input_size_delta
+ self._critic_input_size = critic_input_size + self._action_size + _critic_input_size_delta
+
+ self._register_actor_network_kwargs(
+ activations=actor_activations,
+ hidden_dims=actor_hidden_dims,
+ init_gain=actor_init_gain,
+ input_normalization=actor_input_normalization,
+ recurrent=recurrent,
+ recurrent_layers=actor_recurrent_layers,
+ recurrent_module=actor_recurrent_module,
+ recurrent_tf_context_length=actor_recurrent_tf_context_length,
+ recurrent_tf_head_count=actor_recurrent_tf_head_count,
+ )
+
+ if actor_shared_dims is not None:
+ self._register_actor_network_kwargs(shared_dims=actor_shared_dims)
+
+ self._register_critic_network_kwargs(
+ activations=critic_activations,
+ hidden_dims=critic_hidden_dims,
+ init_gain=critic_init_gain,
+ input_normalization=critic_input_normalization,
+ recurrent=recurrent,
+ recurrent_layers=critic_recurrent_layers,
+ recurrent_module=critic_recurrent_module,
+ recurrent_tf_context_length=critic_recurrent_tf_context_length,
+ recurrent_tf_head_count=critic_recurrent_tf_head_count,
+ )
+
+ if critic_shared_dims is not None:
+ self._register_critic_network_kwargs(shared_dims=critic_shared_dims)
+
+ self._register_serializable(
+ "_actor_input_size", "_actor_network_kwargs", "_critic_input_size", "_critic_network_kwargs"
+ )
+
+ # For computing n-step returns using prior transitions.
+ self._stored_dataset = []
+
+ def export_onnx(self) -> Tuple[torch.nn.Module, torch.Tensor, Dict]:
+ self.eval_mode()
+
+ class ONNXActor(torch.nn.Module):
+ def __init__(self, model: torch.nn.Module):
+ super().__init__()
+
+ self.model = model
+
+ def forward(self, x: torch.Tensor, hidden_state: Tuple[torch.Tensor, torch.Tensor] = None):
+ if hidden_state is None:
+ return self.model(x)
+
+ data = self.model(x, hidden_state=hidden_state)
+ hidden_state = self.model.last_hidden_state
+
+ return data, hidden_state
+
+ model = ONNXActor(self.actor)
+ kwargs = dict(
+ export_params=True,
+ opset_version=11,
+ verbose=True,
+ dynamic_axes={},
+ )
+
+ kwargs["input_names"] = ["observations"]
+ kwargs["output_names"] = ["actions"]
+
+ args = torch.zeros(1, self._actor_input_size)
+
+ if self.actor.recurrent:
+ hidden_state = (
+ torch.zeros(self.actor._features[0].num_layers, 1, self.actor._features[0].hidden_size),
+ torch.zeros(self.actor._features[0].num_layers, 1, self.actor._features[0].hidden_size),
+ )
+ args = (args, {"hidden_state": hidden_state})
+
+ return model, args, kwargs
+
+ def draw_random_actions(self, obs: torch.Tensor, env_info: Dict[str, Any]) -> Tuple[torch.Tensor, Dict]:
+ actions, data = super().draw_random_actions(obs, env_info)
+
+ actor_obs, critic_obs = self._process_observations(obs, env_info)
+ data.update({"actor_observations": actor_obs.clone(), "critic_observations": critic_obs.clone()})
+
+ return actions, data
+
+ def get_inference_policy(self, device=None) -> Callable:
+ self.to(device)
+ self.eval_mode()
+
+ if self.actor.recurrent:
+ self.actor.reset_full_hidden_state(batch_size=self.env.num_envs)
+
+ if self.critic.recurrent:
+ self.critic.reset_full_hidden_state(batch_size=self.env.num_envs)
+
+ def policy(obs, env_info=None):
+ with torch.inference_mode():
+ obs, _ = self._process_observations(obs, env_info)
+
+ actions = self._process_actions(self.actor.forward(obs))
+
+ return actions
+
+ return policy
+
+ def process_transition(
+ self,
+ observations: torch.Tensor,
+ environment_info: Dict[str, Any],
+ actions: torch.Tensor,
+ rewards: torch.Tensor,
+ next_observations: torch.Tensor,
+ next_environment_info: torch.Tensor,
+ dones: torch.Tensor,
+ data: Dict[str, torch.Tensor],
+ ) -> Dict[str, torch.Tensor]:
+ if "actor_observations" in data and "critic_observations" in data:
+ actor_obs, critic_obs = data["actor_observations"], data["critic_observations"]
+ else:
+ actor_obs, critic_obs = self._process_observations(observations, environment_info)
+
+ if "next_actor_observations" in data and "next_critic_observations" in data:
+ next_actor_obs, next_critic_obs = data["next_actor_observations"], data["next_critic_observations"]
+ else:
+ next_actor_obs, next_critic_obs = self._process_observations(next_observations, next_environment_info)
+
+ transition = {
+ "actions": actions,
+ "actor_observations": actor_obs,
+ "critic_observations": critic_obs,
+ "dones": dones,
+ "next_actor_observations": next_actor_obs,
+ "next_critic_observations": next_critic_obs,
+ "rewards": squeeze_preserve_batch(rewards),
+ "timeouts": self._extract_timeouts(next_environment_info),
+ }
+ transition.update(data)
+
+ for key, value in transition.items():
+ transition[key] = value.detach().clone()
+
+ return transition
+
+ @property
+ def recurrent(self) -> bool:
+ return self._recurrent
+
+ def register_terminations(self, terminations: torch.Tensor) -> None:
+ pass
+
+ @abstractmethod
+ def update(self, dataset: Dataset) -> Dict[str, Union[float, torch.Tensor]]:
+ with torch.inference_mode():
+ self.storage.append(self._process_dataset(dataset))
+
+ def _critic_input(self, observations, actions) -> torch.Tensor:
+ """Combines observations and actions into a tensor that can be fed into the critic network.
+
+ Args:
+ observations (torch.Tensor): The critic observations.
+ actions (torch.Tensor): The actions computed by the actor.
+ Returns:
+ A torch.Tensor of inputs for the critic network.
+ """
+ return torch.cat((observations, actions), dim=-1)
+
+ def _extract_timeouts(self, next_environment_info):
+ """Extracts timeout information from the transition next state information dictionary.
+
+ Args:
+ next_environment_info (Dict[str, Any]): The transition next state information dictionary.
+ Returns:
+ A torch.Tensor vector of actor timeouts.
+ """
+ if "time_outs" not in next_environment_info:
+ return torch.zeros(self.env.num_envs, device=self.device)
+
+ timeouts = squeeze_preserve_batch(next_environment_info["time_outs"].to(self.device))
+
+ return timeouts
+
+ def _process_dataset(self, dataset: Dataset) -> Dataset:
+ """Processes a dataset before it is added to the replay memory.
+
+ Handles n-step returns and timeouts.
+
+ TODO: This function seems to be a bottleneck in the training pipeline - speed it up!
+
+ Args:
+ dataset (Dataset): The dataset to process.
+ Returns:
+ A Dataset object containing the processed data.
+ """
+ assert len(dataset) >= self._return_steps
+
+ dataset = self._stored_dataset + dataset
+ length = len(dataset) - self._return_steps + 1
+ self._stored_dataset = dataset[length:]
+
+ for idx in range(len(dataset) - self._return_steps + 1):
+ dead_idx = torch.zeros_like(dataset[idx]["dones"])
+ rewards = torch.zeros_like(dataset[idx]["rewards"])
+
+ for k in range(self._return_steps):
+ data = dataset[idx + k]
+ alive_idx = (dead_idx == 0).nonzero()
+ critic_predictions = self.critic.forward(
+ self._critic_input(
+ data["critic_observations"].clone().to(self.device),
+ data["actions"].clone().to(self.device),
+ )
+ )
+ rewards[alive_idx] += self._discount_factor**k * data["rewards"][alive_idx]
+ rewards[alive_idx] += (
+ self._discount_factor ** (k + 1) * data["timeouts"][alive_idx] * critic_predictions[alive_idx]
+ )
+ dead_idx += data["dones"]
+ dead_idx += data["timeouts"]
+
+ dataset[idx]["rewards"] = rewards
+
+ return dataset[:length]
+
+ def _process_observations(
+ self, observations: torch.Tensor, environment_info: Dict[str, Any] = None
+ ) -> Tuple[torch.Tensor, ...]:
+ """Processes observations returned by the environment to extract actor and critic observations.
+
+ Args:
+ observations (torch.Tensor): normal environment observations.
+ environment_info (Dict[str, Any]): A dictionary of additional environment information.
+ Returns:
+ A tuple containing two torch.Tensors with actor and critic observations, respectively.
+ """
+ try:
+ critic_obs = environment_info["observations"]["critic"]
+ except (KeyError, TypeError):
+ critic_obs = observations
+
+ actor_obs, critic_obs = observations.to(self.device), critic_obs.to(self.device)
+
+ return actor_obs, critic_obs
+
+ def _register_actor_network_kwargs(self, **kwargs) -> None:
+ """Function to configure actor network in child classes before calling super().__init__()."""
+ if not hasattr(self, "_actor_network_kwargs"):
+ self._actor_network_kwargs = dict()
+
+ self._actor_network_kwargs.update(**kwargs)
+
+ def _register_critic_network_kwargs(self, **kwargs) -> None:
+ """Function to configure critic network in child classes before calling super().__init__()."""
+ if not hasattr(self, "_critic_network_kwargs"):
+ self._critic_network_kwargs = dict()
+
+ self._critic_network_kwargs.update(**kwargs)
+
+ def _update_target(self, online: torch.nn.Module, target: torch.nn.Module) -> None:
+ """Updates the target network using the polyak factor.
+
+ Args:
+ online (torch.nn.Module): The online network.
+ target (torch.nn.Module): The target network.
+ """
+ for op, tp in zip(online.parameters(), target.parameters()):
+ tp.data.copy_((1.0 - self._polyak_factor) * op.data + self._polyak_factor * tp.data)
diff --git a/rsl_rl/algorithms/agent.py b/rsl_rl/algorithms/agent.py
new file mode 100644
index 0000000..90afd04
--- /dev/null
+++ b/rsl_rl/algorithms/agent.py
@@ -0,0 +1,197 @@
+from __future__ import annotations
+from abc import ABC, abstractmethod
+import numpy as np
+import torch
+from typing import Any, Callable, Dict, Tuple, Union
+
+from rsl_rl.env import VecEnv
+from rsl_rl.storage.storage import Dataset
+from rsl_rl.utils.benchmarkable import Benchmarkable
+from rsl_rl.utils.serializable import Serializable
+from rsl_rl.utils.utils import environment_dimensions
+
+
+class Agent(ABC, Benchmarkable, Serializable):
+ def __init__(
+ self,
+ env: VecEnv,
+ action_max: float = np.inf,
+ action_min: float = -np.inf,
+ benchmark: bool = False,
+ device: str = "cpu",
+ gamma: float = 0.99,
+ ):
+ """Creates an agent.
+
+ Args:
+ env (VecEnv): The envrionment of the agent.
+ action_max (float): The maximum action value.
+ action_min (float): The minimum action value.
+ bechmark (bool): Whether to benchmark runtime.
+ device (str): The device to use for computation.
+ gamma (float): The environment discount factor.
+ """
+ super().__init__()
+
+ self.env = env
+ self.device = device
+ self.storage = None
+
+ self._action_max = action_max
+ self._action_min = action_min
+ self._discount_factor = gamma
+
+ self._register_serializable("_action_max", "_action_min", "_discount_factor")
+
+ dimensions = environment_dimensions(self.env)
+ self._action_size = dimensions["actions"]
+
+ self._register_serializable("_action_size")
+
+ if self._action_min > -np.inf and self._action_max < np.inf:
+ self._rand_scale = self._action_max - self._action_min
+ self._rand_offset = self._action_min
+ else:
+ self._rand_scale = 2.0
+ self._rand_offset = -1.0
+
+ self._bm_toggle(benchmark)
+
+ @abstractmethod
+ def draw_actions(
+ self, obs: torch.Tensor, env_info: Dict[str, Any]
+ ) -> Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]:
+ """Draws actions from the action space.
+
+ Args:
+ obs (torch.Tensor): The observations for which to draw actions.
+ env_info (Dict[str, Any]): The environment information for the observations.
+ Returns:
+ A tuple containing the actions and the data dictionary.
+ """
+ pass
+
+ def draw_random_actions(
+ self, obs: torch.Tensor, env_info: Dict[str, Any]
+ ) -> Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]:
+ """Draws random actions from the action space.
+
+ Args:
+ obs (torch.Tensor): The observations to include in the data dictionary.
+ env_info (Dict[str, Any]): The environment information to include in the data dictionary.
+ Returns:
+ A tuple containing the random actions and the data dictionary.
+ """
+ actions = self._process_actions(
+ self._rand_offset + self._rand_scale * torch.rand(self.env.num_envs, self._action_size)
+ )
+
+ return actions, {}
+
+ @abstractmethod
+ def eval_mode(self) -> Agent:
+ """Sets the agent to evaluation mode."""
+ return self
+
+ @abstractmethod
+ def export_onnx(self) -> Tuple[torch.nn.Module, torch.Tensor, Dict]:
+ """Exports the agent's policy network to ONNX format.
+
+ Returns:
+ A tuple containing the ONNX model, the input arguments, and the keyword arguments.
+ """
+ pass
+
+ @property
+ def gamma(self) -> float:
+ return self._discount_factor
+
+ @abstractmethod
+ def get_inference_policy(self, device: str = None) -> Callable:
+ """Returns a function that computes actions from observations without storing gradients.
+
+ Args:
+ device (torch.device): The device to use for inference.
+ Returns:
+ A function that computes actions from observations.
+ """
+ pass
+
+ @property
+ def initialized(self) -> bool:
+ """Whether the agent has been initialized."""
+ return self.storage.initialized
+
+ @abstractmethod
+ def process_transition(
+ self,
+ observations: torch.Tensor,
+ environement_info: Dict[str, Any],
+ actions: torch.Tensor,
+ rewards: torch.Tensor,
+ next_observations: torch.Tensor,
+ next_environment_info: torch.Tensor,
+ dones: torch.Tensor,
+ data: Dict[str, torch.Tensor],
+ ) -> Dict[str, torch.Tensor]:
+ """Processes a transition before it is added to the replay memory.
+
+ Args:
+ observations (torch.Tensor): The observations from the environment.
+ environment_info (Dict[str, Any]): The environment information.
+ actions (torch.Tensor): The actions computed by the actor.
+ rewards (torch.Tensor): The rewards from the environment.
+ next_observations (torch.Tensor): The next observations from the environment.
+ next_environment_info (Dict[str, Any]): The next environment information.
+ dones (torch.Tensor): The done flags from the environment.
+ data (Dict[str, torch.Tensor]): Additional data to include in the transition.
+ Returns:
+ A dictionary containing the processed transition.
+ """
+ pass
+
+ @abstractmethod
+ def register_terminations(self, terminations: torch.Tensor) -> None:
+ """Registers terminations with the actor critic agent.
+
+ Args:
+ terminations (torch.Tensor): A tensor of indicator values for each environment.
+ """
+ pass
+
+ @abstractmethod
+ def to(self, device: str) -> Agent:
+ """Transfers agent parameters to device."""
+ self.device = device
+
+ return self
+
+ @abstractmethod
+ def train_mode(self) -> Agent:
+ """Sets the agent to training mode."""
+ return self
+
+ @abstractmethod
+ def update(self, dataset: Dataset) -> Dict[str, Union[float, torch.Tensor]]:
+ """Updates the agent's parameters.
+
+ Args:
+ dataset (Dataset): The dataset from which to update the agent.
+ Returns:
+ A dictionary containing the loss values.
+ """
+ pass
+
+ def _process_actions(self, actions: torch.Tensor) -> torch.Tensor:
+ """Processes actions produced by the agent.
+
+ Args:
+ actions (torch.Tensor): The raw actions.
+ Returns:
+ A torch.Tensor containing the processed actions.
+ """
+ actions = actions.reshape(-1, self._action_size)
+ actions = actions.clamp(self._action_min, self._action_max)
+ actions = actions.to(self.device)
+
+ return actions
diff --git a/rsl_rl/algorithms/d4pg.py b/rsl_rl/algorithms/d4pg.py
new file mode 100644
index 0000000..0dd5535
--- /dev/null
+++ b/rsl_rl/algorithms/d4pg.py
@@ -0,0 +1,168 @@
+from __future__ import annotations
+import torch
+from typing import Dict, Union
+from rsl_rl.algorithms.dpg import AbstractDPG
+from rsl_rl.env import VecEnv
+from rsl_rl.storage.storage import Dataset
+
+from rsl_rl.modules import CategoricalNetwork, Network
+
+
+class D4PG(AbstractDPG):
+ """Distributed Distributional Deep Deterministic Policy Gradients algorithm.
+
+ This is an implementation of the D4PG algorithm by Barth-Maron et. al. for vectorized environments.
+
+ Paper: https://arxiv.org/pdf/1804.08617.pdf
+ """
+
+ def __init__(
+ self,
+ env: VecEnv,
+ actor_lr: float = 1e-4,
+ atom_count: int = 51,
+ critic_activations: list = ["relu", "relu", "relu"],
+ critic_lr: float = 1e-3,
+ target_update_delay: int = 2,
+ value_max: float = 10.0,
+ value_min: float = -10.0,
+ **kwargs,
+ ) -> None:
+ """
+ Args:
+ env (VecEnv): A vectorized environment.
+ actor_lr (float): The learning rate for the actor network.
+ atom_count (int): The number of atoms to use for the categorical distribution.
+ critic_activations (list): A list of activation functions to use for the critic network.
+ critic_lr (float): The learning rate for the critic network.
+ target_update_delay (int): The number of steps to wait before updating the target networks.
+ value_max (float): The maximum value for the categorical distribution.
+ value_min (float): The minimum value for the categorical distribution.
+ """
+ kwargs["critic_activations"] = critic_activations
+
+ super().__init__(env, **kwargs)
+
+ self._atom_count = atom_count
+ self._target_update_delay = target_update_delay
+ self._value_max = value_max
+ self._value_min = value_min
+
+ self._register_serializable("_atom_count", "_target_update_delay", "_value_max", "_value_min")
+
+ self.actor = Network(self._actor_input_size, self._action_size, **self._actor_network_kwargs)
+ self.critic = CategoricalNetwork(
+ self._critic_input_size,
+ 1,
+ atom_count=atom_count,
+ value_max=value_max,
+ value_min=value_min,
+ **self._critic_network_kwargs,
+ )
+
+ self.target_actor = Network(self._actor_input_size, self._action_size, **self._actor_network_kwargs)
+ self.target_critic = CategoricalNetwork(
+ self._critic_input_size,
+ 1,
+ atom_count=atom_count,
+ value_max=value_max,
+ value_min=value_min,
+ **self._critic_network_kwargs,
+ )
+ self.target_actor.load_state_dict(self.actor.state_dict())
+ self.target_critic.load_state_dict(self.critic.state_dict())
+
+ self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
+ self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)
+
+ self._register_serializable(
+ "actor", "critic", "target_actor", "target_critic", "actor_optimizer", "critic_optimizer"
+ )
+
+ self._update_step = 0
+
+ self._register_serializable("_update_step")
+
+ self.to(self.device)
+
+ def eval_mode(self) -> D4PG:
+ super().eval_mode()
+
+ self.actor.eval()
+ self.critic.eval()
+ self.target_actor.eval()
+ self.target_critic.eval()
+
+ return self
+
+ def to(self, device: str) -> D4PG:
+ super().to(device)
+
+ self.actor.to(device)
+ self.critic.to(device)
+ self.target_actor.to(device)
+ self.target_critic.to(device)
+
+ return self
+
+ def train_mode(self) -> D4PG:
+ super().train_mode()
+
+ self.actor.train()
+ self.critic.train()
+ self.target_actor.train()
+ self.target_critic.train()
+
+ return self
+
+ def update(self, dataset: Dataset) -> Dict[str, Union[float, torch.Tensor]]:
+ super().update(dataset)
+
+ if not self.initialized:
+ return {}
+
+ total_actor_loss = torch.zeros(self._batch_count)
+ total_critic_loss = torch.zeros(self._batch_count)
+
+ for idx, batch in enumerate(self.storage.batch_generator(self._batch_size, self._batch_count)):
+ actor_obs = batch["actor_observations"]
+ critic_obs = batch["critic_observations"]
+ actions = batch["actions"].reshape(self._batch_size, -1)
+ rewards = batch["rewards"]
+ actor_next_obs = batch["next_actor_observations"]
+ critic_next_obs = batch["next_critic_observations"]
+ dones = batch["dones"]
+
+ predictions = self.critic.forward(self._critic_input(critic_obs, actions), distribution=True).squeeze()
+ target_actor_prediction = self._process_actions(self.target_actor.forward(actor_next_obs))
+ target_probabilities = self.target_critic.forward(
+ self._critic_input(critic_next_obs, target_actor_prediction), distribution=True
+ ).squeeze()
+ targets = self.target_critic.compute_targets(rewards, dones, self._discount_factor)
+ critic_loss = self.target_critic.categorical_loss(predictions, target_probabilities, targets)
+
+ self.critic_optimizer.zero_grad()
+ critic_loss.backward()
+ self.critic_optimizer.step()
+
+ evaluation = self.critic.forward(
+ self._critic_input(critic_obs, self._process_actions(self.actor.forward(actor_obs)))
+ )
+ actor_loss = -evaluation.mean()
+
+ self.actor_optimizer.zero_grad()
+ actor_loss.backward()
+ self.actor_optimizer.step()
+
+ if self._update_step % self._target_update_delay == 0:
+ self._update_target(self.actor, self.target_actor)
+ self._update_target(self.critic, self.target_critic)
+
+ self._update_step += 1
+
+ total_actor_loss[idx] = actor_loss.item()
+ total_critic_loss[idx] = critic_loss.item()
+
+ stats = {"actor": total_actor_loss.mean().item(), "critic": total_critic_loss.mean().item()}
+
+ return stats
diff --git a/rsl_rl/algorithms/ddpg.py b/rsl_rl/algorithms/ddpg.py
new file mode 100644
index 0000000..ec91793
--- /dev/null
+++ b/rsl_rl/algorithms/ddpg.py
@@ -0,0 +1,125 @@
+from __future__ import annotations
+import torch
+from torch import optim
+from typing import Dict, Union
+
+from rsl_rl.algorithms.dpg import AbstractDPG
+from rsl_rl.env import VecEnv
+from rsl_rl.modules.network import Network
+from rsl_rl.storage.storage import Dataset
+
+
+class DDPG(AbstractDPG):
+ """Deep Deterministic Policy Gradients algorithm.
+
+ This is an implementation of the DDPG algorithm by Lillicrap et. al. for vectorized environments.
+
+ Paper: https://arxiv.org/pdf/1509.02971.pdf
+ """
+
+ def __init__(
+ self,
+ env: VecEnv,
+ actor_lr: float = 1e-4,
+ critic_lr: float = 1e-3,
+ **kwargs,
+ ) -> None:
+ super().__init__(env, **kwargs)
+
+ self.actor = Network(self._actor_input_size, self._action_size, **self._actor_network_kwargs)
+ self.critic = Network(self._critic_input_size, 1, **self._critic_network_kwargs)
+
+ self.target_actor = Network(self._actor_input_size, self._action_size, **self._actor_network_kwargs)
+ self.target_critic = Network(self._critic_input_size, 1, **self._critic_network_kwargs)
+ self.target_actor.load_state_dict(self.actor.state_dict())
+ self.target_critic.load_state_dict(self.critic.state_dict())
+
+ self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)
+ self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=critic_lr)
+
+ self._register_serializable(
+ "actor", "critic", "target_actor", "target_critic", "actor_optimizer", "critic_optimizer"
+ )
+
+ self.to(self.device)
+
+ def eval_mode(self) -> DDPG:
+ super().eval_mode()
+
+ self.actor.eval()
+ self.critic.eval()
+ self.target_actor.eval()
+ self.target_critic.eval()
+
+ return self
+
+ def to(self, device: str) -> DDPG:
+ """Transfers agent parameters to device."""
+ super().to(device)
+
+ self.actor.to(device)
+ self.critic.to(device)
+ self.target_actor.to(device)
+ self.target_critic.to(device)
+
+ return self
+
+ def train_mode(self) -> DDPG:
+ super().train_mode()
+
+ self.actor.train()
+ self.critic.train()
+ self.target_actor.train()
+ self.target_critic.train()
+
+ return self
+
+ def update(self, dataset: Dataset) -> Dict[str, Union[float, torch.Tensor]]:
+ super().update(dataset)
+
+ if not self.initialized:
+ return {}
+
+ total_actor_loss = torch.zeros(self._batch_count)
+ total_critic_loss = torch.zeros(self._batch_count)
+
+ for idx, batch in enumerate(self.storage.batch_generator(self._batch_size, self._batch_count)):
+ actor_obs = batch["actor_observations"]
+ critic_obs = batch["critic_observations"]
+ actions = batch["actions"]
+ rewards = batch["rewards"]
+ actor_next_obs = batch["next_actor_observations"]
+ critic_next_obs = batch["next_critic_observations"]
+ dones = batch["dones"]
+
+ target_actor_prediction = self._process_actions(self.target_actor.forward(actor_next_obs))
+ target_critic_prediction = self.target_critic.forward(
+ self._critic_input(critic_next_obs, target_actor_prediction)
+ )
+
+ target = rewards + self._discount_factor * (1 - dones) * target_critic_prediction
+ prediction = self.critic.forward(self._critic_input(critic_obs, actions))
+ critic_loss = (prediction - target).pow(2).mean()
+
+ self.critic_optimizer.zero_grad()
+ critic_loss.backward()
+ self.critic_optimizer.step()
+
+ evaluation = self.critic.forward(
+ self._critic_input(critic_obs, self._process_actions(self.actor.forward(actor_obs)))
+ )
+ actor_loss = -evaluation.mean()
+
+ self.actor_optimizer.zero_grad()
+ actor_loss.backward()
+ self.actor_optimizer.step()
+
+ self._update_target(self.actor, self.target_actor)
+ self._update_target(self.critic, self.target_critic)
+
+ total_actor_loss[idx] = actor_loss.item()
+ total_critic_loss[idx] = critic_loss.item()
+
+ stats = {"actor": total_actor_loss.mean().item(), "critic": total_critic_loss.mean().item()}
+
+ return stats
diff --git a/rsl_rl/algorithms/dpg.py b/rsl_rl/algorithms/dpg.py
new file mode 100644
index 0000000..8173d0b
--- /dev/null
+++ b/rsl_rl/algorithms/dpg.py
@@ -0,0 +1,49 @@
+import torch
+from typing import Any, Dict, Tuple, Union
+
+from rsl_rl.algorithms.actor_critic import AbstractActorCritic
+from rsl_rl.env import VecEnv
+from rsl_rl.storage.replay_storage import ReplayStorage
+from rsl_rl.storage.storage import Dataset
+
+
+class AbstractDPG(AbstractActorCritic):
+ def __init__(
+ self, env: VecEnv, action_noise_scale: float = 0.1, storage_initial_size=0, storage_size=100000, **kwargs
+ ):
+ """
+ Args:
+ env (VecEnv): A vectorized environment.
+ action_noise_scale (float): The scale of the gaussian action noise.
+ storage_initial_size (int): Initial size of the replay storage.
+ storage_size (int): Maximum size of the replay storage.
+ """
+ assert action_noise_scale > 0
+
+ super().__init__(env, **kwargs)
+
+ self.storage = ReplayStorage(
+ self.env.num_envs, storage_size, device=self.device, initial_size=storage_initial_size
+ )
+
+ self._register_serializable("storage")
+
+ self._action_noise_scale = action_noise_scale
+
+ self._register_serializable("_action_noise_scale")
+
+ def draw_actions(
+ self, obs: torch.Tensor, env_info: Dict[str, Any]
+ ) -> Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]:
+ actor_obs, critic_obs = self._process_observations(obs, env_info)
+
+ actions = self.actor.forward(actor_obs)
+ noise = torch.normal(torch.zeros_like(actions), torch.ones_like(actions) * self._action_noise_scale)
+ noisy_actions = self._process_actions(actions + noise)
+
+ data = {"actor_observations": actor_obs.clone(), "critic_observations": critic_obs.clone()}
+
+ return noisy_actions, data
+
+ def register_terminations(self, terminations: torch.Tensor) -> None:
+ pass
diff --git a/rsl_rl/algorithms/dppo.py b/rsl_rl/algorithms/dppo.py
new file mode 100644
index 0000000..b919816
--- /dev/null
+++ b/rsl_rl/algorithms/dppo.py
@@ -0,0 +1,327 @@
+import torch
+from torch import nn
+from typing import Dict, List, Tuple, Type, Union
+
+from rsl_rl.algorithms.ppo import PPO
+from rsl_rl.distributions import QuantileDistribution
+from rsl_rl.env import VecEnv
+from rsl_rl.utils.benchmarkable import Benchmarkable
+from rsl_rl.utils.recurrency import trajectories_to_transitions, transitions_to_trajectories
+from rsl_rl.modules import ImplicitQuantileNetwork, QuantileNetwork
+from rsl_rl.storage.storage import Dataset
+
+
+class DPPO(PPO):
+ """Distributional Proximal Policy Optimization algorithm.
+
+ This algorithm is an extension of PPO that uses a distributional method (either QR-DQN or IQN) to estimate the
+ value function.
+
+ QR-DQN Paper: https://arxiv.org/pdf/1710.10044.pdf
+ IQN Paper: https://arxiv.org/pdf/1806.06923.pdf
+
+ The implementation works with recurrent neural networks. We further implement Sample-Replacement SR(lambda) for the
+ value target computation, as described by Nam et. al. in https://arxiv.org/pdf/2105.11366.pdf.
+ """
+
+ critic_network: Type[nn.Module] = QuantileNetwork
+ _alg_features = dict(recurrent=True)
+
+ value_loss_energy = "sample_energy"
+ value_loss_l1 = "quantile_l1"
+ value_loss_huber = "quantile_huber"
+
+ network_qrdqn = "qrdqn"
+ network_iqn = "iqn"
+
+ networks = {network_qrdqn: QuantileNetwork, network_iqn: ImplicitQuantileNetwork}
+
+ values_losses = {
+ network_qrdqn: {
+ value_loss_energy: QuantileNetwork.sample_energy_loss,
+ value_loss_l1: QuantileNetwork.quantile_l1_loss,
+ value_loss_huber: QuantileNetwork.quantile_huber_loss,
+ },
+ network_iqn: {
+ value_loss_energy: ImplicitQuantileNetwork.sample_energy_loss,
+ },
+ }
+
+ def __init__(
+ self,
+ env: VecEnv,
+ critic_activations: List[str] = ["relu", "relu", "relu"],
+ critic_network: str = network_qrdqn,
+ iqn_action_samples: int = 32,
+ iqn_embedding_size: int = 64,
+ iqn_feature_layers: int = 1,
+ iqn_value_samples: int = 8,
+ qrdqn_quantile_count: int = 200,
+ value_lambda: float = 0.95,
+ value_loss: str = value_loss_l1,
+ value_loss_kwargs: Dict = {},
+ value_measure: str = None,
+ value_measure_adaptation: Union[Tuple, None] = None,
+ value_measure_kwargs: Dict = {},
+ **kwargs,
+ ):
+ """
+ Args:
+ env (VecEnv): A vectorized environment.
+ critic_activations (List[str]): A list of activations to use for the critic network.
+ critic_network (str): The critic network to use.
+ iqn_action_samples (int): The number of samples to use for the critic IQN network when acting.
+ iqn_embedding_size (int): The embedding size to use for the critic IQN network.
+ iqn_feature_layers (int): The number of feature layers to use for the critic IQN network.
+ iqn_value_samples (int): The number of samples to use for the critic IQN network when computing the value.
+ qrdqn_quantile_count (int): The number of quantiles to use for the critic QR network.
+ value_lambda (float): The lambda parameter for the SR(lambda) value target computation.
+ value_loss (str): The loss function to use for the critic network.
+ value_loss_kwargs (Dict): Keyword arguments for computing the value loss.
+ value_measure (str): The probability measure to apply to the critic network output distribution when
+ updating the policy.
+ value_measure_adaptation (Union[Tuple, None]): Controls adaptation of the value measure. If None, no
+ adaptation is performed. If a tuple, the tuple specifies the observations that are passed to the value
+ measure.
+ value_measure_kwargs (Dict): The keyword arguments to pass to the value measure.
+ """
+ self._register_critic_network_kwargs(measure=value_measure, measure_kwargs=value_measure_kwargs)
+
+ self._critic_network_name = critic_network
+ self.critic_network = self.networks[self._critic_network_name]
+ if self._critic_network_name == self.network_qrdqn:
+ self._register_critic_network_kwargs(quantile_count=qrdqn_quantile_count)
+ elif self._critic_network_name == self.network_iqn:
+ self._register_critic_network_kwargs(feature_layers=iqn_feature_layers, embedding_size=iqn_embedding_size)
+
+ kwargs["critic_activations"] = critic_activations
+
+ if value_measure_adaptation is not None:
+ # Value measure adaptation observations are not passed to the critic network.
+ kwargs["_critic_input_size_delta"] = (
+ kwargs["_critic_input_size_delta"] if "_critic_input_size_delta" in kwargs else 0
+ ) - len(value_measure_adaptation)
+
+ super().__init__(env, **kwargs)
+
+ self._value_lambda = value_lambda
+ self._value_loss_name = value_loss
+ self._register_serializable("_value_lambda", "_value_loss_name")
+
+ assert (
+ self._value_loss_name in self.values_losses[self._critic_network_name]
+ ), f"Value loss '{self._value_loss_name}' is not supported for network '{self._critic_network_name}'."
+ value_loss_func = self.values_losses[critic_network][self._value_loss_name]
+ self._value_loss = lambda *args, **kwargs: value_loss_func(self.critic, *args, **kwargs)
+
+ if value_loss == self.value_loss_energy:
+ value_loss_kwargs["sample_count"] = (
+ value_loss_kwargs["sample_count"] if "sample_count" in value_loss_kwargs else 100
+ )
+
+ self._value_loss_kwargs = value_loss_kwargs
+ self._register_serializable("_value_loss_kwargs")
+
+ self._value_measure_adaptation = value_measure_adaptation
+ self._register_serializable("_value_measure_adaptation")
+
+ if self._critic_network_name == self.network_iqn:
+ self._iqn_action_samples = iqn_action_samples
+ self._iqn_value_samples = iqn_value_samples
+ self._register_serializable("_iqn_action_samples", "_iqn_value_samples")
+
+ def _critic_input(self, observations, actions=None) -> torch.Tensor:
+ mask, shape = self._get_critic_obs_mask(observations)
+
+ processed_observations = observations[mask].reshape(*shape)
+
+ return processed_observations
+
+ def _get_critic_obs_mask(self, observations):
+ mask = torch.ones_like(observations).bool()
+
+ if self._value_measure_adaptation is not None:
+ mask[:, self._value_measure_adaptation] = False
+
+ shape = (observations.shape[0], self._critic_input_size)
+
+ return mask, shape
+
+ def _process_quants(self, x):
+ if self._value_loss_name == self.value_loss_energy:
+ quants, idx = QuantileDistribution(x).sample(self._value_loss_kwargs["sample_count"])
+ else:
+ quants, idx = x, None
+
+ return quants, idx
+
+ @Benchmarkable.register
+ def process_transition(self, *args) -> Dict[str, torch.Tensor]:
+ transition = super(PPO, self).process_transition(*args)
+
+ if self.recurrent:
+ transition["critic_state_h"] = self.critic.hidden_state[0].detach()
+ transition["critic_state_c"] = self.critic.hidden_state[1].detach()
+
+ transition["full_critic_observations"] = transition["critic_observations"].detach()
+ transition["full_next_critic_observations"] = transition["next_critic_observations"].detach()
+ mask, shape = self._get_critic_obs_mask(transition["critic_observations"])
+ transition["critic_observations"] = transition["critic_observations"][mask].reshape(*shape)
+ transition["next_critic_observations"] = transition["next_critic_observations"][mask].reshape(*shape)
+
+ critic_kwargs = (
+ {"sample_count": self._iqn_action_samples} if self._critic_network_name == self.network_iqn else {}
+ )
+ transition["values"] = self.critic.forward(
+ transition["critic_observations"],
+ measure_args=self._extract_value_measure_adaptation(transition["full_critic_observations"]),
+ **critic_kwargs,
+ ).detach()
+
+ if self._critic_network_name == self.network_iqn:
+ # For IQN, we sample new (undistorted) quantiles for computing the value update
+ critic_kwargs = (
+ {"hidden_state": (transition["critic_state_h"], transition["critic_state_c"])} if self.recurrent else {}
+ )
+ self.critic.forward(
+ transition["critic_observations"],
+ sample_count=self._iqn_value_samples,
+ use_measure=False,
+ **critic_kwargs,
+ ).detach()
+ transition["value_taus"] = self.critic.last_taus.detach().reshape(transition["values"].shape[0], -1)
+
+ transition["value_quants"] = self.critic.last_quantiles.detach().reshape(transition["values"].shape[0], -1)
+
+ if self.recurrent:
+ transition["critic_next_state_h"] = self.critic.hidden_state[0].detach()
+ transition["critic_next_state_c"] = self.critic.hidden_state[1].detach()
+
+ return transition
+
+ @Benchmarkable.register
+ def _compute_value_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
+ critic_kwargs = (
+ {"sample_count": self._iqn_value_samples, "taus": batch["value_target_taus"], "use_measure": False}
+ if self._critic_network_name == self.network_iqn
+ else {}
+ )
+
+ if self.recurrent:
+ observations, data = transitions_to_trajectories(batch["critic_observations"], batch["dones"])
+ hidden_state_h, _ = transitions_to_trajectories(batch["critic_state_h"], batch["dones"])
+ hidden_state_c, _ = transitions_to_trajectories(batch["critic_state_c"], batch["dones"])
+ hidden_states = (hidden_state_h[0].transpose(0, 1), hidden_state_c[0].transpose(0, 1))
+
+ if self._critic_network_name == self.network_iqn:
+ critic_kwargs["taus"], _ = transitions_to_trajectories(critic_kwargs["taus"], batch["dones"])
+
+ trajectory_evaluations = self.critic.forward(
+ observations, distribution=True, hidden_state=hidden_states, **critic_kwargs
+ )
+ trajectory_evaluations = trajectory_evaluations.reshape(*observations.shape[:-1], -1)
+
+ predictions = trajectories_to_transitions(trajectory_evaluations, data)
+ else:
+ predictions = self.critic.forward(batch["critic_observations"], distribution=True, **critic_kwargs)
+
+ value_loss = self._value_loss(self._process_quants(predictions)[0], batch["value_target_quants"])
+
+ return value_loss
+
+ def _extract_value_measure_adaptation(self, observations: torch.Tensor) -> Tuple[torch.Tensor]:
+ if self._value_measure_adaptation is None:
+ return tuple()
+
+ relevant_observations = observations[:, self._value_measure_adaptation]
+ measure_adaptations = torch.tensor_split(relevant_observations, relevant_observations.shape[1], dim=1)
+
+ return measure_adaptations
+
+ @Benchmarkable.register
+ def _process_dataset(self, dataset: Dataset) -> Dataset:
+ rewards = torch.stack([entry["rewards"] for entry in dataset])
+ dones = torch.stack([entry["dones"] for entry in dataset]).float()
+ timeouts = torch.stack([entry["timeouts"] for entry in dataset])
+ values = torch.stack([entry["values"] for entry in dataset])
+
+ value_quants_idx = [self._process_quants(entry["value_quants"]) for entry in dataset]
+ value_quants = torch.stack([entry[0] for entry in value_quants_idx])
+
+ critic_kwargs = (
+ {"hidden_state": (dataset[-1]["critic_state_h"], dataset[-1]["critic_state_c"])} if self.recurrent else {}
+ )
+ if self._critic_network_name == self.network_iqn:
+ critic_kwargs["sample_count"] = self._iqn_action_samples
+
+ measure_args = self._extract_value_measure_adaptation(dataset[-1]["full_next_critic_observations"])
+ next_values = self.critic.forward(
+ dataset[-1]["next_critic_observations"], measure_args=measure_args, **critic_kwargs
+ )
+
+ if self._critic_network_name == self.network_iqn:
+ # For IQN, we sample new (undistorted) quantiles for computing the value update
+ critic_kwargs["sample_count"] = self._iqn_value_samples
+ self.critic.forward(
+ dataset[-1]["next_critic_observations"],
+ use_measure=False,
+ **critic_kwargs,
+ )
+
+ final_value_taus = self.critic.last_taus
+ value_taus = torch.stack(
+ [
+ torch.take_along_dim(dataset[i]["value_taus"], value_quants_idx[i][1], -1)
+ for i in range(len(dataset))
+ ]
+ )
+
+ final_value_quants = self.critic.last_quantiles
+
+ # Timeout bootstrapping for rewards.
+ rewards += self.gamma * timeouts * values
+
+ # Compute advantages and value target quantiles
+ next_values = torch.cat((values[1:], next_values.unsqueeze(0)), dim=0)
+ deltas = (rewards + (1 - dones) * self.gamma * next_values - values).reshape(-1, self.env.num_envs)
+ advantages = torch.zeros((len(dataset) + 1, self.env.num_envs), device=self.device)
+
+ next_value_quants, idx = self._process_quants(final_value_quants)
+ value_target_quants = torch.zeros(len(dataset), *next_value_quants.shape, device=self.device)
+
+ if self._critic_network_name == self.network_iqn:
+ value_target_taus = torch.zeros(len(dataset) + 1, *next_value_quants.shape, device=self.device)
+ value_target_taus[-1] = torch.take_along_dim(final_value_taus, idx, -1)
+
+ for step in reversed(range(len(dataset))):
+ not_terminal = 1.0 - dones[step]
+ not_terminal_ = not_terminal.unsqueeze(-1)
+
+ advantages[step] = deltas[step] + (1.0 - dones[step]) * self.gamma * self._gae_lambda * advantages[step + 1]
+ value_target_quants[step] = rewards[step].unsqueeze(-1) + not_terminal_ * self.gamma * next_value_quants
+
+ preserved_value_quants = not_terminal_.bool() * (
+ torch.rand(next_value_quants.shape, device=self.device) < self._value_lambda
+ )
+ next_value_quants = torch.where(preserved_value_quants, value_target_quants[step], value_quants[step])
+
+ if self._critic_network_name == self.network_iqn:
+ value_target_taus[step] = torch.where(
+ preserved_value_quants, value_target_taus[step + 1], value_taus[step]
+ )
+
+ advantages = advantages[:-1]
+ if self._critic_network_name == self.network_iqn:
+ value_target_taus = value_target_taus[:-1]
+
+ # Normalize advantages and pack into dataset structure.
+ amean, astd = advantages.mean(), torch.nan_to_num(advantages.std())
+ for step in range(len(dataset)):
+ dataset[step]["advantages"] = advantages[step]
+ dataset[step]["normalized_advantages"] = (advantages[step] - amean) / (astd + 1e-8)
+ dataset[step]["value_target_quants"] = value_target_quants[step]
+
+ if self._critic_network_name == self.network_iqn:
+ dataset[step]["value_target_taus"] = value_target_taus[step]
+
+ return dataset
diff --git a/rsl_rl/algorithms/dsac.py b/rsl_rl/algorithms/dsac.py
new file mode 100644
index 0000000..94d13da
--- /dev/null
+++ b/rsl_rl/algorithms/dsac.py
@@ -0,0 +1,75 @@
+import torch
+from torch import nn
+from typing import Tuple, Type
+
+
+from rsl_rl.algorithms.sac import SAC
+from rsl_rl.env import VecEnv
+from rsl_rl.modules.quantile_network import QuantileNetwork
+
+
+class DSAC(SAC):
+ """Deep Soft Actor Critic (DSAC) algorithm.
+
+ This is an implementation of the DSAC algorithm by Ma et. al. for vectorized environments.
+
+ Paper: https://arxiv.org/pdf/2004.14547.pdf
+
+ The implementation inherits automatic tuning of the temperature parameter (alpha) and tanh action scaling from
+ the SAC implementation.
+ """
+
+ critic_network: Type[nn.Module] = QuantileNetwork
+
+ def __init__(self, env: VecEnv, critic_activations=["relu", "relu", "relu"], quantile_count=200, **kwargs):
+ """
+ Args:
+ env (VecEnv): A vectorized environment.
+ critic_activations (list): A list of activation functions to use for the critic network.
+ quantile_count (int): The number of quantiles to use for the critic QR network.
+ """
+ self._quantile_count = quantile_count
+ self._register_critic_network_kwargs(quantile_count=self._quantile_count)
+
+ kwargs["critic_activations"] = critic_activations
+
+ super().__init__(env, **kwargs)
+
+ def _update_critic(
+ self,
+ critic_obs: torch.Tensor,
+ actions: torch.Tensor,
+ rewards: torch.Tensor,
+ dones: torch.Tensor,
+ actor_next_obs: torch.Tensor,
+ critic_next_obs: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ target_action, target_action_logp = self._sample_action(actor_next_obs)
+ target_critic_input = self._critic_input(critic_next_obs, target_action)
+ target_critic_prediction_1 = self.target_critic_1.forward(target_critic_input, distribution=True)
+ target_critic_prediction_2 = self.target_critic_2.forward(target_critic_input, distribution=True)
+ target_critic_prediction = torch.minimum(target_critic_prediction_1, target_critic_prediction_2)
+
+ next_soft_q = target_critic_prediction - self.alpha * target_action_logp.unsqueeze(-1).repeat(
+ 1, self._quantile_count
+ )
+ target = (rewards.reshape(-1, 1) + self._discount_factor * (1 - dones).reshape(-1, 1) * next_soft_q).detach()
+
+ critic_input = self._critic_input(critic_obs, actions).detach()
+ critic_1_prediction = self.critic_1.forward(critic_input, distribution=True)
+ critic_1_loss = self.critic_1.quantile_huber_loss(critic_1_prediction, target)
+
+ self.critic_1_optimizer.zero_grad()
+ critic_1_loss.backward()
+ nn.utils.clip_grad_norm_(self.critic_1.parameters(), self._gradient_clip)
+ self.critic_1_optimizer.step()
+
+ critic_2_prediction = self.critic_2.forward(critic_input, distribution=True)
+ critic_2_loss = self.critic_2.quantile_huber_loss(critic_2_prediction, target)
+
+ self.critic_2_optimizer.zero_grad()
+ critic_2_loss.backward()
+ nn.utils.clip_grad_norm_(self.critic_2.parameters(), self._gradient_clip)
+ self.critic_2_optimizer.step()
+
+ return critic_1_loss, critic_2_loss
diff --git a/rsl_rl/algorithms/dtd3.py b/rsl_rl/algorithms/dtd3.py
new file mode 100644
index 0000000..e4ed595
--- /dev/null
+++ b/rsl_rl/algorithms/dtd3.py
@@ -0,0 +1,57 @@
+from __future__ import annotations
+import torch
+from torch import nn
+from typing import Type
+
+from rsl_rl.algorithms.td3 import TD3
+from rsl_rl.env import VecEnv
+from rsl_rl.modules import QuantileNetwork
+
+
+class DTD3(TD3):
+ """Distributional Twin-Delayed Deep Deterministic Policy Gradients algorithm.
+
+ This is an implementation of the TD3 algorithm by Fujimoto et. al. for vectorized environments using a QR-DQN
+ critic.
+ """
+
+ critic_network: Type[nn.Module] = QuantileNetwork
+
+ def __init__(
+ self,
+ env: VecEnv,
+ quantile_count: int = 200,
+ **kwargs,
+ ) -> None:
+ self._quantile_count = quantile_count
+ self._register_critic_network_kwargs(quantile_count=self._quantile_count)
+
+ super().__init__(env, **kwargs)
+
+ def _update_critic(self, critic_obs, actions, rewards, dones, actor_next_obs, critic_next_obs):
+ target_action = self._apply_action_noise(self.target_actor.forward(actor_next_obs), clip=True)
+ target_critic_input = self._critic_input(critic_next_obs, target_action)
+ target_critic_prediction_1 = self.target_critic_1.forward(target_critic_input, distribution=True)
+ target_critic_prediction_2 = self.target_critic_2.forward(target_critic_input, distribution=True)
+ target_critic_prediction = torch.minimum(target_critic_prediction_1, target_critic_prediction_2)
+
+ target = (
+ rewards.reshape(-1, 1) + self._discount_factor * (1 - dones).reshape(-1, 1) * target_critic_prediction
+ ).detach()
+
+ critic_input = self._critic_input(critic_obs, actions).detach()
+ critic_1_prediction = self.critic_1.forward(critic_input, distribution=True)
+ critic_1_loss = self.critic_1.quantile_huber_loss(critic_1_prediction, target)
+
+ self.critic_1_optimizer.zero_grad()
+ critic_1_loss.backward()
+ self.critic_1_optimizer.step()
+
+ critic_2_prediction = self.critic_2.forward(critic_input, distribution=True)
+ critic_2_loss = self.critic_2.quantile_huber_loss(critic_2_prediction, target)
+
+ self.critic_2_optimizer.zero_grad()
+ critic_2_loss.backward()
+ self.critic_2_optimizer.step()
+
+ return critic_1_loss, critic_2_loss
diff --git a/rsl_rl/algorithms/hybrid.py b/rsl_rl/algorithms/hybrid.py
new file mode 100644
index 0000000..11e223c
--- /dev/null
+++ b/rsl_rl/algorithms/hybrid.py
@@ -0,0 +1,193 @@
+from abc import ABC, abstractmethod
+import torch
+from typing import Callable, Dict, Tuple, Type, Union
+
+from rsl_rl.algorithms import D4PG, DSAC
+from rsl_rl.algorithms import TD3
+from rsl_rl.algorithms.actor_critic import AbstractActorCritic
+from rsl_rl.algorithms.agent import Agent
+from rsl_rl.env import VecEnv
+from rsl_rl.storage.storage import Dataset, Storage
+
+
+class AbstractHybridAgent(Agent, ABC):
+ def __init__(
+ self,
+ env: VecEnv,
+ agent_class: Type[Agent],
+ agent_kwargs: dict,
+ pretrain_agent_class: Type[Agent],
+ pretrain_agent_kwargs: dict,
+ pretrain_steps: int,
+ freeze_steps: int = 0,
+ **general_kwargs,
+ ):
+ """
+ Args:
+ env (VecEnv): A vectorized environment.
+ """
+ agent_kwargs["env"] = env
+ pretrain_agent_kwargs["env"] = env
+
+ self.agent = agent_class(**agent_kwargs, **general_kwargs)
+ self.pretrain_agent = pretrain_agent_class(**pretrain_agent_kwargs, **general_kwargs)
+
+ self._freeze_steps = freeze_steps
+ self._pretrain_steps = pretrain_steps
+ self._steps = 0
+
+ self._register_serializable("agent", "pretrain_agent", "_freeze_steps", "_pretrain_steps", "_steps")
+
+ @property
+ def active_agent(self):
+ agent = self.pretrain_agent if self.pretraining else self.agent
+
+ return agent
+
+ def draw_actions(self, *args, **kwargs) -> Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]:
+ return self.active_agent.draw_actions(*args, **kwargs)
+
+ def draw_random_actions(self, *args, **kwargs) -> Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]:
+ return self.active_agent.draw_random_actions(*args, **kwargs)
+
+ def eval_mode(self, *args, **kwargs) -> Agent:
+ self.agent.eval_mode(*args, **kwargs)
+
+ def get_inference_policy(self, *args, **kwargs) -> Callable:
+ return self.active_agent.get_inference_policy(*args, **kwargs)
+
+ @property
+ def initialized(self) -> bool:
+ return self.active_agent.initialized
+
+ @property
+ def pretraining(self):
+ return self._steps < self._pretrain_steps
+
+ def process_dataset(self, *args, **kwargs) -> Dataset:
+ return self.active_agent.process_dataset(*args, **kwargs)
+
+ def process_transition(self, *args, **kwargs) -> Dict[str, torch.Tensor]:
+ return self.active_agent.process_transition(*args, **kwargs)
+
+ def register_terminations(self, *args, **kwargs) -> None:
+ return self.active_agent.register_terminations(*args, **kwargs)
+
+ @property
+ def storage(self) -> Storage:
+ return self.active_agent.storage
+
+ def to(self, *args, **kwargs) -> Agent:
+ self.agent.to(*args, **kwargs)
+ self.pretrain_agent.to(*args, **kwargs)
+
+ def train_mode(self, *args, **kwargs) -> Agent:
+ self.agent.train_mode(*args, **kwargs)
+ self.pretrain_agent.train_mode(*args, **kwargs)
+
+ def update(self, *args, **kwargs) -> Dict[str, Union[float, torch.Tensor]]:
+ result = self.active_agent.update(*args, **kwargs)
+
+ if not self.active_agent.initialized:
+ return
+
+ self._steps += 1
+
+ if self._steps == self._pretrain_steps:
+ self._transfer_weights()
+ self._freeze_weights(freeze=True)
+
+ if self._steps == self._pretrain_steps + self._freeze_steps:
+ self._transfer_weights()
+ self._freeze_weights(freeze=False)
+
+ return result
+
+ @abstractmethod
+ def _freeze_weights(self, freeze=True):
+ pass
+
+ @abstractmethod
+ def _transfer_weights(self):
+ pass
+
+
+class HybridD4PG(AbstractHybridAgent):
+ def __init__(
+ self,
+ env: VecEnv,
+ d4pg_kwargs: dict,
+ pretrain_kwargs: dict,
+ pretrain_agent: Type[AbstractActorCritic] = TD3,
+ **kwargs,
+ ):
+ assert d4pg_kwargs["action_max"] == pretrain_kwargs["action_max"]
+ assert d4pg_kwargs["action_min"] == pretrain_kwargs["action_min"]
+ assert d4pg_kwargs["actor_activations"] == pretrain_kwargs["actor_activations"]
+ assert d4pg_kwargs["actor_hidden_dims"] == pretrain_kwargs["actor_hidden_dims"]
+ assert d4pg_kwargs["actor_input_normalization"] == pretrain_kwargs["actor_input_normalization"]
+
+ super().__init__(
+ env,
+ agent_class=D4PG,
+ agent_kwargs=d4pg_kwargs,
+ pretrain_agent_class=pretrain_agent,
+ pretrain_agent_kwargs=pretrain_kwargs,
+ **kwargs,
+ )
+
+ def _freeze_weights(self, freeze=True):
+ for param in self.agent.actor.parameters():
+ param.requires_grad = not freeze
+
+ def _transfer_weights(self):
+ self.agent.actor.load_state_dict(self.pretrain_agent.actor.state_dict())
+ self.agent.actor_optimizer.load_state_dict(self.pretrain_agent.actor_optimizer.state_dict())
+
+
+class HybridDSAC(AbstractHybridAgent):
+ def __init__(
+ self,
+ env: VecEnv,
+ dsac_kwargs: dict,
+ pretrain_kwargs: dict,
+ pretrain_agent: Type[AbstractActorCritic] = TD3,
+ **kwargs,
+ ):
+ assert dsac_kwargs["action_max"] == pretrain_kwargs["action_max"]
+ assert dsac_kwargs["action_min"] == pretrain_kwargs["action_min"]
+ assert dsac_kwargs["actor_activations"] == pretrain_kwargs["actor_activations"]
+ assert dsac_kwargs["actor_hidden_dims"] == pretrain_kwargs["actor_hidden_dims"]
+ assert dsac_kwargs["actor_input_normalization"] == pretrain_kwargs["actor_input_normalization"]
+
+ super().__init__(
+ env,
+ agent_class=DSAC,
+ agent_kwargs=dsac_kwargs,
+ pretrain_agent_class=pretrain_agent,
+ pretrain_agent_kwargs=pretrain_kwargs,
+ **kwargs,
+ )
+
+ def _freeze_weights(self, freeze=True):
+ """Freezes actor layers.
+
+ Freezes feature encoding and mean computation layers for gaussian network. Leaves log standard deviation layer
+ unfreezed.
+ """
+ for param in self.agent.actor._layers.parameters():
+ param.requires_grad = not freeze
+
+ for param in self.agent.actor._mean_layer.parameters():
+ param.requires_grad = not freeze
+
+ def _transfer_weights(self):
+ """Transfers actor layers.
+
+ Transfers only feature encoding and mean computation layers for gaussian network.
+ """
+ for i, layer in enumerate(self.agent.actor._layers):
+ layer.load_state_dict(self.pretrain_agent.actor._layers[i].state_dict())
+
+ for j, layer in enumerate(self.agent.actor._mean_layer):
+ layer.load_state_dict(self.pretrain_agent.actor._layers[i + j + 1].state_dict())
diff --git a/rsl_rl/algorithms/ppo.py b/rsl_rl/algorithms/ppo.py
index ffe5814..6a1e9b9 100644
--- a/rsl_rl/algorithms/ppo.py
+++ b/rsl_rl/algorithms/ppo.py
@@ -1,185 +1,384 @@
-# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
-# SPDX-License-Identifier: BSD-3-Clause
-
-from __future__ import annotations
-
import torch
-import torch.nn as nn
-import torch.optim as optim
+from torch import nn, optim
+from typing import Any, Dict, Tuple, Type, Union
-from rsl_rl.modules import ActorCritic
-from rsl_rl.storage import RolloutStorage
+from rsl_rl.algorithms.actor_critic import AbstractActorCritic
+from rsl_rl.env import VecEnv
+from rsl_rl.utils.benchmarkable import Benchmarkable
+from rsl_rl.utils.recurrency import trajectories_to_transitions, transitions_to_trajectories
+from rsl_rl.modules import GaussianNetwork, Network
+from rsl_rl.storage.rollout_storage import RolloutStorage
+from rsl_rl.storage.storage import Dataset
-class PPO:
- actor_critic: ActorCritic
+class PPO(AbstractActorCritic):
+ """Proximal Policy Optimization algorithm.
+
+ This is an implementation of the PPO algorithm by Schulman et. al. for vectorized environments.
+
+ Paper: https://arxiv.org/pdf/1707.06347.pdf
+
+ The implementation works with recurrent neural networks. We implement adaptive learning rate based on the
+ KL-divergence between the old and new policy, as described by Schulman et. al. in
+ https://arxiv.org/pdf/1707.06347.pdf.
+ """
+
+ critic_network: Type[nn.Module] = Network
+ _alg_features = dict(recurrent=True)
+
+ schedule_adaptive = "adaptive"
+ schedule_fixed = "fixed"
def __init__(
self,
- actor_critic,
- num_learning_epochs=1,
- num_mini_batches=1,
- clip_param=0.2,
- gamma=0.998,
- lam=0.95,
- value_loss_coef=1.0,
- entropy_coef=0.0,
- learning_rate=1e-3,
- max_grad_norm=1.0,
- use_clipped_value_loss=True,
- schedule="fixed",
- desired_kl=0.01,
- device="cpu",
+ env: VecEnv,
+ actor_noise_std: float = 1.0,
+ clip_ratio: float = 0.2,
+ entropy_coeff: float = 0.0,
+ gae_lambda: float = 0.97,
+ gradient_clip: float = 1.0,
+ learning_rate: float = 1e-3,
+ schedule: str = "fixed",
+ target_kl: float = 0.01,
+ value_coeff: float = 1.0,
+ **kwargs,
):
- self.device = device
+ """
+ Args:
+ env (VecEnv): A vectorized environment.
+ actor_noise_std (float): The standard deviation of the Gaussian noise to add to the actor network output.
+ clip_ratio (float): The clipping ratio for the PPO objective.
+ entropy_coeff (float): The coefficient for the entropy term in the PPO objective.
+ gae_lambda (float): The lambda parameter for the GAE computation.
+ gradient_clip (float): The gradient clipping value.
+ learning_rate (float): The learning rate for the actor and critic networks.
+ schedule (str): The learning rate schedule. Can be "fixed" or "adaptive". Defaults to "fixed".
+ target_kl (float): The target KL-divergence for the adaptive learning rate schedule.
+ value_coeff (float): The coefficient for the value function loss in the PPO objective.
+ """
+ kwargs["batch_size"] = env.num_envs
+ kwargs["return_steps"] = 1
- self.desired_kl = desired_kl
- self.schedule = schedule
- self.learning_rate = learning_rate
+ super().__init__(env, **kwargs)
+ self._critic_input_size = self._critic_input_size - self._action_size # We use a state-value function (not Q)
- # PPO components
- self.actor_critic = actor_critic
- self.actor_critic.to(self.device)
- self.storage = None # initialized later
- self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=learning_rate)
- self.transition = RolloutStorage.Transition()
+ self.storage = RolloutStorage(self.env.num_envs, device=self.device)
- # PPO parameters
- self.clip_param = clip_param
- self.num_learning_epochs = num_learning_epochs
- self.num_mini_batches = num_mini_batches
- self.value_loss_coef = value_loss_coef
- self.entropy_coef = entropy_coef
- self.gamma = gamma
- self.lam = lam
- self.max_grad_norm = max_grad_norm
- self.use_clipped_value_loss = use_clipped_value_loss
+ self._register_serializable("storage")
- def init_storage(self, num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape):
- self.storage = RolloutStorage(
- num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape, self.device
+ self._clip_ratio = clip_ratio
+ self._entropy_coeff = entropy_coeff
+ self._gae_lambda = gae_lambda
+ self._gradient_clip = gradient_clip
+ self._schedule = schedule
+ self._target_kl = target_kl
+ self._value_coeff = value_coeff
+
+ self._register_serializable(
+ "_clip_ratio",
+ "_entropy_coeff",
+ "_gae_lambda",
+ "_gradient_clip",
+ "_schedule",
+ "_target_kl",
+ "_value_coeff",
)
- def test_mode(self):
- self.actor_critic.test()
+ self.actor = GaussianNetwork(
+ self._actor_input_size, self._action_size, std_init=actor_noise_std, **self._actor_network_kwargs
+ )
+ self.critic = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs)
- def train_mode(self):
- self.actor_critic.train()
+ if self.recurrent:
+ self.actor.reset_full_hidden_state(batch_size=self.env.num_envs)
+ self.critic.reset_full_hidden_state(batch_size=self.env.num_envs)
- def act(self, obs, critic_obs):
- if self.actor_critic.is_recurrent:
- self.transition.hidden_states = self.actor_critic.get_hidden_states()
- # Compute the actions and values
- self.transition.actions = self.actor_critic.act(obs).detach()
- self.transition.values = self.actor_critic.evaluate(critic_obs).detach()
- self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach()
- self.transition.action_mean = self.actor_critic.action_mean.detach()
- self.transition.action_sigma = self.actor_critic.action_std.detach()
- # need to record obs and critic_obs before env.step()
- self.transition.observations = obs
- self.transition.critic_observations = critic_obs
- return self.transition.actions
+ tp = lambda v: v.transpose(0, 1) # for storing, transpose (num_layers, batch, F) to (batch, num_layers, F)
+ self.storage.register_processor("actor_state_h", tp)
+ self.storage.register_processor("actor_state_c", tp)
+ self.storage.register_processor("critic_state_h", tp)
+ self.storage.register_processor("critic_state_c", tp)
+ self.storage.register_processor("critic_next_state_h", tp)
+ self.storage.register_processor("critic_next_state_c", tp)
- def process_env_step(self, rewards, dones, infos):
- self.transition.rewards = rewards.clone()
- self.transition.dones = dones
- # Bootstrapping on time outs
- if "time_outs" in infos:
- self.transition.rewards += self.gamma * torch.squeeze(
- self.transition.values * infos["time_outs"].unsqueeze(1).to(self.device), 1
- )
+ self._bm_fuse(self.actor, prefix="actor.")
+ self._bm_fuse(self.critic, prefix="critic.")
- # Record the transition
- self.storage.add_transitions(self.transition)
- self.transition.clear()
- self.actor_critic.reset(dones)
+ self._register_serializable("actor", "critic")
- def compute_returns(self, last_critic_obs):
- last_values = self.actor_critic.evaluate(last_critic_obs).detach()
- self.storage.compute_returns(last_values, self.gamma, self.lam)
+ self.learning_rate = learning_rate
+ self.optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
- def update(self):
- mean_value_loss = 0
- mean_surrogate_loss = 0
- if self.actor_critic.is_recurrent:
- generator = self.storage.reccurent_mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
- else:
- generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
- for (
- obs_batch,
- critic_obs_batch,
- actions_batch,
- target_values_batch,
- advantages_batch,
- returns_batch,
- old_actions_log_prob_batch,
- old_mu_batch,
- old_sigma_batch,
- hid_states_batch,
- masks_batch,
- ) in generator:
- self.actor_critic.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0])
- actions_log_prob_batch = self.actor_critic.get_actions_log_prob(actions_batch)
- value_batch = self.actor_critic.evaluate(
- critic_obs_batch, masks=masks_batch, hidden_states=hid_states_batch[1]
- )
- mu_batch = self.actor_critic.action_mean
- sigma_batch = self.actor_critic.action_std
- entropy_batch = self.actor_critic.entropy
+ self._register_serializable("learning_rate", "optimizer")
- # KL
- if self.desired_kl is not None and self.schedule == "adaptive":
- with torch.inference_mode():
- kl = torch.sum(
- torch.log(sigma_batch / old_sigma_batch + 1.0e-5)
- + (torch.square(old_sigma_batch) + torch.square(old_mu_batch - mu_batch))
- / (2.0 * torch.square(sigma_batch))
- - 0.5,
- axis=-1,
- )
- kl_mean = torch.mean(kl)
+ def draw_random_actions(self, obs: torch.Tensor, env_info: Dict[str, Any]) -> torch.Tensor:
+ raise NotImplementedError("PPO does not support drawing random actions.")
- if kl_mean > self.desired_kl * 2.0:
- self.learning_rate = max(1e-5, self.learning_rate / 1.5)
- elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0:
- self.learning_rate = min(1e-2, self.learning_rate * 1.5)
+ @Benchmarkable.register
+ def draw_actions(
+ self, obs: torch.Tensor, env_info: Dict[str, Any]
+ ) -> Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]:
+ actor_obs, critic_obs = self._process_observations(obs, env_info)
- for param_group in self.optimizer.param_groups:
- param_group["lr"] = self.learning_rate
+ data = {}
- # Surrogate loss
- ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch))
- surrogate = -torch.squeeze(advantages_batch) * ratio
- surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp(
- ratio, 1.0 - self.clip_param, 1.0 + self.clip_param
- )
- surrogate_loss = torch.max(surrogate, surrogate_clipped).mean()
+ if self.recurrent:
+ data["actor_state_h"] = self.actor.hidden_state[0].detach()
+ data["actor_state_c"] = self.actor.hidden_state[1].detach()
- # Value function loss
- if self.use_clipped_value_loss:
- value_clipped = target_values_batch + (value_batch - target_values_batch).clamp(
- -self.clip_param, self.clip_param
- )
- value_losses = (value_batch - returns_batch).pow(2)
- value_losses_clipped = (value_clipped - returns_batch).pow(2)
- value_loss = torch.max(value_losses, value_losses_clipped).mean()
+ mean, std = self.actor.forward(actor_obs, compute_std=True)
+ action_distribution = torch.distributions.Normal(mean, std)
+ actions = self._process_actions(action_distribution.rsample()).detach()
+ action_prediction_logp = action_distribution.log_prob(actions).sum(-1)
+
+ data["actor_observations"] = actor_obs
+ data["critic_observations"] = critic_obs
+ data["actions_logp"] = action_prediction_logp.detach()
+ data["actions_mean"] = action_distribution.mean.detach()
+ data["actions_std"] = action_distribution.stddev.detach()
+
+ return actions, data
+
+ def eval_mode(self) -> AbstractActorCritic:
+ super().eval_mode()
+
+ self.actor.eval()
+ self.critic.eval()
+
+ return self
+
+ @property
+ def initialized(self) -> bool:
+ return True
+
+ @Benchmarkable.register
+ def process_transition(self, *args) -> Dict[str, torch.Tensor]:
+ transition = super().process_transition(*args)
+
+ if self.recurrent:
+ transition["critic_state_h"] = self.critic.hidden_state[0].detach()
+ transition["critic_state_c"] = self.critic.hidden_state[1].detach()
+
+ transition["values"] = self.critic.forward(transition["critic_observations"]).detach()
+
+ if self.recurrent:
+ transition["critic_next_state_h"] = self.critic.hidden_state[0].detach()
+ transition["critic_next_state_c"] = self.critic.hidden_state[1].detach()
+
+ return transition
+
+ def parameters(self):
+ params = list(self.actor.parameters()) + list(self.critic.parameters())
+
+ return params
+
+ def register_terminations(self, terminations: torch.Tensor) -> None:
+ """Registers terminations with the agent.
+
+ Args:
+ terminations (torch.Tensor): A 1-dimensional int tensor containing the indices of the terminated
+ environments.
+ """
+ if terminations.shape[0] == 0:
+ return
+
+ if self.recurrent:
+ self.actor.reset_hidden_state(terminations)
+ self.critic.reset_hidden_state(terminations)
+
+ def to(self, device: str) -> AbstractActorCritic:
+ super().to(device)
+
+ self.actor.to(device)
+ self.critic.to(device)
+
+ return self
+
+ def train_mode(self) -> AbstractActorCritic:
+ super().train_mode()
+
+ self.actor.train()
+ self.critic.train()
+
+ return self
+
+ @Benchmarkable.register
+ def update(self, dataset: Dataset) -> Dict[str, Union[float, torch.Tensor]]:
+ super().update(dataset)
+
+ assert self.storage.initialized
+
+ total_loss = torch.zeros(self._batch_count)
+ total_surrogate_loss = torch.zeros(self._batch_count)
+ total_value_loss = torch.zeros(self._batch_count)
+
+ for idx, batch in enumerate(self.storage.batch_generator(self._batch_count, trajectories=self.recurrent)):
+ if self.recurrent:
+ transition_obs = batch["actor_observations"].reshape(*batch["actor_observations"].shape[:2], -1)
+ observations, data = transitions_to_trajectories(transition_obs, batch["dones"])
+ hidden_state_h, _ = transitions_to_trajectories(batch["actor_state_h"], batch["dones"])
+ hidden_state_c, _ = transitions_to_trajectories(batch["actor_state_c"], batch["dones"])
+ # Init. sequence with each trajectory's first hidden state. Subsequent hidden states are produced by the
+ # network, depending on the previous hidden state and the current observation.
+ hidden_state = (hidden_state_h[0].transpose(0, 1), hidden_state_c[0].transpose(0, 1))
+
+ action_mean, action_std = self.actor.forward(observations, hidden_state=hidden_state, compute_std=True)
+
+ action_mean = action_mean.reshape(*observations.shape[:-1], self._action_size)
+ action_std = action_std.reshape(*observations.shape[:-1], self._action_size)
+
+ action_mean = trajectories_to_transitions(action_mean, data)
+ action_std = trajectories_to_transitions(action_std, data)
else:
- value_loss = (returns_batch - value_batch).pow(2).mean()
+ action_mean, action_std = self.actor.forward(batch["actor_observations"], compute_std=True)
- loss = surrogate_loss + self.value_loss_coef * value_loss - self.entropy_coef * entropy_batch.mean()
+ actions_dist = torch.distributions.Normal(action_mean, action_std)
+
+ if self._schedule == self.schedule_adaptive:
+ self._update_learning_rate(batch, actions_dist)
+
+ surrogate_loss = self._compute_actor_loss(batch, actions_dist)
+ value_loss = self._compute_value_loss(batch)
+ actions_entropy = actions_dist.entropy().sum(-1)
+
+ loss = surrogate_loss + self._value_coeff * value_loss - self._entropy_coeff * actions_entropy.mean()
- # Gradient step
self.optimizer.zero_grad()
loss.backward()
- nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.max_grad_norm)
+ nn.utils.clip_grad_norm_(self.parameters(), self._gradient_clip)
self.optimizer.step()
- mean_value_loss += value_loss.item()
- mean_surrogate_loss += surrogate_loss.item()
+ total_loss[idx] = loss.detach()
+ total_surrogate_loss[idx] = surrogate_loss.detach()
+ total_value_loss[idx] = value_loss.detach()
- num_updates = self.num_learning_epochs * self.num_mini_batches
- mean_value_loss /= num_updates
- mean_surrogate_loss /= num_updates
- self.storage.clear()
+ stats = {
+ "total": total_loss.mean().item(),
+ "surrogate": total_surrogate_loss.mean().item(),
+ "value": total_value_loss.mean().item(),
+ }
- return mean_value_loss, mean_surrogate_loss
+ return stats
+
+ @Benchmarkable.register
+ def _compute_actor_loss(
+ self, batch: Dict[str, torch.Tensor], actions_dist: torch.distributions.Normal
+ ) -> torch.Tensor:
+ batch_actions_logp = actions_dist.log_prob(batch["actions"]).sum(-1)
+
+ ratio = (batch_actions_logp - batch["actions_logp"]).exp()
+ surrogate = batch["normalized_advantages"] * ratio
+ surrogate_clipped = batch["normalized_advantages"] * ratio.clamp(1.0 - self._clip_ratio, 1.0 + self._clip_ratio)
+ surrogate_loss = -torch.min(surrogate, surrogate_clipped).mean()
+
+ return surrogate_loss
+
+ @Benchmarkable.register
+ def _compute_value_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
+ if self.recurrent:
+ observations, data = transitions_to_trajectories(batch["critic_observations"], batch["dones"])
+ hidden_state_h, _ = transitions_to_trajectories(batch["critic_state_h"], batch["dones"])
+ hidden_state_c, _ = transitions_to_trajectories(batch["critic_state_c"], batch["dones"])
+ hidden_states = (hidden_state_h[0].transpose(0, 1), hidden_state_c[0].transpose(0, 1))
+
+ trajectory_evaluations = self.critic.forward(observations, hidden_state=hidden_states)
+ trajectory_evaluations = trajectory_evaluations.reshape(*observations.shape[:-1])
+
+ evaluation = trajectories_to_transitions(trajectory_evaluations, data)
+ else:
+ evaluation = self.critic.forward(batch["critic_observations"])
+
+ value_clipped = batch["values"] + (evaluation - batch["values"]).clamp(-self._clip_ratio, self._clip_ratio)
+ returns = batch["advantages"] + batch["values"]
+ value_losses = (evaluation - returns).pow(2)
+ value_losses_clipped = (value_clipped - returns).pow(2)
+
+ value_loss = torch.max(value_losses, value_losses_clipped).mean()
+
+ return value_loss
+
+ def _critic_input(self, observations, actions=None) -> torch.Tensor:
+ return observations
+
+ def _entry_to_hs(
+ self, entry: Dict[str, torch.Tensor], critic: bool = False, next: bool = False
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Helper function to turn a dataset entry into a hidden state tuple.
+
+ Args:
+ entry (Dict[str, torch.Tensor]): The dataset entry.
+ critic (bool): Whether to extract the hidden state for the critic instead of the actor. Defaults to False.
+ next (bool): Whether the hidden state is for the next step or the current. Defaults to False.
+ Returns:
+ A tuple of hidden state tensors.
+ """
+ key = ("critic" if critic else "actor") + "_" + ("next_state" if next else "state")
+ hidden_state = entry[f"{key}_h"], entry[f"{key}_c"]
+
+ return hidden_state
+
+ @Benchmarkable.register
+ def _process_dataset(self, dataset: Dataset) -> Dataset:
+ """Processes a dataset before it is added to the replay memory.
+
+ Computes advantages and returns.
+
+ Args:
+ dataset (Dataset): The dataset to process.
+ Returns:
+ A Dataset object containing the processed data.
+ """
+ rewards = torch.stack([entry["rewards"] for entry in dataset])
+ dones = torch.stack([entry["dones"] for entry in dataset])
+ timeouts = torch.stack([entry["timeouts"] for entry in dataset])
+ values = torch.stack([entry["values"] for entry in dataset])
+
+ # We could alternatively compute the next hidden state from the current state and hidden state. But this
+ # (storing the hidden state when evaluating the action in process_transition) is computationally more efficient
+ # and doesn't change the result as the network is not updated between storing the data and computing advantages.
+
+ critic_kwargs = (
+ {"hidden_state": (dataset[-1]["critic_state_h"], dataset[-1]["critic_state_c"])} if self.recurrent else {}
+ )
+ final_values = self.critic.forward(dataset[-1]["next_critic_observations"], **critic_kwargs)
+ next_values = torch.cat((values[1:], final_values.unsqueeze(0)), dim=0)
+
+ rewards += self.gamma * timeouts * values
+ deltas = (rewards + (1 - dones).float() * self.gamma * next_values - values).reshape(-1, self.env.num_envs)
+
+ advantages = torch.zeros((len(dataset) + 1, self.env.num_envs), device=self.device)
+ for step in reversed(range(len(dataset))):
+ advantages[step] = (
+ deltas[step] + (1 - dones[step]).float() * self.gamma * self._gae_lambda * advantages[step + 1]
+ )
+ advantages = advantages[:-1]
+
+ amean, astd = advantages.mean(), torch.nan_to_num(advantages.std())
+ for step in range(len(dataset)):
+ dataset[step]["advantages"] = advantages[step]
+ dataset[step]["normalized_advantages"] = (advantages[step] - amean) / (astd + 1e-8)
+
+ return dataset
+
+ @Benchmarkable.register
+ def _update_learning_rate(self, batch: Dict[str, torch.Tensor], actions_dist: torch.distributions.Normal) -> None:
+ with torch.inference_mode():
+ actions_mean = actions_dist.mean
+ actions_std = actions_dist.stddev
+
+ kl = torch.sum(
+ torch.log(actions_std / batch["actions_std"] + 1.0e-5)
+ + (torch.square(batch["actions_std"]) + torch.square(batch["actions_mean"] - actions_mean))
+ / (2.0 * torch.square(actions_std))
+ - 0.5,
+ axis=-1,
+ )
+ kl_mean = torch.mean(kl)
+
+ if kl_mean > self._target_kl * 2.0:
+ self.learning_rate = max(1e-5, self.learning_rate / 1.5)
+ elif kl_mean < self._target_kl / 2.0 and kl_mean > 0.0:
+ self.learning_rate = min(1e-2, self.learning_rate * 1.5)
+
+ for param_group in self.optimizer.param_groups:
+ param_group["lr"] = self.learning_rate
diff --git a/rsl_rl/algorithms/sac.py b/rsl_rl/algorithms/sac.py
new file mode 100644
index 0000000..2f86f94
--- /dev/null
+++ b/rsl_rl/algorithms/sac.py
@@ -0,0 +1,319 @@
+import numpy as np
+import torch
+from torch import nn, optim
+from typing import Any, Callable, Dict, Tuple, Type, Union
+
+from rsl_rl.algorithms.actor_critic import AbstractActorCritic
+from rsl_rl.env import VecEnv
+from rsl_rl.modules import Network, GaussianChimeraNetwork, GaussianNetwork
+from rsl_rl.storage.replay_storage import ReplayStorage
+from rsl_rl.storage.storage import Dataset
+
+
+class SAC(AbstractActorCritic):
+ """Soft Actor Critic algorithm.
+
+ This is an implementation of the SAC algorithm by Haarnoja et. al. for vectorized environments.
+
+ Paper: https://arxiv.org/pdf/1801.01290.pdf
+
+ We additionally implement automatic tuning of the temperature parameter (alpha) and tanh action scaling, as
+ introduced by Haarnoja et. al. in https://arxiv.org/pdf/1812.05905.pdf.
+ """
+
+ critic_network: Type[nn.Module] = Network
+
+ def __init__(
+ self,
+ env: VecEnv,
+ action_max: float = 100.0,
+ action_min: float = -100.0,
+ actor_lr: float = 1e-4,
+ actor_noise_std: float = 1.0,
+ alpha: float = 0.2,
+ alpha_lr: float = 1e-3,
+ chimera: bool = True,
+ critic_lr: float = 1e-3,
+ gradient_clip: float = 1.0,
+ log_std_max: float = 4.0,
+ log_std_min: float = -20.0,
+ storage_initial_size: int = 0,
+ storage_size: int = 100000,
+ target_entropy: float = None,
+ **kwargs
+ ):
+ """
+ Args:
+ env (VecEnv): A vectorized environment.
+ actor_lr (float): Learning rate for the actor.
+ alpha (float): Initial entropy regularization coefficient.
+ alpha_lr (float): Learning rate for entropy regularization coefficient.
+ chimera (bool): Whether to use separate heads for computing action mean and std (True) or treat the std as a
+ tunable parameter (True).
+ critic_lr (float): Learning rate for the critic.
+ gradient_clip (float): Gradient clip value.
+ log_std_max (float): Maximum log standard deviation.
+ log_std_min (float): Minimum log standard deviation.
+ storage_initial_size (int): Initial size of the replay storage.
+ storage_size (int): Maximum size of the replay storage.
+ target_entropy (float): Target entropy for the actor policy. Defaults to action space dimensionality.
+ """
+ super().__init__(env, action_max=action_max, action_min=action_min, **kwargs)
+
+ self.storage = ReplayStorage(
+ self.env.num_envs, storage_size, device=self.device, initial_size=storage_initial_size
+ )
+
+ self._register_serializable("storage")
+
+ assert self._action_max < np.inf, 'Parameter "action_max" needs to be set for SAC.'
+ assert self._action_min > -np.inf, 'Parameter "action_min" needs to be set for SAC.'
+
+ self._action_delta = 0.5 * (self._action_max - self._action_min)
+ self._action_offset = 0.5 * (self._action_max + self._action_min)
+
+ self.log_alpha = torch.tensor(np.log(alpha), dtype=torch.float32).requires_grad_()
+ self._gradient_clip = gradient_clip
+ self._target_entropy = target_entropy if target_entropy else -self._action_size
+
+ self._register_serializable("log_alpha", "_gradient_clip")
+
+ network_class = GaussianChimeraNetwork if chimera else GaussianNetwork
+ self.actor = network_class(
+ self._actor_input_size,
+ self._action_size,
+ log_std_max=log_std_max,
+ log_std_min=log_std_min,
+ std_init=actor_noise_std,
+ **self._actor_network_kwargs
+ )
+
+ self.critic_1 = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs)
+ self.critic_2 = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs)
+
+ self.target_critic_1 = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs)
+ self.target_critic_2 = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs)
+ self.target_critic_1.load_state_dict(self.critic_1.state_dict())
+ self.target_critic_2.load_state_dict(self.critic_2.state_dict())
+
+ self._register_serializable("actor", "critic_1", "critic_2", "target_critic_1", "target_critic_2")
+
+ self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)
+ self.log_alpha_optimizer = optim.Adam([self.log_alpha], lr=alpha_lr)
+ self.critic_1_optimizer = optim.Adam(self.critic_1.parameters(), lr=critic_lr)
+ self.critic_2_optimizer = optim.Adam(self.critic_2.parameters(), lr=critic_lr)
+
+ self._register_serializable(
+ "actor_optimizer", "log_alpha_optimizer", "critic_1_optimizer", "critic_2_optimizer"
+ )
+
+ self.critic = self.critic_1
+
+ @property
+ def alpha(self):
+ return self.log_alpha.exp()
+
+ def draw_actions(
+ self, obs: torch.Tensor, env_info: Dict[str, Any]
+ ) -> Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]:
+ actor_obs, critic_obs = self._process_observations(obs, env_info)
+
+ action = self._sample_action(actor_obs, compute_logp=False)
+ data = {"actor_observations": actor_obs.clone(), "critic_observations": critic_obs.clone()}
+
+ return action, data
+
+ def eval_mode(self) -> AbstractActorCritic:
+ super().eval_mode()
+
+ self.actor.eval()
+ self.critic_1.eval()
+ self.critic_2.eval()
+ self.target_critic_1.eval()
+ self.target_critic_2.eval()
+
+ return self
+
+ def get_inference_policy(self, device=None) -> Callable:
+ self.to(device)
+ self.eval_mode()
+
+ def policy(obs, env_info=None):
+ obs, _ = self._process_observations(obs, env_info)
+ actions = self._scale_actions(self.actor.forward(obs))
+
+ # actions, _ = self.draw_actions(obs, env_info)
+
+ return actions
+
+ return policy
+
+ def register_terminations(self, terminations: torch.Tensor) -> None:
+ pass
+
+ def to(self, device: str) -> AbstractActorCritic:
+ """Transfers agent parameters to device."""
+ super().to(device)
+
+ self.actor.to(device)
+ self.critic_1.to(device)
+ self.critic_2.to(device)
+ self.target_critic_1.to(device)
+ self.target_critic_2.to(device)
+ self.log_alpha.to(device)
+
+ return self
+
+ def train_mode(self) -> AbstractActorCritic:
+ super().train_mode()
+
+ self.actor.train()
+ self.critic_1.train()
+ self.critic_2.train()
+ self.target_critic_1.train()
+ self.target_critic_2.train()
+
+ return self
+
+ def update(self, dataset: Dataset) -> Dict[str, Union[float, torch.Tensor]]:
+ super().update(dataset)
+
+ if not self.initialized:
+ return {}
+
+ total_actor_loss = torch.zeros(self._batch_count)
+ total_alpha_loss = torch.zeros(self._batch_count)
+ total_critic_1_loss = torch.zeros(self._batch_count)
+ total_critic_2_loss = torch.zeros(self._batch_count)
+
+ for idx, batch in enumerate(self.storage.batch_generator(self._batch_size, self._batch_count)):
+ actor_obs = batch["actor_observations"]
+ critic_obs = batch["critic_observations"]
+ actions = batch["actions"].reshape(-1, self._action_size)
+ rewards = batch["rewards"]
+ actor_next_obs = batch["next_actor_observations"]
+ critic_next_obs = batch["next_critic_observations"]
+ dones = batch["dones"]
+
+ critic_1_loss, critic_2_loss = self._update_critic(
+ critic_obs, actions, rewards, dones, actor_next_obs, critic_next_obs
+ )
+ actor_loss, alpha_loss = self._update_actor_and_alpha(actor_obs, critic_obs)
+
+ # Update Target Networks
+
+ self._update_target(self.critic_1, self.target_critic_1)
+ self._update_target(self.critic_2, self.target_critic_2)
+
+ total_actor_loss[idx] = actor_loss.item()
+ total_alpha_loss[idx] = alpha_loss.item()
+ total_critic_1_loss[idx] = critic_1_loss.item()
+ total_critic_2_loss[idx] = critic_2_loss.item()
+
+ stats = {
+ "actor": total_actor_loss.mean().item(),
+ "alpha": total_alpha_loss.mean().item(),
+ "critic1": total_critic_1_loss.mean().item(),
+ "critic2": total_critic_2_loss.mean().item(),
+ }
+
+ return stats
+
+ def _sample_action(
+ self, observation: torch.Tensor, compute_logp=True
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, float]]:
+ """Samples and action from the policy.
+
+ Args:
+ observation (torch.Tensor): The observation to sample an action for.
+ compute_logp (bool): Whether to compute and return the action log probability. Default to True.
+ Returns:
+ Either the action as a torch.Tensor or, if compute_logp is set to true, a tuple containing the actions as a
+ torch.Tensor and the action log probability.
+ """
+ mean, std = self.actor.forward(observation, compute_std=True)
+ dist = torch.distributions.Normal(mean, std)
+
+ actions = dist.rsample()
+ actions_normalized, actions_scaled = self._scale_actions(actions, intermediate=True)
+
+ if not compute_logp:
+ return actions_scaled
+
+ action_logp = dist.log_prob(actions).sum(-1) - torch.log(1.0 - actions_normalized.pow(2) + 1e-6).sum(-1)
+
+ return actions_scaled, action_logp
+
+ def _scale_actions(self, actions: torch.Tensor, intermediate=False) -> torch.Tensor:
+ actions = actions.reshape(-1, self._action_size)
+ action_normalized = torch.tanh(actions)
+ action_scaled = super()._process_actions(action_normalized * self._action_delta + self._action_offset)
+
+ if intermediate:
+ return action_normalized, action_scaled
+
+ return action_scaled
+
+ def _update_actor_and_alpha(
+ self, actor_obs: torch.Tensor, critic_obs: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ actor_prediction, actor_prediction_logp = self._sample_action(actor_obs)
+
+ # Update alpha (also called temperature / entropy coefficient)
+ alpha_loss = -(self.log_alpha * (actor_prediction_logp + self._target_entropy).detach()).mean()
+
+ self.log_alpha_optimizer.zero_grad()
+ alpha_loss.backward()
+ self.log_alpha_optimizer.step()
+
+ # Update actor
+ evaluation_input = self._critic_input(critic_obs, actor_prediction)
+ evaluation_1 = self.critic_1.forward(evaluation_input)
+ evaluation_2 = self.critic_2.forward(evaluation_input)
+ actor_loss = (self.alpha.detach() * actor_prediction_logp - torch.min(evaluation_1, evaluation_2)).mean()
+
+ self.actor_optimizer.zero_grad()
+ actor_loss.backward()
+ nn.utils.clip_grad_norm_(self.actor.parameters(), self._gradient_clip)
+ self.actor_optimizer.step()
+
+ return actor_loss, alpha_loss
+
+ def _update_critic(
+ self,
+ critic_obs: torch.Tensor,
+ actions: torch.Tensor,
+ rewards: torch.Tensor,
+ dones: torch.Tensor,
+ actor_next_obs: torch.Tensor,
+ critic_next_obs: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ with torch.no_grad():
+ target_action, target_action_logp = self._sample_action(actor_next_obs)
+ target_critic_input = self._critic_input(critic_next_obs, target_action)
+ target_critic_prediction_1 = self.target_critic_1.forward(target_critic_input)
+ target_critic_prediction_2 = self.target_critic_2.forward(target_critic_input)
+
+ target_next = (
+ torch.min(target_critic_prediction_1, target_critic_prediction_2) - self.alpha * target_action_logp
+ )
+ target = rewards + self._discount_factor * (1 - dones) * target_next
+
+ critic_input = self._critic_input(critic_obs, actions)
+ critic_prediction_1 = self.critic_1.forward(critic_input)
+ critic_1_loss = nn.functional.mse_loss(critic_prediction_1, target)
+
+ self.critic_1_optimizer.zero_grad()
+ critic_1_loss.backward()
+ nn.utils.clip_grad_norm_(self.critic_1.parameters(), self._gradient_clip)
+ self.critic_1_optimizer.step()
+
+ critic_prediction_2 = self.critic_2.forward(critic_input)
+ critic_2_loss = nn.functional.mse_loss(critic_prediction_2, target)
+
+ self.critic_2_optimizer.zero_grad()
+ critic_2_loss.backward()
+ nn.utils.clip_grad_norm_(self.critic_2.parameters(), self._gradient_clip)
+ self.critic_2_optimizer.step()
+
+ return critic_1_loss, critic_2_loss
diff --git a/rsl_rl/algorithms/td3.py b/rsl_rl/algorithms/td3.py
new file mode 100644
index 0000000..7e5850f
--- /dev/null
+++ b/rsl_rl/algorithms/td3.py
@@ -0,0 +1,198 @@
+from __future__ import annotations
+import torch
+from torch import nn, optim
+from typing import Dict, Type, Union
+
+from rsl_rl.algorithms.dpg import AbstractDPG
+from rsl_rl.env import VecEnv
+from rsl_rl.modules.network import Network
+from rsl_rl.storage.storage import Dataset
+
+
+class TD3(AbstractDPG):
+ """Twin-Delayed Deep Deterministic Policy Gradients algorithm.
+
+ This is an implementation of the TD3 algorithm by Fujimoto et. al. for vectorized environments.
+
+ Paper: https://arxiv.org/pdf/1802.09477.pdf
+ """
+
+ critic_network: Type[nn.Module] = Network
+
+ def __init__(
+ self,
+ env: VecEnv,
+ actor_lr: float = 1e-4,
+ critic_lr: float = 1e-3,
+ noise_clip: float = 0.5,
+ policy_delay: int = 2,
+ target_noise_scale: float = 0.2,
+ **kwargs,
+ ) -> None:
+ super().__init__(env, **kwargs)
+
+ self._noise_clip = noise_clip
+ self._policy_delay = policy_delay
+ self._target_noise_scale = target_noise_scale
+
+ self._register_serializable("_noise_clip", "_policy_delay", "_target_noise_scale")
+
+ self.actor = Network(self._actor_input_size, self._action_size, **self._actor_network_kwargs)
+ self.critic_1 = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs)
+ self.critic_2 = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs)
+
+ self.target_actor = Network(self._actor_input_size, self._action_size, **self._actor_network_kwargs)
+ self.target_critic_1 = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs)
+ self.target_critic_2 = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs)
+ self.target_actor.load_state_dict(self.actor.state_dict())
+ self.target_critic_1.load_state_dict(self.critic_1.state_dict())
+ self.target_critic_2.load_state_dict(self.critic_2.state_dict())
+
+ self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)
+ self.critic_1_optimizer = optim.Adam(self.critic_1.parameters(), lr=critic_lr)
+ self.critic_2_optimizer = optim.Adam(self.critic_2.parameters(), lr=critic_lr)
+
+ self._update_step = 0
+
+ self._register_serializable(
+ "actor",
+ "critic_1",
+ "critic_2",
+ "target_actor",
+ "target_critic_1",
+ "target_critic_2",
+ "actor_optimizer",
+ "critic_1_optimizer",
+ "critic_2_optimizer",
+ "_update_step",
+ )
+
+ self.critic = self.critic_1
+ self.to(self.device)
+
+ def eval_mode(self) -> TD3:
+ super().eval_mode()
+
+ self.actor.eval()
+ self.critic_1.eval()
+ self.critic_2.eval()
+ self.target_actor.eval()
+ self.target_critic_1.eval()
+ self.target_critic_2.eval()
+
+ return self
+
+ def to(self, device: str) -> TD3:
+ """Transfers agent parameters to device."""
+ super().to(device)
+
+ self.actor.to(device)
+ self.critic_1.to(device)
+ self.critic_2.to(device)
+ self.target_actor.to(device)
+ self.target_critic_1.to(device)
+ self.target_critic_2.to(device)
+
+ return self
+
+ def train_mode(self) -> TD3:
+ super().train_mode()
+
+ self.actor.train()
+ self.critic_1.train()
+ self.critic_2.train()
+ self.target_actor.train()
+ self.target_critic_1.train()
+ self.target_critic_2.train()
+
+ return self
+
+ def _apply_action_noise(self, actions: torch.Tensor, clip=False) -> torch.Tensor:
+ noise = torch.normal(torch.zeros_like(actions), torch.ones_like(actions) * self._action_noise_scale)
+
+ if clip:
+ noise = noise.clamp(-self._noise_clip, self._noise_clip)
+
+ noisy_actions = self._process_actions(actions + noise)
+
+ return noisy_actions
+
+ def update(self, dataset: Dataset) -> Dict[str, Union[float, torch.Tensor]]:
+ super().update(dataset)
+
+ if not self.initialized:
+ return {}
+
+ total_actor_loss = torch.zeros(self._batch_count)
+ total_critic_1_loss = torch.zeros(self._batch_count)
+ total_critic_2_loss = torch.zeros(self._batch_count)
+
+ for idx, batch in enumerate(self.storage.batch_generator(self._batch_size, self._batch_count)):
+ actor_obs = batch["actor_observations"]
+ critic_obs = batch["critic_observations"]
+ actions = batch["actions"].reshape(self._batch_size, -1)
+ rewards = batch["rewards"]
+ actor_next_obs = batch["next_actor_observations"]
+ critic_next_obs = batch["next_critic_observations"]
+ dones = batch["dones"]
+
+ critic_1_loss, critic_2_loss = self._update_critic(
+ critic_obs, actions, rewards, dones, actor_next_obs, critic_next_obs
+ )
+
+ if self._update_step % self._policy_delay == 0:
+ evaluation = self.critic_1.forward(
+ self._critic_input(critic_obs, self._process_actions(self.actor.forward(actor_obs)))
+ )
+ actor_loss = -evaluation.mean()
+
+ self.actor_optimizer.zero_grad()
+ actor_loss.backward()
+ self.actor_optimizer.step()
+
+ self._update_target(self.actor, self.target_actor)
+ self._update_target(self.critic_1, self.target_critic_1)
+ self._update_target(self.critic_2, self.target_critic_2)
+
+ total_actor_loss[idx] = actor_loss.item()
+
+ self._update_step = self._update_step + 1
+
+ total_critic_1_loss[idx] = critic_1_loss.item()
+ total_critic_2_loss[idx] = critic_2_loss.item()
+
+ stats = {
+ "actor": total_actor_loss.mean().item(),
+ "critic1": total_critic_1_loss.mean().item(),
+ "critic2": total_critic_2_loss.mean().item(),
+ }
+
+ return stats
+
+ def _update_critic(self, critic_obs, actions, rewards, dones, actor_next_obs, critic_next_obs):
+ target_actor_prediction = self._apply_action_noise(self.target_actor.forward(actor_next_obs), clip=True)
+ target_critic_1_prediction = self.target_critic_1.forward(
+ self._critic_input(critic_next_obs, target_actor_prediction)
+ )
+ target_critic_2_prediction = self.target_critic_2.forward(
+ self._critic_input(critic_next_obs, target_actor_prediction)
+ )
+ target_critic_prediction = torch.min(target_critic_1_prediction, target_critic_2_prediction)
+
+ target = (rewards + self._discount_factor * (1 - dones) * target_critic_prediction).detach()
+
+ prediction_1 = self.critic_1.forward(self._critic_input(critic_obs, actions))
+ critic_1_loss = (prediction_1 - target).pow(2).mean()
+
+ self.critic_1_optimizer.zero_grad()
+ critic_1_loss.backward()
+ self.critic_1_optimizer.step()
+
+ prediction_2 = self.critic_2.forward(self._critic_input(critic_obs, actions))
+ critic_2_loss = (prediction_2 - target).pow(2).mean()
+
+ self.critic_2_optimizer.zero_grad()
+ critic_2_loss.backward()
+ self.critic_2_optimizer.step()
+
+ return critic_1_loss, critic_2_loss
diff --git a/rsl_rl/distributions/__init__.py b/rsl_rl/distributions/__init__.py
new file mode 100644
index 0000000..ac65504
--- /dev/null
+++ b/rsl_rl/distributions/__init__.py
@@ -0,0 +1,2 @@
+from .distribution import Distribution
+from .quantile_distribution import QuantileDistribution
diff --git a/rsl_rl/distributions/distribution.py b/rsl_rl/distributions/distribution.py
new file mode 100644
index 0000000..01f9184
--- /dev/null
+++ b/rsl_rl/distributions/distribution.py
@@ -0,0 +1,18 @@
+from abc import ABC, abstractmethod
+import torch
+
+
+class Distribution(ABC):
+ def __init__(self, params: torch.Tensor) -> None:
+ self._params = params
+
+ @abstractmethod
+ def sample(self, sample_count: int = 1) -> torch.Tensor:
+ """Sample from the distribution.
+
+ Args:
+ sample_count: The number of samples to draw.
+ Returns:
+ A tensor of shape (sample_count, *parameter_shape).
+ """
+ pass
diff --git a/rsl_rl/distributions/quantile_distribution.py b/rsl_rl/distributions/quantile_distribution.py
new file mode 100644
index 0000000..31aee23
--- /dev/null
+++ b/rsl_rl/distributions/quantile_distribution.py
@@ -0,0 +1,13 @@
+import torch
+
+from .distribution import Distribution
+
+
+class QuantileDistribution(Distribution):
+ def sample(self, sample_count: int = 1) -> torch.Tensor:
+ idx = torch.randint(
+ self._params.shape[-1], (*self._params.shape[:-1], sample_count), device=self._params.device
+ )
+ samples = torch.take_along_dim(self._params, idx, -1)
+
+ return samples, idx
diff --git a/rsl_rl/env/__init__.py b/rsl_rl/env/__init__.py
index 54c6491..5c26f6e 100644
--- a/rsl_rl/env/__init__.py
+++ b/rsl_rl/env/__init__.py
@@ -1,6 +1,5 @@
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
# SPDX-License-Identifier: BSD-3-Clause
-
"""Submodule defining the environment definitions."""
from .vec_env import VecEnv
diff --git a/rsl_rl/env/gym_env.py b/rsl_rl/env/gym_env.py
new file mode 100644
index 0000000..289c2c7
--- /dev/null
+++ b/rsl_rl/env/gym_env.py
@@ -0,0 +1,120 @@
+from datetime import datetime
+import gym
+import torch
+from typing import Any, Dict, Tuple, Union
+
+from rsl_rl.env.vec_env import VecEnv
+
+
+class GymEnv(VecEnv):
+ """A vectorized environment wrapper for OpenAI Gym environments.
+
+ This class wraps a single OpenAI Gym environment into a vectorized environment. It is assumed that the environment
+ is a single agent environment. The environment is wrapped in a `gym.vector.SyncVectorEnv` environment, which
+ allows for parallel execution of multiple environments.
+ """
+
+ def __init__(self, name, draw=False, draw_cb=None, draw_directory="videos/", gym_kwargs={}, **kwargs):
+ """
+ Args:
+ name: The name of the OpenAI Gym environment.
+ draw: Whether to record videos of the environment.
+ draw_cb: A callback function that is called after each episode. The callback function is passed the episode
+ number and the path to the video file. The callback function should return `True` if the video should
+ be recorded and `False` otherwise.
+ draw_directory: The directory in which to store the videos.
+ gym_kwargs: Keyword arguments that are passed to the OpenAI Gym environment.
+ **kwargs: Keyword arguments that are passed to the `VecEnv` constructor.
+ """
+ self._gym_kwargs = gym_kwargs
+
+ env = gym.make(name, **self._gym_kwargs)
+
+ assert isinstance(env.observation_space, gym.spaces.Box)
+ assert len(env.observation_space.shape) == 1
+ assert isinstance(env.action_space, gym.spaces.Box)
+ assert len(env.action_space.shape) == 1
+
+ super().__init__(env.observation_space.shape[0], env.observation_space.shape[0], **kwargs)
+
+ self.name = name
+ self.draw_directory = draw_directory
+
+ self.num_actions = env.action_space.shape[0]
+ self._gym_venv = gym.vector.SyncVectorEnv(
+ [lambda: gym.make(self.name, **self._gym_kwargs) for _ in range(self.num_envs)]
+ )
+
+ self._draw = False
+ self._draw_cb = draw_cb if draw_cb is not None else lambda *args: True
+ self._draw_uuid = None
+ self.draw = draw
+
+ self.reset()
+
+ def close(self) -> None:
+ self._gym_venv.close()
+
+ def get_observations(self) -> Tuple[torch.Tensor, Dict[str, Any]]:
+ return self.obs_buf, self.extras
+
+ def get_privileged_observations(self) -> Union[torch.Tensor, None]:
+ return self.obs_buf, self.extras
+
+ @property
+ def draw(self) -> bool:
+ return self._draw
+
+ @draw.setter
+ def draw(self, value: bool) -> None:
+ if value != self._draw:
+ if value:
+ self._draw_uuid = datetime.now().strftime("%Y%m%d%H%M%S")
+ env = gym.make(self.name, render_mode="rgb_array", **self._gym_kwargs)
+ env = gym.wrappers.RecordVideo(
+ env,
+ f"{self.draw_directory}/{self._draw_uuid}/",
+ episode_trigger=lambda ep: (
+ self._draw_cb(ep - 1, f"{self.draw_directory}/{self._draw_uuid}/rl-video-episode-{ep-1}.mp4")
+ or True
+ )
+ if ep > 0
+ else False,
+ )
+ else:
+ env = gym.make(self.name, render_mode=None, **self._gym_kwargs)
+
+ self._gym_venv.envs[0] = env
+ self._draw = value
+
+ self.reset()
+
+ def reset(self) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
+ self.obs_buf = torch.from_numpy(self._gym_venv.reset()[0]).float().to(self.device)
+ self.rew_buf = torch.zeros((self.num_envs,), device=self.device).float()
+ self.reset_buf = torch.zeros((self.num_envs,), device=self.device).float()
+ self.extras = {"observations": {}, "time_outs": torch.zeros((self.num_envs,), device=self.device).float()}
+
+ return self.obs_buf, self.extras
+
+ def step(
+ self, actions: torch.Tensor
+ ) -> Tuple[torch.Tensor, Union[torch.Tensor, None], torch.Tensor, torch.Tensor, Dict[str, Any]]:
+ obs, rew, reset, term, _ = self._gym_venv.step(actions.cpu().numpy())
+
+ self.obs_buf = torch.from_numpy(obs).float().to(self.device)
+ self.rew_buf = torch.from_numpy(rew).float().to(self.device)
+ self.reset_buf = torch.from_numpy(reset).float().to(self.device)
+ self.extras = {
+ "observations": {},
+ "time_outs": torch.from_numpy(term).float().to(self.device).float().to(self.device),
+ }
+
+ return self.obs_buf, self.rew_buf, self.reset_buf, self.extras
+
+ def to(self, device: str) -> None:
+ self.device = device
+
+ self.obs_buf = self.obs_buf.to(device)
+ self.rew_buf = self.rew_buf.to(device)
+ self.reset_buf = self.reset_buf.to(device)
diff --git a/rsl_rl/env/pole_balancing.py b/rsl_rl/env/pole_balancing.py
new file mode 100644
index 0000000..28c6dc0
--- /dev/null
+++ b/rsl_rl/env/pole_balancing.py
@@ -0,0 +1,137 @@
+import math
+import numpy as np
+import time
+import matplotlib.pyplot as plt
+import torch
+from typing import Any, Dict, Tuple, Union
+
+from rsl_rl.env.vec_env import VecEnv
+
+
+class PoleBalancing(VecEnv):
+ """Custom pole balancing environment.
+
+ This class implements a custom pole balancing environment. It demonstrates how to implement a custom `VecEnv`
+ environment.
+ """
+
+ def __init__(self, **kwargs):
+ """
+ Args:
+ **kwargs: Keyword arguments that are passed to the `VecEnv` constructor.
+ """
+ super().__init__(2, 2, **kwargs)
+
+ self.num_actions = 1
+
+ self.gravity = 9.8
+ self.length = 2.0
+ self.dt = 0.1
+
+ # Angle at which to fail the episode (15 deg)
+ self.theta_threshold_radians = 15 * 2 * math.pi / 360
+
+ # Max. angle at which to initialize the episode (5 deg)
+ self.initial_max_position = 2 * 2 * math.pi / 360
+ # Max. angular velocity at which to initialize the episode (1 deg/s)
+ self.initial_max_velocity = 0.3 * 2 * math.pi / 360
+
+ self.initial_position_factor = self.initial_max_position / 0.5
+ self.initial_velocity_factor = self.initial_max_velocity / 0.5
+
+ self.draw = False
+ self.pushes = False
+
+ self.reset()
+
+ def get_observations(self) -> Tuple[torch.Tensor, Dict[str, Any]]:
+ return self.obs_buf, self.extras
+
+ def get_privileged_observations(self) -> Union[torch.Tensor, None]:
+ return self.obs_buf, self.extras
+
+ def step(
+ self, actions: torch.Tensor
+ ) -> Tuple[torch.Tensor, Union[torch.Tensor, None], torch.Tensor, torch.Tensor, Dict[str, Any]]:
+ assert actions.size() == (self.num_envs, 1)
+
+ self.to(self.device)
+ actions = actions.to(self.device)
+
+ noise = torch.normal(torch.zeros_like(actions), torch.ones_like(actions) * 0.005).squeeze()
+ if self.pushes and np.random.rand() < 0.05:
+ noise *= 100.0
+ actions = actions.clamp(min=-0.2, max=0.2).squeeze()
+ gravity = torch.sin(self.state[:, 0]) * self.gravity / self.length
+ angular_acceleration = gravity + actions + noise
+
+ self.state[:, 1] = self.state[:, 1] + self.dt * angular_acceleration
+ self.state[:, 0] = self.state[:, 0] + self.dt * self.state[:, 1]
+
+ self.reset_buf = torch.zeros(self.num_envs)
+ self.reset_buf[(torch.abs(self.state[:, 0]) > self.theta_threshold_radians).nonzero()] = 1.0
+ reset_idx = self.reset_buf.nonzero()
+
+ self.state[reset_idx, 0] = (
+ torch.rand(reset_idx.size()[0], 1, device=self.device) - 0.5
+ ) * self.initial_position_factor
+ self.state[reset_idx, 1] = (
+ torch.rand(reset_idx.size()[0], 1, device=self.device) - 0.5
+ ) * self.initial_velocity_factor
+
+ self.rew_buf = torch.ones(self.num_envs, device=self.device)
+ self.rew_buf[reset_idx] = -1.0
+ self.rew_buf = self.rew_buf - actions.abs()
+ self.rew_buf = self.rew_buf - self.state[:, 0].abs()
+
+ self._update_obs()
+
+ if self.draw:
+ self._debug_draw(actions)
+
+ self.to(self.device)
+
+ return self.obs_buf, self.rew_buf, self.reset_buf, self.extras
+
+ def _update_obs(self):
+ self.obs_buf = self.state
+ self.extras = {"observations": {}, "time_outs": torch.zeros_like(self.rew_buf)}
+
+ def reset(self) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
+ self.state = torch.zeros(self.num_envs, 2, device=self.device)
+ self.state[:, 0] = (torch.rand(self.num_envs) - 0.5) * self.initial_position_factor
+ self.state[:, 1] = (torch.rand(self.num_envs) - 0.5) * self.initial_velocity_factor
+ self.rew_buf = torch.zeros(self.num_envs)
+ self.reset_buf = torch.zeros(self.num_envs)
+ self.extras = {}
+
+ self._update_obs()
+
+ return self.obs_buf, self.extras
+
+ def to(self, device):
+ self.device = device
+
+ self.obs_buf = self.obs_buf.to(device)
+ self.rew_buf = self.rew_buf.to(device)
+ self.reset_buf = self.reset_buf.to(device)
+ self.state = self.state.to(device)
+
+ def _debug_draw(self, actions):
+ if not hasattr(self, "_visuals"):
+ self._visuals = {"x": [0], "pos": [], "act": [], "done": []}
+ plt.gca().figure.show()
+ else:
+ self._visuals["x"].append(self._visuals["x"][-1] + 1)
+
+ self._visuals["pos"].append(self.obs_buf[0, 0].cpu().item())
+ self._visuals["done"].append(self.reset_buf[0].cpu().item())
+ self._visuals["act"].append(actions.squeeze()[0].cpu().item())
+
+ plt.cla()
+ plt.plot(self._visuals["x"][-100:], self._visuals["act"][-100:], color="green")
+ plt.plot(self._visuals["x"][-100:], self._visuals["pos"][-100:], color="blue")
+ plt.plot(self._visuals["x"][-100:], self._visuals["done"][-100:], color="red")
+ plt.draw()
+ plt.gca().figure.canvas.flush_events()
+ time.sleep(0.0001)
diff --git a/rsl_rl/env/pomdp.py b/rsl_rl/env/pomdp.py
new file mode 100644
index 0000000..750a464
--- /dev/null
+++ b/rsl_rl/env/pomdp.py
@@ -0,0 +1,97 @@
+import torch
+from typing import Any, Dict, Tuple, Union
+
+from rsl_rl.env.gym_env import GymEnv
+
+
+class GymPOMDP(GymEnv):
+ """A vectorized POMDP environment wrapper for OpenAI Gym environments.
+
+ This environment allows for the modification of the observation space of an OpenAI Gym environment. The modified
+ observation space is a subset of the original observation space.
+ """
+
+ _reduced_observation_count: int = None
+
+ def __init__(self, name: str, **kwargs):
+ assert self._reduced_observation_count is not None
+
+ super().__init__(name=name, **kwargs)
+
+ self.num_obs = self._reduced_observation_count
+ self.num_privileged_obs = self._reduced_observation_count
+
+ def _process_observations(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Reduces observation space from original observation space to modified observation space.
+
+ Args:
+ obs (torch.Tensor): Original observations.
+ Returns:
+ The modified observations as a torch.Tensor of shape (obs.shape[0], self.num_obs).
+ """
+ raise NotImplementedError
+
+ def reset(self, *args, **kwargs):
+ obs, _ = super().reset(*args, **kwargs)
+
+ self.obs_buf = self._process_observations(obs)
+
+ return self.obs_buf, self.extras
+
+ def step(
+ self, actions: torch.Tensor
+ ) -> Tuple[torch.Tensor, Union[torch.Tensor, None], torch.Tensor, torch.Tensor, Dict[str, Any]]:
+ obs, _, _, _ = super().step(actions)
+
+ self.obs_buf = self._process_observations(obs)
+
+ return self.obs_buf, self.rew_buf, self.reset_buf, self.extras
+
+
+class BipedalWalkerP(GymPOMDP):
+ """
+ Original observation space (24 values):
+ [
+ hull angle,
+ hull angular velocity,
+ horizontal velocity,
+ vertical velocity,
+ joint 1 angle,
+ joint 1 speed,
+ joint 2 angle,
+ joint 2 speed,
+ leg 1 ground contact,
+ joint 3 angle,
+ joint 3 speed,
+ joint 4 angle,
+ joint 4 speed,
+ leg 2 ground contact,
+ lidar (10 values),
+ ]
+ Modified observation space (15 values):
+ [
+ hull angle,
+ joint 1 angle,
+ joint 2 angle,
+ joint 3 angle,
+ joint 4 angle,
+ lidar (10 values),
+ ]
+ """
+
+ _reduced_observation_count: int = 15
+
+ def __init__(self, **kwargs):
+ super().__init__(name="BipedalWalker-v3", **kwargs)
+
+ def _process_observations(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Reduces observation space from original observation space to modified observation space."""
+ reduced_obs = torch.zeros(obs.shape[0], self._reduced_observation_count, device=self.device)
+ reduced_obs[:, 0] = obs[:, 0]
+ reduced_obs[:, 1] = obs[:, 4]
+ reduced_obs[:, 2] = obs[:, 6]
+ reduced_obs[:, 3] = obs[:, 9]
+ reduced_obs[:, 4] = obs[:, 11]
+ reduced_obs[:, 5:] = obs[:, 14:]
+
+ return reduced_obs
diff --git a/rsl_rl/env/rslgym_env.py b/rsl_rl/env/rslgym_env.py
new file mode 100644
index 0000000..a6c0d91
--- /dev/null
+++ b/rsl_rl/env/rslgym_env.py
@@ -0,0 +1,44 @@
+import torch
+from typing import Any, Dict, Tuple, Union
+
+from rsl_rl.env.vec_env import VecEnv
+
+
+class RSLGymEnv(VecEnv):
+ """A wrapper for using rsl_rl with the rslgym library."""
+
+ def __init__(self, rslgym_env, **kwargs):
+ self._rslgym_env = rslgym_env
+
+ observation_count = self._rslgym_env.observation_space.shape[0]
+ super().__init__(observation_count, observation_count, **kwargs)
+
+ self.num_actions = self._rslgym_env.action_space.shape[0]
+
+ self.obs_buf = None
+ self.rew_buf = None
+ self.reset_buf = None
+ self.extras = None
+
+ self.reset()
+
+ def get_observations(self) -> Tuple[torch.Tensor, Dict[str, Any]]:
+ return self.obs_buf, self.extras
+
+ def get_privileged_observations(self) -> Union[torch.Tensor, None]:
+ return self.obs_buf
+
+ def reset(self) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
+ obs = self._rslgym_env.reset()
+
+ self.obs_buf = torch.from_numpy(obs)
+ self.extras = {"observations": {}, "time_outs": torch.zeros((self.num_envs,), device=self.device).float()}
+
+ def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, Any]]:
+ obs, reward, dones, info = self._rslgym_env.step(actions, True)
+
+ self.obs_buf = torch.from_numpy(obs)
+ self.rew_buf = torch.from_numpy(reward)
+ self.reset_buf = torch.from_numpy(dones).float()
+
+ return self.obs_buf, self.rew_buf, self.reset_buf, self.extras
diff --git a/rsl_rl/env/vec_env.py b/rsl_rl/env/vec_env.py
index a7af015..70c9af8 100644
--- a/rsl_rl/env/vec_env.py
+++ b/rsl_rl/env/vec_env.py
@@ -1,85 +1,74 @@
-# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
-# SPDX-License-Identifier: BSD-3-Clause
-
-from __future__ import annotations
-
-import torch
from abc import ABC, abstractmethod
+import torch
+from typing import Any, Dict, Tuple, Union
+# minimal interface of the environment
class VecEnv(ABC):
- """Abstract class for vectorized environment.
-
- The vectorized environment is a collection of environments that are synchronized. This means that
- the same action is applied to all environments and the same observation is returned from all environments.
-
- All extra observations must be provided as a dictionary to "extras" in the step() method. Based on the
- configuration, the extra observations are used for different purposes. The following keys are reserved
- in the "observations" dictionary (if they are present):
-
- - "critic": The observation is used as input to the critic network. Useful for asymmetric observation spaces.
- """
+ """Abstract class for vectorized environment."""
num_envs: int
- """Number of environments."""
num_obs: int
- """Number of observations."""
num_privileged_obs: int
- """Number of privileged observations."""
num_actions: int
- """Number of actions."""
max_episode_length: int
- """Maximum episode length."""
privileged_obs_buf: torch.Tensor
- """Buffer for privileged observations."""
obs_buf: torch.Tensor
- """Buffer for observations."""
rew_buf: torch.Tensor
- """Buffer for rewards."""
reset_buf: torch.Tensor
- """Buffer for resets."""
episode_length_buf: torch.Tensor # current episode duration
- """Buffer for current episode lengths."""
extras: dict
- """Extra information (metrics).
-
- Extra information is stored in a dictionary. This includes metrics such as the episode reward, episode length,
- etc. Additional information can be stored in the dictionary such as observations for the critic network, etc.
- """
device: torch.device
- """Device to use."""
- """
- Operations.
- """
-
- @abstractmethod
- def get_observations(self) -> tuple[torch.Tensor, dict]:
- """Return the current observations.
-
- Returns:
- Tuple[torch.Tensor, dict]: Tuple containing the observations and extras.
+ def __init__(
+ self, observation_count, privileged_observation_count, device="cpu", environment_count=1, max_episode_length=-1
+ ):
"""
- raise NotImplementedError
-
- @abstractmethod
- def reset(self) -> tuple[torch.Tensor, dict]:
- """Reset all environment instances.
-
- Returns:
- Tuple[torch.Tensor, dict]: Tuple containing the observations and extras.
+ Args:
+ observation_count (int): Number of observations per environment.
+ privileged_observation_count (int): Number of privileged observations per environment.
+ device (str): Device to use for the tensors.
+ environment_count (int): Number of environments to run in parallel.
+ max_episode_length (int): Maximum length of an episode. If -1, the episode length is not limited.
"""
- raise NotImplementedError
+ self.num_obs = observation_count
+ self.num_privileged_obs = privileged_observation_count
+
+ self.num_envs = environment_count
+ self.max_episode_length = max_episode_length
+ self.device = device
@abstractmethod
- def step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]:
+ def get_observations(self) -> Tuple[torch.Tensor, Dict[str, Any]]:
+ """Return observations and extra information."""
+ pass
+
+ @abstractmethod
+ def get_privileged_observations(self) -> Union[torch.Tensor, None]:
+ """Return privileged observations."""
+ pass
+
+ @abstractmethod
+ def step(
+ self, actions: torch.Tensor
+ ) -> Tuple[torch.Tensor, Union[torch.Tensor, None], torch.Tensor, torch.Tensor, Dict[str, Any]]:
"""Apply input action on the environment.
Args:
actions (torch.Tensor): Input actions to apply. Shape: (num_envs, num_actions)
Returns:
- Tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]:
- A tuple containing the observations, rewards, dones and extra information (metrics).
+ Tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, dict]:
+ A tuple containing the observations, privileged observations, rewards, dones and
+ extra information (metrics).
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def reset(self) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
+ """Reset all environment instances.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor | None]: Tuple containing the observations and privileged observations.
"""
raise NotImplementedError
diff --git a/rsl_rl/modules/__init__.py b/rsl_rl/modules/__init__.py
index b96383c..b1b4d43 100644
--- a/rsl_rl/modules/__init__.py
+++ b/rsl_rl/modules/__init__.py
@@ -1,10 +1,19 @@
-# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
-# SPDX-License-Identifier: BSD-3-Clause
-
-"""Definitions for neural-network components for RL-agents."""
-
-from .actor_critic import ActorCritic
-from .actor_critic_recurrent import ActorCriticRecurrent
+from .categorical_network import CategoricalNetwork
+from .gaussian_chimera_network import GaussianChimeraNetwork
+from .gaussian_network import GaussianNetwork
+from .implicit_quantile_network import ImplicitQuantileNetwork
+from .network import Network
from .normalizer import EmpiricalNormalization
+from .quantile_network import QuantileNetwork
+from .transformer import Transformer
-__all__ = ["ActorCritic", "ActorCriticRecurrent"]
+__all__ = [
+ "CategoricalNetwork",
+ "EmpiricalNormalization",
+ "GaussianChimeraNetwork",
+ "GaussianNetwork",
+ "ImplicitQuantileNetwork",
+ "Network",
+ "QuantileNetwork",
+ "Transformer",
+]
diff --git a/rsl_rl/modules/actor_critic.py b/rsl_rl/modules/actor_critic.py
deleted file mode 100644
index cdb77e1..0000000
--- a/rsl_rl/modules/actor_critic.py
+++ /dev/null
@@ -1,136 +0,0 @@
-# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
-# SPDX-License-Identifier: BSD-3-Clause
-
-
-from __future__ import annotations
-
-import torch
-import torch.nn as nn
-from torch.distributions import Normal
-
-
-class ActorCritic(nn.Module):
- is_recurrent = False
-
- def __init__(
- self,
- num_actor_obs,
- num_critic_obs,
- num_actions,
- actor_hidden_dims=[256, 256, 256],
- critic_hidden_dims=[256, 256, 256],
- activation="elu",
- init_noise_std=1.0,
- **kwargs,
- ):
- if kwargs:
- print(
- "ActorCritic.__init__ got unexpected arguments, which will be ignored: "
- + str([key for key in kwargs.keys()])
- )
- super().__init__()
- activation = get_activation(activation)
-
- mlp_input_dim_a = num_actor_obs
- mlp_input_dim_c = num_critic_obs
- # Policy
- actor_layers = []
- actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0]))
- actor_layers.append(activation)
- for layer_index in range(len(actor_hidden_dims)):
- if layer_index == len(actor_hidden_dims) - 1:
- actor_layers.append(nn.Linear(actor_hidden_dims[layer_index], num_actions))
- else:
- actor_layers.append(nn.Linear(actor_hidden_dims[layer_index], actor_hidden_dims[layer_index + 1]))
- actor_layers.append(activation)
- self.actor = nn.Sequential(*actor_layers)
-
- # Value function
- critic_layers = []
- critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0]))
- critic_layers.append(activation)
- for layer_index in range(len(critic_hidden_dims)):
- if layer_index == len(critic_hidden_dims) - 1:
- critic_layers.append(nn.Linear(critic_hidden_dims[layer_index], 1))
- else:
- critic_layers.append(nn.Linear(critic_hidden_dims[layer_index], critic_hidden_dims[layer_index + 1]))
- critic_layers.append(activation)
- self.critic = nn.Sequential(*critic_layers)
-
- print(f"Actor MLP: {self.actor}")
- print(f"Critic MLP: {self.critic}")
-
- # Action noise
- self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
- self.distribution = None
- # disable args validation for speedup
- Normal.set_default_validate_args = False
-
- # seems that we get better performance without init
- # self.init_memory_weights(self.memory_a, 0.001, 0.)
- # self.init_memory_weights(self.memory_c, 0.001, 0.)
-
- @staticmethod
- # not used at the moment
- def init_weights(sequential, scales):
- [
- torch.nn.init.orthogonal_(module.weight, gain=scales[idx])
- for idx, module in enumerate(mod for mod in sequential if isinstance(mod, nn.Linear))
- ]
-
- def reset(self, dones=None):
- pass
-
- def forward(self):
- raise NotImplementedError
-
- @property
- def action_mean(self):
- return self.distribution.mean
-
- @property
- def action_std(self):
- return self.distribution.stddev
-
- @property
- def entropy(self):
- return self.distribution.entropy().sum(dim=-1)
-
- def update_distribution(self, observations):
- mean = self.actor(observations)
- self.distribution = Normal(mean, mean * 0.0 + self.std)
-
- def act(self, observations, **kwargs):
- self.update_distribution(observations)
- return self.distribution.sample()
-
- def get_actions_log_prob(self, actions):
- return self.distribution.log_prob(actions).sum(dim=-1)
-
- def act_inference(self, observations):
- actions_mean = self.actor(observations)
- return actions_mean
-
- def evaluate(self, critic_observations, **kwargs):
- value = self.critic(critic_observations)
- return value
-
-
-def get_activation(act_name):
- if act_name == "elu":
- return nn.ELU()
- elif act_name == "selu":
- return nn.SELU()
- elif act_name == "relu":
- return nn.ReLU()
- elif act_name == "crelu":
- return nn.CReLU()
- elif act_name == "lrelu":
- return nn.LeakyReLU()
- elif act_name == "tanh":
- return nn.Tanh()
- elif act_name == "sigmoid":
- return nn.Sigmoid()
- else:
- print("invalid activation function!")
- return None
diff --git a/rsl_rl/modules/actor_critic_recurrent.py b/rsl_rl/modules/actor_critic_recurrent.py
deleted file mode 100644
index 6321ec5..0000000
--- a/rsl_rl/modules/actor_critic_recurrent.py
+++ /dev/null
@@ -1,97 +0,0 @@
-# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
-# SPDX-License-Identifier: BSD-3-Clause
-
-from __future__ import annotations
-
-import torch
-import torch.nn as nn
-
-from rsl_rl.modules.actor_critic import ActorCritic, get_activation
-from rsl_rl.utils import unpad_trajectories
-
-
-class ActorCriticRecurrent(ActorCritic):
- is_recurrent = True
-
- def __init__(
- self,
- num_actor_obs,
- num_critic_obs,
- num_actions,
- actor_hidden_dims=[256, 256, 256],
- critic_hidden_dims=[256, 256, 256],
- activation="elu",
- rnn_type="lstm",
- rnn_hidden_size=256,
- rnn_num_layers=1,
- init_noise_std=1.0,
- **kwargs,
- ):
- if kwargs:
- print(
- "ActorCriticRecurrent.__init__ got unexpected arguments, which will be ignored: " + str(kwargs.keys()),
- )
-
- super().__init__(
- num_actor_obs=rnn_hidden_size,
- num_critic_obs=rnn_hidden_size,
- num_actions=num_actions,
- actor_hidden_dims=actor_hidden_dims,
- critic_hidden_dims=critic_hidden_dims,
- activation=activation,
- init_noise_std=init_noise_std,
- )
-
- activation = get_activation(activation)
-
- self.memory_a = Memory(num_actor_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size)
- self.memory_c = Memory(num_critic_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size)
-
- print(f"Actor RNN: {self.memory_a}")
- print(f"Critic RNN: {self.memory_c}")
-
- def reset(self, dones=None):
- self.memory_a.reset(dones)
- self.memory_c.reset(dones)
-
- def act(self, observations, masks=None, hidden_states=None):
- input_a = self.memory_a(observations, masks, hidden_states)
- return super().act(input_a.squeeze(0))
-
- def act_inference(self, observations):
- input_a = self.memory_a(observations)
- return super().act_inference(input_a.squeeze(0))
-
- def evaluate(self, critic_observations, masks=None, hidden_states=None):
- input_c = self.memory_c(critic_observations, masks, hidden_states)
- return super().evaluate(input_c.squeeze(0))
-
- def get_hidden_states(self):
- return self.memory_a.hidden_states, self.memory_c.hidden_states
-
-
-class Memory(torch.nn.Module):
- def __init__(self, input_size, type="lstm", num_layers=1, hidden_size=256):
- super().__init__()
- # RNN
- rnn_cls = nn.GRU if type.lower() == "gru" else nn.LSTM
- self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
- self.hidden_states = None
-
- def forward(self, input, masks=None, hidden_states=None):
- batch_mode = masks is not None
- if batch_mode:
- # batch mode (policy update): need saved hidden states
- if hidden_states is None:
- raise ValueError("Hidden states not passed to memory module during policy update")
- out, _ = self.rnn(input, hidden_states)
- out = unpad_trajectories(out, masks)
- else:
- # inference mode (collection): use hidden states of last step
- out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states)
- return out
-
- def reset(self, dones=None):
- # When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state
- for hidden_state in self.hidden_states:
- hidden_state[..., dones, :] = 0.0
diff --git a/rsl_rl/modules/categorical_network.py b/rsl_rl/modules/categorical_network.py
new file mode 100644
index 0000000..bc4fb40
--- /dev/null
+++ b/rsl_rl/modules/categorical_network.py
@@ -0,0 +1,105 @@
+import torch
+import torch.nn as nn
+from rsl_rl.modules.network import Network
+from rsl_rl.utils.utils import squeeze_preserve_batch
+
+eps = torch.finfo(torch.float32).eps
+
+
+class CategoricalNetwork(Network):
+ def __init__(
+ self,
+ input_size,
+ output_size,
+ activations=["relu", "relu", "relu"],
+ atom_count=51,
+ hidden_dims=[256, 256, 256],
+ init_gain=1.0,
+ value_max=10.0,
+ value_min=-10.0,
+ **kwargs,
+ ):
+ assert len(hidden_dims) == len(activations)
+ assert value_max > value_min
+ assert atom_count > 1
+
+ super().__init__(
+ input_size,
+ activations=activations,
+ hidden_dims=hidden_dims[:-1],
+ init_fade=False,
+ init_gain=init_gain,
+ output_size=hidden_dims[-1],
+ **kwargs,
+ )
+
+ self._value_max = value_max
+ self._value_min = value_min
+ self._atom_count = atom_count
+
+ self.value_delta = (self._value_max - self._value_min) / (self._atom_count - 1)
+ action_values = torch.arange(self._value_min, self._value_max + eps, self.value_delta)
+ self.register_buffer("action_values", action_values)
+
+ self._categorical_layers = nn.ModuleList([nn.Linear(hidden_dims[-1], atom_count) for _ in range(output_size)])
+
+ self._init(self._categorical_layers, fade=False, gain=init_gain)
+
+ def categorical_loss(
+ self, predictions: torch.Tensor, target_probabilities: torch.Tensor, targets: torch.Tensor
+ ) -> torch.Tensor:
+ """Computes the categorical loss between the prediction and target categorical distributions.
+
+ Projects the targets back onto the categorical distribution supports before computing KL divergence.
+
+ Args:
+ predictions (torch.Tensor): The network prediction.
+ target_probabilities (torch.Tensor): The next-state value probabilities.
+ targets (torch.Tensor): The targets to compute the loss from.
+ Returns:
+ A torch.Tensor of the cross-entropy loss between the projected targets and the prediction.
+ """
+ b = (targets - self._value_min) / self.value_delta
+ l = b.floor().long().clamp(0, self._atom_count - 1)
+ u = b.ceil().long().clamp(0, self._atom_count - 1)
+
+ all_idx = torch.arange(b.shape[0])
+ projected_targets = torch.zeros((b.shape[0], self._atom_count), device=self.device)
+ for i in range(self._atom_count):
+ # Correct for when l == u
+ l[:, i][(l[:, i] == u[:, i]) * (l[:, i] > 0)] -= 1
+ u[:, i][(l[:, i] == u[:, i]) * (u[:, i] < self._atom_count - 1)] += 1
+
+ projected_targets[all_idx, l[:, i]] += (u[:, i] - b[:, i]) * target_probabilities[..., i]
+ projected_targets[all_idx, u[:, i]] += (b[:, i] - l[:, i]) * target_probabilities[..., i]
+
+ loss = torch.nn.functional.cross_entropy(
+ predictions.reshape(*projected_targets.shape), projected_targets.detach()
+ )
+
+ return loss
+
+ def compute_targets(self, rewards: torch.Tensor, dones: torch.Tensor, discount: float) -> torch.Tensor:
+ gamma = (discount * (1 - dones)).reshape(-1, 1)
+ gamma_z = gamma * self.action_values.repeat(dones.size()[0], 1)
+ targets = (rewards.reshape(-1, 1) + gamma_z).clamp(self._value_min, self._value_max)
+
+ return targets
+
+ def forward(self, x: torch.Tensor, distribution: bool = False) -> torch.Tensor:
+ features = super().forward(x)
+ probabilities = squeeze_preserve_batch(
+ torch.stack([layer(features).softmax(dim=-1) for layer in self._categorical_layers], dim=1)
+ )
+
+ if distribution:
+ return probabilities
+
+ values = self.probabilities_to_values(probabilities)
+
+ return values
+
+ def probabilities_to_values(self, probabilities: torch.Tensor) -> torch.Tensor:
+ values = probabilities @ self.action_values
+
+ return values
diff --git a/rsl_rl/modules/gaussian_chimera_network.py b/rsl_rl/modules/gaussian_chimera_network.py
new file mode 100644
index 0000000..8b13e53
--- /dev/null
+++ b/rsl_rl/modules/gaussian_chimera_network.py
@@ -0,0 +1,86 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from typing import List, Tuple, Union
+
+from rsl_rl.modules.network import Network
+from rsl_rl.modules.utils import get_activation
+
+
+class GaussianChimeraNetwork(Network):
+ """A network to predict mean and std of a gaussian distribution with separate heads for mean and std."""
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int,
+ activations: List[str] = ["relu", "relu", "relu", "linear"],
+ hidden_dims: List[int] = [256, 256, 256],
+ init_fade: bool = True,
+ init_gain: float = 0.5,
+ log_std_max: float = 4.0,
+ log_std_min: float = -20.0,
+ std_init: float = 1.0,
+ shared_dims: int = 1,
+ **kwargs,
+ ):
+ assert len(hidden_dims) + 1 == len(activations)
+ assert shared_dims > 0 and shared_dims <= len(hidden_dims)
+
+ super().__init__(
+ input_size,
+ hidden_dims[shared_dims],
+ activations=activations[: shared_dims + 1],
+ hidden_dims=hidden_dims[:shared_dims],
+ init_fade=False,
+ init_gain=init_gain,
+ **kwargs,
+ )
+
+ # Since the network predicts log_std ~= 0 after initialization, compute std = std_init * exp(log_std).
+ self._log_std_init = np.log(std_init)
+ self._log_std_max = log_std_max
+ self._log_std_min = log_std_min
+
+ separate_dims = len(hidden_dims) - shared_dims
+
+ mean_layers = []
+ for i in range(separate_dims):
+ isize = hidden_dims[shared_dims + i]
+ osize = output_size if i == separate_dims - 1 else hidden_dims[shared_dims + i + 1]
+
+ layer = nn.Linear(isize, osize)
+ activation = activations[shared_dims + i + 1]
+
+ mean_layers += [layer, get_activation(activation)]
+ self._mean_layer = nn.Sequential(*mean_layers)
+
+ self._init(self._mean_layer, fade=init_fade, gain=init_gain)
+
+ log_std_layers = []
+ for i in range(separate_dims):
+ isize = hidden_dims[shared_dims + i]
+ osize = output_size if i == separate_dims - 1 else hidden_dims[shared_dims + i + 1]
+
+ layer = nn.Linear(isize, osize)
+ activation = activations[shared_dims + i + 1]
+
+ log_std_layers += [layer, get_activation(activation)]
+ self._log_std_layer = nn.Sequential(*log_std_layers)
+
+ self._init(self._log_std_layer, fade=init_fade, gain=init_gain)
+
+ def forward(self, x: torch.Tensor, compute_std: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
+ features = super().forward(x)
+
+ mean = self._mean_layer(features)
+
+ if not compute_std:
+ return mean
+
+ # compute standard deviation as std = std_init * exp(log_std) = exp(log(std_init) + log(std)) since the network
+ # will predict log_std ~= 0 after initialization.
+ log_std = (self._log_std_init + self._log_std_layer(features)).clamp(self._log_std_min, self._log_std_max)
+ std = log_std.exp()
+
+ return mean, std
diff --git a/rsl_rl/modules/gaussian_network.py b/rsl_rl/modules/gaussian_network.py
new file mode 100644
index 0000000..1aec8d2
--- /dev/null
+++ b/rsl_rl/modules/gaussian_network.py
@@ -0,0 +1,37 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from typing import Tuple, Union
+
+from rsl_rl.modules.network import Network
+
+
+class GaussianNetwork(Network):
+ """A network to predict mean and std of a gaussian distribution where std is a tunable parameter."""
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int,
+ log_std_max: float = 4.0,
+ log_std_min: float = -20.0,
+ std_init: float = 1.0,
+ **kwargs,
+ ):
+ super().__init__(input_size, output_size, **kwargs)
+
+ self._log_std_max = log_std_max
+ self._log_std_min = log_std_min
+
+ self._log_std = nn.Parameter(torch.ones(output_size) * np.log(std_init))
+
+ def forward(self, x: torch.Tensor, compute_std: bool = False, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
+ mean = super().forward(x, **kwargs)
+
+ if not compute_std:
+ return mean
+
+ log_std = torch.ones_like(mean) * self._log_std.clamp(self._log_std_min, self._log_std_max)
+ std = log_std.exp()
+
+ return mean, std
diff --git a/rsl_rl/modules/implicit_quantile_network.py b/rsl_rl/modules/implicit_quantile_network.py
new file mode 100644
index 0000000..898ee67
--- /dev/null
+++ b/rsl_rl/modules/implicit_quantile_network.py
@@ -0,0 +1,168 @@
+import numpy as np
+import torch
+from torch.distributions import Normal
+import torch.nn as nn
+from typing import List, Union
+
+from rsl_rl.modules.network import Network
+from rsl_rl.modules.quantile_network import energy_loss
+from rsl_rl.utils.benchmarkable import Benchmarkable
+
+
+def reshape_measure_param(tau: torch.Tensor, param: torch.Tensor) -> torch.Tensor:
+ if not torch.is_tensor(param):
+ param = torch.tensor([param])
+
+ param = param.expand(tau.shape[0], -1).to(tau.device)
+
+ return param
+
+
+def risk_measure_neutral(tau: torch.Tensor) -> torch.Tensor:
+ return tau
+
+
+def risk_measure_wang(tau: torch.Tensor, beta: float = 0.0) -> torch.Tensor:
+ beta = reshape_measure_param(tau, beta)
+
+ distorted_tau = Normal(0, 1).cdf(Normal(0, 1).icdf(tau) + beta)
+
+ return distorted_tau
+
+
+class ImplicitQuantileNetwork(Network):
+ measure_neutral = "neutral"
+ measure_wang = "wang"
+
+ measures = {
+ measure_neutral: risk_measure_neutral,
+ measure_wang: risk_measure_wang,
+ }
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int,
+ activations: List[str] = ["relu", "relu", "relu"],
+ feature_layers: int = 1,
+ embedding_size: int = 64,
+ hidden_dims: List[int] = [256, 256, 256],
+ init_fade: bool = False,
+ init_gain: float = 0.5,
+ measure: str = None,
+ measure_kwargs: dict = {},
+ **kwargs,
+ ):
+ assert len(hidden_dims) == len(activations), "hidden_dims and activations must have the same length."
+ assert feature_layers > 0, "feature_layers must be greater than 0."
+ assert feature_layers < len(hidden_dims), "feature_layers must be less than the number of hidden dimensions."
+ assert embedding_size > 0, "embedding_size must be greater than 0."
+
+ super().__init__(
+ input_size,
+ hidden_dims[feature_layers - 1],
+ activations=activations[:feature_layers],
+ hidden_dims=hidden_dims[: feature_layers - 1],
+ init_fade=init_fade,
+ init_gain=init_gain,
+ **kwargs,
+ )
+
+ self._last_taus = None
+ self._last_quantiles = None
+ self._embedding_size = embedding_size
+ self.register_buffer(
+ "_embedding_pis",
+ np.pi * (torch.arange(self._embedding_size, device=self.device).reshape(1, 1, self._embedding_size)),
+ )
+ self._embedding_layer = nn.Sequential(
+ nn.Linear(self._embedding_size, hidden_dims[feature_layers - 1]), nn.ReLU()
+ )
+
+ self._fusion_layers = Network(
+ hidden_dims[feature_layers - 1],
+ output_size,
+ activations=activations[feature_layers:] + ["linear"],
+ hidden_dims=hidden_dims[feature_layers:],
+ init_fade=init_fade,
+ init_gain=init_gain,
+ )
+
+ measure_func = risk_measure_neutral if measure is None else self.measures[measure]
+ self._measure_func = measure_func
+ self._measure_kwargs = measure_kwargs
+
+ @Benchmarkable.register
+ def _sample_taus(self, batch_size: int, sample_count: int, measure_args: list, use_measure: bool) -> torch.Tensor:
+ """Sample quantiles and distort them according to the risk metric.
+
+ Args:
+ batch_size: The batch size.
+ sample_count: The number of samples to draw.
+ measure_args: The arguments to pass to the risk measure function.
+ use_measure: Whether to use the risk measure or not.
+ Returns:
+ A tensor of shape (batch_size, sample_count, 1).
+ """
+ taus = torch.rand(batch_size, sample_count, device=self.device)
+
+ if not use_measure:
+ return taus
+
+ if measure_args:
+ taus = self._measure_func(taus, *measure_args)
+ else:
+ taus = self._measure_func(taus, **self._measure_kwargs)
+
+ return taus
+
+ @Benchmarkable.register
+ def forward(
+ self,
+ x: torch.Tensor,
+ distribution: bool = False,
+ measure_args: list = [],
+ sample_count: int = 8,
+ taus: Union[torch.Tensor, None] = None,
+ use_measure: bool = True,
+ **kwargs,
+ ) -> torch.Tensor:
+ assert taus is None or not use_measure, "Cannot use taus and use_measure at the same time."
+
+ batch_size = x.shape[0]
+
+ features = super().forward(x, **kwargs)
+ taus = self._sample_taus(batch_size, sample_count, measure_args, use_measure) if taus is None else taus
+
+ # Compute quantile embeddings
+ singular_dims = [1] * taus.dim()
+ cos = torch.cos(taus.unsqueeze(-1) * self._embedding_pis.reshape(*singular_dims, self._embedding_size))
+ embeddings = self._embedding_layer(cos)
+
+ # Compute the fusion of the features and the embeddings
+ fused_features = features.unsqueeze(-2) * embeddings
+ quantiles = self._fusion_layers(fused_features)
+
+ self._last_quantiles = quantiles
+ self._last_taus = taus
+
+ if distribution:
+ return quantiles
+
+ values = quantiles.mean(-1)
+
+ return values
+
+ @property
+ def last_taus(self):
+ return self._last_taus
+
+ @property
+ def last_quantiles(self):
+ return self._last_quantiles
+
+ @Benchmarkable.register
+ def sample_energy_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
+ loss = energy_loss(predictions, targets)
+
+ return loss
diff --git a/rsl_rl/modules/network.py b/rsl_rl/modules/network.py
new file mode 100644
index 0000000..33e4513
--- /dev/null
+++ b/rsl_rl/modules/network.py
@@ -0,0 +1,210 @@
+import torch
+import torch.nn as nn
+from typing import List
+
+from rsl_rl.modules.normalizer import EmpiricalNormalization
+from rsl_rl.modules.utils import get_activation
+from rsl_rl.modules.transformer import Transformer
+from rsl_rl.utils.benchmarkable import Benchmarkable
+from rsl_rl.utils.utils import squeeze_preserve_batch
+
+
+class Network(Benchmarkable, nn.Module):
+ recurrent_module_lstm = "LSTM"
+ recurrent_module_transformer = "TF"
+
+ recurrent_modules = {recurrent_module_lstm: nn.LSTM, recurrent_module_transformer: Transformer}
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int,
+ activations: List[str] = ["relu", "relu", "relu", "tanh"],
+ hidden_dims: List[int] = [256, 256, 256],
+ init_fade: bool = True,
+ init_gain: float = 1.0,
+ input_normalization: bool = False,
+ recurrent: bool = False,
+ recurrent_layers: int = 1,
+ recurrent_module: str = recurrent_module_lstm,
+ recurrent_tf_context_length: int = 64,
+ recurrent_tf_head_count: int = 8,
+ ) -> None:
+ """
+
+ Args:
+ input_size (int): The size of the input.
+ output_size (int): The size of the output.
+ activations (List[str]): The activation functions to use. If the network is recurrent, the first activation
+ function is used for the output of the recurrent layer.
+ hidden_dims (List[int]): The hidden dimensions. If the network is recurrent, the first hidden dimension is
+ used for the recurrent layer.
+ init_fade (bool): Whether to use the fade in initialization.
+ init_gain (float): The gain to use for the initialization.
+ input_normalization (bool): Whether to use input normalization.
+ recurrent (bool): Whether to use a recurrent network.
+ recurrent_layers (int): The number of recurrent layers (LSTM) / blocks (Transformer) to use.
+ recurrent_module (str): The recurrent module to use. Must be one of Network.recurrent_modules.
+ recurrent_tf_context_length (int): The context length of the Transformer.
+ recurrent_tf_head_count (int): The head count of the Transformer.
+ """
+ assert len(hidden_dims) + 1 == len(activations)
+
+ super().__init__()
+
+ if input_normalization:
+ self._normalization = EmpiricalNormalization(shape=(input_size,))
+ else:
+ self._normalization = nn.Identity()
+
+ dims = [input_size] + hidden_dims + [output_size]
+
+ self._recurrent = recurrent
+ self._recurrent_module = recurrent_module
+ self.hidden_state = None
+ self._last_hidden_state = None
+ if self._recurrent:
+ recurrent_kwargs = dict()
+
+ if recurrent_module == self.recurrent_module_lstm:
+ recurrent_kwargs["hidden_size"] = dims[1]
+ recurrent_kwargs["input_size"] = dims[0]
+ recurrent_kwargs["num_layers"] = recurrent_layers
+ elif recurrent_module == self.recurrent_module_transformer:
+ recurrent_kwargs["block_count"] = recurrent_layers
+ recurrent_kwargs["context_length"] = recurrent_tf_context_length
+ recurrent_kwargs["head_count"] = recurrent_tf_head_count
+ recurrent_kwargs["hidden_size"] = dims[1]
+ recurrent_kwargs["input_size"] = dims[0]
+ recurrent_kwargs["output_size"] = dims[1]
+
+ rnn = self.recurrent_modules[recurrent_module](**recurrent_kwargs)
+ activation = get_activation(activations[0])
+ dims = dims[1:]
+ activations = activations[1:]
+
+ self._features = nn.Sequential(rnn, activation)
+ else:
+ self._features = nn.Identity()
+
+ layers = []
+ for i in range(len(activations)):
+ layer = nn.Linear(dims[i], dims[i + 1])
+ activation = get_activation(activations[i])
+
+ layers.append(layer)
+ layers.append(activation)
+
+ self._layers = nn.Sequential(*layers)
+
+ if len(layers) > 0:
+ self._init(self._layers, fade=init_fade, gain=init_gain)
+
+ @property
+ def device(self):
+ """Returns the device of the network."""
+ return next(self.parameters()).device
+
+ def forward(self, x: torch.Tensor, hidden_state=None) -> torch.Tensor:
+ """
+ Args:
+ x (torch.Tensor): The input data.
+ hidden_state (Tuple[torch.Tensor, torch.Tensor]): The hidden state of the network. If None, the hidden state
+ of the network is used. If provided, the hidden state of the neural network will not be updated. To
+ retrieve the new hidden state, use the last_hidden_state property. If the network is not recurrent,
+ this argument is ignored.
+ Returns:
+ The output of the network as a torch.Tensor.
+ """
+ assert hidden_state is None or self._recurrent, "Cannot pass hidden state to non-recurrent network."
+
+ input = self._normalization(x.to(self.device))
+
+ if self._recurrent:
+ current_hidden_state = self.hidden_state if hidden_state is None else hidden_state
+ current_hidden_state = (current_hidden_state[0].to(self.device), current_hidden_state[1].to(self.device))
+
+ input = input.unsqueeze(0) if len(input.shape) == 2 else input
+ input, next_hidden_state = self._features[0](input, current_hidden_state)
+ input = self._features[1](input).squeeze(0)
+
+ if hidden_state is None:
+ self.hidden_state = next_hidden_state
+ self._last_hidden_state = next_hidden_state
+
+ output = squeeze_preserve_batch(self._layers(input))
+
+ return output
+
+ @property
+ def last_hidden_state(self):
+ """Returns the hidden state of the last forward pass.
+
+ Does not differentiate whether the hidden state depends on the hidden state kept in the network or whether it
+ was passed into the forward pass.
+
+ Returns:
+ The hidden state of the last forward pass as Tuple[torch.Tensor, torch.Tensor].
+ """
+ return self._last_hidden_state
+
+ def normalize(self, x: torch.Tensor) -> torch.Tensor:
+ """Normalizes the given input.
+
+ Args:
+ x (torch.Tensor): The input to normalize.
+ Returns:
+ The normalized input as a torch.Tensor.
+ """
+ output = self._normalization(x.to(self.device))
+
+ return output
+
+ @property
+ def recurrent(self) -> bool:
+ """Returns whether the network is recurrent."""
+ return self._recurrent
+
+ def reset_hidden_state(self, indices: torch.Tensor) -> None:
+ """Resets the hidden state of the neural network.
+
+ Throws an error if the network is not recurrent.
+
+ Args:
+ indices (torch.Tensor): A 1-dimensional int tensor containing the indices of the terminated
+ environments.
+ """
+ assert self._recurrent
+
+ self.hidden_state[0][:, indices] = torch.zeros(len(indices), self._features[0].hidden_size, device=self.device)
+ self.hidden_state[1][:, indices] = torch.zeros(len(indices), self._features[0].hidden_size, device=self.device)
+
+ def reset_full_hidden_state(self, batch_size=None) -> None:
+ """Resets the hidden state of the neural network.
+
+ Args:
+ batch_size (int): The batch size of the hidden state. If None, the hidden state is reset to None.
+ """
+ assert self._recurrent
+
+ if batch_size is None:
+ self.hidden_state = None
+ else:
+ layer_count, hidden_size = self._features[0].num_layers, self._features[0].hidden_size
+ self.hidden_state = (
+ torch.zeros(layer_count, batch_size, hidden_size, device=self.device),
+ torch.zeros(layer_count, batch_size, hidden_size, device=self.device),
+ )
+
+ def _init(self, layers: List[nn.Module], fade: bool = True, gain: float = 1.0) -> List[nn.Module]:
+ """Initializes neural network layers."""
+ last_layer_idx = len(layers) - 1 - next(i for i, l in enumerate(reversed(layers)) if isinstance(l, nn.Linear))
+
+ for idx, layer in enumerate(layers):
+ if not isinstance(layer, nn.Linear):
+ continue
+
+ current_gain = gain / 100.0 if fade and idx == last_layer_idx else gain
+ nn.init.xavier_normal_(layer.weight, gain=current_gain)
+
+ return layers
diff --git a/rsl_rl/modules/normalizer.py b/rsl_rl/modules/normalizer.py
index 771efcf..4de90a3 100644
--- a/rsl_rl/modules/normalizer.py
+++ b/rsl_rl/modules/normalizer.py
@@ -1,9 +1,3 @@
-# Copyright (c) 2020 Preferred Networks, Inc.
-# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
-# SPDX-License-Identifier: BSD-3-Clause
-
-from __future__ import annotations
-
import torch
from torch import nn
@@ -11,48 +5,58 @@ from torch import nn
class EmpiricalNormalization(nn.Module):
"""Normalize mean and variance of values based on empirical values."""
- def __init__(self, shape, eps=1e-2, until=None):
- """Initialize EmpiricalNormalization module.
-
+ def __init__(self, shape, eps=1e-6, until=None) -> None:
+ """
Args:
shape (int or tuple of int): Shape of input values except batch axis.
eps (float): Small value for stability.
until (int or None): If this arg is specified, the link learns input values until the sum of batch sizes
- exceeds it.
+ exceeds it.
"""
super().__init__()
+
self.eps = eps
self.until = until
+
self.register_buffer("_mean", torch.zeros(shape).unsqueeze(0))
self.register_buffer("_var", torch.ones(shape).unsqueeze(0))
self.register_buffer("_std", torch.ones(shape).unsqueeze(0))
+
self.count = 0
@property
- def mean(self):
- return self._mean.squeeze(0).clone()
+ def mean(self) -> torch.Tensor:
+ """Mean of input values."""
+ return self._mean.squeeze(0).detach().clone()
@property
- def std(self):
- return self._std.squeeze(0).clone()
-
- def forward(self, x):
- """Normalize mean and variance of values based on empirical values.
+ def std(self) -> torch.Tensor:
+ """Standard deviation of input values."""
+ return self._std.squeeze(0).detach().clone()
+ def forward(self, x) -> torch.Tensor:
+ """Normalize mean and variance of values based on emprical values.
Args:
x (ndarray or Variable): Input values
-
Returns:
- ndarray or Variable: Normalized output values
+ Normalized output values
"""
if self.training:
self.update(x)
- return (x - self._mean) / (self._std + self.eps)
+
+ x_normalized = (x - self._mean.detach()) / (self._std.detach() + self.eps)
+
+ return x_normalized
@torch.jit.unused
- def update(self, x):
- """Learn input values without computing the output values of them"""
+ def update(self, x: torch.Tensor) -> None:
+ """Learn input values without computing the output values of them.
+
+ Args:
+ x (torch.Tensor): Input values.
+ """
+ x = x.detach()
if self.until is not None and self.count >= self.until:
return
@@ -69,5 +73,14 @@ class EmpiricalNormalization(nn.Module):
self._std = torch.sqrt(self._var)
@torch.jit.unused
- def inverse(self, y):
- return y * (self._std + self.eps) + self._mean
+ def inverse(self, y: torch.Tensor) -> torch.Tensor:
+ """Inverse normalized values.
+
+ Args:
+ y (torch.Tensor): Normalized input values.
+ Returns:
+ Inverse normalized output values.
+ """
+ inv = y * (self._std + self.eps) + self._mean
+
+ return inv
diff --git a/rsl_rl/modules/quantile_network.py b/rsl_rl/modules/quantile_network.py
new file mode 100644
index 0000000..a3f283e
--- /dev/null
+++ b/rsl_rl/modules/quantile_network.py
@@ -0,0 +1,333 @@
+import torch
+import torch.nn as nn
+from torch.distributions import Normal
+from typing import Callable, Tuple, Union
+
+from rsl_rl.modules.network import Network
+from rsl_rl.utils.benchmarkable import Benchmarkable
+from rsl_rl.utils.utils import squeeze_preserve_batch
+
+eps = torch.finfo(torch.float32).eps
+
+
+def reshape_measure_parameters(
+ qn: Network, *params: Union[torch.Tensor, float]
+) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
+ """Reshapes the parameters of a measure function to match the shape of the quantile network.
+
+ Args:
+ qn (Network): The quantile network.
+ *params (Union[torch.Tensor, float]): The parameters of the measure function.
+ Returns:
+ Union[torch.Tensor, Tuple[torch.Tensor, ...]]: The reshaped parameters.
+ """
+ if not params:
+ return qn._tau.to(qn.device), *params
+
+ assert len([*set([torch.is_tensor(p) for p in params])]) == 1, "All parameters must be either tensors or scalars."
+
+ if torch.is_tensor(params[0]):
+ assert all([p.dim() == 1 for p in params]), "All parameters must have dimensionality 1."
+ assert len([*set([p.shape[0] for p in params])]) == 1, "All parameters must have the same size."
+
+ reshaped_params = [p.reshape(-1, 1).to(qn.device) for p in params]
+ tau = qn._tau.expand(params[0].shape[0], -1).to(qn.device)
+ else:
+ reshaped_params = params
+ tau = qn._tau.to(qn.device)
+
+ return tau, *reshaped_params
+
+
+def make_distorted_measure(distorted_tau: torch.Tensor) -> Callable:
+ """Creates a measure function for the distorted expectation under some distortion function.
+
+ The distorted expectation for some distortion function g(tau) is given by the integral w.r.t. tau
+ "int_0^1 g'(tau) * F_Z^{-1}(tau) dtau" where g'(tau) is the derivative of g w.r.t. tau and F_Z^{-1} is the inverse
+ cumulative distribution function of the value distribution.
+ See https://arxiv.org/pdf/2004.14547.pdf and https://arxiv.org/pdf/1806.06923.pdf for details.
+ """
+ distorted_tau = distorted_tau.reshape(-1, distorted_tau.shape[-1])
+ distortion = (distorted_tau[:, 1:] - distorted_tau[:, :-1]).squeeze(0)
+
+ def distorted_measure(quantiles):
+ sorted_quantiles, _ = quantiles.sort(-1)
+ sorted_quantiles = sorted_quantiles.reshape(-1, sorted_quantiles.shape[-1])
+
+ # dtau = tau[1:] - tau[:-1] cancels the denominator of g'(tau) = g(tau)[1:] - g(tau)[:-1] / dtau.
+ values = squeeze_preserve_batch((distortion.to(sorted_quantiles.device) * sorted_quantiles).sum(-1))
+
+ return values
+
+ return distorted_measure
+
+
+def risk_measure_cvar(qn: Network, confidence_level: float = 1.0) -> Callable:
+ """Conditional value at risk measure.
+
+ TODO: Handle confidence_level being a tensor.
+
+ Args:
+ qn (QuantileNetwork): Quantile network to compute the risk measure for.
+ confidence_level (float): Confidence level of the risk measure. Must be between 0 and 1.
+ Returns:
+ A risk measure function.
+ """
+ tau, confidence_level = reshape_measure_parameters(qn, confidence_level)
+ distorted_tau = torch.min(tau / confidence_level, torch.ones(*tau.shape).to(tau.device))
+
+ return make_distorted_measure(distorted_tau)
+
+
+def risk_measure_neutral(_: Network) -> Callable:
+ """Neutral risk measure (expected value).
+
+ Args:
+ _ (QuantileNetwork): Quantile network to compute the risk measure for.
+ Returns:
+ A risk measure function.
+ """
+
+ def measure(quantiles):
+ values = squeeze_preserve_batch(quantiles.mean(-1))
+
+ return values
+
+ return measure
+
+
+def risk_measure_percentile(_: Network, confidence_level: float = 1.0) -> Callable:
+ """Value at risk measure.
+
+ Args:
+ _ (QuantileNetwork): Quantile network to compute the risk measure for.
+ confidence_level (float): Confidence level of the risk measure. Must be between 0 and 1.
+ Returns:
+ A risk measure function.
+ """
+
+ def measure(quantiles):
+ sorted_quantiles, _ = quantiles.sort(-1)
+ sorted_quantiles = sorted_quantiles.reshape(-1, sorted_quantiles.shape[-1])
+ idx = min(int(confidence_level * quantiles.shape[-1]), quantiles.shape[-1] - 1)
+
+ values = squeeze_preserve_batch(sorted_quantiles[:, idx])
+
+ return values
+
+ return measure
+
+
+def risk_measure_wang(qn: Network, beta: Union[float, torch.Tensor] = 0.0) -> Callable:
+ """Wang's risk measure.
+
+ The risk measure computes the distorted expectation under Wang's risk distortion function
+ g(tau) = Phi(Phi^-1(tau) + beta) where Phi and Phi^-1 are the standard normal CDF and its inverse.
+ See https://arxiv.org/pdf/2004.14547.pdf for details.
+
+ Args:
+ qn (QuantileNetwork): Quantile network to compute the risk measure for.
+ beta (float): Parameter of the risk distortion function.
+ Returns:
+ A risk measure function.
+ """
+ tau, beta = reshape_measure_parameters(qn, beta)
+
+ distorted_tau = Normal(0, 1).cdf(Normal(0, 1).icdf(tau) + beta)
+
+ return make_distorted_measure(distorted_tau)
+
+
+def energy_loss(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
+ """Computes sample energy loss between predictions and targets.
+
+ The energy loss is computed as 2*E[||X - Y||_2] - E[||X - X'||_2] - E[||Y - Y'||_2], where X, X' and Y, Y' are
+ random variables and ||.||_2 is the L2-norm. X, X' are the predictions and Y, Y' are the targets.
+
+ Args:
+ predictions (torch.Tensor): Predictions to compute loss from.
+ targets (torch.Tensor): Targets to compare predictions against.
+ Returns:
+ A torch.Tensor of shape (1,) containing the loss.
+ """
+ dims = [-1 for _ in range(predictions.dim())]
+ prediction_mat = predictions.unsqueeze(-1).expand(*dims, predictions.shape[-1])
+ target_mat = targets.unsqueeze(-1).expand(*dims, predictions.shape[-1])
+
+ delta_xx = (prediction_mat - prediction_mat.transpose(-1, -2)).abs().mean()
+ delta_yy = (target_mat - target_mat.transpose(-1, -2)).abs().mean()
+ delta_xy = (prediction_mat - target_mat.transpose(-1, -2)).abs().mean()
+
+ loss = 2 * delta_xy - delta_xx - delta_yy
+
+ return loss
+
+
+class QuantileNetwork(Network):
+ measure_cvar = "cvar"
+ measure_neutral = "neutral"
+ measure_percentile = "percentile"
+ measure_wang = "wang"
+
+ measures = {
+ measure_cvar: risk_measure_cvar,
+ measure_neutral: risk_measure_neutral,
+ measure_percentile: risk_measure_percentile,
+ measure_wang: risk_measure_wang,
+ }
+
+ def __init__(
+ self,
+ input_size,
+ output_size,
+ activations=["relu", "relu", "relu"],
+ hidden_dims=[256, 256, 256],
+ init_fade=False,
+ init_gain=0.5,
+ measure=None,
+ measure_kwargs={},
+ quantile_count=200,
+ **kwargs,
+ ):
+ assert len(hidden_dims) == len(activations)
+ assert quantile_count > 0
+
+ super().__init__(
+ input_size,
+ activations=activations,
+ hidden_dims=hidden_dims[:-1],
+ init_fade=False,
+ init_gain=init_gain,
+ output_size=hidden_dims[-1],
+ **kwargs,
+ )
+
+ self._quantile_count = quantile_count
+ self._tau = torch.arange(self._quantile_count + 1) / self._quantile_count
+ self._tau_hat = torch.tensor([(self._tau[i] + self._tau[i + 1]) / 2 for i in range(self._quantile_count)])
+ self._tau_hat_mat = torch.empty((0,))
+
+ self._quantile_layers = nn.ModuleList([nn.Linear(hidden_dims[-1], quantile_count) for _ in range(output_size)])
+
+ self._init(self._quantile_layers, fade=init_fade, gain=init_gain)
+
+ measure_func = risk_measure_neutral if measure is None else self.measures[measure]
+ self._measure_func = measure_func
+ self._measure = measure_func(self, **measure_kwargs)
+
+ self._last_quantiles = None
+
+ @property
+ def last_quantiles(self) -> torch.Tensor:
+ return self._last_quantiles
+
+ def make_diracs(self, values: torch.Tensor) -> torch.Tensor:
+ """Generates value distributions that have a single spike at the given values.
+
+ Args:
+ values (torch.Tensor): Values to generate dirac distributions for.
+ Returns:
+ A torch.Tensor of shape (*values.shape, quantile_count) containing the dirac distributions.
+ """
+ dirac = values.unsqueeze(-1).expand(*[-1 for _ in range(values.dim())], self._quantile_count)
+
+ return dirac
+
+ @property
+ def quantile_count(self) -> int:
+ return self._quantile_count
+
+ @Benchmarkable.register
+ def quantiles_to_values(self, quantiles: torch.Tensor, *measure_args) -> torch.Tensor:
+ """Computes values from quantiles.
+
+ Args:
+ quantiles (torch.Tensor): Quantiles to compute values from.
+ measure_kwargs (dict): Keyword arguments to pass to the risk measure function instead of the arguments
+ passed when creating the network.
+ Returns:
+ A torch.Tensor of shape (1,) containing the values.
+ """
+ if measure_args:
+ values = self._measure_func(self, *[squeeze_preserve_batch(m) for m in measure_args])(quantiles)
+ else:
+ values = self._measure(quantiles)
+
+ return values
+
+ @Benchmarkable.register
+ def quantile_l1_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
+ """Computes quantile-wise l1 loss between predictions and targets.
+
+ TODO: This function is a bottleneck.
+
+ Args:
+ predictions (torch.Tensor): Predictions to compute loss from.
+ targets (torch.Tensor): Targets to compare predictions against.
+ Returns:
+ A torch.Tensor of shape (1,) containing the loss.
+ """
+ assert (
+ predictions.dim() == 2 or predictions.dim() == 3
+ ), f"Predictions must be 2D or 3D. Got {predictions.dim()}."
+ assert (
+ predictions.shape == targets.shape
+ ), f"The shapes of predictions and targets must match. Got {predictions.shape} and {targets.shape}."
+
+ pre_dims = [-1] if predictions.dim() == 3 else []
+
+ prediction_mat = predictions.unsqueeze(-3).expand(*pre_dims, self._quantile_count, -1, -1)
+ target_mat = targets.transpose(-2, -1).unsqueeze(-1).expand(*pre_dims, -1, -1, self._quantile_count)
+ delta = target_mat - prediction_mat
+
+ tau_hat = self._tau_hat.expand(predictions.shape[-2], -1).to(self.device)
+ loss = (torch.where(delta < 0, (tau_hat - 1), tau_hat) * delta).abs().mean()
+
+ return loss
+
+ @Benchmarkable.register
+ def quantile_huber_loss(self, predictions: torch.Tensor, targets: torch.Tensor, kappa: float = 1.0) -> torch.Tensor:
+ """Computes quantile huber loss between predictions and targets.
+
+ TODO: This function is a bottleneck.
+
+ Args:
+ predictions (torch.Tensor): Predictions to compute loss from.
+ targets (torch.Tensor): Targets to compare predictions against.
+ kappa (float): Defines the interval [-kappa, kappa] around zero where squared loss is used. Defaults to 1.
+ Returns:
+ A torch.tensor of shape (1,) containing the loss.
+ """
+ pre_dims = [-1] if predictions.dim() == 3 else []
+
+ prediction_mat = predictions.unsqueeze(-3).expand(*pre_dims, self._quantile_count, -1, -1)
+ target_mat = targets.transpose(-2, -1).unsqueeze(-1).expand(*pre_dims, -1, -1, self._quantile_count)
+ delta = target_mat - prediction_mat
+ delta_abs = delta.abs()
+
+ huber = torch.where(delta_abs <= kappa, 0.5 * delta.pow(2), kappa * (delta_abs - 0.5 * kappa))
+
+ tau_hat = self._tau_hat.expand(predictions.shape[-2], -1).to(self.device)
+ loss = (torch.where(delta < 0, (tau_hat - 1), tau_hat).abs() * huber).mean()
+
+ return loss
+
+ @Benchmarkable.register
+ def sample_energy_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
+ loss = energy_loss(predictions, targets)
+
+ return loss
+
+ @Benchmarkable.register
+ def forward(self, x: torch.Tensor, distribution: bool = False, measure_args: list = [], **kwargs) -> torch.Tensor:
+ features = super().forward(x, **kwargs)
+ quantiles = squeeze_preserve_batch(torch.stack([layer(features) for layer in self._quantile_layers], dim=1))
+
+ self._last_quantiles = quantiles
+
+ if distribution:
+ return quantiles
+
+ values = self.quantiles_to_values(quantiles, *measure_args)
+
+ return values
diff --git a/rsl_rl/modules/transformer.py b/rsl_rl/modules/transformer.py
new file mode 100644
index 0000000..7e27b0f
--- /dev/null
+++ b/rsl_rl/modules/transformer.py
@@ -0,0 +1,150 @@
+import torch
+from typing import Tuple
+
+
+class Head(torch.nn.Module):
+ """A single causal self-attention head."""
+
+ def __init__(self, hidden_size: int, head_size: int):
+ super().__init__()
+
+ self.query = torch.nn.Linear(hidden_size, head_size)
+ self.key = torch.nn.Linear(hidden_size, head_size)
+ self.value = torch.nn.Linear(hidden_size, head_size)
+
+ def forward(self, x: torch.Tensor):
+ x = x.transpose(0, 1)
+ _, S, F = x.shape # (Batch, Sequence, Features)
+
+ q = self.query(x)
+ k = self.key(x)
+
+ weight = q @ k.transpose(-1, -2) * F**-0.5 # shape: (B, S, S)
+ weight.masked_fill(torch.tril(torch.ones(S, S, device=x.device)) == 0, float("-inf"))
+ weight = torch.nn.functional.softmax(weight, dim=-1)
+
+ v = self.value(x) # shape: (B, S, F)
+ out = (weight @ v).transpose(0, 1) # shape: (S, B, F)
+
+ return out
+
+
+class MultiHead(torch.nn.Module):
+ def __init__(self, hidden_size: int, head_count: int):
+ super().__init__()
+
+ assert hidden_size % head_count == 0, f"Multi-headed attention head size must be a multiple of the head count."
+
+ self.heads = torch.nn.ModuleList([Head(hidden_size, hidden_size // head_count) for _ in range(head_count)])
+ self.proj = torch.nn.Linear(hidden_size, hidden_size)
+
+ def forward(self, x: torch.Tensor):
+ x = torch.cat([head(x) for head in self.heads], dim=-1)
+ out = self.proj(x)
+
+ return out
+
+
+class Block(torch.nn.Module):
+ def __init__(self, hidden_size: int, head_count: int):
+ super().__init__()
+
+ self.sa = MultiHead(hidden_size, head_count)
+ self.ff = torch.nn.Sequential(
+ torch.nn.Linear(hidden_size, 4 * hidden_size),
+ torch.nn.ReLU(),
+ torch.nn.Linear(4 * hidden_size, hidden_size),
+ )
+ self.ln1 = torch.nn.LayerNorm(hidden_size)
+ self.ln2 = torch.nn.LayerNorm(hidden_size)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.sa(self.ln1(x)) + x
+ out = self.ff(self.ln2(x)) + x
+
+ return out
+
+
+class Transformer(torch.nn.Module):
+ """A Transformer-based recurrent module.
+
+ The Transformer module is a recurrent module that uses a Transformer architecture to process the input sequence. It
+ uses a hidden state to emulate RNN-like behavior.
+ """
+
+ def __init__(
+ self, input_size, output_size, hidden_size, block_count: int = 6, context_length: int = 64, head_count: int = 8
+ ):
+ """
+ Args:
+ input_size (int): The size of the input.
+ output_size (int): The size of the output.
+ hidden_size (int): The size of the hidden layers.
+ block_count (int): The number of Transformer blocks.
+ context_length (int): The length of the context to consider when predicting the next token.
+ head_count (int): The number of attention heads per block.
+ """
+
+ assert context_length % 2 == 0, f"Context length must be even."
+
+ super().__init__()
+
+ self.context_length = context_length
+ self.hidden_size = hidden_size
+
+ self.feature_proj = torch.nn.Linear(input_size, hidden_size)
+ self.position_embedding = torch.nn.Embedding(context_length, hidden_size)
+ self.blocks = torch.nn.Sequential(
+ *[Block(hidden_size, head_count) for _ in range(block_count)],
+ torch.nn.LayerNorm(hidden_size),
+ )
+ self.head = torch.nn.Linear(hidden_size, output_size)
+
+ @property
+ def num_layers(self):
+ # Set num_layers to half the context length for simple torch.nn.LSTM compatibility. TODO: This is a bit hacky.
+ return self.context_length // 2
+
+ def step(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
+ """Computes Transformer output given the full context and the input.
+
+ Args:
+ x (torch.Tensor): The input tensor of shape (Sequence, Batch, Features).
+ context (torch.Tensor): The context tensor of shape (Context Length, Batch, Features).
+ Returns:
+ A tuple of the output tensor of shape (Sequence, Batch, Features) and the updated context with the input
+ features appended. The context has shape (Context Length, Batch, Features).
+ """
+ S = x.shape[0]
+
+ # Project input to feature space and add to context.
+ ft_x = self.feature_proj(x)
+ context = torch.cat((context, ft_x), dim=0)[-self.context_length :]
+
+ # Add positional embedding to context.
+ ft_pos = self.position_embedding(torch.arange(self.context_length, device=x.device)).unsqueeze(1)
+ x = context + ft_pos
+
+ # Compute output from Transformer blocks.
+ x = self.blocks(x)
+ out = self.head(x)[-S:]
+
+ return out, context
+
+ def forward(self, x: torch.Tensor, hidden_state: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
+ """Computes Transformer output given the input and the hidden state which encapsulates the context."""
+ if hidden_state is None:
+ hidden_state = self.reset_hidden_state(x.shape[1], device=x.device)
+ context = torch.cat(hidden_state, dim=0)
+
+ out, context = self.step(x, context)
+
+ hidden_state = context[: self.num_layers], context[self.num_layers :]
+
+ return out, hidden_state
+
+ def reset_hidden_state(self, batch_size: int, device="cpu"):
+ hidden_state = torch.zeros((self.context_length, batch_size, self.hidden_size), device=device)
+ hidden_state = hidden_state[: self.num_layers], hidden_state[self.num_layers :]
+
+ return hidden_state
diff --git a/rsl_rl/modules/utils.py b/rsl_rl/modules/utils.py
new file mode 100644
index 0000000..48d2b30
--- /dev/null
+++ b/rsl_rl/modules/utils.py
@@ -0,0 +1,32 @@
+from torch import nn
+
+
+def get_activation(act_name):
+ if act_name == "elu":
+ return nn.ELU()
+ elif act_name == "selu":
+ return nn.SELU()
+ elif act_name == "relu":
+ return nn.ReLU()
+ elif act_name == "crelu":
+ return nn.ReLU()
+ elif act_name == "lrelu":
+ return nn.LeakyReLU()
+ elif act_name == "tanh":
+ return nn.Tanh()
+ elif act_name == "sigmoid":
+ return nn.Sigmoid()
+ elif act_name == "linear":
+ return nn.Identity()
+ elif act_name == "softmax":
+ return nn.Softmax()
+ else:
+ print("invalid activation function!")
+ return None
+
+
+def init_xavier_uniform(layer, activation):
+ try:
+ nn.init.xavier_uniform_(layer.weight, gain=nn.init.calculate_gain(activation))
+ except ValueError:
+ nn.init.xavier_uniform_(layer.weight)
diff --git a/rsl_rl/runners/__init__.py b/rsl_rl/runners/__init__.py
index cebf22b..0097668 100644
--- a/rsl_rl/runners/__init__.py
+++ b/rsl_rl/runners/__init__.py
@@ -3,6 +3,7 @@
"""Implementation of runners for environment-agent interaction."""
-from .on_policy_runner import OnPolicyRunner
+from .runner import Runner
+from .legacy_runner import LeggedGymRunner
-__all__ = ["OnPolicyRunner"]
+__all__ = ["LeggedGymRunner", "Runner"]
diff --git a/rsl_rl/runners/callbacks.py b/rsl_rl/runners/callbacks.py
new file mode 100644
index 0000000..df99d5a
--- /dev/null
+++ b/rsl_rl/runners/callbacks.py
@@ -0,0 +1,79 @@
+import os
+import random
+import string
+import wandb
+
+
+def make_save_model_cb(directory):
+ def cb(runner, stat):
+ path = os.path.join(directory, "model_{}.pt".format(stat["current_iteration"]))
+ runner.save(path)
+
+ return cb
+
+
+def make_save_model_onnx_cb(directory):
+ def cb(runner, stat):
+ path = os.path.join(directory, "model_{}.onnx".format(stat["current_iteration"]))
+ runner.export_onnx(path)
+
+ return cb
+
+
+def make_interval_cb(callback, interval):
+ def cb(runner, stat):
+ if stat["current_iteration"] % interval != 0:
+ return
+
+ callback(runner, stat)
+
+ return cb
+
+
+def make_final_cb(callback):
+ def cb(runner, stat):
+ if not runner._learning_should_terminate():
+ return
+
+ callback(runner, stat)
+
+ return cb
+
+
+def make_first_cb(callback):
+ uuid = "".join(random.choices(string.ascii_letters + string.digits, k=8))
+
+ def cb(runner, stat):
+ if hasattr(runner, f"_first_cb_{uuid}"):
+ return
+
+ setattr(runner, f"_first_cb_{uuid}", True)
+ callback(runner, stat)
+
+ return cb
+
+
+def make_wandb_cb(init_kwargs):
+ assert "project" in init_kwargs, "The project must be specified in the init_kwargs."
+
+ run = wandb.init(**init_kwargs)
+ check_complete = make_final_cb(lambda *_: run.finish())
+
+ def cb(runner, stat):
+ mean_reward = sum(stat["returns"]) / len(stat["returns"]) if len(stat["returns"]) > 0 else 0.0
+ mean_steps = sum(stat["lengths"]) / len(stat["lengths"]) if len(stat["lengths"]) > 0 else 0.0
+ total_steps = stat["current_iteration"] * runner.env.num_envs * runner._num_steps_per_env
+ training_time = stat["training_time"]
+
+ run.log(
+ {
+ "mean_rewards": mean_reward,
+ "mean_steps": mean_steps,
+ "training_steps": total_steps,
+ "training_time": training_time,
+ }
+ )
+
+ check_complete(runner, stat)
+
+ return cb
diff --git a/rsl_rl/runners/legacy_runner.py b/rsl_rl/runners/legacy_runner.py
new file mode 100644
index 0000000..4c98ab3
--- /dev/null
+++ b/rsl_rl/runners/legacy_runner.py
@@ -0,0 +1,136 @@
+import os
+
+from rsl_rl.algorithms import *
+from rsl_rl.env import VecEnv
+from rsl_rl.runners.callbacks import (
+ make_final_cb,
+ make_first_cb,
+ make_interval_cb,
+ make_save_model_onnx_cb,
+)
+from rsl_rl.runners.runner import Runner
+from rsl_rl.storage import *
+
+
+def make_legacy_save_model_cb(directory):
+ def cb(runner, stat):
+ data = {}
+
+ if hasattr(runner.env, "_persistent_data"):
+ data["env_data"] = runner.env._persistent_data
+
+ path = os.path.join(directory, "model_{}.pt".format(stat["current_iteration"]))
+ runner.save(path, data=data)
+
+ return cb
+
+
+class LeggedGymRunner(Runner):
+ """Runner for legged_gym environments."""
+
+ mappings = [
+ ("init_noise_std", "actor_noise_std"),
+ ("clip_param", "clip_ratio"),
+ ("desired_kl", "target_kl"),
+ ("entropy_coef", "entropy_coeff"),
+ ("lam", "gae_lambda"),
+ ("max_grad_norm", "gradient_clip"),
+ ("num_learning_epochs", None),
+ ("num_mini_batches", "batch_count"),
+ ("use_clipped_value_loss", None),
+ ("value_loss_coef", "value_coeff"),
+ ]
+
+ @staticmethod
+ def _hook_env(env: VecEnv):
+ old_step = env.step
+
+ def step_hook(*args, **kwargs):
+ result = old_step(*args, **kwargs)
+
+ if len(result) == 4:
+ obs, rewards, dones, env_info = result
+ elif len(result) == 5:
+ obs, _, rewards, dones, env_info = result
+ else:
+ raise ValueError("Invalid number of return values from env.step().")
+
+ return obs, rewards, dones.float(), env_info
+
+ env.step = step_hook
+
+ return env
+
+ def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"):
+ env = self._hook_env(env)
+ self.cfg = train_cfg["runner"]
+
+ alg_class = eval(self.cfg["algorithm_class_name"])
+ if "policy_class_name" in self.cfg:
+ print("WARNING: ignoring deprecated parameter 'runner.policy_class_name'.")
+
+ alg_cfg = train_cfg["algorithm"]
+ alg_cfg.update(train_cfg["policy"])
+
+ if "activation" in alg_cfg:
+ print(
+ "WARNING: using deprecated parameter 'activation'. Use 'actor_activations' and 'critic_activations' instead."
+ )
+ alg_cfg["actor_activations"] = [alg_cfg["activation"] for _ in range(len(alg_cfg["actor_hidden_dims"]))]
+ alg_cfg["actor_activations"] += ["linear"]
+ alg_cfg["critic_activations"] = [alg_cfg["activation"] for _ in range(len(alg_cfg["critic_hidden_dims"]))]
+ alg_cfg["critic_activations"] += ["linear"]
+ del alg_cfg["activation"]
+
+ for old, new in self.mappings:
+ if old not in alg_cfg:
+ continue
+
+ if new is None:
+ print(f"WARNING: ignoring deprecated parameter '{old}'.")
+ del alg_cfg[old]
+ continue
+
+ print(f"WARNING: using deprecated parameter '{old}'. Use '{new}' instead.")
+ alg_cfg[new] = alg_cfg[old]
+ del alg_cfg[old]
+
+ agent: Agent = alg_class(env, device=device, **train_cfg["algorithm"])
+
+ callbacks = []
+ evaluation_callbacks = []
+
+ evaluation_callbacks.append(lambda *args: Runner._log_progress(*args, prefix="eval"))
+
+ if log_dir and "save_interval" in self.cfg:
+ callbacks.append(make_first_cb(make_legacy_save_model_cb(log_dir)))
+ callbacks.append(make_interval_cb(make_legacy_save_model_cb(log_dir), self.cfg["save_interval"]))
+
+ if log_dir:
+ callbacks.append(Runner._log)
+ callbacks.append(make_final_cb(make_legacy_save_model_cb(log_dir)))
+ callbacks.append(make_final_cb(make_save_model_onnx_cb(log_dir)))
+ # callbacks.append(make_first_cb(lambda *_: store_code_state(log_dir, self._git_status_repos)))
+ else:
+ callbacks.append(Runner._log_progress)
+
+ super().__init__(
+ env,
+ agent,
+ learn_cb=callbacks,
+ evaluation_cb=evaluation_callbacks,
+ device=device,
+ num_steps_per_env=self.cfg["num_steps_per_env"],
+ )
+
+ self._iteration_time = 0.0
+
+ def learn(self, *args, num_learning_iterations=None, init_at_random_ep_len=None, **kwargs):
+ if num_learning_iterations is not None:
+ print("WARNING: using deprecated parameter 'num_learning_iterations'. Use 'iterations' instead.")
+ kwargs["iterations"] = num_learning_iterations
+
+ if init_at_random_ep_len is not None:
+ print("WARNING: ignoring deprecated parameter 'init_at_random_ep_len'.")
+
+ super().learn(*args, **kwargs)
diff --git a/rsl_rl/runners/on_policy_runner.py b/rsl_rl/runners/on_policy_runner.py
deleted file mode 100644
index 9e0a459..0000000
--- a/rsl_rl/runners/on_policy_runner.py
+++ /dev/null
@@ -1,304 +0,0 @@
-# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
-# SPDX-License-Identifier: BSD-3-Clause
-
-from __future__ import annotations
-
-import os
-import statistics
-import time
-import torch
-from collections import deque
-from torch.utils.tensorboard import SummaryWriter as TensorboardSummaryWriter
-
-import rsl_rl
-from rsl_rl.algorithms import PPO
-from rsl_rl.env import VecEnv
-from rsl_rl.modules import ActorCritic, ActorCriticRecurrent, EmpiricalNormalization
-from rsl_rl.utils import store_code_state
-
-
-class OnPolicyRunner:
- """On-policy runner for training and evaluation."""
-
- def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"):
- self.cfg = train_cfg
- self.alg_cfg = train_cfg["algorithm"]
- self.policy_cfg = train_cfg["policy"]
- self.device = device
- self.env = env
- obs, extras = self.env.get_observations()
- num_obs = obs.shape[1]
- if "critic" in extras["observations"]:
- num_critic_obs = extras["observations"]["critic"].shape[1]
- else:
- num_critic_obs = num_obs
- actor_critic_class = eval(self.policy_cfg.pop("class_name")) # ActorCritic
- actor_critic: ActorCritic | ActorCriticRecurrent = actor_critic_class(
- num_obs, num_critic_obs, self.env.num_actions, **self.policy_cfg
- ).to(self.device)
- alg_class = eval(self.alg_cfg.pop("class_name")) # PPO
- self.alg: PPO = alg_class(actor_critic, device=self.device, **self.alg_cfg)
- self.num_steps_per_env = self.cfg["num_steps_per_env"]
- self.save_interval = self.cfg["save_interval"]
- self.empirical_normalization = self.cfg["empirical_normalization"]
- if self.empirical_normalization:
- self.obs_normalizer = EmpiricalNormalization(shape=[num_obs], until=1.0e8).to(self.device)
- self.critic_obs_normalizer = EmpiricalNormalization(shape=[num_critic_obs], until=1.0e8).to(self.device)
- else:
- self.obs_normalizer = torch.nn.Identity() # no normalization
- self.critic_obs_normalizer = torch.nn.Identity() # no normalization
- # init storage and model
- self.alg.init_storage(
- self.env.num_envs,
- self.num_steps_per_env,
- [num_obs],
- [num_critic_obs],
- [self.env.num_actions],
- )
-
- # Log
- self.log_dir = log_dir
- self.writer = None
- self.tot_timesteps = 0
- self.tot_time = 0
- self.current_learning_iteration = 0
- self.git_status_repos = [rsl_rl.__file__]
-
- def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = False):
- # initialize writer
- if self.log_dir is not None and self.writer is None:
- # Launch either Tensorboard or Neptune & Tensorboard summary writer(s), default: Tensorboard.
- self.logger_type = self.cfg.get("logger", "tensorboard")
- self.logger_type = self.logger_type.lower()
-
- if self.logger_type == "neptune":
- from rsl_rl.utils.neptune_utils import NeptuneSummaryWriter
-
- self.writer = NeptuneSummaryWriter(log_dir=self.log_dir, flush_secs=10, cfg=self.cfg)
- self.writer.log_config(self.env.cfg, self.cfg, self.alg_cfg, self.policy_cfg)
- elif self.logger_type == "wandb":
- from rsl_rl.utils.wandb_utils import WandbSummaryWriter
-
- self.writer = WandbSummaryWriter(log_dir=self.log_dir, flush_secs=10, cfg=self.cfg)
- self.writer.log_config(self.env.cfg, self.cfg, self.alg_cfg, self.policy_cfg)
- elif self.logger_type == "tensorboard":
- self.writer = TensorboardSummaryWriter(log_dir=self.log_dir, flush_secs=10)
- else:
- raise AssertionError("logger type not found")
-
- if init_at_random_ep_len:
- self.env.episode_length_buf = torch.randint_like(
- self.env.episode_length_buf, high=int(self.env.max_episode_length)
- )
- obs, extras = self.env.get_observations()
- critic_obs = extras["observations"].get("critic", obs)
- obs, critic_obs = obs.to(self.device), critic_obs.to(self.device)
- self.train_mode() # switch to train mode (for dropout for example)
-
- ep_infos = []
- rewbuffer = deque(maxlen=100)
- lenbuffer = deque(maxlen=100)
- cur_reward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
- cur_episode_length = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
-
- start_iter = self.current_learning_iteration
- tot_iter = start_iter + num_learning_iterations
- for it in range(start_iter, tot_iter):
- start = time.time()
- # Rollout
- with torch.inference_mode():
- for i in range(self.num_steps_per_env):
- actions = self.alg.act(obs, critic_obs)
- obs, rewards, dones, infos = self.env.step(actions)
- obs = self.obs_normalizer(obs)
- if "critic" in infos["observations"]:
- critic_obs = self.critic_obs_normalizer(infos["observations"]["critic"])
- else:
- critic_obs = obs
- obs, critic_obs, rewards, dones = (
- obs.to(self.device),
- critic_obs.to(self.device),
- rewards.to(self.device),
- dones.to(self.device),
- )
- self.alg.process_env_step(rewards, dones, infos)
-
- if self.log_dir is not None:
- # Book keeping
- # note: we changed logging to use "log" instead of "episode" to avoid confusion with
- # different types of logging data (rewards, curriculum, etc.)
- if "episode" in infos:
- ep_infos.append(infos["episode"])
- elif "log" in infos:
- ep_infos.append(infos["log"])
- cur_reward_sum += rewards
- cur_episode_length += 1
- new_ids = (dones > 0).nonzero(as_tuple=False)
- rewbuffer.extend(cur_reward_sum[new_ids][:, 0].cpu().numpy().tolist())
- lenbuffer.extend(cur_episode_length[new_ids][:, 0].cpu().numpy().tolist())
- cur_reward_sum[new_ids] = 0
- cur_episode_length[new_ids] = 0
-
- stop = time.time()
- collection_time = stop - start
-
- # Learning step
- start = stop
- self.alg.compute_returns(critic_obs)
-
- mean_value_loss, mean_surrogate_loss = self.alg.update()
- stop = time.time()
- learn_time = stop - start
- self.current_learning_iteration = it
- if self.log_dir is not None:
- self.log(locals())
- if it % self.save_interval == 0:
- self.save(os.path.join(self.log_dir, f"model_{it}.pt"))
- ep_infos.clear()
- if it == start_iter:
- # obtain all the diff files
- git_file_paths = store_code_state(self.log_dir, self.git_status_repos)
- # if possible store them to wandb
- if self.logger_type in ["wandb", "neptune"] and git_file_paths:
- for path in git_file_paths:
- self.writer.save_file(path)
-
- self.save(os.path.join(self.log_dir, f"model_{self.current_learning_iteration}.pt"))
-
- def log(self, locs: dict, width: int = 80, pad: int = 35):
- self.tot_timesteps += self.num_steps_per_env * self.env.num_envs
- self.tot_time += locs["collection_time"] + locs["learn_time"]
- iteration_time = locs["collection_time"] + locs["learn_time"]
-
- ep_string = ""
- if locs["ep_infos"]:
- for key in locs["ep_infos"][0]:
- infotensor = torch.tensor([], device=self.device)
- for ep_info in locs["ep_infos"]:
- # handle scalar and zero dimensional tensor infos
- if key not in ep_info:
- continue
- if not isinstance(ep_info[key], torch.Tensor):
- ep_info[key] = torch.Tensor([ep_info[key]])
- if len(ep_info[key].shape) == 0:
- ep_info[key] = ep_info[key].unsqueeze(0)
- infotensor = torch.cat((infotensor, ep_info[key].to(self.device)))
- value = torch.mean(infotensor)
- # log to logger and terminal
- if "/" in key:
- self.writer.add_scalar(key, value, locs["it"])
- ep_string += f"""{f'{key}:':>{pad}} {value:.4f}\n"""
- else:
- self.writer.add_scalar("Episode/" + key, value, locs["it"])
- ep_string += f"""{f'Mean episode {key}:':>{pad}} {value:.4f}\n"""
- mean_std = self.alg.actor_critic.std.mean()
- fps = int(self.num_steps_per_env * self.env.num_envs / (locs["collection_time"] + locs["learn_time"]))
-
- self.writer.add_scalar("Loss/value_function", locs["mean_value_loss"], locs["it"])
- self.writer.add_scalar("Loss/surrogate", locs["mean_surrogate_loss"], locs["it"])
- self.writer.add_scalar("Loss/learning_rate", self.alg.learning_rate, locs["it"])
- self.writer.add_scalar("Policy/mean_noise_std", mean_std.item(), locs["it"])
- self.writer.add_scalar("Perf/total_fps", fps, locs["it"])
- self.writer.add_scalar("Perf/collection time", locs["collection_time"], locs["it"])
- self.writer.add_scalar("Perf/learning_time", locs["learn_time"], locs["it"])
- if len(locs["rewbuffer"]) > 0:
- self.writer.add_scalar("Train/mean_reward", statistics.mean(locs["rewbuffer"]), locs["it"])
- self.writer.add_scalar("Train/mean_episode_length", statistics.mean(locs["lenbuffer"]), locs["it"])
- if self.logger_type != "wandb": # wandb does not support non-integer x-axis logging
- self.writer.add_scalar("Train/mean_reward/time", statistics.mean(locs["rewbuffer"]), self.tot_time)
- self.writer.add_scalar(
- "Train/mean_episode_length/time", statistics.mean(locs["lenbuffer"]), self.tot_time
- )
-
- str = f" \033[1m Learning iteration {locs['it']}/{locs['tot_iter']} \033[0m "
-
- if len(locs["rewbuffer"]) > 0:
- log_string = (
- f"""{'#' * width}\n"""
- f"""{str.center(width, ' ')}\n\n"""
- f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[
- 'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
- f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n"""
- f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n"""
- f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n"""
- f"""{'Mean reward:':>{pad}} {statistics.mean(locs['rewbuffer']):.2f}\n"""
- f"""{'Mean episode length:':>{pad}} {statistics.mean(locs['lenbuffer']):.2f}\n"""
- )
- # f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n"""
- # f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n""")
- else:
- log_string = (
- f"""{'#' * width}\n"""
- f"""{str.center(width, ' ')}\n\n"""
- f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[
- 'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
- f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n"""
- f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n"""
- f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n"""
- )
- # f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n"""
- # f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n""")
-
- log_string += ep_string
- log_string += (
- f"""{'-' * width}\n"""
- f"""{'Total timesteps:':>{pad}} {self.tot_timesteps}\n"""
- f"""{'Iteration time:':>{pad}} {iteration_time:.2f}s\n"""
- f"""{'Total time:':>{pad}} {self.tot_time:.2f}s\n"""
- f"""{'ETA:':>{pad}} {self.tot_time / (locs['it'] + 1) * (
- locs['num_learning_iterations'] - locs['it']):.1f}s\n"""
- )
- print(log_string)
-
- def save(self, path, infos=None):
- saved_dict = {
- "model_state_dict": self.alg.actor_critic.state_dict(),
- "optimizer_state_dict": self.alg.optimizer.state_dict(),
- "iter": self.current_learning_iteration,
- "infos": infos,
- }
- if self.empirical_normalization:
- saved_dict["obs_norm_state_dict"] = self.obs_normalizer.state_dict()
- saved_dict["critic_obs_norm_state_dict"] = self.critic_obs_normalizer.state_dict()
- torch.save(saved_dict, path)
-
- # Upload model to external logging service
- if self.logger_type in ["neptune", "wandb"]:
- self.writer.save_model(path, self.current_learning_iteration)
-
- def load(self, path, load_optimizer=True):
- loaded_dict = torch.load(path)
- self.alg.actor_critic.load_state_dict(loaded_dict["model_state_dict"])
- if self.empirical_normalization:
- self.obs_normalizer.load_state_dict(loaded_dict["obs_norm_state_dict"])
- self.critic_obs_normalizer.load_state_dict(loaded_dict["critic_obs_norm_state_dict"])
- if load_optimizer:
- self.alg.optimizer.load_state_dict(loaded_dict["optimizer_state_dict"])
- self.current_learning_iteration = loaded_dict["iter"]
- return loaded_dict["infos"]
-
- def get_inference_policy(self, device=None):
- self.eval_mode() # switch to evaluation mode (dropout for example)
- if device is not None:
- self.alg.actor_critic.to(device)
- policy = self.alg.actor_critic.act_inference
- if self.cfg["empirical_normalization"]:
- if device is not None:
- self.obs_normalizer.to(device)
- policy = lambda x: self.alg.actor_critic.act_inference(self.obs_normalizer(x)) # noqa: E731
- return policy
-
- def train_mode(self):
- self.alg.actor_critic.train()
- if self.empirical_normalization:
- self.obs_normalizer.train()
- self.critic_obs_normalizer.train()
-
- def eval_mode(self):
- self.alg.actor_critic.eval()
- if self.empirical_normalization:
- self.obs_normalizer.eval()
- self.critic_obs_normalizer.eval()
-
- def add_git_repo_to_log(self, repo_file_path):
- self.git_status_repos.append(repo_file_path)
diff --git a/rsl_rl/runners/runner.py b/rsl_rl/runners/runner.py
new file mode 100644
index 0000000..b461a88
--- /dev/null
+++ b/rsl_rl/runners/runner.py
@@ -0,0 +1,498 @@
+from __future__ import annotations
+import copy
+from datetime import timedelta
+import numpy as np
+import os
+import time
+import torch
+from typing import Any, Callable, Dict, List, Tuple, TypedDict, Union
+
+import rsl_rl
+from rsl_rl.storage.storage import Dataset
+from rsl_rl.algorithms import Agent
+from rsl_rl.env import VecEnv
+
+
+class EpisodeStatistics(TypedDict):
+ """The statistics of an episode."""
+
+ # Time it took to collect samples for the current interation.
+ collection_time: Union[int, None]
+ # The counter of the current interation.
+ current_iteration: int
+ # The number of the final iteration of the current run.
+ final_iteration: int
+ # The number of the first iteration of the current run.
+ first_iteration: int
+ # Environment information about the current interation.
+ info: list
+ # The lengths of the episodes.
+ lengths: Union[List[int], None]
+ # The loss of the current interation.
+ loss: Union[dict, None]
+ # The returns of the episodes.
+ returns: Union[List[float], None]
+ # The total time it took to run the current interation.
+ total_time: Union[int, None]
+ # The time it took to update the agent.
+ update_time: Union[int, None]
+
+
+Callback = Callable[[EpisodeStatistics], None]
+
+
+class Runner:
+ """The runner class for running an agent in an environment.
+
+ This class is responsible for running an agent in an environment. It is responsible for collecting data from the
+ environment, updating the agent, and evaluating the agent. It also provides a number of callbacks that can be used
+ to log and visualize the training progress.
+ """
+
+ _dataset: Dataset
+ _episode_statistics: EpisodeStatistics
+ _num_steps_per_env: int
+
+ def __init__(
+ self,
+ environment: VecEnv,
+ agent: Agent,
+ device: str = "cpu",
+ evaluation_cb: List[Callback] = None,
+ learn_cb: List[Callback] = None,
+ observation_history_length: int = 1,
+ **kwargs,
+ ) -> None:
+ """
+ Args:
+ environment (rsl_rl.env.VecEnv): The environment to run the agent in.
+ agent (rsl_rl.algorithms.agent): The RL agent to run.
+ device (str): The device to run on.
+ evaluation_cb (List[Callable[[dict], None]], optional): A list of callbacks that are called after each round
+ of evaluation.
+ learn_cb (List[Callable[[dict], None]], optional): A list of callbacks that are called after each round of
+ learning.
+ observation_history_length: The number of observations to concatenate into a single observation.
+ """
+ self.env = environment
+ self.agent = agent
+ self.device = device
+ self._obs_hist_len = observation_history_length
+ self._learn_cb = learn_cb if learn_cb else []
+ self._eval_cb = evaluation_cb if evaluation_cb else []
+
+ self._set_kwarg(kwargs, "num_steps_per_env", default=1)
+
+ self._current_learning_iteration = 0
+ self._git_status_repos = [rsl_rl.__file__]
+
+ self.to(self.device)
+
+ self._stored_dataset = [] # For computing observation history over multiple steps.
+
+ def add_git_repo_to_log(self, repo_file_path):
+ self._git_status_repos.append(repo_file_path)
+
+ def eval_mode(self):
+ """Sets the agent to evaluation mode."""
+ self.agent.eval_mode()
+
+ def evaluate(self, steps: int, return_epochs: int = 100) -> float:
+ """Evaluates the agent for a number of steps.
+
+ Args:
+ steps (int): The number of steps to evaluate the agent for.
+ return_epochs (int): The number of epochs over which to aggregate the return. Defaults to 100.
+ Returns:
+ The mean return of the agent.
+ """
+ cumulative_rewards = []
+ current_cumulative_rewards = torch.zeros(self.env.num_envs, dtype=torch.float)
+ current_episode_lengths = torch.zeros(self.env.num_envs, dtype=torch.int)
+ episode_lengths = []
+
+ self.eval_mode()
+
+ policy = self.get_inference_policy()
+ obs, env_info = self.env.get_observations()
+
+ with torch.inference_mode():
+ for step in range(steps):
+ actions = policy(obs.clone(), copy.deepcopy(env_info))
+ obs, rewards, dones, env_info, episode_statistics = self.evaluate_step(obs, env_info, actions)
+
+ dones_idx = dones.nonzero().cpu()
+ current_cumulative_rewards += rewards.clone().cpu()
+ current_episode_lengths += 1
+ cumulative_rewards.extend(current_cumulative_rewards[dones_idx].squeeze(1).cpu().tolist())
+ episode_lengths.extend(current_episode_lengths[dones_idx].squeeze(1).cpu().tolist())
+ current_cumulative_rewards[dones_idx] = 0.0
+ current_episode_lengths[dones_idx] = 0
+
+ episode_statistics["current_iteration"] = step
+ episode_statistics["final_iteration"] = steps
+ episode_statistics["lengths"] = episode_lengths[-return_epochs:]
+ episode_statistics["returns"] = cumulative_rewards[-return_epochs:]
+
+ for cb in self._eval_cb:
+ cb(self, episode_statistics)
+
+ cumulative_rewards.extend(current_cumulative_rewards.cpu().tolist())
+ mean_return = np.mean(cumulative_rewards)
+
+ return mean_return
+
+ def evaluate_step(
+ self, observations=None, environment_info=None, actions=None
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Dict, Dict]:
+ """Evaluates the agent for a single step.
+
+ Args:
+ observations (torch.Tensor): The observations to evaluate the agent for.
+ environment_info (Dict[str, Any]): The environment information for the observations.
+ actions (torch.Tensor): The actions to evaluate the agent for.
+ Returns:
+ A tuple containing the observations, rewards, dones, environment information, and episode statistics after
+ the evaluation step.
+ """
+ episode_statistics = {
+ "current_actions": None,
+ "current_dones": None,
+ "current_iteration": 0,
+ "current_observations": None,
+ "current_rewards": None,
+ "final_iteration": 0,
+ "first_iteration": 0,
+ "info": [],
+ "lengths": [],
+ "returns": [],
+ "timeout": None,
+ "total_time": None,
+ }
+
+ self.eval_mode()
+
+ with torch.inference_mode():
+ obs, env_info = self.env.get_observations() if observations is None else (observations, environment_info)
+
+ with torch.inference_mode():
+ start = time.time()
+
+ actions = self.get_inference_policy()(obs.clone(), copy.deepcopy(env_info)) if actions is None else actions
+ obs, rewards, dones, env_info = self.env.step(actions.clone())
+
+ self.agent.register_terminations(dones.nonzero().reshape(-1))
+
+ end = time.time()
+
+ if "episode" in env_info:
+ episode_statistics["info"].append(env_info["episode"])
+ episode_statistics["current_actions"] = actions
+ episode_statistics["current_dones"] = dones
+ episode_statistics["current_observations"] = obs
+ episode_statistics["current_rewards"] = rewards
+ episode_statistics["total_time"] = end - start
+
+ return obs, rewards, dones, env_info, episode_statistics
+
+ def get_inference_policy(self, device=None):
+ self.eval_mode()
+
+ return self.agent.get_inference_policy(device)
+
+ def learn(
+ self, iterations: Union[int, None] = None, timeout: Union[int, None] = None, return_epochs: int = 100
+ ) -> None:
+ """Runs a number of learning iterations.
+
+ Args:
+ iterations (int): The number of iterations to run.
+ timeout (int): Optional number of seconds after which to terminate training. Defaults to None.
+ return_epochs (int): The number of epochs over which to aggregate the return. Defaults to 100.
+ """
+ assert iterations is not None or timeout is not None
+
+ self._episode_statistics = {
+ "collection_time": None,
+ "current_actions": None,
+ "current_iteration": self._current_learning_iteration,
+ "current_observations": None,
+ "final_iteration": self._current_learning_iteration + iterations if iterations is not None else None,
+ "first_iteration": self._current_learning_iteration,
+ "info": [],
+ "lengths": [],
+ "loss": {},
+ "returns": [],
+ "storage_initialized": False,
+ "timeout": timeout,
+ "total_time": None,
+ "training_time": 0,
+ "update_time": None,
+ }
+ self._current_episode_lengths = torch.zeros(self.env.num_envs, dtype=torch.float)
+ self._current_cumulative_rewards = torch.zeros(self.env.num_envs, dtype=torch.float)
+
+ self.train_mode()
+
+ self._obs, self._env_info = self.env.get_observations()
+ while True:
+ if self._learning_should_terminate():
+ break
+
+ # Collect data
+ start = time.time()
+
+ with torch.inference_mode():
+ self._dataset = []
+
+ for _ in range(self._num_steps_per_env):
+ self._collect()
+
+ self._episode_statistics["lengths"] = self._episode_statistics["lengths"][-return_epochs:]
+ self._episode_statistics["returns"] = self._episode_statistics["returns"][-return_epochs:]
+
+ self._episode_statistics["collection_time"] = time.time() - start
+
+ # Update agent
+
+ start = time.time()
+
+ self._update()
+
+ self._episode_statistics["update_time"] = time.time() - start
+
+ # Housekeeping
+
+ self._episode_statistics["total_time"] = (
+ self._episode_statistics["collection_time"] + self._episode_statistics["update_time"]
+ )
+ self._episode_statistics["training_time"] += self._episode_statistics["total_time"]
+
+ if self.agent.initialized:
+ self._episode_statistics["current_iteration"] += 1
+
+ terminate = False
+ for cb in self._learn_cb:
+ terminate = (cb(self, self._episode_statistics) == False) or terminate
+
+ if terminate:
+ break
+
+ self._episode_statistics["info"].clear()
+ self._current_learning_iteration = self._episode_statistics["current_iteration"]
+
+ def _collect(self) -> None:
+ """Runs a single step in the environment to collect a transition and stores it in the dataset.
+
+ This method runs a single step in the environment to collect a transition and stores it in the dataset. If the
+ agent is not initialized, random actions are drawn from the action space. Furthermore, the method gathers
+ statistics about the episode and stores them in the episode statistics dictionary of the runner.
+ """
+ if self.agent.initialized:
+ actions, data = self.agent.draw_actions(self._obs, self._env_info)
+ else:
+ actions, data = self.agent.draw_random_actions(self._obs, self._env_info)
+
+ next_obs, rewards, dones, next_env_info = self.env.step(actions)
+
+ self._dataset.append(
+ self.agent.process_transition(
+ self._obs.clone(),
+ copy.deepcopy(self._env_info),
+ actions.clone(),
+ rewards.clone(),
+ next_obs.clone(),
+ copy.deepcopy(next_env_info),
+ dones.clone(),
+ copy.deepcopy(data),
+ )
+ )
+
+ self.agent.register_terminations(dones.nonzero().reshape(-1))
+
+ self._obs, self._env_info = next_obs, next_env_info
+
+ # Gather statistics
+ if "episode" in self._env_info:
+ self._episode_statistics["info"].append(self._env_info["episode"])
+ dones_idx = (dones + next_env_info["time_outs"]).nonzero().cpu()
+ self._current_episode_lengths += 1
+ self._current_cumulative_rewards += rewards.cpu()
+
+ completed_lengths = self._current_episode_lengths[dones_idx][:, 0].cpu()
+ completed_returns = self._current_cumulative_rewards[dones_idx][:, 0].cpu()
+ self._episode_statistics["lengths"].extend(completed_lengths.tolist())
+ self._episode_statistics["returns"].extend(completed_returns.tolist())
+ self._current_episode_lengths[dones_idx] = 0.0
+ self._current_cumulative_rewards[dones_idx] = 0.0
+
+ self._episode_statistics["current_actions"] = actions
+ self._episode_statistics["current_observations"] = self._obs
+ self._episode_statistics["sample_count"] = self.agent.storage.sample_count
+
+ def _learning_should_terminate(self):
+ """Checks whether the learning should terminate.
+
+ Termination is triggered if the number of iterations or the timeout is reached.
+
+ Returns:
+ Whether the learning should terminate.
+ """
+ if (
+ self._episode_statistics["final_iteration"] is not None
+ and self._episode_statistics["current_iteration"] >= self._episode_statistics["final_iteration"]
+ ):
+ return True
+
+ if (
+ self._episode_statistics["timeout"] is not None
+ and self._episode_statistics["training_time"] >= self._episode_statistics["timeout"]
+ ):
+ return True
+
+ return False
+
+ def _update(self) -> None:
+ """Updates the agent using the collected data."""
+ loss = self.agent.update(self._dataset)
+ self._dataset = []
+
+ if not self.agent.initialized:
+ return
+
+ self._episode_statistics["loss"] = loss
+ self._episode_statistics["storage_initialized"] = True
+
+ def load(self, path: str) -> Any:
+ """Restores the agent and runner state from a file."""
+ content = torch.load(path)
+
+ assert "agent" in content
+ assert "data" in content
+ assert "iteration" in content
+
+ self.agent.load_state_dict(content["agent"])
+ self._current_learning_iteration = content["iteration"]
+
+ return content["data"]
+
+ def save(self, path: str, data: Any = None) -> None:
+ """Saves the agent and runner state to a file."""
+ content = {
+ "agent": self.agent.state_dict(),
+ "data": data,
+ "iteration": self._current_learning_iteration,
+ }
+
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ torch.save(content, path)
+
+ def export_onnx(self, path: str) -> None:
+ """Exports the agent's policy network to ONNX format."""
+ model, args, kwargs = self.agent.export_onnx()
+
+ os.makedirs(os.path.dirname(path), exist_ok=True)
+ torch.onnx.export(model, args, path, **kwargs)
+
+ def to(self, device) -> Runner:
+ """Sets the device of the runner and its components."""
+ self.device = device
+
+ self.agent.to(device)
+
+ try:
+ self.env.to(device)
+ except AttributeError:
+ pass
+
+ return self
+
+ def train_mode(self):
+ """Sets the agent to training mode."""
+ self.agent.train_mode()
+
+ def _set_kwarg(self, args, key, default=None, private=True):
+ setattr(self, f"_{key}" if private else key, args[key] if key in args else default)
+
+ def _log_progress(self, stat, clear_line=True, prefix=""):
+ """Logs the progress of the runner."""
+ if not hasattr(self, "_iteration_times"):
+ self._iteration_times = []
+
+ self._iteration_times = (self._iteration_times + [stat["total_time"]])[-100:]
+ average_total_time = np.mean(self._iteration_times)
+
+ if stat["final_iteration"] is not None:
+ first_iteration = stat["first_iteration"]
+ final_iteration = stat["final_iteration"]
+ current_iteration = stat["current_iteration"]
+ final_run_iteration = final_iteration - first_iteration
+ remaining_iterations = final_iteration - current_iteration
+
+ remaining_iteration_time = remaining_iterations * average_total_time
+ iteration_completion_percentage = 100 * (current_iteration - first_iteration) / final_run_iteration
+ else:
+ remaining_iteration_time = np.inf
+ iteration_completion_percentage = 0
+
+ if stat["timeout"] is not None:
+ training_time = stat["training_time"]
+ timeout = stat["timeout"]
+
+ remaining_timeout_time = stat["timeout"] - stat["training_time"]
+ timeout_completion_percentage = 100 * stat["training_time"] / stat["timeout"]
+ else:
+ remaining_timeout_time = np.inf
+ timeout_completion_percentage = 0
+
+ if remaining_iteration_time > remaining_timeout_time:
+ completion_percentage = timeout_completion_percentage
+ remaining_time = remaining_timeout_time
+ step_string = f"({int(training_time)}s / {timeout}s)"
+ else:
+ completion_percentage = iteration_completion_percentage
+ remaining_time = remaining_iteration_time
+ step_string = f"({current_iteration} / {final_iteration})"
+
+ prefix = f"[{prefix}] " if prefix else ""
+ progress = "".join(["#" if i <= int(completion_percentage) else "_" for i in range(10, 101, 5)])
+ remaining_time_string = str(timedelta(seconds=int(np.ceil(remaining_time))))
+ print(
+ f"{prefix}{progress} {step_string} [{completion_percentage:.1f}%, {1/average_total_time:.2f}it/s, {remaining_time_string} ETA]",
+ end="\r" if clear_line else "\n",
+ )
+
+ def _log(self, stat, prefix=""):
+ """Logs the progress and statistics of the runner."""
+ current_iteration = stat["current_iteration"]
+
+ collection_time = stat["collection_time"]
+ update_time = stat["update_time"]
+ total_time = stat["total_time"]
+ collection_percentage = 100 * collection_time / total_time
+ update_percentage = 100 * update_time / total_time
+
+ if prefix == "":
+ prefix = "learn" if stat["storage_initialized"] else "init"
+ self._log_progress(stat, clear_line=False, prefix=prefix)
+ print(
+ f"iteration time:\t{total_time:.4f}s (collection: {collection_time:.2f}s [{collection_percentage:.1f}%], update: {update_time:.2f}s [{update_percentage:.1f}%])"
+ )
+
+ mean_reward = sum(stat["returns"]) / len(stat["returns"]) if len(stat["returns"]) > 0 else 0.0
+ mean_steps = sum(stat["lengths"]) / len(stat["lengths"]) if len(stat["lengths"]) > 0 else 0.0
+ total_steps = current_iteration * self.env.num_envs * self._num_steps_per_env
+ sample_count = stat["sample_count"]
+ print(f"avg. reward:\t{mean_reward:.4f}")
+ print(f"avg. steps:\t{mean_steps:.4f}")
+ print(f"stored samples:\t{sample_count}")
+ print(f"total steps:\t{total_steps}")
+
+ for key, value in stat["loss"].items():
+ print(f"{key} loss:\t{value:.4f}")
+
+ for key, value in self.agent._bm_report().items():
+ mean, count = value
+ print(f"BM {key}:\t{mean/1000000.0:.4f}ms ({count} calls, total {mean*count/1000000.0:.4f}ms)")
+
+ self.agent._bm_flush()
diff --git a/rsl_rl/storage/__init__.py b/rsl_rl/storage/__init__.py
index 91032a0..e22f507 100644
--- a/rsl_rl/storage/__init__.py
+++ b/rsl_rl/storage/__init__.py
@@ -4,5 +4,6 @@
"""Implementation of transitions storage for RL-agent."""
from .rollout_storage import RolloutStorage
+from .replay_storage import ReplayStorage
-__all__ = ["RolloutStorage"]
+__all__ = ["RolloutStorage", "ReplayStorage"]
diff --git a/rsl_rl/storage/replay_storage.py b/rsl_rl/storage/replay_storage.py
new file mode 100644
index 0000000..955f31f
--- /dev/null
+++ b/rsl_rl/storage/replay_storage.py
@@ -0,0 +1,147 @@
+import torch
+from typing import Callable, Dict, Generator, Tuple, Optional
+
+from rsl_rl.storage.storage import Dataset, Storage, Transition
+
+
+class ReplayStorage(Storage):
+ def __init__(self, environment_count: int, max_size: int, device: str = "cpu", initial_size: int = 0) -> None:
+ self._env_count = environment_count
+ self.initial_size = initial_size // environment_count
+ self.max_size = max_size
+ self.device = device
+
+ self._register_serializable("max_size", "initial_size")
+
+ self._idx = 0
+ self._full = False
+ self._initialized = initial_size == 0
+ self._data = {}
+
+ self._processors: Dict[Tuple[Callable, Callable]] = {}
+
+ @property
+ def max_size(self):
+ return self._size * self._env_count
+
+ @max_size.setter
+ def max_size(self, value):
+ self._size = value // self._env_count
+
+ assert self.initial_size <= self._size
+
+ def _add_item(self, name: str, value: torch.Tensor) -> None:
+ """Adds a transition item to the storage.
+
+ Args:
+ name (str): The name of the item.
+ value (torch.Tensor): The value of the item.
+ """
+ value = self._process(name, value.clone().to(self.device))
+
+ if name not in self._data:
+ if self._full or self._idx != 0:
+ raise ValueError(f'Tried to store invalid transition data for "{name}".')
+ self._data[name] = torch.empty(
+ self._size * self._env_count, *value.shape[1:], device=self.device, dtype=value.dtype
+ )
+
+ start_idx = self._idx * self._env_count
+ end_idx = (self._idx + 1) * self._env_count
+ self._data[name][start_idx:end_idx] = value
+
+ def _process(self, name: str, value: torch.Tensor) -> torch.Tensor:
+ if name not in self._processors:
+ return value
+
+ for process, _ in self._processors[name]:
+ if process is None:
+ continue
+
+ value = process(value)
+
+ return value
+
+ def _process_undo(self, name: str, value: torch.Tensor) -> torch.Tensor:
+ if name not in self._processors:
+ return value
+
+ for _, undo in reversed(self._processors[name]):
+ if undo is None:
+ continue
+
+ value = undo(value)
+
+ return value
+
+ def append(self, dataset: Dataset) -> None:
+ """Appends a dataset of transitions to the storage.
+
+ Args:
+ dataset (Dataset): The dataset of transitions.
+ """
+ for transition in dataset:
+ for name, value in transition.items():
+ self._add_item(name, value)
+
+ self._idx += 1
+
+ if self._idx >= self.initial_size:
+ self._initialized = True
+
+ if self._idx == self._size:
+ self._full = True
+ self._idx = 0
+
+ def batch_generator(self, batch_size: int, batch_count: int) -> Generator[Transition, None, None]:
+ """Returns a generator that yields batches of transitions.
+
+ Args:
+ batch_size (int): The size of the batches.
+ batch_count (int): The number of batches to yield.
+ Returns:
+ A generator that yields batches of transitions.
+ """
+ assert self._full or self._idx > 0
+
+ if not self._initialized:
+ return
+
+ max_idx = self._env_count * (self._size if self._full else self._idx)
+
+ for _ in range(batch_count):
+ batch_idx = torch.randint(high=max_idx, size=(batch_size,))
+
+ batch = {}
+ for key, value in self._data.items():
+ batch[key] = self._process_undo(key, value[batch_idx].clone())
+
+ yield batch
+
+ def register_processor(self, key: str, process: Callable, undo: Optional[Callable] = None) -> None:
+ """Registers a processor for a transition item.
+
+ The processor is called before the item is stored in the storage. The undo function is called when the item is
+ retrieved from the storage. The undo function is called in reverse order of the processors so that the order of
+ the processors does not matter.
+
+ Args:
+ key (str): The name of the transition item.
+ process (Callable): The function to process the item.
+ undo (Optional[Callable], optional): The function to undo the processing. Defaults to None.
+ """
+ if key not in self._processors:
+ self._processors[key] = []
+
+ self._processors[key].append((process, undo))
+
+ @property
+ def initialized(self) -> bool:
+ return self._initialized
+
+ @property
+ def sample_count(self) -> int:
+ """Returns the number of individual transitions stored in the storage."""
+ transition_count = self._size * self._env_count if self._full else self._idx * self._env_count
+
+ return transition_count
diff --git a/rsl_rl/storage/rollout_storage.py b/rsl_rl/storage/rollout_storage.py
index 7aec039..34871a8 100644
--- a/rsl_rl/storage/rollout_storage.py
+++ b/rsl_rl/storage/rollout_storage.py
@@ -1,228 +1,71 @@
-# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
-# SPDX-License-Identifier: BSD-3-Clause
-
-
-from __future__ import annotations
-
import torch
+from typing import Generator
-from rsl_rl.utils import split_and_pad_trajectories
+from rsl_rl.storage.replay_storage import ReplayStorage
+from rsl_rl.storage.storage import Dataset, Transition
-class RolloutStorage:
- class Transition:
- def __init__(self):
- self.observations = None
- self.critic_observations = None
- self.actions = None
- self.rewards = None
- self.dones = None
- self.values = None
- self.actions_log_prob = None
- self.action_mean = None
- self.action_sigma = None
- self.hidden_states = None
+class RolloutStorage(ReplayStorage):
+ """Implementation of rollout storage for RL-agent."""
- def clear(self):
- self.__init__()
+ def __init__(self, environment_count: int, device: str = "cpu"):
+ """
+ Args:
+ environment_count (int): Number of environments.
+ device (str, optional): Device to use. Defaults to "cpu".
+ """
+ super().__init__(environment_count, environment_count, device=device, initial_size=0)
- def __init__(self, num_envs, num_transitions_per_env, obs_shape, privileged_obs_shape, actions_shape, device="cpu"):
- self.device = device
+ self._size_initialized = False
- self.obs_shape = obs_shape
- self.privileged_obs_shape = privileged_obs_shape
- self.actions_shape = actions_shape
+ def append(self, dataset: Dataset) -> None:
+ """Appends a dataset to the rollout storage.
- # Core
- self.observations = torch.zeros(num_transitions_per_env, num_envs, *obs_shape, device=self.device)
- if privileged_obs_shape[0] is not None:
- self.privileged_observations = torch.zeros(
- num_transitions_per_env, num_envs, *privileged_obs_shape, device=self.device
- )
+ Args:
+ dataset (Dataset): Dataset to append.
+ Raises:
+ AssertionError: If the dataset is not of the correct size.
+ """
+ assert self._idx == 0
+
+ if not self._size_initialized:
+ self.max_size = len(dataset) * self._env_count
+
+ assert len(dataset) == self._size
+
+ super().append(dataset)
+
+ def batch_generator(self, batch_count: int, trajectories: bool = False) -> Generator[Transition, None, None]:
+ """Yields batches of transitions or trajectories.
+
+ Args:
+ batch_count (int): Number of batches to yield.
+ trajectories (bool, optional): Whether to yield batches of trajectories. Defaults to False.
+ Raises:
+ AssertionError: If the rollout storage is not full.
+ Returns:
+ Generator yielding batches of transitions of shape (batch_size, *shape). If trajectories is True, yields
+ batches of trajectories of shape (env_count, steps_per_env, *shape).
+ """
+ assert self._full and self._initialized, "Rollout storage must be full and initialized."
+
+ total_size = self._env_count if trajectories else self._size * self._env_count
+ batch_size = total_size // batch_count
+ indices = torch.randperm(total_size)
+
+ assert batch_size > 0, "Batch count is too large."
+
+ if trajectories:
+ # Reshape to (env_count, steps_per_env, *shape)
+ data = {k: v.reshape(-1, self._env_count, *v.shape[1:]).transpose(0, 1) for k, v in self._data.items()}
else:
- self.privileged_observations = None
- self.rewards = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
- self.actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)
- self.dones = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device).byte()
+ data = self._data
- # For PPO
- self.actions_log_prob = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
- self.values = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
- self.returns = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
- self.advantages = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
- self.mu = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)
- self.sigma = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)
+ for i in range(batch_count):
+ batch_idx = indices[i * batch_size : (i + 1) * batch_size].detach().to(self.device)
- self.num_transitions_per_env = num_transitions_per_env
- self.num_envs = num_envs
+ batch = {}
+ for key, value in data.items():
+ batch[key] = self._process_undo(key, value[batch_idx].clone())
- # rnn
- self.saved_hidden_states_a = None
- self.saved_hidden_states_c = None
-
- self.step = 0
-
- def add_transitions(self, transition: Transition):
- if self.step >= self.num_transitions_per_env:
- raise AssertionError("Rollout buffer overflow")
- self.observations[self.step].copy_(transition.observations)
- if self.privileged_observations is not None:
- self.privileged_observations[self.step].copy_(transition.critic_observations)
- self.actions[self.step].copy_(transition.actions)
- self.rewards[self.step].copy_(transition.rewards.view(-1, 1))
- self.dones[self.step].copy_(transition.dones.view(-1, 1))
- self.values[self.step].copy_(transition.values)
- self.actions_log_prob[self.step].copy_(transition.actions_log_prob.view(-1, 1))
- self.mu[self.step].copy_(transition.action_mean)
- self.sigma[self.step].copy_(transition.action_sigma)
- self._save_hidden_states(transition.hidden_states)
- self.step += 1
-
- def _save_hidden_states(self, hidden_states):
- if hidden_states is None or hidden_states == (None, None):
- return
- # make a tuple out of GRU hidden state sto match the LSTM format
- hid_a = hidden_states[0] if isinstance(hidden_states[0], tuple) else (hidden_states[0],)
- hid_c = hidden_states[1] if isinstance(hidden_states[1], tuple) else (hidden_states[1],)
-
- # initialize if needed
- if self.saved_hidden_states_a is None:
- self.saved_hidden_states_a = [
- torch.zeros(self.observations.shape[0], *hid_a[i].shape, device=self.device) for i in range(len(hid_a))
- ]
- self.saved_hidden_states_c = [
- torch.zeros(self.observations.shape[0], *hid_c[i].shape, device=self.device) for i in range(len(hid_c))
- ]
- # copy the states
- for i in range(len(hid_a)):
- self.saved_hidden_states_a[i][self.step].copy_(hid_a[i])
- self.saved_hidden_states_c[i][self.step].copy_(hid_c[i])
-
- def clear(self):
- self.step = 0
-
- def compute_returns(self, last_values, gamma, lam):
- advantage = 0
- for step in reversed(range(self.num_transitions_per_env)):
- if step == self.num_transitions_per_env - 1:
- next_values = last_values
- else:
- next_values = self.values[step + 1]
- next_is_not_terminal = 1.0 - self.dones[step].float()
- delta = self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step]
- advantage = delta + next_is_not_terminal * gamma * lam * advantage
- self.returns[step] = advantage + self.values[step]
-
- # Compute and normalize the advantages
- self.advantages = self.returns - self.values
- self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8)
-
- def get_statistics(self):
- done = self.dones
- done[-1] = 1
- flat_dones = done.permute(1, 0, 2).reshape(-1, 1)
- done_indices = torch.cat(
- (flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero(as_tuple=False)[:, 0])
- )
- trajectory_lengths = done_indices[1:] - done_indices[:-1]
- return trajectory_lengths.float().mean(), self.rewards.mean()
-
- def mini_batch_generator(self, num_mini_batches, num_epochs=8):
- batch_size = self.num_envs * self.num_transitions_per_env
- mini_batch_size = batch_size // num_mini_batches
- indices = torch.randperm(num_mini_batches * mini_batch_size, requires_grad=False, device=self.device)
-
- observations = self.observations.flatten(0, 1)
- if self.privileged_observations is not None:
- critic_observations = self.privileged_observations.flatten(0, 1)
- else:
- critic_observations = observations
-
- actions = self.actions.flatten(0, 1)
- values = self.values.flatten(0, 1)
- returns = self.returns.flatten(0, 1)
- old_actions_log_prob = self.actions_log_prob.flatten(0, 1)
- advantages = self.advantages.flatten(0, 1)
- old_mu = self.mu.flatten(0, 1)
- old_sigma = self.sigma.flatten(0, 1)
-
- for epoch in range(num_epochs):
- for i in range(num_mini_batches):
- start = i * mini_batch_size
- end = (i + 1) * mini_batch_size
- batch_idx = indices[start:end]
-
- obs_batch = observations[batch_idx]
- critic_observations_batch = critic_observations[batch_idx]
- actions_batch = actions[batch_idx]
- target_values_batch = values[batch_idx]
- returns_batch = returns[batch_idx]
- old_actions_log_prob_batch = old_actions_log_prob[batch_idx]
- advantages_batch = advantages[batch_idx]
- old_mu_batch = old_mu[batch_idx]
- old_sigma_batch = old_sigma[batch_idx]
- yield obs_batch, critic_observations_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, (
- None,
- None,
- ), None
-
- # for RNNs only
- def reccurent_mini_batch_generator(self, num_mini_batches, num_epochs=8):
- padded_obs_trajectories, trajectory_masks = split_and_pad_trajectories(self.observations, self.dones)
- if self.privileged_observations is not None:
- padded_critic_obs_trajectories, _ = split_and_pad_trajectories(self.privileged_observations, self.dones)
- else:
- padded_critic_obs_trajectories = padded_obs_trajectories
-
- mini_batch_size = self.num_envs // num_mini_batches
- for ep in range(num_epochs):
- first_traj = 0
- for i in range(num_mini_batches):
- start = i * mini_batch_size
- stop = (i + 1) * mini_batch_size
-
- dones = self.dones.squeeze(-1)
- last_was_done = torch.zeros_like(dones, dtype=torch.bool)
- last_was_done[1:] = dones[:-1]
- last_was_done[0] = True
- trajectories_batch_size = torch.sum(last_was_done[:, start:stop])
- last_traj = first_traj + trajectories_batch_size
-
- masks_batch = trajectory_masks[:, first_traj:last_traj]
- obs_batch = padded_obs_trajectories[:, first_traj:last_traj]
- critic_obs_batch = padded_critic_obs_trajectories[:, first_traj:last_traj]
-
- actions_batch = self.actions[:, start:stop]
- old_mu_batch = self.mu[:, start:stop]
- old_sigma_batch = self.sigma[:, start:stop]
- returns_batch = self.returns[:, start:stop]
- advantages_batch = self.advantages[:, start:stop]
- values_batch = self.values[:, start:stop]
- old_actions_log_prob_batch = self.actions_log_prob[:, start:stop]
-
- # reshape to [num_envs, time, num layers, hidden dim] (original shape: [time, num_layers, num_envs, hidden_dim])
- # then take only time steps after dones (flattens num envs and time dimensions),
- # take a batch of trajectories and finally reshape back to [num_layers, batch, hidden_dim]
- last_was_done = last_was_done.permute(1, 0)
- hid_a_batch = [
- saved_hidden_states.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj]
- .transpose(1, 0)
- .contiguous()
- for saved_hidden_states in self.saved_hidden_states_a
- ]
- hid_c_batch = [
- saved_hidden_states.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj]
- .transpose(1, 0)
- .contiguous()
- for saved_hidden_states in self.saved_hidden_states_c
- ]
- # remove the tuple for GRU
- hid_a_batch = hid_a_batch[0] if len(hid_a_batch) == 1 else hid_a_batch
- hid_c_batch = hid_c_batch[0] if len(hid_c_batch) == 1 else hid_c_batch
-
- yield obs_batch, critic_obs_batch, actions_batch, values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, (
- hid_a_batch,
- hid_c_batch,
- ), masks_batch
-
- first_traj = last_traj
+ yield batch
diff --git a/rsl_rl/storage/storage.py b/rsl_rl/storage/storage.py
new file mode 100644
index 0000000..308cc02
--- /dev/null
+++ b/rsl_rl/storage/storage.py
@@ -0,0 +1,47 @@
+from abc import abstractmethod
+import torch
+from typing import Dict, Generator, List
+
+from rsl_rl.utils.serializable import Serializable
+
+
+# prev_obs, prev_obs_info, actions, rewards, next_obs, next_obs_info, dones, data
+Transition = Dict[str, torch.Tensor]
+Dataset = List[Transition]
+
+
+class Storage(Serializable):
+ @abstractmethod
+ def append(self, dataset: Dataset) -> None:
+ """Adds transitions to the storage.
+
+ Args:
+ dataset (Dataset): The transitions to add to the storage.
+ """
+ pass
+
+ @abstractmethod
+ def batch_generator(self, batch_size: int, batch_count: int) -> Generator[Dict[str, torch.Tensor], None, None]:
+ """Generates a batch of transitions.
+
+ Args:
+ batch_size (int): The size of each batch to generate.
+ batch_count (int): The number of batches to generate.
+ Returns:
+ A generator that yields transitions.
+ """
+ pass
+
+ @property
+ def initialized(self) -> bool:
+ """Returns whether the storage is initialized."""
+ return True
+
+ @abstractmethod
+ def sample_count(self) -> int:
+ """Returns how many individual samples are stored in the storage.
+
+ Returns:
+ The number of stored samples.
+ """
+ pass
diff --git a/rsl_rl/utils/__init__.py b/rsl_rl/utils/__init__.py
index 46b365a..19bdd7b 100644
--- a/rsl_rl/utils/__init__.py
+++ b/rsl_rl/utils/__init__.py
@@ -1,6 +1,3 @@
-# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
-# SPDX-License-Identifier: BSD-3-Clause
-
"""Helper functions."""
-from .utils import split_and_pad_trajectories, store_code_state, unpad_trajectories
+from .utils import split_and_pad_trajectories, unpad_trajectories, store_code_state
diff --git a/rsl_rl/utils/benchmarkable.py b/rsl_rl/utils/benchmarkable.py
new file mode 100644
index 0000000..7cb95a2
--- /dev/null
+++ b/rsl_rl/utils/benchmarkable.py
@@ -0,0 +1,111 @@
+import numpy as np
+import time
+from typing import Callable, Dict
+
+
+class Benchmark:
+ def __init__(self):
+ self.reset()
+
+ def __call__(self):
+ if self.running:
+ self.end()
+ else:
+ self.start()
+
+ def end(self):
+ now = time.process_time_ns()
+
+ assert self.running
+
+ difference = now - self._timer
+ self._timings.append(difference)
+
+ self._timer = None
+
+ def reset(self):
+ self._timer = None
+ self._timings = []
+
+ @property
+ def running(self):
+ return self._timer is not None
+
+ def start(self):
+ self._timer = time.process_time_ns()
+
+ @property
+ def timings(self):
+ return self._timings
+
+
+class Benchmarkable:
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ self._benchmark = False
+ self._bm_data = dict()
+ self._bm_fusions = []
+
+ def _bm(self, name: str) -> None:
+ if not self._benchmark:
+ return
+
+ if name not in self._bm_data:
+ self._bm_data[name] = Benchmark()
+
+ self._bm_data[name]()
+
+ def _bm_flush(self) -> None:
+ # TODO: implement
+ for val in self._bm_data.values():
+ val.reset()
+
+ for fusion in self._bm_fusions:
+ fusion["target"]._bm_flush()
+
+ def _bm_fuse(self, target, prefix="") -> None:
+ assert isinstance(target, Benchmarkable)
+ assert target not in self._bm_fusions
+
+ target._bm_toggle(self._benchmark)
+ self._bm_fusions.append(dict(target=target, prefix=prefix))
+
+ def _bm_report(self) -> Dict:
+ data = dict()
+
+ if not self._benchmark:
+ return data
+
+ for key, val in self._bm_data.items():
+ data[key] = (np.mean(val.timings), len(val.timings))
+
+ for fusion in self._bm_fusions:
+ target = fusion["target"]
+ prefix = fusion["prefix"]
+
+ for key, val in target._bm_report().items():
+ data[f"{prefix}{key}"] = val
+
+ return data
+
+ def _bm_toggle(self, value: bool) -> None:
+ self._benchmark = value
+
+ for fusion in self._bm_fusions:
+ fusion["target"]._bm_toggle(value)
+
+ @staticmethod
+ def register(method: Callable, name=None) -> Callable:
+ benchmark_name = method.__name__ if name is None else name
+
+ def wrapper(self, *args, **kwargs):
+ assert isinstance(self, Benchmarkable)
+
+ self._bm(benchmark_name)
+ result = method(self, *args, **kwargs)
+ self._bm(benchmark_name)
+
+ return result
+
+ return wrapper
diff --git a/rsl_rl/utils/neptune_utils.py b/rsl_rl/utils/neptune_utils.py
index f06cc62..9e0cbb7 100644
--- a/rsl_rl/utils/neptune_utils.py
+++ b/rsl_rl/utils/neptune_utils.py
@@ -1,32 +1,26 @@
-# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
-# SPDX-License-Identifier: BSD-3-Clause
-
-from __future__ import annotations
-
import os
-from dataclasses import asdict
+
from torch.utils.tensorboard import SummaryWriter
+from legged_gym.utils import class_to_dict
try:
- import neptune
+ import neptune.new as neptune
except ModuleNotFoundError:
raise ModuleNotFoundError("neptune-client is required to log to Neptune.")
class NeptuneLogger:
def __init__(self, project, token):
- self.run = neptune.init_run(project=project, api_token=token)
+ self.run = neptune.init(project=project, api_token=token)
def store_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg):
self.run["runner_cfg"] = runner_cfg
self.run["policy_cfg"] = policy_cfg
self.run["alg_cfg"] = alg_cfg
- self.run["env_cfg"] = asdict(env_cfg)
+ self.run["env_cfg"] = class_to_dict(env_cfg)
class NeptuneSummaryWriter(SummaryWriter):
- """Summary writer for Neptune."""
-
def __init__(self, log_dir: str, flush_secs: int, cfg):
super().__init__(log_dir, flush_secs)
@@ -86,7 +80,3 @@ class NeptuneSummaryWriter(SummaryWriter):
def save_model(self, model_path, iter):
self.neptune_logger.run["model/saved_model_" + str(iter)].upload(model_path)
-
- def save_file(self, path, iter=None):
- name = path.rsplit("/", 1)[-1].split(".")[0]
- self.neptune_logger.run["git_diff/" + name].upload(path)
diff --git a/rsl_rl/utils/recurrency.py b/rsl_rl/utils/recurrency.py
new file mode 100644
index 0000000..3e798bd
--- /dev/null
+++ b/rsl_rl/utils/recurrency.py
@@ -0,0 +1,69 @@
+import torch
+from typing import Tuple
+
+
+def trajectories_to_transitions(trajectories: torch.Tensor, data: Tuple[torch.Tensor, int, bool]) -> torch.Tensor:
+ """Unpacks a tensor of trajectories into a tensor of transitions.
+
+ Args:
+ trajectories (torch.Tensor): A tensor of trajectories.
+ data (Tuple[torch.Tensor, int, bool]): A tuple containing the mask and data for the conversion.
+ batch_first (bool, optional): Whether the first dimension of the trajectories tensor is the batch dimension.
+ Defaults to False.
+ Returns:
+ A tensor of transitions of shape (batch_size, time, *).
+ """
+ mask, batch_size, batch_first = data
+
+ if not batch_first:
+ trajectories, mask = trajectories.transpose(0, 1), mask.transpose(0, 1)
+
+ transitions = trajectories[mask == 1.0].reshape(batch_size, -1, *trajectories.shape[2:])
+
+ return transitions
+
+
+def transitions_to_trajectories(
+ transitions: torch.Tensor, dones: torch.Tensor, batch_first: bool = False
+) -> Tuple[torch.Tensor, Tuple[torch.Tensor, int, bool]]:
+ """Packs a tensor of transitions into a tensor of trajectories.
+
+ Example:
+ >>> transitions = torch.tensor([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]])
+ >>> dones = torch.tensor([[0, 0, 1], [0, 1, 0]])
+ >>> transitions_to_trajectories(None, transitions, dones, batch_first=True)
+ (tensor([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [0, 0]], [[11, 12], [0, 0], [0, 0]]]), tensor([[1, 1, 1], [1, 1, 0], [1, 0, 0]]))
+
+ Args:
+ transitions (torch.Tensor): Tensor of transitions of shape (batch_size, time, *).
+ dones (torch.Tensor): Tensor of transition terminations of shape (batch_size, time).
+ batch_first (bool): Whether the first dimension of the output tensor should be the batch dimension. Defaults to
+ False.
+ Returns:
+ A torch.Tensor of trajectories of shape (time, trajectory_count, *) that is padded with zeros and data for
+ reverting the operation. If batch_first is True, the shape of the trajectories is (trajectory_count, time, *).
+ """
+ batch_size = transitions.shape[0]
+
+ # Count the trajectory lengths by (1) padding dones with a 1 at the end to indicate the end of the trajectory,
+ # (2) stacking up the padded dones in a single column, and (3) counting the number of steps between each done by
+ # using the row index.
+ padded_dones = dones.clone()
+ padded_dones[:, -1] = 1
+ stacked_dones = torch.cat((padded_dones.new([-1]), padded_dones.reshape(-1, 1).nonzero()[:, 0]))
+ trajectory_lengths = stacked_dones[1:] - stacked_dones[:-1]
+
+ # Compute trajectories by splitting transitions according to previously computed trajectory lengths.
+ trajectory_list = torch.split(transitions.flatten(0, 1), trajectory_lengths.int().tolist())
+ trajectories = torch.nn.utils.rnn.pad_sequence(trajectory_list, batch_first=batch_first)
+
+ # The mask is generated by computing a 2d matrix of increasing counts in the 2nd dimension and comparing it to the
+ # trajectory lengths.
+ range = torch.arange(0, trajectory_lengths.max()).repeat(len(trajectory_lengths), 1)
+ range = range.cuda(dones.device) if dones.is_cuda else range
+ mask = (trajectory_lengths.unsqueeze(1) > range).float()
+
+ if not batch_first:
+ mask = mask.T
+
+ return trajectories, (mask, batch_size, batch_first)
diff --git a/rsl_rl/utils/serializable.py b/rsl_rl/utils/serializable.py
new file mode 100644
index 0000000..d36e002
--- /dev/null
+++ b/rsl_rl/utils/serializable.py
@@ -0,0 +1,43 @@
+from typing import Any, Dict
+
+
+class Serializable:
+ def load_state_dict(self, data: Dict[str, Any]) -> None:
+ """Loads agent parameters from a dictionary."""
+ assert hasattr(self, "_serializable_objects")
+
+ for name in self._serializable_objects:
+ assert hasattr(self, name)
+ assert name in data, f'Object "{name}" was not found while loading "{self.__class__.__name__}".'
+
+ attr = getattr(self, name)
+ if hasattr(attr, "load_state_dict"):
+ print(f"Loading {name}")
+ attr.load_state_dict(data[name])
+ else:
+ print(f"Loading value {name}={data[name]}")
+ setattr(self, name, data[name])
+
+ def state_dict(self) -> Dict[str, Any]:
+ """Returns a dictionary containing the agent parameters."""
+ assert hasattr(self, "_serializable_objects")
+
+ data = {}
+
+ for name in self._serializable_objects:
+ assert hasattr(self, name)
+
+ attr = getattr(self, name)
+ data[name] = attr.state_dict() if hasattr(attr, "state_dict") else attr
+
+ return data
+
+ def _register_serializable(self, *objects) -> None:
+ if not hasattr(self, "_serializable_objects"):
+ self._serializable_objects = []
+
+ for name in objects:
+ if name in self._serializable_objects:
+ continue
+
+ self._serializable_objects.append(name)
diff --git a/rsl_rl/utils/utils.py b/rsl_rl/utils/utils.py
index f8d3103..f4d6b6a 100644
--- a/rsl_rl/utils/utils.py
+++ b/rsl_rl/utils/utils.py
@@ -1,16 +1,38 @@
-# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
-# SPDX-License-Identifier: BSD-3-Clause
-
-from __future__ import annotations
-
+from datetime import datetime
import git
import os
+import numpy as np
import pathlib
+import random
import torch
+def environment_dimensions(env):
+ obs = env.get_observations()
+
+ if isinstance(obs, tuple):
+ obs, env_info = obs
+ else:
+ env_info = {}
+
+ dims = {}
+
+ dims["observations"] = obs.shape[1]
+
+ if "observations" in env_info and "critic" in env_info["observations"]:
+ dims["actor_observations"] = dims["observations"]
+ dims["critic_observations"] = env_info["observations"]["critic"].shape[1]
+ else:
+ dims["actor_observations"] = dims["observations"]
+ dims["critic_observations"] = dims["observations"]
+
+ dims["actions"] = env.num_actions
+
+ return dims
+
+
def split_and_pad_trajectories(tensor, dones):
- """Splits trajectories at done indices. Then concatenates them and pads with zeros up to the length og the longest trajectory.
+ """Splits trajectories at done indices. Then concatenates them and padds with zeros up to the length og the longest trajectory.
Returns masks corresponding to valid parts of the trajectories
Example:
Input: [ [a1, a2, a3, a4 | a5, a6],
@@ -24,7 +46,7 @@ def split_and_pad_trajectories(tensor, dones):
[b6, 0, 0, 0] | [True, False, False, False],
] | ]
- Assumes that the inputy has the following dimension order: [time, number of envs, additional dimensions]
+ Assumes that the inputy has the following dimension order: [time, number of envs, aditional dimensions]
"""
dones = dones.clone()
dones[-1] = 1
@@ -37,12 +59,7 @@ def split_and_pad_trajectories(tensor, dones):
trajectory_lengths_list = trajectory_lengths.tolist()
# Extract the individual trajectories
trajectories = torch.split(tensor.transpose(1, 0).flatten(0, 1), trajectory_lengths_list)
- # add at least one full length trajectory
- trajectories = trajectories + (torch.zeros(tensor.shape[0], tensor.shape[-1], device=tensor.device),)
- # pad the trajectories to the length of the longest trajectory
padded_trajectories = torch.nn.utils.rnn.pad_sequence(trajectories)
- # remove the added tensor
- padded_trajectories = padded_trajectories[:, :-1]
trajectory_masks = trajectory_lengths > torch.arange(0, tensor.shape[0], device=tensor.device).unsqueeze(1)
return padded_trajectories, trajectory_masks
@@ -58,29 +75,34 @@ def unpad_trajectories(trajectories, masks):
)
-def store_code_state(logdir, repositories) -> list:
- git_log_dir = os.path.join(logdir, "git")
- os.makedirs(git_log_dir, exist_ok=True)
- file_paths = []
+def store_code_state(logdir, repositories):
for repository_file_path in repositories:
- try:
- repo = git.Repo(repository_file_path, search_parent_directories=True)
- except Exception:
- print(f"Could not find git repository in {repository_file_path}. Skipping.")
- # skip if not a git repository
- continue
- # get the name of the repository
+ repo = git.Repo(repository_file_path, search_parent_directories=True)
repo_name = pathlib.Path(repo.working_dir).name
t = repo.head.commit.tree
- diff_file_name = os.path.join(git_log_dir, f"{repo_name}.diff")
- # check if the diff file already exists
- if os.path.isfile(diff_file_name):
- continue
- # write the diff file
- print(f"Storing git diff for '{repo_name}' in: {diff_file_name}")
- with open(diff_file_name, "x") as f:
- content = f"--- git status ---\n{repo.git.status()} \n\n\n--- git diff ---\n{repo.git.diff(t)}"
+ content = f"--- git status ---\n{repo.git.status()} \n\n\n--- git diff ---\n{repo.git.diff(t)}"
+ with open(os.path.join(logdir, f"{repo_name}_git.diff"), "x", encoding="utf-8") as f:
f.write(content)
- # add the file path to the list of files to be uploaded
- file_paths.append(diff_file_name)
- return file_paths
+
+
+def seed(s=None):
+ seed = int(datetime.now().timestamp() * 1e6) % 2**32 if s is None else s
+
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ os.environ["PYTHONHASHSEED"] = str(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+
+def squeeze_preserve_batch(tensor):
+ """Squeezes a tensor, but preserves the batch dimension"""
+ single_batch = tensor.shape[0] == 1
+
+ squeezed_tensor = tensor.squeeze()
+
+ if single_batch:
+ squeezed_tensor = squeezed_tensor.unsqueeze(0)
+
+ return squeezed_tensor
diff --git a/rsl_rl/utils/wandb_utils.py b/rsl_rl/utils/wandb_utils.py
index 2868ce9..686c6dc 100644
--- a/rsl_rl/utils/wandb_utils.py
+++ b/rsl_rl/utils/wandb_utils.py
@@ -1,11 +1,7 @@
-# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
-# SPDX-License-Identifier: BSD-3-Clause
-
-from __future__ import annotations
-
import os
-from dataclasses import asdict
+
from torch.utils.tensorboard import SummaryWriter
+from legged_gym.utils import class_to_dict
try:
import wandb
@@ -14,8 +10,6 @@ except ModuleNotFoundError:
class WandbSummaryWriter(SummaryWriter):
- """Summary writer for Weights and Biases."""
-
def __init__(self, log_dir: str, flush_secs: int, cfg):
super().__init__(log_dir, flush_secs)
@@ -49,7 +43,7 @@ class WandbSummaryWriter(SummaryWriter):
wandb.config.update({"runner_cfg": runner_cfg})
wandb.config.update({"policy_cfg": policy_cfg})
wandb.config.update({"alg_cfg": alg_cfg})
- wandb.config.update({"env_cfg": asdict(env_cfg)})
+ wandb.config.update({"env_cfg": class_to_dict(env_cfg)})
def _map_path(self, path):
if path in self.name_map:
@@ -74,7 +68,4 @@ class WandbSummaryWriter(SummaryWriter):
self.store_config(env_cfg, runner_cfg, alg_cfg, policy_cfg)
def save_model(self, model_path, iter):
- wandb.save(model_path, base_path=os.path.dirname(model_path))
-
- def save_file(self, path, iter=None):
- wandb.save(path, base_path=os.path.dirname(path))
+ wandb.save(model_path)
diff --git a/setup.py b/setup.py
index b933467..b30fa39 100644
--- a/setup.py
+++ b/setup.py
@@ -1,24 +1,20 @@
-# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
-# SPDX-License-Identifier: BSD-3-Clause
-
-from setuptools import find_packages, setup
+from setuptools import setup, find_packages
setup(
name="rsl_rl",
- version="2.0.2",
+ version="1.0.2",
packages=find_packages(),
- author="ETH Zurich, NVIDIA CORPORATION",
- maintainer="Nikita Rudin, David Hoeller",
- maintainer_email="rudinn@ethz.ch",
- url="https://github.com/leggedrobotics/rsl_rl",
license="BSD-3",
description="Fast and simple RL algorithms implemented in pytorch",
python_requires=">=3.6",
install_requires=[
+ "GitPython",
+ "gym[all]>=0.26.0",
+ "numpy>=1.24.4",
+ "onnx>=1.14.0",
+ "tensorboard>=2.13.0",
"torch>=1.10.0",
"torchvision>=0.5.0",
- "numpy>=1.16.4",
- "GitPython",
- "onnx",
+ "wandb>=0.15.4",
],
)
diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py
new file mode 100644
index 0000000..bab6be4
--- /dev/null
+++ b/tests/test_algorithms.py
@@ -0,0 +1,169 @@
+import unittest
+
+from rsl_rl.algorithms import D4PG, DDPG, DPPO, DSAC, PPO, SAC, TD3
+from rsl_rl.env.gym_env import GymEnv
+from rsl_rl.modules import Network
+from rsl_rl.runners.runner import Runner
+
+DEVICE = "cpu"
+
+
+class AlgorithmTestCaseMixin:
+ algorithm_class = None
+
+ def _make_env(self, params={}):
+ my_params = dict(name="LunarLanderContinuous-v2", device=DEVICE, environment_count=4)
+ my_params.update(params)
+
+ return GymEnv(**my_params)
+
+ def _make_agent(self, env, agent_params={}):
+ return self.algorithm_class(env, device=DEVICE, **agent_params)
+
+ def _make_runner(self, env, agent, runner_params={}):
+ if not runner_params or "num_steps_per_env" not in runner_params:
+ runner_params["num_steps_per_env"] = 6
+
+ return Runner(env, agent, device=DEVICE, **runner_params)
+
+ def _learn(self, env, agent, runner_params={}):
+ runner = self._make_runner(env, agent, runner_params)
+ runner.learn(10)
+
+ def test_default(self):
+ env = self._make_env()
+ agent = self._make_agent(env)
+
+ self._learn(env, agent)
+
+ def test_single_env_single_step(self):
+ env = self._make_env(dict(environment_count=1))
+ agent = self._make_agent(env)
+
+ self._learn(env, agent, dict(num_steps_per_env=1))
+
+
+class RecurrentAlgorithmTestCaseMixin(AlgorithmTestCaseMixin):
+ def test_recurrent(self):
+ env = self._make_env()
+ agent = self._make_agent(env, dict(recurrent=True))
+
+ self._learn(env, agent)
+
+ def test_single_env_single_step_recurrent(self):
+ env = self._make_env(dict(environment_count=1))
+ agent = self._make_agent(env, dict(recurrent=True))
+
+ self._learn(env, agent, dict(num_steps_per_env=1))
+
+
+class D4PGTest(AlgorithmTestCaseMixin, unittest.TestCase):
+ algorithm_class = D4PG
+
+
+class DDPGTest(AlgorithmTestCaseMixin, unittest.TestCase):
+ algorithm_class = DDPG
+
+
+iqn_params = dict(
+ critic_network=DPPO.network_iqn,
+ iqn_action_samples=8,
+ iqn_embedding_size=16,
+ iqn_feature_layers=2,
+ iqn_value_samples=4,
+ value_loss=DPPO.value_loss_energy,
+)
+
+qrdqn_params = dict(
+ critic_network=DPPO.network_qrdqn,
+ qrdqn_quantile_count=16,
+ value_loss=DPPO.value_loss_l1,
+)
+
+
+class DPPOTest(RecurrentAlgorithmTestCaseMixin, unittest.TestCase):
+ algorithm_class = DPPO
+
+ def test_qrdqn(self):
+ env = self._make_env()
+ agent = self._make_agent(env, qrdqn_params)
+
+ self._learn(env, agent)
+
+ def test_qrdqn_sing_env_single_step(self):
+ env = self._make_env(dict(environment_count=1))
+ agent = self._make_agent(env, qrdqn_params)
+
+ self._learn(env, agent, dict(num_steps_per_env=1))
+
+ def test_qrdqn_energy_loss(self):
+ my_agent_params = qrdqn_params.copy()
+ my_agent_params["value_loss"] = DPPO.value_loss_energy
+
+ env = self._make_env()
+ agent = self._make_agent(env, my_agent_params)
+
+ self._learn(env, agent)
+
+ def test_qrdqn_huber_loss(self):
+ my_agent_params = qrdqn_params.copy()
+ my_agent_params["value_loss"] = DPPO.value_loss_huber
+
+ env = self._make_env()
+ agent = self._make_agent(env, my_agent_params)
+
+ self._learn(env, agent)
+
+ def test_qrdqn_transformer(self):
+ my_agent_params = qrdqn_params.copy()
+ my_agent_params["recurrent"] = True
+ my_agent_params["critic_recurrent_layers"] = 2
+ my_agent_params["critic_recurrent_module"] = Network.recurrent_module_transformer
+ my_agent_params["critic_recurrent_tf_context_length"] = 8
+ my_agent_params["critic_recurrent_tf_head_count"] = 2
+
+ env = self._make_env()
+ agent = self._make_agent(env, my_agent_params)
+
+ self._learn(env, agent)
+
+ def test_iqn(self):
+ env = self._make_env()
+ agent = self._make_agent(env, iqn_params)
+
+ self._learn(env, agent)
+
+ def test_iqn_single_step_single_env(self):
+ env = self._make_env(dict(environment_count=1))
+ agent = self._make_agent(env, iqn_params)
+
+ self._learn(env, agent, dict(num_steps_per_env=1))
+
+ def test_iqn_recurrent(self):
+ my_agent_params = iqn_params.copy()
+ my_agent_params["recurrent"] = True
+
+ env = self._make_env()
+ agent = self._make_agent(env, my_agent_params)
+
+ self._learn(env, agent)
+
+
+class DSACTest(AlgorithmTestCaseMixin, unittest.TestCase):
+ algorithm_class = DSAC
+
+
+class PPOTest(RecurrentAlgorithmTestCaseMixin, unittest.TestCase):
+ algorithm_class = PPO
+
+
+class SACTest(AlgorithmTestCaseMixin, unittest.TestCase):
+ algorithm_class = SAC
+
+
+class TD3Test(AlgorithmTestCaseMixin, unittest.TestCase):
+ algorithm_class = TD3
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_dpg.py b/tests/test_dpg.py
new file mode 100644
index 0000000..b2f7d3c
--- /dev/null
+++ b/tests/test_dpg.py
@@ -0,0 +1,126 @@
+import torch
+import unittest
+from rsl_rl.algorithms.dpg import AbstractDPG
+from rsl_rl.env.pole_balancing import PoleBalancing
+
+
+class DPG(AbstractDPG):
+ def draw_actions(self, obs, env_info):
+ pass
+
+ def eval_mode(self):
+ pass
+
+ def get_inference_policy(self, device=None):
+ pass
+
+ def process_transition(
+ self, observations, environment_info, actions, rewards, next_observations, next_environment_info, dones, data
+ ):
+ pass
+
+ def register_terminations(self, terminations):
+ pass
+
+ def to(self, device):
+ pass
+
+ def train_mode(self):
+ pass
+
+ def update(self, storage):
+ pass
+
+
+class FakeCritic(torch.nn.Module):
+ def __init__(self, values):
+ self.values = values
+
+ def forward(self, _):
+ return self.values
+
+
+class DPGTest(unittest.TestCase):
+ def test_timeout_bootstrapping(self):
+ env = PoleBalancing(environment_count=4)
+ dpg = DPG(env, device="cpu", return_steps=3)
+
+ rewards = torch.tensor(
+ [
+ [0.1000, 0.4000, 0.6000, 0.2000, -0.6000, -0.2000],
+ [0.0000, 0.9000, 0.5000, -0.9000, -0.4000, 0.8000],
+ [-0.5000, 0.4000, 0.0000, -0.2000, 0.3000, 0.1000],
+ [-0.8000, 0.9000, -0.6000, 0.7000, 0.5000, 0.1000],
+ ]
+ )
+ dones = torch.tensor(
+ [
+ [0, 0, 1, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0],
+ [1, 0, 0, 0, 0, 0],
+ [0, 0, 0, 1, 0, 0],
+ ]
+ )
+ timeouts = torch.tensor(
+ [
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 1, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 1, 0, 0, 0, 0],
+ ]
+ )
+ actions = torch.zeros((4, 6, 1))
+ observations = torch.zeros((4, 6, 2))
+ values = torch.tensor([-0.1000, -0.8000, 0.4000, 0.7000])
+
+ dpg.critic = FakeCritic(values)
+ dataset = [
+ {
+ "actions": actions[:, i],
+ "critic_observations": observations[:, i],
+ "dones": dones[:, i],
+ "rewards": rewards[:, i],
+ "timeouts": timeouts[:, i],
+ }
+ for i in range(3)
+ ]
+
+ processed_dataset = dpg._process_dataset(dataset)
+ processed_rewards = torch.stack([processed_dataset[i]["rewards"] for i in range(1)], dim=-1)
+
+ expected_rewards = torch.tensor(
+ [
+ [1.08406],
+ [1.38105],
+ [-0.5],
+ [0.77707],
+ ]
+ )
+ self.assertTrue(len(processed_dataset) == 1)
+ self.assertTrue(torch.isclose(processed_rewards, expected_rewards).all())
+
+ dataset = [
+ {
+ "actions": actions[:, i + 3],
+ "critic_observations": observations[:, i + 3],
+ "dones": dones[:, i + 3],
+ "rewards": rewards[:, i + 3],
+ "timeouts": timeouts[:, i + 3],
+ }
+ for i in range(3)
+ ]
+
+ processed_dataset = dpg._process_dataset(dataset)
+ processed_rewards = torch.stack([processed_dataset[i]["rewards"] for i in range(3)], dim=-1)
+
+ expected_rewards = torch.tensor(
+ [
+ [0.994, 0.6, -0.59002],
+ [0.51291, -1.5592792, -2.08008],
+ [0.20398, 0.09603, 0.19501],
+ [1.593, 0.093, 0.7],
+ ]
+ )
+
+ self.assertTrue(len(processed_dataset) == 3)
+ self.assertTrue(torch.isclose(processed_rewards, expected_rewards).all())
diff --git a/tests/test_dppo.py b/tests/test_dppo.py
new file mode 100644
index 0000000..9be2316
--- /dev/null
+++ b/tests/test_dppo.py
@@ -0,0 +1,278 @@
+import torch
+import unittest
+from rsl_rl.algorithms import DPPO
+from rsl_rl.env.pole_balancing import PoleBalancing
+
+
+class FakeCritic(torch.nn.Module):
+ def __init__(self, values, quantile_count=1):
+ self.quantile_count = quantile_count
+ self.recurrent = False
+ self.values = values
+ self.last_quantiles = values
+
+ def forward(self, _, distribution=False, measure_args=None):
+ if distribution:
+ return self.values
+
+ return self.values.mean(-1)
+
+ def quantiles_to_values(self, quantiles):
+ return quantiles.mean(-1)
+
+
+class DPPOTest(unittest.TestCase):
+ def test_gae_computation(self):
+ # GT taken from old PPO implementation.
+
+ env = PoleBalancing(environment_count=4)
+ dppo = DPPO(env, device="cpu", gae_lambda=0.97, gamma=0.99, qrdqn_quantile_count=1)
+
+ rewards = torch.tensor(
+ [
+ [-1.0000e02, -1.4055e-01, -3.0476e-02, -2.7149e-01, -1.1157e-01, -2.3366e-01, -3.3658e-01, -1.6447e-01],
+ [
+ -1.7633e-01,
+ -2.6533e-01,
+ -3.0786e-01,
+ -2.6038e-01,
+ -2.7176e-01,
+ -2.1655e-01,
+ -1.5441e-01,
+ -2.9580e-01,
+ ],
+ [-1.5952e-01, -1.5177e-01, -1.4296e-01, -1.6131e-01, -3.1395e-02, 2.8808e-03, -3.1242e-02, 4.8696e-03],
+ [1.1407e-02, -1.0000e02, -6.2290e-02, -3.7030e-01, -2.7648e-01, -3.6655e-01, -2.8456e-01, -2.3165e-01],
+ ]
+ )
+ dones = torch.tensor(
+ [
+ [1, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 1, 0, 0, 0, 0, 0, 0],
+ ]
+ )
+ observations = torch.zeros((dones.shape[0], dones.shape[1], 24))
+ timeouts = torch.zeros((dones.shape[0], dones.shape[1]))
+ values = torch.tensor(
+ [
+ [-4.6342, -7.6510, -7.0166, -7.6137, -7.4130, -7.7071, -7.7413, -7.8301],
+ [-7.0442, -7.0032, -6.9321, -6.7765, -6.5433, -6.3503, -6.2529, -5.9337],
+ [-7.5753, -7.8146, -7.6142, -7.8443, -7.8791, -7.7973, -7.7853, -7.7724],
+ [-6.4326, -6.1673, -7.6511, -7.7505, -8.0004, -7.8584, -7.5949, -7.9023],
+ ]
+ )
+ value_quants = values.unsqueeze(-1)
+ last_values = torch.tensor([-7.9343, -5.8734, -7.8527, -8.1257])
+
+ dppo.critic = FakeCritic(last_values.unsqueeze(-1))
+ dataset = [
+ {
+ "dones": dones[:, i],
+ "full_next_critic_observations": observations[:, i].clone(),
+ "next_critic_observations": observations[:, i],
+ "rewards": rewards[:, i],
+ "timeouts": timeouts[:, i],
+ "values": values[:, i],
+ "value_quants": value_quants[:, i],
+ }
+ for i in range(dones.shape[1])
+ ]
+
+ processed_dataset = dppo._process_dataset(dataset)
+ processed_returns = torch.stack(
+ [processed_dataset[i]["advantages"] + processed_dataset[i]["values"] for i in range(dones.shape[1])],
+ dim=-1,
+ )
+ processed_advantages = torch.stack(
+ [processed_dataset[i]["normalized_advantages"] for i in range(dones.shape[1])], dim=-1
+ )
+
+ expected_returns = torch.tensor(
+ [
+ [-100.0000, -8.4983, -8.4863, -8.5699, -8.4122, -8.4054, -8.2702, -8.0194],
+ [-7.2900, -7.1912, -6.9978, -6.7569, -6.5627, -6.3547, -6.1985, -6.1104],
+ [-7.9179, -7.8374, -7.7679, -7.6976, -7.6041, -7.6446, -7.7229, -7.7693],
+ [-96.2018, -100.0000, -9.0710, -9.1415, -8.8863, -8.7228, -8.4668, -8.2761],
+ ]
+ )
+ expected_advantages = torch.tensor(
+ [
+ [-3.1452, 0.3006, 0.2779, 0.2966, 0.2951, 0.3060, 0.3122, 0.3246],
+ [0.3225, 0.3246, 0.3291, 0.3322, 0.3308, 0.3313, 0.3335, 0.3250],
+ [0.3190, 0.3307, 0.3259, 0.3368, 0.3415, 0.3371, 0.3338, 0.3316],
+ [-2.9412, -3.0893, 0.2797, 0.2808, 0.2992, 0.3000, 0.2997, 0.3179],
+ ]
+ )
+
+ self.assertTrue(torch.isclose(processed_returns, expected_returns, atol=1e-4).all())
+ self.assertTrue(torch.isclose(processed_advantages, expected_advantages, atol=1e-4).all())
+
+ def test_target_computation_nstep(self):
+ # GT taken from old PPO implementation.
+
+ env = PoleBalancing(environment_count=2)
+ dppo = DPPO(env, device="cpu", gae_lambda=0.97, gamma=0.99, qrdqn_quantile_count=3, value_lambda=1.0)
+
+ rewards = torch.tensor(
+ [
+ [0.6, 1.0, 0.0, 0.6, 0.1, -0.2],
+ [0.1, -0.2, -0.4, 0.1, 1.0, 1.0],
+ ]
+ )
+ dones = torch.tensor(
+ [
+ [1.0, 0.0, 0.0, 0.0, 1.0, 0.0],
+ [0.0, 0.0, 1.0, 0.0, 1.0, 0.0],
+ ]
+ )
+ observations = torch.zeros((dones.shape[0], dones.shape[1], 24))
+ timeouts = torch.zeros((dones.shape[0], dones.shape[1]))
+ value_quants = torch.tensor(
+ [
+ [
+ [-0.4, 0.0, 0.1],
+ [0.9, 0.8, 0.7],
+ [0.7, 1.3, 0.0],
+ [1.2, 0.4, 1.2],
+ [1.3, 1.3, 1.1],
+ [0.0, 0.7, 0.5],
+ ],
+ [
+ [1.3, 1.3, 0.9],
+ [0.4, -0.4, -0.1],
+ [0.4, 0.6, 0.1],
+ [0.7, 0.1, 0.3],
+ [0.2, 0.1, 0.3],
+ [1.4, 1.4, -0.3],
+ ],
+ ]
+ )
+ values = value_quants.mean(dim=-1)
+ last_values = torch.rand((dones.shape[0], 3))
+
+ dppo.critic = FakeCritic(last_values, quantile_count=3)
+ dataset = [
+ {
+ "dones": dones[:, i],
+ "full_next_critic_observations": observations[:, i].clone(),
+ "next_critic_observations": observations[:, i],
+ "rewards": rewards[:, i],
+ "timeouts": timeouts[:, i],
+ "values": values[:, i],
+ "value_quants": value_quants[:, i],
+ }
+ for i in range(dones.shape[1])
+ ]
+
+ processed_dataset = dppo._process_dataset(dataset)
+ processed_value_target_quants = torch.stack(
+ [processed_dataset[i]["value_target_quants"] for i in range(dones.shape[1])],
+ dim=-2,
+ )
+
+ # N-step returns
+ # These exclude the reward received on the final step since it should not be added to the value target.
+ expected_returns = torch.tensor(
+ [
+ [0.6, 1.58806, 0.594, 0.6, 0.1, -0.2],
+ [-0.098, -0.2, -0.4, 0.1, 1.0, 1.0],
+ ]
+ )
+ reset = lambda x: [0.0 for _ in x]
+ dscnt = lambda x, s=0: [x[v] * dppo.gamma**s for v in range(len(x))]
+ expected_value_target_quants = expected_returns.unsqueeze(-1) + torch.tensor(
+ [
+ [
+ reset([-0.4, 0.0, 0.1]),
+ dscnt([1.3, 1.3, 1.1], 3),
+ dscnt([1.3, 1.3, 1.1], 2),
+ dscnt([1.3, 1.3, 1.1], 1),
+ reset([1.3, 1.3, 1.1]),
+ dscnt(last_values[0], 1),
+ ],
+ [
+ dscnt([0.4, 0.6, 0.1], 2),
+ dscnt([0.4, 0.6, 0.1], 1),
+ reset([0.4, 0.6, 0.1]),
+ dscnt([0.2, 0.1, 0.3], 1),
+ reset([0.2, 0.1, 0.3]),
+ dscnt(last_values[1], 1),
+ ],
+ ]
+ )
+
+ self.assertTrue(torch.isclose(processed_value_target_quants, expected_value_target_quants, atol=1e-4).all())
+
+ def test_target_computation_1step(self):
+ # GT taken from old PPO implementation.
+
+ env = PoleBalancing(environment_count=2)
+ dppo = DPPO(env, device="cpu", gae_lambda=0.97, gamma=0.99, qrdqn_quantile_count=3, value_lambda=0.0)
+
+ rewards = torch.tensor(
+ [
+ [0.6, 1.0, 0.0, 0.6, 0.1, -0.2],
+ [0.1, -0.2, -0.4, 0.1, 1.0, 1.0],
+ ]
+ )
+ dones = torch.tensor(
+ [
+ [1.0, 0.0, 0.0, 0.0, 1.0, 0.0],
+ [0.0, 0.0, 1.0, 0.0, 1.0, 0.0],
+ ]
+ )
+ observations = torch.zeros((dones.shape[0], dones.shape[1], 24))
+ timeouts = torch.zeros((dones.shape[0], dones.shape[1]))
+ value_quants = torch.tensor(
+ [
+ [
+ [-0.4, 0.0, 0.1],
+ [0.9, 0.8, 0.7],
+ [0.7, 1.3, 0.0],
+ [1.2, 0.4, 1.2],
+ [1.3, 1.3, 1.1],
+ [0.0, 0.7, 0.5],
+ ],
+ [
+ [1.3, 1.3, 0.9],
+ [0.4, -0.4, -0.1],
+ [0.4, 0.6, 0.1],
+ [0.7, 0.1, 0.3],
+ [0.2, 0.1, 0.3],
+ [1.4, 1.4, -0.3],
+ ],
+ ]
+ )
+ values = value_quants.mean(dim=-1)
+ last_values = torch.rand((dones.shape[0], 3))
+
+ dppo.critic = FakeCritic(last_values, quantile_count=3)
+ dataset = [
+ {
+ "dones": dones[:, i],
+ "full_next_critic_observations": observations[:, i].clone(),
+ "next_critic_observations": observations[:, i],
+ "rewards": rewards[:, i],
+ "timeouts": timeouts[:, i],
+ "values": values[:, i],
+ "value_quants": value_quants[:, i],
+ }
+ for i in range(dones.shape[1])
+ ]
+
+ processed_dataset = dppo._process_dataset(dataset)
+ processed_value_target_quants = torch.stack(
+ [processed_dataset[i]["value_target_quants"] for i in range(dones.shape[1])],
+ dim=-2,
+ )
+
+ # 1-step returns
+ expected_value_target_quants = rewards.unsqueeze(-1) + (
+ (1.0 - dones).float().unsqueeze(-1)
+ * dppo.gamma
+ * torch.cat((value_quants[:, 1:], last_values.unsqueeze(1)), dim=1)
+ )
+
+ self.assertTrue(torch.isclose(processed_value_target_quants, expected_value_target_quants, atol=1e-4).all())
diff --git a/tests/test_dppo_iqn.py b/tests/test_dppo_iqn.py
new file mode 100644
index 0000000..dbe498a
--- /dev/null
+++ b/tests/test_dppo_iqn.py
@@ -0,0 +1,171 @@
+import torch
+import unittest
+from rsl_rl.algorithms import DPPO
+from rsl_rl.env.vec_env import VecEnv
+
+ACTION_SIZE = 3
+ENV_COUNT = 3
+OBS_SIZE = 24
+
+
+class FakeEnv(VecEnv):
+ def __init__(self, rewards, dones, environment_count=1):
+ super().__init__(OBS_SIZE, OBS_SIZE, environment_count=environment_count)
+
+ self.num_actions = ACTION_SIZE
+ self.rewards = rewards
+ self.dones = dones
+
+ self._step = 0
+
+ def get_observations(self):
+ return torch.zeros((self.num_envs, self.num_obs)), {"observations": {}}
+
+ def get_privileged_observations(self):
+ return torch.zeros((self.num_envs, self.num_privileged_obs)), {"observations": {}}
+
+ def step(self, actions):
+ obs, _ = self.get_observations()
+ rewards = self.rewards[self._step]
+ dones = self.dones[self._step]
+
+ self._step += 1
+
+ return obs, rewards, dones, {"observations": {}}
+
+ def reset(self):
+ pass
+
+
+class FakeCritic(torch.nn.Module):
+ def __init__(self, action_samples, value_samples, action_values, value_values, action_taus, value_taus):
+ self.recurrent = False
+ self.action_samples = action_samples
+ self.value_samples = value_samples
+ self.action_values = action_values
+ self.value_values = value_values
+ self.action_taus = action_taus
+ self.value_taus = value_taus
+
+ self.last_quantiles = None
+ self.last_taus = None
+
+ def forward(self, _, distribution=False, measure_args=None, sample_count=8, taus=None, use_measure=True):
+ if taus is not None:
+ sample_count = taus.shape[-1]
+
+ if sample_count == self.action_samples:
+ self.last_taus = self.action_taus
+ self.last_quantiles = self.action_values
+ elif sample_count == self.value_samples:
+ self.last_taus = self.value_taus
+ self.last_quantiles = self.value_values
+ else:
+ raise ValueError(f"Invalid sample count: {sample_count}")
+
+ if distribution:
+ return self.last_quantiles
+
+ return self.last_quantiles.mean(-1)
+
+
+def fake_process_quants(self, x):
+ idx = torch.arange(0, x.shape[-1]).expand(*x.shape[:-1])
+
+ return x, idx
+
+
+class DPPOTest(unittest.TestCase):
+ def test_value_target_computation(self):
+ rewards = torch.tensor(
+ [
+ [-1.0000e02, -1.4055e-01, -3.0476e-02],
+ [-1.7633e-01, -2.6533e-01, -3.0786e-01],
+ [-1.5952e-01, -1.5177e-01, -1.4296e-01],
+ [1.1407e-02, -1.0000e02, -6.2290e-02],
+ ]
+ )
+ dones = torch.tensor(
+ [
+ [1, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0],
+ [0, 1, 0],
+ ]
+ )
+
+ env = FakeEnv(rewards, dones, environment_count=ENV_COUNT)
+ dppo = DPPO(
+ env,
+ critic_network=DPPO.network_iqn,
+ device="cpu",
+ gae_lambda=0.97,
+ gamma=0.99,
+ iqn_action_samples=4,
+ iqn_value_samples=2,
+ value_lambda=1.0,
+ value_loss=DPPO.value_loss_energy,
+ )
+
+ # Generate fake dataset
+
+ action_taus = torch.tensor(
+ [
+ [[0.3, 0.5, 1.0, 0.2], [0.8, 0.9, 0.0, 0.9], [0.6, 0.1, 0.6, 0.5]],
+ [[0.7, 0.9, 0.3, 0.0], [1.0, 0.7, 0.7, 0.7], [0.3, 0.8, 0.8, 0.1]],
+ [[0.3, 0.8, 0.3, 0.2], [0.2, 0.9, 0.6, 0.4], [0.8, 0.4, 0.8, 1.0]],
+ [[0.6, 0.6, 0.8, 0.8], [0.8, 0.0, 0.9, 0.1], [0.2, 0.3, 0.6, 0.2]],
+ ]
+ )
+ action_value_quants = torch.tensor(
+ [
+ [[0.2, 0.2, 0.6, 0.5], [0.5, 0.8, 0.1, 0.0], [1.0, 0.1, 0.8, 0.8]],
+ [[0.0, 0.6, 0.1, 0.9], [0.2, 1.0, 0.9, 1.0], [0.4, 0.1, 0.1, 0.8]],
+ [[0.7, 0.0, 0.6, 0.8], [0.7, 0.7, 0.7, 0.8], [0.0, 0.1, 0.5, 0.8]],
+ [[0.5, 0.8, 0.1, 0.1], [0.9, 0.4, 0.7, 0.6], [0.6, 0.3, 0.1, 0.4]],
+ ]
+ )
+ value_taus = torch.tensor(
+ [
+ [[0.3, 0.5], [0.8, 0.9], [0.6, 0.1]],
+ [[0.7, 0.9], [1.0, 0.7], [0.3, 0.8]],
+ [[0.3, 0.8], [0.2, 0.9], [0.8, 0.4]],
+ [[0.6, 0.6], [0.8, 0.0], [0.2, 0.3]],
+ ]
+ )
+ value_value_quants = torch.tensor(
+ [
+ [[0.9, 0.8], [0.1, 0.3], [0.3, 0.5]],
+ [[0.2, 0.1], [0.9, 0.3], [0.4, 0.2]],
+ [[0.7, 1.0], [0.6, 0.2], [0.2, 0.6]],
+ [[0.4, 1.0], [0.3, 0.6], [0.3, 0.1]],
+ ]
+ )
+
+ actions = torch.zeros(ENV_COUNT, ACTION_SIZE)
+ env_info = {"observations": {}}
+ obs = torch.zeros(ENV_COUNT, OBS_SIZE)
+ dataset = []
+ for i in range(4):
+ dppo.critic = FakeCritic(4, 2, action_value_quants[i], value_value_quants[i], action_taus[i], value_taus[i])
+ dppo.critic._process_quants = fake_process_quants
+
+ _, data = dppo.draw_actions(obs, {})
+ _, rewards, dones, _ = env.step(actions)
+
+ dataset.append(
+ dppo.process_transition(
+ obs,
+ env_info,
+ actions,
+ rewards,
+ obs,
+ env_info,
+ dones,
+ data,
+ )
+ )
+
+ processed_dataset = dppo._process_dataset(dataset)
+
+ # TODO: Test that the value targets are correct.
diff --git a/tests/test_dppo_recurrency.py b/tests/test_dppo_recurrency.py
new file mode 100644
index 0000000..dd860e8
--- /dev/null
+++ b/tests/test_dppo_recurrency.py
@@ -0,0 +1,170 @@
+import torch
+import unittest
+from rsl_rl.algorithms import DPPO
+from rsl_rl.env.vec_env import VecEnv
+from rsl_rl.runners.runner import Runner
+from rsl_rl.utils.benchmarkable import Benchmarkable
+
+
+class FakeNetwork(torch.nn.Module, Benchmarkable):
+ def __init__(self, values):
+ super().__init__()
+
+ self.hidden_state = None
+ self.quantile_count = 1
+ self.recurrent = True
+ self.values = values
+
+ self._hidden_size = 2
+
+ def forward(self, x, hidden_state=None):
+ if not hidden_state:
+ self.hidden_state = (self.hidden_state[0] + 1, self.hidden_state[1] - 1)
+
+ values = self.values.repeat((*x.shape[:-1], 1)).squeeze(-1)
+ values.requires_grad_(True)
+
+ return values
+
+ def reset_full_hidden_state(self, batch_size=None):
+ assert batch_size is None or batch_size == 4, f"batch_size={batch_size}"
+
+ self.hidden_state = (torch.zeros((1, 4, self._hidden_size)), torch.zeros((1, 4, self._hidden_size)))
+
+ def reset_hidden_state(self, indices):
+ self.hidden_state[0][:, indices] = torch.zeros((len(indices), self._hidden_size))
+ self.hidden_state[1][:, indices] = torch.zeros((len(indices), self._hidden_size))
+
+
+class FakeActorNetwork(FakeNetwork):
+ def forward(self, x, compute_std=False, hidden_state=None):
+ values = super().forward(x, hidden_state=hidden_state)
+
+ if compute_std:
+ return values, torch.ones_like(values)
+
+ return values
+
+
+class FakeCriticNetwork(FakeNetwork):
+ _quantile_count = 1
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, x, distribution=False, hidden_state=None, measure_args=None):
+ values = super().forward(x, hidden_state=hidden_state)
+ self.last_quantiles = values.reshape(*values.shape, 1)
+
+ if distribution:
+ return self.last_quantiles
+
+ return values
+
+ def quantile_l1_loss(self, *args, **kwargs):
+ return torch.tensor(0.0)
+
+ def quantiles_to_values(self, quantiles):
+ return quantiles.squeeze()
+
+
+class FakeEnv(VecEnv):
+ def __init__(self, dones=None, **kwargs):
+ super().__init__(3, 3, **kwargs)
+
+ self.num_actions = 3
+ self._extra = {"observations": {}, "time_outs": torch.zeros((self.num_envs, 1))}
+
+ self._step = 0
+ self._dones = dones
+
+ self.reset()
+
+ def get_observations(self):
+ return self._state_buf, self._extra
+
+ def get_privileged_observations(self):
+ return self._state_buf, self._extra
+
+ def reset(self):
+ self._state_buf = torch.zeros((self.num_envs, self.num_obs))
+
+ return self._state_buf, self._extra
+
+ def step(self, actions):
+ assert actions.shape[0] == self.num_envs
+ assert actions.shape[1] == self.num_actions
+
+ self._state_buf += actions
+
+ rewards = torch.zeros((self.num_envs))
+ dones = torch.zeros((self.num_envs)) if self._dones is None else self._dones[self._step % self._dones.shape[0]]
+
+ self._step += 1
+
+ return self._state_buf, rewards, dones, self._extra
+
+
+class DPPORecurrencyTest(unittest.TestCase):
+ def test_draw_action_produces_hidden_state(self):
+ """Test that the hidden state is correctly added to the data dictionary when drawing actions."""
+ env = FakeEnv(environment_count=4)
+ dppo = DPPO(env, device="cpu", recurrent=True)
+
+ dppo.actor = FakeActorNetwork(torch.ones(env.num_actions))
+ dppo.critic = FakeCriticNetwork(torch.zeros(1))
+
+ # Done during DPPO.__init__, however we need to reset the hidden state here again since we are using a fake
+ # network that is added after initialization.
+ dppo.actor.reset_full_hidden_state(batch_size=env.num_envs)
+ dppo.critic.reset_full_hidden_state(batch_size=env.num_envs)
+
+ ones = torch.ones((1, env.num_envs, dppo.actor._hidden_size))
+ state, extra = env.reset()
+ for ctr in range(10):
+ _, data = dppo.draw_actions(state, extra)
+
+ # Actor state is changed every time an action is drawn.
+ self.assertTrue(torch.allclose(data["actor_state_h"], ones * ctr))
+ self.assertTrue(torch.allclose(data["actor_state_c"], -ones * ctr))
+ # Critic state is only changed and saved when processing the transition (evaluating the action) so we can't
+ # check it here.
+
+ def test_update_produces_hidden_state(self):
+ """Test that the hidden state is correctly added to the data dictionary when updating."""
+ dones = torch.cat((torch.tensor([[0, 0, 0, 1]]), torch.zeros((4, 4)), torch.tensor([[1, 0, 0, 0]])), dim=0)
+
+ env = FakeEnv(dones=dones, environment_count=4)
+ dppo = DPPO(env, device="cpu", recurrent=True)
+ runner = Runner(env, dppo, num_steps_per_env=6)
+
+ dppo._value_loss = lambda *args, **kwargs: torch.tensor(0.0)
+
+ dppo.actor = FakeActorNetwork(torch.ones(env.num_actions))
+ dppo.critic = FakeCriticNetwork(torch.zeros(1))
+
+ dppo.actor.reset_full_hidden_state(batch_size=env.num_envs)
+ dppo.critic.reset_full_hidden_state(batch_size=env.num_envs)
+
+ runner.learn(1)
+
+ state_h_0 = torch.tensor([[0, 0], [0, 0], [0, 0], [0, 0]])
+ state_h_1 = torch.tensor([[1, 1], [1, 1], [1, 1], [0, 0]])
+ state_h_2 = state_h_1 + 1
+ state_h_3 = state_h_2 + 1
+ state_h_4 = state_h_3 + 1
+ state_h_5 = state_h_4 + 1
+ state_h_6 = torch.tensor([[0, 0], [6, 6], [6, 6], [5, 5]])
+ state_h = (
+ torch.cat((state_h_0, state_h_1, state_h_2, state_h_3, state_h_4, state_h_5), dim=0).float().unsqueeze(1)
+ )
+ next_state_h = (
+ torch.cat((state_h_1, state_h_2, state_h_3, state_h_4, state_h_5, state_h_6), dim=0).float().unsqueeze(1)
+ )
+
+ self.assertTrue(torch.allclose(dppo.storage._data["critic_state_h"], state_h))
+ self.assertTrue(torch.allclose(dppo.storage._data["critic_state_c"], -state_h))
+ self.assertTrue(torch.allclose(dppo.storage._data["critic_next_state_h"], next_state_h))
+ self.assertTrue(torch.allclose(dppo.storage._data["critic_next_state_c"], -next_state_h))
+ self.assertTrue(torch.allclose(dppo.storage._data["actor_state_h"], state_h))
+ self.assertTrue(torch.allclose(dppo.storage._data["actor_state_c"], -state_h))
diff --git a/tests/test_ppo.py b/tests/test_ppo.py
new file mode 100644
index 0000000..4f8ef0f
--- /dev/null
+++ b/tests/test_ppo.py
@@ -0,0 +1,99 @@
+import torch
+import unittest
+from rsl_rl.algorithms import PPO
+from rsl_rl.env.pole_balancing import PoleBalancing
+
+
+class FakeCritic(torch.nn.Module):
+ def __init__(self, values):
+ self.recurrent = False
+ self.values = values
+
+ def forward(self, _):
+ return self.values
+
+
+class PPOTest(unittest.TestCase):
+ def test_gae_computation(self):
+ # GT taken from old PPO implementation.
+
+ env = PoleBalancing(environment_count=4)
+ ppo = PPO(env, device="cpu", gae_lambda=0.97, gamma=0.99)
+
+ rewards = torch.tensor(
+ [
+ [-1.0000e02, -1.4055e-01, -3.0476e-02, -2.7149e-01, -1.1157e-01, -2.3366e-01, -3.3658e-01, -1.6447e-01],
+ [
+ -1.7633e-01,
+ -2.6533e-01,
+ -3.0786e-01,
+ -2.6038e-01,
+ -2.7176e-01,
+ -2.1655e-01,
+ -1.5441e-01,
+ -2.9580e-01,
+ ],
+ [-1.5952e-01, -1.5177e-01, -1.4296e-01, -1.6131e-01, -3.1395e-02, 2.8808e-03, -3.1242e-02, 4.8696e-03],
+ [1.1407e-02, -1.0000e02, -6.2290e-02, -3.7030e-01, -2.7648e-01, -3.6655e-01, -2.8456e-01, -2.3165e-01],
+ ]
+ )
+ dones = torch.tensor(
+ [
+ [1, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0, 0, 0],
+ [0, 1, 0, 0, 0, 0, 0, 0],
+ ]
+ )
+ observations = torch.zeros((dones.shape[0], dones.shape[1], 24))
+ timeouts = torch.zeros((dones.shape[0], dones.shape[1]))
+ values = torch.tensor(
+ [
+ [-4.6342, -7.6510, -7.0166, -7.6137, -7.4130, -7.7071, -7.7413, -7.8301],
+ [-7.0442, -7.0032, -6.9321, -6.7765, -6.5433, -6.3503, -6.2529, -5.9337],
+ [-7.5753, -7.8146, -7.6142, -7.8443, -7.8791, -7.7973, -7.7853, -7.7724],
+ [-6.4326, -6.1673, -7.6511, -7.7505, -8.0004, -7.8584, -7.5949, -7.9023],
+ ]
+ )
+ last_values = torch.tensor([-7.9343, -5.8734, -7.8527, -8.1257])
+
+ ppo.critic = FakeCritic(last_values)
+ dataset = [
+ {
+ "dones": dones[:, i],
+ "next_critic_observations": observations[:, i],
+ "rewards": rewards[:, i],
+ "timeouts": timeouts[:, i],
+ "values": values[:, i],
+ }
+ for i in range(dones.shape[1])
+ ]
+
+ processed_dataset = ppo._process_dataset(dataset)
+ processed_returns = torch.stack(
+ [processed_dataset[i]["advantages"] + processed_dataset[i]["values"] for i in range(dones.shape[1])],
+ dim=-1,
+ )
+ processed_advantages = torch.stack(
+ [processed_dataset[i]["normalized_advantages"] for i in range(dones.shape[1])], dim=-1
+ )
+
+ expected_returns = torch.tensor(
+ [
+ [-100.0000, -8.4983, -8.4863, -8.5699, -8.4122, -8.4054, -8.2702, -8.0194],
+ [-7.2900, -7.1912, -6.9978, -6.7569, -6.5627, -6.3547, -6.1985, -6.1104],
+ [-7.9179, -7.8374, -7.7679, -7.6976, -7.6041, -7.6446, -7.7229, -7.7693],
+ [-96.2018, -100.0000, -9.0710, -9.1415, -8.8863, -8.7228, -8.4668, -8.2761],
+ ]
+ )
+ expected_advantages = torch.tensor(
+ [
+ [-3.1452, 0.3006, 0.2779, 0.2966, 0.2951, 0.3060, 0.3122, 0.3246],
+ [0.3225, 0.3246, 0.3291, 0.3322, 0.3308, 0.3313, 0.3335, 0.3250],
+ [0.3190, 0.3307, 0.3259, 0.3368, 0.3415, 0.3371, 0.3338, 0.3316],
+ [-2.9412, -3.0893, 0.2797, 0.2808, 0.2992, 0.3000, 0.2997, 0.3179],
+ ]
+ )
+
+ self.assertTrue(torch.isclose(processed_returns, expected_returns, atol=1e-4).all())
+ self.assertTrue(torch.isclose(processed_advantages, expected_advantages, atol=1e-4).all())
diff --git a/tests/test_ppo_recurrency.py b/tests/test_ppo_recurrency.py
new file mode 100644
index 0000000..702b19d
--- /dev/null
+++ b/tests/test_ppo_recurrency.py
@@ -0,0 +1,144 @@
+import torch
+import unittest
+from rsl_rl.algorithms import PPO
+from rsl_rl.env.vec_env import VecEnv
+from rsl_rl.runners.runner import Runner
+
+
+class FakeNetwork(torch.nn.Module):
+ def __init__(self, values):
+ super().__init__()
+
+ self.hidden_state = None
+ self.recurrent = True
+ self.values = values
+
+ self._hidden_size = 2
+
+ def forward(self, x, hidden_state=None):
+ if not hidden_state:
+ self.hidden_state = (self.hidden_state[0] + 1, self.hidden_state[1] - 1)
+
+ values = self.values.repeat((*x.shape[:-1], 1)).squeeze(-1)
+ values.requires_grad_(True)
+
+ return values
+
+ def reset_full_hidden_state(self, batch_size=None):
+ assert batch_size is None or batch_size == 4, f"batch_size={batch_size}"
+
+ self.hidden_state = (torch.zeros((1, 4, self._hidden_size)), torch.zeros((1, 4, self._hidden_size)))
+
+ def reset_hidden_state(self, indices):
+ self.hidden_state[0][:, indices] = torch.zeros((len(indices), self._hidden_size))
+ self.hidden_state[1][:, indices] = torch.zeros((len(indices), self._hidden_size))
+
+
+class FakeActorNetwork(FakeNetwork):
+ def forward(self, x, compute_std=False, hidden_state=None):
+ values = super().forward(x, hidden_state=hidden_state)
+
+ if compute_std:
+ return values, torch.ones_like(values)
+
+ return values
+
+
+class FakeEnv(VecEnv):
+ def __init__(self, dones=None, **kwargs):
+ super().__init__(3, 3, **kwargs)
+
+ self.num_actions = 3
+ self._extra = {"observations": {}, "time_outs": torch.zeros((self.num_envs, 1))}
+
+ self._step = 0
+ self._dones = dones
+
+ self.reset()
+
+ def get_observations(self):
+ return self._state_buf, self._extra
+
+ def get_privileged_observations(self):
+ return self._state_buf, self._extra
+
+ def reset(self):
+ self._state_buf = torch.zeros((self.num_envs, self.num_obs))
+
+ return self._state_buf, self._extra
+
+ def step(self, actions):
+ assert actions.shape[0] == self.num_envs
+ assert actions.shape[1] == self.num_actions
+
+ self._state_buf += actions
+
+ rewards = torch.zeros((self.num_envs))
+ dones = torch.zeros((self.num_envs)) if self._dones is None else self._dones[self._step % self._dones.shape[0]]
+
+ self._step += 1
+
+ return self._state_buf, rewards, dones, self._extra
+
+
+class PPORecurrencyTest(unittest.TestCase):
+ def test_draw_action_produces_hidden_state(self):
+ """Test that the hidden state is correctly added to the data dictionary when drawing actions."""
+ env = FakeEnv(environment_count=4)
+ ppo = PPO(env, device="cpu", recurrent=True)
+
+ ppo.actor = FakeActorNetwork(torch.ones(env.num_actions))
+ ppo.critic = FakeNetwork(torch.zeros(1))
+
+ # Done during PPO.__init__, however we need to reset the hidden state here again since we are using a fake
+ # network that is added after initialization.
+ ppo.actor.reset_full_hidden_state(batch_size=env.num_envs)
+ ppo.critic.reset_full_hidden_state(batch_size=env.num_envs)
+
+ ones = torch.ones((1, env.num_envs, ppo.actor._hidden_size))
+ state, extra = env.reset()
+ for ctr in range(10):
+ _, data = ppo.draw_actions(state, extra)
+
+ # Actor state is changed every time an action is drawn.
+ self.assertTrue(torch.allclose(data["actor_state_h"], ones * ctr))
+ self.assertTrue(torch.allclose(data["actor_state_c"], -ones * ctr))
+ # Critic state is only changed and saved when processing the transition (evaluating the action) so we can't
+ # check it here.
+
+ def test_update_produces_hidden_state(self):
+ """Test that the hidden state is correctly added to the data dictionary when updating."""
+ dones = torch.cat((torch.tensor([[0, 0, 0, 1]]), torch.zeros((4, 4)), torch.tensor([[1, 0, 0, 0]])), dim=0)
+
+ env = FakeEnv(dones=dones, environment_count=4)
+ ppo = PPO(env, device="cpu", recurrent=True)
+ runner = Runner(env, ppo, num_steps_per_env=6)
+
+ ppo.actor = FakeActorNetwork(torch.ones(env.num_actions))
+ ppo.critic = FakeNetwork(torch.zeros(1))
+
+ ppo.actor.reset_full_hidden_state(batch_size=env.num_envs)
+ ppo.critic.reset_full_hidden_state(batch_size=env.num_envs)
+
+ runner.learn(1)
+
+ state_h_0 = torch.tensor([[0, 0], [0, 0], [0, 0], [0, 0]])
+ state_h_1 = torch.tensor([[1, 1], [1, 1], [1, 1], [0, 0]])
+ state_h_2 = state_h_1 + 1
+ state_h_3 = state_h_2 + 1
+ state_h_4 = state_h_3 + 1
+ state_h_5 = state_h_4 + 1
+ state_h_6 = torch.tensor([[0, 0], [6, 6], [6, 6], [5, 5]])
+ state_h = (
+ torch.cat((state_h_0, state_h_1, state_h_2, state_h_3, state_h_4, state_h_5), dim=0).float().unsqueeze(1)
+ )
+ next_state_h = (
+ torch.cat((state_h_1, state_h_2, state_h_3, state_h_4, state_h_5, state_h_6), dim=0).float().unsqueeze(1)
+ )
+
+ self.assertTrue(torch.allclose(ppo.storage._data["critic_state_h"], state_h))
+ self.assertTrue(torch.allclose(ppo.storage._data["critic_state_c"], -state_h))
+ self.assertTrue(torch.allclose(ppo.storage._data["critic_next_state_h"], next_state_h))
+ self.assertTrue(torch.allclose(ppo.storage._data["critic_next_state_c"], -next_state_h))
+ self.assertTrue(torch.allclose(ppo.storage._data["actor_state_h"], state_h))
+ self.assertTrue(torch.allclose(ppo.storage._data["actor_state_c"], -state_h))
diff --git a/tests/test_quantile_network.py b/tests/test_quantile_network.py
new file mode 100644
index 0000000..eaea109
--- /dev/null
+++ b/tests/test_quantile_network.py
@@ -0,0 +1,286 @@
+import torch
+import unittest
+from rsl_rl.modules.quantile_network import QuantileNetwork
+
+
+class QuantileNetworkTest(unittest.TestCase):
+ def test_l1_loss(self):
+ qn = QuantileNetwork(10, 1, quantile_count=5)
+
+ prediction = torch.tensor(
+ [
+ [0.8510, 0.2329, 0.4244, 0.5241, 0.2144],
+ [0.7693, 0.2522, 0.3909, 0.0858, 0.7914],
+ [0.8701, 0.2144, 0.9661, 0.9975, 0.5043],
+ [0.2653, 0.6951, 0.9787, 0.2244, 0.0430],
+ [0.7907, 0.5209, 0.7276, 0.1735, 0.2757],
+ [0.1696, 0.7167, 0.6363, 0.2188, 0.7025],
+ [0.0445, 0.6008, 0.5334, 0.1838, 0.7387],
+ [0.4934, 0.5117, 0.4488, 0.0591, 0.6442],
+ ]
+ )
+ target = torch.tensor(
+ [
+ [0.3918, 0.8979, 0.4347, 0.1076, 0.5303],
+ [0.5449, 0.9974, 0.3197, 0.8686, 0.0631],
+ [0.7397, 0.7734, 0.6559, 0.3020, 0.7229],
+ [0.9519, 0.8138, 0.1502, 0.3445, 0.3356],
+ [0.8970, 0.0910, 0.7536, 0.6069, 0.2556],
+ [0.1741, 0.6863, 0.7142, 0.2911, 0.3142],
+ [0.8835, 0.0215, 0.4774, 0.5362, 0.4998],
+ [0.8037, 0.8269, 0.5518, 0.4368, 0.5323],
+ ]
+ )
+
+ loss = qn.quantile_l1_loss(prediction, target)
+
+ self.assertAlmostEqual(loss.item(), 0.16419549)
+
+ def test_l1_loss_3d(self):
+ qn = QuantileNetwork(10, 1, quantile_count=5)
+
+ prediction = torch.tensor(
+ [
+ [
+ [0.8510, 0.2329, 0.4244, 0.5241, 0.2144],
+ [0.7693, 0.2522, 0.3909, 0.0858, 0.7914],
+ [0.8701, 0.2144, 0.9661, 0.9975, 0.5043],
+ [0.2653, 0.6951, 0.9787, 0.2244, 0.0430],
+ [0.7907, 0.5209, 0.7276, 0.1735, 0.2757],
+ [0.1696, 0.7167, 0.6363, 0.2188, 0.7025],
+ [0.0445, 0.6008, 0.5334, 0.1838, 0.7387],
+ [0.4934, 0.5117, 0.4488, 0.0591, 0.6442],
+ ],
+ [
+ [0.6874, 0.6214, 0.7796, 0.8148, 0.2070],
+ [0.0276, 0.5764, 0.5516, 0.9682, 0.6901],
+ [0.4020, 0.7084, 0.9965, 0.4311, 0.3789],
+ [0.5350, 0.9431, 0.1032, 0.6959, 0.4992],
+ [0.5059, 0.5479, 0.2302, 0.6753, 0.1593],
+ [0.6753, 0.4590, 0.9956, 0.6117, 0.1410],
+ [0.7464, 0.7184, 0.2972, 0.7694, 0.7999],
+ [0.3907, 0.2112, 0.6485, 0.0139, 0.6252],
+ ],
+ ]
+ )
+ target = torch.tensor(
+ [
+ [
+ [0.3918, 0.8979, 0.4347, 0.1076, 0.5303],
+ [0.5449, 0.9974, 0.3197, 0.8686, 0.0631],
+ [0.7397, 0.7734, 0.6559, 0.3020, 0.7229],
+ [0.9519, 0.8138, 0.1502, 0.3445, 0.3356],
+ [0.8970, 0.0910, 0.7536, 0.6069, 0.2556],
+ [0.1741, 0.6863, 0.7142, 0.2911, 0.3142],
+ [0.8835, 0.0215, 0.4774, 0.5362, 0.4998],
+ [0.8037, 0.8269, 0.5518, 0.4368, 0.5323],
+ ],
+ [
+ [0.5120, 0.7683, 0.3579, 0.8640, 0.4374],
+ [0.2533, 0.3039, 0.2214, 0.7069, 0.3093],
+ [0.6993, 0.4288, 0.0827, 0.9156, 0.2043],
+ [0.6739, 0.2303, 0.3263, 0.6884, 0.3847],
+ [0.3990, 0.1458, 0.8918, 0.8036, 0.5012],
+ [0.9061, 0.2024, 0.7276, 0.8619, 0.1198],
+ [0.7379, 0.2005, 0.7634, 0.5691, 0.6132],
+ [0.4341, 0.5711, 0.1119, 0.4286, 0.7521],
+ ],
+ ]
+ )
+
+ loss = qn.quantile_l1_loss(prediction, target)
+
+ self.assertAlmostEqual(loss.item(), 0.15836075)
+
+ def test_l1_loss_multi_output(self):
+ qn = QuantileNetwork(10, 3, quantile_count=10)
+
+ prediction = torch.tensor(
+ [
+ [0.3003, 0.8692, 0.4608, 0.7158, 0.2640, 0.3928, 0.4557, 0.4620, 0.1331, 0.6356],
+ [0.8867, 0.1521, 0.5827, 0.0501, 0.4401, 0.7216, 0.6081, 0.5758, 0.2772, 0.6048],
+ [0.0763, 0.1609, 0.1860, 0.9173, 0.2121, 0.1920, 0.8509, 0.8588, 0.3321, 0.7202],
+ [0.8375, 0.5339, 0.4287, 0.9228, 0.8519, 0.0420, 0.5736, 0.9156, 0.4444, 0.2039],
+ [0.0704, 0.1833, 0.0839, 0.9573, 0.9852, 0.4191, 0.3562, 0.7225, 0.8481, 0.2096],
+ [0.4054, 0.8172, 0.8737, 0.2138, 0.4455, 0.7538, 0.1936, 0.9346, 0.8710, 0.0178],
+ [0.2139, 0.6619, 0.6889, 0.5726, 0.0595, 0.3278, 0.7673, 0.0803, 0.0374, 0.9011],
+ [0.2757, 0.0309, 0.8913, 0.0958, 0.1828, 0.9624, 0.6529, 0.7451, 0.9996, 0.8877],
+ [0.0722, 0.4240, 0.0716, 0.3199, 0.5570, 0.1056, 0.5950, 0.9926, 0.2991, 0.7334],
+ [0.0576, 0.6353, 0.5078, 0.4456, 0.9119, 0.6897, 0.1720, 0.5172, 0.9939, 0.5044],
+ [0.6300, 0.2304, 0.4064, 0.9195, 0.3299, 0.8631, 0.5842, 0.6751, 0.2964, 0.1215],
+ [0.7418, 0.5448, 0.7615, 0.6333, 0.9255, 0.1129, 0.0552, 0.4198, 0.9953, 0.7482],
+ [0.9910, 0.7644, 0.7047, 0.1395, 0.3688, 0.7688, 0.8574, 0.3494, 0.6153, 0.1286],
+ [0.2325, 0.7908, 0.3036, 0.4504, 0.3775, 0.6004, 0.0199, 0.9581, 0.8078, 0.8337],
+ [0.4038, 0.8313, 0.5441, 0.4778, 0.5777, 0.0580, 0.5314, 0.5336, 0.0740, 0.0094],
+ [0.9025, 0.5814, 0.4711, 0.2683, 0.4443, 0.5799, 0.6703, 0.2678, 0.7538, 0.1317],
+ [0.6755, 0.5696, 0.3334, 0.9146, 0.6203, 0.2080, 0.0799, 0.0059, 0.8347, 0.1874],
+ [0.0932, 0.0264, 0.9006, 0.3124, 0.3421, 0.8271, 0.3495, 0.2814, 0.9888, 0.5042],
+ [0.4893, 0.3514, 0.2564, 0.8117, 0.3738, 0.9085, 0.3055, 0.1456, 0.3624, 0.4095],
+ [0.0726, 0.2145, 0.6295, 0.7423, 0.1292, 0.7570, 0.4645, 0.0775, 0.1280, 0.7312],
+ [0.8763, 0.5302, 0.8627, 0.0429, 0.2833, 0.4745, 0.6308, 0.2245, 0.2755, 0.6823],
+ [0.9997, 0.3519, 0.0312, 0.1468, 0.5145, 0.0286, 0.6333, 0.1323, 0.2264, 0.9109],
+ [0.7742, 0.4857, 0.0413, 0.4523, 0.6847, 0.5774, 0.9478, 0.5861, 0.9834, 0.9437],
+ [0.7590, 0.5697, 0.7509, 0.3562, 0.9926, 0.3380, 0.0337, 0.7871, 0.1351, 0.9184],
+ [0.5701, 0.0234, 0.8088, 0.0681, 0.7090, 0.5925, 0.5266, 0.7198, 0.4121, 0.0268],
+ [0.5377, 0.1420, 0.2649, 0.0885, 0.1987, 0.1475, 0.1562, 0.2283, 0.9447, 0.4679],
+ [0.0306, 0.9763, 0.1234, 0.5009, 0.8800, 0.9409, 0.3525, 0.7264, 0.2209, 0.1436],
+ [0.2492, 0.4041, 0.9044, 0.3730, 0.3152, 0.7515, 0.2614, 0.9726, 0.6402, 0.5211],
+ [0.8626, 0.2828, 0.6946, 0.7066, 0.4395, 0.3015, 0.2643, 0.4421, 0.6036, 0.9009],
+ [0.7721, 0.1706, 0.7043, 0.4097, 0.7685, 0.3818, 0.1468, 0.6452, 0.1102, 0.1826],
+ [0.7156, 0.1795, 0.5574, 0.9478, 0.0058, 0.8037, 0.8712, 0.7730, 0.5638, 0.5843],
+ [0.8775, 0.6133, 0.4118, 0.3038, 0.2612, 0.2424, 0.8960, 0.8194, 0.3588, 0.3198],
+ ]
+ )
+
+ target = torch.tensor(
+ [
+ [0.0986, 0.4029, 0.3110, 0.9976, 0.5668, 0.2658, 0.0660, 0.8492, 0.7872, 0.6368],
+ [0.3556, 0.9007, 0.0227, 0.7684, 0.0105, 0.9890, 0.7468, 0.0642, 0.5164, 0.1976],
+ [0.1331, 0.0998, 0.0959, 0.5596, 0.5984, 0.3880, 0.8050, 0.8320, 0.8977, 0.3486],
+ [0.3297, 0.8110, 0.2844, 0.4594, 0.0739, 0.2865, 0.2957, 0.9357, 0.9898, 0.4419],
+ [0.0495, 0.2826, 0.8306, 0.2968, 0.5690, 0.7251, 0.5947, 0.7526, 0.5076, 0.6480],
+ [0.0381, 0.8645, 0.7774, 0.9158, 0.9682, 0.5851, 0.0913, 0.8948, 0.1251, 0.1205],
+ [0.9059, 0.2758, 0.1948, 0.2694, 0.0946, 0.4381, 0.4667, 0.2176, 0.3494, 0.6073],
+ [0.1778, 0.8632, 0.3015, 0.2882, 0.4214, 0.2420, 0.8394, 0.1468, 0.9679, 0.6730],
+ [0.2400, 0.4344, 0.9765, 0.6544, 0.6338, 0.3434, 0.4776, 0.7981, 0.2008, 0.2267],
+ [0.5574, 0.8110, 0.0264, 0.4199, 0.8178, 0.8421, 0.8237, 0.2623, 0.8025, 0.9030],
+ [0.8652, 0.2872, 0.9463, 0.5543, 0.4866, 0.2842, 0.6692, 0.2306, 0.3136, 0.4570],
+ [0.0651, 0.8955, 0.7531, 0.9373, 0.0265, 0.0795, 0.7755, 0.1123, 0.1920, 0.3273],
+ [0.9824, 0.4177, 0.2729, 0.9447, 0.3987, 0.5495, 0.3674, 0.8067, 0.8668, 0.2394],
+ [0.4874, 0.3616, 0.7577, 0.6439, 0.2927, 0.8110, 0.6821, 0.0702, 0.5514, 0.7358],
+ [0.3627, 0.6392, 0.9085, 0.3646, 0.6051, 0.0586, 0.8763, 0.3899, 0.3242, 0.4598],
+ [0.0167, 0.0558, 0.3862, 0.7017, 0.0403, 0.6604, 0.9992, 0.2337, 0.5128, 0.1959],
+ [0.7774, 0.9201, 0.0405, 0.7894, 0.1406, 0.2458, 0.2616, 0.8787, 0.8158, 0.8591],
+ [0.3225, 0.9827, 0.4032, 0.2621, 0.7949, 0.9796, 0.9480, 0.3353, 0.1430, 0.5747],
+ [0.4734, 0.8714, 0.9320, 0.4265, 0.7765, 0.6980, 0.1587, 0.8784, 0.7119, 0.5141],
+ [0.7263, 0.4754, 0.8234, 0.0649, 0.4343, 0.5201, 0.8274, 0.9632, 0.3525, 0.8893],
+ [0.3324, 0.0142, 0.7222, 0.5026, 0.6011, 0.9275, 0.9351, 0.9236, 0.2621, 0.0768],
+ [0.8456, 0.1005, 0.5550, 0.0586, 0.3811, 0.0168, 0.9724, 0.9225, 0.7242, 0.0678],
+ [0.2167, 0.5423, 0.9059, 0.3320, 0.4026, 0.2128, 0.4562, 0.3564, 0.2573, 0.1076],
+ [0.8385, 0.2233, 0.0736, 0.3407, 0.4702, 0.1668, 0.5174, 0.4154, 0.4407, 0.1843],
+ [0.1828, 0.5321, 0.6651, 0.4108, 0.5736, 0.4012, 0.0434, 0.0034, 0.9282, 0.3111],
+ [0.1754, 0.8750, 0.6629, 0.7052, 0.9739, 0.7441, 0.8954, 0.9273, 0.3836, 0.5735],
+ [0.5586, 0.0381, 0.1493, 0.8575, 0.9351, 0.5222, 0.5600, 0.2369, 0.9217, 0.2545],
+ [0.1054, 0.8020, 0.8463, 0.6495, 0.3011, 0.3734, 0.7263, 0.8736, 0.9258, 0.5804],
+ [0.7614, 0.4748, 0.6588, 0.7717, 0.9811, 0.1659, 0.7851, 0.2135, 0.1767, 0.6724],
+ [0.7655, 0.8571, 0.4224, 0.9397, 0.1363, 0.9431, 0.9326, 0.3762, 0.1077, 0.9514],
+ [0.4115, 0.2169, 0.1340, 0.6564, 0.9989, 0.8068, 0.0387, 0.5064, 0.9964, 0.9427],
+ [0.5760, 0.2967, 0.3891, 0.6596, 0.8037, 0.1060, 0.0102, 0.8672, 0.5922, 0.6684],
+ ]
+ )
+
+ loss = qn.quantile_l1_loss(prediction, target)
+
+ self.assertAlmostEqual(loss.item(), 0.17235948)
+
+ def test_quantile_huber_loss(self):
+ qn = QuantileNetwork(10, 1, quantile_count=5)
+
+ prediction = torch.tensor(
+ [
+ [0.8510, 0.2329, 0.4244, 0.5241, 0.2144],
+ [0.7693, 0.2522, 0.3909, 0.0858, 0.7914],
+ [0.8701, 0.2144, 0.9661, 0.9975, 0.5043],
+ [0.2653, 0.6951, 0.9787, 0.2244, 0.0430],
+ [0.7907, 0.5209, 0.7276, 0.1735, 0.2757],
+ [0.1696, 0.7167, 0.6363, 0.2188, 0.7025],
+ [0.0445, 0.6008, 0.5334, 0.1838, 0.7387],
+ [0.4934, 0.5117, 0.4488, 0.0591, 0.6442],
+ ]
+ )
+ target = torch.tensor(
+ [
+ [0.3918, 0.8979, 0.4347, 0.1076, 0.5303],
+ [0.5449, 0.9974, 0.3197, 0.8686, 0.0631],
+ [0.7397, 0.7734, 0.6559, 0.3020, 0.7229],
+ [0.9519, 0.8138, 0.1502, 0.3445, 0.3356],
+ [0.8970, 0.0910, 0.7536, 0.6069, 0.2556],
+ [0.1741, 0.6863, 0.7142, 0.2911, 0.3142],
+ [0.8835, 0.0215, 0.4774, 0.5362, 0.4998],
+ [0.8037, 0.8269, 0.5518, 0.4368, 0.5323],
+ ]
+ )
+
+ loss = qn.quantile_huber_loss(prediction, target)
+
+ self.assertAlmostEqual(loss.item(), 0.04035041)
+
+ def test_sample_energy_loss(self):
+ qn = QuantileNetwork(10, 1, quantile_count=5)
+
+ prediction = torch.tensor(
+ [
+ [0.9813, 0.5331, 0.3298, 0.2428, 0.0737],
+ [0.5442, 0.9623, 0.6070, 0.9360, 0.1145],
+ [0.3642, 0.0887, 0.1696, 0.8027, 0.7121],
+ [0.2005, 0.9889, 0.4350, 0.0301, 0.4546],
+ [0.8360, 0.6766, 0.2257, 0.7589, 0.3443],
+ [0.0835, 0.1747, 0.1734, 0.6668, 0.4522],
+ [0.0851, 0.3146, 0.0316, 0.2250, 0.5729],
+ [0.7725, 0.4596, 0.2495, 0.3633, 0.6340],
+ ]
+ )
+ target = torch.tensor(
+ [
+ [0.5365, 0.1495, 0.8120, 0.2595, 0.1409],
+ [0.7784, 0.7070, 0.9066, 0.0123, 0.5587],
+ [0.9097, 0.0773, 0.9430, 0.2747, 0.1912],
+ [0.2307, 0.5068, 0.4624, 0.6708, 0.2844],
+ [0.3356, 0.5885, 0.2484, 0.8468, 0.1833],
+ [0.3354, 0.8831, 0.3489, 0.7165, 0.7953],
+ [0.7577, 0.8578, 0.2735, 0.1029, 0.5621],
+ [0.9124, 0.3476, 0.2012, 0.5830, 0.4615],
+ ]
+ )
+
+ loss = qn.sample_energy_loss(prediction, target)
+
+ self.assertAlmostEqual(loss.item(), 0.09165202)
+
+ def test_cvar(self):
+ qn = QuantileNetwork(10, 1, quantile_count=5)
+ measure = qn.measures[qn.measure_cvar](qn, 0.5)
+
+ # Quantiles for 3 agents
+ input = torch.tensor(
+ [
+ [0.1056, 0.0609, 0.3523, 0.3033, 0.1779],
+ [0.2049, 0.1425, 0.0767, 0.1868, 0.3891],
+ [0.1899, 0.1527, 0.2420, 0.2623, 0.1532],
+ ]
+ )
+ correct_output = torch.tensor(
+ [
+ (0.4 * 0.0609 + 0.4 * 0.1056 + 0.2 * 0.1779),
+ (0.4 * 0.0767 + 0.4 * 0.1425 + 0.2 * 0.1868),
+ (0.4 * 0.1527 + 0.4 * 0.1532 + 0.2 * 0.1899),
+ ]
+ )
+
+ computed_output = measure(input)
+
+ self.assertTrue(torch.isclose(computed_output, correct_output).all())
+
+ def test_cvar_adaptive(self):
+ qn = QuantileNetwork(10, 1, quantile_count=5)
+
+ input = torch.tensor(
+ [
+ [0.95, 0.21, 0.27, 0.26, 0.19],
+ [0.38, 0.34, 0.18, 0.32, 0.97],
+ [0.70, 0.24, 0.38, 0.89, 0.96],
+ ]
+ )
+ confidence_levels = torch.tensor([0.1, 0.7, 0.9])
+ correct_output = torch.tensor(
+ [
+ 0.19,
+ (0.18 / 3.5 + 0.32 / 3.5 + 0.34 / 3.5 + 0.38 / 7.0),
+ (0.24 / 4.5 + 0.38 / 4.5 + 0.70 / 4.5 + 0.89 / 4.5 + 0.96 / 9.0),
+ ]
+ )
+
+ measure = qn.measures[qn.measure_cvar](qn, confidence_levels)
+ computed_output = measure(input)
+
+ self.assertTrue(torch.isclose(computed_output, correct_output).all())
diff --git a/tests/test_trajectory_conversion.py b/tests/test_trajectory_conversion.py
new file mode 100644
index 0000000..796407c
--- /dev/null
+++ b/tests/test_trajectory_conversion.py
@@ -0,0 +1,33 @@
+import torch
+import unittest
+
+from rsl_rl.utils.recurrency import trajectories_to_transitions, transitions_to_trajectories
+
+
+class TrajectoryConversionTest(unittest.TestCase):
+ def test_basic_conversion(self):
+ input = torch.rand(128, 24)
+ dones = (torch.rand(128, 24) > 0.8).float()
+
+ trajectories, data = transitions_to_trajectories(input, dones)
+ transitions = trajectories_to_transitions(trajectories, data)
+
+ self.assertTrue(torch.allclose(input, transitions))
+
+ def test_2d_observations(self):
+ input = torch.rand(128, 24, 32)
+ dones = (torch.rand(128, 24) > 0.8).float()
+
+ trajectories, data = transitions_to_trajectories(input, dones)
+ transitions = trajectories_to_transitions(trajectories, data)
+
+ self.assertTrue(torch.allclose(input, transitions))
+
+ def test_batch_first(self):
+ input = torch.rand(128, 24, 32)
+ dones = (torch.rand(128, 24) > 0.8).float()
+
+ trajectories, data = transitions_to_trajectories(input, dones, batch_first=True)
+ transitions = trajectories_to_transitions(trajectories, data)
+
+ self.assertTrue(torch.allclose(input, transitions))
diff --git a/tests/test_transformer.py b/tests/test_transformer.py
new file mode 100644
index 0000000..1a7ebcc
--- /dev/null
+++ b/tests/test_transformer.py
@@ -0,0 +1,58 @@
+import unittest
+import torch
+
+from rsl_rl.modules import Transformer # Assuming the Transformer class is in a module named my_module
+
+
+class TestTransformer(unittest.TestCase):
+ def setUp(self):
+ self.input_size = 9
+ self.output_size = 12
+ self.hidden_size = 64
+ self.block_count = 2
+ self.context_length = 32
+ self.head_count = 4
+ self.batch_size = 10
+ self.sequence_length = 16
+
+ self.transformer = Transformer(
+ self.input_size, self.output_size, self.hidden_size, self.block_count, self.context_length, self.head_count
+ )
+
+ def test_num_layers(self):
+ self.assertEqual(self.transformer.num_layers, self.context_length // 2)
+
+ def test_reset_hidden_state(self):
+ hidden_state = self.transformer.reset_hidden_state(self.batch_size)
+ self.assertIsInstance(hidden_state, tuple)
+ self.assertEqual(len(hidden_state), 2)
+ self.assertTrue(
+ torch.equal(hidden_state[0], torch.zeros((self.transformer.num_layers, self.batch_size, self.hidden_size)))
+ )
+ self.assertTrue(
+ torch.equal(hidden_state[1], torch.zeros((self.transformer.num_layers, self.batch_size, self.hidden_size)))
+ )
+
+ def test_step(self):
+ x = torch.rand(self.sequence_length, self.batch_size, self.input_size)
+ context = torch.rand(self.context_length, self.batch_size, self.hidden_size)
+
+ out, new_context = self.transformer.step(x, context)
+
+ self.assertEqual(out.shape, (self.sequence_length, self.batch_size, self.output_size))
+ self.assertEqual(new_context.shape, (self.context_length, self.batch_size, self.hidden_size))
+
+ def test_forward(self):
+ x = torch.rand(self.sequence_length, self.batch_size, self.input_size)
+ hidden_state = self.transformer.reset_hidden_state(self.batch_size)
+
+ out, new_hidden_state = self.transformer.forward(x, hidden_state)
+
+ self.assertEqual(out.shape, (self.sequence_length, self.batch_size, self.output_size))
+ self.assertEqual(len(new_hidden_state), 2)
+ self.assertEqual(new_hidden_state[0].shape, (self.transformer.num_layers, self.batch_size, self.hidden_size))
+ self.assertEqual(new_hidden_state[1].shape, (self.transformer.num_layers, self.batch_size, self.hidden_size))
+
+
+if __name__ == "__main__":
+ unittest.main()