From 633115d861f05ed0dbc71ac70790fd2b6610a527 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Fri, 31 May 2024 09:03:28 +0100 Subject: [PATCH 1/2] Fix chaining in MultiLerobotDataset (#233) --- lerobot/common/datasets/lerobot_dataset.py | 1 + tests/test_datasets.py | 9 ++++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index a87c3ee8..58ae51b1 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -371,6 +371,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): if idx >= start_idx + dataset.num_samples: start_idx += dataset.num_samples dataset_idx += 1 + continue break else: raise AssertionError("We expect the loop to break out as long as the index is within bounds.") diff --git a/tests/test_datasets.py b/tests/test_datasets.py index dac18c14..da0ae755 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -114,10 +114,17 @@ def test_factory(env_name, repo_id, policy_name): assert key in item, f"{key}" +# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds. def test_multilerobotdataset_frames(): """Check that all dataset frames are incorporated.""" # Note: use the image variants of the dataset to make the test approx 3x faster. - repo_ids = ["lerobot/aloha_sim_insertion_human_image", "lerobot/aloha_sim_transfer_cube_human_image"] + # Note: We really do need three repo_ids here as at some point this caught an issue with the chaining + # logic that wouldn't be caught with two repo IDs. + repo_ids = [ + "lerobot/aloha_sim_insertion_human_image", + "lerobot/aloha_sim_transfer_cube_human_image", + "lerobot/aloha_sim_insertion_scripted_image", + ] sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids] dataset = MultiLeRobotDataset(repo_ids) assert len(dataset) == sum(len(d) for d in sub_datasets) From 83f4f7f7e83d8f0115463c7dd1b8e0b0da863dc2 Mon Sep 17 00:00:00 2001 From: Radek Osmulski Date: Fri, 31 May 2024 18:19:01 +1000 Subject: [PATCH 2/2] Add precision param to format_big_number (#232) --- lerobot/common/utils/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py index 696999ad..c429efbd 100644 --- a/lerobot/common/utils/utils.py +++ b/lerobot/common/utils/utils.py @@ -120,13 +120,13 @@ def init_logging(): logging.getLogger().addHandler(console_handler) -def format_big_number(num): +def format_big_number(num, precision=0): suffixes = ["", "K", "M", "B", "T", "Q"] divisor = 1000.0 for suffix in suffixes: if abs(num) < divisor: - return f"{num:.0f}{suffix}" + return f"{num:.{precision}f}{suffix}" num /= divisor return num