Merge remote-tracking branch 'origin/main' into user/aliberts/2024_05_28_compile_torchvision

This commit is contained in:
Simon Alibert 2024-05-30 17:40:54 +00:00
commit af19a89614
16 changed files with 674 additions and 107 deletions

View File

@ -70,6 +70,8 @@ jobs:
# files: ./coverage.xml
# verbose: true
- name: Tests end-to-end
env:
DEVICE: cuda
run: make test-end-to-end
# - name: Generate Report

View File

@ -10,6 +10,7 @@ endif
export PATH := $(dir $(PYTHON_PATH)):$(PATH)
DEVICE ?= cpu
build-cpu:
docker build -t lerobot:latest -f docker/lerobot-cpu/Dockerfile .
@ -18,16 +19,16 @@ build-gpu:
docker build -t lerobot:latest -f docker/lerobot-gpu/Dockerfile .
test-end-to-end:
${MAKE} test-act-ete-train
${MAKE} test-act-ete-eval
${MAKE} test-act-ete-train-amp
${MAKE} test-act-ete-eval-amp
${MAKE} test-diffusion-ete-train
${MAKE} test-diffusion-ete-eval
${MAKE} test-tdmpc-ete-train
${MAKE} test-tdmpc-ete-eval
${MAKE} test-default-ete-eval
${MAKE} test-act-pusht-tutorial
${MAKE} DEVICE=$(DEVICE) test-act-ete-train
${MAKE} DEVICE=$(DEVICE) test-act-ete-eval
${MAKE} DEVICE=$(DEVICE) test-act-ete-train-amp
${MAKE} DEVICE=$(DEVICE) test-act-ete-eval-amp
${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-train
${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-eval
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-eval
${MAKE} DEVICE=$(DEVICE) test-default-ete-eval
${MAKE} DEVICE=$(DEVICE) test-act-pusht-tutorial
test-act-ete-train:
python lerobot/scripts/train.py \
@ -39,7 +40,7 @@ test-act-ete-train:
training.online_steps=0 \
eval.n_episodes=1 \
eval.batch_size=1 \
device=cpu \
device=$(DEVICE) \
training.save_checkpoint=true \
training.save_freq=2 \
policy.n_action_steps=20 \
@ -53,7 +54,7 @@ test-act-ete-eval:
eval.n_episodes=1 \
eval.batch_size=1 \
env.episode_length=8 \
device=cpu \
device=$(DEVICE) \
test-act-ete-train-amp:
python lerobot/scripts/train.py \
@ -65,7 +66,7 @@ test-act-ete-train-amp:
training.online_steps=0 \
eval.n_episodes=1 \
eval.batch_size=1 \
device=cpu \
device=$(DEVICE) \
training.save_checkpoint=true \
training.save_freq=2 \
policy.n_action_steps=20 \
@ -80,7 +81,7 @@ test-act-ete-eval-amp:
eval.n_episodes=1 \
eval.batch_size=1 \
env.episode_length=8 \
device=cpu \
device=$(DEVICE) \
use_amp=true
test-diffusion-ete-train:
@ -95,7 +96,7 @@ test-diffusion-ete-train:
training.online_steps=0 \
eval.n_episodes=1 \
eval.batch_size=1 \
device=cpu \
device=$(DEVICE) \
training.save_checkpoint=true \
training.save_freq=2 \
training.batch_size=2 \
@ -107,7 +108,7 @@ test-diffusion-ete-eval:
eval.n_episodes=1 \
eval.batch_size=1 \
env.episode_length=8 \
device=cpu \
device=$(DEVICE) \
# TODO(alexander-soare): Restore online_steps to 2 when it is reinstated.
test-tdmpc-ete-train:
@ -122,7 +123,7 @@ test-tdmpc-ete-train:
eval.n_episodes=1 \
eval.batch_size=1 \
env.episode_length=2 \
device=cpu \
device=$(DEVICE) \
training.save_checkpoint=true \
training.save_freq=2 \
training.batch_size=2 \
@ -134,7 +135,7 @@ test-tdmpc-ete-eval:
eval.n_episodes=1 \
eval.batch_size=1 \
env.episode_length=8 \
device=cpu \
device=$(DEVICE) \
test-default-ete-eval:
python lerobot/scripts/eval.py \
@ -142,7 +143,7 @@ test-default-ete-eval:
eval.n_episodes=1 \
eval.batch_size=1 \
env.episode_length=8 \
device=cpu \
device=$(DEVICE) \
test-act-pusht-tutorial:
cp examples/advanced/1_train_act_pusht/act_pusht.yaml lerobot/configs/policy/created_by_Makefile.yaml
@ -154,7 +155,7 @@ test-act-pusht-tutorial:
eval.n_episodes=1 \
eval.batch_size=1 \
env.episode_length=2 \
device=cpu \
device=$(DEVICE) \
training.save_model=true \
training.save_freq=2 \
training.batch_size=2 \

View File

@ -16,17 +16,15 @@
from copy import deepcopy
from math import ceil
import datasets
import einops
import torch
import tqdm
from datasets import Image
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.video_utils import VideoFrame
def get_stats_einops_patterns(dataset: LeRobotDataset | datasets.Dataset, num_workers=0):
def get_stats_einops_patterns(dataset, num_workers=0):
"""These einops patterns will be used to aggregate batches and compute statistics.
Note: We assume the images are in channel first format
@ -66,9 +64,8 @@ def get_stats_einops_patterns(dataset: LeRobotDataset | datasets.Dataset, num_wo
return stats_patterns
def compute_stats(
dataset: LeRobotDataset | datasets.Dataset, batch_size=32, num_workers=16, max_num_samples=None
):
def compute_stats(dataset, batch_size=32, num_workers=16, max_num_samples=None):
"""Compute mean/std and min/max statistics of all data keys in a LeRobotDataset."""
if max_num_samples is None:
max_num_samples = len(dataset)
@ -159,3 +156,54 @@ def compute_stats(
"min": min[key],
}
return stats
def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
"""Aggregate stats of multiple LeRobot datasets into one set of stats without recomputing from scratch.
The final stats will have the union of all data keys from each of the datasets.
The final stats will have the union of all data keys from each of the datasets. For instance:
- new_max = max(max_dataset_0, max_dataset_1, ...)
- new_min = min(min_dataset_0, min_dataset_1, ...)
- new_mean = (mean of all data)
- new_std = (std of all data)
"""
data_keys = set()
for dataset in ls_datasets:
data_keys.update(dataset.stats.keys())
stats = {k: {} for k in data_keys}
for data_key in data_keys:
for stat_key in ["min", "max"]:
# compute `max(dataset_0["max"], dataset_1["max"], ...)`
stats[data_key][stat_key] = einops.reduce(
torch.stack([d.stats[data_key][stat_key] for d in ls_datasets if data_key in d.stats], dim=0),
"n ... -> ...",
stat_key,
)
total_samples = sum(d.num_samples for d in ls_datasets if data_key in d.stats)
# Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
# dataset, then divide by total_samples to get the overall "mean".
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
# numerical overflow!
stats[data_key]["mean"] = sum(
d.stats[data_key]["mean"] * (d.num_samples / total_samples)
for d in ls_datasets
if data_key in d.stats
)
# The derivation for standard deviation is a little more involved but is much in the same spirit as
# the computation of the mean.
# Given two sets of data where the statistics are known:
# σ_combined = sqrt[ (n1 * (σ1^2 + d1^2) + n2 * (σ2^2 + d2^2)) / (n1 + n2) ]
# where d1 = μ1 - μ_combined, d2 = μ2 - μ_combined
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
# numerical overflow!
stats[data_key]["std"] = torch.sqrt(
sum(
(d.stats[data_key]["std"] ** 2 + (d.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2)
* (d.num_samples / total_samples)
for d in ls_datasets
if data_key in d.stats
)
)
return stats

View File

@ -16,9 +16,9 @@
import logging
import torch
from omegaconf import OmegaConf
from omegaconf import ListConfig, OmegaConf
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
def resolve_delta_timestamps(cfg):
@ -35,11 +35,27 @@ def resolve_delta_timestamps(cfg):
cfg.training.delta_timestamps[key] = eval(delta_timestamps[key])
def make_dataset(
cfg,
split="train",
):
if cfg.env.name not in cfg.dataset_repo_id:
def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotDataset:
"""
Args:
cfg: A Hydra config as per the LeRobot config scheme.
split: Select the data subset used to create an instance of LeRobotDataset.
All datasets hosted on [lerobot](https://huggingface.co/lerobot) contain only one subset: "train".
Thus, by default, `split="train"` selects all the available data. `split` aims to work like the
slicer in the hugging face datasets:
https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
As of now, it only supports `split="train[:n]"` to load the first n frames of the dataset or
`split="train[n:]"` to load the last n frames. For instance `split="train[:1000]"`.
Returns:
The LeRobotDataset.
"""
if not isinstance(cfg.dataset_repo_id, (str, ListConfig)):
raise ValueError(
"Expected cfg.dataset_repo_id to be either a single string to load one dataset or a list of "
"strings to load multiple datasets."
)
if isinstance(cfg.dataset_repo_id, str) and cfg.env.name not in cfg.dataset_repo_id:
logging.warning(
f"There might be a mismatch between your training dataset ({cfg.dataset_repo_id=}) and your "
f"environment ({cfg.env.name=})."
@ -49,12 +65,17 @@ def make_dataset(
# TODO(rcadene): add data augmentations
dataset = LeRobotDataset(
cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"),
video_backend=cfg.video_backend,
)
if isinstance(cfg.dataset_repo_id, str):
dataset = LeRobotDataset(
cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"),
video_backend=cfg.video_backend,
)
else:
dataset = MultiLeRobotDataset(
cfg.dataset_repo_id, split=split, delta_timestamps=cfg.training.get("delta_timestamps")
)
if cfg.get("override_dataset_stats"):
for key, stats_dict in cfg.override_dataset_stats.items():

View File

@ -13,12 +13,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from pathlib import Path
from typing import Callable
import datasets
import torch
import torch.utils
from lerobot.common.datasets.compute_stats import aggregate_stats
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
load_episode_data_index,
@ -42,7 +46,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
version: str | None = CODEBASE_VERSION,
root: Path | None = DATA_DIR,
split: str = "train",
transform: callable = None,
transform: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
video_backend: str | None = None,
):
@ -174,7 +178,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
@classmethod
def from_preloaded(
cls,
repo_id: str,
repo_id: str = "from_preloaded",
version: str | None = CODEBASE_VERSION,
root: Path | None = None,
split: str = "train",
@ -186,7 +190,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
stats=None,
info=None,
videos_dir=None,
):
) -> "LeRobotDataset":
"""Create a LeRobot Dataset from existing data and attributes instead of loading from the filesystem.
It is especially useful when converting raw data into LeRobotDataset before saving the dataset
on the filesystem or uploading to the hub.
Note: Meta-data attributes like `repo_id`, `version`, `root`, etc are optional and potentially
meaningless depending on the downstream usage of the return dataset.
"""
# create an empty object of type LeRobotDataset
obj = cls.__new__(cls)
obj.repo_id = repo_id
@ -198,6 +210,194 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.hf_dataset = hf_dataset
obj.episode_data_index = episode_data_index
obj.stats = stats
obj.info = info
obj.info = info if info is not None else {}
obj.videos_dir = videos_dir
return obj
class MultiLeRobotDataset(torch.utils.data.Dataset):
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API
structure of `LeRobotDataset`.
"""
def __init__(
self,
repo_ids: list[str],
version: str | None = CODEBASE_VERSION,
root: Path | None = DATA_DIR,
split: str = "train",
transform: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
video_backend: str | None = None,
):
super().__init__()
self.repo_ids = repo_ids
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
# are handled by this class.
self._datasets = [
LeRobotDataset(
repo_id,
version=version,
root=root,
split=split,
delta_timestamps=delta_timestamps,
transform=transform,
video_backend=video_backend,
)
for repo_id in repo_ids
]
# Check that some properties are consistent across datasets. Note: We may relax some of these
# consistency requirements in future iterations of this class.
for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True):
if dataset.info != self._datasets[0].info:
raise ValueError(
f"Detected a mismatch in dataset info between {self.repo_ids[0]} and {repo_id}. This is "
"not yet supported."
)
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
# restriction in future iterations of this class. For now, this is necessary at least for being able
# to use PyTorch's default DataLoader collate function.
self.disabled_data_keys = set()
intersection_data_keys = set(self._datasets[0].hf_dataset.features)
for dataset in self._datasets:
intersection_data_keys.intersection_update(dataset.hf_dataset.features)
if len(intersection_data_keys) == 0:
raise RuntimeError(
"Multiple datasets were provided but they had no keys common to all of them. The "
"multi-dataset functionality currently only keeps common keys."
)
for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True):
extra_keys = set(dataset.hf_dataset.features).difference(intersection_data_keys)
logging.warning(
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
"other datasets."
)
self.disabled_data_keys.update(extra_keys)
self.version = version
self.root = root
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
self.stats = aggregate_stats(self._datasets)
@property
def repo_id_to_index(self):
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class.
This index is incorporated as a data key in the dictionary returned by `__getitem__`.
"""
return {repo_id: i for i, repo_id in enumerate(self.repo_ids)}
@property
def repo_index_to_id(self):
"""Return the inverse mapping if repo_id_to_index."""
return {v: k for k, v in self.repo_id_to_index}
@property
def fps(self) -> int:
"""Frames per second used during data collection.
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
"""
return self._datasets[0].info["fps"]
@property
def video(self) -> bool:
"""Returns True if this dataset loads video frames from mp4 files.
Returns False if it only loads images from png files.
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
"""
return self._datasets[0].info.get("video", False)
@property
def features(self) -> datasets.Features:
features = {}
for dataset in self._datasets:
features.update({k: v for k, v in dataset.features.items() if k not in self.disabled_data_keys})
return features
@property
def camera_keys(self) -> list[str]:
"""Keys to access image and video stream from cameras."""
keys = []
for key, feats in self.features.items():
if isinstance(feats, (datasets.Image, VideoFrame)):
keys.append(key)
return keys
@property
def video_frame_keys(self) -> list[str]:
"""Keys to access video frames that requires to be decoded into images.
Note: It is empty if the dataset contains images only,
or equal to `self.cameras` if the dataset contains videos only,
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
"""
video_frame_keys = []
for key, feats in self.features.items():
if isinstance(feats, VideoFrame):
video_frame_keys.append(key)
return video_frame_keys
@property
def num_samples(self) -> int:
"""Number of samples/frames."""
return sum(d.num_samples for d in self._datasets)
@property
def num_episodes(self) -> int:
"""Number of episodes."""
return sum(d.num_episodes for d in self._datasets)
@property
def tolerance_s(self) -> float:
"""Tolerance in seconds used to discard loaded frames when their timestamps
are not close enough from the requested frames. It is only used when `delta_timestamps`
is provided or when loading video frames from mp4 files.
"""
# 1e-4 to account for possible numerical error
return 1 / self.fps - 1e-4
def __len__(self):
return self.num_samples
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
if idx >= len(self):
raise IndexError(f"Index {idx} out of bounds.")
# Determine which dataset to get an item from based on the index.
start_idx = 0
dataset_idx = 0
for dataset in self._datasets:
if idx >= start_idx + dataset.num_samples:
start_idx += dataset.num_samples
dataset_idx += 1
break
else:
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
item = self._datasets[dataset_idx][idx - start_idx]
item["dataset_index"] = torch.tensor(dataset_idx)
for data_key in self.disabled_data_keys:
if data_key in item:
del item[data_key]
return item
def __repr__(self):
return (
f"{self.__class__.__name__}(\n"
f" Repository IDs: '{self.repo_ids}',\n"
f" Version: '{self.version}',\n"
f" Split: '{self.split}',\n"
f" Number of Samples: {self.num_samples},\n"
f" Number of Episodes: {self.num_episodes},\n"
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
f" Recorded Frames per Second: {self.fps},\n"
f" Camera Keys: {self.camera_keys},\n"
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
f" Transformations: {self.transform},\n"
f")"
)

View File

@ -0,0 +1,230 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains utilities to process raw data format from dora-record
"""
import logging
import re
from pathlib import Path
import pandas as pd
import torch
from datasets import Dataset, Features, Image, Sequence, Value
from lerobot.common.datasets.utils import (
hf_transform_to_torch,
)
from lerobot.common.datasets.video_utils import VideoFrame
from lerobot.common.utils.utils import init_logging
def check_format(raw_dir) -> bool:
assert raw_dir.exists()
leader_file = list(raw_dir.glob("*.parquet"))
if len(leader_file) == 0:
raise ValueError(f"Missing parquet files in '{raw_dir}'")
return True
def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
# Load data stream that will be used as reference for the timestamps synchronization
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
if len(reference_files) == 0:
raise ValueError(f"Missing reference files for camera, starting with in '{raw_dir}'")
# select first camera in alphanumeric order
reference_key = sorted(reference_files)[0].stem
reference_df = pd.read_parquet(raw_dir / f"{reference_key}.parquet")
reference_df = reference_df[["timestamp_utc", reference_key]]
# Merge all data stream using nearest backward strategy
df = reference_df
for path in raw_dir.glob("*.parquet"):
key = path.stem # action or observation.state or ...
if key == reference_key:
continue
if "failed_episode_index" in key:
# TODO(rcadene): add support for removing episodes that are tagged as "failed"
continue
modality_df = pd.read_parquet(path)
modality_df = modality_df[["timestamp_utc", key]]
df = pd.merge_asof(
df,
modality_df,
on="timestamp_utc",
# "nearest" is the best option over "backward", since the latter can desynchronizes camera timestamps by
# matching timestamps that are too far appart, in order to fit the backward constraints. It's not the case for "nearest".
# However, note that "nearest" might synchronize the reference camera with other cameras on slightly future timestamps.
# are too far appart.
direction="nearest",
tolerance=pd.Timedelta(f"{1/fps} seconds"),
)
# Remove rows with episode_index -1 which indicates data that correspond to in-between episodes
df = df[df["episode_index"] != -1]
image_keys = [key for key in df if "observation.images." in key]
def get_episode_index(row):
episode_index_per_cam = {}
for key in image_keys:
path = row[key][0]["path"]
match = re.search(r"_(\d{6}).mp4", path)
if not match:
raise ValueError(path)
episode_index = int(match.group(1))
episode_index_per_cam[key] = episode_index
if len(set(episode_index_per_cam.values())) != 1:
raise ValueError(
f"All cameras are expected to belong to the same episode, but getting {episode_index_per_cam}"
)
return episode_index
df["episode_index"] = df.apply(get_episode_index, axis=1)
# dora only use arrays, so single values are encapsulated into a list
df["frame_index"] = df.groupby("episode_index").cumcount()
df = df.reset_index()
df["index"] = df.index
# set 'next.done' to True for the last frame of each episode
df["next.done"] = False
df.loc[df.groupby("episode_index").tail(1).index, "next.done"] = True
df["timestamp"] = df["timestamp_utc"].map(lambda x: x.timestamp())
# each episode starts with timestamp 0 to match the ones from the video
df["timestamp"] = df.groupby("episode_index")["timestamp"].transform(lambda x: x - x.iloc[0])
del df["timestamp_utc"]
# sanity check
has_nan = df.isna().any().any()
if has_nan:
raise ValueError("Dataset contains Nan values.")
# sanity check episode indices go from 0 to n-1
ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")]
expected_ep_ids = list(range(df["episode_index"].max() + 1))
if ep_ids != expected_ep_ids:
raise ValueError(f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}")
# Create symlink to raw videos directory (that needs to be absolute not relative)
out_dir.mkdir(parents=True, exist_ok=True)
videos_dir = out_dir / "videos"
videos_dir.symlink_to((raw_dir / "videos").absolute())
# sanity check the video paths are well formated
for key in df:
if "observation.images." not in key:
continue
for ep_idx in ep_ids:
video_path = videos_dir / f"{key}_episode_{ep_idx:06d}.mp4"
if not video_path.exists():
raise ValueError(f"Video file not found in {video_path}")
data_dict = {}
for key in df:
# is video frame
if "observation.images." in key:
# we need `[0] because dora only use arrays, so single values are encapsulated into a list.
# it is the case for video_frame dictionary = [{"path": ..., "timestamp": ...}]
data_dict[key] = [video_frame[0] for video_frame in df[key].values]
# sanity check the video path is well formated
video_path = videos_dir.parent / data_dict[key][0]["path"]
if not video_path.exists():
raise ValueError(f"Video file not found in {video_path}")
# is number
elif df[key].iloc[0].ndim == 0 or df[key].iloc[0].shape[0] == 1:
data_dict[key] = torch.from_numpy(df[key].values)
# is vector
elif df[key].iloc[0].shape[0] > 1:
data_dict[key] = torch.stack([torch.from_numpy(x.copy()) for x in df[key].values])
else:
raise ValueError(key)
# Get the episode index containing for each unique episode index
first_ep_index_df = df.groupby("episode_index").agg(start_index=("index", "first")).reset_index()
from_ = first_ep_index_df["start_index"].tolist()
to_ = from_[1:] + [len(df)]
episode_data_index = {
"from": from_,
"to": to_,
}
return data_dict, episode_data_index
def to_hf_dataset(data_dict, video) -> Dataset:
features = {}
keys = [key for key in data_dict if "observation.images." in key]
for key in keys:
if video:
features[key] = VideoFrame()
else:
features[key] = Image()
features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
)
if "observation.velocity" in data_dict:
features["observation.velocity"] = Sequence(
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
)
if "observation.effort" in data_dict:
features["observation.effort"] = Sequence(
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
)
features["action"] = Sequence(
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
)
features["episode_index"] = Value(dtype="int64", id=None)
features["frame_index"] = Value(dtype="int64", id=None)
features["timestamp"] = Value(dtype="float32", id=None)
features["next.done"] = Value(dtype="bool", id=None)
features["index"] = Value(dtype="int64", id=None)
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
init_logging()
if debug:
logging.warning("debug=True not implemented. Falling back to debug=False.")
# sanity check
check_format(raw_dir)
if fps is None:
fps = 30
else:
raise NotImplementedError()
if not video:
raise NotImplementedError()
data_df, episode_data_index = load_from_raw(raw_dir, out_dir, fps)
hf_dataset = to_hf_dataset(data_df, video)
info = {
"fps": fps,
"video": video,
}
return hf_dataset, episode_data_index, info

