Fix (Now loading all frames is possible)

This commit is contained in:
Remi Cadene 2025-04-14 14:47:18 +00:00
parent 6c4d122198
commit c2a05a1fde
4 changed files with 85 additions and 56 deletions

View File

@ -36,7 +36,6 @@ from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
from lerobot.common.datasets.utils import (
DEFAULT_DATA_PATH,
DEFAULT_EPISODES_PATH,
DEFAULT_FEATURES,
DEFAULT_IMAGE_PATH,
@ -48,12 +47,10 @@ from lerobot.common.datasets.utils import (
create_lerobot_dataset_card,
embed_images,
flatten_dict,
get_chunk_file_indices,
get_delta_indices,
get_features_from_robot,
get_hf_dataset_size_in_mb,
get_hf_features_from_features,
get_latest_video_path,
get_parquet_num_frames,
get_safe_version,
get_video_duration_in_s,
@ -143,8 +140,8 @@ class LeRobotDatasetMetadata:
return Path(fpath)
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
chunk_idx = self.episodes[f"{vid_key}/chunk_index"][ep_index]
file_idx = self.episodes[f"{vid_key}/file_index"][ep_index]
chunk_idx = self.episodes[f"videos/{vid_key}/chunk_index"][ep_index]
file_idx = self.episodes[f"videos/{vid_key}/file_index"][ep_index]
fpath = self.video_path.format(video_key=vid_key, chunk_index=chunk_idx, file_index=file_idx)
return Path(fpath)
@ -243,7 +240,7 @@ class LeRobotDatasetMetadata:
self.tasks = pd.DataFrame({"task_index": task_indices}, index=tasks)
else:
new_tasks = [task for task in tasks if task not in self.tasks.index]
new_task_indices = range(len(self.tasks), len(new_tasks))
new_task_indices = range(len(self.tasks), len(self.tasks) + len(new_tasks))
for task_idx, task in zip(new_task_indices, new_tasks, strict=False):
self.tasks.loc[task] = task_idx
@ -277,8 +274,13 @@ class LeRobotDatasetMetadata:
df["meta/episodes/file_index"] = [file_idx]
else:
# Retrieve information from the latest parquet file
latest_ep = self.episodes.with_format(columns=["chunk_index", "file_index"])[-1]
chunk_idx, file_idx = latest_ep["chunk_index"], latest_ep["file_index"]
latest_ep = self.episodes.with_format(
columns=["meta/episodes/chunk_index", "meta/episodes/file_index"]
)[-1]
chunk_idx, file_idx = (
latest_ep["meta/episodes/chunk_index"],
latest_ep["meta/episodes/file_index"],
)
latest_path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
@ -294,7 +296,7 @@ class LeRobotDatasetMetadata:
df["meta/episodes/chunk_index"] = [chunk_idx]
df["meta/episodes/file_index"] = [file_idx]
latest_df = pd.read_parquet(latest_path)
latest_df = pd.concat([latest_df, df], ignore_index=True)
df = pd.concat([latest_df, df], ignore_index=True)
# Write the resulting dataframe from RAM to disk
path = self.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
@ -314,14 +316,6 @@ class LeRobotDatasetMetadata:
episode_stats: dict[str, dict],
episode_metadata: dict,
) -> None:
# Update info
self.info["total_episodes"] += 1
self.info["total_frames"] += episode_length
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
if len(self.video_keys) > 0:
self.update_video_info()
write_info(self.info, self.root)
episode_dict = {
"episode_index": episode_index,
"tasks": episode_tasks,
@ -331,6 +325,14 @@ class LeRobotDatasetMetadata:
episode_dict.update(flatten_dict({"stats": episode_stats}))
self._save_episode_metadata(episode_dict)
# Update info
self.info["total_episodes"] += 1
self.info["total_frames"] += episode_length
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
if len(self.video_keys) > 0:
self.update_video_info()
write_info(self.info, self.root)
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
# TODO: write stats
@ -747,7 +749,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Episodes are stored sequentially on a single mp4 to reduce the number of files.
# Thus we load the start timestamp of the episode on this mp4 and,
# shift the query timestamp accordingly.
from_timestamp = self.meta.episodes[f"{vid_key}/from_timestamp"][ep_idx]
from_timestamp = self.meta.episodes[f"videos/{vid_key}/from_timestamp"][ep_idx]
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
@ -977,7 +979,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
latest_ep = self.meta.episodes.with_format(columns=["data/chunk_index", "data/file_index"])[-1]
chunk_idx, file_idx = latest_ep["data/chunk_index"], latest_ep["data/file_index"]
latest_path = self.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
latest_path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
latest_size_in_mb = get_parquet_file_size_in_mb(latest_path)
latest_num_frames = get_parquet_num_frames(latest_path)
@ -1018,30 +1020,52 @@ class LeRobotDataset(torch.utils.data.Dataset):
ep_size_in_mb = get_video_size_in_mb(ep_path)
ep_duration_in_s = get_video_duration_in_s(ep_path)
# Access latest video file information
latest_path = get_latest_video_path(self.root / "videos", video_key)
if self.meta.episodes is None:
# Initialize indices for a new dataset made of the first episode data
chunk_idx, file_idx = 0, 0
latest_duration_in_s = 0
new_path = self.root / self.meta.video_path.format(
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
)
new_path.parent.mkdir(parents=True, exist_ok=True)
shutil.move(str(ep_path), str(new_path))
else:
# Retrieve information from the latest video file
latest_ep = self.meta.episodes.with_format(
columns=[f"videos/{video_key}/chunk_index", f"videos/{video_key}/file_index"]
)[-1]
chunk_idx, file_idx = (
latest_ep[f"videos/{video_key}/chunk_index"],
latest_ep[f"videos/{video_key}/file_index"],
)
latest_path = self.root / self.meta.video_path.format(
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
)
latest_size_in_mb = get_video_size_in_mb(latest_path)
latest_duration_in_s = get_video_duration_in_s(latest_path)
chunk_idx, file_idx = get_chunk_file_indices(latest_path)
if latest_size_in_mb + ep_size_in_mb >= self.meta.files_size_in_mb:
# Move temporary episode video to a new video file in the dataset
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
new_path = self.meta.video_path.format(
new_path = self.root / self.meta.video_path.format(
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
)
new_path.parent.mkdir(parents=True, exist_ok=True)
ep_path.replace(new_path)
shutil.move(str(ep_path), str(new_path))
else:
# Update latest video file
concat_video_files([latest_path, ep_path], self.root, video_key, chunk_idx, file_idx)
# Remove temporary directory
shutil.rmtree(str(ep_path.parent))
metadata = {
"episode_index": episode_index,
f"{video_key}/chunk_index": chunk_idx,
f"{video_key}/file_index": file_idx,
f"{video_key}/from_timestamp": latest_duration_in_s,
f"{video_key}/to_timestamp": latest_duration_in_s + ep_duration_in_s,
f"videos/{video_key}/chunk_index": chunk_idx,
f"videos/{video_key}/file_index": file_idx,
f"videos/{video_key}/from_timestamp": latest_duration_in_s,
f"videos/{video_key}/to_timestamp": latest_duration_in_s + ep_duration_in_s,
}
return metadata
@ -1099,7 +1123,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
since video encoding with ffmpeg is already using multithreading.
"""
temp_path = Path(tempfile.mkdtemp()) / f"{video_key}_{episode_index:3d}.mp4"
temp_path = Path(tempfile.mkdtemp(dir=self.root)) / f"{video_key}_{episode_index:03d}.mp4"
img_dir = self._get_image_file_dir(episode_index, video_key)
encode_video_frames(img_dir, temp_path, self.fps, overwrite=True)
return temp_path

View File

@ -17,6 +17,7 @@ import contextlib
import importlib.resources
import json
import logging
import shutil
import subprocess
import tempfile
from collections.abc import Iterator
@ -155,17 +156,15 @@ def get_video_size_in_mb(mp4_path: Path):
return file_size_mb
def concat_video_files(paths_to_cat, new_root, video_key, chunk_idx, file_idx):
def concat_video_files(paths_to_cat, root, video_key, chunk_idx, file_idx):
tmp_dir = Path(tempfile.mkdtemp(dir=root))
# Create a text file with the list of files to concatenate
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as f:
temp_file_path = f.name
path_concat_video_files = tmp_dir / "concat_video_files.txt"
with open(path_concat_video_files, "w") as f:
for ep_path in paths_to_cat:
f.write(f"file '{str(ep_path)}'\n")
output_path = new_root / DEFAULT_VIDEO_PATH.format(
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
)
output_path.parent.mkdir(parents=True, exist_ok=True)
path_tmp_output = tmp_dir / "tmp_output.mp4"
command = [
"ffmpeg",
"-y",
@ -174,13 +173,19 @@ def concat_video_files(paths_to_cat, new_root, video_key, chunk_idx, file_idx):
"-safe",
"0",
"-i",
str(temp_file_path),
str(path_concat_video_files),
"-c",
"copy",
str(output_path),
str(path_tmp_output),
]
subprocess.run(command, check=True)
Path(temp_file_path).unlink()
output_path = root / DEFAULT_VIDEO_PATH.format(
video_key=video_key, chunk_index=chunk_idx, file_index=file_idx
)
output_path.parent.mkdir(parents=True, exist_ok=True)
shutil.move(str(path_tmp_output), str(output_path))
shutil.rmtree(str(tmp_dir))
def get_video_duration_in_s(mp4_file: Path):
@ -192,7 +197,7 @@ def get_video_duration_in_s(mp4_file: Path):
"format=duration",
"-of",
"default=noprint_wrappers=1:nokey=1",
mp4_file,
str(mp4_file),
]
result = subprocess.run(
command,

View File

@ -235,10 +235,10 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key):
ep_duration_in_s = get_video_duration_in_s(ep_path)
ep_metadata = {
"episode_index": ep_idx,
f"{video_key}/chunk_index": chunk_idx,
f"{video_key}/file_index": file_idx,
f"{video_key}/from_timestamp": duration_in_s,
f"{video_key}/to_timestamp": duration_in_s + ep_duration_in_s,
f"videos/{video_key}/chunk_index": chunk_idx,
f"videos/{video_key}/file_index": file_idx,
f"videos/{video_key}/from_timestamp": duration_in_s,
f"videos/{video_key}/to_timestamp": duration_in_s + ep_duration_in_s,
}
size_in_mb += ep_size_in_mb
duration_in_s += ep_duration_in_s

View File

@ -235,8 +235,8 @@ def episodes_factory(tasks_factory, stats_factory):
}
if video_keys is not None:
for video_key in video_keys:
d[f"{video_key}/chunk_index"] = []
d[f"{video_key}/file_index"] = []
d[f"videos/{video_key}/chunk_index"] = []
d[f"videos/{video_key}/file_index"] = []
for stats_key in flatten_dict({"stats": stats_factory(features)}):
d[stats_key] = []
@ -261,8 +261,8 @@ def episodes_factory(tasks_factory, stats_factory):
if video_keys is not None:
for video_key in video_keys:
d[f"{video_key}/chunk_index"].append(0)
d[f"{video_key}/file_index"].append(0)
d[f"videos/{video_key}/chunk_index"].append(0)
d[f"videos/{video_key}/file_index"].append(0)
# Add stats columns like "stats/action/max"
for stats_key, stats in flatten_dict({"stats": stats_factory(features)}).items():