Allow dataset creation without robot
This commit is contained in:
parent
0d77be90ee
commit
60865e8980
|
@ -54,7 +54,7 @@ class ImageWriter:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1):
|
def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1):
|
||||||
self.dir = write_dir / "images"
|
self.dir = write_dir
|
||||||
self.dir.mkdir(parents=True, exist_ok=True)
|
self.dir.mkdir(parents=True, exist_ok=True)
|
||||||
self.image_path = DEFAULT_IMAGE_PATH
|
self.image_path = DEFAULT_IMAGE_PATH
|
||||||
self.num_processes = num_processes
|
self.num_processes = num_processes
|
||||||
|
|
|
@ -35,6 +35,7 @@ from lerobot.common.datasets.utils import (
|
||||||
INFO_PATH,
|
INFO_PATH,
|
||||||
STATS_PATH,
|
STATS_PATH,
|
||||||
TASKS_PATH,
|
TASKS_PATH,
|
||||||
|
_get_info_from_robot,
|
||||||
append_jsonl,
|
append_jsonl,
|
||||||
check_delta_timestamps,
|
check_delta_timestamps,
|
||||||
check_timestamps_sync,
|
check_timestamps_sync,
|
||||||
|
@ -683,7 +684,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.image_writer = ImageWriter(
|
self.image_writer = ImageWriter(
|
||||||
write_dir=self.root,
|
write_dir=self.root / "images",
|
||||||
num_processes=num_processes,
|
num_processes=num_processes,
|
||||||
num_threads=num_threads,
|
num_threads=num_threads,
|
||||||
)
|
)
|
||||||
|
@ -734,6 +735,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(aliberts)
|
# TODO(aliberts)
|
||||||
|
# - [ ] add video info in info.json
|
||||||
# Sanity checks:
|
# Sanity checks:
|
||||||
# - [ ] shapes
|
# - [ ] shapes
|
||||||
# - [ ] ep_lenghts
|
# - [ ] ep_lenghts
|
||||||
|
@ -744,8 +746,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
cls,
|
cls,
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
fps: int,
|
fps: int,
|
||||||
robot: Robot,
|
|
||||||
root: Path | None = None,
|
root: Path | None = None,
|
||||||
|
robot: Robot | None = None,
|
||||||
|
robot_type: str | None = None,
|
||||||
|
keys: list[str] | None = None,
|
||||||
|
image_keys: list[str] | None = None,
|
||||||
|
video_keys: list[str] = None,
|
||||||
|
shapes: dict | None = None,
|
||||||
|
names: dict | None = None,
|
||||||
tolerance_s: float = 1e-4,
|
tolerance_s: float = 1e-4,
|
||||||
image_writer_processes: int = 0,
|
image_writer_processes: int = 0,
|
||||||
image_writer_threads_per_camera: int = 0,
|
image_writer_threads_per_camera: int = 0,
|
||||||
|
@ -757,26 +765,41 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
obj.repo_id = repo_id
|
obj.repo_id = repo_id
|
||||||
obj.root = root if root is not None else LEROBOT_HOME / repo_id
|
obj.root = root if root is not None else LEROBOT_HOME / repo_id
|
||||||
obj.tolerance_s = tolerance_s
|
obj.tolerance_s = tolerance_s
|
||||||
|
obj.image_writer = None
|
||||||
|
|
||||||
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
if robot is not None:
|
||||||
logging.warning(
|
robot_type, keys, image_keys, video_keys, shapes, names = _get_info_from_robot(robot, use_videos)
|
||||||
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||||
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
|
logging.warning(
|
||||||
)
|
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
||||||
|
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
|
||||||
|
)
|
||||||
|
if len(robot.cameras) > 0 and (image_writer_processes or image_writer_threads_per_camera):
|
||||||
|
obj.start_image_writter(
|
||||||
|
image_writer_processes, image_writer_threads_per_camera * robot.num_cameras
|
||||||
|
)
|
||||||
|
elif (
|
||||||
|
robot_type is None
|
||||||
|
or keys is None
|
||||||
|
or image_keys is None
|
||||||
|
or video_keys is None
|
||||||
|
or shapes is None
|
||||||
|
or names is None
|
||||||
|
):
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
|
if len(video_keys) > 0 and not use_videos:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
obj.tasks, obj.stats, obj.episode_dicts = {}, {}, []
|
obj.tasks, obj.stats, obj.episode_dicts = {}, {}, []
|
||||||
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot, use_videos)
|
obj.info = create_empty_dataset_info(
|
||||||
|
CODEBASE_VERSION, fps, robot_type, keys, image_keys, video_keys, shapes, names
|
||||||
|
)
|
||||||
write_json(obj.info, obj.root / INFO_PATH)
|
write_json(obj.info, obj.root / INFO_PATH)
|
||||||
|
|
||||||
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
|
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
|
||||||
obj.episode_buffer = obj._create_episode_buffer()
|
obj.episode_buffer = obj._create_episode_buffer()
|
||||||
|
|
||||||
obj.image_writer = None
|
|
||||||
if len(robot.cameras) > 0 and (image_writer_processes or image_writer_threads_per_camera):
|
|
||||||
obj.start_image_writter(
|
|
||||||
image_writer_processes, image_writer_threads_per_camera * robot.num_cameras
|
|
||||||
)
|
|
||||||
|
|
||||||
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It
|
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It
|
||||||
# is used to know when certain operations are need (for instance, computing dataset statistics). In
|
# is used to know when certain operations are need (for instance, computing dataset statistics). In
|
||||||
# order to be able to push the dataset to the hub, it needs to be consolidated first by calling
|
# order to be able to push the dataset to the hub, it needs to be consolidated first by calling
|
||||||
|
|
|
@ -193,7 +193,7 @@ def load_episode_dicts(local_dir: Path) -> dict:
|
||||||
return list(reader)
|
return list(reader)
|
||||||
|
|
||||||
|
|
||||||
def create_empty_dataset_info(codebase_version: str, fps: int, robot: Robot, use_videos: bool = True) -> dict:
|
def _get_info_from_robot(robot: Robot, use_videos: bool) -> tuple[list | dict]:
|
||||||
shapes = {key: len(names) for key, names in robot.names.items()}
|
shapes = {key: len(names) for key, names in robot.names.items()}
|
||||||
camera_shapes = {}
|
camera_shapes = {}
|
||||||
for key, cam in robot.cameras.items():
|
for key, cam in robot.cameras.items():
|
||||||
|
@ -203,10 +203,30 @@ def create_empty_dataset_info(codebase_version: str, fps: int, robot: Robot, use
|
||||||
"height": cam.height,
|
"height": cam.height,
|
||||||
"channels": cam.channels,
|
"channels": cam.channels,
|
||||||
}
|
}
|
||||||
|
keys = list(robot.names)
|
||||||
|
image_keys = [] if use_videos else list(camera_shapes)
|
||||||
|
video_keys = list(camera_shapes) if use_videos else []
|
||||||
|
shapes = {**shapes, **camera_shapes}
|
||||||
|
names = robot.names
|
||||||
|
robot_type = robot.robot_type
|
||||||
|
|
||||||
|
return robot_type, keys, image_keys, video_keys, shapes, names
|
||||||
|
|
||||||
|
|
||||||
|
def create_empty_dataset_info(
|
||||||
|
codebase_version: str,
|
||||||
|
fps: int,
|
||||||
|
robot_type: str,
|
||||||
|
keys: list[str],
|
||||||
|
image_keys: list[str],
|
||||||
|
video_keys: list[str],
|
||||||
|
shapes: dict,
|
||||||
|
names: dict,
|
||||||
|
) -> dict:
|
||||||
return {
|
return {
|
||||||
"codebase_version": codebase_version,
|
"codebase_version": codebase_version,
|
||||||
"data_path": DEFAULT_PARQUET_PATH,
|
"data_path": DEFAULT_PARQUET_PATH,
|
||||||
"robot_type": robot.robot_type,
|
"robot_type": robot_type,
|
||||||
"total_episodes": 0,
|
"total_episodes": 0,
|
||||||
"total_frames": 0,
|
"total_frames": 0,
|
||||||
"total_tasks": 0,
|
"total_tasks": 0,
|
||||||
|
@ -215,12 +235,12 @@ def create_empty_dataset_info(codebase_version: str, fps: int, robot: Robot, use
|
||||||
"chunks_size": DEFAULT_CHUNK_SIZE,
|
"chunks_size": DEFAULT_CHUNK_SIZE,
|
||||||
"fps": fps,
|
"fps": fps,
|
||||||
"splits": {},
|
"splits": {},
|
||||||
"keys": list(robot.names),
|
"keys": keys,
|
||||||
"video_keys": list(camera_shapes) if use_videos else [],
|
"video_keys": video_keys,
|
||||||
"image_keys": [] if use_videos else list(camera_shapes),
|
"image_keys": image_keys,
|
||||||
"shapes": {**shapes, **camera_shapes},
|
"shapes": shapes,
|
||||||
"names": robot.names,
|
"names": names,
|
||||||
"videos": {"videos_path": DEFAULT_VIDEO_PATH} if use_videos else None,
|
"videos": {"videos_path": DEFAULT_VIDEO_PATH} if len(video_keys) > 0 else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue