fixes sampler
This commit is contained in:
parent
dbbae6f103
commit
c6336b1e70
|
@ -89,6 +89,11 @@ class OnPolicyRunner:
|
||||||
if init_at_random_ep_len:
|
if init_at_random_ep_len:
|
||||||
if isinstance(self.env.max_episode_length, float):
|
if isinstance(self.env.max_episode_length, float):
|
||||||
raise ValueError("Cannot initialize at random episode length with float max_episode_length!")
|
raise ValueError("Cannot initialize at random episode length with float max_episode_length!")
|
||||||
|
elif isinstance(self.env.max_episode_length, torch.Tensor):
|
||||||
|
# ref: https://github.com/pytorch/pytorch/issues/89438
|
||||||
|
samples = torch.randint(2**63 - 1, size=self.env.episode_length_buf.shape, device=self.env.episode_length_buf.device)
|
||||||
|
self.env.episode_length_buf = samples % self.env.max_episode_length
|
||||||
|
else:
|
||||||
self.env.episode_length_buf = torch.randint_like(
|
self.env.episode_length_buf = torch.randint_like(
|
||||||
self.env.episode_length_buf, high=self.env.max_episode_length
|
self.env.episode_length_buf, high=self.env.max_episode_length
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue