initial open-source commit

This commit is contained in:
Nikita Rudin 2021-10-27 10:18:22 +02:00
commit ff9e971c97
20 changed files with 1420 additions and 0 deletions

12
.gitignore vendored Normal file
View File

@ -0,0 +1,12 @@
# IDEs
.idea
# builds
*.egg-info
# cache
__pycache__
.pytest_cache
# vs code
.vscode

30
LICENSE Normal file
View File

@ -0,0 +1,30 @@
Copyright (c) 2021, ETH Zurich, Nikita Rudin
Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES
All rights reserved.
Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its contributors
may be used to endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
See licenses/dependencies for license information of dependencies of this package.

13
README.md Normal file
View File

@ -0,0 +1,13 @@
# RSL RL
Fast and simple implementation of RL algorithms, designed to run fully on GPU.
This code is an evolution of `rl-pytorch` provided with NVIDIA's Isaac GYM.
Only PPO is implemented for now. More algorithms will be added later.
Contributions are welcome.
**Maintainer**: Nikita Rudin
**Affiliation**: Robotic Systems Lab, ETH Zurich & NVIDIA
**Contact**: rudinn@ethz.ch

View File

@ -0,0 +1,30 @@
Copyright (c) 2005-2021, NumPy Developers.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following
disclaimer in the documentation and/or other materials provided
with the distribution.
* Neither the name of the NumPy Developers nor the names of any
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -0,0 +1,73 @@
From PyTorch:
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
Copyright (c) 2011-2013 NYU (Clement Farabet)
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
From Caffe2:
Copyright (c) 2016-present, Facebook Inc. All rights reserved.
All contributions by Facebook:
Copyright (c) 2016 Facebook Inc.
All contributions by Google:
Copyright (c) 2015 Google Inc.
All rights reserved.
All contributions by Yangqing Jia:
Copyright (c) 2015 Yangqing Jia
All rights reserved.
All contributions by Kakao Brain:
Copyright 2019-2020 Kakao Brain
All contributions from Caffe:
Copyright(c) 2013, 2014, 2015, the respective contributors
All rights reserved.
All other contributions:
Copyright(c) 2015, 2016 the respective contributors
All rights reserved.
Caffe2 uses a copyright model similar to Caffe: each contributor holds
copyright over their contributions to Caffe2. The project versioning records
all such contribution and copyright details. If a contributor wants to further
mark their specific copyright on a particular contribution, they should
indicate their copyright solely in the commit message of the change when it is
committed.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
and IDIAP Research Institute nor the names of its contributors may be
used to endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.

29
rsl_rl/__init__.py Normal file
View File

@ -0,0 +1,29 @@
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Copyright (c) 2021 ETH Zurich, Nikita Rudin

View File

@ -0,0 +1,31 @@
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Copyright (c) 2021 ETH Zurich, Nikita Rudin
from .ppo import PPO

187
rsl_rl/algorithms/ppo.py Normal file
View File

