Add file paths
This commit is contained in:
parent
ac3798bd62
commit
9316cf46ef
|
@ -22,6 +22,7 @@ from typing import Callable
|
|||
import datasets
|
||||
import torch
|
||||
import torch.utils
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||
|
@ -32,7 +33,7 @@ from lerobot.common.datasets.utils import (
|
|||
get_delta_indices,
|
||||
get_episode_data_index,
|
||||
get_hub_safe_version,
|
||||
load_hf_dataset,
|
||||
hf_transform_to_torch,
|
||||
load_metadata,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, decode_video_frames_torchvision
|
||||
|
@ -100,7 +101,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
│ ├── episodes.jsonl
|
||||
│ ├── info.json
|
||||
│ ├── stats.json
|
||||
│ └── tasks.json
|
||||
│ └── tasks.jsonl
|
||||
└── videos (optional)
|
||||
├── chunk-000
|
||||
│ ├── observation.images.laptop
|
||||
|
@ -160,12 +161,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
# Load metadata
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
self._version = get_hub_safe_version(repo_id, CODEBASE_VERSION)
|
||||
self.download_metadata()
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
self.info, self.episode_dicts, self.stats, self.tasks = load_metadata(self.root)
|
||||
|
||||
# Load actual data
|
||||
self.download_episodes()
|
||||
self.hf_dataset = load_hf_dataset(self.root, self.data_path, self.total_episodes, self.episodes)
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
|
||||
|
||||
# Check timestamps
|
||||
|
@ -187,13 +188,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
# - [ ] Update episode_index (arg update=True)
|
||||
# - [ ] Update info.json (arg update=True)
|
||||
|
||||
def download_metadata(self) -> None:
|
||||
def pull_from_repo(
|
||||
self,
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
ignore_patterns: list[str] | str | None = None,
|
||||
) -> None:
|
||||
snapshot_download(
|
||||
self.repo_id,
|
||||
repo_type="dataset",
|
||||
revision=self._version,
|
||||
local_dir=self.root,
|
||||
allow_patterns="meta/",
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
|
||||
def download_episodes(self) -> None:
|
||||
|
@ -207,26 +213,46 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
files = None
|
||||
ignore_patterns = None if self.download_videos else "videos/"
|
||||
if self.episodes is not None:
|
||||
files = [
|
||||
self.data_path.format(episode_index=ep_idx, total_episodes=self.total_episodes)
|
||||
for ep_idx in self.episodes
|
||||
]
|
||||
files = [self.get_data_file_path(ep_idx) for ep_idx in self.episodes]
|
||||
if len(self.video_keys) > 0 and self.download_videos:
|
||||
video_files = [
|
||||
self.videos_path.format(video_key=vid_key, episode_index=ep_idx)
|
||||
self.get_video_file_path(ep_idx, vid_key)
|
||||
for vid_key in self.video_keys
|
||||
for ep_idx in self.episodes
|
||||
]
|
||||
files += video_files
|
||||
|
||||
snapshot_download(
|
||||
self.repo_id,
|
||||
repo_type="dataset",
|
||||
revision=self._version,
|
||||
local_dir=self.root,
|
||||
allow_patterns=files,
|
||||
ignore_patterns=ignore_patterns,
|
||||
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||
|
||||
def load_hf_dataset(self) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
if self.episodes is None:
|
||||
path = str(self.root / "data")
|
||||
hf_dataset = load_dataset("parquet", data_dir=path, split="train")
|
||||
else:
|
||||
files = [self.get_data_file_path(ep_idx) for ep_idx in self.episodes]
|
||||
hf_dataset = load_dataset("parquet", data_files=files, split="train")
|
||||
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
def get_data_file_path(self, ep_index: int, return_str: bool = True) -> str | Path:
|
||||
ep_chunk = self.get_episode_chunk(ep_index)
|
||||
fpath = self.data_path.format(
|
||||
episode_chunk=ep_chunk, episode_index=ep_index, total_episodes=self.total_episodes
|
||||
)
|
||||
return str(self.root / fpath) if return_str else self.root / fpath
|
||||
|
||||
def get_video_file_path(self, ep_index: int, vid_key: str, return_str: bool = True) -> str | Path:
|
||||
ep_chunk = self.get_episode_chunk(ep_index)
|
||||
fpath = self.videos_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
||||
return str(self.root / fpath) if return_str else self.root / fpath
|
||||
|
||||
def get_episode_chunk(self, ep_index: int) -> int:
|
||||
ep_chunk = ep_index // self.chunks_size
|
||||
if ep_index > 0 and ep_index % self.chunks_size == 0:
|
||||
ep_chunk -= 1
|
||||
return ep_chunk
|
||||
|
||||
@property
|
||||
def data_path(self) -> str:
|
||||
|
@ -355,7 +381,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
"""
|
||||
item = {}
|
||||
for vid_key, query_ts in query_timestamps.items():
|
||||
video_path = self.root / self.videos_path.format(video_key=vid_key, episode_index=ep_idx)
|
||||
video_path = self.root / self.get_video_file_path(ep_idx, vid_key)
|
||||
frames = decode_video_frames_torchvision(
|
||||
video_path, query_ts, self.tolerance_s, self.video_backend
|
||||
)
|
||||
|
@ -436,6 +462,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
obj.write_info()
|
||||
obj.fps = fps
|
||||
|
||||
if not all(cam.fps == fps for cam in robot.cameras):
|
||||
logging.warn(
|
||||
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
||||
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
|
||||
)
|
||||
|
||||
# obj.episodes = None
|
||||
# obj.image_transforms = None
|
||||
# obj.delta_timestamps = None
|
||||
|
|
|
@ -23,7 +23,6 @@ from typing import Dict
|
|||
import datasets
|
||||
import jsonlines
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import DatasetCard, HfApi
|
||||
from PIL import Image as PILImage
|
||||
from torchvision import transforms
|
||||
|
@ -87,15 +86,6 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
|||
if isinstance(first_item, PILImage.Image):
|
||||
to_tensor = transforms.ToTensor()
|
||||
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
||||
# TODO(aliberts): remove this part as we'll be using task_index
|
||||
elif isinstance(first_item, str):
|
||||
# TODO (michel-aractingi): add str2embedding via language tokenizer
|
||||
# For now we leave this part up to the user to choose how to address
|
||||
# language conditioned tasks
|
||||
pass
|
||||
elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item:
|
||||
# video frame will be processed downstream
|
||||
pass
|
||||
elif first_item is None:
|
||||
pass
|
||||
else:
|
||||
|
@ -130,32 +120,12 @@ def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) ->
|
|||
return version
|
||||
|
||||
|
||||
def load_hf_dataset(
|
||||
local_dir: Path,
|
||||
data_path: str,
|
||||
total_episodes: int,
|
||||
episodes: list[int] | None = None,
|
||||
split: str = "train",
|
||||
) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
if episodes is None:
|
||||
path = str(local_dir / "data")
|
||||
hf_dataset = load_dataset("parquet", data_dir=path, split=split)
|
||||
else:
|
||||
files = [data_path.format(episode_index=ep_idx, total_episodes=total_episodes) for ep_idx in episodes]
|
||||
files = [str(local_dir / fpath) for fpath in files]
|
||||
hf_dataset = load_dataset("parquet", data_files=files, split=split)
|
||||
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def load_metadata(local_dir: Path) -> tuple[dict | list]:
|
||||
"""Loads metadata files from a dataset."""
|
||||
info_path = local_dir / "meta/info.jsonl"
|
||||
info_path = local_dir / "meta/info.json"
|
||||
episodes_path = local_dir / "meta/episodes.jsonl"
|
||||
stats_path = local_dir / "meta/stats.json"
|
||||
tasks_path = local_dir / "meta/tasks.json"
|
||||
tasks_path = local_dir / "meta/tasks.jsonl"
|
||||
|
||||
with open(info_path) as f:
|
||||
info = json.load(f)
|
||||
|
@ -499,12 +469,17 @@ def create_branch(repo_id, *, branch: str, repo_type: str | None = None):
|
|||
api.create_branch(repo_id, repo_type=repo_type, branch=branch)
|
||||
|
||||
|
||||
def create_lerobot_dataset_card(tags: list | None = None, text: str | None = None) -> DatasetCard:
|
||||
def create_lerobot_dataset_card(
|
||||
tags: list | None = None, text: str | None = None, info: dict | None = None
|
||||
) -> DatasetCard:
|
||||
card = DatasetCard(DATASET_CARD_TEMPLATE)
|
||||
card.data.task_categories = ["robotics"]
|
||||
card.data.tags = ["LeRobot"]
|
||||
if tags is not None:
|
||||
card.data.tags += tags
|
||||
if text is not None:
|
||||
card.text += text
|
||||
card.text += f"{text}\n"
|
||||
if info is not None:
|
||||
card.text += "[meta/info.json](meta/info.json)\n"
|
||||
card.text += f"```json\n{json.dumps(info, indent=4)}\n```"
|
||||
return card
|
||||
|
|
Loading…
Reference in New Issue