Fix advanced example 2

This commit is contained in:
Simon Alibert 2024-11-03 19:03:15 +01:00
parent fde29e0167
commit a6762ec316
1 changed files with 15 additions and 21 deletions

View File

@ -15,7 +15,7 @@ from pathlib import Path
import torch import torch
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
device = torch.device("cuda") device = torch.device("cuda")
@ -42,26 +42,20 @@ delta_timestamps = {
} }
# Load the last 10% of episodes of the dataset as a validation set. # Load the last 10% of episodes of the dataset as a validation set.
# - Load full dataset # - Load dataset metadata
full_dataset = LeRobotDataset("lerobot/pusht", split="train") dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
# - Calculate train and val subsets # - Calculate train and val episodes
num_train_episodes = math.floor(full_dataset.num_episodes * 90 / 100) total_episodes = dataset_metadata.total_episodes
num_val_episodes = full_dataset.num_episodes - num_train_episodes episodes = list(range(dataset_metadata.total_episodes))
print(f"Number of episodes in full dataset: {full_dataset.num_episodes}") num_train_episodes = math.floor(total_episodes * 90 / 100)
print(f"Number of episodes in training dataset (90% subset): {num_train_episodes}") train_episodes = episodes[:num_train_episodes]
print(f"Number of episodes in validation dataset (10% subset): {num_val_episodes}") val_episodes = episodes[num_train_episodes:]
# - Get first frame index of the validation set print(f"Number of episodes in full dataset: {total_episodes}")
first_val_frame_index = full_dataset.episode_data_index["from"][num_train_episodes].item() print(f"Number of episodes in training dataset (90% subset): {len(train_episodes)}")
# - Load frames subset belonging to validation set using the `split` argument. print(f"Number of episodes in validation dataset (10% subset): {len(val_episodes)}")
# It utilizes the `datasets` library's syntax for slicing datasets. # - Load train an val datasets
# For more information on the Slice API, please see: train_dataset = LeRobotDataset("lerobot/pusht", episodes=train_episodes, delta_timestamps=delta_timestamps)
# https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits val_dataset = LeRobotDataset("lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps)
train_dataset = LeRobotDataset(
"lerobot/pusht", split=f"train[:{first_val_frame_index}]", delta_timestamps=delta_timestamps
)
val_dataset = LeRobotDataset(
"lerobot/pusht", split=f"train[{first_val_frame_index}:]", delta_timestamps=delta_timestamps
)
print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}") print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}")
print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}") print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}")