add write_stats, changes names, add some typing

This commit is contained in:
Simon Alibert 2024-10-23 11:38:07 +02:00
parent fb73cdb9a4
commit a2a8538ac9
5 changed files with 33 additions and 26 deletions

View File

@ -91,9 +91,9 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
) )
if isinstance(cfg.dataset_repo_id, str): if isinstance(cfg.dataset_repo_id, str):
# TODO (aliberts): add 'episodes' arg from config after removing hydra
dataset = LeRobotDataset( dataset = LeRobotDataset(
cfg.dataset_repo_id, cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"), delta_timestamps=cfg.training.get("delta_timestamps"),
image_transforms=image_transforms, image_transforms=image_transforms,
video_backend=cfg.video_backend, video_backend=cfg.video_backend,
@ -101,7 +101,6 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
else: else:
dataset = MultiLeRobotDataset( dataset = MultiLeRobotDataset(
cfg.dataset_repo_id, cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"), delta_timestamps=cfg.training.get("delta_timestamps"),
image_transforms=image_transforms, image_transforms=image_transforms,
video_backend=cfg.video_backend, video_backend=cfg.video_backend,

View File

@ -21,7 +21,7 @@ import torch
import tqdm import tqdm
from PIL import Image from PIL import Image
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png" DEFAULT_IMAGE_PATH = "{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
def safe_stop_image_writer(func): def safe_stop_image_writer(func):
@ -54,7 +54,8 @@ class ImageWriter:
""" """
def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1): def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1):
self.dir = write_dir self.dir = write_dir / "images"
self.dir.mkdir(parents=True, exist_ok=True)
self.image_path = DEFAULT_IMAGE_PATH self.image_path = DEFAULT_IMAGE_PATH
self.num_processes = num_processes self.num_processes = num_processes
self.num_threads = self.num_threads_per_process = num_threads self.num_threads = self.num_threads_per_process = num_threads

View File

@ -33,6 +33,7 @@ from lerobot.common.datasets.image_writer import ImageWriter
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
EPISODES_PATH, EPISODES_PATH,
INFO_PATH, INFO_PATH,
STATS_PATH,
TASKS_PATH, TASKS_PATH,
append_jsonl, append_jsonl,
check_delta_timestamps, check_delta_timestamps,
@ -40,7 +41,6 @@ from lerobot.common.datasets.utils import (
check_version_compatibility, check_version_compatibility,
create_branch, create_branch,
create_empty_dataset_info, create_empty_dataset_info,
flatten_dict,
get_delta_indices, get_delta_indices,
get_episode_data_index, get_episode_data_index,
get_hub_safe_version, get_hub_safe_version,
@ -49,8 +49,8 @@ from lerobot.common.datasets.utils import (
load_info, load_info,
load_stats, load_stats,
load_tasks, load_tasks,
unflatten_dict,
write_json, write_json,
write_stats,
) )
from lerobot.common.datasets.video_utils import ( from lerobot.common.datasets.video_utils import (
VideoFrame, VideoFrame,
@ -227,11 +227,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""Codebase version used to create this dataset.""" """Codebase version used to create this dataset."""
return self.info["codebase_version"] return self.info["codebase_version"]
def push_to_repo(self, push_videos: bool = True) -> None: def push_to_hub(self, push_videos: bool = True) -> None:
if not self.consolidated: if not self.consolidated:
raise RuntimeError( raise RuntimeError(
"You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet." "You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet."
"Please use the '.consolidate()' method first." "Please call the dataset 'consolidate()' method first."
) )
ignore_patterns = ["images/"] ignore_patterns = ["images/"]
if not push_videos: if not push_videos:
@ -675,7 +675,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Reset the buffer # Reset the buffer
self.episode_buffer = self._create_episode_buffer() self.episode_buffer = self._create_episode_buffer()
def _remove_image_writer(self) -> None: def read_mode(self) -> None:
"""Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first."""
# TODO(aliberts, rcadene): find better api/interface for this.
if self.image_writer is not None: if self.image_writer is not None:
self.image_writer = None self.image_writer = None
@ -693,9 +695,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, # note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
# since video encoding with ffmpeg is already using multithreading. # since video encoding with ffmpeg is already using multithreading.
encode_video_frames(tmp_imgs_dir, video_path, self.fps, overwrite=True) encode_video_frames(tmp_imgs_dir, video_path, self.fps, overwrite=True)
shutil.rmtree(tmp_imgs_dir)
def consolidate(self, run_compute_stats: bool = True) -> None: def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
self.hf_dataset = self.load_hf_dataset() self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts) self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s) check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
@ -703,14 +704,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
if len(self.video_keys) > 0: if len(self.video_keys) > 0:
self.encode_videos() self.encode_videos()
if not keep_image_files:
shutil.rmtree(self.image_writer.dir)
if run_compute_stats: if run_compute_stats:
logging.info("Computing dataset statistics") self.read_mode()
self._remove_image_writer()
self.stats = compute_stats(self) self.stats = compute_stats(self)
serialized_stats = flatten_dict(self.stats) write_stats(self.stats, self.root / STATS_PATH)
serialized_stats = {key: value.tolist() for key, value in serialized_stats.items()}
serialized_stats = unflatten_dict(serialized_stats)
write_json(serialized_stats, self.root / "meta/stats.json")
self.consolidated = True self.consolidated = True
else: else:
logging.warning("Skipping computation of the dataset statistics.") logging.warning("Skipping computation of the dataset statistics.")
@ -784,8 +784,8 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
def __init__( def __init__(
self, self,
repo_ids: list[str], repo_ids: list[str],
root: Path | None = LEROBOT_HOME, root: Path | None = None,
split: str = "train", episodes: dict | None = None,
image_transforms: Callable | None = None, image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None, delta_timestamps: dict[list[float]] | None = None,
video_backend: str | None = None, video_backend: str | None = None,
@ -797,8 +797,8 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
self._datasets = [ self._datasets = [
LeRobotDataset( LeRobotDataset(
repo_id, repo_id,
root=root, root=root / repo_id if root is not None else None,
split=split, episodes=episodes[repo_id] if episodes is not None else None,
delta_timestamps=delta_timestamps, delta_timestamps=delta_timestamps,
image_transforms=image_transforms, image_transforms=image_transforms,
video_backend=video_backend, video_backend=video_backend,
@ -834,7 +834,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
self.disabled_data_keys.update(extra_keys) self.disabled_data_keys.update(extra_keys)
self.root = root self.root = root
self.split = split
self.image_transforms = image_transforms self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps self.delta_timestamps = delta_timestamps
self.stats = aggregate_stats(self._datasets) self.stats = aggregate_stats(self._datasets)
@ -948,7 +947,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
return ( return (
f"{self.__class__.__name__}(\n" f"{self.__class__.__name__}(\n"
f" Repository IDs: '{self.repo_ids}',\n" f" Repository IDs: '{self.repo_ids}',\n"
f" Split: '{self.split}',\n"
f" Number of Samples: {self.num_samples},\n" f" Number of Samples: {self.num_samples},\n"
f" Number of Episodes: {self.num_episodes},\n" f" Number of Episodes: {self.num_episodes},\n"
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n" f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"

View File

@ -48,7 +48,7 @@ This dataset was created using [LeRobot](https://github.com/huggingface/lerobot)
""" """
def flatten_dict(d, parent_key="", sep="/"): def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator. """Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
For example: For example:
@ -67,7 +67,7 @@ def flatten_dict(d, parent_key="", sep="/"):
return dict(items) return dict(items)
def unflatten_dict(d, sep="/"): def unflatten_dict(d: dict, sep: str = "/") -> dict:
outdict = {} outdict = {}
for key, value in d.items(): for key, value in d.items():
parts = key.split(sep) parts = key.split(sep)
@ -92,6 +92,12 @@ def append_jsonl(data: dict, fpath: Path) -> None:
writer.write(data) writer.write(data)
def write_stats(stats: dict[str, torch.Tensor | dict], fpath: Path) -> None:
serialized_stats = {key: value.tolist() for key, value in flatten_dict(stats).items()}
serialized_stats = unflatten_dict(serialized_stats)
write_json(serialized_stats, fpath)
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
"""Get a transform function that convert items from Hugging Face dataset (pyarrow) """Get a transform function that convert items from Hugging Face dataset (pyarrow)
to torch tensors. Importantly, images are converted from PIL, which corresponds to to torch tensors. Importantly, images are converted from PIL, which corresponds to

View File

@ -315,11 +315,14 @@ def record(
logging.info("Waiting for image writer to terminate...") logging.info("Waiting for image writer to terminate...")
dataset.image_writer.stop() dataset.image_writer.stop()
if run_compute_stats:
logging.info("Computing dataset statistics")
dataset.consolidate(run_compute_stats) dataset.consolidate(run_compute_stats)
# lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds) # lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds)
if push_to_hub: if push_to_hub:
dataset.push_to_repo() dataset.push_to_hub()
log_say("Exiting", play_sounds) log_say("Exiting", play_sounds)
return dataset return dataset