Fix (Now loading all frames is possible)

This commit is contained in:
Remi Cadene 2025-04-14 14:47:18 +00:00
parent 6c4d122198
commit c2a05a1fde
4 changed files with 85 additions and 56 deletions

View File

@ -36,7 +36,6 @@ from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
DEFAULT_DATA_PATH,
DEFAULT_EPISODES_PATH, DEFAULT_EPISODES_PATH,
DEFAULT_FEATURES, DEFAULT_FEATURES,
DEFAULT_IMAGE_PATH, DEFAULT_IMAGE_PATH,
@ -48,12 +47,10 @@ from lerobot.common.datasets.utils import (
create_lerobot_dataset_card, create_lerobot_dataset_card,
embed_images, embed_images,
flatten_dict, flatten_dict,
get_chunk_file_indices,
get_delta_indices, get_delta_indices,
get_features_from_robot, get_features_from_robot,
get_hf_dataset_size_in_mb, get_hf_dataset_size_in_mb,
get_hf_features_from_features, get_hf_features_from_features,
get_latest_video_path,
get_parquet_num_frames, get_parquet_num_frames,
get_safe_version, get_safe_version,
get_video_duration_in_s, get_video_duration_in_s,
@ -143,8 +140,8 @@ class LeRobotDatasetMetadata:
return Path(fpath) return Path(fpath)
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path: def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
chunk_idx = self.episodes[f"{vid_key}/chunk_index"][ep_index] chunk_idx = self.episodes[f"videos/{vid_key}/chunk_index"][ep_index]
file_idx = self.episodes[f"{vid_key}/file_index"][ep_index] file_idx = self.episodes[f"videos/{vid_key}/file_index"][ep_index]
fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx) fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx)
return Path(fpath) return Path(fpath)
@ -243,7 +240,7 @@ class LeRobotDatasetMetadata:
self.tasks = pd.DataFrame({"task_index": task_indices}, index=tasks) self.tasks = pd.DataFrame({"task_index": task_indices}, index=tasks)
else: else:
new_tasks = [task for task in tasks if task not in self.tasks.index] new_tasks = [task for task in tasks if task not in self.tasks.index]
new_task_indices = range(len(self.tasks), len(new_tasks)) new_task_indices = range(len(self.tasks), len(self.tasks) + len(new_tasks))
for task_idx, task in zip(new_task_indices, new_tasks, strict=False): for task_idx, task in zip(new_task_indices, new_tasks, strict=False):
self.tasks.loc[task] = task_idx self.tasks.loc[task] = task_idx
@ -277,8 +274,13 @@ class LeRobotDatasetMetadata:
df["meta/episodes/file_index"] = [file_idx] df["meta/episodes/file_index"] = [file_idx]
else: else:
# Retrieve information from the latest parquet file # Retrieve information from the latest parquet file
latest_ep = self.episodes.with_format(columns=["chunk_index", "file_index"])[-1] latest_ep = self.episodes.with_format(
chunk_idx, file_idx = latest_ep["chunk_index"], latest_ep["file_index"] columns=["meta/episodes/chunk_index", "meta/episodes/file_index"]
)[-1]
chunk_idx, file_idx = (
latest_ep["meta/episodes/chunk_index"],
latest_ep["meta/episodes/file_index"],
)
latest_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) latest_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path) latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
@ -294,7 +296,7 @@ class LeRobotDatasetMetadata:
df["meta/episodes/chunk_index"] = [chunk_idx] df["meta/episodes/chunk_index"] = [chunk_idx]
df["meta/episodes/file_index"] = [file_idx] df["meta/episodes/file_index"] = [file_idx]
latest_df = pd.read_parquet(latest_path) latest_df = pd.read_parquet(latest_path)
latest_df = pd.concat([latest_df, df], ignore_index=True) df = pd.concat([latest_df, df], ignore_index=True)
# Write the resulting dataframe from RAM to disk # Write the resulting dataframe from RAM to disk
path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
@ -314,14 +316,6 @@ class LeRobotDatasetMetadata:
episode_stats: dict[str, dict], episode_stats: dict[str, dict],
episode_metadata: dict, episode_metadata: dict,
) -> None: ) -> None:
# Update info
self.info["total_episodes"] += 1
self.info["total_frames"] += episode_length
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
if len(self.video_keys) > 0:
self.update_video_info()
write_info(self.info, self.root)
episode_dict = { episode_dict = {
"episode_index": episode_index, "episode_index": episode_index,
"tasks": episode_tasks, "tasks": episode_tasks,
@ -331,6 +325,14 @@ class LeRobotDatasetMetadata:
episode_dict.update(flatten_dict({"stats": episode_stats})) episode_dict.update(flatten_dict({"stats": episode_stats}))
self._save_episode_metadata(episode_dict) self._save_episode_metadata(episode_dict)
# Update info
self.info["total_episodes"] += 1
self.info["total_frames"] += episode_length
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
if len(self.video_keys) > 0:
self.update_video_info()
write_info(self.info, self.root)
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
# TODO: write stats # TODO: write stats
@ -747,7 +749,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Episodes are stored sequentially on a single mp4 to reduce the number of files. # Episodes are stored sequentially on a single mp4 to reduce the number of files.
# Thus we load the start timestamp of the episode on this mp4 and, # Thus we load the start timestamp of the episode on this mp4 and,
# shift the query timestamp accordingly. # shift the query timestamp accordingly.
from_timestamp = self.meta.episodes[f"{vid_key}/from_timestamp"][ep_idx] from_timestamp = self.meta.episodes[f"videos/{vid_key}/from_timestamp"][ep_idx]
shifted_query_ts = [from_timestamp + ts for ts in query_ts] shifted_query_ts = [from_timestamp + ts for ts in query_ts]
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key) video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
@ -977,7 +979,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
latest_ep = self.meta.episodes.with_format(columns=["data/chunk_index", "data/file_index"])[-1] latest_ep = self.meta.episodes.with_format(columns=["data/chunk_index", "data/file_index"])[-1]
chunk_idx, file_idx = latest_ep["data/chunk_index"], latest_ep["data/file_index"] chunk_idx, file_idx = latest_ep["data/chunk_index"], latest_ep["data/file_index"]
latest_path = self.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx) latest_path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path) latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
latest_num_frames = get_parquet_num_frames(latest_path) latest_num_frames = get_parquet_num_frames(latest_path)
@ -1018,30 +1020,52 @@ class LeRobotDataset(torch.utils.data.Dataset):
ep_size_in_mb = get_video_size_in_mb(ep_path) ep_size_in_mb = get_video_size_in_mb(ep_path)
ep_duration_in_s = get_video_duration_in_s(ep_path) ep_duration_in_s = get_video_duration_in_s(ep_path)
# Access latest video file information if self.meta.episodes is None:
latest_path = get_latest_video_path(self.root / "videos", video_key) # Initialize indices for a new dataset made of the first episode data
latest_size_in_mb = get_video_size_in_mb(latest_path) chunk_idx, file_idx = 0, 0
latest_duration_in_s = get_video_duration_in_s(latest_path) latest_duration_in_s = 0
chunk_idx, file_idx = get_chunk_file_indices(latest_path) new_path = self.root / self.meta.video_path.format(
if latest_size_in_mb + ep_size_in_mb >= self.meta.files_size_in_mb:
# Move temporary episode video to a new video file in the dataset
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
new_path = self.meta.video_path.format(
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
) )
new_path.parent.mkdir(parents=True, exist_ok=True) new_path.parent.mkdir(parents=True, exist_ok=True)
ep_path.replace(new_path) shutil.move(str(ep_path), str(new_path))
else: else:
# Update latest video file # Retrieve information from the latest video file
concat_video_files([latest_path, ep_path], self.root, video_key, chunk_idx, file_idx) latest_ep = self.meta.episodes.with_format(
columns=[f"videos/{video_key}/chunk_index", f"videos/{video_key}/file_index"]
)[-1]
chunk_idx, file_idx = (
latest_ep[f"videos/{video_key}/chunk_index"],
latest_ep[f"videos/{video_key}/file_index"],
)
latest_path = self.root / self.meta.video_path.format(
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
)
latest_size_in_mb = get_video_size_in_mb(latest_path)
latest_duration_in_s = get_video_duration_in_s(latest_path)
if latest_size_in_mb + ep_size_in_mb >= self.meta.files_size_in_mb:
# Move temporary episode video to a new video file in the dataset
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
new_path = self.root / self.meta.video_path.format(
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
)
new_path.parent.mkdir(parents=True, exist_ok=True)
shutil.move(str(ep_path), str(new_path))
else:
# Update latest video file
concat_video_files([latest_path, ep_path], self.root, video_key, chunk_idx, file_idx)
# Remove temporary directory
shutil.rmtree(str(ep_path.parent))
metadata = { metadata = {
"episode_index": episode_index, "episode_index": episode_index,
f"{video_key}/chunk_index": chunk_idx, f"videos/{video_key}/chunk_index": chunk_idx,
f"{video_key}/file_index": file_idx, f"videos/{video_key}/file_index": file_idx,
f"{video_key}/from_timestamp": latest_duration_in_s, f"videos/{video_key}/from_timestamp": latest_duration_in_s,
f"{video_key}/to_timestamp": latest_duration_in_s + ep_duration_in_s, f"videos/{video_key}/to_timestamp": latest_duration_in_s + ep_duration_in_s,
} }
return metadata return metadata
@ -1099,7 +1123,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
since video encoding with ffmpeg is already using multithreading. since video encoding with ffmpeg is already using multithreading.
""" """
temp_path = Path(tempfile.mkdtemp()) / f"{video_key}_{episode_index:3d}.mp4" temp_path = Path(tempfile.mkdtemp(dir=self.root)) / f"{video_key}_{episode_index:03d}.mp4"
img_dir = self._get_image_file_dir(episode_index, video_key) img_dir = self._get_image_file_dir(episode_index, video_key)
encode_video_frames(img_dir, temp_path, self.fps, overwrite=True) encode_video_frames(img_dir, temp_path, self.fps, overwrite=True)
return temp_path return temp_path

