From 09813504498c3d52fdf9f9118a8c2b97107d6b53 Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Sat, 4 May 2024 20:57:31 +0200 Subject: [PATCH] Enhance example 1 --- examples/1_load_lerobot_dataset.py | 48 +++++++++++----------- lerobot/common/datasets/lerobot_dataset.py | 23 ++++++++++- 2 files changed, 45 insertions(+), 26 deletions(-) diff --git a/examples/1_load_lerobot_dataset.py b/examples/1_load_lerobot_dataset.py index f86199c5..c5f172ca 100644 --- a/examples/1_load_lerobot_dataset.py +++ b/examples/1_load_lerobot_dataset.py @@ -14,6 +14,7 @@ The script ends with examples of how to batch process data using PyTorch's DataL """ from pathlib import Path +from pprint import pprint import imageio import torch @@ -21,39 +22,36 @@ import torch import lerobot from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -print("List of available datasets", lerobot.available_datasets) -# # >>> ['lerobot/aloha_sim_insertion_human', 'lerobot/aloha_sim_insertion_scripted', -# # 'lerobot/aloha_sim_transfer_cube_human', 'lerobot/aloha_sim_transfer_cube_scripted', -# # 'lerobot/pusht', 'lerobot/xarm_lift_medium'] +print("List of available datasets:") +pprint(lerobot.available_datasets) +# Let's take one for this example repo_id = "lerobot/pusht" -# You can easily load a dataset from a Hugging Face repositery +# You can easily load a dataset from a Hugging Face repository dataset = LeRobotDataset(repo_id) -# LeRobotDataset is actually a thin wrapper around an underlying Hugging Face dataset (see https://huggingface.co/docs/datasets/index for more information). -# TODO(rcadene): update to make the print pretty -print(f"{dataset=}") -print(f"{dataset.hf_dataset=}") +# LeRobotDataset is actually a thin wrapper around an underlying Hugging Face dataset +# (see https://huggingface.co/docs/datasets/index for more information). +print(dataset) +print(dataset.hf_dataset) -# and provides additional utilities for robotics and compatibility with pytorch -print(f"number of samples/frames: {dataset.num_samples=}") -print(f"number of episodes: {dataset.num_episodes=}") -print(f"average number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}") +# And provides additional utilities for robotics and compatibility with Pytorch +print(f"\naverage number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}") print(f"frames per second used during data collection: {dataset.fps=}") -print(f"keys to access images from cameras: {dataset.image_keys=}") +print(f"keys to access images from cameras: {dataset.image_keys=}\n") # Access frame indexes associated to first episode episode_index = 0 from_idx = dataset.episode_data_index["from"][episode_index].item() to_idx = dataset.episode_data_index["to"][episode_index].item() -# LeRobot datasets actually subclass PyTorch datasets so you can do everything you know and love from working with the latter, like iterating through the dataset. -# Here we grab all the image frames. +# LeRobot datasets actually subclass PyTorch datasets so you can do everything you know and love from working +# with the latter, like iterating through the dataset. Here we grab all the image frames. frames = [dataset[idx]["observation.image"] for idx in range(from_idx, to_idx)] -# Video frames are now float32 in range [0,1] channel first (c,h,w) to follow pytorch convention. -# To visualize them, we convert to uint8 range [0,255] +# Video frames are now float32 in range [0,1] channel first (c,h,w) to follow pytorch convention. To visualize +# them, we convert to uint8 in range [0,255] frames = [(frame * 255).type(torch.uint8) for frame in frames] # and to channel last (h,w,c). frames = [frame.permute((1, 2, 0)).numpy() for frame in frames] @@ -62,9 +60,9 @@ frames = [frame.permute((1, 2, 0)).numpy() for frame in frames] Path("outputs/examples/1_load_lerobot_dataset").mkdir(parents=True, exist_ok=True) imageio.mimsave("outputs/examples/1_load_lerobot_dataset/episode_0.mp4", frames, fps=dataset.fps) -# For many machine learning applications we need to load the history of past observations or trajectories of future actions. -# Our datasets can load previous and future frames for each key/modality, -# using timestamps differences with the current loaded frame. For instance: +# For many machine learning applications we need to load the history of past observations or trajectories of +# future actions. Our datasets can load previous and future frames for each key/modality, using timestamps +# differences with the current loaded frame. For instance: delta_timestamps = { # loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame "observation.image": [-1, -0.5, -0.20, 0], @@ -74,12 +72,12 @@ delta_timestamps = { "action": [t / dataset.fps for t in range(64)], } dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps) -print(f"{dataset[0]['observation.image'].shape=}") # (4,c,h,w) +print(f"\n{dataset[0]['observation.image'].shape=}") # (4,c,h,w) print(f"{dataset[0]['observation.state'].shape=}") # (8,c) -print(f"{dataset[0]['action'].shape=}") # (64,c) +print(f"{dataset[0]['action'].shape=}\n") # (64,c) -# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers -# because they are just PyTorch datasets. +# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers because they are just +# PyTorch datasets. dataloader = torch.utils.data.DataLoader( dataset, num_workers=0, diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index c8cfbd8e..d17c4307 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -47,6 +47,7 @@ class LeRobotDataset(torch.utils.data.Dataset): @property def fps(self) -> int: + """Frames per second used during data collection.""" return self.info["fps"] @property @@ -62,6 +63,7 @@ class LeRobotDataset(torch.utils.data.Dataset): @property def image_keys(self) -> list[str]: + """Keys to access images from cameras.""" image_keys = [] for key, feats in self.hf_dataset.features.items(): if isinstance(feats, datasets.Image): @@ -69,7 +71,8 @@ class LeRobotDataset(torch.utils.data.Dataset): return image_keys + self.video_frame_keys @property - def video_frame_keys(self): + def video_frame_keys(self) -> list[str]: + """Keys to access video frames from cameras.""" video_frame_keys = [] for key, feats in self.hf_dataset.features.items(): if isinstance(feats, VideoFrame): @@ -78,10 +81,12 @@ class LeRobotDataset(torch.utils.data.Dataset): @property def num_samples(self) -> int: + """Number of samples/frames.""" return len(self.hf_dataset) @property def num_episodes(self) -> int: + """Number of episodes.""" return len(self.hf_dataset.unique("episode_index")) @property @@ -121,6 +126,22 @@ class LeRobotDataset(torch.utils.data.Dataset): return item + def __repr__(self): + return ( + f"{self.__class__.__name__}(\n" + f" Repository ID: '{self.repo_id}',\n" + f" Version: '{self.version}',\n" + f" Split: '{self.split}',\n" + f" Number of Samples: {self.num_samples},\n" + f" Number of Episodes: {self.num_episodes},\n" + f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n" + f" Recorded Frames per Second: {self.fps},\n" + f" Image Keys: {self.image_keys},\n" + f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n" + f" Transformations: {self.transform},\n" + f")" + ) + @classmethod def from_preloaded( cls,