Refactor dataset features

This commit is contained in:
Simon Alibert 2024-11-05 13:10:43 +01:00
parent 757ea175d3
commit aed9f4036a
7 changed files with 172 additions and 185 deletions

View File

@ -30,11 +30,11 @@ from huggingface_hub import snapshot_download, upload_folder
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
from lerobot.common.datasets.image_writer import ImageWriter from lerobot.common.datasets.image_writer import ImageWriter
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
DEFAULT_FEATURES,
EPISODES_PATH, EPISODES_PATH,
INFO_PATH, INFO_PATH,
STATS_PATH, STATS_PATH,
TASKS_PATH, TASKS_PATH,
_get_info_from_robot,
append_jsonlines, append_jsonlines,
check_delta_timestamps, check_delta_timestamps,
check_timestamps_sync, check_timestamps_sync,
@ -43,6 +43,7 @@ from lerobot.common.datasets.utils import (
create_empty_dataset_info, create_empty_dataset_info,
get_delta_indices, get_delta_indices,
get_episode_data_index, get_episode_data_index,
get_features_from_robot,
get_hub_safe_version, get_hub_safe_version,
hf_transform_to_torch, hf_transform_to_torch,
load_episodes, load_episodes,
@ -116,7 +117,7 @@ class LeRobotDatasetMetadata:
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path: def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
ep_chunk = self.get_episode_chunk(ep_index) ep_chunk = self.get_episode_chunk(ep_index)
fpath = self.videos_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index) fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
return Path(fpath) return Path(fpath)
def get_episode_chunk(self, ep_index: int) -> int: def get_episode_chunk(self, ep_index: int) -> int:
@ -128,15 +129,20 @@ class LeRobotDatasetMetadata:
return self.info["data_path"] return self.info["data_path"]
@property @property
def videos_path(self) -> str | None: def video_path(self) -> str | None:
"""Formattable string for the video files.""" """Formattable string for the video files."""
return self.info["videos"]["videos_path"] if len(self.video_keys) > 0 else None return self.info["video_path"]
@property @property
def fps(self) -> int: def fps(self) -> int:
"""Frames per second used during data collection.""" """Frames per second used during data collection."""
return self.info["fps"] return self.info["fps"]
@property
def features(self) -> dict[str, dict]:
""""""
return self.info["features"]
@property @property
def keys(self) -> list[str]: def keys(self) -> list[str]:
"""Keys to access non-image data (state, actions etc.).""" """Keys to access non-image data (state, actions etc.)."""
@ -145,22 +151,27 @@ class LeRobotDatasetMetadata:
@property @property
def image_keys(self) -> list[str]: def image_keys(self) -> list[str]:
"""Keys to access visual modalities stored as images.""" """Keys to access visual modalities stored as images."""
return self.info["image_keys"] return [key for key, ft in self.features.items() if ft["dtype"] == "image"]
@property @property
def video_keys(self) -> list[str]: def video_keys(self) -> list[str]:
"""Keys to access visual modalities stored as videos.""" """Keys to access visual modalities stored as videos."""
return self.info["video_keys"] return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
@property @property
def camera_keys(self) -> list[str]: def camera_keys(self) -> list[str]:
"""Keys to access visual modalities (regardless of their storage method).""" """Keys to access visual modalities (regardless of their storage method)."""
return self.image_keys + self.video_keys return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
@property @property
def names(self) -> dict[list[str]]: def names(self) -> dict[str, list[str]]:
"""Names of the various dimensions of vector modalities.""" """Names of the various dimensions of vector modalities."""
return self.info["names"] return {key: ft["names"] for key, ft in self.features.items()}
@property
def shapes(self) -> dict:
"""Shapes for the different features."""
return {key: tuple(ft["shape"]) for key, ft in self.features.items()}
@property @property
def total_episodes(self) -> int: def total_episodes(self) -> int:
@ -187,11 +198,6 @@ class LeRobotDatasetMetadata:
"""Max number of episodes per chunk.""" """Max number of episodes per chunk."""
return self.info["chunks_size"] return self.info["chunks_size"]
@property
def shapes(self) -> dict:
"""Shapes for the different features."""
return self.info["shapes"]
@property @property
def task_to_task_index(self) -> dict: def task_to_task_index(self) -> dict:
return {task: task_idx for task_idx, task in self.tasks.items()} return {task: task_idx for task_idx, task in self.tasks.items()}
@ -253,45 +259,33 @@ class LeRobotDatasetMetadata:
root: Path | None = None, root: Path | None = None,
robot: Robot | None = None, robot: Robot | None = None,
robot_type: str | None = None, robot_type: str | None = None,
keys: list[str] | None = None, features: dict | None = None,
image_keys: list[str] | None = None,
video_keys: list[str] = None,
shapes: dict | None = None,
names: dict | None = None,
use_videos: bool = True, use_videos: bool = True,
) -> "LeRobotDatasetMetadata": ) -> "LeRobotDatasetMetadata":
"""Creates metadata for a LeRobotDataset.""" """Creates metadata for a LeRobotDataset."""
obj = cls.__new__(cls) obj = cls.__new__(cls)
obj.repo_id = repo_id obj.repo_id = repo_id
obj.root = root if root is not None else LEROBOT_HOME / repo_id obj.root = root if root is not None else LEROBOT_HOME / repo_id
obj.image_writer = None
if robot is not None: if robot is not None:
robot_type, keys, image_keys, video_keys, shapes, names = _get_info_from_robot(robot, use_videos) features = get_features_from_robot(robot)
robot_type = robot.robot_type
if not all(cam.fps == fps for cam in robot.cameras.values()): if not all(cam.fps == fps for cam in robot.cameras.values()):
logging.warning( logging.warning(
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset." f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
"In this case, frames from lower fps cameras will be repeated to fill in the blanks" "In this case, frames from lower fps cameras will be repeated to fill in the blanks."
) )
elif ( elif robot_type is None or features is None:
robot_type is None
or keys is None
or image_keys is None
or video_keys is None
or shapes is None
or names is None
):
raise ValueError( raise ValueError(
"Dataset info (robot_type, keys, shapes...) must either come from a Robot or explicitly passed upon creation." "Dataset features must either come from a Robot or explicitly passed upon creation."
) )
else:
if len(video_keys) > 0 and not use_videos: features = {**features, **DEFAULT_FEATURES}
raise ValueError()
obj.tasks, obj.stats, obj.episodes = {}, {}, [] obj.tasks, obj.stats, obj.episodes = {}, {}, []
obj.info = create_empty_dataset_info( obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
CODEBASE_VERSION, fps, robot_type, keys, image_keys, video_keys, shapes, names if len(obj.video_keys) > 0 and not use_videos:
) raise ValueError()
write_json(obj.info, obj.root / INFO_PATH) write_json(obj.info, obj.root / INFO_PATH)
obj.local_files_only = True obj.local_files_only = True
return obj return obj
@ -509,6 +503,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
hf_dataset = load_dataset("parquet", data_files=files, split="train") hf_dataset = load_dataset("parquet", data_files=files, split="train")
hf_dataset.set_transform(hf_transform_to_torch) hf_dataset.set_transform(hf_transform_to_torch)
# return hf_dataset.with_format("torch") TODO
return hf_dataset return hf_dataset
@property @property
@ -662,8 +657,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
"task_index": None, "task_index": None,
"frame_index": [], "frame_index": [],
"timestamp": [], "timestamp": [],
"next.done": [], **{key: [] for key in self.meta.features},
**{key: [] for key in self.meta.keys},
**{key: [] for key in self.meta.image_keys}, **{key: [] for key in self.meta.image_keys},
} }
@ -845,7 +839,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
@classmethod @classmethod
def create( def create(
cls, cls,
metadata: LeRobotDatasetMetadata, repo_id: str,
fps: int,
root: Path | None = None,
robot: Robot | None = None,
robot_type: str | None = None,
features: dict | None = None,
use_videos: bool = True,
tolerance_s: float = 1e-4, tolerance_s: float = 1e-4,
image_writer_processes: int = 0, image_writer_processes: int = 0,
image_writer_threads: int = 0, image_writer_threads: int = 0,
@ -853,7 +853,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
) -> "LeRobotDataset": ) -> "LeRobotDataset":
"""Create a LeRobot Dataset from scratch in order to record data.""" """Create a LeRobot Dataset from scratch in order to record data."""
obj = cls.__new__(cls) obj = cls.__new__(cls)
obj.meta = metadata obj.meta = LeRobotDatasetMetadata.create(
repo_id=repo_id,
fps=fps,
root=root,
robot=robot,
robot_type=robot_type,
features=features,
use_videos=use_videos,
)
obj.repo_id = obj.meta.repo_id obj.repo_id = obj.meta.repo_id
obj.root = obj.meta.root obj.root = obj.meta.root
obj.local_files_only = obj.meta.local_files_only obj.local_files_only = obj.meta.local_files_only

