This commit is contained in:
Cadene 2024-04-30 17:41:18 +00:00
parent f1935d9ca8
commit 88ff197453
11 changed files with 339 additions and 172 deletions

View File

@ -11,14 +11,16 @@ from lerobot.common.datasets.utils import (
load_stats, load_stats,
load_videos, load_videos,
) )
from lerobot.common.datasets.video_utils import load_from_videos from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
CODEBASE_VERSION = "v1.2"
class LeRobotDataset(torch.utils.data.Dataset): class LeRobotDataset(torch.utils.data.Dataset):
def __init__( def __init__(
self, self,
repo_id: str, repo_id: str,
version: str | None = "v1.1", version: str | None = CODEBASE_VERSION,
root: Path | None = None, root: Path | None = None,
split: str = "train", split: str = "train",
transform: callable = None, transform: callable = None,
@ -49,6 +51,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
def video(self) -> int: def video(self) -> int:
return self.info.get("video", False) return self.info.get("video", False)
@property
def features(self) -> datasets.Features:
return self.hf_dataset.features
@property @property
def image_keys(self) -> list[str]: def image_keys(self) -> list[str]:
image_keys = [] image_keys = []
@ -61,7 +67,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
def video_frame_keys(self): def video_frame_keys(self):
video_frame_keys = [] video_frame_keys = []
for key, feats in self.hf_dataset.features.items(): for key, feats in self.hf_dataset.features.items():
if isinstance(feats, datasets.Value) and feats.id == "video_frame": if isinstance(feats, VideoFrame):
video_frame_keys.append(key) video_frame_keys.append(key)
return video_frame_keys return video_frame_keys
@ -95,3 +101,34 @@ class LeRobotDataset(torch.utils.data.Dataset):
item = self.transform(item) item = self.transform(item)
return item return item
@classmethod
def from_preloaded(
cls,
repo_id: str,
version: str | None = CODEBASE_VERSION,
root: Path | None = None,
split: str = "train",
transform: callable = None,
delta_timestamps: dict[list[float]] | None = None,
# additional preloaded attributes
hf_dataset=None,
episode_data_index=None,
stats=None,
info=None,
videos_dir=None,
):
# create an empty object of type LeRobotDataset
obj = cls.__new__(cls)
obj.repo_id = repo_id
obj.version = version
obj.root = root
obj.split = split
obj.transform = transform
obj.delta_timestamps = delta_timestamps
obj.hf_dataset = hf_dataset
obj.episode_data_index = episode_data_index
obj.stats = stats
obj.info = info
obj.videos_dir = videos_dir
return obj

View File

@ -16,7 +16,7 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episod
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
hf_transform_to_torch, hf_transform_to_torch,
) )
from lerobot.common.datasets.video_utils import encode_video_frames from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
def check_format(raw_dir) -> bool: def check_format(raw_dir) -> bool:
@ -77,14 +77,17 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
save_images_concurrently(imgs_array, tmp_imgs_dir) save_images_concurrently(imgs_array, tmp_imgs_dir)
# encode images to a mp4 video # encode images to a mp4 video
video_path = out_dir / "videos" / f"{img_key}_episode_{ep_idx:06d}.mp4" fname = f"observation.image_episode_{ep_idx:06d}.mp4"
video_path = out_dir / "videos" / fname
encode_video_frames(tmp_imgs_dir, video_path, fps) encode_video_frames(tmp_imgs_dir, video_path, fps)
# clean temporary images directory # clean temporary images directory
shutil.rmtree(tmp_imgs_dir) shutil.rmtree(tmp_imgs_dir)
# store the episode idx # store the episode idx
ep_dict[img_key] = torch.tensor([ep_idx] * num_frames, dtype=torch.int) ep_dict["observation.image"] = [
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
]
else: else:
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array] ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
@ -120,7 +123,7 @@ def to_hf_dataset(data_dict, video) -> Dataset:
image_keys = [key for key in data_dict if "observation.images." in key] image_keys = [key for key in data_dict if "observation.images." in key]
for image_key in image_keys: for image_key in image_keys:
if video: if video:
features[image_key] = Value(dtype="int64", id="video") features[image_key] = VideoFrame()
else: else:
features[image_key] = Image() features[image_key] = Image()

View File

