Remove dataset `consolidate` (#752)
This commit is contained in:
parent
6fe42a72db
commit
969ef745a2
|
@ -200,8 +200,6 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
|
|||
|
||||
dataset.save_episode()
|
||||
|
||||
dataset.consolidate()
|
||||
|
||||
if push_to_hub:
|
||||
dataset.push_to_hub()
|
||||
hub_api = HfApi()
|
||||
|
|
|
@ -39,7 +39,6 @@ from lerobot.common.datasets.utils import (
|
|||
append_jsonlines,
|
||||
backward_compatible_episodes_stats,
|
||||
check_delta_timestamps,
|
||||
check_frame_features,
|
||||
check_timestamps_sync,
|
||||
check_version_compatibility,
|
||||
create_empty_dataset_info,
|
||||
|
@ -55,6 +54,8 @@ from lerobot.common.datasets.utils import (
|
|||
load_info,
|
||||
load_stats,
|
||||
load_tasks,
|
||||
validate_episode_buffer,
|
||||
validate_frame,
|
||||
write_episode,
|
||||
write_episode_stats,
|
||||
write_info,
|
||||
|
@ -256,6 +257,9 @@ class LeRobotDatasetMetadata:
|
|||
|
||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||
self.info["total_videos"] += len(self.video_keys)
|
||||
if len(self.video_keys) > 0:
|
||||
self.update_video_info()
|
||||
|
||||
write_info(self.info, self.root)
|
||||
|
||||
episode_dict = {
|
||||
|
@ -270,7 +274,7 @@ class LeRobotDatasetMetadata:
|
|||
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
|
||||
write_episode_stats(episode_index, episode_stats, self.root)
|
||||
|
||||
def write_video_info(self) -> None:
|
||||
def update_video_info(self) -> None:
|
||||
"""
|
||||
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||
|
@ -280,8 +284,6 @@ class LeRobotDatasetMetadata:
|
|||
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
|
||||
self.info["features"][key]["info"] = get_video_info(video_path)
|
||||
|
||||
write_json(self.info, self.root / INFO_PATH)
|
||||
|
||||
def __repr__(self):
|
||||
feature_keys = list(self.features)
|
||||
return (
|
||||
|
@ -506,9 +508,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||
|
||||
# Available stats implies all videos have been encoded and dataset is iterable
|
||||
self.consolidated = self.meta.stats is not None
|
||||
|
||||
def push_to_hub(
|
||||
self,
|
||||
branch: str | None = None,
|
||||
|
@ -519,13 +518,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
allow_patterns: list[str] | str | None = None,
|
||||
**card_kwargs,
|
||||
) -> None:
|
||||
if not self.consolidated:
|
||||
logging.warning(
|
||||
"You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet. "
|
||||
"Consolidating first."
|
||||
)
|
||||
self.consolidate()
|
||||
|
||||
ignore_patterns = ["images/"]
|
||||
if not push_videos:
|
||||
ignore_patterns.append("videos/")
|
||||
|
@ -779,7 +771,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
if isinstance(frame[name], torch.Tensor):
|
||||
frame[name] = frame[name].numpy()
|
||||
|
||||
check_frame_features(frame, self.features)
|
||||
validate_frame(frame, self.features)
|
||||
|
||||
if self.episode_buffer is None:
|
||||
self.episode_buffer = self.create_episode_buffer()
|
||||
|
@ -815,41 +807,25 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
self.episode_buffer["size"] += 1
|
||||
|
||||
def save_episode(self, encode_videos: bool = True, episode_data: dict | None = None) -> None:
|
||||
def save_episode(self, episode_data: dict | None = None) -> None:
|
||||
"""
|
||||
This will save to disk the current episode in self.episode_buffer. Note that since it affects files on
|
||||
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
|
||||
the hub.
|
||||
This will save to disk the current episode in self.episode_buffer.
|
||||
|
||||
Use 'encode_videos' if you want to encode videos during the saving of this episode. Otherwise,
|
||||
you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend
|
||||
time for video encoding.
|
||||
Args:
|
||||
episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will
|
||||
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
|
||||
None.
|
||||
"""
|
||||
if not episode_data:
|
||||
episode_buffer = self.episode_buffer
|
||||
|
||||
validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features)
|
||||
|
||||
# size and task are special cases that won't be added to hf_dataset
|
||||
episode_length = episode_buffer.pop("size")
|
||||
tasks = episode_buffer.pop("task")
|
||||
episode_tasks = list(set(tasks))
|
||||
|
||||
episode_index = episode_buffer["episode_index"]
|
||||
if episode_index != self.meta.total_episodes:
|
||||
# TODO(aliberts): Add option to use existing episode_index
|
||||
raise NotImplementedError(
|
||||
"You might have manually provided the episode_buffer with an episode_index that doesn't "
|
||||
"match the total number of episodes already in the dataset. This is not supported for now."
|
||||
)
|
||||
|
||||
if episode_length == 0:
|
||||
raise ValueError(
|
||||
"You must add one or several frames with `add_frame` before calling `add_episode`."
|
||||
)
|
||||
|
||||
if not set(episode_buffer.keys()) == set(self.features):
|
||||
raise ValueError(
|
||||
f"Features from `episode_buffer` don't match the ones in `self.features`: '{set(episode_buffer.keys())}' vs '{set(self.features)}'"
|
||||
)
|
||||
|
||||
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
|
||||
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
|
||||
|
@ -875,16 +851,29 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
|
||||
|
||||
if encode_videos and len(self.meta.video_keys) > 0:
|
||||
if len(self.meta.video_keys) > 0:
|
||||
video_paths = self.encode_episode_videos(episode_index)
|
||||
for key in self.meta.video_keys:
|
||||
episode_buffer[key] = video_paths[key]
|
||||
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
||||
|
||||
video_files = list(self.root.rglob("*.mp4"))
|
||||
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
|
||||
|
||||
parquet_files = list(self.root.rglob("*.parquet"))
|
||||
assert len(parquet_files) == self.num_episodes
|
||||
|
||||
# delete images
|
||||
img_dir = self.root / "images"
|
||||
if img_dir.is_dir():
|
||||
shutil.rmtree(self.root / "images")
|
||||
|
||||
if not episode_data: # Reset the buffer
|
||||
self.episode_buffer = self.create_episode_buffer()
|
||||
|
||||
self.consolidated = False
|
||||
|
||||
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
|
||||
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
|
||||
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
|
||||
|
@ -959,28 +948,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
return video_paths
|
||||
|
||||
def consolidate(self, keep_image_files: bool = False) -> None:
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
||||
|
||||
if len(self.meta.video_keys) > 0:
|
||||
self.encode_videos()
|
||||
self.meta.write_video_info()
|
||||
|
||||
if not keep_image_files:
|
||||
img_dir = self.root / "images"
|
||||
if img_dir.is_dir():
|
||||
shutil.rmtree(self.root / "images")
|
||||
|
||||
video_files = list(self.root.rglob("*.mp4"))
|
||||
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
|
||||
|
||||
parquet_files = list(self.root.rglob("*.parquet"))
|
||||
assert len(parquet_files) == self.num_episodes
|
||||
|
||||
self.consolidated = True
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
|
@ -1019,12 +986,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
|
||||
obj.episode_buffer = obj.create_episode_buffer()
|
||||
|
||||
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It
|
||||
# is used to know when certain operations are need (for instance, computing dataset statistics). In
|
||||
# order to be able to push the dataset to the hub, it needs to be consolidated first by calling
|
||||
# self.consolidate().
|
||||
obj.consolidated = True
|
||||
|
||||
obj.episodes = None
|
||||
obj.hf_dataset = None
|
||||
obj.image_transforms = None
|
||||
|
|
|
@ -644,25 +644,25 @@ class IterableNamespace(SimpleNamespace):
|
|||
return vars(self).keys()
|
||||
|
||||
|
||||
def check_frame_features(frame: dict, features: dict):
|
||||
def validate_frame(frame: dict, features: dict):
|
||||
optional_features = {"timestamp"}
|
||||
expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"}
|
||||
actual_features = set(frame.keys())
|
||||
|
||||
error_message = check_features_presence(actual_features, expected_features, optional_features)
|
||||
error_message = validate_features_presence(actual_features, expected_features, optional_features)
|
||||
|
||||
if "task" in frame:
|
||||
error_message += check_feature_string("task", frame["task"])
|
||||
error_message += validate_feature_string("task", frame["task"])
|
||||
|
||||
common_features = actual_features & (expected_features | optional_features)
|
||||
for name in common_features - {"task"}:
|
||||
error_message += check_feature_dtype_and_shape(name, features[name], frame[name])
|
||||
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
|
||||
|
||||
if error_message:
|
||||
raise ValueError(error_message)
|
||||
|
||||
|
||||
def check_features_presence(
|
||||
def validate_features_presence(
|
||||
actual_features: set[str], expected_features: set[str], optional_features: set[str]
|
||||
):
|
||||
error_message = ""
|
||||
|
@ -679,20 +679,22 @@ def check_features_presence(
|
|||
return error_message
|
||||
|
||||
|
||||
def check_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
|
||||
def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
|
||||
expected_dtype = feature["dtype"]
|
||||
expected_shape = feature["shape"]
|
||||
if is_valid_numpy_dtype_string(expected_dtype):
|
||||
return check_feature_numpy_array(name, expected_dtype, expected_shape, value)
|
||||
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
|
||||
elif expected_dtype in ["image", "video"]:
|
||||
return check_feature_image_or_video(name, expected_shape, value)
|
||||
return validate_feature_image_or_video(name, expected_shape, value)
|
||||
elif expected_dtype == "string":
|
||||
return check_feature_string(name, value)
|
||||
return validate_feature_string(name, value)
|
||||
else:
|
||||
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
|
||||
|
||||
|
||||
def check_feature_numpy_array(name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray):
|
||||
def validate_feature_numpy_array(
|
||||
name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray
|
||||
):
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_dtype = value.dtype
|
||||
|
@ -709,7 +711,7 @@ def check_feature_numpy_array(name: str, expected_dtype: str, expected_shape: li
|
|||
return error_message
|
||||
|
||||
|
||||
def check_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
|
||||
def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
|
||||
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
|
@ -725,7 +727,33 @@ def check_feature_image_or_video(name: str, expected_shape: list[str], value: np
|
|||
return error_message
|
||||
|
||||
|
||||
def check_feature_string(name: str, value: str):
|
||||
def validate_feature_string(name: str, value: str):
|
||||
if not isinstance(value, str):
|
||||
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
|
||||
return ""
|
||||
|
||||
|
||||
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict):
|
||||
if "size" not in episode_buffer:
|
||||
raise ValueError("size key not found in episode_buffer")
|
||||
|
||||
if "task" not in episode_buffer:
|
||||
raise ValueError("task key not found in episode_buffer")
|
||||
|
||||
if episode_buffer["episode_index"] != total_episodes:
|
||||
# TODO(aliberts): Add option to use existing episode_index
|
||||
raise NotImplementedError(
|
||||
"You might have manually provided the episode_buffer with an episode_index that doesn't "
|
||||
"match the total number of episodes already in the dataset. This is not supported for now."
|
||||
)
|
||||
|
||||
if episode_buffer["size"] == 0:
|
||||
raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.")
|
||||
|
||||
buffer_keys = set(episode_buffer.keys()) - {"task", "size"}
|
||||
if not buffer_keys == set(features):
|
||||
raise ValueError(
|
||||
f"Features from `episode_buffer` don't match the ones in `features`."
|
||||
f"In episode_buffer not in features: {buffer_keys - set(features)}"
|
||||
f"In features not in episode_buffer: {set(features) - buffer_keys}"
|
||||
)
|
||||
|
|
|
@ -299,8 +299,6 @@ def record(
|
|||
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
||||
stop_recording(robot, listener, cfg.display_cameras)
|
||||
|
||||
dataset.consolidate()
|
||||
|
||||
if cfg.push_to_hub:
|
||||
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)
|
||||
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
import random
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Protocol
|
||||
from unittest.mock import patch
|
||||
|
||||
import datasets
|
||||
|
@ -17,7 +19,6 @@ from lerobot.common.datasets.utils import (
|
|||
get_hf_features_from_features,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from tests.fixtures.constants import (
|
||||
DEFAULT_FPS,
|
||||
DUMMY_CAMERA_FEATURES,
|
||||
|
@ -28,6 +29,10 @@ from tests.fixtures.constants import (
|
|||
)
|
||||
|
||||
|
||||
class LeRobotDatasetFactory(Protocol):
|
||||
def __call__(self, *args, **kwargs) -> LeRobotDataset: ...
|
||||
|
||||
|
||||
def get_task_index(task_dicts: dict, task: str) -> int:
|
||||
tasks = {d["task_index"]: d["task"] for d in task_dicts.values()}
|
||||
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()}
|
||||
|
@ -358,7 +363,7 @@ def lerobot_dataset_factory(
|
|||
hf_dataset_factory,
|
||||
mock_snapshot_download_factory,
|
||||
lerobot_dataset_metadata_factory,
|
||||
):
|
||||
) -> LeRobotDatasetFactory:
|
||||
def _create_lerobot_dataset(
|
||||
root: Path,
|
||||
repo_id: str = DUMMY_REPO_ID,
|
||||
|
@ -430,17 +435,5 @@ def lerobot_dataset_factory(
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def empty_lerobot_dataset_factory():
|
||||
def _create_empty_lerobot_dataset(
|
||||
root: Path,
|
||||
repo_id: str = DUMMY_REPO_ID,
|
||||
fps: int = DEFAULT_FPS,
|
||||
robot: Robot | None = None,
|
||||
robot_type: str | None = None,
|
||||
features: dict | None = None,
|
||||
) -> LeRobotDataset:
|
||||
return LeRobotDataset.create(
|
||||
repo_id=repo_id, fps=fps, root=root, robot=robot, robot_type=robot_type, features=features
|
||||
)
|
||||
|
||||
return _create_empty_lerobot_dataset
|
||||
def empty_lerobot_dataset_factory() -> LeRobotDatasetFactory:
|
||||
return partial(LeRobotDataset.create, repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS)
|
||||
|
|
|
@ -184,8 +184,7 @@ def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
|
|||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate()
|
||||
dataset.save_episode()
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert dataset[0]["task"] == "Dummy task"
|
||||
|
@ -197,8 +196,7 @@ def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
|
|||
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2), "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate()
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2])
|
||||
|
||||
|
@ -207,8 +205,7 @@ def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
|
|||
features = {"state": {"dtype": "float32", "shape": (2, 4), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2, 4), "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate()
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2, 4])
|
||||
|
||||
|
@ -217,8 +214,7 @@ def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
|
|||
features = {"state": {"dtype": "float32", "shape": (2, 4, 3), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate()
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2, 4, 3])
|
||||
|
||||
|
@ -227,8 +223,7 @@ def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
|
|||
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate()
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5])
|
||||
|
||||
|
@ -237,8 +232,7 @@ def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory):
|
|||
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5, 1), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate()
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1])
|
||||
|
||||
|
@ -247,8 +241,7 @@ def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory):
|
|||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": np.array([1], dtype=np.float32), "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate()
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["state"].ndim == 0
|
||||
|
||||
|
@ -257,8 +250,7 @@ def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory):
|
|||
features = {"caption": {"dtype": "string", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"caption": "Dummy caption", "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate()
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["caption"] == "Dummy caption"
|
||||
|
||||
|
@ -287,14 +279,13 @@ def test_add_frame_image_wrong_range(image_dataset):
|
|||
dataset = image_dataset
|
||||
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255, "task": "Dummy task"})
|
||||
with pytest.raises(FileNotFoundError):
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.save_episode()
|
||||
|
||||
|
||||
def test_add_frame_image(image_dataset):
|
||||
dataset = image_dataset
|
||||
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate()
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||
|
||||
|
@ -302,8 +293,7 @@ def test_add_frame_image(image_dataset):
|
|||
def test_add_frame_image_h_w_c(image_dataset):
|
||||
dataset = image_dataset
|
||||
dataset.add_frame({"image": np.random.rand(*DUMMY_HWC), "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate()
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||
|
||||
|
@ -312,8 +302,7 @@ def test_add_frame_image_uint8(image_dataset):
|
|||
dataset = image_dataset
|
||||
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
|
||||
dataset.add_frame({"image": image, "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate()
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||
|
||||
|
@ -322,8 +311,7 @@ def test_add_frame_image_pil(image_dataset):
|
|||
dataset = image_dataset
|
||||
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
|
||||
dataset.add_frame({"image": Image.fromarray(image), "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate()
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||
|
||||
|
@ -338,7 +326,6 @@ def test_image_array_to_pil_image_wrong_range_float_0_255():
|
|||
# - [ ] test various attributes & state from init and create
|
||||
# - [ ] test init with episodes and check num_frames
|
||||
# - [ ] test add_episode
|
||||
# - [ ] test consolidate
|
||||
# - [ ] test push_to_hub
|
||||
# - [ ] test smaller methods
|
||||
|
||||
|
|
Loading…
Reference in New Issue