This commit is contained in:
Alexander Soare 2024-05-30 15:12:55 +01:00
parent 265b0ec44d
commit cfa956bd3b
9 changed files with 344 additions and 74 deletions

View File

@ -16,17 +16,15 @@
from copy import deepcopy
from math import ceil
import datasets
import einops
import torch
import tqdm
from datasets import Image
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.video_utils import VideoFrame
def get_stats_einops_patterns(dataset: LeRobotDataset | datasets.Dataset, num_workers=0):
def get_stats_einops_patterns(dataset, num_workers=0):
"""These einops patterns will be used to aggregate batches and compute statistics.
Note: We assume the images are in channel first format
@ -66,9 +64,8 @@ def get_stats_einops_patterns(dataset: LeRobotDataset | datasets.Dataset, num_wo
return stats_patterns
def compute_stats(
dataset: LeRobotDataset | datasets.Dataset, batch_size=32, num_workers=16, max_num_samples=None
):
def compute_stats(dataset, batch_size=32, num_workers=16, max_num_samples=None):
"""Compute mean/std and min/max statistics of all data keys in a LeRobotDataset."""
if max_num_samples is None:
max_num_samples = len(dataset)
@ -159,3 +156,48 @@ def compute_stats(
"min": min[key],
}
return stats
def consolidate_stats(ls_datasets) -> dict[str, torch.Tensor]:
"""Consolidate stats of multiple LeRobot datasets into one set of stats without recomputing from scratch.
The final stats will have the union of all data keys from each of the datasets.
"""
data_keys = set()
for dataset in ls_datasets:
data_keys.update(dataset.stats.keys())
stats = {k: {} for k in data_keys}
for data_key in data_keys:
for stat_key in ["min", "max"]:
# compute `max(dataset_0["max"], dataset_1["max"], ...)`
stats[data_key][stat_key] = einops.reduce(
torch.stack([d.stats[data_key][stat_key] for d in ls_datasets if data_key in d.stats], dim=0),
"n ... -> ...",
stat_key,
)
total_samples = sum(d.num_samples for d in ls_datasets if data_key in d.stats)
# Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
# dataset, then divide by total_samples to get the overall "mean".
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
# numerical overflow!
stats[data_key]["mean"] = sum(
d.stats[data_key]["mean"] * (d.num_samples / total_samples)
for d in ls_datasets
if data_key in d.stats
)
# The derivation for standard deviation is a little more involved but is much in the same spirit as
# the computation of the mean.
# Given two sets of data where the statistics are known:
# σ_combined = sqrt[ (n1 * (σ1^2 + d1^2) + n2 * (σ2^2 + d2^2)) / (n1 + n2) ]
# where d1 = μ1 - μ_combined, d2 = μ2 - μ_combined
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
# numerical overflow!
stats[data_key]["std"] = torch.sqrt(
sum(
(d.stats[data_key]["std"] ** 2 + (d.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2)
* (d.num_samples / total_samples)
for d in ls_datasets
if data_key in d.stats
)
)
return stats

View File

@ -16,9 +16,9 @@
import logging
import torch
from omegaconf import OmegaConf
from omegaconf import ListConfig, OmegaConf
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
def resolve_delta_timestamps(cfg):
@ -35,11 +35,18 @@ def resolve_delta_timestamps(cfg):
cfg.training.delta_timestamps[key] = eval(delta_timestamps[key])
def make_dataset(
cfg,
split="train",
):
if cfg.env.name not in cfg.dataset_repo_id:
def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotDataset:
"""
Args:
cfg: A Hydra config as per the LeRobot config scheme.
split: TODO(now)
Returns:
The LeRobotDataset.
"""
if not isinstance(cfg.dataset_repo_id, (str, ListConfig)):
raise ValueError("Expected cfg.dataset_repo_id to be either a single string or a list of strings.")
if isinstance(cfg.dataset_repo_id, str) and cfg.env.name not in cfg.dataset_repo_id:
logging.warning(
f"There might be a mismatch between your training dataset ({cfg.dataset_repo_id=}) and your "
f"environment ({cfg.env.name=})."
@ -49,11 +56,16 @@ def make_dataset(
# TODO(rcadene): add data augmentations
dataset = LeRobotDataset(
cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"),
)
if isinstance(cfg.dataset_repo_id, str):
dataset = LeRobotDataset(
cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"),
)
else:
dataset = MultiLeRobotDataset(
cfg.dataset_repo_id, split=split, delta_timestamps=cfg.training.get("delta_timestamps")
)
if cfg.get("override_dataset_stats"):
for key, stats_dict in cfg.override_dataset_stats.items():

View File

@ -13,12 +13,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from pathlib import Path
from typing import Callable
import datasets
import torch
import torch.utils
from lerobot.common.datasets.compute_stats import consolidate_stats
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
load_episode_data_index,
@ -42,7 +46,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
version: str | None = CODEBASE_VERSION,
root: Path | None = DATA_DIR,
split: str = "train",
transform: callable = None,
transform: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
@ -171,7 +175,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
@classmethod
def from_preloaded(
cls,
repo_id: str,
repo_id: str = "from_preloaded",
version: str | None = CODEBASE_VERSION,
root: Path | None = None,
split: str = "train",
@ -183,7 +187,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
stats=None,
info=None,
videos_dir=None,
):
) -> "LeRobotDataset":
"""Create a LeRobot Dataset from existing data and attributes instead of loading from the filesystem.
It is especially useful when converting raw data into LeRobotDataset before saving the dataset
on the filesystem or uploading to the hub.
Note: Meta-data attributes like `repo_id`, `version`, `root`, etc are optional and potentially
meaningless depending on the downstream usage of the return dataset.
"""
# create an empty object of type LeRobotDataset
obj = cls.__new__(cls)
obj.repo_id = repo_id
@ -195,6 +207,202 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.hf_dataset = hf_dataset
obj.episode_data_index = episode_data_index
obj.stats = stats
obj.info = info
obj.info = info if info is not None else {}
obj.videos_dir = videos_dir
return obj
class MultiLeRobotDataset(torch.utils.data.Dataset):
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API
structure of `LeRobotDataset`.
"""
def __init__(
self,
repo_ids: list[str],
version: str | None = CODEBASE_VERSION,
root: Path | None = DATA_DIR,
split: str = "train",
transform: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
self.repo_ids = repo_ids
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
# are handled by this class.
self._datasets = [
LeRobotDataset(
repo_id,
version=version,
root=root,
split=split,
delta_timestamps=delta_timestamps,
transform=transform,
)
for repo_id in repo_ids
]
# Check that some properties are consistent across datasets. Note: We may relax some of these
# consistency requirements in future iterations of this class.
for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True):
if dataset.info != self._datasets[0].info:
raise ValueError(
f"Detected a mismatch in dataset info between {self.repo_ids[0]} and {repo_id}. This is "
"not yet supported."
)
if set(dataset.features) != set(self._datasets[0].features):
# Use a warning here as we don't want to explicitly block this sort of inconsistency.
logging.warning(
f"Detected a mismatch in dataset features between {self.repo_ids[0]} and {repo_id}."
)
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
# restriction in future iterations of this class. For now, this is necessary at least for being able
# to use PyTorch's default DataLoader collate function.
self.disabled_data_keys = set()
intersection_data_keys = set(self._datasets[0].hf_dataset.features)
for dataset in self._datasets:
intersection_data_keys.intersection_update(dataset.hf_dataset.features)
if len(intersection_data_keys) == 0:
raise RuntimeError(
"Multiple datasets were provided but they had no keys common to all of them. The "
"multi-dataset functionality currently only keeps common keys."
)
for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True):
extra_keys = set(dataset.hf_dataset.features).difference(intersection_data_keys)
logging.warning(
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
"other datasets."
)
self.disabled_data_keys.update(extra_keys)
self.version = version
self.root = root
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
self.stats = consolidate_stats(self._datasets)
@property
def repo_id_to_index(self):
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class.
This index is incorporated as a data key in the dictionary returned by `__getitem__`.
"""
return {repo_id: i for i, repo_id in enumerate(self.repo_ids)}
@property
def repo_index_to_id(self):
"""Return the inverse mapping if repo_id_to_index."""
return {v: k for k, v in self.repo_id_to_index}
@property
def fps(self) -> int:
"""Frames per second used during data collection.
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
"""
return self._datasets[0].info["fps"]
@property
def video(self) -> bool:
"""Returns True if this dataset loads video frames from mp4 files.
Returns False if it only loads images from png files.
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
"""
return self._datasets[0].get("video", False)
@property
def features(self) -> datasets.Features:
features = self.hf_datasets.features
for data_key in self.disabled_data_keys:
if data_key in features:
del features[data_key]
return features
@property
def camera_keys(self) -> list[str]:
"""Keys to access image and video stream from cameras."""
keys = []
for key, feats in self.hf_dataset.features.items():
if key in self.disabled_data_keys:
pass
if isinstance(feats, (datasets.Image, VideoFrame)):
keys.append(key)
return keys
@property
def video_frame_keys(self) -> list[str]:
"""Keys to access video frames that requires to be decoded into images.
Note: It is empty if the dataset contains images only,
or equal to `self.cameras` if the dataset contains videos only,
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
"""
video_frame_keys = []
for key, feats in self.hf_dataset.features.items():
if key in self.disabled_data_keys:
pass
if isinstance(feats, VideoFrame):
video_frame_keys.append(key)
return video_frame_keys
@property
def num_samples(self) -> int:
"""Number of samples/frames."""
return sum(d.num_samples for d in self._datasets)
@property
def num_episodes(self) -> int:
"""Number of episodes."""
return sum(d.num_episodes for d in self._datasets)
@property
def tolerance_s(self) -> float:
"""Tolerance in seconds used to discard loaded frames when their timestamps
are not close enough from the requested frames. It is only used when `delta_timestamps`
is provided or when loading video frames from mp4 files.
"""
# 1e-4 to account for possible numerical error
return 1 / self.fps - 1e-4
def __len__(self):
return self.num_samples
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
if idx >= len(self):
raise IndexError(f"Index {idx} out of bounds.")
# Determine which dataset to get an item from based on the index.
start_idx = 0
dataset_idx = 0
for dataset in self._datasets:
if idx >= start_idx + dataset.num_samples:
start_idx += dataset.num_samples
dataset_idx += 1
break
else:
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
item = self._datasets[dataset_idx][idx - start_idx]
item["dataset_index"] = torch.tensor(dataset_idx)
for data_key in self.disabled_data_keys:
if data_key in item:
del item[data_key]
return item
def __repr__(self):
return (
f"{self.__class__.__name__}(\n"
f" Repository IDs: '{self.repo_ids}',\n"
f" Version: '{self.version}',\n"
f" Split: '{self.split}',\n"
f" Number of Samples: {self.num_samples},\n"
f" Number of Episodes: {self.num_episodes},\n"
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
f" Recorded Frames per Second: {self.fps},\n"
f" Camera Keys: {self.camera_keys},\n"
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
f" Transformations: {self.transform},\n"
f")"
)

View File

@ -59,7 +59,7 @@ def unflatten_dict(d, sep="/"):
return outdict
def hf_transform_to_torch(items_dict):
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
to torch tensors. Importantly, images are converted from PIL, which corresponds to
a channel last representation (h w c) of uint8 type, to a torch image representation
@ -73,6 +73,8 @@ def hf_transform_to_torch(items_dict):
elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item:
# video frame will be processed downstream
pass
elif first_item is None:
pass
else:
items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
return items_dict
@ -317,14 +319,18 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torc
return episode_data_index
def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
"""
Reset the `episode_index` of the provided HuggingFace Dataset.
def reset_episode_index(hf_dataset: datasets.Dataset, start_index: int = 0) -> datasets.Dataset:
"""Reset the `episode_index` of the provided HuggingFace Dataset.
`episode_data_index` (and related functionality such as `load_previous_and_future_frames`) requires the
`episode_index` to be sorted, continuous (1,1,1 and not 1,2,1) and start at 0.
This brings the `episode_index` to the required format.
Args:
hf_dataset: Dataset for which we are resetting the indexing.
start_index: The episode index to start with for the new indexing. For most use cases in LeRobot this
should be left as 0.
"""
if len(hf_dataset) == 0:
return hf_dataset
@ -337,7 +343,8 @@ def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
example["episode_index"] = episode_idx_to_reset_idx_mapping[example["episode_index"].item()]
return example
hf_dataset = hf_dataset.map(modify_ep_idx_func)
hf_dataset = hf_dataset.map(modify_ep_idx_func, input_columns=["episode_index"])
return hf_dataset

