diff --git a/examples/advanced/2_calculate_validation_loss.py b/examples/advanced/2_calculate_validation_loss.py index b312b7d0..c260c15d 100644 --- a/examples/advanced/2_calculate_validation_loss.py +++ b/examples/advanced/2_calculate_validation_loss.py @@ -15,7 +15,7 @@ from pathlib import Path import torch from huggingface_hub import snapshot_download -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy device = torch.device("cuda") @@ -42,26 +42,20 @@ delta_timestamps = { } # Load the last 10% of episodes of the dataset as a validation set. -# - Load full dataset -full_dataset = LeRobotDataset("lerobot/pusht", split="train") -# - Calculate train and val subsets -num_train_episodes = math.floor(full_dataset.num_episodes * 90 / 100) -num_val_episodes = full_dataset.num_episodes - num_train_episodes -print(f"Number of episodes in full dataset: {full_dataset.num_episodes}") -print(f"Number of episodes in training dataset (90% subset): {num_train_episodes}") -print(f"Number of episodes in validation dataset (10% subset): {num_val_episodes}") -# - Get first frame index of the validation set -first_val_frame_index = full_dataset.episode_data_index["from"][num_train_episodes].item() -# - Load frames subset belonging to validation set using the `split` argument. -# It utilizes the `datasets` library's syntax for slicing datasets. -# For more information on the Slice API, please see: -# https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits -train_dataset = LeRobotDataset( - "lerobot/pusht", split=f"train[:{first_val_frame_index}]", delta_timestamps=delta_timestamps -) -val_dataset = LeRobotDataset( - "lerobot/pusht", split=f"train[{first_val_frame_index}:]", delta_timestamps=delta_timestamps -) +# - Load dataset metadata +dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht") +# - Calculate train and val episodes +total_episodes = dataset_metadata.total_episodes +episodes = list(range(dataset_metadata.total_episodes)) +num_train_episodes = math.floor(total_episodes * 90 / 100) +train_episodes = episodes[:num_train_episodes] +val_episodes = episodes[num_train_episodes:] +print(f"Number of episodes in full dataset: {total_episodes}") +print(f"Number of episodes in training dataset (90% subset): {len(train_episodes)}") +print(f"Number of episodes in validation dataset (10% subset): {len(val_episodes)}") +# - Load train an val datasets +train_dataset = LeRobotDataset("lerobot/pusht", episodes=train_episodes, delta_timestamps=delta_timestamps) +val_dataset = LeRobotDataset("lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps) print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}") print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}")