Move ImageWriter creation inside the dataset
This commit is contained in:
parent
0098bd264e
commit
0d77be90ee
|
@ -177,11 +177,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.episodes = episodes
|
||||
self.tolerance_s = tolerance_s
|
||||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
self.image_writer = image_writer
|
||||
self.delta_indices = None
|
||||
self.consolidated = True
|
||||
self.episode_buffer = {}
|
||||
self.local_files_only = local_files_only
|
||||
self.consolidated = True
|
||||
|
||||
# Unused attributes
|
||||
self.image_writer = None
|
||||
self.episode_buffer = {}
|
||||
|
||||
# Load metadata
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
|
@ -626,8 +628,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.consolidated = False
|
||||
|
||||
def _save_episode_table(self, episode_index: int) -> None:
|
||||
features = self.features
|
||||
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=features, split="train")
|
||||
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self.features, split="train")
|
||||
ep_table = ep_dataset._data.table
|
||||
ep_data_path = self.root / self.get_data_file_path(ep_index=episode_index)
|
||||
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
@ -675,10 +676,25 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
# Reset the buffer
|
||||
self.episode_buffer = self._create_episode_buffer()
|
||||
|
||||
def read_mode(self) -> None:
|
||||
"""Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first."""
|
||||
# TODO(aliberts, rcadene): find better api/interface for this.
|
||||
def start_image_writter(self, num_processes: int = 0, num_threads: int = 1) -> None:
|
||||
if isinstance(self.image_writer, ImageWriter):
|
||||
logging.warning(
|
||||
"You are starting a new ImageWriter that is replacing an already exising one in the dataset."
|
||||
)
|
||||
|
||||
self.image_writer = ImageWriter(
|
||||
write_dir=self.root,
|
||||
num_processes=num_processes,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
def stop_image_writter(self) -> None:
|
||||
"""
|
||||
Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to
|
||||
remove the image_write in order for the LeRobotDataset object to be pickleable and parallelized.
|
||||
"""
|
||||
if self.image_writer is not None:
|
||||
self.image_writer.stop()
|
||||
self.image_writer = None
|
||||
|
||||
def encode_videos(self) -> None:
|
||||
|
@ -708,20 +724,20 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
shutil.rmtree(self.image_writer.dir)
|
||||
|
||||
if run_compute_stats:
|
||||
self.read_mode()
|
||||
self.stop_image_writter()
|
||||
self.stats = compute_stats(self)
|
||||
write_stats(self.stats, self.root / STATS_PATH)
|
||||
self.consolidated = True
|
||||
else:
|
||||
logging.warning("Skipping computation of the dataset statistics.")
|
||||
logging.warning(
|
||||
"Skipping computation of the dataset statistics, dataset is not fully consolidated."
|
||||
)
|
||||
|
||||
# TODO(aliberts)
|
||||
# Sanity checks:
|
||||
# - [ ] shapes
|
||||
# - [ ] ep_lenghts
|
||||
# - [ ] number of files
|
||||
# - [ ] names of files (e.g. parquet 00000-of-00001 and 00001-of-00002)
|
||||
# - [ ] no remaining self.image_writer.dir
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
|
@ -731,7 +747,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
robot: Robot,
|
||||
root: Path | None = None,
|
||||
tolerance_s: float = 1e-4,
|
||||
image_writer: ImageWriter | None = None,
|
||||
image_writer_processes: int = 0,
|
||||
image_writer_threads_per_camera: int = 0,
|
||||
use_videos: bool = True,
|
||||
video_backend: str | None = None,
|
||||
) -> "LeRobotDataset":
|
||||
|
@ -740,7 +757,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
obj.repo_id = repo_id
|
||||
obj.root = root if root is not None else LEROBOT_HOME / repo_id
|
||||
obj.tolerance_s = tolerance_s
|
||||
obj.image_writer = image_writer
|
||||
|
||||
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||
logging.warning(
|
||||
|
@ -755,20 +771,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
|
||||
obj.episode_buffer = obj._create_episode_buffer()
|
||||
|
||||
obj.image_writer = None
|
||||
if len(robot.cameras) > 0 and (image_writer_processes or image_writer_threads_per_camera):
|
||||
obj.start_image_writter(
|
||||
image_writer_processes, image_writer_threads_per_camera * robot.num_cameras
|
||||
)
|
||||
|
||||
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It
|
||||
# is used to know when certain operations are need (for instance, computing dataset statistics). In
|
||||
# order to be able to push the dataset to the hub, it needs to be consolidated first by calling
|
||||
# self.consolidate().
|
||||
obj.consolidated = True
|
||||
|
||||
obj.local_files_only = True
|
||||
obj.download_videos = False
|
||||
|
||||
obj.episodes = None
|
||||
obj.hf_dataset = None
|
||||
obj.image_transforms = None
|
||||
obj.delta_timestamps = None
|
||||
obj.delta_indices = None
|
||||
obj.local_files_only = True
|
||||
obj.episode_data_index = None
|
||||
obj.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
return obj
|
||||
|
|
|
@ -105,7 +105,6 @@ from pathlib import Path
|
|||
from typing import List
|
||||
|
||||
# from safetensors.torch import load_file, save_file
|
||||
from lerobot.common.datasets.image_writer import ImageWriter
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.robot_devices.control_utils import (
|
||||
control_loop,
|
||||
|
@ -232,17 +231,14 @@ def record(
|
|||
|
||||
# Create empty dataset or load existing saved episodes
|
||||
sanity_check_dataset_name(repo_id, policy)
|
||||
if len(robot.cameras) > 0:
|
||||
image_writer = ImageWriter(
|
||||
write_dir=root,
|
||||
num_processes=num_image_writer_processes,
|
||||
num_threads=num_image_writer_threads_per_camera * robot.num_cameras,
|
||||
)
|
||||
else:
|
||||
image_writer = None
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id, fps, robot, root=root, image_writer=image_writer, use_videos=video
|
||||
repo_id,
|
||||
fps,
|
||||
robot,
|
||||
root=root,
|
||||
image_writer_processes=num_image_writer_processes,
|
||||
image_writer_threads_per_camera=num_image_writer_threads_per_camera,
|
||||
use_videos=video,
|
||||
)
|
||||
|
||||
if not robot.is_connected:
|
||||
|
|
Loading…
Reference in New Issue