Refactor dataset features
This commit is contained in:
parent
757ea175d3
commit
aed9f4036a
|
@ -30,11 +30,11 @@ from huggingface_hub import snapshot_download, upload_folder
|
|||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
|
||||
from lerobot.common.datasets.image_writer import ImageWriter
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_FEATURES,
|
||||
EPISODES_PATH,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
TASKS_PATH,
|
||||
_get_info_from_robot,
|
||||
append_jsonlines,
|
||||
check_delta_timestamps,
|
||||
check_timestamps_sync,
|
||||
|
@ -43,6 +43,7 @@ from lerobot.common.datasets.utils import (
|
|||
create_empty_dataset_info,
|
||||
get_delta_indices,
|
||||
get_episode_data_index,
|
||||
get_features_from_robot,
|
||||
get_hub_safe_version,
|
||||
hf_transform_to_torch,
|
||||
load_episodes,
|
||||
|
@ -116,7 +117,7 @@ class LeRobotDatasetMetadata:
|
|||
|
||||
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
||||
ep_chunk = self.get_episode_chunk(ep_index)
|
||||
fpath = self.videos_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
||||
fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
||||
return Path(fpath)
|
||||
|
||||
def get_episode_chunk(self, ep_index: int) -> int:
|
||||
|
@ -128,15 +129,20 @@ class LeRobotDatasetMetadata:
|
|||
return self.info["data_path"]
|
||||
|
||||
@property
|
||||
def videos_path(self) -> str | None:
|
||||
def video_path(self) -> str | None:
|
||||
"""Formattable string for the video files."""
|
||||
return self.info["videos"]["videos_path"] if len(self.video_keys) > 0 else None
|
||||
return self.info["video_path"]
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
"""Frames per second used during data collection."""
|
||||
return self.info["fps"]
|
||||
|
||||
@property
|
||||
def features(self) -> dict[str, dict]:
|
||||
""""""
|
||||
return self.info["features"]
|
||||
|
||||
@property
|
||||
def keys(self) -> list[str]:
|
||||
"""Keys to access non-image data (state, actions etc.)."""
|
||||
|
@ -145,22 +151,27 @@ class LeRobotDatasetMetadata:
|
|||
@property
|
||||
def image_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities stored as images."""
|
||||
return self.info["image_keys"]
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] == "image"]
|
||||
|
||||
@property
|
||||
def video_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities stored as videos."""
|
||||
return self.info["video_keys"]
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
return self.image_keys + self.video_keys
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
||||
|
||||
@property
|
||||
def names(self) -> dict[list[str]]:
|
||||
def names(self) -> dict[str, list[str]]:
|
||||
"""Names of the various dimensions of vector modalities."""
|
||||
return self.info["names"]
|
||||
return {key: ft["names"] for key, ft in self.features.items()}
|
||||
|
||||
@property
|
||||
def shapes(self) -> dict:
|
||||
"""Shapes for the different features."""
|
||||
return {key: tuple(ft["shape"]) for key, ft in self.features.items()}
|
||||
|
||||
@property
|
||||
def total_episodes(self) -> int:
|
||||
|
@ -187,11 +198,6 @@ class LeRobotDatasetMetadata:
|
|||
"""Max number of episodes per chunk."""
|
||||
return self.info["chunks_size"]
|
||||
|
||||
@property
|
||||
def shapes(self) -> dict:
|
||||
"""Shapes for the different features."""
|
||||
return self.info["shapes"]
|
||||
|
||||
@property
|
||||
def task_to_task_index(self) -> dict:
|
||||
return {task: task_idx for task_idx, task in self.tasks.items()}
|
||||
|
@ -253,45 +259,33 @@ class LeRobotDatasetMetadata:
|
|||
root: Path | None = None,
|
||||
robot: Robot | None = None,
|
||||
robot_type: str | None = None,
|
||||
keys: list[str] | None = None,
|
||||
image_keys: list[str] | None = None,
|
||||
video_keys: list[str] = None,
|
||||
shapes: dict | None = None,
|
||||
names: dict | None = None,
|
||||
features: dict | None = None,
|
||||
use_videos: bool = True,
|
||||
) -> "LeRobotDatasetMetadata":
|
||||
"""Creates metadata for a LeRobotDataset."""
|
||||
obj = cls.__new__(cls)
|
||||
obj.repo_id = repo_id
|
||||
obj.root = root if root is not None else LEROBOT_HOME / repo_id
|
||||
obj.image_writer = None
|
||||
|
||||
if robot is not None:
|
||||
robot_type, keys, image_keys, video_keys, shapes, names = _get_info_from_robot(robot, use_videos)
|
||||
features = get_features_from_robot(robot)
|
||||
robot_type = robot.robot_type
|
||||
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||
logging.warning(
|
||||
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
||||
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
|
||||
"In this case, frames from lower fps cameras will be repeated to fill in the blanks."
|
||||
)
|
||||
elif (
|
||||
robot_type is None
|
||||
or keys is None
|
||||
or image_keys is None
|
||||
or video_keys is None
|
||||
or shapes is None
|
||||
or names is None
|
||||
):
|
||||
elif robot_type is None or features is None:
|
||||
raise ValueError(
|
||||
"Dataset info (robot_type, keys, shapes...) must either come from a Robot or explicitly passed upon creation."
|
||||
"Dataset features must either come from a Robot or explicitly passed upon creation."
|
||||
)
|
||||
|
||||
if len(video_keys) > 0 and not use_videos:
|
||||
raise ValueError()
|
||||
else:
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
|
||||
obj.tasks, obj.stats, obj.episodes = {}, {}, []
|
||||
obj.info = create_empty_dataset_info(
|
||||
CODEBASE_VERSION, fps, robot_type, keys, image_keys, video_keys, shapes, names
|
||||
)
|
||||
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
|
||||
if len(obj.video_keys) > 0 and not use_videos:
|
||||
raise ValueError()
|
||||
write_json(obj.info, obj.root / INFO_PATH)
|
||||
obj.local_files_only = True
|
||||
return obj
|
||||
|
@ -509,6 +503,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
hf_dataset = load_dataset("parquet", data_files=files, split="train")
|
||||
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
# return hf_dataset.with_format("torch") TODO
|
||||
return hf_dataset
|
||||
|
||||
@property
|
||||
|
@ -662,8 +657,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
"task_index": None,
|
||||
"frame_index": [],
|
||||
"timestamp": [],
|
||||
"next.done": [],
|
||||
**{key: [] for key in self.meta.keys},
|
||||
**{key: [] for key in self.meta.features},
|
||||
**{key: [] for key in self.meta.image_keys},
|
||||
}
|
||||
|
||||
|
@ -845,7 +839,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
metadata: LeRobotDatasetMetadata,
|
||||
repo_id: str,
|
||||
fps: int,
|
||||
root: Path | None = None,
|
||||
robot: Robot | None = None,
|
||||
robot_type: str | None = None,
|
||||
features: dict | None = None,
|
||||
use_videos: bool = True,
|
||||
tolerance_s: float = 1e-4,
|
||||
image_writer_processes: int = 0,
|
||||
image_writer_threads: int = 0,
|
||||
|
@ -853,7 +853,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
) -> "LeRobotDataset":
|
||||
"""Create a LeRobot Dataset from scratch in order to record data."""
|
||||
obj = cls.__new__(cls)
|
||||
obj.meta = metadata
|
||||
obj.meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
fps=fps,
|
||||
root=root,
|
||||
robot=robot,
|
||||
robot_type=robot_type,
|
||||
features=features,
|
||||
use_videos=use_videos,
|
||||
)
|
||||
obj.repo_id = obj.meta.repo_id
|
||||
obj.root = obj.meta.root
|
||||
obj.local_files_only = obj.meta.local_files_only
|
||||
|
|
|
@ -48,6 +48,14 @@ This dataset was created using [LeRobot](https://github.com/huggingface/lerobot)
|
|||
|
||||
"""
|
||||
|
||||
DEFAULT_FEATURES = {
|
||||
"timestamp": {"dtype": "float32", "shape": (1,), "names": None},
|
||||
"frame_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
"episode_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
"index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
}
|
||||
|
||||
|
||||
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
|
||||
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
|
||||
|
@ -214,39 +222,25 @@ def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) ->
|
|||
return version
|
||||
|
||||
|
||||
def _get_info_from_robot(robot: Robot, use_videos: bool) -> tuple[list | dict]:
|
||||
shapes = {key: len(names) for key, names in robot.names.items()}
|
||||
camera_shapes = {}
|
||||
for key, cam in robot.cameras.items():
|
||||
video_key = f"observation.images.{key}"
|
||||
camera_shapes[video_key] = {
|
||||
"width": cam.width,
|
||||
"height": cam.height,
|
||||
"channels": cam.channels,
|
||||
def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
|
||||
camera_ft = {}
|
||||
if robot.cameras:
|
||||
camera_ft = {
|
||||
key: {"dtype": "video" if use_videos else "image", **ft}
|
||||
for key, ft in robot.camera_features.items()
|
||||
}
|
||||
keys = list(robot.names)
|
||||
image_keys = [] if use_videos else list(camera_shapes)
|
||||
video_keys = list(camera_shapes) if use_videos else []
|
||||
shapes = {**shapes, **camera_shapes}
|
||||
names = robot.names
|
||||
robot_type = robot.robot_type
|
||||
|
||||
return robot_type, keys, image_keys, video_keys, shapes, names
|
||||
return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES}
|
||||
|
||||
|
||||
def create_empty_dataset_info(
|
||||
codebase_version: str,
|
||||
fps: int,
|
||||
robot_type: str,
|
||||
keys: list[str],
|
||||
image_keys: list[str],
|
||||
video_keys: list[str],
|
||||
shapes: dict,
|
||||
names: dict,
|
||||
features: dict,
|
||||
use_videos: bool,
|
||||
) -> dict:
|
||||
return {
|
||||
"codebase_version": codebase_version,
|
||||
"data_path": DEFAULT_PARQUET_PATH,
|
||||
"robot_type": robot_type,
|
||||
"total_episodes": 0,
|
||||
"total_frames": 0,
|
||||
|
@ -256,12 +250,9 @@ def create_empty_dataset_info(
|
|||
"chunks_size": DEFAULT_CHUNK_SIZE,
|
||||
"fps": fps,
|
||||
"splits": {},
|
||||
"keys": keys,
|
||||
"video_keys": video_keys,
|
||||
"image_keys": image_keys,
|
||||
"shapes": shapes,
|
||||
"names": names,
|
||||
"videos": {"videos_path": DEFAULT_VIDEO_PATH} if len(video_keys) > 0 else None,
|
||||
"data_path": DEFAULT_PARQUET_PATH,
|
||||
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
|
||||
"features": features,
|
||||
}
|
||||
|
||||
|
||||
|
@ -400,6 +391,12 @@ def create_lerobot_dataset_card(
|
|||
tags: list | None = None, text: str | None = None, info: dict | None = None
|
||||
) -> DatasetCard:
|
||||
card = DatasetCard(DATASET_CARD_TEMPLATE)
|
||||
card.data.configs = [
|
||||
{
|
||||
"config_name": "default",
|
||||
"data_files": "data/*/*.parquet",
|
||||
}
|
||||
]
|
||||
card.data.task_categories = ["robotics"]
|
||||
card.data.tags = ["LeRobot"]
|
||||
if tags is not None:
|
||||
|
|
|
@ -106,6 +106,7 @@ import json
|
|||
import math
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -137,9 +138,8 @@ from lerobot.common.datasets.utils import (
|
|||
)
|
||||
from lerobot.common.datasets.video_utils import (
|
||||
VideoFrame, # noqa: F401
|
||||
get_image_shapes,
|
||||
get_image_pixel_channels,
|
||||
get_video_info,
|
||||
get_video_shapes,
|
||||
)
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
|
||||
|
@ -202,21 +202,37 @@ def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
|
|||
torch.testing.assert_close(stats_json[key], stats[key])
|
||||
|
||||
|
||||
def get_keys(dataset: Dataset) -> dict[str, list]:
|
||||
sequence_keys, image_keys, video_keys = [], [], []
|
||||
def get_features_from_hf_dataset(dataset: Dataset, robot_config: dict | None = None) -> dict[str, list]:
|
||||
features = {}
|
||||
for key, ft in dataset.features.items():
|
||||
if isinstance(ft, datasets.Value):
|
||||
dtype = ft.dtype
|
||||
shape = (1,)
|
||||
names = None
|
||||
if isinstance(ft, datasets.Sequence):
|
||||
sequence_keys.append(key)
|
||||
assert isinstance(ft.feature, datasets.Value)
|
||||
dtype = ft.feature.dtype
|
||||
shape = (ft.length,)
|
||||
names = robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)]
|
||||
assert len(names) == shape[0]
|
||||
elif isinstance(ft, datasets.Image):
|
||||
image_keys.append(key)
|
||||
dtype = "image"
|
||||
image = dataset[0][key] # Assuming first row
|
||||
channels = get_image_pixel_channels(image)
|
||||
shape = (image.width, image.height, channels)
|
||||
names = ["width", "height", "channel"]
|
||||
elif ft._type == "VideoFrame":
|
||||
video_keys.append(key)
|
||||
dtype = "video"
|
||||
shape = None # Add shape later
|
||||
names = ["width", "height", "channel"]
|
||||
|
||||
return {
|
||||
"sequence": sequence_keys,
|
||||
"image": image_keys,
|
||||
"video": video_keys,
|
||||
}
|
||||
features[key] = {
|
||||
"dtype": dtype,
|
||||
"shape": shape,
|
||||
"names": names,
|
||||
}
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]:
|
||||
|
@ -259,17 +275,15 @@ def add_task_index_from_tasks_col(
|
|||
|
||||
def split_parquet_by_episodes(
|
||||
dataset: Dataset,
|
||||
keys: dict[str, list],
|
||||
total_episodes: int,
|
||||
total_chunks: int,
|
||||
output_dir: Path,
|
||||
) -> list:
|
||||
table = dataset.remove_columns(keys["video"])._data.table
|
||||
table = dataset.data.table
|
||||
episode_lengths = []
|
||||
for ep_chunk in range(total_chunks):
|
||||
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
|
||||
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
|
||||
|
||||
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
|
||||
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
|
||||
for ep_idx in range(ep_chunk_start, ep_chunk_end):
|
||||
|
@ -396,27 +410,22 @@ def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[st
|
|||
|
||||
|
||||
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
|
||||
hub_api = HfApi()
|
||||
videos_info_dict = {"videos_path": DEFAULT_VIDEO_PATH}
|
||||
|
||||
# Assumes first episode
|
||||
video_files = [
|
||||
DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0)
|
||||
for vid_key in video_keys
|
||||
]
|
||||
hub_api = HfApi()
|
||||
hub_api.snapshot_download(
|
||||
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
|
||||
)
|
||||
videos_info_dict = {}
|
||||
for vid_key, vid_path in zip(video_keys, video_files, strict=True):
|
||||
videos_info_dict[vid_key] = get_video_info(local_dir / vid_path)
|
||||
|
||||
return videos_info_dict
|
||||
|
||||
|
||||
def get_generic_motor_names(sequence_shapes: dict) -> dict:
|
||||
return {key: [f"motor_{i}" for i in range(length)] for key, length in sequence_shapes.items()}
|
||||
|
||||
|
||||
def convert_dataset(
|
||||
repo_id: str,
|
||||
local_dir: Path,
|
||||
|
@ -443,7 +452,8 @@ def convert_dataset(
|
|||
|
||||
metadata_v1 = load_json(v1x_dir / V1_INFO_PATH)
|
||||
dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
|
||||
keys = get_keys(dataset)
|
||||
features = get_features_from_hf_dataset(dataset, robot_config)
|
||||
video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
|
||||
|
||||
if single_task and "language_instruction" in dataset.column_names:
|
||||
warnings.warn(
|
||||
|
@ -457,7 +467,7 @@ def convert_dataset(
|
|||
episode_indices = sorted(dataset.unique("episode_index"))
|
||||
total_episodes = len(episode_indices)
|
||||
assert episode_indices == list(range(total_episodes))
|
||||
total_videos = total_episodes * len(keys["video"])
|
||||
total_videos = total_episodes * len(video_keys)
|
||||
total_chunks = total_episodes // DEFAULT_CHUNK_SIZE
|
||||
if total_episodes % DEFAULT_CHUNK_SIZE != 0:
|
||||
total_chunks += 1
|
||||
|
@ -470,7 +480,6 @@ def convert_dataset(
|
|||
elif tasks_path:
|
||||
tasks_by_episodes = load_json(tasks_path)
|
||||
tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()}
|
||||
# tasks = list(set(tasks_by_episodes.values()))
|
||||
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
|
||||
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
|
||||
elif tasks_col:
|
||||
|
@ -481,56 +490,50 @@ def convert_dataset(
|
|||
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
|
||||
tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
|
||||
write_jsonlines(tasks, v20_dir / TASKS_PATH)
|
||||
|
||||
# Shapes
|
||||
sequence_shapes = {key: dataset.features[key].length for key in keys["sequence"]}
|
||||
image_shapes = get_image_shapes(dataset, keys["image"]) if len(keys["image"]) > 0 else {}
|
||||
features["task_index"] = {
|
||||
"dtype": "int64",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
}
|
||||
|
||||
# Videos
|
||||
if len(keys["video"]) > 0:
|
||||
if video_keys:
|
||||
assert metadata_v1.get("video", False)
|
||||
tmp_video_dir = local_dir / "videos" / V20 / repo_id
|
||||
tmp_video_dir.mkdir(parents=True, exist_ok=True)
|
||||
dataset = dataset.remove_columns(video_keys)
|
||||
clean_gitattr = Path(
|
||||
hub_api.hf_hub_download(
|
||||
repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes"
|
||||
)
|
||||
).absolute()
|
||||
move_videos(
|
||||
repo_id, keys["video"], total_episodes, total_chunks, tmp_video_dir, clean_gitattr, branch
|
||||
)
|
||||
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=keys["video"], branch=branch)
|
||||
video_shapes = get_video_shapes(videos_info, keys["video"])
|
||||
for img_key in keys["video"]:
|
||||
assert math.isclose(videos_info[img_key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
|
||||
with tempfile.TemporaryDirectory() as tmp_video_dir:
|
||||
move_videos(
|
||||
repo_id, video_keys, total_episodes, total_chunks, Path(tmp_video_dir), clean_gitattr, branch
|
||||
)
|
||||
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch)
|
||||
for key in video_keys:
|
||||
features[key]["shape"] = (
|
||||
videos_info[key].pop("video.width"),
|
||||
videos_info[key].pop("video.height"),
|
||||
videos_info[key].pop("video.channels"),
|
||||
)
|
||||
features[key]["video_info"] = videos_info[key]
|
||||
assert math.isclose(videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
|
||||
if "encoding" in metadata_v1:
|
||||
assert videos_info[img_key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
|
||||
assert videos_info[key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
|
||||
else:
|
||||
assert metadata_v1.get("video", 0) == 0
|
||||
videos_info = None
|
||||
video_shapes = {}
|
||||
|
||||
# Split data into 1 parquet file by episode
|
||||
episode_lengths = split_parquet_by_episodes(dataset, keys, total_episodes, total_chunks, v20_dir)
|
||||
episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir)
|
||||
|
||||
# Names
|
||||
if robot_config is not None:
|
||||
robot_type = robot_config["robot_type"]
|
||||
names = robot_config["names"]
|
||||
if "observation.effort" in keys["sequence"]:
|
||||
names["observation.effort"] = names["observation.state"]
|
||||
if "observation.velocity" in keys["sequence"]:
|
||||
names["observation.velocity"] = names["observation.state"]
|
||||
repo_tags = [robot_type]
|
||||
else:
|
||||
robot_type = "unknown"
|
||||
names = get_generic_motor_names(sequence_shapes)
|
||||
repo_tags = None
|
||||
|
||||
assert set(names) == set(keys["sequence"])
|
||||
for key in sequence_shapes:
|
||||
assert len(names[key]) == sequence_shapes[key]
|
||||
|
||||
# Episodes
|
||||
episodes = [
|
||||
{"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
|
||||
|
@ -541,7 +544,6 @@ def convert_dataset(
|
|||
# Assemble metadata v2.0
|
||||
metadata_v2_0 = {
|
||||
"codebase_version": V20,
|
||||
"data_path": DEFAULT_PARQUET_PATH,
|
||||
"robot_type": robot_type,
|
||||
"total_episodes": total_episodes,
|
||||
"total_frames": len(dataset),
|
||||
|
@ -551,15 +553,13 @@ def convert_dataset(
|
|||
"chunks_size": DEFAULT_CHUNK_SIZE,
|
||||
"fps": metadata_v1["fps"],
|
||||
"splits": {"train": f"0:{total_episodes}"},
|
||||
"keys": keys["sequence"],
|
||||
"video_keys": keys["video"],
|
||||
"image_keys": keys["image"],
|
||||
"shapes": {**sequence_shapes, **video_shapes, **image_shapes},
|
||||
"names": names,
|
||||
"videos": videos_info,
|
||||
"data_path": DEFAULT_PARQUET_PATH,
|
||||
"video_path": DEFAULT_VIDEO_PATH if video_keys else None,
|
||||
"features": features,
|
||||
}
|
||||
write_json(metadata_v2_0, v20_dir / INFO_PATH)
|
||||
convert_stats_to_json(v1x_dir, v20_dir)
|
||||
card = create_lerobot_dataset_card(tags=repo_tags, info=metadata_v2_0)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError):
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
|
||||
|
@ -585,28 +585,11 @@ def convert_dataset(
|
|||
revision=branch,
|
||||
)
|
||||
|
||||
card = create_lerobot_dataset_card(tags=repo_tags, info=metadata_v2_0)
|
||||
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=branch)
|
||||
|
||||
if not test_branch:
|
||||
create_branch(repo_id=repo_id, branch=V20, repo_type="dataset")
|
||||
|
||||
# TODO:
|
||||
# - [X] Add shapes
|
||||
# - [X] Add keys
|
||||
# - [X] Add paths
|
||||
# - [X] convert stats.json
|
||||
# - [X] Add task.json
|
||||
# - [X] Add names
|
||||
# - [X] Add robot_type
|
||||
# - [X] Add splits
|
||||
# - [X] Push properly to branch v2.0 and delete v1.6 stuff from that branch
|
||||
# - [X] Handle multitask datasets
|
||||
# - [X] Handle hf hub repo limits (add chunks logic)
|
||||
# - [X] Add test-branch
|
||||
# - [X] Use jsonlines for episodes
|
||||
# - [X] Add sanity checks (encoding, shapes)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
|
|
@ -25,7 +25,6 @@ from typing import Any, ClassVar
|
|||
import pyarrow as pa
|
||||
import torch
|
||||
import torchvision
|
||||
from datasets import Dataset
|
||||
from datasets.features.features import register_feature
|
||||
from PIL import Image
|
||||
|
||||
|
@ -292,33 +291,6 @@ def get_video_info(video_path: Path | str) -> dict:
|
|||
return video_info
|
||||
|
||||
|
||||
def get_video_shapes(videos_info: dict, video_keys: list) -> dict:
|
||||
video_shapes = {}
|
||||
for img_key in video_keys:
|
||||
channels = get_video_pixel_channels(videos_info[img_key]["video.pix_fmt"])
|
||||
video_shapes[img_key] = {
|
||||
"width": videos_info[img_key]["video.width"],
|
||||
"height": videos_info[img_key]["video.height"],
|
||||
"channels": channels,
|
||||
}
|
||||
|
||||
return video_shapes
|
||||
|
||||
|
||||
def get_image_shapes(dataset: Dataset, image_keys: list) -> dict:
|
||||
image_shapes = {}
|
||||
for img_key in image_keys:
|
||||
image = dataset[0][img_key] # Assuming first row
|
||||
channels = get_image_pixel_channels(image)
|
||||
image_shapes[img_key] = {
|
||||
"width": image.width,
|
||||
"height": image.height,
|
||||
"channels": channels,
|
||||
}
|
||||
|
||||
return image_shapes
|
||||
|
||||
|
||||
def get_video_pixel_channels(pix_fmt: str) -> int:
|
||||
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
|
||||
return 1
|
||||
|
|
|
@ -226,13 +226,42 @@ class ManipulatorRobot:
|
|||
self.is_connected = False
|
||||
self.logs = {}
|
||||
|
||||
action_names = [f"{arm}_{motor}" for arm, bus in self.leader_arms.items() for motor in bus.motors]
|
||||
state_names = [f"{arm}_{motor}" for arm, bus in self.follower_arms.items() for motor in bus.motors]
|
||||
self.names = {
|
||||
"action": action_names,
|
||||
"observation.state": state_names,
|
||||
def get_motor_names(self, arm: dict[str, MotorsBus]) -> list:
|
||||
return [f"{arm}_{motor}" for arm, bus in arm.items() for motor in bus.motors]
|
||||
|
||||
@property
|
||||
def camera_features(self) -> dict:
|
||||
cam_ft = {}
|
||||
for cam_key, cam in self.cameras.items():
|
||||
key = f"observation.images.{cam_key}"
|
||||
cam_ft[key] = {
|
||||
"shape": (cam.width, cam.height, cam.channels),
|
||||
"names": ["width", "height", "channels"],
|
||||
"info": None,
|
||||
}
|
||||
return cam_ft
|
||||
|
||||
@property
|
||||
def motor_features(self) -> dict:
|
||||
action_names = self.get_motor_names(self.leader_arms)
|
||||
state_names = self.get_motor_names(self.leader_arms)
|
||||
return {
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(action_names),),
|
||||
"names": action_names,
|
||||
},
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(state_names),),
|
||||
"names": state_names,
|
||||
},
|
||||
}
|
||||
|
||||
@property
|
||||
def features(self):
|
||||
return {**self.motor_features, **self.camera_features}
|
||||
|
||||
@property
|
||||
def has_camera(self):
|
||||
return len(self.cameras) > 0
|
||||
|
|
|
@ -11,6 +11,7 @@ def get_arm_id(name, arm_type):
|
|||
class Robot(Protocol):
|
||||
# TODO(rcadene, aliberts): Add unit test checking the protocol is implemented in the corresponding classes
|
||||
robot_type: str
|
||||
features: dict
|
||||
|
||||
def connect(self): ...
|
||||
def run_calibration(self): ...
|
||||
|
|
|
@ -105,7 +105,7 @@ from pathlib import Path
|
|||
from typing import List
|
||||
|
||||
# from safetensors.torch import load_file, save_file
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.robot_devices.control_utils import (
|
||||
control_loop,
|
||||
has_method,
|
||||
|
@ -234,15 +234,12 @@ def record(
|
|||
|
||||
# Create empty dataset or load existing saved episodes
|
||||
sanity_check_dataset_name(repo_id, policy)
|
||||
dataset_metadata = LeRobotDatasetMetadata.create(
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id,
|
||||
fps,
|
||||
root=root,
|
||||
robot=robot,
|
||||
use_videos=video,
|
||||
)
|
||||
dataset = LeRobotDataset.create(
|
||||
dataset_metadata,
|
||||
image_writer_processes=num_image_writer_processes,
|
||||
image_writer_threads=num_image_writer_threads_per_camera,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue