Add `MultiLerobotDataset` for training with multiple `LeRobotDataset`s (#229)
This commit is contained in:
parent
265b0ec44d
commit
111cd58f8a
|
@ -16,17 +16,15 @@
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from math import ceil
|
from math import ceil
|
||||||
|
|
||||||
import datasets
|
|
||||||
import einops
|
import einops
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
from datasets import Image
|
from datasets import Image
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame
|
from lerobot.common.datasets.video_utils import VideoFrame
|
||||||
|
|
||||||
|
|
||||||
def get_stats_einops_patterns(dataset: LeRobotDataset | datasets.Dataset, num_workers=0):
|
def get_stats_einops_patterns(dataset, num_workers=0):
|
||||||
"""These einops patterns will be used to aggregate batches and compute statistics.
|
"""These einops patterns will be used to aggregate batches and compute statistics.
|
||||||
|
|
||||||
Note: We assume the images are in channel first format
|
Note: We assume the images are in channel first format
|
||||||
|
@ -66,9 +64,8 @@ def get_stats_einops_patterns(dataset: LeRobotDataset | datasets.Dataset, num_wo
|
||||||
return stats_patterns
|
return stats_patterns
|
||||||
|
|
||||||
|
|
||||||
def compute_stats(
|
def compute_stats(dataset, batch_size=32, num_workers=16, max_num_samples=None):
|
||||||
dataset: LeRobotDataset | datasets.Dataset, batch_size=32, num_workers=16, max_num_samples=None
|
"""Compute mean/std and min/max statistics of all data keys in a LeRobotDataset."""
|
||||||
):
|
|
||||||
if max_num_samples is None:
|
if max_num_samples is None:
|
||||||
max_num_samples = len(dataset)
|
max_num_samples = len(dataset)
|
||||||
|
|
||||||
|
@ -159,3 +156,54 @@ def compute_stats(
|
||||||
"min": min[key],
|
"min": min[key],
|
||||||
}
|
}
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
|
||||||
|
"""Aggregate stats of multiple LeRobot datasets into one set of stats without recomputing from scratch.
|
||||||
|
|
||||||
|
The final stats will have the union of all data keys from each of the datasets.
|
||||||
|
|
||||||
|
The final stats will have the union of all data keys from each of the datasets. For instance:
|
||||||
|
- new_max = max(max_dataset_0, max_dataset_1, ...)
|
||||||
|
- new_min = min(min_dataset_0, min_dataset_1, ...)
|
||||||
|
- new_mean = (mean of all data)
|
||||||
|
- new_std = (std of all data)
|
||||||
|
"""
|
||||||
|
data_keys = set()
|
||||||
|
for dataset in ls_datasets:
|
||||||
|
data_keys.update(dataset.stats.keys())
|
||||||
|
stats = {k: {} for k in data_keys}
|
||||||
|
for data_key in data_keys:
|
||||||
|
for stat_key in ["min", "max"]:
|
||||||
|
# compute `max(dataset_0["max"], dataset_1["max"], ...)`
|
||||||
|
stats[data_key][stat_key] = einops.reduce(
|
||||||
|
torch.stack([d.stats[data_key][stat_key] for d in ls_datasets if data_key in d.stats], dim=0),
|
||||||
|
"n ... -> ...",
|
||||||
|
stat_key,
|
||||||
|
)
|
||||||
|
total_samples = sum(d.num_samples for d in ls_datasets if data_key in d.stats)
|
||||||
|
# Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
|
||||||
|
# dataset, then divide by total_samples to get the overall "mean".
|
||||||
|
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
|
||||||
|
# numerical overflow!
|
||||||
|
stats[data_key]["mean"] = sum(
|
||||||
|
d.stats[data_key]["mean"] * (d.num_samples / total_samples)
|
||||||
|
for d in ls_datasets
|
||||||
|
if data_key in d.stats
|
||||||
|
)
|
||||||
|
# The derivation for standard deviation is a little more involved but is much in the same spirit as
|
||||||
|
# the computation of the mean.
|
||||||
|
# Given two sets of data where the statistics are known:
|
||||||
|
# σ_combined = sqrt[ (n1 * (σ1^2 + d1^2) + n2 * (σ2^2 + d2^2)) / (n1 + n2) ]
|
||||||
|
# where d1 = μ1 - μ_combined, d2 = μ2 - μ_combined
|
||||||
|
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
|
||||||
|
# numerical overflow!
|
||||||
|
stats[data_key]["std"] = torch.sqrt(
|
||||||
|
sum(
|
||||||
|
(d.stats[data_key]["std"] ** 2 + (d.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2)
|
||||||
|
* (d.num_samples / total_samples)
|
||||||
|
for d in ls_datasets
|
||||||
|
if data_key in d.stats
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return stats
|
|
@ -16,9 +16,9 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import ListConfig, OmegaConf
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
|
||||||
|
|
||||||
|
|
||||||
def resolve_delta_timestamps(cfg):
|
def resolve_delta_timestamps(cfg):
|
||||||
|
@ -35,11 +35,27 @@ def resolve_delta_timestamps(cfg):
|
||||||
cfg.training.delta_timestamps[key] = eval(delta_timestamps[key])
|
cfg.training.delta_timestamps[key] = eval(delta_timestamps[key])
|
||||||
|
|
||||||
|
|
||||||
def make_dataset(
|
def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotDataset:
|
||||||
cfg,
|
"""
|
||||||
split="train",
|
Args:
|
||||||
):
|
cfg: A Hydra config as per the LeRobot config scheme.
|
||||||
if cfg.env.name not in cfg.dataset_repo_id:
|
split: Select the data subset used to create an instance of LeRobotDataset.
|
||||||
|
All datasets hosted on [lerobot](https://huggingface.co/lerobot) contain only one subset: "train".
|
||||||
|
Thus, by default, `split="train"` selects all the available data. `split` aims to work like the
|
||||||
|
slicer in the hugging face datasets:
|
||||||
|
https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
|
||||||
|
As of now, it only supports `split="train[:n]"` to load the first n frames of the dataset or
|
||||||
|
`split="train[n:]"` to load the last n frames. For instance `split="train[:1000]"`.
|
||||||
|
Returns:
|
||||||
|
The LeRobotDataset.
|
||||||
|
"""
|
||||||
|
if not isinstance(cfg.dataset_repo_id, (str, ListConfig)):
|
||||||
|
raise ValueError(
|
||||||
|
"Expected cfg.dataset_repo_id to be either a single string to load one dataset or a list of "
|
||||||
|
"strings to load multiple datasets."
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(cfg.dataset_repo_id, str) and cfg.env.name not in cfg.dataset_repo_id:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"There might be a mismatch between your training dataset ({cfg.dataset_repo_id=}) and your "
|
f"There might be a mismatch between your training dataset ({cfg.dataset_repo_id=}) and your "
|
||||||
f"environment ({cfg.env.name=})."
|
f"environment ({cfg.env.name=})."
|
||||||
|
@ -49,11 +65,16 @@ def make_dataset(
|
||||||
|
|
||||||
# TODO(rcadene): add data augmentations
|
# TODO(rcadene): add data augmentations
|
||||||
|
|
||||||
|
if isinstance(cfg.dataset_repo_id, str):
|
||||||
dataset = LeRobotDataset(
|
dataset = LeRobotDataset(
|
||||||
cfg.dataset_repo_id,
|
cfg.dataset_repo_id,
|
||||||
split=split,
|
split=split,
|
||||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
dataset = MultiLeRobotDataset(
|
||||||
|
cfg.dataset_repo_id, split=split, delta_timestamps=cfg.training.get("delta_timestamps")
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.get("override_dataset_stats"):
|
if cfg.get("override_dataset_stats"):
|
||||||
for key, stats_dict in cfg.override_dataset_stats.items():
|
for key, stats_dict in cfg.override_dataset_stats.items():
|
||||||
|
|
|
@ -13,12 +13,16 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
|
import torch.utils
|
||||||
|
|
||||||
|
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
calculate_episode_data_index,
|
calculate_episode_data_index,
|
||||||
load_episode_data_index,
|
load_episode_data_index,
|
||||||
|
@ -42,7 +46,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
version: str | None = CODEBASE_VERSION,
|
version: str | None = CODEBASE_VERSION,
|
||||||
root: Path | None = DATA_DIR,
|
root: Path | None = DATA_DIR,
|
||||||
split: str = "train",
|
split: str = "train",
|
||||||
transform: callable = None,
|
transform: Callable | None = None,
|
||||||
delta_timestamps: dict[list[float]] | None = None,
|
delta_timestamps: dict[list[float]] | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -171,7 +175,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_preloaded(
|
def from_preloaded(
|
||||||
cls,
|
cls,
|
||||||
repo_id: str,
|
repo_id: str = "from_preloaded",
|
||||||
version: str | None = CODEBASE_VERSION,
|
version: str | None = CODEBASE_VERSION,
|
||||||
root: Path | None = None,
|
root: Path | None = None,
|
||||||
split: str = "train",
|
split: str = "train",
|
||||||
|
@ -183,7 +187,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
stats=None,
|
stats=None,
|
||||||
info=None,
|
info=None,
|
||||||
videos_dir=None,
|
videos_dir=None,
|
||||||
):
|
) -> "LeRobotDataset":
|
||||||
|
"""Create a LeRobot Dataset from existing data and attributes instead of loading from the filesystem.
|
||||||
|
|
||||||
|
It is especially useful when converting raw data into LeRobotDataset before saving the dataset
|
||||||
|
on the filesystem or uploading to the hub.
|
||||||
|
|
||||||
|
Note: Meta-data attributes like `repo_id`, `version`, `root`, etc are optional and potentially
|
||||||
|
meaningless depending on the downstream usage of the return dataset.
|
||||||
|
"""
|
||||||
# create an empty object of type LeRobotDataset
|
# create an empty object of type LeRobotDataset
|
||||||
obj = cls.__new__(cls)
|
obj = cls.__new__(cls)
|
||||||
obj.repo_id = repo_id
|
obj.repo_id = repo_id
|
||||||
|
@ -195,6 +207,192 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
obj.hf_dataset = hf_dataset
|
obj.hf_dataset = hf_dataset
|
||||||
obj.episode_data_index = episode_data_index
|
obj.episode_data_index = episode_data_index
|
||||||
obj.stats = stats
|
obj.stats = stats
|
||||||
obj.info = info
|
obj.info = info if info is not None else {}
|
||||||
obj.videos_dir = videos_dir
|
obj.videos_dir = videos_dir
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||||
|
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
|
||||||
|
|
||||||
|
The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API
|
||||||
|
structure of `LeRobotDataset`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
repo_ids: list[str],
|
||||||
|
version: str | None = CODEBASE_VERSION,
|
||||||
|
root: Path | None = DATA_DIR,
|
||||||
|
split: str = "train",
|
||||||
|
transform: Callable | None = None,
|
||||||
|
delta_timestamps: dict[list[float]] | None = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.repo_ids = repo_ids
|
||||||
|
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
||||||
|
# are handled by this class.
|
||||||
|
self._datasets = [
|
||||||
|
LeRobotDataset(
|
||||||
|
repo_id,
|
||||||
|
version=version,
|
||||||
|
root=root,
|
||||||
|
split=split,
|
||||||
|
delta_timestamps=delta_timestamps,
|
||||||
|
transform=transform,
|
||||||
|
)
|
||||||
|
for repo_id in repo_ids
|
||||||
|
]
|
||||||
|
# Check that some properties are consistent across datasets. Note: We may relax some of these
|
||||||
|
# consistency requirements in future iterations of this class.
|
||||||
|
for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True):
|
||||||
|
if dataset.info != self._datasets[0].info:
|
||||||
|
raise ValueError(
|
||||||
|
f"Detected a mismatch in dataset info between {self.repo_ids[0]} and {repo_id}. This is "
|
||||||
|
"not yet supported."
|
||||||
|
)
|
||||||
|
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
|
||||||
|
# restriction in future iterations of this class. For now, this is necessary at least for being able
|
||||||
|
# to use PyTorch's default DataLoader collate function.
|
||||||
|
self.disabled_data_keys = set()
|
||||||
|
intersection_data_keys = set(self._datasets[0].hf_dataset.features)
|
||||||
|
for dataset in self._datasets:
|
||||||
|
intersection_data_keys.intersection_update(dataset.hf_dataset.features)
|
||||||
|
if len(intersection_data_keys) == 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Multiple datasets were provided but they had no keys common to all of them. The "
|
||||||
|
"multi-dataset functionality currently only keeps common keys."
|
||||||
|
)
|
||||||
|
for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True):
|
||||||
|
extra_keys = set(dataset.hf_dataset.features).difference(intersection_data_keys)
|
||||||
|
logging.warning(
|
||||||
|
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
||||||
|
"other datasets."
|
||||||
|
)
|
||||||
|
self.disabled_data_keys.update(extra_keys)
|
||||||
|
|
||||||
|
self.version = version
|
||||||
|
self.root = root
|
||||||
|
self.split = split
|
||||||
|
self.transform = transform
|
||||||
|
self.delta_timestamps = delta_timestamps
|
||||||
|
self.stats = aggregate_stats(self._datasets)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def repo_id_to_index(self):
|
||||||
|
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class.
|
||||||
|
|
||||||
|
This index is incorporated as a data key in the dictionary returned by `__getitem__`.
|
||||||
|
"""
|
||||||
|
return {repo_id: i for i, repo_id in enumerate(self.repo_ids)}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def repo_index_to_id(self):
|
||||||
|
"""Return the inverse mapping if repo_id_to_index."""
|
||||||
|
return {v: k for k, v in self.repo_id_to_index}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fps(self) -> int:
|
||||||
|
"""Frames per second used during data collection.
|
||||||
|
|
||||||
|
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
||||||
|
"""
|
||||||
|
return self._datasets[0].info["fps"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def video(self) -> bool:
|
||||||
|
"""Returns True if this dataset loads video frames from mp4 files.
|
||||||
|
|
||||||
|
Returns False if it only loads images from png files.
|
||||||
|
|
||||||
|
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
||||||
|
"""
|
||||||
|
return self._datasets[0].info.get("video", False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def features(self) -> datasets.Features:
|
||||||
|
features = {}
|
||||||
|
for dataset in self._datasets:
|
||||||
|
features.update({k: v for k, v in dataset.features.items() if k not in self.disabled_data_keys})
|
||||||
|
return features
|
||||||
|
|
||||||
|
@property
|
||||||
|
def camera_keys(self) -> list[str]:
|
||||||
|
"""Keys to access image and video stream from cameras."""
|
||||||
|
keys = []
|
||||||
|
for key, feats in self.features.items():
|
||||||
|
if isinstance(feats, (datasets.Image, VideoFrame)):
|
||||||
|
keys.append(key)
|
||||||
|
return keys
|
||||||
|
|
||||||
|
@property
|
||||||
|
def video_frame_keys(self) -> list[str]:
|
||||||
|
"""Keys to access video frames that requires to be decoded into images.
|
||||||
|
|
||||||
|
Note: It is empty if the dataset contains images only,
|
||||||
|
or equal to `self.cameras` if the dataset contains videos only,
|
||||||
|
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
|
||||||
|
"""
|
||||||
|
video_frame_keys = []
|
||||||
|
for key, feats in self.features.items():
|
||||||
|
if isinstance(feats, VideoFrame):
|
||||||
|
video_frame_keys.append(key)
|
||||||
|
return video_frame_keys
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_samples(self) -> int:
|
||||||
|
"""Number of samples/frames."""
|
||||||
|
return sum(d.num_samples for d in self._datasets)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_episodes(self) -> int:
|
||||||
|
"""Number of episodes."""
|
||||||
|
return sum(d.num_episodes for d in self._datasets)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tolerance_s(self) -> float:
|
||||||
|
"""Tolerance in seconds used to discard loaded frames when their timestamps
|
||||||
|
are not close enough from the requested frames. It is only used when `delta_timestamps`
|
||||||
|
is provided or when loading video frames from mp4 files.
|
||||||
|
"""
|
||||||
|
# 1e-4 to account for possible numerical error
|
||||||
|
return 1 / self.fps - 1e-4
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_samples
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||||
|
if idx >= len(self):
|
||||||
|
raise IndexError(f"Index {idx} out of bounds.")
|
||||||
|
# Determine which dataset to get an item from based on the index.
|
||||||
|
start_idx = 0
|
||||||
|
dataset_idx = 0
|
||||||
|
for dataset in self._datasets:
|
||||||
|
if idx >= start_idx + dataset.num_samples:
|
||||||
|
start_idx += dataset.num_samples
|
||||||
|
dataset_idx += 1
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
|
||||||
|
item = self._datasets[dataset_idx][idx - start_idx]
|
||||||
|
item["dataset_index"] = torch.tensor(dataset_idx)
|
||||||
|
for data_key in self.disabled_data_keys:
|
||||||
|
if data_key in item:
|
||||||
|
del item[data_key]
|
||||||
|
return item
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (
|
||||||
|
f"{self.__class__.__name__}(\n"
|
||||||
|
f" Repository IDs: '{self.repo_ids}',\n"
|
||||||
|
f" Version: '{self.version}',\n"
|
||||||
|
f" Split: '{self.split}',\n"
|
||||||
|
f" Number of Samples: {self.num_samples},\n"
|
||||||
|
f" Number of Episodes: {self.num_episodes},\n"
|
||||||
|
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
|
||||||
|
f" Recorded Frames per Second: {self.fps},\n"
|
||||||
|
f" Camera Keys: {self.camera_keys},\n"
|
||||||
|
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
|
||||||
|
f" Transformations: {self.transform},\n"
|
||||||
|
f")"
|
||||||
|
)
|
||||||
|
|
|
@ -59,7 +59,7 @@ def unflatten_dict(d, sep="/"):
|
||||||
return outdict
|
return outdict
|
||||||
|
|
||||||
|
|
||||||
def hf_transform_to_torch(items_dict):
|
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||||
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
||||||
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
||||||
a channel last representation (h w c) of uint8 type, to a torch image representation
|
a channel last representation (h w c) of uint8 type, to a torch image representation
|
||||||
|
@ -73,6 +73,8 @@ def hf_transform_to_torch(items_dict):
|
||||||
elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item:
|
elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item:
|
||||||
# video frame will be processed downstream
|
# video frame will be processed downstream
|
||||||
pass
|
pass
|
||||||
|
elif first_item is None:
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
|
items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
|
||||||
return items_dict
|
return items_dict
|
||||||
|
@ -318,8 +320,7 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torc
|
||||||
|
|
||||||
|
|
||||||
def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
|
def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
|
||||||
"""
|
"""Reset the `episode_index` of the provided HuggingFace Dataset.
|
||||||
Reset the `episode_index` of the provided HuggingFace Dataset.
|
|
||||||
|
|
||||||
`episode_data_index` (and related functionality such as `load_previous_and_future_frames`) requires the
|
`episode_data_index` (and related functionality such as `load_previous_and_future_frames`) requires the
|
||||||
`episode_index` to be sorted, continuous (1,1,1 and not 1,2,1) and start at 0.
|
`episode_index` to be sorted, continuous (1,1,1 and not 1,2,1) and start at 0.
|
||||||
|
@ -338,6 +339,7 @@ def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
|
||||||
return example
|
return example
|
||||||
|
|
||||||
hf_dataset = hf_dataset.map(modify_ep_idx_func)
|
hf_dataset = hf_dataset.map(modify_ep_idx_func)
|
||||||
|
|
||||||
return hf_dataset
|
return hf_dataset
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,10 @@ use_amp: false
|
||||||
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
||||||
# AND for the evaluation environments.
|
# AND for the evaluation environments.
|
||||||
seed: ???
|
seed: ???
|
||||||
|
# You may provide a list of datasets here. `train.py` creates them all and concatenates them. Note: only data
|
||||||
|
# keys common between the datasets are kept. Each dataset gets and additional transform that inserts the
|
||||||
|
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
|
||||||
|
# datsets are provided.
|
||||||
dataset_repo_id: lerobot/pusht
|
dataset_repo_id: lerobot/pusht
|
||||||
|
|
||||||
training:
|
training:
|
||||||
|
|
|
@ -71,9 +71,9 @@ import torch
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
from lerobot.common.datasets.compute_stats import compute_stats
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.compute_stats import compute_stats
|
|
||||||
from lerobot.common.datasets.utils import flatten_dict
|
from lerobot.common.datasets.utils import flatten_dict
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from copy import deepcopy
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
|
|
||||||
|
@ -28,6 +27,7 @@ from termcolor import colored
|
||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
|
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset
|
||||||
from lerobot.common.datasets.utils import cycle
|
from lerobot.common.datasets.utils import cycle
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.logger import Logger, log_output_dir
|
from lerobot.common.logger import Logger, log_output_dir
|
||||||
|
@ -280,6 +280,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
|
|
||||||
logging.info("make_dataset")
|
logging.info("make_dataset")
|
||||||
offline_dataset = make_dataset(cfg)
|
offline_dataset = make_dataset(cfg)
|
||||||
|
if isinstance(offline_dataset, MultiLeRobotDataset):
|
||||||
|
logging.info(
|
||||||
|
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
||||||
|
f"{pformat(offline_dataset.repo_id_to_index , indent=2)}"
|
||||||
|
)
|
||||||
|
|
||||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||||
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||||
|
@ -330,7 +335,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
max_episodes_rendered=4,
|
max_episodes_rendered=4,
|
||||||
start_seed=cfg.seed,
|
start_seed=cfg.seed,
|
||||||
)
|
)
|
||||||
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
|
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline=True)
|
||||||
if cfg.wandb.enable:
|
if cfg.wandb.enable:
|
||||||
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
||||||
logging.info("Resume training")
|
logging.info("Resume training")
|
||||||
|
@ -362,7 +367,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
dl_iter = cycle(dataloader)
|
dl_iter = cycle(dataloader)
|
||||||
|
|
||||||
policy.train()
|
policy.train()
|
||||||
is_offline = True
|
|
||||||
for _ in range(step, cfg.training.offline_steps):
|
for _ in range(step, cfg.training.offline_steps):
|
||||||
if step == 0:
|
if step == 0:
|
||||||
logging.info("Start offline training on a fixed dataset")
|
logging.info("Start offline training on a fixed dataset")
|
||||||
|
@ -382,7 +386,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
)
|
)
|
||||||
|
|
||||||
if step % cfg.training.log_freq == 0:
|
if step % cfg.training.log_freq == 0:
|
||||||
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline)
|
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline=True)
|
||||||
|
|
||||||
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
|
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
|
||||||
# so we pass in step + 1.
|
# so we pass in step + 1.
|
||||||
|
@ -390,41 +394,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
|
|
||||||
step += 1
|
step += 1
|
||||||
|
|
||||||
logging.info("End of offline training")
|
|
||||||
|
|
||||||
if cfg.training.online_steps == 0:
|
|
||||||
if cfg.training.eval_freq > 0:
|
|
||||||
eval_env.close()
|
eval_env.close()
|
||||||
return
|
logging.info("End of training")
|
||||||
|
|
||||||
# create an env dedicated to online episodes collection from policy rollout
|
|
||||||
online_training_env = make_env(cfg, n_envs=1)
|
|
||||||
|
|
||||||
# create an empty online dataset similar to offline dataset
|
|
||||||
online_dataset = deepcopy(offline_dataset)
|
|
||||||
online_dataset.hf_dataset = {}
|
|
||||||
online_dataset.episode_data_index = {}
|
|
||||||
|
|
||||||
# create dataloader for online training
|
|
||||||
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
|
|
||||||
weights = [1.0] * len(concat_dataset)
|
|
||||||
sampler = torch.utils.data.WeightedRandomSampler(
|
|
||||||
weights, num_samples=len(concat_dataset), replacement=True
|
|
||||||
)
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
|
||||||
concat_dataset,
|
|
||||||
num_workers=4,
|
|
||||||
batch_size=cfg.training.batch_size,
|
|
||||||
sampler=sampler,
|
|
||||||
pin_memory=device.type != "cpu",
|
|
||||||
drop_last=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info("End of online training")
|
|
||||||
|
|
||||||
if cfg.training.eval_freq > 0:
|
|
||||||
eval_env.close()
|
|
||||||
online_training_env.close()
|
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
||||||
|
|
|
@ -25,26 +25,34 @@ from datasets import Dataset
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
import lerobot
|
import lerobot
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.compute_stats import (
|
||||||
from lerobot.common.datasets.lerobot_dataset import (
|
aggregate_stats,
|
||||||
LeRobotDataset,
|
|
||||||
)
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.compute_stats import (
|
|
||||||
compute_stats,
|
compute_stats,
|
||||||
get_stats_einops_patterns,
|
get_stats_einops_patterns,
|
||||||
)
|
)
|
||||||
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
flatten_dict,
|
flatten_dict,
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
load_previous_and_future_frames,
|
load_previous_and_future_frames,
|
||||||
unflatten_dict,
|
unflatten_dict,
|
||||||
)
|
)
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE
|
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("env_name, repo_id, policy_name", lerobot.env_dataset_policy_triplets)
|
@pytest.mark.parametrize(
|
||||||
|
"env_name, repo_id, policy_name",
|
||||||
|
lerobot.env_dataset_policy_triplets
|
||||||
|
+ [("aloha", ["lerobot/aloha_sim_insertion_human", "lerobot/aloha_sim_transfer_cube_human"], "act")],
|
||||||
|
)
|
||||||
def test_factory(env_name, repo_id, policy_name):
|
def test_factory(env_name, repo_id, policy_name):
|
||||||
|
"""
|
||||||
|
Tests that:
|
||||||
|
- we can create a dataset with the factory.
|
||||||
|
- for a commonly used set of data keys, the data dimensions are correct.
|
||||||
|
"""
|
||||||
cfg = init_hydra_config(
|
cfg = init_hydra_config(
|
||||||
DEFAULT_CONFIG_PATH,
|
DEFAULT_CONFIG_PATH,
|
||||||
overrides=[
|
overrides=[
|
||||||
|
@ -315,3 +323,31 @@ def test_backward_compatibility(repo_id):
|
||||||
# i = dataset.episode_data_index["to"][-1].item()
|
# i = dataset.episode_data_index["to"][-1].item()
|
||||||
# load_and_compare(i - 2)
|
# load_and_compare(i - 2)
|
||||||
# load_and_compare(i - 1)
|
# load_and_compare(i - 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate_stats():
|
||||||
|
"""Makes 3 basic datasets and checks that aggregate stats are computed correctly."""
|
||||||
|
with seeded_context(0):
|
||||||
|
data_a = torch.rand(30, dtype=torch.float32)
|
||||||
|
data_b = torch.rand(20, dtype=torch.float32)
|
||||||
|
data_c = torch.rand(20, dtype=torch.float32)
|
||||||
|
|
||||||
|
hf_dataset_1 = Dataset.from_dict(
|
||||||
|
{"a": data_a[:10], "b": data_b[:10], "c": data_c[:10], "index": torch.arange(10)}
|
||||||
|
)
|
||||||
|
hf_dataset_1.set_transform(hf_transform_to_torch)
|
||||||
|
hf_dataset_2 = Dataset.from_dict({"a": data_a[10:20], "b": data_b[10:], "index": torch.arange(10)})
|
||||||
|
hf_dataset_2.set_transform(hf_transform_to_torch)
|
||||||
|
hf_dataset_3 = Dataset.from_dict({"a": data_a[20:], "c": data_c[10:], "index": torch.arange(10)})
|
||||||
|
hf_dataset_3.set_transform(hf_transform_to_torch)
|
||||||
|
dataset_1 = LeRobotDataset.from_preloaded("d1", hf_dataset=hf_dataset_1)
|
||||||
|
dataset_1.stats = compute_stats(dataset_1, batch_size=len(hf_dataset_1), num_workers=0)
|
||||||
|
dataset_2 = LeRobotDataset.from_preloaded("d2", hf_dataset=hf_dataset_2)
|
||||||
|
dataset_2.stats = compute_stats(dataset_2, batch_size=len(hf_dataset_2), num_workers=0)
|
||||||
|
dataset_3 = LeRobotDataset.from_preloaded("d3", hf_dataset=hf_dataset_3)
|
||||||
|
dataset_3.stats = compute_stats(dataset_3, batch_size=len(hf_dataset_3), num_workers=0)
|
||||||
|
stats = aggregate_stats([dataset_1, dataset_2, dataset_3])
|
||||||
|
for data_key, data in zip(["a", "b", "c"], [data_a, data_b, data_c], strict=True):
|
||||||
|
for agg_fn in ["mean", "min", "max"]:
|
||||||
|
assert torch.allclose(stats[data_key][agg_fn], einops.reduce(data, "n -> 1", agg_fn))
|
||||||
|
assert torch.allclose(stats[data_key]["std"], torch.std(data, correction=0))
|
||||||
|
|
Loading…
Reference in New Issue