From 9b9e164cf4f6fe1595a32ea593c291e868407e2a Mon Sep 17 00:00:00 2001 From: Ville Kuosmanen Date: Wed, 5 Feb 2025 23:28:57 +0000 Subject: [PATCH 1/4] fix: support multi repo datasets for training --- lerobot/common/datasets/factory.py | 49 +++-- lerobot/common/datasets/lerobot_dataset.py | 4 +- lerobot/common/logger.py | 245 +++++++++++++++++++++ lerobot/scripts/train.py | 12 +- 4 files changed, 287 insertions(+), 23 deletions(-) create mode 100644 lerobot/common/logger.py diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 38c01b42..f0e6abe7 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -81,11 +81,27 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas image_transforms = ( ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None ) - - if isinstance(cfg.dataset.repo_id, str): - ds_meta = LeRobotDatasetMetadata( - cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision + if cfg.dataset.repo_id.startswith('['): + datasets = cfg.dataset.repo_id.strip('[]').split(',') + datasets = [x.strip() for x in datasets] + delta_timestamps = {} + for ds in datasets: + ds_meta = LeRobotDatasetMetadata(ds) + d_ts = resolve_delta_timestamps(cfg.policy, ds_meta) + delta_timestamps[ds] = d_ts + dataset = MultiLeRobotDataset( + datasets, + # TODO(aliberts): add proper support for multi dataset + delta_timestamps=delta_timestamps, + image_transforms=image_transforms, + video_backend=cfg.dataset.video_backend, ) + logging.info( + "Multiple datasets were provided. Applied the following index mapping to the provided datasets: " + f"{pformat(dataset.repo_id_to_index , indent=2)}" + ) + else: + ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id) delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta) dataset = LeRobotDataset( cfg.dataset.repo_id, @@ -96,23 +112,16 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas revision=cfg.dataset.revision, video_backend=cfg.dataset.video_backend, ) - else: - raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.") - dataset = MultiLeRobotDataset( - cfg.dataset.repo_id, - # TODO(aliberts): add proper support for multi dataset - # delta_timestamps=delta_timestamps, - image_transforms=image_transforms, - video_backend=cfg.dataset.video_backend, - ) - logging.info( - "Multiple datasets were provided. Applied the following index mapping to the provided datasets: " - f"{pformat(dataset.repo_id_to_index, indent=2)}" - ) if cfg.dataset.use_imagenet_stats: - for key in dataset.meta.camera_keys: - for stats_type, stats in IMAGENET_STATS.items(): - dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32) + if isinstance(dataset, MultiLeRobotDataset): + for ds in dataset._datasets: + for key in ds.meta.camera_keys: + for stats_type, stats in IMAGENET_STATS.items(): + ds.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32) + else: + for key in dataset.meta.camera_keys: + for stats_type, stats in IMAGENET_STATS.items(): + dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32) return dataset diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 5414c76d..0544daf1 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -1046,7 +1046,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): root: str | Path | None = None, episodes: dict | None = None, image_transforms: Callable | None = None, - delta_timestamps: dict[list[float]] | None = None, + delta_timestamps: dict[str, dict[list[float]]] | None = None, tolerances_s: dict | None = None, download_videos: bool = True, video_backend: str | None = None, @@ -1063,7 +1063,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): root=self.root / repo_id, episodes=episodes[repo_id] if episodes else None, image_transforms=image_transforms, - delta_timestamps=delta_timestamps, + delta_timestamps=delta_timestamps[repo_id], tolerance_s=self.tolerances_s[repo_id], download_videos=download_videos, video_backend=video_backend, diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py new file mode 100644 index 00000000..18458ae6 --- /dev/null +++ b/lerobot/common/logger.py @@ -0,0 +1,245 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py + +# TODO(rcadene, alexander-soare): clean this file +""" + +import logging +import os +import re +from dataclasses import asdict +from glob import glob +from pathlib import Path + +import draccus +import torch +from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE +from termcolor import colored +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler + +from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.utils.utils import get_global_random_state +from lerobot.configs.train import TrainPipelineConfig +from lerobot.configs.types import FeatureType, NormalizationMode + +PRETRAINED_MODEL = "pretrained_model" +TRAINING_STATE = "training_state.pth" + + +def log_output_dir(out_dir): + logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}") + + +def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str: + """Return a group name for logging. Optionally returns group name as list.""" + # TODO: these were used to support multirepodataset in the past - think how they could be supported in new way? + dataset_tag = cfg.dataset.repo_id + if dataset_tag.startswith('['): + tags = dataset_tag.strip('[]').split(',') + dataset_tag = f"{tags[0].strip()}_and_more" + lst = [ + f"policy:{cfg.policy.type}", + f"dataset:{dataset_tag}", + f"seed:{cfg.seed}", + ] + if cfg.env is not None: + lst.append(f"env:{cfg.env.type}") + return lst if return_list else "-".join(lst) + + +def get_wandb_run_id_from_filesystem(checkpoint_dir: Path) -> str: + # Get the WandB run ID. + paths = glob(str(checkpoint_dir / "../wandb/latest-run/run-*")) + if len(paths) != 1: + raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.") + match = re.search(r"run-([^\.]+).wandb", paths[0].split("/")[-1]) + if match is None: + raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.") + wandb_run_id = match.groups(0)[0] + return wandb_run_id + + +class Logger: + """Primary logger object. Logs either locally or using wandb. + + The logger creates the following directory structure: + + provided_log_dir + ├── checkpoints + │ ├── specific_checkpoint_name + │ │ ├── pretrained_model # Hugging Face pretrained model directory + │ │ │ ├── ... + │ │ └── training_state.pth # optimizer, scheduler, and random states + training step + | ├── another_specific_checkpoint_name + │ │ ├── ... + | ├── ... + │ └── last # a softlink to the last logged checkpoint + """ + + pretrained_model_dir_name = PRETRAINED_MODEL + training_state_file_name = TRAINING_STATE + + def __init__(self, cfg: TrainPipelineConfig): + self._cfg = cfg + self.log_dir = cfg.output_dir + self.log_dir.mkdir(parents=True, exist_ok=True) + self.job_name = cfg.job_name + self.checkpoints_dir = self.get_checkpoints_dir(self.log_dir) + self.last_checkpoint_dir = self.get_last_checkpoint_dir(self.log_dir) + self.last_pretrained_model_dir = self.get_last_pretrained_model_dir(self.log_dir) + + # Set up WandB. + self._group = cfg_to_group(cfg) + run_offline = not cfg.wandb.enable or not cfg.wandb.project + if run_offline: + logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) + self._wandb = None + else: + os.environ["WANDB_SILENT"] = "true" + import wandb + + wandb_run_id = None + if cfg.resume: + wandb_run_id = get_wandb_run_id_from_filesystem(self.checkpoints_dir) + + wandb.init( + id=wandb_run_id, + project=cfg.wandb.project, + entity=cfg.wandb.entity, + name=self.job_name, + notes=cfg.wandb.notes, + tags=cfg_to_group(cfg, return_list=True), + dir=self.log_dir, + config=asdict(self._cfg), + # TODO(rcadene): try set to True + save_code=False, + # TODO(rcadene): split train and eval, and run async eval with job_type="eval" + job_type="train_eval", + resume="must" if cfg.resume else None, + ) + print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) + logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}") + self._wandb = wandb + + @classmethod + def get_checkpoints_dir(cls, log_dir: str | Path) -> Path: + """Given the log directory, get the sub-directory in which checkpoints will be saved.""" + return Path(log_dir) / "checkpoints" + + @classmethod + def get_last_checkpoint_dir(cls, log_dir: str | Path) -> Path: + """Given the log directory, get the sub-directory in which the last checkpoint will be saved.""" + return cls.get_checkpoints_dir(log_dir) / "last" + + @classmethod + def get_last_pretrained_model_dir(cls, log_dir: str | Path) -> Path: + """ + Given the log directory, get the sub-directory in which the last checkpoint's pretrained weights will + be saved. + """ + return cls.get_last_checkpoint_dir(log_dir) / cls.pretrained_model_dir_name + + def save_model(self, save_dir: Path, policy: PreTrainedPolicy, wandb_artifact_name: str | None = None): + """Save the weights of the Policy model using PyTorchModelHubMixin. + + The weights are saved in a folder called "pretrained_model" under the checkpoint directory. + + Optionally also upload the model to WandB. + """ + + self.checkpoints_dir.mkdir(parents=True, exist_ok=True) + register_features_types() + policy.save_pretrained(save_dir) + # Also save the full config for the env configuration. + self._cfg.save_pretrained(save_dir) + if self._wandb and not self._cfg.wandb.disable_artifact: + # note wandb artifact does not accept ":" or "/" in its name + artifact = self._wandb.Artifact(wandb_artifact_name, type="model") + artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE) + self._wandb.log_artifact(artifact) + if self.last_checkpoint_dir.exists(): + os.remove(self.last_checkpoint_dir) + + def save_training_state( + self, + save_dir: Path, + train_step: int, + optimizer: Optimizer | None = None, + scheduler: LRScheduler | None = None, + ): + """Checkpoint the global training_step, optimizer state, scheduler state, and random state. + + All of these are saved as "training_state.pth" under the checkpoint directory. + """ + training_state = {} + training_state["step"] = train_step + training_state.update(get_global_random_state()) + if optimizer is not None: + training_state["optimizer"] = optimizer.state_dict() + if scheduler is not None: + training_state["scheduler"] = scheduler.state_dict() + torch.save(training_state, save_dir / self.training_state_file_name) + + def save_checkpoint( + self, + train_step: int, + identifier: str, + policy: PreTrainedPolicy, + optimizer: Optimizer | None = None, + scheduler: LRScheduler | None = None, + ): + """Checkpoint the model weights and the training state.""" + checkpoint_dir = self.checkpoints_dir / str(identifier) + wandb_artifact_name = ( + None + if self._wandb is None + else f"{self._group.replace(':', '_').replace('/', '_')}-{self._cfg.seed}-{identifier}" + ) + self.save_model( + checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name + ) + self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler) + + relative_target = checkpoint_dir.relative_to(self.last_checkpoint_dir.parent) + self.last_checkpoint_dir.symlink_to(relative_target) + + def log_dict(self, d: dict, step: int, mode: str = "train"): + assert mode in {"train", "eval"} + # TODO(alexander-soare): Add local text log. + if self._wandb is not None: + for k, v in d.items(): + if not isinstance(v, (int, float, str)): + logging.warning( + f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.' + ) + continue + self._wandb.log({f"{mode}/{k}": v}, step=step) + + def log_video(self, video_path: str, step: int, mode: str = "train"): + assert mode in {"train", "eval"} + assert self._wandb is not None + wandb_video = self._wandb.Video(video_path, fps=self._cfg.env.fps, format="mp4") + self._wandb.log({f"{mode}/video": wandb_video}, step=step) + + +def register_features_types(): + draccus.decode.register(FeatureType, lambda x: FeatureType[x]) + draccus.encode.register(FeatureType, lambda x: x.name) + + draccus.decode.register(NormalizationMode, lambda x: NormalizationMode[x]) + draccus.encode.register(NormalizationMode, lambda x: x.name) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index e36c697a..8576068a 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -24,7 +24,13 @@ from termcolor import colored from torch.amp import GradScaler from torch.optim import Optimizer +<<<<<<< HEAD from lerobot.common.datasets.factory import make_dataset +======= +from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset +from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights +>>>>>>> 1c9d53c1 (fix: support multi repo datasets for training) from lerobot.common.datasets.sampler import EpisodeAwareSampler from lerobot.common.datasets.utils import cycle from lerobot.common.envs.factory import make_env @@ -136,10 +142,14 @@ def train(cfg: TrainPipelineConfig): eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size) logging.info("Creating policy") + if isinstance(dataset, MultiLeRobotDataset): + ds_meta = dataset._datasets[0].meta + else: + ds_meta = dataset.meta policy = make_policy( cfg=cfg.policy, device=device, - ds_meta=dataset.meta, + ds_meta=ds_meta, ) logging.info("Creating optimizer and scheduler") From 33f6b63a7cfcdd04f208e98d3ee06c607a27fc0e Mon Sep 17 00:00:00 2001 From: Ville Kuosmanen Date: Sun, 13 Apr 2025 12:04:39 +0200 Subject: [PATCH 2/4] fix --- lerobot/scripts/train.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 8576068a..c205f35e 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -24,13 +24,8 @@ from termcolor import colored from torch.amp import GradScaler from torch.optim import Optimizer -<<<<<<< HEAD from lerobot.common.datasets.factory import make_dataset -======= -from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset -from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights ->>>>>>> 1c9d53c1 (fix: support multi repo datasets for training) +from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset from lerobot.common.datasets.sampler import EpisodeAwareSampler from lerobot.common.datasets.utils import cycle from lerobot.common.envs.factory import make_env From 6882dfd0129a91140f70d1f7b978ab1637ce3db0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 13 Apr 2025 10:07:08 +0000 Subject: [PATCH 3/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- lerobot/common/datasets/factory.py | 6 +++--- lerobot/common/logger.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index f0e6abe7..416074fe 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -81,8 +81,8 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas image_transforms = ( ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None ) - if cfg.dataset.repo_id.startswith('['): - datasets = cfg.dataset.repo_id.strip('[]').split(',') + if cfg.dataset.repo_id.startswith("["): + datasets = cfg.dataset.repo_id.strip("[]").split(",") datasets = [x.strip() for x in datasets] delta_timestamps = {} for ds in datasets: @@ -98,7 +98,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas ) logging.info( "Multiple datasets were provided. Applied the following index mapping to the provided datasets: " - f"{pformat(dataset.repo_id_to_index , indent=2)}" + f"{pformat(dataset.repo_id_to_index, indent=2)}" ) else: ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id) diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index 18458ae6..b874f4ba 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -49,8 +49,8 @@ def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[st """Return a group name for logging. Optionally returns group name as list.""" # TODO: these were used to support multirepodataset in the past - think how they could be supported in new way? dataset_tag = cfg.dataset.repo_id - if dataset_tag.startswith('['): - tags = dataset_tag.strip('[]').split(',') + if dataset_tag.startswith("["): + tags = dataset_tag.strip("[]").split(",") dataset_tag = f"{tags[0].strip()}_and_more" lst = [ f"policy:{cfg.policy.type}", From 1dbcf584d653e5833c91a2385b83aa872e391613 Mon Sep 17 00:00:00 2001 From: Ville Kuosmanen Date: Sun, 13 Apr 2025 12:08:52 +0200 Subject: [PATCH 4/4] lint --- lerobot/scripts/train.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index b73b1171..468e7a1a 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -137,10 +137,7 @@ def train(cfg: TrainPipelineConfig): eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) logging.info("Creating policy") - if isinstance(dataset, MultiLeRobotDataset): - ds_meta = dataset._datasets[0].meta - else: - ds_meta = dataset.meta + ds_meta = dataset._datasets[0].meta if isinstance(dataset, MultiLeRobotDataset) else dataset.meta policy = make_policy( cfg=cfg.policy, ds_meta=ds_meta,