fixes tensor max ep length

This commit is contained in:
Mayank Mittal 2024-12-20 22:25:05 +01:00
parent 73fd7c621b
commit dbbae6f103
1 changed files with 3 additions and 1 deletions

View File

@ -87,8 +87,10 @@ class OnPolicyRunner:
raise AssertionError("logger type not found")
if init_at_random_ep_len:
if isinstance(self.env.max_episode_length, float):
raise ValueError("Cannot initialize at random episode length with float max_episode_length!")
self.env.episode_length_buf = torch.randint_like(
self.env.episode_length_buf, high=int(self.env.max_episode_length)
self.env.episode_length_buf, high=self.env.max_episode_length
)
obs, extras = self.env.get_observations()
critic_obs = extras["observations"].get("critic", obs)