Remove dataset `consolidate` (#752)

This commit is contained in:
Simon Alibert 2025-02-19 16:02:54 +01:00 committed by GitHub
parent 6fe42a72db
commit 969ef745a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 93 additions and 128 deletions

View File

@ -200,8 +200,6 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
dataset.save_episode()
dataset.consolidate()
if push_to_hub:
dataset.push_to_hub()
hub_api = HfApi()

View File

@ -39,7 +39,6 @@ from lerobot.common.datasets.utils import (
append_jsonlines,
backward_compatible_episodes_stats,
check_delta_timestamps,
check_frame_features,
check_timestamps_sync,
check_version_compatibility,
create_empty_dataset_info,
@ -55,6 +54,8 @@ from lerobot.common.datasets.utils import (
load_info,
load_stats,
load_tasks,
validate_episode_buffer,
validate_frame,
write_episode,
write_episode_stats,
write_info,
@ -256,6 +257,9 @@ class LeRobotDatasetMetadata:
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
self.info["total_videos"] += len(self.video_keys)
if len(self.video_keys) > 0:
self.update_video_info()
write_info(self.info, self.root)
episode_dict = {
@ -270,7 +274,7 @@ class LeRobotDatasetMetadata:
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
write_episode_stats(episode_index, episode_stats, self.root)
def write_video_info(self) -> None:
def update_video_info(self) -> None:
"""
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
been encoded the same way. Also, this means it assumes the first episode exists.
@ -280,8 +284,6 @@ class LeRobotDatasetMetadata:
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
self.info["features"][key]["info"] = get_video_info(video_path)
write_json(self.info, self.root / INFO_PATH)
def __repr__(self):
feature_keys = list(self.features)
return (
@ -506,9 +508,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
# Available stats implies all videos have been encoded and dataset is iterable
self.consolidated = self.meta.stats is not None
def push_to_hub(
self,
branch: str | None = None,
@ -519,13 +518,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
allow_patterns: list[str] | str | None = None,
**card_kwargs,
) -> None:
if not self.consolidated:
logging.warning(
"You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet. "
"Consolidating first."
)
self.consolidate()
ignore_patterns = ["images/"]
if not push_videos:
ignore_patterns.append("videos/")
@ -779,7 +771,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
if isinstance(frame[name], torch.Tensor):
frame[name] = frame[name].numpy()
check_frame_features(frame, self.features)
validate_frame(frame, self.features)
if self.episode_buffer is None:
self.episode_buffer = self.create_episode_buffer()
@ -815,41 +807,25 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episode_buffer["size"] += 1
def save_episode(self, encode_videos: bool = True, episode_data: dict | None = None) -> None:
def save_episode(self, episode_data: dict | None = None) -> None:
"""
This will save to disk the current episode in self.episode_buffer. Note that since it affects files on
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
the hub.
This will save to disk the current episode in self.episode_buffer.
Use 'encode_videos' if you want to encode videos during the saving of this episode. Otherwise,
you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend
time for video encoding.
Args:
episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
None.
"""
if not episode_data:
episode_buffer = self.episode_buffer
validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features)
# size and task are special cases that won't be added to hf_dataset
episode_length = episode_buffer.pop("size")
tasks = episode_buffer.pop("task")
episode_tasks = list(set(tasks))
episode_index = episode_buffer["episode_index"]
if episode_index != self.meta.total_episodes:
# TODO(aliberts): Add option to use existing episode_index
raise NotImplementedError(
"You might have manually provided the episode_buffer with an episode_index that doesn't "
"match the total number of episodes already in the dataset. This is not supported for now."
)
if episode_length == 0:
raise ValueError(
"You must add one or several frames with `add_frame` before calling `add_episode`."
)
if not set(episode_buffer.keys()) == set(self.features):
raise ValueError(
f"Features from `episode_buffer` don't match the ones in `self.features`: '{set(episode_buffer.keys())}' vs '{set(self.features)}'"
)
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
@ -875,16 +851,29 @@ class LeRobotDataset(torch.utils.data.Dataset):
ep_stats = compute_episode_stats(episode_buffer, self.features)
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
if encode_videos and len(self.meta.video_keys) > 0:
if len(self.meta.video_keys) > 0:
video_paths = self.encode_episode_videos(episode_index)
for key in self.meta.video_keys:
episode_buffer[key] = video_paths[key]
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
video_files = list(self.root.rglob("*.mp4"))
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
parquet_files = list(self.root.rglob("*.parquet"))
assert len(parquet_files) == self.num_episodes
# delete images
img_dir = self.root / "images"
if img_dir.is_dir():
shutil.rmtree(self.root / "images")
if not episode_data: # Reset the buffer
self.episode_buffer = self.create_episode_buffer()
self.consolidated = False
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
@ -959,28 +948,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
return video_paths
def consolidate(self, keep_image_files: bool = False) -> None:
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
if len(self.meta.video_keys) > 0:
self.encode_videos()
self.meta.write_video_info()
if not keep_image_files:
img_dir = self.root / "images"
if img_dir.is_dir():
shutil.rmtree(self.root / "images")
video_files = list(self.root.rglob("*.mp4"))
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
parquet_files = list(self.root.rglob("*.parquet"))
assert len(parquet_files) == self.num_episodes
self.consolidated = True
@classmethod
def create(
cls,
@ -1019,12 +986,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
obj.episode_buffer = obj.create_episode_buffer()
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It
# is used to know when certain operations are need (for instance, computing dataset statistics). In
# order to be able to push the dataset to the hub, it needs to be consolidated first by calling
# self.consolidate().
obj.consolidated = True
obj.episodes = None
obj.hf_dataset = None
obj.image_transforms = None

View File

@ -644,25 +644,25 @@ class IterableNamespace(SimpleNamespace):
return vars(self).keys()
def check_frame_features(frame: dict, features: dict):
def validate_frame(frame: dict, features: dict):
optional_features = {"timestamp"}
expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"}
actual_features = set(frame.keys())
error_message = check_features_presence(actual_features, expected_features, optional_features)
error_message = validate_features_presence(actual_features, expected_features, optional_features)
if "task" in frame:
error_message += check_feature_string("task", frame["task"])
error_message += validate_feature_string("task", frame["task"])
common_features = actual_features & (expected_features | optional_features)
for name in common_features - {"task"}:
error_message += check_feature_dtype_and_shape(name, features[name], frame[name])
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
if error_message:
raise ValueError(error_message)
def check_features_presence(
def validate_features_presence(
actual_features: set[str], expected_features: set[str], optional_features: set[str]
):
error_message = ""
@ -679,20 +679,22 @@ def check_features_presence(
return error_message
def check_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
expected_dtype = feature["dtype"]
expected_shape = feature["shape"]
if is_valid_numpy_dtype_string(expected_dtype):
return check_feature_numpy_array(name, expected_dtype, expected_shape, value)
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
elif expected_dtype in ["image", "video"]:
return check_feature_image_or_video(name, expected_shape, value)
return validate_feature_image_or_video(name, expected_shape, value)
elif expected_dtype == "string":
return check_feature_string(name, value)
return validate_feature_string(name, value)
else:
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
def check_feature_numpy_array(name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray):
def validate_feature_numpy_array(
name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray
):
error_message = ""
if isinstance(value, np.ndarray):
actual_dtype = value.dtype
@ -709,7 +711,7 @@ def check_feature_numpy_array(name: str, expected_dtype: str, expected_shape: li
return error_message
def check_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
error_message = ""
if isinstance(value, np.ndarray):
@ -725,7 +727,33 @@ def check_feature_image_or_video(name: str, expected_shape: list[str], value: np
return error_message
def check_feature_string(name: str, value: str):
def validate_feature_string(name: str, value: str):
if not isinstance(value, str):
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
return ""
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict):
if "size" not in episode_buffer:
raise ValueError("size key not found in episode_buffer")
if "task" not in episode_buffer:
raise ValueError("task key not found in episode_buffer")
if episode_buffer["episode_index"] != total_episodes:
# TODO(aliberts): Add option to use existing episode_index
raise NotImplementedError(
"You might have manually provided the episode_buffer with an episode_index that doesn't "
"match the total number of episodes already in the dataset. This is not supported for now."
)
if episode_buffer["size"] == 0:
raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.")
buffer_keys = set(episode_buffer.keys()) - {"task", "size"}
if not buffer_keys == set(features):
raise ValueError(
f"Features from `episode_buffer` don't match the ones in `features`."
f"In episode_buffer not in features: {buffer_keys - set(features)}"
f"In features not in episode_buffer: {set(features) - buffer_keys}"
)

View File

@ -299,8 +299,6 @@ def record(
log_say("Stop recording", cfg.play_sounds, blocking=True)
stop_recording(robot, listener, cfg.display_cameras)
dataset.consolidate()
if cfg.push_to_hub:
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)

View File

@ -1,5 +1,7 @@
import random
from functools import partial
from pathlib import Path
from typing import Protocol
from unittest.mock import patch
import datasets
@ -17,7 +19,6 @@ from lerobot.common.datasets.utils import (
get_hf_features_from_features,
hf_transform_to_torch,
)
from lerobot.common.robot_devices.robots.utils import Robot
from tests.fixtures.constants import (
DEFAULT_FPS,
DUMMY_CAMERA_FEATURES,
@ -28,6 +29,10 @@ from tests.fixtures.constants import (
)
class LeRobotDatasetFactory(Protocol):
def __call__(self, *args, **kwargs) -> LeRobotDataset: ...
def get_task_index(task_dicts: dict, task: str) -> int:
tasks = {d["task_index"]: d["task"] for d in task_dicts.values()}
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()}
@ -358,7 +363,7 @@ def lerobot_dataset_factory(
hf_dataset_factory,
mock_snapshot_download_factory,
lerobot_dataset_metadata_factory,
):
) -> LeRobotDatasetFactory:
def _create_lerobot_dataset(
root: Path,
repo_id: str = DUMMY_REPO_ID,
@ -430,17 +435,5 @@ def lerobot_dataset_factory(
@pytest.fixture(scope="session")
def empty_lerobot_dataset_factory():
def _create_empty_lerobot_dataset(
root: Path,
repo_id: str = DUMMY_REPO_ID,
fps: int = DEFAULT_FPS,
robot: Robot | None = None,
robot_type: str | None = None,
features: dict | None = None,
) -> LeRobotDataset:
return LeRobotDataset.create(
repo_id=repo_id, fps=fps, root=root, robot=robot, robot_type=robot_type, features=features
)
return _create_empty_lerobot_dataset
def empty_lerobot_dataset_factory() -> LeRobotDatasetFactory:
return partial(LeRobotDataset.create, repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS)

View File

@ -184,8 +184,7 @@ def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
dataset.save_episode(encode_videos=False)
dataset.consolidate()
dataset.save_episode()
assert len(dataset) == 1
assert dataset[0]["task"] == "Dummy task"
@ -197,8 +196,7 @@ def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2), "task": "Dummy task"})
dataset.save_episode(encode_videos=False)
dataset.consolidate()
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2])
@ -207,8 +205,7 @@ def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4), "task": "Dummy task"})
dataset.save_episode(encode_videos=False)
dataset.consolidate()
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2, 4])
@ -217,8 +214,7 @@ def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4, 3), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "Dummy task"})
dataset.save_episode(encode_videos=False)
dataset.consolidate()
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2, 4, 3])
@ -227,8 +223,7 @@ def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "Dummy task"})
dataset.save_episode(encode_videos=False)
dataset.consolidate()
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5])
@ -237,8 +232,7 @@ def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5, 1), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "Dummy task"})
dataset.save_episode(encode_videos=False)
dataset.consolidate()
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1])
@ -247,8 +241,7 @@ def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": np.array([1], dtype=np.float32), "task": "Dummy task"})
dataset.save_episode(encode_videos=False)
dataset.consolidate()
dataset.save_episode()
assert dataset[0]["state"].ndim == 0
@ -257,8 +250,7 @@ def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory):
features = {"caption": {"dtype": "string", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"caption": "Dummy caption", "task": "Dummy task"})
dataset.save_episode(encode_videos=False)
dataset.consolidate()
dataset.save_episode()
assert dataset[0]["caption"] == "Dummy caption"
@ -287,14 +279,13 @@ def test_add_frame_image_wrong_range(image_dataset):
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255, "task": "Dummy task"})
with pytest.raises(FileNotFoundError):
dataset.save_episode(encode_videos=False)
dataset.save_episode()
def test_add_frame_image(image_dataset):
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
dataset.save_episode(encode_videos=False)
dataset.consolidate()
dataset.save_episode()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@ -302,8 +293,7 @@ def test_add_frame_image(image_dataset):
def test_add_frame_image_h_w_c(image_dataset):
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_HWC), "task": "Dummy task"})
dataset.save_episode(encode_videos=False)
dataset.consolidate()
dataset.save_episode()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@ -312,8 +302,7 @@ def test_add_frame_image_uint8(image_dataset):
dataset = image_dataset
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
dataset.add_frame({"image": image, "task": "Dummy task"})
dataset.save_episode(encode_videos=False)
dataset.consolidate()
dataset.save_episode()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@ -322,8 +311,7 @@ def test_add_frame_image_pil(image_dataset):
dataset = image_dataset
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
dataset.add_frame({"image": Image.fromarray(image), "task": "Dummy task"})
dataset.save_episode(encode_videos=False)
dataset.consolidate()
dataset.save_episode()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@ -338,7 +326,6 @@ def test_image_array_to_pil_image_wrong_range_float_0_255():
# - [ ] test various attributes & state from init and create
# - [ ] test init with episodes and check num_frames
# - [ ] test add_episode
# - [ ] test consolidate
# - [ ] test push_to_hub
# - [ ] test smaller methods