diff --git a/rsl_rl/runners/on_policy_runner.py b/rsl_rl/runners/on_policy_runner.py index 4fe58bf..c2bc00c 100644 --- a/rsl_rl/runners/on_policy_runner.py +++ b/rsl_rl/runners/on_policy_runner.py @@ -87,8 +87,10 @@ class OnPolicyRunner: raise AssertionError("logger type not found") if init_at_random_ep_len: + if isinstance(self.env.max_episode_length, float): + raise ValueError("Cannot initialize at random episode length with float max_episode_length!") self.env.episode_length_buf = torch.randint_like( - self.env.episode_length_buf, high=int(self.env.max_episode_length) + self.env.episode_length_buf, high=self.env.max_episode_length ) obs, extras = self.env.get_observations() critic_obs = extras["observations"].get("critic", obs)