@ -0,0 +1,187 @@
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Copyright (c) 2021 ETH Zurich, Nikita Rudin
import torch
import torch.nn as nn
import torch.optim as optim
from rsl_rl.modules import ActorCritic
from rsl_rl.storage import RolloutStorage
class PPO:
actor_critic: ActorCritic
def __init__(self,
actor_critic,
num_learning_epochs=1,
num_mini_batches=1,
clip_param=0.2,
gamma=0.998,
lam=0.95,
value_loss_coef=1.0,
entropy_coef=0.0,
learning_rate=1e-3,
max_grad_norm=1.0,
use_clipped_value_loss=True,
schedule="fixed",
desired_kl=0.01,
device='cpu',
):
self.device = device
self.desired_kl = desired_kl
self.schedule = schedule
self.learning_rate = learning_rate
# PPO components
self.actor_critic = actor_critic
self.actor_critic.to(self.device)
self.storage = None # initialized later
self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=learning_rate)
self.transition = RolloutStorage.Transition()
# PPO parameters
self.clip_param = clip_param
self.num_learning_epochs = num_learning_epochs
self.num_mini_batches = num_mini_batches
self.value_loss_coef = value_loss_coef
self.entropy_coef = entropy_coef
self.gamma = gamma
self.lam = lam
self.max_grad_norm = max_grad_norm
self.use_clipped_value_loss = use_clipped_value_loss
def init_storage(self, num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape):
self.storage = RolloutStorage(num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape, self.device)
def test_mode(self):
self.actor_critic.test()
def train_mode(self):
self.actor_critic.train()
def act(self, obs, critic_obs):
if self.actor_critic.is_recurrent:
self.transition.hidden_states = self.actor_critic.get_hidden_states()
# Compute the actions and values
self.transition.actions = self.actor_critic.act(obs).detach()
self.transition.values = self.actor_critic.evaluate(critic_obs).detach()
self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach()
self.transition.action_mean = self.actor_critic.action_mean.detach()
self.transition.action_sigma = self.actor_critic.action_std.detach()
# need to record obs and critic_obs before env.step()
self.transition.observations = obs
self.transition.critic_observations = critic_obs
return self.transition.actions
def process_env_step(self, rewards, dones, infos):
self.transition.rewards = rewards.clone()
self.transition.dones = dones
# Bootstrapping on time outs
if 'time_outs' in infos:
self.transition.rewards += self.gamma * torch.squeeze(self.transition.values * infos['time_outs'].unsqueeze(1).to(self.device), 1)
# Record the transition
self.storage.add_transitions(self.transition)
self.transition.clear()
self.actor_critic.reset(dones)
def compute_returns(self, last_critic_obs):
last_values= self.actor_critic.evaluate(last_critic_obs).detach()
self.storage.compute_returns(last_values, self.gamma, self.lam)
def update(self):
mean_value_loss = 0
mean_surrogate_loss = 0
if self.actor_critic.is_recurrent:
generator = self.storage.reccurent_mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
else:
generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
for obs_batch, critic_obs_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, \
old_mu_batch, old_sigma_batch, hid_states_batch, masks_batch in generator:
self.actor_critic.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0])
actions_log_prob_batch = self.actor_critic.get_actions_log_prob(actions_batch)
value_batch = self.actor_critic.evaluate(critic_obs_batch, masks=masks_batch, hidden_states=hid_states_batch[1])
mu_batch = self.actor_critic.action_mean
sigma_batch = self.actor_critic.action_std
entropy_batch = self.actor_critic.entropy
# KL
if self.desired_kl != None and self.schedule == 'adaptive':
with torch.inference_mode():
kl = torch.sum(
torch.log(sigma_batch / old_sigma_batch + 1.e-5) + (torch.square(old_sigma_batch) + torch.square(old_mu_batch - mu_batch)) / (2.0 * torch.square(sigma_batch)) - 0.5, axis=-1)
kl_mean = torch.mean(kl)
if kl_mean > self.desired_kl * 2.0:
self.learning_rate = max(1e-5, self.learning_rate / 1.5)
elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0:
self.learning_rate = min(1e-2, self.learning_rate * 1.5)
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.learning_rate
# Surrogate loss
ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch))
surrogate = -torch.squeeze(advantages_batch) * ratio
surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp(ratio, 1.0 - self.clip_param,
1.0 + self.clip_param)
surrogate_loss = torch.max(surrogate, surrogate_clipped).mean()
# Value function loss
if self.use_clipped_value_loss:
value_clipped = target_values_batch + (value_batch - target_values_batch).clamp(-self.clip_param,
self.clip_param)
value_losses = (value_batch - returns_batch).pow(2)
value_losses_clipped = (value_clipped - returns_batch).pow(2)
value_loss = torch.max(value_losses, value_losses_clipped).mean()
else:
value_loss = (returns_batch - value_batch).pow(2).mean()
loss = surrogate_loss + self.value_loss_coef * value_loss - self.entropy_coef * entropy_batch.mean()
# Gradient step
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.max_grad_norm)
self.optimizer.step()
mean_value_loss += value_loss.item()
mean_surrogate_loss += surrogate_loss.item()
num_updates = self.num_learning_epochs * self.num_mini_batches
mean_value_loss /= num_updates
mean_surrogate_loss /= num_updates
self.storage.clear()
return mean_value_loss, mean_surrogate_loss

31
rsl_rl/env/__init__.py vendored Normal file
View File

@ -0,0 +1,31 @@
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Copyright (c) 2021 ETH Zurich, Nikita Rudin
from .vec_env import VecEnv

60
rsl_rl/env/vec_env.py vendored Normal file
View File

@ -0,0 +1,60 @@
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Copyright (c) 2021 ETH Zurich, Nikita Rudin
from abc import ABC, abstractmethod
import torch
from typing import Tuple, Union
# minimal interface of the environment
class VecEnv(ABC):
num_envs: int
num_obs: int
num_privileged_obs: int
num_actions: int
max_episode_length: int
privileged_obs_buf: torch.Tensor
obs_buf: torch.Tensor
rew_buf: torch.Tensor
reset_buf: torch.Tensor
episode_length_buf: torch.Tensor # current episode duration
extras: dict
device: torch.device
@abstractmethod
def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, Union[torch.Tensor, None], torch.Tensor, torch.Tensor, dict]:
pass
@abstractmethod
def reset(self, env_ids: Union[list, torch.Tensor]):
pass
@abstractmethod
def get_observations(self) -> torch.Tensor:
pass
@abstractmethod
def get_privileged_observations(self) -> Union[torch.Tensor, None]:
pass

