Add file paths

This commit is contained in:
Simon Alibert 2024-10-20 14:00:19 +02:00
parent ac3798bd62
commit 9316cf46ef
2 changed files with 60 additions and 53 deletions

View File

@ -22,6 +22,7 @@ from typing import Callable
import datasets
import torch
import torch.utils
from datasets import load_dataset
from huggingface_hub import snapshot_download
from lerobot.common.datasets.compute_stats import aggregate_stats
@ -32,7 +33,7 @@ from lerobot.common.datasets.utils import (
get_delta_indices,
get_episode_data_index,
get_hub_safe_version,
load_hf_dataset,
hf_transform_to_torch,
load_metadata,
)
from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision
@ -100,7 +101,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
episodes.jsonl
info.json
stats.json
tasks.json
tasks.jsonl
videos (optional)
chunk-000
observation.images.laptop
@ -160,12 +161,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Load metadata
self.root.mkdir(exist_ok=True, parents=True)
self._version = get_hub_safe_version(repo_id, CODEBASE_VERSION)
self.download_metadata()
self.pull_from_repo(allow_patterns="meta/")
self.info, self.episode_dicts, self.stats, self.tasks = load_metadata(self.root)
# Load actual data
self.download_episodes()
self.hf_dataset = load_hf_dataset(self.root, self.data_path, self.total_episodes, self.episodes)
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
# Check timestamps
@ -187,13 +188,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
# - [ ] Update episode_index (arg update=True)
# - [ ] Update info.json (arg update=True)
def download_metadata(self) -> None:
def pull_from_repo(
self,
allow_patterns: list[str] | str | None = None,
ignore_patterns: list[str] | str | None = None,
) -> None:
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self._version,
local_dir=self.root,
allow_patterns="meta/",
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
)
def download_episodes(self) -> None:
@ -207,26 +213,46 @@ class LeRobotDataset(torch.utils.data.Dataset):
files = None
ignore_patterns = None if self.download_videos else "videos/"
if self.episodes is not None:
files = [
self.data_path.format(episode_index=ep_idx, total_episodes=self.total_episodes)
for ep_idx in self.episodes
]
files = [self.get_data_file_path(ep_idx) for ep_idx in self.episodes]
if len(self.video_keys) > 0 and self.download_videos:
video_files = [
self.videos_path.format(video_key=vid_key, episode_index=ep_idx)
self.get_video_file_path(ep_idx, vid_key)
for vid_key in self.video_keys
for ep_idx in self.episodes
]
files += video_files
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self._version,
local_dir=self.root,
allow_patterns=files,
ignore_patterns=ignore_patterns,
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
def load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
if self.episodes is None:
path = str(self.root / "data")
hf_dataset = load_dataset("parquet", data_dir=path, split="train")
else:
files = [self.get_data_file_path(ep_idx) for ep_idx in self.episodes]
hf_dataset = load_dataset("parquet", data_files=files, split="train")
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def get_data_file_path(self, ep_index: int, return_str: bool = True) -> str | Path:
ep_chunk = self.get_episode_chunk(ep_index)
fpath = self.data_path.format(
episode_chunk=ep_chunk, episode_index=ep_index, total_episodes=self.total_episodes
)
return str(self.root / fpath) if return_str else self.root / fpath
def get_video_file_path(self, ep_index: int, vid_key: str, return_str: bool = True) -> str | Path:
ep_chunk = self.get_episode_chunk(ep_index)
fpath = self.videos_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
return str(self.root / fpath) if return_str else self.root / fpath
def get_episode_chunk(self, ep_index: int) -> int:
ep_chunk = ep_index // self.chunks_size
if ep_index > 0 and ep_index % self.chunks_size == 0:
ep_chunk -= 1
return ep_chunk
@property
def data_path(self) -> str:
@ -355,7 +381,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""
item = {}
for vid_key, query_ts in query_timestamps.items():
video_path = self.root / self.videos_path.format(video_key=vid_key, episode_index=ep_idx)
video_path = self.root / self.get_video_file_path(ep_idx, vid_key)
frames = decode_video_frames_torchvision(
video_path, query_ts, self.tolerance_s, self.video_backend
)
@ -436,6 +462,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.write_info()
obj.fps = fps
if not all(cam.fps == fps for cam in robot.cameras):
logging.warn(
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
)
# obj.episodes = None
# obj.image_transforms = None
# obj.delta_timestamps = None

View File

@ -23,7 +23,6 @@ from typing import Dict
import datasets
import jsonlines
import torch
from datasets import load_dataset
from huggingface_hub import DatasetCard, HfApi
from PIL import Image as PILImage
from torchvision import transforms
@ -87,15 +86,6 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
if isinstance(first_item, PILImage.Image):
to_tensor = transforms.ToTensor()
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
# TODO(aliberts): remove this part as we'll be using task_index
elif isinstance(first_item, str):
# TODO (michel-aractingi): add str2embedding via language tokenizer
# For now we leave this part up to the user to choose how to address
# language conditioned tasks
pass
elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item:
# video frame will be processed downstream
pass
elif first_item is None:
pass
else:
@ -130,32 +120,12 @@ def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) ->
return version
def load_hf_dataset(
local_dir: Path,
data_path: str,
total_episodes: int,
episodes: list[int] | None = None,
split: str = "train",
) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
if episodes is None:
path = str(local_dir / "data")
hf_dataset = load_dataset("parquet", data_dir=path, split=split)
else:
files = [data_path.format(episode_index=ep_idx, total_episodes=total_episodes) for ep_idx in episodes]
files = [str(local_dir / fpath) for fpath in files]
hf_dataset = load_dataset("parquet", data_files=files, split=split)
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def load_metadata(local_dir: Path) -> tuple[dict | list]:
"""Loads metadata files from a dataset."""
info_path = local_dir / "meta/info.jsonl"
info_path = local_dir / "meta/info.json"
episodes_path = local_dir / "meta/episodes.jsonl"
stats_path = local_dir / "meta/stats.json"
tasks_path = local_dir / "meta/tasks.json"
tasks_path = local_dir / "meta/tasks.jsonl"
with open(info_path) as f:
info = json.load(f)
@ -499,12 +469,17 @@ def create_branch(repo_id, *, branch: str, repo_type: str | None = None):
api.create_branch(repo_id, repo_type=repo_type, branch=branch)
def create_lerobot_dataset_card(tags: list | None = None, text: str | None = None) -> DatasetCard:
def create_lerobot_dataset_card(
tags: list | None = None, text: str | None = None, info: dict | None = None
) -> DatasetCard:
card = DatasetCard(DATASET_CARD_TEMPLATE)
card.data.task_categories = ["robotics"]
card.data.tags = ["LeRobot"]
if tags is not None:
card.data.tags += tags
if text is not None:
card.text += text
card.text += f"{text}\n"
if info is not None:
card.text += "[meta/info.json](meta/info.json)\n"
card.text += f"```json\n{json.dumps(info, indent=4)}\n```"
return card