diff --git a/examples/port_datasets/pusht_zarr.py b/examples/port_datasets/pusht_zarr.py index 4ff361cb..53855ec8 100644 --- a/examples/port_datasets/pusht_zarr.py +++ b/examples/port_datasets/pusht_zarr.py @@ -200,8 +200,6 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T dataset.save_episode() - dataset.consolidate() - if push_to_hub: dataset.push_to_hub() hub_api = HfApi() diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index d4224b7e..81280e68 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -39,7 +39,6 @@ from lerobot.common.datasets.utils import ( append_jsonlines, backward_compatible_episodes_stats, check_delta_timestamps, - check_frame_features, check_timestamps_sync, check_version_compatibility, create_empty_dataset_info, @@ -55,6 +54,8 @@ from lerobot.common.datasets.utils import ( load_info, load_stats, load_tasks, + validate_episode_buffer, + validate_frame, write_episode, write_episode_stats, write_info, @@ -256,6 +257,9 @@ class LeRobotDatasetMetadata: self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"} self.info["total_videos"] += len(self.video_keys) + if len(self.video_keys) > 0: + self.update_video_info() + write_info(self.info, self.root) episode_dict = { @@ -270,7 +274,7 @@ class LeRobotDatasetMetadata: self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats write_episode_stats(episode_index, episode_stats, self.root) - def write_video_info(self) -> None: + def update_video_info(self) -> None: """ Warning: this function writes info from first episode videos, implicitly assuming that all videos have been encoded the same way. Also, this means it assumes the first episode exists. @@ -280,8 +284,6 @@ class LeRobotDatasetMetadata: video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key) self.info["features"][key]["info"] = get_video_info(video_path) - write_json(self.info, self.root / INFO_PATH) - def __repr__(self): feature_keys = list(self.features) return ( @@ -506,9 +508,6 @@ class LeRobotDataset(torch.utils.data.Dataset): check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s) self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps) - # Available stats implies all videos have been encoded and dataset is iterable - self.consolidated = self.meta.stats is not None - def push_to_hub( self, branch: str | None = None, @@ -519,13 +518,6 @@ class LeRobotDataset(torch.utils.data.Dataset): allow_patterns: list[str] | str | None = None, **card_kwargs, ) -> None: - if not self.consolidated: - logging.warning( - "You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet. " - "Consolidating first." - ) - self.consolidate() - ignore_patterns = ["images/"] if not push_videos: ignore_patterns.append("videos/") @@ -779,7 +771,7 @@ class LeRobotDataset(torch.utils.data.Dataset): if isinstance(frame[name], torch.Tensor): frame[name] = frame[name].numpy() - check_frame_features(frame, self.features) + validate_frame(frame, self.features) if self.episode_buffer is None: self.episode_buffer = self.create_episode_buffer() @@ -815,41 +807,25 @@ class LeRobotDataset(torch.utils.data.Dataset): self.episode_buffer["size"] += 1 - def save_episode(self, encode_videos: bool = True, episode_data: dict | None = None) -> None: + def save_episode(self, episode_data: dict | None = None) -> None: """ - This will save to disk the current episode in self.episode_buffer. Note that since it affects files on - disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to - the hub. + This will save to disk the current episode in self.episode_buffer. - Use 'encode_videos' if you want to encode videos during the saving of this episode. Otherwise, - you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend - time for video encoding. + Args: + episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will + save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to + None. """ if not episode_data: episode_buffer = self.episode_buffer + validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features) + # size and task are special cases that won't be added to hf_dataset episode_length = episode_buffer.pop("size") tasks = episode_buffer.pop("task") episode_tasks = list(set(tasks)) - episode_index = episode_buffer["episode_index"] - if episode_index != self.meta.total_episodes: - # TODO(aliberts): Add option to use existing episode_index - raise NotImplementedError( - "You might have manually provided the episode_buffer with an episode_index that doesn't " - "match the total number of episodes already in the dataset. This is not supported for now." - ) - - if episode_length == 0: - raise ValueError( - "You must add one or several frames with `add_frame` before calling `add_episode`." - ) - - if not set(episode_buffer.keys()) == set(self.features): - raise ValueError( - f"Features from `episode_buffer` don't match the ones in `self.features`: '{set(episode_buffer.keys())}' vs '{set(self.features)}'" - ) episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length) episode_buffer["episode_index"] = np.full((episode_length,), episode_index) @@ -875,16 +851,29 @@ class LeRobotDataset(torch.utils.data.Dataset): ep_stats = compute_episode_stats(episode_buffer, self.features) self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats) - if encode_videos and len(self.meta.video_keys) > 0: + if len(self.meta.video_keys) > 0: video_paths = self.encode_episode_videos(episode_index) for key in self.meta.video_keys: episode_buffer[key] = video_paths[key] + self.hf_dataset = self.load_hf_dataset() + self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) + check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s) + + video_files = list(self.root.rglob("*.mp4")) + assert len(video_files) == self.num_episodes * len(self.meta.video_keys) + + parquet_files = list(self.root.rglob("*.parquet")) + assert len(parquet_files) == self.num_episodes + + # delete images + img_dir = self.root / "images" + if img_dir.is_dir(): + shutil.rmtree(self.root / "images") + if not episode_data: # Reset the buffer self.episode_buffer = self.create_episode_buffer() - self.consolidated = False - def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None: episode_dict = {key: episode_buffer[key] for key in self.hf_features} ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train") @@ -959,28 +948,6 @@ class LeRobotDataset(torch.utils.data.Dataset): return video_paths - def consolidate(self, keep_image_files: bool = False) -> None: - self.hf_dataset = self.load_hf_dataset() - self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) - check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s) - - if len(self.meta.video_keys) > 0: - self.encode_videos() - self.meta.write_video_info() - - if not keep_image_files: - img_dir = self.root / "images" - if img_dir.is_dir(): - shutil.rmtree(self.root / "images") - - video_files = list(self.root.rglob("*.mp4")) - assert len(video_files) == self.num_episodes * len(self.meta.video_keys) - - parquet_files = list(self.root.rglob("*.parquet")) - assert len(parquet_files) == self.num_episodes - - self.consolidated = True - @classmethod def create( cls, @@ -1019,12 +986,6 @@ class LeRobotDataset(torch.utils.data.Dataset): # TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer obj.episode_buffer = obj.create_episode_buffer() - # This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It - # is used to know when certain operations are need (for instance, computing dataset statistics). In - # order to be able to push the dataset to the hub, it needs to be consolidated first by calling - # self.consolidate(). - obj.consolidated = True - obj.episodes = None obj.hf_dataset = None obj.image_transforms = None diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index c9b0c345..6125480c 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -644,25 +644,25 @@ class IterableNamespace(SimpleNamespace): return vars(self).keys() -def check_frame_features(frame: dict, features: dict): +def validate_frame(frame: dict, features: dict): optional_features = {"timestamp"} expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"} actual_features = set(frame.keys()) - error_message = check_features_presence(actual_features, expected_features, optional_features) + error_message = validate_features_presence(actual_features, expected_features, optional_features) if "task" in frame: - error_message += check_feature_string("task", frame["task"]) + error_message += validate_feature_string("task", frame["task"]) common_features = actual_features & (expected_features | optional_features) for name in common_features - {"task"}: - error_message += check_feature_dtype_and_shape(name, features[name], frame[name]) + error_message += validate_feature_dtype_and_shape(name, features[name], frame[name]) if error_message: raise ValueError(error_message) -def check_features_presence( +def validate_features_presence( actual_features: set[str], expected_features: set[str], optional_features: set[str] ): error_message = "" @@ -679,20 +679,22 @@ def check_features_presence( return error_message -def check_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str): +def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str): expected_dtype = feature["dtype"] expected_shape = feature["shape"] if is_valid_numpy_dtype_string(expected_dtype): - return check_feature_numpy_array(name, expected_dtype, expected_shape, value) + return validate_feature_numpy_array(name, expected_dtype, expected_shape, value) elif expected_dtype in ["image", "video"]: - return check_feature_image_or_video(name, expected_shape, value) + return validate_feature_image_or_video(name, expected_shape, value) elif expected_dtype == "string": - return check_feature_string(name, value) + return validate_feature_string(name, value) else: raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.") -def check_feature_numpy_array(name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray): +def validate_feature_numpy_array( + name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray +): error_message = "" if isinstance(value, np.ndarray): actual_dtype = value.dtype @@ -709,7 +711,7 @@ def check_feature_numpy_array(name: str, expected_dtype: str, expected_shape: li return error_message -def check_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image): +def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image): # Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads. error_message = "" if isinstance(value, np.ndarray): @@ -725,7 +727,33 @@ def check_feature_image_or_video(name: str, expected_shape: list[str], value: np return error_message -def check_feature_string(name: str, value: str): +def validate_feature_string(name: str, value: str): if not isinstance(value, str): return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n" return "" + + +def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict): + if "size" not in episode_buffer: + raise ValueError("size key not found in episode_buffer") + + if "task" not in episode_buffer: + raise ValueError("task key not found in episode_buffer") + + if episode_buffer["episode_index"] != total_episodes: + # TODO(aliberts): Add option to use existing episode_index + raise NotImplementedError( + "You might have manually provided the episode_buffer with an episode_index that doesn't " + "match the total number of episodes already in the dataset. This is not supported for now." + ) + + if episode_buffer["size"] == 0: + raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.") + + buffer_keys = set(episode_buffer.keys()) - {"task", "size"} + if not buffer_keys == set(features): + raise ValueError( + f"Features from `episode_buffer` don't match the ones in `features`." + f"In episode_buffer not in features: {buffer_keys - set(features)}" + f"In features not in episode_buffer: {set(features) - buffer_keys}" + ) diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index dee2792d..e6103271 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -299,8 +299,6 @@ def record( log_say("Stop recording", cfg.play_sounds, blocking=True) stop_recording(robot, listener, cfg.display_cameras) - dataset.consolidate() - if cfg.push_to_hub: dataset.push_to_hub(tags=cfg.tags, private=cfg.private) diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index 811e29b7..e3604591 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -1,5 +1,7 @@ import random +from functools import partial from pathlib import Path +from typing import Protocol from unittest.mock import patch import datasets @@ -17,7 +19,6 @@ from lerobot.common.datasets.utils import ( get_hf_features_from_features, hf_transform_to_torch, ) -from lerobot.common.robot_devices.robots.utils import Robot from tests.fixtures.constants import ( DEFAULT_FPS, DUMMY_CAMERA_FEATURES, @@ -28,6 +29,10 @@ from tests.fixtures.constants import ( ) +class LeRobotDatasetFactory(Protocol): + def __call__(self, *args, **kwargs) -> LeRobotDataset: ... + + def get_task_index(task_dicts: dict, task: str) -> int: tasks = {d["task_index"]: d["task"] for d in task_dicts.values()} task_to_task_index = {task: task_idx for task_idx, task in tasks.items()} @@ -358,7 +363,7 @@ def lerobot_dataset_factory( hf_dataset_factory, mock_snapshot_download_factory, lerobot_dataset_metadata_factory, -): +) -> LeRobotDatasetFactory: def _create_lerobot_dataset( root: Path, repo_id: str = DUMMY_REPO_ID, @@ -430,17 +435,5 @@ def lerobot_dataset_factory( @pytest.fixture(scope="session") -def empty_lerobot_dataset_factory(): - def _create_empty_lerobot_dataset( - root: Path, - repo_id: str = DUMMY_REPO_ID, - fps: int = DEFAULT_FPS, - robot: Robot | None = None, - robot_type: str | None = None, - features: dict | None = None, - ) -> LeRobotDataset: - return LeRobotDataset.create( - repo_id=repo_id, fps=fps, root=root, robot=robot, robot_type=robot_type, features=features - ) - - return _create_empty_lerobot_dataset +def empty_lerobot_dataset_factory() -> LeRobotDatasetFactory: + return partial(LeRobotDataset.create, repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 6d358eea..61b68aa8 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -184,8 +184,7 @@ def test_add_frame(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"}) - dataset.save_episode(encode_videos=False) - dataset.consolidate() + dataset.save_episode() assert len(dataset) == 1 assert dataset[0]["task"] == "Dummy task" @@ -197,8 +196,7 @@ def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (2,), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": torch.randn(2), "task": "Dummy task"}) - dataset.save_episode(encode_videos=False) - dataset.consolidate() + dataset.save_episode() assert dataset[0]["state"].shape == torch.Size([2]) @@ -207,8 +205,7 @@ def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (2, 4), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": torch.randn(2, 4), "task": "Dummy task"}) - dataset.save_episode(encode_videos=False) - dataset.consolidate() + dataset.save_episode() assert dataset[0]["state"].shape == torch.Size([2, 4]) @@ -217,8 +214,7 @@ def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (2, 4, 3), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "Dummy task"}) - dataset.save_episode(encode_videos=False) - dataset.consolidate() + dataset.save_episode() assert dataset[0]["state"].shape == torch.Size([2, 4, 3]) @@ -227,8 +223,7 @@ def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "Dummy task"}) - dataset.save_episode(encode_videos=False) - dataset.consolidate() + dataset.save_episode() assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5]) @@ -237,8 +232,7 @@ def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5, 1), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "Dummy task"}) - dataset.save_episode(encode_videos=False) - dataset.consolidate() + dataset.save_episode() assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1]) @@ -247,8 +241,7 @@ def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": np.array([1], dtype=np.float32), "task": "Dummy task"}) - dataset.save_episode(encode_videos=False) - dataset.consolidate() + dataset.save_episode() assert dataset[0]["state"].ndim == 0 @@ -257,8 +250,7 @@ def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory): features = {"caption": {"dtype": "string", "shape": (1,), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"caption": "Dummy caption", "task": "Dummy task"}) - dataset.save_episode(encode_videos=False) - dataset.consolidate() + dataset.save_episode() assert dataset[0]["caption"] == "Dummy caption" @@ -287,14 +279,13 @@ def test_add_frame_image_wrong_range(image_dataset): dataset = image_dataset dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255, "task": "Dummy task"}) with pytest.raises(FileNotFoundError): - dataset.save_episode(encode_videos=False) + dataset.save_episode() def test_add_frame_image(image_dataset): dataset = image_dataset dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"}) - dataset.save_episode(encode_videos=False) - dataset.consolidate() + dataset.save_episode() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -302,8 +293,7 @@ def test_add_frame_image(image_dataset): def test_add_frame_image_h_w_c(image_dataset): dataset = image_dataset dataset.add_frame({"image": np.random.rand(*DUMMY_HWC), "task": "Dummy task"}) - dataset.save_episode(encode_videos=False) - dataset.consolidate() + dataset.save_episode() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -312,8 +302,7 @@ def test_add_frame_image_uint8(image_dataset): dataset = image_dataset image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8) dataset.add_frame({"image": image, "task": "Dummy task"}) - dataset.save_episode(encode_videos=False) - dataset.consolidate() + dataset.save_episode() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -322,8 +311,7 @@ def test_add_frame_image_pil(image_dataset): dataset = image_dataset image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8) dataset.add_frame({"image": Image.fromarray(image), "task": "Dummy task"}) - dataset.save_episode(encode_videos=False) - dataset.consolidate() + dataset.save_episode() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -338,7 +326,6 @@ def test_image_array_to_pil_image_wrong_range_float_0_255(): # - [ ] test various attributes & state from init and create # - [ ] test init with episodes and check num_frames # - [ ] test add_episode -# - [ ] test consolidate # - [ ] test push_to_hub # - [ ] test smaller methods