diff --git a/examples/1_load_lerobot_dataset.py b/examples/1_load_lerobot_dataset.py index 2647078c..fb3a4749 100644 --- a/examples/1_load_lerobot_dataset.py +++ b/examples/1_load_lerobot_dataset.py @@ -14,53 +14,92 @@ The script ends with examples of how to batch process data using PyTorch's DataL """ # TODO(aliberts, rcadene): Update this script with the new v2 api -from pathlib import Path from pprint import pprint -import imageio import torch +from huggingface_hub import HfApi import lerobot -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata +# We ported a number of existing datasets ourselves, use this to see the list: print("List of available datasets:") pprint(lerobot.available_datasets) -# Let's take one for this example -repo_id = "lerobot/pusht" +# You can also browse through the datasets created/ported by the community on the hub using the hub api: +hub_api = HfApi() +repo_ids = [info.id for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])] +pprint(repo_ids) -# You can easily load a dataset from a Hugging Face repository +# Or simply explore them in your web browser directly at: +# https://huggingface.co/datasets?other=LeRobot + +# Let's take this one for this example +repo_id = "aliberts/koch_tutorial" +# We can have a look and fetch its metadata to know more about it: +ds_meta = LeRobotDatasetMetadata(repo_id) + +# By instantiating just this class, you can quickly access useful information about the content and the +# structure of the dataset without downloading the actual data yet (only metadata files — which are +# lightweight). +print(f"Total number of episodes: {ds_meta.total_episodes}") +print(f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}") +print(f"Frames per second used during data collection: {ds_meta.fps}") +print(f"Robot type: {ds_meta.robot_type}") +print(f"keys to access images from cameras: {ds_meta.camera_keys=}\n") + +print("Tasks:") +print(ds_meta.tasks) +print("Features:") +pprint(ds_meta.features) + +# You can also get a short summary by simply printing the object: +print(ds_meta) + +# You can then load the actual dataset from the hub. +# Either load any subset of episodes: +dataset = LeRobotDataset(repo_id, episodes=[0, 10, 11, 23]) + +# And see how many frames you have: +print(f"Selected episodes: {dataset.episodes}") +print(f"Number of episodes selected: {dataset.num_episodes}") +print(f"Number of frames selected: {dataset.num_frames}") + +# Or simply load the entire dataset: dataset = LeRobotDataset(repo_id) +print(f"Number of episodes selected: {dataset.num_episodes}") +print(f"Number of frames selected: {dataset.num_frames}") + +# The previous metadata class is contained in the 'meta' attribute of the dataset: +print(dataset.meta) # LeRobotDataset actually wraps an underlying Hugging Face dataset -# (see https://huggingface.co/docs/datasets/index for more information). -print(dataset) +# (see https://huggingface.co/docs/datasets for more information). print(dataset.hf_dataset) -# And provides additional utilities for robotics and compatibility with Pytorch -print(f"\naverage number of frames per episode: {dataset.num_frames / dataset.num_episodes:.3f}") -print(f"frames per second used during data collection: {dataset.fps=}") -print(f"keys to access images from cameras: {dataset.meta.camera_keys=}\n") - -# Access frame indexes associated to first episode +# LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working +# with the latter, like iterating through the dataset. +# The __get_item__ iterates over the frames of the dataset. Since our datasets are also structured by +# episodes, you can access the frame indices of any episode using the episode_data_index. Here, we access +# frame indices associated to the first episode: episode_index = 0 from_idx = dataset.episode_data_index["from"][episode_index].item() to_idx = dataset.episode_data_index["to"][episode_index].item() -# LeRobot datasets actually subclass PyTorch datasets so you can do everything you know and love from working -# with the latter, like iterating through the dataset. Here we grab all the image frames. -frames = [dataset[idx]["observation.image"] for idx in range(from_idx, to_idx)] +# Then we grab all the image frames from the first camera: +camera_key = dataset.meta.camera_keys[0] +frames = [dataset[idx][camera_key] for idx in range(from_idx, to_idx)] -# Video frames are now float32 in range [0,1] channel first (c,h,w) to follow pytorch convention. To visualize -# them, we convert to uint8 in range [0,255] -frames = [(frame * 255).type(torch.uint8) for frame in frames] -# and to channel last (h,w,c). -frames = [frame.permute((1, 2, 0)).numpy() for frame in frames] - -# Finally, we save the frames to a mp4 video for visualization. -Path("outputs/examples/1_load_lerobot_dataset").mkdir(parents=True, exist_ok=True) -imageio.mimsave("outputs/examples/1_load_lerobot_dataset/episode_0.mp4", frames, fps=dataset.fps) +# The objects returned by the dataset are all torch.Tensors +print(type(frames[0])) +print(frames[0].shape) +# Since we're using pytorch, the shape is in pytorch, channel-first convention (c, h, w). +# We can compare this shape with the information available for that feature +pprint(dataset.features[camera_key]) +# In particular: +print(dataset.features[camera_key]["shape"]) +# The shape is in (h, w, c) which is a more universal format. # For many machine learning applications we need to load the history of past observations or trajectories of # future actions. Our datasets can load previous and future frames for each key/modality, using timestamps @@ -86,6 +125,7 @@ dataloader = torch.utils.data.DataLoader( batch_size=32, shuffle=True, ) + for batch in dataloader: print(f"{batch['observation.image'].shape=}") # (32,4,c,h,w) print(f"{batch['observation.state'].shape=}") # (32,8,c)