fixes tensor max ep length
This commit is contained in:
parent
73fd7c621b
commit
dbbae6f103
|
@ -87,8 +87,10 @@ class OnPolicyRunner:
|
||||||
raise AssertionError("logger type not found")
|
raise AssertionError("logger type not found")
|
||||||
|
|
||||||
if init_at_random_ep_len:
|
if init_at_random_ep_len:
|
||||||
|
if isinstance(self.env.max_episode_length, float):
|
||||||
|
raise ValueError("Cannot initialize at random episode length with float max_episode_length!")
|
||||||
self.env.episode_length_buf = torch.randint_like(
|
self.env.episode_length_buf = torch.randint_like(
|
||||||
self.env.episode_length_buf, high=int(self.env.max_episode_length)
|
self.env.episode_length_buf, high=self.env.max_episode_length
|
||||||
)
|
)
|
||||||
obs, extras = self.env.get_observations()
|
obs, extras = self.env.get_observations()
|
||||||
critic_obs = extras["observations"].get("critic", obs)
|
critic_obs = extras["observations"].get("critic", obs)
|
||||||
|
|
Loading…
Reference in New Issue