Merge branch 'main' into add_drop_last_keyframes_sampler_approach
This commit is contained in:
commit
e383086c68
|
@ -371,6 +371,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||
if idx >= start_idx + dataset.num_samples:
|
||||
start_idx += dataset.num_samples
|
||||
dataset_idx += 1
|
||||
continue
|
||||
break
|
||||
else:
|
||||
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
|
||||
|
|
|
@ -120,13 +120,13 @@ def init_logging():
|
|||
logging.getLogger().addHandler(console_handler)
|
||||
|
||||
|
||||
def format_big_number(num):
|
||||
def format_big_number(num, precision=0):
|
||||
suffixes = ["", "K", "M", "B", "T", "Q"]
|
||||
divisor = 1000.0
|
||||
|
||||
for suffix in suffixes:
|
||||
if abs(num) < divisor:
|
||||
return f"{num:.0f}{suffix}"
|
||||
return f"{num:.{precision}f}{suffix}"
|
||||
num /= divisor
|
||||
|
||||
return num
|
||||
|
|
|
@ -114,10 +114,17 @@ def test_factory(env_name, repo_id, policy_name):
|
|||
assert key in item, f"{key}"
|
||||
|
||||
|
||||
# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds.
|
||||
def test_multilerobotdataset_frames():
|
||||
"""Check that all dataset frames are incorporated."""
|
||||
# Note: use the image variants of the dataset to make the test approx 3x faster.
|
||||
repo_ids = ["lerobot/aloha_sim_insertion_human_image", "lerobot/aloha_sim_transfer_cube_human_image"]
|
||||
# Note: We really do need three repo_ids here as at some point this caught an issue with the chaining
|
||||
# logic that wouldn't be caught with two repo IDs.
|
||||
repo_ids = [
|
||||
"lerobot/aloha_sim_insertion_human_image",
|
||||
"lerobot/aloha_sim_transfer_cube_human_image",
|
||||
"lerobot/aloha_sim_insertion_scripted_image",
|
||||
]
|
||||
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
|
||||
dataset = MultiLeRobotDataset(repo_ids)
|
||||
assert len(dataset) == sum(len(d) for d in sub_datasets)
|
||||
|
|
Loading…
Reference in New Issue