Merge branch 'release'
This commit is contained in:
commit
73fd7c621b
|
@ -45,8 +45,8 @@ class OnPolicyRunner:
|
|||
self.obs_normalizer = EmpiricalNormalization(shape=[num_obs], until=1.0e8).to(self.device)
|
||||
self.critic_obs_normalizer = EmpiricalNormalization(shape=[num_critic_obs], until=1.0e8).to(self.device)
|
||||
else:
|
||||
self.obs_normalizer = torch.nn.Identity() # no normalization
|
||||
self.critic_obs_normalizer = torch.nn.Identity() # no normalization
|
||||
self.obs_normalizer = torch.nn.Identity().to(self.device) # no normalization
|
||||
self.critic_obs_normalizer = torch.nn.Identity().to(self.device) # no normalization
|
||||
# init storage and model
|
||||
self.alg.init_storage(
|
||||
self.env.num_envs,
|
||||
|
@ -109,18 +109,21 @@ class OnPolicyRunner:
|
|||
with torch.inference_mode():
|
||||
for i in range(self.num_steps_per_env):
|
||||
actions = self.alg.act(obs, critic_obs)
|
||||
obs, rewards, dones, infos = self.env.step(actions)
|
||||
obs = self.obs_normalizer(obs)
|
||||
if "critic" in infos["observations"]:
|
||||
critic_obs = self.critic_obs_normalizer(infos["observations"]["critic"])
|
||||
else:
|
||||
critic_obs = obs
|
||||
obs, rewards, dones, infos = self.env.step(actions.to(self.env.device))
|
||||
# move to the right device
|
||||
obs, critic_obs, rewards, dones = (
|
||||
obs.to(self.device),
|
||||
critic_obs.to(self.device),
|
||||
rewards.to(self.device),
|
||||
dones.to(self.device),
|
||||
)
|
||||
# perform normalization
|
||||
obs = self.obs_normalizer(obs)
|
||||
if "critic" in infos["observations"]:
|
||||
critic_obs = self.critic_obs_normalizer(infos["observations"]["critic"])
|
||||
else:
|
||||
critic_obs = obs
|
||||
# process the step
|
||||
self.alg.process_env_step(rewards, dones, infos)
|
||||
|
||||
if self.log_dir is not None:
|
||||
|
|
Loading…
Reference in New Issue