Move ImageWriter creation inside the dataset
This commit is contained in:
parent
0098bd264e
commit
0d77be90ee
|
@ -177,11 +177,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.episodes = episodes
|
self.episodes = episodes
|
||||||
self.tolerance_s = tolerance_s
|
self.tolerance_s = tolerance_s
|
||||||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||||
self.image_writer = image_writer
|
|
||||||
self.delta_indices = None
|
self.delta_indices = None
|
||||||
self.consolidated = True
|
|
||||||
self.episode_buffer = {}
|
|
||||||
self.local_files_only = local_files_only
|
self.local_files_only = local_files_only
|
||||||
|
self.consolidated = True
|
||||||
|
|
||||||
|
# Unused attributes
|
||||||
|
self.image_writer = None
|
||||||
|
self.episode_buffer = {}
|
||||||
|
|
||||||
# Load metadata
|
# Load metadata
|
||||||
self.root.mkdir(exist_ok=True, parents=True)
|
self.root.mkdir(exist_ok=True, parents=True)
|
||||||
|
@ -626,8 +628,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.consolidated = False
|
self.consolidated = False
|
||||||
|
|
||||||
def _save_episode_table(self, episode_index: int) -> None:
|
def _save_episode_table(self, episode_index: int) -> None:
|
||||||
features = self.features
|
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self.features, split="train")
|
||||||
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=features, split="train")
|
|
||||||
ep_table = ep_dataset._data.table
|
ep_table = ep_dataset._data.table
|
||||||
ep_data_path = self.root / self.get_data_file_path(ep_index=episode_index)
|
ep_data_path = self.root / self.get_data_file_path(ep_index=episode_index)
|
||||||
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
@ -675,10 +676,25 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
# Reset the buffer
|
# Reset the buffer
|
||||||
self.episode_buffer = self._create_episode_buffer()
|
self.episode_buffer = self._create_episode_buffer()
|
||||||
|
|
||||||
def read_mode(self) -> None:
|
def start_image_writter(self, num_processes: int = 0, num_threads: int = 1) -> None:
|
||||||
"""Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first."""
|
if isinstance(self.image_writer, ImageWriter):
|
||||||
# TODO(aliberts, rcadene): find better api/interface for this.
|
logging.warning(
|
||||||
|
"You are starting a new ImageWriter that is replacing an already exising one in the dataset."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.image_writer = ImageWriter(
|
||||||
|
write_dir=self.root,
|
||||||
|
num_processes=num_processes,
|
||||||
|
num_threads=num_threads,
|
||||||
|
)
|
||||||
|
|
||||||
|
def stop_image_writter(self) -> None:
|
||||||
|
"""
|
||||||
|
Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to
|
||||||
|
remove the image_write in order for the LeRobotDataset object to be pickleable and parallelized.
|
||||||
|
"""
|
||||||
if self.image_writer is not None:
|
if self.image_writer is not None:
|
||||||
|
self.image_writer.stop()
|
||||||
self.image_writer = None
|
self.image_writer = None
|
||||||
|
|
||||||
def encode_videos(self) -> None:
|
def encode_videos(self) -> None:
|
||||||
|
@ -708,20 +724,20 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
shutil.rmtree(self.image_writer.dir)
|
shutil.rmtree(self.image_writer.dir)
|
||||||
|
|
||||||
if run_compute_stats:
|
if run_compute_stats:
|
||||||
self.read_mode()
|
self.stop_image_writter()
|
||||||
self.stats = compute_stats(self)
|
self.stats = compute_stats(self)
|
||||||
write_stats(self.stats, self.root / STATS_PATH)
|
write_stats(self.stats, self.root / STATS_PATH)
|
||||||
self.consolidated = True
|
self.consolidated = True
|
||||||
else:
|
else:
|
||||||
logging.warning("Skipping computation of the dataset statistics.")
|
logging.warning(
|
||||||
|
"Skipping computation of the dataset statistics, dataset is not fully consolidated."
|
||||||
|
)
|
||||||
|
|
||||||
# TODO(aliberts)
|
# TODO(aliberts)
|
||||||
# Sanity checks:
|
# Sanity checks:
|
||||||
# - [ ] shapes
|
# - [ ] shapes
|
||||||
# - [ ] ep_lenghts
|
# - [ ] ep_lenghts
|
||||||
# - [ ] number of files
|
# - [ ] number of files
|
||||||
# - [ ] names of files (e.g. parquet 00000-of-00001 and 00001-of-00002)
|
|
||||||
# - [ ] no remaining self.image_writer.dir
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
|
@ -731,7 +747,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
robot: Robot,
|
robot: Robot,
|
||||||
root: Path | None = None,
|
root: Path | None = None,
|
||||||
tolerance_s: float = 1e-4,
|
tolerance_s: float = 1e-4,
|
||||||
image_writer: ImageWriter | None = None,
|
image_writer_processes: int = 0,
|
||||||
|
image_writer_threads_per_camera: int = 0,
|
||||||
use_videos: bool = True,
|
use_videos: bool = True,
|
||||||
video_backend: str | None = None,
|
video_backend: str | None = None,
|
||||||
) -> "LeRobotDataset":
|
) -> "LeRobotDataset":
|
||||||
|
@ -740,7 +757,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
obj.repo_id = repo_id
|
obj.repo_id = repo_id
|
||||||
obj.root = root if root is not None else LEROBOT_HOME / repo_id
|
obj.root = root if root is not None else LEROBOT_HOME / repo_id
|
||||||
obj.tolerance_s = tolerance_s
|
obj.tolerance_s = tolerance_s
|
||||||
obj.image_writer = image_writer
|
|
||||||
|
|
||||||
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||||
logging.warning(
|
logging.warning(
|
||||||
|
@ -755,20 +771,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
|
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
|
||||||
obj.episode_buffer = obj._create_episode_buffer()
|
obj.episode_buffer = obj._create_episode_buffer()
|
||||||
|
|
||||||
|
obj.image_writer = None
|
||||||
|
if len(robot.cameras) > 0 and (image_writer_processes or image_writer_threads_per_camera):
|
||||||
|
obj.start_image_writter(
|
||||||
|
image_writer_processes, image_writer_threads_per_camera * robot.num_cameras
|
||||||
|
)
|
||||||
|
|
||||||
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It
|
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It
|
||||||
# is used to know when certain operations are need (for instance, computing dataset statistics). In
|
# is used to know when certain operations are need (for instance, computing dataset statistics). In
|
||||||
# order to be able to push the dataset to the hub, it needs to be consolidated first by calling
|
# order to be able to push the dataset to the hub, it needs to be consolidated first by calling
|
||||||
# self.consolidate().
|
# self.consolidate().
|
||||||
obj.consolidated = True
|
obj.consolidated = True
|
||||||
|
|
||||||
obj.local_files_only = True
|
|
||||||
obj.download_videos = False
|
|
||||||
|
|
||||||
obj.episodes = None
|
obj.episodes = None
|
||||||
obj.hf_dataset = None
|
obj.hf_dataset = None
|
||||||
obj.image_transforms = None
|
obj.image_transforms = None
|
||||||
obj.delta_timestamps = None
|
obj.delta_timestamps = None
|
||||||
obj.delta_indices = None
|
obj.delta_indices = None
|
||||||
|
obj.local_files_only = True
|
||||||
obj.episode_data_index = None
|
obj.episode_data_index = None
|
||||||
obj.video_backend = video_backend if video_backend is not None else "pyav"
|
obj.video_backend = video_backend if video_backend is not None else "pyav"
|
||||||
return obj
|
return obj
|
||||||
|
|
|
@ -105,7 +105,6 @@ from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
# from safetensors.torch import load_file, save_file
|
# from safetensors.torch import load_file, save_file
|
||||||
from lerobot.common.datasets.image_writer import ImageWriter
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.robot_devices.control_utils import (
|
from lerobot.common.robot_devices.control_utils import (
|
||||||
control_loop,
|
control_loop,
|
||||||
|
@ -232,17 +231,14 @@ def record(
|
||||||
|
|
||||||
# Create empty dataset or load existing saved episodes
|
# Create empty dataset or load existing saved episodes
|
||||||
sanity_check_dataset_name(repo_id, policy)
|
sanity_check_dataset_name(repo_id, policy)
|
||||||
if len(robot.cameras) > 0:
|
|
||||||
image_writer = ImageWriter(
|
|
||||||
write_dir=root,
|
|
||||||
num_processes=num_image_writer_processes,
|
|
||||||
num_threads=num_image_writer_threads_per_camera * robot.num_cameras,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
image_writer = None
|
|
||||||
|
|
||||||
dataset = LeRobotDataset.create(
|
dataset = LeRobotDataset.create(
|
||||||
repo_id, fps, robot, root=root, image_writer=image_writer, use_videos=video
|
repo_id,
|
||||||
|
fps,
|
||||||
|
robot,
|
||||||
|
root=root,
|
||||||
|
image_writer_processes=num_image_writer_processes,
|
||||||
|
image_writer_threads_per_camera=num_image_writer_threads_per_camera,
|
||||||
|
use_videos=video,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not robot.is_connected:
|
if not robot.is_connected:
|
||||||
|
|
Loading…
Reference in New Issue