Renamed set_seed -> set_global_seed
This commit is contained in:
parent
058ac991eb
commit
7cdd6d2450
|
@ -4,7 +4,7 @@ from typing import Optional
|
||||||
from tensordict import TensorDict
|
from tensordict import TensorDict
|
||||||
from torchrl.envs import EnvBase
|
from torchrl.envs import EnvBase
|
||||||
|
|
||||||
from lerobot.common.utils import set_seed
|
from lerobot.common.utils import set_global_seed
|
||||||
|
|
||||||
|
|
||||||
class AbstractEnv(EnvBase):
|
class AbstractEnv(EnvBase):
|
||||||
|
@ -67,4 +67,4 @@ class AbstractEnv(EnvBase):
|
||||||
raise NotImplementedError("Abstract method")
|
raise NotImplementedError("Abstract method")
|
||||||
|
|
||||||
def _set_seed(self, seed: Optional[int]):
|
def _set_seed(self, seed: Optional[int]):
|
||||||
set_seed(seed)
|
set_global_seed(seed)
|
||||||
|
|
|
@ -29,7 +29,7 @@ from lerobot.common.envs.aloha.tasks.sim_end_effector import (
|
||||||
TransferCubeEndEffectorTask,
|
TransferCubeEndEffectorTask,
|
||||||
)
|
)
|
||||||
from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose
|
from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose
|
||||||
from lerobot.common.utils import set_seed
|
from lerobot.common.utils import set_global_seed
|
||||||
|
|
||||||
_has_gym = importlib.util.find_spec("gymnasium") is not None
|
_has_gym = importlib.util.find_spec("gymnasium") is not None
|
||||||
|
|
||||||
|
@ -290,7 +290,7 @@ class AlohaEnv(AbstractEnv):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _set_seed(self, seed: Optional[int]):
|
def _set_seed(self, seed: Optional[int]):
|
||||||
set_seed(seed)
|
set_global_seed(seed)
|
||||||
# TODO(rcadene): seed the env
|
# TODO(rcadene): seed the env
|
||||||
# self._env.seed(seed)
|
# self._env.seed(seed)
|
||||||
logging.warning("Aloha env is not seeded")
|
logging.warning("Aloha env is not seeded")
|
||||||
|
|
|
@ -16,7 +16,7 @@ from torchrl.data.tensor_specs import (
|
||||||
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
|
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
|
||||||
|
|
||||||
from lerobot.common.envs.abstract import AbstractEnv
|
from lerobot.common.envs.abstract import AbstractEnv
|
||||||
from lerobot.common.utils import set_seed
|
from lerobot.common.utils import set_global_seed
|
||||||
|
|
||||||
_has_gym = importlib.util.find_spec("gymnasium") is not None
|
_has_gym = importlib.util.find_spec("gymnasium") is not None
|
||||||
|
|
||||||
|
@ -238,6 +238,6 @@ class PushtEnv(AbstractEnv):
|
||||||
|
|
||||||
def _set_seed(self, seed: Optional[int]):
|
def _set_seed(self, seed: Optional[int]):
|
||||||
# Set global seed.
|
# Set global seed.
|
||||||
set_seed(seed)
|
set_global_seed(seed)
|
||||||
# Set PushTImageEnv seed as it relies on it's own internal _seed attribute.
|
# Set PushTImageEnv seed as it relies on it's own internal _seed attribute.
|
||||||
self._env.seed(seed)
|
self._env.seed(seed)
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import importlib
|
import importlib
|
||||||
|
import logging
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
@ -15,7 +16,7 @@ from torchrl.data.tensor_specs import (
|
||||||
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
|
from torchrl.envs.libs.gym import _gym_to_torchrl_spec_transform
|
||||||
|
|
||||||
from lerobot.common.envs.abstract import AbstractEnv
|
from lerobot.common.envs.abstract import AbstractEnv
|
||||||
from lerobot.common.utils import set_seed
|
from lerobot.common.utils import set_global_seed
|
||||||
|
|
||||||
MAX_NUM_ACTIONS = 4
|
MAX_NUM_ACTIONS = 4
|
||||||
|
|
||||||
|
@ -229,8 +230,9 @@ class SimxarmEnv(AbstractEnv):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _set_seed(self, seed: Optional[int]):
|
def _set_seed(self, seed: Optional[int]):
|
||||||
set_seed(seed)
|
set_global_seed(seed)
|
||||||
# self._env.seed(seed)
|
# self._env.seed(seed)
|
||||||
# self._env.action_space.seed(seed)
|
# self._env.action_space.seed(seed)
|
||||||
# self.set_seed(seed)
|
# self.set_seed(seed)
|
||||||
|
logging.warning("simxarm env is not seeded")
|
||||||
self._seed = seed
|
self._seed = seed
|
||||||
|
|
|
@ -26,7 +26,7 @@ def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
|
||||||
return device
|
return device
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed):
|
def set_global_seed(seed):
|
||||||
"""Set seed for reproducibility."""
|
"""Set seed for reproducibility."""
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
|
|
|
@ -50,7 +50,7 @@ from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.logger import log_output_dir
|
from lerobot.common.logger import log_output_dir
|
||||||
from lerobot.common.policies.abstract import AbstractPolicy
|
from lerobot.common.policies.abstract import AbstractPolicy
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.utils import get_safe_torch_device, init_logging, set_seed
|
from lerobot.common.utils import get_safe_torch_device, init_logging, set_global_seed
|
||||||
|
|
||||||
|
|
||||||
def write_video(video_path, stacked_frames, fps):
|
def write_video(video_path, stacked_frames, fps):
|
||||||
|
@ -188,7 +188,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
set_seed(cfg.seed)
|
set_global_seed(cfg.seed)
|
||||||
|
|
||||||
log_output_dir(out_dir)
|
log_output_dir(out_dir)
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ from lerobot.common.datasets.factory import make_offline_buffer
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.logger import Logger, log_output_dir
|
from lerobot.common.logger import Logger, log_output_dir
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.utils import format_big_number, get_safe_torch_device, init_logging, set_seed
|
from lerobot.common.utils import format_big_number, get_safe_torch_device, init_logging, set_global_seed
|
||||||
from lerobot.scripts.eval import eval_policy
|
from lerobot.scripts.eval import eval_policy
|
||||||
|
|
||||||
|
|
||||||
|
@ -122,7 +122,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
set_seed(cfg.seed)
|
set_global_seed(cfg.seed)
|
||||||
|
|
||||||
logging.info("make_offline_buffer")
|
logging.info("make_offline_buffer")
|
||||||
offline_buffer = make_offline_buffer(cfg)
|
offline_buffer = make_offline_buffer(cfg)
|
||||||
|
|
|
@ -9,7 +9,6 @@ from .utils import DEVICE, init_config
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"env_name,dataset_id",
|
"env_name,dataset_id",
|
||||||
[
|
[
|
||||||
# TODO(rcadene): simxarm is depreciated for now
|
|
||||||
("simxarm", "lift"),
|
("simxarm", "lift"),
|
||||||
("pusht", "pusht"),
|
("pusht", "pusht"),
|
||||||
("aloha", "sim_insertion_human"),
|
("aloha", "sim_insertion_human"),
|
||||||
|
|
|
@ -39,14 +39,13 @@ def print_spec_rollout(env):
|
||||||
print("data from rollout:", simple_rollout(100))
|
print("data from rollout:", simple_rollout(100))
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.skip(reason="Simxarm is deprecated")
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"task,from_pixels,pixels_only",
|
"task,from_pixels,pixels_only",
|
||||||
[
|
[
|
||||||
("lift", False, False),
|
("lift", False, False),
|
||||||
("lift", True, False),
|
("lift", True, False),
|
||||||
("lift", True, True),
|
("lift", True, True),
|
||||||
# TODO(aliberts): Add simxarm other task or remove them completely from repo
|
# TODO(aliberts): Add simxarm other tasks
|
||||||
# ("reach", False, False),
|
# ("reach", False, False),
|
||||||
# ("reach", True, False),
|
# ("reach", True, False),
|
||||||
# ("push", False, False),
|
# ("push", False, False),
|
||||||
|
|
Loading…
Reference in New Issue