Add local_files_only, encode_videos, fix bugs to pass tests (WIP)

This commit is contained in:
Simon Alibert 2024-10-22 19:57:52 +02:00
parent e991a31061
commit a805458c7e
4 changed files with 183 additions and 80 deletions

View File

@ -17,6 +17,7 @@ import json
import logging
import os
import shutil
from functools import cached_property
from pathlib import Path
from typing import Callable
@ -30,20 +31,32 @@ from huggingface_hub import snapshot_download, upload_folder
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
from lerobot.common.datasets.image_writer import ImageWriter
from lerobot.common.datasets.utils import (
EPISODES_PATH,
INFO_PATH,
TASKS_PATH,
append_jsonl,
check_delta_timestamps,
check_timestamps_sync,
check_version_compatibility,
create_branch,
create_empty_dataset_info,
flatten_dict,
get_delta_indices,
get_episode_data_index,
get_hub_safe_version,
hf_transform_to_torch,
load_metadata,
load_episode_dicts,
load_info,
load_stats,
load_tasks,
unflatten_dict,
write_json,
)
from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision
from lerobot.common.datasets.video_utils import (
VideoFrame,
decode_video_frames_torchvision,
encode_video_frames,
)
from lerobot.common.robot_devices.robots.utils import Robot
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
@ -61,6 +74,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
delta_timestamps: dict[list[float]] | None = None,
tolerance_s: float = 1e-4,
download_videos: bool = True,
local_files_only: bool = False,
video_backend: str | None = None,
image_writer: ImageWriter | None = None,
):
@ -162,21 +176,26 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.delta_timestamps = delta_timestamps
self.episodes = episodes
self.tolerance_s = tolerance_s
self.download_videos = download_videos
self.video_backend = video_backend if video_backend is not None else "pyav"
self.image_writer = image_writer
self.delta_indices = None
self.consolidated = True
self.episode_buffer = {}
self.local_files_only = local_files_only
# Load metadata
self.root.mkdir(exist_ok=True, parents=True)
self._version = get_hub_safe_version(repo_id, CODEBASE_VERSION)
self.pull_from_repo(allow_patterns="meta/")
self.info, self.episode_dicts, self.stats, self.tasks = load_metadata(self.root)
self.info = load_info(self.root)
self.stats = load_stats(self.root)
self.tasks = load_tasks(self.root)
self.episode_dicts = load_episode_dicts(self.root)
# Check version
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
# Load actual data
self.download_episodes()
self.download_episodes(download_videos)
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
@ -199,6 +218,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
# - [ ] Update episode_index (arg update=True)
# - [ ] Update info.json (arg update=True)
@cached_property
def _hub_version(self) -> str | None:
return None if self.local_files_only else get_hub_safe_version(self.repo_id, CODEBASE_VERSION)
@property
def _version(self) -> str:
"""Codebase version used to create this dataset."""
return self.info["codebase_version"]
def push_to_repo(self, push_videos: bool = True) -> None:
if not self.consolidated:
raise RuntimeError(
@ -225,13 +253,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self._version,
revision=self._hub_version,
local_dir=self.root,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
local_files_only=self.local_files_only,
)
def download_episodes(self) -> None:
def download_episodes(self, download_videos: bool = True) -> None:
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
@ -240,10 +269,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
# TODO(rcadene, aliberts): implement faster transfer
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
files = None
ignore_patterns = None if self.download_videos else "videos/"
ignore_patterns = None if download_videos else "videos/"
if self.episodes is not None:
files = [self.get_data_file_path(ep_idx) for ep_idx in self.episodes]
if len(self.video_keys) > 0 and self.download_videos:
if len(self.video_keys) > 0 and download_videos:
video_files = [
self.get_video_file_path(ep_idx, vid_key)
for vid_key in self.video_keys
@ -495,7 +524,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
item = {**video_frames, **item}
if self.image_transforms is not None:
image_keys = self.camera_keys if self.download_videos else self.image_keys
image_keys = self.camera_keys
for cam in image_keys:
item[cam] = self.image_transforms(item[cam])
@ -521,6 +550,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
"timestamp": [],
"next.done": [],
**{key: [] for key in self.keys},
**{key: [] for key in self.image_keys},
}
def add_frame(self, frame: dict) -> None:
@ -553,6 +583,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
image=frame[cam_key],
file_path=img_path,
)
if cam_key in self.image_keys:
self.episode_buffer[cam_key].append(str(img_path))
def add_episode(self, task: str, encode_videos: bool = False) -> None:
"""
@ -574,6 +606,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episode_buffer["next.done"][-1] = True
for key in self.episode_buffer:
if key in self.image_keys:
continue
if key in self.keys:
self.episode_buffer[key] = torch.stack(self.episode_buffer[key])
elif key == "episode_index":
@ -583,11 +617,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
else:
self.episode_buffer[key] = torch.tensor(self.episode_buffer[key])
self.episode_buffer["index"] = torch.arange(self.total_frames, self.total_frames + episode_length)
self._save_episode_to_metadata(episode_index, episode_length, task, task_index)
self._save_episode_table(episode_index)
if encode_videos:
pass # TODO
if encode_videos and len(self.video_keys) > 0:
self.encode_videos()
# Reset the buffer
self.episode_buffer = self._create_episode_buffer()
@ -614,7 +649,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
"task_index": task_index,
"task": task,
}
append_jsonl(task_dict, self.root / "meta/tasks.jsonl")
append_jsonl(task_dict, self.root / TASKS_PATH)
chunk = self.get_episode_chunk(episode_index)
if chunk >= self.total_chunks:
@ -622,22 +657,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
self.info["total_videos"] += len(self.video_keys)
write_json(self.info, self.root / "meta/info.json")
write_json(self.info, self.root / INFO_PATH)
episode_dict = {
"episode_index": episode_index,
"tasks": [task],
"length": episode_length,
}
append_jsonl(episode_dict, self.root / "meta/episodes.jsonl")
self.episode_dicts.append(episode_dict)
append_jsonl(episode_dict, self.root / EPISODES_PATH)
def delete_episode(self) -> None:
episode_index = self.episode_buffer["episode_index"]
if self.image_writer is not None:
for cam_key in self.camera_keys:
cam_dir = self.image_writer.get_episode_dir(episode_index, cam_key)
if cam_dir.is_dir():
shutil.rmtree(cam_dir)
img_dir = self.image_writer.get_episode_dir(episode_index, cam_key, return_str=False)
if img_dir.is_dir():
shutil.rmtree(img_dir)
# Reset the buffer
self.episode_buffer = self._create_episode_buffer()
@ -653,27 +689,54 @@ class LeRobotDataset(torch.utils.data.Dataset):
updated_file_name = self.get_data_file_path(ep_idx)
current_file_name.rename(updated_file_name)
def _remove_image_writer(self) -> None:
if self.image_writer is not None:
self.image_writer = None
def encode_videos(self) -> None:
# Use ffmpeg to convert frames stored as png into mp4 videos
for episode_index in range(self.num_episodes):
for key in self.video_keys:
# TODO: create video_buffer to store the state of encoded/unencoded videos and remove the need
# to call self.image_writer here
tmp_imgs_dir = self.image_writer.get_episode_dir(episode_index, key)
video_path = self.get_video_file_path(episode_index, key, return_str=False)
if video_path.is_file():
# Skip if video is already encoded. Could be the case when resuming data recording.
continue
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
# since video encoding with ffmpeg is already using multithreading.
encode_video_frames(tmp_imgs_dir, video_path, self.fps, overwrite=True)
shutil.rmtree(tmp_imgs_dir)
def consolidate(self, run_compute_stats: bool = True) -> None:
self._update_data_file_names()
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
if len(self.video_keys) > 0:
self.encode_videos()
if run_compute_stats:
logging.info("Computing dataset statistics")
self.hf_dataset = self.load_hf_dataset()
self._remove_image_writer()
self.stats = compute_stats(self)
serialized_stats = {key: value.tolist() for key, value in self.stats.items()}
serialized_stats = flatten_dict(self.stats)
serialized_stats = {key: value.tolist() for key, value in serialized_stats.items()}
serialized_stats = unflatten_dict(serialized_stats)
write_json(serialized_stats, self.root / "meta/stats.json")
self.consolidated = True
else:
logging.warning("Skipping computation of the dataset statistics.")
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
pass # TODO
# TODO(aliberts)
# Sanity checks:
# - [ ] shapes
# - [ ] ep_lenghts
# - [ ] number of files
# - [ ] names of files (e.g. parquet 00000-of-00001 and 00001-of-00002)
# - [ ] no remaining self.image_writer.dir
self.consolidated = True
@classmethod
def create(
@ -691,7 +754,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj = cls.__new__(cls)
obj.repo_id = repo_id
obj.root = root if root is not None else LEROBOT_HOME / repo_id
obj._version = CODEBASE_VERSION
obj.tolerance_s = tolerance_s
obj.image_writer = image_writer
@ -702,21 +764,26 @@ class LeRobotDataset(torch.utils.data.Dataset):
)
obj.tasks, obj.stats, obj.episode_dicts = {}, {}, []
obj.info = create_empty_dataset_info(obj._version, fps, robot, use_videos)
write_json(obj.info, obj.root / "meta/info.json")
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot, use_videos)
write_json(obj.info, obj.root / INFO_PATH)
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
obj.episode_buffer = obj._create_episode_buffer()
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk.
# It is used to know when certain operations are need (for instance, computing dataset statistics).
# In order to be able to push the dataset to the hub, it needs to be consolidation first.
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It
# is used to know when certain operations are need (for instance, computing dataset statistics). In
# order to be able to push the dataset to the hub, it needs to be consolidated first by calling
# self.consolidate().
obj.consolidated = True
obj.local_files_only = True
obj.download_videos = False
obj.episodes = None
obj.hf_dataset = None
obj.image_transforms = None
obj.delta_timestamps = None
obj.delta_indices = None
obj.episode_data_index = None
obj.video_backend = video_backend if video_backend is not None else "pyav"
return obj

View File

@ -30,6 +30,12 @@ from torchvision import transforms
from lerobot.common.robot_devices.robots.utils import Robot
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
INFO_PATH = "meta/info.json"
EPISODES_PATH = "meta/episodes.jsonl"
STATS_PATH = "meta/stats.json"
TASKS_PATH = "meta/tasks.jsonl"
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
DEFAULT_PARQUET_PATH = (
"data/chunk-{episode_chunk:03d}/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
@ -104,6 +110,32 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
return items_dict
def _get_major_minor(version: str) -> tuple[int]:
split = version.strip("v").split(".")
return int(split[0]), int(split[1])
def check_version_compatibility(
repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True
) -> None:
current_major, _ = _get_major_minor(current_version)
major_to_check, _ = _get_major_minor(version_to_check)
if major_to_check < current_major and enforce_breaking_major:
raise ValueError(
f"""The dataset you requested ({repo_id}) is in {version_to_check} format. We introduced a new
format with v2.0 that is not backward compatible. Please use our conversion script
first (convert_dataset_v1_to_v2.py) to convert your dataset to this new format."""
)
elif float(version_to_check.strip("v")) < float(current_version.strip("v")):
warnings.warn(
f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the
codebase. The current codebase version is {current_version}. You should be fine since
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
stacklevel=1,
)
def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) -> str:
num_version = float(version.strip("v"))
if num_version < 2 and enforce_v2:
@ -131,30 +163,28 @@ def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) ->
return version
def load_metadata(local_dir: Path) -> tuple[dict | list]:
"""Loads metadata files from a dataset."""
info_path = local_dir / "meta/info.json"
episodes_path = local_dir / "meta/episodes.jsonl"
stats_path = local_dir / "meta/stats.json"
tasks_path = local_dir / "meta/tasks.jsonl"
def load_info(local_dir: Path) -> dict:
with open(local_dir / INFO_PATH) as f:
return json.load(f)
with open(info_path) as f:
info = json.load(f)
with jsonlines.open(episodes_path, "r") as reader:
episode_dicts = list(reader)
with open(stats_path) as f:
def load_stats(local_dir: Path) -> dict:
with open(local_dir / STATS_PATH) as f:
stats = json.load(f)
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats)
with jsonlines.open(tasks_path, "r") as reader:
def load_tasks(local_dir: Path) -> dict:
with jsonlines.open(local_dir / TASKS_PATH, "r") as reader:
tasks = list(reader)
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
stats = unflatten_dict(stats)
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
return info, episode_dicts, stats, tasks
def load_episode_dicts(local_dir: Path) -> dict:
with jsonlines.open(local_dir / EPISODES_PATH, "r") as reader:
return list(reader)
def create_empty_dataset_info(codebase_version: str, fps: int, robot: Robot, use_videos: bool = True) -> dict:
@ -229,7 +259,7 @@ def check_timestamps_sync(
# Track original indices before masking
original_indices = torch.arange(len(diffs))
filtered_indices = original_indices[mask]
outside_tolerance_filtered_indices = torch.nonzero(~filtered_within_tolerance).squeeze()
outside_tolerance_filtered_indices = torch.nonzero(~filtered_within_tolerance) # .squeeze()
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
episode_indices = torch.stack(hf_dataset["episode_index"])

View File

@ -126,8 +126,8 @@ def decode_video_frames_torchvision(
def encode_video_frames(
imgs_dir: Path,
video_path: Path,
imgs_dir: Path | str,
video_path: Path | str,
fps: int,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",

View File

@ -194,19 +194,17 @@ def record(
pretrained_policy_name_or_path: str | None = None,
policy_overrides: List[str] | None = None,
fps: int | None = None,
warmup_time_s=2,
episode_time_s=10,
reset_time_s=5,
num_episodes=50,
video=True,
run_compute_stats=True,
push_to_hub=True,
tags=None,
num_image_writer_processes=0,
num_image_writer_threads_per_camera=4,
force_override=False,
display_cameras=True,
play_sounds=True,
warmup_time_s: int | float = 2,
episode_time_s: int | float = 10,
reset_time_s: int | float = 5,
num_episodes: int = 50,
video: bool = True,
run_compute_stats: bool = True,
push_to_hub: bool = True,
num_image_writer_processes: int = 0,
num_image_writer_threads_per_camera: int = 4,
display_cameras: bool = True,
play_sounds: bool = True,
) -> LeRobotDataset:
# TODO(rcadene): Add option to record logs
listener = None
@ -234,12 +232,18 @@ def record(
# Create empty dataset or load existing saved episodes
sanity_check_dataset_name(repo_id, policy)
image_writer = ImageWriter(
write_dir=root,
num_processes=num_image_writer_processes,
num_threads=num_image_writer_threads_per_camera * robot.num_cameras,
if len(robot.cameras) > 0:
image_writer = ImageWriter(
write_dir=root,
num_processes=num_image_writer_processes,
num_threads=num_image_writer_threads_per_camera * robot.num_cameras,
)
else:
image_writer = None
dataset = LeRobotDataset.create(
repo_id, fps, robot, root=root, image_writer=image_writer, use_videos=video
)
dataset = LeRobotDataset.create(repo_id, fps, robot, root=root, image_writer=image_writer)
if not robot.is_connected:
robot.connect()
@ -307,8 +311,9 @@ def record(
log_say("Stop recording", play_sounds, blocking=True)
stop_recording(robot, listener, display_cameras)
logging.info("Waiting for image writer to terminate...")
dataset.image_writer.stop()
if dataset.image_writer is not None:
logging.info("Waiting for image writer to terminate...")
dataset.image_writer.stop()
dataset.consolidate(run_compute_stats)
@ -322,27 +327,28 @@ def record(
@safe_disconnect
def replay(
robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug", play_sounds=True
robot: Robot,
root: Path,
repo_id: str,
episode: int,
fps: int | None = None,
play_sounds: bool = True,
local_files_only: bool = True,
):
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
# TODO(rcadene): Add option to record logs
local_dir = Path(root) / repo_id
if not local_dir.exists():
raise ValueError(local_dir)
dataset = LeRobotDataset(repo_id, root=root)
items = dataset.hf_dataset.select_columns("action")
from_idx = dataset.episode_data_index["from"][episode].item()
to_idx = dataset.episode_data_index["to"][episode].item()
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
actions = dataset.hf_dataset.select_columns("action")
if not robot.is_connected:
robot.connect()
log_say("Replaying episode", play_sounds, blocking=True)
for idx in range(from_idx, to_idx):
for idx in range(dataset.num_samples):
start_episode_t = time.perf_counter()
action = items[idx]["action"]
action = actions[idx]["action"]
robot.send_action(action)
dt_s = time.perf_counter() - start_episode_t