From 8ca0bc5e8893fd3faadfa8810b3962f0a3068a0e Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Mon, 20 May 2024 15:09:41 +0000 Subject: [PATCH] hot fix --- examples/4_calculate_validation_loss.py | 22 ++++++++++++++++++---- lerobot/common/datasets/utils.py | 11 ++++++++++- tests/test_examples.py | 2 +- 3 files changed, 29 insertions(+), 6 deletions(-) diff --git a/examples/4_calculate_validation_loss.py b/examples/4_calculate_validation_loss.py index 285184d2..6968c12d 100644 --- a/examples/4_calculate_validation_loss.py +++ b/examples/4_calculate_validation_loss.py @@ -8,6 +8,7 @@ especially in the context of imitation learning. The most reliable approach is t on the target environment, whether that be in simulation or the real world. """ +import math from pathlib import Path import torch @@ -39,11 +40,24 @@ delta_timestamps = { "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4], } -# Load the last 10 episodes of the dataset as a validation set. -# The `split` argument utilizes the `datasets` library's syntax for slicing datasets. -# For more information on the Slice API, please see: +# Load the last 10% episodes of the dataset as a validation set. +# - Load full dataset +full_dataset = LeRobotDataset("lerobot/pusht", split="train") +# - Calculate train and val subsets +num_train_episodes = math.floor(full_dataset.num_episodes * 90 / 100) +num_val_episodes = full_dataset.num_episodes - num_train_episodes +print(f"Number of episodes in full dataset: {full_dataset.num_episodes}") +print(f"Number of episodes in training dataset (90% subset): {num_train_episodes}") +print(f"Number of episodes in validation dataset (10% reset): {num_val_episodes}") +# - Get first frame index of the validation set +first_val_frame_index = full_dataset.episode_data_index["from"][num_train_episodes].item() +# - Load frames subset belonging to validation set using the `split` argument. +# It utilizes the `datasets` library's syntax for slicing datasets. +# For more information on the Slice API, please see: # https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits -val_dataset = LeRobotDataset("lerobot/pusht", split="train[24342:]", delta_timestamps=delta_timestamps) +val_dataset = LeRobotDataset( + "lerobot/pusht", split=f"train[{first_val_frame_index}:]", delta_timestamps=delta_timestamps +) # Create dataloader for evaluation. val_dataloader = torch.utils.data.DataLoader( diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 207ccf7c..519d1b2f 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +import re from pathlib import Path from typing import Dict @@ -80,7 +81,15 @@ def hf_transform_to_torch(items_dict): def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset: """hf_dataset contains all the observations, states, actions, rewards, etc.""" if root is not None: - hf_dataset = load_from_disk(str(Path(root) / repo_id / split)) + hf_dataset = load_from_disk(str(Path(root) / repo_id / "train")) + # TODO(rcadene): clean this which enables getting a subset of dataset + if split != "train": + match = match = re.search(r"train\[(\d+):\]", split) + if match: + from_frame_index = int(match.group(1)) + hf_dataset = hf_dataset.select(range(from_frame_index, len(hf_dataset))) + else: + raise ValueError(split) else: hf_dataset = load_dataset(repo_id, revision=version, split=split) hf_dataset.set_transform(hf_transform_to_torch) diff --git a/tests/test_examples.py b/tests/test_examples.py index 9881e3fa..a0c60b7e 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -111,7 +111,7 @@ def test_examples_2_through_4(): '# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', 'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', ), - ('split="train[24342:]"', 'split="train[-1:]"'), + ('split=f"train[{first_val_frame_index}:]"', 'split="train[30:]"'), ("num_workers=4", "num_workers=0"), ('device = torch.device("cuda")', 'device = torch.device("cpu")'), ("batch_size=64", "batch_size=1"),