174 lines
7.4 KiB
Python
174 lines
7.4 KiB
Python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import datetime as dt
|
|
import os
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Type
|
|
|
|
import draccus
|
|
from huggingface_hub import hf_hub_download
|
|
from huggingface_hub.errors import HfHubHTTPError
|
|
|
|
from lerobot.common import envs
|
|
from lerobot.common.optim import OptimizerConfig
|
|
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
|
from lerobot.common.utils.hub import HubMixin
|
|
from lerobot.configs import parser
|
|
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
|
|
from lerobot.configs.policies import PreTrainedConfig
|
|
|
|
TRAIN_CONFIG_NAME = "train_config.json"
|
|
|
|
|
|
@dataclass
|
|
class TrainPipelineConfig(HubMixin):
|
|
dataset: DatasetConfig
|
|
env: envs.EnvConfig | None = None
|
|
policy: PreTrainedConfig | None = None
|
|
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
|
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
|
output_dir: Path | None = None
|
|
job_name: str | None = None
|
|
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
|
|
# `dir` is the directory of an existing run with at least one checkpoint in it.
|
|
# Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
|
|
# regardless of what's provided with the training command at the time of resumption.
|
|
resume: bool = False
|
|
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
|
# AND for the evaluation environments.
|
|
seed: int | None = 1000
|
|
# Number of workers for the dataloader.
|
|
num_workers: int = 4
|
|
batch_size: int = 8
|
|
steps: int = 100_000
|
|
eval_freq: int = 20_000
|
|
log_freq: int = 200
|
|
save_checkpoint: bool = True
|
|
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
|
save_freq: int = 20_000
|
|
use_policy_training_preset: bool = True
|
|
optimizer: OptimizerConfig | None = None
|
|
scheduler: LRSchedulerConfig | None = None
|
|
eval: EvalConfig = field(default_factory=EvalConfig)
|
|
wandb: WandBConfig = field(default_factory=WandBConfig)
|
|
|
|
def __post_init__(self):
|
|
self.checkpoint_path = None
|
|
|
|
def validate(self):
|
|
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
|
policy_path = parser.get_path_arg("policy")
|
|
if policy_path:
|
|
# Only load the policy config
|
|
cli_overrides = parser.get_cli_overrides("policy")
|
|
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
|
self.policy.pretrained_path = policy_path
|
|
elif self.resume:
|
|
# The entire train config is already loaded, we just need to get the checkpoint dir
|
|
config_path = parser.parse_arg("config_path")
|
|
if not config_path:
|
|
raise ValueError("A config_path is expected when resuming a run.")
|
|
if not Path(config_path).resolve().exists():
|
|
raise NotADirectoryError(
|
|
f"{config_path=} is expected to be a local path. "
|
|
"Resuming from the hub is not supported for now."
|
|
)
|
|
policy_path = Path(config_path).parent
|
|
self.policy.pretrained_path = policy_path
|
|
self.checkpoint_path = policy_path.parent
|
|
|
|
if not self.job_name:
|
|
if self.env is None:
|
|
self.job_name = f"{self.policy.type}"
|
|
else:
|
|
self.job_name = f"{self.env.type}_{self.policy.type}"
|
|
|
|
if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
|
|
raise FileExistsError(
|
|
f"Output directory {self.output_dir} already exists and resume is {self.resume}. "
|
|
f"Please change your output directory so that {self.output_dir} is not overwritten."
|
|
)
|
|
elif not self.output_dir:
|
|
now = dt.datetime.now()
|
|
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
|
|
self.output_dir = Path("outputs/train") / train_dir
|
|
|
|
if isinstance(self.dataset.repo_id, list):
|
|
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
|
|
|
|
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
|
|
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
|
|
elif self.use_policy_training_preset and not self.resume:
|
|
self.optimizer = self.policy.get_optimizer_preset()
|
|
self.scheduler = self.policy.get_scheduler_preset()
|
|
|
|
@classmethod
|
|
def __get_path_fields__(cls) -> list[str]:
|
|
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
|
return ["policy"]
|
|
|
|
def to_dict(self) -> dict:
|
|
return draccus.encode(self)
|
|
|
|
def _save_pretrained(self, save_directory: Path) -> None:
|
|
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
|
draccus.dump(self, f, indent=4)
|
|
|
|
@classmethod
|
|
def from_pretrained(
|
|
cls: Type["TrainPipelineConfig"],
|
|
pretrained_name_or_path: str | Path,
|
|
*,
|
|
force_download: bool = False,
|
|
resume_download: bool = None,
|
|
proxies: dict | None = None,
|
|
token: str | bool | None = None,
|
|
cache_dir: str | Path | None = None,
|
|
local_files_only: bool = False,
|
|
revision: str | None = None,
|
|
**kwargs,
|
|
) -> "TrainPipelineConfig":
|
|
model_id = str(pretrained_name_or_path)
|
|
config_file: str | None = None
|
|
if Path(model_id).is_dir():
|
|
if TRAIN_CONFIG_NAME in os.listdir(model_id):
|
|
config_file = os.path.join(model_id, TRAIN_CONFIG_NAME)
|
|
else:
|
|
print(f"{TRAIN_CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
|
elif Path(model_id).is_file():
|
|
config_file = model_id
|
|
else:
|
|
try:
|
|
config_file = hf_hub_download(
|
|
repo_id=model_id,
|
|
filename=TRAIN_CONFIG_NAME,
|
|
revision=revision,
|
|
cache_dir=cache_dir,
|
|
force_download=force_download,
|
|
proxies=proxies,
|
|
resume_download=resume_download,
|
|
token=token,
|
|
local_files_only=local_files_only,
|
|
)
|
|
except HfHubHTTPError as e:
|
|
raise FileNotFoundError(
|
|
f"{TRAIN_CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
|
|
) from e
|
|
|
|
cli_args = kwargs.pop("cli_args", [])
|
|
cfg = draccus.parse(cls, config_file, args=cli_args)
|
|
|
|
return cfg
|