Fix advanced example 2
This commit is contained in:
parent
fde29e0167
commit
a6762ec316
|
@ -15,7 +15,7 @@ from pathlib import Path
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||||
|
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
|
@ -42,26 +42,20 @@ delta_timestamps = {
|
||||||
}
|
}
|
||||||
|
|
||||||
# Load the last 10% of episodes of the dataset as a validation set.
|
# Load the last 10% of episodes of the dataset as a validation set.
|
||||||
# - Load full dataset
|
# - Load dataset metadata
|
||||||
full_dataset = LeRobotDataset("lerobot/pusht", split="train")
|
dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
|
||||||
# - Calculate train and val subsets
|
# - Calculate train and val episodes
|
||||||
num_train_episodes = math.floor(full_dataset.num_episodes * 90 / 100)
|
total_episodes = dataset_metadata.total_episodes
|
||||||
num_val_episodes = full_dataset.num_episodes - num_train_episodes
|
episodes = list(range(dataset_metadata.total_episodes))
|
||||||
print(f"Number of episodes in full dataset: {full_dataset.num_episodes}")
|
num_train_episodes = math.floor(total_episodes * 90 / 100)
|
||||||
print(f"Number of episodes in training dataset (90% subset): {num_train_episodes}")
|
train_episodes = episodes[:num_train_episodes]
|
||||||
print(f"Number of episodes in validation dataset (10% subset): {num_val_episodes}")
|
val_episodes = episodes[num_train_episodes:]
|
||||||
# - Get first frame index of the validation set
|
print(f"Number of episodes in full dataset: {total_episodes}")
|
||||||
first_val_frame_index = full_dataset.episode_data_index["from"][num_train_episodes].item()
|
print(f"Number of episodes in training dataset (90% subset): {len(train_episodes)}")
|
||||||
# - Load frames subset belonging to validation set using the `split` argument.
|
print(f"Number of episodes in validation dataset (10% subset): {len(val_episodes)}")
|
||||||
# It utilizes the `datasets` library's syntax for slicing datasets.
|
# - Load train an val datasets
|
||||||
# For more information on the Slice API, please see:
|
train_dataset = LeRobotDataset("lerobot/pusht", episodes=train_episodes, delta_timestamps=delta_timestamps)
|
||||||
# https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
|
val_dataset = LeRobotDataset("lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps)
|
||||||
train_dataset = LeRobotDataset(
|
|
||||||
"lerobot/pusht", split=f"train[:{first_val_frame_index}]", delta_timestamps=delta_timestamps
|
|
||||||
)
|
|
||||||
val_dataset = LeRobotDataset(
|
|
||||||
"lerobot/pusht", split=f"train[{first_val_frame_index}:]", delta_timestamps=delta_timestamps
|
|
||||||
)
|
|
||||||
print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}")
|
print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}")
|
||||||
print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}")
|
print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue