fix environment seeding
This commit is contained in:
parent
b633748987
commit
b9047fbdd2
|
@ -4,6 +4,8 @@ from typing import Optional
|
|||
from tensordict import TensorDict
|
||||
from torchrl.envs import EnvBase
|
||||
|
||||
from lerobot.common.utils import set_seed
|
||||
|
||||
|
||||
class AbstractEnv(EnvBase):
|
||||
def __init__(
|
||||
|
@ -34,7 +36,13 @@ class AbstractEnv(EnvBase):
|
|||
|
||||
self._make_env()
|
||||
self._make_spec()
|
||||
self._current_seed = self.set_seed(seed)
|
||||
|
||||
# self._next_seed will be used for the next reset. It is recommended that when self.set_seed is called
|
||||
# you store the return value in self._next_seed (it will be a new randomly generated seed).
|
||||
self._next_seed = seed
|
||||
# Don't store the result of this in self._next_seed, as we want to make sure that the first time
|
||||
# self._reset is called, we use seed.
|
||||
self.set_seed(seed)
|
||||
|
||||
if self.num_prev_obs > 0:
|
||||
self._prev_obs_image_queue = deque(maxlen=self.num_prev_obs)
|
||||
|
@ -59,4 +67,4 @@ class AbstractEnv(EnvBase):
|
|||
raise NotImplementedError("Abstract method")
|
||||
|
||||
def _set_seed(self, seed: Optional[int]):
|
||||
raise NotImplementedError("Abstract method")
|
||||
set_seed(seed)
|
||||
|
|
|
@ -126,9 +126,8 @@ class AlohaEnv(AbstractEnv):
|
|||
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
|
||||
AlohaEnv._reset_warning_issued = True
|
||||
|
||||
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
|
||||
self._current_seed += 1
|
||||
self.set_seed(self._current_seed)
|
||||
# Seed the environment and update the seed to be used for the next reset.
|
||||
self._next_seed = self.set_seed(self._next_seed)
|
||||
|
||||
# TODO(rcadene): do not use global variable for this
|
||||
if "sim_transfer_cube" in self.task:
|
||||
|
@ -137,8 +136,6 @@ class AlohaEnv(AbstractEnv):
|
|||
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
|
||||
|
||||
raw_obs = self._env.reset()
|
||||
# TODO(rcadene): add assert
|
||||
# assert self._current_seed == self._env._seed
|
||||
|
||||
obs = self._format_raw_obs(raw_obs.observation)
|
||||
|
||||
|
|
|
@ -106,11 +106,9 @@ class PushtEnv(AbstractEnv):
|
|||
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
|
||||
PushtEnv._reset_warning_issued = True
|
||||
|
||||
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
|
||||
self._current_seed += 1
|
||||
self.set_seed(self._current_seed)
|
||||
# Seed the environment and update the seed to be used for the next reset.
|
||||
self._next_seed = self.set_seed(self._next_seed)
|
||||
raw_obs = self._env.reset()
|
||||
assert self._current_seed == self._env._seed
|
||||
|
||||
obs = self._format_raw_obs(raw_obs)
|
||||
|
||||
|
@ -239,5 +237,7 @@ class PushtEnv(AbstractEnv):
|
|||
)
|
||||
|
||||
def _set_seed(self, seed: Optional[int]):
|
||||
# Set global seed.
|
||||
set_seed(seed)
|
||||
# Set PushTImageEnv seed as it relies on it's own internal _seed attribute.
|
||||
self._env.seed(seed)
|
||||
|
|
Loading…
Reference in New Issue