Add local_files_only, encode_videos, fix bugs to pass tests (WIP)
This commit is contained in:
parent
e991a31061
commit
a805458c7e
|
@ -17,6 +17,7 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
|
@ -30,20 +31,32 @@ from huggingface_hub import snapshot_download, upload_folder
|
|||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
|
||||
from lerobot.common.datasets.image_writer import ImageWriter
|
||||
from lerobot.common.datasets.utils import (
|
||||
EPISODES_PATH,
|
||||
INFO_PATH,
|
||||
TASKS_PATH,
|
||||
append_jsonl,
|
||||
check_delta_timestamps,
|
||||
check_timestamps_sync,
|
||||
check_version_compatibility,
|
||||
create_branch,
|
||||
create_empty_dataset_info,
|
||||
flatten_dict,
|
||||
get_delta_indices,
|
||||
get_episode_data_index,
|
||||
get_hub_safe_version,
|
||||
hf_transform_to_torch,
|
||||
load_metadata,
|
||||
load_episode_dicts,
|
||||
load_info,
|
||||
load_stats,
|
||||
load_tasks,
|
||||
unflatten_dict,
|
||||
write_json,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision
|
||||
from lerobot.common.datasets.video_utils import (
|
||||
VideoFrame,
|
||||
decode_video_frames_torchvision,
|
||||
encode_video_frames,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
|
||||
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
|
||||
|
@ -61,6 +74,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
delta_timestamps: dict[list[float]] | None = None,
|
||||
tolerance_s: float = 1e-4,
|
||||
download_videos: bool = True,
|
||||
local_files_only: bool = False,
|
||||
video_backend: str | None = None,
|
||||
image_writer: ImageWriter | None = None,
|
||||
):
|
||||
|
@ -162,21 +176,26 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.delta_timestamps = delta_timestamps
|
||||
self.episodes = episodes
|
||||
self.tolerance_s = tolerance_s
|
||||
self.download_videos = download_videos
|
||||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
self.image_writer = image_writer
|
||||
self.delta_indices = None
|
||||
self.consolidated = True
|
||||
self.episode_buffer = {}
|
||||
self.local_files_only = local_files_only
|
||||
|
||||
# Load metadata
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
self._version = get_hub_safe_version(repo_id, CODEBASE_VERSION)
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
self.info, self.episode_dicts, self.stats, self.tasks = load_metadata(self.root)
|
||||
self.info = load_info(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
self.tasks = load_tasks(self.root)
|
||||
self.episode_dicts = load_episode_dicts(self.root)
|
||||
|
||||
# Check version
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
|
||||
# Load actual data
|
||||
self.download_episodes()
|
||||
self.download_episodes(download_videos)
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
|
||||
|
||||
|
@ -199,6 +218,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
# - [ ] Update episode_index (arg update=True)
|
||||
# - [ ] Update info.json (arg update=True)
|
||||
|
||||
@cached_property
|
||||
def _hub_version(self) -> str | None:
|
||||
return None if self.local_files_only else get_hub_safe_version(self.repo_id, CODEBASE_VERSION)
|
||||
|
||||
@property
|
||||
def _version(self) -> str:
|
||||
"""Codebase version used to create this dataset."""
|
||||
return self.info["codebase_version"]
|
||||
|
||||
def push_to_repo(self, push_videos: bool = True) -> None:
|
||||
if not self.consolidated:
|
||||
raise RuntimeError(
|
||||
|
@ -225,13 +253,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
snapshot_download(
|
||||
self.repo_id,
|
||||
repo_type="dataset",
|
||||
revision=self._version,
|
||||
revision=self._hub_version,
|
||||
local_dir=self.root,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
local_files_only=self.local_files_only,
|
||||
)
|
||||
|
||||
def download_episodes(self) -> None:
|
||||
def download_episodes(self, download_videos: bool = True) -> None:
|
||||
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
|
||||
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
|
||||
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
|
||||
|
@ -240,10 +269,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
# TODO(rcadene, aliberts): implement faster transfer
|
||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||
files = None
|
||||
ignore_patterns = None if self.download_videos else "videos/"
|
||||
ignore_patterns = None if download_videos else "videos/"
|
||||
if self.episodes is not None:
|
||||
files = [self.get_data_file_path(ep_idx) for ep_idx in self.episodes]
|
||||
if len(self.video_keys) > 0 and self.download_videos:
|
||||
if len(self.video_keys) > 0 and download_videos:
|
||||
video_files = [
|
||||
self.get_video_file_path(ep_idx, vid_key)
|
||||
for vid_key in self.video_keys
|
||||
|
@ -495,7 +524,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
item = {**video_frames, **item}
|
||||
|
||||
if self.image_transforms is not None:
|
||||
image_keys = self.camera_keys if self.download_videos else self.image_keys
|
||||
image_keys = self.camera_keys
|
||||
for cam in image_keys:
|
||||
item[cam] = self.image_transforms(item[cam])
|
||||
|
||||
|
@ -521,6 +550,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
"timestamp": [],
|
||||
"next.done": [],
|
||||
**{key: [] for key in self.keys},
|
||||
**{key: [] for key in self.image_keys},
|
||||
}
|
||||
|
||||
def add_frame(self, frame: dict) -> None:
|
||||
|
@ -553,6 +583,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
image=frame[cam_key],
|
||||
file_path=img_path,
|
||||
)
|
||||
if cam_key in self.image_keys:
|
||||
self.episode_buffer[cam_key].append(str(img_path))
|
||||
|
||||
def add_episode(self, task: str, encode_videos: bool = False) -> None:
|
||||
"""
|
||||
|
@ -574,6 +606,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.episode_buffer["next.done"][-1] = True
|
||||
|
||||
for key in self.episode_buffer:
|
||||
if key in self.image_keys:
|
||||
continue
|
||||
if key in self.keys:
|
||||
self.episode_buffer[key] = torch.stack(self.episode_buffer[key])
|
||||
elif key == "episode_index":
|
||||
|
@ -583,11 +617,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
else:
|
||||
self.episode_buffer[key] = torch.tensor(self.episode_buffer[key])
|
||||
|
||||
self.episode_buffer["index"] = torch.arange(self.total_frames, self.total_frames + episode_length)
|
||||
self._save_episode_to_metadata(episode_index, episode_length, task, task_index)
|
||||
self._save_episode_table(episode_index)
|
||||
|
||||
if encode_videos:
|
||||
pass # TODO
|
||||
if encode_videos and len(self.video_keys) > 0:
|
||||
self.encode_videos()
|
||||
|
||||
# Reset the buffer
|
||||
self.episode_buffer = self._create_episode_buffer()
|
||||
|
@ -614,7 +649,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
"task_index": task_index,
|
||||
"task": task,
|
||||
}
|
||||
append_jsonl(task_dict, self.root / "meta/tasks.jsonl")
|
||||
append_jsonl(task_dict, self.root / TASKS_PATH)
|
||||
|
||||
chunk = self.get_episode_chunk(episode_index)
|
||||
if chunk >= self.total_chunks:
|
||||
|
@ -622,22 +657,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||
self.info["total_videos"] += len(self.video_keys)
|
||||
write_json(self.info, self.root / "meta/info.json")
|
||||
write_json(self.info, self.root / INFO_PATH)
|
||||
|
||||
episode_dict = {
|
||||
"episode_index": episode_index,
|
||||
"tasks": [task],
|
||||
"length": episode_length,
|
||||
}
|
||||
append_jsonl(episode_dict, self.root / "meta/episodes.jsonl")
|
||||
self.episode_dicts.append(episode_dict)
|
||||
append_jsonl(episode_dict, self.root / EPISODES_PATH)
|
||||
|
||||
def delete_episode(self) -> None:
|
||||
episode_index = self.episode_buffer["episode_index"]
|
||||
if self.image_writer is not None:
|
||||
for cam_key in self.camera_keys:
|
||||
cam_dir = self.image_writer.get_episode_dir(episode_index, cam_key)
|
||||
if cam_dir.is_dir():
|
||||
shutil.rmtree(cam_dir)
|
||||
img_dir = self.image_writer.get_episode_dir(episode_index, cam_key, return_str=False)
|
||||
if img_dir.is_dir():
|
||||
shutil.rmtree(img_dir)
|
||||
|
||||
# Reset the buffer
|
||||
self.episode_buffer = self._create_episode_buffer()
|
||||
|
@ -653,27 +689,54 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
updated_file_name = self.get_data_file_path(ep_idx)
|
||||
current_file_name.rename(updated_file_name)
|
||||
|
||||
def _remove_image_writer(self) -> None:
|
||||
if self.image_writer is not None:
|
||||
self.image_writer = None
|
||||
|
||||
def encode_videos(self) -> None:
|
||||
# Use ffmpeg to convert frames stored as png into mp4 videos
|
||||
for episode_index in range(self.num_episodes):
|
||||
for key in self.video_keys:
|
||||
# TODO: create video_buffer to store the state of encoded/unencoded videos and remove the need
|
||||
# to call self.image_writer here
|
||||
tmp_imgs_dir = self.image_writer.get_episode_dir(episode_index, key)
|
||||
video_path = self.get_video_file_path(episode_index, key, return_str=False)
|
||||
if video_path.is_file():
|
||||
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||
continue
|
||||
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
# since video encoding with ffmpeg is already using multithreading.
|
||||
encode_video_frames(tmp_imgs_dir, video_path, self.fps, overwrite=True)
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
def consolidate(self, run_compute_stats: bool = True) -> None:
|
||||
self._update_data_file_names()
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
|
||||
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
||||
|
||||
if len(self.video_keys) > 0:
|
||||
self.encode_videos()
|
||||
|
||||
if run_compute_stats:
|
||||
logging.info("Computing dataset statistics")
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self._remove_image_writer()
|
||||
self.stats = compute_stats(self)
|
||||
serialized_stats = {key: value.tolist() for key, value in self.stats.items()}
|
||||
serialized_stats = flatten_dict(self.stats)
|
||||
serialized_stats = {key: value.tolist() for key, value in serialized_stats.items()}
|
||||
serialized_stats = unflatten_dict(serialized_stats)
|
||||
write_json(serialized_stats, self.root / "meta/stats.json")
|
||||
self.consolidated = True
|
||||
else:
|
||||
logging.warning("Skipping computation of the dataset statistics.")
|
||||
|
||||
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
|
||||
pass # TODO
|
||||
# TODO(aliberts)
|
||||
# Sanity checks:
|
||||
# - [ ] shapes
|
||||
# - [ ] ep_lenghts
|
||||
# - [ ] number of files
|
||||
# - [ ] names of files (e.g. parquet 00000-of-00001 and 00001-of-00002)
|
||||
# - [ ] no remaining self.image_writer.dir
|
||||
self.consolidated = True
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
|
@ -691,7 +754,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
obj = cls.__new__(cls)
|
||||
obj.repo_id = repo_id
|
||||
obj.root = root if root is not None else LEROBOT_HOME / repo_id
|
||||
obj._version = CODEBASE_VERSION
|
||||
obj.tolerance_s = tolerance_s
|
||||
obj.image_writer = image_writer
|
||||
|
||||
|
@ -702,21 +764,26 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
)
|
||||
|
||||
obj.tasks, obj.stats, obj.episode_dicts = {}, {}, []
|
||||
obj.info = create_empty_dataset_info(obj._version, fps, robot, use_videos)
|
||||
write_json(obj.info, obj.root / "meta/info.json")
|
||||
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot, use_videos)
|
||||
write_json(obj.info, obj.root / INFO_PATH)
|
||||
|
||||
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
|
||||
obj.episode_buffer = obj._create_episode_buffer()
|
||||
|
||||
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk.
|
||||
# It is used to know when certain operations are need (for instance, computing dataset statistics).
|
||||
# In order to be able to push the dataset to the hub, it needs to be consolidation first.
|
||||
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It
|
||||
# is used to know when certain operations are need (for instance, computing dataset statistics). In
|
||||
# order to be able to push the dataset to the hub, it needs to be consolidated first by calling
|
||||
# self.consolidate().
|
||||
obj.consolidated = True
|
||||
|
||||
obj.local_files_only = True
|
||||
obj.download_videos = False
|
||||
|
||||
obj.episodes = None
|
||||
obj.hf_dataset = None
|
||||
obj.image_transforms = None
|
||||
obj.delta_timestamps = None
|
||||
obj.delta_indices = None
|
||||
obj.episode_data_index = None
|
||||
obj.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
return obj
|
||||
|
|
|
@ -30,6 +30,12 @@ from torchvision import transforms
|
|||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
|
||||
|
||||
INFO_PATH = "meta/info.json"
|
||||
EPISODES_PATH = "meta/episodes.jsonl"
|
||||
STATS_PATH = "meta/stats.json"
|
||||
TASKS_PATH = "meta/tasks.jsonl"
|
||||
|
||||
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
DEFAULT_PARQUET_PATH = (
|
||||
"data/chunk-{episode_chunk:03d}/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
|
||||
|
@ -104,6 +110,32 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
|||
return items_dict
|
||||
|
||||
|
||||
def _get_major_minor(version: str) -> tuple[int]:
|
||||
split = version.strip("v").split(".")
|
||||
return int(split[0]), int(split[1])
|
||||
|
||||
|
||||
def check_version_compatibility(
|
||||
repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True
|
||||
) -> None:
|
||||
current_major, _ = _get_major_minor(current_version)
|
||||
major_to_check, _ = _get_major_minor(version_to_check)
|
||||
if major_to_check < current_major and enforce_breaking_major:
|
||||
raise ValueError(
|
||||
f"""The dataset you requested ({repo_id}) is in {version_to_check} format. We introduced a new
|
||||
format with v2.0 that is not backward compatible. Please use our conversion script
|
||||
first (convert_dataset_v1_to_v2.py) to convert your dataset to this new format."""
|
||||
)
|
||||
elif float(version_to_check.strip("v")) < float(current_version.strip("v")):
|
||||
warnings.warn(
|
||||
f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the
|
||||
codebase. The current codebase version is {current_version}. You should be fine since
|
||||
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
|
||||
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
|
||||
stacklevel=1,
|
||||
)
|
||||
|
||||
|
||||
def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) -> str:
|
||||
num_version = float(version.strip("v"))
|
||||
if num_version < 2 and enforce_v2:
|
||||
|
@ -131,30 +163,28 @@ def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) ->
|
|||
return version
|
||||
|
||||
|
||||
def load_metadata(local_dir: Path) -> tuple[dict | list]:
|
||||
"""Loads metadata files from a dataset."""
|
||||
info_path = local_dir / "meta/info.json"
|
||||
episodes_path = local_dir / "meta/episodes.jsonl"
|
||||
stats_path = local_dir / "meta/stats.json"
|
||||
tasks_path = local_dir / "meta/tasks.jsonl"
|
||||
def load_info(local_dir: Path) -> dict:
|
||||
with open(local_dir / INFO_PATH) as f:
|
||||
return json.load(f)
|
||||
|
||||
with open(info_path) as f:
|
||||
info = json.load(f)
|
||||
|
||||
with jsonlines.open(episodes_path, "r") as reader:
|
||||
episode_dicts = list(reader)
|
||||
|
||||
with open(stats_path) as f:
|
||||
def load_stats(local_dir: Path) -> dict:
|
||||
with open(local_dir / STATS_PATH) as f:
|
||||
stats = json.load(f)
|
||||
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
|
||||
return unflatten_dict(stats)
|
||||
|
||||
with jsonlines.open(tasks_path, "r") as reader:
|
||||
|
||||
def load_tasks(local_dir: Path) -> dict:
|
||||
with jsonlines.open(local_dir / TASKS_PATH, "r") as reader:
|
||||
tasks = list(reader)
|
||||
|
||||
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
|
||||
stats = unflatten_dict(stats)
|
||||
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
|
||||
return info, episode_dicts, stats, tasks
|
||||
|
||||
def load_episode_dicts(local_dir: Path) -> dict:
|
||||
with jsonlines.open(local_dir / EPISODES_PATH, "r") as reader:
|
||||
return list(reader)
|
||||
|
||||
|
||||
def create_empty_dataset_info(codebase_version: str, fps: int, robot: Robot, use_videos: bool = True) -> dict:
|
||||
|
@ -229,7 +259,7 @@ def check_timestamps_sync(
|
|||
# Track original indices before masking
|
||||
original_indices = torch.arange(len(diffs))
|
||||
filtered_indices = original_indices[mask]
|
||||
outside_tolerance_filtered_indices = torch.nonzero(~filtered_within_tolerance).squeeze()
|
||||
outside_tolerance_filtered_indices = torch.nonzero(~filtered_within_tolerance) # .squeeze()
|
||||
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
|
||||
episode_indices = torch.stack(hf_dataset["episode_index"])
|
||||
|
||||
|
|
|
@ -126,8 +126,8 @@ def decode_video_frames_torchvision(
|
|||
|
||||
|
||||
def encode_video_frames(
|
||||
imgs_dir: Path,
|
||||
video_path: Path,
|
||||
imgs_dir: Path | str,
|
||||
video_path: Path | str,
|
||||
fps: int,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
|
|
|
@ -194,19 +194,17 @@ def record(
|
|||
pretrained_policy_name_or_path: str | None = None,
|
||||
policy_overrides: List[str] | None = None,
|
||||
fps: int | None = None,
|
||||
warmup_time_s=2,
|
||||
episode_time_s=10,
|
||||
reset_time_s=5,
|
||||
num_episodes=50,
|
||||
video=True,
|
||||
run_compute_stats=True,
|
||||
push_to_hub=True,
|
||||
tags=None,
|
||||
num_image_writer_processes=0,
|
||||
num_image_writer_threads_per_camera=4,
|
||||
force_override=False,
|
||||
display_cameras=True,
|
||||
play_sounds=True,
|
||||
warmup_time_s: int | float = 2,
|
||||
episode_time_s: int | float = 10,
|
||||
reset_time_s: int | float = 5,
|
||||
num_episodes: int = 50,
|
||||
video: bool = True,
|
||||
run_compute_stats: bool = True,
|
||||
push_to_hub: bool = True,
|
||||
num_image_writer_processes: int = 0,
|
||||
num_image_writer_threads_per_camera: int = 4,
|
||||
display_cameras: bool = True,
|
||||
play_sounds: bool = True,
|
||||
) -> LeRobotDataset:
|
||||
# TODO(rcadene): Add option to record logs
|
||||
listener = None
|
||||
|
@ -234,12 +232,18 @@ def record(
|
|||
|
||||
# Create empty dataset or load existing saved episodes
|
||||
sanity_check_dataset_name(repo_id, policy)
|
||||
image_writer = ImageWriter(
|
||||
write_dir=root,
|
||||
num_processes=num_image_writer_processes,
|
||||
num_threads=num_image_writer_threads_per_camera * robot.num_cameras,
|
||||
if len(robot.cameras) > 0:
|
||||
image_writer = ImageWriter(
|
||||
write_dir=root,
|
||||
num_processes=num_image_writer_processes,
|
||||
num_threads=num_image_writer_threads_per_camera * robot.num_cameras,
|
||||
)
|
||||
else:
|
||||
image_writer = None
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id, fps, robot, root=root, image_writer=image_writer, use_videos=video
|
||||
)
|
||||
dataset = LeRobotDataset.create(repo_id, fps, robot, root=root, image_writer=image_writer)
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
@ -307,8 +311,9 @@ def record(
|
|||
log_say("Stop recording", play_sounds, blocking=True)
|
||||
stop_recording(robot, listener, display_cameras)
|
||||
|
||||
logging.info("Waiting for image writer to terminate...")
|
||||
dataset.image_writer.stop()
|
||||
if dataset.image_writer is not None:
|
||||
logging.info("Waiting for image writer to terminate...")
|
||||
dataset.image_writer.stop()
|
||||
|
||||
dataset.consolidate(run_compute_stats)
|
||||
|
||||
|
@ -322,27 +327,28 @@ def record(
|
|||
|
||||
@safe_disconnect
|
||||
def replay(
|
||||
robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug", play_sounds=True
|
||||
robot: Robot,
|
||||
root: Path,
|
||||
repo_id: str,
|
||||
episode: int,
|
||||
fps: int | None = None,
|
||||
play_sounds: bool = True,
|
||||
local_files_only: bool = True,
|
||||
):
|
||||
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
|
||||
# TODO(rcadene): Add option to record logs
|
||||
local_dir = Path(root) / repo_id
|
||||
if not local_dir.exists():
|
||||
raise ValueError(local_dir)
|
||||
|
||||
dataset = LeRobotDataset(repo_id, root=root)
|
||||
items = dataset.hf_dataset.select_columns("action")
|
||||
from_idx = dataset.episode_data_index["from"][episode].item()
|
||||
to_idx = dataset.episode_data_index["to"][episode].item()
|
||||
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
log_say("Replaying episode", play_sounds, blocking=True)
|
||||
for idx in range(from_idx, to_idx):
|
||||
for idx in range(dataset.num_samples):
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action = items[idx]["action"]
|
||||
action = actions[idx]["action"]
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
|
|
Loading…
Reference in New Issue