Remove offline training, refactor `train.py` and logging/checkpointing (#670)

Co-authored-by: Remi <remi.cadene@huggingface.co>
This commit is contained in:
Simon Alibert 2025-02-11 10:36:06 +01:00 committed by GitHub
parent 334deb985d
commit 90e099b39f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
40 changed files with 1515 additions and 935 deletions

View File

@ -39,8 +39,8 @@ test-act-ete-train:
--dataset.image_transforms.enable=true \
--dataset.episodes="[0]" \
--batch_size=2 \
--offline.steps=4 \
--online.steps=0 \
--steps=4 \
--eval_freq=2 \
--eval.n_episodes=1 \
--eval.batch_size=1 \
--save_freq=2 \
@ -76,8 +76,8 @@ test-diffusion-ete-train:
--dataset.image_transforms.enable=true \
--dataset.episodes="[0]" \
--batch_size=2 \
--offline.steps=2 \
--online.steps=0 \
--steps=2 \
--eval_freq=2 \
--eval.n_episodes=1 \
--eval.batch_size=1 \
--save_checkpoint=true \
@ -106,8 +106,8 @@ test-tdmpc-ete-train:
--dataset.image_transforms.enable=true \
--dataset.episodes="[0]" \
--batch_size=2 \
--offline.steps=2 \
--online.steps=0 \
--steps=2 \
--eval_freq=2 \
--eval.n_episodes=1 \
--eval.batch_size=1 \
--save_checkpoint=true \
@ -126,30 +126,3 @@ test-tdmpc-ete-eval:
--eval.n_episodes=1 \
--eval.batch_size=1 \
--device=$(DEVICE)
# TODO(rcadene): fix online buffer to storing "task"
# test-tdmpc-ete-train-with-online:
# python lerobot/scripts/train.py \
# --policy.type=tdmpc \
# --env.type=pusht \
# --env.obs_type=environment_state_agent_pos \
# --env.episode_length=5 \
# --dataset.repo_id=lerobot/pusht_keypoints \
# --dataset.image_transforms.enable=true \
# --dataset.episodes="[0]" \
# --batch_size=2 \
# --offline.steps=2 \
# --online.steps=20 \
# --online.rollout_n_episodes=2 \
# --online.rollout_batch_size=2 \
# --online.steps_between_rollouts=10 \
# --online.buffer_capacity=1000 \
# --online.env_seed=10000 \
# --save_checkpoint=false \
# --save_freq=10 \
# --log_freq=1 \
# --eval.use_async_envs=true \
# --eval.n_episodes=1 \
# --eval.batch_size=1 \
# --device=$(DEVICE) \
# --output_dir=tests/outputs/tdmpc_online/

View File

@ -86,8 +86,7 @@ def main():
while not done:
for batch in dataloader:
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
output_dict = policy.forward(batch)
loss = output_dict["loss"]
loss, _ = policy.forward(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()

View File

@ -161,13 +161,13 @@ python lerobot/scripts/train.py \
```
You should see from the logging that your training picks up from where it left off.
Another reason for which you might want to resume a run is simply to extend training and add more training steps. The number of training steps is set by the option `--offline.steps`, which is 100 000 by default.
Another reason for which you might want to resume a run is simply to extend training and add more training steps. The number of training steps is set by the option `--steps`, which is 100 000 by default.
You could double the number of steps of the previous run with:
```bash
python lerobot/scripts/train.py \
--config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \
--resume=true \
--offline.steps=200000
--steps=200000
```
## Outputs of a run
@ -175,12 +175,16 @@ In the output directory, there will be a folder called `checkpoints` with the fo
```bash
outputs/train/run_resumption/checkpoints
├── 000100 # checkpoint_dir for training step 100
│   ├── pretrained_model
│   │   ├── config.json # pretrained policy config
│   │   ├── model.safetensors # model weights
│   │   ├── train_config.json # train config
│ │ └── README.md # model card
│   └── training_state.pth # optimizer/scheduler/rng state and training step
│ ├── pretrained_model/
│ │ ├── config.json # policy config
│ │ ├── model.safetensors # policy weights
│ │ └── train_config.json # train config
│ └── training_state/
│ ├── optimizer_param_groups.json # optimizer param groups
│ ├── optimizer_state.safetensors # optimizer state
│ ├── rng_state.safetensors # rng states
│ ├── scheduler_state.json # scheduler state
│ └── training_step.json # training step
├── 000200
└── last -> 000200 # symlink to the last available checkpoint
```
@ -250,7 +254,7 @@ python lerobot/scripts/train.py \
python lerobot/scripts/train.py \
--config_path=checkpoint/pretrained_model/ \
--resume=true \
--offline.steps=200000 # <- you can change some training parameters
--steps=200000 # <- you can change some training parameters
```
#### Fine-tuning

View File

@ -75,9 +75,9 @@ def main():
n_examples_evaluated = 0
for batch in val_dataloader:
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
output_dict = policy.forward(batch)
loss, _ = policy.forward(batch)
loss_cumsum += output_dict["loss"].item()
loss_cumsum += loss.item()
n_examples_evaluated += batch["index"].shape[0]
# Calculate the average loss over the validation set.

View File

@ -4,3 +4,14 @@ OBS_ROBOT = "observation.state"
OBS_IMAGE = "observation.image"
OBS_IMAGES = "observation.images"
ACTION = "action"
# files & directories
CHECKPOINTS_DIR = "checkpoints"
LAST_CHECKPOINT_LINK = "last"
PRETRAINED_MODEL_DIR = "pretrained_model"
TRAINING_STATE_DIR = "training_state"
RNG_STATE = "rng_state.safetensors"
TRAINING_STEP = "training_step.json"
OPTIMIZER_STATE = "optimizer_state.safetensors"
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
SCHEDULER_STATE = "scheduler_state.json"

View File

@ -1,240 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Borrowed from https://github.com/fyhMer/fowm/blob/main/src/logger.py
# TODO(rcadene, alexander-soare): clean this file
"""
import logging
import os
import re
from dataclasses import asdict
from glob import glob
from pathlib import Path
import draccus
import torch
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from termcolor import colored
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.utils.utils import get_global_random_state
from lerobot.configs.train import TrainPipelineConfig
from lerobot.configs.types import FeatureType, NormalizationMode
PRETRAINED_MODEL = "pretrained_model"
TRAINING_STATE = "training_state.pth"
def log_output_dir(out_dir):
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str:
"""Return a group name for logging. Optionally returns group name as list."""
lst = [
f"policy:{cfg.policy.type}",
f"dataset:{cfg.dataset.repo_id}",
f"seed:{cfg.seed}",
]
if cfg.env is not None:
lst.append(f"env:{cfg.env.type}")
return lst if return_list else "-".join(lst)
def get_wandb_run_id_from_filesystem(checkpoint_dir: Path) -> str:
# Get the WandB run ID.
paths = glob(str(checkpoint_dir / "../wandb/latest-run/run-*"))
if len(paths) != 1:
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
match = re.search(r"run-([^\.]+).wandb", paths[0].split("/")[-1])
if match is None:
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
wandb_run_id = match.groups(0)[0]
return wandb_run_id
class Logger:
"""Primary logger object. Logs either locally or using wandb.
The logger creates the following directory structure:
provided_log_dir
checkpoints
specific_checkpoint_name
pretrained_model # Hugging Face pretrained model directory
...
training_state.pth # optimizer, scheduler, and random states + training step
| another_specific_checkpoint_name
...
| ...
last # a softlink to the last logged checkpoint
"""
pretrained_model_dir_name = PRETRAINED_MODEL
training_state_file_name = TRAINING_STATE
def __init__(self, cfg: TrainPipelineConfig):
self._cfg = cfg
self.log_dir = cfg.output_dir
self.log_dir.mkdir(parents=True, exist_ok=True)
self.job_name = cfg.job_name
self.checkpoints_dir = self.get_checkpoints_dir(self.log_dir)
self.last_checkpoint_dir = self.get_last_checkpoint_dir(self.log_dir)
self.last_pretrained_model_dir = self.get_last_pretrained_model_dir(self.log_dir)
# Set up WandB.
self._group = cfg_to_group(cfg)
run_offline = not cfg.wandb.enable or not cfg.wandb.project
if run_offline:
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
self._wandb = None
else:
os.environ["WANDB_SILENT"] = "true"
import wandb
wandb_run_id = None
if cfg.resume:
wandb_run_id = get_wandb_run_id_from_filesystem(self.checkpoints_dir)
wandb.init(
id=wandb_run_id,
project=cfg.wandb.project,
entity=cfg.wandb.entity,
name=self.job_name,
notes=cfg.wandb.notes,
tags=cfg_to_group(cfg, return_list=True),
dir=self.log_dir,
config=asdict(self._cfg),
# TODO(rcadene): try set to True
save_code=False,
# TODO(rcadene): split train and eval, and run async eval with job_type="eval"
job_type="train_eval",
resume="must" if cfg.resume else None,
)
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
self._wandb = wandb
@classmethod
def get_checkpoints_dir(cls, log_dir: str | Path) -> Path:
"""Given the log directory, get the sub-directory in which checkpoints will be saved."""
return Path(log_dir) / "checkpoints"
@classmethod
def get_last_checkpoint_dir(cls, log_dir: str | Path) -> Path:
"""Given the log directory, get the sub-directory in which the last checkpoint will be saved."""
return cls.get_checkpoints_dir(log_dir) / "last"
@classmethod
def get_last_pretrained_model_dir(cls, log_dir: str | Path) -> Path:
"""
Given the log directory, get the sub-directory in which the last checkpoint's pretrained weights will
be saved.
"""
return cls.get_last_checkpoint_dir(log_dir) / cls.pretrained_model_dir_name
def save_model(self, save_dir: Path, policy: PreTrainedPolicy, wandb_artifact_name: str | None = None):
"""Save the weights of the Policy model using PyTorchModelHubMixin.
The weights are saved in a folder called "pretrained_model" under the checkpoint directory.
Optionally also upload the model to WandB.
"""
self.checkpoints_dir.mkdir(parents=True, exist_ok=True)
register_features_types()
policy.save_pretrained(save_dir)
# Also save the full config for the env configuration.
self._cfg.save_pretrained(save_dir)
if self._wandb and not self._cfg.wandb.disable_artifact:
# note wandb artifact does not accept ":" or "/" in its name
artifact = self._wandb.Artifact(wandb_artifact_name, type="model")
artifact.add_file(save_dir / SAFETENSORS_SINGLE_FILE)
self._wandb.log_artifact(artifact)
if self.last_checkpoint_dir.exists():
os.remove(self.last_checkpoint_dir)
def save_training_state(
self,
save_dir: Path,
train_step: int,
optimizer: Optimizer | None = None,
scheduler: LRScheduler | None = None,
):
"""Checkpoint the global training_step, optimizer state, scheduler state, and random state.
All of these are saved as "training_state.pth" under the checkpoint directory.
"""
training_state = {}
training_state["step"] = train_step
training_state.update(get_global_random_state())
if optimizer is not None:
training_state["optimizer"] = optimizer.state_dict()
if scheduler is not None:
training_state["scheduler"] = scheduler.state_dict()
torch.save(training_state, save_dir / self.training_state_file_name)
def save_checkpoint(
self,
train_step: int,
identifier: str,
policy: PreTrainedPolicy,
optimizer: Optimizer | None = None,
scheduler: LRScheduler | None = None,
):
"""Checkpoint the model weights and the training state."""
checkpoint_dir = self.checkpoints_dir / str(identifier)
wandb_artifact_name = (
None
if self._wandb is None
else f"{self._group.replace(':', '_').replace('/', '_')}-{self._cfg.seed}-{identifier}"
)
self.save_model(
checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name
)
self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler)
relative_target = checkpoint_dir.relative_to(self.last_checkpoint_dir.parent)
self.last_checkpoint_dir.symlink_to(relative_target)
def log_dict(self, d: dict, step: int, mode: str = "train"):
assert mode in {"train", "eval"}
# TODO(alexander-soare): Add local text log.
if self._wandb is not None:
for k, v in d.items():
if not isinstance(v, (int, float, str)):
logging.warning(
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
)
continue
self._wandb.log({f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"):
assert mode in {"train", "eval"}
assert self._wandb is not None
wandb_video = self._wandb.Video(video_path, fps=self._cfg.env.fps, format="mp4")
self._wandb.log({f"{mode}/video": wandb_video}, step=step)
def register_features_types():
draccus.decode.register(FeatureType, lambda x: FeatureType[x])
draccus.encode.register(FeatureType, lambda x: x.name)
draccus.decode.register(NormalizationMode, lambda x: NormalizationMode[x])
draccus.encode.register(NormalizationMode, lambda x: x.name)

View File

@ -14,15 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from lerobot.common.logger import TRAINING_STATE
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.utils.utils import get_global_random_state, set_global_random_state
from lerobot.configs.train import TrainPipelineConfig
@ -40,22 +36,5 @@ def make_optimizer_and_scheduler(
"""
params = policy.get_optim_params() if cfg.use_policy_training_preset else policy.parameters()
optimizer = cfg.optimizer.build(params)
lr_scheduler = cfg.scheduler.build(optimizer, cfg.offline.steps) if cfg.scheduler is not None else None
lr_scheduler = cfg.scheduler.build(optimizer, cfg.steps) if cfg.scheduler is not None else None
return optimizer, lr_scheduler
def load_training_state(checkpoint_dir: Path, optimizer: Optimizer, scheduler: LRScheduler | None) -> int:
"""
Given the checkpoint directory, load the optimizer state, scheduler state, and random state, and
return the global training step.
"""
# TODO(aliberts): use safetensors instead as weights_only=False is unsafe
training_state = torch.load(checkpoint_dir / TRAINING_STATE, weights_only=False)
optimizer.load_state_dict(training_state["optimizer"])
if scheduler is not None:
scheduler.load_state_dict(training_state["scheduler"])
elif "scheduler" in training_state:
raise ValueError("The checkpoint contains a scheduler state_dict, but no LRScheduler was provided.")
# Small HACK to get the expected keys: use `get_global_random_state`.
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
return training_state["step"], optimizer, scheduler

View File

@ -1,8 +1,32 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from dataclasses import asdict, dataclass
from pathlib import Path
import draccus
import torch
from safetensors.torch import load_file, save_file
from lerobot.common.constants import (
OPTIMIZER_PARAM_GROUPS,
OPTIMIZER_STATE,
)
from lerobot.common.datasets.utils import flatten_dict, unflatten_dict, write_json
from lerobot.common.utils.io_utils import deserialize_json_into_object
@dataclass
@ -68,3 +92,27 @@ class SGDConfig(OptimizerConfig):
kwargs = asdict(self)
kwargs.pop("grad_clip_norm")
return torch.optim.SGD(params, **kwargs)
def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
state = optimizer.state_dict()
param_groups = state.pop("param_groups")
flat_state = flatten_dict(state)
save_file(flat_state, save_dir / OPTIMIZER_STATE)
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
current_state_dict = optimizer.state_dict()
flat_state = load_file(save_dir / OPTIMIZER_STATE)
state = unflatten_dict(flat_state)
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
if "param_groups" in current_state_dict:
param_groups = deserialize_json_into_object(
save_dir / OPTIMIZER_PARAM_GROUPS, current_state_dict["param_groups"]
)
loaded_state_dict["param_groups"] = param_groups
optimizer.load_state_dict(loaded_state_dict)
return optimizer

View File

@ -1,11 +1,31 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import math
from dataclasses import asdict, dataclass
from pathlib import Path
import draccus
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
from lerobot.common.constants import SCHEDULER_STATE
from lerobot.common.datasets.utils import write_json
from lerobot.common.utils.io_utils import deserialize_json_into_object
@dataclass
class LRSchedulerConfig(draccus.ChoiceRegistry, abc.ABC):
@ -89,3 +109,14 @@ class CosineDecayWithWarmupSchedulerConfig(LRSchedulerConfig):
return cosine_decay_schedule(current_step)
return LambdaLR(optimizer, lr_lambda, -1)
def save_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> None:
state_dict = scheduler.state_dict()
write_json(state_dict, save_dir / SCHEDULER_STATE)
def load_scheduler_state(scheduler: LRScheduler, save_dir: Path) -> LRScheduler:
state_dict = deserialize_json_into_object(save_dir / SCHEDULER_STATE, scheduler.state_dict())
scheduler.load_state_dict(state_dict)
return scheduler

View File

@ -144,7 +144,7 @@ class ACTPolicy(PreTrainedPolicy):
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if self.config.image_features:
@ -169,11 +169,11 @@ class ACTPolicy(PreTrainedPolicy):
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
)
loss_dict["kld_loss"] = mean_kld.item()
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
loss = l1_loss + mean_kld * self.config.kl_weight
else:
loss_dict["loss"] = l1_loss
loss = l1_loss
return loss_dict
return loss, loss_dict
class ACTTemporalEnsembler:

View File

@ -143,7 +143,7 @@ class DiffusionPolicy(PreTrainedPolicy):
action = self._queues["action"].popleft()
return action
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if self.config.image_features:
@ -153,7 +153,8 @@ class DiffusionPolicy(PreTrainedPolicy):
)
batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch)
return {"loss": loss}
# no output_dict so returning None
return loss, None
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:

