Add missing continue

This commit is contained in:
Alexander Soare 2024-05-31 08:48:18 +01:00
parent 57fb5fe8a6
commit 6751385e5b
2 changed files with 9 additions and 1 deletions

View File

@ -371,6 +371,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
if idx >= start_idx + dataset.num_samples: if idx >= start_idx + dataset.num_samples:
start_idx += dataset.num_samples start_idx += dataset.num_samples
dataset_idx += 1 dataset_idx += 1
continue
break break
else: else:
raise AssertionError("We expect the loop to break out as long as the index is within bounds.") raise AssertionError("We expect the loop to break out as long as the index is within bounds.")

View File

@ -114,10 +114,17 @@ def test_factory(env_name, repo_id, policy_name):
assert key in item, f"{key}" assert key in item, f"{key}"
# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds.
def test_multilerobotdataset_frames(): def test_multilerobotdataset_frames():
"""Check that all dataset frames are incorporated.""" """Check that all dataset frames are incorporated."""
# Note: use the image variants of the dataset to make the test approx 3x faster. # Note: use the image variants of the dataset to make the test approx 3x faster.
repo_ids = ["lerobot/aloha_sim_insertion_human_image", "lerobot/aloha_sim_transfer_cube_human_image"] # Note: We really do need three repo_ids here as at some point this caught an issue with the chaining
# logic that wouldn't be caught with two repo IDs.
repo_ids = [
"lerobot/aloha_sim_insertion_human_image",
"lerobot/aloha_sim_transfer_cube_human_image",
"lerobot/aloha_sim_insertion_scripted_image",
]
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids] sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
dataset = MultiLeRobotDataset(repo_ids) dataset = MultiLeRobotDataset(repo_ids)
assert len(dataset) == sum(len(d) for d in sub_datasets) assert len(dataset) == sum(len(d) for d in sub_datasets)