From c4da6891719da44139e321cf13327a8773822dde Mon Sep 17 00:00:00 2001 From: Remi Date: Mon, 20 May 2024 18:30:11 +0200 Subject: [PATCH] Hot fix to compute validation loss example test (#200) Co-authored-by: Alexander Soare --- examples/4_calculate_validation_loss.py | 27 +++++++++++++++++++++---- lerobot/common/datasets/utils.py | 27 ++++++++++++++++++++++++- tests/test_examples.py | 2 +- 3 files changed, 50 insertions(+), 6 deletions(-) diff --git a/examples/4_calculate_validation_loss.py b/examples/4_calculate_validation_loss.py index 285184d2..1428014b 100644 --- a/examples/4_calculate_validation_loss.py +++ b/examples/4_calculate_validation_loss.py @@ -8,6 +8,7 @@ especially in the context of imitation learning. The most reliable approach is t on the target environment, whether that be in simulation or the real world. """ +import math from pathlib import Path import torch @@ -39,11 +40,29 @@ delta_timestamps = { "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4], } -# Load the last 10 episodes of the dataset as a validation set. -# The `split` argument utilizes the `datasets` library's syntax for slicing datasets. -# For more information on the Slice API, please see: +# Load the last 10% of episodes of the dataset as a validation set. +# - Load full dataset +full_dataset = LeRobotDataset("lerobot/pusht", split="train") +# - Calculate train and val subsets +num_train_episodes = math.floor(full_dataset.num_episodes * 90 / 100) +num_val_episodes = full_dataset.num_episodes - num_train_episodes +print(f"Number of episodes in full dataset: {full_dataset.num_episodes}") +print(f"Number of episodes in training dataset (90% subset): {num_train_episodes}") +print(f"Number of episodes in validation dataset (10% subset): {num_val_episodes}") +# - Get first frame index of the validation set +first_val_frame_index = full_dataset.episode_data_index["from"][num_train_episodes].item() +# - Load frames subset belonging to validation set using the `split` argument. +# It utilizes the `datasets` library's syntax for slicing datasets. +# For more information on the Slice API, please see: # https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits -val_dataset = LeRobotDataset("lerobot/pusht", split="train[24342:]", delta_timestamps=delta_timestamps) +train_dataset = LeRobotDataset( + "lerobot/pusht", split=f"train[:{first_val_frame_index}]", delta_timestamps=delta_timestamps +) +val_dataset = LeRobotDataset( + "lerobot/pusht", split=f"train[{first_val_frame_index}:]", delta_timestamps=delta_timestamps +) +print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}") +print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}") # Create dataloader for evaluation. val_dataloader = torch.utils.data.DataLoader( diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 207ccf7c..86fef8d4 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +import re from pathlib import Path from typing import Dict @@ -80,7 +81,23 @@ def hf_transform_to_torch(items_dict): def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset: """hf_dataset contains all the observations, states, actions, rewards, etc.""" if root is not None: - hf_dataset = load_from_disk(str(Path(root) / repo_id / split)) + hf_dataset = load_from_disk(str(Path(root) / repo_id / "train")) + # TODO(rcadene): clean this which enables getting a subset of dataset + if split != "train": + if "%" in split: + raise NotImplementedError(f"We dont support splitting based on percentage for now ({split}).") + match_from = re.search(r"train\[(\d+):\]", split) + match_to = re.search(r"train\[:(\d+)\]", split) + if match_from: + from_frame_index = int(match_from.group(1)) + hf_dataset = hf_dataset.select(range(from_frame_index, len(hf_dataset))) + elif match_to: + to_frame_index = int(match_to.group(1)) + hf_dataset = hf_dataset.select(range(to_frame_index)) + else: + raise ValueError( + f'`split` ({split}) should either be "train", "train[INT:]", or "train[:INT]"' + ) else: hf_dataset = load_dataset(repo_id, revision=version, split=split) hf_dataset.set_transform(hf_transform_to_torch) @@ -273,6 +290,12 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torc "to": [3, 7, 12] } """ + if len(hf_dataset) == 0: + episode_data_index = { + "from": torch.tensor([]), + "to": torch.tensor([]), + } + return episode_data_index for idx, episode_idx in enumerate(hf_dataset["episode_index"]): if episode_idx != current_episode: # We encountered a new episode, so we append its starting location to the "from" list @@ -303,6 +326,8 @@ def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset: This brings the `episode_index` to the required format. """ + if len(hf_dataset) == 0: + return hf_dataset unique_episode_idxs = torch.stack(hf_dataset["episode_index"]).unique().tolist() episode_idx_to_reset_idx_mapping = { ep_id: reset_ep_id for reset_ep_id, ep_id in enumerate(unique_episode_idxs) diff --git a/tests/test_examples.py b/tests/test_examples.py index 9881e3fa..a0c60b7e 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -111,7 +111,7 @@ def test_examples_2_through_4(): '# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', 'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', ), - ('split="train[24342:]"', 'split="train[-1:]"'), + ('split=f"train[{first_val_frame_index}:]"', 'split="train[30:]"'), ("num_workers=4", "num_workers=0"), ('device = torch.device("cuda")', 'device = torch.device("cpu")'), ("batch_size=64", "batch_size=1"),