Move reset_warning_issued flag to class attribute
This commit is contained in:
parent
b54cdc9a0f
commit
18fa88475b
|
@ -35,6 +35,8 @@ _has_gym = importlib.util.find_spec("gym") is not None
|
|||
|
||||
|
||||
class AlohaEnv(AbstractEnv):
|
||||
_reset_warning_issued = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
|
@ -58,7 +60,6 @@ class AlohaEnv(AbstractEnv):
|
|||
num_prev_obs=num_prev_obs,
|
||||
num_prev_action=num_prev_action,
|
||||
)
|
||||
self._reset_warning_issued = False
|
||||
|
||||
def _make_env(self):
|
||||
if not _has_gym:
|
||||
|
@ -121,9 +122,9 @@ class AlohaEnv(AbstractEnv):
|
|||
return obs
|
||||
|
||||
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||
if tensordict is not None and not self._reset_warning_issued:
|
||||
if tensordict is not None and not AlohaEnv._reset_warning_issued:
|
||||
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
|
||||
self._reset_warning_issued = True
|
||||
AlohaEnv._reset_warning_issued = True
|
||||
|
||||
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
|
||||
self._current_seed += 1
|
||||
|
|
|
@ -20,6 +20,8 @@ _has_gym = importlib.util.find_spec("gym") is not None
|
|||
|
||||
|
||||
class PushtEnv(AbstractEnv):
|
||||
_reset_warning_issued = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task="pusht",
|
||||
|
@ -43,7 +45,6 @@ class PushtEnv(AbstractEnv):
|
|||
num_prev_obs=num_prev_obs,
|
||||
num_prev_action=num_prev_action,
|
||||
)
|
||||
self._reset_warning_issued = False
|
||||
|
||||
def _make_env(self):
|
||||
if not _has_gym:
|
||||
|
@ -81,9 +82,9 @@ class PushtEnv(AbstractEnv):
|
|||
return obs
|
||||
|
||||
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||
if tensordict is not None and not self._reset_warning_issued:
|
||||
if tensordict is not None and not PushtEnv._reset_warning_issued:
|
||||
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
|
||||
self._reset_warning_issued = True
|
||||
PushtEnv._reset_warning_issued = True
|
||||
|
||||
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
|
||||
self._current_seed += 1
|
||||
|
|
Loading…
Reference in New Issue