diff --git a/rsl_rl/runners/on_policy_runner.py b/rsl_rl/runners/on_policy_runner.py index c2bc00c..b451618 100644 --- a/rsl_rl/runners/on_policy_runner.py +++ b/rsl_rl/runners/on_policy_runner.py @@ -89,9 +89,14 @@ class OnPolicyRunner: if init_at_random_ep_len: if isinstance(self.env.max_episode_length, float): raise ValueError("Cannot initialize at random episode length with float max_episode_length!") - self.env.episode_length_buf = torch.randint_like( - self.env.episode_length_buf, high=self.env.max_episode_length - ) + elif isinstance(self.env.max_episode_length, torch.Tensor): + # ref: https://github.com/pytorch/pytorch/issues/89438 + samples = torch.randint(2**63 - 1, size=self.env.episode_length_buf.shape, device=self.env.episode_length_buf.device) + self.env.episode_length_buf = samples % self.env.max_episode_length + else: + self.env.episode_length_buf = torch.randint_like( + self.env.episode_length_buf, high=self.env.max_episode_length + ) obs, extras = self.env.get_observations() critic_obs = extras["observations"].get("critic", obs) obs, critic_obs = obs.to(self.device), critic_obs.to(self.device)