Update example 1
This commit is contained in:
parent
93d9bf83c2
commit
36b9b60a0e
|
@ -3,10 +3,9 @@ This script demonstrates the use of `LeRobotDataset` class for handling and proc
|
||||||
It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch.
|
It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch.
|
||||||
|
|
||||||
Features included in this script:
|
Features included in this script:
|
||||||
- Loading a dataset and accessing its properties.
|
- Viewing a dataset's metadata and exploring its properties.
|
||||||
- Filtering data by episode number.
|
- Loading an existing dataset from the hub or a subset of it.
|
||||||
- Converting tensor data for visualization.
|
- Accessing frames by episode number.
|
||||||
- Saving video files from dataset frames.
|
|
||||||
- Using advanced dataset features like timestamp-based frame selection.
|
- Using advanced dataset features like timestamp-based frame selection.
|
||||||
- Demonstrating compatibility with PyTorch DataLoader for batch processing.
|
- Demonstrating compatibility with PyTorch DataLoader for batch processing.
|
||||||
|
|
||||||
|
@ -35,7 +34,7 @@ pprint(repo_ids)
|
||||||
# https://huggingface.co/datasets?other=LeRobot
|
# https://huggingface.co/datasets?other=LeRobot
|
||||||
|
|
||||||
# Let's take this one for this example
|
# Let's take this one for this example
|
||||||
repo_id = "aliberts/koch_tutorial"
|
repo_id = "lerobot/aloha_mobile_cabinet"
|
||||||
# We can have a look and fetch its metadata to know more about it:
|
# We can have a look and fetch its metadata to know more about it:
|
||||||
ds_meta = LeRobotDatasetMetadata(repo_id)
|
ds_meta = LeRobotDatasetMetadata(repo_id)
|
||||||
|
|
||||||
|
@ -106,15 +105,18 @@ print(dataset.features[camera_key]["shape"])
|
||||||
# differences with the current loaded frame. For instance:
|
# differences with the current loaded frame. For instance:
|
||||||
delta_timestamps = {
|
delta_timestamps = {
|
||||||
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
|
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
|
||||||
"observation.image": [-1, -0.5, -0.20, 0],
|
camera_key: [-1, -0.5, -0.20, 0],
|
||||||
# loads 8 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
|
# loads 8 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
|
||||||
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0],
|
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0],
|
||||||
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
|
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
|
||||||
"action": [t / dataset.fps for t in range(64)],
|
"action": [t / dataset.fps for t in range(64)],
|
||||||
}
|
}
|
||||||
|
# Note that in any case, these delta_timestamps values need to be multiples of (1/fps) so that added to any
|
||||||
|
# timestamp, you still get a valid timestamp.
|
||||||
|
|
||||||
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
|
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
|
||||||
print(f"\n{dataset[0]['observation.image'].shape=}") # (4,c,h,w)
|
print(f"\n{dataset[0][camera_key].shape=}") # (4, c, h, w)
|
||||||
print(f"{dataset[0]['observation.state'].shape=}") # (8,c)
|
print(f"{dataset[0]['observation.state'].shape=}") # (6, c)
|
||||||
print(f"{dataset[0]['action'].shape=}\n") # (64, c)
|
print(f"{dataset[0]['action'].shape=}\n") # (64, c)
|
||||||
|
|
||||||
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers because they are just
|
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers because they are just
|
||||||
|
@ -127,7 +129,7 @@ dataloader = torch.utils.data.DataLoader(
|
||||||
)
|
)
|
||||||
|
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
print(f"{batch['observation.image'].shape=}") # (32,4,c,h,w)
|
print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
|
||||||
print(f"{batch['observation.state'].shape=}") # (32,8,c)
|
print(f"{batch['observation.state'].shape=}") # (32, 5, c)
|
||||||
print(f"{batch['action'].shape=}") # (32, 64, c)
|
print(f"{batch['action'].shape=}") # (32, 64, c)
|
||||||
break
|
break
|
||||||
|
|
Loading…
Reference in New Issue