Rework LeRobotDataset.__init__
This commit is contained in:
parent
2d75b93ba0
commit
096824b5ff
|
@ -15,6 +15,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from itertools import accumulate
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
|
@ -24,27 +25,27 @@ import torch.utils
|
||||||
|
|
||||||
from lerobot.common.datasets.compute_stats import aggregate_stats
|
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
calculate_episode_data_index,
|
download_episodes,
|
||||||
load_episode_data_index,
|
get_hub_safe_version,
|
||||||
load_hf_dataset,
|
load_hf_dataset,
|
||||||
load_info,
|
load_info,
|
||||||
load_previous_and_future_frames,
|
load_previous_and_future_frames,
|
||||||
load_stats,
|
load_stats,
|
||||||
load_videos,
|
load_tasks,
|
||||||
reset_episode_index,
|
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
|
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
|
||||||
|
|
||||||
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
|
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
|
||||||
CODEBASE_VERSION = "v1.6"
|
CODEBASE_VERSION = "v2.0"
|
||||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
|
||||||
|
|
||||||
|
|
||||||
class LeRobotDataset(torch.utils.data.Dataset):
|
class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
repo_id: str,
|
repo_id: str,
|
||||||
root: Path | None = DATA_DIR,
|
root: Path | None = None,
|
||||||
|
episodes: list[int] | None = None,
|
||||||
split: str = "train",
|
split: str = "train",
|
||||||
image_transforms: Callable | None = None,
|
image_transforms: Callable | None = None,
|
||||||
delta_timestamps: dict[list[float]] | None = None,
|
delta_timestamps: dict[list[float]] | None = None,
|
||||||
|
@ -52,49 +53,89 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.repo_id = repo_id
|
self.repo_id = repo_id
|
||||||
self.root = root
|
self.root = root if root is not None else LEROBOT_HOME / repo_id
|
||||||
self.split = split
|
self.split = split
|
||||||
self.image_transforms = image_transforms
|
self.image_transforms = image_transforms
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = delta_timestamps
|
||||||
# load data from hub or locally when root is provided
|
self.episodes = episodes
|
||||||
# TODO(rcadene, aliberts): implement faster transfer
|
|
||||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
|
||||||
self.hf_dataset = load_hf_dataset(repo_id, CODEBASE_VERSION, root, split)
|
|
||||||
if split == "train":
|
|
||||||
self.episode_data_index = load_episode_data_index(repo_id, CODEBASE_VERSION, root)
|
|
||||||
else:
|
|
||||||
self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
|
|
||||||
self.hf_dataset = reset_episode_index(self.hf_dataset)
|
|
||||||
self.stats = load_stats(repo_id, CODEBASE_VERSION, root)
|
|
||||||
self.info = load_info(repo_id, CODEBASE_VERSION, root)
|
|
||||||
if self.video:
|
|
||||||
self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root)
|
|
||||||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||||
|
|
||||||
|
# Load metadata
|
||||||
|
self.root.mkdir(exist_ok=True, parents=True)
|
||||||
|
self._version = get_hub_safe_version(repo_id, CODEBASE_VERSION)
|
||||||
|
self.info = load_info(repo_id, self._version, self.root)
|
||||||
|
self.stats = load_stats(repo_id, self._version, self.root)
|
||||||
|
self.tasks = load_tasks(repo_id, self._version, self.root)
|
||||||
|
|
||||||
|
# Load actual data
|
||||||
|
download_episodes(
|
||||||
|
repo_id,
|
||||||
|
self._version,
|
||||||
|
self.root,
|
||||||
|
self.data_path,
|
||||||
|
self.video_keys,
|
||||||
|
self.num_episodes,
|
||||||
|
self.episodes,
|
||||||
|
self.videos_path,
|
||||||
|
)
|
||||||
|
self.hf_dataset = load_hf_dataset(self.root, self.data_path, self.total_episodes, self.episodes)
|
||||||
|
self.episode_data_index = self.get_episode_data_index()
|
||||||
|
|
||||||
|
# TODO(aliberts):
|
||||||
|
# - [ ] Update __get_item__
|
||||||
|
# - [ ] Add self.consolidate() for:
|
||||||
|
# - [ ] Sanity checks (episodes num, shapes, files, etc.)
|
||||||
|
# - [ ] Update episode_index (arg update=True)
|
||||||
|
# - [ ] Update info.json (arg update=True)
|
||||||
|
|
||||||
|
# TODO(aliberts): remove (deprecated)
|
||||||
|
# if split == "train":
|
||||||
|
# self.episode_data_index = load_episode_data_index(self.episodes, self.episode_list)
|
||||||
|
# else:
|
||||||
|
# self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
|
||||||
|
# self.hf_dataset = reset_episode_index(self.hf_dataset)
|
||||||
|
# if self.video:
|
||||||
|
# self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data_path(self) -> str:
|
||||||
|
"""Formattable string for the parquet files."""
|
||||||
|
return self.info["data_path"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def videos_path(self) -> str | None:
|
||||||
|
"""Formattable string for the video files."""
|
||||||
|
return self.info["videos"]["videos_path"] if len(self.video_keys) > 0 else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def episode_dicts(self) -> list[dict]:
|
||||||
|
"""List of dictionary containing information for each episode, indexed by episode_index."""
|
||||||
|
return self.info["episodes"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def fps(self) -> int:
|
def fps(self) -> int:
|
||||||
"""Frames per second used during data collection."""
|
"""Frames per second used during data collection."""
|
||||||
return self.info["fps"]
|
return self.info["fps"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def video(self) -> bool:
|
def keys(self) -> list[str]:
|
||||||
"""Returns True if this dataset loads video frames from mp4 files.
|
"""Keys to access non-image data (state, actions etc.)."""
|
||||||
Returns False if it only loads images from png files.
|
return self.info["keys"]
|
||||||
"""
|
|
||||||
return self.info.get("video", False)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def features(self) -> datasets.Features:
|
def image_keys(self) -> list[str]:
|
||||||
return self.hf_dataset.features
|
"""Keys to access visual modalities stored as images."""
|
||||||
|
return self.info["image_keys"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def video_keys(self) -> list[str]:
|
||||||
|
"""Keys to access visual modalities stored as videos."""
|
||||||
|
return self.info["video_keys"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def camera_keys(self) -> list[str]:
|
def camera_keys(self) -> list[str]:
|
||||||
"""Keys to access image and video stream from cameras."""
|
"""Keys to access image and video streams from cameras."""
|
||||||
keys = []
|
return self.image_keys + self.video_keys
|
||||||
for key, feats in self.hf_dataset.features.items():
|
|
||||||
if isinstance(feats, (datasets.Image, VideoFrame)):
|
|
||||||
keys.append(key)
|
|
||||||
return keys
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def video_frame_keys(self) -> list[str]:
|
def video_frame_keys(self) -> list[str]:
|
||||||
|
@ -117,8 +158,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_episodes(self) -> int:
|
def num_episodes(self) -> int:
|
||||||
"""Number of episodes."""
|
"""Number of episodes selected."""
|
||||||
return len(self.hf_dataset.unique("episode_index"))
|
return len(self.episodes) if self.episodes is not None else self.total_episodes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_episodes(self) -> int:
|
||||||
|
"""Total number of episodes available."""
|
||||||
|
return self.info["total_episodes"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tolerance_s(self) -> float:
|
def tolerance_s(self) -> float:
|
||||||
|
@ -129,6 +175,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
# 1e-4 to account for possible numerical error
|
# 1e-4 to account for possible numerical error
|
||||||
return 1 / self.fps - 1e-4
|
return 1 / self.fps - 1e-4
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shapes(self) -> dict:
|
||||||
|
"""Shapes for the different features."""
|
||||||
|
self.info.get("shapes")
|
||||||
|
|
||||||
|
def get_episode_data_index(self) -> dict[str, torch.Tensor]:
|
||||||
|
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(self.episode_dicts)}
|
||||||
|
if self.episodes is not None:
|
||||||
|
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in self.episodes}
|
||||||
|
|
||||||
|
cumulative_lenghts = list(accumulate(episode_lengths.values()))
|
||||||
|
return {
|
||||||
|
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
|
||||||
|
"to": torch.LongTensor(cumulative_lenghts),
|
||||||
|
}
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.num_samples
|
return self.num_samples
|
||||||
|
|
||||||
|
@ -147,7 +209,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
if self.video:
|
if self.video:
|
||||||
item = load_from_videos(
|
item = load_from_videos(
|
||||||
item,
|
item,
|
||||||
self.video_frame_keys,
|
self.video_keys,
|
||||||
self.videos_dir,
|
self.videos_dir,
|
||||||
self.tolerance_s,
|
self.tolerance_s,
|
||||||
self.video_backend,
|
self.video_backend,
|
||||||
|
@ -225,7 +287,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
repo_ids: list[str],
|
repo_ids: list[str],
|
||||||
root: Path | None = DATA_DIR,
|
root: Path | None = LEROBOT_HOME,
|
||||||
split: str = "train",
|
split: str = "train",
|
||||||
image_transforms: Callable | None = None,
|
image_transforms: Callable | None = None,
|
||||||
delta_timestamps: dict[list[float]] | None = None,
|
delta_timestamps: dict[list[float]] | None = None,
|
||||||
|
|
|
@ -14,7 +14,6 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import json
|
import json
|
||||||
import re
|
|
||||||
import warnings
|
import warnings
|
||||||
from functools import cache
|
from functools import cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -22,10 +21,9 @@ from typing import Dict
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
from datasets import load_dataset, load_from_disk
|
from datasets import load_dataset
|
||||||
from huggingface_hub import DatasetCard, HfApi, hf_hub_download, snapshot_download
|
from huggingface_hub import DatasetCard, HfApi, hf_hub_download, snapshot_download
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
from safetensors.torch import load_file
|
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
DATASET_CARD_TEMPLATE = """
|
DATASET_CARD_TEMPLATE = """
|
||||||
|
@ -96,7 +94,14 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||||
|
|
||||||
|
|
||||||
@cache
|
@cache
|
||||||
def get_hf_dataset_safe_version(repo_id: str, version: str) -> str:
|
def get_hub_safe_version(repo_id: str, version: str) -> str:
|
||||||
|
num_version = float(version.strip("v"))
|
||||||
|
if num_version < 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"""The dataset you requested ({repo_id}) is in {version} format. We introduced a new
|
||||||
|
format with v2.0 that is not backward compatible. Please use our conversion script
|
||||||
|
first (convert_dataset_16_to_20.py) to convert your dataset to this new format."""
|
||||||
|
)
|
||||||
api = HfApi()
|
api = HfApi()
|
||||||
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
|
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||||
branches = [b.name for b in dataset_info.branches]
|
branches = [b.name for b in dataset_info.branches]
|
||||||
|
@ -116,56 +121,27 @@ def get_hf_dataset_safe_version(repo_id: str, version: str) -> str:
|
||||||
return version
|
return version
|
||||||
|
|
||||||
|
|
||||||
def load_hf_dataset(repo_id: str, version: str, root: Path, split: str) -> datasets.Dataset:
|
def load_hf_dataset(
|
||||||
|
local_dir: Path,
|
||||||
|
data_path: str,
|
||||||
|
total_episodes: int,
|
||||||
|
episodes: list[int] | None = None,
|
||||||
|
split: str = "train",
|
||||||
|
) -> datasets.Dataset:
|
||||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||||
if root is not None:
|
if episodes is None:
|
||||||
hf_dataset = load_from_disk(str(Path(root) / repo_id / "train"))
|
path = str(local_dir / "data")
|
||||||
# TODO(rcadene): clean this which enables getting a subset of dataset
|
hf_dataset = load_dataset("parquet", data_dir=path, split=split)
|
||||||
if split != "train":
|
|
||||||
if "%" in split:
|
|
||||||
raise NotImplementedError(f"We dont support splitting based on percentage for now ({split}).")
|
|
||||||
match_from = re.search(r"train\[(\d+):\]", split)
|
|
||||||
match_to = re.search(r"train\[:(\d+)\]", split)
|
|
||||||
if match_from:
|
|
||||||
from_frame_index = int(match_from.group(1))
|
|
||||||
hf_dataset = hf_dataset.select(range(from_frame_index, len(hf_dataset)))
|
|
||||||
elif match_to:
|
|
||||||
to_frame_index = int(match_to.group(1))
|
|
||||||
hf_dataset = hf_dataset.select(range(to_frame_index))
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
files = [data_path.format(episode_index=ep_idx, total_episodes=total_episodes) for ep_idx in episodes]
|
||||||
f'`split` ({split}) should either be "train", "train[INT:]", or "train[:INT]"'
|
files = [str(local_dir / fpath) for fpath in files]
|
||||||
)
|
hf_dataset = load_dataset("parquet", data_files=files, split=split)
|
||||||
else:
|
|
||||||
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
|
||||||
hf_dataset = load_dataset(repo_id, revision=safe_version, split=split)
|
|
||||||
|
|
||||||
hf_dataset.set_transform(hf_transform_to_torch)
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
return hf_dataset
|
return hf_dataset
|
||||||
|
|
||||||
|
|
||||||
def load_episode_data_index(repo_id, version, root) -> dict[str, torch.Tensor]:
|
def load_stats(repo_id: str, version: str, local_dir: Path) -> dict[str, dict[str, torch.Tensor]]:
|
||||||
"""episode_data_index contains the range of indices for each episode
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```python
|
|
||||||
from_id = episode_data_index["from"][episode_id].item()
|
|
||||||
to_id = episode_data_index["to"][episode_id].item()
|
|
||||||
episode_frames = [dataset[i] for i in range(from_id, to_id)]
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
if root is not None:
|
|
||||||
path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors"
|
|
||||||
else:
|
|
||||||
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
|
||||||
path = hf_hub_download(
|
|
||||||
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=safe_version
|
|
||||||
)
|
|
||||||
|
|
||||||
return load_file(path)
|
|
||||||
|
|
||||||
|
|
||||||
def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
|
|
||||||
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
|
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
@ -173,47 +149,84 @@ def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
|
||||||
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
|
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
if root is not None:
|
fpath = hf_hub_download(
|
||||||
path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
|
repo_id, filename="meta/stats.json", local_dir=local_dir, repo_type="dataset", revision=version
|
||||||
else:
|
|
||||||
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
|
||||||
path = hf_hub_download(
|
|
||||||
repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=safe_version
|
|
||||||
)
|
)
|
||||||
|
with open(fpath) as f:
|
||||||
|
stats = json.load(f)
|
||||||
|
|
||||||
stats = load_file(path)
|
stats = flatten_dict(stats)
|
||||||
|
stats = {key: torch.tensor(value) for key, value in stats.items()}
|
||||||
return unflatten_dict(stats)
|
return unflatten_dict(stats)
|
||||||
|
|
||||||
|
|
||||||
def load_info(repo_id, version, root) -> dict:
|
def load_info(repo_id: str, version: str, local_dir: Path) -> dict:
|
||||||
"""info contains useful information regarding the dataset that are not stored elsewhere
|
"""info contains structural information about the dataset. It should be the reference and
|
||||||
|
act as the 'source of thruth' for what's inside the dataset.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```python
|
```python
|
||||||
print("frame per second used to collect the video", info["fps"])
|
print("frame per second used to collect the video", info["fps"])
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
if root is not None:
|
fpath = hf_hub_download(
|
||||||
path = Path(root) / repo_id / "meta_data" / "info.json"
|
repo_id, filename="meta/info.json", local_dir=local_dir, repo_type="dataset", revision=version
|
||||||
else:
|
)
|
||||||
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
with open(fpath) as f:
|
||||||
path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=safe_version)
|
return json.load(f)
|
||||||
|
|
||||||
with open(path) as f:
|
|
||||||
info = json.load(f)
|
|
||||||
return info
|
|
||||||
|
|
||||||
|
|
||||||
def load_videos(repo_id, version, root) -> Path:
|
def load_tasks(repo_id: str, version: str, local_dir: Path) -> dict:
|
||||||
if root is not None:
|
"""tasks contains all the tasks of the dataset, indexed by their task_index.
|
||||||
path = Path(root) / repo_id / "videos"
|
|
||||||
else:
|
|
||||||
# TODO(rcadene): we download the whole repo here. see if we can avoid this
|
|
||||||
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
|
||||||
repo_dir = snapshot_download(repo_id, repo_type="dataset", revision=safe_version)
|
|
||||||
path = Path(repo_dir) / "videos"
|
|
||||||
|
|
||||||
return path
|
Example:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"0": "Pick the Lego block and drop it in the box on the right."
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
fpath = hf_hub_download(
|
||||||
|
repo_id, filename="meta/tasks.json", local_dir=local_dir, repo_type="dataset", revision=version
|
||||||
|
)
|
||||||
|
with open(fpath) as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def download_episodes(
|
||||||
|
repo_id: str,
|
||||||
|
version: str,
|
||||||
|
local_dir: Path,
|
||||||
|
data_path: str,
|
||||||
|
video_keys: list,
|
||||||
|
total_episodes: int,
|
||||||
|
episodes: list[int] | None = None,
|
||||||
|
videos_path: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Downloads the dataset from the given 'repo_id' at the provided 'version'. If 'episodes' is given, this
|
||||||
|
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
|
||||||
|
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
|
||||||
|
in 'local_dir', they won't be downloaded again.
|
||||||
|
|
||||||
|
Note: Currently, if you're running this code offline but you already have the files in 'local_dir',
|
||||||
|
snapshot_download will still fail. This behavior will be fixed in an upcoming update of huggingface_hub.
|
||||||
|
"""
|
||||||
|
# TODO(rcadene, aliberts): implement faster transfer
|
||||||
|
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||||
|
files = None
|
||||||
|
if episodes is not None:
|
||||||
|
files = [data_path.format(episode_index=ep_idx, total_episodes=total_episodes) for ep_idx in episodes]
|
||||||
|
if len(video_keys) > 0:
|
||||||
|
video_files = [
|
||||||
|
videos_path.format(video_key=vid_key, episode_index=ep_idx)
|
||||||
|
for vid_key in video_keys
|
||||||
|
for ep_idx in episodes
|
||||||
|
]
|
||||||
|
files += video_files
|
||||||
|
|
||||||
|
snapshot_download(
|
||||||
|
repo_id, repo_type="dataset", revision=version, local_dir=local_dir, allow_patterns=files
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_previous_and_future_frames(
|
def load_previous_and_future_frames(
|
||||||
|
|
Loading…
Reference in New Issue