View File

@ -0,0 +1,32 @@
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Copyright (c) 2021 ETH Zurich, Nikita Rudin
from .actor_critic import ActorCritic
from .actor_critic_recurrent import ActorCriticRecurrent

View File

@ -0,0 +1,155 @@
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Copyright (c) 2021 ETH Zurich, Nikita Rudin
import numpy as np
import torch
import torch.nn as nn
from torch.distributions import Normal
from torch.nn.modules import rnn
class ActorCritic(nn.Module):
is_recurrent = False
def __init__(self, num_actor_obs,
num_critic_obs,
num_actions,
actor_hidden_dims=[256, 256, 256],
critic_hidden_dims=[256, 256, 256],
activation='elu',
init_noise_std=1.0,
**kwargs):
if kwargs:
print("ActorCritic.__init__ got unexpected arguments, which will be ignored: " + str([key for key in kwargs.keys()]))
super(ActorCritic, self).__init__()
activation = get_activation(activation)
mlp_input_dim_a = num_actor_obs
mlp_input_dim_c = num_critic_obs
# Policy
actor_layers = []
actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0]))
actor_layers.append(activation)
for l in range(len(actor_hidden_dims)):
if l == len(actor_hidden_dims) - 1:
actor_layers.append(nn.Linear(actor_hidden_dims[l], num_actions))
else:
actor_layers.append(nn.Linear(actor_hidden_dims[l], actor_hidden_dims[l + 1]))
actor_layers.append(activation)
self.actor = nn.Sequential(*actor_layers)
# Value function
critic_layers = []
critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0]))
critic_layers.append(activation)
for l in range(len(critic_hidden_dims)):
if l == len(critic_hidden_dims) - 1:
critic_layers.append(nn.Linear(critic_hidden_dims[l], 1))
else:
critic_layers.append(nn.Linear(critic_hidden_dims[l], critic_hidden_dims[l + 1]))
critic_layers.append(activation)
self.critic = nn.Sequential(*critic_layers)
print(f"Actor MLP: {self.actor}")
print(f"Critic MLP: {self.critic}")
# Action noise
self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
self.distribution = None
# disable args validation for speedup
Normal.set_default_validate_args = False
# seems that we get better performance without init
# self.init_memory_weights(self.memory_a, 0.001, 0.)
# self.init_memory_weights(self.memory_c, 0.001, 0.)
@staticmethod
# not used at the moment
def init_weights(sequential, scales):
[torch.nn.init.orthogonal_(module.weight, gain=scales[idx]) for idx, module in
enumerate(mod for mod in sequential if isinstance(mod, nn.Linear))]
def reset(self, dones=None):
pass
def forward(self):
raise NotImplementedError
@property
def action_mean(self):
return self.distribution.mean
@property
def action_std(self):
return self.distribution.stddev
@property
def entropy(self):
return self.distribution.entropy().sum(dim=-1)
def update_distribution(self, observations):
mean = self.actor(observations)
self.distribution = Normal(mean, mean*0. + self.std)
def act(self, observations, **kwargs):
self.update_distribution(observations)
return self.distribution.sample()
def get_actions_log_prob(self, actions):
return self.distribution.log_prob(actions).sum(dim=-1)
def act_inference(self, observations):
actions_mean = self.actor(observations)
return actions_mean
def evaluate(self, critic_observations, **kwargs):
value = self.critic(critic_observations)
return value
def get_activation(act_name):
if act_name == "elu":
return nn.ELU()
elif act_name == "selu":
return nn.SELU()
elif act_name == "relu":
return nn.ReLU()
elif act_name == "crelu":
return nn.ReLU()
elif act_name == "lrelu":
return nn.LeakyReLU()
elif act_name == "tanh":
return nn.Tanh()
elif act_name == "sigmoid":
return nn.Sigmoid()
else:
print("invalid activation function!")
return None

View File