View File

@ -163,12 +163,17 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
"""
raise NotImplementedError
# TODO(aliberts, rcadene): split into 'forward' and 'compute_loss'?
@abc.abstractmethod
def forward(self, batch: dict[str, Tensor]) -> dict:
"""Run the batch through the model and compute the loss for training or validation.
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
"""_summary_
Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all
other items should be logging-friendly, native Python types.
Args:
batch (dict[str, Tensor]): _description_
Returns:
tuple[Tensor, dict | None]: The loss and potentially other information. Apart from the loss which
is a Tensor, all other items should be logging-friendly, native Python types.
"""
raise NotImplementedError

View File

@ -302,7 +302,7 @@ class TDMPCPolicy(PreTrainedPolicy):
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
return G
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss.
Returns a dictionary with loss as a tensor, and other information as native floats.
@ -495,7 +495,6 @@ class TDMPCPolicy(PreTrainedPolicy):
"Q_value_loss": q_value_loss.item(),
"V_value_loss": v_value_loss.item(),
"pi_loss": pi_loss.item(),
"loss": loss,
"sum_loss": loss.item() * self.config.horizon,
}
)
@ -505,7 +504,7 @@ class TDMPCPolicy(PreTrainedPolicy):
if isinstance(batch[key], torch.Tensor) and batch[key].ndim > 1:
batch[key] = batch[key].transpose(1, 0)
return info
return loss, info
def update(self):
"""Update the target model's parameters with an EMA step."""

View File

