Compare commits
2 Commits
master
...
fix/episod
Author | SHA1 | Date |
---|---|---|
Mayank Mittal | c6336b1e70 | |
Mayank Mittal | dbbae6f103 |
|
@ -87,9 +87,16 @@ class OnPolicyRunner:
|
|||
raise AssertionError("logger type not found")
|
||||
|
||||
if init_at_random_ep_len:
|
||||
self.env.episode_length_buf = torch.randint_like(
|
||||
self.env.episode_length_buf, high=int(self.env.max_episode_length)
|
||||
)
|
||||
if isinstance(self.env.max_episode_length, float):
|
||||
raise ValueError("Cannot initialize at random episode length with float max_episode_length!")
|
||||
elif isinstance(self.env.max_episode_length, torch.Tensor):
|
||||
# ref: https://github.com/pytorch/pytorch/issues/89438
|
||||
samples = torch.randint(2**63 - 1, size=self.env.episode_length_buf.shape, device=self.env.episode_length_buf.device)
|
||||
self.env.episode_length_buf = samples % self.env.max_episode_length
|
||||
else:
|
||||
self.env.episode_length_buf = torch.randint_like(
|
||||
self.env.episode_length_buf, high=self.env.max_episode_length
|
||||
)
|
||||
obs, extras = self.env.get_observations()
|
||||
critic_obs = extras["observations"].get("critic", obs)
|
||||
obs, critic_obs = obs.to(self.device), critic_obs.to(self.device)
|
||||
|
|
Loading…
Reference in New Issue