Merge remote-tracking branch 'upstream/alexander-soare/qol_patches_for_eval' into refactor_tdmpc
This commit is contained in:
commit
4bafbe9009
|
@ -11,7 +11,7 @@ body:
|
||||||
id: system-info
|
id: system-info
|
||||||
attributes:
|
attributes:
|
||||||
label: System Info
|
label: System Info
|
||||||
description: If needed, you can share your lerobot configuration with us by running `python -m lerobot.commands.env` and copy-pasting its outputs below
|
description: If needed, you can share your lerobot configuration with us by running `python -m lerobot.scripts.display_sys_info` and copy-pasting its outputs below
|
||||||
render: Shell
|
render: Shell
|
||||||
placeholder: lerobot version, OS, python version, numpy version, torch version, and lerobot's configuration
|
placeholder: lerobot version, OS, python version, numpy version, torch version, and lerobot's configuration
|
||||||
validations:
|
validations:
|
||||||
|
|
|
@ -117,11 +117,9 @@ jobs:
|
||||||
# run tests & coverage
|
# run tests & coverage
|
||||||
#----------------------------------------------
|
#----------------------------------------------
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
env:
|
|
||||||
LEROBOT_TESTS_DEVICE: cpu
|
|
||||||
run: |
|
run: |
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
pytest --cov=./lerobot --cov-report=xml tests
|
pytest -v --cov=./lerobot --cov-report=xml tests
|
||||||
|
|
||||||
# TODO(aliberts): Link with HF Codecov account
|
# TODO(aliberts): Link with HF Codecov account
|
||||||
# - name: Upload coverage reports to Codecov with GitHub Action
|
# - name: Upload coverage reports to Codecov with GitHub Action
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
exclude: ^(data/|tests/)
|
exclude: ^(data/|tests/data)
|
||||||
default_language_version:
|
default_language_version:
|
||||||
python: python3.10
|
python: python3.10
|
||||||
repos:
|
repos:
|
||||||
|
|
|
@ -65,6 +65,26 @@ A good feature request addresses the following points:
|
||||||
If your issue is well written we're already 80% of the way there by the time you
|
If your issue is well written we're already 80% of the way there by the time you
|
||||||
post it.
|
post it.
|
||||||
|
|
||||||
|
## Adding new policies, datasets or environments
|
||||||
|
|
||||||
|
Look at our implementations for [datasets](./lerobot/common/datasets/), [policies](./lerobot/common/policies/),
|
||||||
|
environments ([aloha](https://github.com/huggingface/gym-aloha),
|
||||||
|
[xarm](https://github.com/huggingface/gym-xarm),
|
||||||
|
[pusht](https://github.com/huggingface/gym-pusht))
|
||||||
|
and follow the same api design.
|
||||||
|
|
||||||
|
When implementing a new dataset class (e.g. `AlohaDataset`) follow these steps:
|
||||||
|
- Update `available_datasets` in `lerobot/__init__.py`
|
||||||
|
- Copy it in the required `available_datasets` class attribute
|
||||||
|
|
||||||
|
When implementing a new environment (e.g. `gym_aloha`), follow these steps:
|
||||||
|
- Update `available_envs`, `available_tasks_per_env` and `available_datasets` in `lerobot/__init__.py`
|
||||||
|
|
||||||
|
When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps:
|
||||||
|
- Update `available_policies` in `lerobot/__init__.py`
|
||||||
|
- Set the required `name` class attribute.
|
||||||
|
- Update variables in `tests/test_available.py` by importing your new Policy class
|
||||||
|
|
||||||
## Submitting a pull request (PR)
|
## Submitting a pull request (PR)
|
||||||
|
|
||||||
Before writing code, we strongly advise you to search through the existing PRs or
|
Before writing code, we strongly advise you to search through the existing PRs or
|
||||||
|
|
|
@ -200,13 +200,13 @@ def download_and_upload_pusht(root, root_tests, dataset_id="pusht", fps=10):
|
||||||
"episode_data_index_to": Value(dtype="int64", id=None),
|
"episode_data_index_to": Value(dtype="int64", id=None),
|
||||||
}
|
}
|
||||||
features = Features(features)
|
features = Features(features)
|
||||||
dataset = Dataset.from_dict(data_dict, features=features)
|
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||||
dataset = dataset.with_format("torch")
|
hf_dataset = hf_dataset.with_format("torch")
|
||||||
|
|
||||||
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
|
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
|
||||||
dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
|
hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
|
||||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
||||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
||||||
|
|
||||||
|
|
||||||
def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
|
def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
|
||||||
|
@ -311,13 +311,13 @@ def download_and_upload_xarm(root, root_tests, dataset_id, fps=15):
|
||||||
"episode_data_index_to": Value(dtype="int64", id=None),
|
"episode_data_index_to": Value(dtype="int64", id=None),
|
||||||
}
|
}
|
||||||
features = Features(features)
|
features = Features(features)
|
||||||
dataset = Dataset.from_dict(data_dict, features=features)
|
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||||
dataset = dataset.with_format("torch")
|
hf_dataset = hf_dataset.with_format("torch")
|
||||||
|
|
||||||
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
|
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
|
||||||
dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
|
hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
|
||||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
||||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
||||||
|
|
||||||
|
|
||||||
def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
|
def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
|
||||||
|
@ -460,13 +460,13 @@ def download_and_upload_aloha(root, root_tests, dataset_id, fps=50):
|
||||||
"episode_data_index_to": Value(dtype="int64", id=None),
|
"episode_data_index_to": Value(dtype="int64", id=None),
|
||||||
}
|
}
|
||||||
features = Features(features)
|
features = Features(features)
|
||||||
dataset = Dataset.from_dict(data_dict, features=features)
|
hf_dataset = Dataset.from_dict(data_dict, features=features)
|
||||||
dataset = dataset.with_format("torch")
|
hf_dataset = hf_dataset.with_format("torch")
|
||||||
|
|
||||||
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
|
num_items_first_ep = ep_dicts[0]["frame_id"].shape[0]
|
||||||
dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
|
hf_dataset.select(range(num_items_first_ep)).save_to_disk(f"{root_tests}/{dataset_id}/train")
|
||||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True)
|
||||||
dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
hf_dataset.push_to_hub(f"lerobot/{dataset_id}", token=True, revision="v1.0")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -0,0 +1,59 @@
|
||||||
|
"""
|
||||||
|
This script demonstrates the visualization of various robotic datasets from Hugging Face hub.
|
||||||
|
It covers the steps from loading the datasets, filtering specific episodes, and converting the frame data to MP4 videos.
|
||||||
|
Importantly, the dataset format is agnostic to any deep learning library and doesn't require using `lerobot` functions.
|
||||||
|
It is compatible with pytorch, jax, numpy, etc.
|
||||||
|
|
||||||
|
As an example, this script saves frames of episode number 5 of the PushT dataset to a mp4 video and saves the result here:
|
||||||
|
`outputs/examples/1_visualize_hugging_face_datasets/episode_5.mp4`
|
||||||
|
|
||||||
|
This script supports several Hugging Face datasets, among which:
|
||||||
|
1. [Pusht](https://huggingface.co/datasets/lerobot/pusht)
|
||||||
|
2. [Xarm Lift Medium](https://huggingface.co/datasets/lerobot/xarm_lift_medium)
|
||||||
|
3. [Aloha Sim Insertion Human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human)
|
||||||
|
4. [Aloha Sim Insertion Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_scripted)
|
||||||
|
5. [Aloha Sim Transfer Cube Human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human)
|
||||||
|
6. [Aloha Sim Transfer Cube Scripted](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted)
|
||||||
|
|
||||||
|
To try a different Hugging Face dataset, you can replace this line:
|
||||||
|
```python
|
||||||
|
hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10
|
||||||
|
```
|
||||||
|
by one of these:
|
||||||
|
```python
|
||||||
|
hf_dataset, fps = load_dataset("lerobot/xarm_lift_medium", split="train"), 15
|
||||||
|
hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_human", split="train"), 50
|
||||||
|
hf_dataset, fps = load_dataset("lerobot/aloha_sim_insertion_scripted", split="train"), 50
|
||||||
|
hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_human", split="train"), 50
|
||||||
|
hf_dataset, fps = load_dataset("lerobot/aloha_sim_transfer_cube_scripted", split="train"), 50
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import imageio
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
# TODO(rcadene): list available datasets on lerobot page using `datasets`
|
||||||
|
|
||||||
|
# download/load hugging face dataset in pyarrow format
|
||||||
|
hf_dataset, fps = load_dataset("lerobot/pusht", split="train"), 10
|
||||||
|
|
||||||
|
# display name of dataset and its features
|
||||||
|
print(f"{hf_dataset=}")
|
||||||
|
print(f"{hf_dataset.features=}")
|
||||||
|
|
||||||
|
# display useful statistics about frames and episodes, which are sequences of frames from the same video
|
||||||
|
print(f"number of frames: {len(hf_dataset)=}")
|
||||||
|
print(f"number of episodes: {len(hf_dataset.unique('episode_id'))=}")
|
||||||
|
print(f"average number of frames per episode: {len(hf_dataset) / len(hf_dataset.unique('episode_id')):.3f}")
|
||||||
|
|
||||||
|
# select the frames belonging to episode number 5
|
||||||
|
hf_dataset = hf_dataset.filter(lambda frame: frame["episode_id"] == 5)
|
||||||
|
|
||||||
|
# load all frames of episode 5 in RAM in PIL format
|
||||||
|
frames = hf_dataset["observation.image"]
|
||||||
|
|
||||||
|
# save episode frames to a mp4 video
|
||||||
|
Path("outputs/examples/1_load_hugging_face_dataset").mkdir(parents=True, exist_ok=True)
|
||||||
|
imageio.mimsave("outputs/examples/1_load_hugging_face_dataset/episode_5.mp4", frames, fps=fps)
|
|
@ -1,20 +0,0 @@
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import lerobot
|
|
||||||
from lerobot.common.datasets.pusht import PushtDataset
|
|
||||||
from lerobot.scripts.visualize_dataset import render_dataset
|
|
||||||
|
|
||||||
print(lerobot.available_datasets)
|
|
||||||
# >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted', 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted', 'pusht', 'xarm_lift_medium']
|
|
||||||
|
|
||||||
# TODO(rcadene): remove DATA_DIR
|
|
||||||
dataset = PushtDataset("pusht", root=Path(os.environ.get("DATA_DIR")))
|
|
||||||
|
|
||||||
video_paths = render_dataset(
|
|
||||||
dataset,
|
|
||||||
out_dir="outputs/visualize_dataset/example",
|
|
||||||
max_num_episodes=1,
|
|
||||||
)
|
|
||||||
print(video_paths)
|
|
||||||
# ['outputs/visualize_dataset/example/episode_0.mp4']
|
|
|
@ -0,0 +1,98 @@
|
||||||
|
"""
|
||||||
|
This script demonstrates the use of the PushtDataset class for handling and processing robotic datasets from Hugging Face.
|
||||||
|
It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch.
|
||||||
|
|
||||||
|
Features included in this script:
|
||||||
|
- Loading a dataset and accessing its properties.
|
||||||
|
- Filtering data by episode number.
|
||||||
|
- Converting tensor data for visualization.
|
||||||
|
- Saving video files from dataset frames.
|
||||||
|
- Using advanced dataset features like timestamp-based frame selection.
|
||||||
|
- Demonstrating compatibility with PyTorch DataLoader for batch processing.
|
||||||
|
|
||||||
|
The script ends with examples of how to batch process data using PyTorch's DataLoader.
|
||||||
|
|
||||||
|
To try a different Hugging Face dataset, you can replace:
|
||||||
|
```python
|
||||||
|
dataset = PushtDataset()
|
||||||
|
```
|
||||||
|
by one of these:
|
||||||
|
```python
|
||||||
|
dataset = XarmDataset()
|
||||||
|
dataset = AlohaDataset("aloha_sim_insertion_human")
|
||||||
|
dataset = AlohaDataset("aloha_sim_insertion_scripted")
|
||||||
|
dataset = AlohaDataset("aloha_sim_transfer_cube_human")
|
||||||
|
dataset = AlohaDataset("aloha_sim_transfer_cube_scripted")
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import imageio
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.common.datasets.pusht import PushtDataset
|
||||||
|
|
||||||
|
# TODO(rcadene): List available datasets and their dataset ids (e.g. PushtDataset, AlohaDataset(dataset_id="aloha_sim_insertion_human"))
|
||||||
|
# print("List of available datasets", lerobot.available_datasets)
|
||||||
|
# # >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted',
|
||||||
|
# # 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted',
|
||||||
|
# # 'pusht', 'xarm_lift_medium']
|
||||||
|
|
||||||
|
|
||||||
|
# You can easily load datasets from LeRobot
|
||||||
|
dataset = PushtDataset()
|
||||||
|
|
||||||
|
# All LeRobot datasets are actually a thin wrapper around an underlying Hugging Face dataset (see https://huggingface.co/docs/datasets/index for more information).
|
||||||
|
print(f"{dataset=}")
|
||||||
|
print(f"{dataset.hf_dataset=}")
|
||||||
|
|
||||||
|
# and provide additional utilities for robotics and compatibility with pytorch
|
||||||
|
print(f"number of samples/frames: {dataset.num_samples=}")
|
||||||
|
print(f"number of episodes: {dataset.num_episodes=}")
|
||||||
|
print(f"average number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}")
|
||||||
|
print(f"frames per second used during data collection: {dataset.fps=}")
|
||||||
|
print(f"keys to access images from cameras: {dataset.image_keys=}")
|
||||||
|
|
||||||
|
# While the LeRobot dataset adds helpers for working within our library, we still expose the underling Hugging Face dataset. It may be freely replaced or modified in place. Here we use the filtering to keep only frames from episode 5.
|
||||||
|
dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_id"] == 5)
|
||||||
|
|
||||||
|
# LeRobot datsets actually subclass PyTorch datasets. So you can do everything you know and love from working with the latter, for example: iterating through the dataset. Here we grap all the image frames.
|
||||||
|
frames = [sample["observation.image"] for sample in dataset]
|
||||||
|
|
||||||
|
# but frames are now channel first to follow pytorch convention,
|
||||||
|
# to view them, we convert to channel last
|
||||||
|
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
|
||||||
|
|
||||||
|
# and finally save them to a mp4 video
|
||||||
|
Path("outputs/examples/2_load_lerobot_dataset").mkdir(parents=True, exist_ok=True)
|
||||||
|
imageio.mimsave("outputs/examples/2_load_lerobot_dataset/episode_5.mp4", frames, fps=dataset.fps)
|
||||||
|
|
||||||
|
# For many machine learning applications we need to load histories of past observations, or trajectorys of future actions. Our datasets can load previous and future frames for each key/modality,
|
||||||
|
# using timestamps differences with the current loaded frame. For instance:
|
||||||
|
delta_timestamps = {
|
||||||
|
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
|
||||||
|
"observation.image": [-1, -0.5, -0.20, 0],
|
||||||
|
# loads 8 state vectors: 1.5 seconds before, 1 second before, ... 20 ms, 10 ms, and current frame
|
||||||
|
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, -0.02, -0.01, 0],
|
||||||
|
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
|
||||||
|
"action": [t / dataset.fps for t in range(64)],
|
||||||
|
}
|
||||||
|
dataset = PushtDataset(delta_timestamps=delta_timestamps)
|
||||||
|
print(f"{dataset[0]['observation.image'].shape=}") # (4,c,h,w)
|
||||||
|
print(f"{dataset[0]['observation.state'].shape=}") # (8,c)
|
||||||
|
print(f"{dataset[0]['action'].shape=}") # (64,c)
|
||||||
|
|
||||||
|
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers
|
||||||
|
# because they are just PyTorch datasets.
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
num_workers=4,
|
||||||
|
batch_size=32,
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
for batch in dataloader:
|
||||||
|
print(f"{batch['observation.image'].shape=}") # (32,4,c,h,w)
|
||||||
|
print(f"{batch['observation.state'].shape=}") # (32,8,c)
|
||||||
|
print(f"{batch['action'].shape=}") # (32,64,c)
|
||||||
|
break
|
|
@ -7,7 +7,7 @@ from pathlib import Path
|
||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
from lerobot.common.utils import init_hydra_config
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
from lerobot.scripts.eval import eval
|
from lerobot.scripts.eval import eval
|
||||||
|
|
||||||
# Get a pretrained policy from the hub.
|
# Get a pretrained policy from the hub.
|
|
@ -13,7 +13,7 @@ from omegaconf import OmegaConf
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||||
from lerobot.common.utils import init_hydra_config
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
|
|
||||||
output_directory = Path("outputs/train/example_pusht_diffusion")
|
output_directory = Path("outputs/train/example_pusht_diffusion")
|
||||||
os.makedirs(output_directory, exist_ok=True)
|
os.makedirs(output_directory, exist_ok=True)
|
|
@ -7,16 +7,22 @@ Example:
|
||||||
import lerobot
|
import lerobot
|
||||||
print(lerobot.available_envs)
|
print(lerobot.available_envs)
|
||||||
print(lerobot.available_tasks_per_env)
|
print(lerobot.available_tasks_per_env)
|
||||||
print(lerobot.available_datasets_per_env)
|
|
||||||
print(lerobot.available_datasets)
|
print(lerobot.available_datasets)
|
||||||
print(lerobot.available_policies)
|
print(lerobot.available_policies)
|
||||||
|
print(lerobot.available_policies_per_env)
|
||||||
```
|
```
|
||||||
|
|
||||||
When implementing a new dataset (e.g. `AlohaDataset`), policy (e.g. `DiffusionPolicy`), or environment, follow these steps:
|
When implementing a new dataset class (e.g. `AlohaDataset`) follow these steps:
|
||||||
- Set the required class attributes: `available_datasets`.
|
- Update `available_datasets` in `lerobot/__init__.py`
|
||||||
- Set the required class attributes: `name`.
|
- Set the required `available_datasets` class attribute using the previously updated `lerobot.available_datasets`
|
||||||
- Update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
|
||||||
- Update variables in `tests/test_available.py` by importing your new class
|
When implementing a new environment (e.g. `gym_aloha`), follow these steps:
|
||||||
|
- Update `available_envs`, `available_tasks_per_env` and `available_datasets` in `lerobot/__init__.py`
|
||||||
|
|
||||||
|
When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps:
|
||||||
|
- Update `available_policies` in `lerobot/__init__.py`
|
||||||
|
- Set the required `name` class attribute.
|
||||||
|
- Update variables in `tests/test_available.py` by importing your new Policy class
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from lerobot.__version__ import __version__ # noqa: F401
|
from lerobot.__version__ import __version__ # noqa: F401
|
||||||
|
@ -36,7 +42,7 @@ available_tasks_per_env = {
|
||||||
"xarm": ["XarmLift-v0"],
|
"xarm": ["XarmLift-v0"],
|
||||||
}
|
}
|
||||||
|
|
||||||
available_datasets_per_env = {
|
available_datasets = {
|
||||||
"aloha": [
|
"aloha": [
|
||||||
"aloha_sim_insertion_human",
|
"aloha_sim_insertion_human",
|
||||||
"aloha_sim_insertion_scripted",
|
"aloha_sim_insertion_scripted",
|
||||||
|
@ -47,10 +53,23 @@ available_datasets_per_env = {
|
||||||
"xarm": ["xarm_lift_medium"],
|
"xarm": ["xarm_lift_medium"],
|
||||||
}
|
}
|
||||||
|
|
||||||
available_datasets = [dataset for env in available_envs for dataset in available_datasets_per_env[env]]
|
|
||||||
|
|
||||||
available_policies = [
|
available_policies = [
|
||||||
"act",
|
"act",
|
||||||
"diffusion",
|
"diffusion",
|
||||||
"tdmpc",
|
"tdmpc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
available_policies_per_env = {
|
||||||
|
"aloha": ["act"],
|
||||||
|
"pusht": ["diffusion"],
|
||||||
|
"xarm": ["tdmpc"],
|
||||||
|
}
|
||||||
|
|
||||||
|
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
|
||||||
|
env_dataset_pairs = [(env, dataset) for env, datasets in available_datasets.items() for dataset in datasets]
|
||||||
|
env_dataset_policy_triplets = [
|
||||||
|
(env, dataset, policy)
|
||||||
|
for env, datasets in available_datasets.items()
|
||||||
|
for dataset in datasets
|
||||||
|
for policy in available_policies_per_env[env]
|
||||||
|
]
|
||||||
|
|
|
@ -14,6 +14,7 @@ class AlohaDataset(torch.utils.data.Dataset):
|
||||||
https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted
|
https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Copied from lerobot/__init__.py
|
||||||
available_datasets = [
|
available_datasets = [
|
||||||
"aloha_sim_insertion_human",
|
"aloha_sim_insertion_human",
|
||||||
"aloha_sim_insertion_scripted",
|
"aloha_sim_insertion_scripted",
|
||||||
|
@ -40,32 +41,33 @@ class AlohaDataset(torch.utils.data.Dataset):
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = delta_timestamps
|
||||||
if self.root is not None:
|
if self.root is not None:
|
||||||
self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
||||||
else:
|
else:
|
||||||
self.data_dict = load_dataset(
|
self.hf_dataset = load_dataset(
|
||||||
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
|
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
|
||||||
)
|
)
|
||||||
self.data_dict = self.data_dict.with_format("torch")
|
self.hf_dataset = self.hf_dataset.with_format("torch")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_samples(self) -> int:
|
def num_samples(self) -> int:
|
||||||
return len(self.data_dict)
|
return len(self.hf_dataset)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_episodes(self) -> int:
|
def num_episodes(self) -> int:
|
||||||
return len(self.data_dict.unique("episode_id"))
|
return len(self.hf_dataset.unique("episode_id"))
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.num_samples
|
return self.num_samples
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
item = self.data_dict[idx]
|
item = self.hf_dataset[idx]
|
||||||
|
|
||||||
if self.delta_timestamps is not None:
|
if self.delta_timestamps is not None:
|
||||||
item = load_previous_and_future_frames(
|
item = load_previous_and_future_frames(
|
||||||
item,
|
item,
|
||||||
self.data_dict,
|
self.hf_dataset,
|
||||||
self.delta_timestamps,
|
self.delta_timestamps,
|
||||||
|
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
|
||||||
)
|
)
|
||||||
|
|
||||||
# convert images from channel last (PIL) to channel first (pytorch)
|
# convert images from channel last (PIL) to channel first (pytorch)
|
||||||
|
|
|
@ -17,13 +17,14 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||||
If `None`, no shift is applied to current timestamp and the data from the current frame is loaded.
|
If `None`, no shift is applied to current timestamp and the data from the current frame is loaded.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Copied from lerobot/__init__.py
|
||||||
available_datasets = ["pusht"]
|
available_datasets = ["pusht"]
|
||||||
fps = 10
|
fps = 10
|
||||||
image_keys = ["observation.image"]
|
image_keys = ["observation.image"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str = "pusht",
|
||||||
version: str | None = "v1.0",
|
version: str | None = "v1.0",
|
||||||
root: Path | None = None,
|
root: Path | None = None,
|
||||||
split: str = "train",
|
split: str = "train",
|
||||||
|
@ -38,32 +39,33 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = delta_timestamps
|
||||||
if self.root is not None:
|
if self.root is not None:
|
||||||
self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
||||||
else:
|
else:
|
||||||
self.data_dict = load_dataset(
|
self.hf_dataset = load_dataset(
|
||||||
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
|
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
|
||||||
)
|
)
|
||||||
self.data_dict = self.data_dict.with_format("torch")
|
self.hf_dataset = self.hf_dataset.with_format("torch")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_samples(self) -> int:
|
def num_samples(self) -> int:
|
||||||
return len(self.data_dict)
|
return len(self.hf_dataset)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_episodes(self) -> int:
|
def num_episodes(self) -> int:
|
||||||
return len(self.data_dict.unique("episode_id"))
|
return len(self.hf_dataset.unique("episode_id"))
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.num_samples
|
return self.num_samples
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
item = self.data_dict[idx]
|
item = self.hf_dataset[idx]
|
||||||
|
|
||||||
if self.delta_timestamps is not None:
|
if self.delta_timestamps is not None:
|
||||||
item = load_previous_and_future_frames(
|
item = load_previous_and_future_frames(
|
||||||
item,
|
item,
|
||||||
self.data_dict,
|
self.hf_dataset,
|
||||||
self.delta_timestamps,
|
self.delta_timestamps,
|
||||||
|
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
|
||||||
)
|
)
|
||||||
|
|
||||||
# convert images from channel last (PIL) to channel first (pytorch)
|
# convert images from channel last (PIL) to channel first (pytorch)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from math import ceil
|
from math import ceil
|
||||||
|
|
||||||
|
import datasets
|
||||||
import einops
|
import einops
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
|
@ -8,31 +9,41 @@ import tqdm
|
||||||
|
|
||||||
def load_previous_and_future_frames(
|
def load_previous_and_future_frames(
|
||||||
item: dict[str, torch.Tensor],
|
item: dict[str, torch.Tensor],
|
||||||
data_dict: dict[str, torch.Tensor],
|
hf_dataset: datasets.Dataset,
|
||||||
delta_timestamps: dict[str, list[float]],
|
delta_timestamps: dict[str, list[float]],
|
||||||
tol: float = 0.04,
|
tol: float,
|
||||||
) -> dict[torch.Tensor]:
|
) -> dict[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Given a current item in the dataset containing a timestamp (e.g. 0.6 seconds), and a list of time differences of some modalities (e.g. delta_timestamps={"observation.image": [-0.8, -0.2, 0, 0.2]}),
|
Given a current item in the dataset containing a timestamp (e.g. 0.6 seconds), and a list of time differences of
|
||||||
this function computes for each given modality a list of query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames in the dataset.
|
some modalities (e.g. delta_timestamps={"observation.image": [-0.8, -0.2, 0, 0.2]}), this function computes for each
|
||||||
|
given modality a list of query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames in the dataset.
|
||||||
|
|
||||||
Importantly, when no frame can be found around a query timestamp within a specified tolerance window (e.g. tol=0.04), this function raises an AssertionError.
|
Importantly, when no frame can be found around a query timestamp within a specified tolerance window, this function
|
||||||
When a timestamp is queried before the first available timestamp of the episode or after the last available timestamp,
|
raises an AssertionError. When a timestamp is queried before the first available timestamp of the episode or after
|
||||||
the violation of the tolerance doesnt raise an AssertionError, and the function populates a boolean array indicating which frames are outside of the episode range.
|
the last available timestamp, the violation of the tolerance doesnt raise an AssertionError, and the function
|
||||||
For instance, this boolean array is useful during batched training to not supervise actions associated to timestamps coming after the end of the episode,
|
populates a boolean array indicating which frames are outside of the episode range. For instance, this boolean array
|
||||||
or to pad the observations in a specific way. Note that by default the observation frames before the start of the episode are the same as the first frame of the episode.
|
is useful during batched training to not supervise actions associated to timestamps coming after the end of the
|
||||||
|
episode, or to pad the observations in a specific way. Note that by default the observation frames before the start
|
||||||
|
of the episode are the same as the first frame of the episode.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
- item (dict): A dictionary containing all the data related to a frame. It is the result of `dataset[idx]`. Each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
|
- item (dict): A dictionary containing all the data related to a frame. It is the result of `dataset[idx]`. Each key
|
||||||
- data_dict (dict): A dictionary containing the full dataset. Each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
|
corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
|
||||||
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be retrieved. These deltas are added to the item timestamp to form the query timestamps.
|
- hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different
|
||||||
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query timestamp. Defaults to 0.04.
|
modality (e.g., "timestamp", "observation.image", "action").
|
||||||
|
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be
|
||||||
|
retrieved. These deltas are added to the item timestamp to form the query timestamps.
|
||||||
|
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query
|
||||||
|
timestamp by asserting `tol > difference`. It is suggested to set `tol` to a smaller value than the
|
||||||
|
smallest expected inter-frame period, but large enough to account for jitter.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- The same item with the queried frames for each modality specified in delta_timestamps, with an additional key for each modality (e.g. "observation.image_is_pad").
|
- The same item with the queried frames for each modality specified in delta_timestamps, with an additional key for
|
||||||
|
each modality (e.g. "observation.image_is_pad").
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
- AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization issues with timestamps during data collection.
|
- AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization
|
||||||
|
issues with timestamps during data collection.
|
||||||
"""
|
"""
|
||||||
# get indices of the frames associated to the episode, and their timestamps
|
# get indices of the frames associated to the episode, and their timestamps
|
||||||
ep_data_id_from = item["episode_data_index_from"].item()
|
ep_data_id_from = item["episode_data_index_from"].item()
|
||||||
|
@ -40,7 +51,7 @@ def load_previous_and_future_frames(
|
||||||
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
|
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
|
||||||
|
|
||||||
# load timestamps
|
# load timestamps
|
||||||
ep_timestamps = data_dict.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"]
|
ep_timestamps = hf_dataset.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"]
|
||||||
|
|
||||||
# we make the assumption that the timestamps are sorted
|
# we make the assumption that the timestamps are sorted
|
||||||
ep_first_ts = ep_timestamps[0]
|
ep_first_ts = ep_timestamps[0]
|
||||||
|
@ -70,7 +81,7 @@ def load_previous_and_future_frames(
|
||||||
data_ids = ep_data_ids[argmin_]
|
data_ids = ep_data_ids[argmin_]
|
||||||
|
|
||||||
# load frames modality
|
# load frames modality
|
||||||
item[key] = data_dict.select_columns(key)[data_ids][key]
|
item[key] = hf_dataset.select_columns(key)[data_ids][key]
|
||||||
item[f"{key}_is_pad"] = is_pad
|
item[f"{key}_is_pad"] = is_pad
|
||||||
|
|
||||||
return item
|
return item
|
||||||
|
|
|
@ -11,15 +11,14 @@ class XarmDataset(torch.utils.data.Dataset):
|
||||||
https://huggingface.co/datasets/lerobot/xarm_lift_medium
|
https://huggingface.co/datasets/lerobot/xarm_lift_medium
|
||||||
"""
|
"""
|
||||||
|
|
||||||
available_datasets = [
|
# Copied from lerobot/__init__.py
|
||||||
"xarm_lift_medium",
|
available_datasets = ["xarm_lift_medium"]
|
||||||
]
|
|
||||||
fps = 15
|
fps = 15
|
||||||
image_keys = ["observation.image"]
|
image_keys = ["observation.image"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str = "xarm_lift_medium",
|
||||||
version: str | None = "v1.0",
|
version: str | None = "v1.0",
|
||||||
root: Path | None = None,
|
root: Path | None = None,
|
||||||
split: str = "train",
|
split: str = "train",
|
||||||
|
@ -34,32 +33,33 @@ class XarmDataset(torch.utils.data.Dataset):
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = delta_timestamps
|
||||||
if self.root is not None:
|
if self.root is not None:
|
||||||
self.data_dict = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
self.hf_dataset = load_from_disk(Path(self.root) / self.dataset_id / self.split)
|
||||||
else:
|
else:
|
||||||
self.data_dict = load_dataset(
|
self.hf_dataset = load_dataset(
|
||||||
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
|
f"lerobot/{self.dataset_id}", revision=self.version, split=self.split
|
||||||
)
|
)
|
||||||
self.data_dict = self.data_dict.with_format("torch")
|
self.hf_dataset = self.hf_dataset.with_format("torch")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_samples(self) -> int:
|
def num_samples(self) -> int:
|
||||||
return len(self.data_dict)
|
return len(self.hf_dataset)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_episodes(self) -> int:
|
def num_episodes(self) -> int:
|
||||||
return len(self.data_dict.unique("episode_id"))
|
return len(self.hf_dataset.unique("episode_id"))
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.num_samples
|
return self.num_samples
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
item = self.data_dict[idx]
|
item = self.hf_dataset[idx]
|
||||||
|
|
||||||
if self.delta_timestamps is not None:
|
if self.delta_timestamps is not None:
|
||||||
item = load_previous_and_future_frames(
|
item = load_previous_and_future_frames(
|
||||||
item,
|
item,
|
||||||
self.data_dict,
|
self.hf_dataset,
|
||||||
self.delta_timestamps,
|
self.delta_timestamps,
|
||||||
|
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
|
||||||
)
|
)
|
||||||
|
|
||||||
# convert images from channel last (PIL) to channel first (pytorch)
|
# convert images from channel last (PIL) to channel first (pytorch)
|
||||||
|
|
|
@ -2,7 +2,7 @@ import inspect
|
||||||
|
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
|
||||||
from lerobot.common.utils import get_safe_torch_device
|
from lerobot.common.utils.utils import get_safe_torch_device
|
||||||
|
|
||||||
|
|
||||||
def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
|
def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
|
||||||
|
|
|
@ -11,7 +11,7 @@ import torch.nn as nn
|
||||||
|
|
||||||
import lerobot.common.policies.tdmpc.helper as h
|
import lerobot.common.policies.tdmpc.helper as h
|
||||||
from lerobot.common.policies.utils import populate_queues
|
from lerobot.common.policies.utils import populate_queues
|
||||||
from lerobot.common.utils import get_safe_torch_device
|
from lerobot.common.utils.utils import get_safe_torch_device
|
||||||
|
|
||||||
FIRST_FRAME = 0
|
FIRST_FRAME = 0
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
import importlib
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool:
|
||||||
|
"""Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py
|
||||||
|
Check if the package spec exists and grab its version to avoid importing a local directory.
|
||||||
|
**Note:** this doesn't work for all packages.
|
||||||
|
"""
|
||||||
|
package_exists = importlib.util.find_spec(pkg_name) is not None
|
||||||
|
package_version = "N/A"
|
||||||
|
if package_exists:
|
||||||
|
try:
|
||||||
|
# Primary method to get the package version
|
||||||
|
package_version = importlib.metadata.version(pkg_name)
|
||||||
|
except importlib.metadata.PackageNotFoundError:
|
||||||
|
# Fallback method: Only for "torch" and versions containing "dev"
|
||||||
|
if pkg_name == "torch":
|
||||||
|
try:
|
||||||
|
package = importlib.import_module(pkg_name)
|
||||||
|
temp_version = getattr(package, "__version__", "N/A")
|
||||||
|
# Check if the version contains "dev"
|
||||||
|
if "dev" in temp_version:
|
||||||
|
package_version = temp_version
|
||||||
|
package_exists = True
|
||||||
|
else:
|
||||||
|
package_exists = False
|
||||||
|
except ImportError:
|
||||||
|
# If the package can't be imported, it's not available
|
||||||
|
package_exists = False
|
||||||
|
else:
|
||||||
|
# For packages other than "torch", don't attempt the fallback and set as not available
|
||||||
|
package_exists = False
|
||||||
|
logging.debug(f"Detected {pkg_name} version: {package_version}")
|
||||||
|
if return_version:
|
||||||
|
return package_exists, package_version
|
||||||
|
else:
|
||||||
|
return package_exists
|
||||||
|
|
||||||
|
|
||||||
|
_torch_available, _torch_version = is_package_available("torch", return_version=True)
|
||||||
|
_gym_xarm_available = is_package_available("gym_xarm")
|
||||||
|
_gym_aloha_available = is_package_available("gym_aloha")
|
||||||
|
_gym_pusht_available = is_package_available("gym_pusht")
|
|
@ -15,7 +15,7 @@ cuda_version = torch._C._cuda_getCompiledVersion() if torch.version.cuda is not
|
||||||
|
|
||||||
|
|
||||||
# TODO(aliberts): refactor into an actual command `lerobot env`
|
# TODO(aliberts): refactor into an actual command `lerobot env`
|
||||||
def get_env_info() -> dict:
|
def display_sys_info() -> dict:
|
||||||
"""Run this to get basic system info to help for tracking issues & bugs."""
|
"""Run this to get basic system info to help for tracking issues & bugs."""
|
||||||
info = {
|
info = {
|
||||||
"`lerobot` version": version,
|
"`lerobot` version": version,
|
||||||
|
@ -40,4 +40,4 @@ def format_dict(d: dict) -> str:
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
get_env_info()
|
display_sys_info()
|
|
@ -44,13 +44,14 @@ import torch
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from PIL import Image as PILImage
|
from PIL import Image as PILImage
|
||||||
|
from tqdm import trange
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
||||||
from lerobot.common.logger import log_output_dir
|
from lerobot.common.logger import log_output_dir
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
||||||
|
|
||||||
|
|
||||||
def write_video(video_path, stacked_frames, fps):
|
def write_video(video_path, stacked_frames, fps):
|
||||||
|
@ -64,8 +65,12 @@ def eval_policy(
|
||||||
video_dir: Path = None,
|
video_dir: Path = None,
|
||||||
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps
|
# TODO(rcadene): make it possible to overwrite fps? we should use env.fps
|
||||||
transform: callable = None,
|
transform: callable = None,
|
||||||
|
return_episode_data: bool = False,
|
||||||
seed=None,
|
seed=None,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
set `return_episode_data` to return a Hugging Face dataset object in an "episodes" key of the return dict.
|
||||||
|
"""
|
||||||
fps = env.unwrapped.metadata["render_fps"]
|
fps = env.unwrapped.metadata["render_fps"]
|
||||||
|
|
||||||
if policy is not None:
|
if policy is not None:
|
||||||
|
@ -118,9 +123,12 @@ def eval_policy(
|
||||||
|
|
||||||
done = torch.tensor([False for _ in env.envs])
|
done = torch.tensor([False for _ in env.envs])
|
||||||
step = 0
|
step = 0
|
||||||
|
max_steps = env.envs[0]._max_episode_steps
|
||||||
|
progbar = trange(max_steps, desc=f"Running eval with {max_steps} steps (maximum) per rollout.")
|
||||||
while not done.all():
|
while not done.all():
|
||||||
# format from env keys to lerobot keys
|
# format from env keys to lerobot keys
|
||||||
observation = preprocess_observation(observation)
|
observation = preprocess_observation(observation)
|
||||||
|
if return_episode_data:
|
||||||
observations.append(deepcopy(observation))
|
observations.append(deepcopy(observation))
|
||||||
|
|
||||||
# apply transform to normalize the observations
|
# apply transform to normalize the observations
|
||||||
|
@ -167,13 +175,16 @@ def eval_policy(
|
||||||
successes.append(success)
|
successes.append(success)
|
||||||
|
|
||||||
step += 1
|
step += 1
|
||||||
|
progbar.update()
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
# add the last observation when the env is done
|
# add the last observation when the env is done
|
||||||
|
if return_episode_data:
|
||||||
observation = preprocess_observation(observation)
|
observation = preprocess_observation(observation)
|
||||||
observations.append(deepcopy(observation))
|
observations.append(deepcopy(observation))
|
||||||
|
|
||||||
|
if return_episode_data:
|
||||||
new_obses = {}
|
new_obses = {}
|
||||||
for key in observations[0].keys(): # noqa: SIM118
|
for key in observations[0].keys(): # noqa: SIM118
|
||||||
new_obses[key] = torch.stack([obs[key] for obs in observations], dim=1)
|
new_obses[key] = torch.stack([obs[key] for obs in observations], dim=1)
|
||||||
|
@ -209,6 +220,7 @@ def eval_policy(
|
||||||
|
|
||||||
# TODO(rcadene): We need to add a missing last frame which is the observation
|
# TODO(rcadene): We need to add a missing last frame which is the observation
|
||||||
# of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state"
|
# of a done state. it is critical to have this frame for tdmpc to predict a "done observation/state"
|
||||||
|
if return_episode_data:
|
||||||
ep_dict = {
|
ep_dict = {
|
||||||
"action": actions[ep_id, :num_frames],
|
"action": actions[ep_id, :num_frames],
|
||||||
"episode_id": torch.tensor([ep_id] * num_frames),
|
"episode_id": torch.tensor([ep_id] * num_frames),
|
||||||
|
@ -226,6 +238,7 @@ def eval_policy(
|
||||||
idx_from += num_frames
|
idx_from += num_frames
|
||||||
|
|
||||||
# similar logic is implemented in dataset preprocessing
|
# similar logic is implemented in dataset preprocessing
|
||||||
|
if return_episode_data:
|
||||||
data_dict = {}
|
data_dict = {}
|
||||||
keys = ep_dicts[0].keys()
|
keys = ep_dicts[0].keys()
|
||||||
for key in keys:
|
for key in keys:
|
||||||
|
@ -242,7 +255,7 @@ def eval_policy(
|
||||||
|
|
||||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||||
|
|
||||||
data_dict = Dataset.from_dict(data_dict).with_format("torch")
|
hf_dataset = Dataset.from_dict(data_dict).with_format("torch")
|
||||||
|
|
||||||
if max_episodes_rendered > 0:
|
if max_episodes_rendered > 0:
|
||||||
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
|
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
|
||||||
|
@ -250,7 +263,7 @@ def eval_policy(
|
||||||
for stacked_frames, done_index in zip(
|
for stacked_frames, done_index in zip(
|
||||||
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
|
batch_stacked_frames, done_indices.flatten().tolist(), strict=False
|
||||||
):
|
):
|
||||||
if episode_counter >= num_episodes:
|
if episode_counter >= max_episodes_rendered:
|
||||||
continue
|
continue
|
||||||
video_dir.mkdir(parents=True, exist_ok=True)
|
video_dir.mkdir(parents=True, exist_ok=True)
|
||||||
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
|
video_path = video_dir / f"eval_episode_{episode_counter}.mp4"
|
||||||
|
@ -293,8 +306,9 @@ def eval_policy(
|
||||||
"eval_s": time.time() - start,
|
"eval_s": time.time() - start,
|
||||||
"eval_ep_s": (time.time() - start) / num_episodes,
|
"eval_ep_s": (time.time() - start) / num_episodes,
|
||||||
},
|
},
|
||||||
"episodes": data_dict,
|
|
||||||
}
|
}
|
||||||
|
if return_episode_data:
|
||||||
|
info["episodes"] = hf_dataset
|
||||||
if max_episodes_rendered > 0:
|
if max_episodes_rendered > 0:
|
||||||
info["videos"] = videos
|
info["videos"] = videos
|
||||||
return info
|
return info
|
||||||
|
@ -333,6 +347,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||||
max_episodes_rendered=10,
|
max_episodes_rendered=10,
|
||||||
video_dir=Path(out_dir) / "eval",
|
video_dir=Path(out_dir) / "eval",
|
||||||
transform=transform,
|
transform=transform,
|
||||||
|
return_episode_data=False,
|
||||||
seed=cfg.seed,
|
seed=cfg.seed,
|
||||||
)
|
)
|
||||||
print(info["aggregated"])
|
print(info["aggregated"])
|
||||||
|
@ -340,7 +355,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||||
# Save info
|
# Save info
|
||||||
with open(Path(out_dir) / "eval_info.json", "w") as f:
|
with open(Path(out_dir) / "eval_info.json", "w") as f:
|
||||||
# remove pytorch tensors which are not serializable to save the evaluation results only
|
# remove pytorch tensors which are not serializable to save the evaluation results only
|
||||||
del info["episodes"]
|
|
||||||
del info["videos"]
|
del info["videos"]
|
||||||
json.dump(info, f, indent=2)
|
json.dump(info, f, indent=2)
|
||||||
|
|
||||||
|
|
|
@ -2,17 +2,18 @@ import logging
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import datasets
|
||||||
import hydra
|
import hydra
|
||||||
import torch
|
import torch
|
||||||
from datasets import concatenate_datasets
|
from datasets import concatenate_datasets
|
||||||
from datasets.utils.logging import disable_progress_bar
|
from datasets.utils import disable_progress_bars, enable_progress_bars
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.datasets.utils import cycle
|
from lerobot.common.datasets.utils import cycle
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.logger import Logger, log_output_dir
|
from lerobot.common.logger import Logger, log_output_dir
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.utils import (
|
from lerobot.common.utils.utils import (
|
||||||
format_big_number,
|
format_big_number,
|
||||||
get_safe_torch_device,
|
get_safe_torch_device,
|
||||||
init_logging,
|
init_logging,
|
||||||
|
@ -130,15 +131,40 @@ def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
|
||||||
return -(n_off * pc_on) / (n_on * (pc_on - 1))
|
return -(n_off * pc_on) / (n_on * (pc_on - 1))
|
||||||
|
|
||||||
|
|
||||||
def add_episodes_inplace(data_dict, online_dataset, concat_dataset, sampler, pc_online_samples):
|
def add_episodes_inplace(
|
||||||
first_episode_id = data_dict.select_columns("episode_id")[0]["episode_id"].item()
|
online_dataset: torch.utils.data.Dataset,
|
||||||
first_index = data_dict.select_columns("index")[0]["index"].item()
|
concat_dataset: torch.utils.data.ConcatDataset,
|
||||||
|
sampler: torch.utils.data.WeightedRandomSampler,
|
||||||
|
hf_dataset: datasets.Dataset,
|
||||||
|
pc_online_samples: float,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Modifies the online_dataset, concat_dataset, and sampler in place by integrating
|
||||||
|
new episodes from hf_dataset into the online_dataset, updating the concatenated
|
||||||
|
dataset's structure and adjusting the sampling strategy based on the specified
|
||||||
|
percentage of online samples.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- online_dataset (torch.utils.data.Dataset): The existing online dataset to be updated.
|
||||||
|
- concat_dataset (torch.utils.data.ConcatDataset): The concatenated dataset that combines
|
||||||
|
offline and online datasets, used for sampling purposes.
|
||||||
|
- sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to
|
||||||
|
reflect changes in the dataset sizes and specified sampling weights.
|
||||||
|
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
|
||||||
|
- pc_online_samples (float): The target percentage of samples that should come from
|
||||||
|
the online dataset during sampling operations.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
- AssertionError: If the first episode_id or index in hf_dataset is not 0
|
||||||
|
"""
|
||||||
|
first_episode_id = hf_dataset.select_columns("episode_id")[0]["episode_id"].item()
|
||||||
|
first_index = hf_dataset.select_columns("index")[0]["index"].item()
|
||||||
assert first_episode_id == 0, f"We expect the first episode_id to be 0 and not {first_episode_id}"
|
assert first_episode_id == 0, f"We expect the first episode_id to be 0 and not {first_episode_id}"
|
||||||
assert first_index == 0, f"We expect the first first_index to be 0 and not {first_index}"
|
assert first_index == 0, f"We expect the first first_index to be 0 and not {first_index}"
|
||||||
|
|
||||||
if len(online_dataset) == 0:
|
if len(online_dataset) == 0:
|
||||||
# initialize online dataset
|
# initialize online dataset
|
||||||
online_dataset.data_dict = data_dict
|
online_dataset.hf_dataset = hf_dataset
|
||||||
else:
|
else:
|
||||||
# find episode index and data frame indices according to previous episode in online_dataset
|
# find episode index and data frame indices according to previous episode in online_dataset
|
||||||
start_episode = online_dataset.select_columns("episode_id")[-1]["episode_id"].item() + 1
|
start_episode = online_dataset.select_columns("episode_id")[-1]["episode_id"].item() + 1
|
||||||
|
@ -152,11 +178,12 @@ def add_episodes_inplace(data_dict, online_dataset, concat_dataset, sampler, pc_
|
||||||
example["episode_data_index_to"] += start_index
|
example["episode_data_index_to"] += start_index
|
||||||
return example
|
return example
|
||||||
|
|
||||||
disable_progress_bar() # map has a tqdm progress bar
|
disable_progress_bars() # map has a tqdm progress bar
|
||||||
data_dict = data_dict.map(shift_indices)
|
hf_dataset = hf_dataset.map(shift_indices)
|
||||||
|
enable_progress_bars()
|
||||||
|
|
||||||
# extend online dataset
|
# extend online dataset
|
||||||
online_dataset.data_dict = concatenate_datasets([online_dataset.data_dict, data_dict])
|
online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset])
|
||||||
|
|
||||||
# update the concatenated dataset length used during sampling
|
# update the concatenated dataset length used during sampling
|
||||||
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
|
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
|
||||||
|
@ -275,7 +302,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
|
|
||||||
# create an empty online dataset similar to offline dataset
|
# create an empty online dataset similar to offline dataset
|
||||||
online_dataset = deepcopy(offline_dataset)
|
online_dataset = deepcopy(offline_dataset)
|
||||||
online_dataset.data_dict = {}
|
online_dataset.hf_dataset = {}
|
||||||
|
|
||||||
# create dataloader for online training
|
# create dataloader for online training
|
||||||
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
|
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
|
||||||
|
@ -304,12 +331,13 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||||
rollout_env,
|
rollout_env,
|
||||||
policy,
|
policy,
|
||||||
transform=offline_dataset.transform,
|
transform=offline_dataset.transform,
|
||||||
|
return_episode_data=True,
|
||||||
seed=cfg.seed,
|
seed=cfg.seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
online_pc_sampling = cfg.get("demo_schedule", 0.5)
|
online_pc_sampling = cfg.get("demo_schedule", 0.5)
|
||||||
add_episodes_inplace(
|
add_episodes_inplace(
|
||||||
eval_info["episodes"], online_dataset, concat_dataset, sampler, online_pc_sampling
|
online_dataset, concat_dataset, sampler, eval_info["episodes"], online_pc_sampling
|
||||||
)
|
)
|
||||||
|
|
||||||
for _ in range(cfg.policy.utd):
|
for _ in range(cfg.policy.utd):
|
||||||
|
|
|
@ -9,7 +9,7 @@ import torch
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.logger import log_output_dir
|
from lerobot.common.logger import log_output_dir
|
||||||
from lerobot.common.utils import init_logging
|
from lerobot.common.utils.utils import init_logging
|
||||||
|
|
||||||
NUM_EPISODES_TO_RENDER = 50
|
NUM_EPISODES_TO_RENDER = 50
|
||||||
MAX_NUM_STEPS = 1000
|
MAX_NUM_STEPS = 1000
|
||||||
|
|
|
@ -1,63 +0,0 @@
|
||||||
"""
|
|
||||||
This script is designed to facilitate the creation of a subset of an existing dataset by selecting a specific number of frames from the original dataset.
|
|
||||||
This subset can then be used for running quick unit tests.
|
|
||||||
The script takes an input directory containing the original dataset and an output directory where the subset of the dataset will be saved.
|
|
||||||
Additionally, the number of frames to include in the subset can be specified.
|
|
||||||
The script ensures that the subset is a representative sample of the original dataset by copying the specified number of frames and retaining the structure and format of the data.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
Run the script with the following command, specifying the path to the input data directory,
|
|
||||||
the path to the output data directory, and optionally the number of frames to include in the subset dataset:
|
|
||||||
|
|
||||||
`python tests/scripts/mock_dataset.py --in-data-dir path/to/input_data --out-data-dir path/to/output_data`
|
|
||||||
|
|
||||||
Example:
|
|
||||||
`python tests/scripts/mock_dataset.py --in-data-dir data/pusht --out-data-dir tests/data/pusht`
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def mock_dataset(in_data_dir, out_data_dir, num_frames):
|
|
||||||
in_data_dir = Path(in_data_dir)
|
|
||||||
out_data_dir = Path(out_data_dir)
|
|
||||||
out_data_dir.mkdir(exist_ok=True, parents=True)
|
|
||||||
|
|
||||||
# copy the first `n` frames for each data key so that we have real data
|
|
||||||
in_data_dict = torch.load(in_data_dir / "data_dict.pth")
|
|
||||||
out_data_dict = {key: in_data_dict[key][:num_frames].clone() for key in in_data_dict}
|
|
||||||
torch.save(out_data_dict, out_data_dir / "data_dict.pth")
|
|
||||||
|
|
||||||
# recreate data_ids_per_episode that corresponds to the subset
|
|
||||||
episodes = in_data_dict["episode"][:num_frames].tolist()
|
|
||||||
data_ids_per_episode = {}
|
|
||||||
for idx, ep_id in enumerate(episodes):
|
|
||||||
if ep_id not in data_ids_per_episode:
|
|
||||||
data_ids_per_episode[ep_id] = []
|
|
||||||
data_ids_per_episode[ep_id].append(idx)
|
|
||||||
for ep_id in data_ids_per_episode:
|
|
||||||
data_ids_per_episode[ep_id] = torch.tensor(data_ids_per_episode[ep_id])
|
|
||||||
torch.save(data_ids_per_episode, out_data_dir / "data_ids_per_episode.pth")
|
|
||||||
|
|
||||||
# copy the full statistics of dataset since it's small
|
|
||||||
in_stats_path = in_data_dir / "stats.pth"
|
|
||||||
out_stats_path = out_data_dir / "stats.pth"
|
|
||||||
shutil.copy(in_stats_path, out_stats_path)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Create a dataset with a subset of frames for quick testing.")
|
|
||||||
|
|
||||||
parser.add_argument("--in-data-dir", type=str, help="Path to input data")
|
|
||||||
parser.add_argument("--out-data-dir", type=str, help="Path to save the output data")
|
|
||||||
parser.add_argument("--num-frames", type=int, default=50, help="Number of frames to copy over")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
mock_dataset(args.in_data_dir, args.out_data_dir, args.num_frames)
|
|
|
@ -1,53 +1,60 @@
|
||||||
"""
|
|
||||||
This test verifies that all environments, datasets, policies listed in `lerobot/__init__.py` can be sucessfully
|
|
||||||
imported and that their class attributes (eg. `available_datasets`, `name`, `available_tasks`) are valid.
|
|
||||||
|
|
||||||
When implementing a new dataset (e.g. `AlohaDataset`), policy (e.g. `DiffusionPolicy`), or environment, follow these steps:
|
|
||||||
- Set the required class attributes: `available_datasets`.
|
|
||||||
- Set the required class attributes: `name`.
|
|
||||||
- Update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
|
|
||||||
- Update variables in `tests/test_available.py` by importing your new class
|
|
||||||
"""
|
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import pytest
|
|
||||||
import lerobot
|
|
||||||
import gymnasium as gym
|
|
||||||
|
|
||||||
from lerobot.common.datasets.xarm import XarmDataset
|
import gymnasium as gym
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import lerobot
|
||||||
from lerobot.common.datasets.aloha import AlohaDataset
|
from lerobot.common.datasets.aloha import AlohaDataset
|
||||||
from lerobot.common.datasets.pusht import PushtDataset
|
from lerobot.common.datasets.pusht import PushtDataset
|
||||||
|
from lerobot.common.datasets.xarm import XarmDataset
|
||||||
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
|
||||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||||
|
from tests.utils import require_env
|
||||||
|
|
||||||
|
|
||||||
def test_available():
|
@pytest.mark.parametrize("env_name, task_name", lerobot.env_task_pairs)
|
||||||
|
@require_env
|
||||||
|
def test_available_env_task(env_name: str, task_name: list):
|
||||||
|
"""
|
||||||
|
This test verifies that all environments listed in `lerobot/__init__.py` can
|
||||||
|
be sucessfully imported — if they're installed — and that their
|
||||||
|
`available_tasks_per_env` are valid.
|
||||||
|
"""
|
||||||
|
package_name = f"gym_{env_name}"
|
||||||
|
importlib.import_module(package_name)
|
||||||
|
gym_handle = f"{package_name}/{task_name}"
|
||||||
|
assert gym_handle in gym.envs.registry, gym_handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"env_name, dataset_class",
|
||||||
|
[
|
||||||
|
("aloha", AlohaDataset),
|
||||||
|
("pusht", PushtDataset),
|
||||||
|
("xarm", XarmDataset),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_available_datasets(env_name, dataset_class):
|
||||||
|
"""
|
||||||
|
This test verifies that the class attribute `available_datasets` for all
|
||||||
|
dataset classes is consistent with those listed in `lerobot/__init__.py`.
|
||||||
|
"""
|
||||||
|
available_env_datasets = lerobot.available_datasets[env_name]
|
||||||
|
assert set(available_env_datasets) == set(
|
||||||
|
dataset_class.available_datasets
|
||||||
|
), f"{env_name=} {available_env_datasets=}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_available_policies():
|
||||||
|
"""
|
||||||
|
This test verifies that the class attribute `name` for all policies is
|
||||||
|
consistent with those listed in `lerobot/__init__.py`.
|
||||||
|
"""
|
||||||
policy_classes = [
|
policy_classes = [
|
||||||
ActionChunkingTransformerPolicy,
|
ActionChunkingTransformerPolicy,
|
||||||
DiffusionPolicy,
|
DiffusionPolicy,
|
||||||
TDMPCPolicy,
|
TDMPCPolicy,
|
||||||
]
|
]
|
||||||
|
|
||||||
dataset_class_per_env = {
|
|
||||||
"aloha": AlohaDataset,
|
|
||||||
"pusht": PushtDataset,
|
|
||||||
"xarm": XarmDataset,
|
|
||||||
}
|
|
||||||
|
|
||||||
policies = [pol_cls.name for pol_cls in policy_classes]
|
policies = [pol_cls.name for pol_cls in policy_classes]
|
||||||
assert set(policies) == set(lerobot.available_policies), policies
|
assert set(policies) == set(lerobot.available_policies), policies
|
||||||
|
|
||||||
for env_name in lerobot.available_envs:
|
|
||||||
for task_name in lerobot.available_tasks_per_env[env_name]:
|
|
||||||
package_name = f"gym_{env_name}"
|
|
||||||
importlib.import_module(package_name)
|
|
||||||
gym_handle = f"{package_name}/{task_name}"
|
|
||||||
assert gym_handle in gym.envs.registry.keys(), gym_handle
|
|
||||||
|
|
||||||
dataset_class = dataset_class_per_env[env_name]
|
|
||||||
available_datasets = lerobot.available_datasets_per_env[env_name]
|
|
||||||
assert set(available_datasets) == set(dataset_class.available_datasets), f"{env_name=} {available_datasets=}"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,33 +1,35 @@
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.common.datasets.utils import compute_stats, get_stats_einops_patterns, load_previous_and_future_frames
|
|
||||||
from lerobot.common.transforms import Prod
|
|
||||||
from lerobot.common.utils import init_hydra_config
|
|
||||||
import logging
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
|
||||||
|
|
||||||
|
import lerobot
|
||||||
@pytest.mark.parametrize(
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
"env_name,dataset_id,policy_name",
|
from lerobot.common.datasets.utils import (
|
||||||
[
|
compute_stats,
|
||||||
("xarm", "xarm_lift_medium", "tdmpc"),
|
get_stats_einops_patterns,
|
||||||
("pusht", "pusht", "diffusion"),
|
load_previous_and_future_frames,
|
||||||
("aloha", "aloha_sim_insertion_human", "act"),
|
|
||||||
("aloha", "aloha_sim_insertion_scripted", "act"),
|
|
||||||
("aloha", "aloha_sim_transfer_cube_human", "act"),
|
|
||||||
("aloha", "aloha_sim_transfer_cube_scripted", "act"),
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
from lerobot.common.transforms import Prod
|
||||||
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
|
|
||||||
|
from .utils import DEFAULT_CONFIG_PATH, DEVICE
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("env_name, dataset_id, policy_name", lerobot.env_dataset_policy_triplets)
|
||||||
def test_factory(env_name, dataset_id, policy_name):
|
def test_factory(env_name, dataset_id, policy_name):
|
||||||
cfg = init_hydra_config(
|
cfg = init_hydra_config(
|
||||||
DEFAULT_CONFIG_PATH,
|
DEFAULT_CONFIG_PATH,
|
||||||
overrides=[f"env={env_name}", f"dataset_id={dataset_id}", f"policy={policy_name}", f"device={DEVICE}"]
|
overrides=[
|
||||||
|
f"env={env_name}",
|
||||||
|
f"dataset_id={dataset_id}",
|
||||||
|
f"policy={policy_name}",
|
||||||
|
f"device={DEVICE}",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
dataset = make_dataset(cfg)
|
dataset = make_dataset(cfg)
|
||||||
delta_timestamps = dataset.delta_timestamps
|
delta_timestamps = dataset.delta_timestamps
|
||||||
|
@ -50,7 +52,7 @@ def test_factory(env_name, dataset_id, policy_name):
|
||||||
keys_ndim_required.append(
|
keys_ndim_required.append(
|
||||||
(key, 3, True),
|
(key, 3, True),
|
||||||
)
|
)
|
||||||
assert dataset.data_dict[key].dtype == torch.uint8, f"{key}"
|
assert dataset.hf_dataset[key].dtype == torch.uint8, f"{key}"
|
||||||
|
|
||||||
# test number of dimensions
|
# test number of dimensions
|
||||||
for key, ndim, required in keys_ndim_required:
|
for key, ndim, required in keys_ndim_required:
|
||||||
|
@ -80,14 +82,13 @@ def test_factory(env_name, dataset_id, policy_name):
|
||||||
# test c,h,w
|
# test c,h,w
|
||||||
assert item[key].shape[0] == 3, f"{key}"
|
assert item[key].shape[0] == 3, f"{key}"
|
||||||
|
|
||||||
|
|
||||||
if delta_timestamps is not None:
|
if delta_timestamps is not None:
|
||||||
# test missing keys in delta_timestamps
|
# test missing keys in delta_timestamps
|
||||||
for key in delta_timestamps:
|
for key in delta_timestamps:
|
||||||
assert key in item, f"{key}"
|
assert key in item, f"{key}"
|
||||||
|
|
||||||
|
|
||||||
def test_compute_stats():
|
def test_compute_stats_on_xarm():
|
||||||
"""Check that the statistics are computed correctly according to the stats_patterns property.
|
"""Check that the statistics are computed correctly according to the stats_patterns property.
|
||||||
|
|
||||||
We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
|
We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
|
||||||
|
@ -95,14 +96,14 @@ def test_compute_stats():
|
||||||
"""
|
"""
|
||||||
from lerobot.common.datasets.xarm import XarmDataset
|
from lerobot.common.datasets.xarm import XarmDataset
|
||||||
|
|
||||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
data_dir = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||||
|
|
||||||
# get transform to convert images from uint8 [0,255] to float32 [0,1]
|
# get transform to convert images from uint8 [0,255] to float32 [0,1]
|
||||||
transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0)
|
transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0)
|
||||||
|
|
||||||
dataset = XarmDataset(
|
dataset = XarmDataset(
|
||||||
dataset_id="xarm_lift_medium",
|
dataset_id="xarm_lift_medium",
|
||||||
root=DATA_DIR,
|
root=data_dir,
|
||||||
transform=transform,
|
transform=transform,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -121,16 +122,18 @@ def test_compute_stats():
|
||||||
batch_size=len(dataset),
|
batch_size=len(dataset),
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
)
|
)
|
||||||
data_dict = next(iter(dataloader))
|
hf_dataset = next(iter(dataloader))
|
||||||
|
|
||||||
# compute stats based on all frames from the dataset without any batching
|
# compute stats based on all frames from the dataset without any batching
|
||||||
expected_stats = {}
|
expected_stats = {}
|
||||||
for k, pattern in stats_patterns.items():
|
for k, pattern in stats_patterns.items():
|
||||||
expected_stats[k] = {}
|
expected_stats[k] = {}
|
||||||
expected_stats[k]["mean"] = einops.reduce(data_dict[k], pattern, "mean")
|
expected_stats[k]["mean"] = einops.reduce(hf_dataset[k], pattern, "mean")
|
||||||
expected_stats[k]["std"] = torch.sqrt(einops.reduce((data_dict[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean"))
|
expected_stats[k]["std"] = torch.sqrt(
|
||||||
expected_stats[k]["min"] = einops.reduce(data_dict[k], pattern, "min")
|
einops.reduce((hf_dataset[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")
|
||||||
expected_stats[k]["max"] = einops.reduce(data_dict[k], pattern, "max")
|
)
|
||||||
|
expected_stats[k]["min"] = einops.reduce(hf_dataset[k], pattern, "min")
|
||||||
|
expected_stats[k]["max"] = einops.reduce(hf_dataset[k], pattern, "max")
|
||||||
|
|
||||||
# test computed stats match expected stats
|
# test computed stats match expected stats
|
||||||
for k in stats_patterns:
|
for k in stats_patterns:
|
||||||
|
@ -153,49 +156,57 @@ def test_compute_stats():
|
||||||
|
|
||||||
|
|
||||||
def test_load_previous_and_future_frames_within_tolerance():
|
def test_load_previous_and_future_frames_within_tolerance():
|
||||||
data_dict = Dataset.from_dict({
|
hf_dataset = Dataset.from_dict(
|
||||||
|
{
|
||||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||||
"index": [0, 1, 2, 3, 4],
|
"index": [0, 1, 2, 3, 4],
|
||||||
"episode_data_index_from": [0, 0, 0, 0, 0],
|
"episode_data_index_from": [0, 0, 0, 0, 0],
|
||||||
"episode_data_index_to": [5, 5, 5, 5, 5],
|
"episode_data_index_to": [5, 5, 5, 5, 5],
|
||||||
})
|
}
|
||||||
data_dict = data_dict.with_format("torch")
|
)
|
||||||
item = data_dict[2]
|
hf_dataset = hf_dataset.with_format("torch")
|
||||||
|
item = hf_dataset[2]
|
||||||
delta_timestamps = {"index": [-0.2, 0, 0.139]}
|
delta_timestamps = {"index": [-0.2, 0, 0.139]}
|
||||||
tol = 0.04
|
tol = 0.04
|
||||||
item = load_previous_and_future_frames(item, data_dict, delta_timestamps, tol)
|
item = load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
|
||||||
data, is_pad = item["index"], item["index_is_pad"]
|
data, is_pad = item["index"], item["index_is_pad"]
|
||||||
assert torch.equal(data, torch.tensor([0, 2, 3])), "Data does not match expected values"
|
assert torch.equal(data, torch.tensor([0, 2, 3])), "Data does not match expected values"
|
||||||
assert not is_pad.any(), "Unexpected padding detected"
|
assert not is_pad.any(), "Unexpected padding detected"
|
||||||
|
|
||||||
|
|
||||||
def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range():
|
def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range():
|
||||||
data_dict = Dataset.from_dict({
|
hf_dataset = Dataset.from_dict(
|
||||||
|
{
|
||||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||||
"index": [0, 1, 2, 3, 4],
|
"index": [0, 1, 2, 3, 4],
|
||||||
"episode_data_index_from": [0, 0, 0, 0, 0],
|
"episode_data_index_from": [0, 0, 0, 0, 0],
|
||||||
"episode_data_index_to": [5, 5, 5, 5, 5],
|
"episode_data_index_to": [5, 5, 5, 5, 5],
|
||||||
})
|
}
|
||||||
data_dict = data_dict.with_format("torch")
|
)
|
||||||
item = data_dict[2]
|
hf_dataset = hf_dataset.with_format("torch")
|
||||||
|
item = hf_dataset[2]
|
||||||
delta_timestamps = {"index": [-0.2, 0, 0.141]}
|
delta_timestamps = {"index": [-0.2, 0, 0.141]}
|
||||||
tol = 0.04
|
tol = 0.04
|
||||||
with pytest.raises(AssertionError):
|
with pytest.raises(AssertionError):
|
||||||
load_previous_and_future_frames(item, data_dict, delta_timestamps, tol)
|
load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
|
||||||
|
|
||||||
|
|
||||||
def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range():
|
def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range():
|
||||||
data_dict = Dataset.from_dict({
|
hf_dataset = Dataset.from_dict(
|
||||||
|
{
|
||||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||||
"index": [0, 1, 2, 3, 4],
|
"index": [0, 1, 2, 3, 4],
|
||||||
"episode_data_index_from": [0, 0, 0, 0, 0],
|
"episode_data_index_from": [0, 0, 0, 0, 0],
|
||||||
"episode_data_index_to": [5, 5, 5, 5, 5],
|
"episode_data_index_to": [5, 5, 5, 5, 5],
|
||||||
})
|
}
|
||||||
data_dict = data_dict.with_format("torch")
|
)
|
||||||
item = data_dict[2]
|
hf_dataset = hf_dataset.with_format("torch")
|
||||||
|
item = hf_dataset[2]
|
||||||
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
|
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
|
||||||
tol = 0.04
|
tol = 0.04
|
||||||
item = load_previous_and_future_frames(item, data_dict, delta_timestamps, tol)
|
item = load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
|
||||||
data, is_pad = item["index"], item["index_is_pad"]
|
data, is_pad = item["index"], item["index_is_pad"]
|
||||||
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
|
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
|
||||||
assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), "Padding does not match expected values"
|
assert torch.equal(
|
||||||
|
is_pad, torch.tensor([True, False, False, True, True])
|
||||||
|
), "Padding does not match expected values"
|
||||||
|
|
|
@ -1,49 +1,37 @@
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
|
||||||
import gymnasium as gym
|
|
||||||
from gymnasium.utils.env_checker import check_env
|
from gymnasium.utils.env_checker import check_env
|
||||||
|
|
||||||
|
import lerobot
|
||||||
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.utils import init_hydra_config
|
|
||||||
|
|
||||||
from lerobot.common.envs.utils import preprocess_observation
|
from lerobot.common.envs.utils import preprocess_observation
|
||||||
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
|
|
||||||
from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
from .utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
|
||||||
|
|
||||||
|
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("obs_type", OBS_TYPES)
|
||||||
"env_name, task, obs_type",
|
@pytest.mark.parametrize("env_name, env_task", lerobot.env_task_pairs)
|
||||||
[
|
@require_env
|
||||||
# ("AlohaInsertion-v0", "state"),
|
def test_env(env_name, env_task, obs_type):
|
||||||
("aloha", "AlohaInsertion-v0", "pixels"),
|
if env_name == "aloha" and obs_type == "state":
|
||||||
("aloha", "AlohaInsertion-v0", "pixels_agent_pos"),
|
pytest.skip("`state` observations not available for aloha")
|
||||||
("aloha", "AlohaTransferCube-v0", "pixels"),
|
|
||||||
("aloha", "AlohaTransferCube-v0", "pixels_agent_pos"),
|
|
||||||
("xarm", "XarmLift-v0", "state"),
|
|
||||||
("xarm", "XarmLift-v0", "pixels"),
|
|
||||||
("xarm", "XarmLift-v0", "pixels_agent_pos"),
|
|
||||||
("pusht", "PushT-v0", "state"),
|
|
||||||
("pusht", "PushT-v0", "pixels"),
|
|
||||||
("pusht", "PushT-v0", "pixels_agent_pos"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_env(env_name, task, obs_type):
|
|
||||||
package_name = f"gym_{env_name}"
|
package_name = f"gym_{env_name}"
|
||||||
importlib.import_module(package_name)
|
importlib.import_module(package_name)
|
||||||
env = gym.make(f"{package_name}/{task}", obs_type=obs_type)
|
env = gym.make(f"{package_name}/{env_task}", obs_type=obs_type)
|
||||||
check_env(env.unwrapped, skip_render_check=True)
|
check_env(env.unwrapped, skip_render_check=True)
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"env_name",
|
@pytest.mark.parametrize("env_name", lerobot.available_envs)
|
||||||
[
|
@require_env
|
||||||
"pusht",
|
|
||||||
"xarm",
|
|
||||||
"aloha",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_factory(env_name):
|
def test_factory(env_name):
|
||||||
cfg = init_hydra_config(
|
cfg = init_hydra_config(
|
||||||
DEFAULT_CONFIG_PATH,
|
DEFAULT_CONFIG_PATH,
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import subprocess
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,25 +9,31 @@ def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> s
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def _run_script(path):
|
||||||
|
subprocess.run(["python", path], check=True)
|
||||||
|
|
||||||
|
|
||||||
def test_example_1():
|
def test_example_1():
|
||||||
path = "examples/1_visualize_dataset.py"
|
path = "examples/1_load_hugging_face_dataset.py"
|
||||||
|
_run_script(path)
|
||||||
with open(path, "r") as file:
|
assert Path("outputs/examples/1_load_hugging_face_dataset/episode_5.mp4").exists()
|
||||||
file_contents = file.read()
|
|
||||||
exec(file_contents)
|
|
||||||
|
|
||||||
assert Path("outputs/visualize_dataset/example/episode_0.mp4").exists()
|
|
||||||
|
|
||||||
|
|
||||||
def test_examples_3_and_2():
|
def test_example_2():
|
||||||
|
path = "examples/2_load_lerobot_dataset.py"
|
||||||
|
_run_script(path)
|
||||||
|
assert Path("outputs/examples/2_load_lerobot_dataset/episode_5.mp4").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_examples_4_and_3():
|
||||||
"""
|
"""
|
||||||
Train a model with example 3, check the outputs.
|
Train a model with example 3, check the outputs.
|
||||||
Evaluate the trained model with example 2, check the outputs.
|
Evaluate the trained model with example 2, check the outputs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
path = "examples/3_train_policy.py"
|
path = "examples/4_train_policy.py"
|
||||||
|
|
||||||
with open(path, "r") as file:
|
with open(path) as file:
|
||||||
file_contents = file.read()
|
file_contents = file.read()
|
||||||
|
|
||||||
# Do less steps, use smaller batch, use CPU, and don't complicate things with dataloader workers.
|
# Do less steps, use smaller batch, use CPU, and don't complicate things with dataloader workers.
|
||||||
|
@ -46,9 +53,9 @@ def test_examples_3_and_2():
|
||||||
for file_name in ["model.pt", "stats.pth", "config.yaml"]:
|
for file_name in ["model.pt", "stats.pth", "config.yaml"]:
|
||||||
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
|
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
|
||||||
|
|
||||||
path = "examples/2_evaluate_pretrained_policy.py"
|
path = "examples/3_evaluate_pretrained_policy.py"
|
||||||
|
|
||||||
with open(path, "r") as file:
|
with open(path) as file:
|
||||||
file_contents = file.read()
|
file_contents = file.read()
|
||||||
|
|
||||||
# Do less evals, use CPU, and use the local model.
|
# Do less evals, use CPU, and use the local model.
|
||||||
|
@ -67,4 +74,4 @@ def test_examples_3_and_2():
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert Path(f"outputs/train/example_pusht_diffusion").exists()
|
assert Path("outputs/train/example_pusht_diffusion").exists()
|
||||||
|
|
|
@ -1,16 +1,18 @@
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.datasets.utils import cycle
|
from lerobot.common.datasets.utils import cycle
|
||||||
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.policies.policy_protocol import Policy
|
from lerobot.common.policies.policy_protocol import Policy
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
|
||||||
from lerobot.common.utils import init_hydra_config
|
from .utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
|
||||||
from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(aliberts): refactor using lerobot/__init__.py variables
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"env_name,policy_name,extra_overrides",
|
"env_name,policy_name,extra_overrides",
|
||||||
[
|
[
|
||||||
|
@ -21,10 +23,9 @@ from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
||||||
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_scripted"]),
|
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_scripted"]),
|
||||||
("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_human"]),
|
("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_human"]),
|
||||||
("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_scripted"]),
|
("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_scripted"]),
|
||||||
# TODO(aliberts): xarm not working with diffusion
|
|
||||||
# ("xarm", "diffusion", []),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@require_env
|
||||||
def test_policy(env_name, policy_name, extra_overrides):
|
def test_policy(env_name, policy_name, extra_overrides):
|
||||||
"""
|
"""
|
||||||
Tests:
|
Tests:
|
||||||
|
|
|
@ -1,6 +1,37 @@
|
||||||
import os
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.common.utils.import_utils import is_package_available
|
||||||
|
|
||||||
# Pass this as the first argument to init_hydra_config.
|
# Pass this as the first argument to init_hydra_config.
|
||||||
DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
|
DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
|
||||||
|
|
||||||
DEVICE = os.environ.get('LEROBOT_TESTS_DEVICE', "cuda")
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def require_env(func):
|
||||||
|
"""
|
||||||
|
Decorator that skips the test if the required environment package is not installed.
|
||||||
|
As it need 'env_name' in args, it also checks whether it is provided as an argument.
|
||||||
|
"""
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
# Determine if 'env_name' is provided and extract its value
|
||||||
|
arg_names = func.__code__.co_varnames[: func.__code__.co_argcount]
|
||||||
|
if "env_name" in arg_names:
|
||||||
|
# Get the index of 'env_name' and retrieve the value from args
|
||||||
|
index = arg_names.index("env_name")
|
||||||
|
env_name = args[index] if len(args) > index else kwargs.get("env_name")
|
||||||
|
else:
|
||||||
|
raise ValueError("Function does not have 'env_name' as an argument.")
|
||||||
|
|
||||||
|
# Perform the package check
|
||||||
|
package_name = f"gym_{env_name}"
|
||||||
|
if not is_package_available(package_name):
|
||||||
|
pytest.skip(f"gym-{env_name} not installed")
|
||||||
|
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
Loading…
Reference in New Issue