From b9047fbdd246143def84e8aa2bdb97e4bcc1811f Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Fri, 22 Mar 2024 13:25:23 +0000 Subject: [PATCH] fix environment seeding --- lerobot/common/envs/abstract.py | 12 ++++++++++-- lerobot/common/envs/aloha/env.py | 7 ++----- lerobot/common/envs/pusht/env.py | 8 ++++---- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/lerobot/common/envs/abstract.py b/lerobot/common/envs/abstract.py index a449e23f..01250d1c 100644 --- a/lerobot/common/envs/abstract.py +++ b/lerobot/common/envs/abstract.py @@ -4,6 +4,8 @@ from typing import Optional from tensordict import TensorDict from torchrl.envs import EnvBase +from lerobot.common.utils import set_seed + class AbstractEnv(EnvBase): def __init__( @@ -34,7 +36,13 @@ class AbstractEnv(EnvBase): self._make_env() self._make_spec() - self._current_seed = self.set_seed(seed) + + # self._next_seed will be used for the next reset. It is recommended that when self.set_seed is called + # you store the return value in self._next_seed (it will be a new randomly generated seed). + self._next_seed = seed + # Don't store the result of this in self._next_seed, as we want to make sure that the first time + # self._reset is called, we use seed. + self.set_seed(seed) if self.num_prev_obs > 0: self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs) @@ -59,4 +67,4 @@ class AbstractEnv(EnvBase): raise NotImplementedError("Abstract method") def _set_seed(self, seed: Optional[int]): - raise NotImplementedError("Abstract method") + set_seed(seed) diff --git a/lerobot/common/envs/aloha/env.py b/lerobot/common/envs/aloha/env.py index af2b354b..a001ca55 100644 --- a/lerobot/common/envs/aloha/env.py +++ b/lerobot/common/envs/aloha/env.py @@ -126,9 +126,8 @@ class AlohaEnv(AbstractEnv): logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.") AlohaEnv._reset_warning_issued = True - # we need to handle seed iteration, since self._env.reset() rely an internal _seed. - self._current_seed += 1 - self.set_seed(self._current_seed) + # Seed the environment and update the seed to be used for the next reset. + self._next_seed = self.set_seed(self._next_seed) # TODO(rcadene): do not use global variable for this if "sim_transfer_cube" in self.task: @@ -137,8 +136,6 @@ class AlohaEnv(AbstractEnv): BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset raw_obs = self._env.reset() - # TODO(rcadene): add assert - # assert self._current_seed == self._env._seed obs = self._format_raw_obs(raw_obs.observation) diff --git a/lerobot/common/envs/pusht/env.py b/lerobot/common/envs/pusht/env.py index 3824a5d2..070c718f 100644 --- a/lerobot/common/envs/pusht/env.py +++ b/lerobot/common/envs/pusht/env.py @@ -106,11 +106,9 @@ class PushtEnv(AbstractEnv): logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.") PushtEnv._reset_warning_issued = True - # we need to handle seed iteration, since self._env.reset() rely an internal _seed. - self._current_seed += 1 - self.set_seed(self._current_seed) + # Seed the environment and update the seed to be used for the next reset. + self._next_seed = self.set_seed(self._next_seed) raw_obs = self._env.reset() - assert self._current_seed == self._env._seed obs = self._format_raw_obs(raw_obs) @@ -239,5 +237,7 @@ class PushtEnv(AbstractEnv): ) def _set_seed(self, seed: Optional[int]): + # Set global seed. set_seed(seed) + # Set PushTImageEnv seed as it relies on it's own internal _seed attribute. self._env.seed(seed)