Add download_metadata, move default paths

This commit is contained in:
Simon Alibert 2024-10-18 14:48:34 +02:00
parent e7355ba595
commit 1a51505ec6
2 changed files with 68 additions and 64 deletions

View File

@ -31,9 +31,7 @@ from lerobot.common.datasets.utils import (
get_episode_data_index, get_episode_data_index,
get_hub_safe_version, get_hub_safe_version,
load_hf_dataset, load_hf_dataset,
load_info, load_metadata,
load_stats,
load_tasks,
) )
from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision
@ -41,6 +39,12 @@ from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_
CODEBASE_VERSION = "v2.0" CODEBASE_VERSION = "v2.0"
LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser() LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
DEFAULT_CHUNK_SIZE = 1000
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
DEFAULT_PARQUET_PATH = (
"data/chunk-{episode_chunk:03d}/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
)
class LeRobotDataset(torch.utils.data.Dataset): class LeRobotDataset(torch.utils.data.Dataset):
def __init__( def __init__(
@ -70,7 +74,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
Instantiating this class with this 'repo_id' will download the dataset from that address and load Instantiating this class with this 'repo_id' will download the dataset from that address and load
it, pending your dataset is compliant with codebase_version v2.0. If your dataset has been created it, pending your dataset is compliant with codebase_version v2.0. If your dataset has been created
before this new format, you will be prompted to convert it using our conversion script from v1.6 before this new format, you will be prompted to convert it using our conversion script from v1.6
to v2.0, which you can find at [TODO(aliberts): move conversion script & add location here]. to v2.0, which you can find at lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py.
2. Your dataset already exists on your local disk in the 'root' folder: 2. Your dataset already exists on your local disk in the 'root' folder:
This is typically the case when you recorded your dataset locally and you may or may not have This is typically the case when you recorded your dataset locally and you may or may not have
@ -139,7 +143,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
decoded from video files. It is also used to check that `delta_timestamps` (when provided) are decoded from video files. It is also used to check that `delta_timestamps` (when provided) are
multiples of 1/fps. Defaults to 1e-4. multiples of 1/fps. Defaults to 1e-4.
download_videos (bool, optional): Flag to download the videos. Defaults to True. download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
video files are already present on local disk, they won't be downloaded again. Defaults to
True.
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
a single option which is the pyav decoder used by Torchvision. Defaults to pyav. a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
""" """
@ -157,9 +163,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Load metadata # Load metadata
self.root.mkdir(exist_ok=True, parents=True) self.root.mkdir(exist_ok=True, parents=True)
self._version = get_hub_safe_version(repo_id, CODEBASE_VERSION) self._version = get_hub_safe_version(repo_id, CODEBASE_VERSION)
self.info = load_info(repo_id, self._version, self.root) self.download_metadata()
self.stats = load_stats(repo_id, self._version, self.root) self.info, self.episode_dicts, self.stats, self.tasks = load_metadata(self.root)
self.tasks = load_tasks(repo_id, self._version, self.root)
# Load actual data # Load actual data
self.download_episodes() self.download_episodes()
@ -185,6 +190,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
# - [ ] Update episode_index (arg update=True) # - [ ] Update episode_index (arg update=True)
# - [ ] Update info.json (arg update=True) # - [ ] Update info.json (arg update=True)
def download_metadata(self) -> None:
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self._version,
local_dir=self.root,
allow_patterns="meta/",
)
def download_episodes(self) -> None: def download_episodes(self) -> None:
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this """Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
@ -227,11 +241,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""Formattable string for the video files.""" """Formattable string for the video files."""
return self.info["videos"]["videos_path"] if len(self.video_keys) > 0 else None return self.info["videos"]["videos_path"] if len(self.video_keys) > 0 else None
@property
def episode_dicts(self) -> list[dict]:
"""List of dictionary containing information for each episode, indexed by episode_index."""
return self.info["episodes"]
@property @property
def fps(self) -> int: def fps(self) -> int:
"""Frames per second used during data collection.""" """Frames per second used during data collection."""
@ -254,7 +263,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
@property @property
def camera_keys(self) -> list[str]: def camera_keys(self) -> list[str]:
"""Keys to access image and video streams from cameras (regardless of their storage method).""" """Keys to access visual modalities (regardless of their storage method)."""
return self.image_keys + self.video_keys return self.image_keys + self.video_keys
@property @property
@ -277,6 +286,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""Total number of episodes available.""" """Total number of episodes available."""
return self.info["total_episodes"] return self.info["total_episodes"]
@property
def total_chunks(self) -> int:
"""Total number of chunks (groups of episodes)."""
return self.info["total_chunks"]
@property
def chunks_size(self) -> int:
"""Max number of episodes per chunk."""
return self.info["chunks_size"]
@property @property
def shapes(self) -> dict: def shapes(self) -> dict:
"""Shapes for the different features.""" """Shapes for the different features."""
@ -397,42 +416,28 @@ class LeRobotDataset(torch.utils.data.Dataset):
) )
@classmethod @classmethod
def from_preloaded( def create(
cls, cls,
repo_id: str = "from_preloaded", repo_id: str,
root: Path | None = None, root: Path | None = None,
split: str = "train", image_transforms: Callable | None = None,
transform: callable = None,
delta_timestamps: dict[list[float]] | None = None, delta_timestamps: dict[list[float]] | None = None,
# additional preloaded attributes tolerance_s: float = 1e-4,
hf_dataset=None, video_backend: str | None = None,
episode_data_index=None,
stats=None,
info=None,
videos_dir=None,
video_backend=None,
) -> "LeRobotDataset": ) -> "LeRobotDataset":
"""Create a LeRobot Dataset from existing data and attributes instead of loading from the filesystem. """Create a LeRobot Dataset from scratch in order to record data."""
It is especially useful when converting raw data into LeRobotDataset before saving the dataset
on the filesystem or uploading to the hub.
Note: Meta-data attributes like `repo_id`, `version`, `root`, etc are optional and potentially
meaningless depending on the downstream usage of the return dataset.
"""
# create an empty object of type LeRobotDataset # create an empty object of type LeRobotDataset
obj = cls.__new__(cls) obj = cls.__new__(cls)
obj.repo_id = repo_id obj.repo_id = repo_id
obj.root = root obj.root = root if root is not None else LEROBOT_HOME / repo_id
obj.split = split # obj.episodes = None
obj.image_transforms = transform # obj.image_transforms = None
obj.delta_timestamps = delta_timestamps # obj.delta_timestamps = None
obj.hf_dataset = hf_dataset # obj.episode_data_index = episode_data_index
obj.episode_data_index = episode_data_index # obj.stats = stats
obj.stats = stats # obj.info = info if info is not None else {}
obj.info = info if info is not None else {} # obj.videos_dir = videos_dir
obj.videos_dir = videos_dir # obj.video_backend = video_backend if video_backend is not None else "pyav"
obj.video_backend = video_backend if video_backend is not None else "pyav"
return obj return obj

View File

@ -120,6 +120,11 @@ from huggingface_hub.errors import EntryNotFoundError
from PIL import Image from PIL import Image
from safetensors.torch import load_file from safetensors.torch import load_file
from lerobot.common.datasets.lerobot_dataset import (
DEFAULT_CHUNK_SIZE,
DEFAULT_PARQUET_PATH,
DEFAULT_VIDEO_PATH,
)
from lerobot.common.datasets.utils import create_branch, flatten_dict, get_hub_safe_version, unflatten_dict from lerobot.common.datasets.utils import create_branch, flatten_dict, get_hub_safe_version, unflatten_dict
from lerobot.common.utils.utils import init_hydra_config from lerobot.common.utils.utils import init_hydra_config
from lerobot.scripts.push_dataset_to_hub import push_dataset_card_to_hub from lerobot.scripts.push_dataset_to_hub import push_dataset_card_to_hub
@ -127,15 +132,8 @@ from lerobot.scripts.push_dataset_to_hub import push_dataset_card_to_hub
V16 = "v1.6" V16 = "v1.6"
V20 = "v2.0" V20 = "v2.0"
EPISODE_CHUNK_SIZE = 1000
GITATTRIBUTES_REF = "aliberts/gitattributes_reference" GITATTRIBUTES_REF = "aliberts/gitattributes_reference"
VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4" VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4"
PARQUET_CHUNK_PATH = (
"data/chunk-{episode_chunk:03d}/train-{episode_index:05d}-of-{total_episodes:05d}.parquet"
)
VIDEO_CHUNK_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
def parse_robot_config(config_path: Path, config_overrides: list[str] | None = None) -> tuple[str, dict]: def parse_robot_config(config_path: Path, config_overrides: list[str] | None = None) -> tuple[str, dict]:
@ -269,15 +267,15 @@ def split_parquet_by_episodes(
table = dataset.remove_columns(keys["video"])._data.table table = dataset.remove_columns(keys["video"])._data.table
episode_lengths = [] episode_lengths = []
for ep_chunk in range(total_chunks): for ep_chunk in range(total_chunks):
ep_chunk_start = EPISODE_CHUNK_SIZE * ep_chunk ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
ep_chunk_end = min(EPISODE_CHUNK_SIZE * (ep_chunk + 1), total_episodes) ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
chunk_dir = "/".join(PARQUET_CHUNK_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk) chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True) (output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
for ep_idx in range(ep_chunk_start, ep_chunk_end): for ep_idx in range(ep_chunk_start, ep_chunk_end):
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx)) ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
episode_lengths.insert(ep_idx, len(ep_table)) episode_lengths.insert(ep_idx, len(ep_table))
output_file = output_dir / PARQUET_CHUNK_PATH.format( output_file = output_dir / DEFAULT_PARQUET_PATH.format(
episode_chunk=ep_chunk, episode_index=ep_idx, total_episodes=total_episodes episode_chunk=ep_chunk, episode_index=ep_idx, total_episodes=total_episodes
) )
pq.write_table(ep_table, output_file) pq.write_table(ep_table, output_file)
@ -323,16 +321,16 @@ def move_videos(
video_dirs = sorted(work_dir.glob("videos*/")) video_dirs = sorted(work_dir.glob("videos*/"))
for ep_chunk in range(total_chunks): for ep_chunk in range(total_chunks):
ep_chunk_start = EPISODE_CHUNK_SIZE * ep_chunk ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
ep_chunk_end = min(EPISODE_CHUNK_SIZE * (ep_chunk + 1), total_episodes) ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
for vid_key in video_keys: for vid_key in video_keys:
chunk_dir = "/".join(VIDEO_CHUNK_PATH.split("/")[:-1]).format( chunk_dir = "/".join(DEFAULT_VIDEO_PATH.split("/")[:-1]).format(
episode_chunk=ep_chunk, video_key=vid_key episode_chunk=ep_chunk, video_key=vid_key
) )
(work_dir / chunk_dir).mkdir(parents=True, exist_ok=True) (work_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
for ep_idx in range(ep_chunk_start, ep_chunk_end): for ep_idx in range(ep_chunk_start, ep_chunk_end):
target_path = VIDEO_CHUNK_PATH.format( target_path = DEFAULT_VIDEO_PATH.format(
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
) )
video_file = VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx) video_file = VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx)
@ -476,11 +474,12 @@ def _get_video_info(video_path: Path | str) -> dict:
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict: def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
hub_api = HfApi() hub_api = HfApi()
videos_info_dict = {"videos_path": VIDEO_CHUNK_PATH} videos_info_dict = {"videos_path": DEFAULT_VIDEO_PATH}
# Assumes first episode # Assumes first episode
video_files = [ video_files = [
VIDEO_CHUNK_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0) for vid_key in video_keys DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0)
for vid_key in video_keys
] ]
hub_api.snapshot_download( hub_api.snapshot_download(
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
@ -587,8 +586,8 @@ def convert_dataset(
total_episodes = len(episode_indices) total_episodes = len(episode_indices)
assert episode_indices == list(range(total_episodes)) assert episode_indices == list(range(total_episodes))
total_videos = total_episodes * len(keys["video"]) total_videos = total_episodes * len(keys["video"])
total_chunks = total_episodes // EPISODE_CHUNK_SIZE total_chunks = total_episodes // DEFAULT_CHUNK_SIZE
if total_episodes % EPISODE_CHUNK_SIZE != 0: if total_episodes % DEFAULT_CHUNK_SIZE != 0:
total_chunks += 1 total_chunks += 1
# Tasks # Tasks
@ -670,14 +669,14 @@ def convert_dataset(
# Assemble metadata v2.0 # Assemble metadata v2.0
metadata_v2_0 = { metadata_v2_0 = {
"codebase_version": V20, "codebase_version": V20,
"data_path": PARQUET_CHUNK_PATH, "data_path": DEFAULT_PARQUET_PATH,
"robot_type": robot_type, "robot_type": robot_type,
"total_episodes": total_episodes, "total_episodes": total_episodes,
"total_frames": len(dataset), "total_frames": len(dataset),
"total_tasks": len(tasks), "total_tasks": len(tasks),
"total_videos": total_videos, "total_videos": total_videos,
"total_chunks": total_chunks, "total_chunks": total_chunks,
"chunks_size": EPISODE_CHUNK_SIZE, "chunks_size": DEFAULT_CHUNK_SIZE,
"fps": metadata_v1["fps"], "fps": metadata_v1["fps"],
"splits": {"train": f"0:{total_episodes}"}, "splits": {"train": f"0:{total_episodes}"},
"keys": keys["sequence"], "keys": keys["sequence"],