Add local_files_only, encode_videos, fix bugs to pass tests (WIP)
This commit is contained in:
parent
e991a31061
commit
a805458c7e
|
@ -17,6 +17,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
from functools import cached_property
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
|
@ -30,20 +31,32 @@ from huggingface_hub import snapshot_download, upload_folder
|
||||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
|
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
|
||||||
from lerobot.common.datasets.image_writer import ImageWriter
|
from lerobot.common.datasets.image_writer import ImageWriter
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
|
EPISODES_PATH,
|
||||||
|
INFO_PATH,
|
||||||
|
TASKS_PATH,
|
||||||
append_jsonl,
|
append_jsonl,
|
||||||
check_delta_timestamps,
|
check_delta_timestamps,
|
||||||
check_timestamps_sync,
|
check_timestamps_sync,
|
||||||
|
check_version_compatibility,
|
||||||
create_branch,
|
create_branch,
|
||||||
create_empty_dataset_info,
|
create_empty_dataset_info,
|
||||||
|
flatten_dict,
|
||||||
get_delta_indices,
|
get_delta_indices,
|
||||||
get_episode_data_index,
|
get_episode_data_index,
|
||||||
get_hub_safe_version,
|
get_hub_safe_version,
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
load_metadata,
|
load_episode_dicts,
|
||||||
|
load_info,
|
||||||
|
load_stats,
|
||||||
|
load_tasks,
|
||||||
unflatten_dict,
|
unflatten_dict,
|
||||||
write_json,
|
write_json,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision
|
from lerobot.common.datasets.video_utils import (
|
||||||
|
VideoFrame,
|
||||||
|
decode_video_frames_torchvision,
|
||||||
|
encode_video_frames,
|
||||||
|
)
|
||||||
from lerobot.common.robot_devices.robots.utils import Robot
|
from lerobot.common.robot_devices.robots.utils import Robot
|
||||||
|
|
||||||
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
|
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
|
||||||
|
@ -61,6 +74,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
delta_timestamps: dict[list[float]] | None = None,
|
delta_timestamps: dict[list[float]] | None = None,
|
||||||
tolerance_s: float = 1e-4,
|
tolerance_s: float = 1e-4,
|
||||||
download_videos: bool = True,
|
download_videos: bool = True,
|
||||||
|
local_files_only: bool = False,
|
||||||
video_backend: str | None = None,
|
video_backend: str | None = None,
|
||||||
image_writer: ImageWriter | None = None,
|
image_writer: ImageWriter | None = None,
|
||||||
):
|
):
|
||||||
|
@ -162,21 +176,26 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = delta_timestamps
|
||||||
self.episodes = episodes
|
self.episodes = episodes
|
||||||
self.tolerance_s = tolerance_s
|
self.tolerance_s = tolerance_s
|
||||||
self.download_videos = download_videos
|
|
||||||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||||
self.image_writer = image_writer
|
self.image_writer = image_writer
|
||||||
self.delta_indices = None
|
self.delta_indices = None
|
||||||
self.consolidated = True
|
self.consolidated = True
|
||||||
self.episode_buffer = {}
|
self.episode_buffer = {}
|
||||||
|
self.local_files_only = local_files_only
|
||||||
|
|
||||||
# Load metadata
|
# Load metadata
|
||||||
self.root.mkdir(exist_ok=True, parents=True)
|
self.root.mkdir(exist_ok=True, parents=True)
|
||||||
self._version = get_hub_safe_version(repo_id, CODEBASE_VERSION)
|
|
||||||
self.pull_from_repo(allow_patterns="meta/")
|
self.pull_from_repo(allow_patterns="meta/")
|
||||||
self.info, self.episode_dicts, self.stats, self.tasks = load_metadata(self.root)
|
self.info = load_info(self.root)
|
||||||
|
self.stats = load_stats(self.root)
|
||||||
|
self.tasks = load_tasks(self.root)
|
||||||
|
self.episode_dicts = load_episode_dicts(self.root)
|
||||||
|
|
||||||
|
# Check version
|
||||||
|
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||||
|
|
||||||
# Load actual data
|
# Load actual data
|
||||||
self.download_episodes()
|
self.download_episodes(download_videos)
|
||||||
self.hf_dataset = self.load_hf_dataset()
|
self.hf_dataset = self.load_hf_dataset()
|
||||||
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
|
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
|
||||||
|
|
||||||
|
@ -199,6 +218,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
# - [ ] Update episode_index (arg update=True)
|
# - [ ] Update episode_index (arg update=True)
|
||||||
# - [ ] Update info.json (arg update=True)
|
# - [ ] Update info.json (arg update=True)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def _hub_version(self) -> str | None:
|
||||||
|
return None if self.local_files_only else get_hub_safe_version(self.repo_id, CODEBASE_VERSION)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _version(self) -> str:
|
||||||
|
"""Codebase version used to create this dataset."""
|
||||||
|
return self.info["codebase_version"]
|
||||||
|
|
||||||
def push_to_repo(self, push_videos: bool = True) -> None:
|
def push_to_repo(self, push_videos: bool = True) -> None:
|
||||||
if not self.consolidated:
|
if not self.consolidated:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
@ -225,13 +253,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
self.repo_id,
|
self.repo_id,
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
revision=self._version,
|
revision=self._hub_version,
|
||||||
local_dir=self.root,
|
local_dir=self.root,
|
||||||
allow_patterns=allow_patterns,
|
allow_patterns=allow_patterns,
|
||||||
ignore_patterns=ignore_patterns,
|
ignore_patterns=ignore_patterns,
|
||||||
|
local_files_only=self.local_files_only,
|
||||||
)
|
)
|
||||||
|
|
||||||
def download_episodes(self) -> None:
|
def download_episodes(self, download_videos: bool = True) -> None:
|
||||||
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
|
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
|
||||||
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
|
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
|
||||||
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
|
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
|
||||||
|
@ -240,10 +269,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
# TODO(rcadene, aliberts): implement faster transfer
|
# TODO(rcadene, aliberts): implement faster transfer
|
||||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||||
files = None
|
files = None
|
||||||
ignore_patterns = None if self.download_videos else "videos/"
|
ignore_patterns = None if download_videos else "videos/"
|
||||||
if self.episodes is not None:
|
if self.episodes is not None:
|
||||||
files = [self.get_data_file_path(ep_idx) for ep_idx in self.episodes]
|
files = [self.get_data_file_path(ep_idx) for ep_idx in self.episodes]
|
||||||
if len(self.video_keys) > 0 and self.download_videos:
|
if len(self.video_keys) > 0 and download_videos:
|
||||||
video_files = [
|
video_files = [
|
||||||
self.get_video_file_path(ep_idx, vid_key)
|
self.get_video_file_path(ep_idx, vid_key)
|
||||||
for vid_key in self.video_keys
|
for vid_key in self.video_keys
|
||||||
|
@ -495,7 +524,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
item = {**video_frames, **item}
|
item = {**video_frames, **item}
|
||||||
|
|
||||||
if self.image_transforms is not None:
|
if self.image_transforms is not None:
|
||||||
image_keys = self.camera_keys if self.download_videos else self.image_keys
|
image_keys = self.camera_keys
|
||||||
for cam in image_keys:
|
for cam in image_keys:
|
||||||
item[cam] = self.image_transforms(item[cam])
|
item[cam] = self.image_transforms(item[cam])
|
||||||
|
|
||||||
|
@ -521,6 +550,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
"timestamp": [],
|
"timestamp": [],
|
||||||
"next.done": [],
|
"next.done": [],
|
||||||
**{key: [] for key in self.keys},
|
**{key: [] for key in self.keys},
|
||||||
|
**{key: [] for key in self.image_keys},
|
||||||
}
|
}
|
||||||
|
|
||||||
def add_frame(self, frame: dict) -> None:
|
def add_frame(self, frame: dict) -> None:
|
||||||
|
@ -553,6 +583,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
image=frame[cam_key],
|
image=frame[cam_key],
|
||||||
file_path=img_path,
|
file_path=img_path,
|
||||||
)
|
)
|
||||||
|
if cam_key in self.image_keys:
|
||||||
|
self.episode_buffer[cam_key].append(str(img_path))
|
||||||
|
|
||||||
def add_episode(self, task: str, encode_videos: bool = False) -> None:
|
def add_episode(self, task: str, encode_videos: bool = False) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -574,6 +606,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.episode_buffer["next.done"][-1] = True
|
self.episode_buffer["next.done"][-1] = True
|
||||||
|
|
||||||
for key in self.episode_buffer:
|
for key in self.episode_buffer:
|
||||||
|
if key in self.image_keys:
|
||||||
|
continue
|
||||||
if key in self.keys:
|
if key in self.keys:
|
||||||
self.episode_buffer[key] = torch.stack(self.episode_buffer[key])
|
self.episode_buffer[key] = torch.stack(self.episode_buffer[key])
|
||||||
elif key == "episode_index":
|
elif key == "episode_index":
|
||||||
|
@ -583,11 +617,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
else:
|
else:
|
||||||
self.episode_buffer[key] = torch.tensor(self.episode_buffer[key])
|
self.episode_buffer[key] = torch.tensor(self.episode_buffer[key])
|
||||||
|
|
||||||
|
self.episode_buffer["index"] = torch.arange(self.total_frames, self.total_frames + episode_length)
|
||||||
self._save_episode_to_metadata(episode_index, episode_length, task, task_index)
|
self._save_episode_to_metadata(episode_index, episode_length, task, task_index)
|
||||||
self._save_episode_table(episode_index)
|
self._save_episode_table(episode_index)
|
||||||
|
|
||||||
if encode_videos:
|
if encode_videos and len(self.video_keys) > 0:
|
||||||
pass # TODO
|
self.encode_videos()
|
||||||
|
|
||||||
# Reset the buffer
|
# Reset the buffer
|
||||||
self.episode_buffer = self._create_episode_buffer()
|
self.episode_buffer = self._create_episode_buffer()
|
||||||
|
@ -614,7 +649,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
"task_index": task_index,
|
"task_index": task_index,
|
||||||
"task": task,
|
"task": task,
|
||||||
}
|
}
|
||||||
append_jsonl(task_dict, self.root / "meta/tasks.jsonl")
|
append_jsonl(task_dict, self.root / TASKS_PATH)
|
||||||
|
|
||||||
chunk = self.get_episode_chunk(episode_index)
|
chunk = self.get_episode_chunk(episode_index)
|
||||||
if chunk >= self.total_chunks:
|
if chunk >= self.total_chunks:
|
||||||
|
@ -622,22 +657,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||||
self.info["total_videos"] += len(self.video_keys)
|
self.info["total_videos"] += len(self.video_keys)
|
||||||
write_json(self.info, self.root / "meta/info.json")
|
write_json(self.info, self.root / INFO_PATH)
|
||||||
|
|
||||||
episode_dict = {
|
episode_dict = {
|
||||||
"episode_index": episode_index,
|
"episode_index": episode_index,
|
||||||
"tasks": [task],
|
"tasks": [task],
|
||||||
"length": episode_length,
|
"length": episode_length,
|
||||||
}
|
}
|
||||||
append_jsonl(episode_dict, self.root / "meta/episodes.jsonl")
|
self.episode_dicts.append(episode_dict)
|
||||||
|
append_jsonl(episode_dict, self.root / EPISODES_PATH)
|
||||||
|
|
||||||
def delete_episode(self) -> None:
|
def delete_episode(self) -> None:
|
||||||
episode_index = self.episode_buffer["episode_index"]
|
episode_index = self.episode_buffer["episode_index"]
|
||||||
if self.image_writer is not None:
|
if self.image_writer is not None:
|
||||||
for cam_key in self.camera_keys:
|
for cam_key in self.camera_keys:
|
||||||
cam_dir = self.image_writer.get_episode_dir(episode_index, cam_key)
|
img_dir = self.image_writer.get_episode_dir(episode_index, cam_key, return_str=False)
|
||||||
if cam_dir.is_dir():
|
if img_dir.is_dir():
|
||||||
shutil.rmtree(cam_dir)
|
shutil.rmtree(img_dir)
|
||||||
|
|
||||||
# Reset the buffer
|
# Reset the buffer
|
||||||
self.episode_buffer = self._create_episode_buffer()
|
self.episode_buffer = self._create_episode_buffer()
|
||||||
|
@ -653,27 +689,54 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
updated_file_name = self.get_data_file_path(ep_idx)
|
updated_file_name = self.get_data_file_path(ep_idx)
|
||||||
current_file_name.rename(updated_file_name)
|
current_file_name.rename(updated_file_name)
|
||||||
|
|
||||||
|
def _remove_image_writer(self) -> None:
|
||||||
|
if self.image_writer is not None:
|
||||||
|
self.image_writer = None
|
||||||
|
|
||||||
|
def encode_videos(self) -> None:
|
||||||
|
# Use ffmpeg to convert frames stored as png into mp4 videos
|
||||||
|
for episode_index in range(self.num_episodes):
|
||||||
|
for key in self.video_keys:
|
||||||
|
# TODO: create video_buffer to store the state of encoded/unencoded videos and remove the need
|
||||||
|
# to call self.image_writer here
|
||||||
|
tmp_imgs_dir = self.image_writer.get_episode_dir(episode_index, key)
|
||||||
|
video_path = self.get_video_file_path(episode_index, key, return_str=False)
|
||||||
|
if video_path.is_file():
|
||||||
|
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||||
|
continue
|
||||||
|
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||||
|
# since video encoding with ffmpeg is already using multithreading.
|
||||||
|
encode_video_frames(tmp_imgs_dir, video_path, self.fps, overwrite=True)
|
||||||
|
shutil.rmtree(tmp_imgs_dir)
|
||||||
|
|
||||||
def consolidate(self, run_compute_stats: bool = True) -> None:
|
def consolidate(self, run_compute_stats: bool = True) -> None:
|
||||||
self._update_data_file_names()
|
self._update_data_file_names()
|
||||||
|
self.hf_dataset = self.load_hf_dataset()
|
||||||
|
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
|
||||||
|
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
||||||
|
|
||||||
|
if len(self.video_keys) > 0:
|
||||||
|
self.encode_videos()
|
||||||
|
|
||||||
if run_compute_stats:
|
if run_compute_stats:
|
||||||
logging.info("Computing dataset statistics")
|
logging.info("Computing dataset statistics")
|
||||||
self.hf_dataset = self.load_hf_dataset()
|
self._remove_image_writer()
|
||||||
self.stats = compute_stats(self)
|
self.stats = compute_stats(self)
|
||||||
serialized_stats = {key: value.tolist() for key, value in self.stats.items()}
|
serialized_stats = flatten_dict(self.stats)
|
||||||
|
serialized_stats = {key: value.tolist() for key, value in serialized_stats.items()}
|
||||||
serialized_stats = unflatten_dict(serialized_stats)
|
serialized_stats = unflatten_dict(serialized_stats)
|
||||||
write_json(serialized_stats, self.root / "meta/stats.json")
|
write_json(serialized_stats, self.root / "meta/stats.json")
|
||||||
|
self.consolidated = True
|
||||||
else:
|
else:
|
||||||
logging.warning("Skipping computation of the dataset statistics.")
|
logging.warning("Skipping computation of the dataset statistics.")
|
||||||
|
|
||||||
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
|
# TODO(aliberts)
|
||||||
pass # TODO
|
|
||||||
# Sanity checks:
|
# Sanity checks:
|
||||||
# - [ ] shapes
|
# - [ ] shapes
|
||||||
# - [ ] ep_lenghts
|
# - [ ] ep_lenghts
|
||||||
# - [ ] number of files
|
# - [ ] number of files
|
||||||
# - [ ] names of files (e.g. parquet 00000-of-00001 and 00001-of-00002)
|
# - [ ] names of files (e.g. parquet 00000-of-00001 and 00001-of-00002)
|
||||||
# - [ ] no remaining self.image_writer.dir
|
# - [ ] no remaining self.image_writer.dir
|
||||||
self.consolidated = True
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
|
@ -691,7 +754,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
obj = cls.__new__(cls)
|
obj = cls.__new__(cls)
|
||||||
obj.repo_id = repo_id
|
obj.repo_id = repo_id
|
||||||
obj.root = root if root is not None else LEROBOT_HOME / repo_id
|
obj.root = root if root is not None else LEROBOT_HOME / repo_id
|
||||||
obj._version = CODEBASE_VERSION
|
|
||||||
obj.tolerance_s = tolerance_s
|
obj.tolerance_s = tolerance_s
|
||||||
obj.image_writer = image_writer
|
obj.image_writer = image_writer
|
||||||
|
|
||||||
|
@ -702,21 +764,26 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
)
|
)
|
||||||
|
|
||||||
obj.tasks, obj.stats, obj.episode_dicts = {}, {}, []
|
obj.tasks, obj.stats, obj.episode_dicts = {}, {}, []
|
||||||
obj.info = create_empty_dataset_info(obj._version, fps, robot, use_videos)
|
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot, use_videos)
|
||||||
write_json(obj.info, obj.root / "meta/info.json")
|
write_json(obj.info, obj.root / INFO_PATH)
|
||||||
|
|
||||||
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
|
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
|
||||||
obj.episode_buffer = obj._create_episode_buffer()
|
obj.episode_buffer = obj._create_episode_buffer()
|
||||||
|
|
||||||
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk.
|
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It
|
||||||
# It is used to know when certain operations are need (for instance, computing dataset statistics).
|
# is used to know when certain operations are need (for instance, computing dataset statistics). In
|
||||||
# In order to be able to push the dataset to the hub, it needs to be consolidation first.
|
# order to be able to push the dataset to the hub, it needs to be consolidated first by calling
|
||||||
|
# self.consolidate().
|
||||||
obj.consolidated = True
|
obj.consolidated = True
|
||||||
|
|
||||||
|
obj.local_files_only = True
|
||||||
|
obj.download_videos = False
|
||||||
|
|
||||||
obj.episodes = None
|
obj.episodes = None
|
||||||
obj.hf_dataset = None
|
obj.hf_dataset = None
|
||||||
obj.image_transforms = None
|
obj.image_transforms = None
|
||||||
obj.delta_timestamps = None
|
obj.delta_timestamps = None
|
||||||
|
obj.delta_indices = None
|
||||||
obj.episode_data_index = None
|
obj.episode_data_index = None
|
||||||
obj.video_backend = video_backend if video_backend is not None else "pyav"
|
obj.video_backend = video_backend if video_backend is not None else "pyav"
|
||||||
return obj
|
return obj
|
||||||
|
|
|
@ -30,6 +30,12 @@ from torchvision import transforms
|
||||||
from lerobot.common.robot_devices.robots.utils import Robot
|
from lerobot.common.robot_devices.robots.utils import Robot
|
||||||
|
|
||||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
|
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
|
||||||
|
|
||||||
|
INFO_PATH = "meta/info.json"
|
||||||
|
EPISODES_PATH = "meta/episodes.jsonl"
|
||||||
|
STATS_PATH = "meta/stats.json"
|
||||||
|
TASKS_PATH = "meta/tasks.jsonl"
|
||||||
|
|
||||||
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||||
DEFAULT_PARQUET_PATH = (
|
DEFAULT_PARQUET_PATH = (
|
||||||
"data/chunk-{episode_chunk:03d}/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
|
"data/chunk-{episode_chunk:03d}/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
|
||||||
|
@ -104,6 +110,32 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||||
return items_dict
|
return items_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _get_major_minor(version: str) -> tuple[int]:
|
||||||
|
split = version.strip("v").split(".")
|
||||||
|
return int(split[0]), int(split[1])
|
||||||
|
|
||||||
|
|
||||||
|
def check_version_compatibility(
|
||||||
|
repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True
|
||||||
|
) -> None:
|
||||||
|
current_major, _ = _get_major_minor(current_version)
|
||||||
|
major_to_check, _ = _get_major_minor(version_to_check)
|
||||||
|
if major_to_check < current_major and enforce_breaking_major:
|
||||||
|
raise ValueError(
|
||||||
|
f"""The dataset you requested ({repo_id}) is in {version_to_check} format. We introduced a new
|
||||||
|
format with v2.0 that is not backward compatible. Please use our conversion script
|
||||||
|
first (convert_dataset_v1_to_v2.py) to convert your dataset to this new format."""
|
||||||
|
)
|
||||||
|
elif float(version_to_check.strip("v")) < float(current_version.strip("v")):
|
||||||
|
warnings.warn(
|
||||||
|
f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the
|
||||||
|
codebase. The current codebase version is {current_version}. You should be fine since
|
||||||
|
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
|
||||||
|
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
|
||||||
|
stacklevel=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) -> str:
|
def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) -> str:
|
||||||
num_version = float(version.strip("v"))
|
num_version = float(version.strip("v"))
|
||||||
if num_version < 2 and enforce_v2:
|
if num_version < 2 and enforce_v2:
|
||||||
|
@ -131,30 +163,28 @@ def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) ->
|
||||||
return version
|
return version
|
||||||
|
|
||||||
|
|
||||||
def load_metadata(local_dir: Path) -> tuple[dict | list]:
|
def load_info(local_dir: Path) -> dict:
|
||||||
"""Loads metadata files from a dataset."""
|
with open(local_dir / INFO_PATH) as f:
|
||||||
info_path = local_dir / "meta/info.json"
|
return json.load(f)
|
||||||
episodes_path = local_dir / "meta/episodes.jsonl"
|
|
||||||
stats_path = local_dir / "meta/stats.json"
|
|
||||||
tasks_path = local_dir / "meta/tasks.jsonl"
|
|
||||||
|
|
||||||
with open(info_path) as f:
|
|
||||||
info = json.load(f)
|
|
||||||
|
|
||||||
with jsonlines.open(episodes_path, "r") as reader:
|
def load_stats(local_dir: Path) -> dict:
|
||||||
episode_dicts = list(reader)
|
with open(local_dir / STATS_PATH) as f:
|
||||||
|
|
||||||
with open(stats_path) as f:
|
|
||||||
stats = json.load(f)
|
stats = json.load(f)
|
||||||
|
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
|
||||||
|
return unflatten_dict(stats)
|
||||||
|
|
||||||
with jsonlines.open(tasks_path, "r") as reader:
|
|
||||||
|
def load_tasks(local_dir: Path) -> dict:
|
||||||
|
with jsonlines.open(local_dir / TASKS_PATH, "r") as reader:
|
||||||
tasks = list(reader)
|
tasks = list(reader)
|
||||||
|
|
||||||
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
|
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||||
stats = unflatten_dict(stats)
|
|
||||||
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
|
||||||
|
|
||||||
return info, episode_dicts, stats, tasks
|
|
||||||
|
def load_episode_dicts(local_dir: Path) -> dict:
|
||||||
|
with jsonlines.open(local_dir / EPISODES_PATH, "r") as reader:
|
||||||
|
return list(reader)
|
||||||
|
|
||||||
|
|
||||||
def create_empty_dataset_info(codebase_version: str, fps: int, robot: Robot, use_videos: bool = True) -> dict:
|
def create_empty_dataset_info(codebase_version: str, fps: int, robot: Robot, use_videos: bool = True) -> dict:
|
||||||
|
@ -229,7 +259,7 @@ def check_timestamps_sync(
|
||||||
# Track original indices before masking
|
# Track original indices before masking
|
||||||
original_indices = torch.arange(len(diffs))
|
original_indices = torch.arange(len(diffs))
|
||||||
filtered_indices = original_indices[mask]
|
filtered_indices = original_indices[mask]
|
||||||
outside_tolerance_filtered_indices = torch.nonzero(~filtered_within_tolerance).squeeze()
|
outside_tolerance_filtered_indices = torch.nonzero(~filtered_within_tolerance) # .squeeze()
|
||||||
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
|
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
|
||||||
episode_indices = torch.stack(hf_dataset["episode_index"])
|
episode_indices = torch.stack(hf_dataset["episode_index"])
|
||||||
|
|
||||||
|
|
|
@ -126,8 +126,8 @@ def decode_video_frames_torchvision(
|
||||||
|
|
||||||
|
|
||||||
def encode_video_frames(
|
def encode_video_frames(
|
||||||
imgs_dir: Path,
|
imgs_dir: Path | str,
|
||||||
video_path: Path,
|
video_path: Path | str,
|
||||||
fps: int,
|
fps: int,
|
||||||
vcodec: str = "libsvtav1",
|
vcodec: str = "libsvtav1",
|
||||||
pix_fmt: str = "yuv420p",
|
pix_fmt: str = "yuv420p",
|
||||||
|
|
|
@ -194,19 +194,17 @@ def record(
|
||||||
pretrained_policy_name_or_path: str | None = None,
|
pretrained_policy_name_or_path: str | None = None,
|
||||||
policy_overrides: List[str] | None = None,
|
policy_overrides: List[str] | None = None,
|
||||||
fps: int | None = None,
|
fps: int | None = None,
|
||||||
warmup_time_s=2,
|
warmup_time_s: int | float = 2,
|
||||||
episode_time_s=10,
|
episode_time_s: int | float = 10,
|
||||||
reset_time_s=5,
|
reset_time_s: int | float = 5,
|
||||||
num_episodes=50,
|
num_episodes: int = 50,
|
||||||
video=True,
|
video: bool = True,
|
||||||
run_compute_stats=True,
|
run_compute_stats: bool = True,
|
||||||
push_to_hub=True,
|
push_to_hub: bool = True,
|
||||||
tags=None,
|
num_image_writer_processes: int = 0,
|
||||||
num_image_writer_processes=0,
|
num_image_writer_threads_per_camera: int = 4,
|
||||||
num_image_writer_threads_per_camera=4,
|
display_cameras: bool = True,
|
||||||
force_override=False,
|
play_sounds: bool = True,
|
||||||
display_cameras=True,
|
|
||||||
play_sounds=True,
|
|
||||||
) -> LeRobotDataset:
|
) -> LeRobotDataset:
|
||||||
# TODO(rcadene): Add option to record logs
|
# TODO(rcadene): Add option to record logs
|
||||||
listener = None
|
listener = None
|
||||||
|
@ -234,12 +232,18 @@ def record(
|
||||||
|
|
||||||
# Create empty dataset or load existing saved episodes
|
# Create empty dataset or load existing saved episodes
|
||||||
sanity_check_dataset_name(repo_id, policy)
|
sanity_check_dataset_name(repo_id, policy)
|
||||||
|
if len(robot.cameras) > 0:
|
||||||
image_writer = ImageWriter(
|
image_writer = ImageWriter(
|
||||||
write_dir=root,
|
write_dir=root,
|
||||||
num_processes=num_image_writer_processes,
|
num_processes=num_image_writer_processes,
|
||||||
num_threads=num_image_writer_threads_per_camera * robot.num_cameras,
|
num_threads=num_image_writer_threads_per_camera * robot.num_cameras,
|
||||||
)
|
)
|
||||||
dataset = LeRobotDataset.create(repo_id, fps, robot, root=root, image_writer=image_writer)
|
else:
|
||||||
|
image_writer = None
|
||||||
|
|
||||||
|
dataset = LeRobotDataset.create(
|
||||||
|
repo_id, fps, robot, root=root, image_writer=image_writer, use_videos=video
|
||||||
|
)
|
||||||
|
|
||||||
if not robot.is_connected:
|
if not robot.is_connected:
|
||||||
robot.connect()
|
robot.connect()
|
||||||
|
@ -307,6 +311,7 @@ def record(
|
||||||
log_say("Stop recording", play_sounds, blocking=True)
|
log_say("Stop recording", play_sounds, blocking=True)
|
||||||
stop_recording(robot, listener, display_cameras)
|
stop_recording(robot, listener, display_cameras)
|
||||||
|
|
||||||
|
if dataset.image_writer is not None:
|
||||||
logging.info("Waiting for image writer to terminate...")
|
logging.info("Waiting for image writer to terminate...")
|
||||||
dataset.image_writer.stop()
|
dataset.image_writer.stop()
|
||||||
|
|
||||||
|
@ -322,27 +327,28 @@ def record(
|
||||||
|
|
||||||
@safe_disconnect
|
@safe_disconnect
|
||||||
def replay(
|
def replay(
|
||||||
robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug", play_sounds=True
|
robot: Robot,
|
||||||
|
root: Path,
|
||||||
|
repo_id: str,
|
||||||
|
episode: int,
|
||||||
|
fps: int | None = None,
|
||||||
|
play_sounds: bool = True,
|
||||||
|
local_files_only: bool = True,
|
||||||
):
|
):
|
||||||
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
|
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
|
||||||
# TODO(rcadene): Add option to record logs
|
# TODO(rcadene): Add option to record logs
|
||||||
local_dir = Path(root) / repo_id
|
|
||||||
if not local_dir.exists():
|
|
||||||
raise ValueError(local_dir)
|
|
||||||
|
|
||||||
dataset = LeRobotDataset(repo_id, root=root)
|
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
|
||||||
items = dataset.hf_dataset.select_columns("action")
|
actions = dataset.hf_dataset.select_columns("action")
|
||||||
from_idx = dataset.episode_data_index["from"][episode].item()
|
|
||||||
to_idx = dataset.episode_data_index["to"][episode].item()
|
|
||||||
|
|
||||||
if not robot.is_connected:
|
if not robot.is_connected:
|
||||||
robot.connect()
|
robot.connect()
|
||||||
|
|
||||||
log_say("Replaying episode", play_sounds, blocking=True)
|
log_say("Replaying episode", play_sounds, blocking=True)
|
||||||
for idx in range(from_idx, to_idx):
|
for idx in range(dataset.num_samples):
|
||||||
start_episode_t = time.perf_counter()
|
start_episode_t = time.perf_counter()
|
||||||
|
|
||||||
action = items[idx]["action"]
|
action = actions[idx]["action"]
|
||||||
robot.send_action(action)
|
robot.send_action(action)
|
||||||
|
|
||||||
dt_s = time.perf_counter() - start_episode_t
|
dt_s = time.perf_counter() - start_episode_t
|
||||||
|
|
Loading…
Reference in New Issue