@ -0,0 +1,143 @@
from copy import deepcopy
from math import ceil
import datasets
import einops
import torch
import tqdm
from datasets import Image
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.video_utils import VideoFrame
def get_stats_einops_patterns(dataset: LeRobotDataset | datasets.Dataset):
"""These einops patterns will be used to aggregate batches and compute statistics.
Note: We assume the images are in channel first format
"""
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=2,
shuffle=False,
)
batch = next(iter(dataloader))
stats_patterns = {}
for key, feats_type in dataset.features.items():
# sanity check that tensors are not float64
assert batch[key].dtype != torch.float64
if isinstance(feats_type, VideoFrame, Image):
# sanity check that images are channel first
_, c, h, w = batch[key].shape
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
# sanity check that images are float32 in range [0,1]
assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}"
assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
stats_patterns[key] = "b c h w -> c 1 1"
elif batch[key].ndim == 2:
stats_patterns[key] = "b c -> c "
elif batch[key].ndim == 1:
stats_patterns[key] = "b -> 1"
else:
raise ValueError(f"{key}, {feats_type}, {batch[key].shape}")
return stats_patterns
def compute_stats(dataset: LeRobotDataset | datasets.Dataset, batch_size=32, max_num_samples=None):
if max_num_samples is None:
max_num_samples = len(dataset)
stats_patterns = get_stats_einops_patterns(dataset)
# mean and std will be computed incrementally while max and min will track the running value.
mean, std, max, min = {}, {}, {}, {}
for key in stats_patterns:
mean[key] = torch.tensor(0.0).float()
std[key] = torch.tensor(0.0).float()
max[key] = torch.tensor(-float("inf")).float()
min[key] = torch.tensor(float("inf")).float()
def create_seeded_dataloader(dataset, batch_size, seed):
generator = torch.Generator()
generator.manual_seed(seed)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=batch_size,
shuffle=True,
drop_last=False,
generator=generator,
)
return dataloader
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
# surprises when rerunning the sampler.
first_batch = None
running_item_count = 0 # for online mean computation
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
):
this_batch_size = len(batch["index"])
running_item_count += this_batch_size
if first_batch is None:
first_batch = deepcopy(batch)
for key, pattern in stats_patterns.items():
batch[key] = batch[key].float()
# Numerically stable update step for mean computation.
batch_mean = einops.reduce(batch[key], pattern, "mean")
# Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
# the update step, N is the running item count, B is this batch size, x̄ is the running mean,
# and x is the current batch mean. Some rearrangement is then required to avoid risking
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
if i == ceil(max_num_samples / batch_size) - 1:
break
first_batch_ = None
running_item_count = 0 # for online std computation
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
):
this_batch_size = len(batch["index"])
running_item_count += this_batch_size
# Sanity check to make sure the batches are still in the same order as before.
if first_batch_ is None:
first_batch_ = deepcopy(batch)
for key in stats_patterns:
assert torch.equal(first_batch_[key], first_batch[key])
for key, pattern in stats_patterns.items():
batch[key] = batch[key].float()
# Numerically stable update step for mean computation (where the mean is over squared
# residuals).See notes in the mean computation loop above.
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
if i == ceil(max_num_samples / batch_size) - 1:
break
for key in stats_patterns:
std[key] = torch.sqrt(std[key])
stats = {}
for key in stats_patterns:
stats[key] = {
"mean": mean[key],
"std": std[key],
"max": max[key],
"min": min[key],
}
return stats

View File

@ -14,7 +14,7 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episod
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
hf_transform_to_torch, hf_transform_to_torch,
) )
from lerobot.common.datasets.video_utils import encode_video_frames from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
def check_format(raw_dir): def check_format(raw_dir):
@ -131,14 +131,17 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
save_images_concurrently(imgs_array, tmp_imgs_dir) save_images_concurrently(imgs_array, tmp_imgs_dir)
# encode images to a mp4 video # encode images to a mp4 video
video_path = out_dir / "videos" / f"observation.image_episode_{ep_idx:06d}.mp4" fname = f"observation.image_episode_{ep_idx:06d}.mp4"
video_path = out_dir / "videos" / fname
encode_video_frames(tmp_imgs_dir, video_path, fps) encode_video_frames(tmp_imgs_dir, video_path, fps)
# clean temporary images directory # clean temporary images directory
shutil.rmtree(tmp_imgs_dir) shutil.rmtree(tmp_imgs_dir)
# store the episode index # store the episode index
ep_dict["observation.image"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int) ep_dict["observation.image"] = [
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
]
else: else:
ep_dict["observation.image"] = [PILImage.fromarray(x) for x in imgs_array] ep_dict["observation.image"] = [PILImage.fromarray(x) for x in imgs_array]
@ -172,7 +175,7 @@ def to_hf_dataset(data_dict, video):
features = {} features = {}
if video: if video:
features["observation.image"] = Value(dtype="int64", id="video") features["observation.image"] = VideoFrame()
else: else:
features["observation.image"] = Image() features["observation.image"] = Image()

