Enhance example 1

This commit is contained in:
Simon Alibert 2024-05-04 20:57:31 +02:00
parent 292da1f4fe
commit 0981350449
2 changed files with 45 additions and 26 deletions

View File

@ -14,6 +14,7 @@ The script ends with examples of how to batch process data using PyTorch's DataL
""" """
from pathlib import Path from pathlib import Path
from pprint import pprint
import imageio import imageio
import torch import torch
@ -21,39 +22,36 @@ import torch
import lerobot import lerobot
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
print("List of available datasets", lerobot.available_datasets) print("List of available datasets:")
# # >>> ['lerobot/aloha_sim_insertion_human', 'lerobot/aloha_sim_insertion_scripted', pprint(lerobot.available_datasets)
# # 'lerobot/aloha_sim_transfer_cube_human', 'lerobot/aloha_sim_transfer_cube_scripted',
# # 'lerobot/pusht', 'lerobot/xarm_lift_medium']
# Let's take one for this example
repo_id = "lerobot/pusht" repo_id = "lerobot/pusht"
# You can easily load a dataset from a Hugging Face repositery # You can easily load a dataset from a Hugging Face repository
dataset = LeRobotDataset(repo_id) dataset = LeRobotDataset(repo_id)
# LeRobotDataset is actually a thin wrapper around an underlying Hugging Face dataset (see https://huggingface.co/docs/datasets/index for more information). # LeRobotDataset is actually a thin wrapper around an underlying Hugging Face dataset
# TODO(rcadene): update to make the print pretty # (see https://huggingface.co/docs/datasets/index for more information).
print(f"{dataset=}") print(dataset)
print(f"{dataset.hf_dataset=}") print(dataset.hf_dataset)
# and provides additional utilities for robotics and compatibility with pytorch # And provides additional utilities for robotics and compatibility with Pytorch
print(f"number of samples/frames: {dataset.num_samples=}") print(f"\naverage number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}")
print(f"number of episodes: {dataset.num_episodes=}")
print(f"average number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}")
print(f"frames per second used during data collection: {dataset.fps=}") print(f"frames per second used during data collection: {dataset.fps=}")
print(f"keys to access images from cameras: {dataset.image_keys=}") print(f"keys to access images from cameras: {dataset.image_keys=}\n")
# Access frame indexes associated to first episode # Access frame indexes associated to first episode
episode_index = 0 episode_index = 0
from_idx = dataset.episode_data_index["from"][episode_index].item() from_idx = dataset.episode_data_index["from"][episode_index].item()
to_idx = dataset.episode_data_index["to"][episode_index].item() to_idx = dataset.episode_data_index["to"][episode_index].item()
# LeRobot datasets actually subclass PyTorch datasets so you can do everything you know and love from working with the latter, like iterating through the dataset. # LeRobot datasets actually subclass PyTorch datasets so you can do everything you know and love from working
# Here we grab all the image frames. # with the latter, like iterating through the dataset. Here we grab all the image frames.
frames = [dataset[idx]["observation.image"] for idx in range(from_idx, to_idx)] frames = [dataset[idx]["observation.image"] for idx in range(from_idx, to_idx)]
# Video frames are now float32 in range [0,1] channel first (c,h,w) to follow pytorch convention. # Video frames are now float32 in range [0,1] channel first (c,h,w) to follow pytorch convention. To visualize
# To visualize them, we convert to uint8 range [0,255] # them, we convert to uint8 in range [0,255]
frames = [(frame * 255).type(torch.uint8) for frame in frames] frames = [(frame * 255).type(torch.uint8) for frame in frames]
# and to channel last (h,w,c). # and to channel last (h,w,c).
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames] frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
@ -62,9 +60,9 @@ frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
Path("outputs/examples/1_load_lerobot_dataset").mkdir(parents=True, exist_ok=True) Path("outputs/examples/1_load_lerobot_dataset").mkdir(parents=True, exist_ok=True)
imageio.mimsave("outputs/examples/1_load_lerobot_dataset/episode_0.mp4", frames, fps=dataset.fps) imageio.mimsave("outputs/examples/1_load_lerobot_dataset/episode_0.mp4", frames, fps=dataset.fps)
# For many machine learning applications we need to load the history of past observations or trajectories of future actions. # For many machine learning applications we need to load the history of past observations or trajectories of
# Our datasets can load previous and future frames for each key/modality, # future actions. Our datasets can load previous and future frames for each key/modality, using timestamps
# using timestamps differences with the current loaded frame. For instance: # differences with the current loaded frame. For instance:
delta_timestamps = { delta_timestamps = {
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame # loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
"observation.image": [-1, -0.5, -0.20, 0], "observation.image": [-1, -0.5, -0.20, 0],
@ -74,12 +72,12 @@ delta_timestamps = {
"action": [t / dataset.fps for t in range(64)], "action": [t / dataset.fps for t in range(64)],
} }
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps) dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
print(f"{dataset[0]['observation.image'].shape=}") # (4,c,h,w) print(f"\n{dataset[0]['observation.image'].shape=}") # (4,c,h,w)
print(f"{dataset[0]['observation.state'].shape=}") # (8,c) print(f"{dataset[0]['observation.state'].shape=}") # (8,c)
print(f"{dataset[0]['action'].shape=}") # (64,c) print(f"{dataset[0]['action'].shape=}\n") # (64,c)
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers # Finally, our datasets are fully compatible with PyTorch dataloaders and samplers because they are just
# because they are just PyTorch datasets. # PyTorch datasets.
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
dataset, dataset,
num_workers=0, num_workers=0,

View File

@ -47,6 +47,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
@property @property
def fps(self) -> int: def fps(self) -> int:
"""Frames per second used during data collection."""
return self.info["fps"] return self.info["fps"]
@property @property
@ -62,6 +63,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
@property @property
def image_keys(self) -> list[str]: def image_keys(self) -> list[str]:
"""Keys to access images from cameras."""
image_keys = [] image_keys = []
for key, feats in self.hf_dataset.features.items(): for key, feats in self.hf_dataset.features.items():
if isinstance(feats, datasets.Image): if isinstance(feats, datasets.Image):
@ -69,7 +71,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
return image_keys + self.video_frame_keys return image_keys + self.video_frame_keys
@property @property
def video_frame_keys(self): def video_frame_keys(self) -> list[str]:
"""Keys to access video frames from cameras."""
video_frame_keys = [] video_frame_keys = []
for key, feats in self.hf_dataset.features.items(): for key, feats in self.hf_dataset.features.items():
if isinstance(feats, VideoFrame): if isinstance(feats, VideoFrame):
@ -78,10 +81,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
@property @property
def num_samples(self) -> int: def num_samples(self) -> int:
"""Number of samples/frames."""
return len(self.hf_dataset) return len(self.hf_dataset)
@property @property
def num_episodes(self) -> int: def num_episodes(self) -> int:
"""Number of episodes."""
return len(self.hf_dataset.unique("episode_index")) return len(self.hf_dataset.unique("episode_index"))
@property @property
@ -121,6 +126,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
return item return item
def __repr__(self):
return (
f"{self.__class__.__name__}(\n"
f" Repository ID: '{self.repo_id}',\n"
f" Version: '{self.version}',\n"
f" Split: '{self.split}',\n"
f" Number of Samples: {self.num_samples},\n"
f" Number of Episodes: {self.num_episodes},\n"
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
f" Recorded Frames per Second: {self.fps},\n"
f" Image Keys: {self.image_keys},\n"
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
f" Transformations: {self.transform},\n"
f")"
)
@classmethod @classmethod
def from_preloaded( def from_preloaded(
cls, cls,