From 9316cf46ef4f7e2473d1d3e605d6fe1da4e6310d Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Sun, 20 Oct 2024 14:00:19 +0200 Subject: [PATCH] Add file paths --- lerobot/common/datasets/lerobot_dataset.py | 70 ++++++++++++++++------ lerobot/common/datasets/utils.py | 43 +++---------- 2 files changed, 60 insertions(+), 53 deletions(-) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index cda0412f..43d8708d 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -22,6 +22,7 @@ from typing import Callable import datasets import torch import torch.utils +from datasets import load_dataset from huggingface_hub import snapshot_download from lerobot.common.datasets.compute_stats import aggregate_stats @@ -32,7 +33,7 @@ from lerobot.common.datasets.utils import ( get_delta_indices, get_episode_data_index, get_hub_safe_version, - load_hf_dataset, + hf_transform_to_torch, load_metadata, ) from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision @@ -100,7 +101,7 @@ class LeRobotDataset(torch.utils.data.Dataset): │ ├── episodes.jsonl │ ├── info.json │ ├── stats.json - │ └── tasks.json + │ └── tasks.jsonl └── videos (optional) ├── chunk-000 │ ├── observation.images.laptop @@ -160,12 +161,12 @@ class LeRobotDataset(torch.utils.data.Dataset): # Load metadata self.root.mkdir(exist_ok=True, parents=True) self._version = get_hub_safe_version(repo_id, CODEBASE_VERSION) - self.download_metadata() + self.pull_from_repo(allow_patterns="meta/") self.info, self.episode_dicts, self.stats, self.tasks = load_metadata(self.root) # Load actual data self.download_episodes() - self.hf_dataset = load_hf_dataset(self.root, self.data_path, self.total_episodes, self.episodes) + self.hf_dataset = self.load_hf_dataset() self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts) # Check timestamps @@ -187,13 +188,18 @@ class LeRobotDataset(torch.utils.data.Dataset): # - [ ] Update episode_index (arg update=True) # - [ ] Update info.json (arg update=True) - def download_metadata(self) -> None: + def pull_from_repo( + self, + allow_patterns: list[str] | str | None = None, + ignore_patterns: list[str] | str | None = None, + ) -> None: snapshot_download( self.repo_id, repo_type="dataset", revision=self._version, local_dir=self.root, - allow_patterns="meta/", + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, ) def download_episodes(self) -> None: @@ -207,26 +213,46 @@ class LeRobotDataset(torch.utils.data.Dataset): files = None ignore_patterns = None if self.download_videos else "videos/" if self.episodes is not None: - files = [ - self.data_path.format(episode_index=ep_idx, total_episodes=self.total_episodes) - for ep_idx in self.episodes - ] + files = [self.get_data_file_path(ep_idx) for ep_idx in self.episodes] if len(self.video_keys) > 0 and self.download_videos: video_files = [ - self.videos_path.format(video_key=vid_key, episode_index=ep_idx) + self.get_video_file_path(ep_idx, vid_key) for vid_key in self.video_keys for ep_idx in self.episodes ] files += video_files - snapshot_download( - self.repo_id, - repo_type="dataset", - revision=self._version, - local_dir=self.root, - allow_patterns=files, - ignore_patterns=ignore_patterns, + self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns) + + def load_hf_dataset(self) -> datasets.Dataset: + """hf_dataset contains all the observations, states, actions, rewards, etc.""" + if self.episodes is None: + path = str(self.root / "data") + hf_dataset = load_dataset("parquet", data_dir=path, split="train") + else: + files = [self.get_data_file_path(ep_idx) for ep_idx in self.episodes] + hf_dataset = load_dataset("parquet", data_files=files, split="train") + + hf_dataset.set_transform(hf_transform_to_torch) + return hf_dataset + + def get_data_file_path(self, ep_index: int, return_str: bool = True) -> str | Path: + ep_chunk = self.get_episode_chunk(ep_index) + fpath = self.data_path.format( + episode_chunk=ep_chunk, episode_index=ep_index, total_episodes=self.total_episodes ) + return str(self.root / fpath) if return_str else self.root / fpath + + def get_video_file_path(self, ep_index: int, vid_key: str, return_str: bool = True) -> str | Path: + ep_chunk = self.get_episode_chunk(ep_index) + fpath = self.videos_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index) + return str(self.root / fpath) if return_str else self.root / fpath + + def get_episode_chunk(self, ep_index: int) -> int: + ep_chunk = ep_index // self.chunks_size + if ep_index > 0 and ep_index % self.chunks_size == 0: + ep_chunk -= 1 + return ep_chunk @property def data_path(self) -> str: @@ -355,7 +381,7 @@ class LeRobotDataset(torch.utils.data.Dataset): """ item = {} for vid_key, query_ts in query_timestamps.items(): - video_path = self.root / self.videos_path.format(video_key=vid_key, episode_index=ep_idx) + video_path = self.root / self.get_video_file_path(ep_idx, vid_key) frames = decode_video_frames_torchvision( video_path, query_ts, self.tolerance_s, self.video_backend ) @@ -436,6 +462,12 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.write_info() obj.fps = fps + if not all(cam.fps == fps for cam in robot.cameras): + logging.warn( + f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset." + "In this case, frames from lower fps cameras will be repeated to fill in the blanks" + ) + # obj.episodes = None # obj.image_transforms = None # obj.delta_timestamps = None diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index c80838e6..90bb35c1 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -23,7 +23,6 @@ from typing import Dict import datasets import jsonlines import torch -from datasets import load_dataset from huggingface_hub import DatasetCard, HfApi from PIL import Image as PILImage from torchvision import transforms @@ -87,15 +86,6 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): if isinstance(first_item, PILImage.Image): to_tensor = transforms.ToTensor() items_dict[key] = [to_tensor(img) for img in items_dict[key]] - # TODO(aliberts): remove this part as we'll be using task_index - elif isinstance(first_item, str): - # TODO (michel-aractingi): add str2embedding via language tokenizer - # For now we leave this part up to the user to choose how to address - # language conditioned tasks - pass - elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item: - # video frame will be processed downstream - pass elif first_item is None: pass else: @@ -130,32 +120,12 @@ def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) -> return version -def load_hf_dataset( - local_dir: Path, - data_path: str, - total_episodes: int, - episodes: list[int] | None = None, - split: str = "train", -) -> datasets.Dataset: - """hf_dataset contains all the observations, states, actions, rewards, etc.""" - if episodes is None: - path = str(local_dir / "data") - hf_dataset = load_dataset("parquet", data_dir=path, split=split) - else: - files = [data_path.format(episode_index=ep_idx, total_episodes=total_episodes) for ep_idx in episodes] - files = [str(local_dir / fpath) for fpath in files] - hf_dataset = load_dataset("parquet", data_files=files, split=split) - - hf_dataset.set_transform(hf_transform_to_torch) - return hf_dataset - - def load_metadata(local_dir: Path) -> tuple[dict | list]: """Loads metadata files from a dataset.""" - info_path = local_dir / "meta/info.jsonl" + info_path = local_dir / "meta/info.json" episodes_path = local_dir / "meta/episodes.jsonl" stats_path = local_dir / "meta/stats.json" - tasks_path = local_dir / "meta/tasks.json" + tasks_path = local_dir / "meta/tasks.jsonl" with open(info_path) as f: info = json.load(f) @@ -499,12 +469,17 @@ def create_branch(repo_id, *, branch: str, repo_type: str | None = None): api.create_branch(repo_id, repo_type=repo_type, branch=branch) -def create_lerobot_dataset_card(tags: list | None = None, text: str | None = None) -> DatasetCard: +def create_lerobot_dataset_card( + tags: list | None = None, text: str | None = None, info: dict | None = None +) -> DatasetCard: card = DatasetCard(DATASET_CARD_TEMPLATE) card.data.task_categories = ["robotics"] card.data.tags = ["LeRobot"] if tags is not None: card.data.tags += tags if text is not None: - card.text += text + card.text += f"{text}\n" + if info is not None: + card.text += "[meta/info.json](meta/info.json)\n" + card.text += f"```json\n{json.dumps(info, indent=4)}\n```" return card