Merge remote-tracking branch 'origin/user/aliberts/2025_02_10_dataset_v2.1' into user/rcadene/2025_02_19_port_openx

This commit is contained in:
Remi Cadene 2025-02-20 17:34:13 +00:00
commit b520941cd9
10 changed files with 133 additions and 148 deletions

View File

@ -3,8 +3,10 @@ from pathlib import Path
import numpy as np
import torch
from huggingface_hub import HfApi
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
PUSHT_TASK = "Push the T-shaped blue block onto the T-shaped green target surface."
@ -134,8 +136,8 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
if mode not in ["video", "image", "keypoints"]:
raise ValueError(mode)
if (LEROBOT_HOME / repo_id).exists():
shutil.rmtree(LEROBOT_HOME / repo_id)
if (HF_LEROBOT_HOME / repo_id).exists():
shutil.rmtree(HF_LEROBOT_HOME / repo_id)
if not raw_dir.exists():
download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw")
@ -198,10 +200,10 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
dataset.save_episode()
dataset.consolidate()
if push_to_hub:
dataset.push_to_hub()
hub_api = HfApi()
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
if __name__ == "__main__":

View File

@ -1,4 +1,9 @@
# keys
import os
from pathlib import Path
from huggingface_hub.constants import HF_HOME
OBS_ENV = "observation.environment_state"
OBS_ROBOT = "observation.state"
OBS_IMAGE = "observation.image"
@ -15,3 +20,13 @@ TRAINING_STEP = "training_step.json"
OPTIMIZER_STATE = "optimizer_state.safetensors"
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
SCHEDULER_STATE = "scheduler_state.json"
# cache dir
default_cache_path = Path(HF_HOME) / "lerobot"
HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser()
if "LEROBOT_HOME" in os.environ:
raise ValueError(
f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n"
"'LEROBOT_HOME' is deprecated, please use 'HF_LEROBOT_HOME' instead."
)

View File

