Move ImageWriter creation inside the dataset

This commit is contained in:
Simon Alibert 2024-10-23 23:12:44 +02:00
parent 0098bd264e
commit 0d77be90ee
2 changed files with 44 additions and 28 deletions

View File

@ -177,11 +177,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episodes = episodes
self.tolerance_s = tolerance_s
self.video_backend = video_backend if video_backend is not None else "pyav"
self.image_writer = image_writer
self.delta_indices = None
self.consolidated = True
self.episode_buffer = {}
self.local_files_only = local_files_only
self.consolidated = True
# Unused attributes
self.image_writer = None
self.episode_buffer = {}
# Load metadata
self.root.mkdir(exist_ok=True, parents=True)
@ -626,8 +628,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.consolidated = False
def _save_episode_table(self, episode_index: int) -> None:
features = self.features
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=features, split="train")
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self.features, split="train")
ep_table = ep_dataset._data.table
ep_data_path = self.root / self.get_data_file_path(ep_index=episode_index)
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
@ -675,10 +676,25 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Reset the buffer
self.episode_buffer = self._create_episode_buffer()
def read_mode(self) -> None:
"""Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first."""
# TODO(aliberts, rcadene): find better api/interface for this.
def start_image_writter(self, num_processes: int = 0, num_threads: int = 1) -> None:
if isinstance(self.image_writer, ImageWriter):
logging.warning(
"You are starting a new ImageWriter that is replacing an already exising one in the dataset."
)
self.image_writer = ImageWriter(
write_dir=self.root,
num_processes=num_processes,
num_threads=num_threads,
)
def stop_image_writter(self) -> None:
"""
Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to
remove the image_write in order for the LeRobotDataset object to be pickleable and parallelized.
"""
if self.image_writer is not None:
self.image_writer.stop()
self.image_writer = None
def encode_videos(self) -> None:
@ -708,20 +724,20 @@ class LeRobotDataset(torch.utils.data.Dataset):
shutil.rmtree(self.image_writer.dir)
if run_compute_stats:
self.read_mode()
self.stop_image_writter()
self.stats = compute_stats(self)
write_stats(self.stats, self.root / STATS_PATH)
self.consolidated = True
else:
logging.warning("Skipping computation of the dataset statistics.")
logging.warning(
"Skipping computation of the dataset statistics, dataset is not fully consolidated."
)
# TODO(aliberts)
# Sanity checks:
# - [ ] shapes
# - [ ] ep_lenghts
# - [ ] number of files
# - [ ] names of files (e.g. parquet 00000-of-00001 and 00001-of-00002)
# - [ ] no remaining self.image_writer.dir
@classmethod
def create(
@ -731,7 +747,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
robot: Robot,
root: Path | None = None,
tolerance_s: float = 1e-4,
image_writer: ImageWriter | None = None,
image_writer_processes: int = 0,
image_writer_threads_per_camera: int = 0,
use_videos: bool = True,
video_backend: str | None = None,
) -> "LeRobotDataset":
@ -740,7 +757,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.repo_id = repo_id
obj.root = root if root is not None else LEROBOT_HOME / repo_id
obj.tolerance_s = tolerance_s
obj.image_writer = image_writer
if not all(cam.fps == fps for cam in robot.cameras.values()):
logging.warning(
@ -755,20 +771,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
obj.episode_buffer = obj._create_episode_buffer()
obj.image_writer = None
if len(robot.cameras) > 0 and (image_writer_processes or image_writer_threads_per_camera):
obj.start_image_writter(
image_writer_processes, image_writer_threads_per_camera * robot.num_cameras
)
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It
# is used to know when certain operations are need (for instance, computing dataset statistics). In
# order to be able to push the dataset to the hub, it needs to be consolidated first by calling
# self.consolidate().
obj.consolidated = True
obj.local_files_only = True
obj.download_videos = False
obj.episodes = None
obj.hf_dataset = None
obj.image_transforms = None
obj.delta_timestamps = None
obj.delta_indices = None
obj.local_files_only = True
obj.episode_data_index = None
obj.video_backend = video_backend if video_backend is not None else "pyav"
return obj

View File

@ -105,7 +105,6 @@ from pathlib import Path
from typing import List
# from safetensors.torch import load_file, save_file
from lerobot.common.datasets.image_writer import ImageWriter
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.robot_devices.control_utils import (
control_loop,
@ -232,17 +231,14 @@ def record(
# Create empty dataset or load existing saved episodes
sanity_check_dataset_name(repo_id, policy)
if len(robot.cameras) > 0:
image_writer = ImageWriter(
write_dir=root,
num_processes=num_image_writer_processes,
num_threads=num_image_writer_threads_per_camera * robot.num_cameras,
)
else:
image_writer = None
dataset = LeRobotDataset.create(
repo_id, fps, robot, root=root, image_writer=image_writer, use_videos=video
repo_id,
fps,
robot,
root=root,
image_writer_processes=num_image_writer_processes,
image_writer_threads_per_camera=num_image_writer_threads_per_camera,
use_videos=video,
)
if not robot.is_connected: