refactor(config): Move device & amp args to PreTrainedConfig (#812)

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
Steven Palma 2025-03-06 17:59:28 +01:00 committed by GitHub
parent 10706ed753
commit 5e9473806c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 62 additions and 136 deletions

View File

@ -47,6 +47,7 @@ test-act-ete-train:
--policy.dim_model=64 \ --policy.dim_model=64 \
--policy.n_action_steps=20 \ --policy.n_action_steps=20 \
--policy.chunk_size=20 \ --policy.chunk_size=20 \
--policy.device=$(DEVICE) \
--env.type=aloha \ --env.type=aloha \
--env.episode_length=5 \ --env.episode_length=5 \
--dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \ --dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \
@ -61,7 +62,6 @@ test-act-ete-train:
--save_checkpoint=true \ --save_checkpoint=true \
--log_freq=1 \ --log_freq=1 \
--wandb.enable=false \ --wandb.enable=false \
--device=$(DEVICE) \
--output_dir=tests/outputs/act/ --output_dir=tests/outputs/act/
test-act-ete-train-resume: test-act-ete-train-resume:
@ -72,11 +72,11 @@ test-act-ete-train-resume:
test-act-ete-eval: test-act-ete-eval:
python lerobot/scripts/eval.py \ python lerobot/scripts/eval.py \
--policy.path=tests/outputs/act/checkpoints/000004/pretrained_model \ --policy.path=tests/outputs/act/checkpoints/000004/pretrained_model \
--policy.device=$(DEVICE) \
--env.type=aloha \ --env.type=aloha \
--env.episode_length=5 \ --env.episode_length=5 \
--eval.n_episodes=1 \ --eval.n_episodes=1 \
--eval.batch_size=1 \ --eval.batch_size=1
--device=$(DEVICE)
test-diffusion-ete-train: test-diffusion-ete-train:
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
@ -84,6 +84,7 @@ test-diffusion-ete-train:
--policy.down_dims='[64,128,256]' \ --policy.down_dims='[64,128,256]' \
--policy.diffusion_step_embed_dim=32 \ --policy.diffusion_step_embed_dim=32 \
--policy.num_inference_steps=10 \ --policy.num_inference_steps=10 \
--policy.device=$(DEVICE) \
--env.type=pusht \ --env.type=pusht \
--env.episode_length=5 \ --env.episode_length=5 \
--dataset.repo_id=lerobot/pusht \ --dataset.repo_id=lerobot/pusht \
@ -98,21 +99,21 @@ test-diffusion-ete-train:
--save_freq=2 \ --save_freq=2 \
--log_freq=1 \ --log_freq=1 \
--wandb.enable=false \ --wandb.enable=false \
--device=$(DEVICE) \
--output_dir=tests/outputs/diffusion/ --output_dir=tests/outputs/diffusion/
test-diffusion-ete-eval: test-diffusion-ete-eval:
python lerobot/scripts/eval.py \ python lerobot/scripts/eval.py \
--policy.path=tests/outputs/diffusion/checkpoints/000002/pretrained_model \ --policy.path=tests/outputs/diffusion/checkpoints/000002/pretrained_model \
--policy.device=$(DEVICE) \
--env.type=pusht \ --env.type=pusht \
--env.episode_length=5 \ --env.episode_length=5 \
--eval.n_episodes=1 \ --eval.n_episodes=1 \
--eval.batch_size=1 \ --eval.batch_size=1
--device=$(DEVICE)
test-tdmpc-ete-train: test-tdmpc-ete-train:
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
--policy.type=tdmpc \ --policy.type=tdmpc \
--policy.device=$(DEVICE) \
--env.type=xarm \ --env.type=xarm \
--env.task=XarmLift-v0 \ --env.task=XarmLift-v0 \
--env.episode_length=5 \ --env.episode_length=5 \
@ -128,15 +129,14 @@ test-tdmpc-ete-train:
--save_freq=2 \ --save_freq=2 \
--log_freq=1 \ --log_freq=1 \
--wandb.enable=false \ --wandb.enable=false \
--device=$(DEVICE) \
--output_dir=tests/outputs/tdmpc/ --output_dir=tests/outputs/tdmpc/
test-tdmpc-ete-eval: test-tdmpc-ete-eval:
python lerobot/scripts/eval.py \ python lerobot/scripts/eval.py \
--policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \ --policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \
--policy.device=$(DEVICE) \
--env.type=xarm \ --env.type=xarm \
--env.episode_length=5 \ --env.episode_length=5 \
--env.task=XarmLift-v0 \ --env.task=XarmLift-v0 \
--eval.n_episodes=1 \ --eval.n_episodes=1 \
--eval.batch_size=1 \ --eval.batch_size=1
--device=$(DEVICE)

View File

@ -16,7 +16,6 @@
import logging import logging
import torch
from torch import nn from torch import nn
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
@ -76,7 +75,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
def make_policy( def make_policy(
cfg: PreTrainedConfig, cfg: PreTrainedConfig,
device: str | torch.device,
ds_meta: LeRobotDatasetMetadata | None = None, ds_meta: LeRobotDatasetMetadata | None = None,
env_cfg: EnvConfig | None = None, env_cfg: EnvConfig | None = None,
) -> PreTrainedPolicy: ) -> PreTrainedPolicy:
@ -88,7 +86,6 @@ def make_policy(
Args: Args:
cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will
be loaded with the weights from that path. be loaded with the weights from that path.
device (str): the device to load the policy onto.
ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and
statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None. statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None.
env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be
@ -96,7 +93,7 @@ def make_policy(
Raises: Raises:
ValueError: Either ds_meta or env and env_cfg must be provided. ValueError: Either ds_meta or env and env_cfg must be provided.
NotImplementedError: if the policy.type is 'vqbet' and the device 'mps' (due to an incompatibility) NotImplementedError: if the policy.type is 'vqbet' and the policy device 'mps' (due to an incompatibility)
Returns: Returns:
PreTrainedPolicy: _description_ PreTrainedPolicy: _description_
@ -111,7 +108,7 @@ def make_policy(
# https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment # https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment
# variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be # variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be
# slower than running natively on MPS. # slower than running natively on MPS.
if cfg.type == "vqbet" and str(device) == "mps": if cfg.type == "vqbet" and cfg.device == "mps":
raise NotImplementedError( raise NotImplementedError(
"Current implementation of VQBeT does not support `mps` backend. " "Current implementation of VQBeT does not support `mps` backend. "
"Please use `cpu` or `cuda` backend." "Please use `cpu` or `cuda` backend."
@ -145,7 +142,7 @@ def make_policy(
# Make a fresh policy. # Make a fresh policy.
policy = policy_cls(**kwargs) policy = policy_cls(**kwargs)
policy.to(device) policy.to(cfg.device)
assert isinstance(policy, nn.Module) assert isinstance(policy, nn.Module)
# policy = torch.compile(policy, mode="reduce-overhead") # policy = torch.compile(policy, mode="reduce-overhead")

View File

@ -90,6 +90,7 @@ class PI0Config(PreTrainedConfig):
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()
# TODO(Steven): Validate device and amp? in all policy configs?
"""Input validation (not exhaustive).""" """Input validation (not exhaustive)."""
if self.n_action_steps > self.chunk_size: if self.n_action_steps > self.chunk_size:
raise ValueError( raise ValueError(

View File

@ -45,7 +45,7 @@ def main():
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir) cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
cfg.pretrained_path = ckpt_torch_dir cfg.pretrained_path = ckpt_torch_dir
policy = make_policy(cfg, device, ds_meta=dataset.meta) policy = make_policy(cfg, ds_meta=dataset.meta)
# policy = torch.compile(policy, mode="reduce-overhead") # policy = torch.compile(policy, mode="reduce-overhead")

View File

@ -101,7 +101,7 @@ def main():
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir) cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
cfg.pretrained_path = ckpt_torch_dir cfg.pretrained_path = ckpt_torch_dir
policy = make_policy(cfg, device, dataset_meta) policy = make_policy(cfg, dataset_meta)
# loss_dict = policy.forward(batch, noise=noise, time=time_beta) # loss_dict = policy.forward(batch, noise=noise, time=time_beta)
# loss_dict["loss"].backward() # loss_dict["loss"].backward()

View File

@ -86,7 +86,6 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
cache_dir: str | Path | None = None, cache_dir: str | Path | None = None,
local_files_only: bool = False, local_files_only: bool = False,
revision: str | None = None, revision: str | None = None,
map_location: str = "cpu",
strict: bool = False, strict: bool = False,
**kwargs, **kwargs,
) -> T: ) -> T:
@ -111,7 +110,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
if os.path.isdir(model_id): if os.path.isdir(model_id):
print("Loading weights from local directory") print("Loading weights from local directory")
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE) model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
policy = cls._load_as_safetensor(instance, model_file, map_location, strict) policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
else: else:
try: try:
model_file = hf_hub_download( model_file = hf_hub_download(
@ -125,13 +124,13 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
token=token, token=token,
local_files_only=local_files_only, local_files_only=local_files_only,
) )
policy = cls._load_as_safetensor(instance, model_file, map_location, strict) policy = cls._load_as_safetensor(instance, model_file, config.device, strict)
except HfHubHTTPError as e: except HfHubHTTPError as e:
raise FileNotFoundError( raise FileNotFoundError(
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}" f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
) from e ) from e
policy.to(map_location) policy.to(config.device)
policy.eval() policy.eval()
return policy return policy

View File

@ -12,17 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
import draccus import draccus
from lerobot.common.robot_devices.robots.configs import RobotConfig from lerobot.common.robot_devices.robots.configs import RobotConfig
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
@dataclass @dataclass
@ -57,11 +54,6 @@ class RecordControlConfig(ControlConfig):
# Root directory where the dataset will be stored (e.g. 'dataset/path'). # Root directory where the dataset will be stored (e.g. 'dataset/path').
root: str | Path | None = None root: str | Path | None = None
policy: PreTrainedConfig | None = None policy: PreTrainedConfig | None = None
# TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint.
device: str | None = None # cuda | cpu | mps
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: bool | None = None
# Limit the frames per second. By default, uses the policy fps. # Limit the frames per second. By default, uses the policy fps.
fps: int | None = None fps: int | None = None
# Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize. # Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.
@ -104,27 +96,6 @@ class RecordControlConfig(ControlConfig):
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path self.policy.pretrained_path = policy_path
# When no device or use_amp are given, use the one from training config.
if self.device is None or self.use_amp is None:
train_cfg = TrainPipelineConfig.from_pretrained(policy_path)
if self.device is None:
self.device = train_cfg.device
if self.use_amp is None:
self.use_amp = train_cfg.use_amp
# Automatically switch to available device if necessary
if not is_torch_device_available(self.device):
auto_device = auto_select_torch_device()
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
self.device = auto_device
# Automatically deactivate AMP if necessary
if self.use_amp and not is_amp_available(self.device):
logging.warning(
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
)
self.use_amp = False
@ControlConfig.register_subclass("replay") @ControlConfig.register_subclass("replay")
@dataclass @dataclass

View File

@ -32,6 +32,7 @@ from termcolor import colored
from lerobot.common.datasets.image_writer import safe_stop_image_writer from lerobot.common.datasets.image_writer import safe_stop_image_writer
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import get_features_from_robot from lerobot.common.datasets.utils import get_features_from_robot
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robot_devices.utils import busy_wait from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.utils import get_safe_torch_device, has_method from lerobot.common.utils.utils import get_safe_torch_device, has_method
@ -193,8 +194,6 @@ def record_episode(
episode_time_s, episode_time_s,
display_cameras, display_cameras,
policy, policy,
device,
use_amp,
fps, fps,
single_task, single_task,
): ):
@ -205,8 +204,6 @@ def record_episode(
dataset=dataset, dataset=dataset,
events=events, events=events,
policy=policy, policy=policy,
device=device,
use_amp=use_amp,
fps=fps, fps=fps,
teleoperate=policy is None, teleoperate=policy is None,
single_task=single_task, single_task=single_task,
@ -221,9 +218,7 @@ def control_loop(
display_cameras=False, display_cameras=False,
dataset: LeRobotDataset | None = None, dataset: LeRobotDataset | None = None,
events=None, events=None,
policy=None, policy: PreTrainedPolicy = None,
device: torch.device | str | None = None,
use_amp: bool | None = None,
fps: int | None = None, fps: int | None = None,
single_task: str | None = None, single_task: str | None = None,
): ):
@ -246,9 +241,6 @@ def control_loop(
if dataset is not None and fps is not None and dataset.fps != fps: if dataset is not None and fps is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).") raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
if isinstance(device, str):
device = get_safe_torch_device(device)
timestamp = 0 timestamp = 0
start_episode_t = time.perf_counter() start_episode_t = time.perf_counter()
while timestamp < control_time_s: while timestamp < control_time_s:
@ -260,7 +252,9 @@ def control_loop(
observation = robot.capture_observation() observation = robot.capture_observation()
if policy is not None: if policy is not None:
pred_action = predict_action(observation, policy, device, use_amp) pred_action = predict_action(
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
)
# Action can eventually be clipped using `max_relative_target`, # Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset. # so action actually sent is saved in the dataset.
action = robot.send_action(pred_action) action = robot.send_action(pred_action)

View File

@ -51,8 +51,10 @@ def auto_select_torch_device() -> torch.device:
return torch.device("cpu") return torch.device("cpu")
# TODO(Steven): Remove log. log shouldn't be an argument, this should be handled by the logger level
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device: def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
"""Given a string, return a torch.device with checks on whether the device is available.""" """Given a string, return a torch.device with checks on whether the device is available."""
try_device = str(try_device)
match try_device: match try_device:
case "cuda": case "cuda":
assert torch.cuda.is_available() assert torch.cuda.is_available()
@ -85,6 +87,7 @@ def get_safe_dtype(dtype: torch.dtype, device: str | torch.device):
def is_torch_device_available(try_device: str) -> bool: def is_torch_device_available(try_device: str) -> bool:
try_device = str(try_device) # Ensure try_device is a string
if try_device == "cuda": if try_device == "cuda":
return torch.cuda.is_available() return torch.cuda.is_available()
elif try_device == "mps": elif try_device == "mps":
@ -92,7 +95,7 @@ def is_torch_device_available(try_device: str) -> bool:
elif try_device == "cpu": elif try_device == "cpu":
return True return True
else: else:
raise ValueError(f"Unknown device '{try_device}.") raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.")
def is_amp_available(device: str): def is_amp_available(device: str):

View File

@ -18,11 +18,9 @@ from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from lerobot.common import envs, policies # noqa: F401 from lerobot.common import envs, policies # noqa: F401
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.configs.default import EvalConfig from lerobot.configs.default import EvalConfig
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
@dataclass @dataclass
@ -35,11 +33,6 @@ class EvalPipelineConfig:
policy: PreTrainedConfig | None = None policy: PreTrainedConfig | None = None
output_dir: Path | None = None output_dir: Path | None = None
job_name: str | None = None job_name: str | None = None
# TODO(rcadene, aliberts): By default, use device and use_amp values from policy checkpoint.
device: str | None = None # cuda | cpu | mps
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: bool = False
seed: int | None = 1000 seed: int | None = 1000
def __post_init__(self): def __post_init__(self):
@ -50,27 +43,6 @@ class EvalPipelineConfig:
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path self.policy.pretrained_path = policy_path
# When no device or use_amp are given, use the one from training config.
if self.device is None or self.use_amp is None:
train_cfg = TrainPipelineConfig.from_pretrained(policy_path)
if self.device is None:
self.device = train_cfg.device
if self.use_amp is None:
self.use_amp = train_cfg.use_amp
# Automatically switch to available device if necessary
if not is_torch_device_available(self.device):
auto_device = auto_select_torch_device()
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
self.device = auto_device
# Automatically deactivate AMP if necessary
if self.use_amp and not is_amp_available(self.device):
logging.warning(
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
)
self.use_amp = False
else: else:
logging.warning( logging.warning(
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)." "No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
@ -87,11 +59,6 @@ class EvalPipelineConfig:
eval_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}" eval_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
self.output_dir = Path("outputs/eval") / eval_dir self.output_dir = Path("outputs/eval") / eval_dir
if self.device is None:
raise ValueError("Set one of the following device: cuda, cpu or mps")
elif self.device == "cuda" and self.use_amp is None:
raise ValueError("Set 'use_amp' to True or False.")
@classmethod @classmethod
def __get_path_fields__(cls) -> list[str]: def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`""" """This enables the parser to load config from the policy using `--policy.path=local/dir`"""

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc import abc
import logging
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
@ -25,6 +26,7 @@ from huggingface_hub.errors import HfHubHTTPError
from lerobot.common.optim.optimizers import OptimizerConfig from lerobot.common.optim.optimizers import OptimizerConfig
from lerobot.common.optim.schedulers import LRSchedulerConfig from lerobot.common.optim.schedulers import LRSchedulerConfig
from lerobot.common.utils.hub import HubMixin from lerobot.common.utils.hub import HubMixin
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
# Generic variable that is either PreTrainedConfig or a subclass thereof # Generic variable that is either PreTrainedConfig or a subclass thereof
@ -53,8 +55,24 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
input_features: dict[str, PolicyFeature] = field(default_factory=dict) input_features: dict[str, PolicyFeature] = field(default_factory=dict)
output_features: dict[str, PolicyFeature] = field(default_factory=dict) output_features: dict[str, PolicyFeature] = field(default_factory=dict)
device: str | None = None # cuda | cpu | mp
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: bool = False
def __post_init__(self): def __post_init__(self):
self.pretrained_path = None self.pretrained_path = None
if not self.device or not is_torch_device_available(self.device):
auto_device = auto_select_torch_device()
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
self.device = auto_device.type
# Automatically deactivate AMP if necessary
if self.use_amp and not is_amp_available(self.device):
logging.warning(
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
)
self.use_amp = False
@property @property
def type(self) -> str: def type(self) -> str:

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import datetime as dt import datetime as dt
import logging
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
@ -26,7 +25,6 @@ from lerobot.common import envs
from lerobot.common.optim import OptimizerConfig from lerobot.common.optim import OptimizerConfig
from lerobot.common.optim.schedulers import LRSchedulerConfig from lerobot.common.optim.schedulers import LRSchedulerConfig
from lerobot.common.utils.hub import HubMixin from lerobot.common.utils.hub import HubMixin
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
@ -48,10 +46,6 @@ class TrainPipelineConfig(HubMixin):
# Note that when resuming a run, the default behavior is to use the configuration from the checkpoint, # Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
# regardless of what's provided with the training command at the time of resumption. # regardless of what's provided with the training command at the time of resumption.
resume: bool = False resume: bool = False
device: str | None = None # cuda | cpu | mp
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: bool = False
# `seed` is used for training (eg: model initialization, dataset shuffling) # `seed` is used for training (eg: model initialization, dataset shuffling)
# AND for the evaluation environments. # AND for the evaluation environments.
seed: int | None = 1000 seed: int | None = 1000
@ -74,18 +68,6 @@ class TrainPipelineConfig(HubMixin):
self.checkpoint_path = None self.checkpoint_path = None
def validate(self): def validate(self):
if not self.device:
logging.warning("No device specified, trying to infer device automatically")
device = auto_select_torch_device()
self.device = device.type
# Automatically deactivate AMP if necessary
if self.use_amp and not is_amp_available(self.device):
logging.warning(
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
)
self.use_amp = False
# HACK: We parse again the cli args here to get the pretrained paths if there was some. # HACK: We parse again the cli args here to get the pretrained paths if there was some.
policy_path = parser.get_path_arg("policy") policy_path = parser.get_path_arg("policy")
if policy_path: if policy_path:

View File

@ -267,7 +267,7 @@ def record(
) )
# Load pretrained policy # Load pretrained policy
policy = None if cfg.policy is None else make_policy(cfg.policy, cfg.device, ds_meta=dataset.meta) policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
if not robot.is_connected: if not robot.is_connected:
robot.connect() robot.connect()
@ -298,8 +298,6 @@ def record(
episode_time_s=cfg.episode_time_s, episode_time_s=cfg.episode_time_s,
display_cameras=cfg.display_cameras, display_cameras=cfg.display_cameras,
policy=policy, policy=policy,
device=cfg.device,
use_amp=cfg.use_amp,
fps=cfg.fps, fps=cfg.fps,
single_task=cfg.single_task, single_task=cfg.single_task,
) )

View File

@ -458,7 +458,7 @@ def eval_main(cfg: EvalPipelineConfig):
logging.info(pformat(asdict(cfg))) logging.info(pformat(asdict(cfg)))
# Check device is available # Check device is available
device = get_safe_torch_device(cfg.device, log=True) device = get_safe_torch_device(cfg.policy.device, log=True)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
@ -470,14 +470,14 @@ def eval_main(cfg: EvalPipelineConfig):
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
logging.info("Making policy.") logging.info("Making policy.")
policy = make_policy( policy = make_policy(
cfg=cfg.policy, cfg=cfg.policy,
device=device,
env_cfg=cfg.env, env_cfg=cfg.env,
) )
policy.eval() policy.eval()
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext(): with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy( info = eval_policy(
env, env,
policy, policy,

View File

@ -120,7 +120,7 @@ def train(cfg: TrainPipelineConfig):
set_seed(cfg.seed) set_seed(cfg.seed)
# Check device is available # Check device is available
device = get_safe_torch_device(cfg.device, log=True) device = get_safe_torch_device(cfg.policy.device, log=True)
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
@ -138,13 +138,12 @@ def train(cfg: TrainPipelineConfig):
logging.info("Creating policy") logging.info("Creating policy")
policy = make_policy( policy = make_policy(
cfg=cfg.policy, cfg=cfg.policy,
device=device,
ds_meta=dataset.meta, ds_meta=dataset.meta,
) )
logging.info("Creating optimizer and scheduler") logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
grad_scaler = GradScaler(device, enabled=cfg.use_amp) grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
step = 0 # number of policy updates (forward + backward + optim) step = 0 # number of policy updates (forward + backward + optim)
@ -218,7 +217,7 @@ def train(cfg: TrainPipelineConfig):
cfg.optimizer.grad_clip_norm, cfg.optimizer.grad_clip_norm,
grad_scaler=grad_scaler, grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
use_amp=cfg.use_amp, use_amp=cfg.policy.use_amp,
) )
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
@ -249,7 +248,10 @@ def train(cfg: TrainPipelineConfig):
if cfg.env and is_eval_step: if cfg.env and is_eval_step:
step_id = get_step_identifier(step, cfg.steps) step_id = get_step_identifier(step, cfg.steps)
logging.info(f"Eval policy at step {step}") logging.info(f"Eval policy at step {step}")
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext(): with (
torch.no_grad(),
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
):
eval_info = eval_policy( eval_info = eval_policy(
eval_env, eval_env,
policy, policy,

View File

@ -33,12 +33,11 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
# TODO(rcadene, aliberts): remove dataset download # TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]), dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
policy=make_policy_config(policy_name, **policy_kwargs), policy=make_policy_config(policy_name, **policy_kwargs),
device="cpu",
) )
train_cfg.validate() # Needed for auto-setting some parameters train_cfg.validate() # Needed for auto-setting some parameters
dataset = make_dataset(train_cfg) dataset = make_dataset(train_cfg)
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta, device=train_cfg.device) policy = make_policy(train_cfg.policy, ds_meta=dataset.meta)
policy.train() policy.train()
optimizer, _ = make_optimizer_and_scheduler(train_cfg, policy) optimizer, _ = make_optimizer_and_scheduler(train_cfg, policy)

View File

@ -52,7 +52,7 @@ from lerobot.common.robot_devices.control_configs import (
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate
from tests.test_robots import make_robot from tests.test_robots import make_robot
from tests.utils import DEVICE, TEST_ROBOT_TYPES, mock_calibration_dir, require_robot from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES) @pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
@ -184,7 +184,7 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
replay(robot, replay_cfg) replay(robot, replay_cfg)
policy_cfg = ACTConfig() policy_cfg = ACTConfig()
policy = make_policy(policy_cfg, ds_meta=dataset.meta, device=DEVICE) policy = make_policy(policy_cfg, ds_meta=dataset.meta)
out_dir = tmp_path / "logger" out_dir = tmp_path / "logger"
@ -229,8 +229,6 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
display_cameras=False, display_cameras=False,
play_sounds=False, play_sounds=False,
num_image_writer_processes=num_image_writer_processes, num_image_writer_processes=num_image_writer_processes,
device=DEVICE,
use_amp=False,
) )
rec_eval_cfg.policy = PreTrainedConfig.from_pretrained(pretrained_policy_path) rec_eval_cfg.policy = PreTrainedConfig.from_pretrained(pretrained_policy_path)

View File

@ -45,7 +45,7 @@ from lerobot.common.robot_devices.robots.utils import make_robot
from lerobot.configs.default import DatasetConfig from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.train import TrainPipelineConfig
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
from tests.utils import DEVICE, require_x86_64_kernel from tests.utils import require_x86_64_kernel
@pytest.fixture @pytest.fixture
@ -349,7 +349,6 @@ def test_factory(env_name, repo_id, policy_name):
dataset=DatasetConfig(repo_id=repo_id, episodes=[0]), dataset=DatasetConfig(repo_id=repo_id, episodes=[0]),
env=make_env_config(env_name), env=make_env_config(env_name),
policy=make_policy_config(policy_name), policy=make_policy_config(policy_name),
device=DEVICE,
) )
dataset = make_dataset(cfg) dataset = make_dataset(cfg)

View File

@ -143,12 +143,11 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]), dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
policy=make_policy_config(policy_name, **policy_kwargs), policy=make_policy_config(policy_name, **policy_kwargs),
env=make_env_config(env_name, **env_kwargs), env=make_env_config(env_name, **env_kwargs),
device=DEVICE,
) )
# Check that we can make the policy object. # Check that we can make the policy object.
dataset = make_dataset(train_cfg) dataset = make_dataset(train_cfg)
policy = make_policy(train_cfg.policy, ds_meta=dataset.meta, device=DEVICE) policy = make_policy(train_cfg.policy, ds_meta=dataset.meta)
assert isinstance(policy, PreTrainedPolicy) assert isinstance(policy, PreTrainedPolicy)
# Check that we run select_actions and get the appropriate output. # Check that we run select_actions and get the appropriate output.
@ -214,7 +213,6 @@ def test_act_backbone_lr():
# TODO(rcadene, aliberts): remove dataset download # TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]), dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001), policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001),
device=DEVICE,
) )
cfg.validate() # Needed for auto-setting some parameters cfg.validate() # Needed for auto-setting some parameters
@ -222,7 +220,7 @@ def test_act_backbone_lr():
assert cfg.policy.optimizer_lr_backbone == 0.001 assert cfg.policy.optimizer_lr_backbone == 0.001
dataset = make_dataset(cfg) dataset = make_dataset(cfg)
policy = make_policy(cfg.policy, device=DEVICE, ds_meta=dataset.meta) policy = make_policy(cfg.policy, ds_meta=dataset.meta)
optimizer, _ = make_optimizer_and_scheduler(cfg, policy) optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
assert len(optimizer.param_groups) == 2 assert len(optimizer.param_groups) == 2
assert optimizer.param_groups[0]["lr"] == cfg.policy.optimizer_lr assert optimizer.param_groups[0]["lr"] == cfg.policy.optimizer_lr