diff --git a/examples/port_datasets/pusht_zarr.py b/examples/port_datasets/pusht_zarr.py index 622fbd14..53855ec8 100644 --- a/examples/port_datasets/pusht_zarr.py +++ b/examples/port_datasets/pusht_zarr.py @@ -3,8 +3,10 @@ from pathlib import Path import numpy as np import torch +from huggingface_hub import HfApi -from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset +from lerobot.common.constants import HF_LEROBOT_HOME +from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw PUSHT_TASK = "Push the T-shaped blue block onto the T-shaped green target surface." @@ -134,8 +136,8 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T if mode not in ["video", "image", "keypoints"]: raise ValueError(mode) - if (LEROBOT_HOME / repo_id).exists(): - shutil.rmtree(LEROBOT_HOME / repo_id) + if (HF_LEROBOT_HOME / repo_id).exists(): + shutil.rmtree(HF_LEROBOT_HOME / repo_id) if not raw_dir.exists(): download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw") @@ -198,10 +200,10 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T dataset.save_episode() - dataset.consolidate() - if push_to_hub: dataset.push_to_hub() + hub_api = HfApi() + hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset") if __name__ == "__main__": diff --git a/lerobot/common/constants.py b/lerobot/common/constants.py index 34da4ac0..d0c9845a 100644 --- a/lerobot/common/constants.py +++ b/lerobot/common/constants.py @@ -1,4 +1,9 @@ # keys +import os +from pathlib import Path + +from huggingface_hub.constants import HF_HOME + OBS_ENV = "observation.environment_state" OBS_ROBOT = "observation.state" OBS_IMAGE = "observation.image" @@ -15,3 +20,13 @@ TRAINING_STEP = "training_step.json" OPTIMIZER_STATE = "optimizer_state.safetensors" OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json" SCHEDULER_STATE = "scheduler_state.json" + +# cache dir +default_cache_path = Path(HF_HOME) / "lerobot" +HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser() + +if "LEROBOT_HOME" in os.environ: + raise ValueError( + f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n" + "'LEROBOT_HOME' is deprecated, please use 'HF_LEROBOT_HOME' instead." + ) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index dfdb3618..81280e68 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import os import shutil from pathlib import Path from typing import Callable @@ -29,6 +28,7 @@ from huggingface_hub import HfApi, snapshot_download from huggingface_hub.constants import REPOCARD_NAME from packaging import version +from lerobot.common.constants import HF_LEROBOT_HOME from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image from lerobot.common.datasets.utils import ( @@ -39,7 +39,6 @@ from lerobot.common.datasets.utils import ( append_jsonlines, backward_compatible_episodes_stats, check_delta_timestamps, - check_frame_features, check_timestamps_sync, check_version_compatibility, create_empty_dataset_info, @@ -55,6 +54,8 @@ from lerobot.common.datasets.utils import ( load_info, load_stats, load_tasks, + validate_episode_buffer, + validate_frame, write_episode, write_episode_stats, write_info, @@ -71,7 +72,6 @@ from lerobot.common.robot_devices.robots.utils import Robot # For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md CODEBASE_VERSION = "v2.1" -LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser() class LeRobotDatasetMetadata: @@ -84,7 +84,7 @@ class LeRobotDatasetMetadata: ): self.repo_id = repo_id self.revision = revision if revision else CODEBASE_VERSION - self.root = Path(root) if root is not None else LEROBOT_HOME / repo_id + self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id try: if force_cache_sync: @@ -257,6 +257,9 @@ class LeRobotDatasetMetadata: self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"} self.info["total_videos"] += len(self.video_keys) + if len(self.video_keys) > 0: + self.update_video_info() + write_info(self.info, self.root) episode_dict = { @@ -271,7 +274,7 @@ class LeRobotDatasetMetadata: self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats write_episode_stats(episode_index, episode_stats, self.root) - def write_video_info(self) -> None: + def update_video_info(self) -> None: """ Warning: this function writes info from first episode videos, implicitly assuming that all videos have been encoded the same way. Also, this means it assumes the first episode exists. @@ -281,8 +284,6 @@ class LeRobotDatasetMetadata: video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key) self.info["features"][key]["info"] = get_video_info(video_path) - write_json(self.info, self.root / INFO_PATH) - def __repr__(self): feature_keys = list(self.features) return ( @@ -308,7 +309,7 @@ class LeRobotDatasetMetadata: """Creates metadata for a LeRobotDataset.""" obj = cls.__new__(cls) obj.repo_id = repo_id - obj.root = Path(root) if root is not None else LEROBOT_HOME / repo_id + obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id obj.root.mkdir(parents=True, exist_ok=False) @@ -463,7 +464,7 @@ class LeRobotDataset(torch.utils.data.Dataset): """ super().__init__() self.repo_id = repo_id - self.root = Path(root) if root else LEROBOT_HOME / repo_id + self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id self.image_transforms = image_transforms self.delta_timestamps = delta_timestamps self.episodes = episodes @@ -507,9 +508,6 @@ class LeRobotDataset(torch.utils.data.Dataset): check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s) self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps) - # Available stats implies all videos have been encoded and dataset is iterable - self.consolidated = self.meta.stats is not None - def push_to_hub( self, branch: str | None = None, @@ -520,13 +518,6 @@ class LeRobotDataset(torch.utils.data.Dataset): allow_patterns: list[str] | str | None = None, **card_kwargs, ) -> None: - if not self.consolidated: - logging.warning( - "You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet. " - "Consolidating first." - ) - self.consolidate() - ignore_patterns = ["images/"] if not push_videos: ignore_patterns.append("videos/") @@ -780,7 +771,7 @@ class LeRobotDataset(torch.utils.data.Dataset): if isinstance(frame[name], torch.Tensor): frame[name] = frame[name].numpy() - check_frame_features(frame, self.features) + validate_frame(frame, self.features) if self.episode_buffer is None: self.episode_buffer = self.create_episode_buffer() @@ -816,41 +807,25 @@ class LeRobotDataset(torch.utils.data.Dataset): self.episode_buffer["size"] += 1 - def save_episode(self, encode_videos: bool = True, episode_data: dict | None = None) -> None: + def save_episode(self, episode_data: dict | None = None) -> None: """ - This will save to disk the current episode in self.episode_buffer. Note that since it affects files on - disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to - the hub. + This will save to disk the current episode in self.episode_buffer. - Use 'encode_videos' if you want to encode videos during the saving of this episode. Otherwise, - you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend - time for video encoding. + Args: + episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will + save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to + None. """ if not episode_data: episode_buffer = self.episode_buffer + validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features) + # size and task are special cases that won't be added to hf_dataset episode_length = episode_buffer.pop("size") tasks = episode_buffer.pop("task") episode_tasks = list(set(tasks)) - episode_index = episode_buffer["episode_index"] - if episode_index != self.meta.total_episodes: - # TODO(aliberts): Add option to use existing episode_index - raise NotImplementedError( - "You might have manually provided the episode_buffer with an episode_index that doesn't " - "match the total number of episodes already in the dataset. This is not supported for now." - ) - - if episode_length == 0: - raise ValueError( - "You must add one or several frames with `add_frame` before calling `add_episode`." - ) - - if not set(episode_buffer.keys()) == set(self.features): - raise ValueError( - f"Features from `episode_buffer` don't match the ones in `self.features`: '{set(episode_buffer.keys())}' vs '{set(self.features)}'" - ) episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length) episode_buffer["episode_index"] = np.full((episode_length,), episode_index) @@ -876,16 +851,29 @@ class LeRobotDataset(torch.utils.data.Dataset): ep_stats = compute_episode_stats(episode_buffer, self.features) self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats) - if encode_videos and len(self.meta.video_keys) > 0: + if len(self.meta.video_keys) > 0: video_paths = self.encode_episode_videos(episode_index) for key in self.meta.video_keys: episode_buffer[key] = video_paths[key] + self.hf_dataset = self.load_hf_dataset() + self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) + check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s) + + video_files = list(self.root.rglob("*.mp4")) + assert len(video_files) == self.num_episodes * len(self.meta.video_keys) + + parquet_files = list(self.root.rglob("*.parquet")) + assert len(parquet_files) == self.num_episodes + + # delete images + img_dir = self.root / "images" + if img_dir.is_dir(): + shutil.rmtree(self.root / "images") + if not episode_data: # Reset the buffer self.episode_buffer = self.create_episode_buffer() - self.consolidated = False - def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None: episode_dict = {key: episode_buffer[key] for key in self.hf_features} ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train") @@ -960,28 +948,6 @@ class LeRobotDataset(torch.utils.data.Dataset): return video_paths - def consolidate(self, keep_image_files: bool = False) -> None: - self.hf_dataset = self.load_hf_dataset() - self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) - check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s) - - if len(self.meta.video_keys) > 0: - self.encode_videos() - self.meta.write_video_info() - - if not keep_image_files: - img_dir = self.root / "images" - if img_dir.is_dir(): - shutil.rmtree(self.root / "images") - - video_files = list(self.root.rglob("*.mp4")) - assert len(video_files) == self.num_episodes * len(self.meta.video_keys) - - parquet_files = list(self.root.rglob("*.parquet")) - assert len(parquet_files) == self.num_episodes - - self.consolidated = True - @classmethod def create( cls, @@ -1020,12 +986,6 @@ class LeRobotDataset(torch.utils.data.Dataset): # TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer obj.episode_buffer = obj.create_episode_buffer() - # This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It - # is used to know when certain operations are need (for instance, computing dataset statistics). In - # order to be able to push the dataset to the hub, it needs to be consolidated first by calling - # self.consolidate(). - obj.consolidated = True - obj.episodes = None obj.hf_dataset = None obj.image_transforms = None @@ -1056,7 +1016,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): ): super().__init__() self.repo_ids = repo_ids - self.root = Path(root) if root else LEROBOT_HOME + self.root = Path(root) if root else HF_LEROBOT_HOME self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids} # Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which # are handled by this class. diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index c9b0c345..6125480c 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -644,25 +644,25 @@ class IterableNamespace(SimpleNamespace): return vars(self).keys() -def check_frame_features(frame: dict, features: dict): +def validate_frame(frame: dict, features: dict): optional_features = {"timestamp"} expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"} actual_features = set(frame.keys()) - error_message = check_features_presence(actual_features, expected_features, optional_features) + error_message = validate_features_presence(actual_features, expected_features, optional_features) if "task" in frame: - error_message += check_feature_string("task", frame["task"]) + error_message += validate_feature_string("task", frame["task"]) common_features = actual_features & (expected_features | optional_features) for name in common_features - {"task"}: - error_message += check_feature_dtype_and_shape(name, features[name], frame[name]) + error_message += validate_feature_dtype_and_shape(name, features[name], frame[name]) if error_message: raise ValueError(error_message) -def check_features_presence( +def validate_features_presence( actual_features: set[str], expected_features: set[str], optional_features: set[str] ): error_message = "" @@ -679,20 +679,22 @@ def check_features_presence( return error_message -def check_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str): +def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str): expected_dtype = feature["dtype"] expected_shape = feature["shape"] if is_valid_numpy_dtype_string(expected_dtype): - return check_feature_numpy_array(name, expected_dtype, expected_shape, value) + return validate_feature_numpy_array(name, expected_dtype, expected_shape, value) elif expected_dtype in ["image", "video"]: - return check_feature_image_or_video(name, expected_shape, value) + return validate_feature_image_or_video(name, expected_shape, value) elif expected_dtype == "string": - return check_feature_string(name, value) + return validate_feature_string(name, value) else: raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.") -def check_feature_numpy_array(name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray): +def validate_feature_numpy_array( + name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray +): error_message = "" if isinstance(value, np.ndarray): actual_dtype = value.dtype @@ -709,7 +711,7 @@ def check_feature_numpy_array(name: str, expected_dtype: str, expected_shape: li return error_message -def check_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image): +def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image): # Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads. error_message = "" if isinstance(value, np.ndarray): @@ -725,7 +727,33 @@ def check_feature_image_or_video(name: str, expected_shape: list[str], value: np return error_message -def check_feature_string(name: str, value: str): +def validate_feature_string(name: str, value: str): if not isinstance(value, str): return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n" return "" + + +def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict): + if "size" not in episode_buffer: + raise ValueError("size key not found in episode_buffer") + + if "task" not in episode_buffer: + raise ValueError("task key not found in episode_buffer") + + if episode_buffer["episode_index"] != total_episodes: + # TODO(aliberts): Add option to use existing episode_index + raise NotImplementedError( + "You might have manually provided the episode_buffer with an episode_index that doesn't " + "match the total number of episodes already in the dataset. This is not supported for now." + ) + + if episode_buffer["size"] == 0: + raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.") + + buffer_keys = set(episode_buffer.keys()) - {"task", "size"} + if not buffer_keys == set(features): + raise ValueError( + f"Features from `episode_buffer` don't match the ones in `features`." + f"In episode_buffer not in features: {buffer_keys - set(features)}" + f"In features not in episode_buffer: {set(features) - buffer_keys}" + ) diff --git a/lerobot/common/datasets/v21/batch_convert_dataset_v20_to_v21.py b/lerobot/common/datasets/v21/batch_convert_dataset_v20_to_v21.py index 624827bd..cee9da16 100644 --- a/lerobot/common/datasets/v21/batch_convert_dataset_v20_to_v21.py +++ b/lerobot/common/datasets/v21/batch_convert_dataset_v20_to_v21.py @@ -29,8 +29,9 @@ LOCAL_DIR = Path("data/") def batch_convert(): status = {} + LOCAL_DIR.mkdir(parents=True, exist_ok=True) logfile = LOCAL_DIR / "conversion_log_v21.txt" - for num, repo_id in available_datasets: + for num, repo_id in enumerate(available_datasets): print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})") print("---------------------------------------------------------") try: diff --git a/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py index d52a0a10..f55c13c1 100644 --- a/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py +++ b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py @@ -2,10 +2,10 @@ This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to 2.1. It will: -- Generates per-episodes stats and writes them in `episodes_stats.jsonl` +- Generate per-episodes stats and writes them in `episodes_stats.jsonl` - Check consistency between these new stats and the old ones. -- Removes the deprecated `stats.json` (by default) -- Updates codebase_version in `info.json` +- Remove the deprecated `stats.json`. +- Update codebase_version in `info.json`. - Push this new version to the hub on the 'main' branch and tags it with "v2.1". Usage: @@ -80,19 +80,20 @@ if __name__ == "__main__": "--repo-id", type=str, required=True, - help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).", + help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset " + "(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).", ) parser.add_argument( "--branch", type=str, default=None, - help="Repo branch to push your dataset (defaults to the main branch)", + help="Repo branch to push your dataset. Defaults to the main branch.", ) parser.add_argument( "--num-workers", type=int, default=4, - help="Number of workers for parallelizing compute", + help="Number of workers for parallelizing stats compute. Defaults to 4.", ) args = parser.parse_args() diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index dee2792d..e6103271 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -299,8 +299,6 @@ def record( log_say("Stop recording", cfg.play_sounds, blocking=True) stop_recording(robot, listener, cfg.display_cameras) - dataset.consolidate() - if cfg.push_to_hub: dataset.push_to_hub(tags=cfg.tags, private=cfg.private) diff --git a/tests/fixtures/constants.py b/tests/fixtures/constants.py index 7d80d2b7..3201dcf2 100644 --- a/tests/fixtures/constants.py +++ b/tests/fixtures/constants.py @@ -1,6 +1,6 @@ -from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME +from lerobot.common.constants import HF_LEROBOT_HOME -LEROBOT_TEST_DIR = LEROBOT_HOME / "_testing" +LEROBOT_TEST_DIR = HF_LEROBOT_HOME / "_testing" DUMMY_REPO_ID = "dummy/repo" DUMMY_ROBOT_TYPE = "dummy_robot" DUMMY_MOTOR_FEATURES = { diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index 811e29b7..e3604591 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -1,5 +1,7 @@ import random +from functools import partial from pathlib import Path +from typing import Protocol from unittest.mock import patch import datasets @@ -17,7 +19,6 @@ from lerobot.common.datasets.utils import ( get_hf_features_from_features, hf_transform_to_torch, ) -from lerobot.common.robot_devices.robots.utils import Robot from tests.fixtures.constants import ( DEFAULT_FPS, DUMMY_CAMERA_FEATURES, @@ -28,6 +29,10 @@ from tests.fixtures.constants import ( ) +class LeRobotDatasetFactory(Protocol): + def __call__(self, *args, **kwargs) -> LeRobotDataset: ... + + def get_task_index(task_dicts: dict, task: str) -> int: tasks = {d["task_index"]: d["task"] for d in task_dicts.values()} task_to_task_index = {task: task_idx for task_idx, task in tasks.items()} @@ -358,7 +363,7 @@ def lerobot_dataset_factory( hf_dataset_factory, mock_snapshot_download_factory, lerobot_dataset_metadata_factory, -): +) -> LeRobotDatasetFactory: def _create_lerobot_dataset( root: Path, repo_id: str = DUMMY_REPO_ID, @@ -430,17 +435,5 @@ def lerobot_dataset_factory( @pytest.fixture(scope="session") -def empty_lerobot_dataset_factory(): - def _create_empty_lerobot_dataset( - root: Path, - repo_id: str = DUMMY_REPO_ID, - fps: int = DEFAULT_FPS, - robot: Robot | None = None, - robot_type: str | None = None, - features: dict | None = None, - ) -> LeRobotDataset: - return LeRobotDataset.create( - repo_id=repo_id, fps=fps, root=root, robot=robot, robot_type=robot_type, features=features - ) - - return _create_empty_lerobot_dataset +def empty_lerobot_dataset_factory() -> LeRobotDatasetFactory: + return partial(LeRobotDataset.create, repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 3e8b531d..61b68aa8 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -184,8 +184,7 @@ def test_add_frame(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"}) - dataset.save_episode(encode_videos=False) - dataset.consolidate() + dataset.save_episode() assert len(dataset) == 1 assert dataset[0]["task"] == "Dummy task" @@ -197,8 +196,7 @@ def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (2,), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": torch.randn(2), "task": "Dummy task"}) - dataset.save_episode(encode_videos=False) - dataset.consolidate() + dataset.save_episode() assert dataset[0]["state"].shape == torch.Size([2]) @@ -207,8 +205,7 @@ def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (2, 4), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": torch.randn(2, 4), "task": "Dummy task"}) - dataset.save_episode(encode_videos=False) - dataset.consolidate() + dataset.save_episode() assert dataset[0]["state"].shape == torch.Size([2, 4]) @@ -217,8 +214,7 @@ def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (2, 4, 3), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "Dummy task"}) - dataset.save_episode(encode_videos=False) - dataset.consolidate() + dataset.save_episode() assert dataset[0]["state"].shape == torch.Size([2, 4, 3]) @@ -227,8 +223,7 @@ def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "Dummy task"}) - dataset.save_episode(encode_videos=False) - dataset.consolidate() + dataset.save_episode() assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5]) @@ -237,8 +232,7 @@ def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5, 1), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "Dummy task"}) - dataset.save_episode(encode_videos=False) - dataset.consolidate() + dataset.save_episode() assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1]) @@ -247,8 +241,7 @@ def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory): features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"state": np.array([1], dtype=np.float32), "task": "Dummy task"}) - dataset.save_episode(encode_videos=False) - dataset.consolidate() + dataset.save_episode() assert dataset[0]["state"].ndim == 0 @@ -257,8 +250,7 @@ def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory): features = {"caption": {"dtype": "string", "shape": (1,), "names": None}} dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset.add_frame({"caption": "Dummy caption", "task": "Dummy task"}) - dataset.save_episode(encode_videos=False) - dataset.consolidate() + dataset.save_episode() assert dataset[0]["caption"] == "Dummy caption" @@ -287,14 +279,13 @@ def test_add_frame_image_wrong_range(image_dataset): dataset = image_dataset dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255, "task": "Dummy task"}) with pytest.raises(FileNotFoundError): - dataset.save_episode(encode_videos=False) + dataset.save_episode() def test_add_frame_image(image_dataset): dataset = image_dataset dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"}) - dataset.save_episode(encode_videos=False) - dataset.consolidate() + dataset.save_episode() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -302,8 +293,7 @@ def test_add_frame_image(image_dataset): def test_add_frame_image_h_w_c(image_dataset): dataset = image_dataset dataset.add_frame({"image": np.random.rand(*DUMMY_HWC), "task": "Dummy task"}) - dataset.save_episode(encode_videos=False) - dataset.consolidate() + dataset.save_episode() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -312,8 +302,7 @@ def test_add_frame_image_uint8(image_dataset): dataset = image_dataset image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8) dataset.add_frame({"image": image, "task": "Dummy task"}) - dataset.save_episode(encode_videos=False) - dataset.consolidate() + dataset.save_episode() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -322,8 +311,7 @@ def test_add_frame_image_pil(image_dataset): dataset = image_dataset image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8) dataset.add_frame({"image": Image.fromarray(image), "task": "Dummy task"}) - dataset.save_episode(encode_videos=False) - dataset.consolidate() + dataset.save_episode() assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) @@ -338,7 +326,6 @@ def test_image_array_to_pil_image_wrong_range_float_0_255(): # - [ ] test various attributes & state from init and create # - [ ] test init with episodes and check num_frames # - [ ] test add_episode -# - [ ] test consolidate # - [ ] test push_to_hub # - [ ] test smaller methods @@ -581,9 +568,9 @@ def test_create_branch(): def test_dataset_feature_with_forward_slash_raises_error(): # make sure dir does not exist - from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME + from lerobot.common.constants import HF_LEROBOT_HOME - dataset_dir = LEROBOT_HOME / "lerobot/test/with/slash" + dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash" # make sure does not exist if dataset_dir.exists(): dataset_dir.rmdir()