From 7c2bbee6136cca964383f9c1eb3e4e82ab818312 Mon Sep 17 00:00:00 2001 From: Remi Date: Fri, 14 Feb 2025 19:59:48 +0100 Subject: [PATCH] Validate features during `add_frame` + Add 2D-to-5D + Add string (#720) --- lerobot/common/datasets/image_writer.py | 28 ++- lerobot/common/datasets/lerobot_dataset.py | 21 +- lerobot/common/datasets/utils.py | 103 +++++++- lerobot/common/utils/utils.py | 16 +- tests/fixtures/constants.py | 2 + tests/fixtures/dataset_factories.py | 18 ++ tests/test_datasets.py | 263 +++++++++++++++++++-- tests/test_image_writer.py | 50 ++-- 8 files changed, 448 insertions(+), 53 deletions(-) diff --git a/lerobot/common/datasets/image_writer.py b/lerobot/common/datasets/image_writer.py index 85dd6830..6fc0ee2f 100644 --- a/lerobot/common/datasets/image_writer.py +++ b/lerobot/common/datasets/image_writer.py @@ -38,22 +38,40 @@ def safe_stop_image_writer(func): return wrapper -def image_array_to_image(image_array: np.ndarray) -> PIL.Image.Image: +def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image: # TODO(aliberts): handle 1 channel and 4 for depth images - if image_array.ndim == 3 and image_array.shape[0] in [1, 3]: + if image_array.ndim != 3: + raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.") + + if image_array.shape[0] == 3: # Transpose from pytorch convention (C, H, W) to (H, W, C) image_array = image_array.transpose(1, 2, 0) + + elif image_array.shape[-1] != 3: + raise NotImplementedError( + f"The image has {image_array.shape[-1]} channels, but 3 is required for now." + ) + if image_array.dtype != np.uint8: - # Assume the image is in [0, 1] range for floating-point data - image_array = np.clip(image_array, 0, 1) + if range_check: + max_ = image_array.max().item() + min_ = image_array.min().item() + if max_ > 1.0 or min_ < 0.0: + raise ValueError( + "The image data type is float, which requires values in the range [0.0, 1.0]. " + f"However, the provided range is [{min_}, {max_}]. Please adjust the range or " + "provide a uint8 image with values in the range [0, 255]." + ) + image_array = (image_array * 255).astype(np.uint8) + return PIL.Image.fromarray(image_array) def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path): try: if isinstance(image, np.ndarray): - img = image_array_to_image(image) + img = image_array_to_pil_image(image) elif isinstance(image, PIL.Image.Image): img = image else: diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 877703d7..5c4ae68e 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -39,6 +39,7 @@ from lerobot.common.datasets.utils import ( TASKS_PATH, append_jsonlines, check_delta_timestamps, + check_frame_features, check_timestamps_sync, check_version_compatibility, create_branch, @@ -724,10 +725,12 @@ class LeRobotDataset(torch.utils.data.Dataset): temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method then needs to be called. """ - # TODO(aliberts, rcadene): Add sanity check for the input, check it's numpy or torch, - # check the dtype and shape matches, etc. - if "task" not in frame: - raise ValueError("The mandatory feature 'task' wasn't found in `frame` dictionnary.") + # Convert torch to numpy if needed + for name in frame: + if isinstance(frame[name], torch.Tensor): + frame[name] = frame[name].numpy() + + check_frame_features(frame, self.features) if self.episode_buffer is None: self.episode_buffer = self.create_episode_buffer() @@ -757,8 +760,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self._save_image(frame[key], img_path) self.episode_buffer[key].append(str(img_path)) else: - item = frame[key].numpy() if isinstance(frame[key], torch.Tensor) else frame[key] - self.episode_buffer[key].append(item) + self.episode_buffer[key].append(frame[key]) self.episode_buffer["size"] += 1 @@ -815,12 +817,7 @@ class LeRobotDataset(torch.utils.data.Dataset): # are processed separately by storing image path and frame info as meta data if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]: continue - elif len(ft["shape"]) == 1 and ft["shape"][0] == 1: - episode_buffer[key] = np.array(episode_buffer[key], dtype=ft["dtype"]) - elif len(ft["shape"]) == 1 and ft["shape"][0] > 1: - episode_buffer[key] = np.stack(episode_buffer[key]) - else: - raise ValueError(key) + episode_buffer[key] = np.stack(episode_buffer[key]) self._wait_image_writer() self._save_episode_table(episode_buffer, episode_index) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index e6ec169e..505a5492 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -35,6 +35,7 @@ from PIL import Image as PILImage from torchvision import transforms from lerobot.common.robot_devices.robots.utils import Robot +from lerobot.common.utils.utils import is_valid_numpy_dtype_string from lerobot.configs.types import DictLike, FeatureType, PolicyFeature DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk @@ -203,7 +204,7 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): elif first_item is None: pass else: - items_dict[key] = [torch.tensor(x) for x in items_dict[key]] + items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]] return items_dict @@ -285,11 +286,20 @@ def get_hf_features_from_features(features: dict) -> datasets.Features: hf_features[key] = datasets.Image() elif ft["shape"] == (1,): hf_features[key] = datasets.Value(dtype=ft["dtype"]) - else: - assert len(ft["shape"]) == 1 + elif len(ft["shape"]) == 1: hf_features[key] = datasets.Sequence( length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"]) ) + elif len(ft["shape"]) == 2: + hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"]) + elif len(ft["shape"]) == 3: + hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"]) + elif len(ft["shape"]) == 4: + hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"]) + elif len(ft["shape"]) == 5: + hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"]) + else: + raise ValueError(f"Corresponding feature is not valid: {ft}") return datasets.Features(hf_features) @@ -606,3 +616,90 @@ class IterableNamespace(SimpleNamespace): def keys(self): return vars(self).keys() + + +def check_frame_features(frame: dict, features: dict): + optional_features = {"timestamp"} + expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"} + actual_features = set(frame.keys()) + + error_message = check_features_presence(actual_features, expected_features, optional_features) + + if "task" in frame: + error_message += check_feature_string("task", frame["task"]) + + common_features = actual_features & (expected_features | optional_features) + for name in common_features - {"task"}: + error_message += check_feature_dtype_and_shape(name, features[name], frame[name]) + + if error_message: + raise ValueError(error_message) + + +def check_features_presence( + actual_features: set[str], expected_features: set[str], optional_features: set[str] +): + error_message = "" + missing_features = expected_features - actual_features + extra_features = actual_features - (expected_features | optional_features) + + if missing_features or extra_features: + error_message += "Feature mismatch in `frame` dictionary:\n" + if missing_features: + error_message += f"Missing features: {missing_features}\n" + if extra_features: + error_message += f"Extra features: {extra_features}\n" + + return error_message + + +def check_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str): + expected_dtype = feature["dtype"] + expected_shape = feature["shape"] + if is_valid_numpy_dtype_string(expected_dtype): + return check_feature_numpy_array(name, expected_dtype, expected_shape, value) + elif expected_dtype in ["image", "video"]: + return check_feature_image_or_video(name, expected_shape, value) + elif expected_dtype == "string": + return check_feature_string(name, value) + else: + raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.") + + +def check_feature_numpy_array(name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray): + error_message = "" + if isinstance(value, np.ndarray): + actual_dtype = value.dtype + actual_shape = value.shape + + if actual_dtype != np.dtype(expected_dtype): + error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n" + + if actual_shape != expected_shape: + error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n" + else: + error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n" + + return error_message + + +def check_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image): + # Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads. + error_message = "" + if isinstance(value, np.ndarray): + actual_shape = value.shape + c, h, w = expected_shape + if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)): + error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n" + elif isinstance(value, PILImage.Image): + pass + else: + error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n" + + return error_message + + +def check_feature_string(name: str, value: str): + if not isinstance(value, str): + return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n" + return "" diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py index 015d1ede..d0c12b30 100644 --- a/lerobot/common/utils/utils.py +++ b/lerobot/common/utils/utils.py @@ -21,6 +21,7 @@ from copy import copy from datetime import datetime, timezone from pathlib import Path +import numpy as np import torch @@ -200,5 +201,18 @@ def get_channel_first_image_shape(image_shape: tuple) -> tuple: return shape -def has_method(cls: object, method_name: str): +def has_method(cls: object, method_name: str) -> bool: return hasattr(cls, method_name) and callable(getattr(cls, method_name)) + + +def is_valid_numpy_dtype_string(dtype_str: str) -> bool: + """ + Return True if a given string can be converted to a numpy dtype. + """ + try: + # Attempt to convert the string to a numpy dtype + np.dtype(dtype_str) + return True + except TypeError: + # If a TypeError is raised, the string is not a valid dtype + return False diff --git a/tests/fixtures/constants.py b/tests/fixtures/constants.py index bfe6c339..7d80d2b7 100644 --- a/tests/fixtures/constants.py +++ b/tests/fixtures/constants.py @@ -27,3 +27,5 @@ DUMMY_VIDEO_INFO = { "video.is_depth_map": False, "has_audio": False, } +DUMMY_CHW = (3, 96, 128) +DUMMY_HWC = (96, 128, 3) diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index c28a1165..bdd2dc54 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -17,6 +17,7 @@ from lerobot.common.datasets.utils import ( get_hf_features_from_features, hf_transform_to_torch, ) +from lerobot.common.robot_devices.robots.utils import Robot from tests.fixtures.constants import ( DEFAULT_FPS, DUMMY_CAMERA_FEATURES, @@ -394,3 +395,20 @@ def lerobot_dataset_factory( return LeRobotDataset(repo_id=repo_id, root=root, **kwargs) return _create_lerobot_dataset + + +@pytest.fixture(scope="session") +def empty_lerobot_dataset_factory(): + def _create_empty_lerobot_dataset( + root: Path, + repo_id: str = DUMMY_REPO_ID, + fps: int = DEFAULT_FPS, + robot: Robot | None = None, + robot_type: str | None = None, + features: dict | None = None, + ): + return LeRobotDataset.create( + repo_id=repo_id, fps=fps, root=root, robot=robot, robot_type=robot_type, features=features + ) + + return _create_empty_lerobot_dataset diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 460954c1..54d92125 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -15,15 +15,18 @@ # limitations under the License. import json import logging +import re from copy import deepcopy from itertools import chain from pathlib import Path import einops +import numpy as np import pytest import torch from datasets import Dataset from huggingface_hub import HfApi +from PIL import Image from safetensors.torch import load_file import lerobot @@ -33,6 +36,7 @@ from lerobot.common.datasets.compute_stats import ( get_stats_einops_patterns, ) from lerobot.common.datasets.factory import make_dataset +from lerobot.common.datasets.image_writer import image_array_to_pil_image from lerobot.common.datasets.lerobot_dataset import ( LeRobotDataset, MultiLeRobotDataset, @@ -49,11 +53,27 @@ from lerobot.common.robot_devices.robots.utils import make_robot from lerobot.common.utils.random_utils import seeded_context from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig -from tests.fixtures.constants import DUMMY_REPO_ID +from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID from tests.utils import DEVICE, require_x86_64_kernel -def test_same_attributes_defined(lerobot_dataset_factory, tmp_path): +@pytest.fixture +def image_dataset(tmp_path, empty_lerobot_dataset_factory): + features = { + "image": { + "dtype": "image", + "shape": DUMMY_CHW, + "names": [ + "channels", + "height", + "width", + ], + } + } + return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + + +def test_same_attributes_defined(tmp_path, lerobot_dataset_factory): """ Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated objects have the same sets of attributes defined. @@ -76,14 +96,14 @@ def test_same_attributes_defined(lerobot_dataset_factory, tmp_path): assert init_attr == create_attr -def test_dataset_initialization(lerobot_dataset_factory, tmp_path): +def test_dataset_initialization(tmp_path, lerobot_dataset_factory): kwargs = { "repo_id": DUMMY_REPO_ID, "total_episodes": 10, "total_frames": 400, "episodes": [2, 5, 6], } - dataset = lerobot_dataset_factory(root=tmp_path, **kwargs) + dataset = lerobot_dataset_factory(root=tmp_path / "test", **kwargs) assert dataset.repo_id == kwargs["repo_id"] assert dataset.meta.total_episodes == kwargs["total_episodes"] @@ -93,28 +113,243 @@ def test_dataset_initialization(lerobot_dataset_factory, tmp_path): assert dataset.num_frames == len(dataset) -def test_add_frame_no_task(tmp_path): - features = {"1d": {"dtype": "float32", "shape": (1,), "names": None}} - dataset = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, root=tmp_path / "test", features=features) - with pytest.raises(ValueError, match="The mandatory feature 'task' wasn't found in `frame` dictionnary."): - dataset.add_frame({"1d": torch.randn(1)}) +def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory): + features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + with pytest.raises( + ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n" + ): + dataset.add_frame({"state": torch.randn(1)}) -def test_add_frame(tmp_path): - features = {"1d": {"dtype": "float32", "shape": (1,), "names": None}} - dataset = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, root=tmp_path / "test", features=features) - dataset.add_frame({"1d": torch.randn(1), "task": "dummy"}) +def test_add_frame_missing_feature(tmp_path, empty_lerobot_dataset_factory): + features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + with pytest.raises( + ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n" + ): + dataset.add_frame({"task": "Dummy task"}) + + +def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory): + features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + with pytest.raises( + ValueError, match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n" + ): + dataset.add_frame({"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"}) + + +def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory): + features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + with pytest.raises( + ValueError, match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n" + ): + dataset.add_frame({"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"}) + + +def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory): + features = {"state": {"dtype": "float32", "shape": (2,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + with pytest.raises( + ValueError, + match=re.escape("The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"), + ): + dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"}) + + +def test_add_frame_wrong_shape_python_float(tmp_path, empty_lerobot_dataset_factory): + features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + with pytest.raises( + ValueError, + match=re.escape( + "The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '' provided instead.\n" + ), + ): + dataset.add_frame({"state": 1.0, "task": "Dummy task"}) + + +def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_factory): + features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + with pytest.raises( + ValueError, + match=re.escape("The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"), + ): + dataset.add_frame({"state": torch.tensor(1.0), "task": "Dummy task"}) + + +def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_factory): + features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + with pytest.raises( + ValueError, + match=re.escape( + "The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '' provided instead.\n" + ), + ): + dataset.add_frame({"state": np.float32(1.0), "task": "Dummy task"}) + + +def test_add_frame(tmp_path, empty_lerobot_dataset_factory): + features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + dataset.add_frame({"state": torch.randn(1), "task": "dummy"}) dataset.save_episode(encode_videos=False) dataset.consolidate(run_compute_stats=False) + assert len(dataset) == 1 assert dataset[0]["task"] == "dummy" assert dataset[0]["task_index"] == 0 + assert dataset[0]["state"].ndim == 0 + + +def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory): + features = {"state": {"dtype": "float32", "shape": (2,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + dataset.add_frame({"state": torch.randn(2), "task": "dummy"}) + dataset.save_episode(encode_videos=False) + dataset.consolidate(run_compute_stats=False) + + assert dataset[0]["state"].shape == torch.Size([2]) + + +def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory): + features = {"state": {"dtype": "float32", "shape": (2, 4), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + dataset.add_frame({"state": torch.randn(2, 4), "task": "dummy"}) + dataset.save_episode(encode_videos=False) + dataset.consolidate(run_compute_stats=False) + + assert dataset[0]["state"].shape == torch.Size([2, 4]) + + +def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory): + features = {"state": {"dtype": "float32", "shape": (2, 4, 3), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "dummy"}) + dataset.save_episode(encode_videos=False) + dataset.consolidate(run_compute_stats=False) + + assert dataset[0]["state"].shape == torch.Size([2, 4, 3]) + + +def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory): + features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "dummy"}) + dataset.save_episode(encode_videos=False) + dataset.consolidate(run_compute_stats=False) + + assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5]) + + +def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory): + features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5, 1), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "dummy"}) + dataset.save_episode(encode_videos=False) + dataset.consolidate(run_compute_stats=False) + + assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1]) + + +def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory): + features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + dataset.add_frame({"state": np.array([1], dtype=np.float32), "task": "Dummy task"}) + dataset.save_episode(encode_videos=False) + dataset.consolidate(run_compute_stats=False) + + assert dataset[0]["state"].ndim == 0 + + +def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory): + features = {"caption": {"dtype": "string", "shape": (1,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + dataset.add_frame({"caption": "Dummy caption", "task": "Dummy task"}) + dataset.save_episode(encode_videos=False) + dataset.consolidate(run_compute_stats=False) + + assert dataset[0]["caption"] == "Dummy caption" + + +def test_add_frame_image_wrong_shape(image_dataset): + dataset = image_dataset + with pytest.raises( + ValueError, + match=re.escape( + "The feature 'image' of shape '(3, 128, 96)' does not have the expected shape '(3, 96, 128)' or '(96, 128, 3)'.\n" + ), + ): + c, h, w = DUMMY_CHW + dataset.add_frame({"image": torch.randn(c, w, h), "task": "Dummy task"}) + + +def test_add_frame_image_wrong_range(image_dataset): + """This test will display the following error message from a thread: + ``` + Error writing image ...test_add_frame_image_wrong_ran0/test/images/image/episode_000000/frame_000000.png: + The image data type is float, which requires values in the range [0.0, 1.0]. However, the provided range is [0.009678772038470007, 254.9776492089887]. + Please adjust the range or provide a uint8 image with values in the range [0, 255] + ``` + Hence the image won't be saved on disk and save_episode will raise `FileNotFoundError`. + """ + dataset = image_dataset + dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255, "task": "Dummy task"}) + with pytest.raises(FileNotFoundError): + dataset.save_episode(encode_videos=False) + + +def test_add_frame_image(image_dataset): + dataset = image_dataset + dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"}) + dataset.save_episode(encode_videos=False) + dataset.consolidate(run_compute_stats=False) + + assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) + + +def test_add_frame_image_h_w_c(image_dataset): + dataset = image_dataset + dataset.add_frame({"image": np.random.rand(*DUMMY_HWC), "task": "Dummy task"}) + dataset.save_episode(encode_videos=False) + dataset.consolidate(run_compute_stats=False) + + assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) + + +def test_add_frame_image_uint8(image_dataset): + dataset = image_dataset + image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8) + dataset.add_frame({"image": image, "task": "Dummy task"}) + dataset.save_episode(encode_videos=False) + dataset.consolidate(run_compute_stats=False) + + assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) + + +def test_add_frame_image_pil(image_dataset): + dataset = image_dataset + image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8) + dataset.add_frame({"image": Image.fromarray(image), "task": "Dummy task"}) + dataset.save_episode(encode_videos=False) + dataset.consolidate(run_compute_stats=False) + + assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) + + +def test_image_array_to_pil_image_wrong_range_float_0_255(): + image = np.random.rand(*DUMMY_HWC) * 255 + with pytest.raises(ValueError): + image_array_to_pil_image(image) # TODO(aliberts): # - [ ] test various attributes & state from init and create # - [ ] test init with episodes and check num_frames -# - [ ] test add_frame # - [ ] test add_episode # - [ ] test consolidate # - [ ] test push_to_hub diff --git a/tests/test_image_writer.py b/tests/test_image_writer.py index f51e86b4..c7fc11f2 100644 --- a/tests/test_image_writer.py +++ b/tests/test_image_writer.py @@ -9,10 +9,11 @@ from PIL import Image from lerobot.common.datasets.image_writer import ( AsyncImageWriter, - image_array_to_image, + image_array_to_pil_image, safe_stop_image_writer, write_image, ) +from tests.fixtures.constants import DUMMY_HWC DUMMY_IMAGE = "test_image.png" @@ -48,49 +49,62 @@ def test_zero_threads(): AsyncImageWriter(num_processes=0, num_threads=0) -def test_image_array_to_image_rgb(img_array_factory): +def test_image_array_to_pil_image_float_array_wrong_range_0_255(): + image = np.random.rand(*DUMMY_HWC) * 255 + with pytest.raises(ValueError): + image_array_to_pil_image(image) + + +def test_image_array_to_pil_image_float_array_wrong_range_neg_1_1(): + image = np.random.rand(*DUMMY_HWC) * 2 - 1 + with pytest.raises(ValueError): + image_array_to_pil_image(image) + + +def test_image_array_to_pil_image_rgb(img_array_factory): img_array = img_array_factory(100, 100) - result_image = image_array_to_image(img_array) + result_image = image_array_to_pil_image(img_array) assert isinstance(result_image, Image.Image) assert result_image.size == (100, 100) assert result_image.mode == "RGB" -def test_image_array_to_image_pytorch_format(img_array_factory): +def test_image_array_to_pil_image_pytorch_format(img_array_factory): img_array = img_array_factory(100, 100).transpose(2, 0, 1) - result_image = image_array_to_image(img_array) + result_image = image_array_to_pil_image(img_array) assert isinstance(result_image, Image.Image) assert result_image.size == (100, 100) assert result_image.mode == "RGB" -@pytest.mark.skip("TODO: implement") -def test_image_array_to_image_single_channel(img_array_factory): +def test_image_array_to_pil_image_single_channel(img_array_factory): img_array = img_array_factory(channels=1) - result_image = image_array_to_image(img_array) - assert isinstance(result_image, Image.Image) - assert result_image.size == (100, 100) - assert result_image.mode == "L" + with pytest.raises(NotImplementedError): + image_array_to_pil_image(img_array) -def test_image_array_to_image_float_array(img_array_factory): +def test_image_array_to_pil_image_4_channels(img_array_factory): + img_array = img_array_factory(channels=4) + with pytest.raises(NotImplementedError): + image_array_to_pil_image(img_array) + + +def test_image_array_to_pil_image_float_array(img_array_factory): img_array = img_array_factory(dtype=np.float32) - result_image = image_array_to_image(img_array) + result_image = image_array_to_pil_image(img_array) assert isinstance(result_image, Image.Image) assert result_image.size == (100, 100) assert result_image.mode == "RGB" assert np.array(result_image).dtype == np.uint8 -def test_image_array_to_image_out_of_bounds_float(): - # Float array with values out of [0, 1] - img_array = np.random.uniform(-1, 2, size=(100, 100, 3)).astype(np.float32) - result_image = image_array_to_image(img_array) +def test_image_array_to_pil_image_uint8_array(img_array_factory): + img_array = img_array_factory(dtype=np.float32) + result_image = image_array_to_pil_image(img_array) assert isinstance(result_image, Image.Image) assert result_image.size == (100, 100) assert result_image.mode == "RGB" assert np.array(result_image).dtype == np.uint8 - assert np.array(result_image).min() >= 0 and np.array(result_image).max() <= 255 def test_write_image_numpy(tmp_path, img_array_factory):