@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import shutil
from pathlib import Path
from typing import Callable
@ -29,6 +28,7 @@ from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.constants import REPOCARD_NAME
from packaging import version
from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
from lerobot.common.datasets.utils import (
@ -39,7 +39,6 @@ from lerobot.common.datasets.utils import (
append_jsonlines,
backward_compatible_episodes_stats,
check_delta_timestamps,
check_frame_features,
check_timestamps_sync,
check_version_compatibility,
create_empty_dataset_info,
@ -55,6 +54,8 @@ from lerobot.common.datasets.utils import (
load_info,
load_stats,
load_tasks,
validate_episode_buffer,
validate_frame,
write_episode,
write_episode_stats,
write_info,
@ -71,7 +72,6 @@ from lerobot.common.robot_devices.robots.utils import Robot
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
CODEBASE_VERSION = "v2.1"
LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
class LeRobotDatasetMetadata:
@ -84,7 +84,7 @@ class LeRobotDatasetMetadata:
):
self.repo_id = repo_id
self.revision = revision if revision else CODEBASE_VERSION
self.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
try:
if force_cache_sync:
@ -257,6 +257,9 @@ class LeRobotDatasetMetadata:
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
self.info["total_videos"] += len(self.video_keys)
if len(self.video_keys) > 0:
self.update_video_info()
write_info(self.info, self.root)
episode_dict = {
@ -271,7 +274,7 @@ class LeRobotDatasetMetadata:
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
write_episode_stats(episode_index, episode_stats, self.root)
def write_video_info(self) -> None:
def update_video_info(self) -> None:
"""
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
been encoded the same way. Also, this means it assumes the first episode exists.
@ -281,8 +284,6 @@ class LeRobotDatasetMetadata:
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
self.info["features"][key]["info"] = get_video_info(video_path)
write_json(self.info, self.root / INFO_PATH)
def __repr__(self):
feature_keys = list(self.features)
return (
@ -308,7 +309,7 @@ class LeRobotDatasetMetadata:
"""Creates metadata for a LeRobotDataset."""
obj = cls.__new__(cls)
obj.repo_id = repo_id
obj.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
obj.root.mkdir(parents=True, exist_ok=False)
@ -463,7 +464,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""
super().__init__()
self.repo_id = repo_id
self.root = Path(root) if root else LEROBOT_HOME / repo_id
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
self.episodes = episodes
@ -507,9 +508,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
# Available stats implies all videos have been encoded and dataset is iterable
self.consolidated = self.meta.stats is not None
def push_to_hub(
self,
branch: str | None = None,
@ -520,13 +518,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
allow_patterns: list[str] | str | None = None,
**card_kwargs,
) -> None:
if not self.consolidated:
logging.warning(
"You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet. "
"Consolidating first."
)
self.consolidate()
ignore_patterns = ["images/"]
if not push_videos:
ignore_patterns.append("videos/")
@ -780,7 +771,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
if isinstance(frame[name], torch.Tensor):
frame[name] = frame[name].numpy()
check_frame_features(frame, self.features)
validate_frame(frame, self.features)
if self.episode_buffer is None:
self.episode_buffer = self.create_episode_buffer()
@ -816,41 +807,25 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episode_buffer["size"] += 1
def save_episode(self, encode_videos: bool = True, episode_data: dict | None = None) -> None:
def save_episode(self, episode_data: dict | None = None) -> None:
"""
This will save to disk the current episode in self.episode_buffer. Note that since it affects files on
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
the hub.
This will save to disk the current episode in self.episode_buffer.
Use 'encode_videos' if you want to encode videos during the saving of this episode. Otherwise,
you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend
time for video encoding.
Args:
episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
None.
"""
if not episode_data:
episode_buffer = self.episode_buffer
validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features)
# size and task are special cases that won't be added to hf_dataset
episode_length = episode_buffer.pop("size")
tasks = episode_buffer.pop("task")
episode_tasks = list(set(tasks))
episode_index = episode_buffer["episode_index"]
if episode_index != self.meta.total_episodes:
# TODO(aliberts): Add option to use existing episode_index
raise NotImplementedError(
"You might have manually provided the episode_buffer with an episode_index that doesn't "
"match the total number of episodes already in the dataset. This is not supported for now."
)
if episode_length == 0:
raise ValueError(
"You must add one or several frames with `add_frame` before calling `add_episode`."
)
if not set(episode_buffer.keys()) == set(self.features):
raise ValueError(
f"Features from `episode_buffer` don't match the ones in `self.features`: '{set(episode_buffer.keys())}' vs '{set(self.features)}'"
)
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
@ -876,16 +851,29 @@ class LeRobotDataset(torch.utils.data.Dataset):
ep_stats = compute_episode_stats(episode_buffer, self.features)
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
if encode_videos and len(self.meta.video_keys) > 0:
if len(self.meta.video_keys) > 0:
video_paths = self.encode_episode_videos(episode_index)
for key in self.meta.video_keys:
episode_buffer[key] = video_paths[key]
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
video_files = list(self.root.rglob("*.mp4"))
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
parquet_files = list(self.root.rglob("*.parquet"))
assert len(parquet_files) == self.num_episodes
# delete images
img_dir = self.root / "images"
if img_dir.is_dir():
shutil.rmtree(self.root / "images")
if not episode_data: # Reset the buffer
self.episode_buffer = self.create_episode_buffer()
self.consolidated = False
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
@ -960,28 +948,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
return video_paths
def consolidate(self, keep_image_files: bool = False) -> None:
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
if len(self.meta.video_keys) > 0:
self.encode_videos()
self.meta.write_video_info()
if not keep_image_files:
img_dir = self.root / "images"
if img_dir.is_dir():
shutil.rmtree(self.root / "images")
video_files = list(self.root.rglob("*.mp4"))
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
parquet_files = list(self.root.rglob("*.parquet"))
assert len(parquet_files) == self.num_episodes
self.consolidated = True
@classmethod
def create(
cls,
@ -1020,12 +986,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
obj.episode_buffer = obj.create_episode_buffer()
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It
# is used to know when certain operations are need (for instance, computing dataset statistics). In
# order to be able to push the dataset to the hub, it needs to be consolidated first by calling
# self.consolidate().
obj.consolidated = True
obj.episodes = None
obj.hf_dataset = None
obj.image_transforms = None
@ -1056,7 +1016,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
):
super().__init__()
self.repo_ids = repo_ids
self.root = Path(root) if root else LEROBOT_HOME
self.root = Path(root) if root else HF_LEROBOT_HOME
self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
# are handled by this class.

View File

@ -644,25 +644,25 @@ class IterableNamespace(SimpleNamespace):
return vars(self).keys()
def check_frame_features(frame: dict, features: dict):
def validate_frame(frame: dict, features: dict):
optional_features = {"timestamp"}
expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"}
actual_features = set(frame.keys())
error_message = check_features_presence(actual_features, expected_features, optional_features)
error_message = validate_features_presence(actual_features, expected_features, optional_features)
if "task" in frame:
error_message += check_feature_string("task", frame["task"])
error_message += validate_feature_string("task", frame["task"])
common_features = actual_features & (expected_features | optional_features)
for name in common_features - {"task"}:
error_message += check_feature_dtype_and_shape(name, features[name], frame[name])
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
if error_message:
raise ValueError(error_message)
def check_features_presence(
def validate_features_presence(
actual_features: set[str], expected_features: set[str], optional_features: set[str]
):
error_message = ""
@ -679,20 +679,22 @@ def check_features_presence(
return error_message
def check_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
expected_dtype = feature["dtype"]
expected_shape = feature["shape"]
if is_valid_numpy_dtype_string(expected_dtype):
return check_feature_numpy_array(name, expected_dtype, expected_shape, value)
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
elif expected_dtype in ["image", "video"]:
return check_feature_image_or_video(name, expected_shape, value)
return validate_feature_image_or_video(name, expected_shape, value)
elif expected_dtype == "string":
return check_feature_string(name, value)
return validate_feature_string(name, value)
else:
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
def check_feature_numpy_array(name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray):
def validate_feature_numpy_array(
name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray
):
error_message = ""
if isinstance(value, np.ndarray):
actual_dtype = value.dtype
@ -709,7 +711,7 @@ def check_feature_numpy_array(name: str, expected_dtype: str, expected_shape: li
return error_message
def check_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
error_message = ""
if isinstance(value, np.ndarray):
@ -725,7 +727,33 @@ def check_feature_image_or_video(name: str, expected_shape: list[str], value: np
return error_message
def check_feature_string(name: str, value: str):
def validate_feature_string(name: str, value: str):
if not isinstance(value, str):
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
return ""
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict):
if "size" not in episode_buffer:
raise ValueError("size key not found in episode_buffer")
if "task" not in episode_buffer:
raise ValueError("task key not found in episode_buffer")
if episode_buffer["episode_index"] != total_episodes:
# TODO(aliberts): Add option to use existing episode_index
raise NotImplementedError(
"You might have manually provided the episode_buffer with an episode_index that doesn't "
"match the total number of episodes already in the dataset. This is not supported for now."
)
if episode_buffer["size"] == 0:
raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.")
buffer_keys = set(episode_buffer.keys()) - {"task", "size"}
if not buffer_keys == set(features):
raise ValueError(
f"Features from `episode_buffer` don't match the ones in `features`."
f"In episode_buffer not in features: {buffer_keys - set(features)}"
f"In features not in episode_buffer: {set(features) - buffer_keys}"
)

View File

@ -29,8 +29,9 @@ LOCAL_DIR = Path("data/")
def batch_convert():
status = {}
LOCAL_DIR.mkdir(parents=True, exist_ok=True)
logfile = LOCAL_DIR / "conversion_log_v21.txt"
for num, repo_id in available_datasets:
for num, repo_id in enumerate(available_datasets):
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
print("---------------------------------------------------------")
try:

View File

@ -2,10 +2,10 @@
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
2.1. It will:
- Generates per-episodes stats and writes them in `episodes_stats.jsonl`
- Generate per-episodes stats and writes them in `episodes_stats.jsonl`
- Check consistency between these new stats and the old ones.
- Removes the deprecated `stats.json` (by default)
- Updates codebase_version in `info.json`
- Remove the deprecated `stats.json`.
- Update codebase_version in `info.json`.
- Push this new version to the hub on the 'main' branch and tags it with "v2.1".
Usage:
@ -80,19 +80,20 @@ if __name__ == "__main__":
"--repo-id",
type=str,
required=True,
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
)
parser.add_argument(
"--branch",
type=str,
default=None,
help="Repo branch to push your dataset (defaults to the main branch)",
help="Repo branch to push your dataset. Defaults to the main branch.",
)
parser.add_argument(
"--num-workers",
type=int,
default=4,
help="Number of workers for parallelizing compute",
help="Number of workers for parallelizing stats compute. Defaults to 4.",
)
args = parser.parse_args()

View File

@ -299,8 +299,6 @@ def record(
log_say("Stop recording", cfg.play_sounds, blocking=True)
stop_recording(robot, listener, cfg.display_cameras)
dataset.consolidate()
if cfg.push_to_hub:
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)

View File

@ -1,6 +1,6 @@
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
from lerobot.common.constants import HF_LEROBOT_HOME
LEROBOT_TEST_DIR = LEROBOT_HOME / "_testing"
LEROBOT_TEST_DIR = HF_LEROBOT_HOME / "_testing"
DUMMY_REPO_ID = "dummy/repo"
DUMMY_ROBOT_TYPE = "dummy_robot"
DUMMY_MOTOR_FEATURES = {

View File

@ -1,5 +1,7 @@
import random
from functools import partial
from pathlib import Path
from typing import Protocol
from unittest.mock import patch
import datasets
@ -17,7 +19,6 @@ from lerobot.common.datasets.utils import (
get_hf_features_from_features,
hf_transform_to_torch,
)
from lerobot.common.robot_devices.robots.utils import Robot
from tests.fixtures.constants import (
DEFAULT_FPS,
DUMMY_CAMERA_FEATURES,
@ -28,6 +29,10 @@ from tests.fixtures.constants import (
)
class LeRobotDatasetFactory(Protocol):
def __call__(self, *args, **kwargs) -> LeRobotDataset: ...
def get_task_index(task_dicts: dict, task: str) -> int:
tasks = {d["task_index"]: d["task"] for d in task_dicts.values()}
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()}
@ -358,7 +363,7 @@ def lerobot_dataset_factory(
hf_dataset_factory,
mock_snapshot_download_factory,
lerobot_dataset_metadata_factory,
):
) -> LeRobotDatasetFactory:
def _create_lerobot_dataset(
root: Path,
repo_id: str = DUMMY_REPO_ID,
@ -430,17 +435,5 @@ def lerobot_dataset_factory(
@pytest.fixture(scope="session")
def empty_lerobot_dataset_factory():
def _create_empty_lerobot_dataset(
root: Path,
repo_id: str = DUMMY_REPO_ID,
fps: int = DEFAULT_FPS,
robot: Robot | None = None,
robot_type: str | None = None,
features: dict | None = None,
) -> LeRobotDataset:
return LeRobotDataset.create(
repo_id=repo_id, fps=fps, root=root, robot=robot, robot_type=robot_type, features=features
)
return _create_empty_lerobot_dataset
def empty_lerobot_dataset_factory() -> LeRobotDatasetFactory:
return partial(LeRobotDataset.create, repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS)

View File

@ -184,8 +184,7 @@ def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
dataset.save_episode(encode_videos=False)
dataset.consolidate()
dataset.save_episode()
assert len(dataset) == 1
assert dataset[0]["task"] == "Dummy task"
@ -197,8 +196,7 @@ def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2), "task": "Dummy task"})
dataset.save_episode(encode_videos=False)
dataset.consolidate()
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2])
@ -207,8 +205,7 @@ def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4), "task": "Dummy task"})
dataset.save_episode(encode_videos=False)
dataset.consolidate()
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2, 4])
@ -217,8 +214,7 @@ def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4, 3), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "Dummy task"})
dataset.save_episode(encode_videos=False)
dataset.consolidate()
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2, 4, 3])
@ -227,8 +223,7 @@ def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "Dummy task"})
dataset.save_episode(encode_videos=False)
dataset.consolidate()
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5])
@ -237,8 +232,7 @@ def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5, 1), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "Dummy task"})
dataset.save_episode(encode_videos=False)
dataset.consolidate()
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1])
@ -247,8 +241,7 @@ def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": np.array([1], dtype=np.float32), "task": "Dummy task"})
dataset.save_episode(encode_videos=False)
dataset.consolidate()
dataset.save_episode()
assert dataset[0]["state"].ndim == 0
@ -257,8 +250,7 @@ def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory):
features = {"caption": {"dtype": "string", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"caption": "Dummy caption", "task": "Dummy task"})
dataset.save_episode(encode_videos=False)
dataset.consolidate()
dataset.save_episode()
assert dataset[0]["caption"] == "Dummy caption"
@ -287,14 +279,13 @@ def test_add_frame_image_wrong_range(image_dataset):
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255, "task": "Dummy task"})
with pytest.raises(FileNotFoundError):
dataset.save_episode(encode_videos=False)
dataset.save_episode()
def test_add_frame_image(image_dataset):
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
dataset.save_episode(encode_videos=False)
dataset.consolidate()
dataset.save_episode()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@ -302,8 +293,7 @@ def test_add_frame_image(image_dataset):
def test_add_frame_image_h_w_c(image_dataset):
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_HWC), "task": "Dummy task"})
dataset.save_episode(encode_videos=False)
dataset.consolidate()
dataset.save_episode()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@ -312,8 +302,7 @@ def test_add_frame_image_uint8(image_dataset):
dataset = image_dataset
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
dataset.add_frame({"image": image, "task": "Dummy task"})
dataset.save_episode(encode_videos=False)
dataset.consolidate()
dataset.save_episode()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@ -322,8 +311,7 @@ def test_add_frame_image_pil(image_dataset):
dataset = image_dataset
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
dataset.add_frame({"image": Image.fromarray(image), "task": "Dummy task"})
dataset.save_episode(encode_videos=False)
dataset.consolidate()
dataset.save_episode()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@ -338,7 +326,6 @@ def test_image_array_to_pil_image_wrong_range_float_0_255():
# - [ ] test various attributes & state from init and create
# - [ ] test init with episodes and check num_frames
# - [ ] test add_episode
# - [ ] test consolidate
# - [ ] test push_to_hub
# - [ ] test smaller methods
@ -581,9 +568,9 @@ def test_create_branch():
def test_dataset_feature_with_forward_slash_raises_error():
# make sure dir does not exist
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
from lerobot.common.constants import HF_LEROBOT_HOME
dataset_dir = LEROBOT_HOME / "lerobot/test/with/slash"
dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash"
# make sure does not exist
if dataset_dir.exists():
dataset_dir.rmdir()