Compare commits
4 Commits
algorithms
...
master
Author | SHA1 | Date |
---|---|---|
Mayank Mittal | 73fd7c621b | |
Mayank Mittal | 2fab9bbe1a | |
Nikita Rudin | a1d25d1fef | |
Lukas Schneider | dbebd60086 |
|
@ -3,6 +3,9 @@
|
||||||
Fast and simple implementation of RL algorithms, designed to run fully on GPU.
|
Fast and simple implementation of RL algorithms, designed to run fully on GPU.
|
||||||
This code is an evolution of `rl-pytorch` provided with NVIDIA's Isaac GYM.
|
This code is an evolution of `rl-pytorch` provided with NVIDIA's Isaac GYM.
|
||||||
|
|
||||||
|
| :zap: The `algorithms` branch supports additional algorithms (SAC, DDPG, DSAC, and more)! |
|
||||||
|
| ------------------------------------------------------------------------------------------------ |
|
||||||
|
|
||||||
Only PPO is implemented for now. More algorithms will be added later.
|
Only PPO is implemented for now. More algorithms will be added later.
|
||||||
Contributions are welcome.
|
Contributions are welcome.
|
||||||
|
|
||||||
|
|
|
@ -45,8 +45,8 @@ class OnPolicyRunner:
|
||||||
self.obs_normalizer = EmpiricalNormalization(shape=[num_obs], until=1.0e8).to(self.device)
|
self.obs_normalizer = EmpiricalNormalization(shape=[num_obs], until=1.0e8).to(self.device)
|
||||||
self.critic_obs_normalizer = EmpiricalNormalization(shape=[num_critic_obs], until=1.0e8).to(self.device)
|
self.critic_obs_normalizer = EmpiricalNormalization(shape=[num_critic_obs], until=1.0e8).to(self.device)
|
||||||
else:
|
else:
|
||||||
self.obs_normalizer = torch.nn.Identity() # no normalization
|
self.obs_normalizer = torch.nn.Identity().to(self.device) # no normalization
|
||||||
self.critic_obs_normalizer = torch.nn.Identity() # no normalization
|
self.critic_obs_normalizer = torch.nn.Identity().to(self.device) # no normalization
|
||||||
# init storage and model
|
# init storage and model
|
||||||
self.alg.init_storage(
|
self.alg.init_storage(
|
||||||
self.env.num_envs,
|
self.env.num_envs,
|
||||||
|
@ -109,18 +109,21 @@ class OnPolicyRunner:
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
for i in range(self.num_steps_per_env):
|
for i in range(self.num_steps_per_env):
|
||||||
actions = self.alg.act(obs, critic_obs)
|
actions = self.alg.act(obs, critic_obs)
|
||||||
obs, rewards, dones, infos = self.env.step(actions)
|
obs, rewards, dones, infos = self.env.step(actions.to(self.env.device))
|
||||||
obs = self.obs_normalizer(obs)
|
# move to the right device
|
||||||
if "critic" in infos["observations"]:
|
|
||||||
critic_obs = self.critic_obs_normalizer(infos["observations"]["critic"])
|
|
||||||
else:
|
|
||||||
critic_obs = obs
|
|
||||||
obs, critic_obs, rewards, dones = (
|
obs, critic_obs, rewards, dones = (
|
||||||
obs.to(self.device),
|
obs.to(self.device),
|
||||||
critic_obs.to(self.device),
|
critic_obs.to(self.device),
|
||||||
rewards.to(self.device),
|
rewards.to(self.device),
|
||||||
dones.to(self.device),
|
dones.to(self.device),
|
||||||
)
|
)
|
||||||
|
# perform normalization
|
||||||
|
obs = self.obs_normalizer(obs)
|
||||||
|
if "critic" in infos["observations"]:
|
||||||
|
critic_obs = self.critic_obs_normalizer(infos["observations"]["critic"])
|
||||||
|
else:
|
||||||
|
critic_obs = obs
|
||||||
|
# process the step
|
||||||
self.alg.process_env_step(rewards, dones, infos)
|
self.alg.process_env_step(rewards, dones, infos)
|
||||||
|
|
||||||
if self.log_dir is not None:
|
if self.log_dir is not None:
|
||||||
|
|
Loading…
Reference in New Issue