WIP
This commit is contained in:
parent
f1935d9ca8
commit
88ff197453
|
@ -11,14 +11,16 @@ from lerobot.common.datasets.utils import (
|
|||
load_stats,
|
||||
load_videos,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import load_from_videos
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
|
||||
|
||||
CODEBASE_VERSION = "v1.2"
|
||||
|
||||
|
||||
class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
version: str | None = "v1.1",
|
||||
version: str | None = CODEBASE_VERSION,
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
transform: callable = None,
|
||||
|
@ -49,6 +51,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
def video(self) -> int:
|
||||
return self.info.get("video", False)
|
||||
|
||||
@property
|
||||
def features(self) -> datasets.Features:
|
||||
return self.hf_dataset.features
|
||||
|
||||
@property
|
||||
def image_keys(self) -> list[str]:
|
||||
image_keys = []
|
||||
|
@ -61,7 +67,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
def video_frame_keys(self):
|
||||
video_frame_keys = []
|
||||
for key, feats in self.hf_dataset.features.items():
|
||||
if isinstance(feats, datasets.Value) and feats.id == "video_frame":
|
||||
if isinstance(feats, VideoFrame):
|
||||
video_frame_keys.append(key)
|
||||
return video_frame_keys
|
||||
|
||||
|
@ -95,3 +101,34 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
item = self.transform(item)
|
||||
|
||||
return item
|
||||
|
||||
@classmethod
|
||||
def from_preloaded(
|
||||
cls,
|
||||
repo_id: str,
|
||||
version: str | None = CODEBASE_VERSION,
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
transform: callable = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
# additional preloaded attributes
|
||||
hf_dataset=None,
|
||||
episode_data_index=None,
|
||||
stats=None,
|
||||
info=None,
|
||||
videos_dir=None,
|
||||
):
|
||||
# create an empty object of type LeRobotDataset
|
||||
obj = cls.__new__(cls)
|
||||
obj.repo_id = repo_id
|
||||
obj.version = version
|
||||
obj.root = root
|
||||
obj.split = split
|
||||
obj.transform = transform
|
||||
obj.delta_timestamps = delta_timestamps
|
||||
obj.hf_dataset = hf_dataset
|
||||
obj.episode_data_index = episode_data_index
|
||||
obj.stats = stats
|
||||
obj.info = info
|
||||
obj.videos_dir = videos_dir
|
||||
return obj
|
||||
|
|
|
@ -16,7 +16,7 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episod
|
|||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
|
||||
|
||||
def check_format(raw_dir) -> bool:
|
||||
|
@ -77,14 +77,17 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
|||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
# encode images to a mp4 video
|
||||
video_path = out_dir / "videos" / f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
fname = f"observation.image_episode_{ep_idx:06d}.mp4"
|
||||
video_path = out_dir / "videos" / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
# store the episode idx
|
||||
ep_dict[img_key] = torch.tensor([ep_idx] * num_frames, dtype=torch.int)
|
||||
ep_dict["observation.image"] = [
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
|
||||
]
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
|
||||
|
@ -120,7 +123,7 @@ def to_hf_dataset(data_dict, video) -> Dataset:
|
|||
image_keys = [key for key in data_dict if "observation.images." in key]
|
||||
for image_key in image_keys:
|
||||
if video:
|
||||
features[image_key] = Value(dtype="int64", id="video")
|
||||
features[image_key] = VideoFrame()
|
||||
else:
|
||||
features[image_key] = Image()
|
||||
|
||||
|
|
|
@ -0,0 +1,143 @@
|
|||
from copy import deepcopy
|
||||
from math import ceil
|
||||
|
||||
import datasets
|
||||
import einops
|
||||
import torch
|
||||
import tqdm
|
||||
from datasets import Image
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.video_utils import VideoFrame
|
||||
|
||||
|
||||
def get_stats_einops_patterns(dataset: LeRobotDataset | datasets.Dataset):
|
||||
"""These einops patterns will be used to aggregate batches and compute statistics.
|
||||
|
||||
Note: We assume the images are in channel first format
|
||||
"""
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
batch_size=2,
|
||||
shuffle=False,
|
||||
)
|
||||
batch = next(iter(dataloader))
|
||||
|
||||
stats_patterns = {}
|
||||
for key, feats_type in dataset.features.items():
|
||||
# sanity check that tensors are not float64
|
||||
assert batch[key].dtype != torch.float64
|
||||
|
||||
if isinstance(feats_type, VideoFrame, Image):
|
||||
# sanity check that images are channel first
|
||||
_, c, h, w = batch[key].shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
|
||||
|
||||
# sanity check that images are float32 in range [0,1]
|
||||
assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}"
|
||||
assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
|
||||
assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
|
||||
|
||||
stats_patterns[key] = "b c h w -> c 1 1"
|
||||
elif batch[key].ndim == 2:
|
||||
stats_patterns[key] = "b c -> c "
|
||||
elif batch[key].ndim == 1:
|
||||
stats_patterns[key] = "b -> 1"
|
||||
else:
|
||||
raise ValueError(f"{key}, {feats_type}, {batch[key].shape}")
|
||||
|
||||
return stats_patterns
|
||||
|
||||
|
||||
def compute_stats(dataset: LeRobotDataset | datasets.Dataset, batch_size=32, max_num_samples=None):
|
||||
if max_num_samples is None:
|
||||
max_num_samples = len(dataset)
|
||||
|
||||
stats_patterns = get_stats_einops_patterns(dataset)
|
||||
|
||||
# mean and std will be computed incrementally while max and min will track the running value.
|
||||
mean, std, max, min = {}, {}, {}, {}
|
||||
for key in stats_patterns:
|
||||
mean[key] = torch.tensor(0.0).float()
|
||||
std[key] = torch.tensor(0.0).float()
|
||||
max[key] = torch.tensor(-float("inf")).float()
|
||||
min[key] = torch.tensor(float("inf")).float()
|
||||
|
||||
def create_seeded_dataloader(dataset, batch_size, seed):
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
drop_last=False,
|
||||
generator=generator,
|
||||
)
|
||||
return dataloader
|
||||
|
||||
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
|
||||
# surprises when rerunning the sampler.
|
||||
first_batch = None
|
||||
running_item_count = 0 # for online mean computation
|
||||
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
|
||||
for i, batch in enumerate(
|
||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
|
||||
):
|
||||
this_batch_size = len(batch["index"])
|
||||
running_item_count += this_batch_size
|
||||
if first_batch is None:
|
||||
first_batch = deepcopy(batch)
|
||||
for key, pattern in stats_patterns.items():
|
||||
batch[key] = batch[key].float()
|
||||
# Numerically stable update step for mean computation.
|
||||
batch_mean = einops.reduce(batch[key], pattern, "mean")
|
||||
# Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
|
||||
# the update step, N is the running item count, B is this batch size, x̄ is the running mean,
|
||||
# and x is the current batch mean. Some rearrangement is then required to avoid risking
|
||||
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
|
||||
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
|
||||
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
|
||||
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
||||
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
||||
|
||||
if i == ceil(max_num_samples / batch_size) - 1:
|
||||
break
|
||||
|
||||
first_batch_ = None
|
||||
running_item_count = 0 # for online std computation
|
||||
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
|
||||
for i, batch in enumerate(
|
||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
|
||||
):
|
||||
this_batch_size = len(batch["index"])
|
||||
running_item_count += this_batch_size
|
||||
# Sanity check to make sure the batches are still in the same order as before.
|
||||
if first_batch_ is None:
|
||||
first_batch_ = deepcopy(batch)
|
||||
for key in stats_patterns:
|
||||
assert torch.equal(first_batch_[key], first_batch[key])
|
||||
for key, pattern in stats_patterns.items():
|
||||
batch[key] = batch[key].float()
|
||||
# Numerically stable update step for mean computation (where the mean is over squared
|
||||
# residuals).See notes in the mean computation loop above.
|
||||
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
|
||||
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
|
||||
|
||||
if i == ceil(max_num_samples / batch_size) - 1:
|
||||
break
|
||||
|
||||
for key in stats_patterns:
|
||||
std[key] = torch.sqrt(std[key])
|
||||
|
||||
stats = {}
|
||||
for key in stats_patterns:
|
||||
stats[key] = {
|
||||
"mean": mean[key],
|
||||
"std": std[key],
|
||||
"max": max[key],
|
||||
"min": min[key],
|
||||
}
|
||||
return stats
|
|
@ -14,7 +14,7 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episod
|
|||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
|
||||
|
||||
def check_format(raw_dir):
|
||||
|
@ -131,14 +131,17 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
|||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
# encode images to a mp4 video
|
||||
video_path = out_dir / "videos" / f"observation.image_episode_{ep_idx:06d}.mp4"
|
||||
fname = f"observation.image_episode_{ep_idx:06d}.mp4"
|
||||
video_path = out_dir / "videos" / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
# store the episode index
|
||||
ep_dict["observation.image"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int)
|
||||
ep_dict["observation.image"] = [
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
|
||||
]
|
||||
else:
|
||||
ep_dict["observation.image"] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
|
||||
|
@ -172,7 +175,7 @@ def to_hf_dataset(data_dict, video):
|
|||
features = {}
|
||||
|
||||
if video:
|
||||
features["observation.image"] = Value(dtype="int64", id="video")
|
||||
features["observation.image"] = VideoFrame()
|
||||
else:
|
||||
features["observation.image"] = Image()
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episod
|
|||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
|
||||
|
||||
def check_format(raw_dir) -> bool:
|
||||
|
@ -149,7 +149,7 @@ def to_hf_dataset(data_dict, video):
|
|||
features = {}
|
||||
|
||||
if video:
|
||||
features["observation.image"] = Value(dtype="int64", id="video")
|
||||
features["observation.image"] = VideoFrame()
|
||||
else:
|
||||
features["observation.image"] = Image()
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episod
|
|||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
|
||||
|
||||
def check_format(raw_dir):
|
||||
|
@ -80,14 +80,17 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
|||
save_images_concurrently(imgs_array, tmp_imgs_dir)
|
||||
|
||||
# encode images to a mp4 video
|
||||
video_path = out_dir / "videos" / f"observation.image_episode_{ep_idx:06d}.mp4"
|
||||
fname = f"observation.image_episode_{ep_idx:06d}.mp4"
|
||||
video_path = out_dir / "videos" / fname
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||
|
||||
# clean temporary images directory
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
# store the episode index
|
||||
ep_dict["observation.image"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int)
|
||||
ep_dict["observation.image"] = [
|
||||
{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)
|
||||
]
|
||||
else:
|
||||
ep_dict["observation.image"] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
|
||||
|
@ -120,7 +123,7 @@ def to_hf_dataset(data_dict, video):
|
|||
features = {}
|
||||
|
||||
if video:
|
||||
features["observation.image"] = Value(dtype="int64", id="video")
|
||||
features["observation.image"] = VideoFrame()
|
||||
else:
|
||||
features["observation.image"] = Image()
|
||||
|
||||
|
|
|
@ -1,13 +1,9 @@
|
|||
import json
|
||||
from copy import deepcopy
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import einops
|
||||
import torch
|
||||
import tqdm
|
||||
from datasets import Image, load_dataset, load_from_disk
|
||||
from datasets import load_dataset, load_from_disk
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from PIL import Image as PILImage
|
||||
from safetensors.torch import load_file
|
||||
|
@ -57,6 +53,9 @@ def hf_transform_to_torch(items_dict):
|
|||
if isinstance(first_item, PILImage.Image):
|
||||
to_tensor = transforms.ToTensor()
|
||||
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
||||
elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item:
|
||||
# video frame will be processed downstream
|
||||
pass
|
||||
else:
|
||||
items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
|
||||
return items_dict
|
||||
|
@ -223,138 +222,6 @@ def load_previous_and_future_frames(
|
|||
return item
|
||||
|
||||
|
||||
def get_stats_einops_patterns(hf_dataset):
|
||||
"""These einops patterns will be used to aggregate batches and compute statistics.
|
||||
|
||||
Note: We assume the images of `hf_dataset` are in channel first format
|
||||
"""
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
hf_dataset,
|
||||
num_workers=0,
|
||||
batch_size=2,
|
||||
shuffle=False,
|
||||
)
|
||||
batch = next(iter(dataloader))
|
||||
|
||||
stats_patterns = {}
|
||||
for key, feats_type in hf_dataset.features.items():
|
||||
# sanity check that tensors are not float64
|
||||
assert batch[key].dtype != torch.float64
|
||||
|
||||
if isinstance(feats_type, Image):
|
||||
# sanity check that images are channel first
|
||||
_, c, h, w = batch[key].shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
|
||||
|
||||
# sanity check that images are float32 in range [0,1]
|
||||
assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}"
|
||||
assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
|
||||
assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
|
||||
|
||||
stats_patterns[key] = "b c h w -> c 1 1"
|
||||
elif batch[key].ndim == 2:
|
||||
stats_patterns[key] = "b c -> c "
|
||||
elif batch[key].ndim == 1:
|
||||
stats_patterns[key] = "b -> 1"
|
||||
else:
|
||||
raise ValueError(f"{key}, {feats_type}, {batch[key].shape}")
|
||||
|
||||
return stats_patterns
|
||||
|
||||
|
||||
def compute_stats(hf_dataset, batch_size=32, max_num_samples=None):
|
||||
if max_num_samples is None:
|
||||
max_num_samples = len(hf_dataset)
|
||||
|
||||
stats_patterns = get_stats_einops_patterns(hf_dataset)
|
||||
|
||||
# mean and std will be computed incrementally while max and min will track the running value.
|
||||
mean, std, max, min = {}, {}, {}, {}
|
||||
for key in stats_patterns:
|
||||
mean[key] = torch.tensor(0.0).float()
|
||||
std[key] = torch.tensor(0.0).float()
|
||||
max[key] = torch.tensor(-float("inf")).float()
|
||||
min[key] = torch.tensor(float("inf")).float()
|
||||
|
||||
def create_seeded_dataloader(hf_dataset, batch_size, seed):
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
hf_dataset,
|
||||
num_workers=4,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
drop_last=False,
|
||||
generator=generator,
|
||||
)
|
||||
return dataloader
|
||||
|
||||
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
|
||||
# surprises when rerunning the sampler.
|
||||
first_batch = None
|
||||
running_item_count = 0 # for online mean computation
|
||||
dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337)
|
||||
for i, batch in enumerate(
|
||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
|
||||
):
|
||||
this_batch_size = len(batch["index"])
|
||||
running_item_count += this_batch_size
|
||||
if first_batch is None:
|
||||
first_batch = deepcopy(batch)
|
||||
for key, pattern in stats_patterns.items():
|
||||
batch[key] = batch[key].float()
|
||||
# Numerically stable update step for mean computation.
|
||||
batch_mean = einops.reduce(batch[key], pattern, "mean")
|
||||
# Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
|
||||
# the update step, N is the running item count, B is this batch size, x̄ is the running mean,
|
||||
# and x is the current batch mean. Some rearrangement is then required to avoid risking
|
||||
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
|
||||
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
|
||||
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
|
||||
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
|
||||
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
|
||||
|
||||
if i == ceil(max_num_samples / batch_size) - 1:
|
||||
break
|
||||
|
||||
first_batch_ = None
|
||||
running_item_count = 0 # for online std computation
|
||||
dataloader = create_seeded_dataloader(hf_dataset, batch_size, seed=1337)
|
||||
for i, batch in enumerate(
|
||||
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
|
||||
):
|
||||
this_batch_size = len(batch["index"])
|
||||
running_item_count += this_batch_size
|
||||
# Sanity check to make sure the batches are still in the same order as before.
|
||||
if first_batch_ is None:
|
||||
first_batch_ = deepcopy(batch)
|
||||
for key in stats_patterns:
|
||||
assert torch.equal(first_batch_[key], first_batch[key])
|
||||
for key, pattern in stats_patterns.items():
|
||||
batch[key] = batch[key].float()
|
||||
# Numerically stable update step for mean computation (where the mean is over squared
|
||||
# residuals).See notes in the mean computation loop above.
|
||||
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
|
||||
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
|
||||
|
||||
if i == ceil(max_num_samples / batch_size) - 1:
|
||||
break
|
||||
|
||||
for key in stats_patterns:
|
||||
std[key] = torch.sqrt(std[key])
|
||||
|
||||
stats = {}
|
||||
for key in stats_patterns:
|
||||
stats[key] = {
|
||||
"mean": mean[key],
|
||||
"std": std[key],
|
||||
"max": max[key],
|
||||
"min": min[key],
|
||||
}
|
||||
return stats
|
||||
|
||||
|
||||
def cycle(iterable):
|
||||
"""The equivalent of itertools.cycle, but safe for Pytorch dataloaders.
|
||||
|
||||
|
|
|
@ -1,25 +1,43 @@
|
|||
import itertools
|
||||
import logging
|
||||
import subprocess
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import pyarrow as pa
|
||||
import torch
|
||||
import torchvision
|
||||
from datasets.features.features import register_feature
|
||||
|
||||
|
||||
def load_from_videos(item, video_frame_keys, videos_dir):
|
||||
# since video path already contains "videos" (e.g. videos_dir="data/videos", path="videos/episode_0.mp4")
|
||||
data_dir = videos_dir.parent
|
||||
|
||||
for key in video_frame_keys:
|
||||
ep_idx = item["episode_index"]
|
||||
video_path = videos_dir / key / f"episode_{ep_idx:06d}.mp4"
|
||||
video_path = data_dir / key / f"episode_{ep_idx:06d}.mp4"
|
||||
|
||||
if f"{key}_timestamp" in item:
|
||||
if isinstance(item[key], list):
|
||||
# load multiple frames at once
|
||||
timestamps = item[f"{key}_timestamp"]
|
||||
item[key] = decode_video_frames_torchvision(video_path, timestamps)
|
||||
timestamps = [frame["timestamp"] for frame in item[key]]
|
||||
paths = [frame["path"] for frame in item[key]]
|
||||
if len(set(paths)) == 1:
|
||||
raise NotImplementedError("All video paths are expected to be the same for now.")
|
||||
video_path = data_dir / paths[0]
|
||||
|
||||
frames = decode_video_frames_torchvision(video_path, timestamps)
|
||||
assert len(frames) == len(timestamps)
|
||||
|
||||
item[key] = frames
|
||||
else:
|
||||
# load one frame
|
||||
timestamps = [item["timestamp"]]
|
||||
timestamps = [item[key]["timestamp"]]
|
||||
video_path = data_dir / item[key]["path"]
|
||||
|
||||
frames = decode_video_frames_torchvision(video_path, timestamps)
|
||||
assert len(frames) == 1
|
||||
|
||||
item[key] = frames[0]
|
||||
|
||||
return item
|
||||
|
@ -36,6 +54,8 @@ def decode_video_frames_torchvision(
|
|||
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
||||
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
||||
"""
|
||||
video_path = str(video_path)
|
||||
|
||||
# set backend
|
||||
if device == "cpu":
|
||||
# explicitely use pyav
|
||||
|
@ -52,10 +72,13 @@ def decode_video_frames_torchvision(
|
|||
|
||||
# set a video stream reader
|
||||
# TODO(rcadene): also load audio stream at the same time
|
||||
reader = torchvision.io.VideoReader(str(video_path), "video")
|
||||
reader = torchvision.io.VideoReader(video_path, "video")
|
||||
|
||||
# sanity preprocessing (e.g. 3.60000003 -> 3.6)
|
||||
timestamps = [round(ts, 4) for ts in timestamps]
|
||||
def round_timestamp(ts):
|
||||
# sanity preprocessing (e.g. 3.60000003 -> 3.6000, 0.0666666667 -> 0.0667)
|
||||
return round(ts, 4)
|
||||
|
||||
timestamps = [round_timestamp(ts) for ts in timestamps]
|
||||
|
||||
# set the first and last requested timestamps
|
||||
# Note: previous timestamps are usually loaded, since we need to access the previous key frame
|
||||
|
@ -64,10 +87,11 @@ def decode_video_frames_torchvision(
|
|||
|
||||
# access key frame of first requested frame, and load all frames until last requested frame
|
||||
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
|
||||
reader.seek(first_ts)
|
||||
frames = []
|
||||
for frame in itertools.takewhile(lambda x: x["pts"] <= last_ts, reader.seek(first_ts)):
|
||||
for frame in reader:
|
||||
# get timestamp of the loaded frame
|
||||
ts = frame["pts"]
|
||||
ts = round_timestamp(frame["pts"])
|
||||
|
||||
# if the loaded frame is not among the requested frames, we dont add it to the list of output frames
|
||||
is_frame_requested = ts in timestamps
|
||||
|
@ -78,7 +102,15 @@ def decode_video_frames_torchvision(
|
|||
log = f"frame loaded at timestamp={ts:.4f}"
|
||||
if is_frame_requested:
|
||||
log += " requested"
|
||||
print(log)
|
||||
logging.info(log)
|
||||
|
||||
if len(timestamps) == len(frames):
|
||||
break
|
||||
|
||||
# hard stop
|
||||
assert (
|
||||
frame["pts"] >= last_ts
|
||||
), f"Not enough frames have been loaded in [{first_ts}, {last_ts}]. {len(timestamps)} expected, but only {len(frames)} loaded."
|
||||
|
||||
frames = torch.stack(frames)
|
||||
|
||||
|
@ -95,10 +127,38 @@ def encode_video_frames(imgs_dir: Path, video_path: Path, fps: int):
|
|||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ffmpeg_cmd = (
|
||||
f"ffmpeg -r {fps} -f image2 "
|
||||
f"ffmpeg -r {fps} "
|
||||
"-f image2 "
|
||||
"-loglevel error "
|
||||
f"-i {str(imgs_dir / 'frame_%06d.png')} "
|
||||
"-vcodec libx264 "
|
||||
"-pix_fmt yuv444p "
|
||||
f"{str(video_path)}"
|
||||
)
|
||||
subprocess.run(ffmpeg_cmd.split(" "), check=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoFrame:
|
||||
# TODO(rcadene, lhoestq): move to Hugging Face `datasets` repo
|
||||
"""
|
||||
Provides a type for a dataset containing video frames.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
data_dict = [{"image": {"path": "videos/episode_0.mp4", "timestamp": 0.3}}]
|
||||
features = {"image": VideoFrame()}
|
||||
Dataset.from_dict(data_dict, features=Features(features))
|
||||
```
|
||||
"""
|
||||
|
||||
pa_type: ClassVar[Any] = pa.struct({"path": pa.string(), "timestamp": pa.float32()})
|
||||
_type: str = field(default="VideoFrame", init=False, repr=False)
|
||||
|
||||
def __call__(self):
|
||||
return self.pa_type
|
||||
|
||||
|
||||
# to make it available in HuggingFace `datasets`
|
||||
register_feature(VideoFrame, "VideoFrame")
|
||||
|
|
|
@ -60,8 +60,10 @@ import torch
|
|||
from huggingface_hub import HfApi
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
||||
from lerobot.common.datasets.utils import compute_stats, flatten_dict
|
||||
from lerobot.common.datasets.push_dataset_to_hub.compute_stats import compute_stats
|
||||
from lerobot.common.datasets.utils import flatten_dict
|
||||
|
||||
|
||||
def get_from_raw_to_lerobot_format_fn(raw_format):
|
||||
|
@ -131,13 +133,15 @@ def push_dataset_to_hub(
|
|||
video: bool,
|
||||
debug: bool,
|
||||
):
|
||||
repo_id = f"{community_id}/{dataset_id}"
|
||||
|
||||
raw_dir = data_dir / f"{dataset_id}_raw"
|
||||
|
||||
out_dir = data_dir / community_id / dataset_id
|
||||
out_dir = data_dir / repo_id
|
||||
meta_data_dir = out_dir / "meta_data"
|
||||
videos_dir = out_dir / "videos"
|
||||
|
||||
tests_out_dir = tests_data_dir / community_id / dataset_id
|
||||
tests_out_dir = tests_data_dir / repo_id
|
||||
tests_meta_data_dir = tests_out_dir / "meta_data"
|
||||
|
||||
if out_dir.exists():
|
||||
|
@ -159,7 +163,15 @@ def push_dataset_to_hub(
|
|||
# convert dataset from original raw format to LeRobot format
|
||||
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(raw_dir, out_dir, fps, video, debug)
|
||||
|
||||
stats = compute_stats(hf_dataset)
|
||||
lerobot_dataset = LeRobotDataset.from_preloaded(
|
||||
repo_id=repo_id,
|
||||
version=revision,
|
||||
hf_dataset=hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
info=info,
|
||||
videos_dir=videos_dir,
|
||||
)
|
||||
stats = compute_stats(lerobot_dataset)
|
||||
|
||||
if save_to_disk:
|
||||
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
||||
|
|
|
@ -2711,6 +2711,44 @@ files = [
|
|||
{file = "pyarrow_hotfix-0.6.tar.gz", hash = "sha256:79d3e030f7ff890d408a100ac16d6f00b14d44a502d7897cd9fc3e3a534e9945"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyav"
|
||||
version = "12.0.5"
|
||||
description = "Pythonic bindings for FFmpeg's libraries."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "pyav-12.0.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f19129d01d6be826ccf9b16151b0f52d954c8a797bd0fe3b84664f42c55070e2"},
|
||||
{file = "pyav-12.0.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c4d6bf60a86cd73d7b195e7e3b6a386771f64524db72604242acc50beeaa7b62"},
|
||||
{file = "pyav-12.0.5-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fc4521f2f8f48e0d30d5a83d898a7059bad49cbcc51cff299df00d554c6cbf26"},
|
||||
{file = "pyav-12.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:67eacfa977ac669ee3c9952955bce57ad3e93c3c24a686986b7c80e748fcfdd4"},
|
||||
{file = "pyav-12.0.5-cp310-cp310-win_amd64.whl", hash = "sha256:2a8503ba2464fb2a0a23bdb0ac1743942063f7cf2eb55b5d2477567b33acfc3d"},
|
||||
{file = "pyav-12.0.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ac20eb76aeec143d571615c2dcd831976a68fc198b9d53b878b26be175a6499b"},
|
||||
{file = "pyav-12.0.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2110c813aa9b0f2cac979367d69f95cfe94fc1bcef28e2c58cee56bf7f26de34"},
|
||||
{file = "pyav-12.0.5-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6426807ce868b7e56effd7f6bb5092a9101e92ecfbadc3849691faf0bab32c21"},
|
||||
{file = "pyav-12.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5bb08a9f2efe5673bf4c1cf8a809062490de7babafd50c0d5b78894d6c288054"},
|
||||
{file = "pyav-12.0.5-cp311-cp311-win_amd64.whl", hash = "sha256:684edd212f876061e191361f92c7120d6bf43ba3f312f5b56acf3afc8d8333f6"},
|
||||
{file = "pyav-12.0.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:795b3624c8eab6bb8d530d88afcdba744cbb5f8f89d36d3da0265dc388772bde"},
|
||||
{file = "pyav-12.0.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7f083314a92352ceb13b736a71504dea05534aab912ea5f341c4382482395eb3"},
|
||||
{file = "pyav-12.0.5-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7f832618f9bd2f219cec5683939ae76c474ef993b682a67815d8ffb0b377fc17"},
|
||||
{file = "pyav-12.0.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f315cc0d0f87b53ae6de71df29fbae3cd4bfa995029129000ff9d66886e3bcbe"},
|
||||
{file = "pyav-12.0.5-cp312-cp312-win_amd64.whl", hash = "sha256:c8be9e573183a02e88c09ee9fcee8463c3b79625ff905ae96e05f1a282fe4b13"},
|
||||
{file = "pyav-12.0.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c3d11e789115704a0a14805f3cb1d9459b9ab03efeb24bb28b8ee1b25a52ce6d"},
|
||||
{file = "pyav-12.0.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:820bf8ebc82960fd2ae8c1cf1a6d09f6a84abd492d38c4580c37fed082130a22"},
|
||||
{file = "pyav-12.0.5-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eed90bc92f3e9d92ef0119e0e424fd1c58db8b186128e9b9cd9ed0da0360bf13"},
|
||||
{file = "pyav-12.0.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4f8b5fa78779acea93c986ab8afaaae6a71e3995dceff87d8a969c3a2b8c55c"},
|
||||
{file = "pyav-12.0.5-cp39-cp39-win_amd64.whl", hash = "sha256:d8a73d93e3d0377591b08dae057ba8e87211b4a05e6a59a9c90b51b801ce64ea"},
|
||||
{file = "pyav-12.0.5-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:8ad7bc5215b15f9da4990d74b4bf4d4dbf93cd61caf42e8b06d84fa1c960e864"},
|
||||
{file = "pyav-12.0.5-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4ca5db3bc68f572f0fe5d316183725270edefa61ddb4032ebda5cd7751e09020"},
|
||||
{file = "pyav-12.0.5-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c1d86d38b90e13250f62a258b90d6641957dab9bc069cbd4929bc7d3d017ec7"},
|
||||
{file = "pyav-12.0.5-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:ccf267724fe1472add37968ff3768e4e5629c125c1c79af957b366fbad3d2e59"},
|
||||
{file = "pyav-12.0.5-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:f7519a05b19123e074e67248ed0f5672df752852cc43505f721ec2db9f80813c"},
|
||||
{file = "pyav-12.0.5-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1ce141031338974567bc1e0504a5355449c61756626a07e3a43ded37a71afe39"},
|
||||
{file = "pyav-12.0.5-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02f77d361ef728483ffe9430391ee554257c5c0872da8a2276275636226b3a85"},
|
||||
{file = "pyav-12.0.5-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:647ebc369b1c7bfbdae626048e4d59265c3ab3ceb2e571ac83ddbbeaa70abb22"},
|
||||
{file = "pyav-12.0.5.tar.gz", hash = "sha256:fc65bcb72f3f8040c47a5b5a8025b535c71dcb16f1c8f9ff9bb3bf3af17ac09a"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pycparser"
|
||||
version = "2.22"
|
||||
|
@ -4267,4 +4305,4 @@ xarm = ["gym-xarm"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "fab42b4be590cb2007934cd8f5a218f1f3da4f0b42cdff7e7724af518888d7b4"
|
||||
content-hash = "32584053533829448b806a26a3f57712d4758f738778e67409c2e10a0bd6a0fd"
|
||||
|
|
|
@ -56,6 +56,7 @@ pytest = {version = "^8.1.0", optional = true}
|
|||
pytest-cov = {version = "^5.0.0", optional = true}
|
||||
datasets = "^2.19.0"
|
||||
imagecodecs = { version = "^2024.1.1", optional = true }
|
||||
pyav = "^12.0.5"
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
|
|
Loading…
Reference in New Issue