@ -0,0 +1,116 @@
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Copyright (c) 2021 ETH Zurich, Nikita Rudin
import numpy as np
import torch
import torch.nn as nn
from torch.distributions import Normal
from torch.nn.modules import rnn
from .actor_critic import ActorCritic, get_activation
from rsl_rl.utils import unpad_trajectories
class ActorCriticRecurrent(ActorCritic):
is_recurrent = True
def __init__(self, num_actor_obs,
num_critic_obs,
num_actions,
actor_hidden_dims=[256, 256, 256],
critic_hidden_dims=[256, 256, 256],
activation='elu',
rnn_type='lstm',
rnn_hidden_size=256,
rnn_num_layers=1,
init_noise_std=1.0,
**kwargs):
if kwargs:
print("ActorCriticRecurrent.__init__ got unexpected arguments, which will be ignored: " + str(kwargs.keys()),)
super().__init__(num_actor_obs=rnn_hidden_size,
num_critic_obs=rnn_hidden_size,
num_actions=num_actions,
actor_hidden_dims=actor_hidden_dims,
critic_hidden_dims=critic_hidden_dims,
activation=activation,
init_noise_std=init_noise_std)
activation = get_activation(activation)
self.memory_a = Memory(num_actor_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size)
self.memory_c = Memory(num_critic_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size)
print(f"Actor RNN: {self.memory_a}")
print(f"Critic RNN: {self.memory_c}")
def reset(self, dones=None):
self.memory_a.reset(dones)
self.memory_c.reset(dones)
def act(self, observations, masks=None, hidden_states=None):
input_a = self.memory_a(observations, masks, hidden_states)
return super().act(input_a.squeeze(0))
def act_inference(self, observations):
input_a = self.memory_a(observations)
return super().act_inference(input_a.squeeze(0))
def evaluate(self, critic_observations, masks=None, hidden_states=None):
input_c = self.memory_c(critic_observations, masks, hidden_states)
return super().evaluate(input_c.squeeze(0))
def get_hidden_states(self):
return self.memory_a.hidden_states, self.memory_c.hidden_states
class Memory(torch.nn.Module):
def __init__(self, input_size, type='lstm', num_layers=1, hidden_size=256):
super().__init__()
# RNN
rnn_cls = nn.GRU if type.lower() == 'gru' else nn.LSTM
self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers)
self.hidden_states = None
def forward(self, input, masks=None, hidden_states=None):
batch_mode = masks is not None
if batch_mode:
# batch mode (policy update): need saved hidden states
if hidden_states is None:
raise ValueError("Hidden states not passed to memory module during policy update")
out, _ = self.rnn(input, hidden_states)
out = unpad_trajectories(out, masks)
else:
# inference mode (collection): use hidden states of last step
out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states)
return out
def reset(self, dones=None):
# When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state
for hidden_state in self.hidden_states:
hidden_state[..., dones, :] = 0.0

View File

@ -0,0 +1,31 @@
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Copyright (c) 2021 ETH Zurich, Nikita Rudin
from .on_policy_runner import OnPolicyRunner

View File

