Validate features during `add_frame` + Add 2D-to-5D + Add string (#720)

This commit is contained in:
Remi 2025-02-14 19:59:48 +01:00 committed by GitHub
parent 9d6886dd08
commit 7c2bbee613
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 448 additions and 53 deletions

View File

@ -38,22 +38,40 @@ def safe_stop_image_writer(func):
return wrapper return wrapper
def image_array_to_image(image_array: np.ndarray) -> PIL.Image.Image: def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
# TODO(aliberts): handle 1 channel and 4 for depth images # TODO(aliberts): handle 1 channel and 4 for depth images
if image_array.ndim == 3 and image_array.shape[0] in [1, 3]: if image_array.ndim != 3:
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
if image_array.shape[0] == 3:
# Transpose from pytorch convention (C, H, W) to (H, W, C) # Transpose from pytorch convention (C, H, W) to (H, W, C)
image_array = image_array.transpose(1, 2, 0) image_array = image_array.transpose(1, 2, 0)
elif image_array.shape[-1] != 3:
raise NotImplementedError(
f"The image has {image_array.shape[-1]} channels, but 3 is required for now."
)
if image_array.dtype != np.uint8: if image_array.dtype != np.uint8:
# Assume the image is in [0, 1] range for floating-point data if range_check:
image_array = np.clip(image_array, 0, 1) max_ = image_array.max().item()
min_ = image_array.min().item()
if max_ > 1.0 or min_ < 0.0:
raise ValueError(
"The image data type is float, which requires values in the range [0.0, 1.0]. "
f"However, the provided range is [{min_}, {max_}]. Please adjust the range or "
"provide a uint8 image with values in the range [0, 255]."
)
image_array = (image_array * 255).astype(np.uint8) image_array = (image_array * 255).astype(np.uint8)
return PIL.Image.fromarray(image_array) return PIL.Image.fromarray(image_array)
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path): def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
try: try:
if isinstance(image, np.ndarray): if isinstance(image, np.ndarray):
img = image_array_to_image(image) img = image_array_to_pil_image(image)
elif isinstance(image, PIL.Image.Image): elif isinstance(image, PIL.Image.Image):
img = image img = image
else: else:

View File

@ -39,6 +39,7 @@ from lerobot.common.datasets.utils import (
TASKS_PATH, TASKS_PATH,
append_jsonlines, append_jsonlines,
check_delta_timestamps, check_delta_timestamps,
check_frame_features,
check_timestamps_sync, check_timestamps_sync,
check_version_compatibility, check_version_compatibility,
create_branch, create_branch,
@ -724,10 +725,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
temporary directory nothing is written to disk. To save those frames, the 'save_episode()' method temporary directory nothing is written to disk. To save those frames, the 'save_episode()' method
then needs to be called. then needs to be called.
""" """
# TODO(aliberts, rcadene): Add sanity check for the input, check it's numpy or torch, # Convert torch to numpy if needed
# check the dtype and shape matches, etc. for name in frame:
if "task" not in frame: if isinstance(frame[name], torch.Tensor):
raise ValueError("The mandatory feature 'task' wasn't found in `frame` dictionnary.") frame[name] = frame[name].numpy()
check_frame_features(frame, self.features)
if self.episode_buffer is None: if self.episode_buffer is None:
self.episode_buffer = self.create_episode_buffer() self.episode_buffer = self.create_episode_buffer()
@ -757,8 +760,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self._save_image(frame[key], img_path) self._save_image(frame[key], img_path)
self.episode_buffer[key].append(str(img_path)) self.episode_buffer[key].append(str(img_path))
else: else:
item = frame[key].numpy() if isinstance(frame[key], torch.Tensor) else frame[key] self.episode_buffer[key].append(frame[key])
self.episode_buffer[key].append(item)
self.episode_buffer["size"] += 1 self.episode_buffer["size"] += 1
@ -815,12 +817,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# are processed separately by storing image path and frame info as meta data # are processed separately by storing image path and frame info as meta data
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]: if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
continue continue
elif len(ft["shape"]) == 1 and ft["shape"][0] == 1: episode_buffer[key] = np.stack(episode_buffer[key])
episode_buffer[key] = np.array(episode_buffer[key], dtype=ft["dtype"])
elif len(ft["shape"]) == 1 and ft["shape"][0] > 1:
episode_buffer[key] = np.stack(episode_buffer[key])
else:
raise ValueError(key)
self._wait_image_writer() self._wait_image_writer()
self._save_episode_table(episode_buffer, episode_index) self._save_episode_table(episode_buffer, episode_index)

View File

@ -35,6 +35,7 @@ from PIL import Image as PILImage
from torchvision import transforms from torchvision import transforms
from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.utils.utils import is_valid_numpy_dtype_string
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
@ -203,7 +204,7 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
elif first_item is None: elif first_item is None:
pass pass
else: else:
items_dict[key] = [torch.tensor(x) for x in items_dict[key]] items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
return items_dict return items_dict
@ -285,11 +286,20 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
hf_features[key] = datasets.Image() hf_features[key] = datasets.Image()
elif ft["shape"] == (1,): elif ft["shape"] == (1,):
hf_features[key] = datasets.Value(dtype=ft["dtype"]) hf_features[key] = datasets.Value(dtype=ft["dtype"])
else: elif len(ft["shape"]) == 1:
assert len(ft["shape"]) == 1
hf_features[key] = datasets.Sequence( hf_features[key] = datasets.Sequence(
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"]) length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
) )
elif len(ft["shape"]) == 2:
hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"])
elif len(ft["shape"]) == 3:
hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"])
elif len(ft["shape"]) == 4:
hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"])
elif len(ft["shape"]) == 5:
hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"])
else:
raise ValueError(f"Corresponding feature is not valid: {ft}")
return datasets.Features(hf_features) return datasets.Features(hf_features)
@ -606,3 +616,90 @@ class IterableNamespace(SimpleNamespace):
def keys(self): def keys(self):
return vars(self).keys() return vars(self).keys()
def check_frame_features(frame: dict, features: dict):
optional_features = {"timestamp"}
expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"}
actual_features = set(frame.keys())
error_message = check_features_presence(actual_features, expected_features, optional_features)
if "task" in frame:
error_message += check_feature_string("task", frame["task"])
common_features = actual_features & (expected_features | optional_features)
for name in common_features - {"task"}:
error_message += check_feature_dtype_and_shape(name, features[name], frame[name])
if error_message:
raise ValueError(error_message)
def check_features_presence(
actual_features: set[str], expected_features: set[str], optional_features: set[str]
):
error_message = ""
missing_features = expected_features - actual_features
extra_features = actual_features - (expected_features | optional_features)
if missing_features or extra_features:
error_message += "Feature mismatch in `frame` dictionary:\n"
if missing_features:
error_message += f"Missing features: {missing_features}\n"
if extra_features:
error_message += f"Extra features: {extra_features}\n"
return error_message
def check_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
expected_dtype = feature["dtype"]
expected_shape = feature["shape"]
if is_valid_numpy_dtype_string(expected_dtype):
return check_feature_numpy_array(name, expected_dtype, expected_shape, value)
elif expected_dtype in ["image", "video"]:
return check_feature_image_or_video(name, expected_shape, value)
elif expected_dtype == "string":
return check_feature_string(name, value)
else:
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
def check_feature_numpy_array(name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray):
error_message = ""
if isinstance(value, np.ndarray):
actual_dtype = value.dtype
actual_shape = value.shape
if actual_dtype != np.dtype(expected_dtype):
error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n"
if actual_shape != expected_shape:
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n"
else:
error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n"
return error_message
def check_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
error_message = ""
if isinstance(value, np.ndarray):
actual_shape = value.shape
c, h, w = expected_shape
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
elif isinstance(value, PILImage.Image):
pass
else:
error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n"
return error_message
def check_feature_string(name: str, value: str):
if not isinstance(value, str):
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
return ""

