From 5e9473806c78a969f13e2faf941ba1b2950649c4 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 6 Mar 2025 17:59:28 +0100 Subject: [PATCH] refactor(config): Move device & amp args to PreTrainedConfig (#812) Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> --- Makefile | 18 +++++----- lerobot/common/policies/factory.py | 9 ++--- .../common/policies/pi0/configuration_pi0.py | 1 + .../pi0/conversion_scripts/benchmark.py | 2 +- .../conversion_scripts/compare_with_jax.py | 2 +- lerobot/common/policies/pretrained.py | 7 ++-- .../common/robot_devices/control_configs.py | 29 ---------------- lerobot/common/robot_devices/control_utils.py | 16 +++------ lerobot/common/utils/utils.py | 5 ++- lerobot/configs/eval.py | 33 ------------------- lerobot/configs/policies.py | 18 ++++++++++ lerobot/configs/train.py | 18 ---------- lerobot/scripts/control_robot.py | 4 +-- lerobot/scripts/eval.py | 6 ++-- lerobot/scripts/train.py | 12 ++++--- tests/scripts/save_policy_to_safetensors.py | 3 +- tests/test_control_robot.py | 6 ++-- tests/test_datasets.py | 3 +- tests/test_policies.py | 6 ++-- 19 files changed, 62 insertions(+), 136 deletions(-) diff --git a/Makefile b/Makefile index 68f07b21..c82483cc 100644 --- a/Makefile +++ b/Makefile @@ -47,6 +47,7 @@ test-act-ete-train: --policy.dim_model=64 \ --policy.n_action_steps=20 \ --policy.chunk_size=20 \ + --policy.device=$(DEVICE) \ --env.type=aloha \ --env.episode_length=5 \ --dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \ @@ -61,7 +62,6 @@ test-act-ete-train: --save_checkpoint=true \ --log_freq=1 \ --wandb.enable=false \ - --device=$(DEVICE) \ --output_dir=tests/outputs/act/ test-act-ete-train-resume: @@ -72,11 +72,11 @@ test-act-ete-train-resume: test-act-ete-eval: python lerobot/scripts/eval.py \ --policy.path=tests/outputs/act/checkpoints/000004/pretrained_model \ + --policy.device=$(DEVICE) \ --env.type=aloha \ --env.episode_length=5 \ --eval.n_episodes=1 \ - --eval.batch_size=1 \ - --device=$(DEVICE) + --eval.batch_size=1 test-diffusion-ete-train: python lerobot/scripts/train.py \ @@ -84,6 +84,7 @@ test-diffusion-ete-train: --policy.down_dims='[64,128,256]' \ --policy.diffusion_step_embed_dim=32 \ --policy.num_inference_steps=10 \ + --policy.device=$(DEVICE) \ --env.type=pusht \ --env.episode_length=5 \ --dataset.repo_id=lerobot/pusht \ @@ -98,21 +99,21 @@ test-diffusion-ete-train: --save_freq=2 \ --log_freq=1 \ --wandb.enable=false \ - --device=$(DEVICE) \ --output_dir=tests/outputs/diffusion/ test-diffusion-ete-eval: python lerobot/scripts/eval.py \ --policy.path=tests/outputs/diffusion/checkpoints/000002/pretrained_model \ + --policy.device=$(DEVICE) \ --env.type=pusht \ --env.episode_length=5 \ --eval.n_episodes=1 \ - --eval.batch_size=1 \ - --device=$(DEVICE) + --eval.batch_size=1 test-tdmpc-ete-train: python lerobot/scripts/train.py \ --policy.type=tdmpc \ + --policy.device=$(DEVICE) \ --env.type=xarm \ --env.task=XarmLift-v0 \ --env.episode_length=5 \ @@ -128,15 +129,14 @@ test-tdmpc-ete-train: --save_freq=2 \ --log_freq=1 \ --wandb.enable=false \ - --device=$(DEVICE) \ --output_dir=tests/outputs/tdmpc/ test-tdmpc-ete-eval: python lerobot/scripts/eval.py \ --policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \ + --policy.device=$(DEVICE) \ --env.type=xarm \ --env.episode_length=5 \ --env.task=XarmLift-v0 \ --eval.n_episodes=1 \ - --eval.batch_size=1 \ - --device=$(DEVICE) + --eval.batch_size=1 diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index cd440f7a..5d2f6cb5 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -16,7 +16,6 @@ import logging -import torch from torch import nn from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata @@ -76,7 +75,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: def make_policy( cfg: PreTrainedConfig, - device: str | torch.device, ds_meta: LeRobotDatasetMetadata | None = None, env_cfg: EnvConfig | None = None, ) -> PreTrainedPolicy: @@ -88,7 +86,6 @@ def make_policy( Args: cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will be loaded with the weights from that path. - device (str): the device to load the policy onto. ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None. env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be @@ -96,7 +93,7 @@ def make_policy( Raises: ValueError: Either ds_meta or env and env_cfg must be provided. - NotImplementedError: if the policy.type is 'vqbet' and the device 'mps' (due to an incompatibility) + NotImplementedError: if the policy.type is 'vqbet' and the policy device 'mps' (due to an incompatibility) Returns: PreTrainedPolicy: _description_ @@ -111,7 +108,7 @@ def make_policy( # https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment # variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be # slower than running natively on MPS. - if cfg.type == "vqbet" and str(device) == "mps": + if cfg.type == "vqbet" and cfg.device == "mps": raise NotImplementedError( "Current implementation of VQBeT does not support `mps` backend. " "Please use `cpu` or `cuda` backend." @@ -145,7 +142,7 @@ def make_policy( # Make a fresh policy. policy = policy_cls(**kwargs) - policy.to(device) + policy.to(cfg.device) assert isinstance(policy, nn.Module) # policy = torch.compile(policy, mode="reduce-overhead") diff --git a/lerobot/common/policies/pi0/configuration_pi0.py b/lerobot/common/policies/pi0/configuration_pi0.py index ac3a62a8..8c7cc130 100644 --- a/lerobot/common/policies/pi0/configuration_pi0.py +++ b/lerobot/common/policies/pi0/configuration_pi0.py @@ -90,6 +90,7 @@ class PI0Config(PreTrainedConfig): def __post_init__(self): super().__post_init__() + # TODO(Steven): Validate device and amp? in all policy configs? """Input validation (not exhaustive).""" if self.n_action_steps > self.chunk_size: raise ValueError( diff --git a/lerobot/common/policies/pi0/conversion_scripts/benchmark.py b/lerobot/common/policies/pi0/conversion_scripts/benchmark.py index 35b9a45b..cb3c0e9b 100644 --- a/lerobot/common/policies/pi0/conversion_scripts/benchmark.py +++ b/lerobot/common/policies/pi0/conversion_scripts/benchmark.py @@ -45,7 +45,7 @@ def main(): cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir) cfg.pretrained_path = ckpt_torch_dir - policy = make_policy(cfg, device, ds_meta=dataset.meta) + policy = make_policy(cfg, ds_meta=dataset.meta) # policy = torch.compile(policy, mode="reduce-overhead") diff --git a/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py b/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py index ceb8ada0..6bd7c91f 100644 --- a/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py +++ b/lerobot/common/policies/pi0/conversion_scripts/compare_with_jax.py @@ -101,7 +101,7 @@ def main(): cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir) cfg.pretrained_path = ckpt_torch_dir - policy = make_policy(cfg, device, dataset_meta) + policy = make_policy(cfg, dataset_meta) # loss_dict = policy.forward(batch, noise=noise, time=time_beta) # loss_dict["loss"].backward() diff --git a/lerobot/common/policies/pretrained.py b/lerobot/common/policies/pretrained.py index 549ea92b..da4ef157 100644 --- a/lerobot/common/policies/pretrained.py +++ b/lerobot/common/policies/pretrained.py @@ -86,7 +86,6 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): cache_dir: str | Path | None = None, local_files_only: bool = False, revision: str | None = None, - map_location: str = "cpu", strict: bool = False, **kwargs, ) -> T: @@ -111,7 +110,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): if os.path.isdir(model_id): print("Loading weights from local directory") model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE) - policy = cls._load_as_safetensor(instance, model_file, map_location, strict) + policy = cls._load_as_safetensor(instance, model_file, config.device, strict) else: try: model_file = hf_hub_download( @@ -125,13 +124,13 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): token=token, local_files_only=local_files_only, ) - policy = cls._load_as_safetensor(instance, model_file, map_location, strict) + policy = cls._load_as_safetensor(instance, model_file, config.device, strict) except HfHubHTTPError as e: raise FileNotFoundError( f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}" ) from e - policy.to(map_location) + policy.to(config.device) policy.eval() return policy diff --git a/lerobot/common/robot_devices/control_configs.py b/lerobot/common/robot_devices/control_configs.py index 01908472..0ecd8683 100644 --- a/lerobot/common/robot_devices/control_configs.py +++ b/lerobot/common/robot_devices/control_configs.py @@ -12,17 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging from dataclasses import dataclass from pathlib import Path import draccus from lerobot.common.robot_devices.robots.configs import RobotConfig -from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available from lerobot.configs import parser from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.train import TrainPipelineConfig @dataclass @@ -57,11 +54,6 @@ class RecordControlConfig(ControlConfig): # Root directory where the dataset will be stored (e.g. 'dataset/path'). root: str | Path | None = None policy: PreTrainedConfig | None = None - # TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint. - device: str | None = None # cuda | cpu | mps - # `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP, - # automatic gradient scaling is used. - use_amp: bool | None = None # Limit the frames per second. By default, uses the policy fps. fps: int | None = None # Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize. @@ -104,27 +96,6 @@ class RecordControlConfig(ControlConfig): self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) self.policy.pretrained_path = policy_path - # When no device or use_amp are given, use the one from training config. - if self.device is None or self.use_amp is None: - train_cfg = TrainPipelineConfig.from_pretrained(policy_path) - if self.device is None: - self.device = train_cfg.device - if self.use_amp is None: - self.use_amp = train_cfg.use_amp - - # Automatically switch to available device if necessary - if not is_torch_device_available(self.device): - auto_device = auto_select_torch_device() - logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.") - self.device = auto_device - - # Automatically deactivate AMP if necessary - if self.use_amp and not is_amp_available(self.device): - logging.warning( - f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP." - ) - self.use_amp = False - @ControlConfig.register_subclass("replay") @dataclass diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 4782959a..78a8c6a6 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -32,6 +32,7 @@ from termcolor import colored from lerobot.common.datasets.image_writer import safe_stop_image_writer from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.utils import get_features_from_robot +from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.utils import busy_wait from lerobot.common.utils.utils import get_safe_torch_device, has_method @@ -193,8 +194,6 @@ def record_episode( episode_time_s, display_cameras, policy, - device, - use_amp, fps, single_task, ): @@ -205,8 +204,6 @@ def record_episode( dataset=dataset, events=events, policy=policy, - device=device, - use_amp=use_amp, fps=fps, teleoperate=policy is None, single_task=single_task, @@ -221,9 +218,7 @@ def control_loop( display_cameras=False, dataset: LeRobotDataset | None = None, events=None, - policy=None, - device: torch.device | str | None = None, - use_amp: bool | None = None, + policy: PreTrainedPolicy = None, fps: int | None = None, single_task: str | None = None, ): @@ -246,9 +241,6 @@ def control_loop( if dataset is not None and fps is not None and dataset.fps != fps: raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).") - if isinstance(device, str): - device = get_safe_torch_device(device) - timestamp = 0 start_episode_t = time.perf_counter() while timestamp < control_time_s: @@ -260,7 +252,9 @@ def control_loop( observation = robot.capture_observation() if policy is not None: - pred_action = predict_action(observation, policy, device, use_amp) + pred_action = predict_action( + observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp + ) # Action can eventually be clipped using `max_relative_target`, # so action actually sent is saved in the dataset. action = robot.send_action(pred_action) diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py index cd26f04b..563a7b81 100644 --- a/lerobot/common/utils/utils.py +++ b/lerobot/common/utils/utils.py @@ -51,8 +51,10 @@ def auto_select_torch_device() -> torch.device: return torch.device("cpu") +# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device: """Given a string, return a torch.device with checks on whether the device is available.""" + try_device = str(try_device) match try_device: case "cuda": assert torch.cuda.is_available() @@ -85,6 +87,7 @@ def get_safe_dtype(dtype: torch.dtype, device: str | torch.device): def is_torch_device_available(try_device: str) -> bool: + try_device = str(try_device) # Ensure try_device is a string if try_device == "cuda": return torch.cuda.is_available() elif try_device == "mps": @@ -92,7 +95,7 @@ def is_torch_device_available(try_device: str) -> bool: elif try_device == "cpu": return True else: - raise ValueError(f"Unknown device '{try_device}.") + raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.") def is_amp_available(device: str): diff --git a/lerobot/configs/eval.py b/lerobot/configs/eval.py index eb2a6df9..16b35291 100644 --- a/lerobot/configs/eval.py +++ b/lerobot/configs/eval.py @@ -18,11 +18,9 @@ from dataclasses import dataclass, field from pathlib import Path from lerobot.common import envs, policies # noqa: F401 -from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available from lerobot.configs import parser from lerobot.configs.default import EvalConfig from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.train import TrainPipelineConfig @dataclass @@ -35,11 +33,6 @@ class EvalPipelineConfig: policy: PreTrainedConfig | None = None output_dir: Path | None = None job_name: str | None = None - # TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint. - device: str | None = None # cuda | cpu | mps - # `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP, - # automatic gradient scaling is used. - use_amp: bool = False seed: int | None = 1000 def __post_init__(self): @@ -50,27 +43,6 @@ class EvalPipelineConfig: self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) self.policy.pretrained_path = policy_path - # When no device or use_amp are given, use the one from training config. - if self.device is None or self.use_amp is None: - train_cfg = TrainPipelineConfig.from_pretrained(policy_path) - if self.device is None: - self.device = train_cfg.device - if self.use_amp is None: - self.use_amp = train_cfg.use_amp - - # Automatically switch to available device if necessary - if not is_torch_device_available(self.device): - auto_device = auto_select_torch_device() - logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.") - self.device = auto_device - - # Automatically deactivate AMP if necessary - if self.use_amp and not is_amp_available(self.device): - logging.warning( - f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP." - ) - self.use_amp = False - else: logging.warning( "No pretrained path was provided, evaluated policy will be built from scratch (random weights)." @@ -87,11 +59,6 @@ class EvalPipelineConfig: eval_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}" self.output_dir = Path("outputs/eval") / eval_dir - if self.device is None: - raise ValueError("Set one of the following device: cuda, cpu or mps") - elif self.device == "cuda" and self.use_amp is None: - raise ValueError("Set 'use_amp' to True or False.") - @classmethod def __get_path_fields__(cls) -> list[str]: """This enables the parser to load config from the policy using `--policy.path=local/dir`""" diff --git a/lerobot/configs/policies.py b/lerobot/configs/policies.py index 4f52b16c..022d1fb5 100644 --- a/lerobot/configs/policies.py +++ b/lerobot/configs/policies.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc +import logging import os from dataclasses import dataclass, field from pathlib import Path @@ -25,6 +26,7 @@ from huggingface_hub.errors import HfHubHTTPError from lerobot.common.optim.optimizers import OptimizerConfig from lerobot.common.optim.schedulers import LRSchedulerConfig from lerobot.common.utils.hub import HubMixin +from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature # Generic variable that is either PreTrainedConfig or a subclass thereof @@ -53,8 +55,24 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): input_features: dict[str, PolicyFeature] = field(default_factory=dict) output_features: dict[str, PolicyFeature] = field(default_factory=dict) + device: str | None = None # cuda | cpu | mp + # `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP, + # automatic gradient scaling is used. + use_amp: bool = False + def __post_init__(self): self.pretrained_path = None + if not self.device or not is_torch_device_available(self.device): + auto_device = auto_select_torch_device() + logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.") + self.device = auto_device.type + + # Automatically deactivate AMP if necessary + if self.use_amp and not is_amp_available(self.device): + logging.warning( + f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP." + ) + self.use_amp = False @property def type(self) -> str: diff --git a/lerobot/configs/train.py b/lerobot/configs/train.py index 33e98b88..2b147a5b 100644 --- a/lerobot/configs/train.py +++ b/lerobot/configs/train.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import datetime as dt -import logging import os from dataclasses import dataclass, field from pathlib import Path @@ -26,7 +25,6 @@ from lerobot.common import envs from lerobot.common.optim import OptimizerConfig from lerobot.common.optim.schedulers import LRSchedulerConfig from lerobot.common.utils.hub import HubMixin -from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available from lerobot.configs import parser from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig from lerobot.configs.policies import PreTrainedConfig @@ -48,10 +46,6 @@ class TrainPipelineConfig(HubMixin): # Note that when resuming a run, the default behavior is to use the configuration from the checkpoint, # regardless of what's provided with the training command at the time of resumption. resume: bool = False - device: str | None = None # cuda | cpu | mp - # `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP, - # automatic gradient scaling is used. - use_amp: bool = False # `seed` is used for training (eg: model initialization, dataset shuffling) # AND for the evaluation environments. seed: int | None = 1000 @@ -74,18 +68,6 @@ class TrainPipelineConfig(HubMixin): self.checkpoint_path = None def validate(self): - if not self.device: - logging.warning("No device specified, trying to infer device automatically") - device = auto_select_torch_device() - self.device = device.type - - # Automatically deactivate AMP if necessary - if self.use_amp and not is_amp_available(self.device): - logging.warning( - f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP." - ) - self.use_amp = False - # HACK: We parse again the cli args here to get the pretrained paths if there was some. policy_path = parser.get_path_arg("policy") if policy_path: diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index 3a82e5c3..3c3c43f9 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -267,7 +267,7 @@ def record( ) # Load pretrained policy - policy = None if cfg.policy is None else make_policy(cfg.policy, cfg.device, ds_meta=dataset.meta) + policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta) if not robot.is_connected: robot.connect() @@ -298,8 +298,6 @@ def record( episode_time_s=cfg.episode_time_s, display_cameras=cfg.display_cameras, policy=policy, - device=cfg.device, - use_amp=cfg.use_amp, fps=cfg.fps, single_task=cfg.single_task, ) diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 47225993..d7a4201f 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -458,7 +458,7 @@ def eval_main(cfg: EvalPipelineConfig): logging.info(pformat(asdict(cfg))) # Check device is available - device = get_safe_torch_device(cfg.device, log=True) + device = get_safe_torch_device(cfg.policy.device, log=True) torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True @@ -470,14 +470,14 @@ def eval_main(cfg: EvalPipelineConfig): env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) logging.info("Making policy.") + policy = make_policy( cfg=cfg.policy, - device=device, env_cfg=cfg.env, ) policy.eval() - with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext(): + with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(): info = eval_policy( env, policy, diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index e36c697a..f2b1e29e 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -120,7 +120,7 @@ def train(cfg: TrainPipelineConfig): set_seed(cfg.seed) # Check device is available - device = get_safe_torch_device(cfg.device, log=True) + device = get_safe_torch_device(cfg.policy.device, log=True) torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True @@ -138,13 +138,12 @@ def train(cfg: TrainPipelineConfig): logging.info("Creating policy") policy = make_policy( cfg=cfg.policy, - device=device, ds_meta=dataset.meta, ) logging.info("Creating optimizer and scheduler") optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) - grad_scaler = GradScaler(device, enabled=cfg.use_amp) + grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp) step = 0 # number of policy updates (forward + backward + optim) @@ -218,7 +217,7 @@ def train(cfg: TrainPipelineConfig): cfg.optimizer.grad_clip_norm, grad_scaler=grad_scaler, lr_scheduler=lr_scheduler, - use_amp=cfg.use_amp, + use_amp=cfg.policy.use_amp, ) # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we @@ -249,7 +248,10 @@ def train(cfg: TrainPipelineConfig): if cfg.env and is_eval_step: step_id = get_step_identifier(step, cfg.steps) logging.info(f"Eval policy at step {step}") - with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext(): + with ( + torch.no_grad(), + torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(), + ): eval_info = eval_policy( eval_env, policy, diff --git a/tests/scripts/save_policy_to_safetensors.py b/tests/scripts/save_policy_to_safetensors.py index 03726163..60fd9fc0 100644 --- a/tests/scripts/save_policy_to_safetensors.py +++ b/tests/scripts/save_policy_to_safetensors.py @@ -33,12 +33,11 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict): # TODO(rcadene, aliberts): remove dataset download dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]), policy=make_policy_config(policy_name, **policy_kwargs), - device="cpu", ) train_cfg.validate() # Needed for auto-setting some parameters dataset = make_dataset(train_cfg) - policy = make_policy(train_cfg.policy, ds_meta=dataset.meta, device=train_cfg.device) + policy = make_policy(train_cfg.policy, ds_meta=dataset.meta) policy.train() optimizer, _ = make_optimizer_and_scheduler(train_cfg, policy) diff --git a/tests/test_control_robot.py b/tests/test_control_robot.py index 2f24af82..02041e30 100644 --- a/tests/test_control_robot.py +++ b/tests/test_control_robot.py @@ -52,7 +52,7 @@ from lerobot.common.robot_devices.control_configs import ( from lerobot.configs.policies import PreTrainedConfig from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate from tests.test_robots import make_robot -from tests.utils import DEVICE, TEST_ROBOT_TYPES, mock_calibration_dir, require_robot +from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot @pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES) @@ -184,7 +184,7 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock): replay(robot, replay_cfg) policy_cfg = ACTConfig() - policy = make_policy(policy_cfg, ds_meta=dataset.meta, device=DEVICE) + policy = make_policy(policy_cfg, ds_meta=dataset.meta) out_dir = tmp_path / "logger" @@ -229,8 +229,6 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock): display_cameras=False, play_sounds=False, num_image_writer_processes=num_image_writer_processes, - device=DEVICE, - use_amp=False, ) rec_eval_cfg.policy = PreTrainedConfig.from_pretrained(pretrained_policy_path) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 003a60c9..0deaceba 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -45,7 +45,7 @@ from lerobot.common.robot_devices.robots.utils import make_robot from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID -from tests.utils import DEVICE, require_x86_64_kernel +from tests.utils import require_x86_64_kernel @pytest.fixture @@ -349,7 +349,6 @@ def test_factory(env_name, repo_id, policy_name): dataset=DatasetConfig(repo_id=repo_id, episodes=[0]), env=make_env_config(env_name), policy=make_policy_config(policy_name), - device=DEVICE, ) dataset = make_dataset(cfg) diff --git a/tests/test_policies.py b/tests/test_policies.py index 9dab6176..f8e7359c 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -143,12 +143,11 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs): dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]), policy=make_policy_config(policy_name, **policy_kwargs), env=make_env_config(env_name, **env_kwargs), - device=DEVICE, ) # Check that we can make the policy object. dataset = make_dataset(train_cfg) - policy = make_policy(train_cfg.policy, ds_meta=dataset.meta, device=DEVICE) + policy = make_policy(train_cfg.policy, ds_meta=dataset.meta) assert isinstance(policy, PreTrainedPolicy) # Check that we run select_actions and get the appropriate output. @@ -214,7 +213,6 @@ def test_act_backbone_lr(): # TODO(rcadene, aliberts): remove dataset download dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]), policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001), - device=DEVICE, ) cfg.validate() # Needed for auto-setting some parameters @@ -222,7 +220,7 @@ def test_act_backbone_lr(): assert cfg.policy.optimizer_lr_backbone == 0.001 dataset = make_dataset(cfg) - policy = make_policy(cfg.policy, device=DEVICE, ds_meta=dataset.meta) + policy = make_policy(cfg.policy, ds_meta=dataset.meta) optimizer, _ = make_optimizer_and_scheduler(cfg, policy) assert len(optimizer.param_groups) == 2 assert optimizer.param_groups[0]["lr"] == cfg.policy.optimizer_lr