Update example 1

This commit is contained in:
Simon Alibert 2024-11-20 18:22:46 +01:00
parent 93d9bf83c2
commit 36b9b60a0e
1 changed files with 14 additions and 12 deletions

View File

@ -3,10 +3,9 @@ This script demonstrates the use of `LeRobotDataset` class for handling and proc
It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch. It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch.
Features included in this script: Features included in this script:
- Loading a dataset and accessing its properties. - Viewing a dataset's metadata and exploring its properties.
- Filtering data by episode number. - Loading an existing dataset from the hub or a subset of it.
- Converting tensor data for visualization. - Accessing frames by episode number.
- Saving video files from dataset frames.
- Using advanced dataset features like timestamp-based frame selection. - Using advanced dataset features like timestamp-based frame selection.
- Demonstrating compatibility with PyTorch DataLoader for batch processing. - Demonstrating compatibility with PyTorch DataLoader for batch processing.
@ -35,7 +34,7 @@ pprint(repo_ids)
# https://huggingface.co/datasets?other=LeRobot # https://huggingface.co/datasets?other=LeRobot
# Let's take this one for this example # Let's take this one for this example
repo_id = "aliberts/koch_tutorial" repo_id = "lerobot/aloha_mobile_cabinet"
# We can have a look and fetch its metadata to know more about it: # We can have a look and fetch its metadata to know more about it:
ds_meta = LeRobotDatasetMetadata(repo_id) ds_meta = LeRobotDatasetMetadata(repo_id)
@ -106,16 +105,19 @@ print(dataset.features[camera_key]["shape"])
# differences with the current loaded frame. For instance: # differences with the current loaded frame. For instance:
delta_timestamps = { delta_timestamps = {
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame # loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
"observation.image": [-1, -0.5, -0.20, 0], camera_key: [-1, -0.5, -0.20, 0],
# loads 8 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame # loads 8 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0], "observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0],
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future # loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
"action": [t / dataset.fps for t in range(64)], "action": [t / dataset.fps for t in range(64)],
} }
# Note that in any case, these delta_timestamps values need to be multiples of (1/fps) so that added to any
# timestamp, you still get a valid timestamp.
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps) dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
print(f"\n{dataset[0]['observation.image'].shape=}") # (4,c,h,w) print(f"\n{dataset[0][camera_key].shape=}") # (4, c, h, w)
print(f"{dataset[0]['observation.state'].shape=}") # (8,c) print(f"{dataset[0]['observation.state'].shape=}") # (6, c)
print(f"{dataset[0]['action'].shape=}\n") # (64,c) print(f"{dataset[0]['action'].shape=}\n") # (64, c)
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers because they are just # Finally, our datasets are fully compatible with PyTorch dataloaders and samplers because they are just
# PyTorch datasets. # PyTorch datasets.
@ -127,7 +129,7 @@ dataloader = torch.utils.data.DataLoader(
) )
for batch in dataloader: for batch in dataloader:
print(f"{batch['observation.image'].shape=}") # (32,4,c,h,w) print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
print(f"{batch['observation.state'].shape=}") # (32,8,c) print(f"{batch['observation.state'].shape=}") # (32, 5, c)
print(f"{batch['action'].shape=}") # (32,64,c) print(f"{batch['action'].shape=}") # (32, 64, c)
break break