View File

@ -16,7 +16,7 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episod
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
hf_transform_to_torch, hf_transform_to_torch,
) )
from lerobot.common.datasets.video_utils import encode_video_frames from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
def check_format(raw_dir) -> bool: def check_format(raw_dir) -> bool:
@ -149,7 +149,7 @@ def to_hf_dataset(data_dict, video):
features = {} features = {}
if video: if video:
features["observation.image"] = Value(dtype="int64", id="video") features["observation.image"] = VideoFrame()
else: else:
features["observation.image"] = Image() features["observation.image"] = Image()

View File

@ -14,7 +14,7 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episod
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
hf_transform_to_torch, hf_transform_to_torch,
) )
from lerobot.common.datasets.video_utils import encode_video_frames from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
def check_format(raw_dir): def check_format(raw_dir):
@ -80,14 +80,17 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
save_images_concurrently(imgs_array, tmp_imgs_dir) save_images_concurrently(imgs_array, tmp_imgs_dir)
# encode images to a mp4 video # encode images to a mp4 video
video_path = out_dir / "videos" / f"observation.image_episode_{ep_idx:06d}.mp4" fname = f"observation.image_episode_{ep_idx:06d}.mp4"
video_path = out_dir / "videos" / fname
encode_video_frames(tmp_imgs_dir, video_path, fps) encode_video_frames(tmp_imgs_dir, video_path, fps)
# clean temporary images directory # clean temporary images directory
shutil.rmtree(tmp_imgs_dir) shutil.rmtree(tmp_imgs_dir)
# store the episode index # store the episode index
ep_dict["observation.image"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int) ep_dict["observation.image"] = [
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
]
else: else:
ep_dict["observation.image"] = [PILImage.fromarray(x) for x in imgs_array] ep_dict["observation.image"] = [PILImage.fromarray(x) for x in imgs_array]
@ -120,7 +123,7 @@ def to_hf_dataset(data_dict, video):
features = {} features = {}
if video: if video:
features["observation.image"] = Value(dtype="int64", id="video") features["observation.image"] = VideoFrame()
else: else:
features["observation.image"] = Image() features["observation.image"] = Image()

View File

@ -1,13 +1,9 @@
import json import json
from copy import deepcopy
from math import ceil
from pathlib import Path from pathlib import Path
import datasets import datasets
import einops
import torch import torch
import tqdm from datasets import load_dataset, load_from_disk
from datasets import Image, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download, snapshot_download from huggingface_hub import hf_hub_download, snapshot_download
from PIL import Image as PILImage from PIL import Image as PILImage
from safetensors.torch import load_file from safetensors.torch import load_file
@ -57,6 +53,9 @@ def hf_transform_to_torch(items_dict):
if isinstance(first_item, PILImage.Image): if isinstance(first_item, PILImage.Image):
to_tensor = transforms.ToTensor() to_tensor = transforms.ToTensor()
items_dict[key] = [to_tensor(img) for img in items_dict[key]] items_dict[key] = [to_tensor(img) for img in items_dict[key]]
elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item:
# video frame will be processed downstream
pass
else: else:
items_dict[key] = [torch.tensor(x) for x in items_dict[key]] items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
return items_dict return items_dict
@ -223,138 +222,6 @@ def load_previous_and_future_frames(
return item return item
def get_stats_einops_patterns(hf_dataset):
"""These einops patterns will be used to aggregate batches and compute statistics.
Note: We assume the images of `hf_dataset` are in channel first format
"""
dataloader = torch.utils.data.DataLoader(
hf_dataset,
num_workers=0,
batch_size=2,
shuffle=False,
)
batch = next(iter(dataloader))
stats_patterns = {}
for key, feats_type in hf_dataset.features.items():
# sanity check that tensors are not float64
assert batch[key].dtype != torch.float64
if isinstance(feats_type, Image):
# sanity check that images are channel first
_, c, h, w = batch[key].shape
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
# sanity check that images are float32 in range [0,1]
assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}"
assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
stats_patterns[key] = "b c h w -> c 1 1"
elif batch[key].ndim == 2:
stats_patterns[key] = "b c -> c "
elif batch[key].ndim == 1:
stats_patterns[key] = "b -> 1"
else:
raise ValueError(f"{key}, {feats_type}, {batch[key].shape}")
return stats_patterns
def compute_stats(hf_dataset, batch_size=32, max_num_samples=None):
if max_num_samples is None:
max_num_samples = len(hf_dataset)
stats_patterns = get_stats_einops_patterns(hf_dataset)
# mean and std will be computed incrementally while max and min will track the running value.
mean, std, max, min = {}, {}, {}, {}
for key in stats_patterns:
mean[key] = torch.tensor(0.0).float()
std[key] = torch.tensor(0.0).float()
max[key] = torch.tensor(-float("inf")).float()
min[key] = torch.tensor(float("inf")).float()
def create_seeded_dataloader(hf_dataset, batch_size, seed):
generator = torch.Generator()
generator.manual_seed(seed)
dataloader = torch.utils.data.DataLoader(
hf_dataset,
num_workers=4,
batch_size=batch_size,
shuffle=True,
drop_last=False,
generator=generator,
)
return dataloader
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
# surprises when rerunning the sampler.
first_batch = None
running_item_count = 0 # for online mean computation
dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337)
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
):
this_batch_size = len(batch["index"])
running_item_count += this_batch_size
if first_batch is None:
first_batch = deepcopy(batch)
for key, pattern in stats_patterns.items():
batch[key] = batch[key].float()
# Numerically stable update step for mean computation.
batch_mean = einops.reduce(batch[key], pattern, "mean")
# Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
# the update step, N is the running item count, B is this batch size, x̄ is the running mean,
# and x is the current batch mean. Some rearrangement is then required to avoid risking
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
if i == ceil(max_num_samples / batch_size) - 1:
break
first_batch_ = None
running_item_count = 0 # for online std computation
dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337)
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
):
this_batch_size = len(batch["index"])
running_item_count += this_batch_size
# Sanity check to make sure the batches are still in the same order as before.
if first_batch_ is None:
first_batch_ = deepcopy(batch)
for key in stats_patterns:
assert torch.equal(first_batch_[key], first_batch[key])
for key, pattern in stats_patterns.items():
batch[key] = batch[key].float()
# Numerically stable update step for mean computation (where the mean is over squared
# residuals).See notes in the mean computation loop above.
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
if i == ceil(max_num_samples / batch_size) - 1:
break
for key in stats_patterns:
std[key] = torch.sqrt(std[key])
stats = {}
for key in stats_patterns:
stats[key] = {
"mean": mean[key],
"std": std[key],
"max": max[key],
"min": min[key],
}
return stats
def cycle(iterable): def cycle(iterable):
"""The equivalent of itertools.cycle, but safe for Pytorch dataloaders. """The equivalent of itertools.cycle, but safe for Pytorch dataloaders.

View File

@ -1,25 +1,43 @@
import itertools import logging
import subprocess import subprocess
from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, ClassVar
import pyarrow as pa
import torch import torch
import torchvision import torchvision
from datasets.features.features import register_feature
def load_from_videos(item, video_frame_keys, videos_dir): def load_from_videos(item, video_frame_keys, videos_dir):
# since video path already contains "videos" (e.g. videos_dir="data/videos", path="videos/episode_0.mp4")
data_dir = videos_dir.parent
for key in video_frame_keys: for key in video_frame_keys:
ep_idx = item["episode_index"] ep_idx = item["episode_index"]
video_path = videos_dir / key / f"episode_{ep_idx:06d}.mp4" video_path = data_dir / key / f"episode_{ep_idx:06d}.mp4"
if f"{key}_timestamp" in item: if isinstance(item[key], list):
# load multiple frames at once # load multiple frames at once
timestamps = item[f"{key}_timestamp"] timestamps = [frame["timestamp"] for frame in item[key]]
item[key] = decode_video_frames_torchvision(video_path, timestamps) paths = [frame["path"] for frame in item[key]]
if len(set(paths)) == 1:
raise NotImplementedError("All video paths are expected to be the same for now.")
video_path = data_dir / paths[0]
frames = decode_video_frames_torchvision(video_path, timestamps)
assert len(frames) == len(timestamps)
item[key] = frames
else: else:
# load one frame # load one frame
timestamps = [item["timestamp"]] timestamps = [item[key]["timestamp"]]
video_path = data_dir / item[key]["path"]
frames = decode_video_frames_torchvision(video_path, timestamps) frames = decode_video_frames_torchvision(video_path, timestamps)
assert len(frames) == 1 assert len(frames) == 1
item[key] = frames[0] item[key] = frames[0]
return item return item
@ -36,6 +54,8 @@ def decode_video_frames_torchvision(
and all subsequent frames until reaching the requested frame. The number of key frames in a video and all subsequent frames until reaching the requested frame. The number of key frames in a video
can be adjusted during encoding to take into account decoding time and video size in bytes. can be adjusted during encoding to take into account decoding time and video size in bytes.
""" """
video_path = str(video_path)
# set backend # set backend
if device == "cpu": if device == "cpu":
# explicitely use pyav # explicitely use pyav
@ -52,10 +72,13 @@ def decode_video_frames_torchvision(
# set a video stream reader # set a video stream reader
# TODO(rcadene): also load audio stream at the same time # TODO(rcadene): also load audio stream at the same time
reader = torchvision.io.VideoReader(str(video_path), "video") reader = torchvision.io.VideoReader(video_path, "video")
# sanity preprocessing (e.g. 3.60000003 -> 3.6) def round_timestamp(ts):
timestamps = [round(ts, 4) for ts in timestamps] # sanity preprocessing (e.g. 3.60000003 -> 3.6000, 0.0666666667 -> 0.0667)
return round(ts, 4)
timestamps = [round_timestamp(ts) for ts in timestamps]
# set the first and last requested timestamps # set the first and last requested timestamps
# Note: previous timestamps are usually loaded, since we need to access the previous key frame # Note: previous timestamps are usually loaded, since we need to access the previous key frame
@ -64,10 +87,11 @@ def decode_video_frames_torchvision(
# access key frame of first requested frame, and load all frames until last requested frame # access key frame of first requested frame, and load all frames until last requested frame
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek # for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
reader.seek(first_ts)
frames = [] frames = []
for frame in itertools.takewhile(lambda x: x["pts"] <= last_ts, reader.seek(first_ts)): for frame in reader:
# get timestamp of the loaded frame # get timestamp of the loaded frame
ts = frame["pts"] ts = round_timestamp(frame["pts"])
# if the loaded frame is not among the requested frames, we dont add it to the list of output frames # if the loaded frame is not among the requested frames, we dont add it to the list of output frames
is_frame_requested = ts in timestamps is_frame_requested = ts in timestamps
@ -78,7 +102,15 @@ def decode_video_frames_torchvision(
log = f"frame loaded at timestamp={ts:.4f}" log = f"frame loaded at timestamp={ts:.4f}"
if is_frame_requested: if is_frame_requested:
log += " requested" log += " requested"
print(log) logging.info(log)
if len(timestamps) == len(frames):
break
# hard stop
assert (
frame["pts"] >= last_ts
), f"Not enough frames have been loaded in [{first_ts}, {last_ts}]. {len(timestamps)} expected, but only {len(frames)} loaded."
frames = torch.stack(frames) frames = torch.stack(frames)
@ -95,10 +127,38 @@ def encode_video_frames(imgs_dir: Path, video_path: Path, fps: int):
video_path.parent.mkdir(parents=True, exist_ok=True) video_path.parent.mkdir(parents=True, exist_ok=True)
ffmpeg_cmd = ( ffmpeg_cmd = (
f"ffmpeg -r {fps} -f image2 " f"ffmpeg -r {fps} "
"-f image2 "
"-loglevel error "
f"-i {str(imgs_dir / 'frame_%06d.png')} " f"-i {str(imgs_dir / 'frame_%06d.png')} "
"-vcodec libx264 " "-vcodec libx264 "
"-pix_fmt yuv444p " "-pix_fmt yuv444p "
f"{str(video_path)}" f"{str(video_path)}"
) )
subprocess.run(ffmpeg_cmd.split(" "), check=True) subprocess.run(ffmpeg_cmd.split(" "), check=True)
@dataclass
class VideoFrame:
# TODO(rcadene, lhoestq): move to Hugging Face `datasets` repo
"""
Provides a type for a dataset containing video frames.
Example:
```python
data_dict = [{"image": {"path": "videos/episode_0.mp4", "timestamp": 0.3}}]
features = {"image": VideoFrame()}
Dataset.from_dict(data_dict, features=Features(features))
```
"""
pa_type: ClassVar[Any] = pa.struct({"path": pa.string(), "timestamp": pa.float32()})
_type: str = field(default="VideoFrame", init=False, repr=False)
def __call__(self):
return self.pa_type
# to make it available in HuggingFace `datasets`
register_feature(VideoFrame, "VideoFrame")

View File

@ -60,8 +60,10 @@ import torch
from huggingface_hub import HfApi from huggingface_hub import HfApi
from safetensors.torch import save_file from safetensors.torch import save_file
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
from lerobot.common.datasets.utils import compute_stats, flatten_dict from lerobot.common.datasets.push_dataset_to_hub.compute_stats import compute_stats
from lerobot.common.datasets.utils import flatten_dict
def get_from_raw_to_lerobot_format_fn(raw_format): def get_from_raw_to_lerobot_format_fn(raw_format):
@ -131,13 +133,15 @@ def push_dataset_to_hub(
video: bool, video: bool,
debug: bool, debug: bool,
): ):
repo_id = f"{community_id}/{dataset_id}"
raw_dir = data_dir / f"{dataset_id}_raw" raw_dir = data_dir / f"{dataset_id}_raw"
out_dir = data_dir / community_id / dataset_id out_dir = data_dir / repo_id
meta_data_dir = out_dir / "meta_data" meta_data_dir = out_dir / "meta_data"
videos_dir = out_dir / "videos" videos_dir = out_dir / "videos"
tests_out_dir = tests_data_dir / community_id / dataset_id tests_out_dir = tests_data_dir / repo_id
tests_meta_data_dir = tests_out_dir / "meta_data" tests_meta_data_dir = tests_out_dir / "meta_data"
if out_dir.exists(): if out_dir.exists():
@ -159,7 +163,15 @@ def push_dataset_to_hub(
# convert dataset from original raw format to LeRobot format # convert dataset from original raw format to LeRobot format
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(raw_dir, out_dir, fps, video, debug) hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(raw_dir, out_dir, fps, video, debug)
stats = compute_stats(hf_dataset) lerobot_dataset = LeRobotDataset.from_preloaded(
repo_id=repo_id,
version=revision,
hf_dataset=hf_dataset,
episode_data_index=episode_data_index,
info=info,
videos_dir=videos_dir,
)
stats = compute_stats(lerobot_dataset)
if save_to_disk: if save_to_disk:
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved

40
poetry.lock generated
View File

@ -2711,6 +2711,44 @@ files = [
{file = "pyarrow_hotfix-0.6.tar.gz", hash = "sha256:79d3e030f7ff890d408a100ac16d6f00b14d44a502d7897cd9fc3e3a534e9945"}, {file = "pyarrow_hotfix-0.6.tar.gz", hash = "sha256:79d3e030f7ff890d408a100ac16d6f00b14d44a502d7897cd9fc3e3a534e9945"},
] ]
[[package]]
name = "pyav"
version = "12.0.5"
description = "Pythonic bindings for FFmpeg's libraries."
optional = false
python-versions = ">=3.9"
files = [
{file = "pyav-12.0.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f19129d01d6be826ccf9b16151b0f52d954c8a797bd0fe3b84664f42c55070e2"},
{file = "pyav-12.0.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c4d6bf60a86cd73d7b195e7e3b6a386771f64524db72604242acc50beeaa7b62"},
{file = "pyav-12.0.5-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fc4521f2f8f48e0d30d5a83d898a7059bad49cbcc51cff299df00d554c6cbf26"},
{file = "pyav-12.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:67eacfa977ac669ee3c9952955bce57ad3e93c3c24a686986b7c80e748fcfdd4"},
{file = "pyav-12.0.5-cp310-cp310-win_amd64.whl", hash = "sha256:2a8503ba2464fb2a0a23bdb0ac1743942063f7cf2eb55b5d2477567b33acfc3d"},
{file = "pyav-12.0.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ac20eb76aeec143d571615c2dcd831976a68fc198b9d53b878b26be175a6499b"},
{file = "pyav-12.0.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2110c813aa9b0f2cac979367d69f95cfe94fc1bcef28e2c58cee56bf7f26de34"},
{file = "pyav-12.0.5-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6426807ce868b7e56effd7f6bb5092a9101e92ecfbadc3849691faf0bab32c21"},
{file = "pyav-12.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5bb08a9f2efe5673bf4c1cf8a809062490de7babafd50c0d5b78894d6c288054"},
{file = "pyav-12.0.5-cp311-cp311-win_amd64.whl", hash = "sha256:684edd212f876061e191361f92c7120d6bf43ba3f312f5b56acf3afc8d8333f6"},
{file = "pyav-12.0.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:795b3624c8eab6bb8d530d88afcdba744cbb5f8f89d36d3da0265dc388772bde"},
{file = "pyav-12.0.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7f083314a92352ceb13b736a71504dea05534aab912ea5f341c4382482395eb3"},
{file = "pyav-12.0.5-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7f832618f9bd2f219cec5683939ae76c474ef993b682a67815d8ffb0b377fc17"},
{file = "pyav-12.0.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f315cc0d0f87b53ae6de71df29fbae3cd4bfa995029129000ff9d66886e3bcbe"},
{file = "pyav-12.0.5-cp312-cp312-win_amd64.whl", hash = "sha256:c8be9e573183a02e88c09ee9fcee8463c3b79625ff905ae96e05f1a282fe4b13"},
{file = "pyav-12.0.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c3d11e789115704a0a14805f3cb1d9459b9ab03efeb24bb28b8ee1b25a52ce6d"},
{file = "pyav-12.0.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:820bf8ebc82960fd2ae8c1cf1a6d09f6a84abd492d38c4580c37fed082130a22"},
{file = "pyav-12.0.5-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eed90bc92f3e9d92ef0119e0e424fd1c58db8b186128e9b9cd9ed0da0360bf13"},
{file = "pyav-12.0.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4f8b5fa78779acea93c986ab8afaaae6a71e3995dceff87d8a969c3a2b8c55c"},
{file = "pyav-12.0.5-cp39-cp39-win_amd64.whl", hash = "sha256:d8a73d93e3d0377591b08dae057ba8e87211b4a05e6a59a9c90b51b801ce64ea"},
{file = "pyav-12.0.5-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:8ad7bc5215b15f9da4990d74b4bf4d4dbf93cd61caf42e8b06d84fa1c960e864"},
{file = "pyav-12.0.5-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4ca5db3bc68f572f0fe5d316183725270edefa61ddb4032ebda5cd7751e09020"},
{file = "pyav-12.0.5-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c1d86d38b90e13250f62a258b90d6641957dab9bc069cbd4929bc7d3d017ec7"},
{file = "pyav-12.0.5-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:ccf267724fe1472add37968ff3768e4e5629c125c1c79af957b366fbad3d2e59"},
{file = "pyav-12.0.5-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:f7519a05b19123e074e67248ed0f5672df752852cc43505f721ec2db9f80813c"},
{file = "pyav-12.0.5-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1ce141031338974567bc1e0504a5355449c61756626a07e3a43ded37a71afe39"},
{file = "pyav-12.0.5-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02f77d361ef728483ffe9430391ee554257c5c0872da8a2276275636226b3a85"},
{file = "pyav-12.0.5-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:647ebc369b1c7bfbdae626048e4d59265c3ab3ceb2e571ac83ddbbeaa70abb22"},
{file = "pyav-12.0.5.tar.gz", hash = "sha256:fc65bcb72f3f8040c47a5b5a8025b535c71dcb16f1c8f9ff9bb3bf3af17ac09a"},
]
[[package]] [[package]]
name = "pycparser" name = "pycparser"
version = "2.22" version = "2.22"
@ -4267,4 +4305,4 @@ xarm = ["gym-xarm"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "fab42b4be590cb2007934cd8f5a218f1f3da4f0b42cdff7e7724af518888d7b4" content-hash = "32584053533829448b806a26a3f57712d4758f738778e67409c2e10a0bd6a0fd"

View File

@ -56,6 +56,7 @@ pytest = {version = "^8.1.0", optional = true}
pytest-cov = {version = "^5.0.0", optional = true} pytest-cov = {version = "^5.0.0", optional = true}
datasets = "^2.19.0" datasets = "^2.19.0"
imagecodecs = { version = "^2024.1.1", optional = true } imagecodecs = { version = "^2024.1.1", optional = true }
pyav = "^12.0.5"
[tool.poetry.extras] [tool.poetry.extras]