add write_stats, changes names, add some typing

This commit is contained in:
Simon Alibert 2024-10-23 11:38:07 +02:00
parent fb73cdb9a4
commit a2a8538ac9
5 changed files with 33 additions and 26 deletions

View File

@ -91,9 +91,9 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
)
if isinstance(cfg.dataset_repo_id, str):
# TODO (aliberts): add 'episodes' arg from config after removing hydra
dataset = LeRobotDataset(
cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"),
image_transforms=image_transforms,
video_backend=cfg.video_backend,
@ -101,7 +101,6 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
else:
dataset = MultiLeRobotDataset(
cfg.dataset_repo_id,
split=split,
delta_timestamps=cfg.training.get("delta_timestamps"),
image_transforms=image_transforms,
video_backend=cfg.video_backend,

View File

@ -21,7 +21,7 @@ import torch
import tqdm
from PIL import Image
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
DEFAULT_IMAGE_PATH = "{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
def safe_stop_image_writer(func):
@ -54,7 +54,8 @@ class ImageWriter:
"""
def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1):
self.dir = write_dir
self.dir = write_dir / "images"
self.dir.mkdir(parents=True, exist_ok=True)
self.image_path = DEFAULT_IMAGE_PATH
self.num_processes = num_processes
self.num_threads = self.num_threads_per_process = num_threads

View File

@ -33,6 +33,7 @@ from lerobot.common.datasets.image_writer import ImageWriter
from lerobot.common.datasets.utils import (
EPISODES_PATH,
INFO_PATH,
STATS_PATH,
TASKS_PATH,
append_jsonl,
check_delta_timestamps,
@ -40,7 +41,6 @@ from lerobot.common.datasets.utils import (
check_version_compatibility,
create_branch,
create_empty_dataset_info,
flatten_dict,
get_delta_indices,
get_episode_data_index,
get_hub_safe_version,
@ -49,8 +49,8 @@ from lerobot.common.datasets.utils import (
load_info,
load_stats,
load_tasks,
unflatten_dict,
write_json,
write_stats,
)
from lerobot.common.datasets.video_utils import (
VideoFrame,
@ -227,11 +227,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""Codebase version used to create this dataset."""
return self.info["codebase_version"]
def push_to_repo(self, push_videos: bool = True) -> None:
def push_to_hub(self, push_videos: bool = True) -> None:
if not self.consolidated:
raise RuntimeError(
"You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet."
"Please use the '.consolidate()' method first."
"Please call the dataset 'consolidate()' method first."
)
ignore_patterns = ["images/"]
if not push_videos:
@ -675,7 +675,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Reset the buffer
self.episode_buffer = self._create_episode_buffer()
def _remove_image_writer(self) -> None:
def read_mode(self) -> None:
"""Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first."""
# TODO(aliberts, rcadene): find better api/interface for this.
if self.image_writer is not None:
self.image_writer = None
@ -693,9 +695,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
# since video encoding with ffmpeg is already using multithreading.
encode_video_frames(tmp_imgs_dir, video_path, self.fps, overwrite=True)
shutil.rmtree(tmp_imgs_dir)
def consolidate(self, run_compute_stats: bool = True) -> None:
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
@ -703,14 +704,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
if len(self.video_keys) > 0:
self.encode_videos()
if not keep_image_files:
shutil.rmtree(self.image_writer.dir)
if run_compute_stats:
logging.info("Computing dataset statistics")
self._remove_image_writer()
self.read_mode()
self.stats = compute_stats(self)
serialized_stats = flatten_dict(self.stats)
serialized_stats = {key: value.tolist() for key, value in serialized_stats.items()}
serialized_stats = unflatten_dict(serialized_stats)
write_json(serialized_stats, self.root / "meta/stats.json")
write_stats(self.stats, self.root / STATS_PATH)
self.consolidated = True
else:
logging.warning("Skipping computation of the dataset statistics.")
@ -784,8 +784,8 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
def __init__(
self,
repo_ids: list[str],
root: Path | None = LEROBOT_HOME,
split: str = "train",
root: Path | None = None,
episodes: dict | None = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
video_backend: str | None = None,
@ -797,8 +797,8 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
self._datasets = [
LeRobotDataset(
repo_id,
root=root,
split=split,
root=root / repo_id if root is not None else None,
episodes=episodes[repo_id] if episodes is not None else None,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
video_backend=video_backend,
@ -834,7 +834,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
self.disabled_data_keys.update(extra_keys)
self.root = root
self.split = split
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
self.stats = aggregate_stats(self._datasets)
@ -948,7 +947,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
return (
f"{self.__class__.__name__}(\n"
f" Repository IDs: '{self.repo_ids}',\n"
f" Split: '{self.split}',\n"
f" Number of Samples: {self.num_samples},\n"
f" Number of Episodes: {self.num_episodes},\n"
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"

View File

@ -48,7 +48,7 @@ This dataset was created using [LeRobot](https://github.com/huggingface/lerobot)
"""
def flatten_dict(d, parent_key="", sep="/"):
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
For example:
@ -67,7 +67,7 @@ def flatten_dict(d, parent_key="", sep="/"):
return dict(items)
def unflatten_dict(d, sep="/"):
def unflatten_dict(d: dict, sep: str = "/") -> dict:
outdict = {}
for key, value in d.items():
parts = key.split(sep)
@ -92,6 +92,12 @@ def append_jsonl(data: dict, fpath: Path) -> None:
writer.write(data)
def write_stats(stats: dict[str, torch.Tensor | dict], fpath: Path) -> None:
serialized_stats = {key: value.tolist() for key, value in flatten_dict(stats).items()}
serialized_stats = unflatten_dict(serialized_stats)
write_json(serialized_stats, fpath)
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
to torch tensors. Importantly, images are converted from PIL, which corresponds to

View File

@ -315,11 +315,14 @@ def record(
logging.info("Waiting for image writer to terminate...")
dataset.image_writer.stop()
if run_compute_stats:
logging.info("Computing dataset statistics")
dataset.consolidate(run_compute_stats)
# lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds)
if push_to_hub:
dataset.push_to_repo()
dataset.push_to_hub()
log_say("Exiting", play_sounds)
return dataset