From 90e099b39fb7e40ff87904dbed5685ffefd05777 Mon Sep 17 00:00:00 2001 From: Simon Alibert <75076266+aliberts@users.noreply.github.com> Date: Tue, 11 Feb 2025 10:36:06 +0100 Subject: [PATCH] Remove offline training, refactor `train.py` and logging/checkpointing (#670) Co-authored-by: Remi --- Makefile | 39 +- examples/3_train_policy.py | 3 +- examples/4_train_policy_with_script.md | 22 +- .../advanced/2_calculate_validation_loss.py | 4 +- lerobot/common/constants.py | 11 + lerobot/common/logger.py | 240 --------- lerobot/common/optim/factory.py | 23 +- lerobot/common/optim/optimizers.py | 48 ++ lerobot/common/optim/schedulers.py | 31 ++ lerobot/common/policies/act/modeling_act.py | 8 +- .../policies/diffusion/modeling_diffusion.py | 5 +- lerobot/common/policies/pretrained.py | 13 +- .../common/policies/tdmpc/modeling_tdmpc.py | 5 +- .../common/policies/vqbet/modeling_vqbet.py | 12 +- lerobot/common/utils/io_utils.py | 84 +++ lerobot/common/utils/logging_utils.py | 163 ++++++ lerobot/common/utils/random_utils.py | 191 +++++++ lerobot/common/utils/train_utils.py | 161 ++++++ lerobot/common/utils/utils.py | 57 -- lerobot/common/utils/wandb_utils.py | 121 +++++ lerobot/configs/train.py | 75 +-- lerobot/scripts/eval.py | 11 +- lerobot/scripts/train.py | 488 ++++-------------- poetry.lock | 24 +- pyproject.toml | 4 +- tests/conftest.py | 1 + tests/fixtures/optimizers.py | 26 + .../save_image_transforms_to_safetensors.py | 2 +- tests/scripts/save_policy_to_safetensors.py | 8 +- tests/test_control_robot.py | 17 +- tests/test_datasets.py | 2 +- tests/test_image_transforms.py | 2 +- tests/test_io_utils.py | 74 +++ tests/test_logging_utils.py | 107 ++++ tests/test_optimizers.py | 43 ++ tests/test_policies.py | 2 +- tests/test_random_utils.py | 109 ++++ tests/test_schedulers.py | 81 +++ tests/test_train_utils.py | 84 +++ tests/test_utils.py | 49 -- 40 files changed, 1515 insertions(+), 935 deletions(-) delete mode 100644 lerobot/common/logger.py create mode 100644 lerobot/common/utils/logging_utils.py create mode 100644 lerobot/common/utils/random_utils.py create mode 100644 lerobot/common/utils/train_utils.py create mode 100644 lerobot/common/utils/wandb_utils.py create mode 100644 tests/fixtures/optimizers.py create mode 100644 tests/test_io_utils.py create mode 100644 tests/test_logging_utils.py create mode 100644 tests/test_optimizers.py create mode 100644 tests/test_random_utils.py create mode 100644 tests/test_schedulers.py create mode 100644 tests/test_train_utils.py diff --git a/Makefile b/Makefile index c216e009..bc10141a 100644 --- a/Makefile +++ b/Makefile @@ -39,8 +39,8 @@ test-act-ete-train: --dataset.image_transforms.enable=true \ --dataset.episodes="[0]" \ --batch_size=2 \ - --offline.steps=4 \ - --online.steps=0 \ + --steps=4 \ + --eval_freq=2 \ --eval.n_episodes=1 \ --eval.batch_size=1 \ --save_freq=2 \ @@ -76,8 +76,8 @@ test-diffusion-ete-train: --dataset.image_transforms.enable=true \ --dataset.episodes="[0]" \ --batch_size=2 \ - --offline.steps=2 \ - --online.steps=0 \ + --steps=2 \ + --eval_freq=2 \ --eval.n_episodes=1 \ --eval.batch_size=1 \ --save_checkpoint=true \ @@ -106,8 +106,8 @@ test-tdmpc-ete-train: --dataset.image_transforms.enable=true \ --dataset.episodes="[0]" \ --batch_size=2 \ - --offline.steps=2 \ - --online.steps=0 \ + --steps=2 \ + --eval_freq=2 \ --eval.n_episodes=1 \ --eval.batch_size=1 \ --save_checkpoint=true \ @@ -126,30 +126,3 @@ test-tdmpc-ete-eval: --eval.n_episodes=1 \ --eval.batch_size=1 \ --device=$(DEVICE) - -# TODO(rcadene): fix online buffer to storing "task" -# test-tdmpc-ete-train-with-online: -# python lerobot/scripts/train.py \ -# --policy.type=tdmpc \ -# --env.type=pusht \ -# --env.obs_type=environment_state_agent_pos \ -# --env.episode_length=5 \ -# --dataset.repo_id=lerobot/pusht_keypoints \ -# --dataset.image_transforms.enable=true \ -# --dataset.episodes="[0]" \ -# --batch_size=2 \ -# --offline.steps=2 \ -# --online.steps=20 \ -# --online.rollout_n_episodes=2 \ -# --online.rollout_batch_size=2 \ -# --online.steps_between_rollouts=10 \ -# --online.buffer_capacity=1000 \ -# --online.env_seed=10000 \ -# --save_checkpoint=false \ -# --save_freq=10 \ -# --log_freq=1 \ -# --eval.use_async_envs=true \ -# --eval.n_episodes=1 \ -# --eval.batch_size=1 \ -# --device=$(DEVICE) \ -# --output_dir=tests/outputs/tdmpc_online/ diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py index 635c7293..cf5d4d3e 100644 --- a/examples/3_train_policy.py +++ b/examples/3_train_policy.py @@ -86,8 +86,7 @@ def main(): while not done: for batch in dataloader: batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} - output_dict = policy.forward(batch) - loss = output_dict["loss"] + loss, _ = policy.forward(batch) loss.backward() optimizer.step() optimizer.zero_grad() diff --git a/examples/4_train_policy_with_script.md b/examples/4_train_policy_with_script.md index 9d57d424..58ed239a 100644 --- a/examples/4_train_policy_with_script.md +++ b/examples/4_train_policy_with_script.md @@ -161,13 +161,13 @@ python lerobot/scripts/train.py \ ``` You should see from the logging that your training picks up from where it left off. -Another reason for which you might want to resume a run is simply to extend training and add more training steps. The number of training steps is set by the option `--offline.steps`, which is 100 000 by default. +Another reason for which you might want to resume a run is simply to extend training and add more training steps. The number of training steps is set by the option `--steps`, which is 100 000 by default. You could double the number of steps of the previous run with: ```bash python lerobot/scripts/train.py \ --config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \ --resume=true \ - --offline.steps=200000 + --steps=200000 ``` ## Outputs of a run @@ -175,12 +175,16 @@ In the output directory, there will be a folder called `checkpoints` with the fo ```bash outputs/train/run_resumption/checkpoints ├── 000100 # checkpoint_dir for training step 100 -│   ├── pretrained_model -│   │   ├── config.json # pretrained policy config -│   │   ├── model.safetensors # model weights -│   │   ├── train_config.json # train config -│ │ └── README.md # model card -│   └── training_state.pth # optimizer/scheduler/rng state and training step +│ ├── pretrained_model/ +│ │ ├── config.json # policy config +│ │ ├── model.safetensors # policy weights +│ │ └── train_config.json # train config +│ └── training_state/ +│ ├── optimizer_param_groups.json # optimizer param groups +│ ├── optimizer_state.safetensors # optimizer state +│ ├── rng_state.safetensors # rng states +│ ├── scheduler_state.json # scheduler state +│ └── training_step.json # training step ├── 000200 └── last -> 000200 # symlink to the last available checkpoint ``` @@ -250,7 +254,7 @@ python lerobot/scripts/train.py \ python lerobot/scripts/train.py \ --config_path=checkpoint/pretrained_model/ \ --resume=true \ - --offline.steps=200000 # <- you can change some training parameters + --steps=200000 # <- you can change some training parameters ``` #### Fine-tuning diff --git a/examples/advanced/2_calculate_validation_loss.py b/examples/advanced/2_calculate_validation_loss.py index 71e76072..6f234719 100644 --- a/examples/advanced/2_calculate_validation_loss.py +++ b/examples/advanced/2_calculate_validation_loss.py @@ -75,9 +75,9 @@ def main(): n_examples_evaluated = 0 for batch in val_dataloader: batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} - output_dict = policy.forward(batch) + loss, _ = policy.forward(batch) - loss_cumsum += output_dict["loss"].item() + loss_cumsum += loss.item() n_examples_evaluated += batch["index"].shape[0] # Calculate the average loss over the validation set. diff --git a/lerobot/common/constants.py b/lerobot/common/constants.py index 73889594..34da4ac0 100644 --- a/lerobot/common/constants.py +++ b/lerobot/common/constants.py @@ -4,3 +4,14 @@ OBS_ROBOT = "observation.state" OBS_IMAGE = "observation.image" OBS_IMAGES = "observation.images" ACTION = "action" + +# files & directories +CHECKPOINTS_DIR = "checkpoints" +LAST_CHECKPOINT_LINK = "last" +PRETRAINED_MODEL_DIR = "pretrained_model" +TRAINING_STATE_DIR = "training_state" +RNG_STATE = "rng_state.safetensors" +TRAINING_STEP = "training_step.json" +OPTIMIZER_STATE = "optimizer_state.safetensors" +OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json" +SCHEDULER_STATE = "scheduler_state.json" diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py deleted file mode 100644 index 5f863f68..00000000 --- a/lerobot/common/logger.py +++ /dev/null @@ -1,240 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py - -# TODO(rcadene, alexander-soare): clean this file -""" - -import logging -import os -import re -from dataclasses import asdict -from glob import glob -from pathlib import Path - -import draccus -import torch -from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE -from termcolor import colored -from torch.optim import Optimizer -from torch.optim.lr_scheduler import LRScheduler - -from lerobot.common.policies.pretrained import PreTrainedPolicy -from lerobot.common.utils.utils import get_global_random_state -from lerobot.configs.train import TrainPipelineConfig -from lerobot.configs.types import FeatureType, NormalizationMode - -PRETRAINED_MODEL = "pretrained_model" -TRAINING_STATE = "training_state.pth" - - -def log_output_dir(out_dir): - logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}") - - -def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str: - """Return a group name for logging. Optionally returns group name as list.""" - lst = [ - f"policy:{cfg.policy.type}", - f"dataset:{cfg.dataset.repo_id}", - f"seed:{cfg.seed}", - ] - if cfg.env is not None: - lst.append(f"env:{cfg.env.type}") - return lst if return_list else "-".join(lst) - - -def get_wandb_run_id_from_filesystem(checkpoint_dir: Path) -> str: - # Get the WandB run ID. - paths = glob(str(checkpoint_dir / "../wandb/latest-run/run-*")) - if len(paths) != 1: - raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.") - match = re.search(r"run-([^\.]+).wandb", paths[0].split("/")[-1]) - if match is None: - raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.") - wandb_run_id = match.groups(0)[0] - return wandb_run_id - - -class Logger: - """Primary logger object. Logs either locally or using wandb. - - The logger creates the following directory structure: - - provided_log_dir - ├── checkpoints - │ ├── specific_checkpoint_name - │ │ ├── pretrained_model # Hugging Face pretrained model directory - │ │ │ ├── ... - │ │ └── training_state.pth # optimizer, scheduler, and random states + training step - | ├── another_specific_checkpoint_name - │ │ ├── ... - | ├── ... - │ └── last # a softlink to the last logged checkpoint - """ - - pretrained_model_dir_name = PRETRAINED_MODEL - training_state_file_name = TRAINING_STATE - - def __init__(self, cfg: TrainPipelineConfig): - self._cfg = cfg - self.log_dir = cfg.output_dir - self.log_dir.mkdir(parents=True, exist_ok=True) - self.job_name = cfg.job_name - self.checkpoints_dir = self.get_checkpoints_dir(self.log_dir) - self.last_checkpoint_dir = self.get_last_checkpoint_dir(self.log_dir) - self.last_pretrained_model_dir = self.get_last_pretrained_model_dir(self.log_dir) - - # Set up WandB. - self._group = cfg_to_group(cfg) - run_offline = not cfg.wandb.enable or not cfg.wandb.project - if run_offline: - logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) - self._wandb = None - else: - os.environ["WANDB_SILENT"] = "true" - import wandb - - wandb_run_id = None - if cfg.resume: - wandb_run_id = get_wandb_run_id_from_filesystem(self.checkpoints_dir) - - wandb.init( - id=wandb_run_id, - project=cfg.wandb.project, - entity=cfg.wandb.entity, - name=self.job_name, - notes=cfg.wandb.notes, - tags=cfg_to_group(cfg, return_list=True), - dir=self.log_dir, - config=asdict(self._cfg), - # TODO(rcadene): try set to True - save_code=False, - # TODO(rcadene): split train and eval, and run async eval with job_type="eval" - job_type="train_eval", - resume="must" if cfg.resume else None, - ) - print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) - logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}") - self._wandb = wandb - - @classmethod - def get_checkpoints_dir(cls, log_dir: str | Path) -> Path: - """Given the log directory, get the sub-directory in which checkpoints will be saved.""" - return Path(log_dir) / "checkpoints" - - @classmethod - def get_last_checkpoint_dir(cls, log_dir: str | Path) -> Path: - """Given the log directory, get the sub-directory in which the last checkpoint will be saved.""" - return cls.get_checkpoints_dir(log_dir) / "last" - - @classmethod - def get_last_pretrained_model_dir(cls, log_dir: str | Path) -> Path: - """ - Given the log directory, get the sub-directory in which the last checkpoint's pretrained weights will - be saved. - """ - return cls.get_last_checkpoint_dir(log_dir) / cls.pretrained_model_dir_name - - def save_model(self, save_dir: Path, policy: PreTrainedPolicy, wandb_artifact_name: str | None = None): - """Save the weights of the Policy model using PyTorchModelHubMixin. - - The weights are saved in a folder called "pretrained_model" under the checkpoint directory. - - Optionally also upload the model to WandB. - """ - - self.checkpoints_dir.mkdir(parents=True, exist_ok=True) - register_features_types() - policy.save_pretrained(save_dir) - # Also save the full config for the env configuration. - self._cfg.save_pretrained(save_dir) - if self._wandb and not self._cfg.wandb.disable_artifact: - # note wandb artifact does not accept ":" or "/" in its name - artifact = self._wandb.Artifact(wandb_artifact_name, type="model") - artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE) - self._wandb.log_artifact(artifact) - if self.last_checkpoint_dir.exists(): - os.remove(self.last_checkpoint_dir) - - def save_training_state( - self, - save_dir: Path, - train_step: int, - optimizer: Optimizer | None = None, - scheduler: LRScheduler | None = None, - ): - """Checkpoint the global training_step, optimizer state, scheduler state, and random state. - - All of these are saved as "training_state.pth" under the checkpoint directory. - """ - training_state = {} - training_state["step"] = train_step - training_state.update(get_global_random_state()) - if optimizer is not None: - training_state["optimizer"] = optimizer.state_dict() - if scheduler is not None: - training_state["scheduler"] = scheduler.state_dict() - torch.save(training_state, save_dir / self.training_state_file_name) - - def save_checkpoint( - self, - train_step: int, - identifier: str, - policy: PreTrainedPolicy, - optimizer: Optimizer | None = None, - scheduler: LRScheduler | None = None, - ): - """Checkpoint the model weights and the training state.""" - checkpoint_dir = self.checkpoints_dir / str(identifier) - wandb_artifact_name = ( - None - if self._wandb is None - else f"{self._group.replace(':', '_').replace('/', '_')}-{self._cfg.seed}-{identifier}" - ) - self.save_model( - checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name - ) - self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler) - - relative_target = checkpoint_dir.relative_to(self.last_checkpoint_dir.parent) - self.last_checkpoint_dir.symlink_to(relative_target) - - def log_dict(self, d: dict, step: int, mode: str = "train"): - assert mode in {"train", "eval"} - # TODO(alexander-soare): Add local text log. - if self._wandb is not None: - for k, v in d.items(): - if not isinstance(v, (int, float, str)): - logging.warning( - f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.' - ) - continue - self._wandb.log({f"{mode}/{k}": v}, step=step) - - def log_video(self, video_path: str, step: int, mode: str = "train"): - assert mode in {"train", "eval"} - assert self._wandb is not None - wandb_video = self._wandb.Video(video_path, fps=self._cfg.env.fps, format="mp4") - self._wandb.log({f"{mode}/video": wandb_video}, step=step) - - -def register_features_types(): - draccus.decode.register(FeatureType, lambda x: FeatureType[x]) - draccus.encode.register(FeatureType, lambda x: x.name) - - draccus.decode.register(NormalizationMode, lambda x: NormalizationMode[x]) - draccus.encode.register(NormalizationMode, lambda x: x.name) diff --git a/lerobot/common/optim/factory.py b/lerobot/common/optim/factory.py index 010cd461..10ff3df7 100644 --- a/lerobot/common/optim/factory.py +++ b/lerobot/common/optim/factory.py @@ -14,15 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pathlib import Path -import torch from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler -from lerobot.common.logger import TRAINING_STATE from lerobot.common.policies.pretrained import PreTrainedPolicy -from lerobot.common.utils.utils import get_global_random_state, set_global_random_state from lerobot.configs.train import TrainPipelineConfig @@ -40,22 +36,5 @@ def make_optimizer_and_scheduler( """ params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters() optimizer = cfg.optimizer.build(params) - lr_scheduler = cfg.scheduler.build(optimizer, cfg.offline.steps) if cfg.scheduler is not None else None + lr_scheduler = cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None return optimizer, lr_scheduler - - -def load_training_state(checkpoint_dir: Path, optimizer: Optimizer, scheduler: LRScheduler | None) -> int: - """ - Given the checkpoint directory, load the optimizer state, scheduler state, and random state, and - return the global training step. - """ - # TODO(aliberts): use safetensors instead as weights_only=False is unsafe - training_state = torch.load(checkpoint_dir / TRAINING_STATE, weights_only=False) - optimizer.load_state_dict(training_state["optimizer"]) - if scheduler is not None: - scheduler.load_state_dict(training_state["scheduler"]) - elif "scheduler" in training_state: - raise ValueError("The checkpoint contains a scheduler state_dict, but no LRScheduler was provided.") - # Small HACK to get the expected keys: use `get_global_random_state`. - set_global_random_state({k: training_state[k] for k in get_global_random_state()}) - return training_state["step"], optimizer, scheduler diff --git a/lerobot/common/optim/optimizers.py b/lerobot/common/optim/optimizers.py index 737305ad..0cf4124c 100644 --- a/lerobot/common/optim/optimizers.py +++ b/lerobot/common/optim/optimizers.py @@ -1,8 +1,32 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import abc from dataclasses import asdict, dataclass +from pathlib import Path import draccus import torch +from safetensors.torch import load_file, save_file + +from lerobot.common.constants import ( + OPTIMIZER_PARAM_GROUPS, + OPTIMIZER_STATE, +) +from lerobot.common.datasets.utils import flatten_dict, unflatten_dict, write_json +from lerobot.common.utils.io_utils import deserialize_json_into_object @dataclass @@ -68,3 +92,27 @@ class SGDConfig(OptimizerConfig): kwargs = asdict(self) kwargs.pop("grad_clip_norm") return torch.optim.SGD(params, **kwargs) + + +def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None: + state = optimizer.state_dict() + param_groups = state.pop("param_groups") + flat_state = flatten_dict(state) + save_file(flat_state, save_dir / OPTIMIZER_STATE) + write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS) + + +def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer: + current_state_dict = optimizer.state_dict() + flat_state = load_file(save_dir / OPTIMIZER_STATE) + state = unflatten_dict(flat_state) + loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}} + + if "param_groups" in current_state_dict: + param_groups = deserialize_json_into_object( + save_dir / OPTIMIZER_PARAM_GROUPS, current_state_dict["param_groups"] + ) + loaded_state_dict["param_groups"] = param_groups + + optimizer.load_state_dict(loaded_state_dict) + return optimizer diff --git a/lerobot/common/optim/schedulers.py b/lerobot/common/optim/schedulers.py index 80d83bdf..7e158394 100644 --- a/lerobot/common/optim/schedulers.py +++ b/lerobot/common/optim/schedulers.py @@ -1,11 +1,31 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import abc import math from dataclasses import asdict, dataclass +from pathlib import Path import draccus from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR, LRScheduler +from lerobot.common.constants import SCHEDULER_STATE +from lerobot.common.datasets.utils import write_json +from lerobot.common.utils.io_utils import deserialize_json_into_object + @dataclass class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC): @@ -89,3 +109,14 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig): return cosine_decay_schedule(current_step) return LambdaLR(optimizer, lr_lambda, -1) + + +def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None: + state_dict = scheduler.state_dict() + write_json(state_dict, save_dir / SCHEDULER_STATE) + + +def load_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> LRScheduler: + state_dict = deserialize_json_into_object(save_dir / SCHEDULER_STATE, scheduler.state_dict()) + scheduler.load_state_dict(state_dict) + return scheduler diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 615f238f..9a1036c3 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -144,7 +144,7 @@ class ACTPolicy(PreTrainedPolicy): self._action_queue.extend(actions.transpose(0, 1)) return self._action_queue.popleft() - def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) if self.config.image_features: @@ -169,11 +169,11 @@ class ACTPolicy(PreTrainedPolicy): (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() ) loss_dict["kld_loss"] = mean_kld.item() - loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight + loss = l1_loss + mean_kld * self.config.kl_weight else: - loss_dict["loss"] = l1_loss + loss = l1_loss - return loss_dict + return loss, loss_dict class ACTTemporalEnsembler: diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 7147f550..9ecadcb0 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -143,7 +143,7 @@ class DiffusionPolicy(PreTrainedPolicy): action = self._queues["action"].popleft() return action - def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) if self.config.image_features: @@ -153,7 +153,8 @@ class DiffusionPolicy(PreTrainedPolicy): ) batch = self.normalize_targets(batch) loss = self.diffusion.compute_loss(batch) - return {"loss": loss} + # no output_dict so returning None + return loss, None def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler: diff --git a/lerobot/common/policies/pretrained.py b/lerobot/common/policies/pretrained.py index 84767594..1729dfb0 100644 --- a/lerobot/common/policies/pretrained.py +++ b/lerobot/common/policies/pretrained.py @@ -163,12 +163,17 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): """ raise NotImplementedError + # TODO(aliberts, rcadene): split into 'forward' and 'compute_loss'? @abc.abstractmethod - def forward(self, batch: dict[str, Tensor]) -> dict: - """Run the batch through the model and compute the loss for training or validation. + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]: + """_summary_ - Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all - other items should be logging-friendly, native Python types. + Args: + batch (dict[str, Tensor]): _description_ + + Returns: + tuple[Tensor, dict | None]: The loss and potentially other information. Apart from the loss which + is a Tensor, all other items should be logging-friendly, native Python types. """ raise NotImplementedError diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index 6366a5a4..c4f90b8d 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -302,7 +302,7 @@ class TDMPCPolicy(PreTrainedPolicy): G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0) return G - def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: """Run the batch through the model and compute the loss. Returns a dictionary with loss as a tensor, and other information as native floats. @@ -495,7 +495,6 @@ class TDMPCPolicy(PreTrainedPolicy): "Q_value_loss": q_value_loss.item(), "V_value_loss": v_value_loss.item(), "pi_loss": pi_loss.item(), - "loss": loss, "sum_loss": loss.item() * self.config.horizon, } ) @@ -505,7 +504,7 @@ class TDMPCPolicy(PreTrainedPolicy): if isinstance(batch[key], torch.Tensor) and batch[key].ndim > 1: batch[key] = batch[key].transpose(1, 0) - return info + return loss, info def update(self): """Update the target model's parameters with an EMA step.""" diff --git a/lerobot/common/policies/vqbet/modeling_vqbet.py b/lerobot/common/policies/vqbet/modeling_vqbet.py index c4d4a46d..1f70b186 100644 --- a/lerobot/common/policies/vqbet/modeling_vqbet.py +++ b/lerobot/common/policies/vqbet/modeling_vqbet.py @@ -156,7 +156,7 @@ class VQBeTPolicy(PreTrainedPolicy): action = self._queues["action"].popleft() return action - def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) batch = dict(batch) # shallow copy so that adding a key doesn't modify the original @@ -170,16 +170,16 @@ class VQBeTPolicy(PreTrainedPolicy): loss, n_different_codes, n_different_combinations, recon_l1_error = ( self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"]) ) - return { - "loss": loss, + return loss, { "n_different_codes": n_different_codes, "n_different_combinations": n_different_combinations, "recon_l1_error": recon_l1_error, } # if Residual VQ is already trained, VQ-BeT trains its GPT and bin prediction head / offset prediction head parts. _, loss_dict = self.vqbet(batch, rollout=False) + loss = loss_dict.pop("loss") - return loss_dict + return loss, loss_dict class SpatialSoftmax(nn.Module): @@ -342,7 +342,7 @@ class VQBeTModel(nn.Module): torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]), ) - def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor: + def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]: # Input validation. assert set(batch).issuperset({"observation.state", "observation.images"}) batch_size, n_obs_steps = batch["observation.state"].shape[:2] @@ -482,7 +482,7 @@ class VQBeTHead(nn.Module): param.requires_grad = False return loss, n_different_codes, n_different_combinations, recon_l1_error - def forward(self, x, **kwargs): + def forward(self, x, **kwargs) -> dict: # N is the batch size, and T is number of action query tokens, which are process through same GPT N, T, _ = x.shape # we calculate N and T side parallely. Thus, the dimensions would be diff --git a/lerobot/common/utils/io_utils.py b/lerobot/common/utils/io_utils.py index b85f17c7..3fc405f7 100644 --- a/lerobot/common/utils/io_utils.py +++ b/lerobot/common/utils/io_utils.py @@ -13,10 +13,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json import warnings +from pathlib import Path +from typing import TypeVar import imageio +JsonLike = str | int | float | bool | None | list["JsonLike"] | dict[str, "JsonLike"] | tuple["JsonLike", ...] +T = TypeVar("T", bound=JsonLike) + def write_video(video_path, stacked_frames, fps): # Filter out DeprecationWarnings raised from pkg_resources @@ -25,3 +31,81 @@ def write_video(video_path, stacked_frames, fps): "ignore", "pkg_resources is deprecated as an API", category=DeprecationWarning ) imageio.mimsave(video_path, stacked_frames, fps=fps) + + +def deserialize_json_into_object(fpath: Path, obj: T) -> T: + """ + Loads the JSON data from `fpath` and recursively fills `obj` with the + corresponding values (strictly matching structure and types). + Tuples in `obj` are expected to be lists in the JSON data, which will be + converted back into tuples. + """ + with open(fpath, encoding="utf-8") as f: + data = json.load(f) + + def _deserialize(target, source): + """ + Recursively overwrite the structure in `target` with data from `source`, + performing strict checks on structure and type. + Returns the updated version of `target` (especially important for tuples). + """ + + # If the target is a dictionary, source must be a dictionary as well. + if isinstance(target, dict): + if not isinstance(source, dict): + raise TypeError(f"Type mismatch: expected dict, got {type(source)}") + + # Check that they have exactly the same set of keys. + if target.keys() != source.keys(): + raise ValueError( + f"Dictionary keys do not match.\n" f"Expected: {target.keys()}, got: {source.keys()}" + ) + + # Recursively update each key. + for k in target: + target[k] = _deserialize(target[k], source[k]) + + return target + + # If the target is a list, source must be a list as well. + elif isinstance(target, list): + if not isinstance(source, list): + raise TypeError(f"Type mismatch: expected list, got {type(source)}") + + # Check length + if len(target) != len(source): + raise ValueError(f"List length mismatch: expected {len(target)}, got {len(source)}") + + # Recursively update each element. + for i in range(len(target)): + target[i] = _deserialize(target[i], source[i]) + + return target + + # If the target is a tuple, the source must be a list in JSON, + # which we'll convert back to a tuple. + elif isinstance(target, tuple): + if not isinstance(source, list): + raise TypeError(f"Type mismatch: expected list (for tuple), got {type(source)}") + + if len(target) != len(source): + raise ValueError(f"Tuple length mismatch: expected {len(target)}, got {len(source)}") + + # Convert each element, forming a new tuple. + converted_items = [] + for t_item, s_item in zip(target, source, strict=False): + converted_items.append(_deserialize(t_item, s_item)) + + # Return a brand new tuple (tuples are immutable in Python). + return tuple(converted_items) + + # Otherwise, we're dealing with a "primitive" (int, float, str, bool, None). + else: + # Check the exact type. If these must match 1:1, do: + if type(target) is not type(source): + raise TypeError(f"Type mismatch: expected {type(target)}, got {type(source)}") + return source + + # Perform the in-place/recursive deserialization + updated_obj = _deserialize(obj, data) + return updated_obj diff --git a/lerobot/common/utils/logging_utils.py b/lerobot/common/utils/logging_utils.py new file mode 100644 index 00000000..b99c348f --- /dev/null +++ b/lerobot/common/utils/logging_utils.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +from lerobot.common.utils.utils import format_big_number + + +class AverageMeter: + """ + Computes and stores the average and current value + Adapted from https://github.com/pytorch/examples/blob/main/imagenet/main.py + """ + + def __init__(self, name: str, fmt: str = ":f"): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self) -> None: + self.val = 0.0 + self.avg = 0.0 + self.sum = 0.0 + self.count = 0.0 + + def update(self, val: float, n: int = 1) -> None: + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = "{name}:{avg" + self.fmt + "}" + return fmtstr.format(**self.__dict__) + + +class MetricsTracker: + """ + A helper class to track and log metrics over time. + + Usage pattern: + + ```python + # initialize, potentially with non-zero initial step (e.g. if resuming run) + metrics = {"loss": AverageMeter("loss", ":.3f")} + train_metrics = MetricsTracker(cfg, dataset, metrics, initial_step=step) + + # update metrics derived from step (samples, episodes, epochs) at each training step + train_metrics.step() + + # update various metrics + loss = policy.forward(batch) + train_metrics.loss = loss + + # display current metrics + logging.info(train_metrics) + + # export for wandb + wandb.log(train_metrics.to_dict()) + + # reset averages after logging + train_metrics.reset_averages() + ``` + """ + + __keys__ = [ + "_batch_size", + "_num_frames", + "_avg_samples_per_ep", + "metrics", + "steps", + "samples", + "episodes", + "epochs", + ] + + def __init__( + self, + batch_size: int, + num_frames: int, + num_episodes: int, + metrics: dict[str, AverageMeter], + initial_step: int = 0, + ): + self.__dict__.update({k: None for k in self.__keys__}) + self._batch_size = batch_size + self._num_frames = num_frames + self._avg_samples_per_ep = num_frames / num_episodes + self.metrics = metrics + + self.steps = initial_step + # A sample is an (observation,action) pair, where observation and action + # can be on multiple timestamps. In a batch, we have `batch_size` number of samples. + self.samples = self.steps * self._batch_size + self.episodes = self.samples / self._avg_samples_per_ep + self.epochs = self.samples / self._num_frames + + def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any: + if name in self.__dict__: + return self.__dict__[name] + elif name in self.metrics: + return self.metrics[name] + else: + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + + def __setattr__(self, name: str, value: Any) -> None: + if name in self.__dict__: + super().__setattr__(name, value) + elif name in self.metrics: + self.metrics[name].update(value) + else: + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + + def step(self) -> None: + """ + Updates metrics that depend on 'step' for one step. + """ + self.steps += 1 + self.samples += self._batch_size + self.episodes = self.samples / self._avg_samples_per_ep + self.epochs = self.samples / self._num_frames + + def __str__(self) -> str: + display_list = [ + f"step:{format_big_number(self.steps)}", + # number of samples seen during training + f"smpl:{format_big_number(self.samples)}", + # number of episodes seen during training + f"ep:{format_big_number(self.episodes)}", + # number of time all unique samples are seen + f"epch:{self.epochs:.2f}", + *[str(m) for m in self.metrics.values()], + ] + return " ".join(display_list) + + def to_dict(self, use_avg: bool = True) -> dict[str, int | float]: + """ + Returns the current metric values (or averages if `use_avg=True`) as a dict. + """ + return { + "steps": self.steps, + "samples": self.samples, + "episodes": self.episodes, + "epochs": self.epochs, + **{k: m.avg if use_avg else m.val for k, m in self.metrics.items()}, + } + + def reset_averages(self) -> None: + """Resets average meters.""" + for m in self.metrics.values(): + m.reset() diff --git a/lerobot/common/utils/random_utils.py b/lerobot/common/utils/random_utils.py new file mode 100644 index 00000000..3d9bf4dd --- /dev/null +++ b/lerobot/common/utils/random_utils.py @@ -0,0 +1,191 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import random +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Generator + +import numpy as np +import torch +from safetensors.torch import load_file, save_file + +from lerobot.common.constants import RNG_STATE +from lerobot.common.datasets.utils import flatten_dict, unflatten_dict + + +def serialize_python_rng_state() -> dict[str, torch.Tensor]: + """ + Returns the rng state for `random` in the form of a flat dict[str, torch.Tensor] to be saved using + `safetensors.save_file()` or `torch.save()`. + """ + py_state = random.getstate() + return { + "py_rng_version": torch.tensor([py_state[0]], dtype=torch.int64), + "py_rng_state": torch.tensor(py_state[1], dtype=torch.int64), + } + + +def deserialize_python_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None: + """ + Restores the rng state for `random` from a dictionary produced by `serialize_python_rng_state()`. + """ + py_state = (rng_state_dict["py_rng_version"].item(), tuple(rng_state_dict["py_rng_state"].tolist()), None) + random.setstate(py_state) + + +def serialize_numpy_rng_state() -> dict[str, torch.Tensor]: + """ + Returns the rng state for `numpy` in the form of a flat dict[str, torch.Tensor] to be saved using + `safetensors.save_file()` or `torch.save()`. + """ + np_state = np.random.get_state() + # Ensure no breaking changes from numpy + assert np_state[0] == "MT19937" + return { + "np_rng_state_values": torch.tensor(np_state[1], dtype=torch.int64), + "np_rng_state_index": torch.tensor([np_state[2]], dtype=torch.int64), + "np_rng_has_gauss": torch.tensor([np_state[3]], dtype=torch.int64), + "np_rng_cached_gaussian": torch.tensor([np_state[4]], dtype=torch.float32), + } + + +def deserialize_numpy_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None: + """ + Restores the rng state for `numpy` from a dictionary produced by `serialize_numpy_rng_state()`. + """ + np_state = ( + "MT19937", + rng_state_dict["np_rng_state_values"].numpy(), + rng_state_dict["np_rng_state_index"].item(), + rng_state_dict["np_rng_has_gauss"].item(), + rng_state_dict["np_rng_cached_gaussian"].item(), + ) + np.random.set_state(np_state) + + +def serialize_torch_rng_state() -> dict[str, torch.Tensor]: + """ + Returns the rng state for `torch` in the form of a flat dict[str, torch.Tensor] to be saved using + `safetensors.save_file()` or `torch.save()`. + """ + torch_rng_state_dict = {"torch_rng_state": torch.get_rng_state()} + if torch.cuda.is_available(): + torch_rng_state_dict["torch_cuda_rng_state"] = torch.cuda.get_rng_state() + return torch_rng_state_dict + + +def deserialize_torch_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None: + """ + Restores the rng state for `torch` from a dictionary produced by `serialize_torch_rng_state()`. + """ + torch.set_rng_state(rng_state_dict["torch_rng_state"]) + if torch.cuda.is_available() and "torch_cuda_rng_state" in rng_state_dict: + torch.cuda.set_rng_state(rng_state_dict["torch_cuda_rng_state"]) + + +def serialize_rng_state() -> dict[str, torch.Tensor]: + """ + Returns the rng state for `random`, `numpy`, and `torch`, in the form of a flat + dict[str, torch.Tensor] to be saved using `safetensors.save_file()` `torch.save()`. + """ + py_rng_state_dict = serialize_python_rng_state() + np_rng_state_dict = serialize_numpy_rng_state() + torch_rng_state_dict = serialize_torch_rng_state() + + return { + **py_rng_state_dict, + **np_rng_state_dict, + **torch_rng_state_dict, + } + + +def deserialize_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None: + """ + Restores the rng state for `random`, `numpy`, and `torch` from a dictionary produced by + `serialize_rng_state()`. + """ + py_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("py")} + np_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("np")} + torch_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("torch")} + + deserialize_python_rng_state(py_rng_state_dict) + deserialize_numpy_rng_state(np_rng_state_dict) + deserialize_torch_rng_state(torch_rng_state_dict) + + +def save_rng_state(save_dir: Path) -> None: + rng_state_dict = serialize_rng_state() + flat_rng_state_dict = flatten_dict(rng_state_dict) + save_file(flat_rng_state_dict, save_dir / RNG_STATE) + + +def load_rng_state(save_dir: Path) -> None: + flat_rng_state_dict = load_file(save_dir / RNG_STATE) + rng_state_dict = unflatten_dict(flat_rng_state_dict) + deserialize_rng_state(rng_state_dict) + + +def get_rng_state() -> dict[str, Any]: + """Get the random state for `random`, `numpy`, and `torch`.""" + random_state_dict = { + "random_state": random.getstate(), + "numpy_random_state": np.random.get_state(), + "torch_random_state": torch.random.get_rng_state(), + } + if torch.cuda.is_available(): + random_state_dict["torch_cuda_random_state"] = torch.cuda.random.get_rng_state() + return random_state_dict + + +def set_rng_state(random_state_dict: dict[str, Any]): + """Set the random state for `random`, `numpy`, and `torch`. + + Args: + random_state_dict: A dictionary of the form returned by `get_rng_state`. + """ + random.setstate(random_state_dict["random_state"]) + np.random.set_state(random_state_dict["numpy_random_state"]) + torch.random.set_rng_state(random_state_dict["torch_random_state"]) + if torch.cuda.is_available(): + torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"]) + + +def set_seed(seed) -> None: + """Set seed for reproducibility.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +@contextmanager +def seeded_context(seed: int) -> Generator[None, None, None]: + """Set the seed when entering a context, and restore the prior random state at exit. + + Example usage: + + ``` + a = random.random() # produces some random number + with seeded_context(1337): + b = random.random() # produces some other random number + c = random.random() # produces yet another random number, but the same it would have if we never made `b` + ``` + """ + random_state_dict = get_rng_state() + set_seed(seed) + yield None + set_rng_state(random_state_dict) diff --git a/lerobot/common/utils/train_utils.py b/lerobot/common/utils/train_utils.py new file mode 100644 index 00000000..a7998312 --- /dev/null +++ b/lerobot/common/utils/train_utils.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from pathlib import Path + +from termcolor import colored +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler + +from lerobot.common.constants import ( + CHECKPOINTS_DIR, + LAST_CHECKPOINT_LINK, + PRETRAINED_MODEL_DIR, + TRAINING_STATE_DIR, + TRAINING_STEP, +) +from lerobot.common.datasets.utils import load_json, write_json +from lerobot.common.optim.optimizers import load_optimizer_state, save_optimizer_state +from lerobot.common.optim.schedulers import load_scheduler_state, save_scheduler_state +from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.utils.random_utils import load_rng_state, save_rng_state +from lerobot.configs.train import TrainPipelineConfig + + +def log_output_dir(out_dir): + logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}") + + +def get_step_identifier(step: int, total_steps: int) -> str: + num_digits = max(6, len(str(total_steps))) + return f"{step:0{num_digits}d}" + + +def get_step_checkpoint_dir(output_dir: Path, total_steps: int, step: int) -> Path: + """Returns the checkpoint sub-directory corresponding to the step number.""" + step_identifier = get_step_identifier(step, total_steps) + return output_dir / CHECKPOINTS_DIR / step_identifier + + +def save_training_step(step: int, save_dir: Path) -> None: + write_json({"step": step}, save_dir / TRAINING_STEP) + + +def load_training_step(save_dir: Path) -> int: + training_step = load_json(save_dir / TRAINING_STEP) + return training_step["step"] + + +def update_last_checkpoint(checkpoint_dir: Path) -> Path: + last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK + if last_checkpoint_dir.is_symlink(): + last_checkpoint_dir.unlink() + relative_target = checkpoint_dir.relative_to(checkpoint_dir.parent) + last_checkpoint_dir.symlink_to(relative_target) + + +def save_checkpoint( + checkpoint_dir: Path, + step: int, + cfg: TrainPipelineConfig, + policy: PreTrainedPolicy, + optimizer: Optimizer, + scheduler: LRScheduler | None = None, +) -> None: + """This function creates the following directory structure: + + 005000/ # training step at checkpoint + ├── pretrained_model/ + │ ├── config.json # policy config + │ ├── model.safetensors # policy weights + │ └── train_config.json # train config + └── training_state/ + ├── optimizer_param_groups.json # optimizer param groups + ├── optimizer_state.safetensors # optimizer state + ├── rng_state.safetensors # rng states + ├── scheduler_state.json # scheduler state + └── training_step.json # training step + + Args: + cfg (TrainPipelineConfig): The training config used for this run. + step (int): The training step at that checkpoint. + policy (PreTrainedPolicy): The policy to save. + optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None. + scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None. + """ + pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR + policy.save_pretrained(pretrained_dir) + cfg.save_pretrained(pretrained_dir) + save_training_state(checkpoint_dir, step, optimizer, scheduler) + + +def save_training_state( + checkpoint_dir: Path, + train_step: int, + optimizer: Optimizer | None = None, + scheduler: LRScheduler | None = None, +) -> None: + """ + Saves the training step, optimizer state, scheduler state, and rng state. + + Args: + save_dir (Path): The directory to save artifacts to. + train_step (int): Current training step. + optimizer (Optimizer | None, optional): The optimizer from which to save the state_dict. + Defaults to None. + scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict. + Defaults to None. + """ + save_dir = checkpoint_dir / TRAINING_STATE_DIR + save_dir.mkdir(parents=True, exist_ok=True) + save_training_step(train_step, save_dir) + save_rng_state(save_dir) + if optimizer is not None: + save_optimizer_state(optimizer, save_dir) + if scheduler is not None: + save_scheduler_state(scheduler, save_dir) + + +def load_training_state( + checkpoint_dir: Path, optimizer: Optimizer, scheduler: LRScheduler | None +) -> tuple[int, Optimizer, LRScheduler | None]: + """ + Loads the training step, optimizer state, scheduler state, and rng state. + This is used to resume a training run. + + Args: + checkpoint_dir (Path): The checkpoint directory. Should contain a 'training_state' dir. + optimizer (Optimizer): The optimizer to load the state_dict to. + scheduler (LRScheduler | None): The scheduler to load the state_dict to (can be None). + + Raises: + NotADirectoryError: If 'checkpoint_dir' doesn't contain a 'training_state' dir + + Returns: + tuple[int, Optimizer, LRScheduler | None]: training step, optimizer and scheduler with their + state_dict loaded. + """ + training_state_dir = checkpoint_dir / TRAINING_STATE_DIR + if not training_state_dir.is_dir(): + raise NotADirectoryError(training_state_dir) + + load_rng_state(training_state_dir) + step = load_training_step(training_state_dir) + optimizer = load_optimizer_state(optimizer, training_state_dir) + if scheduler is not None: + scheduler = load_scheduler_state(scheduler, training_state_dir) + + return step, optimizer, scheduler diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py index cb4f1874..015d1ede 100644 --- a/lerobot/common/utils/utils.py +++ b/lerobot/common/utils/utils.py @@ -17,14 +17,10 @@ import logging import os import os.path as osp import platform -import random -from contextlib import contextmanager from copy import copy from datetime import datetime, timezone from pathlib import Path -from typing import Any, Generator -import numpy as np import torch @@ -106,59 +102,6 @@ def is_amp_available(device: str): raise ValueError(f"Unknown device '{device}.") -def get_global_random_state() -> dict[str, Any]: - """Get the random state for `random`, `numpy`, and `torch`.""" - random_state_dict = { - "random_state": random.getstate(), - "numpy_random_state": np.random.get_state(), - "torch_random_state": torch.random.get_rng_state(), - } - if torch.cuda.is_available(): - random_state_dict["torch_cuda_random_state"] = torch.cuda.random.get_rng_state() - return random_state_dict - - -def set_global_random_state(random_state_dict: dict[str, Any]): - """Set the random state for `random`, `numpy`, and `torch`. - - Args: - random_state_dict: A dictionary of the form returned by `get_global_random_state`. - """ - random.setstate(random_state_dict["random_state"]) - np.random.set_state(random_state_dict["numpy_random_state"]) - torch.random.set_rng_state(random_state_dict["torch_random_state"]) - if torch.cuda.is_available(): - torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"]) - - -def set_global_seed(seed): - """Set seed for reproducibility.""" - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) - - -@contextmanager -def seeded_context(seed: int) -> Generator[None, None, None]: - """Set the seed when entering a context, and restore the prior random state at exit. - - Example usage: - - ``` - a = random.random() # produces some random number - with seeded_context(1337): - b = random.random() # produces some other random number - c = random.random() # produces yet another random number, but the same it would have if we never made `b` - ``` - """ - random_state_dict = get_global_random_state() - set_global_seed(seed) - yield None - set_global_random_state(random_state_dict) - - def init_logging(): def custom_format(record): dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") diff --git a/lerobot/common/utils/wandb_utils.py b/lerobot/common/utils/wandb_utils.py new file mode 100644 index 00000000..2ab3c3fd --- /dev/null +++ b/lerobot/common/utils/wandb_utils.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +import re +from glob import glob +from pathlib import Path + +from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE +from termcolor import colored + +from lerobot.common.constants import PRETRAINED_MODEL_DIR +from lerobot.configs.train import TrainPipelineConfig + + +def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str: + """Return a group name for logging. Optionally returns group name as list.""" + lst = [ + f"policy:{cfg.policy.type}", + f"dataset:{cfg.dataset.repo_id}", + f"seed:{cfg.seed}", + ] + if cfg.env is not None: + lst.append(f"env:{cfg.env.type}") + return lst if return_list else "-".join(lst) + + +def get_wandb_run_id_from_filesystem(log_dir: Path) -> str: + # Get the WandB run ID. + paths = glob(str(log_dir / "wandb/latest-run/run-*")) + if len(paths) != 1: + raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.") + match = re.search(r"run-([^\.]+).wandb", paths[0].split("/")[-1]) + if match is None: + raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.") + wandb_run_id = match.groups(0)[0] + return wandb_run_id + + +def get_safe_wandb_artifact_name(name: str): + """WandB artifacts don't accept ":" or "/" in their name.""" + return name.replace(":", "_").replace("/", "_") + + +class WandBLogger: + """A helper class to log object using wandb.""" + + def __init__(self, cfg: TrainPipelineConfig): + self.cfg = cfg.wandb + self.log_dir = cfg.output_dir + self.job_name = cfg.job_name + self.env_fps = cfg.env.fps if cfg.env else None + self._group = cfg_to_group(cfg) + + # Set up WandB. + os.environ["WANDB_SILENT"] = "True" + import wandb + + wandb_run_id = get_wandb_run_id_from_filesystem(self.log_dir) if cfg.resume else None + wandb.init( + id=wandb_run_id, + project=self.cfg.project, + entity=self.cfg.entity, + name=self.job_name, + notes=self.cfg.notes, + tags=cfg_to_group(cfg, return_list=True), + dir=self.log_dir, + config=cfg.to_dict(), + # TODO(rcadene): try set to True + save_code=False, + # TODO(rcadene): split train and eval, and run async eval with job_type="eval" + job_type="train_eval", + resume="must" if cfg.resume else None, + ) + print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) + logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}") + self._wandb = wandb + + def log_policy(self, checkpoint_dir: Path): + """Checkpoints the policy to wandb.""" + if self.cfg.disable_artifact: + return + + step_id = checkpoint_dir.name + artifact_name = f"{self._group}-{step_id}" + artifact_name = get_safe_wandb_artifact_name(artifact_name) + artifact = self._wandb.Artifact(artifact_name, type="model") + artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE) + self._wandb.log_artifact(artifact) + + def log_dict(self, d: dict, step: int, mode: str = "train"): + if mode in {"train", "eval"}: + raise ValueError(mode) + + for k, v in d.items(): + if not isinstance(v, (int, float, str)): + logging.warning( + f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.' + ) + continue + self._wandb.log({f"{mode}/{k}": v}, step=step) + + def log_video(self, video_path: str, step: int, mode: str = "train"): + if mode in {"train", "eval"}: + raise ValueError(mode) + + wandb_video = self._wandb.Video(video_path, fps=self.env_fps, format="mp4") + self._wandb.log({f"{mode}/video": wandb_video}, step=step) diff --git a/lerobot/configs/train.py b/lerobot/configs/train.py index 3d976e81..9d63339d 100644 --- a/lerobot/configs/train.py +++ b/lerobot/configs/train.py @@ -21,68 +21,6 @@ from lerobot.configs.policies import PreTrainedConfig TRAIN_CONFIG_NAME = "train_config.json" -@dataclass -class OfflineConfig: - steps: int = 100_000 - - -@dataclass -class OnlineConfig: - """ - The online training loop looks something like: - - ```python - for i in range(steps): - do_online_rollout_and_update_online_buffer() - for j in range(steps_between_rollouts): - batch = next(dataloader_with_offline_and_online_data) - loss = policy(batch) - loss.backward() - optimizer.step() - ``` - - Note that the online training loop adopts most of the options from the offline loop unless specified - otherwise. - """ - - steps: int = 0 - # How many episodes to collect at once when we reach the online rollout part of the training loop. - rollout_n_episodes: int = 1 - # The number of environments to use in the gym.vector.VectorEnv. This ends up also being the batch size for - # the policy. Ideally you should set this to by an even divisor of rollout_n_episodes. - rollout_batch_size: int = 1 - # How many optimization steps (forward, backward, optimizer step) to do between running rollouts. - steps_between_rollouts: int | None = None - # The proportion of online samples (vs offline samples) to include in the online training batches. - sampling_ratio: float = 0.5 - # First seed to use for the online rollout environment. Seeds for subsequent rollouts are incremented by 1. - env_seed: int | None = None - # Sets the maximum number of frames that are stored in the online buffer for online training. The buffer is - # FIFO. - buffer_capacity: int | None = None - # The minimum number of frames to have in the online buffer before commencing online training. - # If buffer_seed_size > rollout_n_episodes, the rollout will be run multiple times until the - # seed size condition is satisfied. - buffer_seed_size: int = 0 - # Whether to run the online rollouts asynchronously. This means we can run the online training steps in - # parallel with the rollouts. This might be advised if your GPU has the bandwidth to handle training - # + eval + environment rendering simultaneously. - do_rollout_async: bool = False - - def __post_init__(self): - if self.steps == 0: - return - - if self.steps_between_rollouts is None: - raise ValueError( - "'steps_between_rollouts' must be set to a positive integer, but it is currently None." - ) - if self.env_seed is None: - raise ValueError("'env_seed' must be set to a positive integer, but it is currently None.") - if self.buffer_capacity is None: - raise ValueError("'buffer_capacity' must be set to a positive integer, but it is currently None.") - - @dataclass class TrainPipelineConfig(HubMixin): dataset: DatasetConfig @@ -107,13 +45,12 @@ class TrainPipelineConfig(HubMixin): # Number of workers for the dataloader. num_workers: int = 4 batch_size: int = 8 + steps: int = 100_000 eval_freq: int = 20_000 log_freq: int = 200 save_checkpoint: bool = True # Checkpoint is saved every `save_freq` training iterations and after the last training step. save_freq: int = 20_000 - offline: OfflineConfig = field(default_factory=OfflineConfig) - online: OnlineConfig = field(default_factory=OnlineConfig) use_policy_training_preset: bool = True optimizer: OptimizerConfig | None = None scheduler: LRSchedulerConfig | None = None @@ -168,11 +105,8 @@ class TrainPipelineConfig(HubMixin): train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}" self.output_dir = Path("outputs/train") / train_dir - if self.online.steps > 0: - if isinstance(self.dataset.repo_id, list): - raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.") - if self.env is None: - raise ValueError("An environment is required for online training") + if isinstance(self.dataset.repo_id, list): + raise NotImplementedError("LeRobotMultiDataset is not currently implemented.") if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None): raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.") @@ -185,6 +119,9 @@ class TrainPipelineConfig(HubMixin): """This enables the parser to load config from the policy using `--policy.path=local/dir`""" return ["policy"] + def to_dict(self) -> dict: + return draccus.encode(self) + def _save_pretrained(self, save_directory: Path) -> None: with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"): draccus.dump(self, f, indent=4) diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 253bc45c..7318748f 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -61,21 +61,21 @@ import einops import gymnasium as gym import numpy as np import torch +from termcolor import colored from torch import Tensor, nn from tqdm import trange from lerobot.common.envs.factory import make_env from lerobot.common.envs.utils import preprocess_observation -from lerobot.common.logger import log_output_dir from lerobot.common.policies.factory import make_policy from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.utils import get_device_from_parameters from lerobot.common.utils.io_utils import write_video +from lerobot.common.utils.random_utils import set_seed from lerobot.common.utils.utils import ( get_safe_torch_device, init_logging, inside_slurm, - set_global_seed, ) from lerobot.configs import parser from lerobot.configs.eval import EvalPipelineConfig @@ -125,9 +125,6 @@ def rollout( # Reset the policy and environments. policy.reset() - if hasattr(policy, "use_ema_modules"): - policy.use_ema_modules() - observation, info = env.reset(seed=seeds) if render_callback is not None: render_callback(env) @@ -463,9 +460,9 @@ def eval(cfg: EvalPipelineConfig): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True - set_global_seed(cfg.seed) + set_seed(cfg.seed) - log_output_dir(cfg.output_dir) + logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") logging.info("Making environment.") env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 9af1a972..a840b33d 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -15,58 +15,61 @@ # limitations under the License. import logging import time -from concurrent.futures import ThreadPoolExecutor from contextlib import nullcontext -from copy import deepcopy -from dataclasses import asdict from pprint import pformat -from threading import Lock +from typing import Any -import numpy as np import torch +from termcolor import colored from torch.amp import GradScaler +from torch.optim import Optimizer -from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights +from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.sampler import EpisodeAwareSampler from lerobot.common.datasets.utils import cycle from lerobot.common.envs.factory import make_env -from lerobot.common.logger import Logger, log_output_dir -from lerobot.common.optim.factory import load_training_state, make_optimizer_and_scheduler +from lerobot.common.optim.factory import make_optimizer_and_scheduler from lerobot.common.policies.factory import make_policy +from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.utils import get_device_from_parameters +from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker +from lerobot.common.utils.random_utils import set_seed +from lerobot.common.utils.train_utils import ( + get_step_checkpoint_dir, + get_step_identifier, + load_training_state, + save_checkpoint, + update_last_checkpoint, +) from lerobot.common.utils.utils import ( format_big_number, - get_safe_dtype, get_safe_torch_device, has_method, init_logging, - set_global_seed, ) +from lerobot.common.utils.wandb_utils import WandBLogger from lerobot.configs import parser from lerobot.configs.train import TrainPipelineConfig from lerobot.scripts.eval import eval_policy def update_policy( - policy, - batch, - optimizer, - grad_clip_norm, + train_metrics: MetricsTracker, + policy: PreTrainedPolicy, + batch: Any, + optimizer: Optimizer, + grad_clip_norm: float, grad_scaler: GradScaler, lr_scheduler=None, use_amp: bool = False, lock=None, -): - """Returns a dictionary of items for logging.""" +) -> tuple[MetricsTracker, dict]: start_time = time.perf_counter() device = get_device_from_parameters(policy) policy.train() with torch.autocast(device_type=device.type) if use_amp else nullcontext(): - output_dict = policy.forward(batch) + loss, output_dict = policy.forward(batch) # TODO(rcadene): policy.unnormalize_outputs(out_dict) - loss = output_dict["loss"] grad_scaler.scale(loss).backward() # Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**. @@ -87,9 +90,6 @@ def update_policy( optimizer.zero_grad() - if hasattr(policy, "update_ema_modules"): - policy.update_ema_modules() - # Step through pytorch scheduler at every batch instead of epoch if lr_scheduler is not None: lr_scheduler.step() @@ -98,113 +98,34 @@ def update_policy( # To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC). policy.update() - info = { - "loss": loss.item(), - "grad_norm": float(grad_norm), - "lr": optimizer.param_groups[0]["lr"], - "update_s": time.perf_counter() - start_time, - **{k: v for k, v in output_dict.items() if k != "loss"}, - } - info.update({k: v for k, v in output_dict.items() if k not in info}) - - return info - - -def log_train_info( - logger: Logger, info: dict, step: int, cfg: TrainPipelineConfig, dataset: LeRobotDataset, is_online: bool -): - loss = info["loss"] - grad_norm = info["grad_norm"] - lr = info["lr"] - update_s = info["update_s"] - dataloading_s = info["dataloading_s"] - - # A sample is an (observation,action) pair, where observation and action - # can be on multiple timestamps. In a batch, we have `batch_size`` number of samples. - num_samples = (step + 1) * cfg.batch_size - avg_samples_per_ep = dataset.num_frames / dataset.num_episodes - num_episodes = num_samples / avg_samples_per_ep - num_epochs = num_samples / dataset.num_frames - log_items = [ - f"step:{format_big_number(step)}", - # number of samples seen during training - f"smpl:{format_big_number(num_samples)}", - # number of episodes seen during training - f"ep:{format_big_number(num_episodes)}", - # number of time all unique samples are seen - f"epch:{num_epochs:.2f}", - f"loss:{loss:.3f}", - f"grdn:{grad_norm:.3f}", - f"lr:{lr:0.1e}", - # in seconds - f"updt_s:{update_s:.3f}", - f"data_s:{dataloading_s:.3f}", # if not ~0, you are bottlenecked by cpu or io - ] - logging.info(" ".join(log_items)) - - info["step"] = step - info["num_samples"] = num_samples - info["num_episodes"] = num_episodes - info["num_epochs"] = num_epochs - info["is_online"] = is_online - - logger.log_dict(info, step, mode="train") - - -def log_eval_info(logger, info, step, cfg, dataset, is_online): - eval_s = info["eval_s"] - avg_sum_reward = info["avg_sum_reward"] - pc_success = info["pc_success"] - - # A sample is an (observation,action) pair, where observation and action - # can be on multiple timestamps. In a batch, we have `batch_size`` number of samples. - num_samples = (step + 1) * cfg.batch_size - avg_samples_per_ep = dataset.num_frames / dataset.num_episodes - num_episodes = num_samples / avg_samples_per_ep - num_epochs = num_samples / dataset.num_frames - log_items = [ - f"step:{format_big_number(step)}", - # number of samples seen during training - f"smpl:{format_big_number(num_samples)}", - # number of episodes seen during training - f"ep:{format_big_number(num_episodes)}", - # number of time all unique samples are seen - f"epch:{num_epochs:.2f}", - f"∑rwrd:{avg_sum_reward:.3f}", - f"success:{pc_success:.1f}%", - f"eval_s:{eval_s:.3f}", - ] - logging.info(" ".join(log_items)) - - info["step"] = step - info["num_samples"] = num_samples - info["num_episodes"] = num_episodes - info["num_epochs"] = num_epochs - info["is_online"] = is_online - - logger.log_dict(info, step, mode="eval") + train_metrics.loss = loss.item() + train_metrics.grad_norm = grad_norm.item() + train_metrics.lr = optimizer.param_groups[0]["lr"] + train_metrics.update_s = time.perf_counter() - start_time + return train_metrics, output_dict @parser.wrap() def train(cfg: TrainPipelineConfig): cfg.validate() + logging.info(pformat(cfg.to_dict())) - logging.info(pformat(asdict(cfg))) - - # log metrics to terminal and wandb - logger = Logger(cfg) + if cfg.wandb.enable and cfg.wandb.project: + wandb_logger = WandBLogger(cfg) + else: + wandb_logger = None + logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) if cfg.seed is not None: - set_global_seed(cfg.seed) + set_seed(cfg.seed) # Check device is available device = get_safe_torch_device(cfg.device, log=True) - torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True logging.info("Creating dataset") - offline_dataset = make_dataset(cfg) + dataset = make_dataset(cfg) # Create environment used for evaluating checkpoints during training on simulation data. # On real-world data, no need to create an environment as evaluations are done outside train.py, @@ -218,7 +139,7 @@ def train(cfg: TrainPipelineConfig): policy = make_policy( cfg=cfg.policy, device=device, - ds_meta=offline_dataset.meta, + ds_meta=dataset.meta, ) logging.info("Creating optimizer and scheduler") @@ -233,65 +154,29 @@ def train(cfg: TrainPipelineConfig): num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) - log_output_dir(cfg.output_dir) + logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") if cfg.env is not None: logging.info(f"{cfg.env.task=}") - logging.info(f"{cfg.offline.steps=} ({format_big_number(cfg.offline.steps)})") - logging.info(f"{cfg.online.steps=}") - logging.info(f"{offline_dataset.num_frames=} ({format_big_number(offline_dataset.num_frames)})") - logging.info(f"{offline_dataset.num_episodes=}") + logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})") + logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})") + logging.info(f"{dataset.num_episodes=}") logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") - # Note: this helper will be used in offline and online training loops. - def evaluate_and_checkpoint_if_needed(step: int, is_online: bool): - _num_digits = max(6, len(str(cfg.offline.steps + cfg.online.steps))) - step_identifier = f"{step:0{_num_digits}d}" - - if cfg.env is not None and cfg.eval_freq > 0 and step % cfg.eval_freq == 0: - logging.info(f"Eval policy at step {step}") - with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext(): - eval_info = eval_policy( - eval_env, - policy, - cfg.eval.n_episodes, - videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_identifier}", - max_episodes_rendered=4, - start_seed=cfg.seed, - ) - log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_online=is_online) - if cfg.wandb.enable: - logger.log_video(eval_info["video_paths"][0], step, mode="eval") - logging.info("Resume training") - - if cfg.save_checkpoint and ( - step % cfg.save_freq == 0 or step == cfg.offline.steps + cfg.online.steps - ): - logging.info(f"Checkpoint policy after step {step}") - # Note: Save with step as the identifier, and format it to have at least 6 digits but more if - # needed (choose 6 as a minimum for consistency without being overkill). - logger.save_checkpoint( - step, - step_identifier, - policy, - optimizer, - lr_scheduler, - ) - logging.info("Resume training") - # create dataloader for offline training - if getattr(cfg.policy, "drop_n_last_frames", None): + if hasattr(cfg.policy, "drop_n_last_frames"): shuffle = False sampler = EpisodeAwareSampler( - offline_dataset.episode_data_index, + dataset.episode_data_index, drop_n_last_frames=cfg.policy.drop_n_last_frames, shuffle=True, ) else: shuffle = True sampler = None + dataloader = torch.utils.data.DataLoader( - offline_dataset, + dataset, num_workers=cfg.num_workers, batch_size=cfg.batch_size, shuffle=shuffle, @@ -303,23 +188,30 @@ def train(cfg: TrainPipelineConfig): policy.train() - if hasattr(policy, "init_ema_modules"): - policy.init_ema_modules() + train_metrics = { + "loss": AverageMeter("loss", ":.3f"), + "grad_norm": AverageMeter("grdn", ":.3f"), + "lr": AverageMeter("lr", ":0.1e"), + "update_s": AverageMeter("updt_s", ":.3f"), + "dataloading_s": AverageMeter("data_s", ":.3f"), + } - offline_step = 0 - for _ in range(step, cfg.offline.steps): - if offline_step == 0: - logging.info("Start offline training on a fixed dataset") + train_tracker = MetricsTracker( + cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step + ) + logging.info("Start offline training on a fixed dataset") + for _ in range(step, cfg.steps): start_time = time.perf_counter() batch = next(dl_iter) - dataloading_s = time.perf_counter() - start_time + train_tracker.dataloading_s = time.perf_counter() - start_time for key in batch: if isinstance(batch[key], torch.Tensor): batch[key] = batch[key].to(device, non_blocking=True) - train_info = update_policy( + train_tracker, output_dict = update_policy( + train_tracker, policy, batch, optimizer, @@ -329,231 +221,57 @@ def train(cfg: TrainPipelineConfig): use_amp=cfg.use_amp, ) - train_info["dataloading_s"] = dataloading_s - - if step % cfg.log_freq == 0: - log_train_info(logger, train_info, step, cfg, offline_dataset, is_online=False) - - # Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed, - # so we pass in step + 1. - evaluate_and_checkpoint_if_needed(step + 1, is_online=False) - + # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we + # increment `step` here. step += 1 - offline_step += 1 # noqa: SIM113 + train_tracker.step() + is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 + is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps + is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0 - if cfg.online.steps == 0: - if eval_env: - eval_env.close() - logging.info("End of training") - return + if is_log_step: + logging.info(train_tracker) + if wandb_logger: + wandb_log_dict = {**train_tracker.to_dict(), **output_dict} + wandb_logger.log_dict(wandb_log_dict) + train_tracker.reset_averages() - # Online training. + if cfg.save_checkpoint and is_saving_step: + logging.info(f"Checkpoint policy after step {step}") + checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) + save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler) + update_last_checkpoint(checkpoint_dir) + if wandb_logger: + wandb_logger.log_policy(checkpoint_dir) - # Create an env dedicated to online episodes collection from policy rollout. - online_env = make_env(cfg.env, n_envs=cfg.online.rollout_batch_size) - delta_timestamps = resolve_delta_timestamps(cfg.policy, offline_dataset.meta) - online_buffer_path = logger.log_dir / "online_buffer" - if cfg.resume and not online_buffer_path.exists(): - # If we are resuming a run, we default to the data shapes and buffer capacity from the saved online - # buffer. - logging.warning( - "When online training is resumed, we load the latest online buffer from the prior run, " - "and this might not coincide with the state of the buffer as it was at the moment the checkpoint " - "was made. This is because the online buffer is updated on disk during training, independently " - "of our explicit checkpointing mechanisms." - ) - online_dataset = OnlineBuffer( - online_buffer_path, - data_spec={ - **{ - key: {"shape": ft.shape, "dtype": np.dtype("float32")} - for key, ft in policy.config.input_features.items() - }, - **{ - key: {"shape": ft.shape, "dtype": np.dtype("float32")} - for key, ft in policy.config.output_features.items() - }, - "next.reward": {"shape": (), "dtype": np.dtype("float32")}, - "next.done": {"shape": (), "dtype": np.dtype("?")}, - "task_index": {"shape": (), "dtype": np.dtype("int64")}, - # FIXME: 'task' is a string - # "task": {"shape": (), "dtype": np.dtype("?")}, - # FIXME: 'next.success' is expected by pusht env but not xarm - "next.success": {"shape": (), "dtype": np.dtype("?")}, - }, - buffer_capacity=cfg.online.buffer_capacity, - fps=online_env.unwrapped.metadata["render_fps"], - delta_timestamps=delta_timestamps, - ) - - # If we are doing online rollouts asynchronously, deepcopy the policy to use for online rollouts (this - # makes it possible to do online rollouts in parallel with training updates). - online_rollout_policy = deepcopy(policy) if cfg.online.do_rollout_async else policy - - # Create dataloader for online training. - concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset]) - sampler_weights = compute_sampler_weights( - offline_dataset, - offline_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0), - online_dataset=online_dataset, - # +1 because online rollouts return an extra frame for the "final observation". Note: we don't have - # this final observation in the offline datasets, but we might add them in future. - online_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0) + 1, - online_sampling_ratio=cfg.online.sampling_ratio, - ) - sampler = torch.utils.data.WeightedRandomSampler( - sampler_weights, - num_samples=len(concat_dataset), - replacement=True, - ) - dataloader = torch.utils.data.DataLoader( - concat_dataset, - batch_size=cfg.batch_size, - num_workers=cfg.num_workers, - sampler=sampler, - pin_memory=device.type != "cpu", - drop_last=True, - ) - dl_iter = cycle(dataloader) - - if cfg.online.do_rollout_async: - # Lock and thread pool executor for asynchronous online rollouts. - lock = Lock() - # Note: 1 worker because we only ever want to run one set of online rollouts at a time. Batch - # parallelization of rollouts is handled within the job. - executor = ThreadPoolExecutor(max_workers=1) - else: - lock = None - - online_step = 0 - online_rollout_s = 0 # time take to do online rollout - update_online_buffer_s = 0 # time taken to update the online buffer with the online rollout data - # Time taken waiting for the online buffer to finish being updated. This is relevant when using the async - # online rollout option. - await_update_online_buffer_s = 0 - rollout_start_seed = cfg.online.env_seed - - while True: - if online_step == cfg.online.steps: - break - - if online_step == 0: - logging.info("Start online training by interacting with environment") - - def sample_trajectory_and_update_buffer(): - nonlocal rollout_start_seed - - with lock if lock is not None else nullcontext(): - online_rollout_policy.load_state_dict(policy.state_dict()) - - online_rollout_policy.eval() - start_rollout_time = time.perf_counter() - - with torch.no_grad(): + if cfg.env and is_eval_step: + step_id = get_step_identifier(step, cfg.steps) + logging.info(f"Eval policy at step {step}") + with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext(): eval_info = eval_policy( - online_env, - online_rollout_policy, - n_episodes=cfg.online.rollout_n_episodes, - max_episodes_rendered=min(10, cfg.online.rollout_n_episodes), - videos_dir=logger.log_dir / "online_rollout_videos", - return_episode_data=True, - start_seed=(rollout_start_seed := (rollout_start_seed + cfg.batch_size) % 1000000), + eval_env, + policy, + cfg.eval.n_episodes, + videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}", + max_episodes_rendered=4, + start_seed=cfg.seed, ) - online_rollout_s = time.perf_counter() - start_rollout_time - if len(offline_dataset.meta.tasks) > 1: - raise NotImplementedError("Add support for multi task.") - - # TODO(rcadene, aliberts): Hack to add a task to the online_dataset (0 is the first task of the offline_dataset) - total_num_frames = eval_info["episodes"]["index"].shape[0] - eval_info["episodes"]["task_index"] = torch.tensor([0] * total_num_frames, dtype=torch.int64) - eval_info["episodes"]["task"] = ["do the thing"] * total_num_frames - - with lock if lock is not None else nullcontext(): - start_update_buffer_time = time.perf_counter() - online_dataset.add_data(eval_info["episodes"]) - - # Update the concatenated dataset length used during sampling. - concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets) - - # Update the sampling weights. - sampler.weights = compute_sampler_weights( - offline_dataset, - offline_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0), - online_dataset=online_dataset, - # +1 because online rollouts return an extra frame for the "final observation". Note: we don't have - # this final observation in the offline datasets, but we might add them in future. - online_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0) + 1, - online_sampling_ratio=cfg.online.sampling_ratio, - ) - sampler.num_frames = len(concat_dataset) - - update_online_buffer_s = time.perf_counter() - start_update_buffer_time - - return online_rollout_s, update_online_buffer_s - - if lock is None: - online_rollout_s, update_online_buffer_s = sample_trajectory_and_update_buffer() - else: - future = executor.submit(sample_trajectory_and_update_buffer) - # If we aren't doing async rollouts, or if we haven't yet gotten enough examples in our buffer, wait - # here until the rollout and buffer update is done, before proceeding to the policy update steps. - if len(online_dataset) <= cfg.online.buffer_seed_size: - online_rollout_s, update_online_buffer_s = future.result() - - if len(online_dataset) <= cfg.online.buffer_seed_size: - logging.info(f"Seeding online buffer: {len(online_dataset)}/{cfg.online.buffer_seed_size}") - continue - - policy.train() - for _ in range(cfg.online.steps_between_rollouts): - with lock if lock is not None else nullcontext(): - start_time = time.perf_counter() - batch = next(dl_iter) - dataloading_s = time.perf_counter() - start_time - - for key in batch: - if isinstance(batch[key], torch.Tensor): - dtype = get_safe_dtype(batch[key].dtype, device) - batch[key] = batch[key].to(device=device, dtype=dtype, non_blocking=True) - - train_info = update_policy( - policy, - batch, - optimizer, - cfg.optimizer.grad_clip_norm, - grad_scaler=grad_scaler, - lr_scheduler=lr_scheduler, - use_amp=cfg.use_amp, - lock=lock, + eval_metrics = { + "avg_sum_reward": AverageMeter("∑rwrd", ":.3f"), + "pc_success": AverageMeter("success", ":.1f"), + "eval_s": AverageMeter("eval_s", ":.3f"), + } + eval_tracker = MetricsTracker( + cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step ) - - train_info["dataloading_s"] = dataloading_s - train_info["online_rollout_s"] = online_rollout_s - train_info["update_online_buffer_s"] = update_online_buffer_s - train_info["await_update_online_buffer_s"] = await_update_online_buffer_s - with lock if lock is not None else nullcontext(): - train_info["online_buffer_size"] = len(online_dataset) - - if step % cfg.log_freq == 0: - log_train_info(logger, train_info, step, cfg, online_dataset, is_online=True) - - # Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed, - # so we pass in step + 1. - evaluate_and_checkpoint_if_needed(step + 1, is_online=True) - - step += 1 - online_step += 1 - - # If we're doing async rollouts, we should now wait until we've completed them before proceeding - # to do the next batch of rollouts. - if cfg.online.do_rollout_async and future.running(): - start = time.perf_counter() - online_rollout_s, update_online_buffer_s = future.result() - await_update_online_buffer_s = time.perf_counter() - start - - if online_step >= cfg.online.steps: - break + eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s") + eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward") + eval_tracker.pc_success = eval_info["aggregated"].pop("pc_success") + logging.info(eval_tracker) + if wandb_logger: + wandb_log_dict = {**eval_tracker.to_dict(), **eval_info} + wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval") if eval_env: eval_env.close() diff --git a/poetry.lock b/poetry.lock index 7f53cfb0..bfa3f084 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1257,12 +1257,14 @@ pyarrow = "*" [[package]] name = "draccus" -version = "0.9.3" +version = "0.10.0" description = "A slightly opinionated framework for simple dataclass-based configurations based on Pyrallis." optional = false -python-versions = ">=3.8" -files = [] -develop = false +python-versions = ">=3.9" +files = [ + {file = "draccus-0.10.0-py3-none-any.whl", hash = "sha256:90243418ae0e9271c390a59cafb6acfd37001193696ed36fcc8525f791a83282"}, + {file = "draccus-0.10.0.tar.gz", hash = "sha256:8dd08304219becdcd66cd16058ba98e9c3e6b7bfe48ccb9579dae39f8d37ae19"}, +] [package.dependencies] mergedeep = ">=1.3,<2.0" @@ -1274,12 +1276,6 @@ typing-inspect = ">=0.9.0,<0.10.0" [package.extras] dev = ["black", "mypy", "pre-commit", "pytest", "ruff"] -[package.source] -type = "git" -url = "https://github.com/dlwh/draccus.git" -reference = "HEAD" -resolved_reference = "9b690730ca108930519f48cc5dead72a72fd27cb" - [[package]] name = "drawnow" version = "0.72.5" @@ -4911,8 +4907,6 @@ files = [ {file = "PyAudio-0.2.14-cp311-cp311-win_amd64.whl", hash = "sha256:bbeb01d36a2f472ae5ee5e1451cacc42112986abe622f735bb870a5db77cf903"}, {file = "PyAudio-0.2.14-cp312-cp312-win32.whl", hash = "sha256:5fce4bcdd2e0e8c063d835dbe2860dac46437506af509353c7f8114d4bacbd5b"}, {file = "PyAudio-0.2.14-cp312-cp312-win_amd64.whl", hash = "sha256:12f2f1ba04e06ff95d80700a78967897a489c05e093e3bffa05a84ed9c0a7fa3"}, - {file = "PyAudio-0.2.14-cp313-cp313-win32.whl", hash = "sha256:95328285b4dab57ea8c52a4a996cb52be6d629353315be5bfda403d15932a497"}, - {file = "PyAudio-0.2.14-cp313-cp313-win_amd64.whl", hash = "sha256:692d8c1446f52ed2662120bcd9ddcb5aa2b71f38bda31e58b19fb4672fffba69"}, {file = "PyAudio-0.2.14-cp38-cp38-win32.whl", hash = "sha256:858caf35b05c26d8fc62f1efa2e8f53d5fa1a01164842bd622f70ddc41f55000"}, {file = "PyAudio-0.2.14-cp38-cp38-win_amd64.whl", hash = "sha256:2dac0d6d675fe7e181ba88f2de88d321059b69abd52e3f4934a8878e03a7a074"}, {file = "PyAudio-0.2.14-cp39-cp39-win32.whl", hash = "sha256:f745109634a7c19fa4d6b8b7d6967c3123d988c9ade0cd35d4295ee1acdb53e9"}, @@ -6934,7 +6928,7 @@ test = ["pytest", "ruff"] name = "tokenizers" version = "0.21.0" description = "" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "tokenizers-0.21.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:3c4c93eae637e7d2aaae3d376f06085164e1660f89304c0ab2b1d08a406636b2"}, @@ -7168,7 +7162,7 @@ test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0, name = "transformers" version = "4.48.0" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" -optional = false +optional = true python-versions = ">=3.9.0" files = [ {file = "transformers-4.48.0-py3-none-any.whl", hash = "sha256:6d3de6d71cb5f2a10f9775ccc17abce9620195caaf32ec96542bd2a6937f25b0"}, @@ -7936,4 +7930,4 @@ xarm = ["gym-xarm"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "089df92a0455bb58d2f600ad645b99362c9ff2b800a0c6108edc09f51509c716" +content-hash = "3df78b6f373b1e4ee79530932f386fb4cfd302c1ceffe26617d26fb1a7e751ce" diff --git a/pyproject.toml b/pyproject.toml index ec30aaea..4469f5ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,8 +69,8 @@ pyrender = {git = "https://github.com/mmatl/pyrender.git", markers = "sys_platfo hello-robot-stretch-body = {version = ">=0.7.27", markers = "sys_platform == 'linux'", optional = true} pyserial = {version = ">=3.5", optional = true} jsonlines = ">=4.0.0" -transformers = ">=4.48.0" -draccus = {git = "https://github.com/dlwh/draccus.git"} # replace with draccus = ">=0.9.4" when https://github.com/dlwh/draccus/pull/29 is in release +transformers = {version = ">=4.48.0", optional = true} +draccus = ">=0.10.0" [tool.poetry.extras] diff --git a/tests/conftest.py b/tests/conftest.py index b49bebd9..dc746142 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,6 +28,7 @@ pytest_plugins = [ "tests.fixtures.dataset_factories", "tests.fixtures.files", "tests.fixtures.hub", + "tests.fixtures.optimizers", ] diff --git a/tests/fixtures/optimizers.py b/tests/fixtures/optimizers.py new file mode 100644 index 00000000..1a9b9d11 --- /dev/null +++ b/tests/fixtures/optimizers.py @@ -0,0 +1,26 @@ +import pytest +import torch + +from lerobot.common.optim.optimizers import AdamConfig +from lerobot.common.optim.schedulers import VQBeTSchedulerConfig + + +@pytest.fixture +def model_params(): + return [torch.nn.Parameter(torch.randn(10, 10))] + + +@pytest.fixture +def optimizer(model_params): + optimizer = AdamConfig().build(model_params) + # Dummy step to populate state + loss = sum(param.sum() for param in model_params) + loss.backward() + optimizer.step() + return optimizer + + +@pytest.fixture +def scheduler(optimizer): + config = VQBeTSchedulerConfig(num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5) + return config.build(optimizer, num_training_steps=100) diff --git a/tests/scripts/save_image_transforms_to_safetensors.py b/tests/scripts/save_image_transforms_to_safetensors.py index 8a011a22..bd2c3add 100644 --- a/tests/scripts/save_image_transforms_to_safetensors.py +++ b/tests/scripts/save_image_transforms_to_safetensors.py @@ -25,7 +25,7 @@ from lerobot.common.datasets.transforms import ( ImageTransformsConfig, make_transform_from_config, ) -from lerobot.common.utils.utils import seeded_context +from lerobot.common.utils.random_utils import seeded_context ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors") DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp" diff --git a/tests/scripts/save_policy_to_safetensors.py b/tests/scripts/save_policy_to_safetensors.py index 7bfdc6b1..de784db3 100644 --- a/tests/scripts/save_policy_to_safetensors.py +++ b/tests/scripts/save_policy_to_safetensors.py @@ -22,14 +22,14 @@ from safetensors.torch import save_file from lerobot.common.datasets.factory import make_dataset from lerobot.common.optim.factory import make_optimizer_and_scheduler from lerobot.common.policies.factory import make_policy, make_policy_config -from lerobot.common.utils.utils import set_global_seed +from lerobot.common.utils.random_utils import set_seed from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs): # TODO(rcadene, aliberts): env_name? - set_global_seed(1337) + set_seed(1337) train_cfg = TrainPipelineConfig( # TODO(rcadene, aliberts): remove dataset download @@ -53,9 +53,9 @@ def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwa ) batch = next(iter(dataloader)) - output_dict = policy.forward(batch) + loss, output_dict = policy.forward(batch) output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)} - loss = output_dict["loss"] + output_dict["loss"] = loss loss.backward() grad_stats = {} diff --git a/tests/test_control_robot.py b/tests/test_control_robot.py index e80b507c..36ee096f 100644 --- a/tests/test_control_robot.py +++ b/tests/test_control_robot.py @@ -29,7 +29,6 @@ from unittest.mock import patch import pytest -from lerobot.common.logger import Logger from lerobot.common.policies.act.configuration_act import ACTConfig from lerobot.common.policies.factory import make_policy from lerobot.common.robot_devices.control_configs import ( @@ -38,9 +37,7 @@ from lerobot.common.robot_devices.control_configs import ( ReplayControlConfig, TeleoperateControlConfig, ) -from lerobot.configs.default import DatasetConfig from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.train import TrainPipelineConfig from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate from tests.test_robots import make_robot from tests.utils import DEVICE, TEST_ROBOT_TYPES, mock_calibration_dir, require_robot @@ -185,20 +182,8 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): out_dir = tmpdir / "logger" - ds_cfg = DatasetConfig(repo_id, local_files_only=True) - train_cfg = TrainPipelineConfig( - dataset=ds_cfg, - policy=policy_cfg, - output_dir=out_dir, - device=DEVICE, - ) - logger = Logger(train_cfg) - logger.save_checkpoint( - train_step=0, - identifier=0, - policy=policy, - ) pretrained_policy_path = out_dir / "checkpoints/last/pretrained_model" + policy.save_pretrained(pretrained_policy_path) # In `examples/9_use_aloha.md`, we advise using `num_image_writer_processes=1` # during inference, to reach constent fps, so we test this here. diff --git a/tests/test_datasets.py b/tests/test_datasets.py index b11534ee..2945df41 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -46,7 +46,7 @@ from lerobot.common.datasets.utils import ( from lerobot.common.envs.factory import make_env_config from lerobot.common.policies.factory import make_policy_config from lerobot.common.robot_devices.robots.utils import make_robot -from lerobot.common.utils.utils import seeded_context +from lerobot.common.utils.random_utils import seeded_context from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig from tests.fixtures.constants import DUMMY_REPO_ID diff --git a/tests/test_image_transforms.py b/tests/test_image_transforms.py index 72a4ee6e..c118018a 100644 --- a/tests/test_image_transforms.py +++ b/tests/test_image_transforms.py @@ -28,7 +28,7 @@ from lerobot.common.datasets.transforms import ( SharpnessJitter, make_transform_from_config, ) -from lerobot.common.utils.utils import seeded_context +from lerobot.common.utils.random_utils import seeded_context from lerobot.scripts.visualize_image_transforms import ( save_all_transforms, save_each_transform, diff --git a/tests/test_io_utils.py b/tests/test_io_utils.py new file mode 100644 index 00000000..d14f7adc --- /dev/null +++ b/tests/test_io_utils.py @@ -0,0 +1,74 @@ +import json +from pathlib import Path +from typing import Any + +import pytest + +from lerobot.common.utils.io_utils import deserialize_json_into_object + + +@pytest.fixture +def tmp_json_file(tmp_path: Path): + """Writes `data` to a temporary JSON file and returns the file's path.""" + + def _write(data: Any) -> Path: + file_path = tmp_path / "data.json" + with file_path.open("w", encoding="utf-8") as f: + json.dump(data, f) + return file_path + + return _write + + +def test_simple_dict(tmp_json_file): + data = {"name": "Alice", "age": 30} + json_path = tmp_json_file(data) + obj = {"name": "", "age": 0} + assert deserialize_json_into_object(json_path, obj) == data + + +def test_nested_structure(tmp_json_file): + data = {"items": [1, 2, 3], "info": {"active": True}} + json_path = tmp_json_file(data) + obj = {"items": [0, 0, 0], "info": {"active": False}} + assert deserialize_json_into_object(json_path, obj) == data + + +def test_tuple_conversion(tmp_json_file): + data = {"coords": [10.5, 20.5]} + json_path = tmp_json_file(data) + obj = {"coords": (0.0, 0.0)} + result = deserialize_json_into_object(json_path, obj) + assert result["coords"] == (10.5, 20.5) + + +def test_type_mismatch_raises(tmp_json_file): + data = {"numbers": {"bad": "structure"}} + json_path = tmp_json_file(data) + obj = {"numbers": [0, 0]} + with pytest.raises(TypeError): + deserialize_json_into_object(json_path, obj) + + +def test_missing_key_raises(tmp_json_file): + data = {"one": 1} + json_path = tmp_json_file(data) + obj = {"one": 0, "two": 0} + with pytest.raises(ValueError): + deserialize_json_into_object(json_path, obj) + + +def test_extra_key_raises(tmp_json_file): + data = {"one": 1, "two": 2} + json_path = tmp_json_file(data) + obj = {"one": 0} + with pytest.raises(ValueError): + deserialize_json_into_object(json_path, obj) + + +def test_list_length_mismatch_raises(tmp_json_file): + data = {"nums": [1, 2, 3]} + json_path = tmp_json_file(data) + obj = {"nums": [0, 0]} + with pytest.raises(ValueError): + deserialize_json_into_object(json_path, obj) diff --git a/tests/test_logging_utils.py b/tests/test_logging_utils.py new file mode 100644 index 00000000..72385496 --- /dev/null +++ b/tests/test_logging_utils.py @@ -0,0 +1,107 @@ +import pytest + +from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker + + +@pytest.fixture +def mock_metrics(): + return {"loss": AverageMeter("loss", ":.3f"), "accuracy": AverageMeter("accuracy", ":.2f")} + + +def test_average_meter_initialization(): + meter = AverageMeter("loss", ":.2f") + assert meter.name == "loss" + assert meter.fmt == ":.2f" + assert meter.val == 0.0 + assert meter.avg == 0.0 + assert meter.sum == 0.0 + assert meter.count == 0.0 + + +def test_average_meter_update(): + meter = AverageMeter("accuracy") + meter.update(5, n=2) + assert meter.val == 5 + assert meter.sum == 10 + assert meter.count == 2 + assert meter.avg == 5 + + +def test_average_meter_reset(): + meter = AverageMeter("loss") + meter.update(3, 4) + meter.reset() + assert meter.val == 0.0 + assert meter.avg == 0.0 + assert meter.sum == 0.0 + assert meter.count == 0.0 + + +def test_average_meter_str(): + meter = AverageMeter("metric", ":.1f") + meter.update(4.567, 3) + assert str(meter) == "metric:4.6" + + +def test_metrics_tracker_initialization(mock_metrics): + tracker = MetricsTracker( + batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics, initial_step=10 + ) + assert tracker.steps == 10 + assert tracker.samples == 10 * 32 + assert tracker.episodes == tracker.samples / (1000 / 50) + assert tracker.epochs == tracker.samples / 1000 + assert "loss" in tracker.metrics + assert "accuracy" in tracker.metrics + + +def test_metrics_tracker_step(mock_metrics): + tracker = MetricsTracker( + batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics, initial_step=5 + ) + tracker.step() + assert tracker.steps == 6 + assert tracker.samples == 6 * 32 + assert tracker.episodes == tracker.samples / (1000 / 50) + assert tracker.epochs == tracker.samples / 1000 + + +def test_metrics_tracker_getattr(mock_metrics): + tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics) + assert tracker.loss == mock_metrics["loss"] + assert tracker.accuracy == mock_metrics["accuracy"] + with pytest.raises(AttributeError): + _ = tracker.non_existent_metric + + +def test_metrics_tracker_setattr(mock_metrics): + tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics) + tracker.loss = 2.0 + assert tracker.loss.val == 2.0 + + +def test_metrics_tracker_str(mock_metrics): + tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics) + tracker.loss.update(3.456, 1) + tracker.accuracy.update(0.876, 1) + output = str(tracker) + assert "loss:3.456" in output + assert "accuracy:0.88" in output + + +def test_metrics_tracker_to_dict(mock_metrics): + tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics) + tracker.loss.update(5, 2) + metrics_dict = tracker.to_dict() + assert isinstance(metrics_dict, dict) + assert metrics_dict["loss"] == 5 # average value + assert metrics_dict["steps"] == tracker.steps + + +def test_metrics_tracker_reset_averages(mock_metrics): + tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics) + tracker.loss.update(10, 3) + tracker.accuracy.update(0.95, 5) + tracker.reset_averages() + assert tracker.loss.avg == 0.0 + assert tracker.accuracy.avg == 0.0 diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py new file mode 100644 index 00000000..cf5c5b18 --- /dev/null +++ b/tests/test_optimizers.py @@ -0,0 +1,43 @@ +import pytest +import torch + +from lerobot.common.constants import ( + OPTIMIZER_PARAM_GROUPS, + OPTIMIZER_STATE, +) +from lerobot.common.optim.optimizers import ( + AdamConfig, + AdamWConfig, + SGDConfig, + load_optimizer_state, + save_optimizer_state, +) + + +@pytest.mark.parametrize( + "config_cls, expected_class", + [ + (AdamConfig, torch.optim.Adam), + (AdamWConfig, torch.optim.AdamW), + (SGDConfig, torch.optim.SGD), + ], +) +def test_optimizer_build(config_cls, expected_class, model_params): + config = config_cls() + optimizer = config.build(model_params) + assert isinstance(optimizer, expected_class) + assert optimizer.defaults["lr"] == config.lr + + +def test_save_optimizer_state(optimizer, tmp_path): + save_optimizer_state(optimizer, tmp_path) + assert (tmp_path / OPTIMIZER_STATE).is_file() + assert (tmp_path / OPTIMIZER_PARAM_GROUPS).is_file() + + +def test_save_and_load_optimizer_state(model_params, optimizer, tmp_path): + save_optimizer_state(optimizer, tmp_path) + loaded_optimizer = AdamConfig().build(model_params) + loaded_optimizer = load_optimizer_state(loaded_optimizer, tmp_path) + + torch.testing.assert_close(optimizer.state_dict(), loaded_optimizer.state_dict()) diff --git a/tests/test_policies.py b/tests/test_policies.py index a949a52c..4374157d 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -36,7 +36,7 @@ from lerobot.common.policies.factory import ( ) from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pretrained import PreTrainedPolicy -from lerobot.common.utils.utils import seeded_context +from lerobot.common.utils.random_utils import seeded_context from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature diff --git a/tests/test_random_utils.py b/tests/test_random_utils.py new file mode 100644 index 00000000..8eee2b68 --- /dev/null +++ b/tests/test_random_utils.py @@ -0,0 +1,109 @@ +import random + +import numpy as np +import pytest +import torch + +from lerobot.common.utils.random_utils import ( + deserialize_numpy_rng_state, + deserialize_python_rng_state, + deserialize_rng_state, + deserialize_torch_rng_state, + get_rng_state, + seeded_context, + serialize_numpy_rng_state, + serialize_python_rng_state, + serialize_rng_state, + serialize_torch_rng_state, + set_rng_state, + set_seed, +) + + +@pytest.fixture +def fixed_seed(): + """Fixture to set a consistent initial seed for each test.""" + set_seed(12345) + yield + + +def test_serialize_deserialize_python_rng(fixed_seed): + # Save state after generating val1 + _ = random.random() + st = serialize_python_rng_state() + # Next random is val2 + val2 = random.random() + # Restore the state, so the next random should match val2 + deserialize_python_rng_state(st) + val3 = random.random() + assert val2 == val3 + + +def test_serialize_deserialize_numpy_rng(fixed_seed): + _ = np.random.rand() + st = serialize_numpy_rng_state() + val2 = np.random.rand() + deserialize_numpy_rng_state(st) + val3 = np.random.rand() + assert val2 == val3 + + +def test_serialize_deserialize_torch_rng(fixed_seed): + _ = torch.rand(1).item() + st = serialize_torch_rng_state() + val2 = torch.rand(1).item() + deserialize_torch_rng_state(st) + val3 = torch.rand(1).item() + assert val2 == val3 + + +def test_serialize_deserialize_rng(fixed_seed): + # Generate one from each library + _ = random.random() + _ = np.random.rand() + _ = torch.rand(1).item() + # Serialize + st = serialize_rng_state() + # Generate second set + val_py2 = random.random() + val_np2 = np.random.rand() + val_th2 = torch.rand(1).item() + # Restore, so the next draws should match val_py2, val_np2, val_th2 + deserialize_rng_state(st) + assert random.random() == val_py2 + assert np.random.rand() == val_np2 + assert torch.rand(1).item() == val_th2 + + +def test_get_set_rng_state(fixed_seed): + st = get_rng_state() + val1 = (random.random(), np.random.rand(), torch.rand(1).item()) + # Change states + random.random() + np.random.rand() + torch.rand(1) + # Restore + set_rng_state(st) + val2 = (random.random(), np.random.rand(), torch.rand(1).item()) + assert val1 == val2 + + +def test_set_seed(): + set_seed(1337) + val1 = (random.random(), np.random.rand(), torch.rand(1).item()) + set_seed(1337) + val2 = (random.random(), np.random.rand(), torch.rand(1).item()) + assert val1 == val2 + + +def test_seeded_context(fixed_seed): + val1 = (random.random(), np.random.rand(), torch.rand(1).item()) + with seeded_context(1337): + seeded_val1 = (random.random(), np.random.rand(), torch.rand(1).item()) + val2 = (random.random(), np.random.rand(), torch.rand(1).item()) + with seeded_context(1337): + seeded_val2 = (random.random(), np.random.rand(), torch.rand(1).item()) + + assert seeded_val1 == seeded_val2 + assert all(a != b for a, b in zip(val1, seeded_val1, strict=True)) # changed inside the context + assert all(a != b for a, b in zip(val2, seeded_val2, strict=True)) # changed again after exiting diff --git a/tests/test_schedulers.py b/tests/test_schedulers.py new file mode 100644 index 00000000..e871fee1 --- /dev/null +++ b/tests/test_schedulers.py @@ -0,0 +1,81 @@ +from torch.optim.lr_scheduler import LambdaLR + +from lerobot.common.constants import SCHEDULER_STATE +from lerobot.common.optim.schedulers import ( + CosineDecayWithWarmupSchedulerConfig, + DiffuserSchedulerConfig, + VQBeTSchedulerConfig, + load_scheduler_state, + save_scheduler_state, +) + + +def test_diffuser_scheduler(optimizer): + config = DiffuserSchedulerConfig(name="cosine", num_warmup_steps=5) + scheduler = config.build(optimizer, num_training_steps=100) + assert isinstance(scheduler, LambdaLR) + + optimizer.step() # so that we don't get torch warning + scheduler.step() + expected_state_dict = { + "_get_lr_called_within_step": False, + "_last_lr": [0.0002], + "_step_count": 2, + "base_lrs": [0.001], + "last_epoch": 1, + "lr_lambdas": [None], + "verbose": False, + } + assert scheduler.state_dict() == expected_state_dict + + +def test_vqbet_scheduler(optimizer): + config = VQBeTSchedulerConfig(num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5) + scheduler = config.build(optimizer, num_training_steps=100) + assert isinstance(scheduler, LambdaLR) + + optimizer.step() + scheduler.step() + expected_state_dict = { + "_get_lr_called_within_step": False, + "_last_lr": [0.001], + "_step_count": 2, + "base_lrs": [0.001], + "last_epoch": 1, + "lr_lambdas": [None], + "verbose": False, + } + assert scheduler.state_dict() == expected_state_dict + + +def test_cosine_decay_with_warmup_scheduler(optimizer): + config = CosineDecayWithWarmupSchedulerConfig( + num_warmup_steps=10, num_decay_steps=90, peak_lr=0.01, decay_lr=0.001 + ) + scheduler = config.build(optimizer, num_training_steps=100) + assert isinstance(scheduler, LambdaLR) + + optimizer.step() + scheduler.step() + expected_state_dict = { + "_get_lr_called_within_step": False, + "_last_lr": [0.0001818181818181819], + "_step_count": 2, + "base_lrs": [0.001], + "last_epoch": 1, + "lr_lambdas": [None], + "verbose": False, + } + assert scheduler.state_dict() == expected_state_dict + + +def test_save_scheduler_state(scheduler, tmp_path): + save_scheduler_state(scheduler, tmp_path) + assert (tmp_path / SCHEDULER_STATE).is_file() + + +def test_save_load_scheduler_state(scheduler, tmp_path): + save_scheduler_state(scheduler, tmp_path) + loaded_scheduler = load_scheduler_state(scheduler, tmp_path) + + assert scheduler.state_dict() == loaded_scheduler.state_dict() diff --git a/tests/test_train_utils.py b/tests/test_train_utils.py new file mode 100644 index 00000000..d6ed0063 --- /dev/null +++ b/tests/test_train_utils.py @@ -0,0 +1,84 @@ +from pathlib import Path +from unittest.mock import Mock, patch + +from lerobot.common.constants import ( + CHECKPOINTS_DIR, + LAST_CHECKPOINT_LINK, + OPTIMIZER_PARAM_GROUPS, + OPTIMIZER_STATE, + RNG_STATE, + SCHEDULER_STATE, + TRAINING_STATE_DIR, + TRAINING_STEP, +) +from lerobot.common.utils.train_utils import ( + get_step_checkpoint_dir, + get_step_identifier, + load_training_state, + load_training_step, + save_checkpoint, + save_training_state, + save_training_step, + update_last_checkpoint, +) + + +def test_get_step_identifier(): + assert get_step_identifier(5, 1000) == "000005" + assert get_step_identifier(123, 100_000) == "000123" + assert get_step_identifier(456789, 1_000_000) == "0456789" + + +def test_get_step_checkpoint_dir(): + output_dir = Path("/checkpoints") + step_dir = get_step_checkpoint_dir(output_dir, 1000, 5) + assert step_dir == output_dir / CHECKPOINTS_DIR / "000005" + + +def test_save_load_training_step(tmp_path): + save_training_step(5000, tmp_path) + assert (tmp_path / TRAINING_STEP).is_file() + + +def test_load_training_step(tmp_path): + step = 5000 + save_training_step(step, tmp_path) + loaded_step = load_training_step(tmp_path) + assert loaded_step == step + + +def test_update_last_checkpoint(tmp_path): + checkpoint = tmp_path / "0005" + checkpoint.mkdir() + update_last_checkpoint(checkpoint) + last_checkpoint = tmp_path / LAST_CHECKPOINT_LINK + assert last_checkpoint.is_symlink() + assert last_checkpoint.resolve() == checkpoint + + +@patch("lerobot.common.utils.train_utils.save_training_state") +def test_save_checkpoint(mock_save_training_state, tmp_path, optimizer): + policy = Mock() + cfg = Mock() + save_checkpoint(tmp_path, 10, cfg, policy, optimizer) + policy.save_pretrained.assert_called_once() + cfg.save_pretrained.assert_called_once() + mock_save_training_state.assert_called_once() + + +def test_save_training_state(tmp_path, optimizer, scheduler): + save_training_state(tmp_path, 10, optimizer, scheduler) + assert (tmp_path / TRAINING_STATE_DIR).is_dir() + assert (tmp_path / TRAINING_STATE_DIR / TRAINING_STEP).is_file() + assert (tmp_path / TRAINING_STATE_DIR / RNG_STATE).is_file() + assert (tmp_path / TRAINING_STATE_DIR / OPTIMIZER_STATE).is_file() + assert (tmp_path / TRAINING_STATE_DIR / OPTIMIZER_PARAM_GROUPS).is_file() + assert (tmp_path / TRAINING_STATE_DIR / SCHEDULER_STATE).is_file() + + +def test_save_load_training_state(tmp_path, optimizer, scheduler): + save_training_state(tmp_path, 10, optimizer, scheduler) + loaded_step, loaded_optimizer, loaded_scheduler = load_training_state(tmp_path, optimizer, scheduler) + assert loaded_step == 10 + assert loaded_optimizer is optimizer + assert loaded_scheduler is scheduler diff --git a/tests/test_utils.py b/tests/test_utils.py index 071e17b8..b2f14694 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,8 +1,3 @@ -import random -from typing import Callable - -import numpy as np -import pytest import torch from datasets import Dataset @@ -10,50 +5,6 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_ from lerobot.common.datasets.utils import ( hf_transform_to_torch, ) -from lerobot.common.utils.utils import ( - get_global_random_state, - seeded_context, - set_global_random_state, - set_global_seed, -) - -# Random generation functions for testing the seeding and random state get/set. -rand_fns = [ - random.random, - np.random.random, - lambda: torch.rand(1).item(), -] -if torch.cuda.is_available(): - rand_fns.append(lambda: torch.rand(1, device="cuda")) - - -@pytest.mark.parametrize("rand_fn", rand_fns) -def test_seeding(rand_fn: Callable[[], int]): - set_global_seed(0) - a = rand_fn() - with seeded_context(1337): - c = rand_fn() - b = rand_fn() - set_global_seed(0) - a_ = rand_fn() - b_ = rand_fn() - # Check that `set_global_seed` lets us reproduce a and b. - assert a_ == a - # Additionally, check that the `seeded_context` didn't interrupt the global RNG. - assert b_ == b - set_global_seed(1337) - c_ = rand_fn() - # Check that `seeded_context` and `global_seed` give the same reproducibility. - assert c_ == c - - -def test_get_set_random_state(): - """Check that getting the random state, then setting it results in the same random number generation.""" - random_state_dict = get_global_random_state() - rand_numbers = [rand_fn() for rand_fn in rand_fns] - set_global_random_state(random_state_dict) - rand_numbers_ = [rand_fn() for rand_fn in rand_fns] - assert rand_numbers_ == rand_numbers def test_calculate_episode_data_index():