View File

@ -21,6 +21,7 @@ from copy import copy
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
import numpy as np
import torch import torch
@ -200,5 +201,18 @@ def get_channel_first_image_shape(image_shape: tuple) -> tuple:
return shape return shape
def has_method(cls: object, method_name: str): def has_method(cls: object, method_name: str) -> bool:
return hasattr(cls, method_name) and callable(getattr(cls, method_name)) return hasattr(cls, method_name) and callable(getattr(cls, method_name))
def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
"""
Return True if a given string can be converted to a numpy dtype.
"""
try:
# Attempt to convert the string to a numpy dtype
np.dtype(dtype_str)
return True
except TypeError:
# If a TypeError is raised, the string is not a valid dtype
return False

View File

@ -27,3 +27,5 @@ DUMMY_VIDEO_INFO = {
"video.is_depth_map": False, "video.is_depth_map": False,
"has_audio": False, "has_audio": False,
} }
DUMMY_CHW = (3, 96, 128)
DUMMY_HWC = (96, 128, 3)

View File

@ -17,6 +17,7 @@ from lerobot.common.datasets.utils import (
get_hf_features_from_features, get_hf_features_from_features,
hf_transform_to_torch, hf_transform_to_torch,
) )
from lerobot.common.robot_devices.robots.utils import Robot
from tests.fixtures.constants import ( from tests.fixtures.constants import (
DEFAULT_FPS, DEFAULT_FPS,
DUMMY_CAMERA_FEATURES, DUMMY_CAMERA_FEATURES,
@ -394,3 +395,20 @@ def lerobot_dataset_factory(
return LeRobotDataset(repo_id=repo_id, root=root, **kwargs) return LeRobotDataset(repo_id=repo_id, root=root, **kwargs)
return _create_lerobot_dataset return _create_lerobot_dataset
@pytest.fixture(scope="session")
def empty_lerobot_dataset_factory():
def _create_empty_lerobot_dataset(
root: Path,
repo_id: str = DUMMY_REPO_ID,
fps: int = DEFAULT_FPS,
robot: Robot | None = None,
robot_type: str | None = None,
features: dict | None = None,
):
return LeRobotDataset.create(
repo_id=repo_id, fps=fps, root=root, robot=robot, robot_type=robot_type, features=features
)
return _create_empty_lerobot_dataset

View File

@ -15,15 +15,18 @@
# limitations under the License. # limitations under the License.
import json import json
import logging import logging
import re
from copy import deepcopy from copy import deepcopy
from itertools import chain from itertools import chain
from pathlib import Path from pathlib import Path
import einops import einops
import numpy as np
import pytest import pytest
import torch import torch
from datasets import Dataset from datasets import Dataset
from huggingface_hub import HfApi from huggingface_hub import HfApi
from PIL import Image
from safetensors.torch import load_file from safetensors.torch import load_file
import lerobot import lerobot
@ -33,6 +36,7 @@ from lerobot.common.datasets.compute_stats import (
get_stats_einops_patterns, get_stats_einops_patterns,
) )
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.image_writer import image_array_to_pil_image
from lerobot.common.datasets.lerobot_dataset import ( from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset, LeRobotDataset,
MultiLeRobotDataset, MultiLeRobotDataset,
@ -49,11 +53,27 @@ from lerobot.common.robot_devices.robots.utils import make_robot
from lerobot.common.utils.random_utils import seeded_context from lerobot.common.utils.random_utils import seeded_context
from lerobot.configs.default import DatasetConfig from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.train import TrainPipelineConfig
from tests.fixtures.constants import DUMMY_REPO_ID from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
from tests.utils import DEVICE, require_x86_64_kernel from tests.utils import DEVICE, require_x86_64_kernel
def test_same_attributes_defined(lerobot_dataset_factory, tmp_path): @pytest.fixture
def image_dataset(tmp_path, empty_lerobot_dataset_factory):
features = {
"image": {
"dtype": "image",
"shape": DUMMY_CHW,
"names": [
"channels",
"height",
"width",
],
}
}
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
""" """
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
objects have the same sets of attributes defined. objects have the same sets of attributes defined.
@ -76,14 +96,14 @@ def test_same_attributes_defined(lerobot_dataset_factory, tmp_path):
assert init_attr == create_attr assert init_attr == create_attr
def test_dataset_initialization(lerobot_dataset_factory, tmp_path): def test_dataset_initialization(tmp_path, lerobot_dataset_factory):
kwargs = { kwargs = {
"repo_id": DUMMY_REPO_ID, "repo_id": DUMMY_REPO_ID,
"total_episodes": 10, "total_episodes": 10,
"total_frames": 400, "total_frames": 400,
"episodes": [2, 5, 6], "episodes": [2, 5, 6],
} }
dataset = lerobot_dataset_factory(root=tmp_path, **kwargs) dataset = lerobot_dataset_factory(root=tmp_path / "test", **kwargs)
assert dataset.repo_id == kwargs["repo_id"] assert dataset.repo_id == kwargs["repo_id"]
assert dataset.meta.total_episodes == kwargs["total_episodes"] assert dataset.meta.total_episodes == kwargs["total_episodes"]
@ -93,28 +113,243 @@ def test_dataset_initialization(lerobot_dataset_factory, tmp_path):
assert dataset.num_frames == len(dataset) assert dataset.num_frames == len(dataset)
def test_add_frame_no_task(tmp_path): def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
features = {"1d": {"dtype": "float32", "shape": (1,), "names": None}} features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, root=tmp_path / "test", features=features) dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(ValueError, match="The mandatory feature 'task' wasn't found in `frame` dictionnary."): with pytest.raises(
dataset.add_frame({"1d": torch.randn(1)}) ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n"
):
dataset.add_frame({"state": torch.randn(1)})
def test_add_frame(tmp_path): def test_add_frame_missing_feature(tmp_path, empty_lerobot_dataset_factory):
features = {"1d": {"dtype": "float32", "shape": (1,), "names": None}} features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, root=tmp_path / "test", features=features) dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"1d": torch.randn(1), "task": "dummy"}) with pytest.raises(
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n"
):
dataset.add_frame({"task": "Dummy task"})
def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError, match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n"
):
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"})
def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError, match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n"
):
dataset.add_frame({"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"})
def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError,
match=re.escape("The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"),
):
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
def test_add_frame_wrong_shape_python_float(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError,
match=re.escape(
"The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '<class 'float'>' provided instead.\n"
),
):
dataset.add_frame({"state": 1.0, "task": "Dummy task"})
def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError,
match=re.escape("The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"),
):
dataset.add_frame({"state": torch.tensor(1.0), "task": "Dummy task"})
def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError,
match=re.escape(
"The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '<class 'numpy.float32'>' provided instead.\n"
),
):
dataset.add_frame({"state": np.float32(1.0), "task": "Dummy task"})
def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(1), "task": "dummy"})
dataset.save_episode(encode_videos=False) dataset.save_episode(encode_videos=False)
dataset.consolidate(run_compute_stats=False) dataset.consolidate(run_compute_stats=False)
assert len(dataset) == 1 assert len(dataset) == 1
assert dataset[0]["task"] == "dummy" assert dataset[0]["task"] == "dummy"
assert dataset[0]["task_index"] == 0 assert dataset[0]["task_index"] == 0
assert dataset[0]["state"].ndim == 0
def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2), "task": "dummy"})
dataset.save_episode(encode_videos=False)
dataset.consolidate(run_compute_stats=False)
assert dataset[0]["state"].shape == torch.Size([2])
def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4), "task": "dummy"})
dataset.save_episode(encode_videos=False)
dataset.consolidate(run_compute_stats=False)
assert dataset[0]["state"].shape == torch.Size([2, 4])
def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4, 3), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "dummy"})
dataset.save_episode(encode_videos=False)
dataset.consolidate(run_compute_stats=False)
assert dataset[0]["state"].shape == torch.Size([2, 4, 3])
def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "dummy"})
dataset.save_episode(encode_videos=False)
dataset.consolidate(run_compute_stats=False)
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5])
def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5, 1), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "dummy"})
dataset.save_episode(encode_videos=False)
dataset.consolidate(run_compute_stats=False)
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1])
def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": np.array([1], dtype=np.float32), "task": "Dummy task"})
dataset.save_episode(encode_videos=False)
dataset.consolidate(run_compute_stats=False)
assert dataset[0]["state"].ndim == 0
def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory):
features = {"caption": {"dtype": "string", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"caption": "Dummy caption", "task": "Dummy task"})
dataset.save_episode(encode_videos=False)
dataset.consolidate(run_compute_stats=False)
assert dataset[0]["caption"] == "Dummy caption"
def test_add_frame_image_wrong_shape(image_dataset):
dataset = image_dataset
with pytest.raises(
ValueError,
match=re.escape(
"The feature 'image' of shape '(3, 128, 96)' does not have the expected shape '(3, 96, 128)' or '(96, 128, 3)'.\n"
),
):
c, h, w = DUMMY_CHW
dataset.add_frame({"image": torch.randn(c, w, h), "task": "Dummy task"})
def test_add_frame_image_wrong_range(image_dataset):
"""This test will display the following error message from a thread:
```
Error writing image ...test_add_frame_image_wrong_ran0/test/images/image/episode_000000/frame_000000.png:
The image data type is float, which requires values in the range [0.0, 1.0]. However, the provided range is [0.009678772038470007, 254.9776492089887].
Please adjust the range or provide a uint8 image with values in the range [0, 255]
```
Hence the image won't be saved on disk and save_episode will raise `FileNotFoundError`.
"""
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255, "task": "Dummy task"})
with pytest.raises(FileNotFoundError):
dataset.save_episode(encode_videos=False)
def test_add_frame_image(image_dataset):
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
dataset.save_episode(encode_videos=False)
dataset.consolidate(run_compute_stats=False)
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
def test_add_frame_image_h_w_c(image_dataset):
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_HWC), "task": "Dummy task"})
dataset.save_episode(encode_videos=False)
dataset.consolidate(run_compute_stats=False)
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
def test_add_frame_image_uint8(image_dataset):
dataset = image_dataset
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
dataset.add_frame({"image": image, "task": "Dummy task"})
dataset.save_episode(encode_videos=False)
dataset.consolidate(run_compute_stats=False)
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
def test_add_frame_image_pil(image_dataset):
dataset = image_dataset
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
dataset.add_frame({"image": Image.fromarray(image), "task": "Dummy task"})
dataset.save_episode(encode_videos=False)
dataset.consolidate(run_compute_stats=False)
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
def test_image_array_to_pil_image_wrong_range_float_0_255():
image = np.random.rand(*DUMMY_HWC) * 255
with pytest.raises(ValueError):
image_array_to_pil_image(image)
# TODO(aliberts): # TODO(aliberts):
# - [ ] test various attributes & state from init and create # - [ ] test various attributes & state from init and create
# - [ ] test init with episodes and check num_frames # - [ ] test init with episodes and check num_frames
# - [ ] test add_frame
# - [ ] test add_episode # - [ ] test add_episode
# - [ ] test consolidate # - [ ] test consolidate
# - [ ] test push_to_hub # - [ ] test push_to_hub