View File

@ -17,6 +17,7 @@ import contextlib
import importlib.resources import importlib.resources
import json import json
import logging import logging
import shutil
import subprocess import subprocess
import tempfile import tempfile
from collections.abc import Iterator from collections.abc import Iterator
@ -155,17 +156,15 @@ def get_video_size_in_mb(mp4_path: Path):
return file_size_mb return file_size_mb
def concat_video_files(paths_to_cat, new_root, video_key, chunk_idx, file_idx): def concat_video_files(paths_to_cat, root, video_key, chunk_idx, file_idx):
tmp_dir = Path(tempfile.mkdtemp(dir=root))
# Create a text file with the list of files to concatenate # Create a text file with the list of files to concatenate
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as f: path_concat_video_files = tmp_dir / "concat_video_files.txt"
temp_file_path = f.name with open(path_concat_video_files, "w") as f:
for ep_path in paths_to_cat: for ep_path in paths_to_cat:
f.write(f"file '{str(ep_path)}'\n") f.write(f"file '{str(ep_path)}'\n")
output_path = new_root / DEFAULT_VIDEO_PATH.format( path_tmp_output = tmp_dir / "tmp_output.mp4"
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
)
output_path.parent.mkdir(parents=True, exist_ok=True)
command = [ command = [
"ffmpeg", "ffmpeg",
"-y", "-y",
@ -174,13 +173,19 @@ def concat_video_files(paths_to_cat, new_root, video_key, chunk_idx, file_idx):
"-safe", "-safe",
"0", "0",
"-i", "-i",
str(temp_file_path), str(path_concat_video_files),
"-c", "-c",
"copy", "copy",
str(output_path), str(path_tmp_output),
] ]
subprocess.run(command, check=True) subprocess.run(command, check=True)
Path(temp_file_path).unlink()
output_path = root / DEFAULT_VIDEO_PATH.format(
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
)
output_path.parent.mkdir(parents=True, exist_ok=True)
shutil.move(str(path_tmp_output), str(output_path))
shutil.rmtree(str(tmp_dir))
def get_video_duration_in_s(mp4_file: Path): def get_video_duration_in_s(mp4_file: Path):
@ -192,7 +197,7 @@ def get_video_duration_in_s(mp4_file: Path):
"format=duration", "format=duration",
"-of", "-of",
"default=noprint_wrappers=1:nokey=1", "default=noprint_wrappers=1:nokey=1",
mp4_file, str(mp4_file),
] ]
result = subprocess.run( result = subprocess.run(
command, command,

View File

@ -235,10 +235,10 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key):
ep_duration_in_s = get_video_duration_in_s(ep_path) ep_duration_in_s = get_video_duration_in_s(ep_path)
ep_metadata = { ep_metadata = {
"episode_index": ep_idx, "episode_index": ep_idx,
f"{video_key}/chunk_index": chunk_idx, f"videos/{video_key}/chunk_index": chunk_idx,
f"{video_key}/file_index": file_idx, f"videos/{video_key}/file_index": file_idx,
f"{video_key}/from_timestamp": duration_in_s, f"videos/{video_key}/from_timestamp": duration_in_s,
f"{video_key}/to_timestamp": duration_in_s + ep_duration_in_s, f"videos/{video_key}/to_timestamp": duration_in_s + ep_duration_in_s,
} }
size_in_mb += ep_size_in_mb size_in_mb += ep_size_in_mb
duration_in_s += ep_duration_in_s duration_in_s += ep_duration_in_s

View File

@ -235,8 +235,8 @@ def episodes_factory(tasks_factory, stats_factory):
} }
if video_keys is not None: if video_keys is not None:
for video_key in video_keys: for video_key in video_keys:
d[f"{video_key}/chunk_index"] = [] d[f"videos/{video_key}/chunk_index"] = []
d[f"{video_key}/file_index"] = [] d[f"videos/{video_key}/file_index"] = []
for stats_key in flatten_dict({"stats": stats_factory(features)}): for stats_key in flatten_dict({"stats": stats_factory(features)}):
d[stats_key] = [] d[stats_key] = []
@ -261,8 +261,8 @@ def episodes_factory(tasks_factory, stats_factory):
if video_keys is not None: if video_keys is not None:
for video_key in video_keys: for video_key in video_keys:
d[f"{video_key}/chunk_index"].append(0) d[f"videos/{video_key}/chunk_index"].append(0)
d[f"{video_key}/file_index"].append(0) d[f"videos/{video_key}/file_index"].append(0)
# Add stats columns like "stats/action/max" # Add stats columns like "stats/action/max"
for stats_key, stats in flatten_dict({"stats": stats_factory(features)}).items(): for stats_key, stats in flatten_dict({"stats": stats_factory(features)}).items():