Validate features during `add_frame` + Add 2D-to-5D + Add string (#720)
This commit is contained in:
parent
9d6886dd08
commit
7c2bbee613
|
@ -38,22 +38,40 @@ def safe_stop_image_writer(func):
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def image_array_to_image(image_array: np.ndarray) -> PIL.Image.Image:
|
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
|
||||||
# TODO(aliberts): handle 1 channel and 4 for depth images
|
# TODO(aliberts): handle 1 channel and 4 for depth images
|
||||||
if image_array.ndim == 3 and image_array.shape[0] in [1, 3]:
|
if image_array.ndim != 3:
|
||||||
|
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
|
||||||
|
|
||||||
|
if image_array.shape[0] == 3:
|
||||||
# Transpose from pytorch convention (C, H, W) to (H, W, C)
|
# Transpose from pytorch convention (C, H, W) to (H, W, C)
|
||||||
image_array = image_array.transpose(1, 2, 0)
|
image_array = image_array.transpose(1, 2, 0)
|
||||||
|
|
||||||
|
elif image_array.shape[-1] != 3:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"The image has {image_array.shape[-1]} channels, but 3 is required for now."
|
||||||
|
)
|
||||||
|
|
||||||
if image_array.dtype != np.uint8:
|
if image_array.dtype != np.uint8:
|
||||||
# Assume the image is in [0, 1] range for floating-point data
|
if range_check:
|
||||||
image_array = np.clip(image_array, 0, 1)
|
max_ = image_array.max().item()
|
||||||
|
min_ = image_array.min().item()
|
||||||
|
if max_ > 1.0 or min_ < 0.0:
|
||||||
|
raise ValueError(
|
||||||
|
"The image data type is float, which requires values in the range [0.0, 1.0]. "
|
||||||
|
f"However, the provided range is [{min_}, {max_}]. Please adjust the range or "
|
||||||
|
"provide a uint8 image with values in the range [0, 255]."
|
||||||
|
)
|
||||||
|
|
||||||
image_array = (image_array * 255).astype(np.uint8)
|
image_array = (image_array * 255).astype(np.uint8)
|
||||||
|
|
||||||
return PIL.Image.fromarray(image_array)
|
return PIL.Image.fromarray(image_array)
|
||||||
|
|
||||||
|
|
||||||
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
|
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
|
||||||
try:
|
try:
|
||||||
if isinstance(image, np.ndarray):
|
if isinstance(image, np.ndarray):
|
||||||
img = image_array_to_image(image)
|
img = image_array_to_pil_image(image)
|
||||||
elif isinstance(image, PIL.Image.Image):
|
elif isinstance(image, PIL.Image.Image):
|
||||||
img = image
|
img = image
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -39,6 +39,7 @@ from lerobot.common.datasets.utils import (
|
||||||
TASKS_PATH,
|
TASKS_PATH,
|
||||||
append_jsonlines,
|
append_jsonlines,
|
||||||
check_delta_timestamps,
|
check_delta_timestamps,
|
||||||
|
check_frame_features,
|
||||||
check_timestamps_sync,
|
check_timestamps_sync,
|
||||||
check_version_compatibility,
|
check_version_compatibility,
|
||||||
create_branch,
|
create_branch,
|
||||||
|
@ -724,10 +725,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method
|
temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method
|
||||||
then needs to be called.
|
then needs to be called.
|
||||||
"""
|
"""
|
||||||
# TODO(aliberts, rcadene): Add sanity check for the input, check it's numpy or torch,
|
# Convert torch to numpy if needed
|
||||||
# check the dtype and shape matches, etc.
|
for name in frame:
|
||||||
if "task" not in frame:
|
if isinstance(frame[name], torch.Tensor):
|
||||||
raise ValueError("The mandatory feature 'task' wasn't found in `frame` dictionnary.")
|
frame[name] = frame[name].numpy()
|
||||||
|
|
||||||
|
check_frame_features(frame, self.features)
|
||||||
|
|
||||||
if self.episode_buffer is None:
|
if self.episode_buffer is None:
|
||||||
self.episode_buffer = self.create_episode_buffer()
|
self.episode_buffer = self.create_episode_buffer()
|
||||||
|
@ -757,8 +760,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self._save_image(frame[key], img_path)
|
self._save_image(frame[key], img_path)
|
||||||
self.episode_buffer[key].append(str(img_path))
|
self.episode_buffer[key].append(str(img_path))
|
||||||
else:
|
else:
|
||||||
item = frame[key].numpy() if isinstance(frame[key], torch.Tensor) else frame[key]
|
self.episode_buffer[key].append(frame[key])
|
||||||
self.episode_buffer[key].append(item)
|
|
||||||
|
|
||||||
self.episode_buffer["size"] += 1
|
self.episode_buffer["size"] += 1
|
||||||
|
|
||||||
|
@ -815,12 +817,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
# are processed separately by storing image path and frame info as meta data
|
# are processed separately by storing image path and frame info as meta data
|
||||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
||||||
continue
|
continue
|
||||||
elif len(ft["shape"]) == 1 and ft["shape"][0] == 1:
|
|
||||||
episode_buffer[key] = np.array(episode_buffer[key], dtype=ft["dtype"])
|
|
||||||
elif len(ft["shape"]) == 1 and ft["shape"][0] > 1:
|
|
||||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||||
else:
|
|
||||||
raise ValueError(key)
|
|
||||||
|
|
||||||
self._wait_image_writer()
|
self._wait_image_writer()
|
||||||
self._save_episode_table(episode_buffer, episode_index)
|
self._save_episode_table(episode_buffer, episode_index)
|
||||||
|
|
|
@ -35,6 +35,7 @@ from PIL import Image as PILImage
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
from lerobot.common.robot_devices.robots.utils import Robot
|
from lerobot.common.robot_devices.robots.utils import Robot
|
||||||
|
from lerobot.common.utils.utils import is_valid_numpy_dtype_string
|
||||||
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
|
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
|
||||||
|
|
||||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
|
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
|
||||||
|
@ -203,7 +204,7 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||||
elif first_item is None:
|
elif first_item is None:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
|
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
|
||||||
return items_dict
|
return items_dict
|
||||||
|
|
||||||
|
|
||||||
|
@ -285,11 +286,20 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||||
hf_features[key] = datasets.Image()
|
hf_features[key] = datasets.Image()
|
||||||
elif ft["shape"] == (1,):
|
elif ft["shape"] == (1,):
|
||||||
hf_features[key] = datasets.Value(dtype=ft["dtype"])
|
hf_features[key] = datasets.Value(dtype=ft["dtype"])
|
||||||
else:
|
elif len(ft["shape"]) == 1:
|
||||||
assert len(ft["shape"]) == 1
|
|
||||||
hf_features[key] = datasets.Sequence(
|
hf_features[key] = datasets.Sequence(
|
||||||
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
|
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
|
||||||
)
|
)
|
||||||
|
elif len(ft["shape"]) == 2:
|
||||||
|
hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"])
|
||||||
|
elif len(ft["shape"]) == 3:
|
||||||
|
hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"])
|
||||||
|
elif len(ft["shape"]) == 4:
|
||||||
|
hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"])
|
||||||
|
elif len(ft["shape"]) == 5:
|
||||||
|
hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Corresponding feature is not valid: {ft}")
|
||||||
|
|
||||||
return datasets.Features(hf_features)
|
return datasets.Features(hf_features)
|
||||||
|
|
||||||
|
@ -606,3 +616,90 @@ class IterableNamespace(SimpleNamespace):
|
||||||
|
|
||||||
def keys(self):
|
def keys(self):
|
||||||
return vars(self).keys()
|
return vars(self).keys()
|
||||||
|
|
||||||
|
|
||||||
|
def check_frame_features(frame: dict, features: dict):
|
||||||
|
optional_features = {"timestamp"}
|
||||||
|
expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"}
|
||||||
|
actual_features = set(frame.keys())
|
||||||
|
|
||||||
|
error_message = check_features_presence(actual_features, expected_features, optional_features)
|
||||||
|
|
||||||
|
if "task" in frame:
|
||||||
|
error_message += check_feature_string("task", frame["task"])
|
||||||
|
|
||||||
|
common_features = actual_features & (expected_features | optional_features)
|
||||||
|
for name in common_features - {"task"}:
|
||||||
|
error_message += check_feature_dtype_and_shape(name, features[name], frame[name])
|
||||||
|
|
||||||
|
if error_message:
|
||||||
|
raise ValueError(error_message)
|
||||||
|
|
||||||
|
|
||||||
|
def check_features_presence(
|
||||||
|
actual_features: set[str], expected_features: set[str], optional_features: set[str]
|
||||||
|
):
|
||||||
|
error_message = ""
|
||||||
|
missing_features = expected_features - actual_features
|
||||||
|
extra_features = actual_features - (expected_features | optional_features)
|
||||||
|
|
||||||
|
if missing_features or extra_features:
|
||||||
|
error_message += "Feature mismatch in `frame` dictionary:\n"
|
||||||
|
if missing_features:
|
||||||
|
error_message += f"Missing features: {missing_features}\n"
|
||||||
|
if extra_features:
|
||||||
|
error_message += f"Extra features: {extra_features}\n"
|
||||||
|
|
||||||
|
return error_message
|
||||||
|
|
||||||
|
|
||||||
|
def check_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
|
||||||
|
expected_dtype = feature["dtype"]
|
||||||
|
expected_shape = feature["shape"]
|
||||||
|
if is_valid_numpy_dtype_string(expected_dtype):
|
||||||
|
return check_feature_numpy_array(name, expected_dtype, expected_shape, value)
|
||||||
|
elif expected_dtype in ["image", "video"]:
|
||||||
|
return check_feature_image_or_video(name, expected_shape, value)
|
||||||
|
elif expected_dtype == "string":
|
||||||
|
return check_feature_string(name, value)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
|
||||||
|
|
||||||
|
|
||||||
|
def check_feature_numpy_array(name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray):
|
||||||
|
error_message = ""
|
||||||
|
if isinstance(value, np.ndarray):
|
||||||
|
actual_dtype = value.dtype
|
||||||
|
actual_shape = value.shape
|
||||||
|
|
||||||
|
if actual_dtype != np.dtype(expected_dtype):
|
||||||
|
error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n"
|
||||||
|
|
||||||
|
if actual_shape != expected_shape:
|
||||||
|
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n"
|
||||||
|
else:
|
||||||
|
error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n"
|
||||||
|
|
||||||
|
return error_message
|
||||||
|
|
||||||
|
|
||||||
|
def check_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
|
||||||
|
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
|
||||||
|
error_message = ""
|
||||||
|
if isinstance(value, np.ndarray):
|
||||||
|
actual_shape = value.shape
|
||||||
|
c, h, w = expected_shape
|
||||||
|
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
|
||||||
|
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
|
||||||
|
elif isinstance(value, PILImage.Image):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n"
|
||||||
|
|
||||||
|
return error_message
|
||||||
|
|
||||||
|
|
||||||
|
def check_feature_string(name: str, value: str):
|
||||||
|
if not isinstance(value, str):
|
||||||
|
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
|
||||||
|
return ""
|
||||||
|
|
|
@ -21,6 +21,7 @@ from copy import copy
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@ -200,5 +201,18 @@ def get_channel_first_image_shape(image_shape: tuple) -> tuple:
|
||||||
return shape
|
return shape
|
||||||
|
|
||||||
|
|
||||||
def has_method(cls: object, method_name: str):
|
def has_method(cls: object, method_name: str) -> bool:
|
||||||
return hasattr(cls, method_name) and callable(getattr(cls, method_name))
|
return hasattr(cls, method_name) and callable(getattr(cls, method_name))
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
|
||||||
|
"""
|
||||||
|
Return True if a given string can be converted to a numpy dtype.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Attempt to convert the string to a numpy dtype
|
||||||
|
np.dtype(dtype_str)
|
||||||
|
return True
|
||||||
|
except TypeError:
|
||||||
|
# If a TypeError is raised, the string is not a valid dtype
|
||||||
|
return False
|
||||||
|
|
|
@ -27,3 +27,5 @@ DUMMY_VIDEO_INFO = {
|
||||||
"video.is_depth_map": False,
|
"video.is_depth_map": False,
|
||||||
"has_audio": False,
|
"has_audio": False,
|
||||||
}
|
}
|
||||||
|
DUMMY_CHW = (3, 96, 128)
|
||||||
|
DUMMY_HWC = (96, 128, 3)
|
||||||
|
|
|
@ -17,6 +17,7 @@ from lerobot.common.datasets.utils import (
|
||||||
get_hf_features_from_features,
|
get_hf_features_from_features,
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
)
|
)
|
||||||
|
from lerobot.common.robot_devices.robots.utils import Robot
|
||||||
from tests.fixtures.constants import (
|
from tests.fixtures.constants import (
|
||||||
DEFAULT_FPS,
|
DEFAULT_FPS,
|
||||||
DUMMY_CAMERA_FEATURES,
|
DUMMY_CAMERA_FEATURES,
|
||||||
|
@ -394,3 +395,20 @@ def lerobot_dataset_factory(
|
||||||
return LeRobotDataset(repo_id=repo_id, root=root, **kwargs)
|
return LeRobotDataset(repo_id=repo_id, root=root, **kwargs)
|
||||||
|
|
||||||
return _create_lerobot_dataset
|
return _create_lerobot_dataset
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def empty_lerobot_dataset_factory():
|
||||||
|
def _create_empty_lerobot_dataset(
|
||||||
|
root: Path,
|
||||||
|
repo_id: str = DUMMY_REPO_ID,
|
||||||
|
fps: int = DEFAULT_FPS,
|
||||||
|
robot: Robot | None = None,
|
||||||
|
robot_type: str | None = None,
|
||||||
|
features: dict | None = None,
|
||||||
|
):
|
||||||
|
return LeRobotDataset.create(
|
||||||
|
repo_id=repo_id, fps=fps, root=root, robot=robot, robot_type=robot_type, features=features
|
||||||
|
)
|
||||||
|
|
||||||
|
return _create_empty_lerobot_dataset
|
||||||
|
|
|
@ -15,15 +15,18 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
|
from PIL import Image
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
import lerobot
|
import lerobot
|
||||||
|
@ -33,6 +36,7 @@ from lerobot.common.datasets.compute_stats import (
|
||||||
get_stats_einops_patterns,
|
get_stats_einops_patterns,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
|
from lerobot.common.datasets.image_writer import image_array_to_pil_image
|
||||||
from lerobot.common.datasets.lerobot_dataset import (
|
from lerobot.common.datasets.lerobot_dataset import (
|
||||||
LeRobotDataset,
|
LeRobotDataset,
|
||||||
MultiLeRobotDataset,
|
MultiLeRobotDataset,
|
||||||
|
@ -49,11 +53,27 @@ from lerobot.common.robot_devices.robots.utils import make_robot
|
||||||
from lerobot.common.utils.random_utils import seeded_context
|
from lerobot.common.utils.random_utils import seeded_context
|
||||||
from lerobot.configs.default import DatasetConfig
|
from lerobot.configs.default import DatasetConfig
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||||
from tests.utils import DEVICE, require_x86_64_kernel
|
from tests.utils import DEVICE, require_x86_64_kernel
|
||||||
|
|
||||||
|
|
||||||
def test_same_attributes_defined(lerobot_dataset_factory, tmp_path):
|
@pytest.fixture
|
||||||
|
def image_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||||
|
features = {
|
||||||
|
"image": {
|
||||||
|
"dtype": "image",
|
||||||
|
"shape": DUMMY_CHW,
|
||||||
|
"names": [
|
||||||
|
"channels",
|
||||||
|
"height",
|
||||||
|
"width",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
|
|
||||||
|
|
||||||
|
def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
||||||
"""
|
"""
|
||||||
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
|
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
|
||||||
objects have the same sets of attributes defined.
|
objects have the same sets of attributes defined.
|
||||||
|
@ -76,14 +96,14 @@ def test_same_attributes_defined(lerobot_dataset_factory, tmp_path):
|
||||||
assert init_attr == create_attr
|
assert init_attr == create_attr
|
||||||
|
|
||||||
|
|
||||||
def test_dataset_initialization(lerobot_dataset_factory, tmp_path):
|
def test_dataset_initialization(tmp_path, lerobot_dataset_factory):
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"repo_id": DUMMY_REPO_ID,
|
"repo_id": DUMMY_REPO_ID,
|
||||||
"total_episodes": 10,
|
"total_episodes": 10,
|
||||||
"total_frames": 400,
|
"total_frames": 400,
|
||||||
"episodes": [2, 5, 6],
|
"episodes": [2, 5, 6],
|
||||||
}
|
}
|
||||||
dataset = lerobot_dataset_factory(root=tmp_path, **kwargs)
|
dataset = lerobot_dataset_factory(root=tmp_path / "test", **kwargs)
|
||||||
|
|
||||||
assert dataset.repo_id == kwargs["repo_id"]
|
assert dataset.repo_id == kwargs["repo_id"]
|
||||||
assert dataset.meta.total_episodes == kwargs["total_episodes"]
|
assert dataset.meta.total_episodes == kwargs["total_episodes"]
|
||||||
|
@ -93,28 +113,243 @@ def test_dataset_initialization(lerobot_dataset_factory, tmp_path):
|
||||||
assert dataset.num_frames == len(dataset)
|
assert dataset.num_frames == len(dataset)
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_no_task(tmp_path):
|
def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {"1d": {"dtype": "float32", "shape": (1,), "names": None}}
|
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||||
dataset = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, root=tmp_path / "test", features=features)
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
with pytest.raises(ValueError, match="The mandatory feature 'task' wasn't found in `frame` dictionnary."):
|
with pytest.raises(
|
||||||
dataset.add_frame({"1d": torch.randn(1)})
|
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n"
|
||||||
|
):
|
||||||
|
dataset.add_frame({"state": torch.randn(1)})
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame(tmp_path):
|
def test_add_frame_missing_feature(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {"1d": {"dtype": "float32", "shape": (1,), "names": None}}
|
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||||
dataset = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, root=tmp_path / "test", features=features)
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
dataset.add_frame({"1d": torch.randn(1), "task": "dummy"})
|
with pytest.raises(
|
||||||
|
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n"
|
||||||
|
):
|
||||||
|
dataset.add_frame({"task": "Dummy task"})
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory):
|
||||||
|
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||||
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n"
|
||||||
|
):
|
||||||
|
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"})
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
|
||||||
|
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||||
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError, match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n"
|
||||||
|
):
|
||||||
|
dataset.add_frame({"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"})
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
|
||||||
|
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
|
||||||
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=re.escape("The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"),
|
||||||
|
):
|
||||||
|
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_frame_wrong_shape_python_float(tmp_path, empty_lerobot_dataset_factory):
|
||||||
|
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||||
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=re.escape(
|
||||||
|
"The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '<class 'float'>' provided instead.\n"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
dataset.add_frame({"state": 1.0, "task": "Dummy task"})
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_factory):
|
||||||
|
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||||
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=re.escape("The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"),
|
||||||
|
):
|
||||||
|
dataset.add_frame({"state": torch.tensor(1.0), "task": "Dummy task"})
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_factory):
|
||||||
|
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||||
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=re.escape(
|
||||||
|
"The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '<class 'numpy.float32'>' provided instead.\n"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
dataset.add_frame({"state": np.float32(1.0), "task": "Dummy task"})
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
|
||||||
|
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||||
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
|
dataset.add_frame({"state": torch.randn(1), "task": "dummy"})
|
||||||
dataset.save_episode(encode_videos=False)
|
dataset.save_episode(encode_videos=False)
|
||||||
dataset.consolidate(run_compute_stats=False)
|
dataset.consolidate(run_compute_stats=False)
|
||||||
|
|
||||||
assert len(dataset) == 1
|
assert len(dataset) == 1
|
||||||
assert dataset[0]["task"] == "dummy"
|
assert dataset[0]["task"] == "dummy"
|
||||||
assert dataset[0]["task_index"] == 0
|
assert dataset[0]["task_index"] == 0
|
||||||
|
assert dataset[0]["state"].ndim == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
|
||||||
|
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
|
||||||
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
|
dataset.add_frame({"state": torch.randn(2), "task": "dummy"})
|
||||||
|
dataset.save_episode(encode_videos=False)
|
||||||
|
dataset.consolidate(run_compute_stats=False)
|
||||||
|
|
||||||
|
assert dataset[0]["state"].shape == torch.Size([2])
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
|
||||||
|
features = {"state": {"dtype": "float32", "shape": (2, 4), "names": None}}
|
||||||
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
|
dataset.add_frame({"state": torch.randn(2, 4), "task": "dummy"})
|
||||||
|
dataset.save_episode(encode_videos=False)
|
||||||
|
dataset.consolidate(run_compute_stats=False)
|
||||||
|
|
||||||
|
assert dataset[0]["state"].shape == torch.Size([2, 4])
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
|
||||||
|
features = {"state": {"dtype": "float32", "shape": (2, 4, 3), "names": None}}
|
||||||
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
|
dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "dummy"})
|
||||||
|
dataset.save_episode(encode_videos=False)
|
||||||
|
dataset.consolidate(run_compute_stats=False)
|
||||||
|
|
||||||
|
assert dataset[0]["state"].shape == torch.Size([2, 4, 3])
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
|
||||||
|
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5), "names": None}}
|
||||||
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
|
dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "dummy"})
|
||||||
|
dataset.save_episode(encode_videos=False)
|
||||||
|
dataset.consolidate(run_compute_stats=False)
|
||||||
|
|
||||||
|
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5])
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory):
|
||||||
|
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5, 1), "names": None}}
|
||||||
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
|
dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "dummy"})
|
||||||
|
dataset.save_episode(encode_videos=False)
|
||||||
|
dataset.consolidate(run_compute_stats=False)
|
||||||
|
|
||||||
|
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1])
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory):
|
||||||
|
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||||
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
|
dataset.add_frame({"state": np.array([1], dtype=np.float32), "task": "Dummy task"})
|
||||||
|
dataset.save_episode(encode_videos=False)
|
||||||
|
dataset.consolidate(run_compute_stats=False)
|
||||||
|
|
||||||
|
assert dataset[0]["state"].ndim == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory):
|
||||||
|
features = {"caption": {"dtype": "string", "shape": (1,), "names": None}}
|
||||||
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
|
dataset.add_frame({"caption": "Dummy caption", "task": "Dummy task"})
|
||||||
|
dataset.save_episode(encode_videos=False)
|
||||||
|
dataset.consolidate(run_compute_stats=False)
|
||||||
|
|
||||||
|
assert dataset[0]["caption"] == "Dummy caption"
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_frame_image_wrong_shape(image_dataset):
|
||||||
|
dataset = image_dataset
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=re.escape(
|
||||||
|
"The feature 'image' of shape '(3, 128, 96)' does not have the expected shape '(3, 96, 128)' or '(96, 128, 3)'.\n"
|
||||||
|
),
|
||||||
|
):
|
||||||
|
c, h, w = DUMMY_CHW
|
||||||
|
dataset.add_frame({"image": torch.randn(c, w, h), "task": "Dummy task"})
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_frame_image_wrong_range(image_dataset):
|
||||||
|
"""This test will display the following error message from a thread:
|
||||||
|
```
|
||||||
|
Error writing image ...test_add_frame_image_wrong_ran0/test/images/image/episode_000000/frame_000000.png:
|
||||||
|
The image data type is float, which requires values in the range [0.0, 1.0]. However, the provided range is [0.009678772038470007, 254.9776492089887].
|
||||||
|
Please adjust the range or provide a uint8 image with values in the range [0, 255]
|
||||||
|
```
|
||||||
|
Hence the image won't be saved on disk and save_episode will raise `FileNotFoundError`.
|
||||||
|
"""
|
||||||
|
dataset = image_dataset
|
||||||
|
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255, "task": "Dummy task"})
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
|
dataset.save_episode(encode_videos=False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_frame_image(image_dataset):
|
||||||
|
dataset = image_dataset
|
||||||
|
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
|
||||||
|
dataset.save_episode(encode_videos=False)
|
||||||
|
dataset.consolidate(run_compute_stats=False)
|
||||||
|
|
||||||
|
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_frame_image_h_w_c(image_dataset):
|
||||||
|
dataset = image_dataset
|
||||||
|
dataset.add_frame({"image": np.random.rand(*DUMMY_HWC), "task": "Dummy task"})
|
||||||
|
dataset.save_episode(encode_videos=False)
|
||||||
|
dataset.consolidate(run_compute_stats=False)
|
||||||
|
|
||||||
|
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_frame_image_uint8(image_dataset):
|
||||||
|
dataset = image_dataset
|
||||||
|
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
|
||||||
|
dataset.add_frame({"image": image, "task": "Dummy task"})
|
||||||
|
dataset.save_episode(encode_videos=False)
|
||||||
|
dataset.consolidate(run_compute_stats=False)
|
||||||
|
|
||||||
|
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_frame_image_pil(image_dataset):
|
||||||
|
dataset = image_dataset
|
||||||
|
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
|
||||||
|
dataset.add_frame({"image": Image.fromarray(image), "task": "Dummy task"})
|
||||||
|
dataset.save_episode(encode_videos=False)
|
||||||
|
dataset.consolidate(run_compute_stats=False)
|
||||||
|
|
||||||
|
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_array_to_pil_image_wrong_range_float_0_255():
|
||||||
|
image = np.random.rand(*DUMMY_HWC) * 255
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
image_array_to_pil_image(image)
|
||||||
|
|
||||||
|
|
||||||
# TODO(aliberts):
|
# TODO(aliberts):
|
||||||
# - [ ] test various attributes & state from init and create
|
# - [ ] test various attributes & state from init and create
|
||||||
# - [ ] test init with episodes and check num_frames
|
# - [ ] test init with episodes and check num_frames
|
||||||
# - [ ] test add_frame
|
|
||||||
# - [ ] test add_episode
|
# - [ ] test add_episode
|
||||||
# - [ ] test consolidate
|
# - [ ] test consolidate
|
||||||
# - [ ] test push_to_hub
|
# - [ ] test push_to_hub
|
||||||
|
|
|
@ -9,10 +9,11 @@ from PIL import Image
|
||||||
|
|
||||||
from lerobot.common.datasets.image_writer import (
|
from lerobot.common.datasets.image_writer import (
|
||||||
AsyncImageWriter,
|
AsyncImageWriter,
|
||||||
image_array_to_image,
|
image_array_to_pil_image,
|
||||||
safe_stop_image_writer,
|
safe_stop_image_writer,
|
||||||
write_image,
|
write_image,
|
||||||
)
|
)
|
||||||
|
from tests.fixtures.constants import DUMMY_HWC
|
||||||
|
|
||||||
DUMMY_IMAGE = "test_image.png"
|
DUMMY_IMAGE = "test_image.png"
|
||||||
|
|
||||||
|
@ -48,49 +49,62 @@ def test_zero_threads():
|
||||||
AsyncImageWriter(num_processes=0, num_threads=0)
|
AsyncImageWriter(num_processes=0, num_threads=0)
|
||||||
|
|
||||||
|
|
||||||
def test_image_array_to_image_rgb(img_array_factory):
|
def test_image_array_to_pil_image_float_array_wrong_range_0_255():
|
||||||
|
image = np.random.rand(*DUMMY_HWC) * 255
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
image_array_to_pil_image(image)
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_array_to_pil_image_float_array_wrong_range_neg_1_1():
|
||||||
|
image = np.random.rand(*DUMMY_HWC) * 2 - 1
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
image_array_to_pil_image(image)
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_array_to_pil_image_rgb(img_array_factory):
|
||||||
img_array = img_array_factory(100, 100)
|
img_array = img_array_factory(100, 100)
|
||||||
result_image = image_array_to_image(img_array)
|
result_image = image_array_to_pil_image(img_array)
|
||||||
assert isinstance(result_image, Image.Image)
|
assert isinstance(result_image, Image.Image)
|
||||||
assert result_image.size == (100, 100)
|
assert result_image.size == (100, 100)
|
||||||
assert result_image.mode == "RGB"
|
assert result_image.mode == "RGB"
|
||||||
|
|
||||||
|
|
||||||
def test_image_array_to_image_pytorch_format(img_array_factory):
|
def test_image_array_to_pil_image_pytorch_format(img_array_factory):
|
||||||
img_array = img_array_factory(100, 100).transpose(2, 0, 1)
|
img_array = img_array_factory(100, 100).transpose(2, 0, 1)
|
||||||
result_image = image_array_to_image(img_array)
|
result_image = image_array_to_pil_image(img_array)
|
||||||
assert isinstance(result_image, Image.Image)
|
assert isinstance(result_image, Image.Image)
|
||||||
assert result_image.size == (100, 100)
|
assert result_image.size == (100, 100)
|
||||||
assert result_image.mode == "RGB"
|
assert result_image.mode == "RGB"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("TODO: implement")
|
def test_image_array_to_pil_image_single_channel(img_array_factory):
|
||||||
def test_image_array_to_image_single_channel(img_array_factory):
|
|
||||||
img_array = img_array_factory(channels=1)
|
img_array = img_array_factory(channels=1)
|
||||||
result_image = image_array_to_image(img_array)
|
with pytest.raises(NotImplementedError):
|
||||||
assert isinstance(result_image, Image.Image)
|
image_array_to_pil_image(img_array)
|
||||||
assert result_image.size == (100, 100)
|
|
||||||
assert result_image.mode == "L"
|
|
||||||
|
|
||||||
|
|
||||||
def test_image_array_to_image_float_array(img_array_factory):
|
def test_image_array_to_pil_image_4_channels(img_array_factory):
|
||||||
|
img_array = img_array_factory(channels=4)
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
image_array_to_pil_image(img_array)
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_array_to_pil_image_float_array(img_array_factory):
|
||||||
img_array = img_array_factory(dtype=np.float32)
|
img_array = img_array_factory(dtype=np.float32)
|
||||||
result_image = image_array_to_image(img_array)
|
result_image = image_array_to_pil_image(img_array)
|
||||||
assert isinstance(result_image, Image.Image)
|
assert isinstance(result_image, Image.Image)
|
||||||
assert result_image.size == (100, 100)
|
assert result_image.size == (100, 100)
|
||||||
assert result_image.mode == "RGB"
|
assert result_image.mode == "RGB"
|
||||||
assert np.array(result_image).dtype == np.uint8
|
assert np.array(result_image).dtype == np.uint8
|
||||||
|
|
||||||
|
|
||||||
def test_image_array_to_image_out_of_bounds_float():
|
def test_image_array_to_pil_image_uint8_array(img_array_factory):
|
||||||
# Float array with values out of [0, 1]
|
img_array = img_array_factory(dtype=np.float32)
|
||||||
img_array = np.random.uniform(-1, 2, size=(100, 100, 3)).astype(np.float32)
|
result_image = image_array_to_pil_image(img_array)
|
||||||
result_image = image_array_to_image(img_array)
|
|
||||||
assert isinstance(result_image, Image.Image)
|
assert isinstance(result_image, Image.Image)
|
||||||
assert result_image.size == (100, 100)
|
assert result_image.size == (100, 100)
|
||||||
assert result_image.mode == "RGB"
|
assert result_image.mode == "RGB"
|
||||||
assert np.array(result_image).dtype == np.uint8
|
assert np.array(result_image).dtype == np.uint8
|
||||||
assert np.array(result_image).min() >= 0 and np.array(result_image).max() <= 255
|
|
||||||
|
|
||||||
|
|
||||||
def test_write_image_numpy(tmp_path, img_array_factory):
|
def test_write_image_numpy(tmp_path, img_array_factory):
|
||||||
|
|
Loading…
Reference in New Issue