Compare commits
6 Commits
master
...
algorithms
Author | SHA1 | Date |
---|---|---|
Lukas Schneider | 96393c41c5 | |
Lukas Schneider | c7e950439d | |
Lukas Schneider | 9c0fcdc677 | |
Lukas Schneider | 7ce11711d4 | |
Lukas Schneider | dc9f33a3c3 | |
Lukas Schneider | 96dd4929c5 |
|
@ -7,6 +7,17 @@
|
|||
# cache
|
||||
__pycache__
|
||||
.pytest_cache
|
||||
wandb/
|
||||
|
||||
# vs code
|
||||
.vscode
|
||||
|
||||
# data
|
||||
videos/
|
||||
|
||||
# secrets
|
||||
examples/wandb_config.py
|
||||
|
||||
# docs
|
||||
docs/_build
|
||||
docs/source
|
||||
|
|
|
@ -25,6 +25,7 @@ Please keep the lists sorted alphabetically.
|
|||
* Eric Vollenweider
|
||||
* Fabian Jenelten
|
||||
* Lorenzo Terenzi
|
||||
* Lukas Schneider
|
||||
* Marko Bjelonic
|
||||
* Matthijs van der Boon
|
||||
* Mayank Mittal
|
||||
|
|
97
README.md
97
README.md
|
@ -1,57 +1,78 @@
|
|||
# RSL RL
|
||||
|
||||
Fast and simple implementation of RL algorithms, designed to run fully on GPU.
|
||||
This code is an evolution of `rl-pytorch` provided with NVIDIA's Isaac GYM.
|
||||
|
||||
Only PPO is implemented for now. More algorithms will be added later.
|
||||
Contributions are welcome.
|
||||
Currently, the following algorithms are implemented:
|
||||
- Distributed Distributional DDPG (D4PG)
|
||||
- Deep Deterministic Policy Gradient (DDPG)
|
||||
- Distributional PPO (DPPO)
|
||||
- Distributional Soft Actor Critic (DSAC)
|
||||
- Proximal Policy Optimization (PPO)
|
||||
- Soft Actor Critic (SAC)
|
||||
- Twin Delayed DDPG (TD3)
|
||||
|
||||
**Maintainer**: David Hoeller and Nikita Rudin <br/>
|
||||
**Maintainer**: David Hoeller, Nikita Rudin <br/>
|
||||
**Affiliation**: Robotic Systems Lab, ETH Zurich & NVIDIA <br/>
|
||||
**Contact**: rudinn@ethz.ch
|
||||
**Contact**: Nikita Rudin (rudinn@ethz.ch), Lukas Schneider (lukas@luschneider.com)
|
||||
|
||||
## Setup
|
||||
## Citation
|
||||
|
||||
Following are the instructions to setup the repository for your workspace:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/leggedrobotics/rsl_rl
|
||||
cd rsl_rl
|
||||
pip install -e .
|
||||
If you use our code in your research, please cite us:
|
||||
```
|
||||
@misc{schneider2023learning,
|
||||
archivePrefix={arXiv},
|
||||
author={Lukas Schneider and Jonas Frey and Takahiro Miki and Marco Hutter},
|
||||
eprint={2309.14246},
|
||||
primaryClass={cs.RO}
|
||||
title={Learning Risk-Aware Quadrupedal Locomotion using Distributional Reinforcement Learning},
|
||||
year={2023},
|
||||
}
|
||||
```
|
||||
|
||||
The framework supports the following logging frameworks which can be configured through `logger`:
|
||||
## Installation
|
||||
|
||||
* Tensorboard: https://www.tensorflow.org/tensorboard/
|
||||
* Weights & Biases: https://wandb.ai/site
|
||||
* Neptune: https://docs.neptune.ai/
|
||||
To install the package, run the following command in the root directory of the repository:
|
||||
|
||||
For a demo configuration of the PPO, please check: [dummy_config.yaml](config/dummy_config.yaml) file.
|
||||
```bash
|
||||
$ pip3 install -e .
|
||||
```
|
||||
|
||||
Examples can be run from the `examples/` directory.
|
||||
The example directory also include hyperparameters tuned for some gym environments.
|
||||
These are automatically loaded when running the example.
|
||||
Videos of the trained policies are periodically saved to the `videos/` directory.
|
||||
|
||||
```bash
|
||||
$ python3 examples/example.py
|
||||
```
|
||||
|
||||
To run gym mujoco environments, you need a working installation of the mujoco simulator and [mujoco_py](https://github.com/openai/mujoco-py).
|
||||
|
||||
## Tests
|
||||
|
||||
The repository contains a set of tests to ensure that the algorithms are working as expected.
|
||||
To run the tests, simply execute:
|
||||
|
||||
```bash
|
||||
$ cd tests/ && python -m unittest
|
||||
```
|
||||
|
||||
## Documentation
|
||||
|
||||
To generate documentation, run the following command in the root directory of the repository:
|
||||
|
||||
```bash
|
||||
$ pip3 install sphinx sphinx-rtd-theme
|
||||
$ sphinx-apidoc -o docs/source . ./examples
|
||||
$ cd docs/ && make html
|
||||
```
|
||||
|
||||
## Contribution Guidelines
|
||||
|
||||
For documentation, we adopt the [Google Style Guide](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) for docstrings. We use [Sphinx](https://www.sphinx-doc.org/en/master/) for generating the documentation. Please make sure that your code is well-documented and follows the guidelines.
|
||||
|
||||
We use the following tools for maintaining code quality:
|
||||
|
||||
- [pre-commit](https://pre-commit.com/): Runs a list of formatters and linters over the codebase.
|
||||
- [black](https://black.readthedocs.io/en/stable/): The uncompromising code formatter.
|
||||
- [flake8](https://flake8.pycqa.org/en/latest/): A wrapper around PyFlakes, pycodestyle, and McCabe complexity checker.
|
||||
|
||||
Please check [here](https://pre-commit.com/#install) for instructions to set these up. To run over the entire repository, please execute the following command in the terminal:
|
||||
|
||||
We use [`black`](https://github.com/psf/black) formatter for formatting the python code.
|
||||
You should [configure `black` with VSCode](https://dev.to/adamlombard/how-to-use-the-black-python-code-formatter-in-vscode-3lo0) or you can manually format files with:
|
||||
|
||||
```bash
|
||||
# for installation (only once)
|
||||
pre-commit install
|
||||
# for running
|
||||
pre-commit run --all-files
|
||||
$ pip install black
|
||||
$ black --line-length 120 .
|
||||
```
|
||||
|
||||
### Useful Links
|
||||
|
||||
Environment repositories using the framework:
|
||||
|
||||
* `Legged-Gym` (built on top of NVIDIA Isaac Gym): https://leggedrobotics.github.io/legged_gym/
|
||||
* `Orbit` (built on top of NVIDIA Isaac Sim): https://isaac-orbit.github.io/
|
||||
|
|
|
@ -1,48 +0,0 @@
|
|||
algorithm:
|
||||
class_name: PPO
|
||||
# training parameters
|
||||
# -- value function
|
||||
value_loss_coef: 1.0
|
||||
clip_param: 0.2
|
||||
use_clipped_value_loss: true
|
||||
# -- surrogate loss
|
||||
desired_kl: 0.01
|
||||
entropy_coef: 0.01
|
||||
gamma: 0.99
|
||||
lam: 0.95
|
||||
max_grad_norm: 1.0
|
||||
# -- training
|
||||
learning_rate: 0.001
|
||||
num_learning_epochs: 5
|
||||
num_mini_batches: 4 # mini batch size = num_envs * num_steps / num_mini_batches
|
||||
schedule: adaptive # adaptive, fixed
|
||||
policy:
|
||||
class_name: ActorCritic
|
||||
# for MLP i.e. `ActorCritic`
|
||||
activation: elu
|
||||
actor_hidden_dims: [128, 128, 128]
|
||||
critic_hidden_dims: [128, 128, 128]
|
||||
init_noise_std: 1.0
|
||||
# only needed for `ActorCriticRecurrent`
|
||||
# rnn_type: 'lstm'
|
||||
# rnn_hidden_size: 512
|
||||
# rnn_num_layers: 1
|
||||
runner:
|
||||
num_steps_per_env: 24 # number of steps per environment per iteration
|
||||
max_iterations: 1500 # number of policy updates
|
||||
empirical_normalization: false
|
||||
# -- logging parameters
|
||||
save_interval: 50 # check for potential saves every `save_interval` iterations
|
||||
experiment_name: walking_experiment
|
||||
run_name: ""
|
||||
# -- logging writer
|
||||
logger: tensorboard # tensorboard, neptune, wandb
|
||||
neptune_project: legged_gym
|
||||
wandb_project: legged_gym
|
||||
# -- load and resuming
|
||||
resume: false
|
||||
load_run: -1 # -1 means load latest run
|
||||
resume_path: null # updated from load_run and checkpoint
|
||||
checkpoint: -1 # -1 means load latest checkpoint
|
||||
runner_class_name: OnPolicyRunner
|
||||
seed: 1
|
|
@ -0,0 +1,20 @@
|
|||
# Minimal makefile for Sphinx documentation
|
||||
#
|
||||
|
||||
# You can set these variables from the command line, and also
|
||||
# from the environment for the first two.
|
||||
SPHINXOPTS ?=
|
||||
SPHINXBUILD ?= sphinx-build
|
||||
SOURCEDIR = .
|
||||
BUILDDIR = _build
|
||||
|
||||
# Put it first so that "make" without argument is like "make help".
|
||||
help:
|
||||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
|
||||
.PHONY: help Makefile
|
||||
|
||||
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||
%: Makefile
|
||||
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
|
@ -0,0 +1,32 @@
|
|||
# Configuration file for the Sphinx documentation builder.
|
||||
#
|
||||
# For the full list of built-in configuration values, see the documentation:
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.abspath(".."))
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
||||
|
||||
project = "rsl_rl"
|
||||
copyright = "2023, Lukas Schneider"
|
||||
author = "Lukas Schneider"
|
||||
release = "1.0.0"
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
||||
|
||||
extensions = ["sphinx.ext.autodoc", "sphinx.ext.napoleon"]
|
||||
|
||||
templates_path = ["_templates"]
|
||||
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
||||
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
|
||||
|
||||
html_theme = "sphinx_rtd_theme"
|
||||
html_static_path = ["_static"]
|
|
@ -0,0 +1,20 @@
|
|||
.. rsl_rl documentation master file, created by
|
||||
sphinx-quickstart on Tue Jul 4 16:39:24 2023.
|
||||
You can adapt this file completely to your liking, but it should at least
|
||||
contain the root `toctree` directive.
|
||||
|
||||
Welcome to rsl_rl's documentation!
|
||||
==================================
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: Contents:
|
||||
|
||||
|
||||
|
||||
Indices and tables
|
||||
==================
|
||||
|
||||
* :ref:`genindex`
|
||||
* :ref:`modindex`
|
||||
* :ref:`search`
|
|
@ -0,0 +1,35 @@
|
|||
@ECHO OFF
|
||||
|
||||
pushd %~dp0
|
||||
|
||||
REM Command file for Sphinx documentation
|
||||
|
||||
if "%SPHINXBUILD%" == "" (
|
||||
set SPHINXBUILD=sphinx-build
|
||||
)
|
||||
set SOURCEDIR=.
|
||||
set BUILDDIR=_build
|
||||
|
||||
%SPHINXBUILD% >NUL 2>NUL
|
||||
if errorlevel 9009 (
|
||||
echo.
|
||||
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
||||
echo.installed, then set the SPHINXBUILD environment variable to point
|
||||
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
||||
echo.may add the Sphinx directory to PATH.
|
||||
echo.
|
||||
echo.If you don't have Sphinx installed, grab it from
|
||||
echo.https://www.sphinx-doc.org/
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
if "%1" == "" goto help
|
||||
|
||||
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
goto end
|
||||
|
||||
:help
|
||||
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
|
||||
:end
|
||||
popd
|
|
@ -0,0 +1,92 @@
|
|||
import numpy as np
|
||||
import os
|
||||
import torch
|
||||
import wandb
|
||||
|
||||
from rsl_rl.algorithms import *
|
||||
from rsl_rl.env.gym_env import GymEnv
|
||||
from rsl_rl.runners.runner import Runner
|
||||
from rsl_rl.runners.callbacks import make_wandb_cb
|
||||
|
||||
from hyperparams import hyperparams
|
||||
from wandb_config import WANDB_API_KEY, WANDB_ENTITY
|
||||
|
||||
|
||||
ALGORITHMS = [PPO, DPPO]
|
||||
ENVIRONMENTS = ["BipedalWalker-v3"]
|
||||
ENVIRONMENT_KWARGS = [{}]
|
||||
EXPERIMENT_DIR = os.environ.get("EXPERIMENT_DIRECTORY", "./")
|
||||
EXPERIMENT_NAME = os.environ.get("EXPERIMENT_NAME", "benchmark")
|
||||
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
RENDER_VIDEO = False
|
||||
RETURN_EPOCHS = 100 # Number of epochs to average return over
|
||||
LOG_WANDB = True
|
||||
RUNS = 3
|
||||
TRAIN_TIMEOUT = 60 * 10 # Training time (in seconds)
|
||||
TRAIN_ENV_STEPS = None # Number of training environment steps
|
||||
|
||||
|
||||
os.environ["WANDB_API_KEY"] = WANDB_API_KEY
|
||||
|
||||
|
||||
def run(alg_class, env_name, env_kwargs={}):
|
||||
try:
|
||||
hp = hyperparams[alg_class.__name__][env_name]
|
||||
except KeyError:
|
||||
print("No hyperparameters found. Using default values.")
|
||||
hp = dict(agent_kwargs={}, env_kwargs={"environment_count": 1}, runner_kwargs={"num_steps_per_env": 1})
|
||||
|
||||
agent_kwargs = dict(device=DEVICE, **hp["agent_kwargs"])
|
||||
env_kwargs = dict(name=env_name, gym_kwargs=env_kwargs, **hp["env_kwargs"])
|
||||
runner_kwargs = dict(device=DEVICE, **hp["runner_kwargs"])
|
||||
|
||||
learn_steps = (
|
||||
None
|
||||
if TRAIN_ENV_STEPS is None
|
||||
else int(np.ceil(TRAIN_ENV_STEPS / (env_kwargs["environment_count"] * runner_kwargs["num_steps_per_env"])))
|
||||
)
|
||||
learn_timeout = None if TRAIN_TIMEOUT is None else TRAIN_TIMEOUT
|
||||
|
||||
video_directory = f"{EXPERIMENT_DIR}/{EXPERIMENT_NAME}/videos/{env_name}/{alg_class.__name__}"
|
||||
save_video_cb = (
|
||||
lambda ep, file: wandb.log({f"video-{ep}": wandb.Video(file, fps=4, format="mp4")}) if LOG_WANDB else None
|
||||
)
|
||||
env = GymEnv(**env_kwargs, draw=RENDER_VIDEO, draw_cb=save_video_cb, draw_directory=video_directory)
|
||||
agent = alg_class(env, **agent_kwargs)
|
||||
|
||||
config = dict(
|
||||
agent_kwargs=agent_kwargs,
|
||||
env_kwargs=env_kwargs,
|
||||
learn_steps=learn_steps,
|
||||
learn_timeout=learn_timeout,
|
||||
runner_kwargs=runner_kwargs,
|
||||
)
|
||||
wandb_learn_config = dict(
|
||||
config=config,
|
||||
entity=WANDB_ENTITY,
|
||||
group=f"{alg_class.__name__}_{env_name}",
|
||||
project="rsl_rl-benchmark",
|
||||
tags=[alg_class.__name__, env_name, "train"],
|
||||
)
|
||||
|
||||
runner = Runner(env, agent, **runner_kwargs)
|
||||
runner._learn_cb = [lambda *args, **kwargs: Runner._log(*args, prefix=f"{alg_class.__name__}_{env_name}", **kwargs)]
|
||||
if LOG_WANDB:
|
||||
runner._learn_cb.append(make_wandb_cb(wandb_learn_config))
|
||||
|
||||
runner.learn(iterations=learn_steps, timeout=learn_timeout, return_epochs=RETURN_EPOCHS)
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
def main():
|
||||
for algorithm in ALGORITHMS:
|
||||
for i, env_name in enumerate(ENVIRONMENTS):
|
||||
env_kwargs = ENVIRONMENT_KWARGS[i]
|
||||
|
||||
for _ in range(RUNS):
|
||||
run(algorithm, env_name, env_kwargs=env_kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,26 @@
|
|||
import torch
|
||||
|
||||
from rsl_rl.algorithms import *
|
||||
from rsl_rl.env.gym_env import GymEnv
|
||||
from rsl_rl.runners.runner import Runner
|
||||
from hyperparams import hyperparams
|
||||
|
||||
|
||||
ALGORITHM = DPPO
|
||||
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
TASK = "BipedalWalker-v3"
|
||||
|
||||
|
||||
def main():
|
||||
hp = hyperparams[ALGORITHM.__name__][TASK]
|
||||
|
||||
env = GymEnv(name=TASK, device=DEVICE, draw=True, **hp["env_kwargs"])
|
||||
agent = ALGORITHM(env, benchmark=True, device=DEVICE, **hp["agent_kwargs"])
|
||||
runner = Runner(env, agent, device=DEVICE, **hp["runner_kwargs"])
|
||||
runner._learn_cb = [Runner._log]
|
||||
|
||||
runner.learn(5000)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,7 @@
|
|||
from rsl_rl.algorithms import PPO, DPPO
|
||||
from .dppo import dppo_hyperparams
|
||||
from .ppo import ppo_hyperparams
|
||||
|
||||
hyperparams = {DPPO.__name__: dppo_hyperparams, PPO.__name__: ppo_hyperparams}
|
||||
|
||||
__all__ = ["hyperparams"]
|
|
@ -0,0 +1,202 @@
|
|||
import copy
|
||||
import numpy as np
|
||||
|
||||
from rsl_rl.algorithms import DPPO
|
||||
from rsl_rl.modules import QuantileNetwork
|
||||
|
||||
default = dict()
|
||||
default["env_kwargs"] = dict(environment_count=1)
|
||||
default["runner_kwargs"] = dict(num_steps_per_env=2048)
|
||||
default["agent_kwargs"] = dict(
|
||||
actor_activations=["tanh", "tanh", "linear"],
|
||||
actor_hidden_dims=[64, 64],
|
||||
actor_input_normalization=False,
|
||||
actor_noise_std=np.exp(0.0),
|
||||
batch_count=(default["env_kwargs"]["environment_count"] * default["runner_kwargs"]["num_steps_per_env"] // 64),
|
||||
clip_ratio=0.2,
|
||||
critic_activations=["tanh", "tanh"],
|
||||
critic_hidden_dims=[64, 64],
|
||||
critic_input_normalization=False,
|
||||
entropy_coeff=0.0,
|
||||
gae_lambda=0.95,
|
||||
gamma=0.99,
|
||||
gradient_clip=0.5,
|
||||
learning_rate=0.0003,
|
||||
qrdqn_quantile_count=50,
|
||||
schedule="adaptive",
|
||||
target_kl=0.01,
|
||||
value_coeff=0.5,
|
||||
value_measure=QuantileNetwork.measure_neutral,
|
||||
value_measure_kwargs={},
|
||||
)
|
||||
|
||||
# Parameters optimized for PPO
|
||||
ant_v4 = copy.deepcopy(default)
|
||||
ant_v4["env_kwargs"]["environment_count"] = 128
|
||||
ant_v4["runner_kwargs"]["num_steps_per_env"] = 64
|
||||
ant_v4["agent_kwargs"]["actor_activations"] = ["tanh", "tanh", "linear"]
|
||||
ant_v4["agent_kwargs"]["actor_hidden_dims"] = [64, 64]
|
||||
ant_v4["agent_kwargs"]["actor_noise_std"] = 0.2611
|
||||
ant_v4["agent_kwargs"]["batch_count"] = 12
|
||||
ant_v4["agent_kwargs"]["clip_ratio"] = 0.4
|
||||
ant_v4["agent_kwargs"]["critic_activations"] = ["tanh", "tanh"]
|
||||
ant_v4["agent_kwargs"]["critic_hidden_dims"] = [64, 64]
|
||||
ant_v4["agent_kwargs"]["entropy_coeff"] = 0.0102
|
||||
ant_v4["agent_kwargs"]["gae_lambda"] = 0.92
|
||||
ant_v4["agent_kwargs"]["gamma"] = 0.9731
|
||||
ant_v4["agent_kwargs"]["gradient_clip"] = 5.0
|
||||
ant_v4["agent_kwargs"]["learning_rate"] = 0.8755
|
||||
ant_v4["agent_kwargs"]["target_kl"] = 0.1711
|
||||
ant_v4["agent_kwargs"]["value_coeff"] = 0.6840
|
||||
|
||||
"""
|
||||
Tuned for environment interactions:
|
||||
[I 2023-01-03 03:11:29,212] Trial 19 finished with value: 0.5272218152693111 and parameters: {
|
||||
'env_count': 16,
|
||||
'actor_noise_std': 0.7304437880901905,
|
||||
'batch_count': 10,
|
||||
'clip_ratio': 0.3,
|
||||
'entropy_coeff': 0.004236574285220795,
|
||||
'gae_lambda': 0.95,
|
||||
'gamma': 0.9890074826092162,
|
||||
'gradient_clip': 0.9,
|
||||
'learning_rate': 0.18594043324129061,
|
||||
'steps_per_env': 256,
|
||||
'target_kl': 0.05838576142010138,
|
||||
'value_coeff': 0.14402022632575992,
|
||||
'net_arch': 'small',
|
||||
'net_activation': 'relu'
|
||||
}. Best is trial 19 with value: 0.5272218152693111.
|
||||
Tuned for training time:
|
||||
[I 2023-01-08 21:09:06,069] Trial 407 finished with value: 7.497591958940029 and parameters: {
|
||||
'actor_noise_std': 0.1907398121300662,
|
||||
'batch_count': 3,
|
||||
'clip_ratio': 0.1,
|
||||
'entropy_coeff': 0.0053458057035692735,
|
||||
'env_count': 16,
|
||||
'gae_lambda': 0.8,
|
||||
'gamma': 0.985000267068182,
|
||||
'gradient_clip': 2.0,
|
||||
'learning_rate': 0.605956844400053,
|
||||
'steps_per_env': 512,
|
||||
'target_kl': 0.17611450607281642,
|
||||
'value_coeff': 0.46015664905111847,
|
||||
'actor_net_arch': 'small',
|
||||
'critic_net_arch': 'medium',
|
||||
'actor_net_activation': 'relu',
|
||||
'critic_net_activation': 'relu',
|
||||
'qrdqn_quantile_count': 200,
|
||||
'value_measure': 'neutral'
|
||||
}. Best is trial 407 with value: 7.497591958940029.
|
||||
"""
|
||||
bipedal_walker_v3 = copy.deepcopy(default)
|
||||
bipedal_walker_v3["env_kwargs"]["environment_count"] = 256
|
||||
bipedal_walker_v3["runner_kwargs"]["num_steps_per_env"] = 16
|
||||
bipedal_walker_v3["agent_kwargs"]["actor_activations"] = ["relu", "relu", "relu", "linear"]
|
||||
bipedal_walker_v3["agent_kwargs"]["actor_hidden_dims"] = [512, 256, 128]
|
||||
bipedal_walker_v3["agent_kwargs"]["actor_noise_std"] = 0.8505
|
||||
bipedal_walker_v3["agent_kwargs"]["batch_count"] = 10
|
||||
bipedal_walker_v3["agent_kwargs"]["clip_ratio"] = 0.1
|
||||
bipedal_walker_v3["agent_kwargs"]["critic_activations"] = ["relu", "relu"]
|
||||
bipedal_walker_v3["agent_kwargs"]["critic_hidden_dims"] = [256, 256]
|
||||
bipedal_walker_v3["agent_kwargs"]["critic_network"] = DPPO.network_qrdqn
|
||||
bipedal_walker_v3["agent_kwargs"]["entropy_coeff"] = 0.0917
|
||||
bipedal_walker_v3["agent_kwargs"]["gae_lambda"] = 0.95
|
||||
bipedal_walker_v3["agent_kwargs"]["gamma"] = 0.9553
|
||||
bipedal_walker_v3["agent_kwargs"]["gradient_clip"] = 2.0
|
||||
bipedal_walker_v3["agent_kwargs"]["iqn_action_samples"] = 32
|
||||
bipedal_walker_v3["agent_kwargs"]["iqn_embedding_size"] = 64
|
||||
bipedal_walker_v3["agent_kwargs"]["iqn_feature_layers"] = 1
|
||||
bipedal_walker_v3["agent_kwargs"]["iqn_value_samples"] = 8
|
||||
bipedal_walker_v3["agent_kwargs"]["learning_rate"] = 0.4762
|
||||
bipedal_walker_v3["agent_kwargs"]["qrdqn_quantile_count"] = 200
|
||||
bipedal_walker_v3["agent_kwargs"]["recurrent"] = False
|
||||
bipedal_walker_v3["agent_kwargs"]["target_kl"] = 0.1999
|
||||
bipedal_walker_v3["agent_kwargs"]["value_coeff"] = 0.4435
|
||||
|
||||
"""
|
||||
[I 2023-01-12 08:01:35,514] Trial 476 finished with value: 5202.960759290059 and parameters: {
|
||||
'actor_noise_std': 0.15412869066185989,
|
||||
'batch_count': 11,
|
||||
'clip_ratio': 0.3,
|
||||
'entropy_coeff': 0.036031209302206955,
|
||||
'env_count': 128,
|
||||
'gae_lambda': 0.92,
|
||||
'gamma': 0.973937576989299,
|
||||
'gradient_clip': 5.0,
|
||||
'learning_rate': 0.1621249118505433,
|
||||
'steps_per_env': 128,
|
||||
'target_kl': 0.05054738172852222,
|
||||
'value_coeff': 0.1647632125820593,
|
||||
'actor_net_arch': 'small',
|
||||
'critic_net_arch': 'medium',
|
||||
'actor_net_activation': 'tanh',
|
||||
'critic_net_activation': 'relu',
|
||||
'qrdqn_quantile_count': 50,
|
||||
'value_measure': 'var-risk-averse'
|
||||
}. Best is trial 476 with value: 5202.960759290059.
|
||||
"""
|
||||
half_cheetah_v4 = copy.deepcopy(default)
|
||||
half_cheetah_v4["env_kwargs"]["environment_count"] = 128
|
||||
half_cheetah_v4["runner_kwargs"]["num_steps_per_env"] = 128
|
||||
half_cheetah_v4["agent_kwargs"]["actor_activations"] = ["tanh", "tanh", "linear"]
|
||||
half_cheetah_v4["agent_kwargs"]["actor_hidden_dims"] = [64, 64]
|
||||
half_cheetah_v4["agent_kwargs"]["actor_noise_std"] = 0.1541
|
||||
half_cheetah_v4["agent_kwargs"]["batch_count"] = 11
|
||||
half_cheetah_v4["agent_kwargs"]["clip_ratio"] = 0.3
|
||||
half_cheetah_v4["agent_kwargs"]["critic_activations"] = ["relu", "relu"]
|
||||
half_cheetah_v4["agent_kwargs"]["critic_hidden_dims"] = [256, 256]
|
||||
half_cheetah_v4["agent_kwargs"]["entropy_coeff"] = 0.03603
|
||||
half_cheetah_v4["agent_kwargs"]["gae_lambda"] = 0.92
|
||||
half_cheetah_v4["agent_kwargs"]["gamma"] = 0.9739
|
||||
half_cheetah_v4["agent_kwargs"]["gradient_clip"] = 5.0
|
||||
half_cheetah_v4["agent_kwargs"]["learning_rate"] = 0.1621
|
||||
half_cheetah_v4["agent_kwargs"]["qrdqn_quantile_count"] = 50
|
||||
half_cheetah_v4["agent_kwargs"]["target_kl"] = 0.0505
|
||||
half_cheetah_v4["agent_kwargs"]["value_coeff"] = 0.1648
|
||||
half_cheetah_v4["agent_kwargs"]["value_measure"] = QuantileNetwork.measure_percentile
|
||||
half_cheetah_v4["agent_kwargs"]["value_measure_kwargs"] = dict(confidence_level=0.25)
|
||||
|
||||
|
||||
# Parameters optimized for PPO
|
||||
hopper_v4 = copy.deepcopy(default)
|
||||
hopper_v4["runner_kwargs"]["num_steps_per_env"] = 128
|
||||
hopper_v4["agent_kwargs"]["actor_activations"] = ["relu", "relu", "linear"]
|
||||
hopper_v4["agent_kwargs"]["actor_hidden_dims"] = [256, 256]
|
||||
hopper_v4["agent_kwargs"]["actor_noise_std"] = 0.5590
|
||||
hopper_v4["agent_kwargs"]["batch_count"] = 15
|
||||
hopper_v4["agent_kwargs"]["clip_ratio"] = 0.2
|
||||
hopper_v4["agent_kwargs"]["critic_activations"] = ["relu", "relu", "linear"]
|
||||
hopper_v4["agent_kwargs"]["critic_hidden_dims"] = [32, 32]
|
||||
hopper_v4["agent_kwargs"]["entropy_coeff"] = 0.03874
|
||||
hopper_v4["agent_kwargs"]["gae_lambda"] = 0.98
|
||||
hopper_v4["agent_kwargs"]["gamma"] = 0.9890
|
||||
hopper_v4["agent_kwargs"]["gradient_clip"] = 1.0
|
||||
hopper_v4["agent_kwargs"]["learning_rate"] = 0.3732
|
||||
hopper_v4["agent_kwargs"]["value_coeff"] = 0.8163
|
||||
|
||||
swimmer_v4 = copy.deepcopy(default)
|
||||
swimmer_v4["agent_kwargs"]["gamma"] = 0.9999
|
||||
|
||||
walker2d_v4 = copy.deepcopy(default)
|
||||
walker2d_v4["runner_kwargs"]["num_steps_per_env"] = 512
|
||||
walker2d_v4["agent_kwargs"]["batch_count"] = (
|
||||
walker2d_v4["env_kwargs"]["environment_count"] * walker2d_v4["runner_kwargs"]["num_steps_per_env"] // 32
|
||||
)
|
||||
walker2d_v4["agent_kwargs"]["clip_ratio"] = 0.1
|
||||
walker2d_v4["agent_kwargs"]["entropy_coeff"] = 0.000585045
|
||||
walker2d_v4["agent_kwargs"]["gae_lambda"] = 0.95
|
||||
walker2d_v4["agent_kwargs"]["gamma"] = 0.99
|
||||
walker2d_v4["agent_kwargs"]["gradient_clip"] = 1.0
|
||||
walker2d_v4["agent_kwargs"]["learning_rate"] = 5.05041e-05
|
||||
walker2d_v4["agent_kwargs"]["value_coeff"] = 0.871923
|
||||
|
||||
dppo_hyperparams = {
|
||||
"default": default,
|
||||
"Ant-v4": ant_v4,
|
||||
"BipedalWalker-v3": bipedal_walker_v3,
|
||||
"HalfCheetah-v4": half_cheetah_v4,
|
||||
"Hopper-v4": hopper_v4,
|
||||
"Swimmer-v4": swimmer_v4,
|
||||
"Walker2d-v4": walker2d_v4,
|
||||
}
|
|
@ -0,0 +1,226 @@
|
|||
import copy
|
||||
import numpy as np
|
||||
|
||||
default = dict()
|
||||
default["env_kwargs"] = dict(environment_count=1)
|
||||
default["runner_kwargs"] = dict(num_steps_per_env=2048)
|
||||
default["agent_kwargs"] = dict(
|
||||
actor_activations=["tanh", "tanh", "linear"],
|
||||
actor_hidden_dims=[64, 64],
|
||||
actor_input_normalization=False,
|
||||
actor_noise_std=np.exp(0.0),
|
||||
batch_count=(default["env_kwargs"]["environment_count"] * default["runner_kwargs"]["num_steps_per_env"] // 64),
|
||||
clip_ratio=0.2,
|
||||
critic_activations=["tanh", "tanh", "linear"],
|
||||
critic_hidden_dims=[64, 64],
|
||||
critic_input_normalization=False,
|
||||
entropy_coeff=0.0,
|
||||
gae_lambda=0.95,
|
||||
gamma=0.99,
|
||||
gradient_clip=0.5,
|
||||
learning_rate=0.0003,
|
||||
schedule="adaptive",
|
||||
target_kl=0.01,
|
||||
value_coeff=0.5,
|
||||
)
|
||||
|
||||
"""
|
||||
[I 2023-01-09 00:33:02,217] Trial 85 finished with value: 2191.0249068421276 and parameters: {
|
||||
'actor_noise_std': 0.2611334861249876,
|
||||
'batch_count': 12,
|
||||
'clip_ratio': 0.4,
|
||||
'entropy_coeff': 0.010204149626344796,
|
||||
'env_count': 128,
|
||||
'gae_lambda': 0.92,
|
||||
'gamma': 0.9730549104215155,
|
||||
'gradient_clip': 5.0,
|
||||
'learning_rate': 0.8754540531090014,
|
||||
'steps_per_env': 64,
|
||||
'target_kl': 0.17110535070344035,
|
||||
'value_coeff': 0.6840401569818934,
|
||||
'actor_net_arch': 'small',
|
||||
'critic_net_arch': 'small',
|
||||
'actor_net_activation': 'tanh',
|
||||
'critic_net_activation': 'tanh'
|
||||
}. Best is trial 85 with value: 2191.0249068421276.
|
||||
"""
|
||||
ant_v3 = copy.deepcopy(default)
|
||||
ant_v3["env_kwargs"]["environment_count"] = 128
|
||||
ant_v3["runner_kwargs"]["num_steps_per_env"] = 64
|
||||
ant_v3["agent_kwargs"]["actor_activations"] = ["tanh", "tanh", "linear"]
|
||||
ant_v3["agent_kwargs"]["actor_hidden_dims"] = [64, 64]
|
||||
ant_v3["agent_kwargs"]["actor_noise_std"] = 0.2611
|
||||
ant_v3["agent_kwargs"]["batch_count"] = 12
|
||||
ant_v3["agent_kwargs"]["clip_ratio"] = 0.4
|
||||
ant_v3["agent_kwargs"]["critic_activations"] = ["tanh", "tanh", "linear"]
|
||||
ant_v3["agent_kwargs"]["critic_hidden_dims"] = [64, 64]
|
||||
ant_v3["agent_kwargs"]["entropy_coeff"] = 0.0102
|
||||
ant_v3["agent_kwargs"]["gae_lambda"] = 0.92
|
||||
ant_v3["agent_kwargs"]["gamma"] = 0.9731
|
||||
ant_v3["agent_kwargs"]["gradient_clip"] = 5.0
|
||||
ant_v3["agent_kwargs"]["learning_rate"] = 0.8755
|
||||
ant_v3["agent_kwargs"]["target_kl"] = 0.1711
|
||||
ant_v3["agent_kwargs"]["value_coeff"] = 0.6840
|
||||
|
||||
"""
|
||||
Standard:
|
||||
[I 2023-01-17 07:43:46,884] Trial 125 finished with value: 150.23491836690064 and parameters: {
|
||||
'actor_net_activation': 'relu',
|
||||
'actor_net_arch': 'large',
|
||||
'actor_noise_std': 0.8504545432069994,
|
||||
'batch_count': 10,
|
||||
'clip_ratio': 0.1,
|
||||
'critic_net_activation': 'relu',
|
||||
'critic_net_arch': 'medium',
|
||||
'entropy_coeff': 0.0916881539697197,
|
||||
'env_count': 256,
|
||||
'gae_lambda': 0.95,
|
||||
'gamma': 0.955285858564339,
|
||||
'gradient_clip': 2.0,
|
||||
'learning_rate': 0.4762365866431558,
|
||||
'steps_per_env': 16,
|
||||
'recurrent': False,
|
||||
'target_kl': 0.19991906392721126,
|
||||
'value_coeff': 0.4434793554275927
|
||||
}. Best is trial 125 with value: 150.23491836690064.
|
||||
Hardcore:
|
||||
[I 2023-01-09 06:25:44,000] Trial 262 finished with value: 2.290071208278338 and parameters: {
|
||||
'actor_noise_std': 0.2710521003644249,
|
||||
'batch_count': 6,
|
||||
'clip_ratio': 0.1,
|
||||
'entropy_coeff': 0.005105282891378981,
|
||||
'env_count': 16,
|
||||
'gae_lambda': 1.0,
|
||||
'gamma': 0.9718119008688937,
|
||||
'gradient_clip': 0.1,
|
||||
'learning_rate': 0.4569184610431825,
|
||||
'steps_per_env': 256,
|
||||
'target_kl': 0.11068348002480229,
|
||||
'value_coeff': 0.19453900570701116,
|
||||
'actor_net_arch': 'small',
|
||||
'critic_net_arch': 'medium',
|
||||
'actor_net_activation': 'relu',
|
||||
'critic_net_activation': 'relu'
|
||||
}. Best is trial 262 with value: 2.290071208278338.
|
||||
"""
|
||||
bipedal_walker_v3 = copy.deepcopy(default)
|
||||
bipedal_walker_v3["env_kwargs"]["environment_count"] = 256
|
||||
bipedal_walker_v3["runner_kwargs"]["num_steps_per_env"] = 16
|
||||
bipedal_walker_v3["agent_kwargs"]["actor_activations"] = ["relu", "relu", "relu", "linear"]
|
||||
bipedal_walker_v3["agent_kwargs"]["actor_hidden_dims"] = [512, 256, 128]
|
||||
bipedal_walker_v3["agent_kwargs"]["actor_noise_std"] = 0.8505
|
||||
bipedal_walker_v3["agent_kwargs"]["batch_count"] = 10
|
||||
bipedal_walker_v3["agent_kwargs"]["clip_ratio"] = 0.1
|
||||
bipedal_walker_v3["agent_kwargs"]["critic_activations"] = ["relu", "relu", "linear"]
|
||||
bipedal_walker_v3["agent_kwargs"]["critic_hidden_dims"] = [256, 256]
|
||||
bipedal_walker_v3["agent_kwargs"]["entropy_coeff"] = 0.0917
|
||||
bipedal_walker_v3["agent_kwargs"]["gae_lambda"] = 0.95
|
||||
bipedal_walker_v3["agent_kwargs"]["gamma"] = 0.9553
|
||||
bipedal_walker_v3["agent_kwargs"]["gradient_clip"] = 2.0
|
||||
bipedal_walker_v3["agent_kwargs"]["learning_rate"] = 0.4762
|
||||
bipedal_walker_v3["agent_kwargs"]["target_kl"] = 0.1999
|
||||
bipedal_walker_v3["agent_kwargs"]["value_coeff"] = 0.4435
|
||||
|
||||
"""
|
||||
[I 2023-01-04 05:57:20,749] Trial 1451 finished with value: 5260.338678148058 and parameters: {
|
||||
'env_count': 32,
|
||||
'actor_noise_std': 0.3397405098274084,
|
||||
'batch_count': 6,
|
||||
'clip_ratio': 0.3,
|
||||
'entropy_coeff': 0.009392937880259133,
|
||||
'gae_lambda': 0.8,
|
||||
'gamma': 0.9683403243382301,
|
||||
'gradient_clip': 5.0,
|
||||
'learning_rate': 0.5985206877398142,
|
||||
'steps_per_env': 16,
|
||||
'target_kl': 0.027651917189297347,
|
||||
'value_coeff': 0.26705235341068373,
|
||||
'net_arch': 'medium',
|
||||
'net_activation': 'tanh'
|
||||
}. Best is trial 1451 with value: 5260.338678148058.
|
||||
"""
|
||||
half_cheetah_v3 = copy.deepcopy(default)
|
||||
half_cheetah_v3["env_kwargs"]["environment_count"] = 32
|
||||
half_cheetah_v3["runner_kwargs"]["num_steps_per_env"] = 16
|
||||
half_cheetah_v3["agent_kwargs"]["actor_activations"] = ["tanh", "tanh", "linear"]
|
||||
half_cheetah_v3["agent_kwargs"]["actor_hidden_dims"] = [256, 256]
|
||||
half_cheetah_v3["agent_kwargs"]["actor_noise_std"] = 0.3397
|
||||
half_cheetah_v3["agent_kwargs"]["batch_count"] = 6
|
||||
half_cheetah_v3["agent_kwargs"]["clip_ratio"] = 0.3
|
||||
half_cheetah_v3["agent_kwargs"]["critic_activations"] = ["tanh", "tanh", "linear"]
|
||||
half_cheetah_v3["agent_kwargs"]["critic_hidden_dims"] = [256, 256]
|
||||
half_cheetah_v3["agent_kwargs"]["entropy_coeff"] = 0.009393
|
||||
half_cheetah_v3["agent_kwargs"]["gae_lambda"] = 0.8
|
||||
half_cheetah_v3["agent_kwargs"]["gamma"] = 0.9683
|
||||
half_cheetah_v3["agent_kwargs"]["gradient_clip"] = 5.0
|
||||
half_cheetah_v3["agent_kwargs"]["learning_rate"] = 0.5985
|
||||
half_cheetah_v3["agent_kwargs"]["target_kl"] = 0.02765
|
||||
half_cheetah_v3["agent_kwargs"]["value_coeff"] = 0.2671
|
||||
|
||||
"""
|
||||
[I 2023-01-08 18:38:51,481] Trial 25 finished with value: 2225.9547948810073 and parameters: {
|
||||
'actor_noise_std': 0.5589708917145111,
|
||||
'batch_count': 15,
|
||||
'clip_ratio': 0.2,
|
||||
'entropy_coeff': 0.03874027035272886,
|
||||
'env_count': 128,
|
||||
'gae_lambda': 0.98,
|
||||
'gamma': 0.9879577396280973,
|
||||
'gradient_clip': 1.0,
|
||||
'learning_rate': 0.3732431793266761,
|
||||
'steps_per_env': 128,
|
||||
'target_kl': 0.12851506672519566,
|
||||
'value_coeff': 0.8162548885723906,
|
||||
'actor_net_arch': 'medium',
|
||||
'critic_net_arch': 'small',
|
||||
'actor_net_activation': 'relu',
|
||||
'critic_net_activation': 'relu'
|
||||
}. Best is trial 25 with value: 2225.9547948810073.
|
||||
"""
|
||||
hopper_v3 = copy.deepcopy(default)
|
||||
half_cheetah_v3["env_kwargs"]["environment_count"] = 128
|
||||
hopper_v3["runner_kwargs"]["num_steps_per_env"] = 128
|
||||
hopper_v3["agent_kwargs"]["actor_activations"] = ["relu", "relu", "linear"]
|
||||
hopper_v3["agent_kwargs"]["actor_hidden_dims"] = [256, 256]
|
||||
hopper_v3["agent_kwargs"]["actor_noise_std"] = 0.5590
|
||||
hopper_v3["agent_kwargs"]["batch_count"] = 15
|
||||
hopper_v3["agent_kwargs"]["clip_ratio"] = 0.2
|
||||
hopper_v3["agent_kwargs"]["critic_activations"] = ["relu", "relu", "linear"]
|
||||
hopper_v3["agent_kwargs"]["critic_hidden_dims"] = [32, 32]
|
||||
hopper_v3["agent_kwargs"]["entropy_coeff"] = 0.03874
|
||||
hopper_v3["agent_kwargs"]["gae_lambda"] = 0.98
|
||||
hopper_v3["agent_kwargs"]["gamma"] = 0.9890
|
||||
hopper_v3["agent_kwargs"]["gradient_clip"] = 1.0
|
||||
hopper_v3["agent_kwargs"]["learning_rate"] = 0.3732
|
||||
hopper_v3["agent_kwargs"]["value_coeff"] = 0.8163
|
||||
|
||||
swimmer_v3 = copy.deepcopy(default)
|
||||
swimmer_v3["agent_kwargs"]["gamma"] = 0.9999
|
||||
|
||||
walker2d_v3 = copy.deepcopy(default)
|
||||
walker2d_v3["runner_kwargs"]["num_steps_per_env"] = 512
|
||||
walker2d_v3["agent_kwargs"]["batch_count"] = (
|
||||
walker2d_v3["env_kwargs"]["environment_count"] * walker2d_v3["runner_kwargs"]["num_steps_per_env"] // 32
|
||||
)
|
||||
walker2d_v3["agent_kwargs"]["clip_ratio"] = 0.1
|
||||
walker2d_v3["agent_kwargs"]["entropy_coeff"] = 0.000585045
|
||||
walker2d_v3["agent_kwargs"]["gae_lambda"] = 0.95
|
||||
walker2d_v3["agent_kwargs"]["gamma"] = 0.99
|
||||
walker2d_v3["agent_kwargs"]["gradient_clip"] = 1.0
|
||||
walker2d_v3["agent_kwargs"]["learning_rate"] = 5.05041e-05
|
||||
walker2d_v3["agent_kwargs"]["value_coeff"] = 0.871923
|
||||
|
||||
ppo_hyperparams = {
|
||||
"default": default,
|
||||
"Ant-v3": ant_v3,
|
||||
"Ant-v4": ant_v3,
|
||||
"BipedalWalker-v3": bipedal_walker_v3,
|
||||
"HalfCheetah-v3": half_cheetah_v3,
|
||||
"HalfCheetah-v4": half_cheetah_v3,
|
||||
"Hopper-v3": hopper_v3,
|
||||
"Hopper-v4": hopper_v3,
|
||||
"Swimmer-v3": swimmer_v3,
|
||||
"Swimmer-v4": swimmer_v3,
|
||||
"Walker2d-v3": walker2d_v3,
|
||||
"Walker2d-v4": walker2d_v3,
|
||||
}
|
|
@ -0,0 +1,103 @@
|
|||
from rsl_rl.algorithms import *
|
||||
from rsl_rl.env.gym_env import GymEnv
|
||||
from rsl_rl.runners.runner import Runner
|
||||
|
||||
import copy
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
import optuna
|
||||
import os
|
||||
import random
|
||||
import torch
|
||||
from tune_cfg import samplers
|
||||
|
||||
|
||||
ALGORITHM = PPO
|
||||
ENVIRONMENT = "BipedalWalker-v3"
|
||||
ENVIRONMENT_KWARGS = {}
|
||||
EVAL_AGENTS = 64
|
||||
EVAL_RUNS = 10
|
||||
EVAL_STEPS = 1000
|
||||
EXPERIMENT_DIR = os.environ.get("EXPERIMENT_DIRECTORY", "./")
|
||||
EXPERIMENT_NAME = os.environ.get("EXPERIMENT_NAME", f"tune-{ALGORITHM.__name__}-{ENVIRONMENT}")
|
||||
TRAIN_ITERATIONS = None
|
||||
TRAIN_TIMEOUT = 60 * 15 # 10 minutes
|
||||
TRAIN_RUNS = 3
|
||||
TRAIN_SEED = None
|
||||
|
||||
|
||||
def tune():
|
||||
assert TRAIN_RUNS == 1 or TRAIN_SEED is None, "If multiple runs are used, the seed must be None."
|
||||
|
||||
storage = optuna.storages.RDBStorage(url=f"sqlite:///{EXPERIMENT_DIR}/{EXPERIMENT_NAME}.db")
|
||||
pruner = optuna.pruners.MedianPruner(n_startup_trials=10)
|
||||
try:
|
||||
study = optuna.create_study(direction="maximize", pruner=pruner, storage=storage, study_name=EXPERIMENT_NAME)
|
||||
except Exception:
|
||||
study = optuna.load_study(pruner=pruner, storage=storage, study_name=EXPERIMENT_NAME)
|
||||
|
||||
study.optimize(objective, n_trials=100)
|
||||
|
||||
|
||||
def seed(s=None):
|
||||
seed = int(datetime.now().timestamp() * 1e6) % 2**32 if s is None else s
|
||||
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def objective(trial):
|
||||
seed()
|
||||
agent_kwargs, env_kwargs, runner_kwargs = samplers[ALGORITHM.__name__](trial)
|
||||
|
||||
evaluations = []
|
||||
for instantiation in range(TRAIN_RUNS):
|
||||
seed(TRAIN_SEED)
|
||||
|
||||
env = GymEnv(ENVIRONMENT, gym_kwargs=ENVIRONMENT_KWARGS, **env_kwargs)
|
||||
agent = ALGORITHM(env, **agent_kwargs)
|
||||
runner = Runner(env, agent, **runner_kwargs)
|
||||
runner._learn_cb = [lambda _, stat: runner._log_progress(stat, prefix=f"learn {instantiation+1}/{TRAIN_RUNS}")]
|
||||
|
||||
eval_env_kwargs = copy.deepcopy(env_kwargs)
|
||||
eval_env_kwargs["environment_count"] = EVAL_AGENTS
|
||||
eval_runner = Runner(
|
||||
GymEnv(ENVIRONMENT, gym_kwargs=ENVIRONMENT_KWARGS, **env_kwargs),
|
||||
agent,
|
||||
**runner_kwargs,
|
||||
)
|
||||
eval_runner._eval_cb = [
|
||||
lambda _, stat: runner._log_progress(stat, prefix=f"eval {instantiation+1}/{TRAIN_RUNS}")
|
||||
]
|
||||
|
||||
try:
|
||||
runner.learn(TRAIN_ITERATIONS, timeout=TRAIN_TIMEOUT)
|
||||
except Exception:
|
||||
raise optuna.TrialPruned()
|
||||
|
||||
intermediate_evaluations = []
|
||||
for eval_run in range(EVAL_RUNS):
|
||||
eval_runner._eval_cb = [lambda _, stat: runner._log_progress(stat, prefix=f"eval {eval_run+1}/{EVAL_RUNS}")]
|
||||
|
||||
seed()
|
||||
eval_runner.env.reset()
|
||||
intermediate_evaluations.append(eval_runner.evaluate(steps=EVAL_STEPS))
|
||||
eval = np.mean(intermediate_evaluations)
|
||||
|
||||
trial.report(eval, instantiation)
|
||||
if trial.should_prune():
|
||||
raise optuna.TrialPruned()
|
||||
|
||||
evaluations.append(eval)
|
||||
|
||||
evaluation = np.mean(evaluations)
|
||||
|
||||
return evaluation
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tune()
|
|
@ -0,0 +1,134 @@
|
|||
import torch
|
||||
|
||||
from rsl_rl.algorithms import DPPO, PPO
|
||||
from rsl_rl.modules import QuantileNetwork
|
||||
|
||||
NETWORKS = {"small": [64, 64], "medium": [256, 256], "large": [512, 256, 128]}
|
||||
|
||||
|
||||
def sample_dppo_hyperparams(trial):
|
||||
actor_net_activation = trial.suggest_categorical("actor_net_activation", ["relu", "tanh"])
|
||||
actor_net_arch = trial.suggest_categorical("actor_net_arch", list(NETWORKS.keys()))
|
||||
actor_noise_std = trial.suggest_float("actor_noise_std", 0.05, 1.0)
|
||||
batch_count = trial.suggest_int("batch_count", 1, 20)
|
||||
clip_ratio = trial.suggest_categorical("clip_ratio", [0.1, 0.2, 0.3, 0.4])
|
||||
critic_net_activation = trial.suggest_categorical("critic_net_activation", ["relu", "tanh"])
|
||||
critic_net_arch = trial.suggest_categorical("critic_net_arch", list(NETWORKS.keys()))
|
||||
entropy_coeff = trial.suggest_float("entropy_coeff", 0.00000001, 0.1)
|
||||
env_count = trial.suggest_categorical("env_count", [1, 8, 16, 32, 64, 128, 256, 512])
|
||||
gae_lambda = trial.suggest_categorical("gae_lambda", [0.8, 0.9, 0.92, 0.95, 0.98, 0.99, 1.0])
|
||||
gamma = trial.suggest_float("gamma", 0.95, 0.999)
|
||||
gradient_clip = trial.suggest_categorical("gradient_clip", [0.1, 0.3, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 2.0, 5.0])
|
||||
learning_rate = trial.suggest_float("learning_rate", 1e-5, 1)
|
||||
num_steps_per_env = trial.suggest_categorical("steps_per_env", [8, 16, 32, 64, 128, 256, 512, 1024, 2048])
|
||||
quantile_count = trial.suggest_categorical("quantile_count", [20, 50, 100, 200])
|
||||
recurrent = trial.suggest_categorical("recurrent", [True, False])
|
||||
target_kl = trial.suggest_float("target_kl", 0.01, 0.3)
|
||||
value_coeff = trial.suggest_float("value_coeff", 0.0, 1.0)
|
||||
value_measure = trial.suggest_categorical(
|
||||
"value_measure",
|
||||
["neutral", "var-risk-averse", "var-risk-seeking", "var-super-risk-averse", "var-super-risk-seeking"],
|
||||
)
|
||||
|
||||
actor_net_arch = NETWORKS[actor_net_arch]
|
||||
critic_net_arch = NETWORKS[critic_net_arch]
|
||||
value_measure_kwargs = {
|
||||
"neutral": dict(),
|
||||
"var-risk-averse": dict(confidence_level=0.25),
|
||||
"var-risk-seeking": dict(confidence_level=0.75),
|
||||
"var-super-risk-averse": dict(confidence_level=0.1),
|
||||
"var-super-risk-seeking": dict(confidence_level=0.9),
|
||||
}[value_measure]
|
||||
value_measure = {
|
||||
"neutral": QuantileNetwork.measure_neutral,
|
||||
"var-risk-averse": QuantileNetwork.measure_percentile,
|
||||
"var-risk-seeking": QuantileNetwork.measure_percentile,
|
||||
"var-super-risk-averse": QuantileNetwork.measure_percentile,
|
||||
"var-super-risk-seeking": QuantileNetwork.measure_percentile,
|
||||
}[value_measure]
|
||||
device = "cuda:0" if env_count * num_steps_per_env > 2048 and torch.cuda.is_available() else "cpu"
|
||||
|
||||
agent_kwargs = dict(
|
||||
actor_activations=([actor_net_activation] * len(actor_net_arch)) + ["linear"],
|
||||
actor_hidden_dims=actor_net_arch,
|
||||
actor_input_normalization=False,
|
||||
actor_noise_std=actor_noise_std,
|
||||
batch_count=batch_count,
|
||||
clip_ratio=clip_ratio,
|
||||
critic_activations=([critic_net_activation] * len(critic_net_arch)),
|
||||
critic_hidden_dims=critic_net_arch,
|
||||
critic_input_normalization=False,
|
||||
device=device,
|
||||
entropy_coeff=entropy_coeff,
|
||||
gae_lambda=gae_lambda,
|
||||
gamma=gamma,
|
||||
gradient_clip=gradient_clip,
|
||||
learning_rate=learning_rate,
|
||||
quantile_count=quantile_count,
|
||||
recurrent=recurrent,
|
||||
schedule="adaptive",
|
||||
target_kl=target_kl,
|
||||
value_coeff=value_coeff,
|
||||
value_measure=value_measure,
|
||||
value_measure_kwargs=value_measure_kwargs,
|
||||
)
|
||||
env_kwargs = dict(device=device, environment_count=env_count)
|
||||
runner_kwargs = dict(device=device, num_steps_per_env=num_steps_per_env)
|
||||
|
||||
return agent_kwargs, env_kwargs, runner_kwargs
|
||||
|
||||
|
||||
def sample_ppo_hyperparams(trial):
|
||||
actor_net_activation = trial.suggest_categorical("actor_net_activation", ["relu", "tanh"])
|
||||
actor_net_arch = trial.suggest_categorical("actor_net_arch", list(NETWORKS.keys()))
|
||||
actor_noise_std = trial.suggest_float("actor_noise_std", 0.05, 1.0)
|
||||
batch_count = trial.suggest_int("batch_count", 1, 20)
|
||||
clip_ratio = trial.suggest_categorical("clip_ratio", [0.1, 0.2, 0.3, 0.4])
|
||||
critic_net_activation = trial.suggest_categorical("critic_net_activation", ["relu", "tanh"])
|
||||
critic_net_arch = trial.suggest_categorical("critic_net_arch", list(NETWORKS.keys()))
|
||||
entropy_coeff = trial.suggest_float("entropy_coeff", 0.00000001, 0.1)
|
||||
env_count = trial.suggest_categorical("env_count", [1, 8, 16, 32, 64, 128, 256, 512])
|
||||
gae_lambda = trial.suggest_categorical("gae_lambda", [0.8, 0.9, 0.92, 0.95, 0.98, 0.99, 1.0])
|
||||
gamma = trial.suggest_float("gamma", 0.95, 0.999)
|
||||
gradient_clip = trial.suggest_categorical("gradient_clip", [0.1, 0.3, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 2.0, 5.0])
|
||||
learning_rate = trial.suggest_float("learning_rate", 1e-5, 1)
|
||||
num_steps_per_env = trial.suggest_categorical("steps_per_env", [8, 16, 32, 64, 128, 256, 512, 1024, 2048])
|
||||
recurrent = trial.suggest_categorical("recurrent", [True, False])
|
||||
target_kl = trial.suggest_float("target_kl", 0.01, 0.3)
|
||||
value_coeff = trial.suggest_float("value_coeff", 0.0, 1.0)
|
||||
|
||||
actor_net_arch = NETWORKS[actor_net_arch]
|
||||
critic_net_arch = NETWORKS[critic_net_arch]
|
||||
device = "cuda:0" if env_count * num_steps_per_env > 2048 and torch.cuda.is_available() else "cpu"
|
||||
|
||||
agent_kwargs = dict(
|
||||
actor_activations=([actor_net_activation] * len(actor_net_arch)) + ["linear"],
|
||||
actor_hidden_dims=actor_net_arch,
|
||||
actor_input_normalization=False,
|
||||
actor_noise_std=actor_noise_std,
|
||||
batch_count=batch_count,
|
||||
clip_ratio=clip_ratio,
|
||||
critic_activations=([critic_net_activation] * len(critic_net_arch)) + ["linear"],
|
||||
critic_hidden_dims=critic_net_arch,
|
||||
critic_input_normalization=False,
|
||||
device=device,
|
||||
entropy_coeff=entropy_coeff,
|
||||
gae_lambda=gae_lambda,
|
||||
gamma=gamma,
|
||||
gradient_clip=gradient_clip,
|
||||
learning_rate=learning_rate,
|
||||
recurrent=recurrent,
|
||||
schedule="adaptive",
|
||||
target_kl=target_kl,
|
||||
value_coeff=value_coeff,
|
||||
)
|
||||
env_kwargs = dict(device=device, environment_count=env_count)
|
||||
runner_kwargs = dict(device=device, num_steps_per_env=num_steps_per_env)
|
||||
|
||||
return agent_kwargs, env_kwargs, runner_kwargs
|
||||
|
||||
|
||||
samplers = {
|
||||
DPPO.__name__: sample_dppo_hyperparams,
|
||||
PPO.__name__: sample_ppo_hyperparams,
|
||||
}
|
|
@ -0,0 +1,4 @@
|
|||
# To use this file, copy it to wandb_config.py and fill in the missing values.
|
||||
|
||||
WANDB_API_KEY = ""
|
||||
WANDB_ENTITY = ""
|
|
@ -1,7 +1,2 @@
|
|||
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
"""Main module for the rsl_rl package."""
|
||||
|
||||
__version__ = "2.0.1"
|
||||
__license__ = "BSD-3"
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
"""Implementation of different RL agents."""
|
||||
|
||||
from .agent import Agent
|
||||
from .d4pg import D4PG
|
||||
from .ddpg import DDPG
|
||||
from .dppo import DPPO
|
||||
from .dsac import DSAC
|
||||
from .ppo import PPO
|
||||
from .sac import SAC
|
||||
from .td3 import TD3
|
||||
|
||||
__all__ = ["PPO"]
|
||||
__all__ = ["Agent", "DDPG", "D4PG", "DPPO", "DSAC", "PPO", "SAC", "TD3"]
|
||||
|
|
|
@ -0,0 +1,371 @@
|
|||
from abc import abstractmethod
|
||||
import torch
|
||||
from typing import Any, Callable, Dict, List, Tuple, Union
|
||||
|
||||
from rsl_rl.algorithms.agent import Agent
|
||||
from rsl_rl.env.vec_env import VecEnv
|
||||
from rsl_rl.modules.network import Network
|
||||
from rsl_rl.storage.storage import Dataset
|
||||
from rsl_rl.utils.utils import environment_dimensions
|
||||
from rsl_rl.utils.utils import squeeze_preserve_batch
|
||||
|
||||
|
||||
class AbstractActorCritic(Agent):
|
||||
_alg_features = dict(recurrent=False)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: VecEnv,
|
||||
actor_activations: List[str] = ["relu", "relu", "relu", "linear"],
|
||||
actor_hidden_dims: List[int] = [256, 256, 256],
|
||||
actor_init_gain: float = 0.5,
|
||||
actor_input_normalization: bool = False,
|
||||
actor_recurrent_layers: int = 1,
|
||||
actor_recurrent_module: str = Network.recurrent_module_lstm,
|
||||
actor_recurrent_tf_context_length: int = 64,
|
||||
actor_recurrent_tf_head_count: int = 8,
|
||||
actor_shared_dims: int = None,
|
||||
batch_count: int = 1,
|
||||
batch_size: int = 1,
|
||||
critic_activations: List[str] = ["relu", "relu", "relu", "linear"],
|
||||
critic_hidden_dims: List[int] = [256, 256, 256],
|
||||
critic_init_gain: float = 0.5,
|
||||
critic_input_normalization: bool = False,
|
||||
critic_recurrent_layers: int = 1,
|
||||
critic_recurrent_module: str = Network.recurrent_module_lstm,
|
||||
critic_recurrent_tf_context_length: int = 64,
|
||||
critic_recurrent_tf_head_count: int = 8,
|
||||
critic_shared_dims: int = None,
|
||||
polyak: float = 0.995,
|
||||
recurrent: bool = False,
|
||||
return_steps: int = 1,
|
||||
_actor_input_size_delta: int = 0,
|
||||
_critic_input_size_delta: int = 0,
|
||||
**kwargs,
|
||||
):
|
||||
"""Creates an actor critic agent.
|
||||
|
||||
Args:
|
||||
env (VecEnv): A vectorized environment.
|
||||
actor_activations (List[str]): A list of activation functions for the actor network.
|
||||
actor_hidden_dims (List[str]): A list of layer sizes for the actor network.
|
||||
actor_init_gain (float): Network initialization gain for actor.
|
||||
actor_input_normalization (bool): Whether to empirically normalize inputs to the actor network.
|
||||
actor_recurrent_layers (int): The number of recurrent layers to use for the actor network.
|
||||
actor_recurrent_module (str): The recurrent module to use. Must be one of Network.recurrent_modules.
|
||||
actor_shared_dims (int): The number of dimensions to share for an actor with multiple heads.
|
||||
batch_count (int): The number of batches to process per update step.
|
||||
batch_size (int): The size of each batch to process during the update step.
|
||||
critic_activations (List[str]): A list of activation functions for the critic network.
|
||||
critic_hidden_dims: (List[str]): A list of layer sizes for the critic network.
|
||||
critic_init_gain (float): Network initialization gain for critic.
|
||||
critic_input_normalization (bool): Whether to empirically normalize inputs to the critic network.
|
||||
critic_recurrent_layers (int): The number of recurrent layers to use for the critic network.
|
||||
critic_recurrent_module (str): The recurrent module to use. Must be one of Network.recurrent_modules.
|
||||
critic_shared_dims (int): The number of dimensions to share for a critic with multiple heads.
|
||||
polyak (float): The actor-critic target network polyak factor.
|
||||
recurrent (bool): Whether to use recurrent actor and critic networks.
|
||||
recurrent_module (str): The recurrent module to use. Must be one of Network.recurrent_modules.
|
||||
recurrent_tf_context_length (int): The context length of the Transformer.
|
||||
recurrent_tf_head_count (int): The head count of the Transformer.
|
||||
return_steps (float): The number of steps over which to compute the returns (n-step return).
|
||||
_actor_input_size_delta (int): The number of additional dimensions to add to the actor input.
|
||||
_critic_input_size_delta (int): The number of additional dimensions to add to the critic input.
|
||||
"""
|
||||
assert (
|
||||
self._alg_features["recurrent"] == True or not recurrent
|
||||
), f"{self.__class__.__name__} does not support recurrent networks."
|
||||
|
||||
super().__init__(env, **kwargs)
|
||||
|
||||
self.actor: torch.nn.Module = None
|
||||
self.actor_optimizer: torch.nn.Module = None
|
||||
self.critic_optimizer: torch.nn.Module = None
|
||||
self.critic: torch.nn.Module = None
|
||||
|
||||
self._batch_size = batch_size
|
||||
self._batch_count = batch_count
|
||||
self._polyak_factor = polyak
|
||||
self._return_steps = return_steps
|
||||
self._recurrent = recurrent
|
||||
|
||||
self._register_serializable(
|
||||
"_batch_size", "_batch_count", "_discount_factor", "_polyak_factor", "_return_steps"
|
||||
)
|
||||
|
||||
dimensions = environment_dimensions(self.env)
|
||||
try:
|
||||
actor_input_size = dimensions["actor_observations"]
|
||||
critic_input_size = dimensions["critic_observations"]
|
||||
except KeyError:
|
||||
actor_input_size = dimensions["observations"]
|
||||
critic_input_size = dimensions["observations"]
|
||||
self._actor_input_size = actor_input_size + _actor_input_size_delta
|
||||
self._critic_input_size = critic_input_size + self._action_size + _critic_input_size_delta
|
||||
|
||||
self._register_actor_network_kwargs(
|
||||
activations=actor_activations,
|
||||
hidden_dims=actor_hidden_dims,
|
||||
init_gain=actor_init_gain,
|
||||
input_normalization=actor_input_normalization,
|
||||
recurrent=recurrent,
|
||||
recurrent_layers=actor_recurrent_layers,
|
||||
recurrent_module=actor_recurrent_module,
|
||||
recurrent_tf_context_length=actor_recurrent_tf_context_length,
|
||||
recurrent_tf_head_count=actor_recurrent_tf_head_count,
|
||||
)
|
||||
|
||||
if actor_shared_dims is not None:
|
||||
self._register_actor_network_kwargs(shared_dims=actor_shared_dims)
|
||||
|
||||
self._register_critic_network_kwargs(
|
||||
activations=critic_activations,
|
||||
hidden_dims=critic_hidden_dims,
|
||||
init_gain=critic_init_gain,
|
||||
input_normalization=critic_input_normalization,
|
||||
recurrent=recurrent,
|
||||
recurrent_layers=critic_recurrent_layers,
|
||||
recurrent_module=critic_recurrent_module,
|
||||
recurrent_tf_context_length=critic_recurrent_tf_context_length,
|
||||
recurrent_tf_head_count=critic_recurrent_tf_head_count,
|
||||
)
|
||||
|
||||
if critic_shared_dims is not None:
|
||||
self._register_critic_network_kwargs(shared_dims=critic_shared_dims)
|
||||
|
||||
self._register_serializable(
|
||||
"_actor_input_size", "_actor_network_kwargs", "_critic_input_size", "_critic_network_kwargs"
|
||||
)
|
||||
|
||||
# For computing n-step returns using prior transitions.
|
||||
self._stored_dataset = []
|
||||
|
||||
def export_onnx(self) -> Tuple[torch.nn.Module, torch.Tensor, Dict]:
|
||||
self.eval_mode()
|
||||
|
||||
class ONNXActor(torch.nn.Module):
|
||||
def __init__(self, model: torch.nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.model = model
|
||||
|
||||
def forward(self, x: torch.Tensor, hidden_state: Tuple[torch.Tensor, torch.Tensor] = None):
|
||||
if hidden_state is None:
|
||||
return self.model(x)
|
||||
|
||||
data = self.model(x, hidden_state=hidden_state)
|
||||
hidden_state = self.model.last_hidden_state
|
||||
|
||||
return data, hidden_state
|
||||
|
||||
model = ONNXActor(self.actor)
|
||||
kwargs = dict(
|
||||
export_params=True,
|
||||
opset_version=11,
|
||||
verbose=True,
|
||||
dynamic_axes={},
|
||||
)
|
||||
|
||||
kwargs["input_names"] = ["observations"]
|
||||
kwargs["output_names"] = ["actions"]
|
||||
|
||||
args = torch.zeros(1, self._actor_input_size)
|
||||
|
||||
if self.actor.recurrent:
|
||||
hidden_state = (
|
||||
torch.zeros(self.actor._features[0].num_layers, 1, self.actor._features[0].hidden_size),
|
||||
torch.zeros(self.actor._features[0].num_layers, 1, self.actor._features[0].hidden_size),
|
||||
)
|
||||
args = (args, {"hidden_state": hidden_state})
|
||||
|
||||
return model, args, kwargs
|
||||
|
||||
def draw_random_actions(self, obs: torch.Tensor, env_info: Dict[str, Any]) -> Tuple[torch.Tensor, Dict]:
|
||||
actions, data = super().draw_random_actions(obs, env_info)
|
||||
|
||||
actor_obs, critic_obs = self._process_observations(obs, env_info)
|
||||
data.update({"actor_observations": actor_obs.clone(), "critic_observations": critic_obs.clone()})
|
||||
|
||||
return actions, data
|
||||
|
||||
def get_inference_policy(self, device=None) -> Callable:
|
||||
self.to(device)
|
||||
self.eval_mode()
|
||||
|
||||
if self.actor.recurrent:
|
||||
self.actor.reset_full_hidden_state(batch_size=self.env.num_envs)
|
||||
|
||||
if self.critic.recurrent:
|
||||
self.critic.reset_full_hidden_state(batch_size=self.env.num_envs)
|
||||
|
||||
def policy(obs, env_info=None):
|
||||
with torch.inference_mode():
|
||||
obs, _ = self._process_observations(obs, env_info)
|
||||
|
||||
actions = self._process_actions(self.actor.forward(obs))
|
||||
|
||||
return actions
|
||||
|
||||
return policy
|
||||
|
||||
def process_transition(
|
||||
self,
|
||||
observations: torch.Tensor,
|
||||
environment_info: Dict[str, Any],
|
||||
actions: torch.Tensor,
|
||||
rewards: torch.Tensor,
|
||||
next_observations: torch.Tensor,
|
||||
next_environment_info: torch.Tensor,
|
||||
dones: torch.Tensor,
|
||||
data: Dict[str, torch.Tensor],
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
if "actor_observations" in data and "critic_observations" in data:
|
||||
actor_obs, critic_obs = data["actor_observations"], data["critic_observations"]
|
||||
else:
|
||||
actor_obs, critic_obs = self._process_observations(observations, environment_info)
|
||||
|
||||
if "next_actor_observations" in data and "next_critic_observations" in data:
|
||||
next_actor_obs, next_critic_obs = data["next_actor_observations"], data["next_critic_observations"]
|
||||
else:
|
||||
next_actor_obs, next_critic_obs = self._process_observations(next_observations, next_environment_info)
|
||||
|
||||
transition = {
|
||||
"actions": actions,
|
||||
"actor_observations": actor_obs,
|
||||
"critic_observations": critic_obs,
|
||||
"dones": dones,
|
||||
"next_actor_observations": next_actor_obs,
|
||||
"next_critic_observations": next_critic_obs,
|
||||
"rewards": squeeze_preserve_batch(rewards),
|
||||
"timeouts": self._extract_timeouts(next_environment_info),
|
||||
}
|
||||
transition.update(data)
|
||||
|
||||
for key, value in transition.items():
|
||||
transition[key] = value.detach().clone()
|
||||
|
||||
return transition
|
||||
|
||||
@property
|
||||
def recurrent(self) -> bool:
|
||||
return self._recurrent
|
||||
|
||||
def register_terminations(self, terminations: torch.Tensor) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(self, dataset: Dataset) -> Dict[str, Union[float, torch.Tensor]]:
|
||||
with torch.inference_mode():
|
||||
self.storage.append(self._process_dataset(dataset))
|
||||
|
||||
def _critic_input(self, observations, actions) -> torch.Tensor:
|
||||
"""Combines observations and actions into a tensor that can be fed into the critic network.
|
||||
|
||||
Args:
|
||||
observations (torch.Tensor): The critic observations.
|
||||
actions (torch.Tensor): The actions computed by the actor.
|
||||
Returns:
|
||||
A torch.Tensor of inputs for the critic network.
|
||||
"""
|
||||
return torch.cat((observations, actions), dim=-1)
|
||||
|
||||
def _extract_timeouts(self, next_environment_info):
|
||||
"""Extracts timeout information from the transition next state information dictionary.
|
||||
|
||||
Args:
|
||||
next_environment_info (Dict[str, Any]): The transition next state information dictionary.
|
||||
Returns:
|
||||
A torch.Tensor vector of actor timeouts.
|
||||
"""
|
||||
if "time_outs" not in next_environment_info:
|
||||
return torch.zeros(self.env.num_envs, device=self.device)
|
||||
|
||||
timeouts = squeeze_preserve_batch(next_environment_info["time_outs"].to(self.device))
|
||||
|
||||
return timeouts
|
||||
|
||||
def _process_dataset(self, dataset: Dataset) -> Dataset:
|
||||
"""Processes a dataset before it is added to the replay memory.
|
||||
|
||||
Handles n-step returns and timeouts.
|
||||
|
||||
TODO: This function seems to be a bottleneck in the training pipeline - speed it up!
|
||||
|
||||
Args:
|
||||
dataset (Dataset): The dataset to process.
|
||||
Returns:
|
||||
A Dataset object containing the processed data.
|
||||
"""
|
||||
assert len(dataset) >= self._return_steps
|
||||
|
||||
dataset = self._stored_dataset + dataset
|
||||
length = len(dataset) - self._return_steps + 1
|
||||
self._stored_dataset = dataset[length:]
|
||||
|
||||
for idx in range(len(dataset) - self._return_steps + 1):
|
||||
dead_idx = torch.zeros_like(dataset[idx]["dones"])
|
||||
rewards = torch.zeros_like(dataset[idx]["rewards"])
|
||||
|
||||
for k in range(self._return_steps):
|
||||
data = dataset[idx + k]
|
||||
alive_idx = (dead_idx == 0).nonzero()
|
||||
critic_predictions = self.critic.forward(
|
||||
self._critic_input(
|
||||
data["critic_observations"].clone().to(self.device),
|
||||
data["actions"].clone().to(self.device),
|
||||
)
|
||||
)
|
||||
rewards[alive_idx] += self._discount_factor**k * data["rewards"][alive_idx]
|
||||
rewards[alive_idx] += (
|
||||
self._discount_factor ** (k + 1) * data["timeouts"][alive_idx] * critic_predictions[alive_idx]
|
||||
)
|
||||
dead_idx += data["dones"]
|
||||
dead_idx += data["timeouts"]
|
||||
|
||||
dataset[idx]["rewards"] = rewards
|
||||
|
||||
return dataset[:length]
|
||||
|
||||
def _process_observations(
|
||||
self, observations: torch.Tensor, environment_info: Dict[str, Any] = None
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
"""Processes observations returned by the environment to extract actor and critic observations.
|
||||
|
||||
Args:
|
||||
observations (torch.Tensor): normal environment observations.
|
||||
environment_info (Dict[str, Any]): A dictionary of additional environment information.
|
||||
Returns:
|
||||
A tuple containing two torch.Tensors with actor and critic observations, respectively.
|
||||
"""
|
||||
try:
|
||||
critic_obs = environment_info["observations"]["critic"]
|
||||
except (KeyError, TypeError):
|
||||
critic_obs = observations
|
||||
|
||||
actor_obs, critic_obs = observations.to(self.device), critic_obs.to(self.device)
|
||||
|
||||
return actor_obs, critic_obs
|
||||
|
||||
def _register_actor_network_kwargs(self, **kwargs) -> None:
|
||||
"""Function to configure actor network in child classes before calling super().__init__()."""
|
||||
if not hasattr(self, "_actor_network_kwargs"):
|
||||
self._actor_network_kwargs = dict()
|
||||
|
||||
self._actor_network_kwargs.update(**kwargs)
|
||||
|
||||
def _register_critic_network_kwargs(self, **kwargs) -> None:
|
||||
"""Function to configure critic network in child classes before calling super().__init__()."""
|
||||
if not hasattr(self, "_critic_network_kwargs"):
|
||||
self._critic_network_kwargs = dict()
|
||||
|
||||
self._critic_network_kwargs.update(**kwargs)
|
||||
|
||||
def _update_target(self, online: torch.nn.Module, target: torch.nn.Module) -> None:
|
||||
"""Updates the target network using the polyak factor.
|
||||
|
||||
Args:
|
||||
online (torch.nn.Module): The online network.
|
||||
target (torch.nn.Module): The target network.
|
||||
"""
|
||||
for op, tp in zip(online.parameters(), target.parameters()):
|
||||
tp.data.copy_((1.0 - self._polyak_factor) * op.data + self._polyak_factor * tp.data)
|
|
@ -0,0 +1,197 @@
|
|||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import Any, Callable, Dict, Tuple, Union
|
||||
|
||||
from rsl_rl.env import VecEnv
|
||||
from rsl_rl.storage.storage import Dataset
|
||||
from rsl_rl.utils.benchmarkable import Benchmarkable
|
||||
from rsl_rl.utils.serializable import Serializable
|
||||
from rsl_rl.utils.utils import environment_dimensions
|
||||
|
||||
|
||||
class Agent(ABC, Benchmarkable, Serializable):
|
||||
def __init__(
|
||||
self,
|
||||
env: VecEnv,
|
||||
action_max: float = np.inf,
|
||||
action_min: float = -np.inf,
|
||||
benchmark: bool = False,
|
||||
device: str = "cpu",
|
||||
gamma: float = 0.99,
|
||||
):
|
||||
"""Creates an agent.
|
||||
|
||||
Args:
|
||||
env (VecEnv): The envrionment of the agent.
|
||||
action_max (float): The maximum action value.
|
||||
action_min (float): The minimum action value.
|
||||
bechmark (bool): Whether to benchmark runtime.
|
||||
device (str): The device to use for computation.
|
||||
gamma (float): The environment discount factor.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.env = env
|
||||
self.device = device
|
||||
self.storage = None
|
||||
|
||||
self._action_max = action_max
|
||||
self._action_min = action_min
|
||||
self._discount_factor = gamma
|
||||
|
||||
self._register_serializable("_action_max", "_action_min", "_discount_factor")
|
||||
|
||||
dimensions = environment_dimensions(self.env)
|
||||
self._action_size = dimensions["actions"]
|
||||
|
||||
self._register_serializable("_action_size")
|
||||
|
||||
if self._action_min > -np.inf and self._action_max < np.inf:
|
||||
self._rand_scale = self._action_max - self._action_min
|
||||
self._rand_offset = self._action_min
|
||||
else:
|
||||
self._rand_scale = 2.0
|
||||
self._rand_offset = -1.0
|
||||
|
||||
self._bm_toggle(benchmark)
|
||||
|
||||
@abstractmethod
|
||||
def draw_actions(
|
||||
self, obs: torch.Tensor, env_info: Dict[str, Any]
|
||||
) -> Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]:
|
||||
"""Draws actions from the action space.
|
||||
|
||||
Args:
|
||||
obs (torch.Tensor): The observations for which to draw actions.
|
||||
env_info (Dict[str, Any]): The environment information for the observations.
|
||||
Returns:
|
||||
A tuple containing the actions and the data dictionary.
|
||||
"""
|
||||
pass
|
||||
|
||||
def draw_random_actions(
|
||||
self, obs: torch.Tensor, env_info: Dict[str, Any]
|
||||
) -> Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]:
|
||||
"""Draws random actions from the action space.
|
||||
|
||||
Args:
|
||||
obs (torch.Tensor): The observations to include in the data dictionary.
|
||||
env_info (Dict[str, Any]): The environment information to include in the data dictionary.
|
||||
Returns:
|
||||
A tuple containing the random actions and the data dictionary.
|
||||
"""
|
||||
actions = self._process_actions(
|
||||
self._rand_offset + self._rand_scale * torch.rand(self.env.num_envs, self._action_size)
|
||||
)
|
||||
|
||||
return actions, {}
|
||||
|
||||
@abstractmethod
|
||||
def eval_mode(self) -> Agent:
|
||||
"""Sets the agent to evaluation mode."""
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
def export_onnx(self) -> Tuple[torch.nn.Module, torch.Tensor, Dict]:
|
||||
"""Exports the agent's policy network to ONNX format.
|
||||
|
||||
Returns:
|
||||
A tuple containing the ONNX model, the input arguments, and the keyword arguments.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def gamma(self) -> float:
|
||||
return self._discount_factor
|
||||
|
||||
@abstractmethod
|
||||
def get_inference_policy(self, device: str = None) -> Callable:
|
||||
"""Returns a function that computes actions from observations without storing gradients.
|
||||
|
||||
Args:
|
||||
device (torch.device): The device to use for inference.
|
||||
Returns:
|
||||
A function that computes actions from observations.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def initialized(self) -> bool:
|
||||
"""Whether the agent has been initialized."""
|
||||
return self.storage.initialized
|
||||
|
||||
@abstractmethod
|
||||
def process_transition(
|
||||
self,
|
||||
observations: torch.Tensor,
|
||||
environement_info: Dict[str, Any],
|
||||
actions: torch.Tensor,
|
||||
rewards: torch.Tensor,
|
||||
next_observations: torch.Tensor,
|
||||
next_environment_info: torch.Tensor,
|
||||
dones: torch.Tensor,
|
||||
data: Dict[str, torch.Tensor],
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Processes a transition before it is added to the replay memory.
|
||||
|
||||
Args:
|
||||
observations (torch.Tensor): The observations from the environment.
|
||||
environment_info (Dict[str, Any]): The environment information.
|
||||
actions (torch.Tensor): The actions computed by the actor.
|
||||
rewards (torch.Tensor): The rewards from the environment.
|
||||
next_observations (torch.Tensor): The next observations from the environment.
|
||||
next_environment_info (Dict[str, Any]): The next environment information.
|
||||
dones (torch.Tensor): The done flags from the environment.
|
||||
data (Dict[str, torch.Tensor]): Additional data to include in the transition.
|
||||
Returns:
|
||||
A dictionary containing the processed transition.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def register_terminations(self, terminations: torch.Tensor) -> None:
|
||||
"""Registers terminations with the actor critic agent.
|
||||
|
||||
Args:
|
||||
terminations (torch.Tensor): A tensor of indicator values for each environment.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def to(self, device: str) -> Agent:
|
||||
"""Transfers agent parameters to device."""
|
||||
self.device = device
|
||||
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
def train_mode(self) -> Agent:
|
||||
"""Sets the agent to training mode."""
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
def update(self, dataset: Dataset) -> Dict[str, Union[float, torch.Tensor]]:
|
||||
"""Updates the agent's parameters.
|
||||
|
||||
Args:
|
||||
dataset (Dataset): The dataset from which to update the agent.
|
||||
Returns:
|
||||
A dictionary containing the loss values.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _process_actions(self, actions: torch.Tensor) -> torch.Tensor:
|
||||
"""Processes actions produced by the agent.
|
||||
|
||||
Args:
|
||||
actions (torch.Tensor): The raw actions.
|
||||
Returns:
|
||||
A torch.Tensor containing the processed actions.
|
||||
"""
|
||||
actions = actions.reshape(-1, self._action_size)
|
||||
actions = actions.clamp(self._action_min, self._action_max)
|
||||
actions = actions.to(self.device)
|
||||
|
||||
return actions
|
|
@ -0,0 +1,168 @@
|
|||
from __future__ import annotations
|
||||
import torch
|
||||
from typing import Dict, Union
|
||||
from rsl_rl.algorithms.dpg import AbstractDPG
|
||||
from rsl_rl.env import VecEnv
|
||||
from rsl_rl.storage.storage import Dataset
|
||||
|
||||
from rsl_rl.modules import CategoricalNetwork, Network
|
||||
|
||||
|
||||
class D4PG(AbstractDPG):
|
||||
"""Distributed Distributional Deep Deterministic Policy Gradients algorithm.
|
||||
|
||||
This is an implementation of the D4PG algorithm by Barth-Maron et. al. for vectorized environments.
|
||||
|
||||
Paper: https://arxiv.org/pdf/1804.08617.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: VecEnv,
|
||||
actor_lr: float = 1e-4,
|
||||
atom_count: int = 51,
|
||||
critic_activations: list = ["relu", "relu", "relu"],
|
||||
critic_lr: float = 1e-3,
|
||||
target_update_delay: int = 2,
|
||||
value_max: float = 10.0,
|
||||
value_min: float = -10.0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
env (VecEnv): A vectorized environment.
|
||||
actor_lr (float): The learning rate for the actor network.
|
||||
atom_count (int): The number of atoms to use for the categorical distribution.
|
||||
critic_activations (list): A list of activation functions to use for the critic network.
|
||||
critic_lr (float): The learning rate for the critic network.
|
||||
target_update_delay (int): The number of steps to wait before updating the target networks.
|
||||
value_max (float): The maximum value for the categorical distribution.
|
||||
value_min (float): The minimum value for the categorical distribution.
|
||||
"""
|
||||
kwargs["critic_activations"] = critic_activations
|
||||
|
||||
super().__init__(env, **kwargs)
|
||||
|
||||
self._atom_count = atom_count
|
||||
self._target_update_delay = target_update_delay
|
||||
self._value_max = value_max
|
||||
self._value_min = value_min
|
||||
|
||||
self._register_serializable("_atom_count", "_target_update_delay", "_value_max", "_value_min")
|
||||
|
||||
self.actor = Network(self._actor_input_size, self._action_size, **self._actor_network_kwargs)
|
||||
self.critic = CategoricalNetwork(
|
||||
self._critic_input_size,
|
||||
1,
|
||||
atom_count=atom_count,
|
||||
value_max=value_max,
|
||||
value_min=value_min,
|
||||
**self._critic_network_kwargs,
|
||||
)
|
||||
|
||||
self.target_actor = Network(self._actor_input_size, self._action_size, **self._actor_network_kwargs)
|
||||
self.target_critic = CategoricalNetwork(
|
||||
self._critic_input_size,
|
||||
1,
|
||||
atom_count=atom_count,
|
||||
value_max=value_max,
|
||||
value_min=value_min,
|
||||
**self._critic_network_kwargs,
|
||||
)
|
||||
self.target_actor.load_state_dict(self.actor.state_dict())
|
||||
self.target_critic.load_state_dict(self.critic.state_dict())
|
||||
|
||||
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
|
||||
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)
|
||||
|
||||
self._register_serializable(
|
||||
"actor", "critic", "target_actor", "target_critic", "actor_optimizer", "critic_optimizer"
|
||||
)
|
||||
|
||||
self._update_step = 0
|
||||
|
||||
self._register_serializable("_update_step")
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
def eval_mode(self) -> D4PG:
|
||||
super().eval_mode()
|
||||
|
||||
self.actor.eval()
|
||||
self.critic.eval()
|
||||
self.target_actor.eval()
|
||||
self.target_critic.eval()
|
||||
|
||||
return self
|
||||
|
||||
def to(self, device: str) -> D4PG:
|
||||
super().to(device)
|
||||
|
||||
self.actor.to(device)
|
||||
self.critic.to(device)
|
||||
self.target_actor.to(device)
|
||||
self.target_critic.to(device)
|
||||
|
||||
return self
|
||||
|
||||
def train_mode(self) -> D4PG:
|
||||
super().train_mode()
|
||||
|
||||
self.actor.train()
|
||||
self.critic.train()
|
||||
self.target_actor.train()
|
||||
self.target_critic.train()
|
||||
|
||||
return self
|
||||
|
||||
def update(self, dataset: Dataset) -> Dict[str, Union[float, torch.Tensor]]:
|
||||
super().update(dataset)
|
||||
|
||||
if not self.initialized:
|
||||
return {}
|
||||
|
||||
total_actor_loss = torch.zeros(self._batch_count)
|
||||
total_critic_loss = torch.zeros(self._batch_count)
|
||||
|
||||
for idx, batch in enumerate(self.storage.batch_generator(self._batch_size, self._batch_count)):
|
||||
actor_obs = batch["actor_observations"]
|
||||
critic_obs = batch["critic_observations"]
|
||||
actions = batch["actions"].reshape(self._batch_size, -1)
|
||||
rewards = batch["rewards"]
|
||||
actor_next_obs = batch["next_actor_observations"]
|
||||
critic_next_obs = batch["next_critic_observations"]
|
||||
dones = batch["dones"]
|
||||
|
||||
predictions = self.critic.forward(self._critic_input(critic_obs, actions), distribution=True).squeeze()
|
||||
target_actor_prediction = self._process_actions(self.target_actor.forward(actor_next_obs))
|
||||
target_probabilities = self.target_critic.forward(
|
||||
self._critic_input(critic_next_obs, target_actor_prediction), distribution=True
|
||||
).squeeze()
|
||||
targets = self.target_critic.compute_targets(rewards, dones, self._discount_factor)
|
||||
critic_loss = self.target_critic.categorical_loss(predictions, target_probabilities, targets)
|
||||
|
||||
self.critic_optimizer.zero_grad()
|
||||
critic_loss.backward()
|
||||
self.critic_optimizer.step()
|
||||
|
||||
evaluation = self.critic.forward(
|
||||
self._critic_input(critic_obs, self._process_actions(self.actor.forward(actor_obs)))
|
||||
)
|
||||
actor_loss = -evaluation.mean()
|
||||
|
||||
self.actor_optimizer.zero_grad()
|
||||
actor_loss.backward()
|
||||
self.actor_optimizer.step()
|
||||
|
||||
if self._update_step % self._target_update_delay == 0:
|
||||
self._update_target(self.actor, self.target_actor)
|
||||
self._update_target(self.critic, self.target_critic)
|
||||
|
||||
self._update_step += 1
|
||||
|
||||
total_actor_loss[idx] = actor_loss.item()
|
||||
total_critic_loss[idx] = critic_loss.item()
|
||||
|
||||
stats = {"actor": total_actor_loss.mean().item(), "critic": total_critic_loss.mean().item()}
|
||||
|
||||
return stats
|
|
@ -0,0 +1,125 @@
|
|||
from __future__ import annotations
|
||||
import torch
|
||||
from torch import optim
|
||||
from typing import Dict, Union
|
||||
|
||||
from rsl_rl.algorithms.dpg import AbstractDPG
|
||||
from rsl_rl.env import VecEnv
|
||||
from rsl_rl.modules.network import Network
|
||||
from rsl_rl.storage.storage import Dataset
|
||||
|
||||
|
||||
class DDPG(AbstractDPG):
|
||||
"""Deep Deterministic Policy Gradients algorithm.
|
||||
|
||||
This is an implementation of the DDPG algorithm by Lillicrap et. al. for vectorized environments.
|
||||
|
||||
Paper: https://arxiv.org/pdf/1509.02971.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: VecEnv,
|
||||
actor_lr: float = 1e-4,
|
||||
critic_lr: float = 1e-3,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(env, **kwargs)
|
||||
|
||||
self.actor = Network(self._actor_input_size, self._action_size, **self._actor_network_kwargs)
|
||||
self.critic = Network(self._critic_input_size, 1, **self._critic_network_kwargs)
|
||||
|
||||
self.target_actor = Network(self._actor_input_size, self._action_size, **self._actor_network_kwargs)
|
||||
self.target_critic = Network(self._critic_input_size, 1, **self._critic_network_kwargs)
|
||||
self.target_actor.load_state_dict(self.actor.state_dict())
|
||||
self.target_critic.load_state_dict(self.critic.state_dict())
|
||||
|
||||
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)
|
||||
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=critic_lr)
|
||||
|
||||
self._register_serializable(
|
||||
"actor", "critic", "target_actor", "target_critic", "actor_optimizer", "critic_optimizer"
|
||||
)
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
def eval_mode(self) -> DDPG:
|
||||
super().eval_mode()
|
||||
|
||||
self.actor.eval()
|
||||
self.critic.eval()
|
||||
self.target_actor.eval()
|
||||
self.target_critic.eval()
|
||||
|
||||
return self
|
||||
|
||||
def to(self, device: str) -> DDPG:
|
||||
"""Transfers agent parameters to device."""
|
||||
super().to(device)
|
||||
|
||||
self.actor.to(device)
|
||||
self.critic.to(device)
|
||||
self.target_actor.to(device)
|
||||
self.target_critic.to(device)
|
||||
|
||||
return self
|
||||
|
||||
def train_mode(self) -> DDPG:
|
||||
super().train_mode()
|
||||
|
||||
self.actor.train()
|
||||
self.critic.train()
|
||||
self.target_actor.train()
|
||||
self.target_critic.train()
|
||||
|
||||
return self
|
||||
|
||||
def update(self, dataset: Dataset) -> Dict[str, Union[float, torch.Tensor]]:
|
||||
super().update(dataset)
|
||||
|
||||
if not self.initialized:
|
||||
return {}
|
||||
|
||||
total_actor_loss = torch.zeros(self._batch_count)
|
||||
total_critic_loss = torch.zeros(self._batch_count)
|
||||
|
||||
for idx, batch in enumerate(self.storage.batch_generator(self._batch_size, self._batch_count)):
|
||||
actor_obs = batch["actor_observations"]
|
||||
critic_obs = batch["critic_observations"]
|
||||
actions = batch["actions"]
|
||||
rewards = batch["rewards"]
|
||||
actor_next_obs = batch["next_actor_observations"]
|
||||
critic_next_obs = batch["next_critic_observations"]
|
||||
dones = batch["dones"]
|
||||
|
||||
target_actor_prediction = self._process_actions(self.target_actor.forward(actor_next_obs))
|
||||
target_critic_prediction = self.target_critic.forward(
|
||||
self._critic_input(critic_next_obs, target_actor_prediction)
|
||||
)
|
||||
|
||||
target = rewards + self._discount_factor * (1 - dones) * target_critic_prediction
|
||||
prediction = self.critic.forward(self._critic_input(critic_obs, actions))
|
||||
critic_loss = (prediction - target).pow(2).mean()
|
||||
|
||||
self.critic_optimizer.zero_grad()
|
||||
critic_loss.backward()
|
||||
self.critic_optimizer.step()
|
||||
|
||||
evaluation = self.critic.forward(
|
||||
self._critic_input(critic_obs, self._process_actions(self.actor.forward(actor_obs)))
|
||||
)
|
||||
actor_loss = -evaluation.mean()
|
||||
|
||||
self.actor_optimizer.zero_grad()
|
||||
actor_loss.backward()
|
||||
self.actor_optimizer.step()
|
||||
|
||||
self._update_target(self.actor, self.target_actor)
|
||||
self._update_target(self.critic, self.target_critic)
|
||||
|
||||
total_actor_loss[idx] = actor_loss.item()
|
||||
total_critic_loss[idx] = critic_loss.item()
|
||||
|
||||
stats = {"actor": total_actor_loss.mean().item(), "critic": total_critic_loss.mean().item()}
|
||||
|
||||
return stats
|
|
@ -0,0 +1,49 @@
|
|||
import torch
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
from rsl_rl.algorithms.actor_critic import AbstractActorCritic
|
||||
from rsl_rl.env import VecEnv
|
||||
from rsl_rl.storage.replay_storage import ReplayStorage
|
||||
from rsl_rl.storage.storage import Dataset
|
||||
|
||||
|
||||
class AbstractDPG(AbstractActorCritic):
|
||||
def __init__(
|
||||
self, env: VecEnv, action_noise_scale: float = 0.1, storage_initial_size=0, storage_size=100000, **kwargs
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
env (VecEnv): A vectorized environment.
|
||||
action_noise_scale (float): The scale of the gaussian action noise.
|
||||
storage_initial_size (int): Initial size of the replay storage.
|
||||
storage_size (int): Maximum size of the replay storage.
|
||||
"""
|
||||
assert action_noise_scale > 0
|
||||
|
||||
super().__init__(env, **kwargs)
|
||||
|
||||
self.storage = ReplayStorage(
|
||||
self.env.num_envs, storage_size, device=self.device, initial_size=storage_initial_size
|
||||
)
|
||||
|
||||
self._register_serializable("storage")
|
||||
|
||||
self._action_noise_scale = action_noise_scale
|
||||
|
||||
self._register_serializable("_action_noise_scale")
|
||||
|
||||
def draw_actions(
|
||||
self, obs: torch.Tensor, env_info: Dict[str, Any]
|
||||
) -> Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]:
|
||||
actor_obs, critic_obs = self._process_observations(obs, env_info)
|
||||
|
||||
actions = self.actor.forward(actor_obs)
|
||||
noise = torch.normal(torch.zeros_like(actions), torch.ones_like(actions) * self._action_noise_scale)
|
||||
noisy_actions = self._process_actions(actions + noise)
|
||||
|
||||
data = {"actor_observations": actor_obs.clone(), "critic_observations": critic_obs.clone()}
|
||||
|
||||
return noisy_actions, data
|
||||
|
||||
def register_terminations(self, terminations: torch.Tensor) -> None:
|
||||
pass
|
|
@ -0,0 +1,327 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from typing import Dict, List, Tuple, Type, Union
|
||||
|
||||
from rsl_rl.algorithms.ppo import PPO
|
||||
from rsl_rl.distributions import QuantileDistribution
|
||||
from rsl_rl.env import VecEnv
|
||||
from rsl_rl.utils.benchmarkable import Benchmarkable
|
||||
from rsl_rl.utils.recurrency import trajectories_to_transitions, transitions_to_trajectories
|
||||
from rsl_rl.modules import ImplicitQuantileNetwork, QuantileNetwork
|
||||
from rsl_rl.storage.storage import Dataset
|
||||
|
||||
|
||||
class DPPO(PPO):
|
||||
"""Distributional Proximal Policy Optimization algorithm.
|
||||
|
||||
This algorithm is an extension of PPO that uses a distributional method (either QR-DQN or IQN) to estimate the
|
||||
value function.
|
||||
|
||||
QR-DQN Paper: https://arxiv.org/pdf/1710.10044.pdf
|
||||
IQN Paper: https://arxiv.org/pdf/1806.06923.pdf
|
||||
|
||||
The implementation works with recurrent neural networks. We further implement Sample-Replacement SR(lambda) for the
|
||||
value target computation, as described by Nam et. al. in https://arxiv.org/pdf/2105.11366.pdf.
|
||||
"""
|
||||
|
||||
critic_network: Type[nn.Module] = QuantileNetwork
|
||||
_alg_features = dict(recurrent=True)
|
||||
|
||||
value_loss_energy = "sample_energy"
|
||||
value_loss_l1 = "quantile_l1"
|
||||
value_loss_huber = "quantile_huber"
|
||||
|
||||
network_qrdqn = "qrdqn"
|
||||
network_iqn = "iqn"
|
||||
|
||||
networks = {network_qrdqn: QuantileNetwork, network_iqn: ImplicitQuantileNetwork}
|
||||
|
||||
values_losses = {
|
||||
network_qrdqn: {
|
||||
value_loss_energy: QuantileNetwork.sample_energy_loss,
|
||||
value_loss_l1: QuantileNetwork.quantile_l1_loss,
|
||||
value_loss_huber: QuantileNetwork.quantile_huber_loss,
|
||||
},
|
||||
network_iqn: {
|
||||
value_loss_energy: ImplicitQuantileNetwork.sample_energy_loss,
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: VecEnv,
|
||||
critic_activations: List[str] = ["relu", "relu", "relu"],
|
||||
critic_network: str = network_qrdqn,
|
||||
iqn_action_samples: int = 32,
|
||||
iqn_embedding_size: int = 64,
|
||||
iqn_feature_layers: int = 1,
|
||||
iqn_value_samples: int = 8,
|
||||
qrdqn_quantile_count: int = 200,
|
||||
value_lambda: float = 0.95,
|
||||
value_loss: str = value_loss_l1,
|
||||
value_loss_kwargs: Dict = {},
|
||||
value_measure: str = None,
|
||||
value_measure_adaptation: Union[Tuple, None] = None,
|
||||
value_measure_kwargs: Dict = {},
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
env (VecEnv): A vectorized environment.
|
||||
critic_activations (List[str]): A list of activations to use for the critic network.
|
||||
critic_network (str): The critic network to use.
|
||||
iqn_action_samples (int): The number of samples to use for the critic IQN network when acting.
|
||||
iqn_embedding_size (int): The embedding size to use for the critic IQN network.
|
||||
iqn_feature_layers (int): The number of feature layers to use for the critic IQN network.
|
||||
iqn_value_samples (int): The number of samples to use for the critic IQN network when computing the value.
|
||||
qrdqn_quantile_count (int): The number of quantiles to use for the critic QR network.
|
||||
value_lambda (float): The lambda parameter for the SR(lambda) value target computation.
|
||||
value_loss (str): The loss function to use for the critic network.
|
||||
value_loss_kwargs (Dict): Keyword arguments for computing the value loss.
|
||||
value_measure (str): The probability measure to apply to the critic network output distribution when
|
||||
updating the policy.
|
||||
value_measure_adaptation (Union[Tuple, None]): Controls adaptation of the value measure. If None, no
|
||||
adaptation is performed. If a tuple, the tuple specifies the observations that are passed to the value
|
||||
measure.
|
||||
value_measure_kwargs (Dict): The keyword arguments to pass to the value measure.
|
||||
"""
|
||||
self._register_critic_network_kwargs(measure=value_measure, measure_kwargs=value_measure_kwargs)
|
||||
|
||||
self._critic_network_name = critic_network
|
||||
self.critic_network = self.networks[self._critic_network_name]
|
||||
if self._critic_network_name == self.network_qrdqn:
|
||||
self._register_critic_network_kwargs(quantile_count=qrdqn_quantile_count)
|
||||
elif self._critic_network_name == self.network_iqn:
|
||||
self._register_critic_network_kwargs(feature_layers=iqn_feature_layers, embedding_size=iqn_embedding_size)
|
||||
|
||||
kwargs["critic_activations"] = critic_activations
|
||||
|
||||
if value_measure_adaptation is not None:
|
||||
# Value measure adaptation observations are not passed to the critic network.
|
||||
kwargs["_critic_input_size_delta"] = (
|
||||
kwargs["_critic_input_size_delta"] if "_critic_input_size_delta" in kwargs else 0
|
||||
) - len(value_measure_adaptation)
|
||||
|
||||
super().__init__(env, **kwargs)
|
||||
|
||||
self._value_lambda = value_lambda
|
||||
self._value_loss_name = value_loss
|
||||
self._register_serializable("_value_lambda", "_value_loss_name")
|
||||
|
||||
assert (
|
||||
self._value_loss_name in self.values_losses[self._critic_network_name]
|
||||
), f"Value loss '{self._value_loss_name}' is not supported for network '{self._critic_network_name}'."
|
||||
value_loss_func = self.values_losses[critic_network][self._value_loss_name]
|
||||
self._value_loss = lambda *args, **kwargs: value_loss_func(self.critic, *args, **kwargs)
|
||||
|
||||
if value_loss == self.value_loss_energy:
|
||||
value_loss_kwargs["sample_count"] = (
|
||||
value_loss_kwargs["sample_count"] if "sample_count" in value_loss_kwargs else 100
|
||||
)
|
||||
|
||||
self._value_loss_kwargs = value_loss_kwargs
|
||||
self._register_serializable("_value_loss_kwargs")
|
||||
|
||||
self._value_measure_adaptation = value_measure_adaptation
|
||||
self._register_serializable("_value_measure_adaptation")
|
||||
|
||||
if self._critic_network_name == self.network_iqn:
|
||||
self._iqn_action_samples = iqn_action_samples
|
||||
self._iqn_value_samples = iqn_value_samples
|
||||
self._register_serializable("_iqn_action_samples", "_iqn_value_samples")
|
||||
|
||||
def _critic_input(self, observations, actions=None) -> torch.Tensor:
|
||||
mask, shape = self._get_critic_obs_mask(observations)
|
||||
|
||||
processed_observations = observations[mask].reshape(*shape)
|
||||
|
||||
return processed_observations
|
||||
|
||||
def _get_critic_obs_mask(self, observations):
|
||||
mask = torch.ones_like(observations).bool()
|
||||
|
||||
if self._value_measure_adaptation is not None:
|
||||
mask[:, self._value_measure_adaptation] = False
|
||||
|
||||
shape = (observations.shape[0], self._critic_input_size)
|
||||
|
||||
return mask, shape
|
||||
|
||||
def _process_quants(self, x):
|
||||
if self._value_loss_name == self.value_loss_energy:
|
||||
quants, idx = QuantileDistribution(x).sample(self._value_loss_kwargs["sample_count"])
|
||||
else:
|
||||
quants, idx = x, None
|
||||
|
||||
return quants, idx
|
||||
|
||||
@Benchmarkable.register
|
||||
def process_transition(self, *args) -> Dict[str, torch.Tensor]:
|
||||
transition = super(PPO, self).process_transition(*args)
|
||||
|
||||
if self.recurrent:
|
||||
transition["critic_state_h"] = self.critic.hidden_state[0].detach()
|
||||
transition["critic_state_c"] = self.critic.hidden_state[1].detach()
|
||||
|
||||
transition["full_critic_observations"] = transition["critic_observations"].detach()
|
||||
transition["full_next_critic_observations"] = transition["next_critic_observations"].detach()
|
||||
mask, shape = self._get_critic_obs_mask(transition["critic_observations"])
|
||||
transition["critic_observations"] = transition["critic_observations"][mask].reshape(*shape)
|
||||
transition["next_critic_observations"] = transition["next_critic_observations"][mask].reshape(*shape)
|
||||
|
||||
critic_kwargs = (
|
||||
{"sample_count": self._iqn_action_samples} if self._critic_network_name == self.network_iqn else {}
|
||||
)
|
||||
transition["values"] = self.critic.forward(
|
||||
transition["critic_observations"],
|
||||
measure_args=self._extract_value_measure_adaptation(transition["full_critic_observations"]),
|
||||
**critic_kwargs,
|
||||
).detach()
|
||||
|
||||
if self._critic_network_name == self.network_iqn:
|
||||
# For IQN, we sample new (undistorted) quantiles for computing the value update
|
||||
critic_kwargs = (
|
||||
{"hidden_state": (transition["critic_state_h"], transition["critic_state_c"])} if self.recurrent else {}
|
||||
)
|
||||
self.critic.forward(
|
||||
transition["critic_observations"],
|
||||
sample_count=self._iqn_value_samples,
|
||||
use_measure=False,
|
||||
**critic_kwargs,
|
||||
).detach()
|
||||
transition["value_taus"] = self.critic.last_taus.detach().reshape(transition["values"].shape[0], -1)
|
||||
|
||||
transition["value_quants"] = self.critic.last_quantiles.detach().reshape(transition["values"].shape[0], -1)
|
||||
|
||||
if self.recurrent:
|
||||
transition["critic_next_state_h"] = self.critic.hidden_state[0].detach()
|
||||
transition["critic_next_state_c"] = self.critic.hidden_state[1].detach()
|
||||
|
||||
return transition
|
||||
|
||||
@Benchmarkable.register
|
||||
def _compute_value_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
critic_kwargs = (
|
||||
{"sample_count": self._iqn_value_samples, "taus": batch["value_target_taus"], "use_measure": False}
|
||||
if self._critic_network_name == self.network_iqn
|
||||
else {}
|
||||
)
|
||||
|
||||
if self.recurrent:
|
||||
observations, data = transitions_to_trajectories(batch["critic_observations"], batch["dones"])
|
||||
hidden_state_h, _ = transitions_to_trajectories(batch["critic_state_h"], batch["dones"])
|
||||
hidden_state_c, _ = transitions_to_trajectories(batch["critic_state_c"], batch["dones"])
|
||||
hidden_states = (hidden_state_h[0].transpose(0, 1), hidden_state_c[0].transpose(0, 1))
|
||||
|
||||
if self._critic_network_name == self.network_iqn:
|
||||
critic_kwargs["taus"], _ = transitions_to_trajectories(critic_kwargs["taus"], batch["dones"])
|
||||
|
||||
trajectory_evaluations = self.critic.forward(
|
||||
observations, distribution=True, hidden_state=hidden_states, **critic_kwargs
|
||||
)
|
||||
trajectory_evaluations = trajectory_evaluations.reshape(*observations.shape[:-1], -1)
|
||||
|
||||
predictions = trajectories_to_transitions(trajectory_evaluations, data)
|
||||
else:
|
||||
predictions = self.critic.forward(batch["critic_observations"], distribution=True, **critic_kwargs)
|
||||
|
||||
value_loss = self._value_loss(self._process_quants(predictions)[0], batch["value_target_quants"])
|
||||
|
||||
return value_loss
|
||||
|
||||
def _extract_value_measure_adaptation(self, observations: torch.Tensor) -> Tuple[torch.Tensor]:
|
||||
if self._value_measure_adaptation is None:
|
||||
return tuple()
|
||||
|
||||
relevant_observations = observations[:, self._value_measure_adaptation]
|
||||
measure_adaptations = torch.tensor_split(relevant_observations, relevant_observations.shape[1], dim=1)
|
||||
|
||||
return measure_adaptations
|
||||
|
||||
@Benchmarkable.register
|
||||
def _process_dataset(self, dataset: Dataset) -> Dataset:
|
||||
rewards = torch.stack([entry["rewards"] for entry in dataset])
|
||||
dones = torch.stack([entry["dones"] for entry in dataset]).float()
|
||||
timeouts = torch.stack([entry["timeouts"] for entry in dataset])
|
||||
values = torch.stack([entry["values"] for entry in dataset])
|
||||
|
||||
value_quants_idx = [self._process_quants(entry["value_quants"]) for entry in dataset]
|
||||
value_quants = torch.stack([entry[0] for entry in value_quants_idx])
|
||||
|
||||
critic_kwargs = (
|
||||
{"hidden_state": (dataset[-1]["critic_state_h"], dataset[-1]["critic_state_c"])} if self.recurrent else {}
|
||||
)
|
||||
if self._critic_network_name == self.network_iqn:
|
||||
critic_kwargs["sample_count"] = self._iqn_action_samples
|
||||
|
||||
measure_args = self._extract_value_measure_adaptation(dataset[-1]["full_next_critic_observations"])
|
||||
next_values = self.critic.forward(
|
||||
dataset[-1]["next_critic_observations"], measure_args=measure_args, **critic_kwargs
|
||||
)
|
||||
|
||||
if self._critic_network_name == self.network_iqn:
|
||||
# For IQN, we sample new (undistorted) quantiles for computing the value update
|
||||
critic_kwargs["sample_count"] = self._iqn_value_samples
|
||||
self.critic.forward(
|
||||
dataset[-1]["next_critic_observations"],
|
||||
use_measure=False,
|
||||
**critic_kwargs,
|
||||
)
|
||||
|
||||
final_value_taus = self.critic.last_taus
|
||||
value_taus = torch.stack(
|
||||
[
|
||||
torch.take_along_dim(dataset[i]["value_taus"], value_quants_idx[i][1], -1)
|
||||
for i in range(len(dataset))
|
||||
]
|
||||
)
|
||||
|
||||
final_value_quants = self.critic.last_quantiles
|
||||
|
||||
# Timeout bootstrapping for rewards.
|
||||
rewards += self.gamma * timeouts * values
|
||||
|
||||
# Compute advantages and value target quantiles
|
||||
next_values = torch.cat((values[1:], next_values.unsqueeze(0)), dim=0)
|
||||
deltas = (rewards + (1 - dones) * self.gamma * next_values - values).reshape(-1, self.env.num_envs)
|
||||
advantages = torch.zeros((len(dataset) + 1, self.env.num_envs), device=self.device)
|
||||
|
||||
next_value_quants, idx = self._process_quants(final_value_quants)
|
||||
value_target_quants = torch.zeros(len(dataset), *next_value_quants.shape, device=self.device)
|
||||
|
||||
if self._critic_network_name == self.network_iqn:
|
||||
value_target_taus = torch.zeros(len(dataset) + 1, *next_value_quants.shape, device=self.device)
|
||||
value_target_taus[-1] = torch.take_along_dim(final_value_taus, idx, -1)
|
||||
|
||||
for step in reversed(range(len(dataset))):
|
||||
not_terminal = 1.0 - dones[step]
|
||||
not_terminal_ = not_terminal.unsqueeze(-1)
|
||||
|
||||
advantages[step] = deltas[step] + (1.0 - dones[step]) * self.gamma * self._gae_lambda * advantages[step + 1]
|
||||
value_target_quants[step] = rewards[step].unsqueeze(-1) + not_terminal_ * self.gamma * next_value_quants
|
||||
|
||||
preserved_value_quants = not_terminal_.bool() * (
|
||||
torch.rand(next_value_quants.shape, device=self.device) < self._value_lambda
|
||||
)
|
||||
next_value_quants = torch.where(preserved_value_quants, value_target_quants[step], value_quants[step])
|
||||
|
||||
if self._critic_network_name == self.network_iqn:
|
||||
value_target_taus[step] = torch.where(
|
||||
preserved_value_quants, value_target_taus[step + 1], value_taus[step]
|
||||
)
|
||||
|
||||
advantages = advantages[:-1]
|
||||
if self._critic_network_name == self.network_iqn:
|
||||
value_target_taus = value_target_taus[:-1]
|
||||
|
||||
# Normalize advantages and pack into dataset structure.
|
||||
amean, astd = advantages.mean(), torch.nan_to_num(advantages.std())
|
||||
for step in range(len(dataset)):
|
||||
dataset[step]["advantages"] = advantages[step]
|
||||
dataset[step]["normalized_advantages"] = (advantages[step] - amean) / (astd + 1e-8)
|
||||
dataset[step]["value_target_quants"] = value_target_quants[step]
|
||||
|
||||
if self._critic_network_name == self.network_iqn:
|
||||
dataset[step]["value_target_taus"] = value_target_taus[step]
|
||||
|
||||
return dataset
|
|
@ -0,0 +1,75 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from typing import Tuple, Type
|
||||
|
||||
|
||||
from rsl_rl.algorithms.sac import SAC
|
||||
from rsl_rl.env import VecEnv
|
||||
from rsl_rl.modules.quantile_network import QuantileNetwork
|
||||
|
||||
|
||||
class DSAC(SAC):
|
||||
"""Deep Soft Actor Critic (DSAC) algorithm.
|
||||
|
||||
This is an implementation of the DSAC algorithm by Ma et. al. for vectorized environments.
|
||||
|
||||
Paper: https://arxiv.org/pdf/2004.14547.pdf
|
||||
|
||||
The implementation inherits automatic tuning of the temperature parameter (alpha) and tanh action scaling from
|
||||
the SAC implementation.
|
||||
"""
|
||||
|
||||
critic_network: Type[nn.Module] = QuantileNetwork
|
||||
|
||||
def __init__(self, env: VecEnv, critic_activations=["relu", "relu", "relu"], quantile_count=200, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
env (VecEnv): A vectorized environment.
|
||||
critic_activations (list): A list of activation functions to use for the critic network.
|
||||
quantile_count (int): The number of quantiles to use for the critic QR network.
|
||||
"""
|
||||
self._quantile_count = quantile_count
|
||||
self._register_critic_network_kwargs(quantile_count=self._quantile_count)
|
||||
|
||||
kwargs["critic_activations"] = critic_activations
|
||||
|
||||
super().__init__(env, **kwargs)
|
||||
|
||||
def _update_critic(
|
||||
self,
|
||||
critic_obs: torch.Tensor,
|
||||
actions: torch.Tensor,
|
||||
rewards: torch.Tensor,
|
||||
dones: torch.Tensor,
|
||||
actor_next_obs: torch.Tensor,
|
||||
critic_next_obs: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
target_action, target_action_logp = self._sample_action(actor_next_obs)
|
||||
target_critic_input = self._critic_input(critic_next_obs, target_action)
|
||||
target_critic_prediction_1 = self.target_critic_1.forward(target_critic_input, distribution=True)
|
||||
target_critic_prediction_2 = self.target_critic_2.forward(target_critic_input, distribution=True)
|
||||
target_critic_prediction = torch.minimum(target_critic_prediction_1, target_critic_prediction_2)
|
||||
|
||||
next_soft_q = target_critic_prediction - self.alpha * target_action_logp.unsqueeze(-1).repeat(
|
||||
1, self._quantile_count
|
||||
)
|
||||
target = (rewards.reshape(-1, 1) + self._discount_factor * (1 - dones).reshape(-1, 1) * next_soft_q).detach()
|
||||
|
||||
critic_input = self._critic_input(critic_obs, actions).detach()
|
||||
critic_1_prediction = self.critic_1.forward(critic_input, distribution=True)
|
||||
critic_1_loss = self.critic_1.quantile_huber_loss(critic_1_prediction, target)
|
||||
|
||||
self.critic_1_optimizer.zero_grad()
|
||||
critic_1_loss.backward()
|
||||
nn.utils.clip_grad_norm_(self.critic_1.parameters(), self._gradient_clip)
|
||||
self.critic_1_optimizer.step()
|
||||
|
||||
critic_2_prediction = self.critic_2.forward(critic_input, distribution=True)
|
||||
critic_2_loss = self.critic_2.quantile_huber_loss(critic_2_prediction, target)
|
||||
|
||||
self.critic_2_optimizer.zero_grad()
|
||||
critic_2_loss.backward()
|
||||
nn.utils.clip_grad_norm_(self.critic_2.parameters(), self._gradient_clip)
|
||||
self.critic_2_optimizer.step()
|
||||
|
||||
return critic_1_loss, critic_2_loss
|
|
@ -0,0 +1,57 @@
|
|||
from __future__ import annotations
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing import Type
|
||||
|
||||
from rsl_rl.algorithms.td3 import TD3
|
||||
from rsl_rl.env import VecEnv
|
||||
from rsl_rl.modules import QuantileNetwork
|
||||
|
||||
|
||||
class DTD3(TD3):
|
||||
"""Distributional Twin-Delayed Deep Deterministic Policy Gradients algorithm.
|
||||
|
||||
This is an implementation of the TD3 algorithm by Fujimoto et. al. for vectorized environments using a QR-DQN
|
||||
critic.
|
||||
"""
|
||||
|
||||
critic_network: Type[nn.Module] = QuantileNetwork
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: VecEnv,
|
||||
quantile_count: int = 200,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self._quantile_count = quantile_count
|
||||
self._register_critic_network_kwargs(quantile_count=self._quantile_count)
|
||||
|
||||
super().__init__(env, **kwargs)
|
||||
|
||||
def _update_critic(self, critic_obs, actions, rewards, dones, actor_next_obs, critic_next_obs):
|
||||
target_action = self._apply_action_noise(self.target_actor.forward(actor_next_obs), clip=True)
|
||||
target_critic_input = self._critic_input(critic_next_obs, target_action)
|
||||
target_critic_prediction_1 = self.target_critic_1.forward(target_critic_input, distribution=True)
|
||||
target_critic_prediction_2 = self.target_critic_2.forward(target_critic_input, distribution=True)
|
||||
target_critic_prediction = torch.minimum(target_critic_prediction_1, target_critic_prediction_2)
|
||||
|
||||
target = (
|
||||
rewards.reshape(-1, 1) + self._discount_factor * (1 - dones).reshape(-1, 1) * target_critic_prediction
|
||||
).detach()
|
||||
|
||||
critic_input = self._critic_input(critic_obs, actions).detach()
|
||||
critic_1_prediction = self.critic_1.forward(critic_input, distribution=True)
|
||||
critic_1_loss = self.critic_1.quantile_huber_loss(critic_1_prediction, target)
|
||||
|
||||
self.critic_1_optimizer.zero_grad()
|
||||
critic_1_loss.backward()
|
||||
self.critic_1_optimizer.step()
|
||||
|
||||
critic_2_prediction = self.critic_2.forward(critic_input, distribution=True)
|
||||
critic_2_loss = self.critic_2.quantile_huber_loss(critic_2_prediction, target)
|
||||
|
||||
self.critic_2_optimizer.zero_grad()
|
||||
critic_2_loss.backward()
|
||||
self.critic_2_optimizer.step()
|
||||
|
||||
return critic_1_loss, critic_2_loss
|
|
@ -0,0 +1,193 @@
|
|||
from abc import ABC, abstractmethod
|
||||
import torch
|
||||
from typing import Callable, Dict, Tuple, Type, Union
|
||||
|
||||
from rsl_rl.algorithms import D4PG, DSAC
|
||||
from rsl_rl.algorithms import TD3
|
||||
from rsl_rl.algorithms.actor_critic import AbstractActorCritic
|
||||
from rsl_rl.algorithms.agent import Agent
|
||||
from rsl_rl.env import VecEnv
|
||||
from rsl_rl.storage.storage import Dataset, Storage
|
||||
|
||||
|
||||
class AbstractHybridAgent(Agent, ABC):
|
||||
def __init__(
|
||||
self,
|
||||
env: VecEnv,
|
||||
agent_class: Type[Agent],
|
||||
agent_kwargs: dict,
|
||||
pretrain_agent_class: Type[Agent],
|
||||
pretrain_agent_kwargs: dict,
|
||||
pretrain_steps: int,
|
||||
freeze_steps: int = 0,
|
||||
**general_kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
env (VecEnv): A vectorized environment.
|
||||
"""
|
||||
agent_kwargs["env"] = env
|
||||
pretrain_agent_kwargs["env"] = env
|
||||
|
||||
self.agent = agent_class(**agent_kwargs, **general_kwargs)
|
||||
self.pretrain_agent = pretrain_agent_class(**pretrain_agent_kwargs, **general_kwargs)
|
||||
|
||||
self._freeze_steps = freeze_steps
|
||||
self._pretrain_steps = pretrain_steps
|
||||
self._steps = 0
|
||||
|
||||
self._register_serializable("agent", "pretrain_agent", "_freeze_steps", "_pretrain_steps", "_steps")
|
||||
|
||||
@property
|
||||
def active_agent(self):
|
||||
agent = self.pretrain_agent if self.pretraining else self.agent
|
||||
|
||||
return agent
|
||||
|
||||
def draw_actions(self, *args, **kwargs) -> Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]:
|
||||
return self.active_agent.draw_actions(*args, **kwargs)
|
||||
|
||||
def draw_random_actions(self, *args, **kwargs) -> Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]:
|
||||
return self.active_agent.draw_random_actions(*args, **kwargs)
|
||||
|
||||
def eval_mode(self, *args, **kwargs) -> Agent:
|
||||
self.agent.eval_mode(*args, **kwargs)
|
||||
|
||||
def get_inference_policy(self, *args, **kwargs) -> Callable:
|
||||
return self.active_agent.get_inference_policy(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def initialized(self) -> bool:
|
||||
return self.active_agent.initialized
|
||||
|
||||
@property
|
||||
def pretraining(self):
|
||||
return self._steps < self._pretrain_steps
|
||||
|
||||
def process_dataset(self, *args, **kwargs) -> Dataset:
|
||||
return self.active_agent.process_dataset(*args, **kwargs)
|
||||
|
||||
def process_transition(self, *args, **kwargs) -> Dict[str, torch.Tensor]:
|
||||
return self.active_agent.process_transition(*args, **kwargs)
|
||||
|
||||
def register_terminations(self, *args, **kwargs) -> None:
|
||||
return self.active_agent.register_terminations(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def storage(self) -> Storage:
|
||||
return self.active_agent.storage
|
||||
|
||||
def to(self, *args, **kwargs) -> Agent:
|
||||
self.agent.to(*args, **kwargs)
|
||||
self.pretrain_agent.to(*args, **kwargs)
|
||||
|
||||
def train_mode(self, *args, **kwargs) -> Agent:
|
||||
self.agent.train_mode(*args, **kwargs)
|
||||
self.pretrain_agent.train_mode(*args, **kwargs)
|
||||
|
||||
def update(self, *args, **kwargs) -> Dict[str, Union[float, torch.Tensor]]:
|
||||
result = self.active_agent.update(*args, **kwargs)
|
||||
|
||||
if not self.active_agent.initialized:
|
||||
return
|
||||
|
||||
self._steps += 1
|
||||
|
||||
if self._steps == self._pretrain_steps:
|
||||
self._transfer_weights()
|
||||
self._freeze_weights(freeze=True)
|
||||
|
||||
if self._steps == self._pretrain_steps + self._freeze_steps:
|
||||
self._transfer_weights()
|
||||
self._freeze_weights(freeze=False)
|
||||
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
def _freeze_weights(self, freeze=True):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _transfer_weights(self):
|
||||
pass
|
||||
|
||||
|
||||
class HybridD4PG(AbstractHybridAgent):
|
||||
def __init__(
|
||||
self,
|
||||
env: VecEnv,
|
||||
d4pg_kwargs: dict,
|
||||
pretrain_kwargs: dict,
|
||||
pretrain_agent: Type[AbstractActorCritic] = TD3,
|
||||
**kwargs,
|
||||
):
|
||||
assert d4pg_kwargs["action_max"] == pretrain_kwargs["action_max"]
|
||||
assert d4pg_kwargs["action_min"] == pretrain_kwargs["action_min"]
|
||||
assert d4pg_kwargs["actor_activations"] == pretrain_kwargs["actor_activations"]
|
||||
assert d4pg_kwargs["actor_hidden_dims"] == pretrain_kwargs["actor_hidden_dims"]
|
||||
assert d4pg_kwargs["actor_input_normalization"] == pretrain_kwargs["actor_input_normalization"]
|
||||
|
||||
super().__init__(
|
||||
env,
|
||||
agent_class=D4PG,
|
||||
agent_kwargs=d4pg_kwargs,
|
||||
pretrain_agent_class=pretrain_agent,
|
||||
pretrain_agent_kwargs=pretrain_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _freeze_weights(self, freeze=True):
|
||||
for param in self.agent.actor.parameters():
|
||||
param.requires_grad = not freeze
|
||||
|
||||
def _transfer_weights(self):
|
||||
self.agent.actor.load_state_dict(self.pretrain_agent.actor.state_dict())
|
||||
self.agent.actor_optimizer.load_state_dict(self.pretrain_agent.actor_optimizer.state_dict())
|
||||
|
||||
|
||||
class HybridDSAC(AbstractHybridAgent):
|
||||
def __init__(
|
||||
self,
|
||||
env: VecEnv,
|
||||
dsac_kwargs: dict,
|
||||
pretrain_kwargs: dict,
|
||||
pretrain_agent: Type[AbstractActorCritic] = TD3,
|
||||
**kwargs,
|
||||
):
|
||||
assert dsac_kwargs["action_max"] == pretrain_kwargs["action_max"]
|
||||
assert dsac_kwargs["action_min"] == pretrain_kwargs["action_min"]
|
||||
assert dsac_kwargs["actor_activations"] == pretrain_kwargs["actor_activations"]
|
||||
assert dsac_kwargs["actor_hidden_dims"] == pretrain_kwargs["actor_hidden_dims"]
|
||||
assert dsac_kwargs["actor_input_normalization"] == pretrain_kwargs["actor_input_normalization"]
|
||||
|
||||
super().__init__(
|
||||
env,
|
||||
agent_class=DSAC,
|
||||
agent_kwargs=dsac_kwargs,
|
||||
pretrain_agent_class=pretrain_agent,
|
||||
pretrain_agent_kwargs=pretrain_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _freeze_weights(self, freeze=True):
|
||||
"""Freezes actor layers.
|
||||
|
||||
Freezes feature encoding and mean computation layers for gaussian network. Leaves log standard deviation layer
|
||||
unfreezed.
|
||||
"""
|
||||
for param in self.agent.actor._layers.parameters():
|
||||
param.requires_grad = not freeze
|
||||
|
||||
for param in self.agent.actor._mean_layer.parameters():
|
||||
param.requires_grad = not freeze
|
||||
|
||||
def _transfer_weights(self):
|
||||
"""Transfers actor layers.
|
||||
|
||||
Transfers only feature encoding and mean computation layers for gaussian network.
|
||||
"""
|
||||
for i, layer in enumerate(self.agent.actor._layers):
|
||||
layer.load_state_dict(self.pretrain_agent.actor._layers[i].state_dict())
|
||||
|
||||
for j, layer in enumerate(self.agent.actor._mean_layer):
|
||||
layer.load_state_dict(self.pretrain_agent.actor._layers[i + j + 1].state_dict())
|
|
@ -1,185 +1,384 @@
|
|||
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch import nn, optim
|
||||
from typing import Any, Dict, Tuple, Type, Union
|
||||
|
||||
from rsl_rl.modules import ActorCritic
|
||||
from rsl_rl.storage import RolloutStorage
|
||||
from rsl_rl.algorithms.actor_critic import AbstractActorCritic
|
||||
from rsl_rl.env import VecEnv
|
||||
from rsl_rl.utils.benchmarkable import Benchmarkable
|
||||
from rsl_rl.utils.recurrency import trajectories_to_transitions, transitions_to_trajectories
|
||||
from rsl_rl.modules import GaussianNetwork, Network
|
||||
from rsl_rl.storage.rollout_storage import RolloutStorage
|
||||
from rsl_rl.storage.storage import Dataset
|
||||
|
||||
|
||||
class PPO:
|
||||
actor_critic: ActorCritic
|
||||
class PPO(AbstractActorCritic):
|
||||
"""Proximal Policy Optimization algorithm.
|
||||
|
||||
This is an implementation of the PPO algorithm by Schulman et. al. for vectorized environments.
|
||||
|
||||
Paper: https://arxiv.org/pdf/1707.06347.pdf
|
||||
|
||||
The implementation works with recurrent neural networks. We implement adaptive learning rate based on the
|
||||
KL-divergence between the old and new policy, as described by Schulman et. al. in
|
||||
https://arxiv.org/pdf/1707.06347.pdf.
|
||||
"""
|
||||
|
||||
critic_network: Type[nn.Module] = Network
|
||||
_alg_features = dict(recurrent=True)
|
||||
|
||||
schedule_adaptive = "adaptive"
|
||||
schedule_fixed = "fixed"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
actor_critic,
|
||||
num_learning_epochs=1,
|
||||
num_mini_batches=1,
|
||||
clip_param=0.2,
|
||||
gamma=0.998,
|
||||
lam=0.95,
|
||||
value_loss_coef=1.0,
|
||||
entropy_coef=0.0,
|
||||
learning_rate=1e-3,
|
||||
max_grad_norm=1.0,
|
||||
use_clipped_value_loss=True,
|
||||
schedule="fixed",
|
||||
desired_kl=0.01,
|
||||
device="cpu",
|
||||
env: VecEnv,
|
||||
actor_noise_std: float = 1.0,
|
||||
clip_ratio: float = 0.2,
|
||||
entropy_coeff: float = 0.0,
|
||||
gae_lambda: float = 0.97,
|
||||
gradient_clip: float = 1.0,
|
||||
learning_rate: float = 1e-3,
|
||||
schedule: str = "fixed",
|
||||
target_kl: float = 0.01,
|
||||
value_coeff: float = 1.0,
|
||||
**kwargs,
|
||||
):
|
||||
self.device = device
|
||||
"""
|
||||
Args:
|
||||
env (VecEnv): A vectorized environment.
|
||||
actor_noise_std (float): The standard deviation of the Gaussian noise to add to the actor network output.
|
||||
clip_ratio (float): The clipping ratio for the PPO objective.
|
||||
entropy_coeff (float): The coefficient for the entropy term in the PPO objective.
|
||||
gae_lambda (float): The lambda parameter for the GAE computation.
|
||||
gradient_clip (float): The gradient clipping value.
|
||||
learning_rate (float): The learning rate for the actor and critic networks.
|
||||
schedule (str): The learning rate schedule. Can be "fixed" or "adaptive". Defaults to "fixed".
|
||||
target_kl (float): The target KL-divergence for the adaptive learning rate schedule.
|
||||
value_coeff (float): The coefficient for the value function loss in the PPO objective.
|
||||
"""
|
||||
kwargs["batch_size"] = env.num_envs
|
||||
kwargs["return_steps"] = 1
|
||||
|
||||
self.desired_kl = desired_kl
|
||||
self.schedule = schedule
|
||||
self.learning_rate = learning_rate
|
||||
super().__init__(env, **kwargs)
|
||||
self._critic_input_size = self._critic_input_size - self._action_size # We use a state-value function (not Q)
|
||||
|
||||
# PPO components
|
||||
self.actor_critic = actor_critic
|
||||
self.actor_critic.to(self.device)
|
||||
self.storage = None # initialized later
|
||||
self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=learning_rate)
|
||||
self.transition = RolloutStorage.Transition()
|
||||
self.storage = RolloutStorage(self.env.num_envs, device=self.device)
|
||||
|
||||
# PPO parameters
|
||||
self.clip_param = clip_param
|
||||
self.num_learning_epochs = num_learning_epochs
|
||||
self.num_mini_batches = num_mini_batches
|
||||
self.value_loss_coef = value_loss_coef
|
||||
self.entropy_coef = entropy_coef
|
||||
self.gamma = gamma
|
||||
self.lam = lam
|
||||
self.max_grad_norm = max_grad_norm
|
||||
self.use_clipped_value_loss = use_clipped_value_loss
|
||||
self._register_serializable("storage")
|
||||
|
||||
def init_storage(self, num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape):
|
||||
self.storage = RolloutStorage(
|
||||
num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape, self.device
|
||||
self._clip_ratio = clip_ratio
|
||||
self._entropy_coeff = entropy_coeff
|
||||
self._gae_lambda = gae_lambda
|
||||
self._gradient_clip = gradient_clip
|
||||
self._schedule = schedule
|
||||
self._target_kl = target_kl
|
||||
self._value_coeff = value_coeff
|
||||
|
||||
self._register_serializable(
|
||||
"_clip_ratio",
|
||||
"_entropy_coeff",
|
||||
"_gae_lambda",
|
||||
"_gradient_clip",
|
||||
"_schedule",
|
||||
"_target_kl",
|
||||
"_value_coeff",
|
||||
)
|
||||
|
||||
def test_mode(self):
|
||||
self.actor_critic.test()
|
||||
self.actor = GaussianNetwork(
|
||||
self._actor_input_size, self._action_size, std_init=actor_noise_std, **self._actor_network_kwargs
|
||||
)
|
||||
self.critic = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs)
|
||||
|
||||
def train_mode(self):
|
||||
self.actor_critic.train()
|
||||
if self.recurrent:
|
||||
self.actor.reset_full_hidden_state(batch_size=self.env.num_envs)
|
||||
self.critic.reset_full_hidden_state(batch_size=self.env.num_envs)
|
||||
|
||||
def act(self, obs, critic_obs):
|
||||
if self.actor_critic.is_recurrent:
|
||||
self.transition.hidden_states = self.actor_critic.get_hidden_states()
|
||||
# Compute the actions and values
|
||||
self.transition.actions = self.actor_critic.act(obs).detach()
|
||||
self.transition.values = self.actor_critic.evaluate(critic_obs).detach()
|
||||
self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach()
|
||||
self.transition.action_mean = self.actor_critic.action_mean.detach()
|
||||
self.transition.action_sigma = self.actor_critic.action_std.detach()
|
||||
# need to record obs and critic_obs before env.step()
|
||||
self.transition.observations = obs
|
||||
self.transition.critic_observations = critic_obs
|
||||
return self.transition.actions
|
||||
tp = lambda v: v.transpose(0, 1) # for storing, transpose (num_layers, batch, F) to (batch, num_layers, F)
|
||||
self.storage.register_processor("actor_state_h", tp)
|
||||
self.storage.register_processor("actor_state_c", tp)
|
||||
self.storage.register_processor("critic_state_h", tp)
|
||||
self.storage.register_processor("critic_state_c", tp)
|
||||
self.storage.register_processor("critic_next_state_h", tp)
|
||||
self.storage.register_processor("critic_next_state_c", tp)
|
||||
|
||||
def process_env_step(self, rewards, dones, infos):
|
||||
self.transition.rewards = rewards.clone()
|
||||
self.transition.dones = dones
|
||||
# Bootstrapping on time outs
|
||||
if "time_outs" in infos:
|
||||
self.transition.rewards += self.gamma * torch.squeeze(
|
||||
self.transition.values * infos["time_outs"].unsqueeze(1).to(self.device), 1
|
||||
)
|
||||
self._bm_fuse(self.actor, prefix="actor.")
|
||||
self._bm_fuse(self.critic, prefix="critic.")
|
||||
|
||||
# Record the transition
|
||||
self.storage.add_transitions(self.transition)
|
||||
self.transition.clear()
|
||||
self.actor_critic.reset(dones)
|
||||
self._register_serializable("actor", "critic")
|
||||
|
||||
def compute_returns(self, last_critic_obs):
|
||||
last_values = self.actor_critic.evaluate(last_critic_obs).detach()
|
||||
self.storage.compute_returns(last_values, self.gamma, self.lam)
|
||||
self.learning_rate = learning_rate
|
||||
self.optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
|
||||
|
||||
def update(self):
|
||||
mean_value_loss = 0
|
||||
mean_surrogate_loss = 0
|
||||
if self.actor_critic.is_recurrent:
|
||||
generator = self.storage.reccurent_mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
|
||||
else:
|
||||
generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
|
||||
for (
|
||||
obs_batch,
|
||||
critic_obs_batch,
|
||||
actions_batch,
|
||||
target_values_batch,
|
||||
advantages_batch,
|
||||
returns_batch,
|
||||
old_actions_log_prob_batch,
|
||||
old_mu_batch,
|
||||
old_sigma_batch,
|
||||
hid_states_batch,
|
||||
masks_batch,
|
||||
) in generator:
|
||||
self.actor_critic.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0])
|
||||
actions_log_prob_batch = self.actor_critic.get_actions_log_prob(actions_batch)
|
||||
value_batch = self.actor_critic.evaluate(
|
||||
critic_obs_batch, masks=masks_batch, hidden_states=hid_states_batch[1]
|
||||
)
|
||||
mu_batch = self.actor_critic.action_mean
|
||||
sigma_batch = self.actor_critic.action_std
|
||||
entropy_batch = self.actor_critic.entropy
|
||||
self._register_serializable("learning_rate", "optimizer")
|
||||
|
||||
# KL
|
||||
if self.desired_kl is not None and self.schedule == "adaptive":
|
||||
with torch.inference_mode():
|
||||
kl = torch.sum(
|
||||
torch.log(sigma_batch / old_sigma_batch + 1.0e-5)
|
||||
+ (torch.square(old_sigma_batch) + torch.square(old_mu_batch - mu_batch))
|
||||
/ (2.0 * torch.square(sigma_batch))
|
||||
- 0.5,
|
||||
axis=-1,
|
||||
)
|
||||
kl_mean = torch.mean(kl)
|
||||
def draw_random_actions(self, obs: torch.Tensor, env_info: Dict[str, Any]) -> torch.Tensor:
|
||||
raise NotImplementedError("PPO does not support drawing random actions.")
|
||||
|
||||
if kl_mean > self.desired_kl * 2.0:
|
||||
self.learning_rate = max(1e-5, self.learning_rate / 1.5)
|
||||
elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0:
|
||||
self.learning_rate = min(1e-2, self.learning_rate * 1.5)
|
||||
@Benchmarkable.register
|
||||
def draw_actions(
|
||||
self, obs: torch.Tensor, env_info: Dict[str, Any]
|
||||
) -> Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]:
|
||||
actor_obs, critic_obs = self._process_observations(obs, env_info)
|
||||
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group["lr"] = self.learning_rate
|
||||
data = {}
|
||||
|
||||
# Surrogate loss
|
||||
ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch))
|
||||
surrogate = -torch.squeeze(advantages_batch) * ratio
|
||||
surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp(
|
||||
ratio, 1.0 - self.clip_param, 1.0 + self.clip_param
|
||||
)
|
||||
surrogate_loss = torch.max(surrogate, surrogate_clipped).mean()
|
||||
if self.recurrent:
|
||||
data["actor_state_h"] = self.actor.hidden_state[0].detach()
|
||||
data["actor_state_c"] = self.actor.hidden_state[1].detach()
|
||||
|
||||
# Value function loss
|
||||
if self.use_clipped_value_loss:
|
||||
value_clipped = target_values_batch + (value_batch - target_values_batch).clamp(
|
||||
-self.clip_param, self.clip_param
|
||||
)
|
||||
value_losses = (value_batch - returns_batch).pow(2)
|
||||
value_losses_clipped = (value_clipped - returns_batch).pow(2)
|
||||
value_loss = torch.max(value_losses, value_losses_clipped).mean()
|
||||
mean, std = self.actor.forward(actor_obs, compute_std=True)
|
||||
action_distribution = torch.distributions.Normal(mean, std)
|
||||
actions = self._process_actions(action_distribution.rsample()).detach()
|
||||
action_prediction_logp = action_distribution.log_prob(actions).sum(-1)
|
||||
|
||||
data["actor_observations"] = actor_obs
|
||||
data["critic_observations"] = critic_obs
|
||||
data["actions_logp"] = action_prediction_logp.detach()
|
||||
data["actions_mean"] = action_distribution.mean.detach()
|
||||
data["actions_std"] = action_distribution.stddev.detach()
|
||||
|
||||
return actions, data
|
||||
|
||||
def eval_mode(self) -> AbstractActorCritic:
|
||||
super().eval_mode()
|
||||
|
||||
self.actor.eval()
|
||||
self.critic.eval()
|
||||
|
||||
return self
|
||||
|
||||
@property
|
||||
def initialized(self) -> bool:
|
||||
return True
|
||||
|
||||
@Benchmarkable.register
|
||||
def process_transition(self, *args) -> Dict[str, torch.Tensor]:
|
||||
transition = super().process_transition(*args)
|
||||
|
||||
if self.recurrent:
|
||||
transition["critic_state_h"] = self.critic.hidden_state[0].detach()
|
||||
transition["critic_state_c"] = self.critic.hidden_state[1].detach()
|
||||
|
||||
transition["values"] = self.critic.forward(transition["critic_observations"]).detach()
|
||||
|
||||
if self.recurrent:
|
||||
transition["critic_next_state_h"] = self.critic.hidden_state[0].detach()
|
||||
transition["critic_next_state_c"] = self.critic.hidden_state[1].detach()
|
||||
|
||||
return transition
|
||||
|
||||
def parameters(self):
|
||||
params = list(self.actor.parameters()) + list(self.critic.parameters())
|
||||
|
||||
return params
|
||||
|
||||
def register_terminations(self, terminations: torch.Tensor) -> None:
|
||||
"""Registers terminations with the agent.
|
||||
|
||||
Args:
|
||||
terminations (torch.Tensor): A 1-dimensional int tensor containing the indices of the terminated
|
||||
environments.
|
||||
"""
|
||||
if terminations.shape[0] == 0:
|
||||
return
|
||||
|
||||
if self.recurrent:
|
||||
self.actor.reset_hidden_state(terminations)
|
||||
self.critic.reset_hidden_state(terminations)
|
||||
|
||||
def to(self, device: str) -> AbstractActorCritic:
|
||||
super().to(device)
|
||||
|
||||
self.actor.to(device)
|
||||
self.critic.to(device)
|
||||
|
||||
return self
|
||||
|
||||
def train_mode(self) -> AbstractActorCritic:
|
||||
super().train_mode()
|
||||
|
||||
self.actor.train()
|
||||
self.critic.train()
|
||||
|
||||
return self
|
||||
|
||||
@Benchmarkable.register
|
||||
def update(self, dataset: Dataset) -> Dict[str, Union[float, torch.Tensor]]:
|
||||
super().update(dataset)
|
||||
|
||||
assert self.storage.initialized
|
||||
|
||||
total_loss = torch.zeros(self._batch_count)
|
||||
total_surrogate_loss = torch.zeros(self._batch_count)
|
||||
total_value_loss = torch.zeros(self._batch_count)
|
||||
|
||||
for idx, batch in enumerate(self.storage.batch_generator(self._batch_count, trajectories=self.recurrent)):
|
||||
if self.recurrent:
|
||||
transition_obs = batch["actor_observations"].reshape(*batch["actor_observations"].shape[:2], -1)
|
||||
observations, data = transitions_to_trajectories(transition_obs, batch["dones"])
|
||||
hidden_state_h, _ = transitions_to_trajectories(batch["actor_state_h"], batch["dones"])
|
||||
hidden_state_c, _ = transitions_to_trajectories(batch["actor_state_c"], batch["dones"])
|
||||
# Init. sequence with each trajectory's first hidden state. Subsequent hidden states are produced by the
|
||||
# network, depending on the previous hidden state and the current observation.
|
||||
hidden_state = (hidden_state_h[0].transpose(0, 1), hidden_state_c[0].transpose(0, 1))
|
||||
|
||||
action_mean, action_std = self.actor.forward(observations, hidden_state=hidden_state, compute_std=True)
|
||||
|
||||
action_mean = action_mean.reshape(*observations.shape[:-1], self._action_size)
|
||||
action_std = action_std.reshape(*observations.shape[:-1], self._action_size)
|
||||
|
||||
action_mean = trajectories_to_transitions(action_mean, data)
|
||||
action_std = trajectories_to_transitions(action_std, data)
|
||||
else:
|
||||
value_loss = (returns_batch - value_batch).pow(2).mean()
|
||||
action_mean, action_std = self.actor.forward(batch["actor_observations"], compute_std=True)
|
||||
|
||||
loss = surrogate_loss + self.value_loss_coef * value_loss - self.entropy_coef * entropy_batch.mean()
|
||||
actions_dist = torch.distributions.Normal(action_mean, action_std)
|
||||
|
||||
if self._schedule == self.schedule_adaptive:
|
||||
self._update_learning_rate(batch, actions_dist)
|
||||
|
||||
surrogate_loss = self._compute_actor_loss(batch, actions_dist)
|
||||
value_loss = self._compute_value_loss(batch)
|
||||
actions_entropy = actions_dist.entropy().sum(-1)
|
||||
|
||||
loss = surrogate_loss + self._value_coeff * value_loss - self._entropy_coeff * actions_entropy.mean()
|
||||
|
||||
# Gradient step
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.max_grad_norm)
|
||||
nn.utils.clip_grad_norm_(self.parameters(), self._gradient_clip)
|
||||
self.optimizer.step()
|
||||
|
||||
mean_value_loss += value_loss.item()
|
||||
mean_surrogate_loss += surrogate_loss.item()
|
||||
total_loss[idx] = loss.detach()
|
||||
total_surrogate_loss[idx] = surrogate_loss.detach()
|
||||
total_value_loss[idx] = value_loss.detach()
|
||||
|
||||
num_updates = self.num_learning_epochs * self.num_mini_batches
|
||||
mean_value_loss /= num_updates
|
||||
mean_surrogate_loss /= num_updates
|
||||
self.storage.clear()
|
||||
stats = {
|
||||
"total": total_loss.mean().item(),
|
||||
"surrogate": total_surrogate_loss.mean().item(),
|
||||
"value": total_value_loss.mean().item(),
|
||||
}
|
||||
|
||||
return mean_value_loss, mean_surrogate_loss
|
||||
return stats
|
||||
|
||||
@Benchmarkable.register
|
||||
def _compute_actor_loss(
|
||||
self, batch: Dict[str, torch.Tensor], actions_dist: torch.distributions.Normal
|
||||
) -> torch.Tensor:
|
||||
batch_actions_logp = actions_dist.log_prob(batch["actions"]).sum(-1)
|
||||
|
||||
ratio = (batch_actions_logp - batch["actions_logp"]).exp()
|
||||
surrogate = batch["normalized_advantages"] * ratio
|
||||
surrogate_clipped = batch["normalized_advantages"] * ratio.clamp(1.0 - self._clip_ratio, 1.0 + self._clip_ratio)
|
||||
surrogate_loss = -torch.min(surrogate, surrogate_clipped).mean()
|
||||
|
||||
return surrogate_loss
|
||||
|
||||
@Benchmarkable.register
|
||||
def _compute_value_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
if self.recurrent:
|
||||
observations, data = transitions_to_trajectories(batch["critic_observations"], batch["dones"])
|
||||
hidden_state_h, _ = transitions_to_trajectories(batch["critic_state_h"], batch["dones"])
|
||||
hidden_state_c, _ = transitions_to_trajectories(batch["critic_state_c"], batch["dones"])
|
||||
hidden_states = (hidden_state_h[0].transpose(0, 1), hidden_state_c[0].transpose(0, 1))
|
||||
|
||||
trajectory_evaluations = self.critic.forward(observations, hidden_state=hidden_states)
|
||||
trajectory_evaluations = trajectory_evaluations.reshape(*observations.shape[:-1])
|
||||
|
||||
evaluation = trajectories_to_transitions(trajectory_evaluations, data)
|
||||
else:
|
||||
evaluation = self.critic.forward(batch["critic_observations"])
|
||||
|
||||
value_clipped = batch["values"] + (evaluation - batch["values"]).clamp(-self._clip_ratio, self._clip_ratio)
|
||||
returns = batch["advantages"] + batch["values"]
|
||||
value_losses = (evaluation - returns).pow(2)
|
||||
value_losses_clipped = (value_clipped - returns).pow(2)
|
||||
|
||||
value_loss = torch.max(value_losses, value_losses_clipped).mean()
|
||||
|
||||
return value_loss
|
||||
|
||||
def _critic_input(self, observations, actions=None) -> torch.Tensor:
|
||||
return observations
|
||||
|
||||
def _entry_to_hs(
|
||||
self, entry: Dict[str, torch.Tensor], critic: bool = False, next: bool = False
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Helper function to turn a dataset entry into a hidden state tuple.
|
||||
|
||||
Args:
|
||||
entry (Dict[str, torch.Tensor]): The dataset entry.
|
||||
critic (bool): Whether to extract the hidden state for the critic instead of the actor. Defaults to False.
|
||||
next (bool): Whether the hidden state is for the next step or the current. Defaults to False.
|
||||
Returns:
|
||||
A tuple of hidden state tensors.
|
||||
"""
|
||||
key = ("critic" if critic else "actor") + "_" + ("next_state" if next else "state")
|
||||
hidden_state = entry[f"{key}_h"], entry[f"{key}_c"]
|
||||
|
||||
return hidden_state
|
||||
|
||||
@Benchmarkable.register
|
||||
def _process_dataset(self, dataset: Dataset) -> Dataset:
|
||||
"""Processes a dataset before it is added to the replay memory.
|
||||
|
||||
Computes advantages and returns.
|
||||
|
||||
Args:
|
||||
dataset (Dataset): The dataset to process.
|
||||
Returns:
|
||||
A Dataset object containing the processed data.
|
||||
"""
|
||||
rewards = torch.stack([entry["rewards"] for entry in dataset])
|
||||
dones = torch.stack([entry["dones"] for entry in dataset])
|
||||
timeouts = torch.stack([entry["timeouts"] for entry in dataset])
|
||||
values = torch.stack([entry["values"] for entry in dataset])
|
||||
|
||||
# We could alternatively compute the next hidden state from the current state and hidden state. But this
|
||||
# (storing the hidden state when evaluating the action in process_transition) is computationally more efficient
|
||||
# and doesn't change the result as the network is not updated between storing the data and computing advantages.
|
||||
|
||||
critic_kwargs = (
|
||||
{"hidden_state": (dataset[-1]["critic_state_h"], dataset[-1]["critic_state_c"])} if self.recurrent else {}
|
||||
)
|
||||
final_values = self.critic.forward(dataset[-1]["next_critic_observations"], **critic_kwargs)
|
||||
next_values = torch.cat((values[1:], final_values.unsqueeze(0)), dim=0)
|
||||
|
||||
rewards += self.gamma * timeouts * values
|
||||
deltas = (rewards + (1 - dones).float() * self.gamma * next_values - values).reshape(-1, self.env.num_envs)
|
||||
|
||||
advantages = torch.zeros((len(dataset) + 1, self.env.num_envs), device=self.device)
|
||||
for step in reversed(range(len(dataset))):
|
||||
advantages[step] = (
|
||||
deltas[step] + (1 - dones[step]).float() * self.gamma * self._gae_lambda * advantages[step + 1]
|
||||
)
|
||||
advantages = advantages[:-1]
|
||||
|
||||
amean, astd = advantages.mean(), torch.nan_to_num(advantages.std())
|
||||
for step in range(len(dataset)):
|
||||
dataset[step]["advantages"] = advantages[step]
|
||||
dataset[step]["normalized_advantages"] = (advantages[step] - amean) / (astd + 1e-8)
|
||||
|
||||
return dataset
|
||||
|
||||
@Benchmarkable.register
|
||||
def _update_learning_rate(self, batch: Dict[str, torch.Tensor], actions_dist: torch.distributions.Normal) -> None:
|
||||
with torch.inference_mode():
|
||||
actions_mean = actions_dist.mean
|
||||
actions_std = actions_dist.stddev
|
||||
|
||||
kl = torch.sum(
|
||||
torch.log(actions_std / batch["actions_std"] + 1.0e-5)
|
||||
+ (torch.square(batch["actions_std"]) + torch.square(batch["actions_mean"] - actions_mean))
|
||||
/ (2.0 * torch.square(actions_std))
|
||||
- 0.5,
|
||||
axis=-1,
|
||||
)
|
||||
kl_mean = torch.mean(kl)
|
||||
|
||||
if kl_mean > self._target_kl * 2.0:
|
||||
self.learning_rate = max(1e-5, self.learning_rate / 1.5)
|
||||
elif kl_mean < self._target_kl / 2.0 and kl_mean > 0.0:
|
||||
self.learning_rate = min(1e-2, self.learning_rate * 1.5)
|
||||
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group["lr"] = self.learning_rate
|
||||
|
|
|
@ -0,0 +1,319 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from torch import nn, optim
|
||||
from typing import Any, Callable, Dict, Tuple, Type, Union
|
||||
|
||||
from rsl_rl.algorithms.actor_critic import AbstractActorCritic
|
||||
from rsl_rl.env import VecEnv
|
||||
from rsl_rl.modules import Network, GaussianChimeraNetwork, GaussianNetwork
|
||||
from rsl_rl.storage.replay_storage import ReplayStorage
|
||||
from rsl_rl.storage.storage import Dataset
|
||||
|
||||
|
||||
class SAC(AbstractActorCritic):
|
||||
"""Soft Actor Critic algorithm.
|
||||
|
||||
This is an implementation of the SAC algorithm by Haarnoja et. al. for vectorized environments.
|
||||
|
||||
Paper: https://arxiv.org/pdf/1801.01290.pdf
|
||||
|
||||
We additionally implement automatic tuning of the temperature parameter (alpha) and tanh action scaling, as
|
||||
introduced by Haarnoja et. al. in https://arxiv.org/pdf/1812.05905.pdf.
|
||||
"""
|
||||
|
||||
critic_network: Type[nn.Module] = Network
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: VecEnv,
|
||||
action_max: float = 100.0,
|
||||
action_min: float = -100.0,
|
||||
actor_lr: float = 1e-4,
|
||||
actor_noise_std: float = 1.0,
|
||||
alpha: float = 0.2,
|
||||
alpha_lr: float = 1e-3,
|
||||
chimera: bool = True,
|
||||
critic_lr: float = 1e-3,
|
||||
gradient_clip: float = 1.0,
|
||||
log_std_max: float = 4.0,
|
||||
log_std_min: float = -20.0,
|
||||
storage_initial_size: int = 0,
|
||||
storage_size: int = 100000,
|
||||
target_entropy: float = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
env (VecEnv): A vectorized environment.
|
||||
actor_lr (float): Learning rate for the actor.
|
||||
alpha (float): Initial entropy regularization coefficient.
|
||||
alpha_lr (float): Learning rate for entropy regularization coefficient.
|
||||
chimera (bool): Whether to use separate heads for computing action mean and std (True) or treat the std as a
|
||||
tunable parameter (True).
|
||||
critic_lr (float): Learning rate for the critic.
|
||||
gradient_clip (float): Gradient clip value.
|
||||
log_std_max (float): Maximum log standard deviation.
|
||||
log_std_min (float): Minimum log standard deviation.
|
||||
storage_initial_size (int): Initial size of the replay storage.
|
||||
storage_size (int): Maximum size of the replay storage.
|
||||
target_entropy (float): Target entropy for the actor policy. Defaults to action space dimensionality.
|
||||
"""
|
||||
super().__init__(env, action_max=action_max, action_min=action_min, **kwargs)
|
||||
|
||||
self.storage = ReplayStorage(
|
||||
self.env.num_envs, storage_size, device=self.device, initial_size=storage_initial_size
|
||||
)
|
||||
|
||||
self._register_serializable("storage")
|
||||
|
||||
assert self._action_max < np.inf, 'Parameter "action_max" needs to be set for SAC.'
|
||||
assert self._action_min > -np.inf, 'Parameter "action_min" needs to be set for SAC.'
|
||||
|
||||
self._action_delta = 0.5 * (self._action_max - self._action_min)
|
||||
self._action_offset = 0.5 * (self._action_max + self._action_min)
|
||||
|
||||
self.log_alpha = torch.tensor(np.log(alpha), dtype=torch.float32).requires_grad_()
|
||||
self._gradient_clip = gradient_clip
|
||||
self._target_entropy = target_entropy if target_entropy else -self._action_size
|
||||
|
||||
self._register_serializable("log_alpha", "_gradient_clip")
|
||||
|
||||
network_class = GaussianChimeraNetwork if chimera else GaussianNetwork
|
||||
self.actor = network_class(
|
||||
self._actor_input_size,
|
||||
self._action_size,
|
||||
log_std_max=log_std_max,
|
||||
log_std_min=log_std_min,
|
||||
std_init=actor_noise_std,
|
||||
**self._actor_network_kwargs
|
||||
)
|
||||
|
||||
self.critic_1 = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs)
|
||||
self.critic_2 = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs)
|
||||
|
||||
self.target_critic_1 = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs)
|
||||
self.target_critic_2 = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs)
|
||||
self.target_critic_1.load_state_dict(self.critic_1.state_dict())
|
||||
self.target_critic_2.load_state_dict(self.critic_2.state_dict())
|
||||
|
||||
self._register_serializable("actor", "critic_1", "critic_2", "target_critic_1", "target_critic_2")
|
||||
|
||||
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)
|
||||
self.log_alpha_optimizer = optim.Adam([self.log_alpha], lr=alpha_lr)
|
||||
self.critic_1_optimizer = optim.Adam(self.critic_1.parameters(), lr=critic_lr)
|
||||
self.critic_2_optimizer = optim.Adam(self.critic_2.parameters(), lr=critic_lr)
|
||||
|
||||
self._register_serializable(
|
||||
"actor_optimizer", "log_alpha_optimizer", "critic_1_optimizer", "critic_2_optimizer"
|
||||
)
|
||||
|
||||
self.critic = self.critic_1
|
||||
|
||||
@property
|
||||
def alpha(self):
|
||||
return self.log_alpha.exp()
|
||||
|
||||
def draw_actions(
|
||||
self, obs: torch.Tensor, env_info: Dict[str, Any]
|
||||
) -> Tuple[torch.Tensor, Union[Dict[str, torch.Tensor], None]]:
|
||||
actor_obs, critic_obs = self._process_observations(obs, env_info)
|
||||
|
||||
action = self._sample_action(actor_obs, compute_logp=False)
|
||||
data = {"actor_observations": actor_obs.clone(), "critic_observations": critic_obs.clone()}
|
||||
|
||||
return action, data
|
||||
|
||||
def eval_mode(self) -> AbstractActorCritic:
|
||||
super().eval_mode()
|
||||
|
||||
self.actor.eval()
|
||||
self.critic_1.eval()
|
||||
self.critic_2.eval()
|
||||
self.target_critic_1.eval()
|
||||
self.target_critic_2.eval()
|
||||
|
||||
return self
|
||||
|
||||
def get_inference_policy(self, device=None) -> Callable:
|
||||
self.to(device)
|
||||
self.eval_mode()
|
||||
|
||||
def policy(obs, env_info=None):
|
||||
obs, _ = self._process_observations(obs, env_info)
|
||||
actions = self._scale_actions(self.actor.forward(obs))
|
||||
|
||||
# actions, _ = self.draw_actions(obs, env_info)
|
||||
|
||||
return actions
|
||||
|
||||
return policy
|
||||
|
||||
def register_terminations(self, terminations: torch.Tensor) -> None:
|
||||
pass
|
||||
|
||||
def to(self, device: str) -> AbstractActorCritic:
|
||||
"""Transfers agent parameters to device."""
|
||||
super().to(device)
|
||||
|
||||
self.actor.to(device)
|
||||
self.critic_1.to(device)
|
||||
self.critic_2.to(device)
|
||||
self.target_critic_1.to(device)
|
||||
self.target_critic_2.to(device)
|
||||
self.log_alpha.to(device)
|
||||
|
||||
return self
|
||||
|
||||
def train_mode(self) -> AbstractActorCritic:
|
||||
super().train_mode()
|
||||
|
||||
self.actor.train()
|
||||
self.critic_1.train()
|
||||
self.critic_2.train()
|
||||
self.target_critic_1.train()
|
||||
self.target_critic_2.train()
|
||||
|
||||
return self
|
||||
|
||||
def update(self, dataset: Dataset) -> Dict[str, Union[float, torch.Tensor]]:
|
||||
super().update(dataset)
|
||||
|
||||
if not self.initialized:
|
||||
return {}
|
||||
|
||||
total_actor_loss = torch.zeros(self._batch_count)
|
||||
total_alpha_loss = torch.zeros(self._batch_count)
|
||||
total_critic_1_loss = torch.zeros(self._batch_count)
|
||||
total_critic_2_loss = torch.zeros(self._batch_count)
|
||||
|
||||
for idx, batch in enumerate(self.storage.batch_generator(self._batch_size, self._batch_count)):
|
||||
actor_obs = batch["actor_observations"]
|
||||
critic_obs = batch["critic_observations"]
|
||||
actions = batch["actions"].reshape(-1, self._action_size)
|
||||
rewards = batch["rewards"]
|
||||
actor_next_obs = batch["next_actor_observations"]
|
||||
critic_next_obs = batch["next_critic_observations"]
|
||||
dones = batch["dones"]
|
||||
|
||||
critic_1_loss, critic_2_loss = self._update_critic(
|
||||
critic_obs, actions, rewards, dones, actor_next_obs, critic_next_obs
|
||||
)
|
||||
actor_loss, alpha_loss = self._update_actor_and_alpha(actor_obs, critic_obs)
|
||||
|
||||
# Update Target Networks
|
||||
|
||||
self._update_target(self.critic_1, self.target_critic_1)
|
||||
self._update_target(self.critic_2, self.target_critic_2)
|
||||
|
||||
total_actor_loss[idx] = actor_loss.item()
|
||||
total_alpha_loss[idx] = alpha_loss.item()
|
||||
total_critic_1_loss[idx] = critic_1_loss.item()
|
||||
total_critic_2_loss[idx] = critic_2_loss.item()
|
||||
|
||||
stats = {
|
||||
"actor": total_actor_loss.mean().item(),
|
||||
"alpha": total_alpha_loss.mean().item(),
|
||||
"critic1": total_critic_1_loss.mean().item(),
|
||||
"critic2": total_critic_2_loss.mean().item(),
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
def _sample_action(
|
||||
self, observation: torch.Tensor, compute_logp=True
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, float]]:
|
||||
"""Samples and action from the policy.
|
||||
|
||||
Args:
|
||||
observation (torch.Tensor): The observation to sample an action for.
|
||||
compute_logp (bool): Whether to compute and return the action log probability. Default to True.
|
||||
Returns:
|
||||
Either the action as a torch.Tensor or, if compute_logp is set to true, a tuple containing the actions as a
|
||||
torch.Tensor and the action log probability.
|
||||
"""
|
||||
mean, std = self.actor.forward(observation, compute_std=True)
|
||||
dist = torch.distributions.Normal(mean, std)
|
||||
|
||||
actions = dist.rsample()
|
||||
actions_normalized, actions_scaled = self._scale_actions(actions, intermediate=True)
|
||||
|
||||
if not compute_logp:
|
||||
return actions_scaled
|
||||
|
||||
action_logp = dist.log_prob(actions).sum(-1) - torch.log(1.0 - actions_normalized.pow(2) + 1e-6).sum(-1)
|
||||
|
||||
return actions_scaled, action_logp
|
||||
|
||||
def _scale_actions(self, actions: torch.Tensor, intermediate=False) -> torch.Tensor:
|
||||
actions = actions.reshape(-1, self._action_size)
|
||||
action_normalized = torch.tanh(actions)
|
||||
action_scaled = super()._process_actions(action_normalized * self._action_delta + self._action_offset)
|
||||
|
||||
if intermediate:
|
||||
return action_normalized, action_scaled
|
||||
|
||||
return action_scaled
|
||||
|
||||
def _update_actor_and_alpha(
|
||||
self, actor_obs: torch.Tensor, critic_obs: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
actor_prediction, actor_prediction_logp = self._sample_action(actor_obs)
|
||||
|
||||
# Update alpha (also called temperature / entropy coefficient)
|
||||
alpha_loss = -(self.log_alpha * (actor_prediction_logp + self._target_entropy).detach()).mean()
|
||||
|
||||
self.log_alpha_optimizer.zero_grad()
|
||||
alpha_loss.backward()
|
||||
self.log_alpha_optimizer.step()
|
||||
|
||||
# Update actor
|
||||
evaluation_input = self._critic_input(critic_obs, actor_prediction)
|
||||
evaluation_1 = self.critic_1.forward(evaluation_input)
|
||||
evaluation_2 = self.critic_2.forward(evaluation_input)
|
||||
actor_loss = (self.alpha.detach() * actor_prediction_logp - torch.min(evaluation_1, evaluation_2)).mean()
|
||||
|
||||
self.actor_optimizer.zero_grad()
|
||||
actor_loss.backward()
|
||||
nn.utils.clip_grad_norm_(self.actor.parameters(), self._gradient_clip)
|
||||
self.actor_optimizer.step()
|
||||
|
||||
return actor_loss, alpha_loss
|
||||
|
||||
def _update_critic(
|
||||
self,
|
||||
critic_obs: torch.Tensor,
|
||||
actions: torch.Tensor,
|
||||
rewards: torch.Tensor,
|
||||
dones: torch.Tensor,
|
||||
actor_next_obs: torch.Tensor,
|
||||
critic_next_obs: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
with torch.no_grad():
|
||||
target_action, target_action_logp = self._sample_action(actor_next_obs)
|
||||
target_critic_input = self._critic_input(critic_next_obs, target_action)
|
||||
target_critic_prediction_1 = self.target_critic_1.forward(target_critic_input)
|
||||
target_critic_prediction_2 = self.target_critic_2.forward(target_critic_input)
|
||||
|
||||
target_next = (
|
||||
torch.min(target_critic_prediction_1, target_critic_prediction_2) - self.alpha * target_action_logp
|
||||
)
|
||||
target = rewards + self._discount_factor * (1 - dones) * target_next
|
||||
|
||||
critic_input = self._critic_input(critic_obs, actions)
|
||||
critic_prediction_1 = self.critic_1.forward(critic_input)
|
||||
critic_1_loss = nn.functional.mse_loss(critic_prediction_1, target)
|
||||
|
||||
self.critic_1_optimizer.zero_grad()
|
||||
critic_1_loss.backward()
|
||||
nn.utils.clip_grad_norm_(self.critic_1.parameters(), self._gradient_clip)
|
||||
self.critic_1_optimizer.step()
|
||||
|
||||
critic_prediction_2 = self.critic_2.forward(critic_input)
|
||||
critic_2_loss = nn.functional.mse_loss(critic_prediction_2, target)
|
||||
|
||||
self.critic_2_optimizer.zero_grad()
|
||||
critic_2_loss.backward()
|
||||
nn.utils.clip_grad_norm_(self.critic_2.parameters(), self._gradient_clip)
|
||||
self.critic_2_optimizer.step()
|
||||
|
||||
return critic_1_loss, critic_2_loss
|
|
@ -0,0 +1,198 @@
|
|||
from __future__ import annotations
|
||||
import torch
|
||||
from torch import nn, optim
|
||||
from typing import Dict, Type, Union
|
||||
|
||||
from rsl_rl.algorithms.dpg import AbstractDPG
|
||||
from rsl_rl.env import VecEnv
|
||||
from rsl_rl.modules.network import Network
|
||||
from rsl_rl.storage.storage import Dataset
|
||||
|
||||
|
||||
class TD3(AbstractDPG):
|
||||
"""Twin-Delayed Deep Deterministic Policy Gradients algorithm.
|
||||
|
||||
This is an implementation of the TD3 algorithm by Fujimoto et. al. for vectorized environments.
|
||||
|
||||
Paper: https://arxiv.org/pdf/1802.09477.pdf
|
||||
"""
|
||||
|
||||
critic_network: Type[nn.Module] = Network
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: VecEnv,
|
||||
actor_lr: float = 1e-4,
|
||||
critic_lr: float = 1e-3,
|
||||
noise_clip: float = 0.5,
|
||||
policy_delay: int = 2,
|
||||
target_noise_scale: float = 0.2,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(env, **kwargs)
|
||||
|
||||
self._noise_clip = noise_clip
|
||||
self._policy_delay = policy_delay
|
||||
self._target_noise_scale = target_noise_scale
|
||||
|
||||
self._register_serializable("_noise_clip", "_policy_delay", "_target_noise_scale")
|
||||
|
||||
self.actor = Network(self._actor_input_size, self._action_size, **self._actor_network_kwargs)
|
||||
self.critic_1 = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs)
|
||||
self.critic_2 = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs)
|
||||
|
||||
self.target_actor = Network(self._actor_input_size, self._action_size, **self._actor_network_kwargs)
|
||||
self.target_critic_1 = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs)
|
||||
self.target_critic_2 = self.critic_network(self._critic_input_size, 1, **self._critic_network_kwargs)
|
||||
self.target_actor.load_state_dict(self.actor.state_dict())
|
||||
self.target_critic_1.load_state_dict(self.critic_1.state_dict())
|
||||
self.target_critic_2.load_state_dict(self.critic_2.state_dict())
|
||||
|
||||
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)
|
||||
self.critic_1_optimizer = optim.Adam(self.critic_1.parameters(), lr=critic_lr)
|
||||
self.critic_2_optimizer = optim.Adam(self.critic_2.parameters(), lr=critic_lr)
|
||||
|
||||
self._update_step = 0
|
||||
|
||||
self._register_serializable(
|
||||
"actor",
|
||||
"critic_1",
|
||||
"critic_2",
|
||||
"target_actor",
|
||||
"target_critic_1",
|
||||
"target_critic_2",
|
||||
"actor_optimizer",
|
||||
"critic_1_optimizer",
|
||||
"critic_2_optimizer",
|
||||
"_update_step",
|
||||
)
|
||||
|
||||
self.critic = self.critic_1
|
||||
self.to(self.device)
|
||||
|
||||
def eval_mode(self) -> TD3:
|
||||
super().eval_mode()
|
||||
|
||||
self.actor.eval()
|
||||
self.critic_1.eval()
|
||||
self.critic_2.eval()
|
||||
self.target_actor.eval()
|
||||
self.target_critic_1.eval()
|
||||
self.target_critic_2.eval()
|
||||
|
||||
return self
|
||||
|
||||
def to(self, device: str) -> TD3:
|
||||
"""Transfers agent parameters to device."""
|
||||
super().to(device)
|
||||
|
||||
self.actor.to(device)
|
||||
self.critic_1.to(device)
|
||||
self.critic_2.to(device)
|
||||
self.target_actor.to(device)
|
||||
self.target_critic_1.to(device)
|
||||
self.target_critic_2.to(device)
|
||||
|
||||
return self
|
||||
|
||||
def train_mode(self) -> TD3:
|
||||
super().train_mode()
|
||||
|
||||
self.actor.train()
|
||||
self.critic_1.train()
|
||||
self.critic_2.train()
|
||||
self.target_actor.train()
|
||||
self.target_critic_1.train()
|
||||
self.target_critic_2.train()
|
||||
|
||||
return self
|
||||
|
||||
def _apply_action_noise(self, actions: torch.Tensor, clip=False) -> torch.Tensor:
|
||||
noise = torch.normal(torch.zeros_like(actions), torch.ones_like(actions) * self._action_noise_scale)
|
||||
|
||||
if clip:
|
||||
noise = noise.clamp(-self._noise_clip, self._noise_clip)
|
||||
|
||||
noisy_actions = self._process_actions(actions + noise)
|
||||
|
||||
return noisy_actions
|
||||
|
||||
def update(self, dataset: Dataset) -> Dict[str, Union[float, torch.Tensor]]:
|
||||
super().update(dataset)
|
||||
|
||||
if not self.initialized:
|
||||
return {}
|
||||
|
||||
total_actor_loss = torch.zeros(self._batch_count)
|
||||
total_critic_1_loss = torch.zeros(self._batch_count)
|
||||
total_critic_2_loss = torch.zeros(self._batch_count)
|
||||
|
||||
for idx, batch in enumerate(self.storage.batch_generator(self._batch_size, self._batch_count)):
|
||||
actor_obs = batch["actor_observations"]
|
||||
critic_obs = batch["critic_observations"]
|
||||
actions = batch["actions"].reshape(self._batch_size, -1)
|
||||
rewards = batch["rewards"]
|
||||
actor_next_obs = batch["next_actor_observations"]
|
||||
critic_next_obs = batch["next_critic_observations"]
|
||||
dones = batch["dones"]
|
||||
|
||||
critic_1_loss, critic_2_loss = self._update_critic(
|
||||
critic_obs, actions, rewards, dones, actor_next_obs, critic_next_obs
|
||||
)
|
||||
|
||||
if self._update_step % self._policy_delay == 0:
|
||||
evaluation = self.critic_1.forward(
|
||||
self._critic_input(critic_obs, self._process_actions(self.actor.forward(actor_obs)))
|
||||
)
|
||||
actor_loss = -evaluation.mean()
|
||||
|
||||
self.actor_optimizer.zero_grad()
|
||||
actor_loss.backward()
|
||||
self.actor_optimizer.step()
|
||||
|
||||
self._update_target(self.actor, self.target_actor)
|
||||
self._update_target(self.critic_1, self.target_critic_1)
|
||||
self._update_target(self.critic_2, self.target_critic_2)
|
||||
|
||||
total_actor_loss[idx] = actor_loss.item()
|
||||
|
||||
self._update_step = self._update_step + 1
|
||||
|
||||
total_critic_1_loss[idx] = critic_1_loss.item()
|
||||
total_critic_2_loss[idx] = critic_2_loss.item()
|
||||
|
||||
stats = {
|
||||
"actor": total_actor_loss.mean().item(),
|
||||
"critic1": total_critic_1_loss.mean().item(),
|
||||
"critic2": total_critic_2_loss.mean().item(),
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
def _update_critic(self, critic_obs, actions, rewards, dones, actor_next_obs, critic_next_obs):
|
||||
target_actor_prediction = self._apply_action_noise(self.target_actor.forward(actor_next_obs), clip=True)
|
||||
target_critic_1_prediction = self.target_critic_1.forward(
|
||||
self._critic_input(critic_next_obs, target_actor_prediction)
|
||||
)
|
||||
target_critic_2_prediction = self.target_critic_2.forward(
|
||||
self._critic_input(critic_next_obs, target_actor_prediction)
|
||||
)
|
||||
target_critic_prediction = torch.min(target_critic_1_prediction, target_critic_2_prediction)
|
||||
|
||||
target = (rewards + self._discount_factor * (1 - dones) * target_critic_prediction).detach()
|
||||
|
||||
prediction_1 = self.critic_1.forward(self._critic_input(critic_obs, actions))
|
||||
critic_1_loss = (prediction_1 - target).pow(2).mean()
|
||||
|
||||
self.critic_1_optimizer.zero_grad()
|
||||
critic_1_loss.backward()
|
||||
self.critic_1_optimizer.step()
|
||||
|
||||
prediction_2 = self.critic_2.forward(self._critic_input(critic_obs, actions))
|
||||
critic_2_loss = (prediction_2 - target).pow(2).mean()
|
||||
|
||||
self.critic_2_optimizer.zero_grad()
|
||||
critic_2_loss.backward()
|
||||
self.critic_2_optimizer.step()
|
||||
|
||||
return critic_1_loss, critic_2_loss
|
|
@ -0,0 +1,2 @@
|
|||
from .distribution import Distribution
|
||||
from .quantile_distribution import QuantileDistribution
|
|
@ -0,0 +1,18 @@
|
|||
from abc import ABC, abstractmethod
|
||||
import torch
|
||||
|
||||
|
||||
class Distribution(ABC):
|
||||
def __init__(self, params: torch.Tensor) -> None:
|
||||
self._params = params
|
||||
|
||||
@abstractmethod
|
||||
def sample(self, sample_count: int = 1) -> torch.Tensor:
|
||||
"""Sample from the distribution.
|
||||
|
||||
Args:
|
||||
sample_count: The number of samples to draw.
|
||||
Returns:
|
||||
A tensor of shape (sample_count, *parameter_shape).
|
||||
"""
|
||||
pass
|
|
@ -0,0 +1,13 @@
|
|||
import torch
|
||||
|
||||
from .distribution import Distribution
|
||||
|
||||
|
||||
class QuantileDistribution(Distribution):
|
||||
def sample(self, sample_count: int = 1) -> torch.Tensor:
|
||||
idx = torch.randint(
|
||||
self._params.shape[-1], (*self._params.shape[:-1], sample_count), device=self._params.device
|
||||
)
|
||||
samples = torch.take_along_dim(self._params, idx, -1)
|
||||
|
||||
return samples, idx
|
|
@ -1,6 +1,5 @@
|
|||
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
"""Submodule defining the environment definitions."""
|
||||
|
||||
from .vec_env import VecEnv
|
||||
|
|
|
@ -0,0 +1,120 @@
|
|||
from datetime import datetime
|
||||
import gym
|
||||
import torch
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
from rsl_rl.env.vec_env import VecEnv
|
||||
|
||||
|
||||
class GymEnv(VecEnv):
|
||||
"""A vectorized environment wrapper for OpenAI Gym environments.
|
||||
|
||||
This class wraps a single OpenAI Gym environment into a vectorized environment. It is assumed that the environment
|
||||
is a single agent environment. The environment is wrapped in a `gym.vector.SyncVectorEnv` environment, which
|
||||
allows for parallel execution of multiple environments.
|
||||
"""
|
||||
|
||||
def __init__(self, name, draw=False, draw_cb=None, draw_directory="videos/", gym_kwargs={}, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
name: The name of the OpenAI Gym environment.
|
||||
draw: Whether to record videos of the environment.
|
||||
draw_cb: A callback function that is called after each episode. The callback function is passed the episode
|
||||
number and the path to the video file. The callback function should return `True` if the video should
|
||||
be recorded and `False` otherwise.
|
||||
draw_directory: The directory in which to store the videos.
|
||||
gym_kwargs: Keyword arguments that are passed to the OpenAI Gym environment.
|
||||
**kwargs: Keyword arguments that are passed to the `VecEnv` constructor.
|
||||
"""
|
||||
self._gym_kwargs = gym_kwargs
|
||||
|
||||
env = gym.make(name, **self._gym_kwargs)
|
||||
|
||||
assert isinstance(env.observation_space, gym.spaces.Box)
|
||||
assert len(env.observation_space.shape) == 1
|
||||
assert isinstance(env.action_space, gym.spaces.Box)
|
||||
assert len(env.action_space.shape) == 1
|
||||
|
||||
super().__init__(env.observation_space.shape[0], env.observation_space.shape[0], **kwargs)
|
||||
|
||||
self.name = name
|
||||
self.draw_directory = draw_directory
|
||||
|
||||
self.num_actions = env.action_space.shape[0]
|
||||
self._gym_venv = gym.vector.SyncVectorEnv(
|
||||
[lambda: gym.make(self.name, **self._gym_kwargs) for _ in range(self.num_envs)]
|
||||
)
|
||||
|
||||
self._draw = False
|
||||
self._draw_cb = draw_cb if draw_cb is not None else lambda *args: True
|
||||
self._draw_uuid = None
|
||||
self.draw = draw
|
||||
|
||||
self.reset()
|
||||
|
||||
def close(self) -> None:
|
||||
self._gym_venv.close()
|
||||
|
||||
def get_observations(self) -> Tuple[torch.Tensor, Dict[str, Any]]:
|
||||
return self.obs_buf, self.extras
|
||||
|
||||
def get_privileged_observations(self) -> Union[torch.Tensor, None]:
|
||||
return self.obs_buf, self.extras
|
||||
|
||||
@property
|
||||
def draw(self) -> bool:
|
||||
return self._draw
|
||||
|
||||
@draw.setter
|
||||
def draw(self, value: bool) -> None:
|
||||
if value != self._draw:
|
||||
if value:
|
||||
self._draw_uuid = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
env = gym.make(self.name, render_mode="rgb_array", **self._gym_kwargs)
|
||||
env = gym.wrappers.RecordVideo(
|
||||
env,
|
||||
f"{self.draw_directory}/{self._draw_uuid}/",
|
||||
episode_trigger=lambda ep: (
|
||||
self._draw_cb(ep - 1, f"{self.draw_directory}/{self._draw_uuid}/rl-video-episode-{ep-1}.mp4")
|
||||
or True
|
||||
)
|
||||
if ep > 0
|
||||
else False,
|
||||
)
|
||||
else:
|
||||
env = gym.make(self.name, render_mode=None, **self._gym_kwargs)
|
||||
|
||||
self._gym_venv.envs[0] = env
|
||||
self._draw = value
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
|
||||
self.obs_buf = torch.from_numpy(self._gym_venv.reset()[0]).float().to(self.device)
|
||||
self.rew_buf = torch.zeros((self.num_envs,), device=self.device).float()
|
||||
self.reset_buf = torch.zeros((self.num_envs,), device=self.device).float()
|
||||
self.extras = {"observations": {}, "time_outs": torch.zeros((self.num_envs,), device=self.device).float()}
|
||||
|
||||
return self.obs_buf, self.extras
|
||||
|
||||
def step(
|
||||
self, actions: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Union[torch.Tensor, None], torch.Tensor, torch.Tensor, Dict[str, Any]]:
|
||||
obs, rew, reset, term, _ = self._gym_venv.step(actions.cpu().numpy())
|
||||
|
||||
self.obs_buf = torch.from_numpy(obs).float().to(self.device)
|
||||
self.rew_buf = torch.from_numpy(rew).float().to(self.device)
|
||||
self.reset_buf = torch.from_numpy(reset).float().to(self.device)
|
||||
self.extras = {
|
||||
"observations": {},
|
||||
"time_outs": torch.from_numpy(term).float().to(self.device).float().to(self.device),
|
||||
}
|
||||
|
||||
return self.obs_buf, self.rew_buf, self.reset_buf, self.extras
|
||||
|
||||
def to(self, device: str) -> None:
|
||||
self.device = device
|
||||
|
||||
self.obs_buf = self.obs_buf.to(device)
|
||||
self.rew_buf = self.rew_buf.to(device)
|
||||
self.reset_buf = self.reset_buf.to(device)
|
|
@ -0,0 +1,137 @@
|
|||
import math
|
||||
import numpy as np
|
||||
import time
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
from rsl_rl.env.vec_env import VecEnv
|
||||
|
||||
|
||||
class PoleBalancing(VecEnv):
|
||||
"""Custom pole balancing environment.
|
||||
|
||||
This class implements a custom pole balancing environment. It demonstrates how to implement a custom `VecEnv`
|
||||
environment.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
**kwargs: Keyword arguments that are passed to the `VecEnv` constructor.
|
||||
"""
|
||||
super().__init__(2, 2, **kwargs)
|
||||
|
||||
self.num_actions = 1
|
||||
|
||||
self.gravity = 9.8
|
||||
self.length = 2.0
|
||||
self.dt = 0.1
|
||||
|
||||
# Angle at which to fail the episode (15 deg)
|
||||
self.theta_threshold_radians = 15 * 2 * math.pi / 360
|
||||
|
||||
# Max. angle at which to initialize the episode (5 deg)
|
||||
self.initial_max_position = 2 * 2 * math.pi / 360
|
||||
# Max. angular velocity at which to initialize the episode (1 deg/s)
|
||||
self.initial_max_velocity = 0.3 * 2 * math.pi / 360
|
||||
|
||||
self.initial_position_factor = self.initial_max_position / 0.5
|
||||
self.initial_velocity_factor = self.initial_max_velocity / 0.5
|
||||
|
||||
self.draw = False
|
||||
self.pushes = False
|
||||
|
||||
self.reset()
|
||||
|
||||
def get_observations(self) -> Tuple[torch.Tensor, Dict[str, Any]]:
|
||||
return self.obs_buf, self.extras
|
||||
|
||||
def get_privileged_observations(self) -> Union[torch.Tensor, None]:
|
||||
return self.obs_buf, self.extras
|
||||
|
||||
def step(
|
||||
self, actions: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Union[torch.Tensor, None], torch.Tensor, torch.Tensor, Dict[str, Any]]:
|
||||
assert actions.size() == (self.num_envs, 1)
|
||||
|
||||
self.to(self.device)
|
||||
actions = actions.to(self.device)
|
||||
|
||||
noise = torch.normal(torch.zeros_like(actions), torch.ones_like(actions) * 0.005).squeeze()
|
||||
if self.pushes and np.random.rand() < 0.05:
|
||||
noise *= 100.0
|
||||
actions = actions.clamp(min=-0.2, max=0.2).squeeze()
|
||||
gravity = torch.sin(self.state[:, 0]) * self.gravity / self.length
|
||||
angular_acceleration = gravity + actions + noise
|
||||
|
||||
self.state[:, 1] = self.state[:, 1] + self.dt * angular_acceleration
|
||||
self.state[:, 0] = self.state[:, 0] + self.dt * self.state[:, 1]
|
||||
|
||||
self.reset_buf = torch.zeros(self.num_envs)
|
||||
self.reset_buf[(torch.abs(self.state[:, 0]) > self.theta_threshold_radians).nonzero()] = 1.0
|
||||
reset_idx = self.reset_buf.nonzero()
|
||||
|
||||
self.state[reset_idx, 0] = (
|
||||
torch.rand(reset_idx.size()[0], 1, device=self.device) - 0.5
|
||||
) * self.initial_position_factor
|
||||
self.state[reset_idx, 1] = (
|
||||
torch.rand(reset_idx.size()[0], 1, device=self.device) - 0.5
|
||||
) * self.initial_velocity_factor
|
||||
|
||||
self.rew_buf = torch.ones(self.num_envs, device=self.device)
|
||||
self.rew_buf[reset_idx] = -1.0
|
||||
self.rew_buf = self.rew_buf - actions.abs()
|
||||
self.rew_buf = self.rew_buf - self.state[:, 0].abs()
|
||||
|
||||
self._update_obs()
|
||||
|
||||
if self.draw:
|
||||
self._debug_draw(actions)
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
return self.obs_buf, self.rew_buf, self.reset_buf, self.extras
|
||||
|
||||
def _update_obs(self):
|
||||
self.obs_buf = self.state
|
||||
self.extras = {"observations": {}, "time_outs": torch.zeros_like(self.rew_buf)}
|
||||
|
||||
def reset(self) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
|
||||
self.state = torch.zeros(self.num_envs, 2, device=self.device)
|
||||
self.state[:, 0] = (torch.rand(self.num_envs) - 0.5) * self.initial_position_factor
|
||||
self.state[:, 1] = (torch.rand(self.num_envs) - 0.5) * self.initial_velocity_factor
|
||||
self.rew_buf = torch.zeros(self.num_envs)
|
||||
self.reset_buf = torch.zeros(self.num_envs)
|
||||
self.extras = {}
|
||||
|
||||
self._update_obs()
|
||||
|
||||
return self.obs_buf, self.extras
|
||||
|
||||
def to(self, device):
|
||||
self.device = device
|
||||
|
||||
self.obs_buf = self.obs_buf.to(device)
|
||||
self.rew_buf = self.rew_buf.to(device)
|
||||
self.reset_buf = self.reset_buf.to(device)
|
||||
self.state = self.state.to(device)
|
||||
|
||||
def _debug_draw(self, actions):
|
||||
if not hasattr(self, "_visuals"):
|
||||
self._visuals = {"x": [0], "pos": [], "act": [], "done": []}
|
||||
plt.gca().figure.show()
|
||||
else:
|
||||
self._visuals["x"].append(self._visuals["x"][-1] + 1)
|
||||
|
||||
self._visuals["pos"].append(self.obs_buf[0, 0].cpu().item())
|
||||
self._visuals["done"].append(self.reset_buf[0].cpu().item())
|
||||
self._visuals["act"].append(actions.squeeze()[0].cpu().item())
|
||||
|
||||
plt.cla()
|
||||
plt.plot(self._visuals["x"][-100:], self._visuals["act"][-100:], color="green")
|
||||
plt.plot(self._visuals["x"][-100:], self._visuals["pos"][-100:], color="blue")
|
||||
plt.plot(self._visuals["x"][-100:], self._visuals["done"][-100:], color="red")
|
||||
plt.draw()
|
||||
plt.gca().figure.canvas.flush_events()
|
||||
time.sleep(0.0001)
|
|
@ -0,0 +1,97 @@
|
|||
import torch
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
from rsl_rl.env.gym_env import GymEnv
|
||||
|
||||
|
||||
class GymPOMDP(GymEnv):
|
||||
"""A vectorized POMDP environment wrapper for OpenAI Gym environments.
|
||||
|
||||
This environment allows for the modification of the observation space of an OpenAI Gym environment. The modified
|
||||
observation space is a subset of the original observation space.
|
||||
"""
|
||||
|
||||
_reduced_observation_count: int = None
|
||||
|
||||
def __init__(self, name: str, **kwargs):
|
||||
assert self._reduced_observation_count is not None
|
||||
|
||||
super().__init__(name=name, **kwargs)
|
||||
|
||||
self.num_obs = self._reduced_observation_count
|
||||
self.num_privileged_obs = self._reduced_observation_count
|
||||
|
||||
def _process_observations(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Reduces observation space from original observation space to modified observation space.
|
||||
|
||||
Args:
|
||||
obs (torch.Tensor): Original observations.
|
||||
Returns:
|
||||
The modified observations as a torch.Tensor of shape (obs.shape[0], self.num_obs).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def reset(self, *args, **kwargs):
|
||||
obs, _ = super().reset(*args, **kwargs)
|
||||
|
||||
self.obs_buf = self._process_observations(obs)
|
||||
|
||||
return self.obs_buf, self.extras
|
||||
|
||||
def step(
|
||||
self, actions: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Union[torch.Tensor, None], torch.Tensor, torch.Tensor, Dict[str, Any]]:
|
||||
obs, _, _, _ = super().step(actions)
|
||||
|
||||
self.obs_buf = self._process_observations(obs)
|
||||
|
||||
return self.obs_buf, self.rew_buf, self.reset_buf, self.extras
|
||||
|
||||
|
||||
class BipedalWalkerP(GymPOMDP):
|
||||
"""
|
||||
Original observation space (24 values):
|
||||
[
|
||||
hull angle,
|
||||
hull angular velocity,
|
||||
horizontal velocity,
|
||||
vertical velocity,
|
||||
joint 1 angle,
|
||||
joint 1 speed,
|
||||
joint 2 angle,
|
||||
joint 2 speed,
|
||||
leg 1 ground contact,
|
||||
joint 3 angle,
|
||||
joint 3 speed,
|
||||
joint 4 angle,
|
||||
joint 4 speed,
|
||||
leg 2 ground contact,
|
||||
lidar (10 values),
|
||||
]
|
||||
Modified observation space (15 values):
|
||||
[
|
||||
hull angle,
|
||||
joint 1 angle,
|
||||
joint 2 angle,
|
||||
joint 3 angle,
|
||||
joint 4 angle,
|
||||
lidar (10 values),
|
||||
]
|
||||
"""
|
||||
|
||||
_reduced_observation_count: int = 15
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(name="BipedalWalker-v3", **kwargs)
|
||||
|
||||
def _process_observations(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Reduces observation space from original observation space to modified observation space."""
|
||||
reduced_obs = torch.zeros(obs.shape[0], self._reduced_observation_count, device=self.device)
|
||||
reduced_obs[:, 0] = obs[:, 0]
|
||||
reduced_obs[:, 1] = obs[:, 4]
|
||||
reduced_obs[:, 2] = obs[:, 6]
|
||||
reduced_obs[:, 3] = obs[:, 9]
|
||||
reduced_obs[:, 4] = obs[:, 11]
|
||||
reduced_obs[:, 5:] = obs[:, 14:]
|
||||
|
||||
return reduced_obs
|
|
@ -0,0 +1,44 @@
|
|||
import torch
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
from rsl_rl.env.vec_env import VecEnv
|
||||
|
||||
|
||||
class RSLGymEnv(VecEnv):
|
||||
"""A wrapper for using rsl_rl with the rslgym library."""
|
||||
|
||||
def __init__(self, rslgym_env, **kwargs):
|
||||
self._rslgym_env = rslgym_env
|
||||
|
||||
observation_count = self._rslgym_env.observation_space.shape[0]
|
||||
super().__init__(observation_count, observation_count, **kwargs)
|
||||
|
||||
self.num_actions = self._rslgym_env.action_space.shape[0]
|
||||
|
||||
self.obs_buf = None
|
||||
self.rew_buf = None
|
||||
self.reset_buf = None
|
||||
self.extras = None
|
||||
|
||||
self.reset()
|
||||
|
||||
def get_observations(self) -> Tuple[torch.Tensor, Dict[str, Any]]:
|
||||
return self.obs_buf, self.extras
|
||||
|
||||
def get_privileged_observations(self) -> Union[torch.Tensor, None]:
|
||||
return self.obs_buf
|
||||
|
||||
def reset(self) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
|
||||
obs = self._rslgym_env.reset()
|
||||
|
||||
self.obs_buf = torch.from_numpy(obs)
|
||||
self.extras = {"observations": {}, "time_outs": torch.zeros((self.num_envs,), device=self.device).float()}
|
||||
|
||||
def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, Any]]:
|
||||
obs, reward, dones, info = self._rslgym_env.step(actions, True)
|
||||
|
||||
self.obs_buf = torch.from_numpy(obs)
|
||||
self.rew_buf = torch.from_numpy(reward)
|
||||
self.reset_buf = torch.from_numpy(dones).float()
|
||||
|
||||
return self.obs_buf, self.rew_buf, self.reset_buf, self.extras
|
|
@ -1,85 +1,74 @@
|
|||
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from abc import ABC, abstractmethod
|
||||
import torch
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
|
||||
# minimal interface of the environment
|
||||
class VecEnv(ABC):
|
||||
"""Abstract class for vectorized environment.
|
||||
|
||||
The vectorized environment is a collection of environments that are synchronized. This means that
|
||||
the same action is applied to all environments and the same observation is returned from all environments.
|
||||
|
||||
All extra observations must be provided as a dictionary to "extras" in the step() method. Based on the
|
||||
configuration, the extra observations are used for different purposes. The following keys are reserved
|
||||
in the "observations" dictionary (if they are present):
|
||||
|
||||
- "critic": The observation is used as input to the critic network. Useful for asymmetric observation spaces.
|
||||
"""
|
||||
"""Abstract class for vectorized environment."""
|
||||
|
||||
num_envs: int
|
||||
"""Number of environments."""
|
||||
num_obs: int
|
||||
"""Number of observations."""
|
||||
num_privileged_obs: int
|
||||
"""Number of privileged observations."""
|
||||
num_actions: int
|
||||
"""Number of actions."""
|
||||
max_episode_length: int
|
||||
"""Maximum episode length."""
|
||||
privileged_obs_buf: torch.Tensor
|
||||
"""Buffer for privileged observations."""
|
||||
obs_buf: torch.Tensor
|
||||
"""Buffer for observations."""
|
||||
rew_buf: torch.Tensor
|
||||
"""Buffer for rewards."""
|
||||
reset_buf: torch.Tensor
|
||||
"""Buffer for resets."""
|
||||
episode_length_buf: torch.Tensor # current episode duration
|
||||
"""Buffer for current episode lengths."""
|
||||
extras: dict
|
||||
"""Extra information (metrics).
|
||||
|
||||
Extra information is stored in a dictionary. This includes metrics such as the episode reward, episode length,
|
||||
etc. Additional information can be stored in the dictionary such as observations for the critic network, etc.
|
||||
"""
|
||||
device: torch.device
|
||||
"""Device to use."""
|
||||
|
||||
"""
|
||||
Operations.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_observations(self) -> tuple[torch.Tensor, dict]:
|
||||
"""Return the current observations.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, dict]: Tuple containing the observations and extras.
|
||||
def __init__(
|
||||
self, observation_count, privileged_observation_count, device="cpu", environment_count=1, max_episode_length=-1
|
||||
):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> tuple[torch.Tensor, dict]:
|
||||
"""Reset all environment instances.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, dict]: Tuple containing the observations and extras.
|
||||
Args:
|
||||
observation_count (int): Number of observations per environment.
|
||||
privileged_observation_count (int): Number of privileged observations per environment.
|
||||
device (str): Device to use for the tensors.
|
||||
environment_count (int): Number of environments to run in parallel.
|
||||
max_episode_length (int): Maximum length of an episode. If -1, the episode length is not limited.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
self.num_obs = observation_count
|
||||
self.num_privileged_obs = privileged_observation_count
|
||||
|
||||
self.num_envs = environment_count
|
||||
self.max_episode_length = max_episode_length
|
||||
self.device = device
|
||||
|
||||
@abstractmethod
|
||||
def step(self, actions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]:
|
||||
def get_observations(self) -> Tuple[torch.Tensor, Dict[str, Any]]:
|
||||
"""Return observations and extra information."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_privileged_observations(self) -> Union[torch.Tensor, None]:
|
||||
"""Return privileged observations."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def step(
|
||||
self, actions: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, Union[torch.Tensor, None], torch.Tensor, torch.Tensor, Dict[str, Any]]:
|
||||
"""Apply input action on the environment.
|
||||
|
||||
Args:
|
||||
actions (torch.Tensor): Input actions to apply. Shape: (num_envs, num_actions)
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]:
|
||||
A tuple containing the observations, rewards, dones and extra information (metrics).
|
||||
Tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, dict]:
|
||||
A tuple containing the observations, privileged observations, rewards, dones and
|
||||
extra information (metrics).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
|
||||
"""Reset all environment instances.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor | None]: Tuple containing the observations and privileged observations.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -1,10 +1,19 @@
|
|||
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
"""Definitions for neural-network components for RL-agents."""
|
||||
|
||||
from .actor_critic import ActorCritic
|
||||
from .actor_critic_recurrent import ActorCriticRecurrent
|
||||
from .categorical_network import CategoricalNetwork
|
||||
from .gaussian_chimera_network import GaussianChimeraNetwork
|
||||
from .gaussian_network import GaussianNetwork
|
||||
from .implicit_quantile_network import ImplicitQuantileNetwork
|
||||
from .network import Network
|
||||
from .normalizer import EmpiricalNormalization
|
||||
from .quantile_network import QuantileNetwork
|
||||
from .transformer import Transformer
|
||||
|
||||
__all__ = ["ActorCritic", "ActorCriticRecurrent"]
|
||||
__all__ = [
|
||||
"CategoricalNetwork",
|
||||
"EmpiricalNormalization",
|
||||
"GaussianChimeraNetwork",
|
||||
"GaussianNetwork",
|
||||
"ImplicitQuantileNetwork",
|
||||
"Network",
|
||||
"QuantileNetwork",
|
||||
"Transformer",
|
||||
]
|
||||
|
|
|
@ -1,136 +0,0 @@
|
|||
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributions import Normal
|
||||
|
||||
|
||||
class ActorCritic(nn.Module):
|
||||
is_recurrent = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_actor_obs,
|
||||
num_critic_obs,
|
||||
num_actions,
|
||||
actor_hidden_dims=[256, 256, 256],
|
||||
critic_hidden_dims=[256, 256, 256],
|
||||
activation="elu",
|
||||
init_noise_std=1.0,
|
||||
**kwargs,
|
||||
):
|
||||
if kwargs:
|
||||
print(
|
||||
"ActorCritic.__init__ got unexpected arguments, which will be ignored: "
|
||||
+ str([key for key in kwargs.keys()])
|
||||
)
|
||||
super().__init__()
|
||||
activation = get_activation(activation)
|
||||
|
||||
mlp_input_dim_a = num_actor_obs
|
||||
mlp_input_dim_c = num_critic_obs
|
||||
# Policy
|
||||
actor_layers = []
|
||||
actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0]))
|
||||
actor_layers.append(activation)
|
||||
for layer_index in range(len(actor_hidden_dims)):
|
||||
if layer_index == len(actor_hidden_dims) - 1:
|
||||
actor_layers.append(nn.Linear(actor_hidden_dims[layer_index], num_actions))
|
||||
else:
|
||||
actor_layers.append(nn.Linear(actor_hidden_dims[layer_index], actor_hidden_dims[layer_index + 1]))
|
||||
actor_layers.append(activation)
|
||||
self.actor = nn.Sequential(*actor_layers)
|
||||
|
||||
# Value function
|
||||
critic_layers = []
|
||||
critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0]))
|
||||
critic_layers.append(activation)
|
||||
for layer_index in range(len(critic_hidden_dims)):
|
||||
if layer_index == len(critic_hidden_dims) - 1:
|
||||
critic_layers.append(nn.Linear(critic_hidden_dims[layer_index], 1))
|
||||
else:
|
||||
critic_layers.append(nn.Linear(critic_hidden_dims[layer_index], critic_hidden_dims[layer_index + 1]))
|
||||
critic_layers.append(activation)
|
||||
self.critic = nn.Sequential(*critic_layers)
|
||||
|
||||
print(f"Actor MLP: {self.actor}")
|
||||
print(f"Critic MLP: {self.critic}")
|
||||
|
||||
# Action noise
|
||||
self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
|
||||
self.distribution = None
|
||||
# disable args validation for speedup
|
||||
Normal.set_default_validate_args = False
|
||||
|
||||
# seems that we get better performance without init
|
||||
# self.init_memory_weights(self.memory_a, 0.001, 0.)
|
||||
# self.init_memory_weights(self.memory_c, 0.001, 0.)
|
||||
|
||||
@staticmethod
|
||||
# not used at the moment
|
||||
def init_weights(sequential, scales):
|
||||
[
|
||||
torch.nn.init.orthogonal_(module.weight, gain=scales[idx])
|
||||
for idx, module in enumerate(mod for mod in sequential if isinstance(mod, nn.Linear))
|
||||
]
|
||||
|
||||
def reset(self, dones=None):
|
||||
pass
|
||||
|
||||
def forward(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def action_mean(self):
|
||||
return self.distribution.mean
|
||||
|
||||
@property
|
||||
def action_std(self):
|
||||
return self.distribution.stddev
|
||||
|
||||
@property
|
||||
def entropy(self):
|
||||
return self.distribution.entropy().sum(dim=-1)
|
||||
|
||||
def update_distribution(self, observations):
|
||||
mean = self.actor(observations)
|
||||
self.distribution = Normal(mean, mean * 0.0 + self.std)
|
||||
|
||||
def act(self, observations, **kwargs):
|
||||
self.update_distribution(observations)
|
||||
return self.distribution.sample()
|
||||
|
||||
def get_actions_log_prob(self, actions):
|
||||
return self.distribution.log_prob(actions).sum(dim=-1)
|
||||
|
||||
def act_inference(self, observations):
|
||||
actions_mean = self.actor(observations)
|
||||
return actions_mean
|
||||
|
||||
def evaluate(self, critic_observations, **kwargs):
|
||||
value = self.critic(critic_observations)
|
||||
return value
|
||||
|
||||
|
||||
def get_activation(act_name):
|
||||
if act_name == "elu":
|
||||
return nn.ELU()
|
||||
elif act_name == "selu":
|
||||
return nn.SELU()
|
||||
elif act_name == "relu":
|
||||
return nn.ReLU()
|
||||
elif act_name == "crelu":
|
||||
return nn.CReLU()
|
||||
elif act_name == "lrelu":
|
||||
return nn.LeakyReLU()
|
||||
elif act_name == "tanh":
|
||||
return nn.Tanh()
|
||||
elif act_name == "sigmoid":
|
||||
return nn.Sigmoid()
|
||||
else:
|
||||
print("invalid activation function!")
|
||||
return None
|
|
@ -1,97 +0,0 @@
|
|||
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from rsl_rl.modules.actor_critic import ActorCritic, get_activation
|
||||
from rsl_rl.utils import unpad_trajectories
|
||||
|
||||
|
||||
class ActorCriticRecurrent(ActorCritic):
|
||||
is_recurrent = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_actor_obs,
|
||||
num_critic_obs,
|
||||
num_actions,
|
||||
actor_hidden_dims=[256, 256, 256],
|
||||
critic_hidden_dims=[256, 256, 256],
|
||||
activation="elu",
|
||||
rnn_type="lstm",
|
||||
rnn_hidden_size=256,
|
||||
rnn_num_layers=1,
|
||||
init_noise_std=1.0,
|
||||
**kwargs,
|
||||
):
|
||||
if kwargs:
|
||||
print(
|
||||
"ActorCriticRecurrent.__init__ got unexpected arguments, which will be ignored: " + str(kwargs.keys()),
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
num_actor_obs=rnn_hidden_size,
|
||||
num_critic_obs=rnn_hidden_size,
|
||||
num_actions=num_actions,
|
||||
actor_hidden_dims=actor_hidden_dims,
|
||||
critic_hidden_dims=critic_hidden_dims,
|
||||
activation=activation,
|
||||
init_noise_std=init_noise_std,
|
||||
)
|
||||
|
||||
activation = get_activation(activation)
|
||||
|
||||
self.memory_a = Memory(num_actor_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size)
|
||||
self.memory_c = Memory(num_critic_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size)
|
||||
|
||||
print(f"Actor RNN: {self.memory_a}")
|
||||
print(f"Critic RNN: {self.memory_c}")
|
||||
|
||||
def reset(self, dones=None):
|
||||
self.memory_a.reset(dones)
|
||||
self.memory_c.reset(dones)
|
||||
|
||||
def act(self, observations, masks=None, hidden_states=None):
|
||||
input_a = self.memory_a(observations, masks, hidden_states)
|
||||
return super().act(input_a.squeeze(0))
|
||||
|
||||
def act_inference(self, observations):
|
||||
input_a = self.memory_a(observations)
|
||||
return super().act_inference(input_a.squeeze(0))
|
||||
|
||||
def evaluate(self, critic_observations, masks=None, hidden_states=None):
|
||||
input_c = self.memory_c(critic_observations, masks, hidden_states)
|
||||
return super().evaluate(input_c.squeeze(0))
|
||||
|
||||
def get_hidden_states(self):
|
||||
return self.memory_a.hidden_states, self.memory_c.hidden_states
|
||||
|
||||
|
||||
class Memory(torch.nn.Module):
|
||||
def __init__(self, input_size, type="lstm", num_layers=1, hidden_size=256):
|
||||
super().__init__()
|
||||
# RNN
|
||||
rnn_cls = nn.GRU if type.lower() == "gru" else nn.LSTM
|
||||
self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
|
||||
self.hidden_states = None
|
||||
|
||||
def forward(self, input, masks=None, hidden_states=None):
|
||||
batch_mode = masks is not None
|
||||
if batch_mode:
|
||||
# batch mode (policy update): need saved hidden states
|
||||
if hidden_states is None:
|
||||
raise ValueError("Hidden states not passed to memory module during policy update")
|
||||
out, _ = self.rnn(input, hidden_states)
|
||||
out = unpad_trajectories(out, masks)
|
||||
else:
|
||||
# inference mode (collection): use hidden states of last step
|
||||
out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states)
|
||||
return out
|
||||
|
||||
def reset(self, dones=None):
|
||||
# When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state
|
||||
for hidden_state in self.hidden_states:
|
||||
hidden_state[..., dones, :] = 0.0
|
|
@ -0,0 +1,105 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from rsl_rl.modules.network import Network
|
||||
from rsl_rl.utils.utils import squeeze_preserve_batch
|
||||
|
||||
eps = torch.finfo(torch.float32).eps
|
||||
|
||||
|
||||
class CategoricalNetwork(Network):
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
output_size,
|
||||
activations=["relu", "relu", "relu"],
|
||||
atom_count=51,
|
||||
hidden_dims=[256, 256, 256],
|
||||
init_gain=1.0,
|
||||
value_max=10.0,
|
||||
value_min=-10.0,
|
||||
**kwargs,
|
||||
):
|
||||
assert len(hidden_dims) == len(activations)
|
||||
assert value_max > value_min
|
||||
assert atom_count > 1
|
||||
|
||||
super().__init__(
|
||||
input_size,
|
||||
activations=activations,
|
||||
hidden_dims=hidden_dims[:-1],
|
||||
init_fade=False,
|
||||
init_gain=init_gain,
|
||||
output_size=hidden_dims[-1],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._value_max = value_max
|
||||
self._value_min = value_min
|
||||
self._atom_count = atom_count
|
||||
|
||||
self.value_delta = (self._value_max - self._value_min) / (self._atom_count - 1)
|
||||
action_values = torch.arange(self._value_min, self._value_max + eps, self.value_delta)
|
||||
self.register_buffer("action_values", action_values)
|
||||
|
||||
self._categorical_layers = nn.ModuleList([nn.Linear(hidden_dims[-1], atom_count) for _ in range(output_size)])
|
||||
|
||||
self._init(self._categorical_layers, fade=False, gain=init_gain)
|
||||
|
||||
def categorical_loss(
|
||||
self, predictions: torch.Tensor, target_probabilities: torch.Tensor, targets: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""Computes the categorical loss between the prediction and target categorical distributions.
|
||||
|
||||
Projects the targets back onto the categorical distribution supports before computing KL divergence.
|
||||
|
||||
Args:
|
||||
predictions (torch.Tensor): The network prediction.
|
||||
target_probabilities (torch.Tensor): The next-state value probabilities.
|
||||
targets (torch.Tensor): The targets to compute the loss from.
|
||||
Returns:
|
||||
A torch.Tensor of the cross-entropy loss between the projected targets and the prediction.
|
||||
"""
|
||||
b = (targets - self._value_min) / self.value_delta
|
||||
l = b.floor().long().clamp(0, self._atom_count - 1)
|
||||
u = b.ceil().long().clamp(0, self._atom_count - 1)
|
||||
|
||||
all_idx = torch.arange(b.shape[0])
|
||||
projected_targets = torch.zeros((b.shape[0], self._atom_count), device=self.device)
|
||||
for i in range(self._atom_count):
|
||||
# Correct for when l == u
|
||||
l[:, i][(l[:, i] == u[:, i]) * (l[:, i] > 0)] -= 1
|
||||
u[:, i][(l[:, i] == u[:, i]) * (u[:, i] < self._atom_count - 1)] += 1
|
||||
|
||||
projected_targets[all_idx, l[:, i]] += (u[:, i] - b[:, i]) * target_probabilities[..., i]
|
||||
projected_targets[all_idx, u[:, i]] += (b[:, i] - l[:, i]) * target_probabilities[..., i]
|
||||
|
||||
loss = torch.nn.functional.cross_entropy(
|
||||
predictions.reshape(*projected_targets.shape), projected_targets.detach()
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
def compute_targets(self, rewards: torch.Tensor, dones: torch.Tensor, discount: float) -> torch.Tensor:
|
||||
gamma = (discount * (1 - dones)).reshape(-1, 1)
|
||||
gamma_z = gamma * self.action_values.repeat(dones.size()[0], 1)
|
||||
targets = (rewards.reshape(-1, 1) + gamma_z).clamp(self._value_min, self._value_max)
|
||||
|
||||
return targets
|
||||
|
||||
def forward(self, x: torch.Tensor, distribution: bool = False) -> torch.Tensor:
|
||||
features = super().forward(x)
|
||||
probabilities = squeeze_preserve_batch(
|
||||
torch.stack([layer(features).softmax(dim=-1) for layer in self._categorical_layers], dim=1)
|
||||
)
|
||||
|
||||
if distribution:
|
||||
return probabilities
|
||||
|
||||
values = self.probabilities_to_values(probabilities)
|
||||
|
||||
return values
|
||||
|
||||
def probabilities_to_values(self, probabilities: torch.Tensor) -> torch.Tensor:
|
||||
values = probabilities @ self.action_values
|
||||
|
||||
return values
|
|
@ -0,0 +1,86 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from rsl_rl.modules.network import Network
|
||||
from rsl_rl.modules.utils import get_activation
|
||||
|
||||
|
||||
class GaussianChimeraNetwork(Network):
|
||||
"""A network to predict mean and std of a gaussian distribution with separate heads for mean and std."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
activations: List[str] = ["relu", "relu", "relu", "linear"],
|
||||
hidden_dims: List[int] = [256, 256, 256],
|
||||
init_fade: bool = True,
|
||||
init_gain: float = 0.5,
|
||||
log_std_max: float = 4.0,
|
||||
log_std_min: float = -20.0,
|
||||
std_init: float = 1.0,
|
||||
shared_dims: int = 1,
|
||||
**kwargs,
|
||||
):
|
||||
assert len(hidden_dims) + 1 == len(activations)
|
||||
assert shared_dims > 0 and shared_dims <= len(hidden_dims)
|
||||
|
||||
super().__init__(
|
||||
input_size,
|
||||
hidden_dims[shared_dims],
|
||||
activations=activations[: shared_dims + 1],
|
||||
hidden_dims=hidden_dims[:shared_dims],
|
||||
init_fade=False,
|
||||
init_gain=init_gain,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Since the network predicts log_std ~= 0 after initialization, compute std = std_init * exp(log_std).
|
||||
self._log_std_init = np.log(std_init)
|
||||
self._log_std_max = log_std_max
|
||||
self._log_std_min = log_std_min
|
||||
|
||||
separate_dims = len(hidden_dims) - shared_dims
|
||||
|
||||
mean_layers = []
|
||||
for i in range(separate_dims):
|
||||
isize = hidden_dims[shared_dims + i]
|
||||
osize = output_size if i == separate_dims - 1 else hidden_dims[shared_dims + i + 1]
|
||||
|
||||
layer = nn.Linear(isize, osize)
|
||||
activation = activations[shared_dims + i + 1]
|
||||
|
||||
mean_layers += [layer, get_activation(activation)]
|
||||
self._mean_layer = nn.Sequential(*mean_layers)
|
||||
|
||||
self._init(self._mean_layer, fade=init_fade, gain=init_gain)
|
||||
|
||||
log_std_layers = []
|
||||
for i in range(separate_dims):
|
||||
isize = hidden_dims[shared_dims + i]
|
||||
osize = output_size if i == separate_dims - 1 else hidden_dims[shared_dims + i + 1]
|
||||
|
||||
layer = nn.Linear(isize, osize)
|
||||
activation = activations[shared_dims + i + 1]
|
||||
|
||||
log_std_layers += [layer, get_activation(activation)]
|
||||
self._log_std_layer = nn.Sequential(*log_std_layers)
|
||||
|
||||
self._init(self._log_std_layer, fade=init_fade, gain=init_gain)
|
||||
|
||||
def forward(self, x: torch.Tensor, compute_std: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
|
||||
features = super().forward(x)
|
||||
|
||||
mean = self._mean_layer(features)
|
||||
|
||||
if not compute_std:
|
||||
return mean
|
||||
|
||||
# compute standard deviation as std = std_init * exp(log_std) = exp(log(std_init) + log(std)) since the network
|
||||
# will predict log_std ~= 0 after initialization.
|
||||
log_std = (self._log_std_init + self._log_std_layer(features)).clamp(self._log_std_min, self._log_std_max)
|
||||
std = log_std.exp()
|
||||
|
||||
return mean, std
|
|
@ -0,0 +1,37 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Tuple, Union
|
||||
|
||||
from rsl_rl.modules.network import Network
|
||||
|
||||
|
||||
class GaussianNetwork(Network):
|
||||
"""A network to predict mean and std of a gaussian distribution where std is a tunable parameter."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
log_std_max: float = 4.0,
|
||||
log_std_min: float = -20.0,
|
||||
std_init: float = 1.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(input_size, output_size, **kwargs)
|
||||
|
||||
self._log_std_max = log_std_max
|
||||
self._log_std_min = log_std_min
|
||||
|
||||
self._log_std = nn.Parameter(torch.ones(output_size) * np.log(std_init))
|
||||
|
||||
def forward(self, x: torch.Tensor, compute_std: bool = False, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
|
||||
mean = super().forward(x, **kwargs)
|
||||
|
||||
if not compute_std:
|
||||
return mean
|
||||
|
||||
log_std = torch.ones_like(mean) * self._log_std.clamp(self._log_std_min, self._log_std_max)
|
||||
std = log_std.exp()
|
||||
|
||||
return mean, std
|
|
@ -0,0 +1,168 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from torch.distributions import Normal
|
||||
import torch.nn as nn
|
||||
from typing import List, Union
|
||||
|
||||
from rsl_rl.modules.network import Network
|
||||
from rsl_rl.modules.quantile_network import energy_loss
|
||||
from rsl_rl.utils.benchmarkable import Benchmarkable
|
||||
|
||||
|
||||
def reshape_measure_param(tau: torch.Tensor, param: torch.Tensor) -> torch.Tensor:
|
||||
if not torch.is_tensor(param):
|
||||
param = torch.tensor([param])
|
||||
|
||||
param = param.expand(tau.shape[0], -1).to(tau.device)
|
||||
|
||||
return param
|
||||
|
||||
|
||||
def risk_measure_neutral(tau: torch.Tensor) -> torch.Tensor:
|
||||
return tau
|
||||
|
||||
|
||||
def risk_measure_wang(tau: torch.Tensor, beta: float = 0.0) -> torch.Tensor:
|
||||
beta = reshape_measure_param(tau, beta)
|
||||
|
||||
distorted_tau = Normal(0, 1).cdf(Normal(0, 1).icdf(tau) + beta)
|
||||
|
||||
return distorted_tau
|
||||
|
||||
|
||||
class ImplicitQuantileNetwork(Network):
|
||||
measure_neutral = "neutral"
|
||||
measure_wang = "wang"
|
||||
|
||||
measures = {
|
||||
measure_neutral: risk_measure_neutral,
|
||||
measure_wang: risk_measure_wang,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
activations: List[str] = ["relu", "relu", "relu"],
|
||||
feature_layers: int = 1,
|
||||
embedding_size: int = 64,
|
||||
hidden_dims: List[int] = [256, 256, 256],
|
||||
init_fade: bool = False,
|
||||
init_gain: float = 0.5,
|
||||
measure: str = None,
|
||||
measure_kwargs: dict = {},
|
||||
**kwargs,
|
||||
):
|
||||
assert len(hidden_dims) == len(activations), "hidden_dims and activations must have the same length."
|
||||
assert feature_layers > 0, "feature_layers must be greater than 0."
|
||||
assert feature_layers < len(hidden_dims), "feature_layers must be less than the number of hidden dimensions."
|
||||
assert embedding_size > 0, "embedding_size must be greater than 0."
|
||||
|
||||
super().__init__(
|
||||
input_size,
|
||||
hidden_dims[feature_layers - 1],
|
||||
activations=activations[:feature_layers],
|
||||
hidden_dims=hidden_dims[: feature_layers - 1],
|
||||
init_fade=init_fade,
|
||||
init_gain=init_gain,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._last_taus = None
|
||||
self._last_quantiles = None
|
||||
self._embedding_size = embedding_size
|
||||
self.register_buffer(
|
||||
"_embedding_pis",
|
||||
np.pi * (torch.arange(self._embedding_size, device=self.device).reshape(1, 1, self._embedding_size)),
|
||||
)
|
||||
self._embedding_layer = nn.Sequential(
|
||||
nn.Linear(self._embedding_size, hidden_dims[feature_layers - 1]), nn.ReLU()
|
||||
)
|
||||
|
||||
self._fusion_layers = Network(
|
||||
hidden_dims[feature_layers - 1],
|
||||
output_size,
|
||||
activations=activations[feature_layers:] + ["linear"],
|
||||
hidden_dims=hidden_dims[feature_layers:],
|
||||
init_fade=init_fade,
|
||||
init_gain=init_gain,
|
||||
)
|
||||
|
||||
measure_func = risk_measure_neutral if measure is None else self.measures[measure]
|
||||
self._measure_func = measure_func
|
||||
self._measure_kwargs = measure_kwargs
|
||||
|
||||
@Benchmarkable.register
|
||||
def _sample_taus(self, batch_size: int, sample_count: int, measure_args: list, use_measure: bool) -> torch.Tensor:
|
||||
"""Sample quantiles and distort them according to the risk metric.
|
||||
|
||||
Args:
|
||||
batch_size: The batch size.
|
||||
sample_count: The number of samples to draw.
|
||||
measure_args: The arguments to pass to the risk measure function.
|
||||
use_measure: Whether to use the risk measure or not.
|
||||
Returns:
|
||||
A tensor of shape (batch_size, sample_count, 1).
|
||||
"""
|
||||
taus = torch.rand(batch_size, sample_count, device=self.device)
|
||||
|
||||
if not use_measure:
|
||||
return taus
|
||||
|
||||
if measure_args:
|
||||
taus = self._measure_func(taus, *measure_args)
|
||||
else:
|
||||
taus = self._measure_func(taus, **self._measure_kwargs)
|
||||
|
||||
return taus
|
||||
|
||||
@Benchmarkable.register
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
distribution: bool = False,
|
||||
measure_args: list = [],
|
||||
sample_count: int = 8,
|
||||
taus: Union[torch.Tensor, None] = None,
|
||||
use_measure: bool = True,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert taus is None or not use_measure, "Cannot use taus and use_measure at the same time."
|
||||
|
||||
batch_size = x.shape[0]
|
||||
|
||||
features = super().forward(x, **kwargs)
|
||||
taus = self._sample_taus(batch_size, sample_count, measure_args, use_measure) if taus is None else taus
|
||||
|
||||
# Compute quantile embeddings
|
||||
singular_dims = [1] * taus.dim()
|
||||
cos = torch.cos(taus.unsqueeze(-1) * self._embedding_pis.reshape(*singular_dims, self._embedding_size))
|
||||
embeddings = self._embedding_layer(cos)
|
||||
|
||||
# Compute the fusion of the features and the embeddings
|
||||
fused_features = features.unsqueeze(-2) * embeddings
|
||||
quantiles = self._fusion_layers(fused_features)
|
||||
|
||||
self._last_quantiles = quantiles
|
||||
self._last_taus = taus
|
||||
|
||||
if distribution:
|
||||
return quantiles
|
||||
|
||||
values = quantiles.mean(-1)
|
||||
|
||||
return values
|
||||
|
||||
@property
|
||||
def last_taus(self):
|
||||
return self._last_taus
|
||||
|
||||
@property
|
||||
def last_quantiles(self):
|
||||
return self._last_quantiles
|
||||
|
||||
@Benchmarkable.register
|
||||
def sample_energy_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
||||
loss = energy_loss(predictions, targets)
|
||||
|
||||
return loss
|
|
@ -0,0 +1,210 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import List
|
||||
|
||||
from rsl_rl.modules.normalizer import EmpiricalNormalization
|
||||
from rsl_rl.modules.utils import get_activation
|
||||
from rsl_rl.modules.transformer import Transformer
|
||||
from rsl_rl.utils.benchmarkable import Benchmarkable
|
||||
from rsl_rl.utils.utils import squeeze_preserve_batch
|
||||
|
||||
|
||||
class Network(Benchmarkable, nn.Module):
|
||||
recurrent_module_lstm = "LSTM"
|
||||
recurrent_module_transformer = "TF"
|
||||
|
||||
recurrent_modules = {recurrent_module_lstm: nn.LSTM, recurrent_module_transformer: Transformer}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
activations: List[str] = ["relu", "relu", "relu", "tanh"],
|
||||
hidden_dims: List[int] = [256, 256, 256],
|
||||
init_fade: bool = True,
|
||||
init_gain: float = 1.0,
|
||||
input_normalization: bool = False,
|
||||
recurrent: bool = False,
|
||||
recurrent_layers: int = 1,
|
||||
recurrent_module: str = recurrent_module_lstm,
|
||||
recurrent_tf_context_length: int = 64,
|
||||
recurrent_tf_head_count: int = 8,
|
||||
) -> None:
|
||||
"""
|
||||
|
||||
Args:
|
||||
input_size (int): The size of the input.
|
||||
output_size (int): The size of the output.
|
||||
activations (List[str]): The activation functions to use. If the network is recurrent, the first activation
|
||||
function is used for the output of the recurrent layer.
|
||||
hidden_dims (List[int]): The hidden dimensions. If the network is recurrent, the first hidden dimension is
|
||||
used for the recurrent layer.
|
||||
init_fade (bool): Whether to use the fade in initialization.
|
||||
init_gain (float): The gain to use for the initialization.
|
||||
input_normalization (bool): Whether to use input normalization.
|
||||
recurrent (bool): Whether to use a recurrent network.
|
||||
recurrent_layers (int): The number of recurrent layers (LSTM) / blocks (Transformer) to use.
|
||||
recurrent_module (str): The recurrent module to use. Must be one of Network.recurrent_modules.
|
||||
recurrent_tf_context_length (int): The context length of the Transformer.
|
||||
recurrent_tf_head_count (int): The head count of the Transformer.
|
||||
"""
|
||||
assert len(hidden_dims) + 1 == len(activations)
|
||||
|
||||
super().__init__()
|
||||
|
||||
if input_normalization:
|
||||
self._normalization = EmpiricalNormalization(shape=(input_size,))
|
||||
else:
|
||||
self._normalization = nn.Identity()
|
||||
|
||||
dims = [input_size] + hidden_dims + [output_size]
|
||||
|
||||
self._recurrent = recurrent
|
||||
self._recurrent_module = recurrent_module
|
||||
self.hidden_state = None
|
||||
self._last_hidden_state = None
|
||||
if self._recurrent:
|
||||
recurrent_kwargs = dict()
|
||||
|
||||
if recurrent_module == self.recurrent_module_lstm:
|
||||
recurrent_kwargs["hidden_size"] = dims[1]
|
||||
recurrent_kwargs["input_size"] = dims[0]
|
||||
recurrent_kwargs["num_layers"] = recurrent_layers
|
||||
elif recurrent_module == self.recurrent_module_transformer:
|
||||
recurrent_kwargs["block_count"] = recurrent_layers
|
||||
recurrent_kwargs["context_length"] = recurrent_tf_context_length
|
||||
recurrent_kwargs["head_count"] = recurrent_tf_head_count
|
||||
recurrent_kwargs["hidden_size"] = dims[1]
|
||||
recurrent_kwargs["input_size"] = dims[0]
|
||||
recurrent_kwargs["output_size"] = dims[1]
|
||||
|
||||
rnn = self.recurrent_modules[recurrent_module](**recurrent_kwargs)
|
||||
activation = get_activation(activations[0])
|
||||
dims = dims[1:]
|
||||
activations = activations[1:]
|
||||
|
||||
self._features = nn.Sequential(rnn, activation)
|
||||
else:
|
||||
self._features = nn.Identity()
|
||||
|
||||
layers = []
|
||||
for i in range(len(activations)):
|
||||
layer = nn.Linear(dims[i], dims[i + 1])
|
||||
activation = get_activation(activations[i])
|
||||
|
||||
layers.append(layer)
|
||||
layers.append(activation)
|
||||
|
||||
self._layers = nn.Sequential(*layers)
|
||||
|
||||
if len(layers) > 0:
|
||||
self._init(self._layers, fade=init_fade, gain=init_gain)
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
"""Returns the device of the network."""
|
||||
return next(self.parameters()).device
|
||||
|
||||
def forward(self, x: torch.Tensor, hidden_state=None) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): The input data.
|
||||
hidden_state (Tuple[torch.Tensor, torch.Tensor]): The hidden state of the network. If None, the hidden state
|
||||
of the network is used. If provided, the hidden state of the neural network will not be updated. To
|
||||
retrieve the new hidden state, use the last_hidden_state property. If the network is not recurrent,
|
||||
this argument is ignored.
|
||||
Returns:
|
||||
The output of the network as a torch.Tensor.
|
||||
"""
|
||||
assert hidden_state is None or self._recurrent, "Cannot pass hidden state to non-recurrent network."
|
||||
|
||||
input = self._normalization(x.to(self.device))
|
||||
|
||||
if self._recurrent:
|
||||
current_hidden_state = self.hidden_state if hidden_state is None else hidden_state
|
||||
current_hidden_state = (current_hidden_state[0].to(self.device), current_hidden_state[1].to(self.device))
|
||||
|
||||
input = input.unsqueeze(0) if len(input.shape) == 2 else input
|
||||
input, next_hidden_state = self._features[0](input, current_hidden_state)
|
||||
input = self._features[1](input).squeeze(0)
|
||||
|
||||
if hidden_state is None:
|
||||
self.hidden_state = next_hidden_state
|
||||
self._last_hidden_state = next_hidden_state
|
||||
|
||||
output = squeeze_preserve_batch(self._layers(input))
|
||||
|
||||
return output
|
||||
|
||||
@property
|
||||
def last_hidden_state(self):
|
||||
"""Returns the hidden state of the last forward pass.
|
||||
|
||||
Does not differentiate whether the hidden state depends on the hidden state kept in the network or whether it
|
||||
was passed into the forward pass.
|
||||
|
||||
Returns:
|
||||
The hidden state of the last forward pass as Tuple[torch.Tensor, torch.Tensor].
|
||||
"""
|
||||
return self._last_hidden_state
|
||||
|
||||
def normalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Normalizes the given input.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input to normalize.
|
||||
Returns:
|
||||
The normalized input as a torch.Tensor.
|
||||
"""
|
||||
output = self._normalization(x.to(self.device))
|
||||
|
||||
return output
|
||||
|
||||
@property
|
||||
def recurrent(self) -> bool:
|
||||
"""Returns whether the network is recurrent."""
|
||||
return self._recurrent
|
||||
|
||||
def reset_hidden_state(self, indices: torch.Tensor) -> None:
|
||||
"""Resets the hidden state of the neural network.
|
||||
|
||||
Throws an error if the network is not recurrent.
|
||||
|
||||
Args:
|
||||
indices (torch.Tensor): A 1-dimensional int tensor containing the indices of the terminated
|
||||
environments.
|
||||
"""
|
||||
assert self._recurrent
|
||||
|
||||
self.hidden_state[0][:, indices] = torch.zeros(len(indices), self._features[0].hidden_size, device=self.device)
|
||||
self.hidden_state[1][:, indices] = torch.zeros(len(indices), self._features[0].hidden_size, device=self.device)
|
||||
|
||||
def reset_full_hidden_state(self, batch_size=None) -> None:
|
||||
"""Resets the hidden state of the neural network.
|
||||
|
||||
Args:
|
||||
batch_size (int): The batch size of the hidden state. If None, the hidden state is reset to None.
|
||||
"""
|
||||
assert self._recurrent
|
||||
|
||||
if batch_size is None:
|
||||
self.hidden_state = None
|
||||
else:
|
||||
layer_count, hidden_size = self._features[0].num_layers, self._features[0].hidden_size
|
||||
self.hidden_state = (
|
||||
torch.zeros(layer_count, batch_size, hidden_size, device=self.device),
|
||||
torch.zeros(layer_count, batch_size, hidden_size, device=self.device),
|
||||
)
|
||||
|
||||
def _init(self, layers: List[nn.Module], fade: bool = True, gain: float = 1.0) -> List[nn.Module]:
|
||||
"""Initializes neural network layers."""
|
||||
last_layer_idx = len(layers) - 1 - next(i for i, l in enumerate(reversed(layers)) if isinstance(l, nn.Linear))
|
||||
|
||||
for idx, layer in enumerate(layers):
|
||||
if not isinstance(layer, nn.Linear):
|
||||
continue
|
||||
|
||||
current_gain = gain / 100.0 if fade and idx == last_layer_idx else gain
|
||||
nn.init.xavier_normal_(layer.weight, gain=current_gain)
|
||||
|
||||
return layers
|
|
@ -1,9 +1,3 @@
|
|||
# Copyright (c) 2020 Preferred Networks, Inc.
|
||||
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
@ -11,48 +5,58 @@ from torch import nn
|
|||
class EmpiricalNormalization(nn.Module):
|
||||
"""Normalize mean and variance of values based on empirical values."""
|
||||
|
||||
def __init__(self, shape, eps=1e-2, until=None):
|
||||
"""Initialize EmpiricalNormalization module.
|
||||
|
||||
def __init__(self, shape, eps=1e-6, until=None) -> None:
|
||||
"""
|
||||
Args:
|
||||
shape (int or tuple of int): Shape of input values except batch axis.
|
||||
eps (float): Small value for stability.
|
||||
until (int or None): If this arg is specified, the link learns input values until the sum of batch sizes
|
||||
exceeds it.
|
||||
exceeds it.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.eps = eps
|
||||
self.until = until
|
||||
|
||||
self.register_buffer("_mean", torch.zeros(shape).unsqueeze(0))
|
||||
self.register_buffer("_var", torch.ones(shape).unsqueeze(0))
|
||||
self.register_buffer("_std", torch.ones(shape).unsqueeze(0))
|
||||
|
||||
self.count = 0
|
||||
|
||||
@property
|
||||
def mean(self):
|
||||
return self._mean.squeeze(0).clone()
|
||||
def mean(self) -> torch.Tensor:
|
||||
"""Mean of input values."""
|
||||
return self._mean.squeeze(0).detach().clone()
|
||||
|
||||
@property
|
||||
def std(self):
|
||||
return self._std.squeeze(0).clone()
|
||||
|
||||
def forward(self, x):
|
||||
"""Normalize mean and variance of values based on empirical values.
|
||||
def std(self) -> torch.Tensor:
|
||||
"""Standard deviation of input values."""
|
||||
return self._std.squeeze(0).detach().clone()
|
||||
|
||||
def forward(self, x) -> torch.Tensor:
|
||||
"""Normalize mean and variance of values based on emprical values.
|
||||
Args:
|
||||
x (ndarray or Variable): Input values
|
||||
|
||||
Returns:
|
||||
ndarray or Variable: Normalized output values
|
||||
Normalized output values
|
||||
"""
|
||||
|
||||
if self.training:
|
||||
self.update(x)
|
||||
return (x - self._mean) / (self._std + self.eps)
|
||||
|
||||
x_normalized = (x - self._mean.detach()) / (self._std.detach() + self.eps)
|
||||
|
||||
return x_normalized
|
||||
|
||||
@torch.jit.unused
|
||||
def update(self, x):
|
||||
"""Learn input values without computing the output values of them"""
|
||||
def update(self, x: torch.Tensor) -> None:
|
||||
"""Learn input values without computing the output values of them.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input values.
|
||||
"""
|
||||
x = x.detach()
|
||||
|
||||
if self.until is not None and self.count >= self.until:
|
||||
return
|
||||
|
@ -69,5 +73,14 @@ class EmpiricalNormalization(nn.Module):
|
|||
self._std = torch.sqrt(self._var)
|
||||
|
||||
@torch.jit.unused
|
||||
def inverse(self, y):
|
||||
return y * (self._std + self.eps) + self._mean
|
||||
def inverse(self, y: torch.Tensor) -> torch.Tensor:
|
||||
"""Inverse normalized values.
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Normalized input values.
|
||||
Returns:
|
||||
Inverse normalized output values.
|
||||
"""
|
||||
inv = y * (self._std + self.eps) + self._mean
|
||||
|
||||
return inv
|
||||
|
|
|
@ -0,0 +1,333 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributions import Normal
|
||||
from typing import Callable, Tuple, Union
|
||||
|
||||
from rsl_rl.modules.network import Network
|
||||
from rsl_rl.utils.benchmarkable import Benchmarkable
|
||||
from rsl_rl.utils.utils import squeeze_preserve_batch
|
||||
|
||||
eps = torch.finfo(torch.float32).eps
|
||||
|
||||
|
||||
def reshape_measure_parameters(
|
||||
qn: Network, *params: Union[torch.Tensor, float]
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
||||
"""Reshapes the parameters of a measure function to match the shape of the quantile network.
|
||||
|
||||
Args:
|
||||
qn (Network): The quantile network.
|
||||
*params (Union[torch.Tensor, float]): The parameters of the measure function.
|
||||
Returns:
|
||||
Union[torch.Tensor, Tuple[torch.Tensor, ...]]: The reshaped parameters.
|
||||
"""
|
||||
if not params:
|
||||
return qn._tau.to(qn.device), *params
|
||||
|
||||
assert len([*set([torch.is_tensor(p) for p in params])]) == 1, "All parameters must be either tensors or scalars."
|
||||
|
||||
if torch.is_tensor(params[0]):
|
||||
assert all([p.dim() == 1 for p in params]), "All parameters must have dimensionality 1."
|
||||
assert len([*set([p.shape[0] for p in params])]) == 1, "All parameters must have the same size."
|
||||
|
||||
reshaped_params = [p.reshape(-1, 1).to(qn.device) for p in params]
|
||||
tau = qn._tau.expand(params[0].shape[0], -1).to(qn.device)
|
||||
else:
|
||||
reshaped_params = params
|
||||
tau = qn._tau.to(qn.device)
|
||||
|
||||
return tau, *reshaped_params
|
||||
|
||||
|
||||
def make_distorted_measure(distorted_tau: torch.Tensor) -> Callable:
|
||||
"""Creates a measure function for the distorted expectation under some distortion function.
|
||||
|
||||
The distorted expectation for some distortion function g(tau) is given by the integral w.r.t. tau
|
||||
"int_0^1 g'(tau) * F_Z^{-1}(tau) dtau" where g'(tau) is the derivative of g w.r.t. tau and F_Z^{-1} is the inverse
|
||||
cumulative distribution function of the value distribution.
|
||||
See https://arxiv.org/pdf/2004.14547.pdf and https://arxiv.org/pdf/1806.06923.pdf for details.
|
||||
"""
|
||||
distorted_tau = distorted_tau.reshape(-1, distorted_tau.shape[-1])
|
||||
distortion = (distorted_tau[:, 1:] - distorted_tau[:, :-1]).squeeze(0)
|
||||
|
||||
def distorted_measure(quantiles):
|
||||
sorted_quantiles, _ = quantiles.sort(-1)
|
||||
sorted_quantiles = sorted_quantiles.reshape(-1, sorted_quantiles.shape[-1])
|
||||
|
||||
# dtau = tau[1:] - tau[:-1] cancels the denominator of g'(tau) = g(tau)[1:] - g(tau)[:-1] / dtau.
|
||||
values = squeeze_preserve_batch((distortion.to(sorted_quantiles.device) * sorted_quantiles).sum(-1))
|
||||
|
||||
return values
|
||||
|
||||
return distorted_measure
|
||||
|
||||
|
||||
def risk_measure_cvar(qn: Network, confidence_level: float = 1.0) -> Callable:
|
||||
"""Conditional value at risk measure.
|
||||
|
||||
TODO: Handle confidence_level being a tensor.
|
||||
|
||||
Args:
|
||||
qn (QuantileNetwork): Quantile network to compute the risk measure for.
|
||||
confidence_level (float): Confidence level of the risk measure. Must be between 0 and 1.
|
||||
Returns:
|
||||
A risk measure function.
|
||||
"""
|
||||
tau, confidence_level = reshape_measure_parameters(qn, confidence_level)
|
||||
distorted_tau = torch.min(tau / confidence_level, torch.ones(*tau.shape).to(tau.device))
|
||||
|
||||
return make_distorted_measure(distorted_tau)
|
||||
|
||||
|
||||
def risk_measure_neutral(_: Network) -> Callable:
|
||||
"""Neutral risk measure (expected value).
|
||||
|
||||
Args:
|
||||
_ (QuantileNetwork): Quantile network to compute the risk measure for.
|
||||
Returns:
|
||||
A risk measure function.
|
||||
"""
|
||||
|
||||
def measure(quantiles):
|
||||
values = squeeze_preserve_batch(quantiles.mean(-1))
|
||||
|
||||
return values
|
||||
|
||||
return measure
|
||||
|
||||
|
||||
def risk_measure_percentile(_: Network, confidence_level: float = 1.0) -> Callable:
|
||||
"""Value at risk measure.
|
||||
|
||||
Args:
|
||||
_ (QuantileNetwork): Quantile network to compute the risk measure for.
|
||||
confidence_level (float): Confidence level of the risk measure. Must be between 0 and 1.
|
||||
Returns:
|
||||
A risk measure function.
|
||||
"""
|
||||
|
||||
def measure(quantiles):
|
||||
sorted_quantiles, _ = quantiles.sort(-1)
|
||||
sorted_quantiles = sorted_quantiles.reshape(-1, sorted_quantiles.shape[-1])
|
||||
idx = min(int(confidence_level * quantiles.shape[-1]), quantiles.shape[-1] - 1)
|
||||
|
||||
values = squeeze_preserve_batch(sorted_quantiles[:, idx])
|
||||
|
||||
return values
|
||||
|
||||
return measure
|
||||
|
||||
|
||||
def risk_measure_wang(qn: Network, beta: Union[float, torch.Tensor] = 0.0) -> Callable:
|
||||
"""Wang's risk measure.
|
||||
|
||||
The risk measure computes the distorted expectation under Wang's risk distortion function
|
||||
g(tau) = Phi(Phi^-1(tau) + beta) where Phi and Phi^-1 are the standard normal CDF and its inverse.
|
||||
See https://arxiv.org/pdf/2004.14547.pdf for details.
|
||||
|
||||
Args:
|
||||
qn (QuantileNetwork): Quantile network to compute the risk measure for.
|
||||
beta (float): Parameter of the risk distortion function.
|
||||
Returns:
|
||||
A risk measure function.
|
||||
"""
|
||||
tau, beta = reshape_measure_parameters(qn, beta)
|
||||
|
||||
distorted_tau = Normal(0, 1).cdf(Normal(0, 1).icdf(tau) + beta)
|
||||
|
||||
return make_distorted_measure(distorted_tau)
|
||||
|
||||
|
||||
def energy_loss(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
||||
"""Computes sample energy loss between predictions and targets.
|
||||
|
||||
The energy loss is computed as 2*E[||X - Y||_2] - E[||X - X'||_2] - E[||Y - Y'||_2], where X, X' and Y, Y' are
|
||||
random variables and ||.||_2 is the L2-norm. X, X' are the predictions and Y, Y' are the targets.
|
||||
|
||||
Args:
|
||||
predictions (torch.Tensor): Predictions to compute loss from.
|
||||
targets (torch.Tensor): Targets to compare predictions against.
|
||||
Returns:
|
||||
A torch.Tensor of shape (1,) containing the loss.
|
||||
"""
|
||||
dims = [-1 for _ in range(predictions.dim())]
|
||||
prediction_mat = predictions.unsqueeze(-1).expand(*dims, predictions.shape[-1])
|
||||
target_mat = targets.unsqueeze(-1).expand(*dims, predictions.shape[-1])
|
||||
|
||||
delta_xx = (prediction_mat - prediction_mat.transpose(-1, -2)).abs().mean()
|
||||
delta_yy = (target_mat - target_mat.transpose(-1, -2)).abs().mean()
|
||||
delta_xy = (prediction_mat - target_mat.transpose(-1, -2)).abs().mean()
|
||||
|
||||
loss = 2 * delta_xy - delta_xx - delta_yy
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class QuantileNetwork(Network):
|
||||
measure_cvar = "cvar"
|
||||
measure_neutral = "neutral"
|
||||
measure_percentile = "percentile"
|
||||
measure_wang = "wang"
|
||||
|
||||
measures = {
|
||||
measure_cvar: risk_measure_cvar,
|
||||
measure_neutral: risk_measure_neutral,
|
||||
measure_percentile: risk_measure_percentile,
|
||||
measure_wang: risk_measure_wang,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
output_size,
|
||||
activations=["relu", "relu", "relu"],
|
||||
hidden_dims=[256, 256, 256],
|
||||
init_fade=False,
|
||||
init_gain=0.5,
|
||||
measure=None,
|
||||
measure_kwargs={},
|
||||
quantile_count=200,
|
||||
**kwargs,
|
||||
):
|
||||
assert len(hidden_dims) == len(activations)
|
||||
assert quantile_count > 0
|
||||
|
||||
super().__init__(
|
||||
input_size,
|
||||
activations=activations,
|
||||
hidden_dims=hidden_dims[:-1],
|
||||
init_fade=False,
|
||||
init_gain=init_gain,
|
||||
output_size=hidden_dims[-1],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._quantile_count = quantile_count
|
||||
self._tau = torch.arange(self._quantile_count + 1) / self._quantile_count
|
||||
self._tau_hat = torch.tensor([(self._tau[i] + self._tau[i + 1]) / 2 for i in range(self._quantile_count)])
|
||||
self._tau_hat_mat = torch.empty((0,))
|
||||
|
||||
self._quantile_layers = nn.ModuleList([nn.Linear(hidden_dims[-1], quantile_count) for _ in range(output_size)])
|
||||
|
||||
self._init(self._quantile_layers, fade=init_fade, gain=init_gain)
|
||||
|
||||
measure_func = risk_measure_neutral if measure is None else self.measures[measure]
|
||||
self._measure_func = measure_func
|
||||
self._measure = measure_func(self, **measure_kwargs)
|
||||
|
||||
self._last_quantiles = None
|
||||
|
||||
@property
|
||||
def last_quantiles(self) -> torch.Tensor:
|
||||
return self._last_quantiles
|
||||
|
||||
def make_diracs(self, values: torch.Tensor) -> torch.Tensor:
|
||||
"""Generates value distributions that have a single spike at the given values.
|
||||
|
||||
Args:
|
||||
values (torch.Tensor): Values to generate dirac distributions for.
|
||||
Returns:
|
||||
A torch.Tensor of shape (*values.shape, quantile_count) containing the dirac distributions.
|
||||
"""
|
||||
dirac = values.unsqueeze(-1).expand(*[-1 for _ in range(values.dim())], self._quantile_count)
|
||||
|
||||
return dirac
|
||||
|
||||
@property
|
||||
def quantile_count(self) -> int:
|
||||
return self._quantile_count
|
||||
|
||||
@Benchmarkable.register
|
||||
def quantiles_to_values(self, quantiles: torch.Tensor, *measure_args) -> torch.Tensor:
|
||||
"""Computes values from quantiles.
|
||||
|
||||
Args:
|
||||
quantiles (torch.Tensor): Quantiles to compute values from.
|
||||
measure_kwargs (dict): Keyword arguments to pass to the risk measure function instead of the arguments
|
||||
passed when creating the network.
|
||||
Returns:
|
||||
A torch.Tensor of shape (1,) containing the values.
|
||||
"""
|
||||
if measure_args:
|
||||
values = self._measure_func(self, *[squeeze_preserve_batch(m) for m in measure_args])(quantiles)
|
||||
else:
|
||||
values = self._measure(quantiles)
|
||||
|
||||
return values
|
||||
|
||||
@Benchmarkable.register
|
||||
def quantile_l1_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
||||
"""Computes quantile-wise l1 loss between predictions and targets.
|
||||
|
||||
TODO: This function is a bottleneck.
|
||||
|
||||
Args:
|
||||
predictions (torch.Tensor): Predictions to compute loss from.
|
||||
targets (torch.Tensor): Targets to compare predictions against.
|
||||
Returns:
|
||||
A torch.Tensor of shape (1,) containing the loss.
|
||||
"""
|
||||
assert (
|
||||
predictions.dim() == 2 or predictions.dim() == 3
|
||||
), f"Predictions must be 2D or 3D. Got {predictions.dim()}."
|
||||
assert (
|
||||
predictions.shape == targets.shape
|
||||
), f"The shapes of predictions and targets must match. Got {predictions.shape} and {targets.shape}."
|
||||
|
||||
pre_dims = [-1] if predictions.dim() == 3 else []
|
||||
|
||||
prediction_mat = predictions.unsqueeze(-3).expand(*pre_dims, self._quantile_count, -1, -1)
|
||||
target_mat = targets.transpose(-2, -1).unsqueeze(-1).expand(*pre_dims, -1, -1, self._quantile_count)
|
||||
delta = target_mat - prediction_mat
|
||||
|
||||
tau_hat = self._tau_hat.expand(predictions.shape[-2], -1).to(self.device)
|
||||
loss = (torch.where(delta < 0, (tau_hat - 1), tau_hat) * delta).abs().mean()
|
||||
|
||||
return loss
|
||||
|
||||
@Benchmarkable.register
|
||||
def quantile_huber_loss(self, predictions: torch.Tensor, targets: torch.Tensor, kappa: float = 1.0) -> torch.Tensor:
|
||||
"""Computes quantile huber loss between predictions and targets.
|
||||
|
||||
TODO: This function is a bottleneck.
|
||||
|
||||
Args:
|
||||
predictions (torch.Tensor): Predictions to compute loss from.
|
||||
targets (torch.Tensor): Targets to compare predictions against.
|
||||
kappa (float): Defines the interval [-kappa, kappa] around zero where squared loss is used. Defaults to 1.
|
||||
Returns:
|
||||
A torch.tensor of shape (1,) containing the loss.
|
||||
"""
|
||||
pre_dims = [-1] if predictions.dim() == 3 else []
|
||||
|
||||
prediction_mat = predictions.unsqueeze(-3).expand(*pre_dims, self._quantile_count, -1, -1)
|
||||
target_mat = targets.transpose(-2, -1).unsqueeze(-1).expand(*pre_dims, -1, -1, self._quantile_count)
|
||||
delta = target_mat - prediction_mat
|
||||
delta_abs = delta.abs()
|
||||
|
||||
huber = torch.where(delta_abs <= kappa, 0.5 * delta.pow(2), kappa * (delta_abs - 0.5 * kappa))
|
||||
|
||||
tau_hat = self._tau_hat.expand(predictions.shape[-2], -1).to(self.device)
|
||||
loss = (torch.where(delta < 0, (tau_hat - 1), tau_hat).abs() * huber).mean()
|
||||
|
||||
return loss
|
||||
|
||||
@Benchmarkable.register
|
||||
def sample_energy_loss(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
||||
loss = energy_loss(predictions, targets)
|
||||
|
||||
return loss
|
||||
|
||||
@Benchmarkable.register
|
||||
def forward(self, x: torch.Tensor, distribution: bool = False, measure_args: list = [], **kwargs) -> torch.Tensor:
|
||||
features = super().forward(x, **kwargs)
|
||||
quantiles = squeeze_preserve_batch(torch.stack([layer(features) for layer in self._quantile_layers], dim=1))
|
||||
|
||||
self._last_quantiles = quantiles
|
||||
|
||||
if distribution:
|
||||
return quantiles
|
||||
|
||||
values = self.quantiles_to_values(quantiles, *measure_args)
|
||||
|
||||
return values
|
|
@ -0,0 +1,150 @@
|
|||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class Head(torch.nn.Module):
|
||||
"""A single causal self-attention head."""
|
||||
|
||||
def __init__(self, hidden_size: int, head_size: int):
|
||||
super().__init__()
|
||||
|
||||
self.query = torch.nn.Linear(hidden_size, head_size)
|
||||
self.key = torch.nn.Linear(hidden_size, head_size)
|
||||
self.value = torch.nn.Linear(hidden_size, head_size)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = x.transpose(0, 1)
|
||||
_, S, F = x.shape # (Batch, Sequence, Features)
|
||||
|
||||
q = self.query(x)
|
||||
k = self.key(x)
|
||||
|
||||
weight = q @ k.transpose(-1, -2) * F**-0.5 # shape: (B, S, S)
|
||||
weight.masked_fill(torch.tril(torch.ones(S, S, device=x.device)) == 0, float("-inf"))
|
||||
weight = torch.nn.functional.softmax(weight, dim=-1)
|
||||
|
||||
v = self.value(x) # shape: (B, S, F)
|
||||
out = (weight @ v).transpose(0, 1) # shape: (S, B, F)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class MultiHead(torch.nn.Module):
|
||||
def __init__(self, hidden_size: int, head_count: int):
|
||||
super().__init__()
|
||||
|
||||
assert hidden_size % head_count == 0, f"Multi-headed attention head size must be a multiple of the head count."
|
||||
|
||||
self.heads = torch.nn.ModuleList([Head(hidden_size, hidden_size // head_count) for _ in range(head_count)])
|
||||
self.proj = torch.nn.Linear(hidden_size, hidden_size)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = torch.cat([head(x) for head in self.heads], dim=-1)
|
||||
out = self.proj(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Block(torch.nn.Module):
|
||||
def __init__(self, hidden_size: int, head_count: int):
|
||||
super().__init__()
|
||||
|
||||
self.sa = MultiHead(hidden_size, head_count)
|
||||
self.ff = torch.nn.Sequential(
|
||||
torch.nn.Linear(hidden_size, 4 * hidden_size),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(4 * hidden_size, hidden_size),
|
||||
)
|
||||
self.ln1 = torch.nn.LayerNorm(hidden_size)
|
||||
self.ln2 = torch.nn.LayerNorm(hidden_size)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.sa(self.ln1(x)) + x
|
||||
out = self.ff(self.ln2(x)) + x
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Transformer(torch.nn.Module):
|
||||
"""A Transformer-based recurrent module.
|
||||
|
||||
The Transformer module is a recurrent module that uses a Transformer architecture to process the input sequence. It
|
||||
uses a hidden state to emulate RNN-like behavior.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, input_size, output_size, hidden_size, block_count: int = 6, context_length: int = 64, head_count: int = 8
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
input_size (int): The size of the input.
|
||||
output_size (int): The size of the output.
|
||||
hidden_size (int): The size of the hidden layers.
|
||||
block_count (int): The number of Transformer blocks.
|
||||
context_length (int): The length of the context to consider when predicting the next token.
|
||||
head_count (int): The number of attention heads per block.
|
||||
"""
|
||||
|
||||
assert context_length % 2 == 0, f"Context length must be even."
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.context_length = context_length
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
self.feature_proj = torch.nn.Linear(input_size, hidden_size)
|
||||
self.position_embedding = torch.nn.Embedding(context_length, hidden_size)
|
||||
self.blocks = torch.nn.Sequential(
|
||||
*[Block(hidden_size, head_count) for _ in range(block_count)],
|
||||
torch.nn.LayerNorm(hidden_size),
|
||||
)
|
||||
self.head = torch.nn.Linear(hidden_size, output_size)
|
||||
|
||||
@property
|
||||
def num_layers(self):
|
||||
# Set num_layers to half the context length for simple torch.nn.LSTM compatibility. TODO: This is a bit hacky.
|
||||
return self.context_length // 2
|
||||
|
||||
def step(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
|
||||
"""Computes Transformer output given the full context and the input.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor of shape (Sequence, Batch, Features).
|
||||
context (torch.Tensor): The context tensor of shape (Context Length, Batch, Features).
|
||||
Returns:
|
||||
A tuple of the output tensor of shape (Sequence, Batch, Features) and the updated context with the input
|
||||
features appended. The context has shape (Context Length, Batch, Features).
|
||||
"""
|
||||
S = x.shape[0]
|
||||
|
||||
# Project input to feature space and add to context.
|
||||
ft_x = self.feature_proj(x)
|
||||
context = torch.cat((context, ft_x), dim=0)[-self.context_length :]
|
||||
|
||||
# Add positional embedding to context.
|
||||
ft_pos = self.position_embedding(torch.arange(self.context_length, device=x.device)).unsqueeze(1)
|
||||
x = context + ft_pos
|
||||
|
||||
# Compute output from Transformer blocks.
|
||||
x = self.blocks(x)
|
||||
out = self.head(x)[-S:]
|
||||
|
||||
return out, context
|
||||
|
||||
def forward(self, x: torch.Tensor, hidden_state: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
||||
"""Computes Transformer output given the input and the hidden state which encapsulates the context."""
|
||||
if hidden_state is None:
|
||||
hidden_state = self.reset_hidden_state(x.shape[1], device=x.device)
|
||||
context = torch.cat(hidden_state, dim=0)
|
||||
|
||||
out, context = self.step(x, context)
|
||||
|
||||
hidden_state = context[: self.num_layers], context[self.num_layers :]
|
||||
|
||||
return out, hidden_state
|
||||
|
||||
def reset_hidden_state(self, batch_size: int, device="cpu"):
|
||||
hidden_state = torch.zeros((self.context_length, batch_size, self.hidden_size), device=device)
|
||||
hidden_state = hidden_state[: self.num_layers], hidden_state[self.num_layers :]
|
||||
|
||||
return hidden_state
|
|
@ -0,0 +1,32 @@
|
|||
from torch import nn
|
||||
|
||||
|
||||
def get_activation(act_name):
|
||||
if act_name == "elu":
|
||||
return nn.ELU()
|
||||
elif act_name == "selu":
|
||||
return nn.SELU()
|
||||
elif act_name == "relu":
|
||||
return nn.ReLU()
|
||||
elif act_name == "crelu":
|
||||
return nn.ReLU()
|
||||
elif act_name == "lrelu":
|
||||
return nn.LeakyReLU()
|
||||
elif act_name == "tanh":
|
||||
return nn.Tanh()
|
||||
elif act_name == "sigmoid":
|
||||
return nn.Sigmoid()
|
||||
elif act_name == "linear":
|
||||
return nn.Identity()
|
||||
elif act_name == "softmax":
|
||||
return nn.Softmax()
|
||||
else:
|
||||
print("invalid activation function!")
|
||||
return None
|
||||
|
||||
|
||||
def init_xavier_uniform(layer, activation):
|
||||
try:
|
||||
nn.init.xavier_uniform_(layer.weight, gain=nn.init.calculate_gain(activation))
|
||||
except ValueError:
|
||||
nn.init.xavier_uniform_(layer.weight)
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
"""Implementation of runners for environment-agent interaction."""
|
||||
|
||||
from .on_policy_runner import OnPolicyRunner
|
||||
from .runner import Runner
|
||||
from .legacy_runner import LeggedGymRunner
|
||||
|
||||
__all__ = ["OnPolicyRunner"]
|
||||
__all__ = ["LeggedGymRunner", "Runner"]
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
import os
|
||||
import random
|
||||
import string
|
||||
import wandb
|
||||
|
||||
|
||||
def make_save_model_cb(directory):
|
||||
def cb(runner, stat):
|
||||
path = os.path.join(directory, "model_{}.pt".format(stat["current_iteration"]))
|
||||
runner.save(path)
|
||||
|
||||
return cb
|
||||
|
||||
|
||||
def make_save_model_onnx_cb(directory):
|
||||
def cb(runner, stat):
|
||||
path = os.path.join(directory, "model_{}.onnx".format(stat["current_iteration"]))
|
||||
runner.export_onnx(path)
|
||||
|
||||
return cb
|
||||
|
||||
|
||||
def make_interval_cb(callback, interval):
|
||||
def cb(runner, stat):
|
||||
if stat["current_iteration"] % interval != 0:
|
||||
return
|
||||
|
||||
callback(runner, stat)
|
||||
|
||||
return cb
|
||||
|
||||
|
||||
def make_final_cb(callback):
|
||||
def cb(runner, stat):
|
||||
if not runner._learning_should_terminate():
|
||||
return
|
||||
|
||||
callback(runner, stat)
|
||||
|
||||
return cb
|
||||
|
||||
|
||||
def make_first_cb(callback):
|
||||
uuid = "".join(random.choices(string.ascii_letters + string.digits, k=8))
|
||||
|
||||
def cb(runner, stat):
|
||||
if hasattr(runner, f"_first_cb_{uuid}"):
|
||||
return
|
||||
|
||||
setattr(runner, f"_first_cb_{uuid}", True)
|
||||
callback(runner, stat)
|
||||
|
||||
return cb
|
||||
|
||||
|
||||
def make_wandb_cb(init_kwargs):
|
||||
assert "project" in init_kwargs, "The project must be specified in the init_kwargs."
|
||||
|
||||
run = wandb.init(**init_kwargs)
|
||||
check_complete = make_final_cb(lambda *_: run.finish())
|
||||
|
||||
def cb(runner, stat):
|
||||
mean_reward = sum(stat["returns"]) / len(stat["returns"]) if len(stat["returns"]) > 0 else 0.0
|
||||
mean_steps = sum(stat["lengths"]) / len(stat["lengths"]) if len(stat["lengths"]) > 0 else 0.0
|
||||
total_steps = stat["current_iteration"] * runner.env.num_envs * runner._num_steps_per_env
|
||||
training_time = stat["training_time"]
|
||||
|
||||
run.log(
|
||||
{
|
||||
"mean_rewards": mean_reward,
|
||||
"mean_steps": mean_steps,
|
||||
"training_steps": total_steps,
|
||||
"training_time": training_time,
|
||||
}
|
||||
)
|
||||
|
||||
check_complete(runner, stat)
|
||||
|
||||
return cb
|
|
@ -0,0 +1,136 @@
|
|||
import os
|
||||
|
||||
from rsl_rl.algorithms import *
|
||||
from rsl_rl.env import VecEnv
|
||||
from rsl_rl.runners.callbacks import (
|
||||
make_final_cb,
|
||||
make_first_cb,
|
||||
make_interval_cb,
|
||||
make_save_model_onnx_cb,
|
||||
)
|
||||
from rsl_rl.runners.runner import Runner
|
||||
from rsl_rl.storage import *
|
||||
|
||||
|
||||
def make_legacy_save_model_cb(directory):
|
||||
def cb(runner, stat):
|
||||
data = {}
|
||||
|
||||
if hasattr(runner.env, "_persistent_data"):
|
||||
data["env_data"] = runner.env._persistent_data
|
||||
|
||||
path = os.path.join(directory, "model_{}.pt".format(stat["current_iteration"]))
|
||||
runner.save(path, data=data)
|
||||
|
||||
return cb
|
||||
|
||||
|
||||
class LeggedGymRunner(Runner):
|
||||
"""Runner for legged_gym environments."""
|
||||
|
||||
mappings = [
|
||||
("init_noise_std", "actor_noise_std"),
|
||||
("clip_param", "clip_ratio"),
|
||||
("desired_kl", "target_kl"),
|
||||
("entropy_coef", "entropy_coeff"),
|
||||
("lam", "gae_lambda"),
|
||||
("max_grad_norm", "gradient_clip"),
|
||||
("num_learning_epochs", None),
|
||||
("num_mini_batches", "batch_count"),
|
||||
("use_clipped_value_loss", None),
|
||||
("value_loss_coef", "value_coeff"),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _hook_env(env: VecEnv):
|
||||
old_step = env.step
|
||||
|
||||
def step_hook(*args, **kwargs):
|
||||
result = old_step(*args, **kwargs)
|
||||
|
||||
if len(result) == 4:
|
||||
obs, rewards, dones, env_info = result
|
||||
elif len(result) == 5:
|
||||
obs, _, rewards, dones, env_info = result
|
||||
else:
|
||||
raise ValueError("Invalid number of return values from env.step().")
|
||||
|
||||
return obs, rewards, dones.float(), env_info
|
||||
|
||||
env.step = step_hook
|
||||
|
||||
return env
|
||||
|
||||
def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"):
|
||||
env = self._hook_env(env)
|
||||
self.cfg = train_cfg["runner"]
|
||||
|
||||
alg_class = eval(self.cfg["algorithm_class_name"])
|
||||
if "policy_class_name" in self.cfg:
|
||||
print("WARNING: ignoring deprecated parameter 'runner.policy_class_name'.")
|
||||
|
||||
alg_cfg = train_cfg["algorithm"]
|
||||
alg_cfg.update(train_cfg["policy"])
|
||||
|
||||
if "activation" in alg_cfg:
|
||||
print(
|
||||
"WARNING: using deprecated parameter 'activation'. Use 'actor_activations' and 'critic_activations' instead."
|
||||
)
|
||||
alg_cfg["actor_activations"] = [alg_cfg["activation"] for _ in range(len(alg_cfg["actor_hidden_dims"]))]
|
||||
alg_cfg["actor_activations"] += ["linear"]
|
||||
alg_cfg["critic_activations"] = [alg_cfg["activation"] for _ in range(len(alg_cfg["critic_hidden_dims"]))]
|
||||
alg_cfg["critic_activations"] += ["linear"]
|
||||
del alg_cfg["activation"]
|
||||
|
||||
for old, new in self.mappings:
|
||||
if old not in alg_cfg:
|
||||
continue
|
||||
|
||||
if new is None:
|
||||
print(f"WARNING: ignoring deprecated parameter '{old}'.")
|
||||
del alg_cfg[old]
|
||||
continue
|
||||
|
||||
print(f"WARNING: using deprecated parameter '{old}'. Use '{new}' instead.")
|
||||
alg_cfg[new] = alg_cfg[old]
|
||||
del alg_cfg[old]
|
||||
|
||||
agent: Agent = alg_class(env, device=device, **train_cfg["algorithm"])
|
||||
|
||||
callbacks = []
|
||||
evaluation_callbacks = []
|
||||
|
||||
evaluation_callbacks.append(lambda *args: Runner._log_progress(*args, prefix="eval"))
|
||||
|
||||
if log_dir and "save_interval" in self.cfg:
|
||||
callbacks.append(make_first_cb(make_legacy_save_model_cb(log_dir)))
|
||||
callbacks.append(make_interval_cb(make_legacy_save_model_cb(log_dir), self.cfg["save_interval"]))
|
||||
|
||||
if log_dir:
|
||||
callbacks.append(Runner._log)
|
||||
callbacks.append(make_final_cb(make_legacy_save_model_cb(log_dir)))
|
||||
callbacks.append(make_final_cb(make_save_model_onnx_cb(log_dir)))
|
||||
# callbacks.append(make_first_cb(lambda *_: store_code_state(log_dir, self._git_status_repos)))
|
||||
else:
|
||||
callbacks.append(Runner._log_progress)
|
||||
|
||||
super().__init__(
|
||||
env,
|
||||
agent,
|
||||
learn_cb=callbacks,
|
||||
evaluation_cb=evaluation_callbacks,
|
||||
device=device,
|
||||
num_steps_per_env=self.cfg["num_steps_per_env"],
|
||||
)
|
||||
|
||||
self._iteration_time = 0.0
|
||||
|
||||
def learn(self, *args, num_learning_iterations=None, init_at_random_ep_len=None, **kwargs):
|
||||
if num_learning_iterations is not None:
|
||||
print("WARNING: using deprecated parameter 'num_learning_iterations'. Use 'iterations' instead.")
|
||||
kwargs["iterations"] = num_learning_iterations
|
||||
|
||||
if init_at_random_ep_len is not None:
|
||||
print("WARNING: ignoring deprecated parameter 'init_at_random_ep_len'.")
|
||||
|
||||
super().learn(*args, **kwargs)
|
|
@ -1,304 +0,0 @@
|
|||
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import statistics
|
||||
import time
|
||||
import torch
|
||||
from collections import deque
|
||||
from torch.utils.tensorboard import SummaryWriter as TensorboardSummaryWriter
|
||||
|
||||
import rsl_rl
|
||||
from rsl_rl.algorithms import PPO
|
||||
from rsl_rl.env import VecEnv
|
||||
from rsl_rl.modules import ActorCritic, ActorCriticRecurrent, EmpiricalNormalization
|
||||
from rsl_rl.utils import store_code_state
|
||||
|
||||
|
||||
class OnPolicyRunner:
|
||||
"""On-policy runner for training and evaluation."""
|
||||
|
||||
def __init__(self, env: VecEnv, train_cfg, log_dir=None, device="cpu"):
|
||||
self.cfg = train_cfg
|
||||
self.alg_cfg = train_cfg["algorithm"]
|
||||
self.policy_cfg = train_cfg["policy"]
|
||||
self.device = device
|
||||
self.env = env
|
||||
obs, extras = self.env.get_observations()
|
||||
num_obs = obs.shape[1]
|
||||
if "critic" in extras["observations"]:
|
||||
num_critic_obs = extras["observations"]["critic"].shape[1]
|
||||
else:
|
||||
num_critic_obs = num_obs
|
||||
actor_critic_class = eval(self.policy_cfg.pop("class_name")) # ActorCritic
|
||||
actor_critic: ActorCritic | ActorCriticRecurrent = actor_critic_class(
|
||||
num_obs, num_critic_obs, self.env.num_actions, **self.policy_cfg
|
||||
).to(self.device)
|
||||
alg_class = eval(self.alg_cfg.pop("class_name")) # PPO
|
||||
self.alg: PPO = alg_class(actor_critic, device=self.device, **self.alg_cfg)
|
||||
self.num_steps_per_env = self.cfg["num_steps_per_env"]
|
||||
self.save_interval = self.cfg["save_interval"]
|
||||
self.empirical_normalization = self.cfg["empirical_normalization"]
|
||||
if self.empirical_normalization:
|
||||
self.obs_normalizer = EmpiricalNormalization(shape=[num_obs], until=1.0e8).to(self.device)
|
||||
self.critic_obs_normalizer = EmpiricalNormalization(shape=[num_critic_obs], until=1.0e8).to(self.device)
|
||||
else:
|
||||
self.obs_normalizer = torch.nn.Identity() # no normalization
|
||||
self.critic_obs_normalizer = torch.nn.Identity() # no normalization
|
||||
# init storage and model
|
||||
self.alg.init_storage(
|
||||
self.env.num_envs,
|
||||
self.num_steps_per_env,
|
||||
[num_obs],
|
||||
[num_critic_obs],
|
||||
[self.env.num_actions],
|
||||
)
|
||||
|
||||
# Log
|
||||
self.log_dir = log_dir
|
||||
self.writer = None
|
||||
self.tot_timesteps = 0
|
||||
self.tot_time = 0
|
||||
self.current_learning_iteration = 0
|
||||
self.git_status_repos = [rsl_rl.__file__]
|
||||
|
||||
def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = False):
|
||||
# initialize writer
|
||||
if self.log_dir is not None and self.writer is None:
|
||||
# Launch either Tensorboard or Neptune & Tensorboard summary writer(s), default: Tensorboard.
|
||||
self.logger_type = self.cfg.get("logger", "tensorboard")
|
||||
self.logger_type = self.logger_type.lower()
|
||||
|
||||
if self.logger_type == "neptune":
|
||||
from rsl_rl.utils.neptune_utils import NeptuneSummaryWriter
|
||||
|
||||
self.writer = NeptuneSummaryWriter(log_dir=self.log_dir, flush_secs=10, cfg=self.cfg)
|
||||
self.writer.log_config(self.env.cfg, self.cfg, self.alg_cfg, self.policy_cfg)
|
||||
elif self.logger_type == "wandb":
|
||||
from rsl_rl.utils.wandb_utils import WandbSummaryWriter
|
||||
|
||||
self.writer = WandbSummaryWriter(log_dir=self.log_dir, flush_secs=10, cfg=self.cfg)
|
||||
self.writer.log_config(self.env.cfg, self.cfg, self.alg_cfg, self.policy_cfg)
|
||||
elif self.logger_type == "tensorboard":
|
||||
self.writer = TensorboardSummaryWriter(log_dir=self.log_dir, flush_secs=10)
|
||||
else:
|
||||
raise AssertionError("logger type not found")
|
||||
|
||||
if init_at_random_ep_len:
|
||||
self.env.episode_length_buf = torch.randint_like(
|
||||
self.env.episode_length_buf, high=int(self.env.max_episode_length)
|
||||
)
|
||||
obs, extras = self.env.get_observations()
|
||||
critic_obs = extras["observations"].get("critic", obs)
|
||||
obs, critic_obs = obs.to(self.device), critic_obs.to(self.device)
|
||||
self.train_mode() # switch to train mode (for dropout for example)
|
||||
|
||||
ep_infos = []
|
||||
rewbuffer = deque(maxlen=100)
|
||||
lenbuffer = deque(maxlen=100)
|
||||
cur_reward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
|
||||
cur_episode_length = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
|
||||
|
||||
start_iter = self.current_learning_iteration
|
||||
tot_iter = start_iter + num_learning_iterations
|
||||
for it in range(start_iter, tot_iter):
|
||||
start = time.time()
|
||||
# Rollout
|
||||
with torch.inference_mode():
|
||||
for i in range(self.num_steps_per_env):
|
||||
actions = self.alg.act(obs, critic_obs)
|
||||
obs, rewards, dones, infos = self.env.step(actions)
|
||||
obs = self.obs_normalizer(obs)
|
||||
if "critic" in infos["observations"]:
|
||||
critic_obs = self.critic_obs_normalizer(infos["observations"]["critic"])
|
||||
else:
|
||||
critic_obs = obs
|
||||
obs, critic_obs, rewards, dones = (
|
||||
obs.to(self.device),
|
||||
critic_obs.to(self.device),
|
||||
rewards.to(self.device),
|
||||
dones.to(self.device),
|
||||
)
|
||||
self.alg.process_env_step(rewards, dones, infos)
|
||||
|
||||
if self.log_dir is not None:
|
||||
# Book keeping
|
||||
# note: we changed logging to use "log" instead of "episode" to avoid confusion with
|
||||
# different types of logging data (rewards, curriculum, etc.)
|
||||
if "episode" in infos:
|
||||
ep_infos.append(infos["episode"])
|
||||
elif "log" in infos:
|
||||
ep_infos.append(infos["log"])
|
||||
cur_reward_sum += rewards
|
||||
cur_episode_length += 1
|
||||
new_ids = (dones > 0).nonzero(as_tuple=False)
|
||||
rewbuffer.extend(cur_reward_sum[new_ids][:, 0].cpu().numpy().tolist())
|
||||
lenbuffer.extend(cur_episode_length[new_ids][:, 0].cpu().numpy().tolist())
|
||||
cur_reward_sum[new_ids] = 0
|
||||
cur_episode_length[new_ids] = 0
|
||||
|
||||
stop = time.time()
|
||||
collection_time = stop - start
|
||||
|
||||
# Learning step
|
||||
start = stop
|
||||
self.alg.compute_returns(critic_obs)
|
||||
|
||||
mean_value_loss, mean_surrogate_loss = self.alg.update()
|
||||
stop = time.time()
|
||||
learn_time = stop - start
|
||||
self.current_learning_iteration = it
|
||||
if self.log_dir is not None:
|
||||
self.log(locals())
|
||||
if it % self.save_interval == 0:
|
||||
self.save(os.path.join(self.log_dir, f"model_{it}.pt"))
|
||||
ep_infos.clear()
|
||||
if it == start_iter:
|
||||
# obtain all the diff files
|
||||
git_file_paths = store_code_state(self.log_dir, self.git_status_repos)
|
||||
# if possible store them to wandb
|
||||
if self.logger_type in ["wandb", "neptune"] and git_file_paths:
|
||||
for path in git_file_paths:
|
||||
self.writer.save_file(path)
|
||||
|
||||
self.save(os.path.join(self.log_dir, f"model_{self.current_learning_iteration}.pt"))
|
||||
|
||||
def log(self, locs: dict, width: int = 80, pad: int = 35):
|
||||
self.tot_timesteps += self.num_steps_per_env * self.env.num_envs
|
||||
self.tot_time += locs["collection_time"] + locs["learn_time"]
|
||||
iteration_time = locs["collection_time"] + locs["learn_time"]
|
||||
|
||||
ep_string = ""
|
||||
if locs["ep_infos"]:
|
||||
for key in locs["ep_infos"][0]:
|
||||
infotensor = torch.tensor([], device=self.device)
|
||||
for ep_info in locs["ep_infos"]:
|
||||
# handle scalar and zero dimensional tensor infos
|
||||
if key not in ep_info:
|
||||
continue
|
||||
if not isinstance(ep_info[key], torch.Tensor):
|
||||
ep_info[key] = torch.Tensor([ep_info[key]])
|
||||
if len(ep_info[key].shape) == 0:
|
||||
ep_info[key] = ep_info[key].unsqueeze(0)
|
||||
infotensor = torch.cat((infotensor, ep_info[key].to(self.device)))
|
||||
value = torch.mean(infotensor)
|
||||
# log to logger and terminal
|
||||
if "/" in key:
|
||||
self.writer.add_scalar(key, value, locs["it"])
|
||||
ep_string += f"""{f'{key}:':>{pad}} {value:.4f}\n"""
|
||||
else:
|
||||
self.writer.add_scalar("Episode/" + key, value, locs["it"])
|
||||
ep_string += f"""{f'Mean episode {key}:':>{pad}} {value:.4f}\n"""
|
||||
mean_std = self.alg.actor_critic.std.mean()
|
||||
fps = int(self.num_steps_per_env * self.env.num_envs / (locs["collection_time"] + locs["learn_time"]))
|
||||
|
||||
self.writer.add_scalar("Loss/value_function", locs["mean_value_loss"], locs["it"])
|
||||
self.writer.add_scalar("Loss/surrogate", locs["mean_surrogate_loss"], locs["it"])
|
||||
self.writer.add_scalar("Loss/learning_rate", self.alg.learning_rate, locs["it"])
|
||||
self.writer.add_scalar("Policy/mean_noise_std", mean_std.item(), locs["it"])
|
||||
self.writer.add_scalar("Perf/total_fps", fps, locs["it"])
|
||||
self.writer.add_scalar("Perf/collection time", locs["collection_time"], locs["it"])
|
||||
self.writer.add_scalar("Perf/learning_time", locs["learn_time"], locs["it"])
|
||||
if len(locs["rewbuffer"]) > 0:
|
||||
self.writer.add_scalar("Train/mean_reward", statistics.mean(locs["rewbuffer"]), locs["it"])
|
||||
self.writer.add_scalar("Train/mean_episode_length", statistics.mean(locs["lenbuffer"]), locs["it"])
|
||||
if self.logger_type != "wandb": # wandb does not support non-integer x-axis logging
|
||||
self.writer.add_scalar("Train/mean_reward/time", statistics.mean(locs["rewbuffer"]), self.tot_time)
|
||||
self.writer.add_scalar(
|
||||
"Train/mean_episode_length/time", statistics.mean(locs["lenbuffer"]), self.tot_time
|
||||
)
|
||||
|
||||
str = f" \033[1m Learning iteration {locs['it']}/{locs['tot_iter']} \033[0m "
|
||||
|
||||
if len(locs["rewbuffer"]) > 0:
|
||||
log_string = (
|
||||
f"""{'#' * width}\n"""
|
||||
f"""{str.center(width, ' ')}\n\n"""
|
||||
f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[
|
||||
'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
|
||||
f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n"""
|
||||
f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n"""
|
||||
f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n"""
|
||||
f"""{'Mean reward:':>{pad}} {statistics.mean(locs['rewbuffer']):.2f}\n"""
|
||||
f"""{'Mean episode length:':>{pad}} {statistics.mean(locs['lenbuffer']):.2f}\n"""
|
||||
)
|
||||
# f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n"""
|
||||
# f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n""")
|
||||
else:
|
||||
log_string = (
|
||||
f"""{'#' * width}\n"""
|
||||
f"""{str.center(width, ' ')}\n\n"""
|
||||
f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[
|
||||
'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
|
||||
f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n"""
|
||||
f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n"""
|
||||
f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n"""
|
||||
)
|
||||
# f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n"""
|
||||
# f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n""")
|
||||
|
||||
log_string += ep_string
|
||||
log_string += (
|
||||
f"""{'-' * width}\n"""
|
||||
f"""{'Total timesteps:':>{pad}} {self.tot_timesteps}\n"""
|
||||
f"""{'Iteration time:':>{pad}} {iteration_time:.2f}s\n"""
|
||||
f"""{'Total time:':>{pad}} {self.tot_time:.2f}s\n"""
|
||||
f"""{'ETA:':>{pad}} {self.tot_time / (locs['it'] + 1) * (
|
||||
locs['num_learning_iterations'] - locs['it']):.1f}s\n"""
|
||||
)
|
||||
print(log_string)
|
||||
|
||||
def save(self, path, infos=None):
|
||||
saved_dict = {
|
||||
"model_state_dict": self.alg.actor_critic.state_dict(),
|
||||
"optimizer_state_dict": self.alg.optimizer.state_dict(),
|
||||
"iter": self.current_learning_iteration,
|
||||
"infos": infos,
|
||||
}
|
||||
if self.empirical_normalization:
|
||||
saved_dict["obs_norm_state_dict"] = self.obs_normalizer.state_dict()
|
||||
saved_dict["critic_obs_norm_state_dict"] = self.critic_obs_normalizer.state_dict()
|
||||
torch.save(saved_dict, path)
|
||||
|
||||
# Upload model to external logging service
|
||||
if self.logger_type in ["neptune", "wandb"]:
|
||||
self.writer.save_model(path, self.current_learning_iteration)
|
||||
|
||||
def load(self, path, load_optimizer=True):
|
||||
loaded_dict = torch.load(path)
|
||||
self.alg.actor_critic.load_state_dict(loaded_dict["model_state_dict"])
|
||||
if self.empirical_normalization:
|
||||
self.obs_normalizer.load_state_dict(loaded_dict["obs_norm_state_dict"])
|
||||
self.critic_obs_normalizer.load_state_dict(loaded_dict["critic_obs_norm_state_dict"])
|
||||
if load_optimizer:
|
||||
self.alg.optimizer.load_state_dict(loaded_dict["optimizer_state_dict"])
|
||||
self.current_learning_iteration = loaded_dict["iter"]
|
||||
return loaded_dict["infos"]
|
||||
|
||||
def get_inference_policy(self, device=None):
|
||||
self.eval_mode() # switch to evaluation mode (dropout for example)
|
||||
if device is not None:
|
||||
self.alg.actor_critic.to(device)
|
||||
policy = self.alg.actor_critic.act_inference
|
||||
if self.cfg["empirical_normalization"]:
|
||||
if device is not None:
|
||||
self.obs_normalizer.to(device)
|
||||
policy = lambda x: self.alg.actor_critic.act_inference(self.obs_normalizer(x)) # noqa: E731
|
||||
return policy
|
||||
|
||||
def train_mode(self):
|
||||
self.alg.actor_critic.train()
|
||||
if self.empirical_normalization:
|
||||
self.obs_normalizer.train()
|
||||
self.critic_obs_normalizer.train()
|
||||
|
||||
def eval_mode(self):
|
||||
self.alg.actor_critic.eval()
|
||||
if self.empirical_normalization:
|
||||
self.obs_normalizer.eval()
|
||||
self.critic_obs_normalizer.eval()
|
||||
|
||||
def add_git_repo_to_log(self, repo_file_path):
|
||||
self.git_status_repos.append(repo_file_path)
|
|
@ -0,0 +1,498 @@
|
|||
from __future__ import annotations
|
||||
import copy
|
||||
from datetime import timedelta
|
||||
import numpy as np
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
from typing import Any, Callable, Dict, List, Tuple, TypedDict, Union
|
||||
|
||||
import rsl_rl
|
||||
from rsl_rl.storage.storage import Dataset
|
||||
from rsl_rl.algorithms import Agent
|
||||
from rsl_rl.env import VecEnv
|
||||
|
||||
|
||||
class EpisodeStatistics(TypedDict):
|
||||
"""The statistics of an episode."""
|
||||
|
||||
# Time it took to collect samples for the current interation.
|
||||
collection_time: Union[int, None]
|
||||
# The counter of the current interation.
|
||||
current_iteration: int
|
||||
# The number of the final iteration of the current run.
|
||||
final_iteration: int
|
||||
# The number of the first iteration of the current run.
|
||||
first_iteration: int
|
||||
# Environment information about the current interation.
|
||||
info: list
|
||||
# The lengths of the episodes.
|
||||
lengths: Union[List[int], None]
|
||||
# The loss of the current interation.
|
||||
loss: Union[dict, None]
|
||||
# The returns of the episodes.
|
||||
returns: Union[List[float], None]
|
||||
# The total time it took to run the current interation.
|
||||
total_time: Union[int, None]
|
||||
# The time it took to update the agent.
|
||||
update_time: Union[int, None]
|
||||
|
||||
|
||||
Callback = Callable[[EpisodeStatistics], None]
|
||||
|
||||
|
||||
class Runner:
|
||||
"""The runner class for running an agent in an environment.
|
||||
|
||||
This class is responsible for running an agent in an environment. It is responsible for collecting data from the
|
||||
environment, updating the agent, and evaluating the agent. It also provides a number of callbacks that can be used
|
||||
to log and visualize the training progress.
|
||||
"""
|
||||
|
||||
_dataset: Dataset
|
||||
_episode_statistics: EpisodeStatistics
|
||||
_num_steps_per_env: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
environment: VecEnv,
|
||||
agent: Agent,
|
||||
device: str = "cpu",
|
||||
evaluation_cb: List[Callback] = None,
|
||||
learn_cb: List[Callback] = None,
|
||||
observation_history_length: int = 1,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
environment (rsl_rl.env.VecEnv): The environment to run the agent in.
|
||||
agent (rsl_rl.algorithms.agent): The RL agent to run.
|
||||
device (str): The device to run on.
|
||||
evaluation_cb (List[Callable[[dict], None]], optional): A list of callbacks that are called after each round
|
||||
of evaluation.
|
||||
learn_cb (List[Callable[[dict], None]], optional): A list of callbacks that are called after each round of
|
||||
learning.
|
||||
observation_history_length: The number of observations to concatenate into a single observation.
|
||||
"""
|
||||
self.env = environment
|
||||
self.agent = agent
|
||||
self.device = device
|
||||
self._obs_hist_len = observation_history_length
|
||||
self._learn_cb = learn_cb if learn_cb else []
|
||||
self._eval_cb = evaluation_cb if evaluation_cb else []
|
||||
|
||||
self._set_kwarg(kwargs, "num_steps_per_env", default=1)
|
||||
|
||||
self._current_learning_iteration = 0
|
||||
self._git_status_repos = [rsl_rl.__file__]
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
self._stored_dataset = [] # For computing observation history over multiple steps.
|
||||
|
||||
def add_git_repo_to_log(self, repo_file_path):
|
||||
self._git_status_repos.append(repo_file_path)
|
||||
|
||||
def eval_mode(self):
|
||||
"""Sets the agent to evaluation mode."""
|
||||
self.agent.eval_mode()
|
||||
|
||||
def evaluate(self, steps: int, return_epochs: int = 100) -> float:
|
||||
"""Evaluates the agent for a number of steps.
|
||||
|
||||
Args:
|
||||
steps (int): The number of steps to evaluate the agent for.
|
||||
return_epochs (int): The number of epochs over which to aggregate the return. Defaults to 100.
|
||||
Returns:
|
||||
The mean return of the agent.
|
||||
"""
|
||||
cumulative_rewards = []
|
||||
current_cumulative_rewards = torch.zeros(self.env.num_envs, dtype=torch.float)
|
||||
current_episode_lengths = torch.zeros(self.env.num_envs, dtype=torch.int)
|
||||
episode_lengths = []
|
||||
|
||||
self.eval_mode()
|
||||
|
||||
policy = self.get_inference_policy()
|
||||
obs, env_info = self.env.get_observations()
|
||||
|
||||
with torch.inference_mode():
|
||||
for step in range(steps):
|
||||
actions = policy(obs.clone(), copy.deepcopy(env_info))
|
||||
obs, rewards, dones, env_info, episode_statistics = self.evaluate_step(obs, env_info, actions)
|
||||
|
||||
dones_idx = dones.nonzero().cpu()
|
||||
current_cumulative_rewards += rewards.clone().cpu()
|
||||
current_episode_lengths += 1
|
||||
cumulative_rewards.extend(current_cumulative_rewards[dones_idx].squeeze(1).cpu().tolist())
|
||||
episode_lengths.extend(current_episode_lengths[dones_idx].squeeze(1).cpu().tolist())
|
||||
current_cumulative_rewards[dones_idx] = 0.0
|
||||
current_episode_lengths[dones_idx] = 0
|
||||
|
||||
episode_statistics["current_iteration"] = step
|
||||
episode_statistics["final_iteration"] = steps
|
||||
episode_statistics["lengths"] = episode_lengths[-return_epochs:]
|
||||
episode_statistics["returns"] = cumulative_rewards[-return_epochs:]
|
||||
|
||||
for cb in self._eval_cb:
|
||||
cb(self, episode_statistics)
|
||||
|
||||
cumulative_rewards.extend(current_cumulative_rewards.cpu().tolist())
|
||||
mean_return = np.mean(cumulative_rewards)
|
||||
|
||||
return mean_return
|
||||
|
||||
def evaluate_step(
|
||||
self, observations=None, environment_info=None, actions=None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Dict, Dict]:
|
||||
"""Evaluates the agent for a single step.
|
||||
|
||||
Args:
|
||||
observations (torch.Tensor): The observations to evaluate the agent for.
|
||||
environment_info (Dict[str, Any]): The environment information for the observations.
|
||||
actions (torch.Tensor): The actions to evaluate the agent for.
|
||||
Returns:
|
||||
A tuple containing the observations, rewards, dones, environment information, and episode statistics after
|
||||
the evaluation step.
|
||||
"""
|
||||
episode_statistics = {
|
||||
"current_actions": None,
|
||||
"current_dones": None,
|
||||
"current_iteration": 0,
|
||||
"current_observations": None,
|
||||
"current_rewards": None,
|
||||
"final_iteration": 0,
|
||||
"first_iteration": 0,
|
||||
"info": [],
|
||||
"lengths": [],
|
||||
"returns": [],
|
||||
"timeout": None,
|
||||
"total_time": None,
|
||||
}
|
||||
|
||||
self.eval_mode()
|
||||
|
||||
with torch.inference_mode():
|
||||
obs, env_info = self.env.get_observations() if observations is None else (observations, environment_info)
|
||||
|
||||
with torch.inference_mode():
|
||||
start = time.time()
|
||||
|
||||
actions = self.get_inference_policy()(obs.clone(), copy.deepcopy(env_info)) if actions is None else actions
|
||||
obs, rewards, dones, env_info = self.env.step(actions.clone())
|
||||
|
||||
self.agent.register_terminations(dones.nonzero().reshape(-1))
|
||||
|
||||
end = time.time()
|
||||
|
||||
if "episode" in env_info:
|
||||
episode_statistics["info"].append(env_info["episode"])
|
||||
episode_statistics["current_actions"] = actions
|
||||
episode_statistics["current_dones"] = dones
|
||||
episode_statistics["current_observations"] = obs
|
||||
episode_statistics["current_rewards"] = rewards
|
||||
episode_statistics["total_time"] = end - start
|
||||
|
||||
return obs, rewards, dones, env_info, episode_statistics
|
||||
|
||||
def get_inference_policy(self, device=None):
|
||||
self.eval_mode()
|
||||
|
||||
return self.agent.get_inference_policy(device)
|
||||
|
||||
def learn(
|
||||
self, iterations: Union[int, None] = None, timeout: Union[int, None] = None, return_epochs: int = 100
|
||||
) -> None:
|
||||
"""Runs a number of learning iterations.
|
||||
|
||||
Args:
|
||||
iterations (int): The number of iterations to run.
|
||||
timeout (int): Optional number of seconds after which to terminate training. Defaults to None.
|
||||
return_epochs (int): The number of epochs over which to aggregate the return. Defaults to 100.
|
||||
"""
|
||||
assert iterations is not None or timeout is not None
|
||||
|
||||
self._episode_statistics = {
|
||||
"collection_time": None,
|
||||
"current_actions": None,
|
||||
"current_iteration": self._current_learning_iteration,
|
||||
"current_observations": None,
|
||||
"final_iteration": self._current_learning_iteration + iterations if iterations is not None else None,
|
||||
"first_iteration": self._current_learning_iteration,
|
||||
"info": [],
|
||||
"lengths": [],
|
||||
"loss": {},
|
||||
"returns": [],
|
||||
"storage_initialized": False,
|
||||
"timeout": timeout,
|
||||
"total_time": None,
|
||||
"training_time": 0,
|
||||
"update_time": None,
|
||||
}
|
||||
self._current_episode_lengths = torch.zeros(self.env.num_envs, dtype=torch.float)
|
||||
self._current_cumulative_rewards = torch.zeros(self.env.num_envs, dtype=torch.float)
|
||||
|
||||
self.train_mode()
|
||||
|
||||
self._obs, self._env_info = self.env.get_observations()
|
||||
while True:
|
||||
if self._learning_should_terminate():
|
||||
break
|
||||
|
||||
# Collect data
|
||||
start = time.time()
|
||||
|
||||
with torch.inference_mode():
|
||||
self._dataset = []
|
||||
|
||||
for _ in range(self._num_steps_per_env):
|
||||
self._collect()
|
||||
|
||||
self._episode_statistics["lengths"] = self._episode_statistics["lengths"][-return_epochs:]
|
||||
self._episode_statistics["returns"] = self._episode_statistics["returns"][-return_epochs:]
|
||||
|
||||
self._episode_statistics["collection_time"] = time.time() - start
|
||||
|
||||
# Update agent
|
||||
|
||||
start = time.time()
|
||||
|
||||
self._update()
|
||||
|
||||
self._episode_statistics["update_time"] = time.time() - start
|
||||
|
||||
# Housekeeping
|
||||
|
||||
self._episode_statistics["total_time"] = (
|
||||
self._episode_statistics["collection_time"] + self._episode_statistics["update_time"]
|
||||
)
|
||||
self._episode_statistics["training_time"] += self._episode_statistics["total_time"]
|
||||
|
||||
if self.agent.initialized:
|
||||
self._episode_statistics["current_iteration"] += 1
|
||||
|
||||
terminate = False
|
||||
for cb in self._learn_cb:
|
||||
terminate = (cb(self, self._episode_statistics) == False) or terminate
|
||||
|
||||
if terminate:
|
||||
break
|
||||
|
||||
self._episode_statistics["info"].clear()
|
||||
self._current_learning_iteration = self._episode_statistics["current_iteration"]
|
||||
|
||||
def _collect(self) -> None:
|
||||
"""Runs a single step in the environment to collect a transition and stores it in the dataset.
|
||||
|
||||
This method runs a single step in the environment to collect a transition and stores it in the dataset. If the
|
||||
agent is not initialized, random actions are drawn from the action space. Furthermore, the method gathers
|
||||
statistics about the episode and stores them in the episode statistics dictionary of the runner.
|
||||
"""
|
||||
if self.agent.initialized:
|
||||
actions, data = self.agent.draw_actions(self._obs, self._env_info)
|
||||
else:
|
||||
actions, data = self.agent.draw_random_actions(self._obs, self._env_info)
|
||||
|
||||
next_obs, rewards, dones, next_env_info = self.env.step(actions)
|
||||
|
||||
self._dataset.append(
|
||||
self.agent.process_transition(
|
||||
self._obs.clone(),
|
||||
copy.deepcopy(self._env_info),
|
||||
actions.clone(),
|
||||
rewards.clone(),
|
||||
next_obs.clone(),
|
||||
copy.deepcopy(next_env_info),
|
||||
dones.clone(),
|
||||
copy.deepcopy(data),
|
||||
)
|
||||
)
|
||||
|
||||
self.agent.register_terminations(dones.nonzero().reshape(-1))
|
||||
|
||||
self._obs, self._env_info = next_obs, next_env_info
|
||||
|
||||
# Gather statistics
|
||||
if "episode" in self._env_info:
|
||||
self._episode_statistics["info"].append(self._env_info["episode"])
|
||||
dones_idx = (dones + next_env_info["time_outs"]).nonzero().cpu()
|
||||
self._current_episode_lengths += 1
|
||||
self._current_cumulative_rewards += rewards.cpu()
|
||||
|
||||
completed_lengths = self._current_episode_lengths[dones_idx][:, 0].cpu()
|
||||
completed_returns = self._current_cumulative_rewards[dones_idx][:, 0].cpu()
|
||||
self._episode_statistics["lengths"].extend(completed_lengths.tolist())
|
||||
self._episode_statistics["returns"].extend(completed_returns.tolist())
|
||||
self._current_episode_lengths[dones_idx] = 0.0
|
||||
self._current_cumulative_rewards[dones_idx] = 0.0
|
||||
|
||||
self._episode_statistics["current_actions"] = actions
|
||||
self._episode_statistics["current_observations"] = self._obs
|
||||
self._episode_statistics["sample_count"] = self.agent.storage.sample_count
|
||||
|
||||
def _learning_should_terminate(self):
|
||||
"""Checks whether the learning should terminate.
|
||||
|
||||
Termination is triggered if the number of iterations or the timeout is reached.
|
||||
|
||||
Returns:
|
||||
Whether the learning should terminate.
|
||||
"""
|
||||
if (
|
||||
self._episode_statistics["final_iteration"] is not None
|
||||
and self._episode_statistics["current_iteration"] >= self._episode_statistics["final_iteration"]
|
||||
):
|
||||
return True
|
||||
|
||||
if (
|
||||
self._episode_statistics["timeout"] is not None
|
||||
and self._episode_statistics["training_time"] >= self._episode_statistics["timeout"]
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _update(self) -> None:
|
||||
"""Updates the agent using the collected data."""
|
||||
loss = self.agent.update(self._dataset)
|
||||
self._dataset = []
|
||||
|
||||
if not self.agent.initialized:
|
||||
return
|
||||
|
||||
self._episode_statistics["loss"] = loss
|
||||
self._episode_statistics["storage_initialized"] = True
|
||||
|
||||
def load(self, path: str) -> Any:
|
||||
"""Restores the agent and runner state from a file."""
|
||||
content = torch.load(path, map_location=self.device)
|
||||
|
||||
assert "agent" in content
|
||||
assert "data" in content
|
||||
assert "iteration" in content
|
||||
|
||||
self.agent.load_state_dict(content["agent"])
|
||||
self._current_learning_iteration = content["iteration"]
|
||||
|
||||
return content["data"]
|
||||
|
||||
def save(self, path: str, data: Any = None) -> None:
|
||||
"""Saves the agent and runner state to a file."""
|
||||
content = {
|
||||
"agent": self.agent.state_dict(),
|
||||
"data": data,
|
||||
"iteration": self._current_learning_iteration,
|
||||
}
|
||||
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
torch.save(content, path)
|
||||
|
||||
def export_onnx(self, path: str) -> None:
|
||||
"""Exports the agent's policy network to ONNX format."""
|
||||
model, args, kwargs = self.agent.export_onnx()
|
||||
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
torch.onnx.export(model, args, path, **kwargs)
|
||||
|
||||
def to(self, device) -> Runner:
|
||||
"""Sets the device of the runner and its components."""
|
||||
self.device = device
|
||||
|
||||
self.agent.to(device)
|
||||
|
||||
try:
|
||||
self.env.to(device)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
return self
|
||||
|
||||
def train_mode(self):
|
||||
"""Sets the agent to training mode."""
|
||||
self.agent.train_mode()
|
||||
|
||||
def _set_kwarg(self, args, key, default=None, private=True):
|
||||
setattr(self, f"_{key}" if private else key, args[key] if key in args else default)
|
||||
|
||||
def _log_progress(self, stat, clear_line=True, prefix=""):
|
||||
"""Logs the progress of the runner."""
|
||||
if not hasattr(self, "_iteration_times"):
|
||||
self._iteration_times = []
|
||||
|
||||
self._iteration_times = (self._iteration_times + [stat["total_time"]])[-100:]
|
||||
average_total_time = np.mean(self._iteration_times)
|
||||
|
||||
if stat["final_iteration"] is not None:
|
||||
first_iteration = stat["first_iteration"]
|
||||
final_iteration = stat["final_iteration"]
|
||||
current_iteration = stat["current_iteration"]
|
||||
final_run_iteration = final_iteration - first_iteration
|
||||
remaining_iterations = final_iteration - current_iteration
|
||||
|
||||
remaining_iteration_time = remaining_iterations * average_total_time
|
||||
iteration_completion_percentage = 100 * (current_iteration - first_iteration) / final_run_iteration
|
||||
else:
|
||||
remaining_iteration_time = np.inf
|
||||
iteration_completion_percentage = 0
|
||||
|
||||
if stat["timeout"] is not None:
|
||||
training_time = stat["training_time"]
|
||||
timeout = stat["timeout"]
|
||||
|
||||
remaining_timeout_time = stat["timeout"] - stat["training_time"]
|
||||
timeout_completion_percentage = 100 * stat["training_time"] / stat["timeout"]
|
||||
else:
|
||||
remaining_timeout_time = np.inf
|
||||
timeout_completion_percentage = 0
|
||||
|
||||
if remaining_iteration_time > remaining_timeout_time:
|
||||
completion_percentage = timeout_completion_percentage
|
||||
remaining_time = remaining_timeout_time
|
||||
step_string = f"({int(training_time)}s / {timeout}s)"
|
||||
else:
|
||||
completion_percentage = iteration_completion_percentage
|
||||
remaining_time = remaining_iteration_time
|
||||
step_string = f"({current_iteration} / {final_iteration})"
|
||||
|
||||
prefix = f"[{prefix}] " if prefix else ""
|
||||
progress = "".join(["#" if i <= int(completion_percentage) else "_" for i in range(10, 101, 5)])
|
||||
remaining_time_string = str(timedelta(seconds=int(np.ceil(remaining_time))))
|
||||
print(
|
||||
f"{prefix}{progress} {step_string} [{completion_percentage:.1f}%, {1/average_total_time:.2f}it/s, {remaining_time_string} ETA]",
|
||||
end="\r" if clear_line else "\n",
|
||||
)
|
||||
|
||||
def _log(self, stat, prefix=""):
|
||||
"""Logs the progress and statistics of the runner."""
|
||||
current_iteration = stat["current_iteration"]
|
||||
|
||||
collection_time = stat["collection_time"]
|
||||
update_time = stat["update_time"]
|
||||
total_time = stat["total_time"]
|
||||
collection_percentage = 100 * collection_time / total_time
|
||||
update_percentage = 100 * update_time / total_time
|
||||
|
||||
if prefix == "":
|
||||
prefix = "learn" if stat["storage_initialized"] else "init"
|
||||
self._log_progress(stat, clear_line=False, prefix=prefix)
|
||||
print(
|
||||
f"iteration time:\t{total_time:.4f}s (collection: {collection_time:.2f}s [{collection_percentage:.1f}%], update: {update_time:.2f}s [{update_percentage:.1f}%])"
|
||||
)
|
||||
|
||||
mean_reward = sum(stat["returns"]) / len(stat["returns"]) if len(stat["returns"]) > 0 else 0.0
|
||||
mean_steps = sum(stat["lengths"]) / len(stat["lengths"]) if len(stat["lengths"]) > 0 else 0.0
|
||||
total_steps = current_iteration * self.env.num_envs * self._num_steps_per_env
|
||||
sample_count = stat["sample_count"]
|
||||
print(f"avg. reward:\t{mean_reward:.4f}")
|
||||
print(f"avg. steps:\t{mean_steps:.4f}")
|
||||
print(f"stored samples:\t{sample_count}")
|
||||
print(f"total steps:\t{total_steps}")
|
||||
|
||||
for key, value in stat["loss"].items():
|
||||
print(f"{key} loss:\t{value:.4f}")
|
||||
|
||||
for key, value in self.agent._bm_report().items():
|
||||
mean, count = value
|
||||
print(f"BM {key}:\t{mean/1000000.0:.4f}ms ({count} calls, total {mean*count/1000000.0:.4f}ms)")
|
||||
|
||||
self.agent._bm_flush()
|
|
@ -4,5 +4,6 @@
|
|||
"""Implementation of transitions storage for RL-agent."""
|
||||
|
||||
from .rollout_storage import RolloutStorage
|
||||
from .replay_storage import ReplayStorage
|
||||
|
||||
__all__ = ["RolloutStorage"]
|
||||
__all__ = ["RolloutStorage", "ReplayStorage"]
|
||||
|
|
|
@ -0,0 +1,147 @@
|
|||
import torch
|
||||
from typing import Callable, Dict, Generator, Tuple, Optional
|
||||
|
||||
from rsl_rl.storage.storage import Dataset, Storage, Transition
|
||||
|
||||
|
||||
class ReplayStorage(Storage):
|
||||
def __init__(self, environment_count: int, max_size: int, device: str = "cpu", initial_size: int = 0) -> None:
|
||||
self._env_count = environment_count
|
||||
self.initial_size = initial_size // environment_count
|
||||
self.max_size = max_size
|
||||
self.device = device
|
||||
|
||||
self._register_serializable("max_size", "initial_size")
|
||||
|
||||
self._idx = 0
|
||||
self._full = False
|
||||
self._initialized = initial_size == 0
|
||||
self._data = {}
|
||||
|
||||
self._processors: Dict[Tuple[Callable, Callable]] = {}
|
||||
|
||||
@property
|
||||
def max_size(self):
|
||||
return self._size * self._env_count
|
||||
|
||||
@max_size.setter
|
||||
def max_size(self, value):
|
||||
self._size = value // self._env_count
|
||||
|
||||
assert self.initial_size <= self._size
|
||||
|
||||
def _add_item(self, name: str, value: torch.Tensor) -> None:
|
||||
"""Adds a transition item to the storage.
|
||||
|
||||
Args:
|
||||
name (str): The name of the item.
|
||||
value (torch.Tensor): The value of the item.
|
||||
"""
|
||||
value = self._process(name, value.clone().to(self.device))
|
||||
|
||||
if name not in self._data:
|
||||
if self._full or self._idx != 0:
|
||||
raise ValueError(f'Tried to store invalid transition data for "{name}".')
|
||||
self._data[name] = torch.empty(
|
||||
self._size * self._env_count, *value.shape[1:], device=self.device, dtype=value.dtype
|
||||
)
|
||||
|
||||
start_idx = self._idx * self._env_count
|
||||
end_idx = (self._idx + 1) * self._env_count
|
||||
self._data[name][start_idx:end_idx] = value
|
||||
|
||||
def _process(self, name: str, value: torch.Tensor) -> torch.Tensor:
|
||||
if name not in self._processors:
|
||||
return value
|
||||
|
||||
for process, _ in self._processors[name]:
|
||||
if process is None:
|
||||
continue
|
||||
|
||||
value = process(value)
|
||||
|
||||
return value
|
||||
|
||||
def _process_undo(self, name: str, value: torch.Tensor) -> torch.Tensor:
|
||||
if name not in self._processors:
|
||||
return value
|
||||
|
||||
for _, undo in reversed(self._processors[name]):
|
||||
if undo is None:
|
||||
continue
|
||||
|
||||
value = undo(value)
|
||||
|
||||
return value
|
||||
|
||||
def append(self, dataset: Dataset) -> None:
|
||||
"""Appends a dataset of transitions to the storage.
|
||||
|
||||
Args:
|
||||
dataset (Dataset): The dataset of transitions.
|
||||
"""
|
||||
for transition in dataset:
|
||||
for name, value in transition.items():
|
||||
self._add_item(name, value)
|
||||
|
||||
self._idx += 1
|
||||
|
||||
if self._idx >= self.initial_size:
|
||||
self._initialized = True
|
||||
|
||||
if self._idx == self._size:
|
||||
self._full = True
|
||||
self._idx = 0
|
||||
|
||||
def batch_generator(self, batch_size: int, batch_count: int) -> Generator[Transition, None, None]:
|
||||
"""Returns a generator that yields batches of transitions.
|
||||
|
||||
Args:
|
||||
batch_size (int): The size of the batches.
|
||||
batch_count (int): The number of batches to yield.
|
||||
Returns:
|
||||
A generator that yields batches of transitions.
|
||||
"""
|
||||
assert self._full or self._idx > 0
|
||||
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
max_idx = self._env_count * (self._size if self._full else self._idx)
|
||||
|
||||
for _ in range(batch_count):
|
||||
batch_idx = torch.randint(high=max_idx, size=(batch_size,))
|
||||
|
||||
batch = {}
|
||||
for key, value in self._data.items():
|
||||
batch[key] = self._process_undo(key, value[batch_idx].clone())
|
||||
|
||||
yield batch
|
||||
|
||||
def register_processor(self, key: str, process: Callable, undo: Optional[Callable] = None) -> None:
|
||||
"""Registers a processor for a transition item.
|
||||
|
||||
The processor is called before the item is stored in the storage. The undo function is called when the item is
|
||||
retrieved from the storage. The undo function is called in reverse order of the processors so that the order of
|
||||
the processors does not matter.
|
||||
|
||||
Args:
|
||||
key (str): The name of the transition item.
|
||||
process (Callable): The function to process the item.
|
||||
undo (Optional[Callable], optional): The function to undo the processing. Defaults to None.
|
||||
"""
|
||||
if key not in self._processors:
|
||||
self._processors[key] = []
|
||||
|
||||
self._processors[key].append((process, undo))
|
||||
|
||||
@property
|
||||
def initialized(self) -> bool:
|
||||
return self._initialized
|
||||
|
||||
@property
|
||||
def sample_count(self) -> int:
|
||||
"""Returns the number of individual transitions stored in the storage."""
|
||||
transition_count = self._size * self._env_count if self._full else self._idx * self._env_count
|
||||
|
||||
return transition_count
|
|
@ -1,228 +1,71 @@
|
|||
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from typing import Generator
|
||||
|
||||
from rsl_rl.utils import split_and_pad_trajectories
|
||||
from rsl_rl.storage.replay_storage import ReplayStorage
|
||||
from rsl_rl.storage.storage import Dataset, Transition
|
||||
|
||||
|
||||
class RolloutStorage:
|
||||
class Transition:
|
||||
def __init__(self):
|
||||
self.observations = None
|
||||
self.critic_observations = None
|
||||
self.actions = None
|
||||
self.rewards = None
|
||||
self.dones = None
|
||||
self.values = None
|
||||
self.actions_log_prob = None
|
||||
self.action_mean = None
|
||||
self.action_sigma = None
|
||||
self.hidden_states = None
|
||||
class RolloutStorage(ReplayStorage):
|
||||
"""Implementation of rollout storage for RL-agent."""
|
||||
|
||||
def clear(self):
|
||||
self.__init__()
|
||||
def __init__(self, environment_count: int, device: str = "cpu"):
|
||||
"""
|
||||
Args:
|
||||
environment_count (int): Number of environments.
|
||||
device (str, optional): Device to use. Defaults to "cpu".
|
||||
"""
|
||||
super().__init__(environment_count, environment_count, device=device, initial_size=0)
|
||||
|
||||
def __init__(self, num_envs, num_transitions_per_env, obs_shape, privileged_obs_shape, actions_shape, device="cpu"):
|
||||
self.device = device
|
||||
self._size_initialized = False
|
||||
|
||||
self.obs_shape = obs_shape
|
||||
self.privileged_obs_shape = privileged_obs_shape
|
||||
self.actions_shape = actions_shape
|
||||
def append(self, dataset: Dataset) -> None:
|
||||
"""Appends a dataset to the rollout storage.
|
||||
|
||||
# Core
|
||||
self.observations = torch.zeros(num_transitions_per_env, num_envs, *obs_shape, device=self.device)
|
||||
if privileged_obs_shape[0] is not None:
|
||||
self.privileged_observations = torch.zeros(
|
||||
num_transitions_per_env, num_envs, *privileged_obs_shape, device=self.device
|
||||
)
|
||||
Args:
|
||||
dataset (Dataset): Dataset to append.
|
||||
Raises:
|
||||
AssertionError: If the dataset is not of the correct size.
|
||||
"""
|
||||
assert self._idx == 0
|
||||
|
||||
if not self._size_initialized:
|
||||
self.max_size = len(dataset) * self._env_count
|
||||
|
||||
assert len(dataset) == self._size
|
||||
|
||||
super().append(dataset)
|
||||
|
||||
def batch_generator(self, batch_count: int, trajectories: bool = False) -> Generator[Transition, None, None]:
|
||||
"""Yields batches of transitions or trajectories.
|
||||
|
||||
Args:
|
||||
batch_count (int): Number of batches to yield.
|
||||
trajectories (bool, optional): Whether to yield batches of trajectories. Defaults to False.
|
||||
Raises:
|
||||
AssertionError: If the rollout storage is not full.
|
||||
Returns:
|
||||
Generator yielding batches of transitions of shape (batch_size, *shape). If trajectories is True, yields
|
||||
batches of trajectories of shape (env_count, steps_per_env, *shape).
|
||||
"""
|
||||
assert self._full and self._initialized, "Rollout storage must be full and initialized."
|
||||
|
||||
total_size = self._env_count if trajectories else self._size * self._env_count
|
||||
batch_size = total_size // batch_count
|
||||
indices = torch.randperm(total_size)
|
||||
|
||||
assert batch_size > 0, "Batch count is too large."
|
||||
|
||||
if trajectories:
|
||||
# Reshape to (env_count, steps_per_env, *shape)
|
||||
data = {k: v.reshape(-1, self._env_count, *v.shape[1:]).transpose(0, 1) for k, v in self._data.items()}
|
||||
else:
|
||||
self.privileged_observations = None
|
||||
self.rewards = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
|
||||
self.actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)
|
||||
self.dones = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device).byte()
|
||||
data = self._data
|
||||
|
||||
# For PPO
|
||||
self.actions_log_prob = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
|
||||
self.values = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
|
||||
self.returns = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
|
||||
self.advantages = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
|
||||
self.mu = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)
|
||||
self.sigma = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)
|
||||
for i in range(batch_count):
|
||||
batch_idx = indices[i * batch_size : (i + 1) * batch_size].detach().to(self.device)
|
||||
|
||||
self.num_transitions_per_env = num_transitions_per_env
|
||||
self.num_envs = num_envs
|
||||
batch = {}
|
||||
for key, value in data.items():
|
||||
batch[key] = self._process_undo(key, value[batch_idx].clone())
|
||||
|
||||
# rnn
|
||||
self.saved_hidden_states_a = None
|
||||
self.saved_hidden_states_c = None
|
||||
|
||||
self.step = 0
|
||||
|
||||
def add_transitions(self, transition: Transition):
|
||||
if self.step >= self.num_transitions_per_env:
|
||||
raise AssertionError("Rollout buffer overflow")
|
||||
self.observations[self.step].copy_(transition.observations)
|
||||
if self.privileged_observations is not None:
|
||||
self.privileged_observations[self.step].copy_(transition.critic_observations)
|
||||
self.actions[self.step].copy_(transition.actions)
|
||||
self.rewards[self.step].copy_(transition.rewards.view(-1, 1))
|
||||
self.dones[self.step].copy_(transition.dones.view(-1, 1))
|
||||
self.values[self.step].copy_(transition.values)
|
||||
self.actions_log_prob[self.step].copy_(transition.actions_log_prob.view(-1, 1))
|
||||
self.mu[self.step].copy_(transition.action_mean)
|
||||
self.sigma[self.step].copy_(transition.action_sigma)
|
||||
self._save_hidden_states(transition.hidden_states)
|
||||
self.step += 1
|
||||
|
||||
def _save_hidden_states(self, hidden_states):
|
||||
if hidden_states is None or hidden_states == (None, None):
|
||||
return
|
||||
# make a tuple out of GRU hidden state sto match the LSTM format
|
||||
hid_a = hidden_states[0] if isinstance(hidden_states[0], tuple) else (hidden_states[0],)
|
||||
hid_c = hidden_states[1] if isinstance(hidden_states[1], tuple) else (hidden_states[1],)
|
||||
|
||||
# initialize if needed
|
||||
if self.saved_hidden_states_a is None:
|
||||
self.saved_hidden_states_a = [
|
||||
torch.zeros(self.observations.shape[0], *hid_a[i].shape, device=self.device) for i in range(len(hid_a))
|
||||
]
|
||||
self.saved_hidden_states_c = [
|
||||
torch.zeros(self.observations.shape[0], *hid_c[i].shape, device=self.device) for i in range(len(hid_c))
|
||||
]
|
||||
# copy the states
|
||||
for i in range(len(hid_a)):
|
||||
self.saved_hidden_states_a[i][self.step].copy_(hid_a[i])
|
||||
self.saved_hidden_states_c[i][self.step].copy_(hid_c[i])
|
||||
|
||||
def clear(self):
|
||||
self.step = 0
|
||||
|
||||
def compute_returns(self, last_values, gamma, lam):
|
||||
advantage = 0
|
||||
for step in reversed(range(self.num_transitions_per_env)):
|
||||
if step == self.num_transitions_per_env - 1:
|
||||
next_values = last_values
|
||||
else:
|
||||
next_values = self.values[step + 1]
|
||||
next_is_not_terminal = 1.0 - self.dones[step].float()
|
||||
delta = self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step]
|
||||
advantage = delta + next_is_not_terminal * gamma * lam * advantage
|
||||
self.returns[step] = advantage + self.values[step]
|
||||
|
||||
# Compute and normalize the advantages
|
||||
self.advantages = self.returns - self.values
|
||||
self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8)
|
||||
|
||||
def get_statistics(self):
|
||||
done = self.dones
|
||||
done[-1] = 1
|
||||
flat_dones = done.permute(1, 0, 2).reshape(-1, 1)
|
||||
done_indices = torch.cat(
|
||||
(flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero(as_tuple=False)[:, 0])
|
||||
)
|
||||
trajectory_lengths = done_indices[1:] - done_indices[:-1]
|
||||
return trajectory_lengths.float().mean(), self.rewards.mean()
|
||||
|
||||
def mini_batch_generator(self, num_mini_batches, num_epochs=8):
|
||||
batch_size = self.num_envs * self.num_transitions_per_env
|
||||
mini_batch_size = batch_size // num_mini_batches
|
||||
indices = torch.randperm(num_mini_batches * mini_batch_size, requires_grad=False, device=self.device)
|
||||
|
||||
observations = self.observations.flatten(0, 1)
|
||||
if self.privileged_observations is not None:
|
||||
critic_observations = self.privileged_observations.flatten(0, 1)
|
||||
else:
|
||||
critic_observations = observations
|
||||
|
||||
actions = self.actions.flatten(0, 1)
|
||||
values = self.values.flatten(0, 1)
|
||||
returns = self.returns.flatten(0, 1)
|
||||
old_actions_log_prob = self.actions_log_prob.flatten(0, 1)
|
||||
advantages = self.advantages.flatten(0, 1)
|
||||
old_mu = self.mu.flatten(0, 1)
|
||||
old_sigma = self.sigma.flatten(0, 1)
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
for i in range(num_mini_batches):
|
||||
start = i * mini_batch_size
|
||||
end = (i + 1) * mini_batch_size
|
||||
batch_idx = indices[start:end]
|
||||
|
||||
obs_batch = observations[batch_idx]
|
||||
critic_observations_batch = critic_observations[batch_idx]
|
||||
actions_batch = actions[batch_idx]
|
||||
target_values_batch = values[batch_idx]
|
||||
returns_batch = returns[batch_idx]
|
||||
old_actions_log_prob_batch = old_actions_log_prob[batch_idx]
|
||||
advantages_batch = advantages[batch_idx]
|
||||
old_mu_batch = old_mu[batch_idx]
|
||||
old_sigma_batch = old_sigma[batch_idx]
|
||||
yield obs_batch, critic_observations_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, (
|
||||
None,
|
||||
None,
|
||||
), None
|
||||
|
||||
# for RNNs only
|
||||
def reccurent_mini_batch_generator(self, num_mini_batches, num_epochs=8):
|
||||
padded_obs_trajectories, trajectory_masks = split_and_pad_trajectories(self.observations, self.dones)
|
||||
if self.privileged_observations is not None:
|
||||
padded_critic_obs_trajectories, _ = split_and_pad_trajectories(self.privileged_observations, self.dones)
|
||||
else:
|
||||
padded_critic_obs_trajectories = padded_obs_trajectories
|
||||
|
||||
mini_batch_size = self.num_envs // num_mini_batches
|
||||
for ep in range(num_epochs):
|
||||
first_traj = 0
|
||||
for i in range(num_mini_batches):
|
||||
start = i * mini_batch_size
|
||||
stop = (i + 1) * mini_batch_size
|
||||
|
||||
dones = self.dones.squeeze(-1)
|
||||
last_was_done = torch.zeros_like(dones, dtype=torch.bool)
|
||||
last_was_done[1:] = dones[:-1]
|
||||
last_was_done[0] = True
|
||||
trajectories_batch_size = torch.sum(last_was_done[:, start:stop])
|
||||
last_traj = first_traj + trajectories_batch_size
|
||||
|
||||
masks_batch = trajectory_masks[:, first_traj:last_traj]
|
||||
obs_batch = padded_obs_trajectories[:, first_traj:last_traj]
|
||||
critic_obs_batch = padded_critic_obs_trajectories[:, first_traj:last_traj]
|
||||
|
||||
actions_batch = self.actions[:, start:stop]
|
||||
old_mu_batch = self.mu[:, start:stop]
|
||||
old_sigma_batch = self.sigma[:, start:stop]
|
||||
returns_batch = self.returns[:, start:stop]
|
||||
advantages_batch = self.advantages[:, start:stop]
|
||||
values_batch = self.values[:, start:stop]
|
||||
old_actions_log_prob_batch = self.actions_log_prob[:, start:stop]
|
||||
|
||||
# reshape to [num_envs, time, num layers, hidden dim] (original shape: [time, num_layers, num_envs, hidden_dim])
|
||||
# then take only time steps after dones (flattens num envs and time dimensions),
|
||||
# take a batch of trajectories and finally reshape back to [num_layers, batch, hidden_dim]
|
||||
last_was_done = last_was_done.permute(1, 0)
|
||||
hid_a_batch = [
|
||||
saved_hidden_states.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj]
|
||||
.transpose(1, 0)
|
||||
.contiguous()
|
||||
for saved_hidden_states in self.saved_hidden_states_a
|
||||
]
|
||||
hid_c_batch = [
|
||||
saved_hidden_states.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj]
|
||||
.transpose(1, 0)
|
||||
.contiguous()
|
||||
for saved_hidden_states in self.saved_hidden_states_c
|
||||
]
|
||||
# remove the tuple for GRU
|
||||
hid_a_batch = hid_a_batch[0] if len(hid_a_batch) == 1 else hid_a_batch
|
||||
hid_c_batch = hid_c_batch[0] if len(hid_c_batch) == 1 else hid_c_batch
|
||||
|
||||
yield obs_batch, critic_obs_batch, actions_batch, values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, (
|
||||
hid_a_batch,
|
||||
hid_c_batch,
|
||||
), masks_batch
|
||||
|
||||
first_traj = last_traj
|
||||
yield batch
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
from abc import abstractmethod
|
||||
import torch
|
||||
from typing import Dict, Generator, List
|
||||
|
||||
from rsl_rl.utils.serializable import Serializable
|
||||
|
||||
|
||||
# prev_obs, prev_obs_info, actions, rewards, next_obs, next_obs_info, dones, data
|
||||
Transition = Dict[str, torch.Tensor]
|
||||
Dataset = List[Transition]
|
||||
|
||||
|
||||
class Storage(Serializable):
|
||||
@abstractmethod
|
||||
def append(self, dataset: Dataset) -> None:
|
||||
"""Adds transitions to the storage.
|
||||
|
||||
Args:
|
||||
dataset (Dataset): The transitions to add to the storage.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def batch_generator(self, batch_size: int, batch_count: int) -> Generator[Dict[str, torch.Tensor], None, None]:
|
||||
"""Generates a batch of transitions.
|
||||
|
||||
Args:
|
||||
batch_size (int): The size of each batch to generate.
|
||||
batch_count (int): The number of batches to generate.
|
||||
Returns:
|
||||
A generator that yields transitions.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def initialized(self) -> bool:
|
||||
"""Returns whether the storage is initialized."""
|
||||
return True
|
||||
|
||||
@abstractmethod
|
||||
def sample_count(self) -> int:
|
||||
"""Returns how many individual samples are stored in the storage.
|
||||
|
||||
Returns:
|
||||
The number of stored samples.
|
||||
"""
|
||||
pass
|
|
@ -1,6 +1,3 @@
|
|||
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
"""Helper functions."""
|
||||
|
||||
from .utils import split_and_pad_trajectories, store_code_state, unpad_trajectories
|
||||
from .utils import split_and_pad_trajectories, unpad_trajectories, store_code_state
|
||||
|
|
|
@ -0,0 +1,111 @@
|
|||
import numpy as np
|
||||
import time
|
||||
from typing import Callable, Dict
|
||||
|
||||
|
||||
class Benchmark:
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def __call__(self):
|
||||
if self.running:
|
||||
self.end()
|
||||
else:
|
||||
self.start()
|
||||
|
||||
def end(self):
|
||||
now = time.process_time_ns()
|
||||
|
||||
assert self.running
|
||||
|
||||
difference = now - self._timer
|
||||
self._timings.append(difference)
|
||||
|
||||
self._timer = None
|
||||
|
||||
def reset(self):
|
||||
self._timer = None
|
||||
self._timings = []
|
||||
|
||||
@property
|
||||
def running(self):
|
||||
return self._timer is not None
|
||||
|
||||
def start(self):
|
||||
self._timer = time.process_time_ns()
|
||||
|
||||
@property
|
||||
def timings(self):
|
||||
return self._timings
|
||||
|
||||
|
||||
class Benchmarkable:
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._benchmark = False
|
||||
self._bm_data = dict()
|
||||
self._bm_fusions = []
|
||||
|
||||
def _bm(self, name: str) -> None:
|
||||
if not self._benchmark:
|
||||
return
|
||||
|
||||
if name not in self._bm_data:
|
||||
self._bm_data[name] = Benchmark()
|
||||
|
||||
self._bm_data[name]()
|
||||
|
||||
def _bm_flush(self) -> None:
|
||||
# TODO: implement
|
||||
for val in self._bm_data.values():
|
||||
val.reset()
|
||||
|
||||
for fusion in self._bm_fusions:
|
||||
fusion["target"]._bm_flush()
|
||||
|
||||
def _bm_fuse(self, target, prefix="") -> None:
|
||||
assert isinstance(target, Benchmarkable)
|
||||
assert target not in self._bm_fusions
|
||||
|
||||
target._bm_toggle(self._benchmark)
|
||||
self._bm_fusions.append(dict(target=target, prefix=prefix))
|
||||
|
||||
def _bm_report(self) -> Dict:
|
||||
data = dict()
|
||||
|
||||
if not self._benchmark:
|
||||
return data
|
||||
|
||||
for key, val in self._bm_data.items():
|
||||
data[key] = (np.mean(val.timings), len(val.timings))
|
||||
|
||||
for fusion in self._bm_fusions:
|
||||
target = fusion["target"]
|
||||
prefix = fusion["prefix"]
|
||||
|
||||
for key, val in target._bm_report().items():
|
||||
data[f"{prefix}{key}"] = val
|
||||
|
||||
return data
|
||||
|
||||
def _bm_toggle(self, value: bool) -> None:
|
||||
self._benchmark = value
|
||||
|
||||
for fusion in self._bm_fusions:
|
||||
fusion["target"]._bm_toggle(value)
|
||||
|
||||
@staticmethod
|
||||
def register(method: Callable, name=None) -> Callable:
|
||||
benchmark_name = method.__name__ if name is None else name
|
||||
|
||||
def wrapper(self, *args, **kwargs):
|
||||
assert isinstance(self, Benchmarkable)
|
||||
|
||||
self._bm(benchmark_name)
|
||||
result = method(self, *args, **kwargs)
|
||||
self._bm(benchmark_name)
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
|
@ -1,32 +1,26 @@
|
|||
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import asdict
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from legged_gym.utils import class_to_dict
|
||||
|
||||
try:
|
||||
import neptune
|
||||
import neptune.new as neptune
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError("neptune-client is required to log to Neptune.")
|
||||
|
||||
|
||||
class NeptuneLogger:
|
||||
def __init__(self, project, token):
|
||||
self.run = neptune.init_run(project=project, api_token=token)
|
||||
self.run = neptune.init(project=project, api_token=token)
|
||||
|
||||
def store_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg):
|
||||
self.run["runner_cfg"] = runner_cfg
|
||||
self.run["policy_cfg"] = policy_cfg
|
||||
self.run["alg_cfg"] = alg_cfg
|
||||
self.run["env_cfg"] = asdict(env_cfg)
|
||||
self.run["env_cfg"] = class_to_dict(env_cfg)
|
||||
|
||||
|
||||
class NeptuneSummaryWriter(SummaryWriter):
|
||||
"""Summary writer for Neptune."""
|
||||
|
||||
def __init__(self, log_dir: str, flush_secs: int, cfg):
|
||||
super().__init__(log_dir, flush_secs)
|
||||
|
||||
|
@ -86,7 +80,3 @@ class NeptuneSummaryWriter(SummaryWriter):
|
|||
|
||||
def save_model(self, model_path, iter):
|
||||
self.neptune_logger.run["model/saved_model_" + str(iter)].upload(model_path)
|
||||
|
||||
def save_file(self, path, iter=None):
|
||||
name = path.rsplit("/", 1)[-1].split(".")[0]
|
||||
self.neptune_logger.run["git_diff/" + name].upload(path)
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def trajectories_to_transitions(trajectories: torch.Tensor, data: Tuple[torch.Tensor, int, bool]) -> torch.Tensor:
|
||||
"""Unpacks a tensor of trajectories into a tensor of transitions.
|
||||
|
||||
Args:
|
||||
trajectories (torch.Tensor): A tensor of trajectories.
|
||||
data (Tuple[torch.Tensor, int, bool]): A tuple containing the mask and data for the conversion.
|
||||
batch_first (bool, optional): Whether the first dimension of the trajectories tensor is the batch dimension.
|
||||
Defaults to False.
|
||||
Returns:
|
||||
A tensor of transitions of shape (batch_size, time, *).
|
||||
"""
|
||||
mask, batch_size, batch_first = data
|
||||
|
||||
if not batch_first:
|
||||
trajectories, mask = trajectories.transpose(0, 1), mask.transpose(0, 1)
|
||||
|
||||
transitions = trajectories[mask == 1.0].reshape(batch_size, -1, *trajectories.shape[2:])
|
||||
|
||||
return transitions
|
||||
|
||||
|
||||
def transitions_to_trajectories(
|
||||
transitions: torch.Tensor, dones: torch.Tensor, batch_first: bool = False
|
||||
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, int, bool]]:
|
||||
"""Packs a tensor of transitions into a tensor of trajectories.
|
||||
|
||||
Example:
|
||||
>>> transitions = torch.tensor([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [11, 12]]])
|
||||
>>> dones = torch.tensor([[0, 0, 1], [0, 1, 0]])
|
||||
>>> transitions_to_trajectories(None, transitions, dones, batch_first=True)
|
||||
(tensor([[[1, 2], [3, 4], [5, 6]], [[7, 8], [9, 10], [0, 0]], [[11, 12], [0, 0], [0, 0]]]), tensor([[1, 1, 1], [1, 1, 0], [1, 0, 0]]))
|
||||
|
||||
Args:
|
||||
transitions (torch.Tensor): Tensor of transitions of shape (batch_size, time, *).
|
||||
dones (torch.Tensor): Tensor of transition terminations of shape (batch_size, time).
|
||||
batch_first (bool): Whether the first dimension of the output tensor should be the batch dimension. Defaults to
|
||||
False.
|
||||
Returns:
|
||||
A torch.Tensor of trajectories of shape (time, trajectory_count, *) that is padded with zeros and data for
|
||||
reverting the operation. If batch_first is True, the shape of the trajectories is (trajectory_count, time, *).
|
||||
"""
|
||||
batch_size = transitions.shape[0]
|
||||
|
||||
# Count the trajectory lengths by (1) padding dones with a 1 at the end to indicate the end of the trajectory,
|
||||
# (2) stacking up the padded dones in a single column, and (3) counting the number of steps between each done by
|
||||
# using the row index.
|
||||
padded_dones = dones.clone()
|
||||
padded_dones[:, -1] = 1
|
||||
stacked_dones = torch.cat((padded_dones.new([-1]), padded_dones.reshape(-1, 1).nonzero()[:, 0]))
|
||||
trajectory_lengths = stacked_dones[1:] - stacked_dones[:-1]
|
||||
|
||||
# Compute trajectories by splitting transitions according to previously computed trajectory lengths.
|
||||
trajectory_list = torch.split(transitions.flatten(0, 1), trajectory_lengths.int().tolist())
|
||||
trajectories = torch.nn.utils.rnn.pad_sequence(trajectory_list, batch_first=batch_first)
|
||||
|
||||
# The mask is generated by computing a 2d matrix of increasing counts in the 2nd dimension and comparing it to the
|
||||
# trajectory lengths.
|
||||
range = torch.arange(0, trajectory_lengths.max()).repeat(len(trajectory_lengths), 1)
|
||||
range = range.cuda(dones.device) if dones.is_cuda else range
|
||||
mask = (trajectory_lengths.unsqueeze(1) > range).float()
|
||||
|
||||
if not batch_first:
|
||||
mask = mask.T
|
||||
|
||||
return trajectories, (mask, batch_size, batch_first)
|
|
@ -0,0 +1,43 @@
|
|||
from typing import Any, Dict
|
||||
|
||||
|
||||
class Serializable:
|
||||
def load_state_dict(self, data: Dict[str, Any]) -> None:
|
||||
"""Loads agent parameters from a dictionary."""
|
||||
assert hasattr(self, "_serializable_objects")
|
||||
|
||||
for name in self._serializable_objects:
|
||||
assert hasattr(self, name)
|
||||
assert name in data, f'Object "{name}" was not found while loading "{self.__class__.__name__}".'
|
||||
|
||||
attr = getattr(self, name)
|
||||
if hasattr(attr, "load_state_dict"):
|
||||
print(f"Loading {name}")
|
||||
attr.load_state_dict(data[name])
|
||||
else:
|
||||
print(f"Loading value {name}={data[name]}")
|
||||
setattr(self, name, data[name])
|
||||
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
"""Returns a dictionary containing the agent parameters."""
|
||||
assert hasattr(self, "_serializable_objects")
|
||||
|
||||
data = {}
|
||||
|
||||
for name in self._serializable_objects:
|
||||
assert hasattr(self, name)
|
||||
|
||||
attr = getattr(self, name)
|
||||
data[name] = attr.state_dict() if hasattr(attr, "state_dict") else attr
|
||||
|
||||
return data
|
||||
|
||||
def _register_serializable(self, *objects) -> None:
|
||||
if not hasattr(self, "_serializable_objects"):
|
||||
self._serializable_objects = []
|
||||
|
||||
for name in objects:
|
||||
if name in self._serializable_objects:
|
||||
continue
|
||||
|
||||
self._serializable_objects.append(name)
|
|
@ -1,16 +1,38 @@
|
|||
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
import git
|
||||
import os
|
||||
import numpy as np
|
||||
import pathlib
|
||||
import random
|
||||
import torch
|
||||
|
||||
|
||||
def environment_dimensions(env):
|
||||
obs = env.get_observations()
|
||||
|
||||
if isinstance(obs, tuple):
|
||||
obs, env_info = obs
|
||||
else:
|
||||
env_info = {}
|
||||
|
||||
dims = {}
|
||||
|
||||
dims["observations"] = obs.shape[1]
|
||||
|
||||
if "observations" in env_info and "critic" in env_info["observations"]:
|
||||
dims["actor_observations"] = dims["observations"]
|
||||
dims["critic_observations"] = env_info["observations"]["critic"].shape[1]
|
||||
else:
|
||||
dims["actor_observations"] = dims["observations"]
|
||||
dims["critic_observations"] = dims["observations"]
|
||||
|
||||
dims["actions"] = env.num_actions
|
||||
|
||||
return dims
|
||||
|
||||
|
||||
def split_and_pad_trajectories(tensor, dones):
|
||||
"""Splits trajectories at done indices. Then concatenates them and pads with zeros up to the length og the longest trajectory.
|
||||
"""Splits trajectories at done indices. Then concatenates them and padds with zeros up to the length og the longest trajectory.
|
||||
Returns masks corresponding to valid parts of the trajectories
|
||||
Example:
|
||||
Input: [ [a1, a2, a3, a4 | a5, a6],
|
||||
|
@ -24,7 +46,7 @@ def split_and_pad_trajectories(tensor, dones):
|
|||
[b6, 0, 0, 0] | [True, False, False, False],
|
||||
] | ]
|
||||
|
||||
Assumes that the inputy has the following dimension order: [time, number of envs, additional dimensions]
|
||||
Assumes that the inputy has the following dimension order: [time, number of envs, aditional dimensions]
|
||||
"""
|
||||
dones = dones.clone()
|
||||
dones[-1] = 1
|
||||
|
@ -37,12 +59,7 @@ def split_and_pad_trajectories(tensor, dones):
|
|||
trajectory_lengths_list = trajectory_lengths.tolist()
|
||||
# Extract the individual trajectories
|
||||
trajectories = torch.split(tensor.transpose(1, 0).flatten(0, 1), trajectory_lengths_list)
|
||||
# add at least one full length trajectory
|
||||
trajectories = trajectories + (torch.zeros(tensor.shape[0], tensor.shape[-1], device=tensor.device),)
|
||||
# pad the trajectories to the length of the longest trajectory
|
||||
padded_trajectories = torch.nn.utils.rnn.pad_sequence(trajectories)
|
||||
# remove the added tensor
|
||||
padded_trajectories = padded_trajectories[:, :-1]
|
||||
|
||||
trajectory_masks = trajectory_lengths > torch.arange(0, tensor.shape[0], device=tensor.device).unsqueeze(1)
|
||||
return padded_trajectories, trajectory_masks
|
||||
|
@ -58,29 +75,34 @@ def unpad_trajectories(trajectories, masks):
|
|||
)
|
||||
|
||||
|
||||
def store_code_state(logdir, repositories) -> list:
|
||||
git_log_dir = os.path.join(logdir, "git")
|
||||
os.makedirs(git_log_dir, exist_ok=True)
|
||||
file_paths = []
|
||||
def store_code_state(logdir, repositories):
|
||||
for repository_file_path in repositories:
|
||||
try:
|
||||
repo = git.Repo(repository_file_path, search_parent_directories=True)
|
||||
except Exception:
|
||||
print(f"Could not find git repository in {repository_file_path}. Skipping.")
|
||||
# skip if not a git repository
|
||||
continue
|
||||
# get the name of the repository
|
||||
repo = git.Repo(repository_file_path, search_parent_directories=True)
|
||||
repo_name = pathlib.Path(repo.working_dir).name
|
||||
t = repo.head.commit.tree
|
||||
diff_file_name = os.path.join(git_log_dir, f"{repo_name}.diff")
|
||||
# check if the diff file already exists
|
||||
if os.path.isfile(diff_file_name):
|
||||
continue
|
||||
# write the diff file
|
||||
print(f"Storing git diff for '{repo_name}' in: {diff_file_name}")
|
||||
with open(diff_file_name, "x") as f:
|
||||
content = f"--- git status ---\n{repo.git.status()} \n\n\n--- git diff ---\n{repo.git.diff(t)}"
|
||||
content = f"--- git status ---\n{repo.git.status()} \n\n\n--- git diff ---\n{repo.git.diff(t)}"
|
||||
with open(os.path.join(logdir, f"{repo_name}_git.diff"), "x", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
# add the file path to the list of files to be uploaded
|
||||
file_paths.append(diff_file_name)
|
||||
return file_paths
|
||||
|
||||
|
||||
def seed(s=None):
|
||||
seed = int(datetime.now().timestamp() * 1e6) % 2**32 if s is None else s
|
||||
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def squeeze_preserve_batch(tensor):
|
||||
"""Squeezes a tensor, but preserves the batch dimension"""
|
||||
single_batch = tensor.shape[0] == 1
|
||||
|
||||
squeezed_tensor = tensor.squeeze()
|
||||
|
||||
if single_batch:
|
||||
squeezed_tensor = squeezed_tensor.unsqueeze(0)
|
||||
|
||||
return squeezed_tensor
|
||||
|
|
|
@ -1,11 +1,7 @@
|
|||
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import asdict
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from legged_gym.utils import class_to_dict
|
||||
|
||||
try:
|
||||
import wandb
|
||||
|
@ -14,8 +10,6 @@ except ModuleNotFoundError:
|
|||
|
||||
|
||||
class WandbSummaryWriter(SummaryWriter):
|
||||
"""Summary writer for Weights and Biases."""
|
||||
|
||||
def __init__(self, log_dir: str, flush_secs: int, cfg):
|
||||
super().__init__(log_dir, flush_secs)
|
||||
|
||||
|
@ -49,7 +43,7 @@ class WandbSummaryWriter(SummaryWriter):
|
|||
wandb.config.update({"runner_cfg": runner_cfg})
|
||||
wandb.config.update({"policy_cfg": policy_cfg})
|
||||
wandb.config.update({"alg_cfg": alg_cfg})
|
||||
wandb.config.update({"env_cfg": asdict(env_cfg)})
|
||||
wandb.config.update({"env_cfg": class_to_dict(env_cfg)})
|
||||
|
||||
def _map_path(self, path):
|
||||
if path in self.name_map:
|
||||
|
@ -74,7 +68,4 @@ class WandbSummaryWriter(SummaryWriter):
|
|||
self.store_config(env_cfg, runner_cfg, alg_cfg, policy_cfg)
|
||||
|
||||
def save_model(self, model_path, iter):
|
||||
wandb.save(model_path, base_path=os.path.dirname(model_path))
|
||||
|
||||
def save_file(self, path, iter=None):
|
||||
wandb.save(path, base_path=os.path.dirname(path))
|
||||
wandb.save(model_path)
|
||||
|
|
17
setup.py
17
setup.py
|
@ -1,24 +1,13 @@
|
|||
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name="rsl_rl",
|
||||
version="2.0.2",
|
||||
version="1.0.2",
|
||||
packages=find_packages(),
|
||||
author="ETH Zurich, NVIDIA CORPORATION",
|
||||
maintainer="Nikita Rudin, David Hoeller",
|
||||
maintainer_email="rudinn@ethz.ch",
|
||||
url="https://github.com/leggedrobotics/rsl_rl",
|
||||
license="BSD-3",
|
||||
description="Fast and simple RL algorithms implemented in pytorch",
|
||||
python_requires=">=3.6",
|
||||
install_requires=[
|
||||
"torch>=1.10.0",
|
||||
"torchvision>=0.5.0",
|
||||
"numpy>=1.16.4",
|
||||
"GitPython",
|
||||
"onnx",
|
||||
"GitPython", "gym", "numpy", "onnx", "tensorboard", "torch", "torchvision", "wandb",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,169 @@
|
|||
import unittest
|
||||
|
||||
from rsl_rl.algorithms import D4PG, DDPG, DPPO, DSAC, PPO, SAC, TD3
|
||||
from rsl_rl.env.gym_env import GymEnv
|
||||
from rsl_rl.modules import Network
|
||||
from rsl_rl.runners.runner import Runner
|
||||
|
||||
DEVICE = "cpu"
|
||||
|
||||
|
||||
class AlgorithmTestCaseMixin:
|
||||
algorithm_class = None
|
||||
|
||||
def _make_env(self, params={}):
|
||||
my_params = dict(name="LunarLanderContinuous-v2", device=DEVICE, environment_count=4)
|
||||
my_params.update(params)
|
||||
|
||||
return GymEnv(**my_params)
|
||||
|
||||
def _make_agent(self, env, agent_params={}):
|
||||
return self.algorithm_class(env, device=DEVICE, **agent_params)
|
||||
|
||||
def _make_runner(self, env, agent, runner_params={}):
|
||||
if not runner_params or "num_steps_per_env" not in runner_params:
|
||||
runner_params["num_steps_per_env"] = 6
|
||||
|
||||
return Runner(env, agent, device=DEVICE, **runner_params)
|
||||
|
||||
def _learn(self, env, agent, runner_params={}):
|
||||
runner = self._make_runner(env, agent, runner_params)
|
||||
runner.learn(10)
|
||||
|
||||
def test_default(self):
|
||||
env = self._make_env()
|
||||
agent = self._make_agent(env)
|
||||
|
||||
self._learn(env, agent)
|
||||
|
||||
def test_single_env_single_step(self):
|
||||
env = self._make_env(dict(environment_count=1))
|
||||
agent = self._make_agent(env)
|
||||
|
||||
self._learn(env, agent, dict(num_steps_per_env=1))
|
||||
|
||||
|
||||
class RecurrentAlgorithmTestCaseMixin(AlgorithmTestCaseMixin):
|
||||
def test_recurrent(self):
|
||||
env = self._make_env()
|
||||
agent = self._make_agent(env, dict(recurrent=True))
|
||||
|
||||
self._learn(env, agent)
|
||||
|
||||
def test_single_env_single_step_recurrent(self):
|
||||
env = self._make_env(dict(environment_count=1))
|
||||
agent = self._make_agent(env, dict(recurrent=True))
|
||||
|
||||
self._learn(env, agent, dict(num_steps_per_env=1))
|
||||
|
||||
|
||||
class D4PGTest(AlgorithmTestCaseMixin, unittest.TestCase):
|
||||
algorithm_class = D4PG
|
||||
|
||||
|
||||
class DDPGTest(AlgorithmTestCaseMixin, unittest.TestCase):
|
||||
algorithm_class = DDPG
|
||||
|
||||
|
||||
iqn_params = dict(
|
||||
critic_network=DPPO.network_iqn,
|
||||
iqn_action_samples=8,
|
||||
iqn_embedding_size=16,
|
||||
iqn_feature_layers=2,
|
||||
iqn_value_samples=4,
|
||||
value_loss=DPPO.value_loss_energy,
|
||||
)
|
||||
|
||||
qrdqn_params = dict(
|
||||
critic_network=DPPO.network_qrdqn,
|
||||
qrdqn_quantile_count=16,
|
||||
value_loss=DPPO.value_loss_l1,
|
||||
)
|
||||
|
||||
|
||||
class DPPOTest(RecurrentAlgorithmTestCaseMixin, unittest.TestCase):
|
||||
algorithm_class = DPPO
|
||||
|
||||
def test_qrdqn(self):
|
||||
env = self._make_env()
|
||||
agent = self._make_agent(env, qrdqn_params)
|
||||
|
||||
self._learn(env, agent)
|
||||
|
||||
def test_qrdqn_sing_env_single_step(self):
|
||||
env = self._make_env(dict(environment_count=1))
|
||||
agent = self._make_agent(env, qrdqn_params)
|
||||
|
||||
self._learn(env, agent, dict(num_steps_per_env=1))
|
||||
|
||||
def test_qrdqn_energy_loss(self):
|
||||
my_agent_params = qrdqn_params.copy()
|
||||
my_agent_params["value_loss"] = DPPO.value_loss_energy
|
||||
|
||||
env = self._make_env()
|
||||
agent = self._make_agent(env, my_agent_params)
|
||||
|
||||
self._learn(env, agent)
|
||||
|
||||
def test_qrdqn_huber_loss(self):
|
||||
my_agent_params = qrdqn_params.copy()
|
||||
my_agent_params["value_loss"] = DPPO.value_loss_huber
|
||||
|
||||
env = self._make_env()
|
||||
agent = self._make_agent(env, my_agent_params)
|
||||
|
||||
self._learn(env, agent)
|
||||
|
||||
def test_qrdqn_transformer(self):
|
||||
my_agent_params = qrdqn_params.copy()
|
||||
my_agent_params["recurrent"] = True
|
||||
my_agent_params["critic_recurrent_layers"] = 2
|
||||
my_agent_params["critic_recurrent_module"] = Network.recurrent_module_transformer
|
||||
my_agent_params["critic_recurrent_tf_context_length"] = 8
|
||||
my_agent_params["critic_recurrent_tf_head_count"] = 2
|
||||
|
||||
env = self._make_env()
|
||||
agent = self._make_agent(env, my_agent_params)
|
||||
|
||||
self._learn(env, agent)
|
||||
|
||||
def test_iqn(self):
|
||||
env = self._make_env()
|
||||
agent = self._make_agent(env, iqn_params)
|
||||
|
||||
self._learn(env, agent)
|
||||
|
||||
def test_iqn_single_step_single_env(self):
|
||||
env = self._make_env(dict(environment_count=1))
|
||||
agent = self._make_agent(env, iqn_params)
|
||||
|
||||
self._learn(env, agent, dict(num_steps_per_env=1))
|
||||
|
||||
def test_iqn_recurrent(self):
|
||||
my_agent_params = iqn_params.copy()
|
||||
my_agent_params["recurrent"] = True
|
||||
|
||||
env = self._make_env()
|
||||
agent = self._make_agent(env, my_agent_params)
|
||||
|
||||
self._learn(env, agent)
|
||||
|
||||
|
||||
class DSACTest(AlgorithmTestCaseMixin, unittest.TestCase):
|
||||
algorithm_class = DSAC
|
||||
|
||||
|
||||
class PPOTest(RecurrentAlgorithmTestCaseMixin, unittest.TestCase):
|
||||
algorithm_class = PPO
|
||||
|
||||
|
||||
class SACTest(AlgorithmTestCaseMixin, unittest.TestCase):
|
||||
algorithm_class = SAC
|
||||
|
||||
|
||||
class TD3Test(AlgorithmTestCaseMixin, unittest.TestCase):
|
||||
algorithm_class = TD3
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,126 @@
|
|||
import torch
|
||||
import unittest
|
||||
from rsl_rl.algorithms.dpg import AbstractDPG
|
||||
from rsl_rl.env.pole_balancing import PoleBalancing
|
||||
|
||||
|
||||
class DPG(AbstractDPG):
|
||||
def draw_actions(self, obs, env_info):
|
||||
pass
|
||||
|
||||
def eval_mode(self):
|
||||
pass
|
||||
|
||||
def get_inference_policy(self, device=None):
|
||||
pass
|
||||
|
||||
def process_transition(
|
||||
self, observations, environment_info, actions, rewards, next_observations, next_environment_info, dones, data
|
||||
):
|
||||
pass
|
||||
|
||||
def register_terminations(self, terminations):
|
||||
pass
|
||||
|
||||
def to(self, device):
|
||||
pass
|
||||
|
||||
def train_mode(self):
|
||||
pass
|
||||
|
||||
def update(self, storage):
|
||||
pass
|
||||
|
||||
|
||||
class FakeCritic(torch.nn.Module):
|
||||
def __init__(self, values):
|
||||
self.values = values
|
||||
|
||||
def forward(self, _):
|
||||
return self.values
|
||||
|
||||
|
||||
class DPGTest(unittest.TestCase):
|
||||
def test_timeout_bootstrapping(self):
|
||||
env = PoleBalancing(environment_count=4)
|
||||
dpg = DPG(env, device="cpu", return_steps=3)
|
||||
|
||||
rewards = torch.tensor(
|
||||
[
|
||||
[0.1000, 0.4000, 0.6000, 0.2000, -0.6000, -0.2000],
|
||||
[0.0000, 0.9000, 0.5000, -0.9000, -0.4000, 0.8000],
|
||||
[-0.5000, 0.4000, 0.0000, -0.2000, 0.3000, 0.1000],
|
||||
[-0.8000, 0.9000, -0.6000, 0.7000, 0.5000, 0.1000],
|
||||
]
|
||||
)
|
||||
dones = torch.tensor(
|
||||
[
|
||||
[0, 0, 1, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 1, 0, 0],
|
||||
]
|
||||
)
|
||||
timeouts = torch.tensor(
|
||||
[
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 1, 0, 0, 0, 0],
|
||||
]
|
||||
)
|
||||
actions = torch.zeros((4, 6, 1))
|
||||
observations = torch.zeros((4, 6, 2))
|
||||
values = torch.tensor([-0.1000, -0.8000, 0.4000, 0.7000])
|
||||
|
||||
dpg.critic = FakeCritic(values)
|
||||
dataset = [
|
||||
{
|
||||
"actions": actions[:, i],
|
||||
"critic_observations": observations[:, i],
|
||||
"dones": dones[:, i],
|
||||
"rewards": rewards[:, i],
|
||||
"timeouts": timeouts[:, i],
|
||||
}
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
processed_dataset = dpg._process_dataset(dataset)
|
||||
processed_rewards = torch.stack([processed_dataset[i]["rewards"] for i in range(1)], dim=-1)
|
||||
|
||||
expected_rewards = torch.tensor(
|
||||
[
|
||||
[1.08406],
|
||||
[1.38105],
|
||||
[-0.5],
|
||||
[0.77707],
|
||||
]
|
||||
)
|
||||
self.assertTrue(len(processed_dataset) == 1)
|
||||
self.assertTrue(torch.isclose(processed_rewards, expected_rewards).all())
|
||||
|
||||
dataset = [
|
||||
{
|
||||
"actions": actions[:, i + 3],
|
||||
"critic_observations": observations[:, i + 3],
|
||||
"dones": dones[:, i + 3],
|
||||
"rewards": rewards[:, i + 3],
|
||||
"timeouts": timeouts[:, i + 3],
|
||||
}
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
processed_dataset = dpg._process_dataset(dataset)
|
||||
processed_rewards = torch.stack([processed_dataset[i]["rewards"] for i in range(3)], dim=-1)
|
||||
|
||||
expected_rewards = torch.tensor(
|
||||
[
|
||||
[0.994, 0.6, -0.59002],
|
||||
[0.51291, -1.5592792, -2.08008],
|
||||
[0.20398, 0.09603, 0.19501],
|
||||
[1.593, 0.093, 0.7],
|
||||
]
|
||||
)
|
||||
|
||||
self.assertTrue(len(processed_dataset) == 3)
|
||||
self.assertTrue(torch.isclose(processed_rewards, expected_rewards).all())
|
|
@ -0,0 +1,278 @@
|
|||
import torch
|
||||
import unittest
|
||||
from rsl_rl.algorithms import DPPO
|
||||
from rsl_rl.env.pole_balancing import PoleBalancing
|
||||
|
||||
|
||||
class FakeCritic(torch.nn.Module):
|
||||
def __init__(self, values, quantile_count=1):
|
||||
self.quantile_count = quantile_count
|
||||
self.recurrent = False
|
||||
self.values = values
|
||||
self.last_quantiles = values
|
||||
|
||||
def forward(self, _, distribution=False, measure_args=None):
|
||||
if distribution:
|
||||
return self.values
|
||||
|
||||
return self.values.mean(-1)
|
||||
|
||||
def quantiles_to_values(self, quantiles):
|
||||
return quantiles.mean(-1)
|
||||
|
||||
|
||||
class DPPOTest(unittest.TestCase):
|
||||
def test_gae_computation(self):
|
||||
# GT taken from old PPO implementation.
|
||||
|
||||
env = PoleBalancing(environment_count=4)
|
||||
dppo = DPPO(env, device="cpu", gae_lambda=0.97, gamma=0.99, qrdqn_quantile_count=1)
|
||||
|
||||
rewards = torch.tensor(
|
||||
[
|
||||
[-1.0000e02, -1.4055e-01, -3.0476e-02, -2.7149e-01, -1.1157e-01, -2.3366e-01, -3.3658e-01, -1.6447e-01],
|
||||
[
|
||||
-1.7633e-01,
|
||||
-2.6533e-01,
|
||||
-3.0786e-01,
|
||||
-2.6038e-01,
|
||||
-2.7176e-01,
|
||||
-2.1655e-01,
|
||||
-1.5441e-01,
|
||||
-2.9580e-01,
|
||||
],
|
||||
[-1.5952e-01, -1.5177e-01, -1.4296e-01, -1.6131e-01, -3.1395e-02, 2.8808e-03, -3.1242e-02, 4.8696e-03],
|
||||
[1.1407e-02, -1.0000e02, -6.2290e-02, -3.7030e-01, -2.7648e-01, -3.6655e-01, -2.8456e-01, -2.3165e-01],
|
||||
]
|
||||
)
|
||||
dones = torch.tensor(
|
||||
[
|
||||
[1, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 1, 0, 0, 0, 0, 0, 0],
|
||||
]
|
||||
)
|
||||
observations = torch.zeros((dones.shape[0], dones.shape[1], 24))
|
||||
timeouts = torch.zeros((dones.shape[0], dones.shape[1]))
|
||||
values = torch.tensor(
|
||||
[
|
||||
[-4.6342, -7.6510, -7.0166, -7.6137, -7.4130, -7.7071, -7.7413, -7.8301],
|
||||
[-7.0442, -7.0032, -6.9321, -6.7765, -6.5433, -6.3503, -6.2529, -5.9337],
|
||||
[-7.5753, -7.8146, -7.6142, -7.8443, -7.8791, -7.7973, -7.7853, -7.7724],
|
||||
[-6.4326, -6.1673, -7.6511, -7.7505, -8.0004, -7.8584, -7.5949, -7.9023],
|
||||
]
|
||||
)
|
||||
value_quants = values.unsqueeze(-1)
|
||||
last_values = torch.tensor([-7.9343, -5.8734, -7.8527, -8.1257])
|
||||
|
||||
dppo.critic = FakeCritic(last_values.unsqueeze(-1))
|
||||
dataset = [
|
||||
{
|
||||
"dones": dones[:, i],
|
||||
"full_next_critic_observations": observations[:, i].clone(),
|
||||
"next_critic_observations": observations[:, i],
|
||||
"rewards": rewards[:, i],
|
||||
"timeouts": timeouts[:, i],
|
||||
"values": values[:, i],
|
||||
"value_quants": value_quants[:, i],
|
||||
}
|
||||
for i in range(dones.shape[1])
|
||||
]
|
||||
|
||||
processed_dataset = dppo._process_dataset(dataset)
|
||||
processed_returns = torch.stack(
|
||||
[processed_dataset[i]["advantages"] + processed_dataset[i]["values"] for i in range(dones.shape[1])],
|
||||
dim=-1,
|
||||
)
|
||||
processed_advantages = torch.stack(
|
||||
[processed_dataset[i]["normalized_advantages"] for i in range(dones.shape[1])], dim=-1
|
||||
)
|
||||
|
||||
expected_returns = torch.tensor(
|
||||
[
|
||||
[-100.0000, -8.4983, -8.4863, -8.5699, -8.4122, -8.4054, -8.2702, -8.0194],
|
||||
[-7.2900, -7.1912, -6.9978, -6.7569, -6.5627, -6.3547, -6.1985, -6.1104],
|
||||
[-7.9179, -7.8374, -7.7679, -7.6976, -7.6041, -7.6446, -7.7229, -7.7693],
|
||||
[-96.2018, -100.0000, -9.0710, -9.1415, -8.8863, -8.7228, -8.4668, -8.2761],
|
||||
]
|
||||
)
|
||||
expected_advantages = torch.tensor(
|
||||
[
|
||||
[-3.1452, 0.3006, 0.2779, 0.2966, 0.2951, 0.3060, 0.3122, 0.3246],
|
||||
[0.3225, 0.3246, 0.3291, 0.3322, 0.3308, 0.3313, 0.3335, 0.3250],
|
||||
[0.3190, 0.3307, 0.3259, 0.3368, 0.3415, 0.3371, 0.3338, 0.3316],
|
||||
[-2.9412, -3.0893, 0.2797, 0.2808, 0.2992, 0.3000, 0.2997, 0.3179],
|
||||
]
|
||||
)
|
||||
|
||||
self.assertTrue(torch.isclose(processed_returns, expected_returns, atol=1e-4).all())
|
||||
self.assertTrue(torch.isclose(processed_advantages, expected_advantages, atol=1e-4).all())
|
||||
|
||||
def test_target_computation_nstep(self):
|
||||
# GT taken from old PPO implementation.
|
||||
|
||||
env = PoleBalancing(environment_count=2)
|
||||
dppo = DPPO(env, device="cpu", gae_lambda=0.97, gamma=0.99, qrdqn_quantile_count=3, value_lambda=1.0)
|
||||
|
||||
rewards = torch.tensor(
|
||||
[
|
||||
[0.6, 1.0, 0.0, 0.6, 0.1, -0.2],
|
||||
[0.1, -0.2, -0.4, 0.1, 1.0, 1.0],
|
||||
]
|
||||
)
|
||||
dones = torch.tensor(
|
||||
[
|
||||
[1.0, 0.0, 0.0, 0.0, 1.0, 0.0],
|
||||
[0.0, 0.0, 1.0, 0.0, 1.0, 0.0],
|
||||
]
|
||||
)
|
||||
observations = torch.zeros((dones.shape[0], dones.shape[1], 24))
|
||||
timeouts = torch.zeros((dones.shape[0], dones.shape[1]))
|
||||
value_quants = torch.tensor(
|
||||
[
|
||||
[
|
||||
[-0.4, 0.0, 0.1],
|
||||
[0.9, 0.8, 0.7],
|
||||
[0.7, 1.3, 0.0],
|
||||
[1.2, 0.4, 1.2],
|
||||
[1.3, 1.3, 1.1],
|
||||
[0.0, 0.7, 0.5],
|
||||
],
|
||||
[
|
||||
[1.3, 1.3, 0.9],
|
||||
[0.4, -0.4, -0.1],
|
||||
[0.4, 0.6, 0.1],
|
||||
[0.7, 0.1, 0.3],
|
||||
[0.2, 0.1, 0.3],
|
||||
[1.4, 1.4, -0.3],
|
||||
],
|
||||
]
|
||||
)
|
||||
values = value_quants.mean(dim=-1)
|
||||
last_values = torch.rand((dones.shape[0], 3))
|
||||
|
||||
dppo.critic = FakeCritic(last_values, quantile_count=3)
|
||||
dataset = [
|
||||
{
|
||||
"dones": dones[:, i],
|
||||
"full_next_critic_observations": observations[:, i].clone(),
|
||||
"next_critic_observations": observations[:, i],
|
||||
"rewards": rewards[:, i],
|
||||
"timeouts": timeouts[:, i],
|
||||
"values": values[:, i],
|
||||
"value_quants": value_quants[:, i],
|
||||
}
|
||||
for i in range(dones.shape[1])
|
||||
]
|
||||
|
||||
processed_dataset = dppo._process_dataset(dataset)
|
||||
processed_value_target_quants = torch.stack(
|
||||
[processed_dataset[i]["value_target_quants"] for i in range(dones.shape[1])],
|
||||
dim=-2,
|
||||
)
|
||||
|
||||
# N-step returns
|
||||
# These exclude the reward received on the final step since it should not be added to the value target.
|
||||
expected_returns = torch.tensor(
|
||||
[
|
||||
[0.6, 1.58806, 0.594, 0.6, 0.1, -0.2],
|
||||
[-0.098, -0.2, -0.4, 0.1, 1.0, 1.0],
|
||||
]
|
||||
)
|
||||
reset = lambda x: [0.0 for _ in x]
|
||||
dscnt = lambda x, s=0: [x[v] * dppo.gamma**s for v in range(len(x))]
|
||||
expected_value_target_quants = expected_returns.unsqueeze(-1) + torch.tensor(
|
||||
[
|
||||
[
|
||||
reset([-0.4, 0.0, 0.1]),
|
||||
dscnt([1.3, 1.3, 1.1], 3),
|
||||
dscnt([1.3, 1.3, 1.1], 2),
|
||||
dscnt([1.3, 1.3, 1.1], 1),
|
||||
reset([1.3, 1.3, 1.1]),
|
||||
dscnt(last_values[0], 1),
|
||||
],
|
||||
[
|
||||
dscnt([0.4, 0.6, 0.1], 2),
|
||||
dscnt([0.4, 0.6, 0.1], 1),
|
||||
reset([0.4, 0.6, 0.1]),
|
||||
dscnt([0.2, 0.1, 0.3], 1),
|
||||
reset([0.2, 0.1, 0.3]),
|
||||
dscnt(last_values[1], 1),
|
||||
],
|
||||
]
|
||||
)
|
||||
|
||||
self.assertTrue(torch.isclose(processed_value_target_quants, expected_value_target_quants, atol=1e-4).all())
|
||||
|
||||
def test_target_computation_1step(self):
|
||||
# GT taken from old PPO implementation.
|
||||
|
||||
env = PoleBalancing(environment_count=2)
|
||||
dppo = DPPO(env, device="cpu", gae_lambda=0.97, gamma=0.99, qrdqn_quantile_count=3, value_lambda=0.0)
|
||||
|
||||
rewards = torch.tensor(
|
||||
[
|
||||
[0.6, 1.0, 0.0, 0.6, 0.1, -0.2],
|
||||
[0.1, -0.2, -0.4, 0.1, 1.0, 1.0],
|
||||
]
|
||||
)
|
||||
dones = torch.tensor(
|
||||
[
|
||||
[1.0, 0.0, 0.0, 0.0, 1.0, 0.0],
|
||||
[0.0, 0.0, 1.0, 0.0, 1.0, 0.0],
|
||||
]
|
||||
)
|
||||
observations = torch.zeros((dones.shape[0], dones.shape[1], 24))
|
||||
timeouts = torch.zeros((dones.shape[0], dones.shape[1]))
|
||||
value_quants = torch.tensor(
|
||||
[
|
||||
[
|
||||
[-0.4, 0.0, 0.1],
|
||||
[0.9, 0.8, 0.7],
|
||||
[0.7, 1.3, 0.0],
|
||||
[1.2, 0.4, 1.2],
|
||||
[1.3, 1.3, 1.1],
|
||||
[0.0, 0.7, 0.5],
|
||||
],
|
||||
[
|
||||
[1.3, 1.3, 0.9],
|
||||
[0.4, -0.4, -0.1],
|
||||
[0.4, 0.6, 0.1],
|
||||
[0.7, 0.1, 0.3],
|
||||
[0.2, 0.1, 0.3],
|
||||
[1.4, 1.4, -0.3],
|
||||
],
|
||||
]
|
||||
)
|
||||
values = value_quants.mean(dim=-1)
|
||||
last_values = torch.rand((dones.shape[0], 3))
|
||||
|
||||
dppo.critic = FakeCritic(last_values, quantile_count=3)
|
||||
dataset = [
|
||||
{
|
||||
"dones": dones[:, i],
|
||||
"full_next_critic_observations": observations[:, i].clone(),
|
||||
"next_critic_observations": observations[:, i],
|
||||
"rewards": rewards[:, i],
|
||||
"timeouts": timeouts[:, i],
|
||||
"values": values[:, i],
|
||||
"value_quants": value_quants[:, i],
|
||||
}
|
||||
for i in range(dones.shape[1])
|
||||
]
|
||||
|
||||
processed_dataset = dppo._process_dataset(dataset)
|
||||
processed_value_target_quants = torch.stack(
|
||||
[processed_dataset[i]["value_target_quants"] for i in range(dones.shape[1])],
|
||||
dim=-2,
|
||||
)
|
||||
|
||||
# 1-step returns
|
||||
expected_value_target_quants = rewards.unsqueeze(-1) + (
|
||||
(1.0 - dones).float().unsqueeze(-1)
|
||||
* dppo.gamma
|
||||
* torch.cat((value_quants[:, 1:], last_values.unsqueeze(1)), dim=1)
|
||||
)
|
||||
|
||||
self.assertTrue(torch.isclose(processed_value_target_quants, expected_value_target_quants, atol=1e-4).all())
|
|
@ -0,0 +1,171 @@
|
|||
import torch
|
||||
import unittest
|
||||
from rsl_rl.algorithms import DPPO
|
||||
from rsl_rl.env.vec_env import VecEnv
|
||||
|
||||
ACTION_SIZE = 3
|
||||
ENV_COUNT = 3
|
||||
OBS_SIZE = 24
|
||||
|
||||
|
||||
class FakeEnv(VecEnv):
|
||||
def __init__(self, rewards, dones, environment_count=1):
|
||||
super().__init__(OBS_SIZE, OBS_SIZE, environment_count=environment_count)
|
||||
|
||||
self.num_actions = ACTION_SIZE
|
||||
self.rewards = rewards
|
||||
self.dones = dones
|
||||
|
||||
self._step = 0
|
||||
|
||||
def get_observations(self):
|
||||
return torch.zeros((self.num_envs, self.num_obs)), {"observations": {}}
|
||||
|
||||
def get_privileged_observations(self):
|
||||
return torch.zeros((self.num_envs, self.num_privileged_obs)), {"observations": {}}
|
||||
|
||||
def step(self, actions):
|
||||
obs, _ = self.get_observations()
|
||||
rewards = self.rewards[self._step]
|
||||
dones = self.dones[self._step]
|
||||
|
||||
self._step += 1
|
||||
|
||||
return obs, rewards, dones, {"observations": {}}
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
|
||||
class FakeCritic(torch.nn.Module):
|
||||
def __init__(self, action_samples, value_samples, action_values, value_values, action_taus, value_taus):
|
||||
self.recurrent = False
|
||||
self.action_samples = action_samples
|
||||
self.value_samples = value_samples
|
||||
self.action_values = action_values
|
||||
self.value_values = value_values
|
||||
self.action_taus = action_taus
|
||||
self.value_taus = value_taus
|
||||
|
||||
self.last_quantiles = None
|
||||
self.last_taus = None
|
||||
|
||||
def forward(self, _, distribution=False, measure_args=None, sample_count=8, taus=None, use_measure=True):
|
||||
if taus is not None:
|
||||
sample_count = taus.shape[-1]
|
||||
|
||||
if sample_count == self.action_samples:
|
||||
self.last_taus = self.action_taus
|
||||
self.last_quantiles = self.action_values
|
||||
elif sample_count == self.value_samples:
|
||||
self.last_taus = self.value_taus
|
||||
self.last_quantiles = self.value_values
|
||||
else:
|
||||
raise ValueError(f"Invalid sample count: {sample_count}")
|
||||
|
||||
if distribution:
|
||||
return self.last_quantiles
|
||||
|
||||
return self.last_quantiles.mean(-1)
|
||||
|
||||
|
||||
def fake_process_quants(self, x):
|
||||
idx = torch.arange(0, x.shape[-1]).expand(*x.shape[:-1])
|
||||
|
||||
return x, idx
|
||||
|
||||
|
||||
class DPPOTest(unittest.TestCase):
|
||||
def test_value_target_computation(self):
|
||||
rewards = torch.tensor(
|
||||
[
|
||||
[-1.0000e02, -1.4055e-01, -3.0476e-02],
|
||||
[-1.7633e-01, -2.6533e-01, -3.0786e-01],
|
||||
[-1.5952e-01, -1.5177e-01, -1.4296e-01],
|
||||
[1.1407e-02, -1.0000e02, -6.2290e-02],
|
||||
]
|
||||
)
|
||||
dones = torch.tensor(
|
||||
[
|
||||
[1, 0, 0],
|
||||
[0, 0, 0],
|
||||
[0, 0, 0],
|
||||
[0, 1, 0],
|
||||
]
|
||||
)
|
||||
|
||||
env = FakeEnv(rewards, dones, environment_count=ENV_COUNT)
|
||||
dppo = DPPO(
|
||||
env,
|
||||
critic_network=DPPO.network_iqn,
|
||||
device="cpu",
|
||||
gae_lambda=0.97,
|
||||
gamma=0.99,
|
||||
iqn_action_samples=4,
|
||||
iqn_value_samples=2,
|
||||
value_lambda=1.0,
|
||||
value_loss=DPPO.value_loss_energy,
|
||||
)
|
||||
|
||||
# Generate fake dataset
|
||||
|
||||
action_taus = torch.tensor(
|
||||
[
|
||||
[[0.3, 0.5, 1.0, 0.2], [0.8, 0.9, 0.0, 0.9], [0.6, 0.1, 0.6, 0.5]],
|
||||
[[0.7, 0.9, 0.3, 0.0], [1.0, 0.7, 0.7, 0.7], [0.3, 0.8, 0.8, 0.1]],
|
||||
[[0.3, 0.8, 0.3, 0.2], [0.2, 0.9, 0.6, 0.4], [0.8, 0.4, 0.8, 1.0]],
|
||||
[[0.6, 0.6, 0.8, 0.8], [0.8, 0.0, 0.9, 0.1], [0.2, 0.3, 0.6, 0.2]],
|
||||
]
|
||||
)
|
||||
action_value_quants = torch.tensor(
|
||||
[
|
||||
[[0.2, 0.2, 0.6, 0.5], [0.5, 0.8, 0.1, 0.0], [1.0, 0.1, 0.8, 0.8]],
|
||||
[[0.0, 0.6, 0.1, 0.9], [0.2, 1.0, 0.9, 1.0], [0.4, 0.1, 0.1, 0.8]],
|
||||
[[0.7, 0.0, 0.6, 0.8], [0.7, 0.7, 0.7, 0.8], [0.0, 0.1, 0.5, 0.8]],
|
||||
[[0.5, 0.8, 0.1, 0.1], [0.9, 0.4, 0.7, 0.6], [0.6, 0.3, 0.1, 0.4]],
|
||||
]
|
||||
)
|
||||
value_taus = torch.tensor(
|
||||
[
|
||||
[[0.3, 0.5], [0.8, 0.9], [0.6, 0.1]],
|
||||
[[0.7, 0.9], [1.0, 0.7], [0.3, 0.8]],
|
||||
[[0.3, 0.8], [0.2, 0.9], [0.8, 0.4]],
|
||||
[[0.6, 0.6], [0.8, 0.0], [0.2, 0.3]],
|
||||
]
|
||||
)
|
||||
value_value_quants = torch.tensor(
|
||||
[
|
||||
[[0.9, 0.8], [0.1, 0.3], [0.3, 0.5]],
|
||||
[[0.2, 0.1], [0.9, 0.3], [0.4, 0.2]],
|
||||
[[0.7, 1.0], [0.6, 0.2], [0.2, 0.6]],
|
||||
[[0.4, 1.0], [0.3, 0.6], [0.3, 0.1]],
|
||||
]
|
||||
)
|
||||
|
||||
actions = torch.zeros(ENV_COUNT, ACTION_SIZE)
|
||||
env_info = {"observations": {}}
|
||||
obs = torch.zeros(ENV_COUNT, OBS_SIZE)
|
||||
dataset = []
|
||||
for i in range(4):
|
||||
dppo.critic = FakeCritic(4, 2, action_value_quants[i], value_value_quants[i], action_taus[i], value_taus[i])
|
||||
dppo.critic._process_quants = fake_process_quants
|
||||
|
||||
_, data = dppo.draw_actions(obs, {})
|
||||
_, rewards, dones, _ = env.step(actions)
|
||||
|
||||
dataset.append(
|
||||
dppo.process_transition(
|
||||
obs,
|
||||
env_info,
|
||||
actions,
|
||||
rewards,
|
||||
obs,
|
||||
env_info,
|
||||
dones,
|
||||
data,
|
||||
)
|
||||
)
|
||||
|
||||
processed_dataset = dppo._process_dataset(dataset)
|
||||
|
||||
# TODO: Test that the value targets are correct.
|
|
@ -0,0 +1,170 @@
|
|||
import torch
|
||||
import unittest
|
||||
from rsl_rl.algorithms import DPPO
|
||||
from rsl_rl.env.vec_env import VecEnv
|
||||
from rsl_rl.runners.runner import Runner
|
||||
from rsl_rl.utils.benchmarkable import Benchmarkable
|
||||
|
||||
|
||||
class FakeNetwork(torch.nn.Module, Benchmarkable):
|
||||
def __init__(self, values):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_state = None
|
||||
self.quantile_count = 1
|
||||
self.recurrent = True
|
||||
self.values = values
|
||||
|
||||
self._hidden_size = 2
|
||||
|
||||
def forward(self, x, hidden_state=None):
|
||||
if not hidden_state:
|
||||
self.hidden_state = (self.hidden_state[0] + 1, self.hidden_state[1] - 1)
|
||||
|
||||
values = self.values.repeat((*x.shape[:-1], 1)).squeeze(-1)
|
||||
values.requires_grad_(True)
|
||||
|
||||
return values
|
||||
|
||||
def reset_full_hidden_state(self, batch_size=None):
|
||||
assert batch_size is None or batch_size == 4, f"batch_size={batch_size}"
|
||||
|
||||
self.hidden_state = (torch.zeros((1, 4, self._hidden_size)), torch.zeros((1, 4, self._hidden_size)))
|
||||
|
||||
def reset_hidden_state(self, indices):
|
||||
self.hidden_state[0][:, indices] = torch.zeros((len(indices), self._hidden_size))
|
||||
self.hidden_state[1][:, indices] = torch.zeros((len(indices), self._hidden_size))
|
||||
|
||||
|
||||
class FakeActorNetwork(FakeNetwork):
|
||||
def forward(self, x, compute_std=False, hidden_state=None):
|
||||
values = super().forward(x, hidden_state=hidden_state)
|
||||
|
||||
if compute_std:
|
||||
return values, torch.ones_like(values)
|
||||
|
||||
return values
|
||||
|
||||
|
||||
class FakeCriticNetwork(FakeNetwork):
|
||||
_quantile_count = 1
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, x, distribution=False, hidden_state=None, measure_args=None):
|
||||
values = super().forward(x, hidden_state=hidden_state)
|
||||
self.last_quantiles = values.reshape(*values.shape, 1)
|
||||
|
||||
if distribution:
|
||||
return self.last_quantiles
|
||||
|
||||
return values
|
||||
|
||||
def quantile_l1_loss(self, *args, **kwargs):
|
||||
return torch.tensor(0.0)
|
||||
|
||||
def quantiles_to_values(self, quantiles):
|
||||
return quantiles.squeeze()
|
||||
|
||||
|
||||
class FakeEnv(VecEnv):
|
||||
def __init__(self, dones=None, **kwargs):
|
||||
super().__init__(3, 3, **kwargs)
|
||||
|
||||
self.num_actions = 3
|
||||
self._extra = {"observations": {}, "time_outs": torch.zeros((self.num_envs, 1))}
|
||||
|
||||
self._step = 0
|
||||
self._dones = dones
|
||||
|
||||
self.reset()
|
||||
|
||||
def get_observations(self):
|
||||
return self._state_buf, self._extra
|
||||
|
||||
def get_privileged_observations(self):
|
||||
return self._state_buf, self._extra
|
||||
|
||||
def reset(self):
|
||||
self._state_buf = torch.zeros((self.num_envs, self.num_obs))
|
||||
|
||||
return self._state_buf, self._extra
|
||||
|
||||
def step(self, actions):
|
||||
assert actions.shape[0] == self.num_envs
|
||||
assert actions.shape[1] == self.num_actions
|
||||
|
||||
self._state_buf += actions
|
||||
|
||||
rewards = torch.zeros((self.num_envs))
|
||||
dones = torch.zeros((self.num_envs)) if self._dones is None else self._dones[self._step % self._dones.shape[0]]
|
||||
|
||||
self._step += 1
|
||||
|
||||
return self._state_buf, rewards, dones, self._extra
|
||||
|
||||
|
||||
class DPPORecurrencyTest(unittest.TestCase):
|
||||
def test_draw_action_produces_hidden_state(self):
|
||||
"""Test that the hidden state is correctly added to the data dictionary when drawing actions."""
|
||||
env = FakeEnv(environment_count=4)
|
||||
dppo = DPPO(env, device="cpu", recurrent=True)
|
||||
|
||||
dppo.actor = FakeActorNetwork(torch.ones(env.num_actions))
|
||||
dppo.critic = FakeCriticNetwork(torch.zeros(1))
|
||||
|
||||
# Done during DPPO.__init__, however we need to reset the hidden state here again since we are using a fake
|
||||
# network that is added after initialization.
|
||||
dppo.actor.reset_full_hidden_state(batch_size=env.num_envs)
|
||||
dppo.critic.reset_full_hidden_state(batch_size=env.num_envs)
|
||||
|
||||
ones = torch.ones((1, env.num_envs, dppo.actor._hidden_size))
|
||||
state, extra = env.reset()
|
||||
for ctr in range(10):
|
||||
_, data = dppo.draw_actions(state, extra)
|
||||
|
||||
# Actor state is changed every time an action is drawn.
|
||||
self.assertTrue(torch.allclose(data["actor_state_h"], ones * ctr))
|
||||
self.assertTrue(torch.allclose(data["actor_state_c"], -ones * ctr))
|
||||
# Critic state is only changed and saved when processing the transition (evaluating the action) so we can't
|
||||
# check it here.
|
||||
|
||||
def test_update_produces_hidden_state(self):
|
||||
"""Test that the hidden state is correctly added to the data dictionary when updating."""
|
||||
dones = torch.cat((torch.tensor([[0, 0, 0, 1]]), torch.zeros((4, 4)), torch.tensor([[1, 0, 0, 0]])), dim=0)
|
||||
|
||||
env = FakeEnv(dones=dones, environment_count=4)
|
||||
dppo = DPPO(env, device="cpu", recurrent=True)
|
||||
runner = Runner(env, dppo, num_steps_per_env=6)
|
||||
|
||||
dppo._value_loss = lambda *args, **kwargs: torch.tensor(0.0)
|
||||
|
||||
dppo.actor = FakeActorNetwork(torch.ones(env.num_actions))
|
||||
dppo.critic = FakeCriticNetwork(torch.zeros(1))
|
||||
|
||||
dppo.actor.reset_full_hidden_state(batch_size=env.num_envs)
|
||||
dppo.critic.reset_full_hidden_state(batch_size=env.num_envs)
|
||||
|
||||
runner.learn(1)
|
||||
|
||||
state_h_0 = torch.tensor([[0, 0], [0, 0], [0, 0], [0, 0]])
|
||||
state_h_1 = torch.tensor([[1, 1], [1, 1], [1, 1], [0, 0]])
|
||||
state_h_2 = state_h_1 + 1
|
||||
state_h_3 = state_h_2 + 1
|
||||
state_h_4 = state_h_3 + 1
|
||||
state_h_5 = state_h_4 + 1
|
||||
state_h_6 = torch.tensor([[0, 0], [6, 6], [6, 6], [5, 5]])
|
||||
state_h = (
|
||||
torch.cat((state_h_0, state_h_1, state_h_2, state_h_3, state_h_4, state_h_5), dim=0).float().unsqueeze(1)
|
||||
)
|
||||
next_state_h = (
|
||||
torch.cat((state_h_1, state_h_2, state_h_3, state_h_4, state_h_5, state_h_6), dim=0).float().unsqueeze(1)
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(dppo.storage._data["critic_state_h"], state_h))
|
||||
self.assertTrue(torch.allclose(dppo.storage._data["critic_state_c"], -state_h))
|
||||
self.assertTrue(torch.allclose(dppo.storage._data["critic_next_state_h"], next_state_h))
|
||||
self.assertTrue(torch.allclose(dppo.storage._data["critic_next_state_c"], -next_state_h))
|
||||
self.assertTrue(torch.allclose(dppo.storage._data["actor_state_h"], state_h))
|
||||
self.assertTrue(torch.allclose(dppo.storage._data["actor_state_c"], -state_h))
|
|
@ -0,0 +1,99 @@
|
|||
import torch
|
||||
import unittest
|
||||
from rsl_rl.algorithms import PPO
|
||||
from rsl_rl.env.pole_balancing import PoleBalancing
|
||||
|
||||
|
||||
class FakeCritic(torch.nn.Module):
|
||||
def __init__(self, values):
|
||||
self.recurrent = False
|
||||
self.values = values
|
||||
|
||||
def forward(self, _):
|
||||
return self.values
|
||||
|
||||
|
||||
class PPOTest(unittest.TestCase):
|
||||
def test_gae_computation(self):
|
||||
# GT taken from old PPO implementation.
|
||||
|
||||
env = PoleBalancing(environment_count=4)
|
||||
ppo = PPO(env, device="cpu", gae_lambda=0.97, gamma=0.99)
|
||||
|
||||
rewards = torch.tensor(
|
||||
[
|
||||
[-1.0000e02, -1.4055e-01, -3.0476e-02, -2.7149e-01, -1.1157e-01, -2.3366e-01, -3.3658e-01, -1.6447e-01],
|
||||
[
|
||||
-1.7633e-01,
|
||||
-2.6533e-01,
|
||||
-3.0786e-01,
|
||||
-2.6038e-01,
|
||||
-2.7176e-01,
|
||||
-2.1655e-01,
|
||||
-1.5441e-01,
|
||||
-2.9580e-01,
|
||||
],
|
||||
[-1.5952e-01, -1.5177e-01, -1.4296e-01, -1.6131e-01, -3.1395e-02, 2.8808e-03, -3.1242e-02, 4.8696e-03],
|
||||
[1.1407e-02, -1.0000e02, -6.2290e-02, -3.7030e-01, -2.7648e-01, -3.6655e-01, -2.8456e-01, -2.3165e-01],
|
||||
]
|
||||
)
|
||||
dones = torch.tensor(
|
||||
[
|
||||
[1, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 1, 0, 0, 0, 0, 0, 0],
|
||||
]
|
||||
)
|
||||
observations = torch.zeros((dones.shape[0], dones.shape[1], 24))
|
||||
timeouts = torch.zeros((dones.shape[0], dones.shape[1]))
|
||||
values = torch.tensor(
|
||||
[
|
||||
[-4.6342, -7.6510, -7.0166, -7.6137, -7.4130, -7.7071, -7.7413, -7.8301],
|
||||
[-7.0442, -7.0032, -6.9321, -6.7765, -6.5433, -6.3503, -6.2529, -5.9337],
|
||||
[-7.5753, -7.8146, -7.6142, -7.8443, -7.8791, -7.7973, -7.7853, -7.7724],
|
||||
[-6.4326, -6.1673, -7.6511, -7.7505, -8.0004, -7.8584, -7.5949, -7.9023],
|
||||
]
|
||||
)
|
||||
last_values = torch.tensor([-7.9343, -5.8734, -7.8527, -8.1257])
|
||||
|
||||
ppo.critic = FakeCritic(last_values)
|
||||
dataset = [
|
||||
{
|
||||
"dones": dones[:, i],
|
||||
"next_critic_observations": observations[:, i],
|
||||
"rewards": rewards[:, i],
|
||||
"timeouts": timeouts[:, i],
|
||||
"values": values[:, i],
|
||||
}
|
||||
for i in range(dones.shape[1])
|
||||
]
|
||||
|
||||
processed_dataset = ppo._process_dataset(dataset)
|
||||
processed_returns = torch.stack(
|
||||
[processed_dataset[i]["advantages"] + processed_dataset[i]["values"] for i in range(dones.shape[1])],
|
||||
dim=-1,
|
||||
)
|
||||
processed_advantages = torch.stack(
|
||||
[processed_dataset[i]["normalized_advantages"] for i in range(dones.shape[1])], dim=-1
|
||||
)
|
||||
|
||||
expected_returns = torch.tensor(
|
||||
[
|
||||
[-100.0000, -8.4983, -8.4863, -8.5699, -8.4122, -8.4054, -8.2702, -8.0194],
|
||||
[-7.2900, -7.1912, -6.9978, -6.7569, -6.5627, -6.3547, -6.1985, -6.1104],
|
||||
[-7.9179, -7.8374, -7.7679, -7.6976, -7.6041, -7.6446, -7.7229, -7.7693],
|
||||
[-96.2018, -100.0000, -9.0710, -9.1415, -8.8863, -8.7228, -8.4668, -8.2761],
|
||||
]
|
||||
)
|
||||
expected_advantages = torch.tensor(
|
||||
[
|
||||
[-3.1452, 0.3006, 0.2779, 0.2966, 0.2951, 0.3060, 0.3122, 0.3246],
|
||||
[0.3225, 0.3246, 0.3291, 0.3322, 0.3308, 0.3313, 0.3335, 0.3250],
|
||||
[0.3190, 0.3307, 0.3259, 0.3368, 0.3415, 0.3371, 0.3338, 0.3316],
|
||||
[-2.9412, -3.0893, 0.2797, 0.2808, 0.2992, 0.3000, 0.2997, 0.3179],
|
||||
]
|
||||
)
|
||||
|
||||
self.assertTrue(torch.isclose(processed_returns, expected_returns, atol=1e-4).all())
|
||||
self.assertTrue(torch.isclose(processed_advantages, expected_advantages, atol=1e-4).all())
|
|
@ -0,0 +1,144 @@
|
|||
import torch
|
||||
import unittest
|
||||
from rsl_rl.algorithms import PPO
|
||||
from rsl_rl.env.vec_env import VecEnv
|
||||
from rsl_rl.runners.runner import Runner
|
||||
|
||||
|
||||
class FakeNetwork(torch.nn.Module):
|
||||
def __init__(self, values):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_state = None
|
||||
self.recurrent = True
|
||||
self.values = values
|
||||
|
||||
self._hidden_size = 2
|
||||
|
||||
def forward(self, x, hidden_state=None):
|
||||
if not hidden_state:
|
||||
self.hidden_state = (self.hidden_state[0] + 1, self.hidden_state[1] - 1)
|
||||
|
||||
values = self.values.repeat((*x.shape[:-1], 1)).squeeze(-1)
|
||||
values.requires_grad_(True)
|
||||
|
||||
return values
|
||||
|
||||
def reset_full_hidden_state(self, batch_size=None):
|
||||
assert batch_size is None or batch_size == 4, f"batch_size={batch_size}"
|
||||
|
||||
self.hidden_state = (torch.zeros((1, 4, self._hidden_size)), torch.zeros((1, 4, self._hidden_size)))
|
||||
|
||||
def reset_hidden_state(self, indices):
|
||||
self.hidden_state[0][:, indices] = torch.zeros((len(indices), self._hidden_size))
|
||||
self.hidden_state[1][:, indices] = torch.zeros((len(indices), self._hidden_size))
|
||||
|
||||
|
||||
class FakeActorNetwork(FakeNetwork):
|
||||
def forward(self, x, compute_std=False, hidden_state=None):
|
||||
values = super().forward(x, hidden_state=hidden_state)
|
||||
|
||||
if compute_std:
|
||||
return values, torch.ones_like(values)
|
||||
|
||||
return values
|
||||
|
||||
|
||||
class FakeEnv(VecEnv):
|
||||
def __init__(self, dones=None, **kwargs):
|
||||
super().__init__(3, 3, **kwargs)
|
||||
|
||||
self.num_actions = 3
|
||||
self._extra = {"observations": {}, "time_outs": torch.zeros((self.num_envs, 1))}
|
||||
|
||||
self._step = 0
|
||||
self._dones = dones
|
||||
|
||||
self.reset()
|
||||
|
||||
def get_observations(self):
|
||||
return self._state_buf, self._extra
|
||||
|
||||
def get_privileged_observations(self):
|
||||
return self._state_buf, self._extra
|
||||
|
||||
def reset(self):
|
||||
self._state_buf = torch.zeros((self.num_envs, self.num_obs))
|
||||
|
||||
return self._state_buf, self._extra
|
||||
|
||||
def step(self, actions):
|
||||
assert actions.shape[0] == self.num_envs
|
||||
assert actions.shape[1] == self.num_actions
|
||||
|
||||
self._state_buf += actions
|
||||
|
||||
rewards = torch.zeros((self.num_envs))
|
||||
dones = torch.zeros((self.num_envs)) if self._dones is None else self._dones[self._step % self._dones.shape[0]]
|
||||
|
||||
self._step += 1
|
||||
|
||||
return self._state_buf, rewards, dones, self._extra
|
||||
|
||||
|
||||
class PPORecurrencyTest(unittest.TestCase):
|
||||
def test_draw_action_produces_hidden_state(self):
|
||||
"""Test that the hidden state is correctly added to the data dictionary when drawing actions."""
|
||||
env = FakeEnv(environment_count=4)
|
||||
ppo = PPO(env, device="cpu", recurrent=True)
|
||||
|
||||
ppo.actor = FakeActorNetwork(torch.ones(env.num_actions))
|
||||
ppo.critic = FakeNetwork(torch.zeros(1))
|
||||
|
||||
# Done during PPO.__init__, however we need to reset the hidden state here again since we are using a fake
|
||||
# network that is added after initialization.
|
||||
ppo.actor.reset_full_hidden_state(batch_size=env.num_envs)
|
||||
ppo.critic.reset_full_hidden_state(batch_size=env.num_envs)
|
||||
|
||||
ones = torch.ones((1, env.num_envs, ppo.actor._hidden_size))
|
||||
state, extra = env.reset()
|
||||
for ctr in range(10):
|
||||
_, data = ppo.draw_actions(state, extra)
|
||||
|
||||
# Actor state is changed every time an action is drawn.
|
||||
self.assertTrue(torch.allclose(data["actor_state_h"], ones * ctr))
|
||||
self.assertTrue(torch.allclose(data["actor_state_c"], -ones * ctr))
|
||||
# Critic state is only changed and saved when processing the transition (evaluating the action) so we can't
|
||||
# check it here.
|
||||
|
||||
def test_update_produces_hidden_state(self):
|
||||
"""Test that the hidden state is correctly added to the data dictionary when updating."""
|
||||
dones = torch.cat((torch.tensor([[0, 0, 0, 1]]), torch.zeros((4, 4)), torch.tensor([[1, 0, 0, 0]])), dim=0)
|
||||
|
||||
env = FakeEnv(dones=dones, environment_count=4)
|
||||
ppo = PPO(env, device="cpu", recurrent=True)
|
||||
runner = Runner(env, ppo, num_steps_per_env=6)
|
||||
|
||||
ppo.actor = FakeActorNetwork(torch.ones(env.num_actions))
|
||||
ppo.critic = FakeNetwork(torch.zeros(1))
|
||||
|
||||
ppo.actor.reset_full_hidden_state(batch_size=env.num_envs)
|
||||
ppo.critic.reset_full_hidden_state(batch_size=env.num_envs)
|
||||
|
||||
runner.learn(1)
|
||||
|
||||
state_h_0 = torch.tensor([[0, 0], [0, 0], [0, 0], [0, 0]])
|
||||
state_h_1 = torch.tensor([[1, 1], [1, 1], [1, 1], [0, 0]])
|
||||
state_h_2 = state_h_1 + 1
|
||||
state_h_3 = state_h_2 + 1
|
||||
state_h_4 = state_h_3 + 1
|
||||
state_h_5 = state_h_4 + 1
|
||||
state_h_6 = torch.tensor([[0, 0], [6, 6], [6, 6], [5, 5]])
|
||||
state_h = (
|
||||
torch.cat((state_h_0, state_h_1, state_h_2, state_h_3, state_h_4, state_h_5), dim=0).float().unsqueeze(1)
|
||||
)
|
||||
next_state_h = (
|
||||
torch.cat((state_h_1, state_h_2, state_h_3, state_h_4, state_h_5, state_h_6), dim=0).float().unsqueeze(1)
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(ppo.storage._data["critic_state_h"], state_h))
|
||||
self.assertTrue(torch.allclose(ppo.storage._data["critic_state_c"], -state_h))
|
||||
self.assertTrue(torch.allclose(ppo.storage._data["critic_next_state_h"], next_state_h))
|
||||
self.assertTrue(torch.allclose(ppo.storage._data["critic_next_state_c"], -next_state_h))
|
||||
self.assertTrue(torch.allclose(ppo.storage._data["actor_state_h"], state_h))
|
||||
self.assertTrue(torch.allclose(ppo.storage._data["actor_state_c"], -state_h))
|
|
@ -0,0 +1,286 @@
|
|||
import torch
|
||||
import unittest
|
||||
from rsl_rl.modules.quantile_network import QuantileNetwork
|
||||
|
||||
|
||||
class QuantileNetworkTest(unittest.TestCase):
|
||||
def test_l1_loss(self):
|
||||
qn = QuantileNetwork(10, 1, quantile_count=5)
|
||||
|
||||
prediction = torch.tensor(
|
||||
[
|
||||
[0.8510, 0.2329, 0.4244, 0.5241, 0.2144],
|
||||
[0.7693, 0.2522, 0.3909, 0.0858, 0.7914],
|
||||
[0.8701, 0.2144, 0.9661, 0.9975, 0.5043],
|
||||
[0.2653, 0.6951, 0.9787, 0.2244, 0.0430],
|
||||
[0.7907, 0.5209, 0.7276, 0.1735, 0.2757],
|
||||
[0.1696, 0.7167, 0.6363, 0.2188, 0.7025],
|
||||
[0.0445, 0.6008, 0.5334, 0.1838, 0.7387],
|
||||
[0.4934, 0.5117, 0.4488, 0.0591, 0.6442],
|
||||
]
|
||||
)
|
||||
target = torch.tensor(
|
||||
[
|
||||
[0.3918, 0.8979, 0.4347, 0.1076, 0.5303],
|
||||
[0.5449, 0.9974, 0.3197, 0.8686, 0.0631],
|
||||
[0.7397, 0.7734, 0.6559, 0.3020, 0.7229],
|
||||
[0.9519, 0.8138, 0.1502, 0.3445, 0.3356],
|
||||
[0.8970, 0.0910, 0.7536, 0.6069, 0.2556],
|
||||
[0.1741, 0.6863, 0.7142, 0.2911, 0.3142],
|
||||
[0.8835, 0.0215, 0.4774, 0.5362, 0.4998],
|
||||
[0.8037, 0.8269, 0.5518, 0.4368, 0.5323],
|
||||
]
|
||||
)
|
||||
|
||||
loss = qn.quantile_l1_loss(prediction, target)
|
||||
|
||||
self.assertAlmostEqual(loss.item(), 0.16419549)
|
||||
|
||||
def test_l1_loss_3d(self):
|
||||
qn = QuantileNetwork(10, 1, quantile_count=5)
|
||||
|
||||
prediction = torch.tensor(
|
||||
[
|
||||
[
|
||||
[0.8510, 0.2329, 0.4244, 0.5241, 0.2144],
|
||||
[0.7693, 0.2522, 0.3909, 0.0858, 0.7914],
|
||||
[0.8701, 0.2144, 0.9661, 0.9975, 0.5043],
|
||||
[0.2653, 0.6951, 0.9787, 0.2244, 0.0430],
|
||||
[0.7907, 0.5209, 0.7276, 0.1735, 0.2757],
|
||||
[0.1696, 0.7167, 0.6363, 0.2188, 0.7025],
|
||||
[0.0445, 0.6008, 0.5334, 0.1838, 0.7387],
|
||||
[0.4934, 0.5117, 0.4488, 0.0591, 0.6442],
|
||||
],
|
||||
[
|
||||
[0.6874, 0.6214, 0.7796, 0.8148, 0.2070],
|
||||
[0.0276, 0.5764, 0.5516, 0.9682, 0.6901],
|
||||
[0.4020, 0.7084, 0.9965, 0.4311, 0.3789],
|
||||
[0.5350, 0.9431, 0.1032, 0.6959, 0.4992],
|
||||
[0.5059, 0.5479, 0.2302, 0.6753, 0.1593],
|
||||
[0.6753, 0.4590, 0.9956, 0.6117, 0.1410],
|
||||
[0.7464, 0.7184, 0.2972, 0.7694, 0.7999],
|
||||
[0.3907, 0.2112, 0.6485, 0.0139, 0.6252],
|
||||
],
|
||||
]
|
||||
)
|
||||
target = torch.tensor(
|
||||
[
|
||||
[
|
||||
[0.3918, 0.8979, 0.4347, 0.1076, 0.5303],
|
||||
[0.5449, 0.9974, 0.3197, 0.8686, 0.0631],
|
||||
[0.7397, 0.7734, 0.6559, 0.3020, 0.7229],
|
||||
[0.9519, 0.8138, 0.1502, 0.3445, 0.3356],
|
||||
[0.8970, 0.0910, 0.7536, 0.6069, 0.2556],
|
||||
[0.1741, 0.6863, 0.7142, 0.2911, 0.3142],
|
||||
[0.8835, 0.0215, 0.4774, 0.5362, 0.4998],
|
||||
[0.8037, 0.8269, 0.5518, 0.4368, 0.5323],
|
||||
],
|
||||
[
|
||||
[0.5120, 0.7683, 0.3579, 0.8640, 0.4374],
|
||||
[0.2533, 0.3039, 0.2214, 0.7069, 0.3093],
|
||||
[0.6993, 0.4288, 0.0827, 0.9156, 0.2043],
|
||||
[0.6739, 0.2303, 0.3263, 0.6884, 0.3847],
|
||||
[0.3990, 0.1458, 0.8918, 0.8036, 0.5012],
|
||||
[0.9061, 0.2024, 0.7276, 0.8619, 0.1198],
|
||||
[0.7379, 0.2005, 0.7634, 0.5691, 0.6132],
|
||||
[0.4341, 0.5711, 0.1119, 0.4286, 0.7521],
|
||||
],
|
||||
]
|
||||
)
|
||||
|
||||
loss = qn.quantile_l1_loss(prediction, target)
|
||||
|
||||
self.assertAlmostEqual(loss.item(), 0.15836075)
|
||||
|
||||
def test_l1_loss_multi_output(self):
|
||||
qn = QuantileNetwork(10, 3, quantile_count=10)
|
||||
|
||||
prediction = torch.tensor(
|
||||
[
|
||||
[0.3003, 0.8692, 0.4608, 0.7158, 0.2640, 0.3928, 0.4557, 0.4620, 0.1331, 0.6356],
|
||||
[0.8867, 0.1521, 0.5827, 0.0501, 0.4401, 0.7216, 0.6081, 0.5758, 0.2772, 0.6048],
|
||||
[0.0763, 0.1609, 0.1860, 0.9173, 0.2121, 0.1920, 0.8509, 0.8588, 0.3321, 0.7202],
|
||||
[0.8375, 0.5339, 0.4287, 0.9228, 0.8519, 0.0420, 0.5736, 0.9156, 0.4444, 0.2039],
|
||||
[0.0704, 0.1833, 0.0839, 0.9573, 0.9852, 0.4191, 0.3562, 0.7225, 0.8481, 0.2096],
|
||||
[0.4054, 0.8172, 0.8737, 0.2138, 0.4455, 0.7538, 0.1936, 0.9346, 0.8710, 0.0178],
|
||||
[0.2139, 0.6619, 0.6889, 0.5726, 0.0595, 0.3278, 0.7673, 0.0803, 0.0374, 0.9011],
|
||||
[0.2757, 0.0309, 0.8913, 0.0958, 0.1828, 0.9624, 0.6529, 0.7451, 0.9996, 0.8877],
|
||||
[0.0722, 0.4240, 0.0716, 0.3199, 0.5570, 0.1056, 0.5950, 0.9926, 0.2991, 0.7334],
|
||||
[0.0576, 0.6353, 0.5078, 0.4456, 0.9119, 0.6897, 0.1720, 0.5172, 0.9939, 0.5044],
|
||||
[0.6300, 0.2304, 0.4064, 0.9195, 0.3299, 0.8631, 0.5842, 0.6751, 0.2964, 0.1215],
|
||||
[0.7418, 0.5448, 0.7615, 0.6333, 0.9255, 0.1129, 0.0552, 0.4198, 0.9953, 0.7482],
|
||||
[0.9910, 0.7644, 0.7047, 0.1395, 0.3688, 0.7688, 0.8574, 0.3494, 0.6153, 0.1286],
|
||||
[0.2325, 0.7908, 0.3036, 0.4504, 0.3775, 0.6004, 0.0199, 0.9581, 0.8078, 0.8337],
|
||||
[0.4038, 0.8313, 0.5441, 0.4778, 0.5777, 0.0580, 0.5314, 0.5336, 0.0740, 0.0094],
|
||||
[0.9025, 0.5814, 0.4711, 0.2683, 0.4443, 0.5799, 0.6703, 0.2678, 0.7538, 0.1317],
|
||||
[0.6755, 0.5696, 0.3334, 0.9146, 0.6203, 0.2080, 0.0799, 0.0059, 0.8347, 0.1874],
|
||||
[0.0932, 0.0264, 0.9006, 0.3124, 0.3421, 0.8271, 0.3495, 0.2814, 0.9888, 0.5042],
|
||||
[0.4893, 0.3514, 0.2564, 0.8117, 0.3738, 0.9085, 0.3055, 0.1456, 0.3624, 0.4095],
|
||||
[0.0726, 0.2145, 0.6295, 0.7423, 0.1292, 0.7570, 0.4645, 0.0775, 0.1280, 0.7312],
|
||||
[0.8763, 0.5302, 0.8627, 0.0429, 0.2833, 0.4745, 0.6308, 0.2245, 0.2755, 0.6823],
|
||||
[0.9997, 0.3519, 0.0312, 0.1468, 0.5145, 0.0286, 0.6333, 0.1323, 0.2264, 0.9109],
|
||||
[0.7742, 0.4857, 0.0413, 0.4523, 0.6847, 0.5774, 0.9478, 0.5861, 0.9834, 0.9437],
|
||||
[0.7590, 0.5697, 0.7509, 0.3562, 0.9926, 0.3380, 0.0337, 0.7871, 0.1351, 0.9184],
|
||||
[0.5701, 0.0234, 0.8088, 0.0681, 0.7090, 0.5925, 0.5266, 0.7198, 0.4121, 0.0268],
|
||||
[0.5377, 0.1420, 0.2649, 0.0885, 0.1987, 0.1475, 0.1562, 0.2283, 0.9447, 0.4679],
|
||||
[0.0306, 0.9763, 0.1234, 0.5009, 0.8800, 0.9409, 0.3525, 0.7264, 0.2209, 0.1436],
|
||||
[0.2492, 0.4041, 0.9044, 0.3730, 0.3152, 0.7515, 0.2614, 0.9726, 0.6402, 0.5211],
|
||||
[0.8626, 0.2828, 0.6946, 0.7066, 0.4395, 0.3015, 0.2643, 0.4421, 0.6036, 0.9009],
|
||||
[0.7721, 0.1706, 0.7043, 0.4097, 0.7685, 0.3818, 0.1468, 0.6452, 0.1102, 0.1826],
|
||||
[0.7156, 0.1795, 0.5574, 0.9478, 0.0058, 0.8037, 0.8712, 0.7730, 0.5638, 0.5843],
|
||||
[0.8775, 0.6133, 0.4118, 0.3038, 0.2612, 0.2424, 0.8960, 0.8194, 0.3588, 0.3198],
|
||||
]
|
||||
)
|
||||
|
||||
target = torch.tensor(
|
||||
[
|
||||
[0.0986, 0.4029, 0.3110, 0.9976, 0.5668, 0.2658, 0.0660, 0.8492, 0.7872, 0.6368],
|
||||
[0.3556, 0.9007, 0.0227, 0.7684, 0.0105, 0.9890, 0.7468, 0.0642, 0.5164, 0.1976],
|
||||
[0.1331, 0.0998, 0.0959, 0.5596, 0.5984, 0.3880, 0.8050, 0.8320, 0.8977, 0.3486],
|
||||
[0.3297, 0.8110, 0.2844, 0.4594, 0.0739, 0.2865, 0.2957, 0.9357, 0.9898, 0.4419],
|
||||
[0.0495, 0.2826, 0.8306, 0.2968, 0.5690, 0.7251, 0.5947, 0.7526, 0.5076, 0.6480],
|
||||
[0.0381, 0.8645, 0.7774, 0.9158, 0.9682, 0.5851, 0.0913, 0.8948, 0.1251, 0.1205],
|
||||
[0.9059, 0.2758, 0.1948, 0.2694, 0.0946, 0.4381, 0.4667, 0.2176, 0.3494, 0.6073],
|
||||
[0.1778, 0.8632, 0.3015, 0.2882, 0.4214, 0.2420, 0.8394, 0.1468, 0.9679, 0.6730],
|
||||
[0.2400, 0.4344, 0.9765, 0.6544, 0.6338, 0.3434, 0.4776, 0.7981, 0.2008, 0.2267],
|
||||
[0.5574, 0.8110, 0.0264, 0.4199, 0.8178, 0.8421, 0.8237, 0.2623, 0.8025, 0.9030],
|
||||
[0.8652, 0.2872, 0.9463, 0.5543, 0.4866, 0.2842, 0.6692, 0.2306, 0.3136, 0.4570],
|
||||
[0.0651, 0.8955, 0.7531, 0.9373, 0.0265, 0.0795, 0.7755, 0.1123, 0.1920, 0.3273],
|
||||
[0.9824, 0.4177, 0.2729, 0.9447, 0.3987, 0.5495, 0.3674, 0.8067, 0.8668, 0.2394],
|
||||
[0.4874, 0.3616, 0.7577, 0.6439, 0.2927, 0.8110, 0.6821, 0.0702, 0.5514, 0.7358],
|
||||
[0.3627, 0.6392, 0.9085, 0.3646, 0.6051, 0.0586, 0.8763, 0.3899, 0.3242, 0.4598],
|
||||
[0.0167, 0.0558, 0.3862, 0.7017, 0.0403, 0.6604, 0.9992, 0.2337, 0.5128, 0.1959],
|
||||
[0.7774, 0.9201, 0.0405, 0.7894, 0.1406, 0.2458, 0.2616, 0.8787, 0.8158, 0.8591],
|
||||
[0.3225, 0.9827, 0.4032, 0.2621, 0.7949, 0.9796, 0.9480, 0.3353, 0.1430, 0.5747],
|
||||
[0.4734, 0.8714, 0.9320, 0.4265, 0.7765, 0.6980, 0.1587, 0.8784, 0.7119, 0.5141],
|
||||
[0.7263, 0.4754, 0.8234, 0.0649, 0.4343, 0.5201, 0.8274, 0.9632, 0.3525, 0.8893],
|
||||
[0.3324, 0.0142, 0.7222, 0.5026, 0.6011, 0.9275, 0.9351, 0.9236, 0.2621, 0.0768],
|
||||
[0.8456, 0.1005, 0.5550, 0.0586, 0.3811, 0.0168, 0.9724, 0.9225, 0.7242, 0.0678],
|
||||
[0.2167, 0.5423, 0.9059, 0.3320, 0.4026, 0.2128, 0.4562, 0.3564, 0.2573, 0.1076],
|
||||
[0.8385, 0.2233, 0.0736, 0.3407, 0.4702, 0.1668, 0.5174, 0.4154, 0.4407, 0.1843],
|
||||
[0.1828, 0.5321, 0.6651, 0.4108, 0.5736, 0.4012, 0.0434, 0.0034, 0.9282, 0.3111],
|
||||
[0.1754, 0.8750, 0.6629, 0.7052, 0.9739, 0.7441, 0.8954, 0.9273, 0.3836, 0.5735],
|
||||
[0.5586, 0.0381, 0.1493, 0.8575, 0.9351, 0.5222, 0.5600, 0.2369, 0.9217, 0.2545],
|
||||
[0.1054, 0.8020, 0.8463, 0.6495, 0.3011, 0.3734, 0.7263, 0.8736, 0.9258, 0.5804],
|
||||
[0.7614, 0.4748, 0.6588, 0.7717, 0.9811, 0.1659, 0.7851, 0.2135, 0.1767, 0.6724],
|
||||
[0.7655, 0.8571, 0.4224, 0.9397, 0.1363, 0.9431, 0.9326, 0.3762, 0.1077, 0.9514],
|
||||
[0.4115, 0.2169, 0.1340, 0.6564, 0.9989, 0.8068, 0.0387, 0.5064, 0.9964, 0.9427],
|
||||
[0.5760, 0.2967, 0.3891, 0.6596, 0.8037, 0.1060, 0.0102, 0.8672, 0.5922, 0.6684],
|
||||
]
|
||||
)
|
||||
|
||||
loss = qn.quantile_l1_loss(prediction, target)
|
||||
|
||||
self.assertAlmostEqual(loss.item(), 0.17235948)
|
||||
|
||||
def test_quantile_huber_loss(self):
|
||||
qn = QuantileNetwork(10, 1, quantile_count=5)
|
||||
|
||||
prediction = torch.tensor(
|
||||
[
|
||||
[0.8510, 0.2329, 0.4244, 0.5241, 0.2144],
|
||||
[0.7693, 0.2522, 0.3909, 0.0858, 0.7914],
|
||||
[0.8701, 0.2144, 0.9661, 0.9975, 0.5043],
|
||||
[0.2653, 0.6951, 0.9787, 0.2244, 0.0430],
|
||||
[0.7907, 0.5209, 0.7276, 0.1735, 0.2757],
|
||||
[0.1696, 0.7167, 0.6363, 0.2188, 0.7025],
|
||||
[0.0445, 0.6008, 0.5334, 0.1838, 0.7387],
|
||||
[0.4934, 0.5117, 0.4488, 0.0591, 0.6442],
|
||||
]
|
||||
)
|
||||
target = torch.tensor(
|
||||
[
|
||||
[0.3918, 0.8979, 0.4347, 0.1076, 0.5303],
|
||||
[0.5449, 0.9974, 0.3197, 0.8686, 0.0631],
|
||||
[0.7397, 0.7734, 0.6559, 0.3020, 0.7229],
|
||||
[0.9519, 0.8138, 0.1502, 0.3445, 0.3356],
|
||||
[0.8970, 0.0910, 0.7536, 0.6069, 0.2556],
|
||||
[0.1741, 0.6863, 0.7142, 0.2911, 0.3142],
|
||||
[0.8835, 0.0215, 0.4774, 0.5362, 0.4998],
|
||||
[0.8037, 0.8269, 0.5518, 0.4368, 0.5323],
|
||||
]
|
||||
)
|
||||
|
||||
loss = qn.quantile_huber_loss(prediction, target)
|
||||
|
||||
self.assertAlmostEqual(loss.item(), 0.04035041)
|
||||
|
||||
def test_sample_energy_loss(self):
|
||||
qn = QuantileNetwork(10, 1, quantile_count=5)
|
||||
|
||||
prediction = torch.tensor(
|
||||
[
|
||||
[0.9813, 0.5331, 0.3298, 0.2428, 0.0737],
|
||||
[0.5442, 0.9623, 0.6070, 0.9360, 0.1145],
|
||||
[0.3642, 0.0887, 0.1696, 0.8027, 0.7121],
|
||||
[0.2005, 0.9889, 0.4350, 0.0301, 0.4546],
|
||||
[0.8360, 0.6766, 0.2257, 0.7589, 0.3443],
|
||||
[0.0835, 0.1747, 0.1734, 0.6668, 0.4522],
|
||||
[0.0851, 0.3146, 0.0316, 0.2250, 0.5729],
|
||||
[0.7725, 0.4596, 0.2495, 0.3633, 0.6340],
|
||||
]
|
||||
)
|
||||
target = torch.tensor(
|
||||
[
|
||||
[0.5365, 0.1495, 0.8120, 0.2595, 0.1409],
|
||||
[0.7784, 0.7070, 0.9066, 0.0123, 0.5587],
|
||||
[0.9097, 0.0773, 0.9430, 0.2747, 0.1912],
|
||||
[0.2307, 0.5068, 0.4624, 0.6708, 0.2844],
|
||||
[0.3356, 0.5885, 0.2484, 0.8468, 0.1833],
|
||||
[0.3354, 0.8831, 0.3489, 0.7165, 0.7953],
|
||||
[0.7577, 0.8578, 0.2735, 0.1029, 0.5621],
|
||||
[0.9124, 0.3476, 0.2012, 0.5830, 0.4615],
|
||||
]
|
||||
)
|
||||
|
||||
loss = qn.sample_energy_loss(prediction, target)
|
||||
|
||||
self.assertAlmostEqual(loss.item(), 0.09165202)
|
||||
|
||||
def test_cvar(self):
|
||||
qn = QuantileNetwork(10, 1, quantile_count=5)
|
||||
measure = qn.measures[qn.measure_cvar](qn, 0.5)
|
||||
|
||||
# Quantiles for 3 agents
|
||||
input = torch.tensor(
|
||||
[
|
||||
[0.1056, 0.0609, 0.3523, 0.3033, 0.1779],
|
||||
[0.2049, 0.1425, 0.0767, 0.1868, 0.3891],
|
||||
[0.1899, 0.1527, 0.2420, 0.2623, 0.1532],
|
||||
]
|
||||
)
|
||||
correct_output = torch.tensor(
|
||||
[
|
||||
(0.4 * 0.0609 + 0.4 * 0.1056 + 0.2 * 0.1779),
|
||||
(0.4 * 0.0767 + 0.4 * 0.1425 + 0.2 * 0.1868),
|
||||
(0.4 * 0.1527 + 0.4 * 0.1532 + 0.2 * 0.1899),
|
||||
]
|
||||
)
|
||||
|
||||
computed_output = measure(input)
|
||||
|
||||
self.assertTrue(torch.isclose(computed_output, correct_output).all())
|
||||
|
||||
def test_cvar_adaptive(self):
|
||||
qn = QuantileNetwork(10, 1, quantile_count=5)
|
||||
|
||||
input = torch.tensor(
|
||||
[
|
||||
[0.95, 0.21, 0.27, 0.26, 0.19],
|
||||
[0.38, 0.34, 0.18, 0.32, 0.97],
|
||||
[0.70, 0.24, 0.38, 0.89, 0.96],
|
||||
]
|
||||
)
|
||||
confidence_levels = torch.tensor([0.1, 0.7, 0.9])
|
||||
correct_output = torch.tensor(
|
||||
[
|
||||
0.19,
|
||||
(0.18 / 3.5 + 0.32 / 3.5 + 0.34 / 3.5 + 0.38 / 7.0),
|
||||
(0.24 / 4.5 + 0.38 / 4.5 + 0.70 / 4.5 + 0.89 / 4.5 + 0.96 / 9.0),
|
||||
]
|
||||
)
|
||||
|
||||
measure = qn.measures[qn.measure_cvar](qn, confidence_levels)
|
||||
computed_output = measure(input)
|
||||
|
||||
self.assertTrue(torch.isclose(computed_output, correct_output).all())
|
|
@ -0,0 +1,33 @@
|
|||
import torch
|
||||
import unittest
|
||||
|
||||
from rsl_rl.utils.recurrency import trajectories_to_transitions, transitions_to_trajectories
|
||||
|
||||
|
||||
class TrajectoryConversionTest(unittest.TestCase):
|
||||
def test_basic_conversion(self):
|
||||
input = torch.rand(128, 24)
|
||||
dones = (torch.rand(128, 24) > 0.8).float()
|
||||
|
||||
trajectories, data = transitions_to_trajectories(input, dones)
|
||||
transitions = trajectories_to_transitions(trajectories, data)
|
||||
|
||||
self.assertTrue(torch.allclose(input, transitions))
|
||||
|
||||
def test_2d_observations(self):
|
||||
input = torch.rand(128, 24, 32)
|
||||
dones = (torch.rand(128, 24) > 0.8).float()
|
||||
|
||||
trajectories, data = transitions_to_trajectories(input, dones)
|
||||
transitions = trajectories_to_transitions(trajectories, data)
|
||||
|
||||
self.assertTrue(torch.allclose(input, transitions))
|
||||
|
||||
def test_batch_first(self):
|
||||
input = torch.rand(128, 24, 32)
|
||||
dones = (torch.rand(128, 24) > 0.8).float()
|
||||
|
||||
trajectories, data = transitions_to_trajectories(input, dones, batch_first=True)
|
||||
transitions = trajectories_to_transitions(trajectories, data)
|
||||
|
||||
self.assertTrue(torch.allclose(input, transitions))
|
|
@ -0,0 +1,58 @@
|
|||
import unittest
|
||||
import torch
|
||||
|
||||
from rsl_rl.modules import Transformer # Assuming the Transformer class is in a module named my_module
|
||||
|
||||
|
||||
class TestTransformer(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.input_size = 9
|
||||
self.output_size = 12
|
||||
self.hidden_size = 64
|
||||
self.block_count = 2
|
||||
self.context_length = 32
|
||||
self.head_count = 4
|
||||
self.batch_size = 10
|
||||
self.sequence_length = 16
|
||||
|
||||
self.transformer = Transformer(
|
||||
self.input_size, self.output_size, self.hidden_size, self.block_count, self.context_length, self.head_count
|
||||
)
|
||||
|
||||
def test_num_layers(self):
|
||||
self.assertEqual(self.transformer.num_layers, self.context_length // 2)
|
||||
|
||||
def test_reset_hidden_state(self):
|
||||
hidden_state = self.transformer.reset_hidden_state(self.batch_size)
|
||||
self.assertIsInstance(hidden_state, tuple)
|
||||
self.assertEqual(len(hidden_state), 2)
|
||||
self.assertTrue(
|
||||
torch.equal(hidden_state[0], torch.zeros((self.transformer.num_layers, self.batch_size, self.hidden_size)))
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.equal(hidden_state[1], torch.zeros((self.transformer.num_layers, self.batch_size, self.hidden_size)))
|
||||
)
|
||||
|
||||
def test_step(self):
|
||||
x = torch.rand(self.sequence_length, self.batch_size, self.input_size)
|
||||
context = torch.rand(self.context_length, self.batch_size, self.hidden_size)
|
||||
|
||||
out, new_context = self.transformer.step(x, context)
|
||||
|
||||
self.assertEqual(out.shape, (self.sequence_length, self.batch_size, self.output_size))
|
||||
self.assertEqual(new_context.shape, (self.context_length, self.batch_size, self.hidden_size))
|
||||
|
||||
def test_forward(self):
|
||||
x = torch.rand(self.sequence_length, self.batch_size, self.input_size)
|
||||
hidden_state = self.transformer.reset_hidden_state(self.batch_size)
|
||||
|
||||
out, new_hidden_state = self.transformer.forward(x, hidden_state)
|
||||
|
||||
self.assertEqual(out.shape, (self.sequence_length, self.batch_size, self.output_size))
|
||||
self.assertEqual(len(new_hidden_state), 2)
|
||||
self.assertEqual(new_hidden_state[0].shape, (self.transformer.num_layers, self.batch_size, self.hidden_size))
|
||||
self.assertEqual(new_hidden_state[1].shape, (self.transformer.num_layers, self.batch_size, self.hidden_size))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue