Refactor datasets into LeRobotDataset (#91)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Remi 2024-04-25 12:23:12 +02:00 committed by GitHub
parent e760e4cd63
commit 659c69a1c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
90 changed files with 167 additions and 352 deletions

View File

@ -73,15 +73,14 @@ environments ([aloha](https://github.com/huggingface/gym-aloha),
[pusht](https://github.com/huggingface/gym-pusht))
and follow the same api design.
When implementing a new dataset class (e.g. `AlohaDataset`) follow these steps:
- Update `available_datasets` in `lerobot/__init__.py`
- Copy it in the required `available_datasets` class attribute
When implementing a new dataset loadable with LeRobotDataset follow these steps:
- Update `available_datasets_per_env` in `lerobot/__init__.py`
When implementing a new environment (e.g. `gym_aloha`), follow these steps:
- Update `available_envs`, `available_tasks_per_env` and `available_datasets` in `lerobot/__init__.py`
- Update `available_tasks_per_env` and `available_datasets_per_env` in `lerobot/__init__.py`
When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps:
- Update `available_policies` in `lerobot/__init__.py`
- Update `available_policies` and `available_policies_per_env`, in `lerobot/__init__.py`
- Set the required `name` class attribute.
- Update variables in `tests/test_available.py` by importing your new Policy class

View File

@ -118,30 +118,7 @@ wandb login
### Visualize datasets
You can import our dataset class, download the data from the HuggingFace hub and use our rendering utilities:
```python
""" Copy pasted from `examples/1_visualize_dataset.py` """
import os
from pathlib import Path
import lerobot
from lerobot.common.datasets.aloha import AlohaDataset
from lerobot.scripts.visualize_dataset import render_dataset
print(lerobot.available_datasets)
# >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium']
# TODO(rcadene): remove DATA_DIR
dataset = AlohaDataset("pusht", root=Path(os.environ.get("DATA_DIR")))
video_paths = render_dataset(
dataset,
out_dir="outputs/visualize_dataset/example",
max_num_episodes=1,
)
print(video_paths)
# ['outputs/visualize_dataset/example/episode_0.mp4']
```
Check out [examples](./examples) to see how you can import our dataset class, download the data from the HuggingFace hub and use our rendering utilities.
Or you can achieve the same result by executing our script from the command line:
```bash
@ -153,7 +130,7 @@ hydra.run.dir=outputs/visualize_dataset/example
### Evaluate a pretrained policy
Check out [example 2](./examples/2_evaluate_pretrained_policy.py) to see how you can load a pretrained policy from HuggingFace hub, load up the corresponding environment and model, and run an evaluation.
Check out [examples](./examples) to see how you can load a pretrained policy from HuggingFace hub, load up the corresponding environment and model, and run an evaluation.
Or you can achieve the same result by executing our script from the command line:
```bash
@ -176,24 +153,30 @@ See `python lerobot/scripts/eval.py --help` for more instructions.
### Train your own policy
You can import our dataset, environment, policy classes, and use our training utilities (if some data is missing, it will be automatically downloaded from HuggingFace hub): check out [example 3](./examples/3_train_policy.py). After you run this, you may want to revisit [example 2](./examples/2_evaluate_pretrained_policy.py) to evaluate your training output!
Check out [examples](./examples) to see how you can start training a model on a dataset, which will be automatically downloaded if needed.
In general, you can use our training script to easily train any policy on any environment:
```bash
python lerobot/scripts/train.py \
env=aloha \
task=sim_insertion \
dataset_id=aloha_sim_insertion_scripted \
repo_id=lerobot/aloha_sim_insertion_scripted \
policy=act \
hydra.run.dir=outputs/train/aloha_act
```
After training, you may want to revisit model evaluation to change the evaluation settings. In fact, during training every checkpoint is already evaluated but on a low number of episodes for efficiency. Check out [example](./examples) to evaluate any model checkpoint on more episodes to increase statistical significance.
## Contribute
If you would like to contribute to 🤗 LeRobot, please check out our [contribution guide](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md).
### Add a new dataset
```python
# TODO(rcadene, AdilZouitine): rewrite this section
```
To add a dataset to the hub, first login and use a token generated from [huggingface settings](https://huggingface.co/settings/tokens) with write access:
```bash
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
@ -255,6 +238,10 @@ python tests/scripts/mock_dataset.py --in-data-dir data/$DATASET --out-data-dir
### Add a pretrained policy
```python
# TODO(rcadene, alexander-soare): rewrite this section
```
Once you have trained a policy you may upload it to the HuggingFace hub.
Firstly, make sure you have a model repository set up on the hub. The hub ID looks like HF_USER/REPO_NAME.

View File

@ -23,6 +23,7 @@ from lerobot.common.datasets.utils import compute_stats, flatten_dict, hf_transf
def download_and_upload(root, revision, dataset_id):
# TODO(rcadene, adilzouitine): add community_id/user_id (e.g. "lerobot", "cadene") or repo_id (e.g. "lerobot/pusht")
if "pusht" in dataset_id:
download_and_upload_pusht(root, revision, dataset_id)
elif "xarm" in dataset_id:
@ -149,11 +150,11 @@ def push_to_hub(hf_dataset, episode_data_index, info, stats, root, revision, dat
# copy in tests folder, the first episode and the meta_data directory
num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
hf_dataset.select(range(num_items_first_ep)).with_format("torch").save_to_disk(
f"tests/data/{dataset_id}/train"
f"tests/data/lerobot/{dataset_id}/train"
)
if Path(f"tests/data/{dataset_id}/meta_data").exists():
shutil.rmtree(f"tests/data/{dataset_id}/meta_data")
shutil.copytree(meta_data_dir, f"tests/data/{dataset_id}/meta_data")
if Path(f"tests/data/lerobot/{dataset_id}/meta_data").exists():
shutil.rmtree(f"tests/data/lerobot/{dataset_id}/meta_data")
shutil.copytree(meta_data_dir, f"tests/data/lerobot/{dataset_id}/meta_data")
def download_and_upload_pusht(root, revision, dataset_id="pusht", fps=10):

View File

@ -1,5 +1,5 @@
"""
This script demonstrates the use of the PushtDataset class for handling and processing robotic datasets from Hugging Face.
This script demonstrates the use of `LeRobotDataset` class for handling and processing robotic datasets from Hugging Face.
It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch.
Features included in this script:
@ -11,22 +11,6 @@ Features included in this script:
- Demonstrating compatibility with PyTorch DataLoader for batch processing.
The script ends with examples of how to batch process data using PyTorch's DataLoader.
To try a different Hugging Face dataset, you can replace:
```python
dataset = PushtDataset()
```
by one of these:
```python
dataset = XarmDataset("xarm_lift_medium")
dataset = XarmDataset("xarm_lift_medium_replay")
dataset = XarmDataset("xarm_push_medium")
dataset = XarmDataset("xarm_push_medium_replay")
dataset = AlohaDataset("aloha_sim_insertion_human")
dataset = AlohaDataset("aloha_sim_insertion_scripted")
dataset = AlohaDataset("aloha_sim_transfer_cube_human")
dataset = AlohaDataset("aloha_sim_transfer_cube_scripted")
```
"""
from pathlib import Path
@ -34,31 +18,33 @@ from pathlib import Path
import imageio
import torch
from lerobot.common.datasets.pusht import PushtDataset
import lerobot
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
# TODO(rcadene): List available datasets and their dataset ids (e.g. PushtDataset, AlohaDataset(dataset_id="aloha_sim_insertion_human"))
# print("List of available datasets", lerobot.available_datasets)
# # >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted',
# # 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted',
# # 'pusht', 'xarm_lift_medium']
print("List of available datasets", lerobot.available_datasets)
# # >>> ['lerobot/aloha_sim_insertion_human', 'lerobot/aloha_sim_insertion_scripted',
# # 'lerobot/aloha_sim_transfer_cube_human', 'lerobot/aloha_sim_transfer_cube_scripted',
# # 'lerobot/pusht', 'lerobot/xarm_lift_medium']
repo_id = "lerobot/pusht"
# You can easily load datasets from LeRobot
dataset = PushtDataset()
# You can easily load a dataset from a Hugging Face repositery
dataset = LeRobotDataset(repo_id)
# All LeRobot datasets are actually a thin wrapper around an underlying Hugging Face dataset (see https://huggingface.co/docs/datasets/index for more information).
# LeRobotDataset is actually a thin wrapper around an underlying Hugging Face dataset (see https://huggingface.co/docs/datasets/index for more information).
# TODO(rcadene): update to make the print pretty
print(f"{dataset=}")
print(f"{dataset.hf_dataset=}")
# and provide additional utilities for robotics and compatibility with pytorch
# and provides additional utilities for robotics and compatibility with pytorch
print(f"number of samples/frames: {dataset.num_samples=}")
print(f"number of episodes: {dataset.num_episodes=}")
print(f"average number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}")
print(f"frames per second used during data collection: {dataset.fps=}")
print(f"keys to access images from cameras: {dataset.image_keys=}")
# While the LeRobot dataset adds helpers for working within our library, we still expose the underling Hugging Face dataset. It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5.
# While the LeRobotDataset adds helpers for working within our library, we still expose the underling Hugging Face dataset.
# It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5.
# TODO(rcadene): remove this example of accessing hf_dataset
dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_index"] == 5)
@ -85,7 +71,7 @@ delta_timestamps = {
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
"action": [t / dataset.fps for t in range(64)],
}
dataset = PushtDataset(delta_timestamps=delta_timestamps)
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
print(f"{dataset[0]['observation.image'].shape=}") # (4,c,h,w)
print(f"{dataset[0]['observation.state'].shape=}") # (8,c)
print(f"{dataset[0]['action'].shape=}") # (64,c)

View File

@ -8,31 +8,25 @@ Example:
print(lerobot.available_envs)
print(lerobot.available_tasks_per_env)
print(lerobot.available_datasets)
print(lerobot.available_datasets_per_env)
print(lerobot.available_policies)
print(lerobot.available_policies_per_env)
```
When implementing a new dataset class (e.g. `AlohaDataset`) follow these steps:
- Update `available_datasets` in `lerobot/__init__.py`
- Set the required `available_datasets` class attribute using the previously updated `lerobot.available_datasets`
When implementing a new dataset loadable with LeRobotDataset follow these steps:
- Update `available_datasets_per_env` in `lerobot/__init__.py`
When implementing a new environment (e.g. `gym_aloha`), follow these steps:
- Update `available_envs`, `available_tasks_per_env` and `available_datasets` in `lerobot/__init__.py`
- Update `available_tasks_per_env` and `available_datasets_per_env` in `lerobot/__init__.py`
When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps:
- Update `available_policies` in `lerobot/__init__.py`
- Update `available_policies` and `available_policies_per_env`, in `lerobot/__init__.py`
- Set the required `name` class attribute.
- Update variables in `tests/test_available.py` by importing your new Policy class
"""
from lerobot.__version__ import __version__ # noqa: F401
available_envs = [
"aloha",
"pusht",
"xarm",
]
available_tasks_per_env = {
"aloha": [
"AlohaInsertion-v0",
@ -41,22 +35,24 @@ available_tasks_per_env = {
"pusht": ["PushT-v0"],
"xarm": ["XarmLift-v0"],
}
available_envs = list(available_tasks_per_env.keys())
available_datasets = {
available_datasets_per_env = {
"aloha": [
"aloha_sim_insertion_human",
"aloha_sim_insertion_scripted",
"aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted",
"lerobot/aloha_sim_insertion_human",
"lerobot/aloha_sim_insertion_scripted",
"lerobot/aloha_sim_transfer_cube_human",
"lerobot/aloha_sim_transfer_cube_scripted",
],
"pusht": ["pusht"],
"pusht": ["lerobot/pusht"],
"xarm": [
"xarm_lift_medium",
"xarm_lift_medium_replay",
"xarm_push_medium",
"xarm_push_medium_replay",
"lerobot/xarm_lift_medium",
"lerobot/xarm_lift_medium_replay",
"lerobot/xarm_push_medium",
"lerobot/xarm_push_medium_replay",
],
}
available_datasets = [dataset for datasets in available_datasets_per_env.values() for dataset in datasets]
available_policies = [
"act",
@ -71,10 +67,12 @@ available_policies_per_env = {
}
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
env_dataset_pairs = [(env, dataset) for env, datasets in available_datasets.items() for dataset in datasets]
env_dataset_pairs = [
(env, dataset) for env, datasets in available_datasets_per_env.items() for dataset in datasets
]
env_dataset_policy_triplets = [
(env, dataset, policy)
for env, datasets in available_datasets.items()
for env, datasets in available_datasets_per_env.items()
for dataset in datasets
for policy in available_policies_per_env[env]
]

View File

@ -1,78 +0,0 @@
from pathlib import Path
import torch
from lerobot.common.datasets.utils import (
load_episode_data_index,
load_hf_dataset,
load_previous_and_future_frames,
load_stats,
)
class AlohaDataset(torch.utils.data.Dataset):
"""
https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human
https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted
https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human
https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted
"""
# Copied from lerobot/__init__.py
available_datasets = [
"aloha_sim_insertion_human",
"aloha_sim_insertion_scripted",
"aloha_sim_transfer_cube_human",
"aloha_sim_transfer_cube_scripted",
]
fps = 50
image_keys = ["observation.images.top"]
def __init__(
self,
dataset_id: str,
version: str | None = "v1.1",
root: Path | None = None,
split: str = "train",
transform: callable = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
self.dataset_id = dataset_id
self.version = version
self.root = root
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
# load data from hub or locally when root is provided
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
self.stats = load_stats(dataset_id, version, root)
@property
def num_samples(self) -> int:
return len(self.hf_dataset)
@property
def num_episodes(self) -> int:
return len(self.hf_dataset.unique("episode_index"))
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
item = self.hf_dataset[idx]
if self.delta_timestamps is not None:
item = load_previous_and_future_frames(
item,
self.hf_dataset,
self.episode_data_index,
self.delta_timestamps,
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
)
if self.transform is not None:
item = self.transform(item)
return item

View File

@ -1,9 +1,12 @@
import logging
import os
from pathlib import Path
import torch
from omegaconf import OmegaConf
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
@ -11,22 +14,10 @@ def make_dataset(
cfg,
split="train",
):
if cfg.env.name == "xarm":
from lerobot.common.datasets.xarm import XarmDataset
clsfunc = XarmDataset
elif cfg.env.name == "pusht":
from lerobot.common.datasets.pusht import PushtDataset
clsfunc = PushtDataset
elif cfg.env.name == "aloha":
from lerobot.common.datasets.aloha import AlohaDataset
clsfunc = AlohaDataset
else:
raise ValueError(cfg.env.name)
if cfg.env.name not in cfg.dataset.repo_id:
logging.warning(
f"There might be a mismatch between your training dataset ({cfg.dataset.repo_id=}) and your environment ({cfg.env.name=})."
)
delta_timestamps = cfg.policy.get("delta_timestamps")
if delta_timestamps is not None:
@ -36,8 +27,8 @@ def make_dataset(
# TODO(rcadene): add data augmentations
dataset = clsfunc(
dataset_id=cfg.dataset_id,
dataset = LeRobotDataset(
cfg.dataset.repo_id,
split=split,
root=DATA_DIR,
delta_timestamps=delta_timestamps,

View File

@ -1,36 +1,21 @@
from pathlib import Path
import datasets
import torch
from lerobot.common.datasets.utils import (
load_episode_data_index,
load_hf_dataset,
load_info,
load_previous_and_future_frames,
load_stats,
)
class XarmDataset(torch.utils.data.Dataset):
"""
https://huggingface.co/datasets/lerobot/xarm_lift_medium
https://huggingface.co/datasets/lerobot/xarm_lift_medium_replay
https://huggingface.co/datasets/lerobot/xarm_push_medium
https://huggingface.co/datasets/lerobot/xarm_push_medium_replay
"""
# Copied from lerobot/__init__.py
available_datasets = [
"xarm_lift_medium",
"xarm_lift_medium_replay",
"xarm_push_medium",
"xarm_push_medium_replay",
]
fps = 15
image_keys = ["observation.image"]
class LeRobotDataset(torch.utils.data.Dataset):
def __init__(
self,
dataset_id: str,
repo_id: str,
version: str | None = "v1.1",
root: Path | None = None,
split: str = "train",
@ -38,16 +23,25 @@ class XarmDataset(torch.utils.data.Dataset):
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
self.dataset_id = dataset_id
self.repo_id = repo_id
self.version = version
self.root = root
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
# load data from hub or locally when root is provided
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
self.stats = load_stats(dataset_id, version, root)
self.hf_dataset = load_hf_dataset(repo_id, version, root, split)
self.episode_data_index = load_episode_data_index(repo_id, version, root)
self.stats = load_stats(repo_id, version, root)
self.info = load_info(repo_id, version, root)
@property
def fps(self) -> int:
return self.info["fps"]
@property
def image_keys(self) -> list[str]:
return [key for key, feats in self.hf_dataset.features.items() if isinstance(feats, datasets.Image)]
@property
def num_samples(self) -> int:

View File

@ -1,76 +0,0 @@
from pathlib import Path
import torch
from lerobot.common.datasets.utils import (
load_episode_data_index,
load_hf_dataset,
load_previous_and_future_frames,
load_stats,
)
class PushtDataset(torch.utils.data.Dataset):
"""
https://huggingface.co/datasets/lerobot/pusht
Arguments
----------
delta_timestamps : dict[list[float]] | None, optional
Loads data from frames with a shift in timestamps with a different strategy for each data key (e.g. state, action or image)
If `None`, no shift is applied to current timestamp and the data from the current frame is loaded.
"""
# Copied from lerobot/__init__.py
available_datasets = ["pusht"]
fps = 10
image_keys = ["observation.image"]
def __init__(
self,
dataset_id: str = "pusht",
version: str | None = "v1.1",
root: Path | None = None,
split: str = "train",
transform: callable = None,
delta_timestamps: dict[list[float]] | None = None,
):
super().__init__()
self.dataset_id = dataset_id
self.version = version
self.root = root
self.split = split
self.transform = transform
self.delta_timestamps = delta_timestamps
# load data from hub or locally when root is provided
self.hf_dataset = load_hf_dataset(dataset_id, version, root, split)
self.episode_data_index = load_episode_data_index(dataset_id, version, root)
self.stats = load_stats(dataset_id, version, root)
@property
def num_samples(self) -> int:
return len(self.hf_dataset)
@property
def num_episodes(self) -> int:
return len(self.episode_data_index["from"])
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
item = self.hf_dataset[idx]
if self.delta_timestamps is not None:
item = load_previous_and_future_frames(
item,
self.hf_dataset,
self.episode_data_index,
self.delta_timestamps,
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
)
if self.transform is not None:
item = self.transform(item)
return item

View File

@ -1,3 +1,4 @@
import json
from copy import deepcopy
from math import ceil
from pathlib import Path
@ -15,7 +16,7 @@ from torchvision import transforms
def flatten_dict(d, parent_key="", sep="/"):
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
For example:
```
>>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}`
@ -61,19 +62,17 @@ def hf_transform_to_torch(items_dict):
return items_dict
def load_hf_dataset(dataset_id, version, root, split) -> datasets.Dataset:
def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
if root is not None:
hf_dataset = load_from_disk(str(Path(root) / dataset_id / split))
hf_dataset = load_from_disk(str(Path(root) / repo_id / split))
else:
# TODO(rcadene): remove dataset_id everywhere and use repo_id instead
repo_id = f"lerobot/{dataset_id}"
hf_dataset = load_dataset(repo_id, revision=version, split=split)
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def load_episode_data_index(dataset_id, version, root) -> dict[str, torch.Tensor]:
def load_episode_data_index(repo_id, version, root) -> dict[str, torch.Tensor]:
"""episode_data_index contains the range of indices for each episode
Example:
@ -84,9 +83,8 @@ def load_episode_data_index(dataset_id, version, root) -> dict[str, torch.Tensor
```
"""
if root is not None:
path = Path(root) / dataset_id / "meta_data" / "episode_data_index.safetensors"
path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors"
else:
repo_id = f"lerobot/{dataset_id}"
path = hf_hub_download(
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=version
)
@ -94,7 +92,7 @@ def load_episode_data_index(dataset_id, version, root) -> dict[str, torch.Tensor
return load_file(path)
def load_stats(dataset_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
Example:
@ -103,15 +101,32 @@ def load_stats(dataset_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
```
"""
if root is not None:
path = Path(root) / dataset_id / "meta_data" / "stats.safetensors"
path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
else:
repo_id = f"lerobot/{dataset_id}"
path = hf_hub_download(repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=version)
stats = load_file(path)
return unflatten_dict(stats)
def load_info(repo_id, version, root) -> dict:
"""info contains useful information regarding the dataset that are not stored elsewhere
Example:
```python
print("frame per second used to collect the video", info["fps"])
```
"""
if root is not None:
path = Path(root) / repo_id / "meta_data" / "info.json"
else:
path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=version)
with open(path) as f:
info = json.load(f)
return info
def load_previous_and_future_frames(
item: dict[str, torch.Tensor],
hf_dataset: datasets.Dataset,

View File

@ -26,7 +26,8 @@ fps: ???
offline_prioritized_sampler: true
dataset_id: ???
dataset:
repo_id: ???
n_action_steps: ???
n_obs_steps: ???

View File

@ -10,7 +10,8 @@ online_steps: 25000
fps: 50
dataset_id: aloha_sim_insertion_human
dataset:
repo_id: lerobot/aloha_sim_insertion_human
env:
name: aloha

View File

@ -10,7 +10,8 @@ online_steps: 25000
fps: 10
dataset_id: pusht
dataset:
repo_id: lerobot/pusht
env:
name: pusht

View File

@ -9,7 +9,8 @@ online_steps: 25000
fps: 15
dataset_id: xarm_lift_medium
dataset:
repo_id: lerobot/xarm_lift_medium
env:
name: xarm

View File

@ -16,22 +16,18 @@ from pathlib import Path
from safetensors.torch import save_file
from lerobot.common.datasets.pusht import PushtDataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
def save_dataset_to_safetensors(output_dir, dataset_id="pusht"):
data_dir = Path(output_dir) / dataset_id
def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
data_dir = Path(output_dir) / repo_id
if data_dir.exists():
shutil.rmtree(data_dir)
data_dir.mkdir(parents=True, exist_ok=True)
# TODO(rcadene): make it work for all datasets with LeRobotDataset(repo_id)
dataset = PushtDataset(
dataset_id=dataset_id,
split="train",
)
dataset = LeRobotDataset(repo_id)
# save 2 first frames of first episode
i = dataset.episode_data_index["from"][0].item()

View File

@ -4,9 +4,6 @@ import gymnasium as gym
import pytest
import lerobot
from lerobot.common.datasets.aloha import AlohaDataset
from lerobot.common.datasets.pusht import PushtDataset
from lerobot.common.datasets.xarm import XarmDataset
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
@ -27,25 +24,6 @@ def test_available_env_task(env_name: str, task_name: list):
assert gym_handle in gym.envs.registry, gym_handle
@pytest.mark.parametrize(
"env_name, dataset_class",
[
("aloha", AlohaDataset),
("pusht", PushtDataset),
("xarm", XarmDataset),
],
)
def test_available_datasets(env_name, dataset_class):
"""
This test verifies that the class attribute `available_datasets` for all
dataset classes is consistent with those listed in `lerobot/__init__.py`.
"""
available_env_datasets = lerobot.available_datasets[env_name]
assert set(available_env_datasets) == set(
dataset_class.available_datasets
), f"{env_name=} {available_env_datasets=}"
def test_available_policies():
"""
This test verifies that the class attribute `name` for all policies is
@ -58,3 +36,12 @@ def test_available_policies():
]
policies = [pol_cls.name for pol_cls in policy_classes]
assert set(policies) == set(lerobot.available_policies), policies
def test_print():
print(lerobot.available_envs)
print(lerobot.available_tasks_per_env)
print(lerobot.available_datasets)
print(lerobot.available_datasets_per_env)
print(lerobot.available_policies)
print(lerobot.available_policies_per_env)

View File

@ -12,7 +12,7 @@ from safetensors.torch import load_file
import lerobot
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.pusht import PushtDataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import (
compute_stats,
flatten_dict,
@ -26,13 +26,13 @@ from lerobot.common.utils.utils import init_hydra_config
from .utils import DEFAULT_CONFIG_PATH, DEVICE
@pytest.mark.parametrize("env_name, dataset_id, policy_name", lerobot.env_dataset_policy_triplets)
def test_factory(env_name, dataset_id, policy_name):
@pytest.mark.parametrize("env_name, repo_id, policy_name", lerobot.env_dataset_policy_triplets)
def test_factory(env_name, repo_id, policy_name):
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=[
f"env={env_name}",
f"dataset_id={dataset_id}",
f"dataset.repo_id={repo_id}",
f"policy={policy_name}",
f"device={DEVICE}",
],
@ -94,14 +94,13 @@ def test_compute_stats_on_xarm():
We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
because we are working with a small dataset).
"""
# TODO(rcadene): Reduce size of dataset sample on which stats compute is tested
from lerobot.common.datasets.xarm import XarmDataset
dataset = XarmDataset(
dataset_id="xarm_lift_medium",
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
dataset = LeRobotDataset(
"lerobot/xarm_lift_medium", root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
)
# reduce size of dataset sample on which stats compute is tested to 10 frames
dataset.hf_dataset = dataset.hf_dataset.select(range(10))
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
# dataset into even batches.
@ -241,16 +240,16 @@ def test_flatten_unflatten_dict():
def test_backward_compatibility():
"""This tests artifacts have been generated by `tests/scripts/save_dataset_to_safetensors.py`."""
# TODO(rcadene): make it work for all datasets with LeRobotDataset(repo_id)
dataset_id = "pusht"
data_dir = Path("tests/data/save_dataset_to_safetensors") / dataset_id
dataset = PushtDataset(
dataset_id=dataset_id,
split="train",
repo_id = "lerobot/pusht"
dataset = LeRobotDataset(
repo_id,
root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None,
)
data_dir = Path("tests/data/save_dataset_to_safetensors") / repo_id
def load_and_compare(i):
new_frame = dataset[i]
old_frame = load_file(data_dir / f"frame_{i}.safetensors")

View File

@ -19,10 +19,22 @@ from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
("xarm", "tdmpc", ["policy.mpc=true"]),
("pusht", "tdmpc", ["policy.mpc=false"]),
("pusht", "diffusion", []),
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_human"]),
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_scripted"]),
("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_human"]),
("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_scripted"]),
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset.repo_id=lerobot/aloha_sim_insertion_human"]),
(
"aloha",
"act",
["env.task=AlohaInsertion-v0", "dataset.repo_id=lerobot/aloha_sim_insertion_scripted"],
),
(
"aloha",
"act",
["env.task=AlohaTransferCube-v0", "dataset.repo_id=lerobot/aloha_sim_transfer_cube_human"],
),
(
"aloha",
"act",
["env.task=AlohaTransferCube-v0", "dataset.repo_id=lerobot/aloha_sim_transfer_cube_scripted"],
),
],
)
@require_env

View File

@ -7,12 +7,12 @@ from .utils import DEFAULT_CONFIG_PATH
@pytest.mark.parametrize(
"dataset_id",
"repo_id",
[
"aloha_sim_insertion_human",
"lerobot/aloha_sim_insertion_human",
],
)
def test_visualize_dataset(tmpdir, dataset_id):
def test_visualize_dataset(tmpdir, repo_id):
# TODO(rcadene): this test might fail with other datasets/policies/envs, since visualization_dataset
# doesnt support multiple timesteps which requires delta_timestamps to None for images.
cfg = init_hydra_config(
@ -20,7 +20,7 @@ def test_visualize_dataset(tmpdir, dataset_id):
overrides=[
"policy=act",
"env=aloha",
f"dataset_id={dataset_id}",
f"dataset.repo_id={repo_id}",
],
)
video_paths = visualize_dataset(cfg, out_dir=tmpdir)