Apply suggestions from code review

This commit is contained in:
Simon Alibert 2024-11-25 12:44:12 +01:00
parent f56d769dfb
commit 23f6c875b5
15 changed files with 69 additions and 155 deletions

View File

@ -77,7 +77,7 @@ print(dataset.hf_dataset)
# LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working
# with the latter, like iterating through the dataset.
# The __get_item__ iterates over the frames of the dataset. Since our datasets are also structured by
# The __getitem__ iterates over the frames of the dataset. Since our datasets are also structured by
# episodes, you can access the frame indices of any episode using the episode_data_index. Here, we access
# frame indices associated to the first episode:
episode_index = 0

View File

@ -1,7 +1,7 @@
"""
This script demonstrates how to use torchvision's image transformation with LeRobotDataset for data
augmentation purposes. The transformations are passed to the dataset as an argument upon creation, and
transforms are applied to the observation images before they are returned in the dataset's __get_item__.
transforms are applied to the observation images before they are returned in the dataset's __getitem__.
"""
from pathlib import Path

View File

@ -8,7 +8,6 @@ especially in the context of imitation learning. The most reliable approach is t
on the target environment, whether that be in simulation or the real world.
"""
# TODO(aliberts, rcadene): Update this script with the new v2 api
import math
from pathlib import Path

View File

@ -170,25 +170,28 @@ def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
"""
data_keys = set()
for dataset in ls_datasets:
data_keys.update(dataset.stats.keys())
data_keys.update(dataset.meta.stats.keys())
stats = {k: {} for k in data_keys}
for data_key in data_keys:
for stat_key in ["min", "max"]:
# compute `max(dataset_0["max"], dataset_1["max"], ...)`
stats[data_key][stat_key] = einops.reduce(
torch.stack([d.stats[data_key][stat_key] for d in ls_datasets if data_key in d.stats], dim=0),
torch.stack(
[ds.meta.stats[data_key][stat_key] for ds in ls_datasets if data_key in ds.meta.stats],
dim=0,
),
"n ... -> ...",
stat_key,
)
total_samples = sum(d.num_frames for d in ls_datasets if data_key in d.stats)
total_samples = sum(d.num_frames for d in ls_datasets if data_key in d.meta.stats)
# Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
# dataset, then divide by total_samples to get the overall "mean".
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
# numerical overflow!
stats[data_key]["mean"] = sum(
d.stats[data_key]["mean"] * (d.num_frames / total_samples)
d.meta.stats[data_key]["mean"] * (d.num_frames / total_samples)
for d in ls_datasets
if data_key in d.stats
if data_key in d.meta.stats
)
# The derivation for standard deviation is a little more involved but is much in the same spirit as
# the computation of the mean.
@ -199,102 +202,13 @@ def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
# numerical overflow!
stats[data_key]["std"] = torch.sqrt(
sum(
(d.stats[data_key]["std"] ** 2 + (d.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2)
(
d.meta.stats[data_key]["std"] ** 2
+ (d.meta.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2
)
* (d.num_frames / total_samples)
for d in ls_datasets
if data_key in d.stats
if data_key in d.meta.stats
)
)
return stats
# TODO(aliberts): refactor stats in save_episodes
# import numpy as np
# from lerobot.common.datasets.utils import load_image_as_numpy
# def aggregate_stats_v2(stats_list: list) -> dict:
# """Aggregate stats from multiple compute_stats outputs into a single set of stats.
# The final stats will have the union of all data keys from each of the stats dicts.
# For instance:
# - new_min = min(min_dataset_0, min_dataset_1, ...)
# - new_max = max(max_dataset_0, max_dataset_1, ...)
# - new_mean = (mean of all data, weighted by counts)
# - new_std = (std of all data)
# """
# data_keys = set(key for stats in stats_list for key in stats.keys())
# aggregated_stats = {key: {} for key in data_keys}
# for key in data_keys:
# # Collect stats for the current key from all datasets where it exists
# stats_with_key = [stats[key] for stats in stats_list if key in stats]
# # Aggregate 'min' and 'max' using np.minimum and np.maximum
# aggregated_stats[key]['min'] = np.minimum.reduce([s['min'] for s in stats_with_key])
# aggregated_stats[key]['max'] = np.maximum.reduce([s['max'] for s in stats_with_key])
# # Extract means, variances (std^2), and counts
# means = np.array([s['mean'] for s in stats_with_key])
# variances = np.array([s['std']**2 for s in stats_with_key])
# counts = np.array([s['count'] for s in stats_with_key])
# # Ensure counts can broadcast with means/variances if they have additional dimensions
# counts = counts.reshape(-1, *[1]*(means.ndim - 1))
# # Compute total counts
# total_count = counts.sum(axis=0)
# # Compute the weighted mean
# weighted_means = means * counts
# total_mean = weighted_means.sum(axis=0) / total_count
# # Compute the variance using the parallel algorithm
# delta_means = means - total_mean
# weighted_variances = (variances + delta_means**2) * counts
# total_variance = weighted_variances.sum(axis=0) / total_count
# # Store the aggregated stats
# aggregated_stats[key]['mean'] = total_mean
# aggregated_stats[key]['std'] = np.sqrt(total_variance)
# aggregated_stats[key]['count'] = total_count
# return aggregated_stats
# def compute_episode_stats(episode_buffer: dict, features: dict, episode_length: int, image_sampling: int = 10) -> dict:
# stats = {}
# for key, data in episode_buffer.items():
# if features[key]["dtype"] in ["image", "video"]:
# stats[key] = compute_image_stats(data, sampling=image_sampling)
# else:
# axes_to_reduce = 0 # Compute stats over the first axis
# stats[key] = {
# "min": np.min(data, axis=axes_to_reduce),
# "max": np.max(data, axis=axes_to_reduce),
# "mean": np.mean(data, axis=axes_to_reduce),
# "std": np.std(data, axis=axes_to_reduce),
# "count": episode_length,
# }
# return stats
# def compute_image_stats(image_paths: list[str], sampling: int = 10) -> dict:
# images = []
# samples = range(0, len(image_paths), sampling)
# for idx in samples:
# path = image_paths[idx]
# img = load_image_as_numpy(path, channel_first=True)
# images.append(img)
# images = np.stack(images)
# axes_to_reduce = (0, 2, 3) # keep channel dim
# image_stats = {
# "min": np.min(images, axis=axes_to_reduce, keepdims=True),
# "max": np.max(images, axis=axes_to_reduce, keepdims=True),
# "mean": np.mean(images, axis=axes_to_reduce, keepdims=True),
# "std": np.std(images, axis=axes_to_reduce, keepdims=True)
# }
# for key in image_stats: # squeeze batch dim
# image_stats[key] = np.squeeze(image_stats[key], axis=0)
# return image_stats

View File

@ -63,7 +63,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
print(f"Error writing image {fpath}: {e}")
def worker_thread_process(queue: queue.Queue):
def worker_thread_loop(queue: queue.Queue):
while True:
item = queue.get()
if item is None:
@ -77,7 +77,7 @@ def worker_thread_process(queue: queue.Queue):
def worker_process(queue: queue.Queue, num_threads: int):
threads = []
for _ in range(num_threads):
t = threading.Thread(target=worker_thread_process, args=(queue,))
t = threading.Thread(target=worker_thread_loop, args=(queue,))
t.daemon = True
t.start()
threads.append(t)
@ -115,7 +115,7 @@ class AsyncImageWriter:
# Use threading
self.queue = queue.Queue()
for _ in range(self.num_threads):
t = threading.Thread(target=worker_thread_process, args=(self.queue,))
t = threading.Thread(target=worker_thread_loop, args=(self.queue,))
t.daemon = True
t.start()
self.threads.append(t)

View File

@ -427,12 +427,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""
super().__init__()
self.repo_id = repo_id
self.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
self.root = Path(root) if root else LEROBOT_HOME / repo_id
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
self.episodes = episodes
self.tolerance_s = tolerance_s
self.video_backend = video_backend if video_backend is not None else "pyav"
self.video_backend = video_backend if video_backend else "pyav"
self.delta_indices = None
self.local_files_only = local_files_only
@ -473,10 +473,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
**card_kwargs,
) -> None:
if not self.consolidated:
raise RuntimeError(
"You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet."
"Please call the dataset 'consolidate()' method first."
logging.warning(
"You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet. "
"Consolidating first."
)
self.consolidate()
ignore_patterns = ["images/"]
if not push_videos:
@ -750,7 +751,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
episode_index = episode_buffer["episode_index"]
if episode_index != self.meta.total_episodes:
# TODO(aliberts): Add option to use existing episode_index
raise NotImplementedError()
raise NotImplementedError(
"You might have manually provided the episode_buffer with an episode_index that doesn't "
"match the total number of episodes in the dataset. This is not supported for now."
)
if episode_length == 0:
raise ValueError(
@ -818,7 +822,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None:
if isinstance(self.image_writer, AsyncImageWriter):
logging.warning(
"You are starting a new AsyncImageWriter that is replacing an already exising one in the dataset."
"You are starting a new AsyncImageWriter that is replacing an already existing one in the dataset."
)
self.image_writer = AsyncImageWriter(
@ -965,56 +969,56 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
def __init__(
self,
repo_ids: list[str],
root: Path | None = None,
root: str | Path | None = None,
episodes: dict | None = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
tolerances_s: dict | None = None,
download_videos: bool = True,
local_files_only: bool = False,
video_backend: str | None = None,
):
super().__init__()
self.repo_ids = repo_ids
self.root = Path(root) if root else LEROBOT_HOME
self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
# are handled by this class.
self._datasets = [
LeRobotDataset(
repo_id,
root=root / repo_id if root is not None else None,
episodes=episodes[repo_id] if episodes is not None else None,
delta_timestamps=delta_timestamps,
root=self.root / repo_id,
episodes=episodes[repo_id] if episodes else None,
image_transforms=image_transforms,
delta_timestamps=delta_timestamps,
tolerance_s=self.tolerances_s[repo_id],
download_videos=download_videos,
local_files_only=local_files_only,
video_backend=video_backend,
)
for repo_id in repo_ids
]
# Check that some properties are consistent across datasets. Note: We may relax some of these
# consistency requirements in future iterations of this class.
for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True):
if dataset.meta.info != self._datasets[0].meta.info:
raise ValueError(
f"Detected a mismatch in dataset info between {self.repo_ids[0]} and {repo_id}. This is "
"not yet supported."
)
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
# restriction in future iterations of this class. For now, this is necessary at least for being able
# to use PyTorch's default DataLoader collate function.
self.disabled_data_keys = set()
intersection_data_keys = set(self._datasets[0].hf_dataset.features)
for dataset in self._datasets:
intersection_data_keys.intersection_update(dataset.hf_dataset.features)
if len(intersection_data_keys) == 0:
self.disabled_features = set()
intersection_features = set(self._datasets[0].features)
for ds in self._datasets:
intersection_features.intersection_update(ds.features)
if len(intersection_features) == 0:
raise RuntimeError(
"Multiple datasets were provided but they had no keys common to all of them. The "
"multi-dataset functionality currently only keeps common keys."
"Multiple datasets were provided but they had no keys common to all of them. "
"The multi-dataset functionality currently only keeps common keys."
)
for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True):
extra_keys = set(dataset.hf_dataset.features).difference(intersection_data_keys)
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
extra_keys = set(ds.features).difference(intersection_features)
logging.warning(
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
"other datasets."
)
self.disabled_data_keys.update(extra_keys)
self.disabled_features.update(extra_keys)
self.root = root
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
self.stats = aggregate_stats(self._datasets)
@ -1054,9 +1058,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
def features(self) -> datasets.Features:
features = {}
for dataset in self._datasets:
features.update(
{k: v for k, v in dataset.hf_features.items() if k not in self.disabled_data_keys}
)
features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features})
return features
@property
@ -1120,7 +1122,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
item = self._datasets[dataset_idx][idx - start_idx]
item["dataset_index"] = torch.tensor(dataset_idx)
for data_key in self.disabled_data_keys:
for data_key in self.disabled_features:
if data_key in item:
del item[data_key]

