Add local_files_only, encode_videos, fix bugs to pass tests (WIP)

This commit is contained in:
Simon Alibert 2024-10-22 19:57:52 +02:00
parent e991a31061
commit a805458c7e
4 changed files with 183 additions and 80 deletions

View File

@ -17,6 +17,7 @@ import json
import logging import logging
import os import os
import shutil import shutil
from functools import cached_property
from pathlib import Path from pathlib import Path
from typing import Callable from typing import Callable
@ -30,20 +31,32 @@ from huggingface_hub import snapshot_download, upload_folder
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
from lerobot.common.datasets.image_writer import ImageWriter from lerobot.common.datasets.image_writer import ImageWriter
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
EPISODES_PATH,
INFO_PATH,
TASKS_PATH,
append_jsonl, append_jsonl,
check_delta_timestamps, check_delta_timestamps,
check_timestamps_sync, check_timestamps_sync,
check_version_compatibility,
create_branch, create_branch,
create_empty_dataset_info, create_empty_dataset_info,
flatten_dict,
get_delta_indices, get_delta_indices,
get_episode_data_index, get_episode_data_index,
get_hub_safe_version, get_hub_safe_version,
hf_transform_to_torch, hf_transform_to_torch,
load_metadata, load_episode_dicts,
load_info,
load_stats,
load_tasks,
unflatten_dict, unflatten_dict,
write_json, write_json,
) )
from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision from lerobot.common.datasets.video_utils import (
VideoFrame,
decode_video_frames_torchvision,
encode_video_frames,
)
from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.robots.utils import Robot
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md # For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
@ -61,6 +74,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
delta_timestamps: dict[list[float]] | None = None, delta_timestamps: dict[list[float]] | None = None,
tolerance_s: float = 1e-4, tolerance_s: float = 1e-4,
download_videos: bool = True, download_videos: bool = True,
local_files_only: bool = False,
video_backend: str | None = None, video_backend: str | None = None,
image_writer: ImageWriter | None = None, image_writer: ImageWriter | None = None,
): ):
@ -162,21 +176,26 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.delta_timestamps = delta_timestamps self.delta_timestamps = delta_timestamps
self.episodes = episodes self.episodes = episodes
self.tolerance_s = tolerance_s self.tolerance_s = tolerance_s
self.download_videos = download_videos
self.video_backend = video_backend if video_backend is not None else "pyav" self.video_backend = video_backend if video_backend is not None else "pyav"
self.image_writer = image_writer self.image_writer = image_writer
self.delta_indices = None self.delta_indices = None
self.consolidated = True self.consolidated = True
self.episode_buffer = {} self.episode_buffer = {}
self.local_files_only = local_files_only
# Load metadata # Load metadata
self.root.mkdir(exist_ok=True, parents=True) self.root.mkdir(exist_ok=True, parents=True)
self._version = get_hub_safe_version(repo_id, CODEBASE_VERSION)
self.pull_from_repo(allow_patterns="meta/") self.pull_from_repo(allow_patterns="meta/")
self.info, self.episode_dicts, self.stats, self.tasks = load_metadata(self.root) self.info = load_info(self.root)
self.stats = load_stats(self.root)
self.tasks = load_tasks(self.root)
self.episode_dicts = load_episode_dicts(self.root)
# Check version
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
# Load actual data # Load actual data
self.download_episodes() self.download_episodes(download_videos)
self.hf_dataset = self.load_hf_dataset() self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts) self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
@ -199,6 +218,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
# - [ ] Update episode_index (arg update=True) # - [ ] Update episode_index (arg update=True)
# - [ ] Update info.json (arg update=True) # - [ ] Update info.json (arg update=True)
@cached_property
def _hub_version(self) -> str | None:
return None if self.local_files_only else get_hub_safe_version(self.repo_id, CODEBASE_VERSION)
@property
def _version(self) -> str:
"""Codebase version used to create this dataset."""
return self.info["codebase_version"]
def push_to_repo(self, push_videos: bool = True) -> None: def push_to_repo(self, push_videos: bool = True) -> None:
if not self.consolidated: if not self.consolidated:
raise RuntimeError( raise RuntimeError(
@ -225,13 +253,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
snapshot_download( snapshot_download(
self.repo_id, self.repo_id,
repo_type="dataset", repo_type="dataset",
revision=self._version, revision=self._hub_version,
local_dir=self.root, local_dir=self.root,
allow_patterns=allow_patterns, allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns, ignore_patterns=ignore_patterns,
local_files_only=self.local_files_only,
) )
def download_episodes(self) -> None: def download_episodes(self, download_videos: bool = True) -> None:
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this """Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
@ -240,10 +269,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
# TODO(rcadene, aliberts): implement faster transfer # TODO(rcadene, aliberts): implement faster transfer
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads # https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
files = None files = None
ignore_patterns = None if self.download_videos else "videos/" ignore_patterns = None if download_videos else "videos/"
if self.episodes is not None: if self.episodes is not None:
files = [self.get_data_file_path(ep_idx) for ep_idx in self.episodes] files = [self.get_data_file_path(ep_idx) for ep_idx in self.episodes]
if len(self.video_keys) > 0 and self.download_videos: if len(self.video_keys) > 0 and download_videos:
video_files = [ video_files = [
self.get_video_file_path(ep_idx, vid_key) self.get_video_file_path(ep_idx, vid_key)
for vid_key in self.video_keys for vid_key in self.video_keys
@ -495,7 +524,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
item = {**video_frames, **item} item = {**video_frames, **item}
if self.image_transforms is not None: if self.image_transforms is not None:
image_keys = self.camera_keys if self.download_videos else self.image_keys image_keys = self.camera_keys
for cam in image_keys: for cam in image_keys:
item[cam] = self.image_transforms(item[cam]) item[cam] = self.image_transforms(item[cam])
@ -521,6 +550,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
"timestamp": [], "timestamp": [],
"next.done": [], "next.done": [],
**{key: [] for key in self.keys}, **{key: [] for key in self.keys},
**{key: [] for key in self.image_keys},
} }
def add_frame(self, frame: dict) -> None: def add_frame(self, frame: dict) -> None:
@ -553,6 +583,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
image=frame[cam_key], image=frame[cam_key],
file_path=img_path, file_path=img_path,
) )
if cam_key in self.image_keys:
self.episode_buffer[cam_key].append(str(img_path))
def add_episode(self, task: str, encode_videos: bool = False) -> None: def add_episode(self, task: str, encode_videos: bool = False) -> None:
""" """
@ -574,6 +606,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episode_buffer["next.done"][-1] = True self.episode_buffer["next.done"][-1] = True
for key in self.episode_buffer: for key in self.episode_buffer:
if key in self.image_keys:
continue
if key in self.keys: if key in self.keys:
self.episode_buffer[key] = torch.stack(self.episode_buffer[key]) self.episode_buffer[key] = torch.stack(self.episode_buffer[key])
elif key == "episode_index": elif key == "episode_index":
@ -583,11 +617,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
else: else:
self.episode_buffer[key] = torch.tensor(self.episode_buffer[key]) self.episode_buffer[key] = torch.tensor(self.episode_buffer[key])
self.episode_buffer["index"] = torch.arange(self.total_frames, self.total_frames + episode_length)
self._save_episode_to_metadata(episode_index, episode_length, task, task_index) self._save_episode_to_metadata(episode_index, episode_length, task, task_index)
self._save_episode_table(episode_index) self._save_episode_table(episode_index)
if encode_videos: if encode_videos and len(self.video_keys) > 0:
pass # TODO self.encode_videos()
# Reset the buffer # Reset the buffer
self.episode_buffer = self._create_episode_buffer() self.episode_buffer = self._create_episode_buffer()
@ -614,7 +649,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
"task_index": task_index, "task_index": task_index,
"task": task, "task": task,
} }
append_jsonl(task_dict, self.root / "meta/tasks.jsonl") append_jsonl(task_dict, self.root / TASKS_PATH)
chunk = self.get_episode_chunk(episode_index) chunk = self.get_episode_chunk(episode_index)
if chunk >= self.total_chunks: if chunk >= self.total_chunks:
@ -622,22 +657,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"} self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
self.info["total_videos"] += len(self.video_keys) self.info["total_videos"] += len(self.video_keys)
write_json(self.info, self.root / "meta/info.json") write_json(self.info, self.root / INFO_PATH)
episode_dict = { episode_dict = {
"episode_index": episode_index, "episode_index": episode_index,
"tasks": [task], "tasks": [task],
"length": episode_length, "length": episode_length,
} }
append_jsonl(episode_dict, self.root / "meta/episodes.jsonl") self.episode_dicts.append(episode_dict)
append_jsonl(episode_dict, self.root / EPISODES_PATH)
def delete_episode(self) -> None: def delete_episode(self) -> None:
episode_index = self.episode_buffer["episode_index"] episode_index = self.episode_buffer["episode_index"]
if self.image_writer is not None: if self.image_writer is not None:
for cam_key in self.camera_keys: for cam_key in self.camera_keys:
cam_dir = self.image_writer.get_episode_dir(episode_index, cam_key) img_dir = self.image_writer.get_episode_dir(episode_index, cam_key, return_str=False)
if cam_dir.is_dir(): if img_dir.is_dir():
shutil.rmtree(cam_dir) shutil.rmtree(img_dir)
# Reset the buffer # Reset the buffer
self.episode_buffer = self._create_episode_buffer() self.episode_buffer = self._create_episode_buffer()
@ -653,27 +689,54 @@ class LeRobotDataset(torch.utils.data.Dataset):
updated_file_name = self.get_data_file_path(ep_idx) updated_file_name = self.get_data_file_path(ep_idx)
current_file_name.rename(updated_file_name) current_file_name.rename(updated_file_name)
def _remove_image_writer(self) -> None:
if self.image_writer is not None:
self.image_writer = None
def encode_videos(self) -> None:
# Use ffmpeg to convert frames stored as png into mp4 videos
for episode_index in range(self.num_episodes):
for key in self.video_keys:
# TODO: create video_buffer to store the state of encoded/unencoded videos and remove the need
# to call self.image_writer here
tmp_imgs_dir = self.image_writer.get_episode_dir(episode_index, key)
video_path = self.get_video_file_path(episode_index, key, return_str=False)
if video_path.is_file():
# Skip if video is already encoded. Could be the case when resuming data recording.
continue
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
# since video encoding with ffmpeg is already using multithreading.
encode_video_frames(tmp_imgs_dir, video_path, self.fps, overwrite=True)
shutil.rmtree(tmp_imgs_dir)
def consolidate(self, run_compute_stats: bool = True) -> None: def consolidate(self, run_compute_stats: bool = True) -> None:
self._update_data_file_names() self._update_data_file_names()
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
if len(self.video_keys) > 0:
self.encode_videos()
if run_compute_stats: if run_compute_stats:
logging.info("Computing dataset statistics") logging.info("Computing dataset statistics")
self.hf_dataset = self.load_hf_dataset() self._remove_image_writer()
self.stats = compute_stats(self) self.stats = compute_stats(self)
serialized_stats = {key: value.tolist() for key, value in self.stats.items()} serialized_stats = flatten_dict(self.stats)
serialized_stats = {key: value.tolist() for key, value in serialized_stats.items()}
serialized_stats = unflatten_dict(serialized_stats) serialized_stats = unflatten_dict(serialized_stats)
write_json(serialized_stats, self.root / "meta/stats.json") write_json(serialized_stats, self.root / "meta/stats.json")
self.consolidated = True
else: else:
logging.warning("Skipping computation of the dataset statistics.") logging.warning("Skipping computation of the dataset statistics.")
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts) # TODO(aliberts)
pass # TODO
# Sanity checks: # Sanity checks:
# - [ ] shapes # - [ ] shapes
# - [ ] ep_lenghts # - [ ] ep_lenghts
# - [ ] number of files # - [ ] number of files
# - [ ] names of files (e.g. parquet 00000-of-00001 and 00001-of-00002) # - [ ] names of files (e.g. parquet 00000-of-00001 and 00001-of-00002)
# - [ ] no remaining self.image_writer.dir # - [ ] no remaining self.image_writer.dir
self.consolidated = True
@classmethod @classmethod
def create( def create(
@ -691,7 +754,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj = cls.__new__(cls) obj = cls.__new__(cls)
obj.repo_id = repo_id obj.repo_id = repo_id
obj.root = root if root is not None else LEROBOT_HOME / repo_id obj.root = root if root is not None else LEROBOT_HOME / repo_id
obj._version = CODEBASE_VERSION
obj.tolerance_s = tolerance_s obj.tolerance_s = tolerance_s
obj.image_writer = image_writer obj.image_writer = image_writer
@ -702,21 +764,26 @@ class LeRobotDataset(torch.utils.data.Dataset):
) )
obj.tasks, obj.stats, obj.episode_dicts = {}, {}, [] obj.tasks, obj.stats, obj.episode_dicts = {}, {}, []
obj.info = create_empty_dataset_info(obj._version, fps, robot, use_videos) obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot, use_videos)
write_json(obj.info, obj.root / "meta/info.json") write_json(obj.info, obj.root / INFO_PATH)
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer # TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
obj.episode_buffer = obj._create_episode_buffer() obj.episode_buffer = obj._create_episode_buffer()
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. # This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It
# It is used to know when certain operations are need (for instance, computing dataset statistics). # is used to know when certain operations are need (for instance, computing dataset statistics). In
# In order to be able to push the dataset to the hub, it needs to be consolidation first. # order to be able to push the dataset to the hub, it needs to be consolidated first by calling
# self.consolidate().
obj.consolidated = True obj.consolidated = True
obj.local_files_only = True
obj.download_videos = False
obj.episodes = None obj.episodes = None
obj.hf_dataset = None obj.hf_dataset = None
obj.image_transforms = None obj.image_transforms = None
obj.delta_timestamps = None obj.delta_timestamps = None
obj.delta_indices = None
obj.episode_data_index = None obj.episode_data_index = None
obj.video_backend = video_backend if video_backend is not None else "pyav" obj.video_backend = video_backend if video_backend is not None else "pyav"
return obj return obj

View File

@ -30,6 +30,12 @@ from torchvision import transforms
from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.robots.utils import Robot
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
INFO_PATH = "meta/info.json"
EPISODES_PATH = "meta/episodes.jsonl"
STATS_PATH = "meta/stats.json"
TASKS_PATH = "meta/tasks.jsonl"
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4" DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
DEFAULT_PARQUET_PATH = ( DEFAULT_PARQUET_PATH = (
"data/chunk-{episode_chunk:03d}/train-{episode_index:05d}-of-{total_episodes:05d}.parquet" "data/chunk-{episode_chunk:03d}/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
@ -104,6 +110,32 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
return items_dict return items_dict
def _get_major_minor(version: str) -> tuple[int]:
split = version.strip("v").split(".")
return int(split[0]), int(split[1])
def check_version_compatibility(
repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True
) -> None:
current_major, _ = _get_major_minor(current_version)
major_to_check, _ = _get_major_minor(version_to_check)
if major_to_check < current_major and enforce_breaking_major:
raise ValueError(
f"""The dataset you requested ({repo_id}) is in {version_to_check} format. We introduced a new
format with v2.0 that is not backward compatible. Please use our conversion script
first (convert_dataset_v1_to_v2.py) to convert your dataset to this new format."""
)
elif float(version_to_check.strip("v")) < float(current_version.strip("v")):
warnings.warn(
f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the
codebase. The current codebase version is {current_version}. You should be fine since
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
stacklevel=1,
)
def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) -> str: def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) -> str:
num_version = float(version.strip("v")) num_version = float(version.strip("v"))
if num_version < 2 and enforce_v2: if num_version < 2 and enforce_v2:
@ -131,30 +163,28 @@ def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) ->
return version return version
def load_metadata(local_dir: Path) -> tuple[dict | list]: def load_info(local_dir: Path) -> dict:
"""Loads metadata files from a dataset.""" with open(local_dir / INFO_PATH) as f:
info_path = local_dir / "meta/info.json" return json.load(f)
episodes_path = local_dir / "meta/episodes.jsonl"
stats_path = local_dir / "meta/stats.json"
tasks_path = local_dir / "meta/tasks.jsonl"
with open(info_path) as f:
info = json.load(f)
with jsonlines.open(episodes_path, "r") as reader: def load_stats(local_dir: Path) -> dict:
episode_dicts = list(reader) with open(local_dir / STATS_PATH) as f:
with open(stats_path) as f:
stats = json.load(f) stats = json.load(f)
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats)
with jsonlines.open(tasks_path, "r") as reader:
def load_tasks(local_dir: Path) -> dict:
with jsonlines.open(local_dir / TASKS_PATH, "r") as reader:
tasks = list(reader) tasks = list(reader)
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()} return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
stats = unflatten_dict(stats)
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
return info, episode_dicts, stats, tasks
def load_episode_dicts(local_dir: Path) -> dict:
with jsonlines.open(local_dir / EPISODES_PATH, "r") as reader:
return list(reader)
def create_empty_dataset_info(codebase_version: str, fps: int, robot: Robot, use_videos: bool = True) -> dict: def create_empty_dataset_info(codebase_version: str, fps: int, robot: Robot, use_videos: bool = True) -> dict:
@ -229,7 +259,7 @@ def check_timestamps_sync(
# Track original indices before masking # Track original indices before masking
original_indices = torch.arange(len(diffs)) original_indices = torch.arange(len(diffs))
filtered_indices = original_indices[mask] filtered_indices = original_indices[mask]
outside_tolerance_filtered_indices = torch.nonzero(~filtered_within_tolerance).squeeze() outside_tolerance_filtered_indices = torch.nonzero(~filtered_within_tolerance) # .squeeze()
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices] outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
episode_indices = torch.stack(hf_dataset["episode_index"]) episode_indices = torch.stack(hf_dataset["episode_index"])

View File

@ -126,8 +126,8 @@ def decode_video_frames_torchvision(
def encode_video_frames( def encode_video_frames(
imgs_dir: Path, imgs_dir: Path | str,
video_path: Path, video_path: Path | str,
fps: int, fps: int,
vcodec: str = "libsvtav1", vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p", pix_fmt: str = "yuv420p",

View File

@ -194,19 +194,17 @@ def record(
pretrained_policy_name_or_path: str | None = None, pretrained_policy_name_or_path: str | None = None,
policy_overrides: List[str] | None = None, policy_overrides: List[str] | None = None,
fps: int | None = None, fps: int | None = None,
warmup_time_s=2, warmup_time_s: int | float = 2,
episode_time_s=10, episode_time_s: int | float = 10,
reset_time_s=5, reset_time_s: int | float = 5,
num_episodes=50, num_episodes: int = 50,
video=True, video: bool = True,
run_compute_stats=True, run_compute_stats: bool = True,
push_to_hub=True, push_to_hub: bool = True,
tags=None, num_image_writer_processes: int = 0,
num_image_writer_processes=0, num_image_writer_threads_per_camera: int = 4,
num_image_writer_threads_per_camera=4, display_cameras: bool = True,
force_override=False, play_sounds: bool = True,
display_cameras=True,
play_sounds=True,
) -> LeRobotDataset: ) -> LeRobotDataset:
# TODO(rcadene): Add option to record logs # TODO(rcadene): Add option to record logs
listener = None listener = None
@ -234,12 +232,18 @@ def record(
# Create empty dataset or load existing saved episodes # Create empty dataset or load existing saved episodes
sanity_check_dataset_name(repo_id, policy) sanity_check_dataset_name(repo_id, policy)
image_writer = ImageWriter( if len(robot.cameras) > 0:
write_dir=root, image_writer = ImageWriter(
num_processes=num_image_writer_processes, write_dir=root,
num_threads=num_image_writer_threads_per_camera * robot.num_cameras, num_processes=num_image_writer_processes,
num_threads=num_image_writer_threads_per_camera * robot.num_cameras,
)
else:
image_writer = None
dataset = LeRobotDataset.create(
repo_id, fps, robot, root=root, image_writer=image_writer, use_videos=video
) )
dataset = LeRobotDataset.create(repo_id, fps, robot, root=root, image_writer=image_writer)
if not robot.is_connected: if not robot.is_connected:
robot.connect() robot.connect()
@ -307,8 +311,9 @@ def record(
log_say("Stop recording", play_sounds, blocking=True) log_say("Stop recording", play_sounds, blocking=True)
stop_recording(robot, listener, display_cameras) stop_recording(robot, listener, display_cameras)
logging.info("Waiting for image writer to terminate...") if dataset.image_writer is not None:
dataset.image_writer.stop() logging.info("Waiting for image writer to terminate...")
dataset.image_writer.stop()
dataset.consolidate(run_compute_stats) dataset.consolidate(run_compute_stats)
@ -322,27 +327,28 @@ def record(
@safe_disconnect @safe_disconnect
def replay( def replay(
robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug", play_sounds=True robot: Robot,
root: Path,
repo_id: str,
episode: int,
fps: int | None = None,
play_sounds: bool = True,
local_files_only: bool = True,
): ):
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset # TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
# TODO(rcadene): Add option to record logs # TODO(rcadene): Add option to record logs
local_dir = Path(root) / repo_id
if not local_dir.exists():
raise ValueError(local_dir)
dataset = LeRobotDataset(repo_id, root=root) dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
items = dataset.hf_dataset.select_columns("action") actions = dataset.hf_dataset.select_columns("action")
from_idx = dataset.episode_data_index["from"][episode].item()
to_idx = dataset.episode_data_index["to"][episode].item()
if not robot.is_connected: if not robot.is_connected:
robot.connect() robot.connect()
log_say("Replaying episode", play_sounds, blocking=True) log_say("Replaying episode", play_sounds, blocking=True)
for idx in range(from_idx, to_idx): for idx in range(dataset.num_samples):
start_episode_t = time.perf_counter() start_episode_t = time.perf_counter()
action = items[idx]["action"] action = actions[idx]["action"]
robot.send_action(action) robot.send_action(action)
dt_s = time.perf_counter() - start_episode_t dt_s = time.perf_counter() - start_episode_t