View File

@ -59,7 +59,7 @@ def unflatten_dict(d, sep="/"):
return outdict
def hf_transform_to_torch(items_dict):
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
to torch tensors. Importantly, images are converted from PIL, which corresponds to
a channel last representation (h w c) of uint8 type, to a torch image representation
@ -73,6 +73,8 @@ def hf_transform_to_torch(items_dict):
elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item:
# video frame will be processed downstream
pass
elif first_item is None:
pass
else:
items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
return items_dict
@ -318,8 +320,7 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torc
def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
"""
Reset the `episode_index` of the provided HuggingFace Dataset.
"""Reset the `episode_index` of the provided HuggingFace Dataset.
`episode_data_index` (and related functionality such as `load_previous_and_future_frames`) requires the
`episode_index` to be sorted, continuous (1,1,1 and not 1,2,1) and start at 0.
@ -338,6 +339,7 @@ def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
return example
hf_dataset = hf_dataset.map(modify_ep_idx_func)
return hf_dataset

View File

@ -27,14 +27,6 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
if n_envs is not None and n_envs < 1:
raise ValueError("`n_envs must be at least 1")
kwargs = {
"obs_type": "pixels_agent_pos",
"render_mode": "rgb_array",
"max_episode_steps": cfg.env.episode_length,
"visualization_width": 384,
"visualization_height": 384,
}
package_name = f"gym_{cfg.env.name}"
try:
@ -46,12 +38,16 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
raise e
gym_handle = f"{package_name}/{cfg.env.task}"
gym_kwgs = dict(cfg.env.get("gym", {}))
if cfg.env.get("episode_length"):
gym_kwgs["max_episode_steps"] = cfg.env.episode_length
# batched version of the env that returns an observation of shape (b, c)
env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv
env = env_cls(
[
lambda: gym.make(gym_handle, disable_env_checker=True, **kwargs)
lambda: gym.make(gym_handle, disable_env_checker=True, **gym_kwgs)
for _ in range(n_envs if n_envs is not None else cfg.eval.batch_size)
]
)