View File

@ -23,6 +23,10 @@ use_amp: false
# `seed` is used for training (eg: model initialization, dataset shuffling)
# AND for the evaluation environments.
seed: ???
# You may provide a list of datasets here. `train.py` creates them all and concatenates them. Note: only data
# keys common between the datasets are kept. Each dataset gets and additional transform that inserts the
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
# datsets are provided.
dataset_repo_id: lerobot/pusht
training:

View File

@ -71,9 +71,9 @@ import torch
from huggingface_hub import HfApi
from safetensors.torch import save_file
from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
from lerobot.common.datasets.push_dataset_to_hub.compute_stats import compute_stats
from lerobot.common.datasets.utils import flatten_dict

View File

@ -16,7 +16,6 @@
import logging
import time
from contextlib import nullcontext
from copy import deepcopy
from pathlib import Path
from pprint import pformat
@ -28,6 +27,7 @@ from termcolor import colored
from torch.cuda.amp import GradScaler
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir
@ -280,6 +280,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info("make_dataset")
offline_dataset = make_dataset(cfg)
if isinstance(offline_dataset, MultiLeRobotDataset):
logging.info(
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
f"{pformat(offline_dataset.repo_id_to_index , indent=2)}"
)
# Create environment used for evaluating checkpoints during training on simulation data.
# On real-world data, no need to create an environment as evaluations are done outside train.py,
@ -330,7 +335,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
max_episodes_rendered=4,
start_seed=cfg.seed,
)
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline=True)
if cfg.wandb.enable:
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
logging.info("Resume training")
@ -362,7 +367,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
dl_iter = cycle(dataloader)
policy.train()
is_offline = True
for _ in range(step, cfg.training.offline_steps):
if step == 0:
logging.info("Start offline training on a fixed dataset")
@ -382,7 +386,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
)
if step % cfg.training.log_freq == 0:
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline)
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline=True)
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
# so we pass in step + 1.
@ -390,41 +394,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
step += 1
logging.info("End of offline training")
if cfg.training.online_steps == 0:
if cfg.training.eval_freq > 0:
eval_env.close()
return
# create an env dedicated to online episodes collection from policy rollout
online_training_env = make_env(cfg, n_envs=1)
# create an empty online dataset similar to offline dataset
online_dataset = deepcopy(offline_dataset)
online_dataset.hf_dataset = {}
online_dataset.episode_data_index = {}
# create dataloader for online training
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
weights = [1.0] * len(concat_dataset)
sampler = torch.utils.data.WeightedRandomSampler(
weights, num_samples=len(concat_dataset), replacement=True
)
dataloader = torch.utils.data.DataLoader(
concat_dataset,
num_workers=4,
batch_size=cfg.training.batch_size,
sampler=sampler,
pin_memory=device.type != "cpu",
drop_last=False,
)
logging.info("End of online training")
if cfg.training.eval_freq > 0:
eval_env.close()
online_training_env.close()
eval_env.close()
logging.info("End of training")
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")

