Compare commits

...

2 Commits

Author SHA1 Message Date
Mayank Mittal c6336b1e70 fixes sampler 2024-12-20 22:44:57 +01:00
Mayank Mittal dbbae6f103 fixes tensor max ep length 2024-12-20 22:25:05 +01:00
1 changed files with 10 additions and 3 deletions

View File

@ -87,9 +87,16 @@ class OnPolicyRunner:
raise AssertionError("logger type not found")
if init_at_random_ep_len:
self.env.episode_length_buf = torch.randint_like(
self.env.episode_length_buf, high=int(self.env.max_episode_length)
)
if isinstance(self.env.max_episode_length, float):
raise ValueError("Cannot initialize at random episode length with float max_episode_length!")
elif isinstance(self.env.max_episode_length, torch.Tensor):
# ref: https://github.com/pytorch/pytorch/issues/89438
samples = torch.randint(2**63 - 1, size=self.env.episode_length_buf.shape, device=self.env.episode_length_buf.device)
self.env.episode_length_buf = samples % self.env.max_episode_length
else:
self.env.episode_length_buf = torch.randint_like(
self.env.episode_length_buf, high=self.env.max_episode_length
)
obs, extras = self.env.get_observations()
critic_obs = extras["observations"].get("critic", obs)
obs, critic_obs = obs.to(self.device), critic_obs.to(self.device)