Compare commits

...

4 Commits

Author SHA1 Message Date
Mayank Mittal 73fd7c621b Merge branch 'release' 2024-10-11 14:26:58 +02:00
Mayank Mittal 2fab9bbe1a Fixes device discrepancy for environment and RL agent
Approved-by: Fan Yang
2024-10-11 12:24:56 +00:00
Nikita Rudin a1d25d1fef
Merge pull request #19 from leggedrobotics/master-algorithms-notice
added notice on algorithms branch to README
2024-01-31 18:47:36 +01:00
Lukas Schneider dbebd60086 added notice on algorithms branch to README 2023-12-12 18:44:24 +01:00
2 changed files with 14 additions and 8 deletions

View File

@ -3,6 +3,9 @@
Fast and simple implementation of RL algorithms, designed to run fully on GPU. Fast and simple implementation of RL algorithms, designed to run fully on GPU.
This code is an evolution of `rl-pytorch` provided with NVIDIA's Isaac GYM. This code is an evolution of `rl-pytorch` provided with NVIDIA's Isaac GYM.
| :zap: The `algorithms` branch supports additional algorithms (SAC, DDPG, DSAC, and more)! |
| ------------------------------------------------------------------------------------------------ |
Only PPO is implemented for now. More algorithms will be added later. Only PPO is implemented for now. More algorithms will be added later.
Contributions are welcome. Contributions are welcome.

View File

@ -45,8 +45,8 @@ class OnPolicyRunner:
self.obs_normalizer = EmpiricalNormalization(shape=[num_obs], until=1.0e8).to(self.device) self.obs_normalizer = EmpiricalNormalization(shape=[num_obs], until=1.0e8).to(self.device)
self.critic_obs_normalizer = EmpiricalNormalization(shape=[num_critic_obs], until=1.0e8).to(self.device) self.critic_obs_normalizer = EmpiricalNormalization(shape=[num_critic_obs], until=1.0e8).to(self.device)
else: else:
self.obs_normalizer = torch.nn.Identity() # no normalization self.obs_normalizer = torch.nn.Identity().to(self.device) # no normalization
self.critic_obs_normalizer = torch.nn.Identity() # no normalization self.critic_obs_normalizer = torch.nn.Identity().to(self.device) # no normalization
# init storage and model # init storage and model
self.alg.init_storage( self.alg.init_storage(
self.env.num_envs, self.env.num_envs,
@ -109,18 +109,21 @@ class OnPolicyRunner:
with torch.inference_mode(): with torch.inference_mode():
for i in range(self.num_steps_per_env): for i in range(self.num_steps_per_env):
actions = self.alg.act(obs, critic_obs) actions = self.alg.act(obs, critic_obs)
obs, rewards, dones, infos = self.env.step(actions) obs, rewards, dones, infos = self.env.step(actions.to(self.env.device))
obs = self.obs_normalizer(obs) # move to the right device
if "critic" in infos["observations"]:
critic_obs = self.critic_obs_normalizer(infos["observations"]["critic"])
else:
critic_obs = obs
obs, critic_obs, rewards, dones = ( obs, critic_obs, rewards, dones = (
obs.to(self.device), obs.to(self.device),
critic_obs.to(self.device), critic_obs.to(self.device),
rewards.to(self.device), rewards.to(self.device),
dones.to(self.device), dones.to(self.device),
) )
# perform normalization
obs = self.obs_normalizer(obs)
if "critic" in infos["observations"]:
critic_obs = self.critic_obs_normalizer(infos["observations"]["critic"])
else:
critic_obs = obs
# process the step
self.alg.process_env_step(rewards, dones, infos) self.alg.process_env_step(rewards, dones, infos)
if self.log_dir is not None: if self.log_dir is not None: