diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index c4035908..f2d312d7 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -11,14 +11,16 @@ from lerobot.common.datasets.utils import ( load_stats, load_videos, ) -from lerobot.common.datasets.video_utils import load_from_videos +from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos + +CODEBASE_VERSION = "v1.2" class LeRobotDataset(torch.utils.data.Dataset): def __init__( self, repo_id: str, - version: str | None = "v1.1", + version: str | None = CODEBASE_VERSION, root: Path | None = None, split: str = "train", transform: callable = None, @@ -49,6 +51,10 @@ class LeRobotDataset(torch.utils.data.Dataset): def video(self) -> int: return self.info.get("video", False) + @property + def features(self) -> datasets.Features: + return self.hf_dataset.features + @property def image_keys(self) -> list[str]: image_keys = [] @@ -61,7 +67,7 @@ class LeRobotDataset(torch.utils.data.Dataset): def video_frame_keys(self): video_frame_keys = [] for key, feats in self.hf_dataset.features.items(): - if isinstance(feats, datasets.Value) and feats.id == "video_frame": + if isinstance(feats, VideoFrame): video_frame_keys.append(key) return video_frame_keys @@ -95,3 +101,34 @@ class LeRobotDataset(torch.utils.data.Dataset): item = self.transform(item) return item + + @classmethod + def from_preloaded( + cls, + repo_id: str, + version: str | None = CODEBASE_VERSION, + root: Path | None = None, + split: str = "train", + transform: callable = None, + delta_timestamps: dict[list[float]] | None = None, + # additional preloaded attributes + hf_dataset=None, + episode_data_index=None, + stats=None, + info=None, + videos_dir=None, + ): + # create an empty object of type LeRobotDataset + obj = cls.__new__(cls) + obj.repo_id = repo_id + obj.version = version + obj.root = root + obj.split = split + obj.transform = transform + obj.delta_timestamps = delta_timestamps + obj.hf_dataset = hf_dataset + obj.episode_data_index = episode_data_index + obj.stats = stats + obj.info = info + obj.videos_dir = videos_dir + return obj diff --git a/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py b/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py index 47ef2dc0..3ca9a2b1 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py @@ -16,7 +16,7 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episod from lerobot.common.datasets.utils import ( hf_transform_to_torch, ) -from lerobot.common.datasets.video_utils import encode_video_frames +from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames def check_format(raw_dir) -> bool: @@ -77,14 +77,17 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug): save_images_concurrently(imgs_array, tmp_imgs_dir) # encode images to a mp4 video - video_path = out_dir / "videos" / f"{img_key}_episode_{ep_idx:06d}.mp4" + fname = f"observation.image_episode_{ep_idx:06d}.mp4" + video_path = out_dir / "videos" / fname encode_video_frames(tmp_imgs_dir, video_path, fps) # clean temporary images directory shutil.rmtree(tmp_imgs_dir) # store the episode idx - ep_dict[img_key] = torch.tensor([ep_idx] * num_frames, dtype=torch.int) + ep_dict["observation.image"] = [ + {"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames) + ] else: ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array] @@ -120,7 +123,7 @@ def to_hf_dataset(data_dict, video) -> Dataset: image_keys = [key for key in data_dict if "observation.images." in key] for image_key in image_keys: if video: - features[image_key] = Value(dtype="int64", id="video") + features[image_key] = VideoFrame() else: features[image_key] = Image() diff --git a/lerobot/common/datasets/push_dataset_to_hub/compute_stats.py b/lerobot/common/datasets/push_dataset_to_hub/compute_stats.py new file mode 100644 index 00000000..a1fd0750 --- /dev/null +++ b/lerobot/common/datasets/push_dataset_to_hub/compute_stats.py @@ -0,0 +1,143 @@ +from copy import deepcopy +from math import ceil + +import datasets +import einops +import torch +import tqdm +from datasets import Image + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.video_utils import VideoFrame + + +def get_stats_einops_patterns(dataset: LeRobotDataset | datasets.Dataset): + """These einops patterns will be used to aggregate batches and compute statistics. + + Note: We assume the images are in channel first format + """ + + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=0, + batch_size=2, + shuffle=False, + ) + batch = next(iter(dataloader)) + + stats_patterns = {} + for key, feats_type in dataset.features.items(): + # sanity check that tensors are not float64 + assert batch[key].dtype != torch.float64 + + if isinstance(feats_type, VideoFrame, Image): + # sanity check that images are channel first + _, c, h, w = batch[key].shape + assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}" + + # sanity check that images are float32 in range [0,1] + assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}" + assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}" + assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}" + + stats_patterns[key] = "b c h w -> c 1 1" + elif batch[key].ndim == 2: + stats_patterns[key] = "b c -> c " + elif batch[key].ndim == 1: + stats_patterns[key] = "b -> 1" + else: + raise ValueError(f"{key}, {feats_type}, {batch[key].shape}") + + return stats_patterns + + +def compute_stats(dataset: LeRobotDataset | datasets.Dataset, batch_size=32, max_num_samples=None): + if max_num_samples is None: + max_num_samples = len(dataset) + + stats_patterns = get_stats_einops_patterns(dataset) + + # mean and std will be computed incrementally while max and min will track the running value. + mean, std, max, min = {}, {}, {}, {} + for key in stats_patterns: + mean[key] = torch.tensor(0.0).float() + std[key] = torch.tensor(0.0).float() + max[key] = torch.tensor(-float("inf")).float() + min[key] = torch.tensor(float("inf")).float() + + def create_seeded_dataloader(dataset, batch_size, seed): + generator = torch.Generator() + generator.manual_seed(seed) + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=0, + batch_size=batch_size, + shuffle=True, + drop_last=False, + generator=generator, + ) + return dataloader + + # Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get + # surprises when rerunning the sampler. + first_batch = None + running_item_count = 0 # for online mean computation + dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337) + for i, batch in enumerate( + tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max") + ): + this_batch_size = len(batch["index"]) + running_item_count += this_batch_size + if first_batch is None: + first_batch = deepcopy(batch) + for key, pattern in stats_patterns.items(): + batch[key] = batch[key].float() + # Numerically stable update step for mean computation. + batch_mean = einops.reduce(batch[key], pattern, "mean") + # Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents + # the update step, N is the running item count, B is this batch size, x̄ is the running mean, + # and x is the current batch mean. Some rearrangement is then required to avoid risking + # numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields + # x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ + mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count + max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max")) + min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min")) + + if i == ceil(max_num_samples / batch_size) - 1: + break + + first_batch_ = None + running_item_count = 0 # for online std computation + dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337) + for i, batch in enumerate( + tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std") + ): + this_batch_size = len(batch["index"]) + running_item_count += this_batch_size + # Sanity check to make sure the batches are still in the same order as before. + if first_batch_ is None: + first_batch_ = deepcopy(batch) + for key in stats_patterns: + assert torch.equal(first_batch_[key], first_batch[key]) + for key, pattern in stats_patterns.items(): + batch[key] = batch[key].float() + # Numerically stable update step for mean computation (where the mean is over squared + # residuals).See notes in the mean computation loop above. + batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean") + std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count + + if i == ceil(max_num_samples / batch_size) - 1: + break + + for key in stats_patterns: + std[key] = torch.sqrt(std[key]) + + stats = {} + for key in stats_patterns: + stats[key] = { + "mean": mean[key], + "std": std[key], + "max": max[key], + "min": min[key], + } + return stats diff --git a/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py b/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py index 0fae71bc..8e20c296 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py @@ -14,7 +14,7 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episod from lerobot.common.datasets.utils import ( hf_transform_to_torch, ) -from lerobot.common.datasets.video_utils import encode_video_frames +from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames def check_format(raw_dir): @@ -131,14 +131,17 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug): save_images_concurrently(imgs_array, tmp_imgs_dir) # encode images to a mp4 video - video_path = out_dir / "videos" / f"observation.image_episode_{ep_idx:06d}.mp4" + fname = f"observation.image_episode_{ep_idx:06d}.mp4" + video_path = out_dir / "videos" / fname encode_video_frames(tmp_imgs_dir, video_path, fps) # clean temporary images directory shutil.rmtree(tmp_imgs_dir) # store the episode index - ep_dict["observation.image"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int) + ep_dict["observation.image"] = [ + {"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames) + ] else: ep_dict["observation.image"] = [PILImage.fromarray(x) for x in imgs_array] @@ -172,7 +175,7 @@ def to_hf_dataset(data_dict, video): features = {} if video: - features["observation.image"] = Value(dtype="int64", id="video") + features["observation.image"] = VideoFrame() else: features["observation.image"] = Image() diff --git a/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py b/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py index 1b624b5d..2b52282c 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py @@ -16,7 +16,7 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episod from lerobot.common.datasets.utils import ( hf_transform_to_torch, ) -from lerobot.common.datasets.video_utils import encode_video_frames +from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames def check_format(raw_dir) -> bool: @@ -149,7 +149,7 @@ def to_hf_dataset(data_dict, video): features = {} if video: - features["observation.image"] = Value(dtype="int64", id="video") + features["observation.image"] = VideoFrame() else: features["observation.image"] = Image() diff --git a/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py b/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py index da80073f..6078dec6 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py @@ -14,7 +14,7 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episod from lerobot.common.datasets.utils import ( hf_transform_to_torch, ) -from lerobot.common.datasets.video_utils import encode_video_frames +from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames def check_format(raw_dir): @@ -80,14 +80,17 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug): save_images_concurrently(imgs_array, tmp_imgs_dir) # encode images to a mp4 video - video_path = out_dir / "videos" / f"observation.image_episode_{ep_idx:06d}.mp4" + fname = f"observation.image_episode_{ep_idx:06d}.mp4" + video_path = out_dir / "videos" / fname encode_video_frames(tmp_imgs_dir, video_path, fps) # clean temporary images directory shutil.rmtree(tmp_imgs_dir) # store the episode index - ep_dict["observation.image"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int) + ep_dict["observation.image"] = [ + {"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames) + ] else: ep_dict["observation.image"] = [PILImage.fromarray(x) for x in imgs_array] @@ -120,7 +123,7 @@ def to_hf_dataset(data_dict, video): features = {} if video: - features["observation.image"] = Value(dtype="int64", id="video") + features["observation.image"] = VideoFrame() else: features["observation.image"] = Image() diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 70d5913c..5b161afe 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -1,13 +1,9 @@ import json -from copy import deepcopy -from math import ceil from pathlib import Path import datasets -import einops import torch -import tqdm -from datasets import Image, load_dataset, load_from_disk +from datasets import load_dataset, load_from_disk from huggingface_hub import hf_hub_download, snapshot_download from PIL import Image as PILImage from safetensors.torch import load_file @@ -57,6 +53,9 @@ def hf_transform_to_torch(items_dict): if isinstance(first_item, PILImage.Image): to_tensor = transforms.ToTensor() items_dict[key] = [to_tensor(img) for img in items_dict[key]] + elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item: + # video frame will be processed downstream + pass else: items_dict[key] = [torch.tensor(x) for x in items_dict[key]] return items_dict @@ -223,138 +222,6 @@ def load_previous_and_future_frames( return item -def get_stats_einops_patterns(hf_dataset): - """These einops patterns will be used to aggregate batches and compute statistics. - - Note: We assume the images of `hf_dataset` are in channel first format - """ - - dataloader = torch.utils.data.DataLoader( - hf_dataset, - num_workers=0, - batch_size=2, - shuffle=False, - ) - batch = next(iter(dataloader)) - - stats_patterns = {} - for key, feats_type in hf_dataset.features.items(): - # sanity check that tensors are not float64 - assert batch[key].dtype != torch.float64 - - if isinstance(feats_type, Image): - # sanity check that images are channel first - _, c, h, w = batch[key].shape - assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}" - - # sanity check that images are float32 in range [0,1] - assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}" - assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}" - assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}" - - stats_patterns[key] = "b c h w -> c 1 1" - elif batch[key].ndim == 2: - stats_patterns[key] = "b c -> c " - elif batch[key].ndim == 1: - stats_patterns[key] = "b -> 1" - else: - raise ValueError(f"{key}, {feats_type}, {batch[key].shape}") - - return stats_patterns - - -def compute_stats(hf_dataset, batch_size=32, max_num_samples=None): - if max_num_samples is None: - max_num_samples = len(hf_dataset) - - stats_patterns = get_stats_einops_patterns(hf_dataset) - - # mean and std will be computed incrementally while max and min will track the running value. - mean, std, max, min = {}, {}, {}, {} - for key in stats_patterns: - mean[key] = torch.tensor(0.0).float() - std[key] = torch.tensor(0.0).float() - max[key] = torch.tensor(-float("inf")).float() - min[key] = torch.tensor(float("inf")).float() - - def create_seeded_dataloader(hf_dataset, batch_size, seed): - generator = torch.Generator() - generator.manual_seed(seed) - dataloader = torch.utils.data.DataLoader( - hf_dataset, - num_workers=4, - batch_size=batch_size, - shuffle=True, - drop_last=False, - generator=generator, - ) - return dataloader - - # Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get - # surprises when rerunning the sampler. - first_batch = None - running_item_count = 0 # for online mean computation - dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337) - for i, batch in enumerate( - tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max") - ): - this_batch_size = len(batch["index"]) - running_item_count += this_batch_size - if first_batch is None: - first_batch = deepcopy(batch) - for key, pattern in stats_patterns.items(): - batch[key] = batch[key].float() - # Numerically stable update step for mean computation. - batch_mean = einops.reduce(batch[key], pattern, "mean") - # Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents - # the update step, N is the running item count, B is this batch size, x̄ is the running mean, - # and x is the current batch mean. Some rearrangement is then required to avoid risking - # numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields - # x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ - mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count - max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max")) - min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min")) - - if i == ceil(max_num_samples / batch_size) - 1: - break - - first_batch_ = None - running_item_count = 0 # for online std computation - dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337) - for i, batch in enumerate( - tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std") - ): - this_batch_size = len(batch["index"]) - running_item_count += this_batch_size - # Sanity check to make sure the batches are still in the same order as before. - if first_batch_ is None: - first_batch_ = deepcopy(batch) - for key in stats_patterns: - assert torch.equal(first_batch_[key], first_batch[key]) - for key, pattern in stats_patterns.items(): - batch[key] = batch[key].float() - # Numerically stable update step for mean computation (where the mean is over squared - # residuals).See notes in the mean computation loop above. - batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean") - std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count - - if i == ceil(max_num_samples / batch_size) - 1: - break - - for key in stats_patterns: - std[key] = torch.sqrt(std[key]) - - stats = {} - for key in stats_patterns: - stats[key] = { - "mean": mean[key], - "std": std[key], - "max": max[key], - "min": min[key], - } - return stats - - def cycle(iterable): """The equivalent of itertools.cycle, but safe for Pytorch dataloaders. diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index 0b9be70e..00399bfa 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -1,25 +1,43 @@ -import itertools +import logging import subprocess +from dataclasses import dataclass, field from pathlib import Path +from typing import Any, ClassVar +import pyarrow as pa import torch import torchvision +from datasets.features.features import register_feature def load_from_videos(item, video_frame_keys, videos_dir): + # since video path already contains "videos" (e.g. videos_dir="data/videos", path="videos/episode_0.mp4") + data_dir = videos_dir.parent + for key in video_frame_keys: ep_idx = item["episode_index"] - video_path = videos_dir / key / f"episode_{ep_idx:06d}.mp4" + video_path = data_dir / key / f"episode_{ep_idx:06d}.mp4" - if f"{key}_timestamp" in item: + if isinstance(item[key], list): # load multiple frames at once - timestamps = item[f"{key}_timestamp"] - item[key] = decode_video_frames_torchvision(video_path, timestamps) + timestamps = [frame["timestamp"] for frame in item[key]] + paths = [frame["path"] for frame in item[key]] + if len(set(paths)) == 1: + raise NotImplementedError("All video paths are expected to be the same for now.") + video_path = data_dir / paths[0] + + frames = decode_video_frames_torchvision(video_path, timestamps) + assert len(frames) == len(timestamps) + + item[key] = frames else: # load one frame - timestamps = [item["timestamp"]] + timestamps = [item[key]["timestamp"]] + video_path = data_dir / item[key]["path"] + frames = decode_video_frames_torchvision(video_path, timestamps) assert len(frames) == 1 + item[key] = frames[0] return item @@ -36,6 +54,8 @@ def decode_video_frames_torchvision( and all subsequent frames until reaching the requested frame. The number of key frames in a video can be adjusted during encoding to take into account decoding time and video size in bytes. """ + video_path = str(video_path) + # set backend if device == "cpu": # explicitely use pyav @@ -52,10 +72,13 @@ def decode_video_frames_torchvision( # set a video stream reader # TODO(rcadene): also load audio stream at the same time - reader = torchvision.io.VideoReader(str(video_path), "video") + reader = torchvision.io.VideoReader(video_path, "video") - # sanity preprocessing (e.g. 3.60000003 -> 3.6) - timestamps = [round(ts, 4) for ts in timestamps] + def round_timestamp(ts): + # sanity preprocessing (e.g. 3.60000003 -> 3.6000, 0.0666666667 -> 0.0667) + return round(ts, 4) + + timestamps = [round_timestamp(ts) for ts in timestamps] # set the first and last requested timestamps # Note: previous timestamps are usually loaded, since we need to access the previous key frame @@ -64,10 +87,11 @@ def decode_video_frames_torchvision( # access key frame of first requested frame, and load all frames until last requested frame # for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek + reader.seek(first_ts) frames = [] - for frame in itertools.takewhile(lambda x: x["pts"] <= last_ts, reader.seek(first_ts)): + for frame in reader: # get timestamp of the loaded frame - ts = frame["pts"] + ts = round_timestamp(frame["pts"]) # if the loaded frame is not among the requested frames, we dont add it to the list of output frames is_frame_requested = ts in timestamps @@ -78,7 +102,15 @@ def decode_video_frames_torchvision( log = f"frame loaded at timestamp={ts:.4f}" if is_frame_requested: log += " requested" - print(log) + logging.info(log) + + if len(timestamps) == len(frames): + break + + # hard stop + assert ( + frame["pts"] >= last_ts + ), f"Not enough frames have been loaded in [{first_ts}, {last_ts}]. {len(timestamps)} expected, but only {len(frames)} loaded." frames = torch.stack(frames) @@ -95,10 +127,38 @@ def encode_video_frames(imgs_dir: Path, video_path: Path, fps: int): video_path.parent.mkdir(parents=True, exist_ok=True) ffmpeg_cmd = ( - f"ffmpeg -r {fps} -f image2 " + f"ffmpeg -r {fps} " + "-f image2 " + "-loglevel error " f"-i {str(imgs_dir / 'frame_%06d.png')} " "-vcodec libx264 " "-pix_fmt yuv444p " f"{str(video_path)}" ) subprocess.run(ffmpeg_cmd.split(" "), check=True) + + +@dataclass +class VideoFrame: + # TODO(rcadene, lhoestq): move to Hugging Face `datasets` repo + """ + Provides a type for a dataset containing video frames. + + Example: + + ```python + data_dict = [{"image": {"path": "videos/episode_0.mp4", "timestamp": 0.3}}] + features = {"image": VideoFrame()} + Dataset.from_dict(data_dict, features=Features(features)) + ``` + """ + + pa_type: ClassVar[Any] = pa.struct({"path": pa.string(), "timestamp": pa.float32()}) + _type: str = field(default="VideoFrame", init=False, repr=False) + + def __call__(self): + return self.pa_type + + +# to make it available in HuggingFace `datasets` +register_feature(VideoFrame, "VideoFrame") diff --git a/lerobot/scripts/push_dataset_to_hub.py b/lerobot/scripts/push_dataset_to_hub.py index 2561a6bf..f4b8ad29 100644 --- a/lerobot/scripts/push_dataset_to_hub.py +++ b/lerobot/scripts/push_dataset_to_hub.py @@ -60,8 +60,10 @@ import torch from huggingface_hub import HfApi from safetensors.torch import save_file +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw -from lerobot.common.datasets.utils import compute_stats, flatten_dict +from lerobot.common.datasets.push_dataset_to_hub.compute_stats import compute_stats +from lerobot.common.datasets.utils import flatten_dict def get_from_raw_to_lerobot_format_fn(raw_format): @@ -131,13 +133,15 @@ def push_dataset_to_hub( video: bool, debug: bool, ): + repo_id = f"{community_id}/{dataset_id}" + raw_dir = data_dir / f"{dataset_id}_raw" - out_dir = data_dir / community_id / dataset_id + out_dir = data_dir / repo_id meta_data_dir = out_dir / "meta_data" videos_dir = out_dir / "videos" - tests_out_dir = tests_data_dir / community_id / dataset_id + tests_out_dir = tests_data_dir / repo_id tests_meta_data_dir = tests_out_dir / "meta_data" if out_dir.exists(): @@ -159,7 +163,15 @@ def push_dataset_to_hub( # convert dataset from original raw format to LeRobot format hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(raw_dir, out_dir, fps, video, debug) - stats = compute_stats(hf_dataset) + lerobot_dataset = LeRobotDataset.from_preloaded( + repo_id=repo_id, + version=revision, + hf_dataset=hf_dataset, + episode_data_index=episode_data_index, + info=info, + videos_dir=videos_dir, + ) + stats = compute_stats(lerobot_dataset) if save_to_disk: hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved diff --git a/poetry.lock b/poetry.lock index 0a75bd7e..fce56ea6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2711,6 +2711,44 @@ files = [ {file = "pyarrow_hotfix-0.6.tar.gz", hash = "sha256:79d3e030f7ff890d408a100ac16d6f00b14d44a502d7897cd9fc3e3a534e9945"}, ] +[[package]] +name = "pyav" +version = "12.0.5" +description = "Pythonic bindings for FFmpeg's libraries." +optional = false +python-versions = ">=3.9" +files = [ + {file = "pyav-12.0.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f19129d01d6be826ccf9b16151b0f52d954c8a797bd0fe3b84664f42c55070e2"}, + {file = "pyav-12.0.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c4d6bf60a86cd73d7b195e7e3b6a386771f64524db72604242acc50beeaa7b62"}, + {file = "pyav-12.0.5-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fc4521f2f8f48e0d30d5a83d898a7059bad49cbcc51cff299df00d554c6cbf26"}, + {file = "pyav-12.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:67eacfa977ac669ee3c9952955bce57ad3e93c3c24a686986b7c80e748fcfdd4"}, + {file = "pyav-12.0.5-cp310-cp310-win_amd64.whl", hash = "sha256:2a8503ba2464fb2a0a23bdb0ac1743942063f7cf2eb55b5d2477567b33acfc3d"}, + {file = "pyav-12.0.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ac20eb76aeec143d571615c2dcd831976a68fc198b9d53b878b26be175a6499b"}, + {file = "pyav-12.0.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2110c813aa9b0f2cac979367d69f95cfe94fc1bcef28e2c58cee56bf7f26de34"}, + {file = "pyav-12.0.5-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6426807ce868b7e56effd7f6bb5092a9101e92ecfbadc3849691faf0bab32c21"}, + {file = "pyav-12.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5bb08a9f2efe5673bf4c1cf8a809062490de7babafd50c0d5b78894d6c288054"}, + {file = "pyav-12.0.5-cp311-cp311-win_amd64.whl", hash = "sha256:684edd212f876061e191361f92c7120d6bf43ba3f312f5b56acf3afc8d8333f6"}, + {file = "pyav-12.0.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:795b3624c8eab6bb8d530d88afcdba744cbb5f8f89d36d3da0265dc388772bde"}, + {file = "pyav-12.0.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7f083314a92352ceb13b736a71504dea05534aab912ea5f341c4382482395eb3"}, + {file = "pyav-12.0.5-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7f832618f9bd2f219cec5683939ae76c474ef993b682a67815d8ffb0b377fc17"}, + {file = "pyav-12.0.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f315cc0d0f87b53ae6de71df29fbae3cd4bfa995029129000ff9d66886e3bcbe"}, + {file = "pyav-12.0.5-cp312-cp312-win_amd64.whl", hash = "sha256:c8be9e573183a02e88c09ee9fcee8463c3b79625ff905ae96e05f1a282fe4b13"}, + {file = "pyav-12.0.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c3d11e789115704a0a14805f3cb1d9459b9ab03efeb24bb28b8ee1b25a52ce6d"}, + {file = "pyav-12.0.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:820bf8ebc82960fd2ae8c1cf1a6d09f6a84abd492d38c4580c37fed082130a22"}, + {file = "pyav-12.0.5-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eed90bc92f3e9d92ef0119e0e424fd1c58db8b186128e9b9cd9ed0da0360bf13"}, + {file = "pyav-12.0.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4f8b5fa78779acea93c986ab8afaaae6a71e3995dceff87d8a969c3a2b8c55c"}, + {file = "pyav-12.0.5-cp39-cp39-win_amd64.whl", hash = "sha256:d8a73d93e3d0377591b08dae057ba8e87211b4a05e6a59a9c90b51b801ce64ea"}, + {file = "pyav-12.0.5-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:8ad7bc5215b15f9da4990d74b4bf4d4dbf93cd61caf42e8b06d84fa1c960e864"}, + {file = "pyav-12.0.5-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4ca5db3bc68f572f0fe5d316183725270edefa61ddb4032ebda5cd7751e09020"}, + {file = "pyav-12.0.5-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c1d86d38b90e13250f62a258b90d6641957dab9bc069cbd4929bc7d3d017ec7"}, + {file = "pyav-12.0.5-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:ccf267724fe1472add37968ff3768e4e5629c125c1c79af957b366fbad3d2e59"}, + {file = "pyav-12.0.5-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:f7519a05b19123e074e67248ed0f5672df752852cc43505f721ec2db9f80813c"}, + {file = "pyav-12.0.5-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1ce141031338974567bc1e0504a5355449c61756626a07e3a43ded37a71afe39"}, + {file = "pyav-12.0.5-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02f77d361ef728483ffe9430391ee554257c5c0872da8a2276275636226b3a85"}, + {file = "pyav-12.0.5-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:647ebc369b1c7bfbdae626048e4d59265c3ab3ceb2e571ac83ddbbeaa70abb22"}, + {file = "pyav-12.0.5.tar.gz", hash = "sha256:fc65bcb72f3f8040c47a5b5a8025b535c71dcb16f1c8f9ff9bb3bf3af17ac09a"}, +] + [[package]] name = "pycparser" version = "2.22" @@ -4267,4 +4305,4 @@ xarm = ["gym-xarm"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "fab42b4be590cb2007934cd8f5a218f1f3da4f0b42cdff7e7724af518888d7b4" +content-hash = "32584053533829448b806a26a3f57712d4758f738778e67409c2e10a0bd6a0fd" diff --git a/pyproject.toml b/pyproject.toml index a0b318b7..3cf9754f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ pytest = {version = "^8.1.0", optional = true} pytest-cov = {version = "^5.0.0", optional = true} datasets = "^2.19.0" imagecodecs = { version = "^2024.1.1", optional = true } +pyav = "^12.0.5" [tool.poetry.extras]