From 7cdd6d24506a890f7b027d72a9da1fd42f8daa2a Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Mon, 25 Mar 2024 17:19:28 +0100 Subject: [PATCH] Renamed set_seed -> set_global_seed --- lerobot/common/envs/abstract.py | 4 ++-- lerobot/common/envs/aloha/env.py | 4 ++-- lerobot/common/envs/pusht/env.py | 4 ++-- lerobot/common/envs/simxarm/env.py | 6 ++++-- lerobot/common/utils.py | 2 +- lerobot/scripts/eval.py | 4 ++-- lerobot/scripts/train.py | 4 ++-- tests/test_datasets.py | 1 - tests/test_envs.py | 3 +-- 9 files changed, 16 insertions(+), 16 deletions(-) diff --git a/lerobot/common/envs/abstract.py b/lerobot/common/envs/abstract.py index 01250d1c..bca0af3e 100644 --- a/lerobot/common/envs/abstract.py +++ b/lerobot/common/envs/abstract.py @@ -4,7 +4,7 @@ from typing import Optional from tensordict import TensorDict from torchrl.envs import EnvBase -from lerobot.common.utils import set_seed +from lerobot.common.utils import set_global_seed class AbstractEnv(EnvBase): @@ -67,4 +67,4 @@ class AbstractEnv(EnvBase): raise NotImplementedError("Abstract method") def _set_seed(self, seed: Optional[int]): - set_seed(seed) + set_global_seed(seed) diff --git a/lerobot/common/envs/aloha/env.py b/lerobot/common/envs/aloha/env.py index 8e735237..99f12cf0 100644 --- a/lerobot/common/envs/aloha/env.py +++ b/lerobot/common/envs/aloha/env.py @@ -29,7 +29,7 @@ from lerobot.common.envs.aloha.tasks.sim_end_effector import ( TransferCubeEndEffectorTask, ) from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose -from lerobot.common.utils import set_seed +from lerobot.common.utils import set_global_seed _has_gym = importlib.util.find_spec("gymnasium") is not None @@ -290,7 +290,7 @@ class AlohaEnv(AbstractEnv): ) def _set_seed(self, seed: Optional[int]): - set_seed(seed) + set_global_seed(seed) # TODO(rcadene): seed the env # self._env.seed(seed) logging.warning("Aloha env is not seeded") diff --git a/lerobot/common/envs/pusht/env.py b/lerobot/common/envs/pusht/env.py index 22dbef7f..6c9d211d 100644 --- a/lerobot/common/envs/pusht/env.py +++ b/lerobot/common/envs/pusht/env.py @@ -16,7 +16,7 @@ from torchrl.data.tensor_specs import ( from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform from lerobot.common.envs.abstract import AbstractEnv -from lerobot.common.utils import set_seed +from lerobot.common.utils import set_global_seed _has_gym = importlib.util.find_spec("gymnasium") is not None @@ -238,6 +238,6 @@ class PushtEnv(AbstractEnv): def _set_seed(self, seed: Optional[int]): # Set global seed. - set_seed(seed) + set_global_seed(seed) # Set PushTImageEnv seed as it relies on it's own internal _seed attribute. self._env.seed(seed) diff --git a/lerobot/common/envs/simxarm/env.py b/lerobot/common/envs/simxarm/env.py index 9b08be6a..fc66d013 100644 --- a/lerobot/common/envs/simxarm/env.py +++ b/lerobot/common/envs/simxarm/env.py @@ -1,4 +1,5 @@ import importlib +import logging from collections import deque from typing import Optional @@ -15,7 +16,7 @@ from torchrl.data.tensor_specs import ( from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform from lerobot.common.envs.abstract import AbstractEnv -from lerobot.common.utils import set_seed +from lerobot.common.utils import set_global_seed MAX_NUM_ACTIONS = 4 @@ -229,8 +230,9 @@ class SimxarmEnv(AbstractEnv): ) def _set_seed(self, seed: Optional[int]): - set_seed(seed) + set_global_seed(seed) # self._env.seed(seed) # self._env.action_space.seed(seed) # self.set_seed(seed) + logging.warning("simxarm env is not seeded") self._seed = seed diff --git a/lerobot/common/utils.py b/lerobot/common/utils.py index a56543b7..2af1d966 100644 --- a/lerobot/common/utils.py +++ b/lerobot/common/utils.py @@ -26,7 +26,7 @@ def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device: return device -def set_seed(seed): +def set_global_seed(seed): """Set seed for reproducibility.""" random.seed(seed) np.random.seed(seed) diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index b3d107ab..e30cd9dd 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -50,7 +50,7 @@ from lerobot.common.envs.factory import make_env from lerobot.common.logger import log_output_dir from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.factory import make_policy -from lerobot.common.utils import get_safe_torch_device, init_logging, set_seed +from lerobot.common.utils import get_safe_torch_device, init_logging, set_global_seed def write_video(video_path, stacked_frames, fps): @@ -188,7 +188,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True - set_seed(cfg.seed) + set_global_seed(cfg.seed) log_output_dir(out_dir) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 91d2cf00..3a45ddc9 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -12,7 +12,7 @@ from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.envs.factory import make_env from lerobot.common.logger import Logger, log_output_dir from lerobot.common.policies.factory import make_policy -from lerobot.common.utils import format_big_number, get_safe_torch_device, init_logging, set_seed +from lerobot.common.utils import format_big_number, get_safe_torch_device, init_logging, set_global_seed from lerobot.scripts.eval import eval_policy @@ -122,7 +122,7 @@ def train(cfg: dict, out_dir=None, job_name=None): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True - set_seed(cfg.seed) + set_global_seed(cfg.seed) logging.info("make_offline_buffer") offline_buffer = make_offline_buffer(cfg) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index c3fcfccd..252e0046 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -9,7 +9,6 @@ from .utils import DEVICE, init_config @pytest.mark.parametrize( "env_name,dataset_id", [ - # TODO(rcadene): simxarm is depreciated for now ("simxarm", "lift"), ("pusht", "pusht"), ("aloha", "sim_insertion_human"), diff --git a/tests/test_envs.py b/tests/test_envs.py index 6535535e..1db83afd 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -39,14 +39,13 @@ def print_spec_rollout(env): print("data from rollout:", simple_rollout(100)) -# @pytest.mark.skip(reason="Simxarm is deprecated") @pytest.mark.parametrize( "task,from_pixels,pixels_only", [ ("lift", False, False), ("lift", True, False), ("lift", True, True), - # TODO(aliberts): Add simxarm other task or remove them completely from repo + # TODO(aliberts): Add simxarm other tasks # ("reach", False, False), # ("reach", True, False), # ("push", False, False),