Allow dataset creation without robot

This commit is contained in:
Simon Alibert 2024-10-24 00:13:21 +02:00
parent 0d77be90ee
commit 60865e8980
3 changed files with 66 additions and 23 deletions

View File

@ -54,7 +54,7 @@ class ImageWriter:
""" """
def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1): def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1):
self.dir = write_dir / "images" self.dir = write_dir
self.dir.mkdir(parents=True, exist_ok=True) self.dir.mkdir(parents=True, exist_ok=True)
self.image_path = DEFAULT_IMAGE_PATH self.image_path = DEFAULT_IMAGE_PATH
self.num_processes = num_processes self.num_processes = num_processes

View File

@ -35,6 +35,7 @@ from lerobot.common.datasets.utils import (
INFO_PATH, INFO_PATH,
STATS_PATH, STATS_PATH,
TASKS_PATH, TASKS_PATH,
_get_info_from_robot,
append_jsonl, append_jsonl,
check_delta_timestamps, check_delta_timestamps,
check_timestamps_sync, check_timestamps_sync,
@ -683,7 +684,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
) )
self.image_writer = ImageWriter( self.image_writer = ImageWriter(
write_dir=self.root, write_dir=self.root / "images",
num_processes=num_processes, num_processes=num_processes,
num_threads=num_threads, num_threads=num_threads,
) )
@ -734,6 +735,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
) )
# TODO(aliberts) # TODO(aliberts)
# - [ ] add video info in info.json
# Sanity checks: # Sanity checks:
# - [ ] shapes # - [ ] shapes
# - [ ] ep_lenghts # - [ ] ep_lenghts
@ -744,8 +746,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
cls, cls,
repo_id: str, repo_id: str,
fps: int, fps: int,
robot: Robot,
root: Path | None = None, root: Path | None = None,
robot: Robot | None = None,
robot_type: str | None = None,
keys: list[str] | None = None,
image_keys: list[str] | None = None,
video_keys: list[str] = None,
shapes: dict | None = None,
names: dict | None = None,
tolerance_s: float = 1e-4, tolerance_s: float = 1e-4,
image_writer_processes: int = 0, image_writer_processes: int = 0,
image_writer_threads_per_camera: int = 0, image_writer_threads_per_camera: int = 0,
@ -757,25 +765,40 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.repo_id = repo_id obj.repo_id = repo_id
obj.root = root if root is not None else LEROBOT_HOME / repo_id obj.root = root if root is not None else LEROBOT_HOME / repo_id
obj.tolerance_s = tolerance_s obj.tolerance_s = tolerance_s
obj.image_writer = None
if robot is not None:
robot_type, keys, image_keys, video_keys, shapes, names = _get_info_from_robot(robot, use_videos)
if not all(cam.fps == fps for cam in robot.cameras.values()): if not all(cam.fps == fps for cam in robot.cameras.values()):
logging.warning( logging.warning(
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset." f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
"In this case, frames from lower fps cameras will be repeated to fill in the blanks" "In this case, frames from lower fps cameras will be repeated to fill in the blanks"
) )
obj.tasks, obj.stats, obj.episode_dicts = {}, {}, []
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot, use_videos)
write_json(obj.info, obj.root / INFO_PATH)
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
obj.episode_buffer = obj._create_episode_buffer()
obj.image_writer = None
if len(robot.cameras) > 0 and (image_writer_processes or image_writer_threads_per_camera): if len(robot.cameras) > 0 and (image_writer_processes or image_writer_threads_per_camera):
obj.start_image_writter( obj.start_image_writter(
image_writer_processes, image_writer_threads_per_camera * robot.num_cameras image_writer_processes, image_writer_threads_per_camera * robot.num_cameras
) )
elif (
robot_type is None
or keys is None
or image_keys is None
or video_keys is None
or shapes is None
or names is None
):
raise ValueError()
if len(video_keys) > 0 and not use_videos:
raise ValueError
obj.tasks, obj.stats, obj.episode_dicts = {}, {}, []
obj.info = create_empty_dataset_info(
CODEBASE_VERSION, fps, robot_type, keys, image_keys, video_keys, shapes, names
)
write_json(obj.info, obj.root / INFO_PATH)
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
obj.episode_buffer = obj._create_episode_buffer()
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It # This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It
# is used to know when certain operations are need (for instance, computing dataset statistics). In # is used to know when certain operations are need (for instance, computing dataset statistics). In

View File

@ -193,7 +193,7 @@ def load_episode_dicts(local_dir: Path) -> dict:
return list(reader) return list(reader)
def create_empty_dataset_info(codebase_version: str, fps: int, robot: Robot, use_videos: bool = True) -> dict: def _get_info_from_robot(robot: Robot, use_videos: bool) -> tuple[list | dict]:
shapes = {key: len(names) for key, names in robot.names.items()} shapes = {key: len(names) for key, names in robot.names.items()}
camera_shapes = {} camera_shapes = {}
for key, cam in robot.cameras.items(): for key, cam in robot.cameras.items():
@ -203,10 +203,30 @@ def create_empty_dataset_info(codebase_version: str, fps: int, robot: Robot, use
"height": cam.height, "height": cam.height,
"channels": cam.channels, "channels": cam.channels,
} }
keys = list(robot.names)
image_keys = [] if use_videos else list(camera_shapes)
video_keys = list(camera_shapes) if use_videos else []
shapes = {**shapes, **camera_shapes}
names = robot.names
robot_type = robot.robot_type
return robot_type, keys, image_keys, video_keys, shapes, names
def create_empty_dataset_info(
codebase_version: str,
fps: int,
robot_type: str,
keys: list[str],
image_keys: list[str],
video_keys: list[str],
shapes: dict,
names: dict,
) -> dict:
return { return {
"codebase_version": codebase_version, "codebase_version": codebase_version,
"data_path": DEFAULT_PARQUET_PATH, "data_path": DEFAULT_PARQUET_PATH,
"robot_type": robot.robot_type, "robot_type": robot_type,
"total_episodes": 0, "total_episodes": 0,
"total_frames": 0, "total_frames": 0,
"total_tasks": 0, "total_tasks": 0,
@ -215,12 +235,12 @@ def create_empty_dataset_info(codebase_version: str, fps: int, robot: Robot, use
"chunks_size": DEFAULT_CHUNK_SIZE, "chunks_size": DEFAULT_CHUNK_SIZE,
"fps": fps, "fps": fps,
"splits": {}, "splits": {},
"keys": list(robot.names), "keys": keys,
"video_keys": list(camera_shapes) if use_videos else [], "video_keys": video_keys,
"image_keys": [] if use_videos else list(camera_shapes), "image_keys": image_keys,
"shapes": {**shapes, **camera_shapes}, "shapes": shapes,
"names": robot.names, "names": names,
"videos": {"videos_path": DEFAULT_VIDEO_PATH} if use_videos else None, "videos": {"videos_path": DEFAULT_VIDEO_PATH} if len(video_keys) > 0 else None,
} }