From 96dd4929c5724de5c95de25021ba384cce09d634 Mon Sep 17 00:00:00 2001 From: Lukas Schneider Date: Tue, 12 Dec 2023 18:32:21 +0100 Subject: [PATCH] added rewrite of rsl_rl for supporting additional algorithms --- .gitignore | 11 + CONTRIBUTORS.md | 1 + README.md | 77 +-- config/dummy_config.yaml | 48 -- docs/Makefile | 20 + docs/conf.py | 32 ++ docs/index.rst | 20 + docs/make.bat | 35 ++ examples/__init__.py | 0 examples/benchmark.py | 92 ++++ examples/example.py | 26 + examples/hyperparams/__init__.py | 7 + examples/hyperparams/dppo.py | 202 +++++++ examples/hyperparams/ppo.py | 226 ++++++++ examples/tune.py | 103 ++++ examples/tune_cfg.py | 134 +++++ examples/wandb_config.example.py | 4 + rsl_rl/__init__.py | 5 - rsl_rl/algorithms/__init__.py | 14 +- rsl_rl/algorithms/actor_critic.py | 371 +++++++++++++ rsl_rl/algorithms/agent.py | 197 +++++++ rsl_rl/algorithms/d4pg.py | 168 ++++++ rsl_rl/algorithms/ddpg.py | 125 +++++ rsl_rl/algorithms/dpg.py | 49 ++ rsl_rl/algorithms/dppo.py | 327 ++++++++++++ rsl_rl/algorithms/dsac.py | 75 +++ rsl_rl/algorithms/dtd3.py | 57 ++ rsl_rl/algorithms/hybrid.py | 193 +++++++ rsl_rl/algorithms/ppo.py | 501 ++++++++++++------ rsl_rl/algorithms/sac.py | 319 +++++++++++ rsl_rl/algorithms/td3.py | 198 +++++++ rsl_rl/distributions/__init__.py | 2 + rsl_rl/distributions/distribution.py | 18 + rsl_rl/distributions/quantile_distribution.py | 13 + rsl_rl/env/__init__.py | 1 - rsl_rl/env/gym_env.py | 120 +++++ rsl_rl/env/pole_balancing.py | 137 +++++ rsl_rl/env/pomdp.py | 97 ++++ rsl_rl/env/rslgym_env.py | 44 ++ rsl_rl/env/vec_env.py | 99 ++-- rsl_rl/modules/__init__.py | 25 +- rsl_rl/modules/actor_critic.py | 136 ----- rsl_rl/modules/actor_critic_recurrent.py | 97 ---- rsl_rl/modules/categorical_network.py | 105 ++++ rsl_rl/modules/gaussian_chimera_network.py | 86 +++ rsl_rl/modules/gaussian_network.py | 37 ++ rsl_rl/modules/implicit_quantile_network.py | 168 ++++++ rsl_rl/modules/network.py | 210 ++++++++ rsl_rl/modules/normalizer.py | 61 ++- rsl_rl/modules/quantile_network.py | 333 ++++++++++++ rsl_rl/modules/transformer.py | 150 ++++++ rsl_rl/modules/utils.py | 32 ++ rsl_rl/runners/__init__.py | 5 +- rsl_rl/runners/callbacks.py | 79 +++ rsl_rl/runners/legacy_runner.py | 136 +++++ rsl_rl/runners/on_policy_runner.py | 304 ----------- rsl_rl/runners/runner.py | 498 +++++++++++++++++ rsl_rl/storage/__init__.py | 3 +- rsl_rl/storage/replay_storage.py | 147 +++++ rsl_rl/storage/rollout_storage.py | 275 +++------- rsl_rl/storage/storage.py | 47 ++ rsl_rl/utils/__init__.py | 5 +- rsl_rl/utils/benchmarkable.py | 111 ++++ rsl_rl/utils/neptune_utils.py | 20 +- rsl_rl/utils/recurrency.py | 69 +++ rsl_rl/utils/serializable.py | 43 ++ rsl_rl/utils/utils.py | 90 ++-- rsl_rl/utils/wandb_utils.py | 17 +- setup.py | 20 +- tests/test_algorithms.py | 169 ++++++ tests/test_dpg.py | 126 +++++ tests/test_dppo.py | 278 ++++++++++ tests/test_dppo_iqn.py | 171 ++++++ tests/test_dppo_recurrency.py | 170 ++++++ tests/test_ppo.py | 99 ++++ tests/test_ppo_recurrency.py | 144 +++++ tests/test_quantile_network.py | 286 ++++++++++ tests/test_trajectory_conversion.py | 33 ++ tests/test_transformer.py | 58 ++ 79 files changed, 7844 insertions(+), 1167 deletions(-) delete mode 100644 config/dummy_config.yaml create mode 100644 docs/Makefile create mode 100644 docs/conf.py create mode 100644 docs/index.rst create mode 100644 docs/make.bat create mode 100644 examples/__init__.py create mode 100644 examples/benchmark.py create mode 100644 examples/example.py create mode 100644 examples/hyperparams/__init__.py create mode 100644 examples/hyperparams/dppo.py create mode 100644 examples/hyperparams/ppo.py create mode 100644 examples/tune.py create mode 100644 examples/tune_cfg.py create mode 100644 examples/wandb_config.example.py create mode 100644 rsl_rl/algorithms/actor_critic.py create mode 100644 rsl_rl/algorithms/agent.py create mode 100644 rsl_rl/algorithms/d4pg.py create mode 100644 rsl_rl/algorithms/ddpg.py create mode 100644 rsl_rl/algorithms/dpg.py create mode 100644 rsl_rl/algorithms/dppo.py create mode 100644 rsl_rl/algorithms/dsac.py create mode 100644 rsl_rl/algorithms/dtd3.py create mode 100644 rsl_rl/algorithms/hybrid.py create mode 100644 rsl_rl/algorithms/sac.py create mode 100644 rsl_rl/algorithms/td3.py create mode 100644 rsl_rl/distributions/__init__.py create mode 100644 rsl_rl/distributions/distribution.py create mode 100644 rsl_rl/distributions/quantile_distribution.py create mode 100644 rsl_rl/env/gym_env.py create mode 100644 rsl_rl/env/pole_balancing.py create mode 100644 rsl_rl/env/pomdp.py create mode 100644 rsl_rl/env/rslgym_env.py delete mode 100644 rsl_rl/modules/actor_critic.py delete mode 100644 rsl_rl/modules/actor_critic_recurrent.py create mode 100644 rsl_rl/modules/categorical_network.py create mode 100644 rsl_rl/modules/gaussian_chimera_network.py create mode 100644 rsl_rl/modules/gaussian_network.py create mode 100644 rsl_rl/modules/implicit_quantile_network.py create mode 100644 rsl_rl/modules/network.py create mode 100644 rsl_rl/modules/quantile_network.py create mode 100644 rsl_rl/modules/transformer.py create mode 100644 rsl_rl/modules/utils.py create mode 100644 rsl_rl/runners/callbacks.py create mode 100644 rsl_rl/runners/legacy_runner.py delete mode 100644 rsl_rl/runners/on_policy_runner.py create mode 100644 rsl_rl/runners/runner.py create mode 100644 rsl_rl/storage/replay_storage.py create mode 100644 rsl_rl/storage/storage.py create mode 100644 rsl_rl/utils/benchmarkable.py create mode 100644 rsl_rl/utils/recurrency.py create mode 100644 rsl_rl/utils/serializable.py create mode 100644 tests/test_algorithms.py create mode 100644 tests/test_dpg.py create mode 100644 tests/test_dppo.py create mode 100644 tests/test_dppo_iqn.py create mode 100644 tests/test_dppo_recurrency.py create mode 100644 tests/test_ppo.py create mode 100644 tests/test_ppo_recurrency.py create mode 100644 tests/test_quantile_network.py create mode 100644 tests/test_trajectory_conversion.py create mode 100644 tests/test_transformer.py diff --git a/.gitignore b/.gitignore index 6009f71..e47297e 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,17 @@ # cache __pycache__ .pytest_cache +wandb/ # vs code .vscode + +# data +videos/ + +# secrets +examples/wandb_config.py + +# docs +docs/_build +docs/source diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index e3153d9..54f9094 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -25,6 +25,7 @@ Please keep the lists sorted alphabetically. * Eric Vollenweider * Fabian Jenelten * Lorenzo Terenzi +* Lukas Schneider * Marko Bjelonic * Matthijs van der Boon * Mayank Mittal diff --git a/README.md b/README.md index 6b5c2d8..734376a 100644 --- a/README.md +++ b/README.md @@ -1,57 +1,64 @@ # RSL RL Fast and simple implementation of RL algorithms, designed to run fully on GPU. -This code is an evolution of `rl-pytorch` provided with NVIDIA's Isaac GYM. -Only PPO is implemented for now. More algorithms will be added later. -Contributions are welcome. +Currently, the following algorithms are implemented: +- Distributed Distributional DDPG (D4PG) +- Deep Deterministic Policy Gradient (DDPG) +- Distributional PPO (DPPO) +- Distributional Soft Actor Critic (DSAC) +- Proximal Policy Optimization (PPO) +- Soft Actor Critic (SAC) +- Twin Delayed DDPG (TD3) -**Maintainer**: David Hoeller and Nikita Rudin
+**Maintainer**: David Hoeller, Nikita Rudin
**Affiliation**: Robotic Systems Lab, ETH Zurich & NVIDIA
**Contact**: rudinn@ethz.ch -## Setup +## Installation -Following are the instructions to setup the repository for your workspace: +To install the package, run the following command in the root directory of the repository: ```bash -git clone https://github.com/leggedrobotics/rsl_rl -cd rsl_rl -pip install -e . +$ pip3 install -e . ``` -The framework supports the following logging frameworks which can be configured through `logger`: +Examples can be run from the `examples/` directory. +The example directory also include hyperparameters tuned for some gym environments. +These are automatically loaded when running the example. +Videos of the trained policies are periodically saved to the `videos/` directory. -* Tensorboard: https://www.tensorflow.org/tensorboard/ -* Weights & Biases: https://wandb.ai/site -* Neptune: https://docs.neptune.ai/ +```bash +$ python3 examples/example.py +``` -For a demo configuration of the PPO, please check: [dummy_config.yaml](config/dummy_config.yaml) file. +To run gym mujoco environments, you need a working installation of the mujoco simulator and [mujoco_py](https://github.com/openai/mujoco-py). +## Tests + +The repository contains a set of tests to ensure that the algorithms are working as expected. +To run the tests, simply execute: + +```bash +$ cd tests/ && python -m unittest +``` + +## Documentation + +To generate documentation, run the following command in the root directory of the repository: + +```bash +$ pip3 install sphinx sphinx-rtd-theme +$ sphinx-apidoc -o docs/source . ./examples +$ cd docs/ && make html +``` ## Contribution Guidelines -For documentation, we adopt the [Google Style Guide](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) for docstrings. We use [Sphinx](https://www.sphinx-doc.org/en/master/) for generating the documentation. Please make sure that your code is well-documented and follows the guidelines. - -We use the following tools for maintaining code quality: - -- [pre-commit](https://pre-commit.com/): Runs a list of formatters and linters over the codebase. -- [black](https://black.readthedocs.io/en/stable/): The uncompromising code formatter. -- [flake8](https://flake8.pycqa.org/en/latest/): A wrapper around PyFlakes, pycodestyle, and McCabe complexity checker. - -Please check [here](https://pre-commit.com/#install) for instructions to set these up. To run over the entire repository, please execute the following command in the terminal: - +We use [`black`](https://github.com/psf/black) formatter for formatting the python code. +You should [configure `black` with VSCode](https://dev.to/adamlombard/how-to-use-the-black-python-code-formatter-in-vscode-3lo0) or you can manually format files with: ```bash -# for installation (only once) -pre-commit install -# for running -pre-commit run --all-files +$ pip install black +$ black --line-length 120 . ``` - -### Useful Links - -Environment repositories using the framework: - -* `Legged-Gym` (built on top of NVIDIA Isaac Gym): https://leggedrobotics.github.io/legged_gym/ -* `Orbit` (built on top of NVIDIA Isaac Sim): https://isaac-orbit.github.io/ diff --git a/config/dummy_config.yaml b/config/dummy_config.yaml deleted file mode 100644 index aaf5d21..0000000 --- a/config/dummy_config.yaml +++ /dev/null @@ -1,48 +0,0 @@ -algorithm: - class_name: PPO - # training parameters - # -- value function - value_loss_coef: 1.0 - clip_param: 0.2 - use_clipped_value_loss: true - # -- surrogate loss - desired_kl: 0.01 - entropy_coef: 0.01 - gamma: 0.99 - lam: 0.95 - max_grad_norm: 1.0 - # -- training - learning_rate: 0.001 - num_learning_epochs: 5 - num_mini_batches: 4 # mini batch size = num_envs * num_steps / num_mini_batches - schedule: adaptive # adaptive, fixed -policy: - class_name: ActorCritic - # for MLP i.e. `ActorCritic` - activation: elu - actor_hidden_dims: [128, 128, 128] - critic_hidden_dims: [128, 128, 128] - init_noise_std: 1.0 - # only needed for `ActorCriticRecurrent` - # rnn_type: 'lstm' - # rnn_hidden_size: 512 - # rnn_num_layers: 1 -runner: - num_steps_per_env: 24 # number of steps per environment per iteration - max_iterations: 1500 # number of policy updates - empirical_normalization: false - # -- logging parameters - save_interval: 50 # check for potential saves every `save_interval` iterations - experiment_name: walking_experiment - run_name: "" - # -- logging writer - logger: tensorboard # tensorboard, neptune, wandb - neptune_project: legged_gym - wandb_project: legged_gym - # -- load and resuming - resume: false - load_run: -1 # -1 means load latest run - resume_path: null # updated from load_run and checkpoint - checkpoint: -1 # -1 means load latest checkpoint -runner_class_name: OnPolicyRunner -seed: 1 diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..d4bb2cb --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..4dbe906 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,32 @@ +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +import os +import sys + +sys.path.insert(0, os.path.abspath("..")) + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +project = "rsl_rl" +copyright = "2023, Lukas Schneider" +author = "Lukas Schneider" +release = "1.0.0" + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + +extensions = ["sphinx.ext.autodoc", "sphinx.ext.napoleon"] + +templates_path = ["_templates"] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] + + +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output + +html_theme = "sphinx_rtd_theme" +html_static_path = ["_static"] diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 0000000..4af7326 --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,20 @@ +.. rsl_rl documentation master file, created by + sphinx-quickstart on Tue Jul 4 16:39:24 2023. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Welcome to rsl_rl's documentation! +================================== + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + + + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000..32bb245 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=. +set BUILDDIR=_build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/benchmark.py b/examples/benchmark.py new file mode 100644 index 0000000..3bba830 --- /dev/null +++ b/examples/benchmark.py @@ -0,0 +1,92 @@ +import numpy as np +import os +import torch +import wandb + +from rsl_rl.algorithms import * +from rsl_rl.env.gym_env import GymEnv +from rsl_rl.runners.runner import Runner +from rsl_rl.runners.callbacks import make_wandb_cb + +from hyperparams import hyperparams +from wandb_config import WANDB_API_KEY, WANDB_ENTITY + + +ALGORITHMS = [PPO, DPPO] +ENVIRONMENTS = ["BipedalWalker-v3"] +ENVIRONMENT_KWARGS = [{}] +EXPERIMENT_DIR = os.environ.get("EXPERIMENT_DIRECTORY", "./") +EXPERIMENT_NAME = os.environ.get("EXPERIMENT_NAME", "benchmark") +DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" +RENDER_VIDEO = False +RETURN_EPOCHS = 100 # Number of epochs to average return over +LOG_WANDB = True +RUNS = 3 +TRAIN_TIMEOUT = 60 * 10 # Training time (in seconds) +TRAIN_ENV_STEPS = None # Number of training environment steps + + +os.environ["WANDB_API_KEY"] = WANDB_API_KEY + + +def run(alg_class, env_name, env_kwargs={}): + try: + hp = hyperparams[alg_class.__name__][env_name] + except KeyError: + print("No hyperparameters found. Using default values.") + hp = dict(agent_kwargs={}, env_kwargs={"environment_count": 1}, runner_kwargs={"num_steps_per_env": 1}) + + agent_kwargs = dict(device=DEVICE, **hp["agent_kwargs"]) + env_kwargs = dict(name=env_name, gym_kwargs=env_kwargs, **hp["env_kwargs"]) + runner_kwargs = dict(device=DEVICE, **hp["runner_kwargs"]) + + learn_steps = ( + None + if TRAIN_ENV_STEPS is None + else int(np.ceil(TRAIN_ENV_STEPS / (env_kwargs["environment_count"] * runner_kwargs["num_steps_per_env"]))) + ) + learn_timeout = None if TRAIN_TIMEOUT is None else TRAIN_TIMEOUT + + video_directory = f"{EXPERIMENT_DIR}/{EXPERIMENT_NAME}/videos/{env_name}/{alg_class.__name__}" + save_video_cb = ( + lambda ep, file: wandb.log({f"video-{ep}": wandb.Video(file, fps=4, format="mp4")}) if LOG_WANDB else None + ) + env = GymEnv(**env_kwargs, draw=RENDER_VIDEO, draw_cb=save_video_cb, draw_directory=video_directory) + agent = alg_class(env, **agent_kwargs) + + config = dict( + agent_kwargs=agent_kwargs, + env_kwargs=env_kwargs, + learn_steps=learn_steps, + learn_timeout=learn_timeout, + runner_kwargs=runner_kwargs, + ) + wandb_learn_config = dict( + config=config, + entity=WANDB_ENTITY, + group=f"{alg_class.__name__}_{env_name}", + project="rsl_rl-benchmark", + tags=[alg_class.__name__, env_name, "train"], + ) + + runner = Runner(env, agent, **runner_kwargs) + runner._learn_cb = [lambda *args, **kwargs: Runner._log(*args, prefix=f"{alg_class.__name__}_{env_name}", **kwargs)] + if LOG_WANDB: + runner._learn_cb.append(make_wandb_cb(wandb_learn_config)) + + runner.learn(iterations=learn_steps, timeout=learn_timeout, return_epochs=RETURN_EPOCHS) + + env.close() + + +def main(): + for algorithm in ALGORITHMS: + for i, env_name in enumerate(ENVIRONMENTS): + env_kwargs = ENVIRONMENT_KWARGS[i] + + for _ in range(RUNS): + run(algorithm, env_name, env_kwargs=env_kwargs) + + +if __name__ == "__main__": + main() diff --git a/examples/example.py b/examples/example.py new file mode 100644 index 0000000..0392d89 --- /dev/null +++ b/examples/example.py @@ -0,0 +1,26 @@ +import torch + +from rsl_rl.algorithms import * +from rsl_rl.env.gym_env import GymEnv +from rsl_rl.runners.runner import Runner +from hyperparams import hyperparams + + +ALGORITHM = DPPO +DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" +TASK = "BipedalWalker-v3" + + +def main(): + hp = hyperparams[ALGORITHM.__name__][TASK] + + env = GymEnv(name=TASK, device=DEVICE, draw=True, **hp["env_kwargs"]) + agent = ALGORITHM(env, benchmark=True, device=DEVICE, **hp["agent_kwargs"]) + runner = Runner(env, agent, device=DEVICE, **hp["runner_kwargs"]) + runner._learn_cb = [Runner._log] + + runner.learn(5000) + + +if __name__ == "__main__": + main() diff --git a/examples/hyperparams/__init__.py b/examples/hyperparams/__init__.py new file mode 100644 index 0000000..8d9fb1e --- /dev/null +++ b/examples/hyperparams/__init__.py @@ -0,0 +1,7 @@ +from rsl_rl.algorithms import PPO, DPPO +from .dppo import dppo_hyperparams +from .ppo import ppo_hyperparams + +hyperparams = {DPPO.__name__: dppo_hyperparams, PPO.__name__: ppo_hyperparams} + +__all__ = ["hyperparams"] diff --git a/examples/hyperparams/dppo.py b/examples/hyperparams/dppo.py new file mode 100644 index 0000000..5f8ecad --- /dev/null +++ b/examples/hyperparams/dppo.py @@ -0,0 +1,202 @@ +import copy +import numpy as np + +from rsl_rl.algorithms import DPPO +from rsl_rl.modules import QuantileNetwork + +default = dict() +default["env_kwargs"] = dict(environment_count=1) +default["runner_kwargs"] = dict(num_steps_per_env=2048) +default["agent_kwargs"] = dict( + actor_activations=["tanh", "tanh", "linear"], + actor_hidden_dims=[64, 64], + actor_input_normalization=False, + actor_noise_std=np.exp(0.0), + batch_count=(default["env_kwargs"]["environment_count"] * default["runner_kwargs"]["num_steps_per_env"] // 64), + clip_ratio=0.2, + critic_activations=["tanh", "tanh"], + critic_hidden_dims=[64, 64], + critic_input_normalization=False, + entropy_coeff=0.0, + gae_lambda=0.95, + gamma=0.99, + gradient_clip=0.5, + learning_rate=0.0003, + qrdqn_quantile_count=50, + schedule="adaptive", + target_kl=0.01, + value_coeff=0.5, + value_measure=QuantileNetwork.measure_neutral, + value_measure_kwargs={}, +) + +# Parameters optimized for PPO +ant_v4 = copy.deepcopy(default) +ant_v4["env_kwargs"]["environment_count"] = 128 +ant_v4["runner_kwargs"]["num_steps_per_env"] = 64 +ant_v4["agent_kwargs"]["actor_activations"] = ["tanh", "tanh", "linear"] +ant_v4["agent_kwargs"]["actor_hidden_dims"] = [64, 64] +ant_v4["agent_kwargs"]["actor_noise_std"] = 0.2611 +ant_v4["agent_kwargs"]["batch_count"] = 12 +ant_v4["agent_kwargs"]["clip_ratio"] = 0.4 +ant_v4["agent_kwargs"]["critic_activations"] = ["tanh", "tanh"] +ant_v4["agent_kwargs"]["critic_hidden_dims"] = [64, 64] +ant_v4["agent_kwargs"]["entropy_coeff"] = 0.0102 +ant_v4["agent_kwargs"]["gae_lambda"] = 0.92 +ant_v4["agent_kwargs"]["gamma"] = 0.9731 +ant_v4["agent_kwargs"]["gradient_clip"] = 5.0 +ant_v4["agent_kwargs"]["learning_rate"] = 0.8755 +ant_v4["agent_kwargs"]["target_kl"] = 0.1711 +ant_v4["agent_kwargs"]["value_coeff"] = 0.6840 + +""" +Tuned for environment interactions: +[I 2023-01-03 03:11:29,212] Trial 19 finished with value: 0.5272218152693111 and parameters: { + 'env_count': 16, + 'actor_noise_std': 0.7304437880901905, + 'batch_count': 10, + 'clip_ratio': 0.3, + 'entropy_coeff': 0.004236574285220795, + 'gae_lambda': 0.95, + 'gamma': 0.9890074826092162, + 'gradient_clip': 0.9, + 'learning_rate': 0.18594043324129061, + 'steps_per_env': 256, + 'target_kl': 0.05838576142010138, + 'value_coeff': 0.14402022632575992, + 'net_arch': 'small', + 'net_activation': 'relu' +}. Best is trial 19 with value: 0.5272218152693111. +Tuned for training time: +[I 2023-01-08 21:09:06,069] Trial 407 finished with value: 7.497591958940029 and parameters: { + 'actor_noise_std': 0.1907398121300662, + 'batch_count': 3, + 'clip_ratio': 0.1, + 'entropy_coeff': 0.0053458057035692735, + 'env_count': 16, + 'gae_lambda': 0.8, + 'gamma': 0.985000267068182, + 'gradient_clip': 2.0, + 'learning_rate': 0.605956844400053, + 'steps_per_env': 512, + 'target_kl': 0.17611450607281642, + 'value_coeff': 0.46015664905111847, + 'actor_net_arch': 'small', + 'critic_net_arch': 'medium', + 'actor_net_activation': 'relu', + 'critic_net_activation': 'relu', + 'qrdqn_quantile_count': 200, + 'value_measure': 'neutral' +}. Best is trial 407 with value: 7.497591958940029. +""" +bipedal_walker_v3 = copy.deepcopy(default) +bipedal_walker_v3["env_kwargs"]["environment_count"] = 256 +bipedal_walker_v3["runner_kwargs"]["num_steps_per_env"] = 16 +bipedal_walker_v3["agent_kwargs"]["actor_activations"] = ["relu", "relu", "relu", "linear"] +bipedal_walker_v3["agent_kwargs"]["actor_hidden_dims"] = [512, 256, 128] +bipedal_walker_v3["agent_kwargs"]["actor_noise_std"] = 0.8505 +bipedal_walker_v3["agent_kwargs"]["batch_count"] = 10 +bipedal_walker_v3["agent_kwargs"]["clip_ratio"] = 0.1 +bipedal_walker_v3["agent_kwargs"]["critic_activations"] = ["relu", "relu"] +bipedal_walker_v3["agent_kwargs"]["critic_hidden_dims"] = [256, 256] +bipedal_walker_v3["agent_kwargs"]["critic_network"] = DPPO.network_qrdqn +bipedal_walker_v3["agent_kwargs"]["entropy_coeff"] = 0.0917 +bipedal_walker_v3["agent_kwargs"]["gae_lambda"] = 0.95 +bipedal_walker_v3["agent_kwargs"]["gamma"] = 0.9553 +bipedal_walker_v3["agent_kwargs"]["gradient_clip"] = 2.0 +bipedal_walker_v3["agent_kwargs"]["iqn_action_samples"] = 32 +bipedal_walker_v3["agent_kwargs"]["iqn_embedding_size"] = 64 +bipedal_walker_v3["agent_kwargs"]["iqn_feature_layers"] = 1 +bipedal_walker_v3["agent_kwargs"]["iqn_value_samples"] = 8 +bipedal_walker_v3["agent_kwargs"]["learning_rate"] = 0.4762 +bipedal_walker_v3["agent_kwargs"]["qrdqn_quantile_count"] = 200 +bipedal_walker_v3["agent_kwargs"]["recurrent"] = False +bipedal_walker_v3["agent_kwargs"]["target_kl"] = 0.1999 +bipedal_walker_v3["agent_kwargs"]["value_coeff"] = 0.4435 + +""" +[I 2023-01-12 08:01:35,514] Trial 476 finished with value: 5202.960759290059 and parameters: { + 'actor_noise_std': 0.15412869066185989, + 'batch_count': 11, + 'clip_ratio': 0.3, + 'entropy_coeff': 0.036031209302206955, + 'env_count': 128, + 'gae_lambda': 0.92, + 'gamma': 0.973937576989299, + 'gradient_clip': 5.0, + 'learning_rate': 0.1621249118505433, + 'steps_per_env': 128, + 'target_kl': 0.05054738172852222, + 'value_coeff': 0.1647632125820593, + 'actor_net_arch': 'small', + 'critic_net_arch': 'medium', + 'actor_net_activation': 'tanh', + 'critic_net_activation': 'relu', + 'qrdqn_quantile_count': 50, + 'value_measure': 'var-risk-averse' +}. Best is trial 476 with value: 5202.960759290059. +""" +half_cheetah_v4 = copy.deepcopy(default) +half_cheetah_v4["env_kwargs"]["environment_count"] = 128 +half_cheetah_v4["runner_kwargs"]["num_steps_per_env"] = 128 +half_cheetah_v4["agent_kwargs"]["actor_activations"] = ["tanh", "tanh", "linear"] +half_cheetah_v4["agent_kwargs"]["actor_hidden_dims"] = [64, 64] +half_cheetah_v4["agent_kwargs"]["actor_noise_std"] = 0.1541 +half_cheetah_v4["agent_kwargs"]["batch_count"] = 11 +half_cheetah_v4["agent_kwargs"]["clip_ratio"] = 0.3 +half_cheetah_v4["agent_kwargs"]["critic_activations"] = ["relu", "relu"] +half_cheetah_v4["agent_kwargs"]["critic_hidden_dims"] = [256, 256] +half_cheetah_v4["agent_kwargs"]["entropy_coeff"] = 0.03603 +half_cheetah_v4["agent_kwargs"]["gae_lambda"] = 0.92 +half_cheetah_v4["agent_kwargs"]["gamma"] = 0.9739 +half_cheetah_v4["agent_kwargs"]["gradient_clip"] = 5.0 +half_cheetah_v4["agent_kwargs"]["learning_rate"] = 0.1621 +half_cheetah_v4["agent_kwargs"]["qrdqn_quantile_count"] = 50 +half_cheetah_v4["agent_kwargs"]["target_kl"] = 0.0505 +half_cheetah_v4["agent_kwargs"]["value_coeff"] = 0.1648 +half_cheetah_v4["agent_kwargs"]["value_measure"] = QuantileNetwork.measure_percentile +half_cheetah_v4["agent_kwargs"]["value_measure_kwargs"] = dict(confidence_level=0.25) + + +# Parameters optimized for PPO +hopper_v4 = copy.deepcopy(default) +hopper_v4["runner_kwargs"]["num_steps_per_env"] = 128 +hopper_v4["agent_kwargs"]["actor_activations"] = ["relu", "relu", "linear"] +hopper_v4["agent_kwargs"]["actor_hidden_dims"] = [256, 256] +hopper_v4["agent_kwargs"]["actor_noise_std"] = 0.5590 +hopper_v4["agent_kwargs"]["batch_count"] = 15 +hopper_v4["agent_kwargs"]["clip_ratio"] = 0.2 +hopper_v4["agent_kwargs"]["critic_activations"] = ["relu", "relu", "linear"] +hopper_v4["agent_kwargs"]["critic_hidden_dims"] = [32, 32] +hopper_v4["agent_kwargs"]["entropy_coeff"] = 0.03874 +hopper_v4["agent_kwargs"]["gae_lambda"] = 0.98 +hopper_v4["agent_kwargs"]["gamma"] = 0.9890 +hopper_v4["agent_kwargs"]["gradient_clip"] = 1.0 +hopper_v4["agent_kwargs"]["learning_rate"] = 0.3732 +hopper_v4["agent_kwargs"]["value_coeff"] = 0.8163 + +swimmer_v4 = copy.deepcopy(default) +swimmer_v4["agent_kwargs"]["gamma"] = 0.9999 + +walker2d_v4 = copy.deepcopy(default) +walker2d_v4["runner_kwargs"]["num_steps_per_env"] = 512 +walker2d_v4["agent_kwargs"]["batch_count"] = ( + walker2d_v4["env_kwargs"]["environment_count"] * walker2d_v4["runner_kwargs"]["num_steps_per_env"] // 32 +) +walker2d_v4["agent_kwargs"]["clip_ratio"] = 0.1 +walker2d_v4["agent_kwargs"]["entropy_coeff"] = 0.000585045 +walker2d_v4["agent_kwargs"]["gae_lambda"] = 0.95 +walker2d_v4["agent_kwargs"]["gamma"] = 0.99 +walker2d_v4["agent_kwargs"]["gradient_clip"] = 1.0 +walker2d_v4["agent_kwargs"]["learning_rate"] = 5.05041e-05 +walker2d_v4["agent_kwargs"]["value_coeff"] = 0.871923 + +dppo_hyperparams = { + "default": default, + "Ant-v4": ant_v4, + "BipedalWalker-v3": bipedal_walker_v3, + "HalfCheetah-v4": half_cheetah_v4, + "Hopper-v4": hopper_v4, + "Swimmer-v4": swimmer_v4, + "Walker2d-v4": walker2d_v4, +} diff --git a/examples/hyperparams/ppo.py b/examples/hyperparams/ppo.py new file mode 100644 index 0000000..d45dd8e --- /dev/null +++ b/examples/hyperparams/ppo.py @@ -0,0 +1,226 @@ +import copy +import numpy as np + +default = dict() +default["env_kwargs"] = dict(environment_count=1) +default["runner_kwargs"] = dict(num_steps_per_env=2048) +default["agent_kwargs"] = dict( + actor_activations=["tanh", "tanh", "linear"], + actor_hidden_dims=[64, 64], + actor_input_normalization=False, + actor_noise_std=np.exp(0.0), + batch_count=(default["env_kwargs"]["environment_count"] * default["runner_kwargs"]["num_steps_per_env"] // 64), + clip_ratio=0.2, + critic_activations=["tanh", "tanh", "linear"], + critic_hidden_dims=[64, 64], + critic_input_normalization=False, + entropy_coeff=0.0, + gae_lambda=0.95, + gamma=0.99, + gradient_clip=0.5, + learning_rate=0.0003, + schedule="adaptive", + target_kl=0.01, + value_coeff=0.5, +) + +""" +[I 2023-01-09 00:33:02,217] Trial 85 finished with value: 2191.0249068421276 and parameters: { + 'actor_noise_std': 0.2611334861249876, + 'batch_count': 12, + 'clip_ratio': 0.4, + 'entropy_coeff': 0.010204149626344796, + 'env_count': 128, + 'gae_lambda': 0.92, + 'gamma': 0.9730549104215155, + 'gradient_clip': 5.0, + 'learning_rate': 0.8754540531090014, + 'steps_per_env': 64, + 'target_kl': 0.17110535070344035, + 'value_coeff': 0.6840401569818934, + 'actor_net_arch': 'small', + 'critic_net_arch': 'small', + 'actor_net_activation': 'tanh', + 'critic_net_activation': 'tanh' +}. Best is trial 85 with value: 2191.0249068421276. +""" +ant_v3 = copy.deepcopy(default) +ant_v3["env_kwargs"]["environment_count"] = 128 +ant_v3["runner_kwargs"]["num_steps_per_env"] = 64 +ant_v3["agent_kwargs"]["actor_activations"] = ["tanh", "tanh", "linear"] +ant_v3["agent_kwargs"]["actor_hidden_dims"] = [64, 64] +ant_v3["agent_kwargs"]["actor_noise_std"] = 0.2611 +ant_v3["agent_kwargs"]["batch_count"] = 12 +ant_v3["agent_kwargs"]["clip_ratio"] = 0.4 +ant_v3["agent_kwargs"]["critic_activations"] = ["tanh", "tanh", "linear"] +ant_v3["agent_kwargs"]["critic_hidden_dims"] = [64, 64] +ant_v3["agent_kwargs"]["entropy_coeff"] = 0.0102 +ant_v3["agent_kwargs"]["gae_lambda"] = 0.92 +ant_v3["agent_kwargs"]["gamma"] = 0.9731 +ant_v3["agent_kwargs"]["gradient_clip"] = 5.0 +ant_v3["agent_kwargs"]["learning_rate"] = 0.8755 +ant_v3["agent_kwargs"]["target_kl"] = 0.1711 +ant_v3["agent_kwargs"]["value_coeff"] = 0.6840 + +""" +Standard: +[I 2023-01-17 07:43:46,884] Trial 125 finished with value: 150.23491836690064 and parameters: { + 'actor_net_activation': 'relu', + 'actor_net_arch': 'large', + 'actor_noise_std': 0.8504545432069994, + 'batch_count': 10, + 'clip_ratio': 0.1, + 'critic_net_activation': 'relu', + 'critic_net_arch': 'medium', + 'entropy_coeff': 0.0916881539697197, + 'env_count': 256, + 'gae_lambda': 0.95, + 'gamma': 0.955285858564339, + 'gradient_clip': 2.0, + 'learning_rate': 0.4762365866431558, + 'steps_per_env': 16, + 'recurrent': False, + 'target_kl': 0.19991906392721126, + 'value_coeff': 0.4434793554275927 +}. Best is trial 125 with value: 150.23491836690064. +Hardcore: +[I 2023-01-09 06:25:44,000] Trial 262 finished with value: 2.290071208278338 and parameters: { + 'actor_noise_std': 0.2710521003644249, + 'batch_count': 6, + 'clip_ratio': 0.1, + 'entropy_coeff': 0.005105282891378981, + 'env_count': 16, + 'gae_lambda': 1.0, + 'gamma': 0.9718119008688937, + 'gradient_clip': 0.1, + 'learning_rate': 0.4569184610431825, + 'steps_per_env': 256, + 'target_kl': 0.11068348002480229, + 'value_coeff': 0.19453900570701116, + 'actor_net_arch': 'small', + 'critic_net_arch': 'medium', + 'actor_net_activation': 'relu', + 'critic_net_activation': 'relu' +}. Best is trial 262 with value: 2.290071208278338. +""" +bipedal_walker_v3 = copy.deepcopy(default) +bipedal_walker_v3["env_kwargs"]["environment_count"] = 256 +bipedal_walker_v3["runner_kwargs"]["num_steps_per_env"] = 16 +bipedal_walker_v3["agent_kwargs"]["actor_activations"] = ["relu", "relu", "relu", "linear"] +bipedal_walker_v3["agent_kwargs"]["actor_hidden_dims"] = [512, 256, 128] +bipedal_walker_v3["agent_kwargs"]["actor_noise_std"] = 0.8505 +bipedal_walker_v3["agent_kwargs"]["batch_count"] = 10 +bipedal_walker_v3["agent_kwargs"]["clip_ratio"] = 0.1 +bipedal_walker_v3["agent_kwargs"]["critic_activations"] = ["relu", "relu", "linear"] +bipedal_walker_v3["agent_kwargs"]["critic_hidden_dims"] = [256, 256] +bipedal_walker_v3["agent_kwargs"]["entropy_coeff"] = 0.0917 +bipedal_walker_v3["agent_kwargs"]["gae_lambda"] = 0.95 +bipedal_walker_v3["agent_kwargs"]["gamma"] = 0.9553 +bipedal_walker_v3["agent_kwargs"]["gradient_clip"] = 2.0 +bipedal_walker_v3["agent_kwargs"]["learning_rate"] = 0.4762 +bipedal_walker_v3["agent_kwargs"]["target_kl"] = 0.1999 +bipedal_walker_v3["agent_kwargs"]["value_coeff"] = 0.4435 + +""" +[I 2023-01-04 05:57:20,749] Trial 1451 finished with value: 5260.338678148058 and parameters: { + 'env_count': 32, + 'actor_noise_std': 0.3397405098274084, + 'batch_count': 6, + 'clip_ratio': 0.3, + 'entropy_coeff': 0.009392937880259133, + 'gae_lambda': 0.8, + 'gamma': 0.9683403243382301, + 'gradient_clip': 5.0, + 'learning_rate': 0.5985206877398142, + 'steps_per_env': 16, + 'target_kl': 0.027651917189297347, + 'value_coeff': 0.26705235341068373, + 'net_arch': 'medium', + 'net_activation': 'tanh' +}. Best is trial 1451 with value: 5260.338678148058. +""" +half_cheetah_v3 = copy.deepcopy(default) +half_cheetah_v3["env_kwargs"]["environment_count"] = 32 +half_cheetah_v3["runner_kwargs"]["num_steps_per_env"] = 16 +half_cheetah_v3["agent_kwargs"]["actor_activations"] = ["tanh", "tanh", "linear"] +half_cheetah_v3["agent_kwargs"]["actor_hidden_dims"] = [256, 256] +half_cheetah_v3["agent_kwargs"]["actor_noise_std"] = 0.3397 +half_cheetah_v3["agent_kwargs"]["batch_count"] = 6 +half_cheetah_v3["agent_kwargs"]["clip_ratio"] = 0.3 +half_cheetah_v3["agent_kwargs"]["critic_activations"] = ["tanh", "tanh", "linear"] +half_cheetah_v3["agent_kwargs"]["critic_hidden_dims"] = [256, 256] +half_cheetah_v3["agent_kwargs"]["entropy_coeff"] = 0.009393 +half_cheetah_v3["agent_kwargs"]["gae_lambda"] = 0.8 +half_cheetah_v3["agent_kwargs"]["gamma"] = 0.9683 +half_cheetah_v3["agent_kwargs"]["gradient_clip"] = 5.0 +half_cheetah_v3["agent_kwargs"]["learning_rate"] = 0.5985 +half_cheetah_v3["agent_kwargs"]["target_kl"] = 0.02765 +half_cheetah_v3["agent_kwargs"]["value_coeff"] = 0.2671 + +""" +[I 2023-01-08 18:38:51,481] Trial 25 finished with value: 2225.9547948810073 and parameters: { + 'actor_noise_std': 0.5589708917145111, + 'batch_count': 15, + 'clip_ratio': 0.2, + 'entropy_coeff': 0.03874027035272886, + 'env_count': 128, + 'gae_lambda': 0.98, + 'gamma': 0.9879577396280973, + 'gradient_clip': 1.0, + 'learning_rate': 0.3732431793266761, + 'steps_per_env': 128, + 'target_kl': 0.12851506672519566, + 'value_coeff': 0.8162548885723906, + 'actor_net_arch': 'medium', + 'critic_net_arch': 'small', + 'actor_net_activation': 'relu', + 'critic_net_activation': 'relu' +}. Best is trial 25 with value: 2225.9547948810073. +""" +hopper_v3 = copy.deepcopy(default) +half_cheetah_v3["env_kwargs"]["environment_count"] = 128 +hopper_v3["runner_kwargs"]["num_steps_per_env"] = 128 +hopper_v3["agent_kwargs"]["actor_activations"] = ["relu", "relu", "linear"] +hopper_v3["agent_kwargs"]["actor_hidden_dims"] = [256, 256] +hopper_v3["agent_kwargs"]["actor_noise_std"] = 0.5590 +hopper_v3["agent_kwargs"]["batch_count"] = 15 +hopper_v3["agent_kwargs"]["clip_ratio"] = 0.2 +hopper_v3["agent_kwargs"]["critic_activations"] = ["relu", "relu", "linear"] +hopper_v3["agent_kwargs"]["critic_hidden_dims"] = [32, 32] +hopper_v3["agent_kwargs"]["entropy_coeff"] = 0.03874 +hopper_v3["agent_kwargs"]["gae_lambda"] = 0.98 +hopper_v3["agent_kwargs"]["gamma"] = 0.9890 +hopper_v3["agent_kwargs"]["gradient_clip"] = 1.0 +hopper_v3["agent_kwargs"]["learning_rate"] = 0.3732 +hopper_v3["agent_kwargs"]["value_coeff"] = 0.8163 + +swimmer_v3 = copy.deepcopy(default) +swimmer_v3["agent_kwargs"]["gamma"] = 0.9999 + +walker2d_v3 = copy.deepcopy(default) +walker2d_v3["runner_kwargs"]["num_steps_per_env"] = 512 +walker2d_v3["agent_kwargs"]["batch_count"] = ( + walker2d_v3["env_kwargs"]["environment_count"] * walker2d_v3["runner_kwargs"]["num_steps_per_env"] // 32 +) +walker2d_v3["agent_kwargs"]["clip_ratio"] = 0.1 +walker2d_v3["agent_kwargs"]["entropy_coeff"] = 0.000585045 +walker2d_v3["agent_kwargs"]["gae_lambda"] = 0.95 +walker2d_v3["agent_kwargs"]["gamma"] = 0.99 +walker2d_v3["agent_kwargs"]["gradient_clip"] = 1.0 +walker2d_v3["agent_kwargs"]["learning_rate"] = 5.05041e-05 +walker2d_v3["agent_kwargs"]["value_coeff"] = 0.871923 + +ppo_hyperparams = { + "default": default, + "Ant-v3": ant_v3, + "Ant-v4": ant_v3, + "BipedalWalker-v3": bipedal_walker_v3, + "HalfCheetah-v3": half_cheetah_v3, + "HalfCheetah-v4": half_cheetah_v3, + "Hopper-v3": hopper_v3, + "Hopper-v4": hopper_v3, + "Swimmer-v3": swimmer_v3, + "Swimmer-v4": swimmer_v3, + "Walker2d-v3": walker2d_v3, + "Walker2d-v4": walker2d_v3, +} diff --git a/examples/tune.py b/examples/tune.py new file mode 100644 index 0000000..b11c058 --- /dev/null +++ b/examples/tune.py @@ -0,0 +1,103 @@ +from rsl_rl.algorithms import * +from rsl_rl.env.gym_env import GymEnv +from rsl_rl.runners.runner import Runner + +import copy +from datetime import datetime +import numpy as np +import optuna +import os +import random +import torch +from tune_cfg import samplers + + +ALGORITHM = PPO +ENVIRONMENT = "BipedalWalker-v3" +ENVIRONMENT_KWARGS = {} +EVAL_AGENTS = 64 +EVAL_RUNS = 10 +EVAL_STEPS = 1000 +EXPERIMENT_DIR = os.environ.get("EXPERIMENT_DIRECTORY", "./") +EXPERIMENT_NAME = os.environ.get("EXPERIMENT_NAME", f"tune-{ALGORITHM.__name__}-{ENVIRONMENT}") +TRAIN_ITERATIONS = None +TRAIN_TIMEOUT = 60 * 15 # 10 minutes +TRAIN_RUNS = 3 +TRAIN_SEED = None + + +def tune(): + assert TRAIN_RUNS == 1 or TRAIN_SEED is None, "If multiple runs are used, the seed must be None." + + storage = optuna.storages.RDBStorage(url=f"sqlite:///{EXPERIMENT_DIR}/{EXPERIMENT_NAME}.db") + pruner = optuna.pruners.MedianPruner(n_startup_trials=10) + try: + study = optuna.create_study(direction="maximize", pruner=pruner, storage=storage, study_name=EXPERIMENT_NAME) + except Exception: + study = optuna.load_study(pruner=pruner, storage=storage, study_name=EXPERIMENT_NAME) + + study.optimize(objective, n_trials=100) + + +def seed(s=None): + seed = int(datetime.now().timestamp() * 1e6) % 2**32 if s is None else s + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def objective(trial): + seed() + agent_kwargs, env_kwargs, runner_kwargs = samplers[ALGORITHM.__name__](trial) + + evaluations = [] + for instantiation in range(TRAIN_RUNS): + seed(TRAIN_SEED) + + env = GymEnv(ENVIRONMENT, gym_kwargs=ENVIRONMENT_KWARGS, **env_kwargs) + agent = ALGORITHM(env, **agent_kwargs) + runner = Runner(env, agent, **runner_kwargs) + runner._learn_cb = [lambda _, stat: runner._log_progress(stat, prefix=f"learn {instantiation+1}/{TRAIN_RUNS}")] + + eval_env_kwargs = copy.deepcopy(env_kwargs) + eval_env_kwargs["environment_count"] = EVAL_AGENTS + eval_runner = Runner( + GymEnv(ENVIRONMENT, gym_kwargs=ENVIRONMENT_KWARGS, **env_kwargs), + agent, + **runner_kwargs, + ) + eval_runner._eval_cb = [ + lambda _, stat: runner._log_progress(stat, prefix=f"eval {instantiation+1}/{TRAIN_RUNS}") + ] + + try: + runner.learn(TRAIN_ITERATIONS, timeout=TRAIN_TIMEOUT) + except Exception: + raise optuna.TrialPruned() + + intermediate_evaluations = [] + for eval_run in range(EVAL_RUNS): + eval_runner._eval_cb = [lambda _, stat: runner._log_progress(stat, prefix=f"eval {eval_run+1}/{EVAL_RUNS}")] + + seed() + eval_runner.env.reset() + intermediate_evaluations.append(eval_runner.evaluate(steps=EVAL_STEPS)) + eval = np.mean(intermediate_evaluations) + + trial.report(eval, instantiation) + if trial.should_prune(): + raise optuna.TrialPruned() + + evaluations.append(eval) + + evaluation = np.mean(evaluations) + + return evaluation + + +if __name__ == "__main__": + tune() diff --git a/examples/tune_cfg.py b/examples/tune_cfg.py new file mode 100644 index 0000000..1ca78c2 --- /dev/null +++ b/examples/tune_cfg.py @@ -0,0 +1,134 @@ +import torch + +from rsl_rl.algorithms import DPPO, PPO +from rsl_rl.modules import QuantileNetwork + +NETWORKS = {"small": [64, 64], "medium": [256, 256], "large": [512, 256, 128]} + + +def sample_dppo_hyperparams(trial): + actor_net_activation = trial.suggest_categorical("actor_net_activation", ["relu", "tanh"]) + actor_net_arch = trial.suggest_categorical("actor_net_arch", list(NETWORKS.keys())) + actor_noise_std = trial.suggest_float("actor_noise_std", 0.05, 1.0) + batch_count = trial.suggest_int("batch_count", 1, 20) + clip_ratio = trial.suggest_categorical("clip_ratio", [0.1, 0.2, 0.3, 0.4]) + critic_net_activation = trial.suggest_categorical("critic_net_activation", ["relu", "tanh"]) + critic_net_arch = trial.suggest_categorical("critic_net_arch", list(NETWORKS.keys())) + entropy_coeff = trial.suggest_float("entropy_coeff", 0.00000001, 0.1) + env_count = trial.suggest_categorical("env_count", [1, 8, 16, 32, 64, 128, 256, 512]) + gae_lambda = trial.suggest_categorical("gae_lambda", [0.8, 0.9, 0.92, 0.95, 0.98, 0.99, 1.0]) + gamma = trial.suggest_float("gamma", 0.95, 0.999) + gradient_clip = trial.suggest_categorical("gradient_clip", [0.1, 0.3, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 2.0, 5.0]) + learning_rate = trial.suggest_float("learning_rate", 1e-5, 1) + num_steps_per_env = trial.suggest_categorical("steps_per_env", [8, 16, 32, 64, 128, 256, 512, 1024, 2048]) + quantile_count = trial.suggest_categorical("quantile_count", [20, 50, 100, 200]) + recurrent = trial.suggest_categorical("recurrent", [True, False]) + target_kl = trial.suggest_float("target_kl", 0.01, 0.3) + value_coeff = trial.suggest_float("value_coeff", 0.0, 1.0) + value_measure = trial.suggest_categorical( + "value_measure", + ["neutral", "var-risk-averse", "var-risk-seeking", "var-super-risk-averse", "var-super-risk-seeking"], + ) + + actor_net_arch = NETWORKS[actor_net_arch] + critic_net_arch = NETWORKS[critic_net_arch] + value_measure_kwargs = { + "neutral": dict(), + "var-risk-averse": dict(confidence_level=0.25), + "var-risk-seeking": dict(confidence_level=0.75), + "var-super-risk-averse": dict(confidence_level=0.1), + "var-super-risk-seeking": dict(confidence_level=0.9), + }[value_measure] + value_measure = { + "neutral": QuantileNetwork.measure_neutral, + "var-risk-averse": QuantileNetwork.measure_percentile, + "var-risk-seeking": QuantileNetwork.measure_percentile, + "var-super-risk-averse": QuantileNetwork.measure_percentile, + "var-super-risk-seeking": QuantileNetwork.measure_percentile, + }[value_measure] + device = "cuda:0" if env_count * num_steps_per_env > 2048 and torch.cuda.is_available() else "cpu" + + agent_kwargs = dict( + actor_activations=([actor_net_activation] * len(actor_net_arch)) + ["linear"], + actor_hidden_dims=actor_net_arch, + actor_input_normalization=False, + actor_noise_std=actor_noise_std, + batch_count=batch_count, + clip_ratio=clip_ratio, + critic_activations=([critic_net_activation] * len(critic_net_arch)), + critic_hidden_dims=critic_net_arch, + critic_input_normalization=False, + device=device, + entropy_coeff=entropy_coeff, + gae_lambda=gae_lambda, + gamma=gamma, + gradient_clip=gradient_clip, + learning_rate=learning_rate, + quantile_count=quantile_count, + recurrent=recurrent, + schedule="adaptive", + target_kl=target_kl, + value_coeff=value_coeff, + value_measure=value_measure, + value_measure_kwargs=value_measure_kwargs, + ) + env_kwargs = dict(device=device, environment_count=env_count) + runner_kwargs = dict(device=device, num_steps_per_env=num_steps_per_env) + + return agent_kwargs, env_kwargs, runner_kwargs + + +def sample_ppo_hyperparams(trial): + actor_net_activation = trial.suggest_categorical("actor_net_activation", ["relu", "tanh"]) + actor_net_arch = trial.suggest_categorical("actor_net_arch", list(NETWORKS.keys())) + actor_noise_std = trial.suggest_float("actor_noise_std", 0.05, 1.0) + batch_count = trial.suggest_int("batch_count", 1, 20) + clip_ratio = trial.suggest_categorical("clip_ratio", [0.1, 0.2, 0.3, 0.4]) + critic_net_activation = trial.suggest_categorical("critic_net_activation", ["relu", "tanh"]) + critic_net_arch = trial.suggest_categorical("critic_net_arch", list(NETWORKS.keys())) + entropy_coeff = trial.suggest_float("entropy_coeff", 0.00000001, 0.1) + env_count = trial.suggest_categorical("env_count", [1, 8, 16, 32, 64, 128, 256, 512]) + gae_lambda = trial.suggest_categorical("gae_lambda", [0.8, 0.9, 0.92, 0.95, 0.98, 0.99, 1.0]) + gamma = trial.suggest_float("gamma", 0.95, 0.999) + gradient_clip = trial.suggest_categorical("gradient_clip", [0.1, 0.3, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 2.0, 5.0]) + learning_rate = trial.suggest_float("learning_rate", 1e-5, 1) + num_steps_per_env = trial.suggest_categorical("steps_per_env", [8, 16, 32, 64, 128, 256, 512, 1024, 2048]) + recurrent = trial.suggest_categorical("recurrent", [True, False]) + target_kl = trial.suggest_float("target_kl", 0.01, 0.3) + value_coeff = trial.suggest_float("value_coeff", 0.0, 1.0) + + actor_net_arch = NETWORKS[actor_net_arch] + critic_net_arch = NETWORKS[critic_net_arch] + device = "cuda:0" if env_count * num_steps_per_env > 2048 and torch.cuda.is_available() else "cpu" + + agent_kwargs = dict( + actor_activations=([actor_net_activation] * len(actor_net_arch)) + ["linear"], + actor_hidden_dims=actor_net_arch, + actor_input_normalization=False, + actor_noise_std=actor_noise_std, + batch_count=batch_count, + clip_ratio=clip_ratio, + critic_activations=([critic_net_activation] * len(critic_net_arch)) + ["linear"], + critic_hidden_dims=critic_net_arch, + critic_input_normalization=False, + device=device, + entropy_coeff=entropy_coeff, + gae_lambda=gae_lambda, + gamma=gamma, + gradient_clip=gradient_clip, + learning_rate=learning_rate, + recurrent=recurrent, + schedule="adaptive", + target_kl=target_kl, + value_coeff=value_coeff, + ) + env_kwargs = dict(device=device, environment_count=env_count) + runner_kwargs = dict(device=device, num_steps_per_env=num_steps_per_env) + + return agent_kwargs, env_kwargs, runner_kwargs + + +samplers = { + DPPO.__name__: sample_dppo_hyperparams, + PPO.__name__: sample_ppo_hyperparams, +} diff --git a/examples/wandb_config.example.py b/examples/wandb_config.example.py new file mode 100644 index 0000000..051710b --- /dev/null +++ b/examples/wandb_config.example.py @@ -0,0 +1,4 @@ +# To use this file, copy it to wandb_config.py and fill in the missing values. + +WANDB_API_KEY = "" +WANDB_ENTITY = "" diff --git a/rsl_rl/__init__.py b/rsl_rl/__init__.py index 8b1a072..89f4d9a 100644 --- a/rsl_rl/__init__.py +++ b/rsl_rl/__init__.py @@ -1,7 +1,2 @@ # Copyright 2021 ETH Zurich, NVIDIA CORPORATION # SPDX-License-Identifier: BSD-3-Clause - -"""Main module for the rsl_rl package.""" - -__version__ = "2.0.1" -__license__ = "BSD-3" diff --git a/rsl_rl/algorithms/__init__.py b/rsl_rl/algorithms/__init__.py index 776258c..202faf2 100644 --- a/rsl_rl/algorithms/__init__.py +++ b/rsl_rl/algorithms/__init__.py @@ -1,8 +1,10 @@ -# Copyright 2021 ETH Zurich, NVIDIA CORPORATION -# SPDX-License-Identifier: BSD-3-Clause - -"""Implementation of different RL agents.""" - +from .agent import Agent +from .d4pg import D4PG +from .ddpg import DDPG +from .dppo import DPPO +from .dsac import DSAC from .ppo import PPO +from .sac import SAC +from .td3 import TD3 -__all__ = ["PPO"] +__all__ = ["Agent", "DDPG", "D4PG", "DPPO", "DSAC", "PPO", "SAC", "TD3"] diff --git a/rsl_rl/algorithms/actor_critic.py b/rsl_rl/algorithms/actor_critic.py new file mode 100644 index 0000000..6e3fe8b --- /dev/null +++ b/rsl_rl/algorithms/actor_critic.py @@ -0,0 +1,371 @@ +from abc import abstractmethod +import torch +from typing import Any, Callable, Dict, List, Tuple, Union + +from rsl_rl.algorithms.agent import Agent +from rsl_rl.env.vec_env import VecEnv +from rsl_rl.modules.network import Network +from rsl_rl.storage.storage import Dataset +from rsl_rl.utils.utils import environment_dimensions +from rsl_rl.utils.utils import squeeze_preserve_batch + + +class AbstractActorCritic(Agent): + _alg_features = dict(recurrent=False) + + def __init__( + self, + env: VecEnv, + actor_activations: List[str] = ["relu", "relu", "relu", "linear"], + actor_hidden_dims: List[int] = [256, 256, 256], + actor_init_gain: float = 0.5, + actor_input_normalization: bool = False, + actor_recurrent_layers: int = 1, + actor_recurrent_module: str = Network.recurrent_module_lstm, + actor_recurrent_tf_context_length: int = 64, + actor_recurrent_tf_head_count: int = 8, + actor_shared_dims: int = None, + batch_count: int = 1, + batch_size: int = 1, + critic_activations: List[str] = ["relu", "relu", "relu", "linear"], + critic_hidden_dims: List[int] = [256, 256, 256], + critic_init_gain: float = 0.5, + critic_input_normalization: bool = False, + critic_recurrent_layers: int = 1, + critic_recurrent_module: str = Network.recurrent_module_lstm, + critic_recurrent_tf_context_length: int = 64, + critic_recurrent_tf_head_count: int = 8, + critic_shared_dims: int = None, + polyak: float = 0.995, + recurrent: bool = False, + return_steps: int = 1, + _actor_input_size_delta: int = 0, + _critic_input_size_delta: int = 0, + **kwargs, + ): + """Creates an actor critic agent. + + Args: + env (VecEnv): A vectorized environment. + actor_activations (List[str]): A list of activation functions for the actor network. + actor_hidden_dims (List[str]): A list of layer sizes for the actor network. + actor_init_gain (float): Network initialization gain for actor. + actor_input_normalization (bool): Whether to empirically normalize inputs to the actor network. + actor_recurrent_layers (int): The number of recurrent layers to use for the actor network. + actor_recurrent_module (str): The recurrent module to use. Must be one of Network.recurrent_modules. + actor_shared_dims (int): The number of dimensions to share for an actor with multiple heads. + batch_count (int): The number of batches to process per update step. + batch_size (int): The size of each batch to process during the update step. + critic_activations (List[str]): A list of activation functions for the critic network. + critic_hidden_dims: (List[str]): A list of layer sizes for the critic network. + critic_init_gain (float): Network initialization gain for critic. + critic_input_normalization (bool): Whether to empirically normalize inputs to the critic network. + critic_recurrent_layers (int): The number of recurrent layers to use for the critic network. + critic_recurrent_module (str): The recurrent module to use. Must be one of Network.recurrent_modules. + critic_shared_dims (int): The number of dimensions to share for a critic with multiple heads. + polyak (float): The actor-critic target network polyak factor. + recurrent (bool): Whether to use recurrent actor and critic networks. + recurrent_module (str): The recurrent module to use. Must be one of Network.recurrent_modules. + recurrent_tf_context_length (int): The context length of the Transformer. + recurrent_tf_head_count (int): The head count of the Transformer. + return_steps (float): The number of steps over which to compute the returns (n-step return). + _actor_input_size_delta (int): The number of additional dimensions to add to the actor input. + _critic_input_size_delta (int): The number of additional dimensions to add to the critic input. + """ + assert ( + self._alg_features["recurrent"] == True or not recurrent + ), f"{self.__class__.__name__} does not support recurrent networks." + + super().__init__(env, **kwargs) + + self.actor: torch.nn.Module = None + self.actor_optimizer: torch.nn.Module = None + self.critic_optimizer: torch.nn.Module = None + self.critic: torch.nn.Module = None + + self._batch_size = batch_size + self._batch_count = batch_count + self._polyak_factor = polyak + self._return_steps = return_steps + self._recurrent = recurrent + + self._register_serializable( + "_batch_size", "_batch_count", "_discount_factor", "_polyak_factor", "_return_steps" + ) + + dimensions = environment_dimensions(self.env) + try: + actor_input_size = dimensions["actor_observations"] + critic_input_size = dimensions["critic_observations"] + except KeyError: + actor_input_size = dimensions["observations"] + critic_input_size = dimensions["observations"] + self._actor_input_size = actor_input_size + _actor_input_size_delta + self._critic_input_size = critic_input_size + self._action_size + _critic_input_size_delta + + self._register_actor_network_kwargs( + activations=actor_activations, + hidden_dims=actor_hidden_dims, + init_gain=actor_init_gain, + input_normalization=actor_input_normalization, + recurrent=recurrent, + recurrent_layers=actor_recurrent_layers, + recurrent_module=actor_recurrent_module, + recurrent_tf_context_length=actor_recurrent_tf_context_length, + recurrent_tf_head_count=actor_recurrent_tf_head_count, + ) + + if actor_shared_dims is not None: + self._register_actor_network_kwargs(shared_dims=actor_shared_dims) + + self._register_critic_network_kwargs( + activations=critic_activations, + hidden_dims=critic_hidden_dims, + init_gain=critic_init_gain, + input_normalization=critic_input_normalization, + recurrent=recurrent, + recurrent_layers=critic_recurrent_layers, + recurrent_module=critic_recurrent_module, + recurrent_tf_context_length=critic_recurrent_tf_context_length, + recurrent_tf_head_count=critic_recurrent_tf_head_count, + ) + + if critic_shared_dims is not None: + self._register_critic_network_kwargs(shared_dims=critic_shared_dims) + + self._register_serializable( + "_actor_input_size", "_actor_network_kwargs", "_critic_input_size", "_critic_network_kwargs" + ) + + # For computing n-step returns using prior transitions. + self._stored_dataset = [] + + def export_onnx(self) -> Tuple[torch.nn.Module, torch.Tensor, Dict]: + self.eval_mode() + + class ONNXActor(torch.nn.Module): + def __init__(self, model: torch.nn.Module): + super().__init__() + + self.model = model + + def forward(self, x: torch.Tensor, hidden_state: Tuple[torch.Tensor, torch.Tensor] = None): + if hidden_state is None: + return self.model(x) + + data = self.model(x, hidden_state=hidden_state) + hidden_state = self.model.last_hidden_state + + return data, hidden_state + + model = ONNXActor(self.actor) + kwargs = dict( + export_params=True, + opset_version=11, + verbose=True, + dynamic_axes={}, + ) + + kwargs["input_names"] = ["observations"] + kwargs["output_names"] = ["actions"] + + args = torch.zeros(1, self._actor_input_size) + + if self.actor.recurrent: + hidden_state = ( + torch.zeros(self.actor._features[0].num_layers, 1, self.actor._features[0].hidden_size), + torch.zeros(self.actor._features[0].num_layers, 1, self.actor._features[0].hidden_size), + ) + args = (args, {"hidden_state": hidden_state}) + + return model, args, kwargs + + def draw_random_actions(self, obs: torch.Tensor, env_info: Dict[str, Any]) -> Tuple[torch.Tensor, Dict]: + actions, data = super().draw_random_actions(obs, env_info) + + actor_obs, critic_obs = self._process_observations(obs, env_info) + data.update({"actor_observations": actor_obs.clone(), "critic_observations": critic_obs.clone()}) + + return actions, data + + def get_inference_policy(self, device=None) -> Callable: + self.to(device) + self.eval_mode() + + if self.actor.recurrent: + self.actor.reset_full_hidden_state(batch_size=self.env.num_envs) + + if self.critic.recurrent: + self.critic.reset_full_hidden_state(batch_size=self.env.num_envs) + + def policy(obs, env_info=None): + with torch.inference_mode(): + obs, _ = self._process_observations(obs, env_info) + + actions = self._process_actions(self.actor.forward(obs)) + + return actions + + return policy + + def process_transition( + self, + observations: torch.Tensor, + environment_info: Dict[str, Any], + actions: torch.Tensor, + rewards: torch.Tensor, + next_observations: torch.Tensor, + next_environment_info: torch.Tensor, + dones: torch.Tensor, + data: Dict[str, torch.Tensor], + ) -> Dict[str, torch.Tensor]: + if "actor_observations" in data and "critic_observations" in data: + actor_obs, critic_obs = data["actor_observations"], data["critic_observations"] + else: + actor_obs, critic_obs = self._process_observations(observations, environment_info) + + if "next_actor_observations" in data and "next_critic_observations" in data: + next_actor_obs, next_critic_obs = data["next_actor_observations"], data["next_critic_observations"] + else: + next_actor_obs, next_critic_obs = self._process_observations(next_observations, next_environment_info) + + transition = { + "actions": actions, + "actor_observations": actor_obs, + "critic_observations": critic_obs, + "dones": dones, + "next_actor_observations": next_actor_obs, + "next_critic_observations": next_critic_obs, + "rewards": squeeze_preserve_batch(rewards), + "timeouts": self._extract_timeouts(next_environment_info), + } + transition.update(data) + + for key, value in transition.items(): + transition[key] = value.detach().clone() + + return transition + + @property + def recurrent(self) -> bool: + return self._recurrent + + def register_terminations(self, terminations: torch.Tensor) -> None: + pass + + @abstractmethod + def update(self, dataset: Dataset) -> Dict[str, Union[float, torch.Tensor]]: + with torch.inference_mode(): + self.storage.append(self._process_dataset(dataset)) + + def _critic_input(self, observations, actions) -> torch.Tensor: + """Combines observations and actions into a tensor that can be fed into the critic network. + + Args: + observations (torch.Tensor): The critic observations. + actions (torch.Tensor): The actions computed by the actor. + Returns: + A torch.Tensor of inputs for the critic network. + """ + return torch.cat((observations, actions), dim=-1) + + def _extract_timeouts(self, next_environment_info): + """Extracts timeout information from the transition next state information dictionary. + + Args: + next_environment_info (Dict[str, Any]): The transition next state information dictionary. + Returns: + A torch.Tensor vector of actor timeouts. + """ + if "time_outs" not in next_environment_info: + return torch.zeros(self.env.num_envs, device=self.device) + + timeouts = squeeze_preserve_batch(next_environment_info["time_outs"].to(self.device)) + + return timeouts + + def _process_dataset(self, dataset: Dataset) -> Dataset: + """Processes a dataset before it is added to the replay memory. + + Handles n-step returns and timeouts. + + TODO: This function seems to be a bottleneck in the training pipeline - speed it up! + + Args: + dataset (Dataset): The dataset to process. + Returns: + A Dataset object containing the processed data. + """ + assert len(dataset) >= self._return_steps + + dataset = self._stored_dataset + dataset + length = len(dataset) - self._return_steps + 1 + self._stored_dataset = dataset[length:] + + for idx in range(len(dataset) - self._return_steps + 1): + dead_idx = torch.zeros_like(dataset[idx]["dones"]) + rewards = torch.zeros_like(dataset[idx]["rewards"]) + + for k in range(self._return_steps): + data = dataset[idx + k] + alive_idx = (dead_idx == 0).nonzero() + critic_predictions = self.critic.forward( + self._critic_input( + data["critic_observations"].clone().to(self.device), + data["actions"].clone().to(self.device), + ) + ) + rewards[alive_idx] += self._discount_factor**k * data["rewards"][alive_idx] + rewards[alive_idx] += ( + self._discount_factor ** (k + 1) * data["timeouts"][alive_idx] * critic_predictions[alive_idx] + ) + dead_idx += data["dones"] + dead_idx += data["timeouts"] + + dataset[idx]["rewards"] = rewards + + return dataset[:length] + + def _process_observations( + self, observations: torch.Tensor, environment_info: Dict[str, Any] = None + ) -> Tuple[torch.Tensor, ...]: + """Processes observations returned by the environment to extract actor and critic observations. + + Args: + observations (torch.Tensor): normal environment observations. + environment_info (Dict[str, Any]): A dictionary of additional environment information. + Returns: + A tuple containing two torch.Tensors with actor and critic observations, respectively. + """ + try: + critic_obs = environment_info["observations"]["critic"] + except (KeyError, TypeError): + critic_obs = observations + + actor_obs, critic_obs = observations.to(self.device), critic_obs.to(self.device) + + return actor_obs, critic_obs + + def _register_actor_network_kwargs(self, **kwargs) -> None: + """Function to configure actor network in child classes before calling super().__init__().""" + if not hasattr(self, "_actor_network_kwargs"): + self._actor_network_kwargs = dict() + + self._actor_network_kwargs.update(**kwargs) + + def _register_critic_network_kwargs(self, **kwargs) -> None: + """Function to configure critic network in child classes before calling super().__init__().""" + if not hasattr(self, "_critic_network_kwargs"): + self._critic_network_kwargs = dict() + + self._critic_network_kwargs.update(**kwargs) + + def _update_target(self, online: torch.nn.Module, target: torch.nn.Module) -> None: + """Updates the target network using the polyak factor. + + Args: + online (torch.nn.Module): The online network. + target (torch.nn.Module): The target network. + """ + for op, tp in zip(online.parameters(), target.parameters()): + tp.data.copy_((1.0 - self._polyak_factor) * op.data + self._polyak_factor * tp.data) diff --git a/rsl_rl/algorithms/agent.py b/rsl_rl/algorithms/agent.py new file mode 100644 index 0000000..90afd04 --- /dev/null +++ b/rsl_rl/algorithms/agent.py @@ -0,0 +1,197 @@ +from __future__ import annotations +from abc import ABC, abstractmethod +import numpy as np +import torch +from typing import Any, Callable, Dict, Tuple, Union + +from rsl_rl.env import VecEnv +from rsl_rl.storage.storage import Dataset +from rsl_rl.utils.benchmarkable import Benchmarkable +from rsl_rl.utils.serializable import Serializable +from rsl_rl.utils.utils import environment_dimensions + + +class Agent(ABC, Benchmarkable, Serializable): + def __init__( + self, + env: VecEnv, + action_max: float = np.inf, + action_min: float = -np.inf, + benchmark: bool = False, + device: str = "cpu", + gamma: float = 0.99, + ): + """Creates an agent. + + Args: + env (VecEnv): The envrionment of the agent. + action_max (float): The maximum action value. + action_min (float): The minimum action value. + bechmark (bool): Whether to benchmark runtime. + device (str): The device to use for computation. + gamma (float): The environment discount factor. + """ + super().__init__() + + self.env = env + self.device = device + self.storage = None + + self._action_max = action_max + self._action_min = action_min + self._discount_factor = gamma + + self._register_serializable("_action_max", "_action_min", "_discount_factor") + + dimensions = environment_dimensions(self.env) + self._action_size = dimensions["actions"] + + self._register_serializable("_action_size") + + if self._action_min > -np.inf and self._action_max < np.inf: + self._rand_scale = self._action_max - self._action_min + self._rand_offset = self._action_min + else: + self._rand_scale = 2.0 + self._rand_offset = -1.0 + + self._bm_toggle(benchmark) + + @abstractmethod + def draw_actions( + self, obs: torch.Tensor, env_info: Dict[str, Any] + ) -> Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]: + """Draws actions from the action space. + + Args: + obs (torch.Tensor): The observations for which to draw actions. + env_info (Dict[str, Any]): The environment information for the observations. + Returns: + A tuple containing the actions and the data dictionary. + """ + pass + + def draw_random_actions( + self, obs: torch.Tensor, env_info: Dict[str, Any] + ) -> Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]: + """Draws random actions from the action space. + + Args: + obs (torch.Tensor): The observations to include in the data dictionary. + env_info (Dict[str, Any]): The environment information to include in the data dictionary. + Returns: + A tuple containing the random actions and the data dictionary. + """ + actions = self._process_actions( + self._rand_offset + self._rand_scale * torch.rand(self.env.num_envs, self._action_size) + ) + + return actions, {} + + @abstractmethod + def eval_mode(self) -> Agent: + """Sets the agent to evaluation mode.""" + return self + + @abstractmethod + def export_onnx(self) -> Tuple[torch.nn.Module, torch.Tensor, Dict]: + """Exports the agent's policy network to ONNX format. + + Returns: + A tuple containing the ONNX model, the input arguments, and the keyword arguments. + """ + pass + + @property + def gamma(self) -> float: + return self._discount_factor + + @abstractmethod + def get_inference_policy(self, device: str = None) -> Callable: + """Returns a function that computes actions from observations without storing gradients. + + Args: + device (torch.device): The device to use for inference. + Returns: + A function that computes actions from observations. + """ + pass + + @property + def initialized(self) -> bool: + """Whether the agent has been initialized.""" + return self.storage.initialized + + @abstractmethod + def process_transition( + self, + observations: torch.Tensor, + environement_info: Dict[str, Any], + actions: torch.Tensor, + rewards: torch.Tensor, + next_observations: torch.Tensor, + next_environment_info: torch.Tensor, + dones: torch.Tensor, + data: Dict[str, torch.Tensor], + ) -> Dict[str, torch.Tensor]: + """Processes a transition before it is added to the replay memory. + + Args: + observations (torch.Tensor): The observations from the environment. + environment_info (Dict[str, Any]): The environment information. + actions (torch.Tensor): The actions computed by the actor. + rewards (torch.Tensor): The rewards from the environment. + next_observations (torch.Tensor): The next observations from the environment. + next_environment_info (Dict[str, Any]): The next environment information. + dones (torch.Tensor): The done flags from the environment. + data (Dict[str, torch.Tensor]): Additional data to include in the transition. + Returns: + A dictionary containing the processed transition. + """ + pass + + @abstractmethod + def register_terminations(self, terminations: torch.Tensor) -> None: + """Registers terminations with the actor critic agent. + + Args: + terminations (torch.Tensor): A tensor of indicator values for each environment. + """ + pass + + @abstractmethod + def to(self, device: str) -> Agent: + """Transfers agent parameters to device.""" + self.device = device + + return self + + @abstractmethod + def train_mode(self) -> Agent: + """Sets the agent to training mode.""" + return self + + @abstractmethod + def update(self, dataset: Dataset) -> Dict[str, Union[float, torch.Tensor]]: + """Updates the agent's parameters. + + Args: + dataset (Dataset): The dataset from which to update the agent. + Returns: + A dictionary containing the loss values. + """ + pass + + def _process_actions(self, actions: torch.Tensor) -> torch.Tensor: + """Processes actions produced by the agent. + + Args: + actions (torch.Tensor): The raw actions. + Returns: + A torch.Tensor containing the processed actions. + """ + actions = actions.reshape(-1, self._action_size) + actions = actions.clamp(self._action_min, self._action_max) + actions = actions.to(self.device) + + return actions diff --git a/rsl_rl/algorithms/d4pg.py b/rsl_rl/algorithms/d4pg.py new file mode 100644 index 0000000..0dd5535 --- /dev/null +++ b/rsl_rl/algorithms/d4pg.py @@ -0,0 +1,168 @@ +from __future__ import annotations +import torch +from typing import Dict, Union +from rsl_rl.algorithms.dpg import AbstractDPG +from rsl_rl.env import VecEnv +from rsl_rl.storage.storage import Dataset + +from rsl_rl.modules import CategoricalNetwork, Network + + +class D4PG(AbstractDPG): + """Distributed Distributional Deep Deterministic Policy Gradients algorithm. + + This is an implementation of the D4PG algorithm by Barth-Maron et. al. for vectorized environments. + + Paper: https://arxiv.org/pdf/1804.08617.pdf + """ + + def __init__( + self, + env: VecEnv, + actor_lr: float = 1e-4, + atom_count: int = 51, + critic_activations: list = ["relu", "relu", "relu"], + critic_lr: float = 1e-3, + target_update_delay: int = 2, + value_max: float = 10.0, + value_min: float = -10.0, + **kwargs, + ) -> None: + """ + Args: + env (VecEnv): A vectorized environment. + actor_lr (float): The learning rate for the actor network. + atom_count (int): The number of atoms to use for the categorical distribution. + critic_activations (list): A list of activation functions to use for the critic network. + critic_lr (float): The learning rate for the critic network. + target_update_delay (int): The number of steps to wait before updating the target networks. + value_max (float): The maximum value for the categorical distribution. + value_min (float): The minimum value for the categorical distribution. + """ + kwargs["critic_activations"] = critic_activations + + super().__init__(env, **kwargs) + + self._atom_count = atom_count + self._target_update_delay = target_update_delay + self._value_max = value_max + self._value_min = value_min + + self._register_serializable("_atom_count", "_target_update_delay", "_value_max", "_value_min") + + self.actor = Network(self._actor_input_size, self._action_size, **self._actor_network_kwargs) + self.critic = CategoricalNetwork( + self._critic_input_size, + 1, + atom_count=atom_count, + value_max=value_max, + value_min=value_min, + **self._critic_network_kwargs, + ) + + self.target_actor = Network(self._actor_input_size, self._action_size, **self._actor_network_kwargs) + self.target_critic = CategoricalNetwork( + self._critic_input_size, + 1, + atom_count=atom_count, + value_max=value_max, + value_min=value_min, + **self._critic_network_kwargs, + ) + self.target_actor.load_state_dict(self.actor.state_dict()) + self.target_critic.load_state_dict(self.critic.state_dict()) + + self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr) + self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr) + + self._register_serializable( + "actor", "critic", "target_actor", "target_critic", "actor_optimizer", "critic_optimizer" + ) + + self._update_step = 0 + + self._register_serializable("_update_step") + + self.to(self.device) + + def eval_mode(self) -> D4PG: + super().eval_mode() + + self.actor.eval() + self.critic.eval() + self.target_actor.eval() + self.target_critic.eval() + + return self + + def to(self, device: str) -> D4PG: + super().to(device) + + self.actor.to(device) + self.critic.to(device) + self.target_actor.to(device) + self.target_critic.to(device) + + return self + + def train_mode(self) -> D4PG: + super().train_mode() + + self.actor.train() + self.critic.train() + self.target_actor.train() + self.target_critic.train() + + return self + + def update(self, dataset: Dataset) -> Dict[str, Union[float, torch.Tensor]]: + super().update(dataset) + + if not self.initialized: + return {} + + total_actor_loss = torch.zeros(self._batch_count) + total_critic_loss = torch.zeros(self._batch_count) + + for idx, batch in enumerate(self.storage.batch_generator(self._batch_size, self._batch_count)): + actor_obs = batch["actor_observations"] + critic_obs = batch["critic_observations"] + actions = batch["actions"].reshape(self._batch_size, -1) + rewards = batch["rewards"] + actor_next_obs = batch["next_actor_observations"] + critic_next_obs = batch["next_critic_observations"] + dones = batch["dones"] + + predictions = self.critic.forward(self._critic_input(critic_obs, actions), distribution=True).squeeze() + target_actor_prediction = self._process_actions(self.target_actor.forward(actor_next_obs)) + target_probabilities = self.target_critic.forward( + self._critic_input(critic_next_obs, target_actor_prediction), distribution=True + ).squeeze() + targets = self.target_critic.compute_targets(rewards, dones, self._discount_factor) + critic_loss = self.target_critic.categorical_loss(predictions, target_probabilities, targets) + + self.critic_optimizer.zero_grad() + critic_loss.backward() + self.critic_optimizer.step() + + evaluation = self.critic.forward( + self._critic_input(critic_obs, self._process_actions(self.actor.forward(actor_obs))) + ) + actor_loss = -evaluation.mean() + + self.actor_optimizer.zero_grad() + actor_loss.backward() + self.actor_optimizer.step() + + if self._update_step % self._target_update_delay == 0: + self._update_target(self.actor, self.target_actor) + self._update_target(self.critic, self.target_critic) + + self._update_step += 1 + + total_actor_loss[idx] = actor_loss.item() + total_critic_loss[idx] = critic_loss.item() + + stats = {"actor": total_actor_loss.mean().item(), "critic": total_critic_loss.mean().item()} + + return stats diff --git a/rsl_rl/algorithms/ddpg.py b/rsl_rl/algorithms/ddpg.py new file mode 100644 index 0000000..ec91793 --- /dev/null +++ b/rsl_rl/algorithms/ddpg.py @@ -0,0 +1,125 @@ +from __future__ import annotations +import torch +from torch import optim +from typing import Dict, Union + +from rsl_rl.algorithms.dpg import AbstractDPG +from rsl_rl.env import VecEnv +from rsl_rl.modules.network import Network +from rsl_rl.storage.storage import Dataset + + +class DDPG(AbstractDPG): + """Deep Deterministic Policy Gradients algorithm. + + This is an implementation of the DDPG algorithm by Lillicrap et. al. for vectorized environments. + + Paper: https://arxiv.org/pdf/1509.02971.pdf + """ + + def __init__( + self, + env: VecEnv, + actor_lr: float = 1e-4, + critic_lr: float = 1e-3, + **kwargs, + ) -> None: + super().__init__(env, **kwargs) + + self.actor = Network(self._actor_input_size, self._action_size, **self._actor_network_kwargs) + self.critic = Network(self._critic_input_size, 1, **self._critic_network_kwargs) + + self.target_actor = Network(self._actor_input_size, self._action_size, **self._actor_network_kwargs) + self.target_critic = Network(self._critic_input_size, 1, **self._critic_network_kwargs) + self.target_actor.load_state_dict(self.actor.state_dict()) + self.target_critic.load_state_dict(self.critic.state_dict()) + + self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr) + self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=critic_lr) + + self._register_serializable( + "actor", "critic", "target_actor", "target_critic", "actor_optimizer", "critic_optimizer" + ) + + self.to(self.device) + + def eval_mode(self) -> DDPG: + super().eval_mode() + + self.actor.eval() + self.critic.eval() + self.target_actor.eval() + self.target_critic.eval() + + return self + + def to(self, device: str) -> DDPG: + """Transfers agent parameters to device.""" + super().to(device) + + self.actor.to(device) + self.critic.to(device) + self.target_actor.to(device) + self.target_critic.to(device) + + return self + + def train_mode(self) -> DDPG: + super().train_mode() + + self.actor.train() + self.critic.train() + self.target_actor.train() + self.target_critic.train() + + return self + + def update(self, dataset: Dataset) -> Dict[str, Union[float, torch.Tensor]]: + super().update(dataset) + + if not self.initialized: + return {} + + total_actor_loss = torch.zeros(self._batch_count) + total_critic_loss = torch.zeros(self._batch_count) + + for idx, batch in enumerate(self.storage.batch_generator(self._batch_size, self._batch_count)): + actor_obs = batch["actor_observations"] + critic_obs = batch["critic_observations"] + actions = batch["actions"] + rewards = batch["rewards"] + actor_next_obs = batch["next_actor_observations"] + critic_next_obs = batch["next_critic_observations"] + dones = batch["dones"] + + target_actor_prediction = self._process_actions(self.target_actor.forward(actor_next_obs)) + target_critic_prediction = self.target_critic.forward( + self._critic_input(critic_next_obs, target_actor_prediction) + ) + + target = rewards + self._discount_factor * (1 - dones) * target_critic_prediction + prediction = self.critic.forward(self._critic_input(critic_obs, actions)) + critic_loss = (prediction - target).pow(2).mean() + + self.critic_optimizer.zero_grad() + critic_loss.backward() + self.critic_optimizer.step() + + evaluation = self.critic.forward( + self._critic_input(critic_obs, self._process_actions(self.actor.forward(actor_obs))) + ) + actor_loss = -evaluation.mean() + + self.actor_optimizer.zero_grad() + actor_loss.backward() + self.actor_optimizer.step() + + self._update_target(self.actor, self.target_actor) + self._update_target(self.critic, self.target_critic) + + total_actor_loss[idx] = actor_loss.item() + total_critic_loss[idx] = critic_loss.item() + + stats = {"actor": total_actor_loss.mean().item(), "critic": total_critic_loss.mean().item()} + + return stats diff --git a/rsl_rl/algorithms/dpg.py b/rsl_rl/algorithms/dpg.py new file mode 100644 index 0000000..8173d0b --- /dev/null +++ b/rsl_rl/algorithms/dpg.py @@ -0,0 +1,49 @@ +import torch +from typing import Any, Dict, Tuple, Union + +from rsl_rl.algorithms.actor_critic import AbstractActorCritic +from rsl_rl.env import VecEnv +from rsl_rl.storage.replay_storage import ReplayStorage +from rsl_rl.storage.storage import Dataset + + +class AbstractDPG(AbstractActorCritic): + def __init__( + self, env: VecEnv, action_noise_scale: float = 0.1, storage_initial_size=0, storage_size=100000, **kwargs + ): + """ + Args: + env (VecEnv): A vectorized environment. + action_noise_scale (float): The scale of the gaussian action noise. + storage_initial_size (int): Initial size of the replay storage. + storage_size (int): Maximum size of the replay storage. + """ + assert action_noise_scale > 0 + + super().__init__(env, **kwargs) + + self.storage = ReplayStorage( + self.env.num_envs, storage_size, device=self.device, initial_size=storage_initial_size + ) + + self._register_serializable("storage") + + self._action_noise_scale = action_noise_scale + + self._register_serializable("_action_noise_scale") + + def draw_actions( + self, obs: torch.Tensor, env_info: Dict[str, Any] + ) -> Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]: + actor_obs, critic_obs = self._process_observations(obs, env_info) + + actions = self.actor.forward(actor_obs) + noise = torch.normal(torch.zeros_like(actions), torch.ones_like(actions) * self._action_noise_scale) + noisy_actions = self._process_actions(actions + noise) + + data = {"actor_observations": actor_obs.clone(), "critic_observations": critic_obs.clone()} + + return noisy_actions, data + + def register_terminations(self, terminations: torch.Tensor) -> None: + pass diff --git a/rsl_rl/algorithms/dppo.py b/rsl_rl/algorithms/dppo.py new file mode 100644 index 0000000..b919816 --- /dev/null +++ b/rsl_rl/algorithms/dppo.py @@ -0,0 +1,327 @@ +import torch +from torch import nn +from typing import Dict, List, Tuple, Type, Union + +from rsl_rl.algorithms.ppo import PPO +from rsl_rl.distributions import QuantileDistribution +from rsl_rl.env import VecEnv +from rsl_rl.utils.benchmarkable import Benchmarkable +from rsl_rl.utils.recurrency import trajectories_to_transitions, transitions_to_trajectories +from rsl_rl.modules import ImplicitQuantileNetwork, QuantileNetwork +from rsl_rl.storage.storage import Dataset + + +class DPPO(PPO): + """Distributional Proximal Policy Optimization algorithm. + + This algorithm is an extension of PPO that uses a distributional method (either QR-DQN or IQN) to estimate the + value function. + + QR-DQN Paper: https://arxiv.org/pdf/1710.10044.pdf + IQN Paper: https://arxiv.org/pdf/1806.06923.pdf + + The implementation works with recurrent neural networks. We further implement Sample-Replacement SR(lambda) for the + value target computation, as described by Nam et. al. in https://arxiv.org/pdf/2105.11366.pdf. + """ + + critic_network: Type[nn.Module] = QuantileNetwork + _alg_features = dict(recurrent=True) + + value_loss_energy = "sample_energy" + value_loss_l1 = "quantile_l1" + value_loss_huber = "quantile_huber" + + network_qrdqn = "qrdqn" + network_iqn = "iqn" + + networks = {network_qrdqn: QuantileNetwork, network_iqn: ImplicitQuantileNetwork} + + values_losses = { + network_qrdqn: { + value_loss_energy: QuantileNetwork.sample_energy_loss, + value_loss_l1: QuantileNetwork.quantile_l1_loss, + value_loss_huber: QuantileNetwork.quantile_huber_loss, + }, + network_iqn: { + value_loss_energy: ImplicitQuantileNetwork.sample_energy_loss, + }, + } + + def __init__( + self, + env: VecEnv, + critic_activations: List[str] = ["relu", "relu", "relu"], + critic_network: str = network_qrdqn, + iqn_action_samples: int = 32, + iqn_embedding_size: int = 64, + iqn_feature_layers: int = 1, + iqn_value_samples: int = 8, + qrdqn_quantile_count: int = 200, + value_lambda: float = 0.95, + value_loss: str = value_loss_l1, + value_loss_kwargs: Dict = {}, + value_measure: str = None, + value_measure_adaptation: Union[Tuple, None] = None, + value_measure_kwargs: Dict = {}, + **kwargs, + ): + """ + Args: + env (VecEnv): A vectorized environment. + critic_activations (List[str]): A list of activations to use for the critic network. + critic_network (str): The critic network to use. + iqn_action_samples (int): The number of samples to use for the critic IQN network when acting. + iqn_embedding_size (int): The embedding size to use for the critic IQN network. + iqn_feature_layers (int): The number of feature layers to use for the critic IQN network. + iqn_value_samples (int): The number of samples to use for the critic IQN network when computing the value. + qrdqn_quantile_count (int): The number of quantiles to use for the critic QR network. + value_lambda (float): The lambda parameter for the SR(lambda) value target computation. + value_loss (str): The loss function to use for the critic network. + value_loss_kwargs (Dict): Keyword arguments for computing the value loss. + value_measure (str): The probability measure to apply to the critic network output distribution when + updating the policy. + value_measure_adaptation (Union[Tuple, None]): Controls adaptation of the value measure. If None, no + adaptation is performed. If a tuple, the tuple specifies the observations that are passed to the value + measure. + value_measure_kwargs (Dict): The keyword arguments to pass to the value measure. + """ + self._register_critic_network_kwargs(measure=value_measure, measure_kwargs=value_measure_kwargs) + + self._critic_network_name = critic_network + self.critic_network = self.networks[self._critic_network_name] + if self._critic_network_name == self.network_qrdqn: + self._register_critic_network_kwargs(quantile_count=qrdqn_quantile_count) + elif self._critic_network_name == self.network_iqn: + self._register_critic_network_kwargs(feature_layers=iqn_feature_layers, embedding_size=iqn_embedding_size) + + kwargs["critic_activations"] = critic_activations + + if value_measure_adaptation is not None: + # Value measure adaptation observations are not passed to the critic network. + kwargs["_critic_input_size_delta"] = ( + kwargs["_critic_input_size_delta"] if "_critic_input_size_delta" in kwargs else 0 + ) - len(value_measure_adaptation) + + super().__init__(env, **kwargs) + + self._value_lambda = value_lambda + self._value_loss_name = value_loss + self._register_serializable("_value_lambda", "_value_loss_name") + + assert ( + self._value_loss_name in self.values_losses[self._critic_network_name] + ), f"Value loss '{self._value_loss_name}' is not supported for network '{self._critic_network_name}'." + value_loss_func = self.values_losses[critic_network][self._value_loss_name] + self._value_loss = lambda *args, **kwargs: value_loss_func(self.critic, *args, **kwargs) + + if value_loss == self.value_loss_energy: + value_loss_kwargs["sample_count"] = ( + value_loss_kwargs["sample_count"] if "sample_count" in value_loss_kwargs else 100 + ) + + self._value_loss_kwargs = value_loss_kwargs + self._register_serializable("_value_loss_kwargs") + + self._value_measure_adaptation = value_measure_adaptation + self._register_serializable("_value_measure_adaptation") + + if self._critic_network_name == self.network_iqn: + self._iqn_action_samples = iqn_action_samples + self._iqn_value_samples = iqn_value_samples + self._register_serializable("_iqn_action_samples", "_iqn_value_samples") + + def _critic_input(self, observations, actions=None) -> torch.Tensor: + mask, shape = self._get_critic_obs_mask(observations) + + processed_observations = observations[mask].reshape(*shape) + + return processed_observations + + def _get_critic_obs_mask(self, observations): + mask = torch.ones_like(observations).bool() + + if self._value_measure_adaptation is not None: + mask[:, self._value_measure_adaptation] = False + + shape = (observations.shape[0], self._critic_input_size) + + return mask, shape + + def _process_quants(self, x): + if self._value_loss_name == self.value_loss_energy: + quants, idx = QuantileDistribution(x).sample(self._value_loss_kwargs["sample_count"]) + else: + quants, idx = x, None + + return quants, idx + + @Benchmarkable.register + def process_transition(self, *args) -> Dict[str, torch.Tensor]: + transition = super(PPO, self).process_transition(*args) + + if self.recurrent: + transition["critic_state_h"] = self.critic.hidden_state[0].detach() + transition["critic_state_c"] = self.critic.hidden_state[1].detach() + + transition["full_critic_observations"] = transition["critic_observations"].detach() + transition["full_next_critic_observations"] = transition["next_critic_observations"].detach() + mask, shape = self._get_critic_obs_mask(transition["critic_observations"]) + transition["critic_observations"] = transition["critic_observations"][mask].reshape(*shape) + transition["next_critic_observations"] = transition["next_critic_observations"][mask].reshape(*shape) + + critic_kwargs = ( + {"sample_count": self._iqn_action_samples} if self._critic_network_name == self.network_iqn else {} + ) + transition["values"] = self.critic.forward( + transition["critic_observations"], + measure_args=self._extract_value_measure_adaptation(transition["full_critic_observations"]), + **critic_kwargs, + ).detach() + + if self._critic_network_name == self.network_iqn: + # For IQN, we sample new (undistorted) quantiles for computing the value update + critic_kwargs = ( + {"hidden_state": (transition["critic_state_h"], transition["critic_state_c"])} if self.recurrent else {} + ) + self.critic.forward( + transition["critic_observations"], + sample_count=self._iqn_value_samples, + use_measure=False, + **critic_kwargs, + ).detach() + transition["value_taus"] = self.critic.last_taus.detach().reshape(transition["values"].shape[0], -1) + + transition["value_quants"] = self.critic.last_quantiles.detach().reshape(transition["values"].shape[0], -1) + + if self.recurrent: + transition["critic_next_state_h"] = self.critic.hidden_state[0].detach() + transition["critic_next_state_c"] = self.critic.hidden_state[1].detach() + + return transition + + @Benchmarkable.register + def _compute_value_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + critic_kwargs = ( + {"sample_count": self._iqn_value_samples, "taus": batch["value_target_taus"], "use_measure": False} + if self._critic_network_name == self.network_iqn + else {} + ) + + if self.recurrent: + observations, data = transitions_to_trajectories(batch["critic_observations"], batch["dones"]) + hidden_state_h, _ = transitions_to_trajectories(batch["critic_state_h"], batch["dones"]) + hidden_state_c, _ = transitions_to_trajectories(batch["critic_state_c"], batch["dones"]) + hidden_states = (hidden_state_h[0].transpose(0, 1), hidden_state_c[0].transpose(0, 1)) + + if self._critic_network_name == self.network_iqn: + critic_kwargs["taus"], _ = transitions_to_trajectories(critic_kwargs["taus"], batch["dones"]) + + trajectory_evaluations = self.critic.forward( + observations, distribution=True, hidden_state=hidden_states, **critic_kwargs + ) + trajectory_evaluations = trajectory_evaluations.reshape(*observations.shape[:-1], -1) + + predictions = trajectories_to_transitions(trajectory_evaluations, data) + else: + predictions = self.critic.forward(batch["critic_observations"], distribution=True, **critic_kwargs) + + value_loss = self._value_loss(self._process_quants(predictions)[0], batch["value_target_quants"]) + + return value_loss + + def _extract_value_measure_adaptation(self, observations: torch.Tensor) -> Tuple[torch.Tensor]: + if self._value_measure_adaptation is None: + return tuple() + + relevant_observations = observations[:, self._value_measure_adaptation] + measure_adaptations = torch.tensor_split(relevant_observations, relevant_observations.shape[1], dim=1) + + return measure_adaptations + + @Benchmarkable.register + def _process_dataset(self, dataset: Dataset) -> Dataset: + rewards = torch.stack([entry["rewards"] for entry in dataset]) + dones = torch.stack([entry["dones"] for entry in dataset]).float() + timeouts = torch.stack([entry["timeouts"] for entry in dataset]) + values = torch.stack([entry["values"] for entry in dataset]) + + value_quants_idx = [self._process_quants(entry["value_quants"]) for entry in dataset] + value_quants = torch.stack([entry[0] for entry in value_quants_idx]) + + critic_kwargs = ( + {"hidden_state": (dataset[-1]["critic_state_h"], dataset[-1]["critic_state_c"])} if self.recurrent else {} + ) + if self._critic_network_name == self.network_iqn: + critic_kwargs["sample_count"] = self._iqn_action_samples + + measure_args = self._extract_value_measure_adaptation(dataset[-1]["full_next_critic_observations"]) + next_values = self.critic.forward( + dataset[-1]["next_critic_observations"], measure_args=measure_args, **critic_kwargs + ) + + if self._critic_network_name == self.network_iqn: + # For IQN, we sample new (undistorted) quantiles for computing the value update + critic_kwargs["sample_count"] = self._iqn_value_samples + self.critic.forward( + dataset[-1]["next_critic_observations"], + use_measure=False, + **critic_kwargs, + ) + + final_value_taus = self.critic.last_taus + value_taus = torch.stack( + [ + torch.take_along_dim(dataset[i]["value_taus"], value_quants_idx[i][1], -1) + for i in range(len(dataset)) + ] + ) + + final_value_quants = self.critic.last_quantiles + + # Timeout bootstrapping for rewards. + rewards += self.gamma * timeouts * values + + # Compute advantages and value target quantiles + next_values = torch.cat((values[1:], next_values.unsqueeze(0)), dim=0) + deltas = (rewards + (1 - dones) * self.gamma * next_values - values).reshape(-1, self.env.num_envs) + advantages = torch.zeros((len(dataset) + 1, self.env.num_envs), device=self.device) + + next_value_quants, idx = self._process_quants(final_value_quants) + value_target_quants = torch.zeros(len(dataset), *next_value_quants.shape, device=self.device) + + if self._critic_network_name == self.network_iqn: + value_target_taus = torch.zeros(len(dataset) + 1, *next_value_quants.shape, device=self.device) + value_target_taus[-1] = torch.take_along_dim(final_value_taus, idx, -1) + + for step in reversed(range(len(dataset))): + not_terminal = 1.0 - dones[step] + not_terminal_ = not_terminal.unsqueeze(-1) + + advantages[step] = deltas[step] + (1.0 - dones[step]) * self.gamma * self._gae_lambda * advantages[step + 1] + value_target_quants[step] = rewards[step].unsqueeze(-1) + not_terminal_ * self.gamma * next_value_quants + + preserved_value_quants = not_terminal_.bool() * ( + torch.rand(next_value_quants.shape, device=self.device) < self._value_lambda + ) + next_value_quants = torch.where(preserved_value_quants, value_target_quants[step], value_quants[step]) + + if self._critic_network_name == self.network_iqn: + value_target_taus[step] = torch.where( + preserved_value_quants, value_target_taus[step + 1], value_taus[step] + ) + + advantages = advantages[:-1] + if self._critic_network_name == self.network_iqn: + value_target_taus = value_target_taus[:-1] + + # Normalize advantages and pack into dataset structure. + amean, astd = advantages.mean(), torch.nan_to_num(advantages.std()) + for step in range(len(dataset)): + dataset[step]["advantages"] = advantages[step] + dataset[step]["normalized_advantages"] = (advantages[step] - amean) / (astd + 1e-8) + dataset[step]["value_target_quants"] = value_target_quants[step] + + if self._critic_network_name == self.network_iqn: + dataset[step]["value_target_taus"] = value_target_taus[step] + + return dataset diff --git a/rsl_rl/algorithms/dsac.py b/rsl_rl/algorithms/dsac.py new file mode 100644 index 0000000..94d13da --- /dev/null +++ b/rsl_rl/algorithms/dsac.py @@ -0,0 +1,75 @@ +import torch +from torch import nn +from typing import Tuple, Type + + +from rsl_rl.algorithms.sac import SAC +from rsl_rl.env import VecEnv +from rsl_rl.modules.quantile_network import QuantileNetwork + + +class DSAC(SAC): + """Deep Soft Actor Critic (DSAC) algorithm. + + This is an implementation of the DSAC algorithm by Ma et. al. for vectorized environments. + + Paper: https://arxiv.org/pdf/2004.14547.pdf + + The implementation inherits automatic tuning of the temperature parameter (alpha) and tanh action scaling from + the SAC implementation. + """ + + critic_network: Type[nn.Module] = QuantileNetwork + + def __init__(self, env: VecEnv, critic_activations=["relu", "relu", "relu"], quantile_count=200, **kwargs): + """ + Args: + env (VecEnv): A vectorized environment. + critic_activations (list): A list of activation functions to use for the critic network. + quantile_count (int): The number of quantiles to use for the critic QR network. + """ + self._quantile_count = quantile_count + self._register_critic_network_kwargs(quantile_count=self._quantile_count) + + kwargs["critic_activations"] = critic_activations + + super().__init__(env, **kwargs) + + def _update_critic( + self, + critic_obs: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + dones: torch.Tensor, + actor_next_obs: torch.Tensor, + critic_next_obs: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + target_action, target_action_logp = self._sample_action(actor_next_obs) + target_critic_input = self._critic_input(critic_next_obs, target_action) + target_critic_prediction_1 = self.target_critic_1.forward(target_critic_input, distribution=True) + target_critic_prediction_2 = self.target_critic_2.forward(target_critic_input, distribution=True) + target_critic_prediction = torch.minimum(target_critic_prediction_1, target_critic_prediction_2) + + next_soft_q = target_critic_prediction - self.alpha * target_action_logp.unsqueeze(-1).repeat( + 1, self._quantile_count + ) + target = (rewards.reshape(-1, 1) + self._discount_factor * (1 - dones).reshape(-1, 1) * next_soft_q).detach() + + critic_input = self._critic_input(critic_obs, actions).detach() + critic_1_prediction = self.critic_1.forward(critic_input, distribution=True) + critic_1_loss = self.critic_1.quantile_huber_loss(critic_1_prediction, target) + + self.critic_1_optimizer.zero_grad() + critic_1_loss.backward() + nn.utils.clip_grad_norm_(self.critic_1.parameters(), self._gradient_clip) + self.critic_1_optimizer.step() + + critic_2_prediction = self.critic_2.forward(critic_input, distribution=True) + critic_2_loss = self.critic_2.quantile_huber_loss(critic_2_prediction, target) + + self.critic_2_optimizer.zero_grad() + critic_2_loss.backward() + nn.utils.clip_grad_norm_(self.critic_2.parameters(), self._gradient_clip) + self.critic_2_optimizer.step() + + return critic_1_loss, critic_2_loss diff --git a/rsl_rl/algorithms/dtd3.py b/rsl_rl/algorithms/dtd3.py new file mode 100644 index 0000000..e4ed595 --- /dev/null +++ b/rsl_rl/algorithms/dtd3.py @@ -0,0 +1,57 @@ +from __future__ import annotations +import torch +from torch import nn +from typing import Type + +from rsl_rl.algorithms.td3 import TD3 +from rsl_rl.env import VecEnv +from rsl_rl.modules import QuantileNetwork + + +class DTD3(TD3): + """Distributional Twin-Delayed Deep Deterministic Policy Gradients algorithm. + + This is an implementation of the TD3 algorithm by Fujimoto et. al. for vectorized environments using a QR-DQN + critic. + """ + + critic_network: Type[nn.Module] = QuantileNetwork + + def __init__( + self, + env: VecEnv, + quantile_count: int = 200, + **kwargs, + ) -> None: + self._quantile_count = quantile_count + self._register_critic_network_kwargs(quantile_count=self._quantile_count) + + super().__init__(env, **kwargs) + + def _update_critic(self, critic_obs, actions, rewards, dones, actor_next_obs, critic_next_obs): + target_action = self._apply_action_noise(self.target_actor.forward(actor_next_obs), clip=True) + target_critic_input = self._critic_input(critic_next_obs, target_action) + target_critic_prediction_1 = self.target_critic_1.forward(target_critic_input, distribution=True) + target_critic_prediction_2 = self.target_critic_2.forward(target_critic_input, distribution=True) + target_critic_prediction = torch.minimum(target_critic_prediction_1, target_critic_prediction_2) + + target = ( + rewards.reshape(-1, 1) + self._discount_factor * (1 - dones).reshape(-1, 1) * target_critic_prediction + ).detach() + + critic_input = self._critic_input(critic_obs, actions).detach() + critic_1_prediction = self.critic_1.forward(critic_input, distribution=True) + critic_1_loss = self.critic_1.quantile_huber_loss(critic_1_prediction, target) + + self.critic_1_optimizer.zero_grad() + critic_1_loss.backward() + self.critic_1_optimizer.step() + + critic_2_prediction = self.critic_2.forward(critic_input, distribution=True) + critic_2_loss = self.critic_2.quantile_huber_loss(critic_2_prediction, target) + + self.critic_2_optimizer.zero_grad() + critic_2_loss.backward() + self.critic_2_optimizer.step() + + return critic_1_loss, critic_2_loss diff --git a/rsl_rl/algorithms/hybrid.py b/rsl_rl/algorithms/hybrid.py new file mode 100644 index 0000000..11e223c --- /dev/null +++ b/rsl_rl/algorithms/hybrid.py @@ -0,0 +1,193 @@ +from abc import ABC, abstractmethod +import torch +from typing import Callable, Dict, Tuple, Type, Union + +from rsl_rl.algorithms import D4PG, DSAC +from rsl_rl.algorithms import TD3 +from rsl_rl.algorithms.actor_critic import AbstractActorCritic +from rsl_rl.algorithms.agent import Agent +from rsl_rl.env import VecEnv +from rsl_rl.storage.storage import Dataset, Storage + + +class AbstractHybridAgent(Agent, ABC): + def __init__( + self, + env: VecEnv, + agent_class: Type[Agent], + agent_kwargs: dict, + pretrain_agent_class: Type[Agent], + pretrain_agent_kwargs: dict, + pretrain_steps: int, + freeze_steps: int = 0, + **general_kwargs, + ): + """ + Args: + env (VecEnv): A vectorized environment. + """ + agent_kwargs["env"] = env + pretrain_agent_kwargs["env"] = env + + self.agent = agent_class(**agent_kwargs, **general_kwargs) + self.pretrain_agent = pretrain_agent_class(**pretrain_agent_kwargs, **general_kwargs) + + self._freeze_steps = freeze_steps + self._pretrain_steps = pretrain_steps + self._steps = 0 + + self._register_serializable("agent", "pretrain_agent", "_freeze_steps", "_pretrain_steps", "_steps") + + @property + def active_agent(self): + agent = self.pretrain_agent if self.pretraining else self.agent + + return agent + + def draw_actions(self, *args, **kwargs) -> Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]: + return self.active_agent.draw_actions(*args, **kwargs) + + def draw_random_actions(self, *args, **kwargs) -> Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]: + return self.active_agent.draw_random_actions(*args, **kwargs) + + def eval_mode(self, *args, **kwargs) -> Agent: + self.agent.eval_mode(*args, **kwargs) + + def get_inference_policy(self, *args, **kwargs) -> Callable: + return self.active_agent.get_inference_policy(*args, **kwargs) + + @property + def initialized(self) -> bool: + return self.active_agent.initialized + + @property + def pretraining(self): + return self._steps < self._pretrain_steps + + def process_dataset(self, *args, **kwargs) -> Dataset: + return self.active_agent.process_dataset(*args, **kwargs) + + def process_transition(self, *args, **kwargs) -> Dict[str, torch.Tensor]: + return self.active_agent.process_transition(*args, **kwargs) + + def register_terminations(self, *args, **kwargs) -> None: + return self.active_agent.register_terminations(*args, **kwargs) + + @property + def storage(self) -> Storage: + return self.active_agent.storage + + def to(self, *args, **kwargs) -> Agent: + self.agent.to(*args, **kwargs) + self.pretrain_agent.to(*args, **kwargs) + + def train_mode(self, *args, **kwargs) -> Agent: + self.agent.train_mode(*args, **kwargs) + self.pretrain_agent.train_mode(*args, **kwargs) + + def update(self, *args, **kwargs) -> Dict[str, Union[float, torch.Tensor]]: + result = self.active_agent.update(*args, **kwargs) + + if not self.active_agent.initialized: + return + + self._steps += 1 + + if self._steps == self._pretrain_steps: + self._transfer_weights() + self._freeze_weights(freeze=True) + + if self._steps == self._pretrain_steps + self._freeze_steps: + self._transfer_weights() + self._freeze_weights(freeze=False) + + return result + + @abstractmethod + def _freeze_weights(self, freeze=True): + pass + + @abstractmethod + def _transfer_weights(self): + pass + + +class HybridD4PG(AbstractHybridAgent): + def __init__( + self, + env: VecEnv, + d4pg_kwargs: dict, + pretrain_kwargs: dict, + pretrain_agent: Type[AbstractActorCritic] = TD3, + **kwargs, + ): + assert d4pg_kwargs["action_max"] == pretrain_kwargs["action_max"] + assert d4pg_kwargs["action_min"] == pretrain_kwargs["action_min"] + assert d4pg_kwargs["actor_activations"] == pretrain_kwargs["actor_activations"] + assert d4pg_kwargs["actor_hidden_dims"] == pretrain_kwargs["actor_hidden_dims"] + assert d4pg_kwargs["actor_input_normalization"] == pretrain_kwargs["actor_input_normalization"] + + super().__init__( + env, + agent_class=D4PG, + agent_kwargs=d4pg_kwargs, + pretrain_agent_class=pretrain_agent, + pretrain_agent_kwargs=pretrain_kwargs, + **kwargs, + ) + + def _freeze_weights(self, freeze=True): + for param in self.agent.actor.parameters(): + param.requires_grad = not freeze + + def _transfer_weights(self): + self.agent.actor.load_state_dict(self.pretrain_agent.actor.state_dict()) + self.agent.actor_optimizer.load_state_dict(self.pretrain_agent.actor_optimizer.state_dict()) + + +class HybridDSAC(AbstractHybridAgent): + def __init__( + self, + env: VecEnv, + dsac_kwargs: dict, + pretrain_kwargs: dict, + pretrain_agent: Type[AbstractActorCritic] = TD3, + **kwargs, + ): + assert dsac_kwargs["action_max"] == pretrain_kwargs["action_max"] + assert dsac_kwargs["action_min"] == pretrain_kwargs["action_min"] + assert dsac_kwargs["actor_activations"] == pretrain_kwargs["actor_activations"] + assert dsac_kwargs["actor_hidden_dims"] == pretrain_kwargs["actor_hidden_dims"] + assert dsac_kwargs["actor_input_normalization"] == pretrain_kwargs["actor_input_normalization"] + + super().__init__( + env, + agent_class=DSAC, + agent_kwargs=dsac_kwargs, + pretrain_agent_class=pretrain_agent, + pretrain_agent_kwargs=pretrain_kwargs, + **kwargs, + ) + + def _freeze_weights(self, freeze=True): + """Freezes actor layers. + + Freezes feature encoding and mean computation layers for gaussian network. Leaves log standard deviation layer + unfreezed. + """ + for param in self.agent.actor._layers.parameters(): + param.requires_grad = not freeze + + for param in self.agent.actor._mean_layer.parameters(): + param.requires_grad = not freeze + + def _transfer_weights(self): + """Transfers actor layers. + + Transfers only feature encoding and mean computation layers for gaussian network. + """ + for i, layer in enumerate(self.agent.actor._layers): + layer.load_state_dict(self.pretrain_agent.actor._layers[i].state_dict()) + + for j, layer in enumerate(self.agent.actor._mean_layer): + layer.load_state_dict(self.pretrain_agent.actor._layers[i + j + 1].state_dict()) diff --git a/rsl_rl/algorithms/ppo.py b/rsl_rl/algorithms/ppo.py index ffe5814..6a1e9b9 100644 --- a/rsl_rl/algorithms/ppo.py +++ b/rsl_rl/algorithms/ppo.py @@ -1,185 +1,384 @@ -# Copyright 2021 ETH Zurich, NVIDIA CORPORATION -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - import torch -import torch.nn as nn -import torch.optim as optim +from torch import nn, optim +from typing import Any, Dict, Tuple, Type, Union -from rsl_rl.modules import ActorCritic -from rsl_rl.storage import RolloutStorage +from rsl_rl.algorithms.actor_critic import AbstractActorCritic +from rsl_rl.env import VecEnv +from rsl_rl.utils.benchmarkable import Benchmarkable +from rsl_rl.utils.recurrency import trajectories_to_transitions, transitions_to_trajectories +from rsl_rl.modules import GaussianNetwork, Network +from rsl_rl.storage.rollout_storage import RolloutStorage +from rsl_rl.storage.storage import Dataset -class PPO: - actor_critic: ActorCritic +class PPO(AbstractActorCritic): + """Proximal Policy Optimization algorithm. + + This is an implementation of the PPO algorithm by Schulman et. al. for vectorized environments. + + Paper: https://arxiv.org/pdf/1707.06347.pdf + + The implementation works with recurrent neural networks. We implement adaptive learning rate based on the + KL-divergence between the old and new policy, as described by Schulman et. al. in + https://arxiv.org/pdf/1707.06347.pdf. + """ + + critic_network: Type[nn.Module] = Network + _alg_features = dict(recurrent=True) + + schedule_adaptive = "adaptive" + schedule_fixed = "fixed" def __init__( self, - actor_critic, - num_learning_epochs=1, - num_mini_batches=1, - clip_param=0.2, - gamma=0.998, - lam=0.95, - value_loss_coef=1.0, - entropy_coef=0.0, - learning_rate=1e-3, - max_grad_norm=1.0, - use_clipped_value_loss=True, - schedule="fixed", - desired_kl=0.01, - device="cpu", + env: VecEnv, + actor_noise_std: float = 1.0, + clip_ratio: float = 0.2, + entropy_coeff: float = 0.0, + gae_lambda: float = 0.97, + gradient_clip: float = 1.0, + learning_rate: float = 1e-3, + schedule: str = "fixed", + target_kl: float = 0.01, + value_coeff: float = 1.0, + **kwargs, ): - self.device = device + """ + Args: + env (VecEnv): A vectorized environment. + actor_noise_std (float): The standard deviation of the Gaussian noise to add to the actor network output. + clip_ratio (float): The clipping ratio for the PPO objective. + entropy_coeff (float): The coefficient for the entropy term in the PPO objective. + gae_lambda (float): The lambda parameter for the GAE computation. + gradient_clip (float): The gradient clipping value. + learning_rate (float): The learning rate for the actor and critic networks. + schedule (str): The learning rate schedule. Can be "fixed" or "adaptive". Defaults to "fixed". + target_kl (float): The target KL-divergence for the adaptive learning rate schedule. + value_coeff (float): The coefficient for the value function loss in the PPO objective. + """ + kwargs["batch_size"] = env.num_envs + kwargs["return_steps"] = 1 - self.desired_kl = desired_kl - self.schedule = schedule - self.learning_rate = learning_rate + super().__init__(env, **kwargs) + self._critic_input_size = self._critic_input_size - self._action_size # We use a state-value function (not Q) - # PPO components - self.actor_critic = actor_critic - self.actor_critic.to(self.device) - self.storage = None # initialized later - self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=learning_rate) - self.transition = RolloutStorage.Transition() + self.storage = RolloutStorage(self.env.num_envs, device=self.device) - # PPO parameters - self.clip_param = clip_param - self.num_learning_epochs = num_learning_epochs - self.num_mini_batches = num_mini_batches - self.value_loss_coef = value_loss_coef - self.entropy_coef = entropy_coef - self.gamma = gamma - self.lam = lam - self.max_grad_norm = max_grad_norm - self.use_clipped_value_loss = use_clipped_value_loss + self._register_serializable("storage") - def init_storage(self, num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape): - self.storage = RolloutStorage( - num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape, self.device + self._clip_ratio = clip_ratio + self._entropy_coeff = entropy_coeff + self._gae_lambda = gae_lambda + self._gradient_clip = gradient_clip + self._schedule = schedule + self._target_kl = target_kl + self._value_coeff = value_coeff + + self._register_serializable( + "_clip_ratio", + "_entropy_coeff", + "_gae_lambda", + "_gradient_clip", + "_schedule", + "_target_kl", + "_value_coeff", ) - def test_mode(self): - self.actor_critic.test() + self.actor = GaussianNetwork( + self._actor_input_size, self._action_size, std_init=actor_noise_std, **self._actor_network_kwargs + ) + self.critic = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs) - def train_mode(self): - self.actor_critic.train() + if self.recurrent: + self.actor.reset_full_hidden_state(batch_size=self.env.num_envs) + self.critic.reset_full_hidden_state(batch_size=self.env.num_envs) - def act(self, obs, critic_obs): - if self.actor_critic.is_recurrent: - self.transition.hidden_states = self.actor_critic.get_hidden_states() - # Compute the actions and values - self.transition.actions = self.actor_critic.act(obs).detach() - self.transition.values = self.actor_critic.evaluate(critic_obs).detach() - self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach() - self.transition.action_mean = self.actor_critic.action_mean.detach() - self.transition.action_sigma = self.actor_critic.action_std.detach() - # need to record obs and critic_obs before env.step() - self.transition.observations = obs - self.transition.critic_observations = critic_obs - return self.transition.actions + tp = lambda v: v.transpose(0, 1) # for storing, transpose (num_layers, batch, F) to (batch, num_layers, F) + self.storage.register_processor("actor_state_h", tp) + self.storage.register_processor("actor_state_c", tp) + self.storage.register_processor("critic_state_h", tp) + self.storage.register_processor("critic_state_c", tp) + self.storage.register_processor("critic_next_state_h", tp) + self.storage.register_processor("critic_next_state_c", tp) - def process_env_step(self, rewards, dones, infos): - self.transition.rewards = rewards.clone() - self.transition.dones = dones - # Bootstrapping on time outs - if "time_outs" in infos: - self.transition.rewards += self.gamma * torch.squeeze( - self.transition.values * infos["time_outs"].unsqueeze(1).to(self.device), 1 - ) + self._bm_fuse(self.actor, prefix="actor.") + self._bm_fuse(self.critic, prefix="critic.") - # Record the transition - self.storage.add_transitions(self.transition) - self.transition.clear() - self.actor_critic.reset(dones) + self._register_serializable("actor", "critic") - def compute_returns(self, last_critic_obs): - last_values = self.actor_critic.evaluate(last_critic_obs).detach() - self.storage.compute_returns(last_values, self.gamma, self.lam) + self.learning_rate = learning_rate + self.optimizer = optim.Adam(self.parameters(), lr=self.learning_rate) - def update(self): - mean_value_loss = 0 - mean_surrogate_loss = 0 - if self.actor_critic.is_recurrent: - generator = self.storage.reccurent_mini_batch_generator(self.num_mini_batches, self.num_learning_epochs) - else: - generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs) - for ( - obs_batch, - critic_obs_batch, - actions_batch, - target_values_batch, - advantages_batch, - returns_batch, - old_actions_log_prob_batch, - old_mu_batch, - old_sigma_batch, - hid_states_batch, - masks_batch, - ) in generator: - self.actor_critic.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0]) - actions_log_prob_batch = self.actor_critic.get_actions_log_prob(actions_batch) - value_batch = self.actor_critic.evaluate( - critic_obs_batch, masks=masks_batch, hidden_states=hid_states_batch[1] - ) - mu_batch = self.actor_critic.action_mean - sigma_batch = self.actor_critic.action_std - entropy_batch = self.actor_critic.entropy + self._register_serializable("learning_rate", "optimizer") - # KL - if self.desired_kl is not None and self.schedule == "adaptive": - with torch.inference_mode(): - kl = torch.sum( - torch.log(sigma_batch / old_sigma_batch + 1.0e-5) - + (torch.square(old_sigma_batch) + torch.square(old_mu_batch - mu_batch)) - / (2.0 * torch.square(sigma_batch)) - - 0.5, - axis=-1, - ) - kl_mean = torch.mean(kl) + def draw_random_actions(self, obs: torch.Tensor, env_info: Dict[str, Any]) -> torch.Tensor: + raise NotImplementedError("PPO does not support drawing random actions.") - if kl_mean > self.desired_kl * 2.0: - self.learning_rate = max(1e-5, self.learning_rate / 1.5) - elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0: - self.learning_rate = min(1e-2, self.learning_rate * 1.5) + @Benchmarkable.register + def draw_actions( + self, obs: torch.Tensor, env_info: Dict[str, Any] + ) -> Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]: + actor_obs, critic_obs = self._process_observations(obs, env_info) - for param_group in self.optimizer.param_groups: - param_group["lr"] = self.learning_rate + data = {} - # Surrogate loss - ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch)) - surrogate = -torch.squeeze(advantages_batch) * ratio - surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp( - ratio, 1.0 - self.clip_param, 1.0 + self.clip_param - ) - surrogate_loss = torch.max(surrogate, surrogate_clipped).mean() + if self.recurrent: + data["actor_state_h"] = self.actor.hidden_state[0].detach() + data["actor_state_c"] = self.actor.hidden_state[1].detach() - # Value function loss - if self.use_clipped_value_loss: - value_clipped = target_values_batch + (value_batch - target_values_batch).clamp( - -self.clip_param, self.clip_param - ) - value_losses = (value_batch - returns_batch).pow(2) - value_losses_clipped = (value_clipped - returns_batch).pow(2) - value_loss = torch.max(value_losses, value_losses_clipped).mean() + mean, std = self.actor.forward(actor_obs, compute_std=True) + action_distribution = torch.distributions.Normal(mean, std) + actions = self._process_actions(action_distribution.rsample()).detach() + action_prediction_logp = action_distribution.log_prob(actions).sum(-1) + + data["actor_observations"] = actor_obs + data["critic_observations"] = critic_obs + data["actions_logp"] = action_prediction_logp.detach() + data["actions_mean"] = action_distribution.mean.detach() + data["actions_std"] = action_distribution.stddev.detach() + + return actions, data + + def eval_mode(self) -> AbstractActorCritic: + super().eval_mode() + + self.actor.eval() + self.critic.eval() + + return self + + @property + def initialized(self) -> bool: + return True + + @Benchmarkable.register + def process_transition(self, *args) -> Dict[str, torch.Tensor]: + transition = super().process_transition(*args) + + if self.recurrent: + transition["critic_state_h"] = self.critic.hidden_state[0].detach() + transition["critic_state_c"] = self.critic.hidden_state[1].detach() + + transition["values"] = self.critic.forward(transition["critic_observations"]).detach() + + if self.recurrent: + transition["critic_next_state_h"] = self.critic.hidden_state[0].detach() + transition["critic_next_state_c"] = self.critic.hidden_state[1].detach() + + return transition + + def parameters(self): + params = list(self.actor.parameters()) + list(self.critic.parameters()) + + return params + + def register_terminations(self, terminations: torch.Tensor) -> None: + """Registers terminations with the agent. + + Args: + terminations (torch.Tensor): A 1-dimensional int tensor containing the indices of the terminated + environments. + """ + if terminations.shape[0] == 0: + return + + if self.recurrent: + self.actor.reset_hidden_state(terminations) + self.critic.reset_hidden_state(terminations) + + def to(self, device: str) -> AbstractActorCritic: + super().to(device) + + self.actor.to(device) + self.critic.to(device) + + return self + + def train_mode(self) -> AbstractActorCritic: + super().train_mode() + + self.actor.train() + self.critic.train() + + return self + + @Benchmarkable.register + def update(self, dataset: Dataset) -> Dict[str, Union[float, torch.Tensor]]: + super().update(dataset) + + assert self.storage.initialized + + total_loss = torch.zeros(self._batch_count) + total_surrogate_loss = torch.zeros(self._batch_count) + total_value_loss = torch.zeros(self._batch_count) + + for idx, batch in enumerate(self.storage.batch_generator(self._batch_count, trajectories=self.recurrent)): + if self.recurrent: + transition_obs = batch["actor_observations"].reshape(*batch["actor_observations"].shape[:2], -1) + observations, data = transitions_to_trajectories(transition_obs, batch["dones"]) + hidden_state_h, _ = transitions_to_trajectories(batch["actor_state_h"], batch["dones"]) + hidden_state_c, _ = transitions_to_trajectories(batch["actor_state_c"], batch["dones"]) + # Init. sequence with each trajectory's first hidden state. Subsequent hidden states are produced by the + # network, depending on the previous hidden state and the current observation. + hidden_state = (hidden_state_h[0].transpose(0, 1), hidden_state_c[0].transpose(0, 1)) + + action_mean, action_std = self.actor.forward(observations, hidden_state=hidden_state, compute_std=True) + + action_mean = action_mean.reshape(*observations.shape[:-1], self._action_size) + action_std = action_std.reshape(*observations.shape[:-1], self._action_size) + + action_mean = trajectories_to_transitions(action_mean, data) + action_std = trajectories_to_transitions(action_std, data) else: - value_loss = (returns_batch - value_batch).pow(2).mean() + action_mean, action_std = self.actor.forward(batch["actor_observations"], compute_std=True) - loss = surrogate_loss + self.value_loss_coef * value_loss - self.entropy_coef * entropy_batch.mean() + actions_dist = torch.distributions.Normal(action_mean, action_std) + + if self._schedule == self.schedule_adaptive: + self._update_learning_rate(batch, actions_dist) + + surrogate_loss = self._compute_actor_loss(batch, actions_dist) + value_loss = self._compute_value_loss(batch) + actions_entropy = actions_dist.entropy().sum(-1) + + loss = surrogate_loss + self._value_coeff * value_loss - self._entropy_coeff * actions_entropy.mean() - # Gradient step self.optimizer.zero_grad() loss.backward() - nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.max_grad_norm) + nn.utils.clip_grad_norm_(self.parameters(), self._gradient_clip) self.optimizer.step() - mean_value_loss += value_loss.item() - mean_surrogate_loss += surrogate_loss.item() + total_loss[idx] = loss.detach() + total_surrogate_loss[idx] = surrogate_loss.detach() + total_value_loss[idx] = value_loss.detach() - num_updates = self.num_learning_epochs * self.num_mini_batches - mean_value_loss /= num_updates - mean_surrogate_loss /= num_updates - self.storage.clear() + stats = { + "total": total_loss.mean().item(), + "surrogate": total_surrogate_loss.mean().item(), + "value": total_value_loss.mean().item(), + } - return mean_value_loss, mean_surrogate_loss + return stats + + @Benchmarkable.register + def _compute_actor_loss( + self, batch: Dict[str, torch.Tensor], actions_dist: torch.distributions.Normal + ) -> torch.Tensor: + batch_actions_logp = actions_dist.log_prob(batch["actions"]).sum(-1) + + ratio = (batch_actions_logp - batch["actions_logp"]).exp() + surrogate = batch["normalized_advantages"] * ratio + surrogate_clipped = batch["normalized_advantages"] * ratio.clamp(1.0 - self._clip_ratio, 1.0 + self._clip_ratio) + surrogate_loss = -torch.min(surrogate, surrogate_clipped).mean() + + return surrogate_loss + + @Benchmarkable.register + def _compute_value_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + if self.recurrent: + observations, data = transitions_to_trajectories(batch["critic_observations"], batch["dones"]) + hidden_state_h, _ = transitions_to_trajectories(batch["critic_state_h"], batch["dones"]) + hidden_state_c, _ = transitions_to_trajectories(batch["critic_state_c"], batch["dones"]) + hidden_states = (hidden_state_h[0].transpose(0, 1), hidden_state_c[0].transpose(0, 1)) + + trajectory_evaluations = self.critic.forward(observations, hidden_state=hidden_states) + trajectory_evaluations = trajectory_evaluations.reshape(*observations.shape[:-1]) + + evaluation = trajectories_to_transitions(trajectory_evaluations, data) + else: + evaluation = self.critic.forward(batch["critic_observations"]) + + value_clipped = batch["values"] + (evaluation - batch["values"]).clamp(-self._clip_ratio, self._clip_ratio) + returns = batch["advantages"] + batch["values"] + value_losses = (evaluation - returns).pow(2) + value_losses_clipped = (value_clipped - returns).pow(2) + + value_loss = torch.max(value_losses, value_losses_clipped).mean() + + return value_loss + + def _critic_input(self, observations, actions=None) -> torch.Tensor: + return observations + + def _entry_to_hs( + self, entry: Dict[str, torch.Tensor], critic: bool = False, next: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Helper function to turn a dataset entry into a hidden state tuple. + + Args: + entry (Dict[str, torch.Tensor]): The dataset entry. + critic (bool): Whether to extract the hidden state for the critic instead of the actor. Defaults to False. + next (bool): Whether the hidden state is for the next step or the current. Defaults to False. + Returns: + A tuple of hidden state tensors. + """ + key = ("critic" if critic else "actor") + "_" + ("next_state" if next else "state") + hidden_state = entry[f"{key}_h"], entry[f"{key}_c"] + + return hidden_state + + @Benchmarkable.register + def _process_dataset(self, dataset: Dataset) -> Dataset: + """Processes a dataset before it is added to the replay memory. + + Computes advantages and returns. + + Args: + dataset (Dataset): The dataset to process. + Returns: + A Dataset object containing the processed data. + """ + rewards = torch.stack([entry["rewards"] for entry in dataset]) + dones = torch.stack([entry["dones"] for entry in dataset]) + timeouts = torch.stack([entry["timeouts"] for entry in dataset]) + values = torch.stack([entry["values"] for entry in dataset]) + + # We could alternatively compute the next hidden state from the current state and hidden state. But this + # (storing the hidden state when evaluating the action in process_transition) is computationally more efficient + # and doesn't change the result as the network is not updated between storing the data and computing advantages. + + critic_kwargs = ( + {"hidden_state": (dataset[-1]["critic_state_h"], dataset[-1]["critic_state_c"])} if self.recurrent else {} + ) + final_values = self.critic.forward(dataset[-1]["next_critic_observations"], **critic_kwargs) + next_values = torch.cat((values[1:], final_values.unsqueeze(0)), dim=0) + + rewards += self.gamma * timeouts * values + deltas = (rewards + (1 - dones).float() * self.gamma * next_values - values).reshape(-1, self.env.num_envs) + + advantages = torch.zeros((len(dataset) + 1, self.env.num_envs), device=self.device) + for step in reversed(range(len(dataset))): + advantages[step] = ( + deltas[step] + (1 - dones[step]).float() * self.gamma * self._gae_lambda * advantages[step + 1] + ) + advantages = advantages[:-1] + + amean, astd = advantages.mean(), torch.nan_to_num(advantages.std()) + for step in range(len(dataset)): + dataset[step]["advantages"] = advantages[step] + dataset[step]["normalized_advantages"] = (advantages[step] - amean) / (astd + 1e-8) + + return dataset + + @Benchmarkable.register + def _update_learning_rate(self, batch: Dict[str, torch.Tensor], actions_dist: torch.distributions.Normal) -> None: + with torch.inference_mode(): + actions_mean = actions_dist.mean + actions_std = actions_dist.stddev + + kl = torch.sum( + torch.log(actions_std / batch["actions_std"] + 1.0e-5) + + (torch.square(batch["actions_std"]) + torch.square(batch["actions_mean"] - actions_mean)) + / (2.0 * torch.square(actions_std)) + - 0.5, + axis=-1, + ) + kl_mean = torch.mean(kl) + + if kl_mean > self._target_kl * 2.0: + self.learning_rate = max(1e-5, self.learning_rate / 1.5) + elif kl_mean < self._target_kl / 2.0 and kl_mean > 0.0: + self.learning_rate = min(1e-2, self.learning_rate * 1.5) + + for param_group in self.optimizer.param_groups: + param_group["lr"] = self.learning_rate diff --git a/rsl_rl/algorithms/sac.py b/rsl_rl/algorithms/sac.py new file mode 100644 index 0000000..2f86f94 --- /dev/null +++ b/rsl_rl/algorithms/sac.py @@ -0,0 +1,319 @@ +import numpy as np +import torch +from torch import nn, optim +from typing import Any, Callable, Dict, Tuple, Type, Union + +from rsl_rl.algorithms.actor_critic import AbstractActorCritic +from rsl_rl.env import VecEnv +from rsl_rl.modules import Network, GaussianChimeraNetwork, GaussianNetwork +from rsl_rl.storage.replay_storage import ReplayStorage +from rsl_rl.storage.storage import Dataset + + +class SAC(AbstractActorCritic): + """Soft Actor Critic algorithm. + + This is an implementation of the SAC algorithm by Haarnoja et. al. for vectorized environments. + + Paper: https://arxiv.org/pdf/1801.01290.pdf + + We additionally implement automatic tuning of the temperature parameter (alpha) and tanh action scaling, as + introduced by Haarnoja et. al. in https://arxiv.org/pdf/1812.05905.pdf. + """ + + critic_network: Type[nn.Module] = Network + + def __init__( + self, + env: VecEnv, + action_max: float = 100.0, + action_min: float = -100.0, + actor_lr: float = 1e-4, + actor_noise_std: float = 1.0, + alpha: float = 0.2, + alpha_lr: float = 1e-3, + chimera: bool = True, + critic_lr: float = 1e-3, + gradient_clip: float = 1.0, + log_std_max: float = 4.0, + log_std_min: float = -20.0, + storage_initial_size: int = 0, + storage_size: int = 100000, + target_entropy: float = None, + **kwargs + ): + """ + Args: + env (VecEnv): A vectorized environment. + actor_lr (float): Learning rate for the actor. + alpha (float): Initial entropy regularization coefficient. + alpha_lr (float): Learning rate for entropy regularization coefficient. + chimera (bool): Whether to use separate heads for computing action mean and std (True) or treat the std as a + tunable parameter (True). + critic_lr (float): Learning rate for the critic. + gradient_clip (float): Gradient clip value. + log_std_max (float): Maximum log standard deviation. + log_std_min (float): Minimum log standard deviation. + storage_initial_size (int): Initial size of the replay storage. + storage_size (int): Maximum size of the replay storage. + target_entropy (float): Target entropy for the actor policy. Defaults to action space dimensionality. + """ + super().__init__(env, action_max=action_max, action_min=action_min, **kwargs) + + self.storage = ReplayStorage( + self.env.num_envs, storage_size, device=self.device, initial_size=storage_initial_size + ) + + self._register_serializable("storage") + + assert self._action_max < np.inf, 'Parameter "action_max" needs to be set for SAC.' + assert self._action_min > -np.inf, 'Parameter "action_min" needs to be set for SAC.' + + self._action_delta = 0.5 * (self._action_max - self._action_min) + self._action_offset = 0.5 * (self._action_max + self._action_min) + + self.log_alpha = torch.tensor(np.log(alpha), dtype=torch.float32).requires_grad_() + self._gradient_clip = gradient_clip + self._target_entropy = target_entropy if target_entropy else -self._action_size + + self._register_serializable("log_alpha", "_gradient_clip") + + network_class = GaussianChimeraNetwork if chimera else GaussianNetwork + self.actor = network_class( + self._actor_input_size, + self._action_size, + log_std_max=log_std_max, + log_std_min=log_std_min, + std_init=actor_noise_std, + **self._actor_network_kwargs + ) + + self.critic_1 = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs) + self.critic_2 = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs) + + self.target_critic_1 = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs) + self.target_critic_2 = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs) + self.target_critic_1.load_state_dict(self.critic_1.state_dict()) + self.target_critic_2.load_state_dict(self.critic_2.state_dict()) + + self._register_serializable("actor", "critic_1", "critic_2", "target_critic_1", "target_critic_2") + + self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr) + self.log_alpha_optimizer = optim.Adam([self.log_alpha], lr=alpha_lr) + self.critic_1_optimizer = optim.Adam(self.critic_1.parameters(), lr=critic_lr) + self.critic_2_optimizer = optim.Adam(self.critic_2.parameters(), lr=critic_lr) + + self._register_serializable( + "actor_optimizer", "log_alpha_optimizer", "critic_1_optimizer", "critic_2_optimizer" + ) + + self.critic = self.critic_1 + + @property + def alpha(self): + return self.log_alpha.exp() + + def draw_actions( + self, obs: torch.Tensor, env_info: Dict[str, Any] + ) -> Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]: + actor_obs, critic_obs = self._process_observations(obs, env_info) + + action = self._sample_action(actor_obs, compute_logp=False) + data = {"actor_observations": actor_obs.clone(), "critic_observations": critic_obs.clone()} + + return action, data + + def eval_mode(self) -> AbstractActorCritic: + super().eval_mode() + + self.actor.eval() + self.critic_1.eval() + self.critic_2.eval() + self.target_critic_1.eval() + self.target_critic_2.eval() + + return self + + def get_inference_policy(self, device=None) -> Callable: + self.to(device) + self.eval_mode() + + def policy(obs, env_info=None): + obs, _ = self._process_observations(obs, env_info) + actions = self._scale_actions(self.actor.forward(obs)) + + # actions, _ = self.draw_actions(obs, env_info) + + return actions + + return policy + + def register_terminations(self, terminations: torch.Tensor) -> None: + pass + + def to(self, device: str) -> AbstractActorCritic: + """Transfers agent parameters to device.""" + super().to(device) + + self.actor.to(device) + self.critic_1.to(device) + self.critic_2.to(device) + self.target_critic_1.to(device) + self.target_critic_2.to(device) + self.log_alpha.to(device) + + return self + + def train_mode(self) -> AbstractActorCritic: + super().train_mode() + + self.actor.train() + self.critic_1.train() + self.critic_2.train() + self.target_critic_1.train() + self.target_critic_2.train() + + return self + + def update(self, dataset: Dataset) -> Dict[str, Union[float, torch.Tensor]]: + super().update(dataset) + + if not self.initialized: + return {} + + total_actor_loss = torch.zeros(self._batch_count) + total_alpha_loss = torch.zeros(self._batch_count) + total_critic_1_loss = torch.zeros(self._batch_count) + total_critic_2_loss = torch.zeros(self._batch_count) + + for idx, batch in enumerate(self.storage.batch_generator(self._batch_size, self._batch_count)): + actor_obs = batch["actor_observations"] + critic_obs = batch["critic_observations"] + actions = batch["actions"].reshape(-1, self._action_size) + rewards = batch["rewards"] + actor_next_obs = batch["next_actor_observations"] + critic_next_obs = batch["next_critic_observations"] + dones = batch["dones"] + + critic_1_loss, critic_2_loss = self._update_critic( + critic_obs, actions, rewards, dones, actor_next_obs, critic_next_obs + ) + actor_loss, alpha_loss = self._update_actor_and_alpha(actor_obs, critic_obs) + + # Update Target Networks + + self._update_target(self.critic_1, self.target_critic_1) + self._update_target(self.critic_2, self.target_critic_2) + + total_actor_loss[idx] = actor_loss.item() + total_alpha_loss[idx] = alpha_loss.item() + total_critic_1_loss[idx] = critic_1_loss.item() + total_critic_2_loss[idx] = critic_2_loss.item() + + stats = { + "actor": total_actor_loss.mean().item(), + "alpha": total_alpha_loss.mean().item(), + "critic1": total_critic_1_loss.mean().item(), + "critic2": total_critic_2_loss.mean().item(), + } + + return stats + + def _sample_action( + self, observation: torch.Tensor, compute_logp=True + ) -> Union[torch.Tensor, Tuple[torch.Tensor, float]]: + """Samples and action from the policy. + + Args: + observation (torch.Tensor): The observation to sample an action for. + compute_logp (bool): Whether to compute and return the action log probability. Default to True. + Returns: + Either the action as a torch.Tensor or, if compute_logp is set to true, a tuple containing the actions as a + torch.Tensor and the action log probability. + """ + mean, std = self.actor.forward(observation, compute_std=True) + dist = torch.distributions.Normal(mean, std) + + actions = dist.rsample() + actions_normalized, actions_scaled = self._scale_actions(actions, intermediate=True) + + if not compute_logp: + return actions_scaled + + action_logp = dist.log_prob(actions).sum(-1) - torch.log(1.0 - actions_normalized.pow(2) + 1e-6).sum(-1) + + return actions_scaled, action_logp + + def _scale_actions(self, actions: torch.Tensor, intermediate=False) -> torch.Tensor: + actions = actions.reshape(-1, self._action_size) + action_normalized = torch.tanh(actions) + action_scaled = super()._process_actions(action_normalized * self._action_delta + self._action_offset) + + if intermediate: + return action_normalized, action_scaled + + return action_scaled + + def _update_actor_and_alpha( + self, actor_obs: torch.Tensor, critic_obs: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + actor_prediction, actor_prediction_logp = self._sample_action(actor_obs) + + # Update alpha (also called temperature / entropy coefficient) + alpha_loss = -(self.log_alpha * (actor_prediction_logp + self._target_entropy).detach()).mean() + + self.log_alpha_optimizer.zero_grad() + alpha_loss.backward() + self.log_alpha_optimizer.step() + + # Update actor + evaluation_input = self._critic_input(critic_obs, actor_prediction) + evaluation_1 = self.critic_1.forward(evaluation_input) + evaluation_2 = self.critic_2.forward(evaluation_input) + actor_loss = (self.alpha.detach() * actor_prediction_logp - torch.min(evaluation_1, evaluation_2)).mean() + + self.actor_optimizer.zero_grad() + actor_loss.backward() + nn.utils.clip_grad_norm_(self.actor.parameters(), self._gradient_clip) + self.actor_optimizer.step() + + return actor_loss, alpha_loss + + def _update_critic( + self, + critic_obs: torch.Tensor, + actions: torch.Tensor, + rewards: torch.Tensor, + dones: torch.Tensor, + actor_next_obs: torch.Tensor, + critic_next_obs: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + with torch.no_grad(): + target_action, target_action_logp = self._sample_action(actor_next_obs) + target_critic_input = self._critic_input(critic_next_obs, target_action) + target_critic_prediction_1 = self.target_critic_1.forward(target_critic_input) + target_critic_prediction_2 = self.target_critic_2.forward(target_critic_input) + + target_next = ( + torch.min(target_critic_prediction_1, target_critic_prediction_2) - self.alpha * target_action_logp + ) + target = rewards + self._discount_factor * (1 - dones) * target_next + + critic_input = self._critic_input(critic_obs, actions) + critic_prediction_1 = self.critic_1.forward(critic_input) + critic_1_loss = nn.functional.mse_loss(critic_prediction_1, target) + + self.critic_1_optimizer.zero_grad() + critic_1_loss.backward() + nn.utils.clip_grad_norm_(self.critic_1.parameters(), self._gradient_clip) + self.critic_1_optimizer.step() + + critic_prediction_2 = self.critic_2.forward(critic_input) + critic_2_loss = nn.functional.mse_loss(critic_prediction_2, target) + + self.critic_2_optimizer.zero_grad() + critic_2_loss.backward() + nn.utils.clip_grad_norm_(self.critic_2.parameters(), self._gradient_clip) + self.critic_2_optimizer.step() + + return critic_1_loss, critic_2_loss diff --git a/rsl_rl/algorithms/td3.py b/rsl_rl/algorithms/td3.py new file mode 100644 index 0000000..7e5850f --- /dev/null +++ b/rsl_rl/algorithms/td3.py @@ -0,0 +1,198 @@ +from __future__ import annotations +import torch +from torch import nn, optim +from typing import Dict, Type, Union + +from rsl_rl.algorithms.dpg import AbstractDPG +from rsl_rl.env import VecEnv +from rsl_rl.modules.network import Network +from rsl_rl.storage.storage import Dataset + + +class TD3(AbstractDPG): + """Twin-Delayed Deep Deterministic Policy Gradients algorithm. + + This is an implementation of the TD3 algorithm by Fujimoto et. al. for vectorized environments. + + Paper: https://arxiv.org/pdf/1802.09477.pdf + """ + + critic_network: Type[nn.Module] = Network + + def __init__( + self, + env: VecEnv, + actor_lr: float = 1e-4, + critic_lr: float = 1e-3, + noise_clip: float = 0.5, + policy_delay: int = 2, + target_noise_scale: float = 0.2, + **kwargs, + ) -> None: + super().__init__(env, **kwargs) + + self._noise_clip = noise_clip + self._policy_delay = policy_delay + self._target_noise_scale = target_noise_scale + + self._register_serializable("_noise_clip", "_policy_delay", "_target_noise_scale") + + self.actor = Network(self._actor_input_size, self._action_size, **self._actor_network_kwargs) + self.critic_1 = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs) + self.critic_2 = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs) + + self.target_actor = Network(self._actor_input_size, self._action_size, **self._actor_network_kwargs) + self.target_critic_1 = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs) + self.target_critic_2 = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs) + self.target_actor.load_state_dict(self.actor.state_dict()) + self.target_critic_1.load_state_dict(self.critic_1.state_dict()) + self.target_critic_2.load_state_dict(self.critic_2.state_dict()) + + self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr) + self.critic_1_optimizer = optim.Adam(self.critic_1.parameters(), lr=critic_lr) + self.critic_2_optimizer = optim.Adam(self.critic_2.parameters(), lr=critic_lr) + + self._update_step = 0 + + self._register_serializable( + "actor", + "critic_1", + "critic_2", + "target_actor", + "target_critic_1", + "target_critic_2", + "actor_optimizer", + "critic_1_optimizer", + "critic_2_optimizer", + "_update_step", + ) + + self.critic = self.critic_1 + self.to(self.device) + + def eval_mode(self) -> TD3: + super().eval_mode() + + self.actor.eval() + self.critic_1.eval() + self.critic_2.eval() + self.target_actor.eval() + self.target_critic_1.eval() + self.target_critic_2.eval() + + return self + + def to(self, device: str) -> TD3: + """Transfers agent parameters to device.""" + super().to(device) + + self.actor.to(device) + self.critic_1.to(device) + self.critic_2.to(device) + self.target_actor.to(device) + self.target_critic_1.to(device) + self.target_critic_2.to(device) + + return self + + def train_mode(self) -> TD3: + super().train_mode() + + self.actor.train() + self.critic_1.train() + self.critic_2.train() + self.target_actor.train() + self.target_critic_1.train() + self.target_critic_2.train() + + return self + + def _apply_action_noise(self, actions: torch.Tensor, clip=False) -> torch.Tensor: + noise = torch.normal(torch.zeros_like(actions), torch.ones_like(actions) * self._action_noise_scale) + + if clip: + noise = noise.clamp(-self._noise_clip, self._noise_clip) + + noisy_actions = self._process_actions(actions + noise) + + return noisy_actions + + def update(self, dataset: Dataset) -> Dict[str, Union[float, torch.Tensor]]: + super().update(dataset) + + if not self.initialized: + return {} + + total_actor_loss = torch.zeros(self._batch_count) + total_critic_1_loss = torch.zeros(self._batch_count) + total_critic_2_loss = torch.zeros(self._batch_count) + + for idx, batch in enumerate(self.storage.batch_generator(self._batch_size, self._batch_count)): + actor_obs = batch["actor_observations"] + critic_obs = batch["critic_observations"] + actions = batch["actions"].reshape(self._batch_size, -1) + rewards = batch["rewards"] + actor_next_obs = batch["next_actor_observations"] + critic_next_obs = batch["next_critic_observations"] + dones = batch["dones"] + + critic_1_loss, critic_2_loss = self._update_critic( + critic_obs, actions, rewards, dones, actor_next_obs, critic_next_obs + ) + + if self._update_step % self._policy_delay == 0: + evaluation = self.critic_1.forward( + self._critic_input(critic_obs, self._process_actions(self.actor.forward(actor_obs))) + ) + actor_loss = -evaluation.mean() + + self.actor_optimizer.zero_grad() + actor_loss.backward() + self.actor_optimizer.step() + + self._update_target(self.actor, self.target_actor) + self._update_target(self.critic_1, self.target_critic_1) + self._update_target(self.critic_2, self.target_critic_2) + + total_actor_loss[idx] = actor_loss.item() + + self._update_step = self._update_step + 1 + + total_critic_1_loss[idx] = critic_1_loss.item() + total_critic_2_loss[idx] = critic_2_loss.item() + + stats = { + "actor": total_actor_loss.mean().item(), + "critic1": total_critic_1_loss.mean().item(), + "critic2": total_critic_2_loss.mean().item(), + } + + return stats + + def _update_critic(self, critic_obs, actions, rewards, dones, actor_next_obs, critic_next_obs): + target_actor_prediction = self._apply_action_noise(self.target_actor.forward(actor_next_obs), clip=True) + target_critic_1_prediction = self.target_critic_1.forward( + self._critic_input(critic_next_obs, target_actor_prediction) + ) + target_critic_2_prediction = self.target_critic_2.forward( + self._critic_input(critic_next_obs, target_actor_prediction) + ) + target_critic_prediction = torch.min(target_critic_1_prediction, target_critic_2_prediction) + + target = (rewards + self._discount_factor * (1 - dones) * target_critic_prediction).detach() + + prediction_1 = self.critic_1.forward(self._critic_input(critic_obs, actions)) + critic_1_loss = (prediction_1 - target).pow(2).mean() + + self.critic_1_optimizer.zero_grad() + critic_1_loss.backward() + self.critic_1_optimizer.step() + + prediction_2 = self.critic_2.forward(self._critic_input(critic_obs, actions)) + critic_2_loss = (prediction_2 - target).pow(2).mean() + + self.critic_2_optimizer.zero_grad() + critic_2_loss.backward() + self.critic_2_optimizer.step() + + return critic_1_loss, critic_2_loss diff --git a/rsl_rl/distributions/__init__.py b/rsl_rl/distributions/__init__.py new file mode 100644 index 0000000..ac65504 --- /dev/null +++ b/rsl_rl/distributions/__init__.py @@ -0,0 +1,2 @@ +from .distribution import Distribution +from .quantile_distribution import QuantileDistribution diff --git a/rsl_rl/distributions/distribution.py b/rsl_rl/distributions/distribution.py new file mode 100644 index 0000000..01f9184 --- /dev/null +++ b/rsl_rl/distributions/distribution.py @@ -0,0 +1,18 @@ +from abc import ABC, abstractmethod +import torch + + +class Distribution(ABC): + def __init__(self, params: torch.Tensor) -> None: + self._params = params + + @abstractmethod + def sample(self, sample_count: int = 1) -> torch.Tensor: + """Sample from the distribution. + + Args: + sample_count: The number of samples to draw. + Returns: + A tensor of shape (sample_count, *parameter_shape). + """ + pass diff --git a/rsl_rl/distributions/quantile_distribution.py b/rsl_rl/distributions/quantile_distribution.py new file mode 100644 index 0000000..31aee23 --- /dev/null +++ b/rsl_rl/distributions/quantile_distribution.py @@ -0,0 +1,13 @@ +import torch + +from .distribution import Distribution + + +class QuantileDistribution(Distribution): + def sample(self, sample_count: int = 1) -> torch.Tensor: + idx = torch.randint( + self._params.shape[-1], (*self._params.shape[:-1], sample_count), device=self._params.device + ) + samples = torch.take_along_dim(self._params, idx, -1) + + return samples, idx diff --git a/rsl_rl/env/__init__.py b/rsl_rl/env/__init__.py index 54c6491..5c26f6e 100644 --- a/rsl_rl/env/__init__.py +++ b/rsl_rl/env/__init__.py @@ -1,6 +1,5 @@ # Copyright 2021 ETH Zurich, NVIDIA CORPORATION # SPDX-License-Identifier: BSD-3-Clause - """Submodule defining the environment definitions.""" from .vec_env import VecEnv diff --git a/rsl_rl/env/gym_env.py b/rsl_rl/env/gym_env.py new file mode 100644 index 0000000..289c2c7 --- /dev/null +++ b/rsl_rl/env/gym_env.py @@ -0,0 +1,120 @@ +from datetime import datetime +import gym +import torch +from typing import Any, Dict, Tuple, Union + +from rsl_rl.env.vec_env import VecEnv + + +class GymEnv(VecEnv): + """A vectorized environment wrapper for OpenAI Gym environments. + + This class wraps a single OpenAI Gym environment into a vectorized environment. It is assumed that the environment + is a single agent environment. The environment is wrapped in a `gym.vector.SyncVectorEnv` environment, which + allows for parallel execution of multiple environments. + """ + + def __init__(self, name, draw=False, draw_cb=None, draw_directory="videos/", gym_kwargs={}, **kwargs): + """ + Args: + name: The name of the OpenAI Gym environment. + draw: Whether to record videos of the environment. + draw_cb: A callback function that is called after each episode. The callback function is passed the episode + number and the path to the video file. The callback function should return `True` if the video should + be recorded and `False` otherwise. + draw_directory: The directory in which to store the videos. + gym_kwargs: Keyword arguments that are passed to the OpenAI Gym environment. + **kwargs: Keyword arguments that are passed to the `VecEnv` constructor. + """ + self._gym_kwargs = gym_kwargs + + env = gym.make(name, **self._gym_kwargs) + + assert isinstance(env.observation_space, gym.spaces.Box) + assert len(env.observation_space.shape) == 1 + assert isinstance(env.action_space, gym.spaces.Box) + assert len(env.action_space.shape) == 1 + + super().__init__(env.observation_space.shape[0], env.observation_space.shape[0], **kwargs) + + self.name = name + self.draw_directory = draw_directory + + self.num_actions = env.action_space.shape[0] + self._gym_venv = gym.vector.SyncVectorEnv( + [lambda: gym.make(self.name, **self._gym_kwargs) for _ in range(self.num_envs)] + ) + + self._draw = False + self._draw_cb = draw_cb if draw_cb is not None else lambda *args: True + self._draw_uuid = None + self.draw = draw + + self.reset() + + def close(self) -> None: + self._gym_venv.close() + + def get_observations(self) -> Tuple[torch.Tensor, Dict[str, Any]]: + return self.obs_buf, self.extras + + def get_privileged_observations(self) -> Union[torch.Tensor, None]: + return self.obs_buf, self.extras + + @property + def draw(self) -> bool: + return self._draw + + @draw.setter + def draw(self, value: bool) -> None: + if value != self._draw: + if value: + self._draw_uuid = datetime.now().strftime("%Y%m%d%H%M%S") + env = gym.make(self.name, render_mode="rgb_array", **self._gym_kwargs) + env = gym.wrappers.RecordVideo( + env, + f"{self.draw_directory}/{self._draw_uuid}/", + episode_trigger=lambda ep: ( + self._draw_cb(ep - 1, f"{self.draw_directory}/{self._draw_uuid}/rl-video-episode-{ep-1}.mp4") + or True + ) + if ep > 0 + else False, + ) + else: + env = gym.make(self.name, render_mode=None, **self._gym_kwargs) + + self._gym_venv.envs[0] = env + self._draw = value + + self.reset() + + def reset(self) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: + self.obs_buf = torch.from_numpy(self._gym_venv.reset()[0]).float().to(self.device) + self.rew_buf = torch.zeros((self.num_envs,), device=self.device).float() + self.reset_buf = torch.zeros((self.num_envs,), device=self.device).float() + self.extras = {"observations": {}, "time_outs": torch.zeros((self.num_envs,), device=self.device).float()} + + return self.obs_buf, self.extras + + def step( + self, actions: torch.Tensor + ) -> Tuple[torch.Tensor, Union[torch.Tensor, None], torch.Tensor, torch.Tensor, Dict[str, Any]]: + obs, rew, reset, term, _ = self._gym_venv.step(actions.cpu().numpy()) + + self.obs_buf = torch.from_numpy(obs).float().to(self.device) + self.rew_buf = torch.from_numpy(rew).float().to(self.device) + self.reset_buf = torch.from_numpy(reset).float().to(self.device) + self.extras = { + "observations": {}, + "time_outs": torch.from_numpy(term).float().to(self.device).float().to(self.device), + } + + return self.obs_buf, self.rew_buf, self.reset_buf, self.extras + + def to(self, device: str) -> None: + self.device = device + + self.obs_buf = self.obs_buf.to(device) + self.rew_buf = self.rew_buf.to(device) + self.reset_buf = self.reset_buf.to(device) diff --git a/rsl_rl/env/pole_balancing.py b/rsl_rl/env/pole_balancing.py new file mode 100644 index 0000000..28c6dc0 --- /dev/null +++ b/rsl_rl/env/pole_balancing.py @@ -0,0 +1,137 @@ +import math +import numpy as np +import time +import matplotlib.pyplot as plt +import torch +from typing import Any, Dict, Tuple, Union + +from rsl_rl.env.vec_env import VecEnv + + +class PoleBalancing(VecEnv): + """Custom pole balancing environment. + + This class implements a custom pole balancing environment. It demonstrates how to implement a custom `VecEnv` + environment. + """ + + def __init__(self, **kwargs): + """ + Args: + **kwargs: Keyword arguments that are passed to the `VecEnv` constructor. + """ + super().__init__(2, 2, **kwargs) + + self.num_actions = 1 + + self.gravity = 9.8 + self.length = 2.0 + self.dt = 0.1 + + # Angle at which to fail the episode (15 deg) + self.theta_threshold_radians = 15 * 2 * math.pi / 360 + + # Max. angle at which to initialize the episode (5 deg) + self.initial_max_position = 2 * 2 * math.pi / 360 + # Max. angular velocity at which to initialize the episode (1 deg/s) + self.initial_max_velocity = 0.3 * 2 * math.pi / 360 + + self.initial_position_factor = self.initial_max_position / 0.5 + self.initial_velocity_factor = self.initial_max_velocity / 0.5 + + self.draw = False + self.pushes = False + + self.reset() + + def get_observations(self) -> Tuple[torch.Tensor, Dict[str, Any]]: + return self.obs_buf, self.extras + + def get_privileged_observations(self) -> Union[torch.Tensor, None]: + return self.obs_buf, self.extras + + def step( + self, actions: torch.Tensor + ) -> Tuple[torch.Tensor, Union[torch.Tensor, None], torch.Tensor, torch.Tensor, Dict[str, Any]]: + assert actions.size() == (self.num_envs, 1) + + self.to(self.device) + actions = actions.to(self.device) + + noise = torch.normal(torch.zeros_like(actions), torch.ones_like(actions) * 0.005).squeeze() + if self.pushes and np.random.rand() < 0.05: + noise *= 100.0 + actions = actions.clamp(min=-0.2, max=0.2).squeeze() + gravity = torch.sin(self.state[:, 0]) * self.gravity / self.length + angular_acceleration = gravity + actions + noise + + self.state[:, 1] = self.state[:, 1] + self.dt * angular_acceleration + self.state[:, 0] = self.state[:, 0] + self.dt * self.state[:, 1] + + self.reset_buf = torch.zeros(self.num_envs) + self.reset_buf[(torch.abs(self.state[:, 0]) > self.theta_threshold_radians).nonzero()] = 1.0 + reset_idx = self.reset_buf.nonzero() + + self.state[reset_idx, 0] = ( + torch.rand(reset_idx.size()[0], 1, device=self.device) - 0.5 + ) * self.initial_position_factor + self.state[reset_idx, 1] = ( + torch.rand(reset_idx.size()[0], 1, device=self.device) - 0.5 + ) * self.initial_velocity_factor + + self.rew_buf = torch.ones(self.num_envs, device=self.device) + self.rew_buf[reset_idx] = -1.0 + self.rew_buf = self.rew_buf - actions.abs() + self.rew_buf = self.rew_buf - self.state[:, 0].abs() + + self._update_obs() + + if self.draw: + self._debug_draw(actions) + + self.to(self.device) + + return self.obs_buf, self.rew_buf, self.reset_buf, self.extras + + def _update_obs(self): + self.obs_buf = self.state + self.extras = {"observations": {}, "time_outs": torch.zeros_like(self.rew_buf)} + + def reset(self) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: + self.state = torch.zeros(self.num_envs, 2, device=self.device) + self.state[:, 0] = (torch.rand(self.num_envs) - 0.5) * self.initial_position_factor + self.state[:, 1] = (torch.rand(self.num_envs) - 0.5) * self.initial_velocity_factor + self.rew_buf = torch.zeros(self.num_envs) + self.reset_buf = torch.zeros(self.num_envs) + self.extras = {} + + self._update_obs() + + return self.obs_buf, self.extras + + def to(self, device): + self.device = device + + self.obs_buf = self.obs_buf.to(device) + self.rew_buf = self.rew_buf.to(device) + self.reset_buf = self.reset_buf.to(device) + self.state = self.state.to(device) + + def _debug_draw(self, actions): + if not hasattr(self, "_visuals"): + self._visuals = {"x": [0], "pos": [], "act": [], "done": []} + plt.gca().figure.show() + else: + self._visuals["x"].append(self._visuals["x"][-1] + 1) + + self._visuals["pos"].append(self.obs_buf[0, 0].cpu().item()) + self._visuals["done"].append(self.reset_buf[0].cpu().item()) + self._visuals["act"].append(actions.squeeze()[0].cpu().item()) + + plt.cla() + plt.plot(self._visuals["x"][-100:], self._visuals["act"][-100:], color="green") + plt.plot(self._visuals["x"][-100:], self._visuals["pos"][-100:], color="blue") + plt.plot(self._visuals["x"][-100:], self._visuals["done"][-100:], color="red") + plt.draw() + plt.gca().figure.canvas.flush_events() + time.sleep(0.0001) diff --git a/rsl_rl/env/pomdp.py b/rsl_rl/env/pomdp.py new file mode 100644 index 0000000..750a464 --- /dev/null +++ b/rsl_rl/env/pomdp.py @@ -0,0 +1,97 @@ +import torch +from typing import Any, Dict, Tuple, Union + +from rsl_rl.env.gym_env import GymEnv + + +class GymPOMDP(GymEnv): + """A vectorized POMDP environment wrapper for OpenAI Gym environments. + + This environment allows for the modification of the observation space of an OpenAI Gym environment. The modified + observation space is a subset of the original observation space. + """ + + _reduced_observation_count: int = None + + def __init__(self, name: str, **kwargs): + assert self._reduced_observation_count is not None + + super().__init__(name=name, **kwargs) + + self.num_obs = self._reduced_observation_count + self.num_privileged_obs = self._reduced_observation_count + + def _process_observations(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Reduces observation space from original observation space to modified observation space. + + Args: + obs (torch.Tensor): Original observations. + Returns: + The modified observations as a torch.Tensor of shape (obs.shape[0], self.num_obs). + """ + raise NotImplementedError + + def reset(self, *args, **kwargs): + obs, _ = super().reset(*args, **kwargs) + + self.obs_buf = self._process_observations(obs) + + return self.obs_buf, self.extras + + def step( + self, actions: torch.Tensor + ) -> Tuple[torch.Tensor, Union[torch.Tensor, None], torch.Tensor, torch.Tensor, Dict[str, Any]]: + obs, _, _, _ = super().step(actions) + + self.obs_buf = self._process_observations(obs) + + return self.obs_buf, self.rew_buf, self.reset_buf, self.extras + + +class BipedalWalkerP(GymPOMDP): + """ + Original observation space (24 values): + [ + hull angle, + hull angular velocity, + horizontal velocity, + vertical velocity, + joint 1 angle, + joint 1 speed, + joint 2 angle, + joint 2 speed, + leg 1 ground contact, + joint 3 angle, + joint 3 speed, + joint 4 angle, + joint 4 speed, + leg 2 ground contact, + lidar (10 values), + ] + Modified observation space (15 values): + [ + hull angle, + joint 1 angle, + joint 2 angle, + joint 3 angle, + joint 4 angle, + lidar (10 values), + ] + """ + + _reduced_observation_count: int = 15 + + def __init__(self, **kwargs): + super().__init__(name="BipedalWalker-v3", **kwargs) + + def _process_observations(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Reduces observation space from original observation space to modified observation space.""" + reduced_obs = torch.zeros(obs.shape[0], self._reduced_observation_count, device=self.device) + reduced_obs[:, 0] = obs[:, 0] + reduced_obs[:, 1] = obs[:, 4] + reduced_obs[:, 2] = obs[:, 6] + reduced_obs[:, 3] = obs[:, 9] + reduced_obs[:, 4] = obs[:, 11] + reduced_obs[:, 5:] = obs[:, 14:] + + return reduced_obs diff --git a/rsl_rl/env/rslgym_env.py b/rsl_rl/env/rslgym_env.py new file mode 100644 index 0000000..a6c0d91 --- /dev/null +++ b/rsl_rl/env/rslgym_env.py @@ -0,0 +1,44 @@ +import torch +from typing import Any, Dict, Tuple, Union + +from rsl_rl.env.vec_env import VecEnv + + +class RSLGymEnv(VecEnv): + """A wrapper for using rsl_rl with the rslgym library.""" + + def __init__(self, rslgym_env, **kwargs): + self._rslgym_env = rslgym_env + + observation_count = self._rslgym_env.observation_space.shape[0] + super().__init__(observation_count, observation_count, **kwargs) + + self.num_actions = self._rslgym_env.action_space.shape[0] + + self.obs_buf = None + self.rew_buf = None + self.reset_buf = None + self.extras = None + + self.reset() + + def get_observations(self) -> Tuple[torch.Tensor, Dict[str, Any]]: + return self.obs_buf, self.extras + + def get_privileged_observations(self) -> Union[torch.Tensor, None]: + return self.obs_buf + + def reset(self) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: + obs = self._rslgym_env.reset() + + self.obs_buf = torch.from_numpy(obs) + self.extras = {"observations": {}, "time_outs": torch.zeros((self.num_envs,), device=self.device).float()} + + def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, Any]]: + obs, reward, dones, info = self._rslgym_env.step(actions, True) + + self.obs_buf = torch.from_numpy(obs) + self.rew_buf = torch.from_numpy(reward) + self.reset_buf = torch.from_numpy(dones).float() + + return self.obs_buf, self.rew_buf, self.reset_buf, self.extras diff --git a/rsl_rl/env/vec_env.py b/rsl_rl/env/vec_env.py index a7af015..70c9af8 100644 --- a/rsl_rl/env/vec_env.py +++ b/rsl_rl/env/vec_env.py @@ -1,85 +1,74 @@ -# Copyright 2021 ETH Zurich, NVIDIA CORPORATION -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -import torch from abc import ABC, abstractmethod +import torch +from typing import Any, Dict, Tuple, Union +# minimal interface of the environment class VecEnv(ABC): - """Abstract class for vectorized environment. - - The vectorized environment is a collection of environments that are synchronized. This means that - the same action is applied to all environments and the same observation is returned from all environments. - - All extra observations must be provided as a dictionary to "extras" in the step() method. Based on the - configuration, the extra observations are used for different purposes. The following keys are reserved - in the "observations" dictionary (if they are present): - - - "critic": The observation is used as input to the critic network. Useful for asymmetric observation spaces. - """ + """Abstract class for vectorized environment.""" num_envs: int - """Number of environments.""" num_obs: int - """Number of observations.""" num_privileged_obs: int - """Number of privileged observations.""" num_actions: int - """Number of actions.""" max_episode_length: int - """Maximum episode length.""" privileged_obs_buf: torch.Tensor - """Buffer for privileged observations.""" obs_buf: torch.Tensor - """Buffer for observations.""" rew_buf: torch.Tensor - """Buffer for rewards.""" reset_buf: torch.Tensor - """Buffer for resets.""" episode_length_buf: torch.Tensor # current episode duration - """Buffer for current episode lengths.""" extras: dict - """Extra information (metrics). - - Extra information is stored in a dictionary. This includes metrics such as the episode reward, episode length, - etc. Additional information can be stored in the dictionary such as observations for the critic network, etc. - """ device: torch.device - """Device to use.""" - """ - Operations. - """ - - @abstractmethod - def get_observations(self) -> tuple[torch.Tensor, dict]: - """Return the current observations. - - Returns: - Tuple[torch.Tensor, dict]: Tuple containing the observations and extras. + def __init__( + self, observation_count, privileged_observation_count, device="cpu", environment_count=1, max_episode_length=-1 + ): """ - raise NotImplementedError - - @abstractmethod - def reset(self) -> tuple[torch.Tensor, dict]: - """Reset all environment instances. - - Returns: - Tuple[torch.Tensor, dict]: Tuple containing the observations and extras. + Args: + observation_count (int): Number of observations per environment. + privileged_observation_count (int): Number of privileged observations per environment. + device (str): Device to use for the tensors. + environment_count (int): Number of environments to run in parallel. + max_episode_length (int): Maximum length of an episode. If -1, the episode length is not limited. """ - raise NotImplementedError + self.num_obs = observation_count + self.num_privileged_obs = privileged_observation_count + + self.num_envs = environment_count + self.max_episode_length = max_episode_length + self.device = device @abstractmethod - def step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]: + def get_observations(self) -> Tuple[torch.Tensor, Dict[str, Any]]: + """Return observations and extra information.""" + pass + + @abstractmethod + def get_privileged_observations(self) -> Union[torch.Tensor, None]: + """Return privileged observations.""" + pass + + @abstractmethod + def step( + self, actions: torch.Tensor + ) -> Tuple[torch.Tensor, Union[torch.Tensor, None], torch.Tensor, torch.Tensor, Dict[str, Any]]: """Apply input action on the environment. Args: actions (torch.Tensor): Input actions to apply. Shape: (num_envs, num_actions) Returns: - Tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]: - A tuple containing the observations, rewards, dones and extra information (metrics). + Tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, dict]: + A tuple containing the observations, privileged observations, rewards, dones and + extra information (metrics). + """ + raise NotImplementedError + + @abstractmethod + def reset(self) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: + """Reset all environment instances. + + Returns: + Tuple[torch.Tensor, torch.Tensor | None]: Tuple containing the observations and privileged observations. """ raise NotImplementedError diff --git a/rsl_rl/modules/__init__.py b/rsl_rl/modules/__init__.py index b96383c..b1b4d43 100644 --- a/rsl_rl/modules/__init__.py +++ b/rsl_rl/modules/__init__.py @@ -1,10 +1,19 @@ -# Copyright 2021 ETH Zurich, NVIDIA CORPORATION -# SPDX-License-Identifier: BSD-3-Clause - -"""Definitions for neural-network components for RL-agents.""" - -from .actor_critic import ActorCritic -from .actor_critic_recurrent import ActorCriticRecurrent +from .categorical_network import CategoricalNetwork +from .gaussian_chimera_network import GaussianChimeraNetwork +from .gaussian_network import GaussianNetwork +from .implicit_quantile_network import ImplicitQuantileNetwork +from .network import Network from .normalizer import EmpiricalNormalization +from .quantile_network import QuantileNetwork +from .transformer import Transformer -__all__ = ["ActorCritic", "ActorCriticRecurrent"] +__all__ = [ + "CategoricalNetwork", + "EmpiricalNormalization", + "GaussianChimeraNetwork", + "GaussianNetwork", + "ImplicitQuantileNetwork", + "Network", + "QuantileNetwork", + "Transformer", +] diff --git a/rsl_rl/modules/actor_critic.py b/rsl_rl/modules/actor_critic.py deleted file mode 100644 index cdb77e1..0000000 --- a/rsl_rl/modules/actor_critic.py +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright 2021 ETH Zurich, NVIDIA CORPORATION -# SPDX-License-Identifier: BSD-3-Clause - - -from __future__ import annotations - -import torch -import torch.nn as nn -from torch.distributions import Normal - - -class ActorCritic(nn.Module): - is_recurrent = False - - def __init__( - self, - num_actor_obs, - num_critic_obs, - num_actions, - actor_hidden_dims=[256, 256, 256], - critic_hidden_dims=[256, 256, 256], - activation="elu", - init_noise_std=1.0, - **kwargs, - ): - if kwargs: - print( - "ActorCritic.__init__ got unexpected arguments, which will be ignored: " - + str([key for key in kwargs.keys()]) - ) - super().__init__() - activation = get_activation(activation) - - mlp_input_dim_a = num_actor_obs - mlp_input_dim_c = num_critic_obs - # Policy - actor_layers = [] - actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0])) - actor_layers.append(activation) - for layer_index in range(len(actor_hidden_dims)): - if layer_index == len(actor_hidden_dims) - 1: - actor_layers.append(nn.Linear(actor_hidden_dims[layer_index], num_actions)) - else: - actor_layers.append(nn.Linear(actor_hidden_dims[layer_index], actor_hidden_dims[layer_index + 1])) - actor_layers.append(activation) - self.actor = nn.Sequential(*actor_layers) - - # Value function - critic_layers = [] - critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0])) - critic_layers.append(activation) - for layer_index in range(len(critic_hidden_dims)): - if layer_index == len(critic_hidden_dims) - 1: - critic_layers.append(nn.Linear(critic_hidden_dims[layer_index], 1)) - else: - critic_layers.append(nn.Linear(critic_hidden_dims[layer_index], critic_hidden_dims[layer_index + 1])) - critic_layers.append(activation) - self.critic = nn.Sequential(*critic_layers) - - print(f"Actor MLP: {self.actor}") - print(f"Critic MLP: {self.critic}") - - # Action noise - self.std = nn.Parameter(init_noise_std * torch.ones(num_actions)) - self.distribution = None - # disable args validation for speedup - Normal.set_default_validate_args = False - - # seems that we get better performance without init - # self.init_memory_weights(self.memory_a, 0.001, 0.) - # self.init_memory_weights(self.memory_c, 0.001, 0.) - - @staticmethod - # not used at the moment - def init_weights(sequential, scales): - [ - torch.nn.init.orthogonal_(module.weight, gain=scales[idx]) - for idx, module in enumerate(mod for mod in sequential if isinstance(mod, nn.Linear)) - ] - - def reset(self, dones=None): - pass - - def forward(self): - raise NotImplementedError - - @property - def action_mean(self): - return self.distribution.mean - - @property - def action_std(self): - return self.distribution.stddev - - @property - def entropy(self): - return self.distribution.entropy().sum(dim=-1) - - def update_distribution(self, observations): - mean = self.actor(observations) - self.distribution = Normal(mean, mean * 0.0 + self.std) - - def act(self, observations, **kwargs): - self.update_distribution(observations) - return self.distribution.sample() - - def get_actions_log_prob(self, actions): - return self.distribution.log_prob(actions).sum(dim=-1) - - def act_inference(self, observations): - actions_mean = self.actor(observations) - return actions_mean - - def evaluate(self, critic_observations, **kwargs): - value = self.critic(critic_observations) - return value - - -def get_activation(act_name): - if act_name == "elu": - return nn.ELU() - elif act_name == "selu": - return nn.SELU() - elif act_name == "relu": - return nn.ReLU() - elif act_name == "crelu": - return nn.CReLU() - elif act_name == "lrelu": - return nn.LeakyReLU() - elif act_name == "tanh": - return nn.Tanh() - elif act_name == "sigmoid": - return nn.Sigmoid() - else: - print("invalid activation function!") - return None diff --git a/rsl_rl/modules/actor_critic_recurrent.py b/rsl_rl/modules/actor_critic_recurrent.py deleted file mode 100644 index 6321ec5..0000000 --- a/rsl_rl/modules/actor_critic_recurrent.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright 2021 ETH Zurich, NVIDIA CORPORATION -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -import torch -import torch.nn as nn - -from rsl_rl.modules.actor_critic import ActorCritic, get_activation -from rsl_rl.utils import unpad_trajectories - - -class ActorCriticRecurrent(ActorCritic): - is_recurrent = True - - def __init__( - self, - num_actor_obs, - num_critic_obs, - num_actions, - actor_hidden_dims=[256, 256, 256], - critic_hidden_dims=[256, 256, 256], - activation="elu", - rnn_type="lstm", - rnn_hidden_size=256, - rnn_num_layers=1, - init_noise_std=1.0, - **kwargs, - ): - if kwargs: - print( - "ActorCriticRecurrent.__init__ got unexpected arguments, which will be ignored: " + str(kwargs.keys()), - ) - - super().__init__( - num_actor_obs=rnn_hidden_size, - num_critic_obs=rnn_hidden_size, - num_actions=num_actions, - actor_hidden_dims=actor_hidden_dims, - critic_hidden_dims=critic_hidden_dims, - activation=activation, - init_noise_std=init_noise_std, - ) - - activation = get_activation(activation) - - self.memory_a = Memory(num_actor_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size) - self.memory_c = Memory(num_critic_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size) - - print(f"Actor RNN: {self.memory_a}") - print(f"Critic RNN: {self.memory_c}") - - def reset(self, dones=None): - self.memory_a.reset(dones) - self.memory_c.reset(dones) - - def act(self, observations, masks=None, hidden_states=None): - input_a = self.memory_a(observations, masks, hidden_states) - return super().act(input_a.squeeze(0)) - - def act_inference(self, observations): - input_a = self.memory_a(observations) - return super().act_inference(input_a.squeeze(0)) - - def evaluate(self, critic_observations, masks=None, hidden_states=None): - input_c = self.memory_c(critic_observations, masks, hidden_states) - return super().evaluate(input_c.squeeze(0)) - - def get_hidden_states(self): - return self.memory_a.hidden_states, self.memory_c.hidden_states - - -class Memory(torch.nn.Module): - def __init__(self, input_size, type="lstm", num_layers=1, hidden_size=256): - super().__init__() - # RNN - rnn_cls = nn.GRU if type.lower() == "gru" else nn.LSTM - self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers) - self.hidden_states = None - - def forward(self, input, masks=None, hidden_states=None): - batch_mode = masks is not None - if batch_mode: - # batch mode (policy update): need saved hidden states - if hidden_states is None: - raise ValueError("Hidden states not passed to memory module during policy update") - out, _ = self.rnn(input, hidden_states) - out = unpad_trajectories(out, masks) - else: - # inference mode (collection): use hidden states of last step - out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states) - return out - - def reset(self, dones=None): - # When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state - for hidden_state in self.hidden_states: - hidden_state[..., dones, :] = 0.0 diff --git a/rsl_rl/modules/categorical_network.py b/rsl_rl/modules/categorical_network.py new file mode 100644 index 0000000..bc4fb40 --- /dev/null +++ b/rsl_rl/modules/categorical_network.py @@ -0,0 +1,105 @@ +import torch +import torch.nn as nn +from rsl_rl.modules.network import Network +from rsl_rl.utils.utils import squeeze_preserve_batch + +eps = torch.finfo(torch.float32).eps + + +class CategoricalNetwork(Network): + def __init__( + self, + input_size, + output_size, + activations=["relu", "relu", "relu"], + atom_count=51, + hidden_dims=[256, 256, 256], + init_gain=1.0, + value_max=10.0, + value_min=-10.0, + **kwargs, + ): + assert len(hidden_dims) == len(activations) + assert value_max > value_min + assert atom_count > 1 + + super().__init__( + input_size, + activations=activations, + hidden_dims=hidden_dims[:-1], + init_fade=False, + init_gain=init_gain, + output_size=hidden_dims[-1], + **kwargs, + ) + + self._value_max = value_max + self._value_min = value_min + self._atom_count = atom_count + + self.value_delta = (self._value_max - self._value_min) / (self._atom_count - 1) + action_values = torch.arange(self._value_min, self._value_max + eps, self.value_delta) + self.register_buffer("action_values", action_values) + + self._categorical_layers = nn.ModuleList([nn.Linear(hidden_dims[-1], atom_count) for _ in range(output_size)]) + + self._init(self._categorical_layers, fade=False, gain=init_gain) + + def categorical_loss( + self, predictions: torch.Tensor, target_probabilities: torch.Tensor, targets: torch.Tensor + ) -> torch.Tensor: + """Computes the categorical loss between the prediction and target categorical distributions. + + Projects the targets back onto the categorical distribution supports before computing KL divergence. + + Args: + predictions (torch.Tensor): The network prediction. + target_probabilities (torch.Tensor): The next-state value probabilities. + targets (torch.Tensor): The targets to compute the loss from. + Returns: + A torch.Tensor of the cross-entropy loss between the projected targets and the prediction. + """ + b = (targets - self._value_min) / self.value_delta + l = b.floor().long().clamp(0, self._atom_count - 1) + u = b.ceil().long().clamp(0, self._atom_count - 1) + + all_idx = torch.arange(b.shape[0]) + projected_targets = torch.zeros((b.shape[0], self._atom_count), device=self.device) + for i in range(self._atom_count): + # Correct for when l == u + l[:, i][(l[:, i] == u[:, i]) * (l[:, i] > 0)] -= 1 + u[:, i][(l[:, i] == u[:, i]) * (u[:, i] < self._atom_count - 1)] += 1 + + projected_targets[all_idx, l[:, i]] += (u[:, i] - b[:, i]) * target_probabilities[..., i] + projected_targets[all_idx, u[:, i]] += (b[:, i] - l[:, i]) * target_probabilities[..., i] + + loss = torch.nn.functional.cross_entropy( + predictions.reshape(*projected_targets.shape), projected_targets.detach() + ) + + return loss + + def compute_targets(self, rewards: torch.Tensor, dones: torch.Tensor, discount: float) -> torch.Tensor: + gamma = (discount * (1 - dones)).reshape(-1, 1) + gamma_z = gamma * self.action_values.repeat(dones.size()[0], 1) + targets = (rewards.reshape(-1, 1) + gamma_z).clamp(self._value_min, self._value_max) + + return targets + + def forward(self, x: torch.Tensor, distribution: bool = False) -> torch.Tensor: + features = super().forward(x) + probabilities = squeeze_preserve_batch( + torch.stack([layer(features).softmax(dim=-1) for layer in self._categorical_layers], dim=1) + ) + + if distribution: + return probabilities + + values = self.probabilities_to_values(probabilities) + + return values + + def probabilities_to_values(self, probabilities: torch.Tensor) -> torch.Tensor: + values = probabilities @ self.action_values + + return values diff --git a/rsl_rl/modules/gaussian_chimera_network.py b/rsl_rl/modules/gaussian_chimera_network.py new file mode 100644 index 0000000..8b13e53 --- /dev/null +++ b/rsl_rl/modules/gaussian_chimera_network.py @@ -0,0 +1,86 @@ +import numpy as np +import torch +import torch.nn as nn +from typing import List, Tuple, Union + +from rsl_rl.modules.network import Network +from rsl_rl.modules.utils import get_activation + + +class GaussianChimeraNetwork(Network): + """A network to predict mean and std of a gaussian distribution with separate heads for mean and std.""" + + def __init__( + self, + input_size: int, + output_size: int, + activations: List[str] = ["relu", "relu", "relu", "linear"], + hidden_dims: List[int] = [256, 256, 256], + init_fade: bool = True, + init_gain: float = 0.5, + log_std_max: float = 4.0, + log_std_min: float = -20.0, + std_init: float = 1.0, + shared_dims: int = 1, + **kwargs, + ): + assert len(hidden_dims) + 1 == len(activations) + assert shared_dims > 0 and shared_dims <= len(hidden_dims) + + super().__init__( + input_size, + hidden_dims[shared_dims], + activations=activations[: shared_dims + 1], + hidden_dims=hidden_dims[:shared_dims], + init_fade=False, + init_gain=init_gain, + **kwargs, + ) + + # Since the network predicts log_std ~= 0 after initialization, compute std = std_init * exp(log_std). + self._log_std_init = np.log(std_init) + self._log_std_max = log_std_max + self._log_std_min = log_std_min + + separate_dims = len(hidden_dims) - shared_dims + + mean_layers = [] + for i in range(separate_dims): + isize = hidden_dims[shared_dims + i] + osize = output_size if i == separate_dims - 1 else hidden_dims[shared_dims + i + 1] + + layer = nn.Linear(isize, osize) + activation = activations[shared_dims + i + 1] + + mean_layers += [layer, get_activation(activation)] + self._mean_layer = nn.Sequential(*mean_layers) + + self._init(self._mean_layer, fade=init_fade, gain=init_gain) + + log_std_layers = [] + for i in range(separate_dims): + isize = hidden_dims[shared_dims + i] + osize = output_size if i == separate_dims - 1 else hidden_dims[shared_dims + i + 1] + + layer = nn.Linear(isize, osize) + activation = activations[shared_dims + i + 1] + + log_std_layers += [layer, get_activation(activation)] + self._log_std_layer = nn.Sequential(*log_std_layers) + + self._init(self._log_std_layer, fade=init_fade, gain=init_gain) + + def forward(self, x: torch.Tensor, compute_std: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + features = super().forward(x) + + mean = self._mean_layer(features) + + if not compute_std: + return mean + + # compute standard deviation as std = std_init * exp(log_std) = exp(log(std_init) + log(std)) since the network + # will predict log_std ~= 0 after initialization. + log_std = (self._log_std_init + self._log_std_layer(features)).clamp(self._log_std_min, self._log_std_max) + std = log_std.exp() + + return mean, std diff --git a/rsl_rl/modules/gaussian_network.py b/rsl_rl/modules/gaussian_network.py new file mode 100644 index 0000000..1aec8d2 --- /dev/null +++ b/rsl_rl/modules/gaussian_network.py @@ -0,0 +1,37 @@ +import numpy as np +import torch +import torch.nn as nn +from typing import Tuple, Union + +from rsl_rl.modules.network import Network + + +class GaussianNetwork(Network): + """A network to predict mean and std of a gaussian distribution where std is a tunable parameter.""" + + def __init__( + self, + input_size: int, + output_size: int, + log_std_max: float = 4.0, + log_std_min: float = -20.0, + std_init: float = 1.0, + **kwargs, + ): + super().__init__(input_size, output_size, **kwargs) + + self._log_std_max = log_std_max + self._log_std_min = log_std_min + + self._log_std = nn.Parameter(torch.ones(output_size) * np.log(std_init)) + + def forward(self, x: torch.Tensor, compute_std: bool = False, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]: + mean = super().forward(x, **kwargs) + + if not compute_std: + return mean + + log_std = torch.ones_like(mean) * self._log_std.clamp(self._log_std_min, self._log_std_max) + std = log_std.exp() + + return mean, std diff --git a/rsl_rl/modules/implicit_quantile_network.py b/rsl_rl/modules/implicit_quantile_network.py new file mode 100644 index 0000000..898ee67 --- /dev/null +++ b/rsl_rl/modules/implicit_quantile_network.py @@ -0,0 +1,168 @@ +import numpy as np +import torch +from torch.distributions import Normal +import torch.nn as nn +from typing import List, Union + +from rsl_rl.modules.network import Network +from rsl_rl.modules.quantile_network import energy_loss +from rsl_rl.utils.benchmarkable import Benchmarkable + + +def reshape_measure_param(tau: torch.Tensor, param: torch.Tensor) -> torch.Tensor: + if not torch.is_tensor(param): + param = torch.tensor([param]) + + param = param.expand(tau.shape[0], -1).to(tau.device) + + return param + + +def risk_measure_neutral(tau: torch.Tensor) -> torch.Tensor: + return tau + + +def risk_measure_wang(tau: torch.Tensor, beta: float = 0.0) -> torch.Tensor: + beta = reshape_measure_param(tau, beta) + + distorted_tau = Normal(0, 1).cdf(Normal(0, 1).icdf(tau) + beta) + + return distorted_tau + + +class ImplicitQuantileNetwork(Network): + measure_neutral = "neutral" + measure_wang = "wang" + + measures = { + measure_neutral: risk_measure_neutral, + measure_wang: risk_measure_wang, + } + + def __init__( + self, + input_size: int, + output_size: int, + activations: List[str] = ["relu", "relu", "relu"], + feature_layers: int = 1, + embedding_size: int = 64, + hidden_dims: List[int] = [256, 256, 256], + init_fade: bool = False, + init_gain: float = 0.5, + measure: str = None, + measure_kwargs: dict = {}, + **kwargs, + ): + assert len(hidden_dims) == len(activations), "hidden_dims and activations must have the same length." + assert feature_layers > 0, "feature_layers must be greater than 0." + assert feature_layers < len(hidden_dims), "feature_layers must be less than the number of hidden dimensions." + assert embedding_size > 0, "embedding_size must be greater than 0." + + super().__init__( + input_size, + hidden_dims[feature_layers - 1], + activations=activations[:feature_layers], + hidden_dims=hidden_dims[: feature_layers - 1], + init_fade=init_fade, + init_gain=init_gain, + **kwargs, + ) + + self._last_taus = None + self._last_quantiles = None + self._embedding_size = embedding_size + self.register_buffer( + "_embedding_pis", + np.pi * (torch.arange(self._embedding_size, device=self.device).reshape(1, 1, self._embedding_size)), + ) + self._embedding_layer = nn.Sequential( + nn.Linear(self._embedding_size, hidden_dims[feature_layers - 1]), nn.ReLU() + ) + + self._fusion_layers = Network( + hidden_dims[feature_layers - 1], + output_size, + activations=activations[feature_layers:] + ["linear"], + hidden_dims=hidden_dims[feature_layers:], + init_fade=init_fade, + init_gain=init_gain, + ) + + measure_func = risk_measure_neutral if measure is None else self.measures[measure] + self._measure_func = measure_func + self._measure_kwargs = measure_kwargs + + @Benchmarkable.register + def _sample_taus(self, batch_size: int, sample_count: int, measure_args: list, use_measure: bool) -> torch.Tensor: + """Sample quantiles and distort them according to the risk metric. + + Args: + batch_size: The batch size. + sample_count: The number of samples to draw. + measure_args: The arguments to pass to the risk measure function. + use_measure: Whether to use the risk measure or not. + Returns: + A tensor of shape (batch_size, sample_count, 1). + """ + taus = torch.rand(batch_size, sample_count, device=self.device) + + if not use_measure: + return taus + + if measure_args: + taus = self._measure_func(taus, *measure_args) + else: + taus = self._measure_func(taus, **self._measure_kwargs) + + return taus + + @Benchmarkable.register + def forward( + self, + x: torch.Tensor, + distribution: bool = False, + measure_args: list = [], + sample_count: int = 8, + taus: Union[torch.Tensor, None] = None, + use_measure: bool = True, + **kwargs, + ) -> torch.Tensor: + assert taus is None or not use_measure, "Cannot use taus and use_measure at the same time." + + batch_size = x.shape[0] + + features = super().forward(x, **kwargs) + taus = self._sample_taus(batch_size, sample_count, measure_args, use_measure) if taus is None else taus + + # Compute quantile embeddings + singular_dims = [1] * taus.dim() + cos = torch.cos(taus.unsqueeze(-1) * self._embedding_pis.reshape(*singular_dims, self._embedding_size)) + embeddings = self._embedding_layer(cos) + + # Compute the fusion of the features and the embeddings + fused_features = features.unsqueeze(-2) * embeddings + quantiles = self._fusion_layers(fused_features) + + self._last_quantiles = quantiles + self._last_taus = taus + + if distribution: + return quantiles + + values = quantiles.mean(-1) + + return values + + @property + def last_taus(self): + return self._last_taus + + @property + def last_quantiles(self): + return self._last_quantiles + + @Benchmarkable.register + def sample_energy_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + loss = energy_loss(predictions, targets) + + return loss diff --git a/rsl_rl/modules/network.py b/rsl_rl/modules/network.py new file mode 100644 index 0000000..33e4513 --- /dev/null +++ b/rsl_rl/modules/network.py @@ -0,0 +1,210 @@ +import torch +import torch.nn as nn +from typing import List + +from rsl_rl.modules.normalizer import EmpiricalNormalization +from rsl_rl.modules.utils import get_activation +from rsl_rl.modules.transformer import Transformer +from rsl_rl.utils.benchmarkable import Benchmarkable +from rsl_rl.utils.utils import squeeze_preserve_batch + + +class Network(Benchmarkable, nn.Module): + recurrent_module_lstm = "LSTM" + recurrent_module_transformer = "TF" + + recurrent_modules = {recurrent_module_lstm: nn.LSTM, recurrent_module_transformer: Transformer} + + def __init__( + self, + input_size: int, + output_size: int, + activations: List[str] = ["relu", "relu", "relu", "tanh"], + hidden_dims: List[int] = [256, 256, 256], + init_fade: bool = True, + init_gain: float = 1.0, + input_normalization: bool = False, + recurrent: bool = False, + recurrent_layers: int = 1, + recurrent_module: str = recurrent_module_lstm, + recurrent_tf_context_length: int = 64, + recurrent_tf_head_count: int = 8, + ) -> None: + """ + + Args: + input_size (int): The size of the input. + output_size (int): The size of the output. + activations (List[str]): The activation functions to use. If the network is recurrent, the first activation + function is used for the output of the recurrent layer. + hidden_dims (List[int]): The hidden dimensions. If the network is recurrent, the first hidden dimension is + used for the recurrent layer. + init_fade (bool): Whether to use the fade in initialization. + init_gain (float): The gain to use for the initialization. + input_normalization (bool): Whether to use input normalization. + recurrent (bool): Whether to use a recurrent network. + recurrent_layers (int): The number of recurrent layers (LSTM) / blocks (Transformer) to use. + recurrent_module (str): The recurrent module to use. Must be one of Network.recurrent_modules. + recurrent_tf_context_length (int): The context length of the Transformer. + recurrent_tf_head_count (int): The head count of the Transformer. + """ + assert len(hidden_dims) + 1 == len(activations) + + super().__init__() + + if input_normalization: + self._normalization = EmpiricalNormalization(shape=(input_size,)) + else: + self._normalization = nn.Identity() + + dims = [input_size] + hidden_dims + [output_size] + + self._recurrent = recurrent + self._recurrent_module = recurrent_module + self.hidden_state = None + self._last_hidden_state = None + if self._recurrent: + recurrent_kwargs = dict() + + if recurrent_module == self.recurrent_module_lstm: + recurrent_kwargs["hidden_size"] = dims[1] + recurrent_kwargs["input_size"] = dims[0] + recurrent_kwargs["num_layers"] = recurrent_layers + elif recurrent_module == self.recurrent_module_transformer: + recurrent_kwargs["block_count"] = recurrent_layers + recurrent_kwargs["context_length"] = recurrent_tf_context_length + recurrent_kwargs["head_count"] = recurrent_tf_head_count + recurrent_kwargs["hidden_size"] = dims[1] + recurrent_kwargs["input_size"] = dims[0] + recurrent_kwargs["output_size"] = dims[1] + + rnn = self.recurrent_modules[recurrent_module](**recurrent_kwargs) + activation = get_activation(activations[0]) + dims = dims[1:] + activations = activations[1:] + + self._features = nn.Sequential(rnn, activation) + else: + self._features = nn.Identity() + + layers = [] + for i in range(len(activations)): + layer = nn.Linear(dims[i], dims[i + 1]) + activation = get_activation(activations[i]) + + layers.append(layer) + layers.append(activation) + + self._layers = nn.Sequential(*layers) + + if len(layers) > 0: + self._init(self._layers, fade=init_fade, gain=init_gain) + + @property + def device(self): + """Returns the device of the network.""" + return next(self.parameters()).device + + def forward(self, x: torch.Tensor, hidden_state=None) -> torch.Tensor: + """ + Args: + x (torch.Tensor): The input data. + hidden_state (Tuple[torch.Tensor, torch.Tensor]): The hidden state of the network. If None, the hidden state + of the network is used. If provided, the hidden state of the neural network will not be updated. To + retrieve the new hidden state, use the last_hidden_state property. If the network is not recurrent, + this argument is ignored. + Returns: + The output of the network as a torch.Tensor. + """ + assert hidden_state is None or self._recurrent, "Cannot pass hidden state to non-recurrent network." + + input = self._normalization(x.to(self.device)) + + if self._recurrent: + current_hidden_state = self.hidden_state if hidden_state is None else hidden_state + current_hidden_state = (current_hidden_state[0].to(self.device), current_hidden_state[1].to(self.device)) + + input = input.unsqueeze(0) if len(input.shape) == 2 else input + input, next_hidden_state = self._features[0](input, current_hidden_state) + input = self._features[1](input).squeeze(0) + + if hidden_state is None: + self.hidden_state = next_hidden_state + self._last_hidden_state = next_hidden_state + + output = squeeze_preserve_batch(self._layers(input)) + + return output + + @property + def last_hidden_state(self): + """Returns the hidden state of the last forward pass. + + Does not differentiate whether the hidden state depends on the hidden state kept in the network or whether it + was passed into the forward pass. + + Returns: + The hidden state of the last forward pass as Tuple[torch.Tensor, torch.Tensor]. + """ + return self._last_hidden_state + + def normalize(self, x: torch.Tensor) -> torch.Tensor: + """Normalizes the given input. + + Args: + x (torch.Tensor): The input to normalize. + Returns: + The normalized input as a torch.Tensor. + """ + output = self._normalization(x.to(self.device)) + + return output + + @property + def recurrent(self) -> bool: + """Returns whether the network is recurrent.""" + return self._recurrent + + def reset_hidden_state(self, indices: torch.Tensor) -> None: + """Resets the hidden state of the neural network. + + Throws an error if the network is not recurrent. + + Args: + indices (torch.Tensor): A 1-dimensional int tensor containing the indices of the terminated + environments. + """ + assert self._recurrent + + self.hidden_state[0][:, indices] = torch.zeros(len(indices), self._features[0].hidden_size, device=self.device) + self.hidden_state[1][:, indices] = torch.zeros(len(indices), self._features[0].hidden_size, device=self.device) + + def reset_full_hidden_state(self, batch_size=None) -> None: + """Resets the hidden state of the neural network. + + Args: + batch_size (int): The batch size of the hidden state. If None, the hidden state is reset to None. + """ + assert self._recurrent + + if batch_size is None: + self.hidden_state = None + else: + layer_count, hidden_size = self._features[0].num_layers, self._features[0].hidden_size + self.hidden_state = ( + torch.zeros(layer_count, batch_size, hidden_size, device=self.device), + torch.zeros(layer_count, batch_size, hidden_size, device=self.device), + ) + + def _init(self, layers: List[nn.Module], fade: bool = True, gain: float = 1.0) -> List[nn.Module]: + """Initializes neural network layers.""" + last_layer_idx = len(layers) - 1 - next(i for i, l in enumerate(reversed(layers)) if isinstance(l, nn.Linear)) + + for idx, layer in enumerate(layers): + if not isinstance(layer, nn.Linear): + continue + + current_gain = gain / 100.0 if fade and idx == last_layer_idx else gain + nn.init.xavier_normal_(layer.weight, gain=current_gain) + + return layers diff --git a/rsl_rl/modules/normalizer.py b/rsl_rl/modules/normalizer.py index 771efcf..4de90a3 100644 --- a/rsl_rl/modules/normalizer.py +++ b/rsl_rl/modules/normalizer.py @@ -1,9 +1,3 @@ -# Copyright (c) 2020 Preferred Networks, Inc. -# Copyright 2021 ETH Zurich, NVIDIA CORPORATION -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - import torch from torch import nn @@ -11,48 +5,58 @@ from torch import nn class EmpiricalNormalization(nn.Module): """Normalize mean and variance of values based on empirical values.""" - def __init__(self, shape, eps=1e-2, until=None): - """Initialize EmpiricalNormalization module. - + def __init__(self, shape, eps=1e-6, until=None) -> None: + """ Args: shape (int or tuple of int): Shape of input values except batch axis. eps (float): Small value for stability. until (int or None): If this arg is specified, the link learns input values until the sum of batch sizes - exceeds it. + exceeds it. """ super().__init__() + self.eps = eps self.until = until + self.register_buffer("_mean", torch.zeros(shape).unsqueeze(0)) self.register_buffer("_var", torch.ones(shape).unsqueeze(0)) self.register_buffer("_std", torch.ones(shape).unsqueeze(0)) + self.count = 0 @property - def mean(self): - return self._mean.squeeze(0).clone() + def mean(self) -> torch.Tensor: + """Mean of input values.""" + return self._mean.squeeze(0).detach().clone() @property - def std(self): - return self._std.squeeze(0).clone() - - def forward(self, x): - """Normalize mean and variance of values based on empirical values. + def std(self) -> torch.Tensor: + """Standard deviation of input values.""" + return self._std.squeeze(0).detach().clone() + def forward(self, x) -> torch.Tensor: + """Normalize mean and variance of values based on emprical values. Args: x (ndarray or Variable): Input values - Returns: - ndarray or Variable: Normalized output values + Normalized output values """ if self.training: self.update(x) - return (x - self._mean) / (self._std + self.eps) + + x_normalized = (x - self._mean.detach()) / (self._std.detach() + self.eps) + + return x_normalized @torch.jit.unused - def update(self, x): - """Learn input values without computing the output values of them""" + def update(self, x: torch.Tensor) -> None: + """Learn input values without computing the output values of them. + + Args: + x (torch.Tensor): Input values. + """ + x = x.detach() if self.until is not None and self.count >= self.until: return @@ -69,5 +73,14 @@ class EmpiricalNormalization(nn.Module): self._std = torch.sqrt(self._var) @torch.jit.unused - def inverse(self, y): - return y * (self._std + self.eps) + self._mean + def inverse(self, y: torch.Tensor) -> torch.Tensor: + """Inverse normalized values. + + Args: + y (torch.Tensor): Normalized input values. + Returns: + Inverse normalized output values. + """ + inv = y * (self._std + self.eps) + self._mean + + return inv diff --git a/rsl_rl/modules/quantile_network.py b/rsl_rl/modules/quantile_network.py new file mode 100644 index 0000000..a3f283e --- /dev/null +++ b/rsl_rl/modules/quantile_network.py @@ -0,0 +1,333 @@ +import torch +import torch.nn as nn +from torch.distributions import Normal +from typing import Callable, Tuple, Union + +from rsl_rl.modules.network import Network +from rsl_rl.utils.benchmarkable import Benchmarkable +from rsl_rl.utils.utils import squeeze_preserve_batch + +eps = torch.finfo(torch.float32).eps + + +def reshape_measure_parameters( + qn: Network, *params: Union[torch.Tensor, float] +) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: + """Reshapes the parameters of a measure function to match the shape of the quantile network. + + Args: + qn (Network): The quantile network. + *params (Union[torch.Tensor, float]): The parameters of the measure function. + Returns: + Union[torch.Tensor, Tuple[torch.Tensor, ...]]: The reshaped parameters. + """ + if not params: + return qn._tau.to(qn.device), *params + + assert len([*set([torch.is_tensor(p) for p in params])]) == 1, "All parameters must be either tensors or scalars." + + if torch.is_tensor(params[0]): + assert all([p.dim() == 1 for p in params]), "All parameters must have dimensionality 1." + assert len([*set([p.shape[0] for p in params])]) == 1, "All parameters must have the same size." + + reshaped_params = [p.reshape(-1, 1).to(qn.device) for p in params] + tau = qn._tau.expand(params[0].shape[0], -1).to(qn.device) + else: + reshaped_params = params + tau = qn._tau.to(qn.device) + + return tau, *reshaped_params + + +def make_distorted_measure(distorted_tau: torch.Tensor) -> Callable: + """Creates a measure function for the distorted expectation under some distortion function. + + The distorted expectation for some distortion function g(tau) is given by the integral w.r.t. tau + "int_0^1 g'(tau) * F_Z^{-1}(tau) dtau" where g'(tau) is the derivative of g w.r.t. tau and F_Z^{-1} is the inverse + cumulative distribution function of the value distribution. + See https://arxiv.org/pdf/2004.14547.pdf and https://arxiv.org/pdf/1806.06923.pdf for details. + """ + distorted_tau = distorted_tau.reshape(-1, distorted_tau.shape[-1]) + distortion = (distorted_tau[:, 1:] - distorted_tau[:, :-1]).squeeze(0) + + def distorted_measure(quantiles): + sorted_quantiles, _ = quantiles.sort(-1) + sorted_quantiles = sorted_quantiles.reshape(-1, sorted_quantiles.shape[-1]) + + # dtau = tau[1:] - tau[:-1] cancels the denominator of g'(tau) = g(tau)[1:] - g(tau)[:-1] / dtau. + values = squeeze_preserve_batch((distortion.to(sorted_quantiles.device) * sorted_quantiles).sum(-1)) + + return values + + return distorted_measure + + +def risk_measure_cvar(qn: Network, confidence_level: float = 1.0) -> Callable: + """Conditional value at risk measure. + + TODO: Handle confidence_level being a tensor. + + Args: + qn (QuantileNetwork): Quantile network to compute the risk measure for. + confidence_level (float): Confidence level of the risk measure. Must be between 0 and 1. + Returns: + A risk measure function. + """ + tau, confidence_level = reshape_measure_parameters(qn, confidence_level) + distorted_tau = torch.min(tau / confidence_level, torch.ones(*tau.shape).to(tau.device)) + + return make_distorted_measure(distorted_tau) + + +def risk_measure_neutral(_: Network) -> Callable: + """Neutral risk measure (expected value). + + Args: + _ (QuantileNetwork): Quantile network to compute the risk measure for. + Returns: + A risk measure function. + """ + + def measure(quantiles): + values = squeeze_preserve_batch(quantiles.mean(-1)) + + return values + + return measure + + +def risk_measure_percentile(_: Network, confidence_level: float = 1.0) -> Callable: + """Value at risk measure. + + Args: + _ (QuantileNetwork): Quantile network to compute the risk measure for. + confidence_level (float): Confidence level of the risk measure. Must be between 0 and 1. + Returns: + A risk measure function. + """ + + def measure(quantiles): + sorted_quantiles, _ = quantiles.sort(-1) + sorted_quantiles = sorted_quantiles.reshape(-1, sorted_quantiles.shape[-1]) + idx = min(int(confidence_level * quantiles.shape[-1]), quantiles.shape[-1] - 1) + + values = squeeze_preserve_batch(sorted_quantiles[:, idx]) + + return values + + return measure + + +def risk_measure_wang(qn: Network, beta: Union[float, torch.Tensor] = 0.0) -> Callable: + """Wang's risk measure. + + The risk measure computes the distorted expectation under Wang's risk distortion function + g(tau) = Phi(Phi^-1(tau) + beta) where Phi and Phi^-1 are the standard normal CDF and its inverse. + See https://arxiv.org/pdf/2004.14547.pdf for details. + + Args: + qn (QuantileNetwork): Quantile network to compute the risk measure for. + beta (float): Parameter of the risk distortion function. + Returns: + A risk measure function. + """ + tau, beta = reshape_measure_parameters(qn, beta) + + distorted_tau = Normal(0, 1).cdf(Normal(0, 1).icdf(tau) + beta) + + return make_distorted_measure(distorted_tau) + + +def energy_loss(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + """Computes sample energy loss between predictions and targets. + + The energy loss is computed as 2*E[||X - Y||_2] - E[||X - X'||_2] - E[||Y - Y'||_2], where X, X' and Y, Y' are + random variables and ||.||_2 is the L2-norm. X, X' are the predictions and Y, Y' are the targets. + + Args: + predictions (torch.Tensor): Predictions to compute loss from. + targets (torch.Tensor): Targets to compare predictions against. + Returns: + A torch.Tensor of shape (1,) containing the loss. + """ + dims = [-1 for _ in range(predictions.dim())] + prediction_mat = predictions.unsqueeze(-1).expand(*dims, predictions.shape[-1]) + target_mat = targets.unsqueeze(-1).expand(*dims, predictions.shape[-1]) + + delta_xx = (prediction_mat - prediction_mat.transpose(-1, -2)).abs().mean() + delta_yy = (target_mat - target_mat.transpose(-1, -2)).abs().mean() + delta_xy = (prediction_mat - target_mat.transpose(-1, -2)).abs().mean() + + loss = 2 * delta_xy - delta_xx - delta_yy + + return loss + + +class QuantileNetwork(Network): + measure_cvar = "cvar" + measure_neutral = "neutral" + measure_percentile = "percentile" + measure_wang = "wang" + + measures = { + measure_cvar: risk_measure_cvar, + measure_neutral: risk_measure_neutral, + measure_percentile: risk_measure_percentile, + measure_wang: risk_measure_wang, + } + + def __init__( + self, + input_size, + output_size, + activations=["relu", "relu", "relu"], + hidden_dims=[256, 256, 256], + init_fade=False, + init_gain=0.5, + measure=None, + measure_kwargs={}, + quantile_count=200, + **kwargs, + ): + assert len(hidden_dims) == len(activations) + assert quantile_count > 0 + + super().__init__( + input_size, + activations=activations, + hidden_dims=hidden_dims[:-1], + init_fade=False, + init_gain=init_gain, + output_size=hidden_dims[-1], + **kwargs, + ) + + self._quantile_count = quantile_count + self._tau = torch.arange(self._quantile_count + 1) / self._quantile_count + self._tau_hat = torch.tensor([(self._tau[i] + self._tau[i + 1]) / 2 for i in range(self._quantile_count)]) + self._tau_hat_mat = torch.empty((0,)) + + self._quantile_layers = nn.ModuleList([nn.Linear(hidden_dims[-1], quantile_count) for _ in range(output_size)]) + + self._init(self._quantile_layers, fade=init_fade, gain=init_gain) + + measure_func = risk_measure_neutral if measure is None else self.measures[measure] + self._measure_func = measure_func + self._measure = measure_func(self, **measure_kwargs) + + self._last_quantiles = None + + @property + def last_quantiles(self) -> torch.Tensor: + return self._last_quantiles + + def make_diracs(self, values: torch.Tensor) -> torch.Tensor: + """Generates value distributions that have a single spike at the given values. + + Args: + values (torch.Tensor): Values to generate dirac distributions for. + Returns: + A torch.Tensor of shape (*values.shape, quantile_count) containing the dirac distributions. + """ + dirac = values.unsqueeze(-1).expand(*[-1 for _ in range(values.dim())], self._quantile_count) + + return dirac + + @property + def quantile_count(self) -> int: + return self._quantile_count + + @Benchmarkable.register + def quantiles_to_values(self, quantiles: torch.Tensor, *measure_args) -> torch.Tensor: + """Computes values from quantiles. + + Args: + quantiles (torch.Tensor): Quantiles to compute values from. + measure_kwargs (dict): Keyword arguments to pass to the risk measure function instead of the arguments + passed when creating the network. + Returns: + A torch.Tensor of shape (1,) containing the values. + """ + if measure_args: + values = self._measure_func(self, *[squeeze_preserve_batch(m) for m in measure_args])(quantiles) + else: + values = self._measure(quantiles) + + return values + + @Benchmarkable.register + def quantile_l1_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + """Computes quantile-wise l1 loss between predictions and targets. + + TODO: This function is a bottleneck. + + Args: + predictions (torch.Tensor): Predictions to compute loss from. + targets (torch.Tensor): Targets to compare predictions against. + Returns: + A torch.Tensor of shape (1,) containing the loss. + """ + assert ( + predictions.dim() == 2 or predictions.dim() == 3 + ), f"Predictions must be 2D or 3D. Got {predictions.dim()}." + assert ( + predictions.shape == targets.shape + ), f"The shapes of predictions and targets must match. Got {predictions.shape} and {targets.shape}." + + pre_dims = [-1] if predictions.dim() == 3 else [] + + prediction_mat = predictions.unsqueeze(-3).expand(*pre_dims, self._quantile_count, -1, -1) + target_mat = targets.transpose(-2, -1).unsqueeze(-1).expand(*pre_dims, -1, -1, self._quantile_count) + delta = target_mat - prediction_mat + + tau_hat = self._tau_hat.expand(predictions.shape[-2], -1).to(self.device) + loss = (torch.where(delta < 0, (tau_hat - 1), tau_hat) * delta).abs().mean() + + return loss + + @Benchmarkable.register + def quantile_huber_loss(self, predictions: torch.Tensor, targets: torch.Tensor, kappa: float = 1.0) -> torch.Tensor: + """Computes quantile huber loss between predictions and targets. + + TODO: This function is a bottleneck. + + Args: + predictions (torch.Tensor): Predictions to compute loss from. + targets (torch.Tensor): Targets to compare predictions against. + kappa (float): Defines the interval [-kappa, kappa] around zero where squared loss is used. Defaults to 1. + Returns: + A torch.tensor of shape (1,) containing the loss. + """ + pre_dims = [-1] if predictions.dim() == 3 else [] + + prediction_mat = predictions.unsqueeze(-3).expand(*pre_dims, self._quantile_count, -1, -1) + target_mat = targets.transpose(-2, -1).unsqueeze(-1).expand(*pre_dims, -1, -1, self._quantile_count) + delta = target_mat - prediction_mat + delta_abs = delta.abs() + + huber = torch.where(delta_abs <= kappa, 0.5 * delta.pow(2), kappa * (delta_abs - 0.5 * kappa)) + + tau_hat = self._tau_hat.expand(predictions.shape[-2], -1).to(self.device) + loss = (torch.where(delta < 0, (tau_hat - 1), tau_hat).abs() * huber).mean() + + return loss + + @Benchmarkable.register + def sample_energy_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + loss = energy_loss(predictions, targets) + + return loss + + @Benchmarkable.register + def forward(self, x: torch.Tensor, distribution: bool = False, measure_args: list = [], **kwargs) -> torch.Tensor: + features = super().forward(x, **kwargs) + quantiles = squeeze_preserve_batch(torch.stack([layer(features) for layer in self._quantile_layers], dim=1)) + + self._last_quantiles = quantiles + + if distribution: + return quantiles + + values = self.quantiles_to_values(quantiles, *measure_args) + + return values diff --git a/rsl_rl/modules/transformer.py b/rsl_rl/modules/transformer.py new file mode 100644 index 0000000..7e27b0f --- /dev/null +++ b/rsl_rl/modules/transformer.py @@ -0,0 +1,150 @@ +import torch +from typing import Tuple + + +class Head(torch.nn.Module): + """A single causal self-attention head.""" + + def __init__(self, hidden_size: int, head_size: int): + super().__init__() + + self.query = torch.nn.Linear(hidden_size, head_size) + self.key = torch.nn.Linear(hidden_size, head_size) + self.value = torch.nn.Linear(hidden_size, head_size) + + def forward(self, x: torch.Tensor): + x = x.transpose(0, 1) + _, S, F = x.shape # (Batch, Sequence, Features) + + q = self.query(x) + k = self.key(x) + + weight = q @ k.transpose(-1, -2) * F**-0.5 # shape: (B, S, S) + weight.masked_fill(torch.tril(torch.ones(S, S, device=x.device)) == 0, float("-inf")) + weight = torch.nn.functional.softmax(weight, dim=-1) + + v = self.value(x) # shape: (B, S, F) + out = (weight @ v).transpose(0, 1) # shape: (S, B, F) + + return out + + +class MultiHead(torch.nn.Module): + def __init__(self, hidden_size: int, head_count: int): + super().__init__() + + assert hidden_size % head_count == 0, f"Multi-headed attention head size must be a multiple of the head count." + + self.heads = torch.nn.ModuleList([Head(hidden_size, hidden_size // head_count) for _ in range(head_count)]) + self.proj = torch.nn.Linear(hidden_size, hidden_size) + + def forward(self, x: torch.Tensor): + x = torch.cat([head(x) for head in self.heads], dim=-1) + out = self.proj(x) + + return out + + +class Block(torch.nn.Module): + def __init__(self, hidden_size: int, head_count: int): + super().__init__() + + self.sa = MultiHead(hidden_size, head_count) + self.ff = torch.nn.Sequential( + torch.nn.Linear(hidden_size, 4 * hidden_size), + torch.nn.ReLU(), + torch.nn.Linear(4 * hidden_size, hidden_size), + ) + self.ln1 = torch.nn.LayerNorm(hidden_size) + self.ln2 = torch.nn.LayerNorm(hidden_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.sa(self.ln1(x)) + x + out = self.ff(self.ln2(x)) + x + + return out + + +class Transformer(torch.nn.Module): + """A Transformer-based recurrent module. + + The Transformer module is a recurrent module that uses a Transformer architecture to process the input sequence. It + uses a hidden state to emulate RNN-like behavior. + """ + + def __init__( + self, input_size, output_size, hidden_size, block_count: int = 6, context_length: int = 64, head_count: int = 8 + ): + """ + Args: + input_size (int): The size of the input. + output_size (int): The size of the output. + hidden_size (int): The size of the hidden layers. + block_count (int): The number of Transformer blocks. + context_length (int): The length of the context to consider when predicting the next token. + head_count (int): The number of attention heads per block. + """ + + assert context_length % 2 == 0, f"Context length must be even." + + super().__init__() + + self.context_length = context_length + self.hidden_size = hidden_size + + self.feature_proj = torch.nn.Linear(input_size, hidden_size) + self.position_embedding = torch.nn.Embedding(context_length, hidden_size) + self.blocks = torch.nn.Sequential( + *[Block(hidden_size, head_count) for _ in range(block_count)], + torch.nn.LayerNorm(hidden_size), + ) + self.head = torch.nn.Linear(hidden_size, output_size) + + @property + def num_layers(self): + # Set num_layers to half the context length for simple torch.nn.LSTM compatibility. TODO: This is a bit hacky. + return self.context_length // 2 + + def step(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor: + """Computes Transformer output given the full context and the input. + + Args: + x (torch.Tensor): The input tensor of shape (Sequence, Batch, Features). + context (torch.Tensor): The context tensor of shape (Context Length, Batch, Features). + Returns: + A tuple of the output tensor of shape (Sequence, Batch, Features) and the updated context with the input + features appended. The context has shape (Context Length, Batch, Features). + """ + S = x.shape[0] + + # Project input to feature space and add to context. + ft_x = self.feature_proj(x) + context = torch.cat((context, ft_x), dim=0)[-self.context_length :] + + # Add positional embedding to context. + ft_pos = self.position_embedding(torch.arange(self.context_length, device=x.device)).unsqueeze(1) + x = context + ft_pos + + # Compute output from Transformer blocks. + x = self.blocks(x) + out = self.head(x)[-S:] + + return out, context + + def forward(self, x: torch.Tensor, hidden_state: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + """Computes Transformer output given the input and the hidden state which encapsulates the context.""" + if hidden_state is None: + hidden_state = self.reset_hidden_state(x.shape[1], device=x.device) + context = torch.cat(hidden_state, dim=0) + + out, context = self.step(x, context) + + hidden_state = context[: self.num_layers], context[self.num_layers :] + + return out, hidden_state + + def reset_hidden_state(self, batch_size: int, device="cpu"): + hidden_state = torch.zeros((self.context_length, batch_size, self.hidden_size), device=device) + hidden_state = hidden_state[: self.num_layers], hidden_state[self.num_layers :] + + return hidden_state diff --git a/rsl_rl/modules/utils.py b/rsl_rl/modules/utils.py new file mode 100644 index 0000000..48d2b30 --- /dev/null +++ b/rsl_rl/modules/utils.py @@ -0,0 +1,32 @@ +from torch import nn + + +def get_activation(act_name): + if act_name == "elu": + return nn.ELU() + elif act_name == "selu": + return nn.SELU() + elif act_name == "relu": + return nn.ReLU() + elif act_name == "crelu": + return nn.ReLU() + elif act_name == "lrelu": + return nn.LeakyReLU() + elif act_name == "tanh": + return nn.Tanh() + elif act_name == "sigmoid": + return nn.Sigmoid() + elif act_name == "linear": + return nn.Identity() + elif act_name == "softmax": + return nn.Softmax() + else: + print("invalid activation function!") + return None + + +def init_xavier_uniform(layer, activation): + try: + nn.init.xavier_uniform_(layer.weight, gain=nn.init.calculate_gain(activation)) + except ValueError: + nn.init.xavier_uniform_(layer.weight) diff --git a/rsl_rl/runners/__init__.py b/rsl_rl/runners/__init__.py index cebf22b..0097668 100644 --- a/rsl_rl/runners/__init__.py +++ b/rsl_rl/runners/__init__.py @@ -3,6 +3,7 @@ """Implementation of runners for environment-agent interaction.""" -from .on_policy_runner import OnPolicyRunner +from .runner import Runner +from .legacy_runner import LeggedGymRunner -__all__ = ["OnPolicyRunner"] +__all__ = ["LeggedGymRunner", "Runner"] diff --git a/rsl_rl/runners/callbacks.py b/rsl_rl/runners/callbacks.py new file mode 100644 index 0000000..df99d5a --- /dev/null +++ b/rsl_rl/runners/callbacks.py @@ -0,0 +1,79 @@ +import os +import random +import string +import wandb + + +def make_save_model_cb(directory): + def cb(runner, stat): + path = os.path.join(directory, "model_{}.pt".format(stat["current_iteration"])) + runner.save(path) + + return cb + + +def make_save_model_onnx_cb(directory): + def cb(runner, stat): + path = os.path.join(directory, "model_{}.onnx".format(stat["current_iteration"])) + runner.export_onnx(path) + + return cb + + +def make_interval_cb(callback, interval): + def cb(runner, stat): + if stat["current_iteration"] % interval != 0: + return + + callback(runner, stat) + + return cb + + +def make_final_cb(callback): + def cb(runner, stat): + if not runner._learning_should_terminate(): + return + + callback(runner, stat) + + return cb + + +def make_first_cb(callback): + uuid = "".join(random.choices(string.ascii_letters + string.digits, k=8)) + + def cb(runner, stat): + if hasattr(runner, f"_first_cb_{uuid}"): + return + + setattr(runner, f"_first_cb_{uuid}", True) + callback(runner, stat) + + return cb + + +def make_wandb_cb(init_kwargs): + assert "project" in init_kwargs, "The project must be specified in the init_kwargs." + + run = wandb.init(**init_kwargs) + check_complete = make_final_cb(lambda *_: run.finish()) + + def cb(runner, stat): + mean_reward = sum(stat["returns"]) / len(stat["returns"]) if len(stat["returns"]) > 0 else 0.0 + mean_steps = sum(stat["lengths"]) / len(stat["lengths"]) if len(stat["lengths"]) > 0 else 0.0 + total_steps = stat["current_iteration"] * runner.env.num_envs * runner._num_steps_per_env + training_time = stat["training_time"] + + run.log( + { + "mean_rewards": mean_reward, + "mean_steps": mean_steps, + "training_steps": total_steps, + "training_time": training_time, + } + ) + + check_complete(runner, stat) + + return cb diff --git a/rsl_rl/runners/legacy_runner.py b/rsl_rl/runners/legacy_runner.py new file mode 100644 index 0000000..4c98ab3 --- /dev/null +++ b/rsl_rl/runners/legacy_runner.py @@ -0,0 +1,136 @@ +import os + +from rsl_rl.algorithms import * +from rsl_rl.env import VecEnv +from rsl_rl.runners.callbacks import ( + make_final_cb, + make_first_cb, + make_interval_cb, + make_save_model_onnx_cb, +) +from rsl_rl.runners.runner import Runner +from rsl_rl.storage import * + + +def make_legacy_save_model_cb(directory): + def cb(runner, stat): + data = {} + + if hasattr(runner.env, "_persistent_data"): + data["env_data"] = runner.env._persistent_data + + path = os.path.join(directory, "model_{}.pt".format(stat["current_iteration"])) + runner.save(path, data=data) + + return cb + + +class LeggedGymRunner(Runner): + """Runner for legged_gym environments.""" + + mappings = [ + ("init_noise_std", "actor_noise_std"), + ("clip_param", "clip_ratio"), + ("desired_kl", "target_kl"), + ("entropy_coef", "entropy_coeff"), + ("lam", "gae_lambda"), + ("max_grad_norm", "gradient_clip"), + ("num_learning_epochs", None), + ("num_mini_batches", "batch_count"), + ("use_clipped_value_loss", None), + ("value_loss_coef", "value_coeff"), + ] + + @staticmethod + def _hook_env(env: VecEnv): + old_step = env.step + + def step_hook(*args, **kwargs): + result = old_step(*args, **kwargs) + + if len(result) == 4: + obs, rewards, dones, env_info = result + elif len(result) == 5: + obs, _, rewards, dones, env_info = result + else: + raise ValueError("Invalid number of return values from env.step().") + + return obs, rewards, dones.float(), env_info + + env.step = step_hook + + return env + + def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"): + env = self._hook_env(env) + self.cfg = train_cfg["runner"] + + alg_class = eval(self.cfg["algorithm_class_name"]) + if "policy_class_name" in self.cfg: + print("WARNING: ignoring deprecated parameter 'runner.policy_class_name'.") + + alg_cfg = train_cfg["algorithm"] + alg_cfg.update(train_cfg["policy"]) + + if "activation" in alg_cfg: + print( + "WARNING: using deprecated parameter 'activation'. Use 'actor_activations' and 'critic_activations' instead." + ) + alg_cfg["actor_activations"] = [alg_cfg["activation"] for _ in range(len(alg_cfg["actor_hidden_dims"]))] + alg_cfg["actor_activations"] += ["linear"] + alg_cfg["critic_activations"] = [alg_cfg["activation"] for _ in range(len(alg_cfg["critic_hidden_dims"]))] + alg_cfg["critic_activations"] += ["linear"] + del alg_cfg["activation"] + + for old, new in self.mappings: + if old not in alg_cfg: + continue + + if new is None: + print(f"WARNING: ignoring deprecated parameter '{old}'.") + del alg_cfg[old] + continue + + print(f"WARNING: using deprecated parameter '{old}'. Use '{new}' instead.") + alg_cfg[new] = alg_cfg[old] + del alg_cfg[old] + + agent: Agent = alg_class(env, device=device, **train_cfg["algorithm"]) + + callbacks = [] + evaluation_callbacks = [] + + evaluation_callbacks.append(lambda *args: Runner._log_progress(*args, prefix="eval")) + + if log_dir and "save_interval" in self.cfg: + callbacks.append(make_first_cb(make_legacy_save_model_cb(log_dir))) + callbacks.append(make_interval_cb(make_legacy_save_model_cb(log_dir), self.cfg["save_interval"])) + + if log_dir: + callbacks.append(Runner._log) + callbacks.append(make_final_cb(make_legacy_save_model_cb(log_dir))) + callbacks.append(make_final_cb(make_save_model_onnx_cb(log_dir))) + # callbacks.append(make_first_cb(lambda *_: store_code_state(log_dir, self._git_status_repos))) + else: + callbacks.append(Runner._log_progress) + + super().__init__( + env, + agent, + learn_cb=callbacks, + evaluation_cb=evaluation_callbacks, + device=device, + num_steps_per_env=self.cfg["num_steps_per_env"], + ) + + self._iteration_time = 0.0 + + def learn(self, *args, num_learning_iterations=None, init_at_random_ep_len=None, **kwargs): + if num_learning_iterations is not None: + print("WARNING: using deprecated parameter 'num_learning_iterations'. Use 'iterations' instead.") + kwargs["iterations"] = num_learning_iterations + + if init_at_random_ep_len is not None: + print("WARNING: ignoring deprecated parameter 'init_at_random_ep_len'.") + + super().learn(*args, **kwargs) diff --git a/rsl_rl/runners/on_policy_runner.py b/rsl_rl/runners/on_policy_runner.py deleted file mode 100644 index 9e0a459..0000000 --- a/rsl_rl/runners/on_policy_runner.py +++ /dev/null @@ -1,304 +0,0 @@ -# Copyright 2021 ETH Zurich, NVIDIA CORPORATION -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -import os -import statistics -import time -import torch -from collections import deque -from torch.utils.tensorboard import SummaryWriter as TensorboardSummaryWriter - -import rsl_rl -from rsl_rl.algorithms import PPO -from rsl_rl.env import VecEnv -from rsl_rl.modules import ActorCritic, ActorCriticRecurrent, EmpiricalNormalization -from rsl_rl.utils import store_code_state - - -class OnPolicyRunner: - """On-policy runner for training and evaluation.""" - - def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"): - self.cfg = train_cfg - self.alg_cfg = train_cfg["algorithm"] - self.policy_cfg = train_cfg["policy"] - self.device = device - self.env = env - obs, extras = self.env.get_observations() - num_obs = obs.shape[1] - if "critic" in extras["observations"]: - num_critic_obs = extras["observations"]["critic"].shape[1] - else: - num_critic_obs = num_obs - actor_critic_class = eval(self.policy_cfg.pop("class_name")) # ActorCritic - actor_critic: ActorCritic | ActorCriticRecurrent = actor_critic_class( - num_obs, num_critic_obs, self.env.num_actions, **self.policy_cfg - ).to(self.device) - alg_class = eval(self.alg_cfg.pop("class_name")) # PPO - self.alg: PPO = alg_class(actor_critic, device=self.device, **self.alg_cfg) - self.num_steps_per_env = self.cfg["num_steps_per_env"] - self.save_interval = self.cfg["save_interval"] - self.empirical_normalization = self.cfg["empirical_normalization"] - if self.empirical_normalization: - self.obs_normalizer = EmpiricalNormalization(shape=[num_obs], until=1.0e8).to(self.device) - self.critic_obs_normalizer = EmpiricalNormalization(shape=[num_critic_obs], until=1.0e8).to(self.device) - else: - self.obs_normalizer = torch.nn.Identity() # no normalization - self.critic_obs_normalizer = torch.nn.Identity() # no normalization - # init storage and model - self.alg.init_storage( - self.env.num_envs, - self.num_steps_per_env, - [num_obs], - [num_critic_obs], - [self.env.num_actions], - ) - - # Log - self.log_dir = log_dir - self.writer = None - self.tot_timesteps = 0 - self.tot_time = 0 - self.current_learning_iteration = 0 - self.git_status_repos = [rsl_rl.__file__] - - def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = False): - # initialize writer - if self.log_dir is not None and self.writer is None: - # Launch either Tensorboard or Neptune & Tensorboard summary writer(s), default: Tensorboard. - self.logger_type = self.cfg.get("logger", "tensorboard") - self.logger_type = self.logger_type.lower() - - if self.logger_type == "neptune": - from rsl_rl.utils.neptune_utils import NeptuneSummaryWriter - - self.writer = NeptuneSummaryWriter(log_dir=self.log_dir, flush_secs=10, cfg=self.cfg) - self.writer.log_config(self.env.cfg, self.cfg, self.alg_cfg, self.policy_cfg) - elif self.logger_type == "wandb": - from rsl_rl.utils.wandb_utils import WandbSummaryWriter - - self.writer = WandbSummaryWriter(log_dir=self.log_dir, flush_secs=10, cfg=self.cfg) - self.writer.log_config(self.env.cfg, self.cfg, self.alg_cfg, self.policy_cfg) - elif self.logger_type == "tensorboard": - self.writer = TensorboardSummaryWriter(log_dir=self.log_dir, flush_secs=10) - else: - raise AssertionError("logger type not found") - - if init_at_random_ep_len: - self.env.episode_length_buf = torch.randint_like( - self.env.episode_length_buf, high=int(self.env.max_episode_length) - ) - obs, extras = self.env.get_observations() - critic_obs = extras["observations"].get("critic", obs) - obs, critic_obs = obs.to(self.device), critic_obs.to(self.device) - self.train_mode() # switch to train mode (for dropout for example) - - ep_infos = [] - rewbuffer = deque(maxlen=100) - lenbuffer = deque(maxlen=100) - cur_reward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device) - cur_episode_length = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device) - - start_iter = self.current_learning_iteration - tot_iter = start_iter + num_learning_iterations - for it in range(start_iter, tot_iter): - start = time.time() - # Rollout - with torch.inference_mode(): - for i in range(self.num_steps_per_env): - actions = self.alg.act(obs, critic_obs) - obs, rewards, dones, infos = self.env.step(actions) - obs = self.obs_normalizer(obs) - if "critic" in infos["observations"]: - critic_obs = self.critic_obs_normalizer(infos["observations"]["critic"]) - else: - critic_obs = obs - obs, critic_obs, rewards, dones = ( - obs.to(self.device), - critic_obs.to(self.device), - rewards.to(self.device), - dones.to(self.device), - ) - self.alg.process_env_step(rewards, dones, infos) - - if self.log_dir is not None: - # Book keeping - # note: we changed logging to use "log" instead of "episode" to avoid confusion with - # different types of logging data (rewards, curriculum, etc.) - if "episode" in infos: - ep_infos.append(infos["episode"]) - elif "log" in infos: - ep_infos.append(infos["log"]) - cur_reward_sum += rewards - cur_episode_length += 1 - new_ids = (dones > 0).nonzero(as_tuple=False) - rewbuffer.extend(cur_reward_sum[new_ids][:, 0].cpu().numpy().tolist()) - lenbuffer.extend(cur_episode_length[new_ids][:, 0].cpu().numpy().tolist()) - cur_reward_sum[new_ids] = 0 - cur_episode_length[new_ids] = 0 - - stop = time.time() - collection_time = stop - start - - # Learning step - start = stop - self.alg.compute_returns(critic_obs) - - mean_value_loss, mean_surrogate_loss = self.alg.update() - stop = time.time() - learn_time = stop - start - self.current_learning_iteration = it - if self.log_dir is not None: - self.log(locals()) - if it % self.save_interval == 0: - self.save(os.path.join(self.log_dir, f"model_{it}.pt")) - ep_infos.clear() - if it == start_iter: - # obtain all the diff files - git_file_paths = store_code_state(self.log_dir, self.git_status_repos) - # if possible store them to wandb - if self.logger_type in ["wandb", "neptune"] and git_file_paths: - for path in git_file_paths: - self.writer.save_file(path) - - self.save(os.path.join(self.log_dir, f"model_{self.current_learning_iteration}.pt")) - - def log(self, locs: dict, width: int = 80, pad: int = 35): - self.tot_timesteps += self.num_steps_per_env * self.env.num_envs - self.tot_time += locs["collection_time"] + locs["learn_time"] - iteration_time = locs["collection_time"] + locs["learn_time"] - - ep_string = "" - if locs["ep_infos"]: - for key in locs["ep_infos"][0]: - infotensor = torch.tensor([], device=self.device) - for ep_info in locs["ep_infos"]: - # handle scalar and zero dimensional tensor infos - if key not in ep_info: - continue - if not isinstance(ep_info[key], torch.Tensor): - ep_info[key] = torch.Tensor([ep_info[key]]) - if len(ep_info[key].shape) == 0: - ep_info[key] = ep_info[key].unsqueeze(0) - infotensor = torch.cat((infotensor, ep_info[key].to(self.device))) - value = torch.mean(infotensor) - # log to logger and terminal - if "/" in key: - self.writer.add_scalar(key, value, locs["it"]) - ep_string += f"""{f'{key}:':>{pad}} {value:.4f}\n""" - else: - self.writer.add_scalar("Episode/" + key, value, locs["it"]) - ep_string += f"""{f'Mean episode {key}:':>{pad}} {value:.4f}\n""" - mean_std = self.alg.actor_critic.std.mean() - fps = int(self.num_steps_per_env * self.env.num_envs / (locs["collection_time"] + locs["learn_time"])) - - self.writer.add_scalar("Loss/value_function", locs["mean_value_loss"], locs["it"]) - self.writer.add_scalar("Loss/surrogate", locs["mean_surrogate_loss"], locs["it"]) - self.writer.add_scalar("Loss/learning_rate", self.alg.learning_rate, locs["it"]) - self.writer.add_scalar("Policy/mean_noise_std", mean_std.item(), locs["it"]) - self.writer.add_scalar("Perf/total_fps", fps, locs["it"]) - self.writer.add_scalar("Perf/collection time", locs["collection_time"], locs["it"]) - self.writer.add_scalar("Perf/learning_time", locs["learn_time"], locs["it"]) - if len(locs["rewbuffer"]) > 0: - self.writer.add_scalar("Train/mean_reward", statistics.mean(locs["rewbuffer"]), locs["it"]) - self.writer.add_scalar("Train/mean_episode_length", statistics.mean(locs["lenbuffer"]), locs["it"]) - if self.logger_type != "wandb": # wandb does not support non-integer x-axis logging - self.writer.add_scalar("Train/mean_reward/time", statistics.mean(locs["rewbuffer"]), self.tot_time) - self.writer.add_scalar( - "Train/mean_episode_length/time", statistics.mean(locs["lenbuffer"]), self.tot_time - ) - - str = f" \033[1m Learning iteration {locs['it']}/{locs['tot_iter']} \033[0m " - - if len(locs["rewbuffer"]) > 0: - log_string = ( - f"""{'#' * width}\n""" - f"""{str.center(width, ' ')}\n\n""" - f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[ - 'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n""" - f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n""" - f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n""" - f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n""" - f"""{'Mean reward:':>{pad}} {statistics.mean(locs['rewbuffer']):.2f}\n""" - f"""{'Mean episode length:':>{pad}} {statistics.mean(locs['lenbuffer']):.2f}\n""" - ) - # f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n""" - # f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n""") - else: - log_string = ( - f"""{'#' * width}\n""" - f"""{str.center(width, ' ')}\n\n""" - f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[ - 'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n""" - f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n""" - f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n""" - f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n""" - ) - # f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n""" - # f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n""") - - log_string += ep_string - log_string += ( - f"""{'-' * width}\n""" - f"""{'Total timesteps:':>{pad}} {self.tot_timesteps}\n""" - f"""{'Iteration time:':>{pad}} {iteration_time:.2f}s\n""" - f"""{'Total time:':>{pad}} {self.tot_time:.2f}s\n""" - f"""{'ETA:':>{pad}} {self.tot_time / (locs['it'] + 1) * ( - locs['num_learning_iterations'] - locs['it']):.1f}s\n""" - ) - print(log_string) - - def save(self, path, infos=None): - saved_dict = { - "model_state_dict": self.alg.actor_critic.state_dict(), - "optimizer_state_dict": self.alg.optimizer.state_dict(), - "iter": self.current_learning_iteration, - "infos": infos, - } - if self.empirical_normalization: - saved_dict["obs_norm_state_dict"] = self.obs_normalizer.state_dict() - saved_dict["critic_obs_norm_state_dict"] = self.critic_obs_normalizer.state_dict() - torch.save(saved_dict, path) - - # Upload model to external logging service - if self.logger_type in ["neptune", "wandb"]: - self.writer.save_model(path, self.current_learning_iteration) - - def load(self, path, load_optimizer=True): - loaded_dict = torch.load(path) - self.alg.actor_critic.load_state_dict(loaded_dict["model_state_dict"]) - if self.empirical_normalization: - self.obs_normalizer.load_state_dict(loaded_dict["obs_norm_state_dict"]) - self.critic_obs_normalizer.load_state_dict(loaded_dict["critic_obs_norm_state_dict"]) - if load_optimizer: - self.alg.optimizer.load_state_dict(loaded_dict["optimizer_state_dict"]) - self.current_learning_iteration = loaded_dict["iter"] - return loaded_dict["infos"] - - def get_inference_policy(self, device=None): - self.eval_mode() # switch to evaluation mode (dropout for example) - if device is not None: - self.alg.actor_critic.to(device) - policy = self.alg.actor_critic.act_inference - if self.cfg["empirical_normalization"]: - if device is not None: - self.obs_normalizer.to(device) - policy = lambda x: self.alg.actor_critic.act_inference(self.obs_normalizer(x)) # noqa: E731 - return policy - - def train_mode(self): - self.alg.actor_critic.train() - if self.empirical_normalization: - self.obs_normalizer.train() - self.critic_obs_normalizer.train() - - def eval_mode(self): - self.alg.actor_critic.eval() - if self.empirical_normalization: - self.obs_normalizer.eval() - self.critic_obs_normalizer.eval() - - def add_git_repo_to_log(self, repo_file_path): - self.git_status_repos.append(repo_file_path) diff --git a/rsl_rl/runners/runner.py b/rsl_rl/runners/runner.py new file mode 100644 index 0000000..b461a88 --- /dev/null +++ b/rsl_rl/runners/runner.py @@ -0,0 +1,498 @@ +from __future__ import annotations +import copy +from datetime import timedelta +import numpy as np +import os +import time +import torch +from typing import Any, Callable, Dict, List, Tuple, TypedDict, Union + +import rsl_rl +from rsl_rl.storage.storage import Dataset +from rsl_rl.algorithms import Agent +from rsl_rl.env import VecEnv + + +class EpisodeStatistics(TypedDict): + """The statistics of an episode.""" + + # Time it took to collect samples for the current interation. + collection_time: Union[int, None] + # The counter of the current interation. + current_iteration: int + # The number of the final iteration of the current run. + final_iteration: int + # The number of the first iteration of the current run. + first_iteration: int + # Environment information about the current interation. + info: list + # The lengths of the episodes. + lengths: Union[List[int], None] + # The loss of the current interation. + loss: Union[dict, None] + # The returns of the episodes. + returns: Union[List[float], None] + # The total time it took to run the current interation. + total_time: Union[int, None] + # The time it took to update the agent. + update_time: Union[int, None] + + +Callback = Callable[[EpisodeStatistics], None] + + +class Runner: + """The runner class for running an agent in an environment. + + This class is responsible for running an agent in an environment. It is responsible for collecting data from the + environment, updating the agent, and evaluating the agent. It also provides a number of callbacks that can be used + to log and visualize the training progress. + """ + + _dataset: Dataset + _episode_statistics: EpisodeStatistics + _num_steps_per_env: int + + def __init__( + self, + environment: VecEnv, + agent: Agent, + device: str = "cpu", + evaluation_cb: List[Callback] = None, + learn_cb: List[Callback] = None, + observation_history_length: int = 1, + **kwargs, + ) -> None: + """ + Args: + environment (rsl_rl.env.VecEnv): The environment to run the agent in. + agent (rsl_rl.algorithms.agent): The RL agent to run. + device (str): The device to run on. + evaluation_cb (List[Callable[[dict], None]], optional): A list of callbacks that are called after each round + of evaluation. + learn_cb (List[Callable[[dict], None]], optional): A list of callbacks that are called after each round of + learning. + observation_history_length: The number of observations to concatenate into a single observation. + """ + self.env = environment + self.agent = agent + self.device = device + self._obs_hist_len = observation_history_length + self._learn_cb = learn_cb if learn_cb else [] + self._eval_cb = evaluation_cb if evaluation_cb else [] + + self._set_kwarg(kwargs, "num_steps_per_env", default=1) + + self._current_learning_iteration = 0 + self._git_status_repos = [rsl_rl.__file__] + + self.to(self.device) + + self._stored_dataset = [] # For computing observation history over multiple steps. + + def add_git_repo_to_log(self, repo_file_path): + self._git_status_repos.append(repo_file_path) + + def eval_mode(self): + """Sets the agent to evaluation mode.""" + self.agent.eval_mode() + + def evaluate(self, steps: int, return_epochs: int = 100) -> float: + """Evaluates the agent for a number of steps. + + Args: + steps (int): The number of steps to evaluate the agent for. + return_epochs (int): The number of epochs over which to aggregate the return. Defaults to 100. + Returns: + The mean return of the agent. + """ + cumulative_rewards = [] + current_cumulative_rewards = torch.zeros(self.env.num_envs, dtype=torch.float) + current_episode_lengths = torch.zeros(self.env.num_envs, dtype=torch.int) + episode_lengths = [] + + self.eval_mode() + + policy = self.get_inference_policy() + obs, env_info = self.env.get_observations() + + with torch.inference_mode(): + for step in range(steps): + actions = policy(obs.clone(), copy.deepcopy(env_info)) + obs, rewards, dones, env_info, episode_statistics = self.evaluate_step(obs, env_info, actions) + + dones_idx = dones.nonzero().cpu() + current_cumulative_rewards += rewards.clone().cpu() + current_episode_lengths += 1 + cumulative_rewards.extend(current_cumulative_rewards[dones_idx].squeeze(1).cpu().tolist()) + episode_lengths.extend(current_episode_lengths[dones_idx].squeeze(1).cpu().tolist()) + current_cumulative_rewards[dones_idx] = 0.0 + current_episode_lengths[dones_idx] = 0 + + episode_statistics["current_iteration"] = step + episode_statistics["final_iteration"] = steps + episode_statistics["lengths"] = episode_lengths[-return_epochs:] + episode_statistics["returns"] = cumulative_rewards[-return_epochs:] + + for cb in self._eval_cb: + cb(self, episode_statistics) + + cumulative_rewards.extend(current_cumulative_rewards.cpu().tolist()) + mean_return = np.mean(cumulative_rewards) + + return mean_return + + def evaluate_step( + self, observations=None, environment_info=None, actions=None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Dict, Dict]: + """Evaluates the agent for a single step. + + Args: + observations (torch.Tensor): The observations to evaluate the agent for. + environment_info (Dict[str, Any]): The environment information for the observations. + actions (torch.Tensor): The actions to evaluate the agent for. + Returns: + A tuple containing the observations, rewards, dones, environment information, and episode statistics after + the evaluation step. + """ + episode_statistics = { + "current_actions": None, + "current_dones": None, + "current_iteration": 0, + "current_observations": None, + "current_rewards": None, + "final_iteration": 0, + "first_iteration": 0, + "info": [], + "lengths": [], + "returns": [], + "timeout": None, + "total_time": None, + } + + self.eval_mode() + + with torch.inference_mode(): + obs, env_info = self.env.get_observations() if observations is None else (observations, environment_info) + + with torch.inference_mode(): + start = time.time() + + actions = self.get_inference_policy()(obs.clone(), copy.deepcopy(env_info)) if actions is None else actions + obs, rewards, dones, env_info = self.env.step(actions.clone()) + + self.agent.register_terminations(dones.nonzero().reshape(-1)) + + end = time.time() + + if "episode" in env_info: + episode_statistics["info"].append(env_info["episode"]) + episode_statistics["current_actions"] = actions + episode_statistics["current_dones"] = dones + episode_statistics["current_observations"] = obs + episode_statistics["current_rewards"] = rewards + episode_statistics["total_time"] = end - start + + return obs, rewards, dones, env_info, episode_statistics + + def get_inference_policy(self, device=None): + self.eval_mode() + + return self.agent.get_inference_policy(device) + + def learn( + self, iterations: Union[int, None] = None, timeout: Union[int, None] = None, return_epochs: int = 100 + ) -> None: + """Runs a number of learning iterations. + + Args: + iterations (int): The number of iterations to run. + timeout (int): Optional number of seconds after which to terminate training. Defaults to None. + return_epochs (int): The number of epochs over which to aggregate the return. Defaults to 100. + """ + assert iterations is not None or timeout is not None + + self._episode_statistics = { + "collection_time": None, + "current_actions": None, + "current_iteration": self._current_learning_iteration, + "current_observations": None, + "final_iteration": self._current_learning_iteration + iterations if iterations is not None else None, + "first_iteration": self._current_learning_iteration, + "info": [], + "lengths": [], + "loss": {}, + "returns": [], + "storage_initialized": False, + "timeout": timeout, + "total_time": None, + "training_time": 0, + "update_time": None, + } + self._current_episode_lengths = torch.zeros(self.env.num_envs, dtype=torch.float) + self._current_cumulative_rewards = torch.zeros(self.env.num_envs, dtype=torch.float) + + self.train_mode() + + self._obs, self._env_info = self.env.get_observations() + while True: + if self._learning_should_terminate(): + break + + # Collect data + start = time.time() + + with torch.inference_mode(): + self._dataset = [] + + for _ in range(self._num_steps_per_env): + self._collect() + + self._episode_statistics["lengths"] = self._episode_statistics["lengths"][-return_epochs:] + self._episode_statistics["returns"] = self._episode_statistics["returns"][-return_epochs:] + + self._episode_statistics["collection_time"] = time.time() - start + + # Update agent + + start = time.time() + + self._update() + + self._episode_statistics["update_time"] = time.time() - start + + # Housekeeping + + self._episode_statistics["total_time"] = ( + self._episode_statistics["collection_time"] + self._episode_statistics["update_time"] + ) + self._episode_statistics["training_time"] += self._episode_statistics["total_time"] + + if self.agent.initialized: + self._episode_statistics["current_iteration"] += 1 + + terminate = False + for cb in self._learn_cb: + terminate = (cb(self, self._episode_statistics) == False) or terminate + + if terminate: + break + + self._episode_statistics["info"].clear() + self._current_learning_iteration = self._episode_statistics["current_iteration"] + + def _collect(self) -> None: + """Runs a single step in the environment to collect a transition and stores it in the dataset. + + This method runs a single step in the environment to collect a transition and stores it in the dataset. If the + agent is not initialized, random actions are drawn from the action space. Furthermore, the method gathers + statistics about the episode and stores them in the episode statistics dictionary of the runner. + """ + if self.agent.initialized: + actions, data = self.agent.draw_actions(self._obs, self._env_info) + else: + actions, data = self.agent.draw_random_actions(self._obs, self._env_info) + + next_obs, rewards, dones, next_env_info = self.env.step(actions) + + self._dataset.append( + self.agent.process_transition( + self._obs.clone(), + copy.deepcopy(self._env_info), + actions.clone(), + rewards.clone(), + next_obs.clone(), + copy.deepcopy(next_env_info), + dones.clone(), + copy.deepcopy(data), + ) + ) + + self.agent.register_terminations(dones.nonzero().reshape(-1)) + + self._obs, self._env_info = next_obs, next_env_info + + # Gather statistics + if "episode" in self._env_info: + self._episode_statistics["info"].append(self._env_info["episode"]) + dones_idx = (dones + next_env_info["time_outs"]).nonzero().cpu() + self._current_episode_lengths += 1 + self._current_cumulative_rewards += rewards.cpu() + + completed_lengths = self._current_episode_lengths[dones_idx][:, 0].cpu() + completed_returns = self._current_cumulative_rewards[dones_idx][:, 0].cpu() + self._episode_statistics["lengths"].extend(completed_lengths.tolist()) + self._episode_statistics["returns"].extend(completed_returns.tolist()) + self._current_episode_lengths[dones_idx] = 0.0 + self._current_cumulative_rewards[dones_idx] = 0.0 + + self._episode_statistics["current_actions"] = actions + self._episode_statistics["current_observations"] = self._obs + self._episode_statistics["sample_count"] = self.agent.storage.sample_count + + def _learning_should_terminate(self): + """Checks whether the learning should terminate. + + Termination is triggered if the number of iterations or the timeout is reached. + + Returns: + Whether the learning should terminate. + """ + if ( + self._episode_statistics["final_iteration"] is not None + and self._episode_statistics["current_iteration"] >= self._episode_statistics["final_iteration"] + ): + return True + + if ( + self._episode_statistics["timeout"] is not None + and self._episode_statistics["training_time"] >= self._episode_statistics["timeout"] + ): + return True + + return False + + def _update(self) -> None: + """Updates the agent using the collected data.""" + loss = self.agent.update(self._dataset) + self._dataset = [] + + if not self.agent.initialized: + return + + self._episode_statistics["loss"] = loss + self._episode_statistics["storage_initialized"] = True + + def load(self, path: str) -> Any: + """Restores the agent and runner state from a file.""" + content = torch.load(path) + + assert "agent" in content + assert "data" in content + assert "iteration" in content + + self.agent.load_state_dict(content["agent"]) + self._current_learning_iteration = content["iteration"] + + return content["data"] + + def save(self, path: str, data: Any = None) -> None: + """Saves the agent and runner state to a file.""" + content = { + "agent": self.agent.state_dict(), + "data": data, + "iteration": self._current_learning_iteration, + } + + os.makedirs(os.path.dirname(path), exist_ok=True) + torch.save(content, path) + + def export_onnx(self, path: str) -> None: + """Exports the agent's policy network to ONNX format.""" + model, args, kwargs = self.agent.export_onnx() + + os.makedirs(os.path.dirname(path), exist_ok=True) + torch.onnx.export(model, args, path, **kwargs) + + def to(self, device) -> Runner: + """Sets the device of the runner and its components.""" + self.device = device + + self.agent.to(device) + + try: + self.env.to(device) + except AttributeError: + pass + + return self + + def train_mode(self): + """Sets the agent to training mode.""" + self.agent.train_mode() + + def _set_kwarg(self, args, key, default=None, private=True): + setattr(self, f"_{key}" if private else key, args[key] if key in args else default) + + def _log_progress(self, stat, clear_line=True, prefix=""): + """Logs the progress of the runner.""" + if not hasattr(self, "_iteration_times"): + self._iteration_times = [] + + self._iteration_times = (self._iteration_times + [stat["total_time"]])[-100:] + average_total_time = np.mean(self._iteration_times) + + if stat["final_iteration"] is not None: + first_iteration = stat["first_iteration"] + final_iteration = stat["final_iteration"] + current_iteration = stat["current_iteration"] + final_run_iteration = final_iteration - first_iteration + remaining_iterations = final_iteration - current_iteration + + remaining_iteration_time = remaining_iterations * average_total_time + iteration_completion_percentage = 100 * (current_iteration - first_iteration) / final_run_iteration + else: + remaining_iteration_time = np.inf + iteration_completion_percentage = 0 + + if stat["timeout"] is not None: + training_time = stat["training_time"] + timeout = stat["timeout"] + + remaining_timeout_time = stat["timeout"] - stat["training_time"] + timeout_completion_percentage = 100 * stat["training_time"] / stat["timeout"] + else: + remaining_timeout_time = np.inf + timeout_completion_percentage = 0 + + if remaining_iteration_time > remaining_timeout_time: + completion_percentage = timeout_completion_percentage + remaining_time = remaining_timeout_time + step_string = f"({int(training_time)}s / {timeout}s)" + else: + completion_percentage = iteration_completion_percentage + remaining_time = remaining_iteration_time + step_string = f"({current_iteration} / {final_iteration})" + + prefix = f"[{prefix}] " if prefix else "" + progress = "".join(["#" if i <= int(completion_percentage) else "_" for i in range(10, 101, 5)]) + remaining_time_string = str(timedelta(seconds=int(np.ceil(remaining_time)))) + print( + f"{prefix}{progress} {step_string} [{completion_percentage:.1f}%, {1/average_total_time:.2f}it/s, {remaining_time_string} ETA]", + end="\r" if clear_line else "\n", + ) + + def _log(self, stat, prefix=""): + """Logs the progress and statistics of the runner.""" + current_iteration = stat["current_iteration"] + + collection_time = stat["collection_time"] + update_time = stat["update_time"] + total_time = stat["total_time"] + collection_percentage = 100 * collection_time / total_time + update_percentage = 100 * update_time / total_time + + if prefix == "": + prefix = "learn" if stat["storage_initialized"] else "init" + self._log_progress(stat, clear_line=False, prefix=prefix) + print( + f"iteration time:\t{total_time:.4f}s (collection: {collection_time:.2f}s [{collection_percentage:.1f}%], update: {update_time:.2f}s [{update_percentage:.1f}%])" + ) + + mean_reward = sum(stat["returns"]) / len(stat["returns"]) if len(stat["returns"]) > 0 else 0.0 + mean_steps = sum(stat["lengths"]) / len(stat["lengths"]) if len(stat["lengths"]) > 0 else 0.0 + total_steps = current_iteration * self.env.num_envs * self._num_steps_per_env + sample_count = stat["sample_count"] + print(f"avg. reward:\t{mean_reward:.4f}") + print(f"avg. steps:\t{mean_steps:.4f}") + print(f"stored samples:\t{sample_count}") + print(f"total steps:\t{total_steps}") + + for key, value in stat["loss"].items(): + print(f"{key} loss:\t{value:.4f}") + + for key, value in self.agent._bm_report().items(): + mean, count = value + print(f"BM {key}:\t{mean/1000000.0:.4f}ms ({count} calls, total {mean*count/1000000.0:.4f}ms)") + + self.agent._bm_flush() diff --git a/rsl_rl/storage/__init__.py b/rsl_rl/storage/__init__.py index 91032a0..e22f507 100644 --- a/rsl_rl/storage/__init__.py +++ b/rsl_rl/storage/__init__.py @@ -4,5 +4,6 @@ """Implementation of transitions storage for RL-agent.""" from .rollout_storage import RolloutStorage +from .replay_storage import ReplayStorage -__all__ = ["RolloutStorage"] +__all__ = ["RolloutStorage", "ReplayStorage"] diff --git a/rsl_rl/storage/replay_storage.py b/rsl_rl/storage/replay_storage.py new file mode 100644 index 0000000..955f31f --- /dev/null +++ b/rsl_rl/storage/replay_storage.py @@ -0,0 +1,147 @@ +import torch +from typing import Callable, Dict, Generator, Tuple, Optional + +from rsl_rl.storage.storage import Dataset, Storage, Transition + + +class ReplayStorage(Storage): + def __init__(self, environment_count: int, max_size: int, device: str = "cpu", initial_size: int = 0) -> None: + self._env_count = environment_count + self.initial_size = initial_size // environment_count + self.max_size = max_size + self.device = device + + self._register_serializable("max_size", "initial_size") + + self._idx = 0 + self._full = False + self._initialized = initial_size == 0 + self._data = {} + + self._processors: Dict[Tuple[Callable, Callable]] = {} + + @property + def max_size(self): + return self._size * self._env_count + + @max_size.setter + def max_size(self, value): + self._size = value // self._env_count + + assert self.initial_size <= self._size + + def _add_item(self, name: str, value: torch.Tensor) -> None: + """Adds a transition item to the storage. + + Args: + name (str): The name of the item. + value (torch.Tensor): The value of the item. + """ + value = self._process(name, value.clone().to(self.device)) + + if name not in self._data: + if self._full or self._idx != 0: + raise ValueError(f'Tried to store invalid transition data for "{name}".') + self._data[name] = torch.empty( + self._size * self._env_count, *value.shape[1:], device=self.device, dtype=value.dtype + ) + + start_idx = self._idx * self._env_count + end_idx = (self._idx + 1) * self._env_count + self._data[name][start_idx:end_idx] = value + + def _process(self, name: str, value: torch.Tensor) -> torch.Tensor: + if name not in self._processors: + return value + + for process, _ in self._processors[name]: + if process is None: + continue + + value = process(value) + + return value + + def _process_undo(self, name: str, value: torch.Tensor) -> torch.Tensor: + if name not in self._processors: + return value + + for _, undo in reversed(self._processors[name]): + if undo is None: + continue + + value = undo(value) + + return value + + def append(self, dataset: Dataset) -> None: + """Appends a dataset of transitions to the storage. + + Args: + dataset (Dataset): The dataset of transitions. + """ + for transition in dataset: + for name, value in transition.items(): + self._add_item(name, value) + + self._idx += 1 + + if self._idx >= self.initial_size: + self._initialized = True + + if self._idx == self._size: + self._full = True + self._idx = 0 + + def batch_generator(self, batch_size: int, batch_count: int) -> Generator[Transition, None, None]: + """Returns a generator that yields batches of transitions. + + Args: + batch_size (int): The size of the batches. + batch_count (int): The number of batches to yield. + Returns: + A generator that yields batches of transitions. + """ + assert self._full or self._idx > 0 + + if not self._initialized: + return + + max_idx = self._env_count * (self._size if self._full else self._idx) + + for _ in range(batch_count): + batch_idx = torch.randint(high=max_idx, size=(batch_size,)) + + batch = {} + for key, value in self._data.items(): + batch[key] = self._process_undo(key, value[batch_idx].clone()) + + yield batch + + def register_processor(self, key: str, process: Callable, undo: Optional[Callable] = None) -> None: + """Registers a processor for a transition item. + + The processor is called before the item is stored in the storage. The undo function is called when the item is + retrieved from the storage. The undo function is called in reverse order of the processors so that the order of + the processors does not matter. + + Args: + key (str): The name of the transition item. + process (Callable): The function to process the item. + undo (Optional[Callable], optional): The function to undo the processing. Defaults to None. + """ + if key not in self._processors: + self._processors[key] = [] + + self._processors[key].append((process, undo)) + + @property + def initialized(self) -> bool: + return self._initialized + + @property + def sample_count(self) -> int: + """Returns the number of individual transitions stored in the storage.""" + transition_count = self._size * self._env_count if self._full else self._idx * self._env_count + + return transition_count diff --git a/rsl_rl/storage/rollout_storage.py b/rsl_rl/storage/rollout_storage.py index 7aec039..34871a8 100644 --- a/rsl_rl/storage/rollout_storage.py +++ b/rsl_rl/storage/rollout_storage.py @@ -1,228 +1,71 @@ -# Copyright 2021 ETH Zurich, NVIDIA CORPORATION -# SPDX-License-Identifier: BSD-3-Clause - - -from __future__ import annotations - import torch +from typing import Generator -from rsl_rl.utils import split_and_pad_trajectories +from rsl_rl.storage.replay_storage import ReplayStorage +from rsl_rl.storage.storage import Dataset, Transition -class RolloutStorage: - class Transition: - def __init__(self): - self.observations = None - self.critic_observations = None - self.actions = None - self.rewards = None - self.dones = None - self.values = None - self.actions_log_prob = None - self.action_mean = None - self.action_sigma = None - self.hidden_states = None +class RolloutStorage(ReplayStorage): + """Implementation of rollout storage for RL-agent.""" - def clear(self): - self.__init__() + def __init__(self, environment_count: int, device: str = "cpu"): + """ + Args: + environment_count (int): Number of environments. + device (str, optional): Device to use. Defaults to "cpu". + """ + super().__init__(environment_count, environment_count, device=device, initial_size=0) - def __init__(self, num_envs, num_transitions_per_env, obs_shape, privileged_obs_shape, actions_shape, device="cpu"): - self.device = device + self._size_initialized = False - self.obs_shape = obs_shape - self.privileged_obs_shape = privileged_obs_shape - self.actions_shape = actions_shape + def append(self, dataset: Dataset) -> None: + """Appends a dataset to the rollout storage. - # Core - self.observations = torch.zeros(num_transitions_per_env, num_envs, *obs_shape, device=self.device) - if privileged_obs_shape[0] is not None: - self.privileged_observations = torch.zeros( - num_transitions_per_env, num_envs, *privileged_obs_shape, device=self.device - ) + Args: + dataset (Dataset): Dataset to append. + Raises: + AssertionError: If the dataset is not of the correct size. + """ + assert self._idx == 0 + + if not self._size_initialized: + self.max_size = len(dataset) * self._env_count + + assert len(dataset) == self._size + + super().append(dataset) + + def batch_generator(self, batch_count: int, trajectories: bool = False) -> Generator[Transition, None, None]: + """Yields batches of transitions or trajectories. + + Args: + batch_count (int): Number of batches to yield. + trajectories (bool, optional): Whether to yield batches of trajectories. Defaults to False. + Raises: + AssertionError: If the rollout storage is not full. + Returns: + Generator yielding batches of transitions of shape (batch_size, *shape). If trajectories is True, yields + batches of trajectories of shape (env_count, steps_per_env, *shape). + """ + assert self._full and self._initialized, "Rollout storage must be full and initialized." + + total_size = self._env_count if trajectories else self._size * self._env_count + batch_size = total_size // batch_count + indices = torch.randperm(total_size) + + assert batch_size > 0, "Batch count is too large." + + if trajectories: + # Reshape to (env_count, steps_per_env, *shape) + data = {k: v.reshape(-1, self._env_count, *v.shape[1:]).transpose(0, 1) for k, v in self._data.items()} else: - self.privileged_observations = None - self.rewards = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) - self.actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device) - self.dones = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device).byte() + data = self._data - # For PPO - self.actions_log_prob = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) - self.values = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) - self.returns = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) - self.advantages = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) - self.mu = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device) - self.sigma = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device) + for i in range(batch_count): + batch_idx = indices[i * batch_size : (i + 1) * batch_size].detach().to(self.device) - self.num_transitions_per_env = num_transitions_per_env - self.num_envs = num_envs + batch = {} + for key, value in data.items(): + batch[key] = self._process_undo(key, value[batch_idx].clone()) - # rnn - self.saved_hidden_states_a = None - self.saved_hidden_states_c = None - - self.step = 0 - - def add_transitions(self, transition: Transition): - if self.step >= self.num_transitions_per_env: - raise AssertionError("Rollout buffer overflow") - self.observations[self.step].copy_(transition.observations) - if self.privileged_observations is not None: - self.privileged_observations[self.step].copy_(transition.critic_observations) - self.actions[self.step].copy_(transition.actions) - self.rewards[self.step].copy_(transition.rewards.view(-1, 1)) - self.dones[self.step].copy_(transition.dones.view(-1, 1)) - self.values[self.step].copy_(transition.values) - self.actions_log_prob[self.step].copy_(transition.actions_log_prob.view(-1, 1)) - self.mu[self.step].copy_(transition.action_mean) - self.sigma[self.step].copy_(transition.action_sigma) - self._save_hidden_states(transition.hidden_states) - self.step += 1 - - def _save_hidden_states(self, hidden_states): - if hidden_states is None or hidden_states == (None, None): - return - # make a tuple out of GRU hidden state sto match the LSTM format - hid_a = hidden_states[0] if isinstance(hidden_states[0], tuple) else (hidden_states[0],) - hid_c = hidden_states[1] if isinstance(hidden_states[1], tuple) else (hidden_states[1],) - - # initialize if needed - if self.saved_hidden_states_a is None: - self.saved_hidden_states_a = [ - torch.zeros(self.observations.shape[0], *hid_a[i].shape, device=self.device) for i in range(len(hid_a)) - ] - self.saved_hidden_states_c = [ - torch.zeros(self.observations.shape[0], *hid_c[i].shape, device=self.device) for i in range(len(hid_c)) - ] - # copy the states - for i in range(len(hid_a)): - self.saved_hidden_states_a[i][self.step].copy_(hid_a[i]) - self.saved_hidden_states_c[i][self.step].copy_(hid_c[i]) - - def clear(self): - self.step = 0 - - def compute_returns(self, last_values, gamma, lam): - advantage = 0 - for step in reversed(range(self.num_transitions_per_env)): - if step == self.num_transitions_per_env - 1: - next_values = last_values - else: - next_values = self.values[step + 1] - next_is_not_terminal = 1.0 - self.dones[step].float() - delta = self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step] - advantage = delta + next_is_not_terminal * gamma * lam * advantage - self.returns[step] = advantage + self.values[step] - - # Compute and normalize the advantages - self.advantages = self.returns - self.values - self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8) - - def get_statistics(self): - done = self.dones - done[-1] = 1 - flat_dones = done.permute(1, 0, 2).reshape(-1, 1) - done_indices = torch.cat( - (flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero(as_tuple=False)[:, 0]) - ) - trajectory_lengths = done_indices[1:] - done_indices[:-1] - return trajectory_lengths.float().mean(), self.rewards.mean() - - def mini_batch_generator(self, num_mini_batches, num_epochs=8): - batch_size = self.num_envs * self.num_transitions_per_env - mini_batch_size = batch_size // num_mini_batches - indices = torch.randperm(num_mini_batches * mini_batch_size, requires_grad=False, device=self.device) - - observations = self.observations.flatten(0, 1) - if self.privileged_observations is not None: - critic_observations = self.privileged_observations.flatten(0, 1) - else: - critic_observations = observations - - actions = self.actions.flatten(0, 1) - values = self.values.flatten(0, 1) - returns = self.returns.flatten(0, 1) - old_actions_log_prob = self.actions_log_prob.flatten(0, 1) - advantages = self.advantages.flatten(0, 1) - old_mu = self.mu.flatten(0, 1) - old_sigma = self.sigma.flatten(0, 1) - - for epoch in range(num_epochs): - for i in range(num_mini_batches): - start = i * mini_batch_size - end = (i + 1) * mini_batch_size - batch_idx = indices[start:end] - - obs_batch = observations[batch_idx] - critic_observations_batch = critic_observations[batch_idx] - actions_batch = actions[batch_idx] - target_values_batch = values[batch_idx] - returns_batch = returns[batch_idx] - old_actions_log_prob_batch = old_actions_log_prob[batch_idx] - advantages_batch = advantages[batch_idx] - old_mu_batch = old_mu[batch_idx] - old_sigma_batch = old_sigma[batch_idx] - yield obs_batch, critic_observations_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, ( - None, - None, - ), None - - # for RNNs only - def reccurent_mini_batch_generator(self, num_mini_batches, num_epochs=8): - padded_obs_trajectories, trajectory_masks = split_and_pad_trajectories(self.observations, self.dones) - if self.privileged_observations is not None: - padded_critic_obs_trajectories, _ = split_and_pad_trajectories(self.privileged_observations, self.dones) - else: - padded_critic_obs_trajectories = padded_obs_trajectories - - mini_batch_size = self.num_envs // num_mini_batches - for ep in range(num_epochs): - first_traj = 0 - for i in range(num_mini_batches): - start = i * mini_batch_size - stop = (i + 1) * mini_batch_size - - dones = self.dones.squeeze(-1) - last_was_done = torch.zeros_like(dones, dtype=torch.bool) - last_was_done[1:] = dones[:-1] - last_was_done[0] = True - trajectories_batch_size = torch.sum(last_was_done[:, start:stop]) - last_traj = first_traj + trajectories_batch_size - - masks_batch = trajectory_masks[:, first_traj:last_traj] - obs_batch = padded_obs_trajectories[:, first_traj:last_traj] - critic_obs_batch = padded_critic_obs_trajectories[:, first_traj:last_traj] - - actions_batch = self.actions[:, start:stop] - old_mu_batch = self.mu[:, start:stop] - old_sigma_batch = self.sigma[:, start:stop] - returns_batch = self.returns[:, start:stop] - advantages_batch = self.advantages[:, start:stop] - values_batch = self.values[:, start:stop] - old_actions_log_prob_batch = self.actions_log_prob[:, start:stop] - - # reshape to [num_envs, time, num layers, hidden dim] (original shape: [time, num_layers, num_envs, hidden_dim]) - # then take only time steps after dones (flattens num envs and time dimensions), - # take a batch of trajectories and finally reshape back to [num_layers, batch, hidden_dim] - last_was_done = last_was_done.permute(1, 0) - hid_a_batch = [ - saved_hidden_states.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj] - .transpose(1, 0) - .contiguous() - for saved_hidden_states in self.saved_hidden_states_a - ] - hid_c_batch = [ - saved_hidden_states.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj] - .transpose(1, 0) - .contiguous() - for saved_hidden_states in self.saved_hidden_states_c - ] - # remove the tuple for GRU - hid_a_batch = hid_a_batch[0] if len(hid_a_batch) == 1 else hid_a_batch - hid_c_batch = hid_c_batch[0] if len(hid_c_batch) == 1 else hid_c_batch - - yield obs_batch, critic_obs_batch, actions_batch, values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, ( - hid_a_batch, - hid_c_batch, - ), masks_batch - - first_traj = last_traj + yield batch diff --git a/rsl_rl/storage/storage.py b/rsl_rl/storage/storage.py new file mode 100644 index 0000000..308cc02 --- /dev/null +++ b/rsl_rl/storage/storage.py @@ -0,0 +1,47 @@ +from abc import abstractmethod +import torch +from typing import Dict, Generator, List + +from rsl_rl.utils.serializable import Serializable + + +# prev_obs, prev_obs_info, actions, rewards, next_obs, next_obs_info, dones, data +Transition = Dict[str, torch.Tensor] +Dataset = List[Transition] + + +class Storage(Serializable): + @abstractmethod + def append(self, dataset: Dataset) -> None: + """Adds transitions to the storage. + + Args: + dataset (Dataset): The transitions to add to the storage. + """ + pass + + @abstractmethod + def batch_generator(self, batch_size: int, batch_count: int) -> Generator[Dict[str, torch.Tensor], None, None]: + """Generates a batch of transitions. + + Args: + batch_size (int): The size of each batch to generate. + batch_count (int): The number of batches to generate. + Returns: + A generator that yields transitions. + """ + pass + + @property + def initialized(self) -> bool: + """Returns whether the storage is initialized.""" + return True + + @abstractmethod + def sample_count(self) -> int: + """Returns how many individual samples are stored in the storage. + + Returns: + The number of stored samples. + """ + pass diff --git a/rsl_rl/utils/__init__.py b/rsl_rl/utils/__init__.py index 46b365a..19bdd7b 100644 --- a/rsl_rl/utils/__init__.py +++ b/rsl_rl/utils/__init__.py @@ -1,6 +1,3 @@ -# Copyright 2021 ETH Zurich, NVIDIA CORPORATION -# SPDX-License-Identifier: BSD-3-Clause - """Helper functions.""" -from .utils import split_and_pad_trajectories, store_code_state, unpad_trajectories +from .utils import split_and_pad_trajectories, unpad_trajectories, store_code_state diff --git a/rsl_rl/utils/benchmarkable.py b/rsl_rl/utils/benchmarkable.py new file mode 100644 index 0000000..7cb95a2 --- /dev/null +++ b/rsl_rl/utils/benchmarkable.py @@ -0,0 +1,111 @@ +import numpy as np +import time +from typing import Callable, Dict + + +class Benchmark: + def __init__(self): + self.reset() + + def __call__(self): + if self.running: + self.end() + else: + self.start() + + def end(self): + now = time.process_time_ns() + + assert self.running + + difference = now - self._timer + self._timings.append(difference) + + self._timer = None + + def reset(self): + self._timer = None + self._timings = [] + + @property + def running(self): + return self._timer is not None + + def start(self): + self._timer = time.process_time_ns() + + @property + def timings(self): + return self._timings + + +class Benchmarkable: + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self._benchmark = False + self._bm_data = dict() + self._bm_fusions = [] + + def _bm(self, name: str) -> None: + if not self._benchmark: + return + + if name not in self._bm_data: + self._bm_data[name] = Benchmark() + + self._bm_data[name]() + + def _bm_flush(self) -> None: + # TODO: implement + for val in self._bm_data.values(): + val.reset() + + for fusion in self._bm_fusions: + fusion["target"]._bm_flush() + + def _bm_fuse(self, target, prefix="") -> None: + assert isinstance(target, Benchmarkable) + assert target not in self._bm_fusions + + target._bm_toggle(self._benchmark) + self._bm_fusions.append(dict(target=target, prefix=prefix)) + + def _bm_report(self) -> Dict: + data = dict() + + if not self._benchmark: + return data + + for key, val in self._bm_data.items(): + data[key] = (np.mean(val.timings), len(val.timings)) + + for fusion in self._bm_fusions: + target = fusion["target"] + prefix = fusion["prefix"] + + for key, val in target._bm_report().items(): + data[f"{prefix}{key}"] = val + + return data + + def _bm_toggle(self, value: bool) -> None: + self._benchmark = value + + for fusion in self._bm_fusions: + fusion["target"]._bm_toggle(value) + + @staticmethod + def register(method: Callable, name=None) -> Callable: + benchmark_name = method.__name__ if name is None else name + + def wrapper(self, *args, **kwargs): + assert isinstance(self, Benchmarkable) + + self._bm(benchmark_name) + result = method(self, *args, **kwargs) + self._bm(benchmark_name) + + return result + + return wrapper diff --git a/rsl_rl/utils/neptune_utils.py b/rsl_rl/utils/neptune_utils.py index f06cc62..9e0cbb7 100644 --- a/rsl_rl/utils/neptune_utils.py +++ b/rsl_rl/utils/neptune_utils.py @@ -1,32 +1,26 @@ -# Copyright 2021 ETH Zurich, NVIDIA CORPORATION -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - import os -from dataclasses import asdict + from torch.utils.tensorboard import SummaryWriter +from legged_gym.utils import class_to_dict try: - import neptune + import neptune.new as neptune except ModuleNotFoundError: raise ModuleNotFoundError("neptune-client is required to log to Neptune.") class NeptuneLogger: def __init__(self, project, token): - self.run = neptune.init_run(project=project, api_token=token) + self.run = neptune.init(project=project, api_token=token) def store_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg): self.run["runner_cfg"] = runner_cfg self.run["policy_cfg"] = policy_cfg self.run["alg_cfg"] = alg_cfg - self.run["env_cfg"] = asdict(env_cfg) + self.run["env_cfg"] = class_to_dict(env_cfg) class NeptuneSummaryWriter(SummaryWriter): - """Summary writer for Neptune.""" - def __init__(self, log_dir: str, flush_secs: int, cfg): super().__init__(log_dir, flush_secs) @@ -86,7 +80,3 @@ class NeptuneSummaryWriter(SummaryWriter): def save_model(self, model_path, iter): self.neptune_logger.run["model/saved_model_" + str(iter)].upload(model_path) - - def save_file(self, path, iter=None): - name = path.rsplit("/", 1)[-1].split(".")[0] - self.neptune_logger.run["git_diff/" + name].upload(path) diff --git a/rsl_rl/utils/recurrency.py b/rsl_rl/utils/recurrency.py new file mode 100644 index 0000000..3e798bd --- /dev/null +++ b/rsl_rl/utils/recurrency.py @@ -0,0 +1,69 @@ +import torch +from typing import Tuple + + +def trajectories_to_transitions(trajectories: torch.Tensor, data: Tuple[torch.Tensor, int, bool]) -> torch.Tensor: + """Unpacks a tensor of trajectories into a tensor of transitions. + + Args: + trajectories (torch.Tensor): A tensor of trajectories. + data (Tuple[torch.Tensor, int, bool]): A tuple containing the mask and data for the conversion. + batch_first (bool, optional): Whether the first dimension of the trajectories tensor is the batch dimension. + Defaults to False. + Returns: + A tensor of transitions of shape (batch_size, time, *). + """ + mask, batch_size, batch_first = data + + if not batch_first: + trajectories, mask = trajectories.transpose(0, 1), mask.transpose(0, 1) + + transitions = trajectories[mask == 1.0].reshape(batch_size, -1, *trajectories.shape[2:]) + + return transitions + + +def transitions_to_trajectories( + transitions: torch.Tensor, dones: torch.Tensor, batch_first: bool = False +) -> Tuple[torch.Tensor, Tuple[torch.Tensor, int, bool]]: + """Packs a tensor of transitions into a tensor of trajectories. + + Example: + >>> transitions = torch.tensor([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]]) + >>> dones = torch.tensor([[0, 0, 1], [0, 1, 0]]) + >>> transitions_to_trajectories(None, transitions, dones, batch_first=True) + (tensor([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [0, 0]], [[11, 12], [0, 0], [0, 0]]]), tensor([[1, 1, 1], [1, 1, 0], [1, 0, 0]])) + + Args: + transitions (torch.Tensor): Tensor of transitions of shape (batch_size, time, *). + dones (torch.Tensor): Tensor of transition terminations of shape (batch_size, time). + batch_first (bool): Whether the first dimension of the output tensor should be the batch dimension. Defaults to + False. + Returns: + A torch.Tensor of trajectories of shape (time, trajectory_count, *) that is padded with zeros and data for + reverting the operation. If batch_first is True, the shape of the trajectories is (trajectory_count, time, *). + """ + batch_size = transitions.shape[0] + + # Count the trajectory lengths by (1) padding dones with a 1 at the end to indicate the end of the trajectory, + # (2) stacking up the padded dones in a single column, and (3) counting the number of steps between each done by + # using the row index. + padded_dones = dones.clone() + padded_dones[:, -1] = 1 + stacked_dones = torch.cat((padded_dones.new([-1]), padded_dones.reshape(-1, 1).nonzero()[:, 0])) + trajectory_lengths = stacked_dones[1:] - stacked_dones[:-1] + + # Compute trajectories by splitting transitions according to previously computed trajectory lengths. + trajectory_list = torch.split(transitions.flatten(0, 1), trajectory_lengths.int().tolist()) + trajectories = torch.nn.utils.rnn.pad_sequence(trajectory_list, batch_first=batch_first) + + # The mask is generated by computing a 2d matrix of increasing counts in the 2nd dimension and comparing it to the + # trajectory lengths. + range = torch.arange(0, trajectory_lengths.max()).repeat(len(trajectory_lengths), 1) + range = range.cuda(dones.device) if dones.is_cuda else range + mask = (trajectory_lengths.unsqueeze(1) > range).float() + + if not batch_first: + mask = mask.T + + return trajectories, (mask, batch_size, batch_first) diff --git a/rsl_rl/utils/serializable.py b/rsl_rl/utils/serializable.py new file mode 100644 index 0000000..d36e002 --- /dev/null +++ b/rsl_rl/utils/serializable.py @@ -0,0 +1,43 @@ +from typing import Any, Dict + + +class Serializable: + def load_state_dict(self, data: Dict[str, Any]) -> None: + """Loads agent parameters from a dictionary.""" + assert hasattr(self, "_serializable_objects") + + for name in self._serializable_objects: + assert hasattr(self, name) + assert name in data, f'Object "{name}" was not found while loading "{self.__class__.__name__}".' + + attr = getattr(self, name) + if hasattr(attr, "load_state_dict"): + print(f"Loading {name}") + attr.load_state_dict(data[name]) + else: + print(f"Loading value {name}={data[name]}") + setattr(self, name, data[name]) + + def state_dict(self) -> Dict[str, Any]: + """Returns a dictionary containing the agent parameters.""" + assert hasattr(self, "_serializable_objects") + + data = {} + + for name in self._serializable_objects: + assert hasattr(self, name) + + attr = getattr(self, name) + data[name] = attr.state_dict() if hasattr(attr, "state_dict") else attr + + return data + + def _register_serializable(self, *objects) -> None: + if not hasattr(self, "_serializable_objects"): + self._serializable_objects = [] + + for name in objects: + if name in self._serializable_objects: + continue + + self._serializable_objects.append(name) diff --git a/rsl_rl/utils/utils.py b/rsl_rl/utils/utils.py index f8d3103..f4d6b6a 100644 --- a/rsl_rl/utils/utils.py +++ b/rsl_rl/utils/utils.py @@ -1,16 +1,38 @@ -# Copyright 2021 ETH Zurich, NVIDIA CORPORATION -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - +from datetime import datetime import git import os +import numpy as np import pathlib +import random import torch +def environment_dimensions(env): + obs = env.get_observations() + + if isinstance(obs, tuple): + obs, env_info = obs + else: + env_info = {} + + dims = {} + + dims["observations"] = obs.shape[1] + + if "observations" in env_info and "critic" in env_info["observations"]: + dims["actor_observations"] = dims["observations"] + dims["critic_observations"] = env_info["observations"]["critic"].shape[1] + else: + dims["actor_observations"] = dims["observations"] + dims["critic_observations"] = dims["observations"] + + dims["actions"] = env.num_actions + + return dims + + def split_and_pad_trajectories(tensor, dones): - """Splits trajectories at done indices. Then concatenates them and pads with zeros up to the length og the longest trajectory. + """Splits trajectories at done indices. Then concatenates them and padds with zeros up to the length og the longest trajectory. Returns masks corresponding to valid parts of the trajectories Example: Input: [ [a1, a2, a3, a4 | a5, a6], @@ -24,7 +46,7 @@ def split_and_pad_trajectories(tensor, dones): [b6, 0, 0, 0] | [True, False, False, False], ] | ] - Assumes that the inputy has the following dimension order: [time, number of envs, additional dimensions] + Assumes that the inputy has the following dimension order: [time, number of envs, aditional dimensions] """ dones = dones.clone() dones[-1] = 1 @@ -37,12 +59,7 @@ def split_and_pad_trajectories(tensor, dones): trajectory_lengths_list = trajectory_lengths.tolist() # Extract the individual trajectories trajectories = torch.split(tensor.transpose(1, 0).flatten(0, 1), trajectory_lengths_list) - # add at least one full length trajectory - trajectories = trajectories + (torch.zeros(tensor.shape[0], tensor.shape[-1], device=tensor.device),) - # pad the trajectories to the length of the longest trajectory padded_trajectories = torch.nn.utils.rnn.pad_sequence(trajectories) - # remove the added tensor - padded_trajectories = padded_trajectories[:, :-1] trajectory_masks = trajectory_lengths > torch.arange(0, tensor.shape[0], device=tensor.device).unsqueeze(1) return padded_trajectories, trajectory_masks @@ -58,29 +75,34 @@ def unpad_trajectories(trajectories, masks): ) -def store_code_state(logdir, repositories) -> list: - git_log_dir = os.path.join(logdir, "git") - os.makedirs(git_log_dir, exist_ok=True) - file_paths = [] +def store_code_state(logdir, repositories): for repository_file_path in repositories: - try: - repo = git.Repo(repository_file_path, search_parent_directories=True) - except Exception: - print(f"Could not find git repository in {repository_file_path}. Skipping.") - # skip if not a git repository - continue - # get the name of the repository + repo = git.Repo(repository_file_path, search_parent_directories=True) repo_name = pathlib.Path(repo.working_dir).name t = repo.head.commit.tree - diff_file_name = os.path.join(git_log_dir, f"{repo_name}.diff") - # check if the diff file already exists - if os.path.isfile(diff_file_name): - continue - # write the diff file - print(f"Storing git diff for '{repo_name}' in: {diff_file_name}") - with open(diff_file_name, "x") as f: - content = f"--- git status ---\n{repo.git.status()} \n\n\n--- git diff ---\n{repo.git.diff(t)}" + content = f"--- git status ---\n{repo.git.status()} \n\n\n--- git diff ---\n{repo.git.diff(t)}" + with open(os.path.join(logdir, f"{repo_name}_git.diff"), "x", encoding="utf-8") as f: f.write(content) - # add the file path to the list of files to be uploaded - file_paths.append(diff_file_name) - return file_paths + + +def seed(s=None): + seed = int(datetime.now().timestamp() * 1e6) % 2**32 if s is None else s + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def squeeze_preserve_batch(tensor): + """Squeezes a tensor, but preserves the batch dimension""" + single_batch = tensor.shape[0] == 1 + + squeezed_tensor = tensor.squeeze() + + if single_batch: + squeezed_tensor = squeezed_tensor.unsqueeze(0) + + return squeezed_tensor diff --git a/rsl_rl/utils/wandb_utils.py b/rsl_rl/utils/wandb_utils.py index 2868ce9..686c6dc 100644 --- a/rsl_rl/utils/wandb_utils.py +++ b/rsl_rl/utils/wandb_utils.py @@ -1,11 +1,7 @@ -# Copyright 2021 ETH Zurich, NVIDIA CORPORATION -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - import os -from dataclasses import asdict + from torch.utils.tensorboard import SummaryWriter +from legged_gym.utils import class_to_dict try: import wandb @@ -14,8 +10,6 @@ except ModuleNotFoundError: class WandbSummaryWriter(SummaryWriter): - """Summary writer for Weights and Biases.""" - def __init__(self, log_dir: str, flush_secs: int, cfg): super().__init__(log_dir, flush_secs) @@ -49,7 +43,7 @@ class WandbSummaryWriter(SummaryWriter): wandb.config.update({"runner_cfg": runner_cfg}) wandb.config.update({"policy_cfg": policy_cfg}) wandb.config.update({"alg_cfg": alg_cfg}) - wandb.config.update({"env_cfg": asdict(env_cfg)}) + wandb.config.update({"env_cfg": class_to_dict(env_cfg)}) def _map_path(self, path): if path in self.name_map: @@ -74,7 +68,4 @@ class WandbSummaryWriter(SummaryWriter): self.store_config(env_cfg, runner_cfg, alg_cfg, policy_cfg) def save_model(self, model_path, iter): - wandb.save(model_path, base_path=os.path.dirname(model_path)) - - def save_file(self, path, iter=None): - wandb.save(path, base_path=os.path.dirname(path)) + wandb.save(model_path) diff --git a/setup.py b/setup.py index b933467..b30fa39 100644 --- a/setup.py +++ b/setup.py @@ -1,24 +1,20 @@ -# Copyright 2021 ETH Zurich, NVIDIA CORPORATION -# SPDX-License-Identifier: BSD-3-Clause - -from setuptools import find_packages, setup +from setuptools import setup, find_packages setup( name="rsl_rl", - version="2.0.2", + version="1.0.2", packages=find_packages(), - author="ETH Zurich, NVIDIA CORPORATION", - maintainer="Nikita Rudin, David Hoeller", - maintainer_email="rudinn@ethz.ch", - url="https://github.com/leggedrobotics/rsl_rl", license="BSD-3", description="Fast and simple RL algorithms implemented in pytorch", python_requires=">=3.6", install_requires=[ + "GitPython", + "gym[all]>=0.26.0", + "numpy>=1.24.4", + "onnx>=1.14.0", + "tensorboard>=2.13.0", "torch>=1.10.0", "torchvision>=0.5.0", - "numpy>=1.16.4", - "GitPython", - "onnx", + "wandb>=0.15.4", ], ) diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py new file mode 100644 index 0000000..bab6be4 --- /dev/null +++ b/tests/test_algorithms.py @@ -0,0 +1,169 @@ +import unittest + +from rsl_rl.algorithms import D4PG, DDPG, DPPO, DSAC, PPO, SAC, TD3 +from rsl_rl.env.gym_env import GymEnv +from rsl_rl.modules import Network +from rsl_rl.runners.runner import Runner + +DEVICE = "cpu" + + +class AlgorithmTestCaseMixin: + algorithm_class = None + + def _make_env(self, params={}): + my_params = dict(name="LunarLanderContinuous-v2", device=DEVICE, environment_count=4) + my_params.update(params) + + return GymEnv(**my_params) + + def _make_agent(self, env, agent_params={}): + return self.algorithm_class(env, device=DEVICE, **agent_params) + + def _make_runner(self, env, agent, runner_params={}): + if not runner_params or "num_steps_per_env" not in runner_params: + runner_params["num_steps_per_env"] = 6 + + return Runner(env, agent, device=DEVICE, **runner_params) + + def _learn(self, env, agent, runner_params={}): + runner = self._make_runner(env, agent, runner_params) + runner.learn(10) + + def test_default(self): + env = self._make_env() + agent = self._make_agent(env) + + self._learn(env, agent) + + def test_single_env_single_step(self): + env = self._make_env(dict(environment_count=1)) + agent = self._make_agent(env) + + self._learn(env, agent, dict(num_steps_per_env=1)) + + +class RecurrentAlgorithmTestCaseMixin(AlgorithmTestCaseMixin): + def test_recurrent(self): + env = self._make_env() + agent = self._make_agent(env, dict(recurrent=True)) + + self._learn(env, agent) + + def test_single_env_single_step_recurrent(self): + env = self._make_env(dict(environment_count=1)) + agent = self._make_agent(env, dict(recurrent=True)) + + self._learn(env, agent, dict(num_steps_per_env=1)) + + +class D4PGTest(AlgorithmTestCaseMixin, unittest.TestCase): + algorithm_class = D4PG + + +class DDPGTest(AlgorithmTestCaseMixin, unittest.TestCase): + algorithm_class = DDPG + + +iqn_params = dict( + critic_network=DPPO.network_iqn, + iqn_action_samples=8, + iqn_embedding_size=16, + iqn_feature_layers=2, + iqn_value_samples=4, + value_loss=DPPO.value_loss_energy, +) + +qrdqn_params = dict( + critic_network=DPPO.network_qrdqn, + qrdqn_quantile_count=16, + value_loss=DPPO.value_loss_l1, +) + + +class DPPOTest(RecurrentAlgorithmTestCaseMixin, unittest.TestCase): + algorithm_class = DPPO + + def test_qrdqn(self): + env = self._make_env() + agent = self._make_agent(env, qrdqn_params) + + self._learn(env, agent) + + def test_qrdqn_sing_env_single_step(self): + env = self._make_env(dict(environment_count=1)) + agent = self._make_agent(env, qrdqn_params) + + self._learn(env, agent, dict(num_steps_per_env=1)) + + def test_qrdqn_energy_loss(self): + my_agent_params = qrdqn_params.copy() + my_agent_params["value_loss"] = DPPO.value_loss_energy + + env = self._make_env() + agent = self._make_agent(env, my_agent_params) + + self._learn(env, agent) + + def test_qrdqn_huber_loss(self): + my_agent_params = qrdqn_params.copy() + my_agent_params["value_loss"] = DPPO.value_loss_huber + + env = self._make_env() + agent = self._make_agent(env, my_agent_params) + + self._learn(env, agent) + + def test_qrdqn_transformer(self): + my_agent_params = qrdqn_params.copy() + my_agent_params["recurrent"] = True + my_agent_params["critic_recurrent_layers"] = 2 + my_agent_params["critic_recurrent_module"] = Network.recurrent_module_transformer + my_agent_params["critic_recurrent_tf_context_length"] = 8 + my_agent_params["critic_recurrent_tf_head_count"] = 2 + + env = self._make_env() + agent = self._make_agent(env, my_agent_params) + + self._learn(env, agent) + + def test_iqn(self): + env = self._make_env() + agent = self._make_agent(env, iqn_params) + + self._learn(env, agent) + + def test_iqn_single_step_single_env(self): + env = self._make_env(dict(environment_count=1)) + agent = self._make_agent(env, iqn_params) + + self._learn(env, agent, dict(num_steps_per_env=1)) + + def test_iqn_recurrent(self): + my_agent_params = iqn_params.copy() + my_agent_params["recurrent"] = True + + env = self._make_env() + agent = self._make_agent(env, my_agent_params) + + self._learn(env, agent) + + +class DSACTest(AlgorithmTestCaseMixin, unittest.TestCase): + algorithm_class = DSAC + + +class PPOTest(RecurrentAlgorithmTestCaseMixin, unittest.TestCase): + algorithm_class = PPO + + +class SACTest(AlgorithmTestCaseMixin, unittest.TestCase): + algorithm_class = SAC + + +class TD3Test(AlgorithmTestCaseMixin, unittest.TestCase): + algorithm_class = TD3 + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_dpg.py b/tests/test_dpg.py new file mode 100644 index 0000000..b2f7d3c --- /dev/null +++ b/tests/test_dpg.py @@ -0,0 +1,126 @@ +import torch +import unittest +from rsl_rl.algorithms.dpg import AbstractDPG +from rsl_rl.env.pole_balancing import PoleBalancing + + +class DPG(AbstractDPG): + def draw_actions(self, obs, env_info): + pass + + def eval_mode(self): + pass + + def get_inference_policy(self, device=None): + pass + + def process_transition( + self, observations, environment_info, actions, rewards, next_observations, next_environment_info, dones, data + ): + pass + + def register_terminations(self, terminations): + pass + + def to(self, device): + pass + + def train_mode(self): + pass + + def update(self, storage): + pass + + +class FakeCritic(torch.nn.Module): + def __init__(self, values): + self.values = values + + def forward(self, _): + return self.values + + +class DPGTest(unittest.TestCase): + def test_timeout_bootstrapping(self): + env = PoleBalancing(environment_count=4) + dpg = DPG(env, device="cpu", return_steps=3) + + rewards = torch.tensor( + [ + [0.1000, 0.4000, 0.6000, 0.2000, -0.6000, -0.2000], + [0.0000, 0.9000, 0.5000, -0.9000, -0.4000, 0.8000], + [-0.5000, 0.4000, 0.0000, -0.2000, 0.3000, 0.1000], + [-0.8000, 0.9000, -0.6000, 0.7000, 0.5000, 0.1000], + ] + ) + dones = torch.tensor( + [ + [0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + ] + ) + timeouts = torch.tensor( + [ + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + ] + ) + actions = torch.zeros((4, 6, 1)) + observations = torch.zeros((4, 6, 2)) + values = torch.tensor([-0.1000, -0.8000, 0.4000, 0.7000]) + + dpg.critic = FakeCritic(values) + dataset = [ + { + "actions": actions[:, i], + "critic_observations": observations[:, i], + "dones": dones[:, i], + "rewards": rewards[:, i], + "timeouts": timeouts[:, i], + } + for i in range(3) + ] + + processed_dataset = dpg._process_dataset(dataset) + processed_rewards = torch.stack([processed_dataset[i]["rewards"] for i in range(1)], dim=-1) + + expected_rewards = torch.tensor( + [ + [1.08406], + [1.38105], + [-0.5], + [0.77707], + ] + ) + self.assertTrue(len(processed_dataset) == 1) + self.assertTrue(torch.isclose(processed_rewards, expected_rewards).all()) + + dataset = [ + { + "actions": actions[:, i + 3], + "critic_observations": observations[:, i + 3], + "dones": dones[:, i + 3], + "rewards": rewards[:, i + 3], + "timeouts": timeouts[:, i + 3], + } + for i in range(3) + ] + + processed_dataset = dpg._process_dataset(dataset) + processed_rewards = torch.stack([processed_dataset[i]["rewards"] for i in range(3)], dim=-1) + + expected_rewards = torch.tensor( + [ + [0.994, 0.6, -0.59002], + [0.51291, -1.5592792, -2.08008], + [0.20398, 0.09603, 0.19501], + [1.593, 0.093, 0.7], + ] + ) + + self.assertTrue(len(processed_dataset) == 3) + self.assertTrue(torch.isclose(processed_rewards, expected_rewards).all()) diff --git a/tests/test_dppo.py b/tests/test_dppo.py new file mode 100644 index 0000000..9be2316 --- /dev/null +++ b/tests/test_dppo.py @@ -0,0 +1,278 @@ +import torch +import unittest +from rsl_rl.algorithms import DPPO +from rsl_rl.env.pole_balancing import PoleBalancing + + +class FakeCritic(torch.nn.Module): + def __init__(self, values, quantile_count=1): + self.quantile_count = quantile_count + self.recurrent = False + self.values = values + self.last_quantiles = values + + def forward(self, _, distribution=False, measure_args=None): + if distribution: + return self.values + + return self.values.mean(-1) + + def quantiles_to_values(self, quantiles): + return quantiles.mean(-1) + + +class DPPOTest(unittest.TestCase): + def test_gae_computation(self): + # GT taken from old PPO implementation. + + env = PoleBalancing(environment_count=4) + dppo = DPPO(env, device="cpu", gae_lambda=0.97, gamma=0.99, qrdqn_quantile_count=1) + + rewards = torch.tensor( + [ + [-1.0000e02, -1.4055e-01, -3.0476e-02, -2.7149e-01, -1.1157e-01, -2.3366e-01, -3.3658e-01, -1.6447e-01], + [ + -1.7633e-01, + -2.6533e-01, + -3.0786e-01, + -2.6038e-01, + -2.7176e-01, + -2.1655e-01, + -1.5441e-01, + -2.9580e-01, + ], + [-1.5952e-01, -1.5177e-01, -1.4296e-01, -1.6131e-01, -3.1395e-02, 2.8808e-03, -3.1242e-02, 4.8696e-03], + [1.1407e-02, -1.0000e02, -6.2290e-02, -3.7030e-01, -2.7648e-01, -3.6655e-01, -2.8456e-01, -2.3165e-01], + ] + ) + dones = torch.tensor( + [ + [1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0], + ] + ) + observations = torch.zeros((dones.shape[0], dones.shape[1], 24)) + timeouts = torch.zeros((dones.shape[0], dones.shape[1])) + values = torch.tensor( + [ + [-4.6342, -7.6510, -7.0166, -7.6137, -7.4130, -7.7071, -7.7413, -7.8301], + [-7.0442, -7.0032, -6.9321, -6.7765, -6.5433, -6.3503, -6.2529, -5.9337], + [-7.5753, -7.8146, -7.6142, -7.8443, -7.8791, -7.7973, -7.7853, -7.7724], + [-6.4326, -6.1673, -7.6511, -7.7505, -8.0004, -7.8584, -7.5949, -7.9023], + ] + ) + value_quants = values.unsqueeze(-1) + last_values = torch.tensor([-7.9343, -5.8734, -7.8527, -8.1257]) + + dppo.critic = FakeCritic(last_values.unsqueeze(-1)) + dataset = [ + { + "dones": dones[:, i], + "full_next_critic_observations": observations[:, i].clone(), + "next_critic_observations": observations[:, i], + "rewards": rewards[:, i], + "timeouts": timeouts[:, i], + "values": values[:, i], + "value_quants": value_quants[:, i], + } + for i in range(dones.shape[1]) + ] + + processed_dataset = dppo._process_dataset(dataset) + processed_returns = torch.stack( + [processed_dataset[i]["advantages"] + processed_dataset[i]["values"] for i in range(dones.shape[1])], + dim=-1, + ) + processed_advantages = torch.stack( + [processed_dataset[i]["normalized_advantages"] for i in range(dones.shape[1])], dim=-1 + ) + + expected_returns = torch.tensor( + [ + [-100.0000, -8.4983, -8.4863, -8.5699, -8.4122, -8.4054, -8.2702, -8.0194], + [-7.2900, -7.1912, -6.9978, -6.7569, -6.5627, -6.3547, -6.1985, -6.1104], + [-7.9179, -7.8374, -7.7679, -7.6976, -7.6041, -7.6446, -7.7229, -7.7693], + [-96.2018, -100.0000, -9.0710, -9.1415, -8.8863, -8.7228, -8.4668, -8.2761], + ] + ) + expected_advantages = torch.tensor( + [ + [-3.1452, 0.3006, 0.2779, 0.2966, 0.2951, 0.3060, 0.3122, 0.3246], + [0.3225, 0.3246, 0.3291, 0.3322, 0.3308, 0.3313, 0.3335, 0.3250], + [0.3190, 0.3307, 0.3259, 0.3368, 0.3415, 0.3371, 0.3338, 0.3316], + [-2.9412, -3.0893, 0.2797, 0.2808, 0.2992, 0.3000, 0.2997, 0.3179], + ] + ) + + self.assertTrue(torch.isclose(processed_returns, expected_returns, atol=1e-4).all()) + self.assertTrue(torch.isclose(processed_advantages, expected_advantages, atol=1e-4).all()) + + def test_target_computation_nstep(self): + # GT taken from old PPO implementation. + + env = PoleBalancing(environment_count=2) + dppo = DPPO(env, device="cpu", gae_lambda=0.97, gamma=0.99, qrdqn_quantile_count=3, value_lambda=1.0) + + rewards = torch.tensor( + [ + [0.6, 1.0, 0.0, 0.6, 0.1, -0.2], + [0.1, -0.2, -0.4, 0.1, 1.0, 1.0], + ] + ) + dones = torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 1.0, 0.0], + ] + ) + observations = torch.zeros((dones.shape[0], dones.shape[1], 24)) + timeouts = torch.zeros((dones.shape[0], dones.shape[1])) + value_quants = torch.tensor( + [ + [ + [-0.4, 0.0, 0.1], + [0.9, 0.8, 0.7], + [0.7, 1.3, 0.0], + [1.2, 0.4, 1.2], + [1.3, 1.3, 1.1], + [0.0, 0.7, 0.5], + ], + [ + [1.3, 1.3, 0.9], + [0.4, -0.4, -0.1], + [0.4, 0.6, 0.1], + [0.7, 0.1, 0.3], + [0.2, 0.1, 0.3], + [1.4, 1.4, -0.3], + ], + ] + ) + values = value_quants.mean(dim=-1) + last_values = torch.rand((dones.shape[0], 3)) + + dppo.critic = FakeCritic(last_values, quantile_count=3) + dataset = [ + { + "dones": dones[:, i], + "full_next_critic_observations": observations[:, i].clone(), + "next_critic_observations": observations[:, i], + "rewards": rewards[:, i], + "timeouts": timeouts[:, i], + "values": values[:, i], + "value_quants": value_quants[:, i], + } + for i in range(dones.shape[1]) + ] + + processed_dataset = dppo._process_dataset(dataset) + processed_value_target_quants = torch.stack( + [processed_dataset[i]["value_target_quants"] for i in range(dones.shape[1])], + dim=-2, + ) + + # N-step returns + # These exclude the reward received on the final step since it should not be added to the value target. + expected_returns = torch.tensor( + [ + [0.6, 1.58806, 0.594, 0.6, 0.1, -0.2], + [-0.098, -0.2, -0.4, 0.1, 1.0, 1.0], + ] + ) + reset = lambda x: [0.0 for _ in x] + dscnt = lambda x, s=0: [x[v] * dppo.gamma**s for v in range(len(x))] + expected_value_target_quants = expected_returns.unsqueeze(-1) + torch.tensor( + [ + [ + reset([-0.4, 0.0, 0.1]), + dscnt([1.3, 1.3, 1.1], 3), + dscnt([1.3, 1.3, 1.1], 2), + dscnt([1.3, 1.3, 1.1], 1), + reset([1.3, 1.3, 1.1]), + dscnt(last_values[0], 1), + ], + [ + dscnt([0.4, 0.6, 0.1], 2), + dscnt([0.4, 0.6, 0.1], 1), + reset([0.4, 0.6, 0.1]), + dscnt([0.2, 0.1, 0.3], 1), + reset([0.2, 0.1, 0.3]), + dscnt(last_values[1], 1), + ], + ] + ) + + self.assertTrue(torch.isclose(processed_value_target_quants, expected_value_target_quants, atol=1e-4).all()) + + def test_target_computation_1step(self): + # GT taken from old PPO implementation. + + env = PoleBalancing(environment_count=2) + dppo = DPPO(env, device="cpu", gae_lambda=0.97, gamma=0.99, qrdqn_quantile_count=3, value_lambda=0.0) + + rewards = torch.tensor( + [ + [0.6, 1.0, 0.0, 0.6, 0.1, -0.2], + [0.1, -0.2, -0.4, 0.1, 1.0, 1.0], + ] + ) + dones = torch.tensor( + [ + [1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 1.0, 0.0], + ] + ) + observations = torch.zeros((dones.shape[0], dones.shape[1], 24)) + timeouts = torch.zeros((dones.shape[0], dones.shape[1])) + value_quants = torch.tensor( + [ + [ + [-0.4, 0.0, 0.1], + [0.9, 0.8, 0.7], + [0.7, 1.3, 0.0], + [1.2, 0.4, 1.2], + [1.3, 1.3, 1.1], + [0.0, 0.7, 0.5], + ], + [ + [1.3, 1.3, 0.9], + [0.4, -0.4, -0.1], + [0.4, 0.6, 0.1], + [0.7, 0.1, 0.3], + [0.2, 0.1, 0.3], + [1.4, 1.4, -0.3], + ], + ] + ) + values = value_quants.mean(dim=-1) + last_values = torch.rand((dones.shape[0], 3)) + + dppo.critic = FakeCritic(last_values, quantile_count=3) + dataset = [ + { + "dones": dones[:, i], + "full_next_critic_observations": observations[:, i].clone(), + "next_critic_observations": observations[:, i], + "rewards": rewards[:, i], + "timeouts": timeouts[:, i], + "values": values[:, i], + "value_quants": value_quants[:, i], + } + for i in range(dones.shape[1]) + ] + + processed_dataset = dppo._process_dataset(dataset) + processed_value_target_quants = torch.stack( + [processed_dataset[i]["value_target_quants"] for i in range(dones.shape[1])], + dim=-2, + ) + + # 1-step returns + expected_value_target_quants = rewards.unsqueeze(-1) + ( + (1.0 - dones).float().unsqueeze(-1) + * dppo.gamma + * torch.cat((value_quants[:, 1:], last_values.unsqueeze(1)), dim=1) + ) + + self.assertTrue(torch.isclose(processed_value_target_quants, expected_value_target_quants, atol=1e-4).all()) diff --git a/tests/test_dppo_iqn.py b/tests/test_dppo_iqn.py new file mode 100644 index 0000000..dbe498a --- /dev/null +++ b/tests/test_dppo_iqn.py @@ -0,0 +1,171 @@ +import torch +import unittest +from rsl_rl.algorithms import DPPO +from rsl_rl.env.vec_env import VecEnv + +ACTION_SIZE = 3 +ENV_COUNT = 3 +OBS_SIZE = 24 + + +class FakeEnv(VecEnv): + def __init__(self, rewards, dones, environment_count=1): + super().__init__(OBS_SIZE, OBS_SIZE, environment_count=environment_count) + + self.num_actions = ACTION_SIZE + self.rewards = rewards + self.dones = dones + + self._step = 0 + + def get_observations(self): + return torch.zeros((self.num_envs, self.num_obs)), {"observations": {}} + + def get_privileged_observations(self): + return torch.zeros((self.num_envs, self.num_privileged_obs)), {"observations": {}} + + def step(self, actions): + obs, _ = self.get_observations() + rewards = self.rewards[self._step] + dones = self.dones[self._step] + + self._step += 1 + + return obs, rewards, dones, {"observations": {}} + + def reset(self): + pass + + +class FakeCritic(torch.nn.Module): + def __init__(self, action_samples, value_samples, action_values, value_values, action_taus, value_taus): + self.recurrent = False + self.action_samples = action_samples + self.value_samples = value_samples + self.action_values = action_values + self.value_values = value_values + self.action_taus = action_taus + self.value_taus = value_taus + + self.last_quantiles = None + self.last_taus = None + + def forward(self, _, distribution=False, measure_args=None, sample_count=8, taus=None, use_measure=True): + if taus is not None: + sample_count = taus.shape[-1] + + if sample_count == self.action_samples: + self.last_taus = self.action_taus + self.last_quantiles = self.action_values + elif sample_count == self.value_samples: + self.last_taus = self.value_taus + self.last_quantiles = self.value_values + else: + raise ValueError(f"Invalid sample count: {sample_count}") + + if distribution: + return self.last_quantiles + + return self.last_quantiles.mean(-1) + + +def fake_process_quants(self, x): + idx = torch.arange(0, x.shape[-1]).expand(*x.shape[:-1]) + + return x, idx + + +class DPPOTest(unittest.TestCase): + def test_value_target_computation(self): + rewards = torch.tensor( + [ + [-1.0000e02, -1.4055e-01, -3.0476e-02], + [-1.7633e-01, -2.6533e-01, -3.0786e-01], + [-1.5952e-01, -1.5177e-01, -1.4296e-01], + [1.1407e-02, -1.0000e02, -6.2290e-02], + ] + ) + dones = torch.tensor( + [ + [1, 0, 0], + [0, 0, 0], + [0, 0, 0], + [0, 1, 0], + ] + ) + + env = FakeEnv(rewards, dones, environment_count=ENV_COUNT) + dppo = DPPO( + env, + critic_network=DPPO.network_iqn, + device="cpu", + gae_lambda=0.97, + gamma=0.99, + iqn_action_samples=4, + iqn_value_samples=2, + value_lambda=1.0, + value_loss=DPPO.value_loss_energy, + ) + + # Generate fake dataset + + action_taus = torch.tensor( + [ + [[0.3, 0.5, 1.0, 0.2], [0.8, 0.9, 0.0, 0.9], [0.6, 0.1, 0.6, 0.5]], + [[0.7, 0.9, 0.3, 0.0], [1.0, 0.7, 0.7, 0.7], [0.3, 0.8, 0.8, 0.1]], + [[0.3, 0.8, 0.3, 0.2], [0.2, 0.9, 0.6, 0.4], [0.8, 0.4, 0.8, 1.0]], + [[0.6, 0.6, 0.8, 0.8], [0.8, 0.0, 0.9, 0.1], [0.2, 0.3, 0.6, 0.2]], + ] + ) + action_value_quants = torch.tensor( + [ + [[0.2, 0.2, 0.6, 0.5], [0.5, 0.8, 0.1, 0.0], [1.0, 0.1, 0.8, 0.8]], + [[0.0, 0.6, 0.1, 0.9], [0.2, 1.0, 0.9, 1.0], [0.4, 0.1, 0.1, 0.8]], + [[0.7, 0.0, 0.6, 0.8], [0.7, 0.7, 0.7, 0.8], [0.0, 0.1, 0.5, 0.8]], + [[0.5, 0.8, 0.1, 0.1], [0.9, 0.4, 0.7, 0.6], [0.6, 0.3, 0.1, 0.4]], + ] + ) + value_taus = torch.tensor( + [ + [[0.3, 0.5], [0.8, 0.9], [0.6, 0.1]], + [[0.7, 0.9], [1.0, 0.7], [0.3, 0.8]], + [[0.3, 0.8], [0.2, 0.9], [0.8, 0.4]], + [[0.6, 0.6], [0.8, 0.0], [0.2, 0.3]], + ] + ) + value_value_quants = torch.tensor( + [ + [[0.9, 0.8], [0.1, 0.3], [0.3, 0.5]], + [[0.2, 0.1], [0.9, 0.3], [0.4, 0.2]], + [[0.7, 1.0], [0.6, 0.2], [0.2, 0.6]], + [[0.4, 1.0], [0.3, 0.6], [0.3, 0.1]], + ] + ) + + actions = torch.zeros(ENV_COUNT, ACTION_SIZE) + env_info = {"observations": {}} + obs = torch.zeros(ENV_COUNT, OBS_SIZE) + dataset = [] + for i in range(4): + dppo.critic = FakeCritic(4, 2, action_value_quants[i], value_value_quants[i], action_taus[i], value_taus[i]) + dppo.critic._process_quants = fake_process_quants + + _, data = dppo.draw_actions(obs, {}) + _, rewards, dones, _ = env.step(actions) + + dataset.append( + dppo.process_transition( + obs, + env_info, + actions, + rewards, + obs, + env_info, + dones, + data, + ) + ) + + processed_dataset = dppo._process_dataset(dataset) + + # TODO: Test that the value targets are correct. diff --git a/tests/test_dppo_recurrency.py b/tests/test_dppo_recurrency.py new file mode 100644 index 0000000..dd860e8 --- /dev/null +++ b/tests/test_dppo_recurrency.py @@ -0,0 +1,170 @@ +import torch +import unittest +from rsl_rl.algorithms import DPPO +from rsl_rl.env.vec_env import VecEnv +from rsl_rl.runners.runner import Runner +from rsl_rl.utils.benchmarkable import Benchmarkable + + +class FakeNetwork(torch.nn.Module, Benchmarkable): + def __init__(self, values): + super().__init__() + + self.hidden_state = None + self.quantile_count = 1 + self.recurrent = True + self.values = values + + self._hidden_size = 2 + + def forward(self, x, hidden_state=None): + if not hidden_state: + self.hidden_state = (self.hidden_state[0] + 1, self.hidden_state[1] - 1) + + values = self.values.repeat((*x.shape[:-1], 1)).squeeze(-1) + values.requires_grad_(True) + + return values + + def reset_full_hidden_state(self, batch_size=None): + assert batch_size is None or batch_size == 4, f"batch_size={batch_size}" + + self.hidden_state = (torch.zeros((1, 4, self._hidden_size)), torch.zeros((1, 4, self._hidden_size))) + + def reset_hidden_state(self, indices): + self.hidden_state[0][:, indices] = torch.zeros((len(indices), self._hidden_size)) + self.hidden_state[1][:, indices] = torch.zeros((len(indices), self._hidden_size)) + + +class FakeActorNetwork(FakeNetwork): + def forward(self, x, compute_std=False, hidden_state=None): + values = super().forward(x, hidden_state=hidden_state) + + if compute_std: + return values, torch.ones_like(values) + + return values + + +class FakeCriticNetwork(FakeNetwork): + _quantile_count = 1 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, x, distribution=False, hidden_state=None, measure_args=None): + values = super().forward(x, hidden_state=hidden_state) + self.last_quantiles = values.reshape(*values.shape, 1) + + if distribution: + return self.last_quantiles + + return values + + def quantile_l1_loss(self, *args, **kwargs): + return torch.tensor(0.0) + + def quantiles_to_values(self, quantiles): + return quantiles.squeeze() + + +class FakeEnv(VecEnv): + def __init__(self, dones=None, **kwargs): + super().__init__(3, 3, **kwargs) + + self.num_actions = 3 + self._extra = {"observations": {}, "time_outs": torch.zeros((self.num_envs, 1))} + + self._step = 0 + self._dones = dones + + self.reset() + + def get_observations(self): + return self._state_buf, self._extra + + def get_privileged_observations(self): + return self._state_buf, self._extra + + def reset(self): + self._state_buf = torch.zeros((self.num_envs, self.num_obs)) + + return self._state_buf, self._extra + + def step(self, actions): + assert actions.shape[0] == self.num_envs + assert actions.shape[1] == self.num_actions + + self._state_buf += actions + + rewards = torch.zeros((self.num_envs)) + dones = torch.zeros((self.num_envs)) if self._dones is None else self._dones[self._step % self._dones.shape[0]] + + self._step += 1 + + return self._state_buf, rewards, dones, self._extra + + +class DPPORecurrencyTest(unittest.TestCase): + def test_draw_action_produces_hidden_state(self): + """Test that the hidden state is correctly added to the data dictionary when drawing actions.""" + env = FakeEnv(environment_count=4) + dppo = DPPO(env, device="cpu", recurrent=True) + + dppo.actor = FakeActorNetwork(torch.ones(env.num_actions)) + dppo.critic = FakeCriticNetwork(torch.zeros(1)) + + # Done during DPPO.__init__, however we need to reset the hidden state here again since we are using a fake + # network that is added after initialization. + dppo.actor.reset_full_hidden_state(batch_size=env.num_envs) + dppo.critic.reset_full_hidden_state(batch_size=env.num_envs) + + ones = torch.ones((1, env.num_envs, dppo.actor._hidden_size)) + state, extra = env.reset() + for ctr in range(10): + _, data = dppo.draw_actions(state, extra) + + # Actor state is changed every time an action is drawn. + self.assertTrue(torch.allclose(data["actor_state_h"], ones * ctr)) + self.assertTrue(torch.allclose(data["actor_state_c"], -ones * ctr)) + # Critic state is only changed and saved when processing the transition (evaluating the action) so we can't + # check it here. + + def test_update_produces_hidden_state(self): + """Test that the hidden state is correctly added to the data dictionary when updating.""" + dones = torch.cat((torch.tensor([[0, 0, 0, 1]]), torch.zeros((4, 4)), torch.tensor([[1, 0, 0, 0]])), dim=0) + + env = FakeEnv(dones=dones, environment_count=4) + dppo = DPPO(env, device="cpu", recurrent=True) + runner = Runner(env, dppo, num_steps_per_env=6) + + dppo._value_loss = lambda *args, **kwargs: torch.tensor(0.0) + + dppo.actor = FakeActorNetwork(torch.ones(env.num_actions)) + dppo.critic = FakeCriticNetwork(torch.zeros(1)) + + dppo.actor.reset_full_hidden_state(batch_size=env.num_envs) + dppo.critic.reset_full_hidden_state(batch_size=env.num_envs) + + runner.learn(1) + + state_h_0 = torch.tensor([[0, 0], [0, 0], [0, 0], [0, 0]]) + state_h_1 = torch.tensor([[1, 1], [1, 1], [1, 1], [0, 0]]) + state_h_2 = state_h_1 + 1 + state_h_3 = state_h_2 + 1 + state_h_4 = state_h_3 + 1 + state_h_5 = state_h_4 + 1 + state_h_6 = torch.tensor([[0, 0], [6, 6], [6, 6], [5, 5]]) + state_h = ( + torch.cat((state_h_0, state_h_1, state_h_2, state_h_3, state_h_4, state_h_5), dim=0).float().unsqueeze(1) + ) + next_state_h = ( + torch.cat((state_h_1, state_h_2, state_h_3, state_h_4, state_h_5, state_h_6), dim=0).float().unsqueeze(1) + ) + + self.assertTrue(torch.allclose(dppo.storage._data["critic_state_h"], state_h)) + self.assertTrue(torch.allclose(dppo.storage._data["critic_state_c"], -state_h)) + self.assertTrue(torch.allclose(dppo.storage._data["critic_next_state_h"], next_state_h)) + self.assertTrue(torch.allclose(dppo.storage._data["critic_next_state_c"], -next_state_h)) + self.assertTrue(torch.allclose(dppo.storage._data["actor_state_h"], state_h)) + self.assertTrue(torch.allclose(dppo.storage._data["actor_state_c"], -state_h)) diff --git a/tests/test_ppo.py b/tests/test_ppo.py new file mode 100644 index 0000000..4f8ef0f --- /dev/null +++ b/tests/test_ppo.py @@ -0,0 +1,99 @@ +import torch +import unittest +from rsl_rl.algorithms import PPO +from rsl_rl.env.pole_balancing import PoleBalancing + + +class FakeCritic(torch.nn.Module): + def __init__(self, values): + self.recurrent = False + self.values = values + + def forward(self, _): + return self.values + + +class PPOTest(unittest.TestCase): + def test_gae_computation(self): + # GT taken from old PPO implementation. + + env = PoleBalancing(environment_count=4) + ppo = PPO(env, device="cpu", gae_lambda=0.97, gamma=0.99) + + rewards = torch.tensor( + [ + [-1.0000e02, -1.4055e-01, -3.0476e-02, -2.7149e-01, -1.1157e-01, -2.3366e-01, -3.3658e-01, -1.6447e-01], + [ + -1.7633e-01, + -2.6533e-01, + -3.0786e-01, + -2.6038e-01, + -2.7176e-01, + -2.1655e-01, + -1.5441e-01, + -2.9580e-01, + ], + [-1.5952e-01, -1.5177e-01, -1.4296e-01, -1.6131e-01, -3.1395e-02, 2.8808e-03, -3.1242e-02, 4.8696e-03], + [1.1407e-02, -1.0000e02, -6.2290e-02, -3.7030e-01, -2.7648e-01, -3.6655e-01, -2.8456e-01, -2.3165e-01], + ] + ) + dones = torch.tensor( + [ + [1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0], + ] + ) + observations = torch.zeros((dones.shape[0], dones.shape[1], 24)) + timeouts = torch.zeros((dones.shape[0], dones.shape[1])) + values = torch.tensor( + [ + [-4.6342, -7.6510, -7.0166, -7.6137, -7.4130, -7.7071, -7.7413, -7.8301], + [-7.0442, -7.0032, -6.9321, -6.7765, -6.5433, -6.3503, -6.2529, -5.9337], + [-7.5753, -7.8146, -7.6142, -7.8443, -7.8791, -7.7973, -7.7853, -7.7724], + [-6.4326, -6.1673, -7.6511, -7.7505, -8.0004, -7.8584, -7.5949, -7.9023], + ] + ) + last_values = torch.tensor([-7.9343, -5.8734, -7.8527, -8.1257]) + + ppo.critic = FakeCritic(last_values) + dataset = [ + { + "dones": dones[:, i], + "next_critic_observations": observations[:, i], + "rewards": rewards[:, i], + "timeouts": timeouts[:, i], + "values": values[:, i], + } + for i in range(dones.shape[1]) + ] + + processed_dataset = ppo._process_dataset(dataset) + processed_returns = torch.stack( + [processed_dataset[i]["advantages"] + processed_dataset[i]["values"] for i in range(dones.shape[1])], + dim=-1, + ) + processed_advantages = torch.stack( + [processed_dataset[i]["normalized_advantages"] for i in range(dones.shape[1])], dim=-1 + ) + + expected_returns = torch.tensor( + [ + [-100.0000, -8.4983, -8.4863, -8.5699, -8.4122, -8.4054, -8.2702, -8.0194], + [-7.2900, -7.1912, -6.9978, -6.7569, -6.5627, -6.3547, -6.1985, -6.1104], + [-7.9179, -7.8374, -7.7679, -7.6976, -7.6041, -7.6446, -7.7229, -7.7693], + [-96.2018, -100.0000, -9.0710, -9.1415, -8.8863, -8.7228, -8.4668, -8.2761], + ] + ) + expected_advantages = torch.tensor( + [ + [-3.1452, 0.3006, 0.2779, 0.2966, 0.2951, 0.3060, 0.3122, 0.3246], + [0.3225, 0.3246, 0.3291, 0.3322, 0.3308, 0.3313, 0.3335, 0.3250], + [0.3190, 0.3307, 0.3259, 0.3368, 0.3415, 0.3371, 0.3338, 0.3316], + [-2.9412, -3.0893, 0.2797, 0.2808, 0.2992, 0.3000, 0.2997, 0.3179], + ] + ) + + self.assertTrue(torch.isclose(processed_returns, expected_returns, atol=1e-4).all()) + self.assertTrue(torch.isclose(processed_advantages, expected_advantages, atol=1e-4).all()) diff --git a/tests/test_ppo_recurrency.py b/tests/test_ppo_recurrency.py new file mode 100644 index 0000000..702b19d --- /dev/null +++ b/tests/test_ppo_recurrency.py @@ -0,0 +1,144 @@ +import torch +import unittest +from rsl_rl.algorithms import PPO +from rsl_rl.env.vec_env import VecEnv +from rsl_rl.runners.runner import Runner + + +class FakeNetwork(torch.nn.Module): + def __init__(self, values): + super().__init__() + + self.hidden_state = None + self.recurrent = True + self.values = values + + self._hidden_size = 2 + + def forward(self, x, hidden_state=None): + if not hidden_state: + self.hidden_state = (self.hidden_state[0] + 1, self.hidden_state[1] - 1) + + values = self.values.repeat((*x.shape[:-1], 1)).squeeze(-1) + values.requires_grad_(True) + + return values + + def reset_full_hidden_state(self, batch_size=None): + assert batch_size is None or batch_size == 4, f"batch_size={batch_size}" + + self.hidden_state = (torch.zeros((1, 4, self._hidden_size)), torch.zeros((1, 4, self._hidden_size))) + + def reset_hidden_state(self, indices): + self.hidden_state[0][:, indices] = torch.zeros((len(indices), self._hidden_size)) + self.hidden_state[1][:, indices] = torch.zeros((len(indices), self._hidden_size)) + + +class FakeActorNetwork(FakeNetwork): + def forward(self, x, compute_std=False, hidden_state=None): + values = super().forward(x, hidden_state=hidden_state) + + if compute_std: + return values, torch.ones_like(values) + + return values + + +class FakeEnv(VecEnv): + def __init__(self, dones=None, **kwargs): + super().__init__(3, 3, **kwargs) + + self.num_actions = 3 + self._extra = {"observations": {}, "time_outs": torch.zeros((self.num_envs, 1))} + + self._step = 0 + self._dones = dones + + self.reset() + + def get_observations(self): + return self._state_buf, self._extra + + def get_privileged_observations(self): + return self._state_buf, self._extra + + def reset(self): + self._state_buf = torch.zeros((self.num_envs, self.num_obs)) + + return self._state_buf, self._extra + + def step(self, actions): + assert actions.shape[0] == self.num_envs + assert actions.shape[1] == self.num_actions + + self._state_buf += actions + + rewards = torch.zeros((self.num_envs)) + dones = torch.zeros((self.num_envs)) if self._dones is None else self._dones[self._step % self._dones.shape[0]] + + self._step += 1 + + return self._state_buf, rewards, dones, self._extra + + +class PPORecurrencyTest(unittest.TestCase): + def test_draw_action_produces_hidden_state(self): + """Test that the hidden state is correctly added to the data dictionary when drawing actions.""" + env = FakeEnv(environment_count=4) + ppo = PPO(env, device="cpu", recurrent=True) + + ppo.actor = FakeActorNetwork(torch.ones(env.num_actions)) + ppo.critic = FakeNetwork(torch.zeros(1)) + + # Done during PPO.__init__, however we need to reset the hidden state here again since we are using a fake + # network that is added after initialization. + ppo.actor.reset_full_hidden_state(batch_size=env.num_envs) + ppo.critic.reset_full_hidden_state(batch_size=env.num_envs) + + ones = torch.ones((1, env.num_envs, ppo.actor._hidden_size)) + state, extra = env.reset() + for ctr in range(10): + _, data = ppo.draw_actions(state, extra) + + # Actor state is changed every time an action is drawn. + self.assertTrue(torch.allclose(data["actor_state_h"], ones * ctr)) + self.assertTrue(torch.allclose(data["actor_state_c"], -ones * ctr)) + # Critic state is only changed and saved when processing the transition (evaluating the action) so we can't + # check it here. + + def test_update_produces_hidden_state(self): + """Test that the hidden state is correctly added to the data dictionary when updating.""" + dones = torch.cat((torch.tensor([[0, 0, 0, 1]]), torch.zeros((4, 4)), torch.tensor([[1, 0, 0, 0]])), dim=0) + + env = FakeEnv(dones=dones, environment_count=4) + ppo = PPO(env, device="cpu", recurrent=True) + runner = Runner(env, ppo, num_steps_per_env=6) + + ppo.actor = FakeActorNetwork(torch.ones(env.num_actions)) + ppo.critic = FakeNetwork(torch.zeros(1)) + + ppo.actor.reset_full_hidden_state(batch_size=env.num_envs) + ppo.critic.reset_full_hidden_state(batch_size=env.num_envs) + + runner.learn(1) + + state_h_0 = torch.tensor([[0, 0], [0, 0], [0, 0], [0, 0]]) + state_h_1 = torch.tensor([[1, 1], [1, 1], [1, 1], [0, 0]]) + state_h_2 = state_h_1 + 1 + state_h_3 = state_h_2 + 1 + state_h_4 = state_h_3 + 1 + state_h_5 = state_h_4 + 1 + state_h_6 = torch.tensor([[0, 0], [6, 6], [6, 6], [5, 5]]) + state_h = ( + torch.cat((state_h_0, state_h_1, state_h_2, state_h_3, state_h_4, state_h_5), dim=0).float().unsqueeze(1) + ) + next_state_h = ( + torch.cat((state_h_1, state_h_2, state_h_3, state_h_4, state_h_5, state_h_6), dim=0).float().unsqueeze(1) + ) + + self.assertTrue(torch.allclose(ppo.storage._data["critic_state_h"], state_h)) + self.assertTrue(torch.allclose(ppo.storage._data["critic_state_c"], -state_h)) + self.assertTrue(torch.allclose(ppo.storage._data["critic_next_state_h"], next_state_h)) + self.assertTrue(torch.allclose(ppo.storage._data["critic_next_state_c"], -next_state_h)) + self.assertTrue(torch.allclose(ppo.storage._data["actor_state_h"], state_h)) + self.assertTrue(torch.allclose(ppo.storage._data["actor_state_c"], -state_h)) diff --git a/tests/test_quantile_network.py b/tests/test_quantile_network.py new file mode 100644 index 0000000..eaea109 --- /dev/null +++ b/tests/test_quantile_network.py @@ -0,0 +1,286 @@ +import torch +import unittest +from rsl_rl.modules.quantile_network import QuantileNetwork + + +class QuantileNetworkTest(unittest.TestCase): + def test_l1_loss(self): + qn = QuantileNetwork(10, 1, quantile_count=5) + + prediction = torch.tensor( + [ + [0.8510, 0.2329, 0.4244, 0.5241, 0.2144], + [0.7693, 0.2522, 0.3909, 0.0858, 0.7914], + [0.8701, 0.2144, 0.9661, 0.9975, 0.5043], + [0.2653, 0.6951, 0.9787, 0.2244, 0.0430], + [0.7907, 0.5209, 0.7276, 0.1735, 0.2757], + [0.1696, 0.7167, 0.6363, 0.2188, 0.7025], + [0.0445, 0.6008, 0.5334, 0.1838, 0.7387], + [0.4934, 0.5117, 0.4488, 0.0591, 0.6442], + ] + ) + target = torch.tensor( + [ + [0.3918, 0.8979, 0.4347, 0.1076, 0.5303], + [0.5449, 0.9974, 0.3197, 0.8686, 0.0631], + [0.7397, 0.7734, 0.6559, 0.3020, 0.7229], + [0.9519, 0.8138, 0.1502, 0.3445, 0.3356], + [0.8970, 0.0910, 0.7536, 0.6069, 0.2556], + [0.1741, 0.6863, 0.7142, 0.2911, 0.3142], + [0.8835, 0.0215, 0.4774, 0.5362, 0.4998], + [0.8037, 0.8269, 0.5518, 0.4368, 0.5323], + ] + ) + + loss = qn.quantile_l1_loss(prediction, target) + + self.assertAlmostEqual(loss.item(), 0.16419549) + + def test_l1_loss_3d(self): + qn = QuantileNetwork(10, 1, quantile_count=5) + + prediction = torch.tensor( + [ + [ + [0.8510, 0.2329, 0.4244, 0.5241, 0.2144], + [0.7693, 0.2522, 0.3909, 0.0858, 0.7914], + [0.8701, 0.2144, 0.9661, 0.9975, 0.5043], + [0.2653, 0.6951, 0.9787, 0.2244, 0.0430], + [0.7907, 0.5209, 0.7276, 0.1735, 0.2757], + [0.1696, 0.7167, 0.6363, 0.2188, 0.7025], + [0.0445, 0.6008, 0.5334, 0.1838, 0.7387], + [0.4934, 0.5117, 0.4488, 0.0591, 0.6442], + ], + [ + [0.6874, 0.6214, 0.7796, 0.8148, 0.2070], + [0.0276, 0.5764, 0.5516, 0.9682, 0.6901], + [0.4020, 0.7084, 0.9965, 0.4311, 0.3789], + [0.5350, 0.9431, 0.1032, 0.6959, 0.4992], + [0.5059, 0.5479, 0.2302, 0.6753, 0.1593], + [0.6753, 0.4590, 0.9956, 0.6117, 0.1410], + [0.7464, 0.7184, 0.2972, 0.7694, 0.7999], + [0.3907, 0.2112, 0.6485, 0.0139, 0.6252], + ], + ] + ) + target = torch.tensor( + [ + [ + [0.3918, 0.8979, 0.4347, 0.1076, 0.5303], + [0.5449, 0.9974, 0.3197, 0.8686, 0.0631], + [0.7397, 0.7734, 0.6559, 0.3020, 0.7229], + [0.9519, 0.8138, 0.1502, 0.3445, 0.3356], + [0.8970, 0.0910, 0.7536, 0.6069, 0.2556], + [0.1741, 0.6863, 0.7142, 0.2911, 0.3142], + [0.8835, 0.0215, 0.4774, 0.5362, 0.4998], + [0.8037, 0.8269, 0.5518, 0.4368, 0.5323], + ], + [ + [0.5120, 0.7683, 0.3579, 0.8640, 0.4374], + [0.2533, 0.3039, 0.2214, 0.7069, 0.3093], + [0.6993, 0.4288, 0.0827, 0.9156, 0.2043], + [0.6739, 0.2303, 0.3263, 0.6884, 0.3847], + [0.3990, 0.1458, 0.8918, 0.8036, 0.5012], + [0.9061, 0.2024, 0.7276, 0.8619, 0.1198], + [0.7379, 0.2005, 0.7634, 0.5691, 0.6132], + [0.4341, 0.5711, 0.1119, 0.4286, 0.7521], + ], + ] + ) + + loss = qn.quantile_l1_loss(prediction, target) + + self.assertAlmostEqual(loss.item(), 0.15836075) + + def test_l1_loss_multi_output(self): + qn = QuantileNetwork(10, 3, quantile_count=10) + + prediction = torch.tensor( + [ + [0.3003, 0.8692, 0.4608, 0.7158, 0.2640, 0.3928, 0.4557, 0.4620, 0.1331, 0.6356], + [0.8867, 0.1521, 0.5827, 0.0501, 0.4401, 0.7216, 0.6081, 0.5758, 0.2772, 0.6048], + [0.0763, 0.1609, 0.1860, 0.9173, 0.2121, 0.1920, 0.8509, 0.8588, 0.3321, 0.7202], + [0.8375, 0.5339, 0.4287, 0.9228, 0.8519, 0.0420, 0.5736, 0.9156, 0.4444, 0.2039], + [0.0704, 0.1833, 0.0839, 0.9573, 0.9852, 0.4191, 0.3562, 0.7225, 0.8481, 0.2096], + [0.4054, 0.8172, 0.8737, 0.2138, 0.4455, 0.7538, 0.1936, 0.9346, 0.8710, 0.0178], + [0.2139, 0.6619, 0.6889, 0.5726, 0.0595, 0.3278, 0.7673, 0.0803, 0.0374, 0.9011], + [0.2757, 0.0309, 0.8913, 0.0958, 0.1828, 0.9624, 0.6529, 0.7451, 0.9996, 0.8877], + [0.0722, 0.4240, 0.0716, 0.3199, 0.5570, 0.1056, 0.5950, 0.9926, 0.2991, 0.7334], + [0.0576, 0.6353, 0.5078, 0.4456, 0.9119, 0.6897, 0.1720, 0.5172, 0.9939, 0.5044], + [0.6300, 0.2304, 0.4064, 0.9195, 0.3299, 0.8631, 0.5842, 0.6751, 0.2964, 0.1215], + [0.7418, 0.5448, 0.7615, 0.6333, 0.9255, 0.1129, 0.0552, 0.4198, 0.9953, 0.7482], + [0.9910, 0.7644, 0.7047, 0.1395, 0.3688, 0.7688, 0.8574, 0.3494, 0.6153, 0.1286], + [0.2325, 0.7908, 0.3036, 0.4504, 0.3775, 0.6004, 0.0199, 0.9581, 0.8078, 0.8337], + [0.4038, 0.8313, 0.5441, 0.4778, 0.5777, 0.0580, 0.5314, 0.5336, 0.0740, 0.0094], + [0.9025, 0.5814, 0.4711, 0.2683, 0.4443, 0.5799, 0.6703, 0.2678, 0.7538, 0.1317], + [0.6755, 0.5696, 0.3334, 0.9146, 0.6203, 0.2080, 0.0799, 0.0059, 0.8347, 0.1874], + [0.0932, 0.0264, 0.9006, 0.3124, 0.3421, 0.8271, 0.3495, 0.2814, 0.9888, 0.5042], + [0.4893, 0.3514, 0.2564, 0.8117, 0.3738, 0.9085, 0.3055, 0.1456, 0.3624, 0.4095], + [0.0726, 0.2145, 0.6295, 0.7423, 0.1292, 0.7570, 0.4645, 0.0775, 0.1280, 0.7312], + [0.8763, 0.5302, 0.8627, 0.0429, 0.2833, 0.4745, 0.6308, 0.2245, 0.2755, 0.6823], + [0.9997, 0.3519, 0.0312, 0.1468, 0.5145, 0.0286, 0.6333, 0.1323, 0.2264, 0.9109], + [0.7742, 0.4857, 0.0413, 0.4523, 0.6847, 0.5774, 0.9478, 0.5861, 0.9834, 0.9437], + [0.7590, 0.5697, 0.7509, 0.3562, 0.9926, 0.3380, 0.0337, 0.7871, 0.1351, 0.9184], + [0.5701, 0.0234, 0.8088, 0.0681, 0.7090, 0.5925, 0.5266, 0.7198, 0.4121, 0.0268], + [0.5377, 0.1420, 0.2649, 0.0885, 0.1987, 0.1475, 0.1562, 0.2283, 0.9447, 0.4679], + [0.0306, 0.9763, 0.1234, 0.5009, 0.8800, 0.9409, 0.3525, 0.7264, 0.2209, 0.1436], + [0.2492, 0.4041, 0.9044, 0.3730, 0.3152, 0.7515, 0.2614, 0.9726, 0.6402, 0.5211], + [0.8626, 0.2828, 0.6946, 0.7066, 0.4395, 0.3015, 0.2643, 0.4421, 0.6036, 0.9009], + [0.7721, 0.1706, 0.7043, 0.4097, 0.7685, 0.3818, 0.1468, 0.6452, 0.1102, 0.1826], + [0.7156, 0.1795, 0.5574, 0.9478, 0.0058, 0.8037, 0.8712, 0.7730, 0.5638, 0.5843], + [0.8775, 0.6133, 0.4118, 0.3038, 0.2612, 0.2424, 0.8960, 0.8194, 0.3588, 0.3198], + ] + ) + + target = torch.tensor( + [ + [0.0986, 0.4029, 0.3110, 0.9976, 0.5668, 0.2658, 0.0660, 0.8492, 0.7872, 0.6368], + [0.3556, 0.9007, 0.0227, 0.7684, 0.0105, 0.9890, 0.7468, 0.0642, 0.5164, 0.1976], + [0.1331, 0.0998, 0.0959, 0.5596, 0.5984, 0.3880, 0.8050, 0.8320, 0.8977, 0.3486], + [0.3297, 0.8110, 0.2844, 0.4594, 0.0739, 0.2865, 0.2957, 0.9357, 0.9898, 0.4419], + [0.0495, 0.2826, 0.8306, 0.2968, 0.5690, 0.7251, 0.5947, 0.7526, 0.5076, 0.6480], + [0.0381, 0.8645, 0.7774, 0.9158, 0.9682, 0.5851, 0.0913, 0.8948, 0.1251, 0.1205], + [0.9059, 0.2758, 0.1948, 0.2694, 0.0946, 0.4381, 0.4667, 0.2176, 0.3494, 0.6073], + [0.1778, 0.8632, 0.3015, 0.2882, 0.4214, 0.2420, 0.8394, 0.1468, 0.9679, 0.6730], + [0.2400, 0.4344, 0.9765, 0.6544, 0.6338, 0.3434, 0.4776, 0.7981, 0.2008, 0.2267], + [0.5574, 0.8110, 0.0264, 0.4199, 0.8178, 0.8421, 0.8237, 0.2623, 0.8025, 0.9030], + [0.8652, 0.2872, 0.9463, 0.5543, 0.4866, 0.2842, 0.6692, 0.2306, 0.3136, 0.4570], + [0.0651, 0.8955, 0.7531, 0.9373, 0.0265, 0.0795, 0.7755, 0.1123, 0.1920, 0.3273], + [0.9824, 0.4177, 0.2729, 0.9447, 0.3987, 0.5495, 0.3674, 0.8067, 0.8668, 0.2394], + [0.4874, 0.3616, 0.7577, 0.6439, 0.2927, 0.8110, 0.6821, 0.0702, 0.5514, 0.7358], + [0.3627, 0.6392, 0.9085, 0.3646, 0.6051, 0.0586, 0.8763, 0.3899, 0.3242, 0.4598], + [0.0167, 0.0558, 0.3862, 0.7017, 0.0403, 0.6604, 0.9992, 0.2337, 0.5128, 0.1959], + [0.7774, 0.9201, 0.0405, 0.7894, 0.1406, 0.2458, 0.2616, 0.8787, 0.8158, 0.8591], + [0.3225, 0.9827, 0.4032, 0.2621, 0.7949, 0.9796, 0.9480, 0.3353, 0.1430, 0.5747], + [0.4734, 0.8714, 0.9320, 0.4265, 0.7765, 0.6980, 0.1587, 0.8784, 0.7119, 0.5141], + [0.7263, 0.4754, 0.8234, 0.0649, 0.4343, 0.5201, 0.8274, 0.9632, 0.3525, 0.8893], + [0.3324, 0.0142, 0.7222, 0.5026, 0.6011, 0.9275, 0.9351, 0.9236, 0.2621, 0.0768], + [0.8456, 0.1005, 0.5550, 0.0586, 0.3811, 0.0168, 0.9724, 0.9225, 0.7242, 0.0678], + [0.2167, 0.5423, 0.9059, 0.3320, 0.4026, 0.2128, 0.4562, 0.3564, 0.2573, 0.1076], + [0.8385, 0.2233, 0.0736, 0.3407, 0.4702, 0.1668, 0.5174, 0.4154, 0.4407, 0.1843], + [0.1828, 0.5321, 0.6651, 0.4108, 0.5736, 0.4012, 0.0434, 0.0034, 0.9282, 0.3111], + [0.1754, 0.8750, 0.6629, 0.7052, 0.9739, 0.7441, 0.8954, 0.9273, 0.3836, 0.5735], + [0.5586, 0.0381, 0.1493, 0.8575, 0.9351, 0.5222, 0.5600, 0.2369, 0.9217, 0.2545], + [0.1054, 0.8020, 0.8463, 0.6495, 0.3011, 0.3734, 0.7263, 0.8736, 0.9258, 0.5804], + [0.7614, 0.4748, 0.6588, 0.7717, 0.9811, 0.1659, 0.7851, 0.2135, 0.1767, 0.6724], + [0.7655, 0.8571, 0.4224, 0.9397, 0.1363, 0.9431, 0.9326, 0.3762, 0.1077, 0.9514], + [0.4115, 0.2169, 0.1340, 0.6564, 0.9989, 0.8068, 0.0387, 0.5064, 0.9964, 0.9427], + [0.5760, 0.2967, 0.3891, 0.6596, 0.8037, 0.1060, 0.0102, 0.8672, 0.5922, 0.6684], + ] + ) + + loss = qn.quantile_l1_loss(prediction, target) + + self.assertAlmostEqual(loss.item(), 0.17235948) + + def test_quantile_huber_loss(self): + qn = QuantileNetwork(10, 1, quantile_count=5) + + prediction = torch.tensor( + [ + [0.8510, 0.2329, 0.4244, 0.5241, 0.2144], + [0.7693, 0.2522, 0.3909, 0.0858, 0.7914], + [0.8701, 0.2144, 0.9661, 0.9975, 0.5043], + [0.2653, 0.6951, 0.9787, 0.2244, 0.0430], + [0.7907, 0.5209, 0.7276, 0.1735, 0.2757], + [0.1696, 0.7167, 0.6363, 0.2188, 0.7025], + [0.0445, 0.6008, 0.5334, 0.1838, 0.7387], + [0.4934, 0.5117, 0.4488, 0.0591, 0.6442], + ] + ) + target = torch.tensor( + [ + [0.3918, 0.8979, 0.4347, 0.1076, 0.5303], + [0.5449, 0.9974, 0.3197, 0.8686, 0.0631], + [0.7397, 0.7734, 0.6559, 0.3020, 0.7229], + [0.9519, 0.8138, 0.1502, 0.3445, 0.3356], + [0.8970, 0.0910, 0.7536, 0.6069, 0.2556], + [0.1741, 0.6863, 0.7142, 0.2911, 0.3142], + [0.8835, 0.0215, 0.4774, 0.5362, 0.4998], + [0.8037, 0.8269, 0.5518, 0.4368, 0.5323], + ] + ) + + loss = qn.quantile_huber_loss(prediction, target) + + self.assertAlmostEqual(loss.item(), 0.04035041) + + def test_sample_energy_loss(self): + qn = QuantileNetwork(10, 1, quantile_count=5) + + prediction = torch.tensor( + [ + [0.9813, 0.5331, 0.3298, 0.2428, 0.0737], + [0.5442, 0.9623, 0.6070, 0.9360, 0.1145], + [0.3642, 0.0887, 0.1696, 0.8027, 0.7121], + [0.2005, 0.9889, 0.4350, 0.0301, 0.4546], + [0.8360, 0.6766, 0.2257, 0.7589, 0.3443], + [0.0835, 0.1747, 0.1734, 0.6668, 0.4522], + [0.0851, 0.3146, 0.0316, 0.2250, 0.5729], + [0.7725, 0.4596, 0.2495, 0.3633, 0.6340], + ] + ) + target = torch.tensor( + [ + [0.5365, 0.1495, 0.8120, 0.2595, 0.1409], + [0.7784, 0.7070, 0.9066, 0.0123, 0.5587], + [0.9097, 0.0773, 0.9430, 0.2747, 0.1912], + [0.2307, 0.5068, 0.4624, 0.6708, 0.2844], + [0.3356, 0.5885, 0.2484, 0.8468, 0.1833], + [0.3354, 0.8831, 0.3489, 0.7165, 0.7953], + [0.7577, 0.8578, 0.2735, 0.1029, 0.5621], + [0.9124, 0.3476, 0.2012, 0.5830, 0.4615], + ] + ) + + loss = qn.sample_energy_loss(prediction, target) + + self.assertAlmostEqual(loss.item(), 0.09165202) + + def test_cvar(self): + qn = QuantileNetwork(10, 1, quantile_count=5) + measure = qn.measures[qn.measure_cvar](qn, 0.5) + + # Quantiles for 3 agents + input = torch.tensor( + [ + [0.1056, 0.0609, 0.3523, 0.3033, 0.1779], + [0.2049, 0.1425, 0.0767, 0.1868, 0.3891], + [0.1899, 0.1527, 0.2420, 0.2623, 0.1532], + ] + ) + correct_output = torch.tensor( + [ + (0.4 * 0.0609 + 0.4 * 0.1056 + 0.2 * 0.1779), + (0.4 * 0.0767 + 0.4 * 0.1425 + 0.2 * 0.1868), + (0.4 * 0.1527 + 0.4 * 0.1532 + 0.2 * 0.1899), + ] + ) + + computed_output = measure(input) + + self.assertTrue(torch.isclose(computed_output, correct_output).all()) + + def test_cvar_adaptive(self): + qn = QuantileNetwork(10, 1, quantile_count=5) + + input = torch.tensor( + [ + [0.95, 0.21, 0.27, 0.26, 0.19], + [0.38, 0.34, 0.18, 0.32, 0.97], + [0.70, 0.24, 0.38, 0.89, 0.96], + ] + ) + confidence_levels = torch.tensor([0.1, 0.7, 0.9]) + correct_output = torch.tensor( + [ + 0.19, + (0.18 / 3.5 + 0.32 / 3.5 + 0.34 / 3.5 + 0.38 / 7.0), + (0.24 / 4.5 + 0.38 / 4.5 + 0.70 / 4.5 + 0.89 / 4.5 + 0.96 / 9.0), + ] + ) + + measure = qn.measures[qn.measure_cvar](qn, confidence_levels) + computed_output = measure(input) + + self.assertTrue(torch.isclose(computed_output, correct_output).all()) diff --git a/tests/test_trajectory_conversion.py b/tests/test_trajectory_conversion.py new file mode 100644 index 0000000..796407c --- /dev/null +++ b/tests/test_trajectory_conversion.py @@ -0,0 +1,33 @@ +import torch +import unittest + +from rsl_rl.utils.recurrency import trajectories_to_transitions, transitions_to_trajectories + + +class TrajectoryConversionTest(unittest.TestCase): + def test_basic_conversion(self): + input = torch.rand(128, 24) + dones = (torch.rand(128, 24) > 0.8).float() + + trajectories, data = transitions_to_trajectories(input, dones) + transitions = trajectories_to_transitions(trajectories, data) + + self.assertTrue(torch.allclose(input, transitions)) + + def test_2d_observations(self): + input = torch.rand(128, 24, 32) + dones = (torch.rand(128, 24) > 0.8).float() + + trajectories, data = transitions_to_trajectories(input, dones) + transitions = trajectories_to_transitions(trajectories, data) + + self.assertTrue(torch.allclose(input, transitions)) + + def test_batch_first(self): + input = torch.rand(128, 24, 32) + dones = (torch.rand(128, 24) > 0.8).float() + + trajectories, data = transitions_to_trajectories(input, dones, batch_first=True) + transitions = trajectories_to_transitions(trajectories, data) + + self.assertTrue(torch.allclose(input, transitions)) diff --git a/tests/test_transformer.py b/tests/test_transformer.py new file mode 100644 index 0000000..1a7ebcc --- /dev/null +++ b/tests/test_transformer.py @@ -0,0 +1,58 @@ +import unittest +import torch + +from rsl_rl.modules import Transformer # Assuming the Transformer class is in a module named my_module + + +class TestTransformer(unittest.TestCase): + def setUp(self): + self.input_size = 9 + self.output_size = 12 + self.hidden_size = 64 + self.block_count = 2 + self.context_length = 32 + self.head_count = 4 + self.batch_size = 10 + self.sequence_length = 16 + + self.transformer = Transformer( + self.input_size, self.output_size, self.hidden_size, self.block_count, self.context_length, self.head_count + ) + + def test_num_layers(self): + self.assertEqual(self.transformer.num_layers, self.context_length // 2) + + def test_reset_hidden_state(self): + hidden_state = self.transformer.reset_hidden_state(self.batch_size) + self.assertIsInstance(hidden_state, tuple) + self.assertEqual(len(hidden_state), 2) + self.assertTrue( + torch.equal(hidden_state[0], torch.zeros((self.transformer.num_layers, self.batch_size, self.hidden_size))) + ) + self.assertTrue( + torch.equal(hidden_state[1], torch.zeros((self.transformer.num_layers, self.batch_size, self.hidden_size))) + ) + + def test_step(self): + x = torch.rand(self.sequence_length, self.batch_size, self.input_size) + context = torch.rand(self.context_length, self.batch_size, self.hidden_size) + + out, new_context = self.transformer.step(x, context) + + self.assertEqual(out.shape, (self.sequence_length, self.batch_size, self.output_size)) + self.assertEqual(new_context.shape, (self.context_length, self.batch_size, self.hidden_size)) + + def test_forward(self): + x = torch.rand(self.sequence_length, self.batch_size, self.input_size) + hidden_state = self.transformer.reset_hidden_state(self.batch_size) + + out, new_hidden_state = self.transformer.forward(x, hidden_state) + + self.assertEqual(out.shape, (self.sequence_length, self.batch_size, self.output_size)) + self.assertEqual(len(new_hidden_state), 2) + self.assertEqual(new_hidden_state[0].shape, (self.transformer.num_layers, self.batch_size, self.hidden_size)) + self.assertEqual(new_hidden_state[1].shape, (self.transformer.num_layers, self.batch_size, self.hidden_size)) + + +if __name__ == "__main__": + unittest.main()