From 1f13bda25be74d1290cf82665f315961710cf549 Mon Sep 17 00:00:00 2001 From: Remi Date: Tue, 19 Nov 2024 12:31:47 +0100 Subject: [PATCH] Improve dataset v2 (#498) --- examples/port_datasets/pusht_zarr.py | 246 ++++++++++++++++++ lerobot/common/datasets/lerobot_dataset.py | 37 ++- lerobot/common/datasets/utils.py | 48 +++- .../datasets/v2/convert_dataset_v1_to_v2.py | 2 +- lerobot/common/robot_devices/control_utils.py | 19 ++ lerobot/scripts/control_robot.py | 43 ++- lerobot/scripts/push_dataset_to_hub.py | 8 +- tests/fixtures/dataset_factories.py | 2 +- tests/test_control_robot.py | 58 ++--- 9 files changed, 393 insertions(+), 70 deletions(-) create mode 100644 examples/port_datasets/pusht_zarr.py diff --git a/examples/port_datasets/pusht_zarr.py b/examples/port_datasets/pusht_zarr.py new file mode 100644 index 00000000..742d1346 --- /dev/null +++ b/examples/port_datasets/pusht_zarr.py @@ -0,0 +1,246 @@ +import shutil +from pathlib import Path + +import numpy as np +import torch + +from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset +from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw + + +def create_empty_dataset(repo_id, mode): + features = { + "observation.state": { + "dtype": "float32", + "shape": (2,), + "names": [ + ["x", "y"], + ], + }, + "action": { + "dtype": "float32", + "shape": (2,), + "names": [ + ["x", "y"], + ], + }, + "next.reward": { + "dtype": "float32", + "shape": (1,), + "names": None, + }, + "next.success": { + "dtype": "bool", + "shape": (1,), + "names": None, + }, + } + + if mode == "keypoints": + features["observation.environment_state"] = { + "dtype": "float32", + "shape": (16,), + "names": [ + "keypoints", + ], + } + else: + features["observation.image"] = { + "dtype": mode, + "shape": (3, 96, 96), + "names": [ + "channel", + "height", + "width", + ], + } + + dataset = LeRobotDataset.create( + repo_id=repo_id, + fps=10, + robot_type="2d pointer", + features=features, + image_writer_threads=4, + ) + return dataset + + +def load_raw_dataset(zarr_path, load_images=True): + try: + from lerobot.common.datasets.push_dataset_to_hub._diffusion_policy_replay_buffer import ( + ReplayBuffer as DiffusionPolicyReplayBuffer, + ) + except ModuleNotFoundError as e: + print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`") + raise e + + zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path) + + env_state = zarr_data["state"][:] + agent_pos = env_state[:, :2] + block_pos = env_state[:, 2:4] + block_angle = env_state[:, 4] + + action = zarr_data["action"][:] + + image = None + if load_images: + # b h w c + image = zarr_data["img"] + + episode_data_index = { + "from": np.array([0] + zarr_data.meta["episode_ends"][:-1].tolist()), + "to": zarr_data.meta["episode_ends"], + } + + return image, agent_pos, block_pos, block_angle, action, episode_data_index + + +def calculate_coverage(block_pos, block_angle): + try: + import pymunk + from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely + except ModuleNotFoundError as e: + print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`") + raise e + + num_frames = len(block_pos) + + coverage = np.zeros((num_frames,)) + # 8 keypoints with 2 coords each + keypoints = np.zeros((num_frames, 16)) + + # Set x, y, theta (in radians) + goal_pos_angle = np.array([256, 256, np.pi / 4]) + goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle) + + for i in range(num_frames): + space = pymunk.Space() + space.gravity = 0, 0 + space.damping = 0 + + # Add walls. + walls = [ + PushTEnv.add_segment(space, (5, 506), (5, 5), 2), + PushTEnv.add_segment(space, (5, 5), (506, 5), 2), + PushTEnv.add_segment(space, (506, 5), (506, 506), 2), + PushTEnv.add_segment(space, (5, 506), (506, 506), 2), + ] + space.add(*walls) + + block_body, block_shapes = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item()) + goal_geom = pymunk_to_shapely(goal_body, block_body.shapes) + block_geom = pymunk_to_shapely(block_body, block_body.shapes) + intersection_area = goal_geom.intersection(block_geom).area + goal_area = goal_geom.area + coverage[i] = intersection_area / goal_area + keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten()) + + return coverage, keypoints + + +def calculate_success(coverage, success_threshold): + return coverage > success_threshold + + +def calculate_reward(coverage, success_threshold): + return np.clip(coverage / success_threshold, 0, 1) + + +def populate_dataset(dataset, episode_data_index, episodes, image, state, env_state, action, reward, success): + if episodes is None: + episodes = range(len(episode_data_index["from"])) + + for ep_idx in episodes: + from_idx = episode_data_index["from"][ep_idx] + to_idx = episode_data_index["to"][ep_idx] + num_frames = to_idx - from_idx + + for frame_idx in range(num_frames): + i = from_idx + frame_idx + + frame = { + "action": torch.from_numpy(action[i]), + # Shift reward and success by +1 until the last item of the episode + "next.reward": reward[i + (frame_idx < num_frames - 1)], + "next.success": success[i + (frame_idx < num_frames - 1)], + } + + frame["observation.state"] = torch.from_numpy(state[i]) + + if env_state is not None: + frame["observation.environment_state"] = torch.from_numpy(env_state[i]) + + if image is not None: + frame["observation.image"] = torch.from_numpy(image[i]) + + dataset.add_frame(frame) + + dataset.save_episode(task="Push the T-shaped blue block onto the T-shaped green target surface.") + + return dataset + + +def port_pusht(raw_dir, repo_id, episodes=None, mode="video", push_to_hub=True): + if mode not in ["video", "image", "keypoints"]: + raise ValueError(mode) + + if (LEROBOT_HOME / repo_id).exists(): + shutil.rmtree(LEROBOT_HOME / repo_id) + + raw_dir = Path(raw_dir) + if not raw_dir.exists(): + download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw") + + image, agent_pos, block_pos, block_angle, action, episode_data_index = load_raw_dataset( + zarr_path=raw_dir / "pusht_cchi_v7_replay.zarr" + ) + + # Calculate success and reward based on the overlapping area + # of the T-object and the T-area. + coverage, keypoints = calculate_coverage(block_pos, block_angle) + success = calculate_success(coverage, success_threshold=0.95) + reward = calculate_reward(coverage, success_threshold=0.95) + + dataset = create_empty_dataset(repo_id, mode) + dataset = populate_dataset( + dataset, + episode_data_index, + episodes, + image=None if mode == "keypoints" else image, + state=agent_pos, + env_state=keypoints if mode == "keypoints" else None, + action=action, + reward=reward, + success=success, + ) + dataset.consolidate() + + if push_to_hub: + dataset.push_to_hub() + + +if __name__ == "__main__": + # To try this script, modify the repo id with your own HuggingFace user (e.g cadene/pusht) + repo_id = "lerobot/pusht" + + episodes = None + # Uncomment if you want to try with a subset (episode 0 and 1) + # episodes = [0, 1] + + modes = ["video", "image", "keypoints"] + # Uncomment if you want to try with a specific mode + # modes = ["video"] + # modes = ["image"] + # modes = ["keypoints"] + + for mode in ["video", "image", "keypoints"]: + if mode in ["image", "keypoints"]: + repo_id += f"_{mode}" + + # download and load raw dataset, create LeRobotDataset, populate it, push to hub + port_pusht("data/lerobot-raw/pusht_raw", repo_id=repo_id, mode=mode, episodes=episodes) + + # Uncomment if you want to loal the local dataset and explore it + # dataset = LeRobotDataset(repo_id=repo_id, local_files_only=True) + # breakpoint() diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 96aac5c0..20b874b5 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -280,6 +280,8 @@ class LeRobotDatasetMetadata: obj.repo_id = repo_id obj.root = Path(root) if root is not None else LEROBOT_HOME / repo_id + obj.root.mkdir(parents=True, exist_ok=False) + if robot is not None: features = get_features_from_robot(robot, use_videos) robot_type = robot.robot_type @@ -293,6 +295,7 @@ class LeRobotDatasetMetadata: "Dataset features must either come from a Robot or explicitly passed upon creation." ) else: + # TODO(aliberts, rcadene): implement sanity check for features features = {**features, **DEFAULT_FEATURES} obj.tasks, obj.stats, obj.episodes = {}, {}, [] @@ -424,11 +427,10 @@ class LeRobotDataset(torch.utils.data.Dataset): self.video_backend = video_backend if video_backend is not None else "pyav" self.delta_indices = None self.local_files_only = local_files_only - self.consolidated = True # Unused attributes self.image_writer = None - self.episode_buffer = {} + self.episode_buffer = None self.root.mkdir(exist_ok=True, parents=True) @@ -451,12 +453,16 @@ class LeRobotDataset(torch.utils.data.Dataset): check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s) self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps) + # Available stats implies all videos have been encoded and dataset is iterable + self.consolidated = self.meta.stats is not None + def push_to_hub( self, tags: list | None = None, text: str | None = None, - license: str | None = "mit", + license: str | None = "apache-2.0", push_videos: bool = True, + private: bool = False, ) -> None: if not self.consolidated: raise RuntimeError( @@ -468,7 +474,13 @@ class LeRobotDataset(torch.utils.data.Dataset): if not push_videos: ignore_patterns.append("videos/") - create_repo(self.repo_id, repo_type="dataset", exist_ok=True) + create_repo( + repo_id=self.repo_id, + private=private, + repo_type="dataset", + exist_ok=True, + ) + upload_folder( repo_id=self.repo_id, folder_path=self.root, @@ -658,7 +670,7 @@ class LeRobotDataset(torch.utils.data.Dataset): current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index return { "size": 0, - **{key: [] if key != "episode_index" else current_ep_idx for key in self.features}, + **{key: current_ep_idx if key == "episode_index" else [] for key in self.features}, } def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path: @@ -681,8 +693,14 @@ class LeRobotDataset(torch.utils.data.Dataset): temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method then needs to be called. """ + # TODO(aliberts, rcadene): Add sanity check for the input, check it's numpy or torch, + # check the dtype and shape matches, etc. + + if self.episode_buffer is None: + self.episode_buffer = self._create_episode_buffer() + frame_index = self.episode_buffer["size"] - timestamp = frame["timestamp"] if "timestamp" in frame else frame_index / self.fps + timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps self.episode_buffer["frame_index"].append(frame_index) self.episode_buffer["timestamp"].append(timestamp) @@ -723,6 +741,11 @@ class LeRobotDataset(torch.utils.data.Dataset): # TODO(aliberts): Add option to use existing episode_index raise NotImplementedError() + if episode_length == 0: + raise ValueError( + "You must add one or several frames with `add_frame` before calling `add_episode`." + ) + task_index = self.meta.get_task_index(task) if not set(episode_buffer.keys()) == set(self.features): @@ -781,7 +804,7 @@ class LeRobotDataset(torch.utils.data.Dataset): # Reset the buffer self.episode_buffer = self._create_episode_buffer() - def start_image_writer(self, num_processes: int = 0, num_threads: int = 1) -> None: + def start_image_writer(self, num_processes: int = 0, num_threads: int = 4) -> None: if isinstance(self.image_writer, AsyncImageWriter): logging.warning( "You are starting a new AsyncImageWriter that is replacing an already exising one in the dataset." diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 1ad27ca9..875d5169 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +import textwrap import warnings from itertools import accumulate from pathlib import Path @@ -139,6 +140,8 @@ def load_info(local_dir: Path) -> dict: def load_stats(local_dir: Path) -> dict: + if not (local_dir / STATS_PATH).exists(): + return None stats = load_json(local_dir / STATS_PATH) stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()} return unflatten_dict(stats) @@ -186,17 +189,37 @@ def _get_major_minor(version: str) -> tuple[int]: return int(split[0]), int(split[1]) +class BackwardCompatibilityError(Exception): + def __init__(self, repo_id, version): + message = textwrap.dedent(f""" + BackwardCompatibilityError: The dataset you requested ({repo_id}) is in {version} format. + + We introduced a new format since v2.0 which is not backward compatible with v1.x. + Please, use our conversion script. Modify the following command with your own task description: + ``` + python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\ + --repo-id {repo_id} \\ + --single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\ + ``` + + A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.", + "Insert the peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.", + "Open the top cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped target.", + "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the sweatshirt.", ... + + If you encounter a problem, contact LeRobot maintainers on Discord ('https://discord.com/invite/s3KuuzsPFb') + or open an issue on GitHub. + """) + super().__init__(message) + + def check_version_compatibility( repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True ) -> None: current_major, _ = _get_major_minor(current_version) major_to_check, _ = _get_major_minor(version_to_check) if major_to_check < current_major and enforce_breaking_major: - raise ValueError( - f"""The dataset you requested ({repo_id}) is in {version_to_check} format. We introduced a new - format with v2.0 that is not backward compatible. Please use our conversion script - first (convert_dataset_v1_to_v2.py) to convert your dataset to this new format.""" - ) + raise BackwardCompatibilityError(repo_id, version_to_check) elif float(version_to_check.strip("v")) < float(current_version.strip("v")): warnings.warn( f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the @@ -207,18 +230,16 @@ def check_version_compatibility( ) -def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) -> str: - num_version = float(version.strip("v")) - if num_version < 2 and enforce_v2: - raise ValueError( - f"""The dataset you requested ({repo_id}) is in {version} format. We introduced a new - format with v2.0 that is not backward compatible. Please use our conversion script - first (convert_dataset_v1_to_v2.py) to convert your dataset to this new format.""" - ) +def get_hub_safe_version(repo_id: str, version: str) -> str: api = HfApi() dataset_info = api.list_repo_refs(repo_id, repo_type="dataset") branches = [b.name for b in dataset_info.branches] if version not in branches: + num_version = float(version.strip("v")) + hub_num_versions = [float(v.strip("v")) for v in branches if v.startswith("v")] + if num_version >= 2.0 and all(v < 2.0 for v in hub_num_versions): + raise BackwardCompatibilityError(repo_id, version) + warnings.warn( f"""You are trying to load a dataset from {repo_id} created with a previous version of the codebase. The following versions are available: {branches}. @@ -461,6 +482,7 @@ def create_lerobot_dataset_card( } ] card.data.task_categories = ["robotics"] + card.data.license = license card.data.tags = ["LeRobot"] if license: card.data.license = license diff --git a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py index d6961501..827cc1de 100644 --- a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py +++ b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py @@ -441,7 +441,7 @@ def convert_dataset( arxiv: str | None = None, test_branch: str | None = None, ): - v1 = get_hub_safe_version(repo_id, V16, enforce_v2=False) + v1 = get_hub_safe_version(repo_id, V16) v1x_dir = local_dir / V16 / repo_id v20_dir = local_dir / V20 / repo_id v1x_dir.mkdir(parents=True, exist_ok=True) diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 9bcdaea3..3ede0c38 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -17,6 +17,7 @@ from termcolor import colored from lerobot.common.datasets.image_writer import safe_stop_image_writer from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.utils import get_features_from_robot from lerobot.common.policies.factory import make_policy from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.utils import busy_wait @@ -330,3 +331,21 @@ def sanity_check_dataset_name(repo_id, policy): raise ValueError( f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})." ) + + +def sanity_check_dataset_robot_compatibility(dataset, robot, fps, use_videos): + fields = [ + ("robot_type", dataset.meta.info["robot_type"], robot.robot_type), + ("fps", dataset.meta.info["fps"], fps), + ("features", dataset.features, get_features_from_robot(robot, use_videos)), + ] + + mismatches = [] + for field, dataset_value, present_value in fields: + if dataset_value != present_value: + mismatches.append(f"{field}: expected {present_value}, got {dataset_value}") + + if mismatches: + raise ValueError( + "Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches) + ) diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index b8a7d25b..ad73eef4 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -115,6 +115,7 @@ from lerobot.common.robot_devices.control_utils import ( record_episode, reset_environment, sanity_check_dataset_name, + sanity_check_dataset_robot_compatibility, stop_recording, warmup_record, ) @@ -207,6 +208,9 @@ def record( num_image_writer_threads_per_camera: int = 4, display_cameras: bool = True, play_sounds: bool = True, + resume: bool = False, + # TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument + local_files_only: bool = False, ) -> LeRobotDataset: # TODO(rcadene): Add option to record logs listener = None @@ -232,17 +236,29 @@ def record( f"There is a mismatch between the provided fps ({fps}) and the one from policy config ({policy_fps})." ) - # Create empty dataset or load existing saved episodes - sanity_check_dataset_name(repo_id, policy) - dataset = LeRobotDataset.create( - repo_id, - fps, - root=root, - robot=robot, - use_videos=video, - image_writer_processes=num_image_writer_processes, - image_writer_threads=num_image_writer_threads_per_camera, - ) + if resume: + dataset = LeRobotDataset( + repo_id, + root=root, + local_files_only=local_files_only, + ) + dataset.start_image_writer( + num_processes=num_image_writer_processes, + num_threads=num_image_writer_threads_per_camera * len(robot.cameras), + ) + sanity_check_dataset_robot_compatibility(dataset, robot, fps, video) + else: + # Create empty dataset or load existing saved episodes + sanity_check_dataset_name(repo_id, policy) + dataset = LeRobotDataset.create( + repo_id, + fps, + root=root, + robot=robot, + use_videos=video, + image_writer_processes=num_image_writer_processes, + image_writer_threads=num_image_writer_threads_per_camera * len(robot.cameras), + ) if not robot.is_connected: robot.connect() @@ -270,8 +286,7 @@ def record( # if multi_task: # task = input("Enter your task description: ") - episode_index = dataset.episode_buffer["episode_index"] - log_say(f"Recording episode {episode_index}", play_sounds) + log_say(f"Recording episode {dataset.num_episodes}", play_sounds) record_episode( dataset=dataset, robot=robot, @@ -289,7 +304,7 @@ def record( # TODO(rcadene): add an option to enable teleoperation during reset # Skip reset for the last episode to be recorded if not events["stop_recording"] and ( - (episode_index < num_episodes - 1) or events["rerecord_episode"] + (dataset.num_episodes < num_episodes - 1) or events["rerecord_episode"] ): log_say("Reset the environment", play_sounds) reset_environment(robot, events, reset_time_s) diff --git a/lerobot/scripts/push_dataset_to_hub.py b/lerobot/scripts/push_dataset_to_hub.py index 6eac4d0e..75542457 100644 --- a/lerobot/scripts/push_dataset_to_hub.py +++ b/lerobot/scripts/push_dataset_to_hub.py @@ -117,10 +117,14 @@ def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str def push_dataset_card_to_hub( - repo_id: str, revision: str | None, tags: list | None = None, text: str | None = None + repo_id: str, + revision: str | None, + tags: list | None = None, + text: str | None = None, + license: str = "apache-2.0", ): """Creates and pushes a LeRobotDataset Card with appropriate tags to easily find it on the hub.""" - card = create_lerobot_dataset_card(tags=tags, text=text) + card = create_lerobot_dataset_card(tags=tags, text=text, license=license) card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=revision) diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index c773dac8..5d003e1f 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -325,7 +325,7 @@ def lerobot_dataset_metadata_factory( "lerobot.common.datasets.lerobot_dataset.snapshot_download" ) as mock_snapshot_download_patch, ): - mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version, enforce_v2=True: version + mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version: version mock_snapshot_download_patch.side_effect = mock_snapshot_download return LeRobotDatasetMetadata(repo_id=repo_id, root=root, local_files_only=local_files_only) diff --git a/tests/test_control_robot.py b/tests/test_control_robot.py index 88a4d1cd..c51ca972 100644 --- a/tests/test_control_robot.py +++ b/tests/test_control_robot.py @@ -275,22 +275,25 @@ def test_resume_record(tmpdir, request, robot_type, mock): root = Path(tmpdir) / "data" / repo_id single_task = "Do something." - dataset = record( - robot, - root, - repo_id, - single_task, - fps=1, - warmup_time_s=0, - episode_time_s=1, - num_episodes=1, - push_to_hub=False, - video=False, - display_cameras=False, - play_sounds=False, - run_compute_stats=False, - ) - assert len(dataset) == 1, "`dataset` should contain only 1 frame" + record_kwargs = { + "robot": robot, + "root": root, + "repo_id": repo_id, + "single_task": single_task, + "fps": 1, + "warmup_time_s": 0, + "episode_time_s": 1, + "push_to_hub": False, + "video": False, + "display_cameras": False, + "play_sounds": False, + "run_compute_stats": False, + "local_files_only": True, + "num_episodes": 1, + } + + dataset = record(**record_kwargs) + assert len(dataset) == 1, f"`dataset` should contain 1 frame, not {len(dataset)}" # init_dataset_return_value = {} @@ -300,22 +303,13 @@ def test_resume_record(tmpdir, request, robot_type, mock): # return init_dataset_return_value # with patch("lerobot.scripts.control_robot.init_dataset", wraps=wrapped_init_dataset): - dataset = record( - robot, - root, - repo_id, - single_task, - fps=1, - warmup_time_s=0, - episode_time_s=1, - num_episodes=2, - push_to_hub=False, - video=False, - display_cameras=False, - play_sounds=False, - run_compute_stats=False, - ) - assert len(dataset) == 2, "`dataset` should contain only 1 frame" + + with pytest.raises(FileExistsError): + # Dataset already exists, but resume=False by default + record(**record_kwargs) + + dataset = record(**record_kwargs, resume=True) + assert len(dataset) == 2, f"`dataset` should contain 2 frames, not {len(dataset)}" # assert ( # init_dataset_return_value["num_episodes"] == 2 # ), "`init_dataset` should load the previous episode"