Apply suggestions from code review
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
parent
8ca0bc5e88
commit
3966079f09
|
@ -40,7 +40,7 @@ delta_timestamps = {
|
||||||
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
|
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Load the last 10% episodes of the dataset as a validation set.
|
# Load the last 10% of episodes of the dataset as a validation set.
|
||||||
# - Load full dataset
|
# - Load full dataset
|
||||||
full_dataset = LeRobotDataset("lerobot/pusht", split="train")
|
full_dataset = LeRobotDataset("lerobot/pusht", split="train")
|
||||||
# - Calculate train and val subsets
|
# - Calculate train and val subsets
|
||||||
|
@ -48,7 +48,7 @@ num_train_episodes = math.floor(full_dataset.num_episodes * 90 / 100)
|
||||||
num_val_episodes = full_dataset.num_episodes - num_train_episodes
|
num_val_episodes = full_dataset.num_episodes - num_train_episodes
|
||||||
print(f"Number of episodes in full dataset: {full_dataset.num_episodes}")
|
print(f"Number of episodes in full dataset: {full_dataset.num_episodes}")
|
||||||
print(f"Number of episodes in training dataset (90% subset): {num_train_episodes}")
|
print(f"Number of episodes in training dataset (90% subset): {num_train_episodes}")
|
||||||
print(f"Number of episodes in validation dataset (10% reset): {num_val_episodes}")
|
print(f"Number of episodes in validation dataset (10% subset): {num_val_episodes}")
|
||||||
# - Get first frame index of the validation set
|
# - Get first frame index of the validation set
|
||||||
first_val_frame_index = full_dataset.episode_data_index["from"][num_train_episodes].item()
|
first_val_frame_index = full_dataset.episode_data_index["from"][num_train_episodes].item()
|
||||||
# - Load frames subset belonging to validation set using the `split` argument.
|
# - Load frames subset belonging to validation set using the `split` argument.
|
||||||
|
|
|
@ -84,12 +84,12 @@ def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
|
||||||
hf_dataset = load_from_disk(str(Path(root) / repo_id / "train"))
|
hf_dataset = load_from_disk(str(Path(root) / repo_id / "train"))
|
||||||
# TODO(rcadene): clean this which enables getting a subset of dataset
|
# TODO(rcadene): clean this which enables getting a subset of dataset
|
||||||
if split != "train":
|
if split != "train":
|
||||||
match = match = re.search(r"train\[(\d+):\]", split)
|
match = re.search(r"train\[(\d+):\]", split)
|
||||||
if match:
|
if match:
|
||||||
from_frame_index = int(match.group(1))
|
from_frame_index = int(match.group(1))
|
||||||
hf_dataset = hf_dataset.select(range(from_frame_index, len(hf_dataset)))
|
hf_dataset = hf_dataset.select(range(from_frame_index, len(hf_dataset)))
|
||||||
else:
|
else:
|
||||||
raise ValueError(split)
|
raise ValueError(split), '`split` should either be "train" or of the form "train[{int}:]"'
|
||||||
else:
|
else:
|
||||||
hf_dataset = load_dataset(repo_id, revision=version, split=split)
|
hf_dataset = load_dataset(repo_id, revision=version, split=split)
|
||||||
hf_dataset.set_transform(hf_transform_to_torch)
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
|
|
Loading…
Reference in New Issue