2024-04-18 17:43:16 +08:00
"""
2024-04-25 18:23:12 +08:00
This script demonstrates the use of ` LeRobotDataset ` class for handling and processing robotic datasets from Hugging Face .
2024-04-18 17:43:16 +08:00
It illustrates how to load datasets , manipulate them , and apply transformations suitable for machine learning tasks in PyTorch .
Features included in this script :
- Loading a dataset and accessing its properties .
- Filtering data by episode number .
- Converting tensor data for visualization .
- Saving video files from dataset frames .
- Using advanced dataset features like timestamp - based frame selection .
- Demonstrating compatibility with PyTorch DataLoader for batch processing .
The script ends with examples of how to batch process data using PyTorch ' s DataLoader.
"""
from pathlib import Path
import imageio
import torch
2024-04-25 18:23:12 +08:00
import lerobot
from lerobot . common . datasets . lerobot_dataset import LeRobotDataset
2024-04-18 17:43:16 +08:00
2024-04-25 18:23:12 +08:00
print ( " List of available datasets " , lerobot . available_datasets )
# # >>> ['lerobot/aloha_sim_insertion_human', 'lerobot/aloha_sim_insertion_scripted',
# # 'lerobot/aloha_sim_transfer_cube_human', 'lerobot/aloha_sim_transfer_cube_scripted',
# # 'lerobot/pusht', 'lerobot/xarm_lift_medium']
2024-04-18 17:43:16 +08:00
2024-04-25 18:23:12 +08:00
repo_id = " lerobot/pusht "
2024-04-18 17:43:16 +08:00
2024-04-25 18:23:12 +08:00
# You can easily load a dataset from a Hugging Face repositery
dataset = LeRobotDataset ( repo_id )
2024-04-18 17:43:16 +08:00
2024-04-25 18:23:12 +08:00
# LeRobotDataset is actually a thin wrapper around an underlying Hugging Face dataset (see https://huggingface.co/docs/datasets/index for more information).
2024-04-23 20:13:25 +08:00
# TODO(rcadene): update to make the print pretty
2024-04-18 17:43:16 +08:00
print ( f " { dataset =} " )
print ( f " { dataset . hf_dataset =} " )
2024-04-25 18:23:12 +08:00
# and provides additional utilities for robotics and compatibility with pytorch
2024-04-18 17:43:16 +08:00
print ( f " number of samples/frames: { dataset . num_samples =} " )
print ( f " number of episodes: { dataset . num_episodes =} " )
print ( f " average number of frames per episode: { dataset . num_samples / dataset . num_episodes : .3f } " )
print ( f " frames per second used during data collection: { dataset . fps =} " )
print ( f " keys to access images from cameras: { dataset . image_keys =} " )
2024-05-04 22:07:14 +08:00
# Access frame indexes associated to first episode
episode_index = 0
from_idx = dataset . episode_data_index [ " from " ] [ episode_index ] . item ( )
to_idx = dataset . episode_data_index [ " to " ] [ episode_index ] . item ( )
2024-04-18 17:43:16 +08:00
2024-05-04 22:07:14 +08:00
# LeRobot datasets actually subclass PyTorch datasets so you can do everything you know and love from working with the latter, like iterating through the dataset.
# Here we grab all the image frames.
frames = [ dataset [ idx ] [ " observation.image " ] for idx in range ( from_idx , to_idx ) ]
2024-04-18 17:43:16 +08:00
2024-05-04 22:07:14 +08:00
# Video frames are now float32 in range [0,1] channel first (c,h,w) to follow pytorch convention.
# To visualize them, we convert to uint8 range [0,255]
2024-04-23 20:13:25 +08:00
frames = [ ( frame * 255 ) . type ( torch . uint8 ) for frame in frames ]
2024-05-04 22:07:14 +08:00
# and to channel last (h,w,c).
2024-04-18 17:43:16 +08:00
frames = [ frame . permute ( ( 1 , 2 , 0 ) ) . numpy ( ) for frame in frames ]
2024-05-04 22:07:14 +08:00
# Finally, we save the frames to a mp4 video for visualization.
2024-04-27 15:48:02 +08:00
Path ( " outputs/examples/1_load_lerobot_dataset " ) . mkdir ( parents = True , exist_ok = True )
2024-05-04 22:07:14 +08:00
imageio . mimsave ( " outputs/examples/1_load_lerobot_dataset/episode_0.mp4 " , frames , fps = dataset . fps )
2024-04-18 17:43:16 +08:00
2024-05-04 22:07:14 +08:00
# For many machine learning applications we need to load the history of past observations or trajectories of future actions.
# Our datasets can load previous and future frames for each key/modality,
2024-04-18 17:43:16 +08:00
# using timestamps differences with the current loaded frame. For instance:
delta_timestamps = {
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
" observation.image " : [ - 1 , - 0.5 , - 0.20 , 0 ] ,
# loads 8 state vectors: 1.5 seconds before, 1 second before, ... 20 ms, 10 ms, and current frame
" observation.state " : [ - 1.5 , - 1 , - 0.5 , - 0.20 , - 0.10 , - 0.02 , - 0.01 , 0 ] ,
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
" action " : [ t / dataset . fps for t in range ( 64 ) ] ,
}
2024-04-25 18:23:12 +08:00
dataset = LeRobotDataset ( repo_id , delta_timestamps = delta_timestamps )
2024-04-18 17:43:16 +08:00
print ( f " { dataset [ 0 ] [ ' observation.image ' ] . shape =} " ) # (4,c,h,w)
print ( f " { dataset [ 0 ] [ ' observation.state ' ] . shape =} " ) # (8,c)
print ( f " { dataset [ 0 ] [ ' action ' ] . shape =} " ) # (64,c)
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers
# because they are just PyTorch datasets.
dataloader = torch . utils . data . DataLoader (
dataset ,
2024-04-19 18:36:04 +08:00
num_workers = 0 ,
2024-04-18 17:43:16 +08:00
batch_size = 32 ,
shuffle = True ,
)
for batch in dataloader :
print ( f " { batch [ ' observation.image ' ] . shape =} " ) # (32,4,c,h,w)
print ( f " { batch [ ' observation.state ' ] . shape =} " ) # (32,8,c)
print ( f " { batch [ ' action ' ] . shape =} " ) # (32,64,c)
break