From 3966079f09bbb126a5f68706ada03cd40452c096 Mon Sep 17 00:00:00 2001 From: Remi Date: Mon, 20 May 2024 17:43:22 +0200 Subject: [PATCH] Apply suggestions from code review Co-authored-by: Alexander Soare --- examples/4_calculate_validation_loss.py | 4 ++-- lerobot/common/datasets/utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/4_calculate_validation_loss.py b/examples/4_calculate_validation_loss.py index 6968c12d..c2ca7cb3 100644 --- a/examples/4_calculate_validation_loss.py +++ b/examples/4_calculate_validation_loss.py @@ -40,7 +40,7 @@ delta_timestamps = { "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4], } -# Load the last 10% episodes of the dataset as a validation set. +# Load the last 10% of episodes of the dataset as a validation set. # - Load full dataset full_dataset = LeRobotDataset("lerobot/pusht", split="train") # - Calculate train and val subsets @@ -48,7 +48,7 @@ num_train_episodes = math.floor(full_dataset.num_episodes * 90 / 100) num_val_episodes = full_dataset.num_episodes - num_train_episodes print(f"Number of episodes in full dataset: {full_dataset.num_episodes}") print(f"Number of episodes in training dataset (90% subset): {num_train_episodes}") -print(f"Number of episodes in validation dataset (10% reset): {num_val_episodes}") +print(f"Number of episodes in validation dataset (10% subset): {num_val_episodes}") # - Get first frame index of the validation set first_val_frame_index = full_dataset.episode_data_index["from"][num_train_episodes].item() # - Load frames subset belonging to validation set using the `split` argument. diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 519d1b2f..51765836 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -84,12 +84,12 @@ def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset: hf_dataset = load_from_disk(str(Path(root) / repo_id / "train")) # TODO(rcadene): clean this which enables getting a subset of dataset if split != "train": - match = match = re.search(r"train\[(\d+):\]", split) + match = re.search(r"train\[(\d+):\]", split) if match: from_frame_index = int(match.group(1)) hf_dataset = hf_dataset.select(range(from_frame_index, len(hf_dataset))) else: - raise ValueError(split) + raise ValueError(split), '`split` should either be "train" or of the form "train[{int}:]"' else: hf_dataset = load_dataset(repo_id, revision=version, split=split) hf_dataset.set_transform(hf_transform_to_torch)