From c6336b1e70cb58730270e37f5668021c772840ef Mon Sep 17 00:00:00 2001 From: Mayank Mittal Date: Fri, 20 Dec 2024 22:44:57 +0100 Subject: [PATCH] fixes sampler --- rsl_rl/runners/on_policy_runner.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/rsl_rl/runners/on_policy_runner.py b/rsl_rl/runners/on_policy_runner.py index c2bc00c..b451618 100644 --- a/rsl_rl/runners/on_policy_runner.py +++ b/rsl_rl/runners/on_policy_runner.py @@ -89,9 +89,14 @@ class OnPolicyRunner: if init_at_random_ep_len: if isinstance(self.env.max_episode_length, float): raise ValueError("Cannot initialize at random episode length with float max_episode_length!") - self.env.episode_length_buf = torch.randint_like( - self.env.episode_length_buf, high=self.env.max_episode_length - ) + elif isinstance(self.env.max_episode_length, torch.Tensor): + # ref: https://github.com/pytorch/pytorch/issues/89438 + samples = torch.randint(2**63 - 1, size=self.env.episode_length_buf.shape, device=self.env.episode_length_buf.device) + self.env.episode_length_buf = samples % self.env.max_episode_length + else: + self.env.episode_length_buf = torch.randint_like( + self.env.episode_length_buf, high=self.env.max_episode_length + ) obs, extras = self.env.get_observations() critic_obs = extras["observations"].get("critic", obs) obs, critic_obs = obs.to(self.device), critic_obs.to(self.device)