Add download_metadata, move default paths

This commit is contained in:
Simon Alibert 2024-10-18 14:48:34 +02:00
parent e7355ba595
commit 1a51505ec6
2 changed files with 68 additions and 64 deletions

View File

@ -31,9 +31,7 @@ from lerobot.common.datasets.utils import (
get_episode_data_index,
get_hub_safe_version,
load_hf_dataset,
load_info,
load_stats,
load_tasks,
load_metadata,
)
from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision
@ -41,6 +39,12 @@ from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_
CODEBASE_VERSION = "v2.0"
LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
DEFAULT_CHUNK_SIZE = 1000
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
DEFAULT_PARQUET_PATH = (
"data/chunk-{episode_chunk:03d}/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
)
class LeRobotDataset(torch.utils.data.Dataset):
def __init__(
@ -70,7 +74,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
Instantiating this class with this 'repo_id' will download the dataset from that address and load
it, pending your dataset is compliant with codebase_version v2.0. If your dataset has been created
before this new format, you will be prompted to convert it using our conversion script from v1.6
to v2.0, which you can find at [TODO(aliberts): move conversion script & add location here].
to v2.0, which you can find at lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py.
2. Your dataset already exists on your local disk in the 'root' folder:
This is typically the case when you recorded your dataset locally and you may or may not have
@ -139,7 +143,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
decoded from video files. It is also used to check that `delta_timestamps` (when provided) are
multiples of 1/fps. Defaults to 1e-4.
download_videos (bool, optional): Flag to download the videos. Defaults to True.
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
video files are already present on local disk, they won't be downloaded again. Defaults to
True.
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
"""
@ -157,9 +163,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Load metadata
self.root.mkdir(exist_ok=True, parents=True)
self._version = get_hub_safe_version(repo_id, CODEBASE_VERSION)
self.info = load_info(repo_id, self._version, self.root)
self.stats = load_stats(repo_id, self._version, self.root)
self.tasks = load_tasks(repo_id, self._version, self.root)
self.download_metadata()
self.info, self.episode_dicts, self.stats, self.tasks = load_metadata(self.root)
# Load actual data
self.download_episodes()
@ -185,6 +190,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
# - [ ] Update episode_index (arg update=True)
# - [ ] Update info.json (arg update=True)
def download_metadata(self) -> None:
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self._version,
local_dir=self.root,
allow_patterns="meta/",
)
def download_episodes(self) -> None:
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
@ -227,11 +241,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""Formattable string for the video files."""
return self.info["videos"]["videos_path"] if len(self.video_keys) > 0 else None
@property
def episode_dicts(self) -> list[dict]:
"""List of dictionary containing information for each episode, indexed by episode_index."""
return self.info["episodes"]
@property
def fps(self) -> int:
"""Frames per second used during data collection."""
@ -254,7 +263,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
@property
def camera_keys(self) -> list[str]:
"""Keys to access image and video streams from cameras (regardless of their storage method)."""
"""Keys to access visual modalities (regardless of their storage method)."""
return self.image_keys + self.video_keys
@property
@ -277,6 +286,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""Total number of episodes available."""
return self.info["total_episodes"]
@property
def total_chunks(self) -> int:
"""Total number of chunks (groups of episodes)."""
return self.info["total_chunks"]
@property
def chunks_size(self) -> int:
"""Max number of episodes per chunk."""
return self.info["chunks_size"]
@property
def shapes(self) -> dict:
"""Shapes for the different features."""
@ -397,42 +416,28 @@ class LeRobotDataset(torch.utils.data.Dataset):
)
@classmethod
def from_preloaded(
def create(
cls,
repo_id: str = "from_preloaded",
repo_id: str,
root: Path | None = None,
split: str = "train",
transform: callable = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
# additional preloaded attributes
hf_dataset=None,
episode_data_index=None,
stats=None,
info=None,
videos_dir=None,
video_backend=None,
tolerance_s: float = 1e-4,
video_backend: str | None = None,
) -> "LeRobotDataset":
"""Create a LeRobot Dataset from existing data and attributes instead of loading from the filesystem.
It is especially useful when converting raw data into LeRobotDataset before saving the dataset
on the filesystem or uploading to the hub.
Note: Meta-data attributes like `repo_id`, `version`, `root`, etc are optional and potentially
meaningless depending on the downstream usage of the return dataset.
"""
"""Create a LeRobot Dataset from scratch in order to record data."""
# create an empty object of type LeRobotDataset
obj = cls.__new__(cls)
obj.repo_id = repo_id
obj.root = root
obj.split = split
obj.image_transforms = transform
obj.delta_timestamps = delta_timestamps
obj.hf_dataset = hf_dataset
obj.episode_data_index = episode_data_index
obj.stats = stats
obj.info = info if info is not None else {}
obj.videos_dir = videos_dir
obj.video_backend = video_backend if video_backend is not None else "pyav"
obj.root = root if root is not None else LEROBOT_HOME / repo_id
# obj.episodes = None
# obj.image_transforms = None
# obj.delta_timestamps = None
# obj.episode_data_index = episode_data_index
# obj.stats = stats
# obj.info = info if info is not None else {}
# obj.videos_dir = videos_dir
# obj.video_backend = video_backend if video_backend is not None else "pyav"
return obj

View File

@ -120,6 +120,11 @@ from huggingface_hub.errors import EntryNotFoundError
from PIL import Image
from safetensors.torch import load_file
from lerobot.common.datasets.lerobot_dataset import (
DEFAULT_CHUNK_SIZE,
DEFAULT_PARQUET_PATH,
DEFAULT_VIDEO_PATH,
)
from lerobot.common.datasets.utils import create_branch, flatten_dict, get_hub_safe_version, unflatten_dict
from lerobot.common.utils.utils import init_hydra_config
from lerobot.scripts.push_dataset_to_hub import push_dataset_card_to_hub
@ -127,15 +132,8 @@ from lerobot.scripts.push_dataset_to_hub import push_dataset_card_to_hub
V16 = "v1.6"
V20 = "v2.0"
EPISODE_CHUNK_SIZE = 1000
GITATTRIBUTES_REF = "aliberts/gitattributes_reference"
VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4"
PARQUET_CHUNK_PATH = (
"data/chunk-{episode_chunk:03d}/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
)
VIDEO_CHUNK_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
def parse_robot_config(config_path: Path, config_overrides: list[str] | None = None) -> tuple[str, dict]:
@ -269,15 +267,15 @@ def split_parquet_by_episodes(
table = dataset.remove_columns(keys["video"])._data.table
episode_lengths = []
for ep_chunk in range(total_chunks):
ep_chunk_start = EPISODE_CHUNK_SIZE * ep_chunk
ep_chunk_end = min(EPISODE_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
chunk_dir = "/".join(PARQUET_CHUNK_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
for ep_idx in range(ep_chunk_start, ep_chunk_end):
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
episode_lengths.insert(ep_idx, len(ep_table))
output_file = output_dir / PARQUET_CHUNK_PATH.format(
output_file = output_dir / DEFAULT_PARQUET_PATH.format(
episode_chunk=ep_chunk, episode_index=ep_idx, total_episodes=total_episodes
)
pq.write_table(ep_table, output_file)
@ -323,16 +321,16 @@ def move_videos(
video_dirs = sorted(work_dir.glob("videos*/"))
for ep_chunk in range(total_chunks):
ep_chunk_start = EPISODE_CHUNK_SIZE * ep_chunk
ep_chunk_end = min(EPISODE_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
for vid_key in video_keys:
chunk_dir = "/".join(VIDEO_CHUNK_PATH.split("/")[:-1]).format(
chunk_dir = "/".join(DEFAULT_VIDEO_PATH.split("/")[:-1]).format(
episode_chunk=ep_chunk, video_key=vid_key
)
(work_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
for ep_idx in range(ep_chunk_start, ep_chunk_end):
target_path = VIDEO_CHUNK_PATH.format(
target_path = DEFAULT_VIDEO_PATH.format(
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
)
video_file = VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx)
@ -476,11 +474,12 @@ def _get_video_info(video_path: Path | str) -> dict:
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
hub_api = HfApi()
videos_info_dict = {"videos_path": VIDEO_CHUNK_PATH}
videos_info_dict = {"videos_path": DEFAULT_VIDEO_PATH}
# Assumes first episode
video_files = [
VIDEO_CHUNK_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0) for vid_key in video_keys
DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0)
for vid_key in video_keys
]
hub_api.snapshot_download(
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
@ -587,8 +586,8 @@ def convert_dataset(
total_episodes = len(episode_indices)
assert episode_indices == list(range(total_episodes))
total_videos = total_episodes * len(keys["video"])
total_chunks = total_episodes // EPISODE_CHUNK_SIZE
if total_episodes % EPISODE_CHUNK_SIZE != 0:
total_chunks = total_episodes // DEFAULT_CHUNK_SIZE
if total_episodes % DEFAULT_CHUNK_SIZE != 0:
total_chunks += 1
# Tasks
@ -670,14 +669,14 @@ def convert_dataset(
# Assemble metadata v2.0
metadata_v2_0 = {
"codebase_version": V20,
"data_path": PARQUET_CHUNK_PATH,
"data_path": DEFAULT_PARQUET_PATH,
"robot_type": robot_type,
"total_episodes": total_episodes,
"total_frames": len(dataset),
"total_tasks": len(tasks),
"total_videos": total_videos,
"total_chunks": total_chunks,
"chunks_size": EPISODE_CHUNK_SIZE,
"chunks_size": DEFAULT_CHUNK_SIZE,
"fps": metadata_v1["fps"],
"splits": {"train": f"0:{total_episodes}"},
"keys": keys["sequence"],