Compare commits

..

No commits in common. "fix/episode-length" and "master" have entirely different histories.

1 changed files with 3 additions and 10 deletions

View File

@ -87,16 +87,9 @@ class OnPolicyRunner:
raise AssertionError("logger type not found") raise AssertionError("logger type not found")
if init_at_random_ep_len: if init_at_random_ep_len:
if isinstance(self.env.max_episode_length, float): self.env.episode_length_buf = torch.randint_like(
raise ValueError("Cannot initialize at random episode length with float max_episode_length!") self.env.episode_length_buf, high=int(self.env.max_episode_length)
elif isinstance(self.env.max_episode_length, torch.Tensor): )
# ref: https://github.com/pytorch/pytorch/issues/89438
samples = torch.randint(2**63 - 1, size=self.env.episode_length_buf.shape, device=self.env.episode_length_buf.device)
self.env.episode_length_buf = samples % self.env.max_episode_length
else:
self.env.episode_length_buf = torch.randint_like(
self.env.episode_length_buf, high=self.env.max_episode_length
)
obs, extras = self.env.get_observations() obs, extras = self.env.get_observations()
critic_obs = extras["observations"].get("critic", obs) critic_obs = extras["observations"].get("critic", obs)
obs, critic_obs = obs.to(self.device), critic_obs.to(self.device) obs, critic_obs = obs.to(self.device), critic_obs.to(self.device)