Apply suggestions from code review
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
parent
8ca0bc5e88
commit
3966079f09
|
@ -40,7 +40,7 @@ delta_timestamps = {
|
|||
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
|
||||
}
|
||||
|
||||
# Load the last 10% episodes of the dataset as a validation set.
|
||||
# Load the last 10% of episodes of the dataset as a validation set.
|
||||
# - Load full dataset
|
||||
full_dataset = LeRobotDataset("lerobot/pusht", split="train")
|
||||
# - Calculate train and val subsets
|
||||
|
@ -48,7 +48,7 @@ num_train_episodes = math.floor(full_dataset.num_episodes * 90 / 100)
|
|||
num_val_episodes = full_dataset.num_episodes - num_train_episodes
|
||||
print(f"Number of episodes in full dataset: {full_dataset.num_episodes}")
|
||||
print(f"Number of episodes in training dataset (90% subset): {num_train_episodes}")
|
||||
print(f"Number of episodes in validation dataset (10% reset): {num_val_episodes}")
|
||||
print(f"Number of episodes in validation dataset (10% subset): {num_val_episodes}")
|
||||
# - Get first frame index of the validation set
|
||||
first_val_frame_index = full_dataset.episode_data_index["from"][num_train_episodes].item()
|
||||
# - Load frames subset belonging to validation set using the `split` argument.
|
||||
|
|
|
@ -84,12 +84,12 @@ def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
|
|||
hf_dataset = load_from_disk(str(Path(root) / repo_id / "train"))
|
||||
# TODO(rcadene): clean this which enables getting a subset of dataset
|
||||
if split != "train":
|
||||
match = match = re.search(r"train\[(\d+):\]", split)
|
||||
match = re.search(r"train\[(\d+):\]", split)
|
||||
if match:
|
||||
from_frame_index = int(match.group(1))
|
||||
hf_dataset = hf_dataset.select(range(from_frame_index, len(hf_dataset)))
|
||||
else:
|
||||
raise ValueError(split)
|
||||
raise ValueError(split), '`split` should either be "train" or of the form "train[{int}:]"'
|
||||
else:
|
||||
hf_dataset = load_dataset(repo_id, revision=version, split=split)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
|
|
Loading…
Reference in New Issue