@ -156,7 +156,7 @@ class VQBeTPolicy(PreTrainedPolicy):
action = self._queues["action"].popleft()
return action
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
@ -170,16 +170,16 @@ class VQBeTPolicy(PreTrainedPolicy):
loss, n_different_codes, n_different_combinations, recon_l1_error = (
self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"])
)
return {
"loss": loss,
return loss, {
"n_different_codes": n_different_codes,
"n_different_combinations": n_different_combinations,
"recon_l1_error": recon_l1_error,
}
# if Residual VQ is already trained, VQ-BeT trains its GPT and bin prediction head / offset prediction head parts.
_, loss_dict = self.vqbet(batch, rollout=False)
loss = loss_dict.pop("loss")
return loss_dict
return loss, loss_dict
class SpatialSoftmax(nn.Module):
@ -342,7 +342,7 @@ class VQBeTModel(nn.Module):
torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]),
)
def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor:
def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]:
# Input validation.
assert set(batch).issuperset({"observation.state", "observation.images"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
@ -482,7 +482,7 @@ class VQBeTHead(nn.Module):
param.requires_grad = False
return loss, n_different_codes, n_different_combinations, recon_l1_error
def forward(self, x, **kwargs):
def forward(self, x, **kwargs) -> dict:
# N is the batch size, and T is number of action query tokens, which are process through same GPT
N, T, _ = x.shape
# we calculate N and T side parallely. Thus, the dimensions would be

View File

@ -13,10 +13,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import warnings
from pathlib import Path
from typing import TypeVar
import imageio
JsonLike = str | int | float | bool | None | list["JsonLike"] | dict[str, "JsonLike"] | tuple["JsonLike", ...]
T = TypeVar("T", bound=JsonLike)
def write_video(video_path, stacked_frames, fps):
# Filter out DeprecationWarnings raised from pkg_resources
@ -25,3 +31,81 @@ def write_video(video_path, stacked_frames, fps):
"ignore", "pkg_resources is deprecated as an API", category=DeprecationWarning
)
imageio.mimsave(video_path, stacked_frames, fps=fps)
def deserialize_json_into_object(fpath: Path, obj: T) -> T:
"""
Loads the JSON data from `fpath` and recursively fills `obj` with the
corresponding values (strictly matching structure and types).
Tuples in `obj` are expected to be lists in the JSON data, which will be
converted back into tuples.
"""
with open(fpath, encoding="utf-8") as f:
data = json.load(f)
def _deserialize(target, source):
"""
Recursively overwrite the structure in `target` with data from `source`,
performing strict checks on structure and type.
Returns the updated version of `target` (especially important for tuples).
"""
# If the target is a dictionary, source must be a dictionary as well.
if isinstance(target, dict):
if not isinstance(source, dict):
raise TypeError(f"Type mismatch: expected dict, got {type(source)}")
# Check that they have exactly the same set of keys.
if target.keys() != source.keys():
raise ValueError(
f"Dictionary keys do not match.\n" f"Expected: {target.keys()}, got: {source.keys()}"
)
# Recursively update each key.
for k in target:
target[k] = _deserialize(target[k], source[k])
return target
# If the target is a list, source must be a list as well.
elif isinstance(target, list):
if not isinstance(source, list):
raise TypeError(f"Type mismatch: expected list, got {type(source)}")
# Check length
if len(target) != len(source):
raise ValueError(f"List length mismatch: expected {len(target)}, got {len(source)}")
# Recursively update each element.
for i in range(len(target)):
target[i] = _deserialize(target[i], source[i])
return target
# If the target is a tuple, the source must be a list in JSON,
# which we'll convert back to a tuple.
elif isinstance(target, tuple):
if not isinstance(source, list):
raise TypeError(f"Type mismatch: expected list (for tuple), got {type(source)}")
if len(target) != len(source):
raise ValueError(f"Tuple length mismatch: expected {len(target)}, got {len(source)}")
# Convert each element, forming a new tuple.
converted_items = []
for t_item, s_item in zip(target, source, strict=False):
converted_items.append(_deserialize(t_item, s_item))
# Return a brand new tuple (tuples are immutable in Python).
return tuple(converted_items)
# Otherwise, we're dealing with a "primitive" (int, float, str, bool, None).
else:
# Check the exact type. If these must match 1:1, do:
if type(target) is not type(source):
raise TypeError(f"Type mismatch: expected {type(target)}, got {type(source)}")
return source
# Perform the in-place/recursive deserialization
updated_obj = _deserialize(obj, data)
return updated_obj

View File

@ -0,0 +1,163 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from lerobot.common.utils.utils import format_big_number
class AverageMeter:
"""
Computes and stores the average and current value
Adapted from https://github.com/pytorch/examples/blob/main/imagenet/main.py
"""
def __init__(self, name: str, fmt: str = ":f"):
self.name = name
self.fmt = fmt
self.reset()
def reset(self) -> None:
self.val = 0.0
self.avg = 0.0
self.sum = 0.0
self.count = 0.0
def update(self, val: float, n: int = 1) -> None:
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = "{name}:{avg" + self.fmt + "}"
return fmtstr.format(**self.__dict__)
class MetricsTracker:
"""
A helper class to track and log metrics over time.
Usage pattern:
```python
# initialize, potentially with non-zero initial step (e.g. if resuming run)
metrics = {"loss": AverageMeter("loss", ":.3f")}
train_metrics = MetricsTracker(cfg, dataset, metrics, initial_step=step)
# update metrics derived from step (samples, episodes, epochs) at each training step
train_metrics.step()
# update various metrics
loss = policy.forward(batch)
train_metrics.loss = loss
# display current metrics
logging.info(train_metrics)
# export for wandb
wandb.log(train_metrics.to_dict())
# reset averages after logging
train_metrics.reset_averages()
```
"""
__keys__ = [
"_batch_size",
"_num_frames",
"_avg_samples_per_ep",
"metrics",
"steps",
"samples",
"episodes",
"epochs",
]
def __init__(
self,
batch_size: int,
num_frames: int,
num_episodes: int,
metrics: dict[str, AverageMeter],
initial_step: int = 0,
):
self.__dict__.update({k: None for k in self.__keys__})
self._batch_size = batch_size
self._num_frames = num_frames
self._avg_samples_per_ep = num_frames / num_episodes
self.metrics = metrics
self.steps = initial_step
# A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size` number of samples.
self.samples = self.steps * self._batch_size
self.episodes = self.samples / self._avg_samples_per_ep
self.epochs = self.samples / self._num_frames
def __getattr__(self, name: str) -> int | dict[str, AverageMeter] | AverageMeter | Any:
if name in self.__dict__:
return self.__dict__[name]
elif name in self.metrics:
return self.metrics[name]
else:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
def __setattr__(self, name: str, value: Any) -> None:
if name in self.__dict__:
super().__setattr__(name, value)
elif name in self.metrics:
self.metrics[name].update(value)
else:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
def step(self) -> None:
"""
Updates metrics that depend on 'step' for one step.
"""
self.steps += 1
self.samples += self._batch_size
self.episodes = self.samples / self._avg_samples_per_ep
self.epochs = self.samples / self._num_frames
def __str__(self) -> str:
display_list = [
f"step:{format_big_number(self.steps)}",
# number of samples seen during training
f"smpl:{format_big_number(self.samples)}",
# number of episodes seen during training
f"ep:{format_big_number(self.episodes)}",
# number of time all unique samples are seen
f"epch:{self.epochs:.2f}",
*[str(m) for m in self.metrics.values()],
]
return " ".join(display_list)
def to_dict(self, use_avg: bool = True) -> dict[str, int | float]:
"""
Returns the current metric values (or averages if `use_avg=True`) as a dict.
"""
return {
"steps": self.steps,
"samples": self.samples,
"episodes": self.episodes,
"epochs": self.epochs,
**{k: m.avg if use_avg else m.val for k, m in self.metrics.items()},
}
def reset_averages(self) -> None:
"""Resets average meters."""
for m in self.metrics.values():
m.reset()

View File

@ -0,0 +1,191 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Generator
import numpy as np
import torch
from safetensors.torch import load_file, save_file
from lerobot.common.constants import RNG_STATE
from lerobot.common.datasets.utils import flatten_dict, unflatten_dict
def serialize_python_rng_state() -> dict[str, torch.Tensor]:
"""
Returns the rng state for `random` in the form of a flat dict[str, torch.Tensor] to be saved using
`safetensors.save_file()` or `torch.save()`.
"""
py_state = random.getstate()
return {
"py_rng_version": torch.tensor([py_state[0]], dtype=torch.int64),
"py_rng_state": torch.tensor(py_state[1], dtype=torch.int64),
}
def deserialize_python_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
"""
Restores the rng state for `random` from a dictionary produced by `serialize_python_rng_state()`.
"""
py_state = (rng_state_dict["py_rng_version"].item(), tuple(rng_state_dict["py_rng_state"].tolist()), None)
random.setstate(py_state)
def serialize_numpy_rng_state() -> dict[str, torch.Tensor]:
"""
Returns the rng state for `numpy` in the form of a flat dict[str, torch.Tensor] to be saved using
`safetensors.save_file()` or `torch.save()`.
"""
np_state = np.random.get_state()
# Ensure no breaking changes from numpy
assert np_state[0] == "MT19937"
return {
"np_rng_state_values": torch.tensor(np_state[1], dtype=torch.int64),
"np_rng_state_index": torch.tensor([np_state[2]], dtype=torch.int64),
"np_rng_has_gauss": torch.tensor([np_state[3]], dtype=torch.int64),
"np_rng_cached_gaussian": torch.tensor([np_state[4]], dtype=torch.float32),
}
def deserialize_numpy_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
"""
Restores the rng state for `numpy` from a dictionary produced by `serialize_numpy_rng_state()`.
"""
np_state = (
"MT19937",
rng_state_dict["np_rng_state_values"].numpy(),
rng_state_dict["np_rng_state_index"].item(),
rng_state_dict["np_rng_has_gauss"].item(),
rng_state_dict["np_rng_cached_gaussian"].item(),
)
np.random.set_state(np_state)
def serialize_torch_rng_state() -> dict[str, torch.Tensor]:
"""
Returns the rng state for `torch` in the form of a flat dict[str, torch.Tensor] to be saved using
`safetensors.save_file()` or `torch.save()`.
"""
torch_rng_state_dict = {"torch_rng_state": torch.get_rng_state()}
if torch.cuda.is_available():
torch_rng_state_dict["torch_cuda_rng_state"] = torch.cuda.get_rng_state()
return torch_rng_state_dict
def deserialize_torch_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
"""
Restores the rng state for `torch` from a dictionary produced by `serialize_torch_rng_state()`.
"""
torch.set_rng_state(rng_state_dict["torch_rng_state"])
if torch.cuda.is_available() and "torch_cuda_rng_state" in rng_state_dict:
torch.cuda.set_rng_state(rng_state_dict["torch_cuda_rng_state"])
def serialize_rng_state() -> dict[str, torch.Tensor]:
"""
Returns the rng state for `random`, `numpy`, and `torch`, in the form of a flat
dict[str, torch.Tensor] to be saved using `safetensors.save_file()` `torch.save()`.
"""
py_rng_state_dict = serialize_python_rng_state()
np_rng_state_dict = serialize_numpy_rng_state()
torch_rng_state_dict = serialize_torch_rng_state()
return {
**py_rng_state_dict,
**np_rng_state_dict,
**torch_rng_state_dict,
}
def deserialize_rng_state(rng_state_dict: dict[str, torch.Tensor]) -> None:
"""
Restores the rng state for `random`, `numpy`, and `torch` from a dictionary produced by
`serialize_rng_state()`.
"""
py_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("py")}
np_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("np")}
torch_rng_state_dict = {k: v for k, v in rng_state_dict.items() if k.startswith("torch")}
deserialize_python_rng_state(py_rng_state_dict)
deserialize_numpy_rng_state(np_rng_state_dict)
deserialize_torch_rng_state(torch_rng_state_dict)
def save_rng_state(save_dir: Path) -> None:
rng_state_dict = serialize_rng_state()
flat_rng_state_dict = flatten_dict(rng_state_dict)
save_file(flat_rng_state_dict, save_dir / RNG_STATE)
def load_rng_state(save_dir: Path) -> None:
flat_rng_state_dict = load_file(save_dir / RNG_STATE)
rng_state_dict = unflatten_dict(flat_rng_state_dict)
deserialize_rng_state(rng_state_dict)
def get_rng_state() -> dict[str, Any]:
"""Get the random state for `random`, `numpy`, and `torch`."""
random_state_dict = {
"random_state": random.getstate(),
"numpy_random_state": np.random.get_state(),
"torch_random_state": torch.random.get_rng_state(),
}
if torch.cuda.is_available():
random_state_dict["torch_cuda_random_state"] = torch.cuda.random.get_rng_state()
return random_state_dict
def set_rng_state(random_state_dict: dict[str, Any]):
"""Set the random state for `random`, `numpy`, and `torch`.
Args:
random_state_dict: A dictionary of the form returned by `get_rng_state`.
"""
random.setstate(random_state_dict["random_state"])
np.random.set_state(random_state_dict["numpy_random_state"])
torch.random.set_rng_state(random_state_dict["torch_random_state"])
if torch.cuda.is_available():
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
def set_seed(seed) -> None:
"""Set seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
@contextmanager
def seeded_context(seed: int) -> Generator[None, None, None]:
"""Set the seed when entering a context, and restore the prior random state at exit.
Example usage:
```
a = random.random() # produces some random number
with seeded_context(1337):
b = random.random() # produces some other random number
c = random.random() # produces yet another random number, but the same it would have if we never made `b`
```
"""
random_state_dict = get_rng_state()
set_seed(seed)
yield None
set_rng_state(random_state_dict)

View File

@ -0,0 +1,161 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
from termcolor import colored
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from lerobot.common.constants import (
CHECKPOINTS_DIR,
LAST_CHECKPOINT_LINK,
PRETRAINED_MODEL_DIR,
TRAINING_STATE_DIR,
TRAINING_STEP,
)
from lerobot.common.datasets.utils import load_json, write_json
from lerobot.common.optim.optimizers import load_optimizer_state, save_optimizer_state
from lerobot.common.optim.schedulers import load_scheduler_state, save_scheduler_state
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.utils.random_utils import load_rng_state, save_rng_state
from lerobot.configs.train import TrainPipelineConfig
def log_output_dir(out_dir):
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
def get_step_identifier(step: int, total_steps: int) -> str:
num_digits = max(6, len(str(total_steps)))
return f"{step:0{num_digits}d}"
def get_step_checkpoint_dir(output_dir: Path, total_steps: int, step: int) -> Path:
"""Returns the checkpoint sub-directory corresponding to the step number."""
step_identifier = get_step_identifier(step, total_steps)
return output_dir / CHECKPOINTS_DIR / step_identifier
def save_training_step(step: int, save_dir: Path) -> None:
write_json({"step": step}, save_dir / TRAINING_STEP)
def load_training_step(save_dir: Path) -> int:
training_step = load_json(save_dir / TRAINING_STEP)
return training_step["step"]
def update_last_checkpoint(checkpoint_dir: Path) -> Path:
last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK
if last_checkpoint_dir.is_symlink():
last_checkpoint_dir.unlink()
relative_target = checkpoint_dir.relative_to(checkpoint_dir.parent)
last_checkpoint_dir.symlink_to(relative_target)
def save_checkpoint(
checkpoint_dir: Path,
step: int,
cfg: TrainPipelineConfig,
policy: PreTrainedPolicy,
optimizer: Optimizer,
scheduler: LRScheduler | None = None,
) -> None:
"""This function creates the following directory structure:
005000/ # training step at checkpoint
pretrained_model/
config.json # policy config
model.safetensors # policy weights
train_config.json # train config
training_state/
optimizer_param_groups.json # optimizer param groups
optimizer_state.safetensors # optimizer state
rng_state.safetensors # rng states
scheduler_state.json # scheduler state
training_step.json # training step
Args:
cfg (TrainPipelineConfig): The training config used for this run.
step (int): The training step at that checkpoint.
policy (PreTrainedPolicy): The policy to save.
optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None.
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
"""
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
policy.save_pretrained(pretrained_dir)
cfg.save_pretrained(pretrained_dir)
save_training_state(checkpoint_dir, step, optimizer, scheduler)
def save_training_state(
checkpoint_dir: Path,
train_step: int,
optimizer: Optimizer | None = None,
scheduler: LRScheduler | None = None,
) -> None:
"""
Saves the training step, optimizer state, scheduler state, and rng state.
Args:
save_dir (Path): The directory to save artifacts to.
train_step (int): Current training step.
optimizer (Optimizer | None, optional): The optimizer from which to save the state_dict.
Defaults to None.
scheduler (LRScheduler | None, optional): The scheduler from which to save the state_dict.
Defaults to None.
"""
save_dir = checkpoint_dir / TRAINING_STATE_DIR
save_dir.mkdir(parents=True, exist_ok=True)
save_training_step(train_step, save_dir)
save_rng_state(save_dir)
if optimizer is not None:
save_optimizer_state(optimizer, save_dir)
if scheduler is not None:
save_scheduler_state(scheduler, save_dir)
def load_training_state(
checkpoint_dir: Path, optimizer: Optimizer, scheduler: LRScheduler | None
) -> tuple[int, Optimizer, LRScheduler | None]:
"""
Loads the training step, optimizer state, scheduler state, and rng state.
This is used to resume a training run.
Args:
checkpoint_dir (Path): The checkpoint directory. Should contain a 'training_state' dir.
optimizer (Optimizer): The optimizer to load the state_dict to.
scheduler (LRScheduler | None): The scheduler to load the state_dict to (can be None).
Raises:
NotADirectoryError: If 'checkpoint_dir' doesn't contain a 'training_state' dir
Returns:
tuple[int, Optimizer, LRScheduler | None]: training step, optimizer and scheduler with their
state_dict loaded.
"""
training_state_dir = checkpoint_dir / TRAINING_STATE_DIR
if not training_state_dir.is_dir():
raise NotADirectoryError(training_state_dir)
load_rng_state(training_state_dir)
step = load_training_step(training_state_dir)
optimizer = load_optimizer_state(optimizer, training_state_dir)
if scheduler is not None:
scheduler = load_scheduler_state(scheduler, training_state_dir)
return step, optimizer, scheduler

View File

@ -17,14 +17,10 @@ import logging
import os
import os.path as osp
import platform
import random
from contextlib import contextmanager
from copy import copy
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Generator
import numpy as np
import torch
@ -106,59 +102,6 @@ def is_amp_available(device: str):
raise ValueError(f"Unknown device '{device}.")
def get_global_random_state() -> dict[str, Any]:
"""Get the random state for `random`, `numpy`, and `torch`."""
random_state_dict = {
"random_state": random.getstate(),
"numpy_random_state": np.random.get_state(),
"torch_random_state": torch.random.get_rng_state(),
}
if torch.cuda.is_available():
random_state_dict["torch_cuda_random_state"] = torch.cuda.random.get_rng_state()
return random_state_dict
def set_global_random_state(random_state_dict: dict[str, Any]):
"""Set the random state for `random`, `numpy`, and `torch`.
Args:
random_state_dict: A dictionary of the form returned by `get_global_random_state`.
"""
random.setstate(random_state_dict["random_state"])
np.random.set_state(random_state_dict["numpy_random_state"])
torch.random.set_rng_state(random_state_dict["torch_random_state"])
if torch.cuda.is_available():
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
def set_global_seed(seed):
"""Set seed for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
@contextmanager
def seeded_context(seed: int) -> Generator[None, None, None]:
"""Set the seed when entering a context, and restore the prior random state at exit.
Example usage:
```
a = random.random() # produces some random number
with seeded_context(1337):
b = random.random() # produces some other random number
c = random.random() # produces yet another random number, but the same it would have if we never made `b`
```
"""
random_state_dict = get_global_random_state()
set_global_seed(seed)
yield None
set_global_random_state(random_state_dict)
def init_logging():
def custom_format(record):
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

