Merge remote-tracking branch 'origin/user/aliberts/2025_02_10_dataset_v2.1' into user/rcadene/2025_02_19_port_openx
This commit is contained in:
commit
b520941cd9
|
@ -3,8 +3,10 @@ from pathlib import Path
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
|
||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
||||
|
||||
PUSHT_TASK = "Push the T-shaped blue block onto the T-shaped green target surface."
|
||||
|
@ -134,8 +136,8 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
|
|||
if mode not in ["video", "image", "keypoints"]:
|
||||
raise ValueError(mode)
|
||||
|
||||
if (LEROBOT_HOME / repo_id).exists():
|
||||
shutil.rmtree(LEROBOT_HOME / repo_id)
|
||||
if (HF_LEROBOT_HOME / repo_id).exists():
|
||||
shutil.rmtree(HF_LEROBOT_HOME / repo_id)
|
||||
|
||||
if not raw_dir.exists():
|
||||
download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw")
|
||||
|
@ -198,10 +200,10 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
|
|||
|
||||
dataset.save_episode()
|
||||
|
||||
dataset.consolidate()
|
||||
|
||||
if push_to_hub:
|
||||
dataset.push_to_hub()
|
||||
hub_api = HfApi()
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -1,4 +1,9 @@
|
|||
# keys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub.constants import HF_HOME
|
||||
|
||||
OBS_ENV = "observation.environment_state"
|
||||
OBS_ROBOT = "observation.state"
|
||||
OBS_IMAGE = "observation.image"
|
||||
|
@ -15,3 +20,13 @@ TRAINING_STEP = "training_step.json"
|
|||
OPTIMIZER_STATE = "optimizer_state.safetensors"
|
||||
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
|
||||
SCHEDULER_STATE = "scheduler_state.json"
|
||||
|
||||
# cache dir
|
||||
default_cache_path = Path(HF_HOME) / "lerobot"
|
||||
HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser()
|
||||
|
||||
if "LEROBOT_HOME" in os.environ:
|
||||
raise ValueError(
|
||||
f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n"
|
||||
"'LEROBOT_HOME' is deprecated, please use 'HF_LEROBOT_HOME' instead."
|
||||
)
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
@ -29,6 +28,7 @@ from huggingface_hub import HfApi, snapshot_download
|
|||
from huggingface_hub.constants import REPOCARD_NAME
|
||||
from packaging import version
|
||||
|
||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
|
||||
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
|
||||
from lerobot.common.datasets.utils import (
|
||||
|
@ -39,7 +39,6 @@ from lerobot.common.datasets.utils import (
|
|||
append_jsonlines,
|
||||
backward_compatible_episodes_stats,
|
||||
check_delta_timestamps,
|
||||
check_frame_features,
|
||||
check_timestamps_sync,
|
||||
check_version_compatibility,
|
||||
create_empty_dataset_info,
|
||||
|
@ -55,6 +54,8 @@ from lerobot.common.datasets.utils import (
|
|||
load_info,
|
||||
load_stats,
|
||||
load_tasks,
|
||||
validate_episode_buffer,
|
||||
validate_frame,
|
||||
write_episode,
|
||||
write_episode_stats,
|
||||
write_info,
|
||||
|
@ -71,7 +72,6 @@ from lerobot.common.robot_devices.robots.utils import Robot
|
|||
|
||||
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
|
||||
CODEBASE_VERSION = "v2.1"
|
||||
LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
|
||||
|
||||
|
||||
class LeRobotDatasetMetadata:
|
||||
|
@ -84,7 +84,7 @@ class LeRobotDatasetMetadata:
|
|||
):
|
||||
self.repo_id = repo_id
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
|
||||
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
||||
|
||||
try:
|
||||
if force_cache_sync:
|
||||
|
@ -257,6 +257,9 @@ class LeRobotDatasetMetadata:
|
|||
|
||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||
self.info["total_videos"] += len(self.video_keys)
|
||||
if len(self.video_keys) > 0:
|
||||
self.update_video_info()
|
||||
|
||||
write_info(self.info, self.root)
|
||||
|
||||
episode_dict = {
|
||||
|
@ -271,7 +274,7 @@ class LeRobotDatasetMetadata:
|
|||
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
|
||||
write_episode_stats(episode_index, episode_stats, self.root)
|
||||
|
||||
def write_video_info(self) -> None:
|
||||
def update_video_info(self) -> None:
|
||||
"""
|
||||
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||
|
@ -281,8 +284,6 @@ class LeRobotDatasetMetadata:
|
|||
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
|
||||
self.info["features"][key]["info"] = get_video_info(video_path)
|
||||
|
||||
write_json(self.info, self.root / INFO_PATH)
|
||||
|
||||
def __repr__(self):
|
||||
feature_keys = list(self.features)
|
||||
return (
|
||||
|
@ -308,7 +309,7 @@ class LeRobotDatasetMetadata:
|
|||
"""Creates metadata for a LeRobotDataset."""
|
||||
obj = cls.__new__(cls)
|
||||
obj.repo_id = repo_id
|
||||
obj.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
|
||||
obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
||||
|
||||
obj.root.mkdir(parents=True, exist_ok=False)
|
||||
|
||||
|
@ -463,7 +464,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
"""
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.root = Path(root) if root else LEROBOT_HOME / repo_id
|
||||
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.episodes = episodes
|
||||
|
@ -507,9 +508,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||
|
||||
# Available stats implies all videos have been encoded and dataset is iterable
|
||||
self.consolidated = self.meta.stats is not None
|
||||
|
||||
def push_to_hub(
|
||||
self,
|
||||
branch: str | None = None,
|
||||
|
@ -520,13 +518,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
allow_patterns: list[str] | str | None = None,
|
||||
**card_kwargs,
|
||||
) -> None:
|
||||
if not self.consolidated:
|
||||
logging.warning(
|
||||
"You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet. "
|
||||
"Consolidating first."
|
||||
)
|
||||
self.consolidate()
|
||||
|
||||
ignore_patterns = ["images/"]
|
||||
if not push_videos:
|
||||
ignore_patterns.append("videos/")
|
||||
|
@ -780,7 +771,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
if isinstance(frame[name], torch.Tensor):
|
||||
frame[name] = frame[name].numpy()
|
||||
|
||||
check_frame_features(frame, self.features)
|
||||
validate_frame(frame, self.features)
|
||||
|
||||
if self.episode_buffer is None:
|
||||
self.episode_buffer = self.create_episode_buffer()
|
||||
|
@ -816,41 +807,25 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
self.episode_buffer["size"] += 1
|
||||
|
||||
def save_episode(self, encode_videos: bool = True, episode_data: dict | None = None) -> None:
|
||||
def save_episode(self, episode_data: dict | None = None) -> None:
|
||||
"""
|
||||
This will save to disk the current episode in self.episode_buffer. Note that since it affects files on
|
||||
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
|
||||
the hub.
|
||||
This will save to disk the current episode in self.episode_buffer.
|
||||
|
||||
Use 'encode_videos' if you want to encode videos during the saving of this episode. Otherwise,
|
||||
you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend
|
||||
time for video encoding.
|
||||
Args:
|
||||
episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will
|
||||
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
|
||||
None.
|
||||
"""
|
||||
if not episode_data:
|
||||
episode_buffer = self.episode_buffer
|
||||
|
||||
validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features)
|
||||
|
||||
# size and task are special cases that won't be added to hf_dataset
|
||||
episode_length = episode_buffer.pop("size")
|
||||
tasks = episode_buffer.pop("task")
|
||||
episode_tasks = list(set(tasks))
|
||||
|
||||
episode_index = episode_buffer["episode_index"]
|
||||
if episode_index != self.meta.total_episodes:
|
||||
# TODO(aliberts): Add option to use existing episode_index
|
||||
raise NotImplementedError(
|
||||
"You might have manually provided the episode_buffer with an episode_index that doesn't "
|
||||
"match the total number of episodes already in the dataset. This is not supported for now."
|
||||
)
|
||||
|
||||
if episode_length == 0:
|
||||
raise ValueError(
|
||||
"You must add one or several frames with `add_frame` before calling `add_episode`."
|
||||
)
|
||||
|
||||
if not set(episode_buffer.keys()) == set(self.features):
|
||||
raise ValueError(
|
||||
f"Features from `episode_buffer` don't match the ones in `self.features`: '{set(episode_buffer.keys())}' vs '{set(self.features)}'"
|
||||
)
|
||||
|
||||
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
|
||||
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
|
||||
|
@ -876,16 +851,29 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
|
||||
|
||||
if encode_videos and len(self.meta.video_keys) > 0:
|
||||
if len(self.meta.video_keys) > 0:
|
||||
video_paths = self.encode_episode_videos(episode_index)
|
||||
for key in self.meta.video_keys:
|
||||
episode_buffer[key] = video_paths[key]
|
||||
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
||||
|
||||
video_files = list(self.root.rglob("*.mp4"))
|
||||
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
|
||||
|
||||
parquet_files = list(self.root.rglob("*.parquet"))
|
||||
assert len(parquet_files) == self.num_episodes
|
||||
|
||||
# delete images
|
||||
img_dir = self.root / "images"
|
||||
if img_dir.is_dir():
|
||||
shutil.rmtree(self.root / "images")
|
||||
|
||||
if not episode_data: # Reset the buffer
|
||||
self.episode_buffer = self.create_episode_buffer()
|
||||
|
||||
self.consolidated = False
|
||||
|
||||
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
|
||||
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
|
||||
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
|
||||
|
@ -960,28 +948,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
return video_paths
|
||||
|
||||
def consolidate(self, keep_image_files: bool = False) -> None:
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
||||
|
||||
if len(self.meta.video_keys) > 0:
|
||||
self.encode_videos()
|
||||
self.meta.write_video_info()
|
||||
|
||||
if not keep_image_files:
|
||||
img_dir = self.root / "images"
|
||||
if img_dir.is_dir():
|
||||
shutil.rmtree(self.root / "images")
|
||||
|
||||
video_files = list(self.root.rglob("*.mp4"))
|
||||
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
|
||||
|
||||
parquet_files = list(self.root.rglob("*.parquet"))
|
||||
assert len(parquet_files) == self.num_episodes
|
||||
|
||||
self.consolidated = True
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
|
@ -1020,12 +986,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
|
||||
obj.episode_buffer = obj.create_episode_buffer()
|
||||
|
||||
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It
|
||||
# is used to know when certain operations are need (for instance, computing dataset statistics). In
|
||||
# order to be able to push the dataset to the hub, it needs to be consolidated first by calling
|
||||
# self.consolidate().
|
||||
obj.consolidated = True
|
||||
|
||||
obj.episodes = None
|
||||
obj.hf_dataset = None
|
||||
obj.image_transforms = None
|
||||
|
@ -1056,7 +1016,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||
):
|
||||
super().__init__()
|
||||
self.repo_ids = repo_ids
|
||||
self.root = Path(root) if root else LEROBOT_HOME
|
||||
self.root = Path(root) if root else HF_LEROBOT_HOME
|
||||
self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
|
||||
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
||||
# are handled by this class.
|
||||
|
|
|
@ -644,25 +644,25 @@ class IterableNamespace(SimpleNamespace):
|
|||
return vars(self).keys()
|
||||
|
||||
|
||||
def check_frame_features(frame: dict, features: dict):
|
||||
def validate_frame(frame: dict, features: dict):
|
||||
optional_features = {"timestamp"}
|
||||
expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"}
|
||||
actual_features = set(frame.keys())
|
||||
|
||||
error_message = check_features_presence(actual_features, expected_features, optional_features)
|
||||
error_message = validate_features_presence(actual_features, expected_features, optional_features)
|
||||
|
||||
if "task" in frame:
|
||||
error_message += check_feature_string("task", frame["task"])
|
||||
error_message += validate_feature_string("task", frame["task"])
|
||||
|
||||
common_features = actual_features & (expected_features | optional_features)
|
||||
for name in common_features - {"task"}:
|
||||
error_message += check_feature_dtype_and_shape(name, features[name], frame[name])
|
||||
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
|
||||
|
||||
if error_message:
|
||||
raise ValueError(error_message)
|
||||
|
||||
|
||||
def check_features_presence(
|
||||
def validate_features_presence(
|
||||
actual_features: set[str], expected_features: set[str], optional_features: set[str]
|
||||
):
|
||||
error_message = ""
|
||||
|
@ -679,20 +679,22 @@ def check_features_presence(
|
|||
return error_message
|
||||
|
||||
|
||||
def check_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
|
||||
def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
|
||||
expected_dtype = feature["dtype"]
|
||||
expected_shape = feature["shape"]
|
||||
if is_valid_numpy_dtype_string(expected_dtype):
|
||||
return check_feature_numpy_array(name, expected_dtype, expected_shape, value)
|
||||
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
|
||||
elif expected_dtype in ["image", "video"]:
|
||||
return check_feature_image_or_video(name, expected_shape, value)
|
||||
return validate_feature_image_or_video(name, expected_shape, value)
|
||||
elif expected_dtype == "string":
|
||||
return check_feature_string(name, value)
|
||||
return validate_feature_string(name, value)
|
||||
else:
|
||||
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
|
||||
|
||||
|
||||
def check_feature_numpy_array(name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray):
|
||||
def validate_feature_numpy_array(
|
||||
name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray
|
||||
):
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_dtype = value.dtype
|
||||
|
@ -709,7 +711,7 @@ def check_feature_numpy_array(name: str, expected_dtype: str, expected_shape: li
|
|||
return error_message
|
||||
|
||||
|
||||
def check_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
|
||||
def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
|
||||
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
|
@ -725,7 +727,33 @@ def check_feature_image_or_video(name: str, expected_shape: list[str], value: np
|
|||
return error_message
|
||||
|
||||
|
||||
def check_feature_string(name: str, value: str):
|
||||
def validate_feature_string(name: str, value: str):
|
||||
if not isinstance(value, str):
|
||||
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
|
||||
return ""
|
||||
|
||||
|
||||
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict):
|
||||
if "size" not in episode_buffer:
|
||||
raise ValueError("size key not found in episode_buffer")
|
||||
|
||||
if "task" not in episode_buffer:
|
||||
raise ValueError("task key not found in episode_buffer")
|
||||
|
||||
if episode_buffer["episode_index"] != total_episodes:
|
||||
# TODO(aliberts): Add option to use existing episode_index
|
||||
raise NotImplementedError(
|
||||
"You might have manually provided the episode_buffer with an episode_index that doesn't "
|
||||
"match the total number of episodes already in the dataset. This is not supported for now."
|
||||
)
|
||||
|
||||
if episode_buffer["size"] == 0:
|
||||
raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.")
|
||||
|
||||
buffer_keys = set(episode_buffer.keys()) - {"task", "size"}
|
||||
if not buffer_keys == set(features):
|
||||
raise ValueError(
|
||||
f"Features from `episode_buffer` don't match the ones in `features`."
|
||||
f"In episode_buffer not in features: {buffer_keys - set(features)}"
|
||||
f"In features not in episode_buffer: {set(features) - buffer_keys}"
|
||||
)
|
||||
|
|
|
@ -29,8 +29,9 @@ LOCAL_DIR = Path("data/")
|
|||
|
||||
def batch_convert():
|
||||
status = {}
|
||||
LOCAL_DIR.mkdir(parents=True, exist_ok=True)
|
||||
logfile = LOCAL_DIR / "conversion_log_v21.txt"
|
||||
for num, repo_id in available_datasets:
|
||||
for num, repo_id in enumerate(available_datasets):
|
||||
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
|
||||
print("---------------------------------------------------------")
|
||||
try:
|
||||
|
|
|
@ -2,10 +2,10 @@
|
|||
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
|
||||
2.1. It will:
|
||||
|
||||
- Generates per-episodes stats and writes them in `episodes_stats.jsonl`
|
||||
- Generate per-episodes stats and writes them in `episodes_stats.jsonl`
|
||||
- Check consistency between these new stats and the old ones.
|
||||
- Removes the deprecated `stats.json` (by default)
|
||||
- Updates codebase_version in `info.json`
|
||||
- Remove the deprecated `stats.json`.
|
||||
- Update codebase_version in `info.json`.
|
||||
- Push this new version to the hub on the 'main' branch and tags it with "v2.1".
|
||||
|
||||
Usage:
|
||||
|
@ -80,19 +80,20 @@ if __name__ == "__main__":
|
|||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
|
||||
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--branch",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Repo branch to push your dataset (defaults to the main branch)",
|
||||
help="Repo branch to push your dataset. Defaults to the main branch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of workers for parallelizing compute",
|
||||
help="Number of workers for parallelizing stats compute. Defaults to 4.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -299,8 +299,6 @@ def record(
|
|||
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
||||
stop_recording(robot, listener, cfg.display_cameras)
|
||||
|
||||
dataset.consolidate()
|
||||
|
||||
if cfg.push_to_hub:
|
||||
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
|
||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
|
||||
LEROBOT_TEST_DIR = LEROBOT_HOME / "_testing"
|
||||
LEROBOT_TEST_DIR = HF_LEROBOT_HOME / "_testing"
|
||||
DUMMY_REPO_ID = "dummy/repo"
|
||||
DUMMY_ROBOT_TYPE = "dummy_robot"
|
||||
DUMMY_MOTOR_FEATURES = {
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
import random
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Protocol
|
||||
from unittest.mock import patch
|
||||
|
||||
import datasets
|
||||
|
@ -17,7 +19,6 @@ from lerobot.common.datasets.utils import (
|
|||
get_hf_features_from_features,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from tests.fixtures.constants import (
|
||||
DEFAULT_FPS,
|
||||
DUMMY_CAMERA_FEATURES,
|
||||
|
@ -28,6 +29,10 @@ from tests.fixtures.constants import (
|
|||
)
|
||||
|
||||
|
||||
class LeRobotDatasetFactory(Protocol):
|
||||
def __call__(self, *args, **kwargs) -> LeRobotDataset: ...
|
||||
|
||||
|
||||
def get_task_index(task_dicts: dict, task: str) -> int:
|
||||
tasks = {d["task_index"]: d["task"] for d in task_dicts.values()}
|
||||
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()}
|
||||
|
@ -358,7 +363,7 @@ def lerobot_dataset_factory(
|
|||
hf_dataset_factory,
|
||||
mock_snapshot_download_factory,
|
||||
lerobot_dataset_metadata_factory,
|
||||
):
|
||||
) -> LeRobotDatasetFactory:
|
||||
def _create_lerobot_dataset(
|
||||
root: Path,
|
||||
repo_id: str = DUMMY_REPO_ID,
|
||||
|
@ -430,17 +435,5 @@ def lerobot_dataset_factory(
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def empty_lerobot_dataset_factory():
|
||||
def _create_empty_lerobot_dataset(
|
||||
root: Path,
|
||||
repo_id: str = DUMMY_REPO_ID,
|
||||
fps: int = DEFAULT_FPS,
|
||||
robot: Robot | None = None,
|
||||
robot_type: str | None = None,
|
||||
features: dict | None = None,
|
||||
) -> LeRobotDataset:
|
||||
return LeRobotDataset.create(
|
||||
repo_id=repo_id, fps=fps, root=root, robot=robot, robot_type=robot_type, features=features
|
||||
)
|
||||
|
||||
return _create_empty_lerobot_dataset
|
||||
def empty_lerobot_dataset_factory() -> LeRobotDatasetFactory:
|
||||
return partial(LeRobotDataset.create, repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS)
|
||||
|
|
|
@ -184,8 +184,7 @@ def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
|
|||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate()
|
||||
dataset.save_episode()
|
||||
|
||||
assert len(dataset) == 1
|
||||
assert dataset[0]["task"] == "Dummy task"
|
||||
|
@ -197,8 +196,7 @@ def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
|
|||
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2), "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate()
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2])
|
||||
|
||||
|
@ -207,8 +205,7 @@ def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
|
|||
features = {"state": {"dtype": "float32", "shape": (2, 4), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2, 4), "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate()
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2, 4])
|
||||
|
||||
|
@ -217,8 +214,7 @@ def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
|
|||
features = {"state": {"dtype": "float32", "shape": (2, 4, 3), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate()
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2, 4, 3])
|
||||
|
||||
|
@ -227,8 +223,7 @@ def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
|
|||
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate()
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5])
|
||||
|
||||
|
@ -237,8 +232,7 @@ def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory):
|
|||
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5, 1), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate()
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1])
|
||||
|
||||
|
@ -247,8 +241,7 @@ def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory):
|
|||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": np.array([1], dtype=np.float32), "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate()
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["state"].ndim == 0
|
||||
|
||||
|
@ -257,8 +250,7 @@ def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory):
|
|||
features = {"caption": {"dtype": "string", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"caption": "Dummy caption", "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate()
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["caption"] == "Dummy caption"
|
||||
|
||||
|
@ -287,14 +279,13 @@ def test_add_frame_image_wrong_range(image_dataset):
|
|||
dataset = image_dataset
|
||||
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255, "task": "Dummy task"})
|
||||
with pytest.raises(FileNotFoundError):
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.save_episode()
|
||||
|
||||
|
||||
def test_add_frame_image(image_dataset):
|
||||
dataset = image_dataset
|
||||
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate()
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||
|
||||
|
@ -302,8 +293,7 @@ def test_add_frame_image(image_dataset):
|
|||
def test_add_frame_image_h_w_c(image_dataset):
|
||||
dataset = image_dataset
|
||||
dataset.add_frame({"image": np.random.rand(*DUMMY_HWC), "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate()
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||
|
||||
|
@ -312,8 +302,7 @@ def test_add_frame_image_uint8(image_dataset):
|
|||
dataset = image_dataset
|
||||
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
|
||||
dataset.add_frame({"image": image, "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate()
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||
|
||||
|
@ -322,8 +311,7 @@ def test_add_frame_image_pil(image_dataset):
|
|||
dataset = image_dataset
|
||||
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
|
||||
dataset.add_frame({"image": Image.fromarray(image), "task": "Dummy task"})
|
||||
dataset.save_episode(encode_videos=False)
|
||||
dataset.consolidate()
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||
|
||||
|
@ -338,7 +326,6 @@ def test_image_array_to_pil_image_wrong_range_float_0_255():
|
|||
# - [ ] test various attributes & state from init and create
|
||||
# - [ ] test init with episodes and check num_frames
|
||||
# - [ ] test add_episode
|
||||
# - [ ] test consolidate
|
||||
# - [ ] test push_to_hub
|
||||
# - [ ] test smaller methods
|
||||
|
||||
|
@ -581,9 +568,9 @@ def test_create_branch():
|
|||
|
||||
def test_dataset_feature_with_forward_slash_raises_error():
|
||||
# make sure dir does not exist
|
||||
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
|
||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
|
||||
dataset_dir = LEROBOT_HOME / "lerobot/test/with/slash"
|
||||
dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash"
|
||||
# make sure does not exist
|
||||
if dataset_dir.exists():
|
||||
dataset_dir.rmdir()
|
||||
|
|
Loading…
Reference in New Issue