View File

@ -198,7 +198,7 @@ class ACT(nn.Module):
def __init__(self, config: ACTConfig):
super().__init__()
self.config = config
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
if self.config.use_vae:
self.vae_encoder = ACTEncoder(config)
@ -214,7 +214,7 @@ class ACT(nn.Module):
self.latent_dim = config.latent_dim
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2)
# Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
# dimension.
self.register_buffer(
"vae_encoder_pos_enc",

View File

@ -23,6 +23,10 @@ use_amp: false
# `seed` is used for training (eg: model initialization, dataset shuffling)
# AND for the evaluation environments.
seed: ???
# You may provide a list of datasets here. `train.py` creates them all and concatenates them. Note: only data
# keys common between the datasets are kept. Each dataset gets and additional transform that inserts the
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
# datsets are provided.
dataset_repo_id: lerobot/pusht
video_backend: pyav
@ -38,6 +42,8 @@ training:
save_freq: ???
log_freq: 250
save_checkpoint: true
num_workers: 4
batch_size: ???
eval:
n_episodes: 1

View File

@ -5,10 +5,10 @@ fps: 50
env:
name: aloha
task: AlohaInsertion-v0
from_pixels: True
pixels_only: False
image_size: [3, 480, 640]
episode_length: 400
fps: ${fps}
state_dim: 14
action_dim: 14
fps: ${fps}
episode_length: 400
gym:
obs_type: pixels_agent_pos
render_mode: rgb_array

View File

@ -5,10 +5,13 @@ fps: 10
env:
name: pusht
task: PushT-v0
from_pixels: True
pixels_only: False
image_size: 96
episode_length: 300
fps: ${fps}
state_dim: 2
action_dim: 2
fps: ${fps}
episode_length: 300
gym:
obs_type: pixels_agent_pos
render_mode: rgb_array
visualization_width: 384
visualization_height: 384

View File

@ -5,10 +5,13 @@ fps: 15
env:
name: xarm
task: XarmLift-v0
from_pixels: True
pixels_only: False
image_size: 84
episode_length: 25
fps: ${fps}
state_dim: 4
action_dim: 4
fps: ${fps}
episode_length: 25
gym:
obs_type: pixels_agent_pos
render_mode: rgb_array
visualization_width: 384
visualization_height: 384

View File

@ -71,9 +71,9 @@ import torch
from huggingface_hub import HfApi
from safetensors.torch import save_file
from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
from lerobot.common.datasets.push_dataset_to_hub.compute_stats import compute_stats
from lerobot.common.datasets.utils import flatten_dict
@ -84,10 +84,14 @@ def get_from_raw_to_lerobot_format_fn(raw_format):
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
elif raw_format == "aloha_hdf5":
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
elif raw_format == "aloha_dora":
from lerobot.common.datasets.push_dataset_to_hub.aloha_dora_format import from_raw_to_lerobot_format
elif raw_format == "xarm_pkl":
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
else:
raise ValueError(raw_format)
raise ValueError(
f"The selected {raw_format} can't be found. Did you add it to `lerobot/scripts/push_dataset_to_hub.py::get_from_raw_to_lerobot_format_fn`?"
)
return from_raw_to_lerobot_format

View File

@ -16,7 +16,6 @@
import logging
import time
from contextlib import nullcontext
from copy import deepcopy
from pathlib import Path
from pprint import pformat
@ -28,6 +27,7 @@ from termcolor import colored
from torch.cuda.amp import GradScaler
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir
@ -280,9 +280,18 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info("make_dataset")
offline_dataset = make_dataset(cfg)
if isinstance(offline_dataset, MultiLeRobotDataset):
logging.info(
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
f"{pformat(offline_dataset.repo_id_to_index , indent=2)}"
)
logging.info("make_env")
eval_env = make_env(cfg)
# Create environment used for evaluating checkpoints during training on simulation data.
# On real-world data, no need to create an environment as evaluations are done outside train.py,
# using the eval.py instead, with gym_dora environment and dora-rs.
if cfg.training.eval_freq > 0:
logging.info("make_env")
eval_env = make_env(cfg)
logging.info("make_policy")
policy = make_policy(
@ -315,7 +324,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# Note: this helper will be used in offline and online training loops.
def evaluate_and_checkpoint_if_needed(step):
if step % cfg.training.eval_freq == 0:
if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0:
logging.info(f"Eval policy at step {step}")
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
eval_info = eval_policy(
@ -326,7 +335,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
max_episodes_rendered=4,
start_seed=cfg.seed,
)
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline=True)
if cfg.wandb.enable:
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
logging.info("Resume training")
@ -349,7 +358,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# create dataloader for offline training
dataloader = torch.utils.data.DataLoader(
offline_dataset,
num_workers=4,
num_workers=cfg.training.num_workers,
batch_size=cfg.training.batch_size,
shuffle=True,
pin_memory=device.type != "cpu",
@ -358,7 +367,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
dl_iter = cycle(dataloader)
policy.train()
is_offline = True
for _ in range(step, cfg.training.offline_steps):
if step == 0:
logging.info("Start offline training on a fixed dataset")
@ -378,7 +386,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
)
if step % cfg.training.log_freq == 0:
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline)
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline=True)
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
# so we pass in step + 1.
@ -386,26 +394,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
step += 1
# create an empty online dataset similar to offline dataset
online_dataset = deepcopy(offline_dataset)
online_dataset.hf_dataset = {}
online_dataset.episode_data_index = {}
# create dataloader for online training
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
weights = [1.0] * len(concat_dataset)
sampler = torch.utils.data.WeightedRandomSampler(
weights, num_samples=len(concat_dataset), replacement=True
)
dataloader = torch.utils.data.DataLoader(
concat_dataset,
num_workers=4,
batch_size=cfg.training.batch_size,
sampler=sampler,
pin_memory=device.type != "cpu",
drop_last=False,
)
eval_env.close()
logging.info("End of training")

