Apply suggestions from code review

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Remi 2024-05-20 17:43:22 +02:00 committed by GitHub
parent 8ca0bc5e88
commit 3966079f09
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 4 deletions

View File

@ -40,7 +40,7 @@ delta_timestamps = {
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
}
# Load the last 10% episodes of the dataset as a validation set.
# Load the last 10% of episodes of the dataset as a validation set.
# - Load full dataset
full_dataset = LeRobotDataset("lerobot/pusht", split="train")
# - Calculate train and val subsets
@ -48,7 +48,7 @@ num_train_episodes = math.floor(full_dataset.num_episodes * 90 / 100)
num_val_episodes = full_dataset.num_episodes - num_train_episodes
print(f"Number of episodes in full dataset: {full_dataset.num_episodes}")
print(f"Number of episodes in training dataset (90% subset): {num_train_episodes}")
print(f"Number of episodes in validation dataset (10% reset): {num_val_episodes}")
print(f"Number of episodes in validation dataset (10% subset): {num_val_episodes}")
# - Get first frame index of the validation set
first_val_frame_index = full_dataset.episode_data_index["from"][num_train_episodes].item()
# - Load frames subset belonging to validation set using the `split` argument.

View File

@ -84,12 +84,12 @@ def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
hf_dataset = load_from_disk(str(Path(root) / repo_id / "train"))
# TODO(rcadene): clean this which enables getting a subset of dataset
if split != "train":
match = match = re.search(r"train\[(\d+):\]", split)
match = re.search(r"train\[(\d+):\]", split)
if match:
from_frame_index = int(match.group(1))
hf_dataset = hf_dataset.select(range(from_frame_index, len(hf_dataset)))
else:
raise ValueError(split)
raise ValueError(split), '`split` should either be "train" or of the form "train[{int}:]"'
else:
hf_dataset = load_dataset(repo_id, revision=version, split=split)
hf_dataset.set_transform(hf_transform_to_torch)