View File

@ -9,10 +9,11 @@ from PIL import Image
from lerobot.common.datasets.image_writer import ( from lerobot.common.datasets.image_writer import (
AsyncImageWriter, AsyncImageWriter,
image_array_to_image, image_array_to_pil_image,
safe_stop_image_writer, safe_stop_image_writer,
write_image, write_image,
) )
from tests.fixtures.constants import DUMMY_HWC
DUMMY_IMAGE = "test_image.png" DUMMY_IMAGE = "test_image.png"
@ -48,49 +49,62 @@ def test_zero_threads():
AsyncImageWriter(num_processes=0, num_threads=0) AsyncImageWriter(num_processes=0, num_threads=0)
def test_image_array_to_image_rgb(img_array_factory): def test_image_array_to_pil_image_float_array_wrong_range_0_255():
image = np.random.rand(*DUMMY_HWC) * 255
with pytest.raises(ValueError):
image_array_to_pil_image(image)
def test_image_array_to_pil_image_float_array_wrong_range_neg_1_1():
image = np.random.rand(*DUMMY_HWC) * 2 - 1
with pytest.raises(ValueError):
image_array_to_pil_image(image)
def test_image_array_to_pil_image_rgb(img_array_factory):
img_array = img_array_factory(100, 100) img_array = img_array_factory(100, 100)
result_image = image_array_to_image(img_array) result_image = image_array_to_pil_image(img_array)
assert isinstance(result_image, Image.Image) assert isinstance(result_image, Image.Image)
assert result_image.size == (100, 100) assert result_image.size == (100, 100)
assert result_image.mode == "RGB" assert result_image.mode == "RGB"
def test_image_array_to_image_pytorch_format(img_array_factory): def test_image_array_to_pil_image_pytorch_format(img_array_factory):
img_array = img_array_factory(100, 100).transpose(2, 0, 1) img_array = img_array_factory(100, 100).transpose(2, 0, 1)
result_image = image_array_to_image(img_array) result_image = image_array_to_pil_image(img_array)
assert isinstance(result_image, Image.Image) assert isinstance(result_image, Image.Image)
assert result_image.size == (100, 100) assert result_image.size == (100, 100)
assert result_image.mode == "RGB" assert result_image.mode == "RGB"
@pytest.mark.skip("TODO: implement") def test_image_array_to_pil_image_single_channel(img_array_factory):
def test_image_array_to_image_single_channel(img_array_factory):
img_array = img_array_factory(channels=1) img_array = img_array_factory(channels=1)
result_image = image_array_to_image(img_array) with pytest.raises(NotImplementedError):
assert isinstance(result_image, Image.Image) image_array_to_pil_image(img_array)
assert result_image.size == (100, 100)
assert result_image.mode == "L"
def test_image_array_to_image_float_array(img_array_factory): def test_image_array_to_pil_image_4_channels(img_array_factory):
img_array = img_array_factory(channels=4)
with pytest.raises(NotImplementedError):
image_array_to_pil_image(img_array)
def test_image_array_to_pil_image_float_array(img_array_factory):
img_array = img_array_factory(dtype=np.float32) img_array = img_array_factory(dtype=np.float32)
result_image = image_array_to_image(img_array) result_image = image_array_to_pil_image(img_array)
assert isinstance(result_image, Image.Image) assert isinstance(result_image, Image.Image)
assert result_image.size == (100, 100) assert result_image.size == (100, 100)
assert result_image.mode == "RGB" assert result_image.mode == "RGB"
assert np.array(result_image).dtype == np.uint8 assert np.array(result_image).dtype == np.uint8
def test_image_array_to_image_out_of_bounds_float(): def test_image_array_to_pil_image_uint8_array(img_array_factory):
# Float array with values out of [0, 1] img_array = img_array_factory(dtype=np.float32)
img_array = np.random.uniform(-1, 2, size=(100, 100, 3)).astype(np.float32) result_image = image_array_to_pil_image(img_array)
result_image = image_array_to_image(img_array)
assert isinstance(result_image, Image.Image) assert isinstance(result_image, Image.Image)
assert result_image.size == (100, 100) assert result_image.size == (100, 100)
assert result_image.mode == "RGB" assert result_image.mode == "RGB"
assert np.array(result_image).dtype == np.uint8 assert np.array(result_image).dtype == np.uint8
assert np.array(result_image).min() >= 0 and np.array(result_image).max() <= 255
def test_write_image_numpy(tmp_path, img_array_factory): def test_write_image_numpy(tmp_path, img_array_factory):