fixes sampler

This commit is contained in:
Mayank Mittal 2024-12-20 22:44:57 +01:00
parent dbbae6f103
commit c6336b1e70
1 changed files with 8 additions and 3 deletions

View File

@ -89,9 +89,14 @@ class OnPolicyRunner:
if init_at_random_ep_len:
if isinstance(self.env.max_episode_length, float):
raise ValueError("Cannot initialize at random episode length with float max_episode_length!")
self.env.episode_length_buf = torch.randint_like(
self.env.episode_length_buf, high=self.env.max_episode_length
)
elif isinstance(self.env.max_episode_length, torch.Tensor):
# ref: https://github.com/pytorch/pytorch/issues/89438
samples = torch.randint(2**63 - 1, size=self.env.episode_length_buf.shape, device=self.env.episode_length_buf.device)
self.env.episode_length_buf = samples % self.env.max_episode_length
else:
self.env.episode_length_buf = torch.randint_like(
self.env.episode_length_buf, high=self.env.max_episode_length
)
obs, extras = self.env.get_observations()
critic_obs = extras["observations"].get("critic", obs)
obs, critic_obs = obs.to(self.device), critic_obs.to(self.device)