Add missing continue
This commit is contained in:
parent
57fb5fe8a6
commit
6751385e5b
|
@ -371,6 +371,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||
if idx >= start_idx + dataset.num_samples:
|
||||
start_idx += dataset.num_samples
|
||||
dataset_idx += 1
|
||||
continue
|
||||
break
|
||||
else:
|
||||
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
|
||||
|
|
|
@ -114,10 +114,17 @@ def test_factory(env_name, repo_id, policy_name):
|
|||
assert key in item, f"{key}"
|
||||
|
||||
|
||||
# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds.
|
||||
def test_multilerobotdataset_frames():
|
||||
"""Check that all dataset frames are incorporated."""
|
||||
# Note: use the image variants of the dataset to make the test approx 3x faster.
|
||||
repo_ids = ["lerobot/aloha_sim_insertion_human_image", "lerobot/aloha_sim_transfer_cube_human_image"]
|
||||
# Note: We really do need three repo_ids here as at some point this caught an issue with the chaining
|
||||
# logic that wouldn't be caught with two repo IDs.
|
||||
repo_ids = [
|
||||
"lerobot/aloha_sim_insertion_human_image",
|
||||
"lerobot/aloha_sim_transfer_cube_human_image",
|
||||
"lerobot/aloha_sim_insertion_scripted_image",
|
||||
]
|
||||
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
|
||||
dataset = MultiLeRobotDataset(repo_ids)
|
||||
assert len(dataset) == sum(len(d) for d in sub_datasets)
|
||||
|
|
Loading…
Reference in New Issue