98 lines
4.1 KiB
Python
98 lines
4.1 KiB
Python
"""
|
|
This script demonstrates the use of the PushtDataset class for handling and processing robotic datasets from Hugging Face.
|
|
It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch.
|
|
|
|
Features included in this script:
|
|
- Loading a dataset and accessing its properties.
|
|
- Filtering data by episode number.
|
|
- Converting tensor data for visualization.
|
|
- Saving video files from dataset frames.
|
|
- Using advanced dataset features like timestamp-based frame selection.
|
|
- Demonstrating compatibility with PyTorch DataLoader for batch processing.
|
|
|
|
The script ends with examples of how to batch process data using PyTorch's DataLoader.
|
|
|
|
To try a different Hugging Face dataset, you can replace:
|
|
```python
|
|
dataset = PushtDataset()
|
|
```
|
|
by one of these:
|
|
```python
|
|
dataset = XarmDataset()
|
|
dataset = AlohaDataset("aloha_sim_insertion_human")
|
|
dataset = AlohaDataset("aloha_sim_insertion_scripted")
|
|
dataset = AlohaDataset("aloha_sim_transfer_cube_human")
|
|
dataset = AlohaDataset("aloha_sim_transfer_cube_scripted")
|
|
```
|
|
"""
|
|
|
|
from pathlib import Path
|
|
|
|
import imageio
|
|
import torch
|
|
|
|
from lerobot.common.datasets.pusht import PushtDataset
|
|
|
|
# TODO(rcadene): List available datasets and their dataset ids (e.g. PushtDataset, AlohaDataset(dataset_id="aloha_sim_insertion_human"))
|
|
# print("List of available datasets", lerobot.available_datasets)
|
|
# # >>> ['aloha_sim_insertion_human', 'aloha_sim_insertion_scripted',
|
|
# # 'aloha_sim_transfer_cube_human', 'aloha_sim_transfer_cube_scripted',
|
|
# # 'pusht', 'xarm_lift_medium']
|
|
|
|
|
|
# You can easily load datasets from LeRobot
|
|
dataset = PushtDataset()
|
|
|
|
# These datasets encapsulate their associated Hugging Face dataset
|
|
print(f"{dataset=}")
|
|
print(f"{dataset.hf_dataset=}")
|
|
|
|
# and provide additional utilities for robotics and compatibility with pytorch
|
|
print(f"number of samples/frames: {dataset.num_samples=}")
|
|
print(f"number of episodes: {dataset.num_episodes=}")
|
|
print(f"average number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}")
|
|
print(f"frames per second used during data collection: {dataset.fps=}")
|
|
print(f"keys to access images from cameras: {dataset.image_keys=}")
|
|
|
|
# It's still possible to select the frames belonging to episode number 5 by accessing Hugging Face dataset
|
|
dataset.hf_dataset = dataset.hf_dataset.filter(lambda frame: frame["episode_id"] == 5)
|
|
|
|
# and load all frames from one of the available cameras
|
|
frames = [sample["observation.image"] for sample in dataset]
|
|
|
|
# but frames are now channel first to follow pytorch convention,
|
|
# to view them, we convert to channel last
|
|
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
|
|
|
|
# and finally save them to a mp4 video
|
|
Path("outputs/examples/2_load_lerobot_dataset").mkdir(parents=True, exist_ok=True)
|
|
imageio.mimsave("outputs/examples/2_load_lerobot_dataset/episode_5.mp4", frames, fps=dataset.fps)
|
|
|
|
# Importantly, our datasets can load previous and future frames for each key/modality,
|
|
# using timestamps differences with the current loaded frame. For instance:
|
|
delta_timestamps = {
|
|
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
|
|
"observation.image": [-1, -0.5, -0.20, 0],
|
|
# loads 8 state vectors: 1.5 seconds before, 1 second before, ... 20 ms, 10 ms, and current frame
|
|
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, -0.02, -0.01, 0],
|
|
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
|
|
"action": [t / dataset.fps for t in range(64)],
|
|
}
|
|
dataset = PushtDataset(delta_timestamps=delta_timestamps)
|
|
print(f"{dataset[0]['observation.image'].shape=}") # (4,c,h,w)
|
|
print(f"{dataset[0]['observation.state'].shape=}") # (8,c)
|
|
print(f"{dataset[0]['action'].shape=}") # (64,c)
|
|
|
|
# Finally, our datasets are fully compatible with pytorch dataloaders and samplers
|
|
dataloader = torch.utils.data.DataLoader(
|
|
dataset,
|
|
num_workers=4,
|
|
batch_size=32,
|
|
shuffle=True,
|
|
)
|
|
for batch in dataloader:
|
|
print(f"{batch['observation.image'].shape=}") # (32,4,c,h,w)
|
|
print(f"{batch['observation.state'].shape=}") # (32,8,c)
|
|
print(f"{batch['action'].shape=}") # (32,64,c)
|
|
break
|