@ -0,0 +1,232 @@
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Copyright (c) 2021 ETH Zurich, Nikita Rudin
import time
import os
from collections import deque
import statistics
from torch.utils.tensorboard import SummaryWriter
import torch
from rsl_rl.algorithms import PPO
from rsl_rl.modules import ActorCritic, ActorCriticRecurrent
from rsl_rl.env import VecEnv
class OnPolicyRunner:
def __init__(self,
env: VecEnv,
train_cfg,
log_dir=None,
device='cpu'):
self.cfg=train_cfg["runner"]
self.alg_cfg = train_cfg["algorithm"]
self.policy_cfg = train_cfg["policy"]
self.device = device
self.env = env
if self.env.num_privileged_obs is not None:
num_critic_obs = self.env.num_privileged_obs
else:
num_critic_obs = self.env.num_obs
actor_critic_class = eval(self.cfg["policy_class_name"]) # ActorCritic
actor_critic: ActorCritic = actor_critic_class( self.env.num_obs,
num_critic_obs,
self.env.num_actions,
**self.policy_cfg).to(self.device)
alg_class = eval(self.cfg["algorithm_class_name"]) # PPO
self.alg: PPO = alg_class(actor_critic, device=self.device, **self.alg_cfg)
self.num_steps_per_env = self.cfg["num_steps_per_env"]
self.save_interval = self.cfg["save_interval"]
# init storage and model
self.alg.init_storage(self.env.num_envs, self.num_steps_per_env, [self.env.num_obs], [self.env.num_privileged_obs], [self.env.num_actions])
# Log
self.log_dir = log_dir
self.writer = None
self.tot_timesteps = 0
self.tot_time = 0
self.current_learning_iteration = 0
_, _ = self.env.reset()
def learn(self, num_learning_iterations, init_at_random_ep_len=False):
# initialize writer
if self.log_dir is not None and self.writer is None:
self.writer = SummaryWriter(log_dir=self.log_dir, flush_secs=10)
if init_at_random_ep_len:
self.env.episode_length_buf = torch.randint_like(self.env.episode_length_buf, high=int(self.env.max_episode_length))
obs = self.env.get_observations()
privileged_obs = self.env.get_privileged_observations()
critic_obs = privileged_obs if privileged_obs is not None else obs
obs, critic_obs = obs.to(self.device), critic_obs.to(self.device)
self.alg.actor_critic.train() # switch to train mode (for dropout for example)
ep_infos = []
rewbuffer = deque(maxlen=100)
lenbuffer = deque(maxlen=100)
cur_reward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
cur_episode_length = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
tot_iter = self.current_learning_iteration + num_learning_iterations
for it in range(self.current_learning_iteration, tot_iter):
start = time.time()
# Rollout
with torch.inference_mode():
for i in range(self.num_steps_per_env):
actions = self.alg.act(obs, critic_obs)
obs, privileged_obs, rewards, dones, infos = self.env.step(actions)
critic_obs = privileged_obs if privileged_obs is not None else obs
obs, critic_obs, rewards, dones = obs.to(self.device), critic_obs.to(self.device), rewards.to(self.device), dones.to(self.device)
self.alg.process_env_step(rewards, dones, infos)
if self.log_dir is not None:
# Book keeping
if 'episode' in infos:
ep_infos.append(infos['episode'])
cur_reward_sum += rewards
cur_episode_length += 1
new_ids = (dones > 0).nonzero(as_tuple=False)
rewbuffer.extend(cur_reward_sum[new_ids][:, 0].cpu().numpy().tolist())
lenbuffer.extend(cur_episode_length[new_ids][:, 0].cpu().numpy().tolist())
cur_reward_sum[new_ids] = 0
cur_episode_length[new_ids] = 0
stop = time.time()
collection_time = stop - start
# Learning step
start = stop
self.alg.compute_returns(critic_obs)
mean_value_loss, mean_surrogate_loss = self.alg.update()
stop = time.time()
learn_time = stop - start
if self.log_dir is not None:
self.log(locals())
if it % self.save_interval == 0:
self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(it)))
ep_infos.clear()
self.current_learning_iteration = num_learning_iterations
self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(num_learning_iterations)))
def log(self, locs, width=80, pad=35):
self.tot_timesteps += self.num_steps_per_env * self.env.num_envs
self.tot_time += locs['collection_time'] + locs['learn_time']
iteration_time = locs['collection_time'] + locs['learn_time']
ep_string = f''
if locs['ep_infos']:
for key in locs['ep_infos'][0]:
infotensor = torch.tensor([], device=self.device)
for ep_info in locs['ep_infos']:
# handle scalar and zero dimensional tensor infos
if not isinstance(ep_info[key], torch.Tensor):
ep_info[key] = torch.Tensor([ep_info[key]])
if len(ep_info[key].shape) == 0:
ep_info[key] = ep_info[key].unsqueeze(0)
infotensor = torch.cat((infotensor, ep_info[key].to(self.device)))
value = torch.mean(infotensor)
self.writer.add_scalar('Episode/' + key, value, locs['it'])
ep_string += f"""{f'Mean episode {key}:':>{pad}} {value:.4f}\n"""
mean_std = self.alg.actor_critic.std.mean()
fps = int(self.num_steps_per_env * self.env.num_envs / (locs['collection_time'] + locs['learn_time']))
self.writer.add_scalar('Loss/value_function', locs['mean_value_loss'], locs['it'])
self.writer.add_scalar('Loss/surrogate', locs['mean_surrogate_loss'], locs['it'])
self.writer.add_scalar('Loss/learning_rate', self.alg.learning_rate, locs['it'])
self.writer.add_scalar('Policy/mean_noise_std', mean_std.item(), locs['it'])
self.writer.add_scalar('Perf/total_fps', fps, locs['it'])
self.writer.add_scalar('Perf/collection time', locs['collection_time'], locs['it'])
self.writer.add_scalar('Perf/learning_time', locs['learn_time'], locs['it'])
if len(locs['rewbuffer']) > 0:
self.writer.add_scalar('Train/mean_reward', statistics.mean(locs['rewbuffer']), locs['it'])
self.writer.add_scalar('Train/mean_episode_length', statistics.mean(locs['lenbuffer']), locs['it'])
self.writer.add_scalar('Train/mean_reward/time', statistics.mean(locs['rewbuffer']), self.tot_time)
self.writer.add_scalar('Train/mean_episode_length/time', statistics.mean(locs['lenbuffer']), self.tot_time)
str = f" \033[1m Learning iteration {locs['it']}/{locs['num_learning_iterations']} \033[0m "
if len(locs['rewbuffer']) > 0:
log_string = (f"""{'#' * width}\n"""
f"""{str.center(width, ' ')}\n\n"""
f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[
'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n"""
f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n"""
f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n"""
f"""{'Mean reward:':>{pad}} {statistics.mean(locs['rewbuffer']):.2f}\n"""
f"""{'Mean episode length:':>{pad}} {statistics.mean(locs['lenbuffer']):.2f}\n""")
# f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n"""
# f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n""")
else:
log_string = (f"""{'#' * width}\n"""
f"""{str.center(width, ' ')}\n\n"""
f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[
'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n"""
f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n"""
f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n"""
f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n""")
# f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n"""
# f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n""")
log_string += ep_string
log_string += (f"""{'-' * width}\n"""
f"""{'Total timesteps:':>{pad}} {self.tot_timesteps}\n"""
f"""{'Iteration time:':>{pad}} {iteration_time:.2f}s\n"""
f"""{'Total time:':>{pad}} {self.tot_time:.2f}s\n"""
f"""{'ETA:':>{pad}} {self.tot_time / (locs['it'] + 1) * (
locs['num_learning_iterations'] - locs['it']):.1f}s\n""")
print(log_string)
def save(self, path, infos=None):
torch.save({
'model_state_dict': self.alg.actor_critic.state_dict(),
'optimizer_state_dict': self.alg.optimizer.state_dict(),
'iter': self.current_learning_iteration,
'infos': infos,
}, path)
def load(self, path, load_optimizer=True):
loaded_dict = torch.load(path)
self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict'])
if load_optimizer:
self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict'])
self.current_learning_iteration = loaded_dict['iter']
return loaded_dict['infos']
def get_inference_policy(self, device=None):
self.alg.actor_critic.eval() # switch to evaluation mode (dropout for example)
if device is not None:
self.alg.actor_critic.to(device)
return self.alg.actor_critic.act_inference

View File

@ -0,0 +1,4 @@
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
# SPDX-License-Identifier: BSD-3-Clause
from .rollout_storage import RolloutStorage

View File

@ -0,0 +1,235 @@
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Copyright (c) 2021 ETH Zurich, Nikita Rudin
import torch
import numpy as np
from rsl_rl.utils import split_and_pad_trajectories
class RolloutStorage:
class Transition:
def __init__(self):
self.observations = None
self.critic_observations = None
self.actions = None
self.rewards = None
self.dones = None
self.values = None
self.actions_log_prob = None
self.action_mean = None
self.action_sigma = None
self.hidden_states = None
def clear(self):
self.__init__()
def __init__(self, num_envs, num_transitions_per_env, obs_shape, privileged_obs_shape, actions_shape, device='cpu'):
self.device = device
self.obs_shape = obs_shape
self.privileged_obs_shape = privileged_obs_shape
self.actions_shape = actions_shape
# Core
self.observations = torch.zeros(num_transitions_per_env, num_envs, *obs_shape, device=self.device)
if privileged_obs_shape[0] is not None:
self.privileged_observations = torch.zeros(num_transitions_per_env, num_envs, *privileged_obs_shape, device=self.device)
else:
self.privileged_observations = None
self.rewards = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
self.actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)
self.dones = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device).byte()
# For PPO
self.actions_log_prob = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
self.values = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
self.returns = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
self.advantages = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device)
self.mu = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)
self.sigma = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device)
self.num_transitions_per_env = num_transitions_per_env
self.num_envs = num_envs
# rnn
self.saved_hidden_states_a = None
self.saved_hidden_states_c = None
self.step = 0
def add_transitions(self, transition: Transition):
if self.step >= self.num_transitions_per_env:
raise AssertionError("Rollout buffer overflow")
self.observations[self.step].copy_(transition.observations)
if self.privileged_observations is not None: self.privileged_observations[self.step].copy_(transition.critic_observations)
self.actions[self.step].copy_(transition.actions)
self.rewards[self.step].copy_(transition.rewards.view(-1, 1))
self.dones[self.step].copy_(transition.dones.view(-1, 1))
self.values[self.step].copy_(transition.values)
self.actions_log_prob[self.step].copy_(transition.actions_log_prob.view(-1, 1))
self.mu[self.step].copy_(transition.action_mean)
self.sigma[self.step].copy_(transition.action_sigma)
self._save_hidden_states(transition.hidden_states)
self.step += 1
def _save_hidden_states(self, hidden_states):
if hidden_states is None or hidden_states==(None, None):
return
# make a tuple out of GRU hidden state sto match the LSTM format
hid_a = hidden_states[0] if isinstance(hidden_states[0], tuple) else (hidden_states[0],)
hid_c = hidden_states[1] if isinstance(hidden_states[1], tuple) else (hidden_states[1],)
# initialize if needed
if self.saved_hidden_states_a is None:
self.saved_hidden_states_a = [torch.zeros(self.observations.shape[0], *hid_a[i].shape, device=self.device) for i in range(len(hid_a))]
self.saved_hidden_states_c = [torch.zeros(self.observations.shape[0], *hid_c[i].shape, device=self.device) for i in range(len(hid_c))]
# copy the states
for i in range(len(hid_a)):
self.saved_hidden_states_a[i][self.step].copy_(hid_a[i])
self.saved_hidden_states_c[i][self.step].copy_(hid_c[i])
def clear(self):
self.step = 0
def compute_returns(self, last_values, gamma, lam):
advantage = 0
for step in reversed(range(self.num_transitions_per_env)):
if step == self.num_transitions_per_env - 1:
next_values = last_values
else:
next_values = self.values[step + 1]
next_is_not_terminal = 1.0 - self.dones[step].float()
delta = self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step]
advantage = delta + next_is_not_terminal * gamma * lam * advantage
self.returns[step] = advantage + self.values[step]
# Compute and normalize the advantages
self.advantages = self.returns - self.values
self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8)
def get_statistics(self):
done = self.dones
done[-1] = 1
flat_dones = done.permute(1, 0, 2).reshape(-1, 1)
done_indices = torch.cat((flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero(as_tuple=False)[:, 0]))
trajectory_lengths = (done_indices[1:] - done_indices[:-1])
return trajectory_lengths.float().mean(), self.rewards.mean()
def mini_batch_generator(self, num_mini_batches, num_epochs=8):
batch_size = self.num_envs * self.num_transitions_per_env
mini_batch_size = batch_size // num_mini_batches
indices = torch.randperm(num_mini_batches*mini_batch_size, requires_grad=False, device=self.device)
observations = self.observations.flatten(0, 1)
if self.privileged_observations is not None:
critic_observations = self.privileged_observations.flatten(0, 1)
else:
critic_observations = observations
actions = self.actions.flatten(0, 1)
values = self.values.flatten(0, 1)
returns = self.returns.flatten(0, 1)
old_actions_log_prob = self.actions_log_prob.flatten(0, 1)
advantages = self.advantages.flatten(0, 1)
old_mu = self.mu.flatten(0, 1)
old_sigma = self.sigma.flatten(0, 1)
for epoch in range(num_epochs):
for i in range(num_mini_batches):
start = i*mini_batch_size
end = (i+1)*mini_batch_size
batch_idx = indices[start:end]
obs_batch = observations[batch_idx]
critic_observations_batch = critic_observations[batch_idx]
actions_batch = actions[batch_idx]
target_values_batch = values[batch_idx]
returns_batch = returns[batch_idx]
old_actions_log_prob_batch = old_actions_log_prob[batch_idx]
advantages_batch = advantages[batch_idx]
old_mu_batch = old_mu[batch_idx]
old_sigma_batch = old_sigma[batch_idx]
yield obs_batch, critic_observations_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, \
old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, (None, None), None
# for RNNs only
def reccurent_mini_batch_generator(self, num_mini_batches, num_epochs=8):
padded_obs_trajectories, trajectory_masks = split_and_pad_trajectories(self.observations, self.dones)
if self.privileged_observations is not None:
padded_critic_obs_trajectories, _ = split_and_pad_trajectories(self.privileged_observations, self.dones)
else:
padded_critic_obs_trajectories = padded_obs_trajectories
mini_batch_size = self.num_envs // num_mini_batches
for ep in range(num_epochs):
first_traj = 0
for i in range(num_mini_batches):
start = i*mini_batch_size
stop = (i+1)*mini_batch_size
dones = self.dones.squeeze(-1)
last_was_done = torch.zeros_like(dones, dtype=torch.bool)
last_was_done[1:] = dones[:-1]
last_was_done[0] = True
trajectories_batch_size = torch.sum(last_was_done[:, start:stop])
last_traj = first_traj + trajectories_batch_size
masks_batch = trajectory_masks[:, first_traj:last_traj]
obs_batch = padded_obs_trajectories[:, first_traj:last_traj]
critic_obs_batch = padded_critic_obs_trajectories[:, first_traj:last_traj]
actions_batch = self.actions[:, start:stop]
old_mu_batch = self.mu[:, start:stop]
old_sigma_batch = self.sigma[:, start:stop]
returns_batch = self.returns[:, start:stop]
advantages_batch = self.advantages[:, start:stop]
values_batch = self.values[:, start:stop]
old_actions_log_prob_batch = self.actions_log_prob[:, start:stop]
# reshape to [num_envs, time, num layers, hidden dim] (original shape: [time, num_layers, num_envs, hidden_dim])
# then take only time steps after dones (flattens num envs and time dimensions),
# take a batch of trajectories and finally reshape back to [num_layers, batch, hidden_dim]
last_was_done = last_was_done.permute(1, 0)
hid_a_batch = [ saved_hidden_states.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj].transpose(1, 0).contiguous()
for saved_hidden_states in self.saved_hidden_states_a ]
hid_c_batch = [ saved_hidden_states.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj].transpose(1, 0).contiguous()
for saved_hidden_states in self.saved_hidden_states_c ]
# remove the tuple for GRU
hid_a_batch = hid_a_batch[0] if len(hid_a_batch)==1 else hid_a_batch
hid_c_batch = hid_c_batch[0] if len(hid_c_batch)==1 else hid_a_batch
yield obs_batch, critic_obs_batch, actions_batch, values_batch, advantages_batch, returns_batch, \
old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, (hid_a_batch, hid_c_batch), masks_batch
first_traj = last_traj