View File

@ -0,0 +1,121 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import re
from glob import glob
from pathlib import Path
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from termcolor import colored
from lerobot.common.constants import PRETRAINED_MODEL_DIR
from lerobot.configs.train import TrainPipelineConfig
def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str:
"""Return a group name for logging. Optionally returns group name as list."""
lst = [
f"policy:{cfg.policy.type}",
f"dataset:{cfg.dataset.repo_id}",
f"seed:{cfg.seed}",
]
if cfg.env is not None:
lst.append(f"env:{cfg.env.type}")
return lst if return_list else "-".join(lst)
def get_wandb_run_id_from_filesystem(log_dir: Path) -> str:
# Get the WandB run ID.
paths = glob(str(log_dir / "wandb/latest-run/run-*"))
if len(paths) != 1:
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
match = re.search(r"run-([^\.]+).wandb", paths[0].split("/")[-1])
if match is None:
raise RuntimeError("Couldn't get the previous WandB run ID for run resumption.")
wandb_run_id = match.groups(0)[0]
return wandb_run_id
def get_safe_wandb_artifact_name(name: str):
"""WandB artifacts don't accept ":" or "/" in their name."""
return name.replace(":", "_").replace("/", "_")
class WandBLogger:
"""A helper class to log object using wandb."""
def __init__(self, cfg: TrainPipelineConfig):
self.cfg = cfg.wandb
self.log_dir = cfg.output_dir
self.job_name = cfg.job_name
self.env_fps = cfg.env.fps if cfg.env else None
self._group = cfg_to_group(cfg)
# Set up WandB.
os.environ["WANDB_SILENT"] = "True"
import wandb
wandb_run_id = get_wandb_run_id_from_filesystem(self.log_dir) if cfg.resume else None
wandb.init(
id=wandb_run_id,
project=self.cfg.project,
entity=self.cfg.entity,
name=self.job_name,
notes=self.cfg.notes,
tags=cfg_to_group(cfg, return_list=True),
dir=self.log_dir,
config=cfg.to_dict(),
# TODO(rcadene): try set to True
save_code=False,
# TODO(rcadene): split train and eval, and run async eval with job_type="eval"
job_type="train_eval",
resume="must" if cfg.resume else None,
)
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
self._wandb = wandb
def log_policy(self, checkpoint_dir: Path):
"""Checkpoints the policy to wandb."""
if self.cfg.disable_artifact:
return
step_id = checkpoint_dir.name
artifact_name = f"{self._group}-{step_id}"
artifact_name = get_safe_wandb_artifact_name(artifact_name)
artifact = self._wandb.Artifact(artifact_name, type="model")
artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE)
self._wandb.log_artifact(artifact)
def log_dict(self, d: dict, step: int, mode: str = "train"):
if mode in {"train", "eval"}:
raise ValueError(mode)
for k, v in d.items():
if not isinstance(v, (int, float, str)):
logging.warning(
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
)
continue
self._wandb.log({f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"):
if mode in {"train", "eval"}:
raise ValueError(mode)
wandb_video = self._wandb.Video(video_path, fps=self.env_fps, format="mp4")
self._wandb.log({f"{mode}/video": wandb_video}, step=step)

View File

@ -21,68 +21,6 @@ from lerobot.configs.policies import PreTrainedConfig
TRAIN_CONFIG_NAME = "train_config.json"
@dataclass
class OfflineConfig:
steps: int = 100_000
@dataclass
class OnlineConfig:
"""
The online training loop looks something like:
```python
for i in range(steps):
do_online_rollout_and_update_online_buffer()
for j in range(steps_between_rollouts):
batch = next(dataloader_with_offline_and_online_data)
loss = policy(batch)
loss.backward()
optimizer.step()
```
Note that the online training loop adopts most of the options from the offline loop unless specified
otherwise.
"""
steps: int = 0
# How many episodes to collect at once when we reach the online rollout part of the training loop.
rollout_n_episodes: int = 1
# The number of environments to use in the gym.vector.VectorEnv. This ends up also being the batch size for
# the policy. Ideally you should set this to by an even divisor of rollout_n_episodes.
rollout_batch_size: int = 1
# How many optimization steps (forward, backward, optimizer step) to do between running rollouts.
steps_between_rollouts: int | None = None
# The proportion of online samples (vs offline samples) to include in the online training batches.
sampling_ratio: float = 0.5
# First seed to use for the online rollout environment. Seeds for subsequent rollouts are incremented by 1.
env_seed: int | None = None
# Sets the maximum number of frames that are stored in the online buffer for online training. The buffer is
# FIFO.
buffer_capacity: int | None = None
# The minimum number of frames to have in the online buffer before commencing online training.
# If buffer_seed_size > rollout_n_episodes, the rollout will be run multiple times until the
# seed size condition is satisfied.
buffer_seed_size: int = 0
# Whether to run the online rollouts asynchronously. This means we can run the online training steps in
# parallel with the rollouts. This might be advised if your GPU has the bandwidth to handle training
# + eval + environment rendering simultaneously.
do_rollout_async: bool = False
def __post_init__(self):
if self.steps == 0:
return
if self.steps_between_rollouts is None:
raise ValueError(
"'steps_between_rollouts' must be set to a positive integer, but it is currently None."
)
if self.env_seed is None:
raise ValueError("'env_seed' must be set to a positive integer, but it is currently None.")
if self.buffer_capacity is None:
raise ValueError("'buffer_capacity' must be set to a positive integer, but it is currently None.")
@dataclass
class TrainPipelineConfig(HubMixin):
dataset: DatasetConfig
@ -107,13 +45,12 @@ class TrainPipelineConfig(HubMixin):
# Number of workers for the dataloader.
num_workers: int = 4
batch_size: int = 8
steps: int = 100_000
eval_freq: int = 20_000
log_freq: int = 200
save_checkpoint: bool = True
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
save_freq: int = 20_000
offline: OfflineConfig = field(default_factory=OfflineConfig)
online: OnlineConfig = field(default_factory=OnlineConfig)
use_policy_training_preset: bool = True
optimizer: OptimizerConfig | None = None
scheduler: LRSchedulerConfig | None = None
@ -168,11 +105,8 @@ class TrainPipelineConfig(HubMixin):
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
self.output_dir = Path("outputs/train") / train_dir
if self.online.steps > 0:
if isinstance(self.dataset.repo_id, list):
raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.")
if self.env is None:
raise ValueError("An environment is required for online training")
if isinstance(self.dataset.repo_id, list):
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
@ -185,6 +119,9 @@ class TrainPipelineConfig(HubMixin):
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]
def to_dict(self) -> dict:
return draccus.encode(self)
def _save_pretrained(self, save_directory: Path) -> None:
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
draccus.dump(self, f, indent=4)

View File

@ -61,21 +61,21 @@ import einops
import gymnasium as gym
import numpy as np
import torch
from termcolor import colored
from torch import Tensor, nn
from tqdm import trange
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.logger import log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.io_utils import write_video
from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.utils import (
get_safe_torch_device,
init_logging,
inside_slurm,
set_global_seed,
)
from lerobot.configs import parser
from lerobot.configs.eval import EvalPipelineConfig
@ -125,9 +125,6 @@ def rollout(
# Reset the policy and environments.
policy.reset()
if hasattr(policy, "use_ema_modules"):
policy.use_ema_modules()
observation, info = env.reset(seed=seeds)
if render_callback is not None:
render_callback(env)
@ -463,9 +460,9 @@ def eval(cfg: EvalPipelineConfig):
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(cfg.seed)
set_seed(cfg.seed)
log_output_dir(cfg.output_dir)
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
logging.info("Making environment.")
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)

View File

@ -15,58 +15,61 @@
# limitations under the License.
import logging
import time
from concurrent.futures import ThreadPoolExecutor
from contextlib import nullcontext
from copy import deepcopy
from dataclasses import asdict
from pprint import pformat
from threading import Lock
from typing import Any
import numpy as np
import torch
from termcolor import colored
from torch.amp import GradScaler
from torch.optim import Optimizer
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.optim.factory import load_training_state, make_optimizer_and_scheduler
from lerobot.common.optim.factory import make_optimizer_and_scheduler
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.common.utils.random_utils import set_seed
from lerobot.common.utils.train_utils import (
get_step_checkpoint_dir,
get_step_identifier,
load_training_state,
save_checkpoint,
update_last_checkpoint,
)
from lerobot.common.utils.utils import (
format_big_number,
get_safe_dtype,
get_safe_torch_device,
has_method,
init_logging,
set_global_seed,
)
from lerobot.common.utils.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.scripts.eval import eval_policy
def update_policy(
policy,
batch,
optimizer,
grad_clip_norm,
train_metrics: MetricsTracker,
policy: PreTrainedPolicy,
batch: Any,
optimizer: Optimizer,
grad_clip_norm: float,
grad_scaler: GradScaler,
lr_scheduler=None,
use_amp: bool = False,
lock=None,
):
"""Returns a dictionary of items for logging."""
) -> tuple[MetricsTracker, dict]:
start_time = time.perf_counter()
device = get_device_from_parameters(policy)
policy.train()
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
output_dict = policy.forward(batch)
loss, output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
loss = output_dict["loss"]
grad_scaler.scale(loss).backward()
# Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**.
@ -87,9 +90,6 @@ def update_policy(
optimizer.zero_grad()
if hasattr(policy, "update_ema_modules"):
policy.update_ema_modules()
# Step through pytorch scheduler at every batch instead of epoch
if lr_scheduler is not None:
lr_scheduler.step()
@ -98,113 +98,34 @@ def update_policy(
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
policy.update()
info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": optimizer.param_groups[0]["lr"],
"update_s": time.perf_counter() - start_time,
**{k: v for k, v in output_dict.items() if k != "loss"},
}
info.update({k: v for k, v in output_dict.items() if k not in info})
return info
def log_train_info(
logger: Logger, info: dict, step: int, cfg: TrainPipelineConfig, dataset: LeRobotDataset, is_online: bool
):
loss = info["loss"]
grad_norm = info["grad_norm"]
lr = info["lr"]
update_s = info["update_s"]
dataloading_s = info["dataloading_s"]
# A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
num_samples = (step + 1) * cfg.batch_size
avg_samples_per_ep = dataset.num_frames / dataset.num_episodes
num_episodes = num_samples / avg_samples_per_ep
num_epochs = num_samples / dataset.num_frames
log_items = [
f"step:{format_big_number(step)}",
# number of samples seen during training
f"smpl:{format_big_number(num_samples)}",
# number of episodes seen during training
f"ep:{format_big_number(num_episodes)}",
# number of time all unique samples are seen
f"epch:{num_epochs:.2f}",
f"loss:{loss:.3f}",
f"grdn:{grad_norm:.3f}",
f"lr:{lr:0.1e}",
# in seconds
f"updt_s:{update_s:.3f}",
f"data_s:{dataloading_s:.3f}", # if not ~0, you are bottlenecked by cpu or io
]
logging.info(" ".join(log_items))
info["step"] = step
info["num_samples"] = num_samples
info["num_episodes"] = num_episodes
info["num_epochs"] = num_epochs
info["is_online"] = is_online
logger.log_dict(info, step, mode="train")
def log_eval_info(logger, info, step, cfg, dataset, is_online):
eval_s = info["eval_s"]
avg_sum_reward = info["avg_sum_reward"]
pc_success = info["pc_success"]
# A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
num_samples = (step + 1) * cfg.batch_size
avg_samples_per_ep = dataset.num_frames / dataset.num_episodes
num_episodes = num_samples / avg_samples_per_ep
num_epochs = num_samples / dataset.num_frames
log_items = [
f"step:{format_big_number(step)}",
# number of samples seen during training
f"smpl:{format_big_number(num_samples)}",
# number of episodes seen during training
f"ep:{format_big_number(num_episodes)}",
# number of time all unique samples are seen
f"epch:{num_epochs:.2f}",
f"∑rwrd:{avg_sum_reward:.3f}",
f"success:{pc_success:.1f}%",
f"eval_s:{eval_s:.3f}",
]
logging.info(" ".join(log_items))
info["step"] = step
info["num_samples"] = num_samples
info["num_episodes"] = num_episodes
info["num_epochs"] = num_epochs
info["is_online"] = is_online
logger.log_dict(info, step, mode="eval")
train_metrics.loss = loss.item()
train_metrics.grad_norm = grad_norm.item()
train_metrics.lr = optimizer.param_groups[0]["lr"]
train_metrics.update_s = time.perf_counter() - start_time
return train_metrics, output_dict
@parser.wrap()
def train(cfg: TrainPipelineConfig):
cfg.validate()
logging.info(pformat(cfg.to_dict()))
logging.info(pformat(asdict(cfg)))
# log metrics to terminal and wandb
logger = Logger(cfg)
if cfg.wandb.enable and cfg.wandb.project:
wandb_logger = WandBLogger(cfg)
else:
wandb_logger = None
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
if cfg.seed is not None:
set_global_seed(cfg.seed)
set_seed(cfg.seed)
# Check device is available
device = get_safe_torch_device(cfg.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
logging.info("Creating dataset")
offline_dataset = make_dataset(cfg)
dataset = make_dataset(cfg)
# Create environment used for evaluating checkpoints during training on simulation data.
# On real-world data, no need to create an environment as evaluations are done outside train.py,
@ -218,7 +139,7 @@ def train(cfg: TrainPipelineConfig):
policy = make_policy(
cfg=cfg.policy,
device=device,
ds_meta=offline_dataset.meta,
ds_meta=dataset.meta,
)
logging.info("Creating optimizer and scheduler")
@ -233,65 +154,29 @@ def train(cfg: TrainPipelineConfig):
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
log_output_dir(cfg.output_dir)
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
if cfg.env is not None:
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.offline.steps=} ({format_big_number(cfg.offline.steps)})")
logging.info(f"{cfg.online.steps=}")
logging.info(f"{offline_dataset.num_frames=} ({format_big_number(offline_dataset.num_frames)})")
logging.info(f"{offline_dataset.num_episodes=}")
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
logging.info(f"{dataset.num_episodes=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# Note: this helper will be used in offline and online training loops.
def evaluate_and_checkpoint_if_needed(step: int, is_online: bool):
_num_digits = max(6, len(str(cfg.offline.steps + cfg.online.steps)))
step_identifier = f"{step:0{_num_digits}d}"
if cfg.env is not None and cfg.eval_freq > 0 and step % cfg.eval_freq == 0:
logging.info(f"Eval policy at step {step}")
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
eval_info = eval_policy(
eval_env,
policy,
cfg.eval.n_episodes,
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_identifier}",
max_episodes_rendered=4,
start_seed=cfg.seed,
)
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_online=is_online)
if cfg.wandb.enable:
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
logging.info("Resume training")
if cfg.save_checkpoint and (
step % cfg.save_freq == 0 or step == cfg.offline.steps + cfg.online.steps
):
logging.info(f"Checkpoint policy after step {step}")
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
# needed (choose 6 as a minimum for consistency without being overkill).
logger.save_checkpoint(
step,
step_identifier,
policy,
optimizer,
lr_scheduler,
)
logging.info("Resume training")
# create dataloader for offline training
if getattr(cfg.policy, "drop_n_last_frames", None):
if hasattr(cfg.policy, "drop_n_last_frames"):
shuffle = False
sampler = EpisodeAwareSampler(
offline_dataset.episode_data_index,
dataset.episode_data_index,
drop_n_last_frames=cfg.policy.drop_n_last_frames,
shuffle=True,
)
else:
shuffle = True
sampler = None
dataloader = torch.utils.data.DataLoader(
offline_dataset,
dataset,
num_workers=cfg.num_workers,
batch_size=cfg.batch_size,
shuffle=shuffle,
@ -303,23 +188,30 @@ def train(cfg: TrainPipelineConfig):
policy.train()
if hasattr(policy, "init_ema_modules"):
policy.init_ema_modules()
train_metrics = {
"loss": AverageMeter("loss", ":.3f"),
"grad_norm": AverageMeter("grdn", ":.3f"),
"lr": AverageMeter("lr", ":0.1e"),
"update_s": AverageMeter("updt_s", ":.3f"),
"dataloading_s": AverageMeter("data_s", ":.3f"),
}
offline_step = 0
for _ in range(step, cfg.offline.steps):
if offline_step == 0:
logging.info("Start offline training on a fixed dataset")
train_tracker = MetricsTracker(
cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step
)
logging.info("Start offline training on a fixed dataset")
for _ in range(step, cfg.steps):
start_time = time.perf_counter()
batch = next(dl_iter)
dataloading_s = time.perf_counter() - start_time
train_tracker.dataloading_s = time.perf_counter() - start_time
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(device, non_blocking=True)
train_info = update_policy(
train_tracker, output_dict = update_policy(
train_tracker,
policy,
batch,
optimizer,
@ -329,231 +221,57 @@ def train(cfg: TrainPipelineConfig):
use_amp=cfg.use_amp,
)
train_info["dataloading_s"] = dataloading_s
if step % cfg.log_freq == 0:
log_train_info(logger, train_info, step, cfg, offline_dataset, is_online=False)
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
# so we pass in step + 1.
evaluate_and_checkpoint_if_needed(step + 1, is_online=False)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
# increment `step` here.
step += 1
offline_step += 1 # noqa: SIM113
train_tracker.step()
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
if cfg.online.steps == 0:
if eval_env:
eval_env.close()
logging.info("End of training")
return
if is_log_step:
logging.info(train_tracker)
if wandb_logger:
wandb_log_dict = {**train_tracker.to_dict(), **output_dict}
wandb_logger.log_dict(wandb_log_dict)
train_tracker.reset_averages()
# Online training.
if cfg.save_checkpoint and is_saving_step:
logging.info(f"Checkpoint policy after step {step}")
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)
update_last_checkpoint(checkpoint_dir)
if wandb_logger:
wandb_logger.log_policy(checkpoint_dir)
# Create an env dedicated to online episodes collection from policy rollout.
online_env = make_env(cfg.env, n_envs=cfg.online.rollout_batch_size)
delta_timestamps = resolve_delta_timestamps(cfg.policy, offline_dataset.meta)
online_buffer_path = logger.log_dir / "online_buffer"
if cfg.resume and not online_buffer_path.exists():
# If we are resuming a run, we default to the data shapes and buffer capacity from the saved online
# buffer.
logging.warning(
"When online training is resumed, we load the latest online buffer from the prior run, "
"and this might not coincide with the state of the buffer as it was at the moment the checkpoint "
"was made. This is because the online buffer is updated on disk during training, independently "
"of our explicit checkpointing mechanisms."
)
online_dataset = OnlineBuffer(
online_buffer_path,
data_spec={
**{
key: {"shape": ft.shape, "dtype": np.dtype("float32")}
for key, ft in policy.config.input_features.items()
},
**{
key: {"shape": ft.shape, "dtype": np.dtype("float32")}
for key, ft in policy.config.output_features.items()
},
"next.reward": {"shape": (), "dtype": np.dtype("float32")},
"next.done": {"shape": (), "dtype": np.dtype("?")},
"task_index": {"shape": (), "dtype": np.dtype("int64")},
# FIXME: 'task' is a string
# "task": {"shape": (), "dtype": np.dtype("?")},
# FIXME: 'next.success' is expected by pusht env but not xarm
"next.success": {"shape": (), "dtype": np.dtype("?")},
},
buffer_capacity=cfg.online.buffer_capacity,
fps=online_env.unwrapped.metadata["render_fps"],
delta_timestamps=delta_timestamps,
)
# If we are doing online rollouts asynchronously, deepcopy the policy to use for online rollouts (this
# makes it possible to do online rollouts in parallel with training updates).
online_rollout_policy = deepcopy(policy) if cfg.online.do_rollout_async else policy
# Create dataloader for online training.
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
sampler_weights = compute_sampler_weights(
offline_dataset,
offline_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0),
online_dataset=online_dataset,
# +1 because online rollouts return an extra frame for the "final observation". Note: we don't have
# this final observation in the offline datasets, but we might add them in future.
online_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0) + 1,
online_sampling_ratio=cfg.online.sampling_ratio,
)
sampler = torch.utils.data.WeightedRandomSampler(
sampler_weights,
num_samples=len(concat_dataset),
replacement=True,
)
dataloader = torch.utils.data.DataLoader(
concat_dataset,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
sampler=sampler,
pin_memory=device.type != "cpu",
drop_last=True,
)
dl_iter = cycle(dataloader)
if cfg.online.do_rollout_async:
# Lock and thread pool executor for asynchronous online rollouts.
lock = Lock()
# Note: 1 worker because we only ever want to run one set of online rollouts at a time. Batch
# parallelization of rollouts is handled within the job.
executor = ThreadPoolExecutor(max_workers=1)
else:
lock = None
online_step = 0
online_rollout_s = 0 # time take to do online rollout
update_online_buffer_s = 0 # time taken to update the online buffer with the online rollout data
# Time taken waiting for the online buffer to finish being updated. This is relevant when using the async
# online rollout option.
await_update_online_buffer_s = 0
rollout_start_seed = cfg.online.env_seed
while True:
if online_step == cfg.online.steps:
break
if online_step == 0:
logging.info("Start online training by interacting with environment")
def sample_trajectory_and_update_buffer():
nonlocal rollout_start_seed
with lock if lock is not None else nullcontext():
online_rollout_policy.load_state_dict(policy.state_dict())
online_rollout_policy.eval()
start_rollout_time = time.perf_counter()
with torch.no_grad():
if cfg.env and is_eval_step:
step_id = get_step_identifier(step, cfg.steps)
logging.info(f"Eval policy at step {step}")
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
eval_info = eval_policy(
online_env,
online_rollout_policy,
n_episodes=cfg.online.rollout_n_episodes,
max_episodes_rendered=min(10, cfg.online.rollout_n_episodes),
videos_dir=logger.log_dir / "online_rollout_videos",
return_episode_data=True,
start_seed=(rollout_start_seed := (rollout_start_seed + cfg.batch_size) % 1000000),
eval_env,
policy,
cfg.eval.n_episodes,
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
max_episodes_rendered=4,
start_seed=cfg.seed,
)
online_rollout_s = time.perf_counter() - start_rollout_time
if len(offline_dataset.meta.tasks) > 1:
raise NotImplementedError("Add support for multi task.")
# TODO(rcadene, aliberts): Hack to add a task to the online_dataset (0 is the first task of the offline_dataset)
total_num_frames = eval_info["episodes"]["index"].shape[0]
eval_info["episodes"]["task_index"] = torch.tensor([0] * total_num_frames, dtype=torch.int64)
eval_info["episodes"]["task"] = ["do the thing"] * total_num_frames
with lock if lock is not None else nullcontext():
start_update_buffer_time = time.perf_counter()
online_dataset.add_data(eval_info["episodes"])
# Update the concatenated dataset length used during sampling.
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
# Update the sampling weights.
sampler.weights = compute_sampler_weights(
offline_dataset,
offline_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0),
online_dataset=online_dataset,
# +1 because online rollouts return an extra frame for the "final observation". Note: we don't have
# this final observation in the offline datasets, but we might add them in future.
online_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0) + 1,
online_sampling_ratio=cfg.online.sampling_ratio,
)
sampler.num_frames = len(concat_dataset)
update_online_buffer_s = time.perf_counter() - start_update_buffer_time
return online_rollout_s, update_online_buffer_s
if lock is None:
online_rollout_s, update_online_buffer_s = sample_trajectory_and_update_buffer()
else:
future = executor.submit(sample_trajectory_and_update_buffer)
# If we aren't doing async rollouts, or if we haven't yet gotten enough examples in our buffer, wait
# here until the rollout and buffer update is done, before proceeding to the policy update steps.
if len(online_dataset) <= cfg.online.buffer_seed_size:
online_rollout_s, update_online_buffer_s = future.result()
if len(online_dataset) <= cfg.online.buffer_seed_size:
logging.info(f"Seeding online buffer: {len(online_dataset)}/{cfg.online.buffer_seed_size}")
continue
policy.train()
for _ in range(cfg.online.steps_between_rollouts):
with lock if lock is not None else nullcontext():
start_time = time.perf_counter()
batch = next(dl_iter)
dataloading_s = time.perf_counter() - start_time
for key in batch:
if isinstance(batch[key], torch.Tensor):
dtype = get_safe_dtype(batch[key].dtype, device)
batch[key] = batch[key].to(device=device, dtype=dtype, non_blocking=True)
train_info = update_policy(
policy,
batch,
optimizer,
cfg.optimizer.grad_clip_norm,
grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler,
use_amp=cfg.use_amp,
lock=lock,
eval_metrics = {
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
"pc_success": AverageMeter("success", ":.1f"),
"eval_s": AverageMeter("eval_s", ":.3f"),
}
eval_tracker = MetricsTracker(
cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step
)
train_info["dataloading_s"] = dataloading_s
train_info["online_rollout_s"] = online_rollout_s
train_info["update_online_buffer_s"] = update_online_buffer_s
train_info["await_update_online_buffer_s"] = await_update_online_buffer_s
with lock if lock is not None else nullcontext():
train_info["online_buffer_size"] = len(online_dataset)
if step % cfg.log_freq == 0:
log_train_info(logger, train_info, step, cfg, online_dataset, is_online=True)
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
# so we pass in step + 1.
evaluate_and_checkpoint_if_needed(step + 1, is_online=True)
step += 1
online_step += 1
# If we're doing async rollouts, we should now wait until we've completed them before proceeding
# to do the next batch of rollouts.
if cfg.online.do_rollout_async and future.running():
start = time.perf_counter()
online_rollout_s, update_online_buffer_s = future.result()
await_update_online_buffer_s = time.perf_counter() - start
if online_step >= cfg.online.steps:
break
eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s")
eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward")
eval_tracker.pc_success = eval_info["aggregated"].pop("pc_success")
logging.info(eval_tracker)
if wandb_logger:
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
wandb_logger.log_video(eval_info["video_paths"][0], step, mode="eval")
if eval_env:
eval_env.close()

24
poetry.lock generated
View File

@ -1257,12 +1257,14 @@ pyarrow = "*"
[[package]]
name = "draccus"
version = "0.9.3"
version = "0.10.0"
description = "A slightly opinionated framework for simple dataclass-based configurations based on Pyrallis."
optional = false
python-versions = ">=3.8"
files = []
develop = false
python-versions = ">=3.9"
files = [
{file = "draccus-0.10.0-py3-none-any.whl", hash = "sha256:90243418ae0e9271c390a59cafb6acfd37001193696ed36fcc8525f791a83282"},
{file = "draccus-0.10.0.tar.gz", hash = "sha256:8dd08304219becdcd66cd16058ba98e9c3e6b7bfe48ccb9579dae39f8d37ae19"},
]
[package.dependencies]
mergedeep = ">=1.3,<2.0"
@ -1274,12 +1276,6 @@ typing-inspect = ">=0.9.0,<0.10.0"
[package.extras]
dev = ["black", "mypy", "pre-commit", "pytest", "ruff"]
[package.source]
type = "git"
url = "https://github.com/dlwh/draccus.git"
reference = "HEAD"
resolved_reference = "9b690730ca108930519f48cc5dead72a72fd27cb"
[[package]]
name = "drawnow"
version = "0.72.5"
@ -4911,8 +4907,6 @@ files = [
{file = "PyAudio-0.2.14-cp311-cp311-win_amd64.whl", hash = "sha256:bbeb01d36a2f472ae5ee5e1451cacc42112986abe622f735bb870a5db77cf903"},
{file = "PyAudio-0.2.14-cp312-cp312-win32.whl", hash = "sha256:5fce4bcdd2e0e8c063d835dbe2860dac46437506af509353c7f8114d4bacbd5b"},
{file = "PyAudio-0.2.14-cp312-cp312-win_amd64.whl", hash = "sha256:12f2f1ba04e06ff95d80700a78967897a489c05e093e3bffa05a84ed9c0a7fa3"},
{file = "PyAudio-0.2.14-cp313-cp313-win32.whl", hash = "sha256:95328285b4dab57ea8c52a4a996cb52be6d629353315be5bfda403d15932a497"},
{file = "PyAudio-0.2.14-cp313-cp313-win_amd64.whl", hash = "sha256:692d8c1446f52ed2662120bcd9ddcb5aa2b71f38bda31e58b19fb4672fffba69"},
{file = "PyAudio-0.2.14-cp38-cp38-win32.whl", hash = "sha256:858caf35b05c26d8fc62f1efa2e8f53d5fa1a01164842bd622f70ddc41f55000"},
{file = "PyAudio-0.2.14-cp38-cp38-win_amd64.whl", hash = "sha256:2dac0d6d675fe7e181ba88f2de88d321059b69abd52e3f4934a8878e03a7a074"},
{file = "PyAudio-0.2.14-cp39-cp39-win32.whl", hash = "sha256:f745109634a7c19fa4d6b8b7d6967c3123d988c9ade0cd35d4295ee1acdb53e9"},
@ -6934,7 +6928,7 @@ test = ["pytest", "ruff"]
name = "tokenizers"
version = "0.21.0"
description = ""
optional = false
optional = true
python-versions = ">=3.7"
files = [
{file = "tokenizers-0.21.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:3c4c93eae637e7d2aaae3d376f06085164e1660f89304c0ab2b1d08a406636b2"},
@ -7168,7 +7162,7 @@ test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,
name = "transformers"
version = "4.48.0"
description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow"
optional = false
optional = true
python-versions = ">=3.9.0"
files = [
{file = "transformers-4.48.0-py3-none-any.whl", hash = "sha256:6d3de6d71cb5f2a10f9775ccc17abce9620195caaf32ec96542bd2a6937f25b0"},
@ -7936,4 +7930,4 @@ xarm = ["gym-xarm"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.13"
content-hash = "089df92a0455bb58d2f600ad645b99362c9ff2b800a0c6108edc09f51509c716"
content-hash = "3df78b6f373b1e4ee79530932f386fb4cfd302c1ceffe26617d26fb1a7e751ce"

View File

@ -69,8 +69,8 @@ pyrender = {git = "https://github.com/mmatl/pyrender.git", markers = "sys_platfo
hello-robot-stretch-body = {version = ">=0.7.27", markers = "sys_platform == 'linux'", optional = true}
pyserial = {version = ">=3.5", optional = true}
jsonlines = ">=4.0.0"
transformers = ">=4.48.0"
draccus = {git = "https://github.com/dlwh/draccus.git"} # replace with draccus = ">=0.9.4" when https://github.com/dlwh/draccus/pull/29 is in release
transformers = {version = ">=4.48.0", optional = true}
draccus = ">=0.10.0"
[tool.poetry.extras]

View File

@ -28,6 +28,7 @@ pytest_plugins = [
"tests.fixtures.dataset_factories",
"tests.fixtures.files",
"tests.fixtures.hub",
"tests.fixtures.optimizers",
]

26
tests/fixtures/optimizers.py vendored Normal file
View File

@ -0,0 +1,26 @@
import pytest
import torch
from lerobot.common.optim.optimizers import AdamConfig
from lerobot.common.optim.schedulers import VQBeTSchedulerConfig
@pytest.fixture
def model_params():
return [torch.nn.Parameter(torch.randn(10, 10))]
@pytest.fixture
def optimizer(model_params):
optimizer = AdamConfig().build(model_params)
# Dummy step to populate state
loss = sum(param.sum() for param in model_params)
loss.backward()
optimizer.step()
return optimizer
@pytest.fixture
def scheduler(optimizer):
config = VQBeTSchedulerConfig(num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5)
return config.build(optimizer, num_training_steps=100)

View File

@ -25,7 +25,7 @@ from lerobot.common.datasets.transforms import (
ImageTransformsConfig,
make_transform_from_config,
)
from lerobot.common.utils.utils import seeded_context
from lerobot.common.utils.random_utils import seeded_context
ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors")
DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"

View File

@ -22,14 +22,14 @@ from safetensors.torch import save_file
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.optim.factory import make_optimizer_and_scheduler
from lerobot.common.policies.factory import make_policy, make_policy_config
from lerobot.common.utils.utils import set_global_seed
from lerobot.common.utils.random_utils import set_seed
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs):
# TODO(rcadene, aliberts): env_name?
set_global_seed(1337)
set_seed(1337)
train_cfg = TrainPipelineConfig(
# TODO(rcadene, aliberts): remove dataset download
@ -53,9 +53,9 @@ def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwa
)
batch = next(iter(dataloader))
output_dict = policy.forward(batch)
loss, output_dict = policy.forward(batch)
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
loss = output_dict["loss"]
output_dict["loss"] = loss
loss.backward()
grad_stats = {}

View File

@ -29,7 +29,6 @@ from unittest.mock import patch
import pytest
from lerobot.common.logger import Logger
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.control_configs import (
@ -38,9 +37,7 @@ from lerobot.common.robot_devices.control_configs import (
ReplayControlConfig,
TeleoperateControlConfig,
)
from lerobot.configs.default import DatasetConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate
from tests.test_robots import make_robot
from tests.utils import DEVICE, TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
@ -185,20 +182,8 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
out_dir = tmpdir / "logger"
ds_cfg = DatasetConfig(repo_id, local_files_only=True)
train_cfg = TrainPipelineConfig(
dataset=ds_cfg,
policy=policy_cfg,
output_dir=out_dir,
device=DEVICE,
)
logger = Logger(train_cfg)
logger.save_checkpoint(
train_step=0,
identifier=0,
policy=policy,
)
pretrained_policy_path = out_dir / "checkpoints/last/pretrained_model"
policy.save_pretrained(pretrained_policy_path)
# In `examples/9_use_aloha.md`, we advise using `num_image_writer_processes=1`
# during inference, to reach constent fps, so we test this here.

View File

@ -46,7 +46,7 @@ from lerobot.common.datasets.utils import (
from lerobot.common.envs.factory import make_env_config
from lerobot.common.policies.factory import make_policy_config
from lerobot.common.robot_devices.robots.utils import make_robot
from lerobot.common.utils.utils import seeded_context
from lerobot.common.utils.random_utils import seeded_context
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from tests.fixtures.constants import DUMMY_REPO_ID

View File

@ -28,7 +28,7 @@ from lerobot.common.datasets.transforms import (
SharpnessJitter,
make_transform_from_config,
)
from lerobot.common.utils.utils import seeded_context
from lerobot.common.utils.random_utils import seeded_context
from lerobot.scripts.visualize_image_transforms import (
save_all_transforms,
save_each_transform,

74
tests/test_io_utils.py Normal file
View File

@ -0,0 +1,74 @@
import json
from pathlib import Path
from typing import Any
import pytest
from lerobot.common.utils.io_utils import deserialize_json_into_object
@pytest.fixture
def tmp_json_file(tmp_path: Path):
"""Writes `data` to a temporary JSON file and returns the file's path."""
def _write(data: Any) -> Path:
file_path = tmp_path / "data.json"
with file_path.open("w", encoding="utf-8") as f:
json.dump(data, f)
return file_path
return _write
def test_simple_dict(tmp_json_file):
data = {"name": "Alice", "age": 30}
json_path = tmp_json_file(data)
obj = {"name": "", "age": 0}
assert deserialize_json_into_object(json_path, obj) == data
def test_nested_structure(tmp_json_file):
data = {"items": [1, 2, 3], "info": {"active": True}}
json_path = tmp_json_file(data)
obj = {"items": [0, 0, 0], "info": {"active": False}}
assert deserialize_json_into_object(json_path, obj) == data
def test_tuple_conversion(tmp_json_file):
data = {"coords": [10.5, 20.5]}
json_path = tmp_json_file(data)
obj = {"coords": (0.0, 0.0)}
result = deserialize_json_into_object(json_path, obj)
assert result["coords"] == (10.5, 20.5)
def test_type_mismatch_raises(tmp_json_file):
data = {"numbers": {"bad": "structure"}}
json_path = tmp_json_file(data)
obj = {"numbers": [0, 0]}
with pytest.raises(TypeError):
deserialize_json_into_object(json_path, obj)
def test_missing_key_raises(tmp_json_file):
data = {"one": 1}
json_path = tmp_json_file(data)
obj = {"one": 0, "two": 0}
with pytest.raises(ValueError):
deserialize_json_into_object(json_path, obj)
def test_extra_key_raises(tmp_json_file):
data = {"one": 1, "two": 2}
json_path = tmp_json_file(data)
obj = {"one": 0}
with pytest.raises(ValueError):
deserialize_json_into_object(json_path, obj)
def test_list_length_mismatch_raises(tmp_json_file):
data = {"nums": [1, 2, 3]}
json_path = tmp_json_file(data)
obj = {"nums": [0, 0]}
with pytest.raises(ValueError):
deserialize_json_into_object(json_path, obj)

107
tests/test_logging_utils.py Normal file
View File

@ -0,0 +1,107 @@
import pytest
from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
@pytest.fixture
def mock_metrics():
return {"loss": AverageMeter("loss", ":.3f"), "accuracy": AverageMeter("accuracy", ":.2f")}
def test_average_meter_initialization():
meter = AverageMeter("loss", ":.2f")
assert meter.name == "loss"
assert meter.fmt == ":.2f"
assert meter.val == 0.0
assert meter.avg == 0.0
assert meter.sum == 0.0
assert meter.count == 0.0
def test_average_meter_update():
meter = AverageMeter("accuracy")
meter.update(5, n=2)
assert meter.val == 5
assert meter.sum == 10
assert meter.count == 2
assert meter.avg == 5
def test_average_meter_reset():
meter = AverageMeter("loss")
meter.update(3, 4)
meter.reset()
assert meter.val == 0.0
assert meter.avg == 0.0
assert meter.sum == 0.0
assert meter.count == 0.0
def test_average_meter_str():
meter = AverageMeter("metric", ":.1f")
meter.update(4.567, 3)
assert str(meter) == "metric:4.6"
def test_metrics_tracker_initialization(mock_metrics):
tracker = MetricsTracker(
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics, initial_step=10
)
assert tracker.steps == 10
assert tracker.samples == 10 * 32
assert tracker.episodes == tracker.samples / (1000 / 50)
assert tracker.epochs == tracker.samples / 1000
assert "loss" in tracker.metrics
assert "accuracy" in tracker.metrics
def test_metrics_tracker_step(mock_metrics):
tracker = MetricsTracker(
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics, initial_step=5
)
tracker.step()
assert tracker.steps == 6
assert tracker.samples == 6 * 32
assert tracker.episodes == tracker.samples / (1000 / 50)
assert tracker.epochs == tracker.samples / 1000
def test_metrics_tracker_getattr(mock_metrics):
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
assert tracker.loss == mock_metrics["loss"]
assert tracker.accuracy == mock_metrics["accuracy"]
with pytest.raises(AttributeError):
_ = tracker.non_existent_metric
def test_metrics_tracker_setattr(mock_metrics):
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
tracker.loss = 2.0
assert tracker.loss.val == 2.0
def test_metrics_tracker_str(mock_metrics):
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
tracker.loss.update(3.456, 1)
tracker.accuracy.update(0.876, 1)
output = str(tracker)
assert "loss:3.456" in output
assert "accuracy:0.88" in output
def test_metrics_tracker_to_dict(mock_metrics):
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
tracker.loss.update(5, 2)
metrics_dict = tracker.to_dict()
assert isinstance(metrics_dict, dict)
assert metrics_dict["loss"] == 5 # average value
assert metrics_dict["steps"] == tracker.steps
def test_metrics_tracker_reset_averages(mock_metrics):
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
tracker.loss.update(10, 3)
tracker.accuracy.update(0.95, 5)
tracker.reset_averages()
assert tracker.loss.avg == 0.0
assert tracker.accuracy.avg == 0.0

43
tests/test_optimizers.py Normal file
View File

@ -0,0 +1,43 @@
import pytest
import torch
from lerobot.common.constants import (
OPTIMIZER_PARAM_GROUPS,
OPTIMIZER_STATE,
)
from lerobot.common.optim.optimizers import (
AdamConfig,
AdamWConfig,
SGDConfig,
load_optimizer_state,
save_optimizer_state,
)
@pytest.mark.parametrize(
"config_cls, expected_class",
[
(AdamConfig, torch.optim.Adam),
(AdamWConfig, torch.optim.AdamW),
(SGDConfig, torch.optim.SGD),
],
)
def test_optimizer_build(config_cls, expected_class, model_params):
config = config_cls()
optimizer = config.build(model_params)
assert isinstance(optimizer, expected_class)
assert optimizer.defaults["lr"] == config.lr
def test_save_optimizer_state(optimizer, tmp_path):
save_optimizer_state(optimizer, tmp_path)
assert (tmp_path / OPTIMIZER_STATE).is_file()
assert (tmp_path / OPTIMIZER_PARAM_GROUPS).is_file()
def test_save_and_load_optimizer_state(model_params, optimizer, tmp_path):
save_optimizer_state(optimizer, tmp_path)
loaded_optimizer = AdamConfig().build(model_params)
loaded_optimizer = load_optimizer_state(loaded_optimizer, tmp_path)
torch.testing.assert_close(optimizer.state_dict(), loaded_optimizer.state_dict())

View File

@ -36,7 +36,7 @@ from lerobot.common.policies.factory import (
)
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.utils.utils import seeded_context
from lerobot.common.utils.random_utils import seeded_context
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature

109
tests/test_random_utils.py Normal file
View File

@ -0,0 +1,109 @@
import random
import numpy as np
import pytest
import torch
from lerobot.common.utils.random_utils import (
deserialize_numpy_rng_state,
deserialize_python_rng_state,
deserialize_rng_state,
deserialize_torch_rng_state,
get_rng_state,
seeded_context,
serialize_numpy_rng_state,
serialize_python_rng_state,
serialize_rng_state,
serialize_torch_rng_state,
set_rng_state,
set_seed,
)
@pytest.fixture
def fixed_seed():
"""Fixture to set a consistent initial seed for each test."""
set_seed(12345)
yield
def test_serialize_deserialize_python_rng(fixed_seed):
# Save state after generating val1
_ = random.random()
st = serialize_python_rng_state()
# Next random is val2
val2 = random.random()
# Restore the state, so the next random should match val2
deserialize_python_rng_state(st)
val3 = random.random()
assert val2 == val3
def test_serialize_deserialize_numpy_rng(fixed_seed):
_ = np.random.rand()
st = serialize_numpy_rng_state()
val2 = np.random.rand()
deserialize_numpy_rng_state(st)
val3 = np.random.rand()
assert val2 == val3
def test_serialize_deserialize_torch_rng(fixed_seed):
_ = torch.rand(1).item()
st = serialize_torch_rng_state()
val2 = torch.rand(1).item()
deserialize_torch_rng_state(st)
val3 = torch.rand(1).item()
assert val2 == val3
def test_serialize_deserialize_rng(fixed_seed):
# Generate one from each library
_ = random.random()
_ = np.random.rand()
_ = torch.rand(1).item()
# Serialize
st = serialize_rng_state()
# Generate second set
val_py2 = random.random()
val_np2 = np.random.rand()
val_th2 = torch.rand(1).item()
# Restore, so the next draws should match val_py2, val_np2, val_th2
deserialize_rng_state(st)
assert random.random() == val_py2
assert np.random.rand() == val_np2
assert torch.rand(1).item() == val_th2
def test_get_set_rng_state(fixed_seed):
st = get_rng_state()
val1 = (random.random(), np.random.rand(), torch.rand(1).item())
# Change states
random.random()
np.random.rand()
torch.rand(1)
# Restore
set_rng_state(st)
val2 = (random.random(), np.random.rand(), torch.rand(1).item())
assert val1 == val2
def test_set_seed():
set_seed(1337)
val1 = (random.random(), np.random.rand(), torch.rand(1).item())
set_seed(1337)
val2 = (random.random(), np.random.rand(), torch.rand(1).item())
assert val1 == val2
def test_seeded_context(fixed_seed):
val1 = (random.random(), np.random.rand(), torch.rand(1).item())
with seeded_context(1337):
seeded_val1 = (random.random(), np.random.rand(), torch.rand(1).item())
val2 = (random.random(), np.random.rand(), torch.rand(1).item())
with seeded_context(1337):
seeded_val2 = (random.random(), np.random.rand(), torch.rand(1).item())
assert seeded_val1 == seeded_val2
assert all(a != b for a, b in zip(val1, seeded_val1, strict=True)) # changed inside the context
assert all(a != b for a, b in zip(val2, seeded_val2, strict=True)) # changed again after exiting

81
tests/test_schedulers.py Normal file
View File

@ -0,0 +1,81 @@
from torch.optim.lr_scheduler import LambdaLR
from lerobot.common.constants import SCHEDULER_STATE
from lerobot.common.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
DiffuserSchedulerConfig,
VQBeTSchedulerConfig,
load_scheduler_state,
save_scheduler_state,
)
def test_diffuser_scheduler(optimizer):
config = DiffuserSchedulerConfig(name="cosine", num_warmup_steps=5)
scheduler = config.build(optimizer, num_training_steps=100)
assert isinstance(scheduler, LambdaLR)
optimizer.step() # so that we don't get torch warning
scheduler.step()
expected_state_dict = {
"_get_lr_called_within_step": False,
"_last_lr": [0.0002],
"_step_count": 2,
"base_lrs": [0.001],
"last_epoch": 1,
"lr_lambdas": [None],
"verbose": False,
}
assert scheduler.state_dict() == expected_state_dict
def test_vqbet_scheduler(optimizer):
config = VQBeTSchedulerConfig(num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5)
scheduler = config.build(optimizer, num_training_steps=100)
assert isinstance(scheduler, LambdaLR)
optimizer.step()
scheduler.step()
expected_state_dict = {
"_get_lr_called_within_step": False,
"_last_lr": [0.001],
"_step_count": 2,
"base_lrs": [0.001],
"last_epoch": 1,
"lr_lambdas": [None],
"verbose": False,
}
assert scheduler.state_dict() == expected_state_dict
def test_cosine_decay_with_warmup_scheduler(optimizer):
config = CosineDecayWithWarmupSchedulerConfig(
num_warmup_steps=10, num_decay_steps=90, peak_lr=0.01, decay_lr=0.001
)
scheduler = config.build(optimizer, num_training_steps=100)
assert isinstance(scheduler, LambdaLR)
optimizer.step()
scheduler.step()
expected_state_dict = {
"_get_lr_called_within_step": False,
"_last_lr": [0.0001818181818181819],
"_step_count": 2,
"base_lrs": [0.001],
"last_epoch": 1,
"lr_lambdas": [None],
"verbose": False,
}
assert scheduler.state_dict() == expected_state_dict
def test_save_scheduler_state(scheduler, tmp_path):
save_scheduler_state(scheduler, tmp_path)
assert (tmp_path / SCHEDULER_STATE).is_file()
def test_save_load_scheduler_state(scheduler, tmp_path):
save_scheduler_state(scheduler, tmp_path)
loaded_scheduler = load_scheduler_state(scheduler, tmp_path)
assert scheduler.state_dict() == loaded_scheduler.state_dict()

84
tests/test_train_utils.py Normal file
View File

@ -0,0 +1,84 @@
from pathlib import Path
from unittest.mock import Mock, patch
from lerobot.common.constants import (
CHECKPOINTS_DIR,
LAST_CHECKPOINT_LINK,
OPTIMIZER_PARAM_GROUPS,
OPTIMIZER_STATE,
RNG_STATE,
SCHEDULER_STATE,
TRAINING_STATE_DIR,
TRAINING_STEP,
)
from lerobot.common.utils.train_utils import (
get_step_checkpoint_dir,
get_step_identifier,
load_training_state,
load_training_step,
save_checkpoint,
save_training_state,
save_training_step,
update_last_checkpoint,
)
def test_get_step_identifier():
assert get_step_identifier(5, 1000) == "000005"
assert get_step_identifier(123, 100_000) == "000123"
assert get_step_identifier(456789, 1_000_000) == "0456789"
def test_get_step_checkpoint_dir():
output_dir = Path("/checkpoints")
step_dir = get_step_checkpoint_dir(output_dir, 1000, 5)
assert step_dir == output_dir / CHECKPOINTS_DIR / "000005"
def test_save_load_training_step(tmp_path):
save_training_step(5000, tmp_path)
assert (tmp_path / TRAINING_STEP).is_file()
def test_load_training_step(tmp_path):
step = 5000
save_training_step(step, tmp_path)
loaded_step = load_training_step(tmp_path)
assert loaded_step == step
def test_update_last_checkpoint(tmp_path):
checkpoint = tmp_path / "0005"
checkpoint.mkdir()
update_last_checkpoint(checkpoint)
last_checkpoint = tmp_path / LAST_CHECKPOINT_LINK
assert last_checkpoint.is_symlink()
assert last_checkpoint.resolve() == checkpoint
@patch("lerobot.common.utils.train_utils.save_training_state")
def test_save_checkpoint(mock_save_training_state, tmp_path, optimizer):
policy = Mock()
cfg = Mock()
save_checkpoint(tmp_path, 10, cfg, policy, optimizer)
policy.save_pretrained.assert_called_once()
cfg.save_pretrained.assert_called_once()
mock_save_training_state.assert_called_once()
def test_save_training_state(tmp_path, optimizer, scheduler):
save_training_state(tmp_path, 10, optimizer, scheduler)
assert (tmp_path / TRAINING_STATE_DIR).is_dir()
assert (tmp_path / TRAINING_STATE_DIR / TRAINING_STEP).is_file()
assert (tmp_path / TRAINING_STATE_DIR / RNG_STATE).is_file()
assert (tmp_path / TRAINING_STATE_DIR / OPTIMIZER_STATE).is_file()
assert (tmp_path / TRAINING_STATE_DIR / OPTIMIZER_PARAM_GROUPS).is_file()
assert (tmp_path / TRAINING_STATE_DIR / SCHEDULER_STATE).is_file()
def test_save_load_training_state(tmp_path, optimizer, scheduler):
save_training_state(tmp_path, 10, optimizer, scheduler)
loaded_step, loaded_optimizer, loaded_scheduler = load_training_state(tmp_path, optimizer, scheduler)
assert loaded_step == 10
assert loaded_optimizer is optimizer
assert loaded_scheduler is scheduler

View File

@ -1,8 +1,3 @@
import random
from typing import Callable
import numpy as np
import pytest
import torch
from datasets import Dataset
@ -10,50 +5,6 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_
from lerobot.common.datasets.utils import (
hf_transform_to_torch,
)
from lerobot.common.utils.utils import (
get_global_random_state,
seeded_context,
set_global_random_state,
set_global_seed,
)
# Random generation functions for testing the seeding and random state get/set.
rand_fns = [
random.random,
np.random.random,
lambda: torch.rand(1).item(),
]
if torch.cuda.is_available():
rand_fns.append(lambda: torch.rand(1, device="cuda"))
@pytest.mark.parametrize("rand_fn", rand_fns)
def test_seeding(rand_fn: Callable[[], int]):
set_global_seed(0)
a = rand_fn()
with seeded_context(1337):
c = rand_fn()
b = rand_fn()
set_global_seed(0)
a_ = rand_fn()
b_ = rand_fn()
# Check that `set_global_seed` lets us reproduce a and b.
assert a_ == a
# Additionally, check that the `seeded_context` didn't interrupt the global RNG.
assert b_ == b
set_global_seed(1337)
c_ = rand_fn()
# Check that `seeded_context` and `global_seed` give the same reproducibility.
assert c_ == c
def test_get_set_random_state():
"""Check that getting the random state, then setting it results in the same random number generation."""
random_state_dict = get_global_random_state()
rand_numbers = [rand_fn() for rand_fn in rand_fns]
set_global_random_state(random_state_dict)
rand_numbers_ = [rand_fn() for rand_fn in rand_fns]
assert rand_numbers_ == rand_numbers
def test_calculate_episode_data_index():