View File

@ -16,6 +16,7 @@
import json
import logging
from copy import deepcopy
from itertools import chain
from pathlib import Path
import einops
@ -25,26 +26,34 @@ from datasets import Dataset
from safetensors.torch import load_file
import lerobot
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset,
)
from lerobot.common.datasets.push_dataset_to_hub.compute_stats import (
from lerobot.common.datasets.compute_stats import (
aggregate_stats,
compute_stats,
get_stats_einops_patterns,
)
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
from lerobot.common.datasets.utils import (
flatten_dict,
hf_transform_to_torch,
load_previous_and_future_frames,
unflatten_dict,
)
from lerobot.common.utils.utils import init_hydra_config
from lerobot.common.utils.utils import init_hydra_config, seeded_context
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE
@pytest.mark.parametrize("env_name, repo_id, policy_name", lerobot.env_dataset_policy_triplets)
@pytest.mark.parametrize(
"env_name, repo_id, policy_name",
lerobot.env_dataset_policy_triplets
+ [("aloha", ["lerobot/aloha_sim_insertion_human", "lerobot/aloha_sim_transfer_cube_human"], "act")],
)
def test_factory(env_name, repo_id, policy_name):
"""
Tests that:
- we can create a dataset with the factory.
- for a commonly used set of data keys, the data dimensions are correct.
"""
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=[
@ -105,6 +114,32 @@ def test_factory(env_name, repo_id, policy_name):
assert key in item, f"{key}"
def test_multilerobotdataset_frames():
"""Check that all dataset frames are incorporated."""
# Note: use the image variants of the dataset to make the test approx 3x faster.
repo_ids = ["lerobot/aloha_sim_insertion_human_image", "lerobot/aloha_sim_transfer_cube_human_image"]
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
dataset = MultiLeRobotDataset(repo_ids)
assert len(dataset) == sum(len(d) for d in sub_datasets)
assert dataset.num_samples == sum(d.num_samples for d in sub_datasets)
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
# Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and
# check they match.
expected_dataset_indices = []
for i, sub_dataset in enumerate(sub_datasets):
expected_dataset_indices.extend([i] * len(sub_dataset))
for expected_dataset_index, sub_dataset_item, dataset_item in zip(
expected_dataset_indices, chain(*sub_datasets), dataset, strict=True
):
dataset_index = dataset_item.pop("dataset_index")
assert dataset_index == expected_dataset_index
assert sub_dataset_item.keys() == dataset_item.keys()
for k in sub_dataset_item:
assert torch.equal(sub_dataset_item[k], dataset_item[k])
def test_compute_stats_on_xarm():
"""Check that the statistics are computed correctly according to the stats_patterns property.
@ -315,3 +350,31 @@ def test_backward_compatibility(repo_id):
# i = dataset.episode_data_index["to"][-1].item()
# load_and_compare(i - 2)
# load_and_compare(i - 1)
def test_aggregate_stats():
"""Makes 3 basic datasets and checks that aggregate stats are computed correctly."""
with seeded_context(0):
data_a = torch.rand(30, dtype=torch.float32)
data_b = torch.rand(20, dtype=torch.float32)
data_c = torch.rand(20, dtype=torch.float32)
hf_dataset_1 = Dataset.from_dict(
{"a": data_a[:10], "b": data_b[:10], "c": data_c[:10], "index": torch.arange(10)}
)
hf_dataset_1.set_transform(hf_transform_to_torch)
hf_dataset_2 = Dataset.from_dict({"a": data_a[10:20], "b": data_b[10:], "index": torch.arange(10)})
hf_dataset_2.set_transform(hf_transform_to_torch)
hf_dataset_3 = Dataset.from_dict({"a": data_a[20:], "c": data_c[10:], "index": torch.arange(10)})
hf_dataset_3.set_transform(hf_transform_to_torch)
dataset_1 = LeRobotDataset.from_preloaded("d1", hf_dataset=hf_dataset_1)
dataset_1.stats = compute_stats(dataset_1, batch_size=len(hf_dataset_1), num_workers=0)
dataset_2 = LeRobotDataset.from_preloaded("d2", hf_dataset=hf_dataset_2)
dataset_2.stats = compute_stats(dataset_2, batch_size=len(hf_dataset_2), num_workers=0)
dataset_3 = LeRobotDataset.from_preloaded("d3", hf_dataset=hf_dataset_3)
dataset_3.stats = compute_stats(dataset_3, batch_size=len(hf_dataset_3), num_workers=0)
stats = aggregate_stats([dataset_1, dataset_2, dataset_3])
for data_key, data in zip(["a", "b", "c"], [data_a, data_b, data_c], strict=True):
for agg_fn in ["mean", "min", "max"]:
assert torch.allclose(stats[data_key][agg_fn], einops.reduce(data, "n -> 1", agg_fn))
assert torch.allclose(stats[data_key]["std"], torch.std(data, correction=0))