View File

@ -14,8 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import textwrap
import warnings
from itertools import accumulate
from pathlib import Path
from pprint import pformat
@ -212,8 +212,8 @@ class BackwardCompatibilityError(Exception):
"Open the top cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped target.",
"Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the sweatshirt.", ...
If you encounter a problem, contact LeRobot maintainers on Discord ('https://discord.com/invite/s3KuuzsPFb')
or open an issue on GitHub.
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
""")
super().__init__(message)
@ -226,12 +226,11 @@ def check_version_compatibility(
if major_to_check < current_major and enforce_breaking_major:
raise BackwardCompatibilityError(repo_id, version_to_check)
elif float(version_to_check.strip("v")) < float(current_version.strip("v")):
warnings.warn(
logging.warning(
f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the
codebase. The current codebase version is {current_version}. You should be fine since
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
stacklevel=1,
)
@ -245,13 +244,12 @@ def get_hub_safe_version(repo_id: str, version: str) -> str:
if num_version >= 2.0 and all(v < 2.0 for v in hub_num_versions):
raise BackwardCompatibilityError(repo_id, version)
warnings.warn(
logging.warning(
f"""You are trying to load a dataset from {repo_id} created with a previous version of the
codebase. The following versions are available: {branches}.
The requested version ('{version}') is not found. You should be fine since
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
stacklevel=1,
)
if "main" not in branches:
raise ValueError(f"Version 'main' not found on {repo_id}")

