Move reset_warning_issued flag to class attribute

This commit is contained in:
Alexander Soare 2024-03-20 08:09:38 +00:00
parent b54cdc9a0f
commit 18fa88475b
2 changed files with 8 additions and 6 deletions

View File

@ -35,6 +35,8 @@ _has_gym = importlib.util.find_spec("gym") is not None
class AlohaEnv(AbstractEnv):
_reset_warning_issued = False
def __init__(
self,
task,
@ -58,7 +60,6 @@ class AlohaEnv(AbstractEnv):
num_prev_obs=num_prev_obs,
num_prev_action=num_prev_action,
)
self._reset_warning_issued = False
def _make_env(self):
if not _has_gym:
@ -121,9 +122,9 @@ class AlohaEnv(AbstractEnv):
return obs
def _reset(self, tensordict: Optional[TensorDict] = None):
if tensordict is not None and not self._reset_warning_issued:
if tensordict is not None and not AlohaEnv._reset_warning_issued:
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
self._reset_warning_issued = True
AlohaEnv._reset_warning_issued = True
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
self._current_seed += 1

View File

@ -20,6 +20,8 @@ _has_gym = importlib.util.find_spec("gym") is not None
class PushtEnv(AbstractEnv):
_reset_warning_issued = False
def __init__(
self,
task="pusht",
@ -43,7 +45,6 @@ class PushtEnv(AbstractEnv):
num_prev_obs=num_prev_obs,
num_prev_action=num_prev_action,
)
self._reset_warning_issued = False
def _make_env(self):
if not _has_gym:
@ -81,9 +82,9 @@ class PushtEnv(AbstractEnv):
return obs
def _reset(self, tensordict: Optional[TensorDict] = None):
if tensordict is not None and not self._reset_warning_issued:
if tensordict is not None and not PushtEnv._reset_warning_issued:
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
self._reset_warning_issued = True
PushtEnv._reset_warning_issued = True
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
self._current_seed += 1