Move reset_warning_issued flag to class attribute
This commit is contained in:
parent
b54cdc9a0f
commit
18fa88475b
|
@ -35,6 +35,8 @@ _has_gym = importlib.util.find_spec("gym") is not None
|
||||||
|
|
||||||
|
|
||||||
class AlohaEnv(AbstractEnv):
|
class AlohaEnv(AbstractEnv):
|
||||||
|
_reset_warning_issued = False
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
task,
|
task,
|
||||||
|
@ -58,7 +60,6 @@ class AlohaEnv(AbstractEnv):
|
||||||
num_prev_obs=num_prev_obs,
|
num_prev_obs=num_prev_obs,
|
||||||
num_prev_action=num_prev_action,
|
num_prev_action=num_prev_action,
|
||||||
)
|
)
|
||||||
self._reset_warning_issued = False
|
|
||||||
|
|
||||||
def _make_env(self):
|
def _make_env(self):
|
||||||
if not _has_gym:
|
if not _has_gym:
|
||||||
|
@ -121,9 +122,9 @@ class AlohaEnv(AbstractEnv):
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
def _reset(self, tensordict: Optional[TensorDict] = None):
|
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||||
if tensordict is not None and not self._reset_warning_issued:
|
if tensordict is not None and not AlohaEnv._reset_warning_issued:
|
||||||
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
|
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
|
||||||
self._reset_warning_issued = True
|
AlohaEnv._reset_warning_issued = True
|
||||||
|
|
||||||
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
|
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
|
||||||
self._current_seed += 1
|
self._current_seed += 1
|
||||||
|
|
|
@ -20,6 +20,8 @@ _has_gym = importlib.util.find_spec("gym") is not None
|
||||||
|
|
||||||
|
|
||||||
class PushtEnv(AbstractEnv):
|
class PushtEnv(AbstractEnv):
|
||||||
|
_reset_warning_issued = False
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
task="pusht",
|
task="pusht",
|
||||||
|
@ -43,7 +45,6 @@ class PushtEnv(AbstractEnv):
|
||||||
num_prev_obs=num_prev_obs,
|
num_prev_obs=num_prev_obs,
|
||||||
num_prev_action=num_prev_action,
|
num_prev_action=num_prev_action,
|
||||||
)
|
)
|
||||||
self._reset_warning_issued = False
|
|
||||||
|
|
||||||
def _make_env(self):
|
def _make_env(self):
|
||||||
if not _has_gym:
|
if not _has_gym:
|
||||||
|
@ -81,9 +82,9 @@ class PushtEnv(AbstractEnv):
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
def _reset(self, tensordict: Optional[TensorDict] = None):
|
def _reset(self, tensordict: Optional[TensorDict] = None):
|
||||||
if tensordict is not None and not self._reset_warning_issued:
|
if tensordict is not None and not PushtEnv._reset_warning_issued:
|
||||||
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
|
logging.warning(f"{self.__class__.__name__}._reset ignores the provided tensordict.")
|
||||||
self._reset_warning_issued = True
|
PushtEnv._reset_warning_issued = True
|
||||||
|
|
||||||
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
|
# we need to handle seed iteration, since self._env.reset() rely an internal _seed.
|
||||||
self._current_seed += 1
|
self._current_seed += 1
|
||||||
|
|
Loading…
Reference in New Issue