179 lines
7.6 KiB
Python
179 lines
7.6 KiB
Python
import datetime as dt
|
|
import logging
|
|
import os
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Type
|
|
|
|
import draccus
|
|
from huggingface_hub import hf_hub_download
|
|
from huggingface_hub.errors import HfHubHTTPError
|
|
|
|
from lerobot.common import envs
|
|
from lerobot.common.optim import OptimizerConfig
|
|
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
|
from lerobot.common.utils.hub import HubMixin
|
|
from lerobot.common.utils.utils import auto_select_torch_device, is_amp_available
|
|
from lerobot.configs import parser
|
|
from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig
|
|
from lerobot.configs.policies import PreTrainedConfig
|
|
|
|
TRAIN_CONFIG_NAME = "train_config.json"
|
|
|
|
|
|
@dataclass
|
|
class TrainPipelineConfig(HubMixin):
|
|
dataset: DatasetConfig
|
|
env: envs.EnvConfig | None = None
|
|
policy: PreTrainedConfig | None = None
|
|
# Set `dir` to where you would like to save all of the run outputs. If you run another training session
|
|
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
|
output_dir: Path | None = None
|
|
job_name: str | None = None
|
|
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
|
|
# `dir` is the directory of an existing run with at least one checkpoint in it.
|
|
# Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
|
|
# regardless of what's provided with the training command at the time of resumption.
|
|
resume: bool = False
|
|
device: str | None = None # cuda | cpu | mp
|
|
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
|
# automatic gradient scaling is used.
|
|
use_amp: bool = False
|
|
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
|
# AND for the evaluation environments.
|
|
seed: int | None = 1000
|
|
# Number of workers for the dataloader.
|
|
num_workers: int = 4
|
|
batch_size: int = 8
|
|
steps: int = 100_000
|
|
eval_freq: int = 20_000
|
|
log_freq: int = 200
|
|
save_checkpoint: bool = True
|
|
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
|
save_freq: int = 20_000
|
|
use_policy_training_preset: bool = True
|
|
optimizer: OptimizerConfig | None = None
|
|
scheduler: LRSchedulerConfig | None = None
|
|
eval: EvalConfig = field(default_factory=EvalConfig)
|
|
wandb: WandBConfig = field(default_factory=WandBConfig)
|
|
|
|
def __post_init__(self):
|
|
self.checkpoint_path = None
|
|
|
|
def validate(self):
|
|
if not self.device:
|
|
logging.warning("No device specified, trying to infer device automatically")
|
|
device = auto_select_torch_device()
|
|
self.device = device.type
|
|
|
|
# Automatically deactivate AMP if necessary
|
|
if self.use_amp and not is_amp_available(self.device):
|
|
logging.warning(
|
|
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
|
)
|
|
self.use_amp = False
|
|
|
|
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
|
policy_path = parser.get_path_arg("policy")
|
|
if policy_path:
|
|
# Only load the policy config
|
|
cli_overrides = parser.get_cli_overrides("policy")
|
|
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
|
self.policy.pretrained_path = policy_path
|
|
elif self.resume:
|
|
# The entire train config is already loaded, we just need to get the checkpoint dir
|
|
config_path = parser.parse_arg("config_path")
|
|
if not config_path:
|
|
raise ValueError("A config_path is expected when resuming a run.")
|
|
if not Path(config_path).resolve().exists():
|
|
raise NotADirectoryError(
|
|
f"{config_path=} is expected to be a local path. "
|
|
"Resuming from the hub is not supported for now."
|
|
)
|
|
policy_path = Path(config_path).parent
|
|
self.policy.pretrained_path = policy_path
|
|
self.checkpoint_path = policy_path.parent
|
|
|
|
if not self.job_name:
|
|
if self.env is None:
|
|
self.job_name = f"{self.policy.type}"
|
|
else:
|
|
self.job_name = f"{self.env.type}_{self.policy.type}"
|
|
|
|
if not self.resume and isinstance(self.output_dir, Path) and self.output_dir.is_dir():
|
|
raise FileExistsError(
|
|
f"Output directory {self.output_dir} alreay exists and resume is {self.resume}. "
|
|
f"Please change your output directory so that {self.output_dir} is not overwritten."
|
|
)
|
|
elif not self.output_dir:
|
|
now = dt.datetime.now()
|
|
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
|
|
self.output_dir = Path("outputs/train") / train_dir
|
|
|
|
if isinstance(self.dataset.repo_id, list):
|
|
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
|
|
|
|
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
|
|
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
|
|
elif self.use_policy_training_preset and not self.resume:
|
|
self.optimizer = self.policy.get_optimizer_preset()
|
|
self.scheduler = self.policy.get_scheduler_preset()
|
|
|
|
@classmethod
|
|
def __get_path_fields__(cls) -> list[str]:
|
|
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
|
return ["policy"]
|
|
|
|
def to_dict(self) -> dict:
|
|
return draccus.encode(self)
|
|
|
|
def _save_pretrained(self, save_directory: Path) -> None:
|
|
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
|
draccus.dump(self, f, indent=4)
|
|
|
|
@classmethod
|
|
def from_pretrained(
|
|
cls: Type["TrainPipelineConfig"],
|
|
pretrained_name_or_path: str | Path,
|
|
*,
|
|
force_download: bool = False,
|
|
resume_download: bool = None,
|
|
proxies: dict | None = None,
|
|
token: str | bool | None = None,
|
|
cache_dir: str | Path | None = None,
|
|
local_files_only: bool = False,
|
|
revision: str | None = None,
|
|
**kwargs,
|
|
) -> "TrainPipelineConfig":
|
|
model_id = str(pretrained_name_or_path)
|
|
config_file: str | None = None
|
|
if Path(model_id).is_dir():
|
|
if TRAIN_CONFIG_NAME in os.listdir(model_id):
|
|
config_file = os.path.join(model_id, TRAIN_CONFIG_NAME)
|
|
else:
|
|
print(f"{TRAIN_CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
|
elif Path(model_id).is_file():
|
|
config_file = model_id
|
|
else:
|
|
try:
|
|
config_file = hf_hub_download(
|
|
repo_id=model_id,
|
|
filename=TRAIN_CONFIG_NAME,
|
|
revision=revision,
|
|
cache_dir=cache_dir,
|
|
force_download=force_download,
|
|
proxies=proxies,
|
|
resume_download=resume_download,
|
|
token=token,
|
|
local_files_only=local_files_only,
|
|
)
|
|
except HfHubHTTPError as e:
|
|
raise FileNotFoundError(
|
|
f"{TRAIN_CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
|
|
) from e
|
|
|
|
cli_args = kwargs.pop("cli_args", [])
|
|
cfg = draccus.parse(cls, config_file, args=cli_args)
|
|
|
|
return cfg
|