View File

@ -15,6 +15,8 @@
# limitations under the License.
"""
This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.
Note: Since the original Aloha datasets don't use shadow motors, you need to comment those out in
lerobot/configs/robot/aloha.yaml before running this script.
"""

View File

@ -103,11 +103,11 @@ import argparse
import contextlib
import filecmp
import json
import logging
import math
import shutil
import subprocess
import tempfile
import warnings
from pathlib import Path
import datasets
@ -461,9 +461,8 @@ def convert_dataset(
video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
if single_task and "language_instruction" in dataset.column_names:
warnings.warn(
logging.warning(
"'single_task' provided but 'language_instruction' tasks_col found. Using 'language_instruction'.",
stacklevel=1,
)
single_task = None
tasks_col = "language_instruction"
@ -642,7 +641,7 @@ def main():
parser.add_argument(
"--license",
type=str,
default="mit",
default="apache-2.0",
help="Repo license. Must be one of https://huggingface.co/docs/hub/repositories-licenses. Defaults to mit.",
)
parser.add_argument(

View File

@ -17,7 +17,7 @@ from lerobot.common.datasets.utils import (
get_hf_features_from_features,
hf_transform_to_torch,
)
from tests.fixtures.defaults import (
from tests.fixtures.constants import (
DEFAULT_FPS,
DUMMY_CAMERA_FEATURES,
DUMMY_MOTOR_FEATURES,

View File

@ -5,7 +5,7 @@ import pytest
from huggingface_hub.utils import filter_repo_objects
from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH
from tests.fixtures.defaults import LEROBOT_TEST_DIR
from tests.fixtures.constants import LEROBOT_TEST_DIR
@pytest.fixture(scope="session")

View File

@ -44,7 +44,7 @@ from lerobot.common.datasets.utils import (
unflatten_dict,
)
from lerobot.common.utils.utils import init_hydra_config, seeded_context
from tests.fixtures.defaults import DUMMY_REPO_ID
from tests.fixtures.constants import DUMMY_REPO_ID
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, make_robot

View File

@ -9,7 +9,7 @@ from lerobot.common.datasets.utils import (
get_delta_indices,
hf_transform_to_torch,
)
from tests.fixtures.defaults import DUMMY_MOTOR_FEATURES
from tests.fixtures.constants import DUMMY_MOTOR_FEATURES
@pytest.fixture(scope="module")

View File

@ -21,7 +21,7 @@ from pathlib import Path
import pytest
from tests.fixtures.defaults import DUMMY_REPO_ID
from tests.fixtures.constants import DUMMY_REPO_ID
from tests.utils import require_package