View File

@ -48,6 +48,14 @@ This dataset was created using [LeRobot](https://github.com/huggingface/lerobot)
""" """
DEFAULT_FEATURES = {
"timestamp": {"dtype": "float32", "shape": (1,), "names": None},
"frame_index": {"dtype": "int64", "shape": (1,), "names": None},
"episode_index": {"dtype": "int64", "shape": (1,), "names": None},
"index": {"dtype": "int64", "shape": (1,), "names": None},
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
}
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict: def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator. """Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
@ -214,39 +222,25 @@ def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) ->
return version return version
def _get_info_from_robot(robot: Robot, use_videos: bool) -> tuple[list | dict]: def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
shapes = {key: len(names) for key, names in robot.names.items()} camera_ft = {}
camera_shapes = {} if robot.cameras:
for key, cam in robot.cameras.items(): camera_ft = {
video_key = f"observation.images.{key}" key: {"dtype": "video" if use_videos else "image", **ft}
camera_shapes[video_key] = { for key, ft in robot.camera_features.items()
"width": cam.width,
"height": cam.height,
"channels": cam.channels,
} }
keys = list(robot.names) return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES}
image_keys = [] if use_videos else list(camera_shapes)
video_keys = list(camera_shapes) if use_videos else []
shapes = {**shapes, **camera_shapes}
names = robot.names
robot_type = robot.robot_type
return robot_type, keys, image_keys, video_keys, shapes, names
def create_empty_dataset_info( def create_empty_dataset_info(
codebase_version: str, codebase_version: str,
fps: int, fps: int,
robot_type: str, robot_type: str,
keys: list[str], features: dict,
image_keys: list[str], use_videos: bool,
video_keys: list[str],
shapes: dict,
names: dict,
) -> dict: ) -> dict:
return { return {
"codebase_version": codebase_version, "codebase_version": codebase_version,
"data_path": DEFAULT_PARQUET_PATH,
"robot_type": robot_type, "robot_type": robot_type,
"total_episodes": 0, "total_episodes": 0,
"total_frames": 0, "total_frames": 0,
@ -256,12 +250,9 @@ def create_empty_dataset_info(
"chunks_size": DEFAULT_CHUNK_SIZE, "chunks_size": DEFAULT_CHUNK_SIZE,
"fps": fps, "fps": fps,
"splits": {}, "splits": {},
"keys": keys, "data_path": DEFAULT_PARQUET_PATH,
"video_keys": video_keys, "video_path": DEFAULT_VIDEO_PATH if use_videos else None,
"image_keys": image_keys, "features": features,
"shapes": shapes,
"names": names,
"videos": {"videos_path": DEFAULT_VIDEO_PATH} if len(video_keys) > 0 else None,
} }
@ -400,6 +391,12 @@ def create_lerobot_dataset_card(
tags: list | None = None, text: str | None = None, info: dict | None = None tags: list | None = None, text: str | None = None, info: dict | None = None
) -> DatasetCard: ) -> DatasetCard:
card = DatasetCard(DATASET_CARD_TEMPLATE) card = DatasetCard(DATASET_CARD_TEMPLATE)
card.data.configs = [
{
"config_name": "default",
"data_files": "data/*/*.parquet",
}
]
card.data.task_categories = ["robotics"] card.data.task_categories = ["robotics"]
card.data.tags = ["LeRobot"] card.data.tags = ["LeRobot"]
if tags is not None: if tags is not None:

View File

@ -106,6 +106,7 @@ import json
import math import math
import shutil import shutil
import subprocess import subprocess
import tempfile
import warnings import warnings
from pathlib import Path from pathlib import Path
@ -137,9 +138,8 @@ from lerobot.common.datasets.utils import (
) )
from lerobot.common.datasets.video_utils import ( from lerobot.common.datasets.video_utils import (
VideoFrame, # noqa: F401 VideoFrame, # noqa: F401
get_image_shapes, get_image_pixel_channels,
get_video_info, get_video_info,
get_video_shapes,
) )
from lerobot.common.utils.utils import init_hydra_config from lerobot.common.utils.utils import init_hydra_config
@ -202,21 +202,37 @@ def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
torch.testing.assert_close(stats_json[key], stats[key]) torch.testing.assert_close(stats_json[key], stats[key])
def get_keys(dataset: Dataset) -> dict[str, list]: def get_features_from_hf_dataset(dataset: Dataset, robot_config: dict | None = None) -> dict[str, list]:
sequence_keys, image_keys, video_keys = [], [], [] features = {}
for key, ft in dataset.features.items(): for key, ft in dataset.features.items():
if isinstance(ft, datasets.Value):
dtype = ft.dtype
shape = (1,)
names = None
if isinstance(ft, datasets.Sequence): if isinstance(ft, datasets.Sequence):
sequence_keys.append(key) assert isinstance(ft.feature, datasets.Value)
dtype = ft.feature.dtype
shape = (ft.length,)
names = robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)]
assert len(names) == shape[0]
elif isinstance(ft, datasets.Image): elif isinstance(ft, datasets.Image):
image_keys.append(key) dtype = "image"
image = dataset[0][key] # Assuming first row
channels = get_image_pixel_channels(image)
shape = (image.width, image.height, channels)
names = ["width", "height", "channel"]
elif ft._type == "VideoFrame": elif ft._type == "VideoFrame":
video_keys.append(key) dtype = "video"
shape = None # Add shape later
names = ["width", "height", "channel"]
return { features[key] = {
"sequence": sequence_keys, "dtype": dtype,
"image": image_keys, "shape": shape,
"video": video_keys, "names": names,
} }
return features
def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]: def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]:
@ -259,17 +275,15 @@ def add_task_index_from_tasks_col(
def split_parquet_by_episodes( def split_parquet_by_episodes(
dataset: Dataset, dataset: Dataset,
keys: dict[str, list],
total_episodes: int, total_episodes: int,
total_chunks: int, total_chunks: int,
output_dir: Path, output_dir: Path,
) -> list: ) -> list:
table = dataset.remove_columns(keys["video"])._data.table table = dataset.data.table
episode_lengths = [] episode_lengths = []
for ep_chunk in range(total_chunks): for ep_chunk in range(total_chunks):
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes) ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk) chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True) (output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
for ep_idx in range(ep_chunk_start, ep_chunk_end): for ep_idx in range(ep_chunk_start, ep_chunk_end):
@ -396,27 +410,22 @@ def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[st
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict: def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
hub_api = HfApi()
videos_info_dict = {"videos_path": DEFAULT_VIDEO_PATH}
# Assumes first episode # Assumes first episode
video_files = [ video_files = [
DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0) DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0)
for vid_key in video_keys for vid_key in video_keys
] ]
hub_api = HfApi()
hub_api.snapshot_download( hub_api.snapshot_download(
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
) )
videos_info_dict = {}
for vid_key, vid_path in zip(video_keys, video_files, strict=True): for vid_key, vid_path in zip(video_keys, video_files, strict=True):
videos_info_dict[vid_key] = get_video_info(local_dir / vid_path) videos_info_dict[vid_key] = get_video_info(local_dir / vid_path)
return videos_info_dict return videos_info_dict
def get_generic_motor_names(sequence_shapes: dict) -> dict:
return {key: [f"motor_{i}" for i in range(length)] for key, length in sequence_shapes.items()}
def convert_dataset( def convert_dataset(
repo_id: str, repo_id: str,
local_dir: Path, local_dir: Path,
@ -443,7 +452,8 @@ def convert_dataset(
metadata_v1 = load_json(v1x_dir / V1_INFO_PATH) metadata_v1 = load_json(v1x_dir / V1_INFO_PATH)
dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train") dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
keys = get_keys(dataset) features = get_features_from_hf_dataset(dataset, robot_config)
video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
if single_task and "language_instruction" in dataset.column_names: if single_task and "language_instruction" in dataset.column_names:
warnings.warn( warnings.warn(
@ -457,7 +467,7 @@ def convert_dataset(
episode_indices = sorted(dataset.unique("episode_index")) episode_indices = sorted(dataset.unique("episode_index"))
total_episodes = len(episode_indices) total_episodes = len(episode_indices)
assert episode_indices == list(range(total_episodes)) assert episode_indices == list(range(total_episodes))
total_videos = total_episodes * len(keys["video"]) total_videos = total_episodes * len(video_keys)
total_chunks = total_episodes // DEFAULT_CHUNK_SIZE total_chunks = total_episodes // DEFAULT_CHUNK_SIZE
if total_episodes % DEFAULT_CHUNK_SIZE != 0: if total_episodes % DEFAULT_CHUNK_SIZE != 0:
total_chunks += 1 total_chunks += 1
@ -470,7 +480,6 @@ def convert_dataset(
elif tasks_path: elif tasks_path:
tasks_by_episodes = load_json(tasks_path) tasks_by_episodes = load_json(tasks_path)
tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()} tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()}
# tasks = list(set(tasks_by_episodes.values()))
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes) dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()} tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
elif tasks_col: elif tasks_col:
@ -481,56 +490,50 @@ def convert_dataset(
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks} assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)] tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
write_jsonlines(tasks, v20_dir / TASKS_PATH) write_jsonlines(tasks, v20_dir / TASKS_PATH)
features["task_index"] = {
# Shapes "dtype": "int64",
sequence_shapes = {key: dataset.features[key].length for key in keys["sequence"]} "shape": (1,),
image_shapes = get_image_shapes(dataset, keys["image"]) if len(keys["image"]) > 0 else {} "names": None,
}
# Videos # Videos
if len(keys["video"]) > 0: if video_keys:
assert metadata_v1.get("video", False) assert metadata_v1.get("video", False)
tmp_video_dir = local_dir / "videos" / V20 / repo_id dataset = dataset.remove_columns(video_keys)
tmp_video_dir.mkdir(parents=True, exist_ok=True)
clean_gitattr = Path( clean_gitattr = Path(
hub_api.hf_hub_download( hub_api.hf_hub_download(
repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes" repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes"
) )
).absolute() ).absolute()
move_videos( with tempfile.TemporaryDirectory() as tmp_video_dir:
repo_id, keys["video"], total_episodes, total_chunks, tmp_video_dir, clean_gitattr, branch move_videos(
) repo_id, video_keys, total_episodes, total_chunks, Path(tmp_video_dir), clean_gitattr, branch
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=keys["video"], branch=branch) )
video_shapes = get_video_shapes(videos_info, keys["video"]) videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch)
for img_key in keys["video"]: for key in video_keys:
assert math.isclose(videos_info[img_key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3) features[key]["shape"] = (
videos_info[key].pop("video.width"),
videos_info[key].pop("video.height"),
videos_info[key].pop("video.channels"),
)
features[key]["video_info"] = videos_info[key]
assert math.isclose(videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
if "encoding" in metadata_v1: if "encoding" in metadata_v1:
assert videos_info[img_key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"] assert videos_info[key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
else: else:
assert metadata_v1.get("video", 0) == 0 assert metadata_v1.get("video", 0) == 0
videos_info = None videos_info = None
video_shapes = {}
# Split data into 1 parquet file by episode # Split data into 1 parquet file by episode
episode_lengths = split_parquet_by_episodes(dataset, keys, total_episodes, total_chunks, v20_dir) episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir)
# Names
if robot_config is not None: if robot_config is not None:
robot_type = robot_config["robot_type"] robot_type = robot_config["robot_type"]
names = robot_config["names"]
if "observation.effort" in keys["sequence"]:
names["observation.effort"] = names["observation.state"]
if "observation.velocity" in keys["sequence"]:
names["observation.velocity"] = names["observation.state"]
repo_tags = [robot_type] repo_tags = [robot_type]
else: else:
robot_type = "unknown" robot_type = "unknown"
names = get_generic_motor_names(sequence_shapes)
repo_tags = None repo_tags = None
assert set(names) == set(keys["sequence"])
for key in sequence_shapes:
assert len(names[key]) == sequence_shapes[key]
# Episodes # Episodes
episodes = [ episodes = [
{"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]} {"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
@ -541,7 +544,6 @@ def convert_dataset(
# Assemble metadata v2.0 # Assemble metadata v2.0
metadata_v2_0 = { metadata_v2_0 = {
"codebase_version": V20, "codebase_version": V20,
"data_path": DEFAULT_PARQUET_PATH,
"robot_type": robot_type, "robot_type": robot_type,
"total_episodes": total_episodes, "total_episodes": total_episodes,
"total_frames": len(dataset), "total_frames": len(dataset),
@ -551,15 +553,13 @@ def convert_dataset(
"chunks_size": DEFAULT_CHUNK_SIZE, "chunks_size": DEFAULT_CHUNK_SIZE,
"fps": metadata_v1["fps"], "fps": metadata_v1["fps"],
"splits": {"train": f"0:{total_episodes}"}, "splits": {"train": f"0:{total_episodes}"},
"keys": keys["sequence"], "data_path": DEFAULT_PARQUET_PATH,
"video_keys": keys["video"], "video_path": DEFAULT_VIDEO_PATH if video_keys else None,
"image_keys": keys["image"], "features": features,
"shapes": {**sequence_shapes, **video_shapes, **image_shapes},
"names": names,
"videos": videos_info,
} }
write_json(metadata_v2_0, v20_dir / INFO_PATH) write_json(metadata_v2_0, v20_dir / INFO_PATH)
convert_stats_to_json(v1x_dir, v20_dir) convert_stats_to_json(v1x_dir, v20_dir)
card = create_lerobot_dataset_card(tags=repo_tags, info=metadata_v2_0)
with contextlib.suppress(EntryNotFoundError): with contextlib.suppress(EntryNotFoundError):
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch) hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
@ -585,28 +585,11 @@ def convert_dataset(
revision=branch, revision=branch,
) )
card = create_lerobot_dataset_card(tags=repo_tags, info=metadata_v2_0)
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=branch) card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=branch)
if not test_branch: if not test_branch:
create_branch(repo_id=repo_id, branch=V20, repo_type="dataset") create_branch(repo_id=repo_id, branch=V20, repo_type="dataset")
# TODO:
# - [X] Add shapes
# - [X] Add keys
# - [X] Add paths
# - [X] convert stats.json
# - [X] Add task.json
# - [X] Add names
# - [X] Add robot_type
# - [X] Add splits
# - [X] Push properly to branch v2.0 and delete v1.6 stuff from that branch
# - [X] Handle multitask datasets
# - [X] Handle hf hub repo limits (add chunks logic)
# - [X] Add test-branch
# - [X] Use jsonlines for episodes
# - [X] Add sanity checks (encoding, shapes)
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()

View File

@ -25,7 +25,6 @@ from typing import Any, ClassVar
import pyarrow as pa import pyarrow as pa
import torch import torch
import torchvision import torchvision
from datasets import Dataset
from datasets.features.features import register_feature from datasets.features.features import register_feature
from PIL import Image from PIL import Image
@ -292,33 +291,6 @@ def get_video_info(video_path: Path | str) -> dict:
return video_info return video_info
def get_video_shapes(videos_info: dict, video_keys: list) -> dict:
video_shapes = {}
for img_key in video_keys:
channels = get_video_pixel_channels(videos_info[img_key]["video.pix_fmt"])
video_shapes[img_key] = {
"width": videos_info[img_key]["video.width"],
"height": videos_info[img_key]["video.height"],
"channels": channels,
}
return video_shapes
def get_image_shapes(dataset: Dataset, image_keys: list) -> dict:
image_shapes = {}
for img_key in image_keys:
image = dataset[0][img_key] # Assuming first row
channels = get_image_pixel_channels(image)
image_shapes[img_key] = {
"width": image.width,
"height": image.height,
"channels": channels,
}
return image_shapes
def get_video_pixel_channels(pix_fmt: str) -> int: def get_video_pixel_channels(pix_fmt: str) -> int:
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt: if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
return 1 return 1

View File

@ -226,13 +226,42 @@ class ManipulatorRobot:
self.is_connected = False self.is_connected = False
self.logs = {} self.logs = {}
action_names = [f"{arm}_{motor}" for arm, bus in self.leader_arms.items() for motor in bus.motors] def get_motor_names(self, arm: dict[str, MotorsBus]) -> list:
state_names = [f"{arm}_{motor}" for arm, bus in self.follower_arms.items() for motor in bus.motors] return [f"{arm}_{motor}" for arm, bus in arm.items() for motor in bus.motors]
self.names = {
"action": action_names, @property
"observation.state": state_names, def camera_features(self) -> dict:
cam_ft = {}
for cam_key, cam in self.cameras.items():
key = f"observation.images.{cam_key}"
cam_ft[key] = {
"shape": (cam.width, cam.height, cam.channels),
"names": ["width", "height", "channels"],
"info": None,
}
return cam_ft
@property
def motor_features(self) -> dict:
action_names = self.get_motor_names(self.leader_arms)
state_names = self.get_motor_names(self.leader_arms)
return {
"action": {
"dtype": "float32",
"shape": (len(action_names),),
"names": action_names,
},
"observation.state": {
"dtype": "float32",
"shape": (len(state_names),),
"names": state_names,
},
} }
@property
def features(self):
return {**self.motor_features, **self.camera_features}
@property @property
def has_camera(self): def has_camera(self):
return len(self.cameras) > 0 return len(self.cameras) > 0

View File

@ -11,6 +11,7 @@ def get_arm_id(name, arm_type):
class Robot(Protocol): class Robot(Protocol):
# TODO(rcadene, aliberts): Add unit test checking the protocol is implemented in the corresponding classes # TODO(rcadene, aliberts): Add unit test checking the protocol is implemented in the corresponding classes
robot_type: str robot_type: str
features: dict
def connect(self): ... def connect(self): ...
def run_calibration(self): ... def run_calibration(self): ...

View File

@ -105,7 +105,7 @@ from pathlib import Path
from typing import List from typing import List
# from safetensors.torch import load_file, save_file # from safetensors.torch import load_file, save_file
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.robot_devices.control_utils import ( from lerobot.common.robot_devices.control_utils import (
control_loop, control_loop,
has_method, has_method,
@ -234,15 +234,12 @@ def record(
# Create empty dataset or load existing saved episodes # Create empty dataset or load existing saved episodes
sanity_check_dataset_name(repo_id, policy) sanity_check_dataset_name(repo_id, policy)
dataset_metadata = LeRobotDatasetMetadata.create( dataset = LeRobotDataset.create(
repo_id, repo_id,
fps, fps,
root=root, root=root,
robot=robot, robot=robot,
use_videos=video, use_videos=video,
)
dataset = LeRobotDataset.create(
dataset_metadata,
image_writer_processes=num_image_writer_processes, image_writer_processes=num_image_writer_processes,
image_writer_threads=num_image_writer_threads_per_camera, image_writer_threads=num_image_writer_threads_per_camera,
) )