import datetime as dt import logging import os from dataclasses import dataclass, field from pathlib import Path from typing import Type import draccus from huggingface_hub import hf_hub_download from huggingface_hub.errors import HfHubHTTPError from lerobot.common import envs from lerobot.common.optim import OptimizerConfig from lerobot.common.optim.schedulers import LRSchedulerConfig from lerobot.common.utils.hub import HubMixin from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available from lerobot.configs import parser from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig from lerobot.configs.policies import PreTrainedConfig TRAIN_CONFIG_NAME = "train_config.json" @dataclass class TrainPipelineConfig(HubMixin): dataset: DatasetConfig env: envs.EnvConfig | None = None policy: PreTrainedConfig | None = None # Set `dir` to where you would like to save all of the run outputs. If you run another training session # with the same value for `dir` its contents will be overwritten unless you set `resume` to true. output_dir: Path | None = None job_name: str | None = None # Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure # `dir` is the directory of an existing run with at least one checkpoint in it. # Note that when resuming a run, the default behavior is to use the configuration from the checkpoint, # regardless of what's provided with the training command at the time of resumption. resume: bool = False device: str | None = None # cuda | cpu | mp # `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP, # automatic gradient scaling is used. use_amp: bool = False # `seed` is used for training (eg: model initialization, dataset shuffling) # AND for the evaluation environments. seed: int | None = 1000 # Number of workers for the dataloader. num_workers: int = 4 batch_size: int = 8 steps: int = 100_000 eval_freq: int = 20_000 log_freq: int = 200 save_checkpoint: bool = True # Checkpoint is saved every `save_freq` training iterations and after the last training step. save_freq: int = 20_000 use_policy_training_preset: bool = True optimizer: OptimizerConfig | None = None scheduler: LRSchedulerConfig | None = None eval: EvalConfig = field(default_factory=EvalConfig) wandb: WandBConfig = field(default_factory=WandBConfig) def __post_init__(self): self.checkpoint_path = None def validate(self): if not self.device: logging.warning("No device specified, trying to infer device automatically") device = auto_select_torch_device() self.device = device.type # Automatically deactivate AMP if necessary if self.use_amp and not is_amp_available(self.device): logging.warning( f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP." ) self.use_amp = False # HACK: We parse again the cli args here to get the pretrained paths if there was some. policy_path = parser.get_path_arg("policy") if policy_path: # Only load the policy config cli_overrides = parser.get_cli_overrides("policy") self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) self.policy.pretrained_path = policy_path elif self.resume: # The entire train config is already loaded, we just need to get the checkpoint dir config_path = parser.parse_arg("config_path") if not config_path: raise ValueError("A config_path is expected when resuming a run.") if not Path(config_path).resolve().exists(): raise NotADirectoryError( f"{config_path=} is expected to be a local path. " "Resuming from the hub is not supported for now." ) policy_path = Path(config_path).parent self.policy.pretrained_path = policy_path self.checkpoint_path = policy_path.parent if not self.job_name: if self.env is None: self.job_name = f"{self.policy.type}" else: self.job_name = f"{self.env.type}_{self.policy.type}" if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir(): raise FileExistsError( f"Output directory {self.output_dir} alreay exists and resume is {self.resume}. " f"Please change your output directory so that {self.output_dir} is not overwritten." ) elif not self.output_dir: now = dt.datetime.now() train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}" self.output_dir = Path("outputs/train") / train_dir if isinstance(self.dataset.repo_id, list): raise NotImplementedError("LeRobotMultiDataset is not currently implemented.") if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None): raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.") elif self.use_policy_training_preset and not self.resume: self.optimizer = self.policy.get_optimizer_preset() self.scheduler = self.policy.get_scheduler_preset() @classmethod def __get_path_fields__(cls) -> list[str]: """This enables the parser to load config from the policy using `--policy.path=local/dir`""" return ["policy"] def to_dict(self) -> dict: return draccus.encode(self) def _save_pretrained(self, save_directory: Path) -> None: with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"): draccus.dump(self, f, indent=4) @classmethod def from_pretrained( cls: Type["TrainPipelineConfig"], pretrained_name_or_path: str | Path, *, force_download: bool = False, resume_download: bool = None, proxies: dict | None = None, token: str | bool | None = None, cache_dir: str | Path | None = None, local_files_only: bool = False, revision: str | None = None, **kwargs, ) -> "TrainPipelineConfig": model_id = str(pretrained_name_or_path) config_file: str | None = None if Path(model_id).is_dir(): if TRAIN_CONFIG_NAME in os.listdir(model_id): config_file = os.path.join(model_id, TRAIN_CONFIG_NAME) else: print(f"{TRAIN_CONFIG_NAME} not found in {Path(model_id).resolve()}") elif Path(model_id).is_file(): config_file = model_id else: try: config_file = hf_hub_download( repo_id=model_id, filename=TRAIN_CONFIG_NAME, revision=revision, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, token=token, local_files_only=local_files_only, ) except HfHubHTTPError as e: raise FileNotFoundError( f"{TRAIN_CONFIG_NAME} not found on the HuggingFace Hub in {model_id}" ) from e cli_args = kwargs.pop("cli_args", []) cfg = draccus.parse(cls, config_file, args=cli_args) return cfg