View File

@ -23,7 +23,7 @@ If you know that your change will break backward compatibility, you should write
doesnt need to be merged into the `main` branch. Then you need to run this script and update the tests artifacts.
Example usage:
`python tests/scripts/save_dataset_to_safetensors.py`
`DATA_DIR=tests/data python tests/scripts/save_dataset_to_safetensors.py`
"""
import shutil

View File

@ -25,21 +25,20 @@ from datasets import Dataset
from safetensors.torch import load_file
import lerobot
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset,
)
from lerobot.common.datasets.push_dataset_to_hub.compute_stats import (
from lerobot.common.datasets.compute_stats import (
compute_stats,
consolidate_stats,
get_stats_einops_patterns,
)
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import (
flatten_dict,
hf_transform_to_torch,
load_previous_and_future_frames,
unflatten_dict,
)
from lerobot.common.utils.utils import init_hydra_config
from lerobot.common.utils.utils import init_hydra_config, seeded_context
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE
@ -315,3 +314,30 @@ def test_backward_compatibility(repo_id):
# i = dataset.episode_data_index["to"][-1].item()
# load_and_compare(i - 2)
# load_and_compare(i - 1)
def test_consolidate_stats():
"""Makes 3 basic datasets and checks that aggregate stats are computed correctly."""
with seeded_context(0):
data_a = torch.rand(30, dtype=torch.float32)
data_b = torch.rand(20, dtype=torch.float32)
data_c = torch.rand(20, dtype=torch.float32)
hf_dataset_1 = Dataset.from_dict(
{"a": data_a[:10], "b": data_b[:10], "c": data_c[:10], "index": torch.arange(10)}
)
hf_dataset_1.set_transform(hf_transform_to_torch)
hf_dataset_2 = Dataset.from_dict({"a": data_a[10:20], "b": data_b[10:], "index": torch.arange(10)})
hf_dataset_2.set_transform(hf_transform_to_torch)
hf_dataset_3 = Dataset.from_dict({"a": data_a[20:], "c": data_c[10:], "index": torch.arange(10)})
hf_dataset_3.set_transform(hf_transform_to_torch)
dataset_1 = LeRobotDataset.from_preloaded("d1", hf_dataset=hf_dataset_1)
dataset_1.stats = compute_stats(dataset_1, batch_size=len(hf_dataset_1), num_workers=0)
dataset_2 = LeRobotDataset.from_preloaded("d2", hf_dataset=hf_dataset_2)
dataset_2.stats = compute_stats(dataset_2, batch_size=len(hf_dataset_2), num_workers=0)
dataset_3 = LeRobotDataset.from_preloaded("d3", hf_dataset=hf_dataset_3)
dataset_3.stats = compute_stats(dataset_3, batch_size=len(hf_dataset_3), num_workers=0)
stats = consolidate_stats([dataset_1, dataset_2, dataset_3])
for data_key, data in zip(["a", "b", "c"], [data_a, data_b, data_c], strict=True):
for agg_fn in ["mean", "min", "max"]:
assert torch.allclose(stats[data_key][agg_fn], einops.reduce(data, "n -> 1", agg_fn))