add write_stats, changes names, add some typing
This commit is contained in:
parent
fb73cdb9a4
commit
a2a8538ac9
|
@ -91,9 +91,9 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
|||
)
|
||||
|
||||
if isinstance(cfg.dataset_repo_id, str):
|
||||
# TODO (aliberts): add 'episodes' arg from config after removing hydra
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset_repo_id,
|
||||
split=split,
|
||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||
image_transforms=image_transforms,
|
||||
video_backend=cfg.video_backend,
|
||||
|
@ -101,7 +101,6 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
|||
else:
|
||||
dataset = MultiLeRobotDataset(
|
||||
cfg.dataset_repo_id,
|
||||
split=split,
|
||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||
image_transforms=image_transforms,
|
||||
video_backend=cfg.video_backend,
|
||||
|
|
|
@ -21,7 +21,7 @@ import torch
|
|||
import tqdm
|
||||
from PIL import Image
|
||||
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
||||
DEFAULT_IMAGE_PATH = "{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
||||
|
||||
|
||||
def safe_stop_image_writer(func):
|
||||
|
@ -54,7 +54,8 @@ class ImageWriter:
|
|||
"""
|
||||
|
||||
def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1):
|
||||
self.dir = write_dir
|
||||
self.dir = write_dir / "images"
|
||||
self.dir.mkdir(parents=True, exist_ok=True)
|
||||
self.image_path = DEFAULT_IMAGE_PATH
|
||||
self.num_processes = num_processes
|
||||
self.num_threads = self.num_threads_per_process = num_threads
|
||||
|
|
|
@ -33,6 +33,7 @@ from lerobot.common.datasets.image_writer import ImageWriter
|
|||
from lerobot.common.datasets.utils import (
|
||||
EPISODES_PATH,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
TASKS_PATH,
|
||||
append_jsonl,
|
||||
check_delta_timestamps,
|
||||
|
@ -40,7 +41,6 @@ from lerobot.common.datasets.utils import (
|
|||
check_version_compatibility,
|
||||
create_branch,
|
||||
create_empty_dataset_info,
|
||||
flatten_dict,
|
||||
get_delta_indices,
|
||||
get_episode_data_index,
|
||||
get_hub_safe_version,
|
||||
|
@ -49,8 +49,8 @@ from lerobot.common.datasets.utils import (
|
|||
load_info,
|
||||
load_stats,
|
||||
load_tasks,
|
||||
unflatten_dict,
|
||||
write_json,
|
||||
write_stats,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import (
|
||||
VideoFrame,
|
||||
|
@ -227,11 +227,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
"""Codebase version used to create this dataset."""
|
||||
return self.info["codebase_version"]
|
||||
|
||||
def push_to_repo(self, push_videos: bool = True) -> None:
|
||||
def push_to_hub(self, push_videos: bool = True) -> None:
|
||||
if not self.consolidated:
|
||||
raise RuntimeError(
|
||||
"You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet."
|
||||
"Please use the '.consolidate()' method first."
|
||||
"Please call the dataset 'consolidate()' method first."
|
||||
)
|
||||
ignore_patterns = ["images/"]
|
||||
if not push_videos:
|
||||
|
@ -675,7 +675,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
# Reset the buffer
|
||||
self.episode_buffer = self._create_episode_buffer()
|
||||
|
||||
def _remove_image_writer(self) -> None:
|
||||
def read_mode(self) -> None:
|
||||
"""Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first."""
|
||||
# TODO(aliberts, rcadene): find better api/interface for this.
|
||||
if self.image_writer is not None:
|
||||
self.image_writer = None
|
||||
|
||||
|
@ -693,9 +695,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
# since video encoding with ffmpeg is already using multithreading.
|
||||
encode_video_frames(tmp_imgs_dir, video_path, self.fps, overwrite=True)
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
def consolidate(self, run_compute_stats: bool = True) -> None:
|
||||
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
|
||||
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
||||
|
@ -703,14 +704,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
if len(self.video_keys) > 0:
|
||||
self.encode_videos()
|
||||
|
||||
if not keep_image_files:
|
||||
shutil.rmtree(self.image_writer.dir)
|
||||
|
||||
if run_compute_stats:
|
||||
logging.info("Computing dataset statistics")
|
||||
self._remove_image_writer()
|
||||
self.read_mode()
|
||||
self.stats = compute_stats(self)
|
||||
serialized_stats = flatten_dict(self.stats)
|
||||
serialized_stats = {key: value.tolist() for key, value in serialized_stats.items()}
|
||||
serialized_stats = unflatten_dict(serialized_stats)
|
||||
write_json(serialized_stats, self.root / "meta/stats.json")
|
||||
write_stats(self.stats, self.root / STATS_PATH)
|
||||
self.consolidated = True
|
||||
else:
|
||||
logging.warning("Skipping computation of the dataset statistics.")
|
||||
|
@ -784,8 +784,8 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||
def __init__(
|
||||
self,
|
||||
repo_ids: list[str],
|
||||
root: Path | None = LEROBOT_HOME,
|
||||
split: str = "train",
|
||||
root: Path | None = None,
|
||||
episodes: dict | None = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
video_backend: str | None = None,
|
||||
|
@ -797,8 +797,8 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||
self._datasets = [
|
||||
LeRobotDataset(
|
||||
repo_id,
|
||||
root=root,
|
||||
split=split,
|
||||
root=root / repo_id if root is not None else None,
|
||||
episodes=episodes[repo_id] if episodes is not None else None,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
video_backend=video_backend,
|
||||
|
@ -834,7 +834,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||
self.disabled_data_keys.update(extra_keys)
|
||||
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.stats = aggregate_stats(self._datasets)
|
||||
|
@ -948,7 +947,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||
return (
|
||||
f"{self.__class__.__name__}(\n"
|
||||
f" Repository IDs: '{self.repo_ids}',\n"
|
||||
f" Split: '{self.split}',\n"
|
||||
f" Number of Samples: {self.num_samples},\n"
|
||||
f" Number of Episodes: {self.num_episodes},\n"
|
||||
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
|
||||
|
|
|
@ -48,7 +48,7 @@ This dataset was created using [LeRobot](https://github.com/huggingface/lerobot)
|
|||
"""
|
||||
|
||||
|
||||
def flatten_dict(d, parent_key="", sep="/"):
|
||||
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
|
||||
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
|
||||
|
||||
For example:
|
||||
|
@ -67,7 +67,7 @@ def flatten_dict(d, parent_key="", sep="/"):
|
|||
return dict(items)
|
||||
|
||||
|
||||
def unflatten_dict(d, sep="/"):
|
||||
def unflatten_dict(d: dict, sep: str = "/") -> dict:
|
||||
outdict = {}
|
||||
for key, value in d.items():
|
||||
parts = key.split(sep)
|
||||
|
@ -92,6 +92,12 @@ def append_jsonl(data: dict, fpath: Path) -> None:
|
|||
writer.write(data)
|
||||
|
||||
|
||||
def write_stats(stats: dict[str, torch.Tensor | dict], fpath: Path) -> None:
|
||||
serialized_stats = {key: value.tolist() for key, value in flatten_dict(stats).items()}
|
||||
serialized_stats = unflatten_dict(serialized_stats)
|
||||
write_json(serialized_stats, fpath)
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
||||
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
||||
|
|
|
@ -315,11 +315,14 @@ def record(
|
|||
logging.info("Waiting for image writer to terminate...")
|
||||
dataset.image_writer.stop()
|
||||
|
||||
if run_compute_stats:
|
||||
logging.info("Computing dataset statistics")
|
||||
|
||||
dataset.consolidate(run_compute_stats)
|
||||
|
||||
# lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds)
|
||||
if push_to_hub:
|
||||
dataset.push_to_repo()
|
||||
dataset.push_to_hub()
|
||||
|
||||
log_say("Exiting", play_sounds)
|
||||
return dataset
|
||||
|
|
Loading…
Reference in New Issue