Compare commits

...

4 Commits

Author SHA1 Message Date
Mayank Mittal 73fd7c621b Merge branch 'release' 2024-10-11 14:26:58 +02:00
Mayank Mittal 2fab9bbe1a Fixes device discrepancy for environment and RL agent
Approved-by: Fan Yang
2024-10-11 12:24:56 +00:00
Nikita Rudin a1d25d1fef
Merge pull request #19 from leggedrobotics/master-algorithms-notice
added notice on algorithms branch to README
2024-01-31 18:47:36 +01:00
Lukas Schneider dbebd60086 added notice on algorithms branch to README 2023-12-12 18:44:24 +01:00
2 changed files with 14 additions and 8 deletions

View File

@ -3,6 +3,9 @@
Fast and simple implementation of RL algorithms, designed to run fully on GPU.
This code is an evolution of `rl-pytorch` provided with NVIDIA's Isaac GYM.
| :zap: The `algorithms` branch supports additional algorithms (SAC, DDPG, DSAC, and more)! |
| ------------------------------------------------------------------------------------------------ |
Only PPO is implemented for now. More algorithms will be added later.
Contributions are welcome.

View File

@ -45,8 +45,8 @@ class OnPolicyRunner:
self.obs_normalizer = EmpiricalNormalization(shape=[num_obs], until=1.0e8).to(self.device)
self.critic_obs_normalizer = EmpiricalNormalization(shape=[num_critic_obs], until=1.0e8).to(self.device)
else:
self.obs_normalizer = torch.nn.Identity() # no normalization
self.critic_obs_normalizer = torch.nn.Identity() # no normalization
self.obs_normalizer = torch.nn.Identity().to(self.device) # no normalization
self.critic_obs_normalizer = torch.nn.Identity().to(self.device) # no normalization
# init storage and model
self.alg.init_storage(
self.env.num_envs,
@ -109,18 +109,21 @@ class OnPolicyRunner:
with torch.inference_mode():
for i in range(self.num_steps_per_env):
actions = self.alg.act(obs, critic_obs)
obs, rewards, dones, infos = self.env.step(actions)
obs = self.obs_normalizer(obs)
if "critic" in infos["observations"]:
critic_obs = self.critic_obs_normalizer(infos["observations"]["critic"])
else:
critic_obs = obs
obs, rewards, dones, infos = self.env.step(actions.to(self.env.device))
# move to the right device
obs, critic_obs, rewards, dones = (
obs.to(self.device),
critic_obs.to(self.device),
rewards.to(self.device),
dones.to(self.device),
)
# perform normalization
obs = self.obs_normalizer(obs)
if "critic" in infos["observations"]:
critic_obs = self.critic_obs_normalizer(infos["observations"]["critic"])
else:
critic_obs = obs
# process the step
self.alg.process_env_step(rewards, dones, infos)
if self.log_dir is not None: