fix environment seeding

This commit is contained in:
Alexander Soare 2024-03-22 13:25:23 +00:00
parent b633748987
commit b9047fbdd2
3 changed files with 16 additions and 11 deletions

View File

@ -4,6 +4,8 @@ from typing import Optional
from tensordict import TensorDict
from torchrl.envs import EnvBase
from lerobot.common.utils import set_seed
class AbstractEnv(EnvBase):
def __init__(
@ -34,7 +36,13 @@ class AbstractEnv(EnvBase):
self._make_env()
self._make_spec()
self._current_seed = self.set_seed(seed)
# self._next_seed will be used for the next reset. It is recommended that when self.set_seed is called
# you store the return value in self._next_seed (it will be a new randomly generated seed).
self._next_seed = seed
# Don't store the result of this in self._next_seed, as we want to make sure that the first time
# self._reset is called, we use seed.
self.set_seed(seed)
if self.num_prev_obs > 0:
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
@ -59,4 +67,4 @@ class AbstractEnv(EnvBase):
raise NotImplementedError("Abstract method")
def _set_seed(self, seed: Optional[int]):
raise NotImplementedError("Abstract method")
set_seed(seed)

View File

@ -126,9 +126,8 @@ class AlohaEnv(AbstractEnv):
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
AlohaEnv._reset_warning_issued = True
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
self._current_seed += 1
self.set_seed(self._current_seed)
# Seed the environment and update the seed to be used for the next reset.
self._next_seed = self.set_seed(self._next_seed)
# TODO(rcadene): do not use global variable for this
if "sim_transfer_cube" in self.task:
@ -137,8 +136,6 @@ class AlohaEnv(AbstractEnv):
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
raw_obs = self._env.reset()
# TODO(rcadene): add assert
# assert self._current_seed == self._env._seed
obs = self._format_raw_obs(raw_obs.observation)

View File

@ -106,11 +106,9 @@ class PushtEnv(AbstractEnv):
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
PushtEnv._reset_warning_issued = True
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
self._current_seed += 1
self.set_seed(self._current_seed)
# Seed the environment and update the seed to be used for the next reset.
self._next_seed = self.set_seed(self._next_seed)
raw_obs = self._env.reset()
assert self._current_seed == self._env._seed
obs = self._format_raw_obs(raw_obs)
@ -239,5 +237,7 @@ class PushtEnv(AbstractEnv):
)
def _set_seed(self, seed: Optional[int]):
# Set global seed.
set_seed(seed)
# Set PushTImageEnv seed as it relies on it's own internal _seed attribute.
self._env.seed(seed)