31
rsl_rl/utils/__init__.py Normal file
View File

@ -0,0 +1,31 @@
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Copyright (c) 2021 ETH Zurich, Nikita Rudin
from .utils import split_and_pad_trajectories, unpad_trajectories

71
rsl_rl/utils/utils.py Normal file
View File

@ -0,0 +1,71 @@
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Copyright (c) 2021 ETH Zurich, Nikita Rudin
import torch
def split_and_pad_trajectories(tensor, dones):
""" Splits trajectories at done indices. Then concatenates them and padds with zeros up to the length og the longest trajectory.
Returns masks corresponding to valid parts of the trajectories
Example:
Input: [ [a1, a2, a3, a4 | a5, a6],
[b1, b2 | b3, b4, b5 | b6]
]
Output:[ [a1, a2, a3, a4], | [ [True, True, True, True],
[a5, a6, 0, 0], | [True, True, False, False],
[b1, b2, 0, 0], | [True, True, False, False],
[b3, b4, b5, 0], | [True, True, True, False],
[b6, 0, 0, 0] | [True, False, False, False],
] | ]
Assumes that the inputy has the following dimension order: [time, number of envs, aditional dimensions]
"""
dones = dones.clone()
dones[-1] = 1
# Permute the buffers to have order (num_envs, num_transitions_per_env, ...), for correct reshaping
flat_dones = dones.transpose(1, 0).reshape(-1, 1)
# Get length of trajectory by counting the number of successive not done elements
done_indices = torch.cat((flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero()[:, 0]))
trajectory_lengths = done_indices[1:] - done_indices[:-1]
trajectory_lengths_list = trajectory_lengths.tolist()
# Extract the individual trajectories
trajectories = torch.split(tensor.transpose(1, 0).flatten(0, 1),trajectory_lengths_list)
padded_trajectories = torch.nn.utils.rnn.pad_sequence(trajectories)
trajectory_masks = trajectory_lengths > torch.arange(0, tensor.shape[0], device=tensor.device).unsqueeze(1)
return padded_trajectories, trajectory_masks
def unpad_trajectories(trajectories, masks):
""" Does the inverse operation of split_and_pad_trajectories()
"""
# Need to transpose before and after the masking to have proper reshaping
return trajectories.transpose(1, 0)[masks.transpose(1, 0)].view(-1, trajectories.shape[0], trajectories.shape[-1]).transpose(1, 0)

17
setup.py Normal file
View File

@ -0,0 +1,17 @@
from setuptools import setup, find_packages
setup(name='rsl_rl',
version='1.0.2',
packages=find_packages(),
author='Nikita Rudin',
author_email='rudinn@ethz.ch',
license="BSD-3-Clause",
packages=find_packages(),
description='Fast and simple RL algorithms implemented in pytorch',
python_requires='>=3.6',
install_requires=[
"torch>=1.4.0",
"torchvision>=0.5.0",
"numpy>=1.16.4",
],
)