From ff9e971c979a1cf42cd984cf6e0f50922229cb9c Mon Sep 17 00:00:00 2001 From: Nikita Rudin Date: Wed, 27 Oct 2021 10:18:22 +0200 Subject: [PATCH] initial open-source commit --- .gitignore | 12 ++ LICENSE | 30 +++ README.md | 13 ++ licenses/dependencies/numpy_license.txt | 30 +++ licenses/dependencies/torch_license.txt | 73 +++++++ rsl_rl/__init__.py | 29 +++ rsl_rl/algorithms/__init__.py | 31 +++ rsl_rl/algorithms/ppo.py | 187 ++++++++++++++++++ rsl_rl/env/__init__.py | 31 +++ rsl_rl/env/vec_env.py | 60 ++++++ rsl_rl/modules/__init__.py | 32 +++ rsl_rl/modules/actor_critic.py | 155 +++++++++++++++ rsl_rl/modules/actor_critic_recurrent.py | 116 +++++++++++ rsl_rl/runners/__init__.py | 31 +++ rsl_rl/runners/on_policy_runner.py | 232 ++++++++++++++++++++++ rsl_rl/storage/__init__.py | 4 + rsl_rl/storage/rollout_storage.py | 235 +++++++++++++++++++++++ rsl_rl/utils/__init__.py | 31 +++ rsl_rl/utils/utils.py | 71 +++++++ setup.py | 17 ++ 20 files changed, 1420 insertions(+) create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 README.md create mode 100644 licenses/dependencies/numpy_license.txt create mode 100644 licenses/dependencies/torch_license.txt create mode 100644 rsl_rl/__init__.py create mode 100644 rsl_rl/algorithms/__init__.py create mode 100644 rsl_rl/algorithms/ppo.py create mode 100644 rsl_rl/env/__init__.py create mode 100644 rsl_rl/env/vec_env.py create mode 100644 rsl_rl/modules/__init__.py create mode 100644 rsl_rl/modules/actor_critic.py create mode 100644 rsl_rl/modules/actor_critic_recurrent.py create mode 100644 rsl_rl/runners/__init__.py create mode 100644 rsl_rl/runners/on_policy_runner.py create mode 100644 rsl_rl/storage/__init__.py create mode 100644 rsl_rl/storage/rollout_storage.py create mode 100644 rsl_rl/utils/__init__.py create mode 100644 rsl_rl/utils/utils.py create mode 100644 setup.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..34f42f1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,12 @@ +# IDEs +.idea + +# builds +*.egg-info + +# cache +__pycache__ +.pytest_cache + +# vs code +.vscode \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..01d9567 --- /dev/null +++ b/LICENSE @@ -0,0 +1,30 @@ +Copyright (c) 2021, ETH Zurich, Nikita Rudin +Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors + may be used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +See licenses/dependencies for license information of dependencies of this package. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..cc4d0b6 --- /dev/null +++ b/README.md @@ -0,0 +1,13 @@ +# RSL RL +Fast and simple implementation of RL algorithms, designed to run fully on GPU. +This code is an evolution of `rl-pytorch` provided with NVIDIA's Isaac GYM. + +Only PPO is implemented for now. More algorithms will be added later. +Contributions are welcome. + +**Maintainer**: Nikita Rudin +**Affiliation**: Robotic Systems Lab, ETH Zurich & NVIDIA +**Contact**: rudinn@ethz.ch + + + diff --git a/licenses/dependencies/numpy_license.txt b/licenses/dependencies/numpy_license.txt new file mode 100644 index 0000000..84e9bfe --- /dev/null +++ b/licenses/dependencies/numpy_license.txt @@ -0,0 +1,30 @@ +Copyright (c) 2005-2021, NumPy Developers. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + + * Neither the name of the NumPy Developers nor the names of any + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/licenses/dependencies/torch_license.txt b/licenses/dependencies/torch_license.txt new file mode 100644 index 0000000..244b249 --- /dev/null +++ b/licenses/dependencies/torch_license.txt @@ -0,0 +1,73 @@ +From PyTorch: + +Copyright (c) 2016- Facebook, Inc (Adam Paszke) +Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +Copyright (c) 2011-2013 NYU (Clement Farabet) +Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +From Caffe2: + +Copyright (c) 2016-present, Facebook Inc. All rights reserved. + +All contributions by Facebook: +Copyright (c) 2016 Facebook Inc. + +All contributions by Google: +Copyright (c) 2015 Google Inc. +All rights reserved. + +All contributions by Yangqing Jia: +Copyright (c) 2015 Yangqing Jia +All rights reserved. + +All contributions by Kakao Brain: +Copyright 2019-2020 Kakao Brain + +All contributions from Caffe: +Copyright(c) 2013, 2014, 2015, the respective contributors +All rights reserved. + +All other contributions: +Copyright(c) 2015, 2016 the respective contributors +All rights reserved. + +Caffe2 uses a copyright model similar to Caffe: each contributor holds +copyright over their contributions to Caffe2. The project versioning records +all such contribution and copyright details. If a contributor wants to further +mark their specific copyright on a particular contribution, they should +indicate their copyright solely in the commit message of the change when it is +committed. + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America + and IDIAP Research Institute nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/rsl_rl/__init__.py b/rsl_rl/__init__.py new file mode 100644 index 0000000..466dfca --- /dev/null +++ b/rsl_rl/__init__.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# Copyright (c) 2021 ETH Zurich, Nikita Rudin \ No newline at end of file diff --git a/rsl_rl/algorithms/__init__.py b/rsl_rl/algorithms/__init__.py new file mode 100644 index 0000000..6f94329 --- /dev/null +++ b/rsl_rl/algorithms/__init__.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# Copyright (c) 2021 ETH Zurich, Nikita Rudin + +from .ppo import PPO \ No newline at end of file diff --git a/rsl_rl/algorithms/ppo.py b/rsl_rl/algorithms/ppo.py new file mode 100644 index 0000000..2017042 --- /dev/null +++ b/rsl_rl/algorithms/ppo.py @@ -0,0 +1,187 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# Copyright (c) 2021 ETH Zurich, Nikita Rudin + +import torch +import torch.nn as nn +import torch.optim as optim + +from rsl_rl.modules import ActorCritic +from rsl_rl.storage import RolloutStorage + +class PPO: + actor_critic: ActorCritic + def __init__(self, + actor_critic, + num_learning_epochs=1, + num_mini_batches=1, + clip_param=0.2, + gamma=0.998, + lam=0.95, + value_loss_coef=1.0, + entropy_coef=0.0, + learning_rate=1e-3, + max_grad_norm=1.0, + use_clipped_value_loss=True, + schedule="fixed", + desired_kl=0.01, + device='cpu', + ): + + self.device = device + + self.desired_kl = desired_kl + self.schedule = schedule + self.learning_rate = learning_rate + + # PPO components + self.actor_critic = actor_critic + self.actor_critic.to(self.device) + self.storage = None # initialized later + self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=learning_rate) + self.transition = RolloutStorage.Transition() + + # PPO parameters + self.clip_param = clip_param + self.num_learning_epochs = num_learning_epochs + self.num_mini_batches = num_mini_batches + self.value_loss_coef = value_loss_coef + self.entropy_coef = entropy_coef + self.gamma = gamma + self.lam = lam + self.max_grad_norm = max_grad_norm + self.use_clipped_value_loss = use_clipped_value_loss + + def init_storage(self, num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape): + self.storage = RolloutStorage(num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape, self.device) + + def test_mode(self): + self.actor_critic.test() + + def train_mode(self): + self.actor_critic.train() + + def act(self, obs, critic_obs): + if self.actor_critic.is_recurrent: + self.transition.hidden_states = self.actor_critic.get_hidden_states() + # Compute the actions and values + self.transition.actions = self.actor_critic.act(obs).detach() + self.transition.values = self.actor_critic.evaluate(critic_obs).detach() + self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach() + self.transition.action_mean = self.actor_critic.action_mean.detach() + self.transition.action_sigma = self.actor_critic.action_std.detach() + # need to record obs and critic_obs before env.step() + self.transition.observations = obs + self.transition.critic_observations = critic_obs + return self.transition.actions + + def process_env_step(self, rewards, dones, infos): + self.transition.rewards = rewards.clone() + self.transition.dones = dones + # Bootstrapping on time outs + if 'time_outs' in infos: + self.transition.rewards += self.gamma * torch.squeeze(self.transition.values * infos['time_outs'].unsqueeze(1).to(self.device), 1) + + # Record the transition + self.storage.add_transitions(self.transition) + self.transition.clear() + self.actor_critic.reset(dones) + + def compute_returns(self, last_critic_obs): + last_values= self.actor_critic.evaluate(last_critic_obs).detach() + self.storage.compute_returns(last_values, self.gamma, self.lam) + + def update(self): + mean_value_loss = 0 + mean_surrogate_loss = 0 + if self.actor_critic.is_recurrent: + generator = self.storage.reccurent_mini_batch_generator(self.num_mini_batches, self.num_learning_epochs) + else: + generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs) + for obs_batch, critic_obs_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, \ + old_mu_batch, old_sigma_batch, hid_states_batch, masks_batch in generator: + + + self.actor_critic.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0]) + actions_log_prob_batch = self.actor_critic.get_actions_log_prob(actions_batch) + value_batch = self.actor_critic.evaluate(critic_obs_batch, masks=masks_batch, hidden_states=hid_states_batch[1]) + mu_batch = self.actor_critic.action_mean + sigma_batch = self.actor_critic.action_std + entropy_batch = self.actor_critic.entropy + + # KL + if self.desired_kl != None and self.schedule == 'adaptive': + with torch.inference_mode(): + kl = torch.sum( + torch.log(sigma_batch / old_sigma_batch + 1.e-5) + (torch.square(old_sigma_batch) + torch.square(old_mu_batch - mu_batch)) / (2.0 * torch.square(sigma_batch)) - 0.5, axis=-1) + kl_mean = torch.mean(kl) + + if kl_mean > self.desired_kl * 2.0: + self.learning_rate = max(1e-5, self.learning_rate / 1.5) + elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0: + self.learning_rate = min(1e-2, self.learning_rate * 1.5) + + for param_group in self.optimizer.param_groups: + param_group['lr'] = self.learning_rate + + + # Surrogate loss + ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch)) + surrogate = -torch.squeeze(advantages_batch) * ratio + surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp(ratio, 1.0 - self.clip_param, + 1.0 + self.clip_param) + surrogate_loss = torch.max(surrogate, surrogate_clipped).mean() + + # Value function loss + if self.use_clipped_value_loss: + value_clipped = target_values_batch + (value_batch - target_values_batch).clamp(-self.clip_param, + self.clip_param) + value_losses = (value_batch - returns_batch).pow(2) + value_losses_clipped = (value_clipped - returns_batch).pow(2) + value_loss = torch.max(value_losses, value_losses_clipped).mean() + else: + value_loss = (returns_batch - value_batch).pow(2).mean() + + loss = surrogate_loss + self.value_loss_coef * value_loss - self.entropy_coef * entropy_batch.mean() + + # Gradient step + self.optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.max_grad_norm) + self.optimizer.step() + + mean_value_loss += value_loss.item() + mean_surrogate_loss += surrogate_loss.item() + + num_updates = self.num_learning_epochs * self.num_mini_batches + mean_value_loss /= num_updates + mean_surrogate_loss /= num_updates + self.storage.clear() + + return mean_value_loss, mean_surrogate_loss diff --git a/rsl_rl/env/__init__.py b/rsl_rl/env/__init__.py new file mode 100644 index 0000000..9539b9f --- /dev/null +++ b/rsl_rl/env/__init__.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# Copyright (c) 2021 ETH Zurich, Nikita Rudin + +from .vec_env import VecEnv \ No newline at end of file diff --git a/rsl_rl/env/vec_env.py b/rsl_rl/env/vec_env.py new file mode 100644 index 0000000..6a7ef1b --- /dev/null +++ b/rsl_rl/env/vec_env.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# Copyright (c) 2021 ETH Zurich, Nikita Rudin + +from abc import ABC, abstractmethod +import torch +from typing import Tuple, Union + +# minimal interface of the environment +class VecEnv(ABC): + num_envs: int + num_obs: int + num_privileged_obs: int + num_actions: int + max_episode_length: int + privileged_obs_buf: torch.Tensor + obs_buf: torch.Tensor + rew_buf: torch.Tensor + reset_buf: torch.Tensor + episode_length_buf: torch.Tensor # current episode duration + extras: dict + device: torch.device + @abstractmethod + def step(self, actions: torch.Tensor) -> Tuple[torch.Tensor, Union[torch.Tensor, None], torch.Tensor, torch.Tensor, dict]: + pass + @abstractmethod + def reset(self, env_ids: Union[list, torch.Tensor]): + pass + @abstractmethod + def get_observations(self) -> torch.Tensor: + pass + @abstractmethod + def get_privileged_observations(self) -> Union[torch.Tensor, None]: + pass \ No newline at end of file diff --git a/rsl_rl/modules/__init__.py b/rsl_rl/modules/__init__.py new file mode 100644 index 0000000..3b9d30f --- /dev/null +++ b/rsl_rl/modules/__init__.py @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# Copyright (c) 2021 ETH Zurich, Nikita Rudin + +from .actor_critic import ActorCritic +from .actor_critic_recurrent import ActorCriticRecurrent \ No newline at end of file diff --git a/rsl_rl/modules/actor_critic.py b/rsl_rl/modules/actor_critic.py new file mode 100644 index 0000000..1864ff4 --- /dev/null +++ b/rsl_rl/modules/actor_critic.py @@ -0,0 +1,155 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# Copyright (c) 2021 ETH Zurich, Nikita Rudin + +import numpy as np + +import torch +import torch.nn as nn +from torch.distributions import Normal +from torch.nn.modules import rnn + +class ActorCritic(nn.Module): + is_recurrent = False + def __init__(self, num_actor_obs, + num_critic_obs, + num_actions, + actor_hidden_dims=[256, 256, 256], + critic_hidden_dims=[256, 256, 256], + activation='elu', + init_noise_std=1.0, + **kwargs): + if kwargs: + print("ActorCritic.__init__ got unexpected arguments, which will be ignored: " + str([key for key in kwargs.keys()])) + super(ActorCritic, self).__init__() + + activation = get_activation(activation) + + mlp_input_dim_a = num_actor_obs + mlp_input_dim_c = num_critic_obs + + # Policy + actor_layers = [] + actor_layers.append(nn.Linear(mlp_input_dim_a, actor_hidden_dims[0])) + actor_layers.append(activation) + for l in range(len(actor_hidden_dims)): + if l == len(actor_hidden_dims) - 1: + actor_layers.append(nn.Linear(actor_hidden_dims[l], num_actions)) + else: + actor_layers.append(nn.Linear(actor_hidden_dims[l], actor_hidden_dims[l + 1])) + actor_layers.append(activation) + self.actor = nn.Sequential(*actor_layers) + + # Value function + critic_layers = [] + critic_layers.append(nn.Linear(mlp_input_dim_c, critic_hidden_dims[0])) + critic_layers.append(activation) + for l in range(len(critic_hidden_dims)): + if l == len(critic_hidden_dims) - 1: + critic_layers.append(nn.Linear(critic_hidden_dims[l], 1)) + else: + critic_layers.append(nn.Linear(critic_hidden_dims[l], critic_hidden_dims[l + 1])) + critic_layers.append(activation) + self.critic = nn.Sequential(*critic_layers) + + print(f"Actor MLP: {self.actor}") + print(f"Critic MLP: {self.critic}") + + # Action noise + self.std = nn.Parameter(init_noise_std * torch.ones(num_actions)) + self.distribution = None + # disable args validation for speedup + Normal.set_default_validate_args = False + + # seems that we get better performance without init + # self.init_memory_weights(self.memory_a, 0.001, 0.) + # self.init_memory_weights(self.memory_c, 0.001, 0.) + + @staticmethod + # not used at the moment + def init_weights(sequential, scales): + [torch.nn.init.orthogonal_(module.weight, gain=scales[idx]) for idx, module in + enumerate(mod for mod in sequential if isinstance(mod, nn.Linear))] + + + def reset(self, dones=None): + pass + + def forward(self): + raise NotImplementedError + + @property + def action_mean(self): + return self.distribution.mean + + @property + def action_std(self): + return self.distribution.stddev + + @property + def entropy(self): + return self.distribution.entropy().sum(dim=-1) + + def update_distribution(self, observations): + mean = self.actor(observations) + self.distribution = Normal(mean, mean*0. + self.std) + + def act(self, observations, **kwargs): + self.update_distribution(observations) + return self.distribution.sample() + + def get_actions_log_prob(self, actions): + return self.distribution.log_prob(actions).sum(dim=-1) + + def act_inference(self, observations): + actions_mean = self.actor(observations) + return actions_mean + + def evaluate(self, critic_observations, **kwargs): + value = self.critic(critic_observations) + return value + +def get_activation(act_name): + if act_name == "elu": + return nn.ELU() + elif act_name == "selu": + return nn.SELU() + elif act_name == "relu": + return nn.ReLU() + elif act_name == "crelu": + return nn.ReLU() + elif act_name == "lrelu": + return nn.LeakyReLU() + elif act_name == "tanh": + return nn.Tanh() + elif act_name == "sigmoid": + return nn.Sigmoid() + else: + print("invalid activation function!") + return None diff --git a/rsl_rl/modules/actor_critic_recurrent.py b/rsl_rl/modules/actor_critic_recurrent.py new file mode 100644 index 0000000..07d9580 --- /dev/null +++ b/rsl_rl/modules/actor_critic_recurrent.py @@ -0,0 +1,116 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# Copyright (c) 2021 ETH Zurich, Nikita Rudin + +import numpy as np + +import torch +import torch.nn as nn +from torch.distributions import Normal +from torch.nn.modules import rnn +from .actor_critic import ActorCritic, get_activation +from rsl_rl.utils import unpad_trajectories + +class ActorCriticRecurrent(ActorCritic): + is_recurrent = True + def __init__(self, num_actor_obs, + num_critic_obs, + num_actions, + actor_hidden_dims=[256, 256, 256], + critic_hidden_dims=[256, 256, 256], + activation='elu', + rnn_type='lstm', + rnn_hidden_size=256, + rnn_num_layers=1, + init_noise_std=1.0, + **kwargs): + if kwargs: + print("ActorCriticRecurrent.__init__ got unexpected arguments, which will be ignored: " + str(kwargs.keys()),) + + super().__init__(num_actor_obs=rnn_hidden_size, + num_critic_obs=rnn_hidden_size, + num_actions=num_actions, + actor_hidden_dims=actor_hidden_dims, + critic_hidden_dims=critic_hidden_dims, + activation=activation, + init_noise_std=init_noise_std) + + activation = get_activation(activation) + + self.memory_a = Memory(num_actor_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size) + self.memory_c = Memory(num_critic_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_size) + + print(f"Actor RNN: {self.memory_a}") + print(f"Critic RNN: {self.memory_c}") + + def reset(self, dones=None): + self.memory_a.reset(dones) + self.memory_c.reset(dones) + + def act(self, observations, masks=None, hidden_states=None): + input_a = self.memory_a(observations, masks, hidden_states) + return super().act(input_a.squeeze(0)) + + def act_inference(self, observations): + input_a = self.memory_a(observations) + return super().act_inference(input_a.squeeze(0)) + + def evaluate(self, critic_observations, masks=None, hidden_states=None): + input_c = self.memory_c(critic_observations, masks, hidden_states) + return super().evaluate(input_c.squeeze(0)) + + def get_hidden_states(self): + return self.memory_a.hidden_states, self.memory_c.hidden_states + + +class Memory(torch.nn.Module): + def __init__(self, input_size, type='lstm', num_layers=1, hidden_size=256): + super().__init__() + # RNN + rnn_cls = nn.GRU if type.lower() == 'gru' else nn.LSTM + self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers) + self.hidden_states = None + + def forward(self, input, masks=None, hidden_states=None): + batch_mode = masks is not None + if batch_mode: + # batch mode (policy update): need saved hidden states + if hidden_states is None: + raise ValueError("Hidden states not passed to memory module during policy update") + out, _ = self.rnn(input, hidden_states) + out = unpad_trajectories(out, masks) + else: + # inference mode (collection): use hidden states of last step + out, self.hidden_states = self.rnn(input.unsqueeze(0), self.hidden_states) + return out + + def reset(self, dones=None): + # When the RNN is an LSTM, self.hidden_states_a is a list with hidden_state and cell_state + for hidden_state in self.hidden_states: + hidden_state[..., dones, :] = 0.0 \ No newline at end of file diff --git a/rsl_rl/runners/__init__.py b/rsl_rl/runners/__init__.py new file mode 100644 index 0000000..4753f9f --- /dev/null +++ b/rsl_rl/runners/__init__.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# Copyright (c) 2021 ETH Zurich, Nikita Rudin + +from .on_policy_runner import OnPolicyRunner \ No newline at end of file diff --git a/rsl_rl/runners/on_policy_runner.py b/rsl_rl/runners/on_policy_runner.py new file mode 100644 index 0000000..fe3d930 --- /dev/null +++ b/rsl_rl/runners/on_policy_runner.py @@ -0,0 +1,232 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# Copyright (c) 2021 ETH Zurich, Nikita Rudin + +import time +import os +from collections import deque +import statistics + +from torch.utils.tensorboard import SummaryWriter +import torch + +from rsl_rl.algorithms import PPO +from rsl_rl.modules import ActorCritic, ActorCriticRecurrent +from rsl_rl.env import VecEnv + + +class OnPolicyRunner: + + def __init__(self, + env: VecEnv, + train_cfg, + log_dir=None, + device='cpu'): + + self.cfg=train_cfg["runner"] + self.alg_cfg = train_cfg["algorithm"] + self.policy_cfg = train_cfg["policy"] + self.device = device + self.env = env + if self.env.num_privileged_obs is not None: + num_critic_obs = self.env.num_privileged_obs + else: + num_critic_obs = self.env.num_obs + actor_critic_class = eval(self.cfg["policy_class_name"]) # ActorCritic + actor_critic: ActorCritic = actor_critic_class( self.env.num_obs, + num_critic_obs, + self.env.num_actions, + **self.policy_cfg).to(self.device) + alg_class = eval(self.cfg["algorithm_class_name"]) # PPO + self.alg: PPO = alg_class(actor_critic, device=self.device, **self.alg_cfg) + self.num_steps_per_env = self.cfg["num_steps_per_env"] + self.save_interval = self.cfg["save_interval"] + + # init storage and model + self.alg.init_storage(self.env.num_envs, self.num_steps_per_env, [self.env.num_obs], [self.env.num_privileged_obs], [self.env.num_actions]) + + # Log + self.log_dir = log_dir + self.writer = None + self.tot_timesteps = 0 + self.tot_time = 0 + self.current_learning_iteration = 0 + + _, _ = self.env.reset() + + def learn(self, num_learning_iterations, init_at_random_ep_len=False): + # initialize writer + if self.log_dir is not None and self.writer is None: + self.writer = SummaryWriter(log_dir=self.log_dir, flush_secs=10) + if init_at_random_ep_len: + self.env.episode_length_buf = torch.randint_like(self.env.episode_length_buf, high=int(self.env.max_episode_length)) + obs = self.env.get_observations() + privileged_obs = self.env.get_privileged_observations() + critic_obs = privileged_obs if privileged_obs is not None else obs + obs, critic_obs = obs.to(self.device), critic_obs.to(self.device) + self.alg.actor_critic.train() # switch to train mode (for dropout for example) + + ep_infos = [] + rewbuffer = deque(maxlen=100) + lenbuffer = deque(maxlen=100) + cur_reward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device) + cur_episode_length = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device) + + tot_iter = self.current_learning_iteration + num_learning_iterations + for it in range(self.current_learning_iteration, tot_iter): + start = time.time() + # Rollout + with torch.inference_mode(): + for i in range(self.num_steps_per_env): + actions = self.alg.act(obs, critic_obs) + obs, privileged_obs, rewards, dones, infos = self.env.step(actions) + critic_obs = privileged_obs if privileged_obs is not None else obs + obs, critic_obs, rewards, dones = obs.to(self.device), critic_obs.to(self.device), rewards.to(self.device), dones.to(self.device) + self.alg.process_env_step(rewards, dones, infos) + + if self.log_dir is not None: + # Book keeping + if 'episode' in infos: + ep_infos.append(infos['episode']) + cur_reward_sum += rewards + cur_episode_length += 1 + new_ids = (dones > 0).nonzero(as_tuple=False) + rewbuffer.extend(cur_reward_sum[new_ids][:, 0].cpu().numpy().tolist()) + lenbuffer.extend(cur_episode_length[new_ids][:, 0].cpu().numpy().tolist()) + cur_reward_sum[new_ids] = 0 + cur_episode_length[new_ids] = 0 + + stop = time.time() + collection_time = stop - start + + # Learning step + start = stop + self.alg.compute_returns(critic_obs) + + mean_value_loss, mean_surrogate_loss = self.alg.update() + stop = time.time() + learn_time = stop - start + if self.log_dir is not None: + self.log(locals()) + if it % self.save_interval == 0: + self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(it))) + ep_infos.clear() + self.current_learning_iteration = num_learning_iterations + self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(num_learning_iterations))) + + def log(self, locs, width=80, pad=35): + self.tot_timesteps += self.num_steps_per_env * self.env.num_envs + self.tot_time += locs['collection_time'] + locs['learn_time'] + iteration_time = locs['collection_time'] + locs['learn_time'] + + ep_string = f'' + if locs['ep_infos']: + for key in locs['ep_infos'][0]: + infotensor = torch.tensor([], device=self.device) + for ep_info in locs['ep_infos']: + # handle scalar and zero dimensional tensor infos + if not isinstance(ep_info[key], torch.Tensor): + ep_info[key] = torch.Tensor([ep_info[key]]) + if len(ep_info[key].shape) == 0: + ep_info[key] = ep_info[key].unsqueeze(0) + infotensor = torch.cat((infotensor, ep_info[key].to(self.device))) + value = torch.mean(infotensor) + self.writer.add_scalar('Episode/' + key, value, locs['it']) + ep_string += f"""{f'Mean episode {key}:':>{pad}} {value:.4f}\n""" + mean_std = self.alg.actor_critic.std.mean() + fps = int(self.num_steps_per_env * self.env.num_envs / (locs['collection_time'] + locs['learn_time'])) + + self.writer.add_scalar('Loss/value_function', locs['mean_value_loss'], locs['it']) + self.writer.add_scalar('Loss/surrogate', locs['mean_surrogate_loss'], locs['it']) + self.writer.add_scalar('Loss/learning_rate', self.alg.learning_rate, locs['it']) + self.writer.add_scalar('Policy/mean_noise_std', mean_std.item(), locs['it']) + self.writer.add_scalar('Perf/total_fps', fps, locs['it']) + self.writer.add_scalar('Perf/collection time', locs['collection_time'], locs['it']) + self.writer.add_scalar('Perf/learning_time', locs['learn_time'], locs['it']) + if len(locs['rewbuffer']) > 0: + self.writer.add_scalar('Train/mean_reward', statistics.mean(locs['rewbuffer']), locs['it']) + self.writer.add_scalar('Train/mean_episode_length', statistics.mean(locs['lenbuffer']), locs['it']) + self.writer.add_scalar('Train/mean_reward/time', statistics.mean(locs['rewbuffer']), self.tot_time) + self.writer.add_scalar('Train/mean_episode_length/time', statistics.mean(locs['lenbuffer']), self.tot_time) + + str = f" \033[1m Learning iteration {locs['it']}/{locs['num_learning_iterations']} \033[0m " + + if len(locs['rewbuffer']) > 0: + log_string = (f"""{'#' * width}\n""" + f"""{str.center(width, ' ')}\n\n""" + f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[ + 'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n""" + f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n""" + f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n""" + f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n""" + f"""{'Mean reward:':>{pad}} {statistics.mean(locs['rewbuffer']):.2f}\n""" + f"""{'Mean episode length:':>{pad}} {statistics.mean(locs['lenbuffer']):.2f}\n""") + # f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n""" + # f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n""") + else: + log_string = (f"""{'#' * width}\n""" + f"""{str.center(width, ' ')}\n\n""" + f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[ + 'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n""" + f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n""" + f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n""" + f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n""") + # f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n""" + # f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n""") + + log_string += ep_string + log_string += (f"""{'-' * width}\n""" + f"""{'Total timesteps:':>{pad}} {self.tot_timesteps}\n""" + f"""{'Iteration time:':>{pad}} {iteration_time:.2f}s\n""" + f"""{'Total time:':>{pad}} {self.tot_time:.2f}s\n""" + f"""{'ETA:':>{pad}} {self.tot_time / (locs['it'] + 1) * ( + locs['num_learning_iterations'] - locs['it']):.1f}s\n""") + print(log_string) + + def save(self, path, infos=None): + torch.save({ + 'model_state_dict': self.alg.actor_critic.state_dict(), + 'optimizer_state_dict': self.alg.optimizer.state_dict(), + 'iter': self.current_learning_iteration, + 'infos': infos, + }, path) + + def load(self, path, load_optimizer=True): + loaded_dict = torch.load(path) + self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict']) + if load_optimizer: + self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict']) + self.current_learning_iteration = loaded_dict['iter'] + return loaded_dict['infos'] + + def get_inference_policy(self, device=None): + self.alg.actor_critic.eval() # switch to evaluation mode (dropout for example) + if device is not None: + self.alg.actor_critic.to(device) + return self.alg.actor_critic.act_inference diff --git a/rsl_rl/storage/__init__.py b/rsl_rl/storage/__init__.py new file mode 100644 index 0000000..e8f67f5 --- /dev/null +++ b/rsl_rl/storage/__init__.py @@ -0,0 +1,4 @@ +# Copyright 2021 ETH Zurich, NVIDIA CORPORATION +# SPDX-License-Identifier: BSD-3-Clause + +from .rollout_storage import RolloutStorage \ No newline at end of file diff --git a/rsl_rl/storage/rollout_storage.py b/rsl_rl/storage/rollout_storage.py new file mode 100644 index 0000000..66ee55f --- /dev/null +++ b/rsl_rl/storage/rollout_storage.py @@ -0,0 +1,235 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# Copyright (c) 2021 ETH Zurich, Nikita Rudin + +import torch +import numpy as np + +from rsl_rl.utils import split_and_pad_trajectories + +class RolloutStorage: + class Transition: + def __init__(self): + self.observations = None + self.critic_observations = None + self.actions = None + self.rewards = None + self.dones = None + self.values = None + self.actions_log_prob = None + self.action_mean = None + self.action_sigma = None + self.hidden_states = None + + def clear(self): + self.__init__() + + def __init__(self, num_envs, num_transitions_per_env, obs_shape, privileged_obs_shape, actions_shape, device='cpu'): + + self.device = device + + self.obs_shape = obs_shape + self.privileged_obs_shape = privileged_obs_shape + self.actions_shape = actions_shape + + # Core + self.observations = torch.zeros(num_transitions_per_env, num_envs, *obs_shape, device=self.device) + if privileged_obs_shape[0] is not None: + self.privileged_observations = torch.zeros(num_transitions_per_env, num_envs, *privileged_obs_shape, device=self.device) + else: + self.privileged_observations = None + self.rewards = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) + self.actions = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device) + self.dones = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device).byte() + + # For PPO + self.actions_log_prob = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) + self.values = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) + self.returns = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) + self.advantages = torch.zeros(num_transitions_per_env, num_envs, 1, device=self.device) + self.mu = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device) + self.sigma = torch.zeros(num_transitions_per_env, num_envs, *actions_shape, device=self.device) + + self.num_transitions_per_env = num_transitions_per_env + self.num_envs = num_envs + + # rnn + self.saved_hidden_states_a = None + self.saved_hidden_states_c = None + + self.step = 0 + + def add_transitions(self, transition: Transition): + if self.step >= self.num_transitions_per_env: + raise AssertionError("Rollout buffer overflow") + self.observations[self.step].copy_(transition.observations) + if self.privileged_observations is not None: self.privileged_observations[self.step].copy_(transition.critic_observations) + self.actions[self.step].copy_(transition.actions) + self.rewards[self.step].copy_(transition.rewards.view(-1, 1)) + self.dones[self.step].copy_(transition.dones.view(-1, 1)) + self.values[self.step].copy_(transition.values) + self.actions_log_prob[self.step].copy_(transition.actions_log_prob.view(-1, 1)) + self.mu[self.step].copy_(transition.action_mean) + self.sigma[self.step].copy_(transition.action_sigma) + self._save_hidden_states(transition.hidden_states) + self.step += 1 + + def _save_hidden_states(self, hidden_states): + if hidden_states is None or hidden_states==(None, None): + return + # make a tuple out of GRU hidden state sto match the LSTM format + hid_a = hidden_states[0] if isinstance(hidden_states[0], tuple) else (hidden_states[0],) + hid_c = hidden_states[1] if isinstance(hidden_states[1], tuple) else (hidden_states[1],) + + # initialize if needed + if self.saved_hidden_states_a is None: + self.saved_hidden_states_a = [torch.zeros(self.observations.shape[0], *hid_a[i].shape, device=self.device) for i in range(len(hid_a))] + self.saved_hidden_states_c = [torch.zeros(self.observations.shape[0], *hid_c[i].shape, device=self.device) for i in range(len(hid_c))] + # copy the states + for i in range(len(hid_a)): + self.saved_hidden_states_a[i][self.step].copy_(hid_a[i]) + self.saved_hidden_states_c[i][self.step].copy_(hid_c[i]) + + + def clear(self): + self.step = 0 + + def compute_returns(self, last_values, gamma, lam): + advantage = 0 + for step in reversed(range(self.num_transitions_per_env)): + if step == self.num_transitions_per_env - 1: + next_values = last_values + else: + next_values = self.values[step + 1] + next_is_not_terminal = 1.0 - self.dones[step].float() + delta = self.rewards[step] + next_is_not_terminal * gamma * next_values - self.values[step] + advantage = delta + next_is_not_terminal * gamma * lam * advantage + self.returns[step] = advantage + self.values[step] + + # Compute and normalize the advantages + self.advantages = self.returns - self.values + self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8) + + def get_statistics(self): + done = self.dones + done[-1] = 1 + flat_dones = done.permute(1, 0, 2).reshape(-1, 1) + done_indices = torch.cat((flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero(as_tuple=False)[:, 0])) + trajectory_lengths = (done_indices[1:] - done_indices[:-1]) + return trajectory_lengths.float().mean(), self.rewards.mean() + + def mini_batch_generator(self, num_mini_batches, num_epochs=8): + batch_size = self.num_envs * self.num_transitions_per_env + mini_batch_size = batch_size // num_mini_batches + indices = torch.randperm(num_mini_batches*mini_batch_size, requires_grad=False, device=self.device) + + observations = self.observations.flatten(0, 1) + if self.privileged_observations is not None: + critic_observations = self.privileged_observations.flatten(0, 1) + else: + critic_observations = observations + + actions = self.actions.flatten(0, 1) + values = self.values.flatten(0, 1) + returns = self.returns.flatten(0, 1) + old_actions_log_prob = self.actions_log_prob.flatten(0, 1) + advantages = self.advantages.flatten(0, 1) + old_mu = self.mu.flatten(0, 1) + old_sigma = self.sigma.flatten(0, 1) + + for epoch in range(num_epochs): + for i in range(num_mini_batches): + + start = i*mini_batch_size + end = (i+1)*mini_batch_size + batch_idx = indices[start:end] + + obs_batch = observations[batch_idx] + critic_observations_batch = critic_observations[batch_idx] + actions_batch = actions[batch_idx] + target_values_batch = values[batch_idx] + returns_batch = returns[batch_idx] + old_actions_log_prob_batch = old_actions_log_prob[batch_idx] + advantages_batch = advantages[batch_idx] + old_mu_batch = old_mu[batch_idx] + old_sigma_batch = old_sigma[batch_idx] + yield obs_batch, critic_observations_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, \ + old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, (None, None), None + + # for RNNs only + def reccurent_mini_batch_generator(self, num_mini_batches, num_epochs=8): + + padded_obs_trajectories, trajectory_masks = split_and_pad_trajectories(self.observations, self.dones) + if self.privileged_observations is not None: + padded_critic_obs_trajectories, _ = split_and_pad_trajectories(self.privileged_observations, self.dones) + else: + padded_critic_obs_trajectories = padded_obs_trajectories + + mini_batch_size = self.num_envs // num_mini_batches + for ep in range(num_epochs): + first_traj = 0 + for i in range(num_mini_batches): + start = i*mini_batch_size + stop = (i+1)*mini_batch_size + + dones = self.dones.squeeze(-1) + last_was_done = torch.zeros_like(dones, dtype=torch.bool) + last_was_done[1:] = dones[:-1] + last_was_done[0] = True + trajectories_batch_size = torch.sum(last_was_done[:, start:stop]) + last_traj = first_traj + trajectories_batch_size + + masks_batch = trajectory_masks[:, first_traj:last_traj] + obs_batch = padded_obs_trajectories[:, first_traj:last_traj] + critic_obs_batch = padded_critic_obs_trajectories[:, first_traj:last_traj] + + actions_batch = self.actions[:, start:stop] + old_mu_batch = self.mu[:, start:stop] + old_sigma_batch = self.sigma[:, start:stop] + returns_batch = self.returns[:, start:stop] + advantages_batch = self.advantages[:, start:stop] + values_batch = self.values[:, start:stop] + old_actions_log_prob_batch = self.actions_log_prob[:, start:stop] + + # reshape to [num_envs, time, num layers, hidden dim] (original shape: [time, num_layers, num_envs, hidden_dim]) + # then take only time steps after dones (flattens num envs and time dimensions), + # take a batch of trajectories and finally reshape back to [num_layers, batch, hidden_dim] + last_was_done = last_was_done.permute(1, 0) + hid_a_batch = [ saved_hidden_states.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj].transpose(1, 0).contiguous() + for saved_hidden_states in self.saved_hidden_states_a ] + hid_c_batch = [ saved_hidden_states.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj].transpose(1, 0).contiguous() + for saved_hidden_states in self.saved_hidden_states_c ] + # remove the tuple for GRU + hid_a_batch = hid_a_batch[0] if len(hid_a_batch)==1 else hid_a_batch + hid_c_batch = hid_c_batch[0] if len(hid_c_batch)==1 else hid_a_batch + + yield obs_batch, critic_obs_batch, actions_batch, values_batch, advantages_batch, returns_batch, \ + old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, (hid_a_batch, hid_c_batch), masks_batch + + first_traj = last_traj \ No newline at end of file diff --git a/rsl_rl/utils/__init__.py b/rsl_rl/utils/__init__.py new file mode 100644 index 0000000..1b505f3 --- /dev/null +++ b/rsl_rl/utils/__init__.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# Copyright (c) 2021 ETH Zurich, Nikita Rudin + +from .utils import split_and_pad_trajectories, unpad_trajectories \ No newline at end of file diff --git a/rsl_rl/utils/utils.py b/rsl_rl/utils/utils.py new file mode 100644 index 0000000..b6affab --- /dev/null +++ b/rsl_rl/utils/utils.py @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# Copyright (c) 2021 ETH Zurich, Nikita Rudin + +import torch + +def split_and_pad_trajectories(tensor, dones): + """ Splits trajectories at done indices. Then concatenates them and padds with zeros up to the length og the longest trajectory. + Returns masks corresponding to valid parts of the trajectories + Example: + Input: [ [a1, a2, a3, a4 | a5, a6], + [b1, b2 | b3, b4, b5 | b6] + ] + + Output:[ [a1, a2, a3, a4], | [ [True, True, True, True], + [a5, a6, 0, 0], | [True, True, False, False], + [b1, b2, 0, 0], | [True, True, False, False], + [b3, b4, b5, 0], | [True, True, True, False], + [b6, 0, 0, 0] | [True, False, False, False], + ] | ] + + Assumes that the inputy has the following dimension order: [time, number of envs, aditional dimensions] + """ + dones = dones.clone() + dones[-1] = 1 + # Permute the buffers to have order (num_envs, num_transitions_per_env, ...), for correct reshaping + flat_dones = dones.transpose(1, 0).reshape(-1, 1) + + # Get length of trajectory by counting the number of successive not done elements + done_indices = torch.cat((flat_dones.new_tensor([-1], dtype=torch.int64), flat_dones.nonzero()[:, 0])) + trajectory_lengths = done_indices[1:] - done_indices[:-1] + trajectory_lengths_list = trajectory_lengths.tolist() + # Extract the individual trajectories + trajectories = torch.split(tensor.transpose(1, 0).flatten(0, 1),trajectory_lengths_list) + padded_trajectories = torch.nn.utils.rnn.pad_sequence(trajectories) + + + trajectory_masks = trajectory_lengths > torch.arange(0, tensor.shape[0], device=tensor.device).unsqueeze(1) + return padded_trajectories, trajectory_masks + +def unpad_trajectories(trajectories, masks): + """ Does the inverse operation of split_and_pad_trajectories() + """ + # Need to transpose before and after the masking to have proper reshaping + return trajectories.transpose(1, 0)[masks.transpose(1, 0)].view(-1, trajectories.shape[0], trajectories.shape[-1]).transpose(1, 0) \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..525adc4 --- /dev/null +++ b/setup.py @@ -0,0 +1,17 @@ +from setuptools import setup, find_packages + +setup(name='rsl_rl', + version='1.0.2', + packages=find_packages(), + author='Nikita Rudin', + author_email='rudinn@ethz.ch', + license="BSD-3-Clause", + packages=find_packages(), + description='Fast and simple RL algorithms implemented in pytorch', + python_requires='>=3.6', + install_requires=[ + "torch>=1.4.0", + "torchvision>=0.5.0", + "numpy>=1.16.4", + ], + )