Enhance example 1
This commit is contained in:
parent
292da1f4fe
commit
0981350449
|
@ -14,6 +14,7 @@ The script ends with examples of how to batch process data using PyTorch's DataL
|
|||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
|
||||
import imageio
|
||||
import torch
|
||||
|
@ -21,39 +22,36 @@ import torch
|
|||
import lerobot
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
print("List of available datasets", lerobot.available_datasets)
|
||||
# # >>> ['lerobot/aloha_sim_insertion_human', 'lerobot/aloha_sim_insertion_scripted',
|
||||
# # 'lerobot/aloha_sim_transfer_cube_human', 'lerobot/aloha_sim_transfer_cube_scripted',
|
||||
# # 'lerobot/pusht', 'lerobot/xarm_lift_medium']
|
||||
print("List of available datasets:")
|
||||
pprint(lerobot.available_datasets)
|
||||
|
||||
# Let's take one for this example
|
||||
repo_id = "lerobot/pusht"
|
||||
|
||||
# You can easily load a dataset from a Hugging Face repositery
|
||||
# You can easily load a dataset from a Hugging Face repository
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
|
||||
# LeRobotDataset is actually a thin wrapper around an underlying Hugging Face dataset (see https://huggingface.co/docs/datasets/index for more information).
|
||||
# TODO(rcadene): update to make the print pretty
|
||||
print(f"{dataset=}")
|
||||
print(f"{dataset.hf_dataset=}")
|
||||
# LeRobotDataset is actually a thin wrapper around an underlying Hugging Face dataset
|
||||
# (see https://huggingface.co/docs/datasets/index for more information).
|
||||
print(dataset)
|
||||
print(dataset.hf_dataset)
|
||||
|
||||
# and provides additional utilities for robotics and compatibility with pytorch
|
||||
print(f"number of samples/frames: {dataset.num_samples=}")
|
||||
print(f"number of episodes: {dataset.num_episodes=}")
|
||||
print(f"average number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}")
|
||||
# And provides additional utilities for robotics and compatibility with Pytorch
|
||||
print(f"\naverage number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}")
|
||||
print(f"frames per second used during data collection: {dataset.fps=}")
|
||||
print(f"keys to access images from cameras: {dataset.image_keys=}")
|
||||
print(f"keys to access images from cameras: {dataset.image_keys=}\n")
|
||||
|
||||
# Access frame indexes associated to first episode
|
||||
episode_index = 0
|
||||
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
||||
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
||||
|
||||
# LeRobot datasets actually subclass PyTorch datasets so you can do everything you know and love from working with the latter, like iterating through the dataset.
|
||||
# Here we grab all the image frames.
|
||||
# LeRobot datasets actually subclass PyTorch datasets so you can do everything you know and love from working
|
||||
# with the latter, like iterating through the dataset. Here we grab all the image frames.
|
||||
frames = [dataset[idx]["observation.image"] for idx in range(from_idx, to_idx)]
|
||||
|
||||
# Video frames are now float32 in range [0,1] channel first (c,h,w) to follow pytorch convention.
|
||||
# To visualize them, we convert to uint8 range [0,255]
|
||||
# Video frames are now float32 in range [0,1] channel first (c,h,w) to follow pytorch convention. To visualize
|
||||
# them, we convert to uint8 in range [0,255]
|
||||
frames = [(frame * 255).type(torch.uint8) for frame in frames]
|
||||
# and to channel last (h,w,c).
|
||||
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
|
||||
|
@ -62,9 +60,9 @@ frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
|
|||
Path("outputs/examples/1_load_lerobot_dataset").mkdir(parents=True, exist_ok=True)
|
||||
imageio.mimsave("outputs/examples/1_load_lerobot_dataset/episode_0.mp4", frames, fps=dataset.fps)
|
||||
|
||||
# For many machine learning applications we need to load the history of past observations or trajectories of future actions.
|
||||
# Our datasets can load previous and future frames for each key/modality,
|
||||
# using timestamps differences with the current loaded frame. For instance:
|
||||
# For many machine learning applications we need to load the history of past observations or trajectories of
|
||||
# future actions. Our datasets can load previous and future frames for each key/modality, using timestamps
|
||||
# differences with the current loaded frame. For instance:
|
||||
delta_timestamps = {
|
||||
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
|
||||
"observation.image": [-1, -0.5, -0.20, 0],
|
||||
|
@ -74,12 +72,12 @@ delta_timestamps = {
|
|||
"action": [t / dataset.fps for t in range(64)],
|
||||
}
|
||||
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
|
||||
print(f"{dataset[0]['observation.image'].shape=}") # (4,c,h,w)
|
||||
print(f"\n{dataset[0]['observation.image'].shape=}") # (4,c,h,w)
|
||||
print(f"{dataset[0]['observation.state'].shape=}") # (8,c)
|
||||
print(f"{dataset[0]['action'].shape=}") # (64,c)
|
||||
print(f"{dataset[0]['action'].shape=}\n") # (64,c)
|
||||
|
||||
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers
|
||||
# because they are just PyTorch datasets.
|
||||
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers because they are just
|
||||
# PyTorch datasets.
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
|
|
|
@ -47,6 +47,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
"""Frames per second used during data collection."""
|
||||
return self.info["fps"]
|
||||
|
||||
@property
|
||||
|
@ -62,6 +63,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
@property
|
||||
def image_keys(self) -> list[str]:
|
||||
"""Keys to access images from cameras."""
|
||||
image_keys = []
|
||||
for key, feats in self.hf_dataset.features.items():
|
||||
if isinstance(feats, datasets.Image):
|
||||
|
@ -69,7 +71,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
return image_keys + self.video_frame_keys
|
||||
|
||||
@property
|
||||
def video_frame_keys(self):
|
||||
def video_frame_keys(self) -> list[str]:
|
||||
"""Keys to access video frames from cameras."""
|
||||
video_frame_keys = []
|
||||
for key, feats in self.hf_dataset.features.items():
|
||||
if isinstance(feats, VideoFrame):
|
||||
|
@ -78,10 +81,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
"""Number of samples/frames."""
|
||||
return len(self.hf_dataset)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
"""Number of episodes."""
|
||||
return len(self.hf_dataset.unique("episode_index"))
|
||||
|
||||
@property
|
||||
|
@ -121,6 +126,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
return item
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{self.__class__.__name__}(\n"
|
||||
f" Repository ID: '{self.repo_id}',\n"
|
||||
f" Version: '{self.version}',\n"
|
||||
f" Split: '{self.split}',\n"
|
||||
f" Number of Samples: {self.num_samples},\n"
|
||||
f" Number of Episodes: {self.num_episodes},\n"
|
||||
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
|
||||
f" Recorded Frames per Second: {self.fps},\n"
|
||||
f" Image Keys: {self.image_keys},\n"
|
||||
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
|
||||
f" Transformations: {self.transform},\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_preloaded(
|
||||
cls,
|
||||
|
|
Loading…
Reference in New Issue