Compare commits

...

10 Commits

Author SHA1 Message Date
Mayank Mittal 73fd7c621b Merge branch 'release' 2024-10-11 14:26:58 +02:00
Mayank Mittal 2fab9bbe1a Fixes device discrepancy for environment and RL agent
Approved-by: Fan Yang
2024-10-11 12:24:56 +00:00
Nikita Rudin a1d25d1fef
Merge pull request #19 from leggedrobotics/master-algorithms-notice
added notice on algorithms branch to README
2024-01-31 18:47:36 +01:00
Lukas Schneider dbebd60086 added notice on algorithms branch to README 2023-12-12 18:44:24 +01:00
Mayank Mittal 8804e4f730 bumps version 2023-12-11 19:09:35 +01:00
Mayank Mittal 0dc9544952 updates to new neptune 2023-12-11 12:02:31 +01:00
Mayank Mittal cde1e87a19 fixes diff storage location 2023-12-11 11:50:05 +01:00
Mayank Mittal d7a15a7436 adds logging of git files for wandb and neptune 2023-12-11 11:34:39 +01:00
Mayank Mittal 3c5190162c adds missing info 2023-11-09 12:49:41 +01:00
Mayank Mittal 112dd68f16 fixes git repo issue 2023-11-09 12:40:07 +01:00
7 changed files with 58 additions and 18 deletions

View File

@ -3,6 +3,9 @@
Fast and simple implementation of RL algorithms, designed to run fully on GPU.
This code is an evolution of `rl-pytorch` provided with NVIDIA's Isaac GYM.
| :zap: The `algorithms` branch supports additional algorithms (SAC, DDPG, DSAC, and more)! |
| ------------------------------------------------------------------------------------------------ |
Only PPO is implemented for now. More algorithms will be added later.
Contributions are welcome.

View File

@ -3,5 +3,5 @@
"""Main module for the rsl_rl package."""
__version__ = "2.0.0"
__version__ = "2.0.1"
__license__ = "BSD-3"

View File

@ -45,8 +45,8 @@ class OnPolicyRunner:
self.obs_normalizer = EmpiricalNormalization(shape=[num_obs], until=1.0e8).to(self.device)
self.critic_obs_normalizer = EmpiricalNormalization(shape=[num_critic_obs], until=1.0e8).to(self.device)
else:
self.obs_normalizer = torch.nn.Identity() # no normalization
self.critic_obs_normalizer = torch.nn.Identity() # no normalization
self.obs_normalizer = torch.nn.Identity().to(self.device) # no normalization
self.critic_obs_normalizer = torch.nn.Identity().to(self.device) # no normalization
# init storage and model
self.alg.init_storage(
self.env.num_envs,
@ -109,18 +109,21 @@ class OnPolicyRunner:
with torch.inference_mode():
for i in range(self.num_steps_per_env):
actions = self.alg.act(obs, critic_obs)
obs, rewards, dones, infos = self.env.step(actions)
obs = self.obs_normalizer(obs)
if "critic" in infos["observations"]:
critic_obs = self.critic_obs_normalizer(infos["observations"]["critic"])
else:
critic_obs = obs
obs, rewards, dones, infos = self.env.step(actions.to(self.env.device))
# move to the right device
obs, critic_obs, rewards, dones = (
obs.to(self.device),
critic_obs.to(self.device),
rewards.to(self.device),
dones.to(self.device),
)
# perform normalization
obs = self.obs_normalizer(obs)
if "critic" in infos["observations"]:
critic_obs = self.critic_obs_normalizer(infos["observations"]["critic"])
else:
critic_obs = obs
# process the step
self.alg.process_env_step(rewards, dones, infos)
if self.log_dir is not None:
@ -156,7 +159,12 @@ class OnPolicyRunner:
self.save(os.path.join(self.log_dir, f"model_{it}.pt"))
ep_infos.clear()
if it == start_iter:
store_code_state(self.log_dir, self.git_status_repos)
# obtain all the diff files
git_file_paths = store_code_state(self.log_dir, self.git_status_repos)
# if possible store them to wandb
if self.logger_type in ["wandb", "neptune"] and git_file_paths:
for path in git_file_paths:
self.writer.save_file(path)
self.save(os.path.join(self.log_dir, f"model_{self.current_learning_iteration}.pt"))

View File

@ -8,14 +8,14 @@ from dataclasses import asdict
from torch.utils.tensorboard import SummaryWriter
try:
import neptune.new as neptune
import neptune
except ModuleNotFoundError:
raise ModuleNotFoundError("neptune-client is required to log to Neptune.")
class NeptuneLogger:
def __init__(self, project, token):
self.run = neptune.init(project=project, api_token=token)
self.run = neptune.init_run(project=project, api_token=token)
def store_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg):
self.run["runner_cfg"] = runner_cfg
@ -86,3 +86,7 @@ class NeptuneSummaryWriter(SummaryWriter):
def save_model(self, model_path, iter):
self.neptune_logger.run["model/saved_model_" + str(iter)].upload(model_path)
def save_file(self, path, iter=None):
name = path.rsplit("/", 1)[-1].split(".")[0]
self.neptune_logger.run["git_diff/" + name].upload(path)

View File

@ -58,11 +58,29 @@ def unpad_trajectories(trajectories, masks):
)
def store_code_state(logdir, repositories):
def store_code_state(logdir, repositories) -> list:
git_log_dir = os.path.join(logdir, "git")
os.makedirs(git_log_dir, exist_ok=True)
file_paths = []
for repository_file_path in repositories:
repo = git.Repo(repository_file_path, search_parent_directories=True)
try:
repo = git.Repo(repository_file_path, search_parent_directories=True)
except Exception:
print(f"Could not find git repository in {repository_file_path}. Skipping.")
# skip if not a git repository
continue
# get the name of the repository
repo_name = pathlib.Path(repo.working_dir).name
t = repo.head.commit.tree
content = f"--- git status ---\n{repo.git.status()} \n\n\n--- git diff ---\n{repo.git.diff(t)}"
with open(os.path.join(logdir, f"{repo_name}_git.diff"), "x") as f:
diff_file_name = os.path.join(git_log_dir, f"{repo_name}.diff")
# check if the diff file already exists
if os.path.isfile(diff_file_name):
continue
# write the diff file
print(f"Storing git diff for '{repo_name}' in: {diff_file_name}")
with open(diff_file_name, "x") as f:
content = f"--- git status ---\n{repo.git.status()} \n\n\n--- git diff ---\n{repo.git.diff(t)}"
f.write(content)
# add the file path to the list of files to be uploaded
file_paths.append(diff_file_name)
return file_paths

View File

@ -74,4 +74,7 @@ class WandbSummaryWriter(SummaryWriter):
self.store_config(env_cfg, runner_cfg, alg_cfg, policy_cfg)
def save_model(self, model_path, iter):
wandb.save(model_path)
wandb.save(model_path, base_path=os.path.dirname(model_path))
def save_file(self, path, iter=None):
wandb.save(path, base_path=os.path.dirname(path))

View File

@ -5,8 +5,12 @@ from setuptools import find_packages, setup
setup(
name="rsl_rl",
version="2.0.0",
version="2.0.2",
packages=find_packages(),
author="ETH Zurich, NVIDIA CORPORATION",
maintainer="Nikita Rudin, David Hoeller",
maintainer_email="rudinn@ethz.ch",
url="https://github.com/leggedrobotics/rsl_rl",
license="BSD-3",
description="Fast and simple RL algorithms implemented in pytorch",
python_requires=">=3.6",