Remove/comment obsolete tests

This commit is contained in:
Simon Alibert 2024-10-31 21:43:29 +01:00
parent ab23a4fd27
commit 443a9eec88
1 changed files with 82 additions and 147 deletions

View File

@ -16,7 +16,6 @@
import json import json
import logging import logging
from copy import deepcopy from copy import deepcopy
from itertools import chain
from pathlib import Path from pathlib import Path
import einops import einops
@ -30,15 +29,13 @@ import lerobot
from lerobot.common.datasets.compute_stats import ( from lerobot.common.datasets.compute_stats import (
aggregate_stats, aggregate_stats,
compute_stats, compute_stats,
get_stats_einops_patterns,
) )
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
create_branch, create_branch,
flatten_dict, flatten_dict,
hf_transform_to_torch, hf_transform_to_torch,
load_previous_and_future_frames,
unflatten_dict, unflatten_dict,
) )
from lerobot.common.utils.utils import init_hydra_config, seeded_context from lerobot.common.utils.utils import init_hydra_config, seeded_context
@ -72,6 +69,7 @@ def test_same_attributes_defined(dataset_create, dataset_init):
assert init_attr == create_attr, "Attribute sets do not match between __init__ and .create()" assert init_attr == create_attr, "Attribute sets do not match between __init__ and .create()"
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"env_name, repo_id, policy_name", "env_name, repo_id, policy_name",
lerobot.env_dataset_policy_triplets lerobot.env_dataset_policy_triplets
@ -143,162 +141,97 @@ def test_factory(env_name, repo_id, policy_name):
assert key in item, f"{key}" assert key in item, f"{key}"
# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds. # # TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds.
def test_multilerobotdataset_frames(): # def test_multilerobotdataset_frames():
"""Check that all dataset frames are incorporated.""" # """Check that all dataset frames are incorporated."""
# Note: use the image variants of the dataset to make the test approx 3x faster. # # Note: use the image variants of the dataset to make the test approx 3x faster.
# Note: We really do need three repo_ids here as at some point this caught an issue with the chaining # # Note: We really do need three repo_ids here as at some point this caught an issue with the chaining
# logic that wouldn't be caught with two repo IDs. # # logic that wouldn't be caught with two repo IDs.
repo_ids = [ # repo_ids = [
"lerobot/aloha_sim_insertion_human_image", # "lerobot/aloha_sim_insertion_human_image",
"lerobot/aloha_sim_transfer_cube_human_image", # "lerobot/aloha_sim_transfer_cube_human_image",
"lerobot/aloha_sim_insertion_scripted_image", # "lerobot/aloha_sim_insertion_scripted_image",
] # ]
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids] # sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
dataset = MultiLeRobotDataset(repo_ids) # dataset = MultiLeRobotDataset(repo_ids)
assert len(dataset) == sum(len(d) for d in sub_datasets) # assert len(dataset) == sum(len(d) for d in sub_datasets)
assert dataset.num_samples == sum(d.num_samples for d in sub_datasets) # assert dataset.num_samples == sum(d.num_samples for d in sub_datasets)
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets) # assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
# Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and # # Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and
# check they match. # # check they match.
expected_dataset_indices = [] # expected_dataset_indices = []
for i, sub_dataset in enumerate(sub_datasets): # for i, sub_dataset in enumerate(sub_datasets):
expected_dataset_indices.extend([i] * len(sub_dataset)) # expected_dataset_indices.extend([i] * len(sub_dataset))
for expected_dataset_index, sub_dataset_item, dataset_item in zip( # for expected_dataset_index, sub_dataset_item, dataset_item in zip(
expected_dataset_indices, chain(*sub_datasets), dataset, strict=True # expected_dataset_indices, chain(*sub_datasets), dataset, strict=True
): # ):
dataset_index = dataset_item.pop("dataset_index") # dataset_index = dataset_item.pop("dataset_index")
assert dataset_index == expected_dataset_index # assert dataset_index == expected_dataset_index
assert sub_dataset_item.keys() == dataset_item.keys() # assert sub_dataset_item.keys() == dataset_item.keys()
for k in sub_dataset_item: # for k in sub_dataset_item:
assert torch.equal(sub_dataset_item[k], dataset_item[k]) # assert torch.equal(sub_dataset_item[k], dataset_item[k])
def test_compute_stats_on_xarm(): # TODO(aliberts, rcadene): Refactor and move this to a tests/test_compute_stats.py
"""Check that the statistics are computed correctly according to the stats_patterns property. # def test_compute_stats_on_xarm():
# """Check that the statistics are computed correctly according to the stats_patterns property.
We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do # We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
because we are working with a small dataset). # because we are working with a small dataset).
""" # """
dataset = LeRobotDataset("lerobot/xarm_lift_medium") # dataset = LeRobotDataset("lerobot/xarm_lift_medium")
# reduce size of dataset sample on which stats compute is tested to 10 frames # # reduce size of dataset sample on which stats compute is tested to 10 frames
dataset.hf_dataset = dataset.hf_dataset.select(range(10)) # dataset.hf_dataset = dataset.hf_dataset.select(range(10))
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched # # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
# computation of the statistics. While doing this, we also make sure it works when we don't divide the # # computation of the statistics. While doing this, we also make sure it works when we don't divide the
# dataset into even batches. # # dataset into even batches.
computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25), num_workers=0) # computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25), num_workers=0)
# get einops patterns to aggregate batches and compute statistics # # get einops patterns to aggregate batches and compute statistics
stats_patterns = get_stats_einops_patterns(dataset) # stats_patterns = get_stats_einops_patterns(dataset)
# get all frames from the dataset in the same dtype and range as during compute_stats # # get all frames from the dataset in the same dtype and range as during compute_stats
dataloader = torch.utils.data.DataLoader( # dataloader = torch.utils.data.DataLoader(
dataset, # dataset,
num_workers=0, # num_workers=0,
batch_size=len(dataset), # batch_size=len(dataset),
shuffle=False, # shuffle=False,
) # )
full_batch = next(iter(dataloader)) # full_batch = next(iter(dataloader))
# compute stats based on all frames from the dataset without any batching # # compute stats based on all frames from the dataset without any batching
expected_stats = {} # expected_stats = {}
for k, pattern in stats_patterns.items(): # for k, pattern in stats_patterns.items():
full_batch[k] = full_batch[k].float() # full_batch[k] = full_batch[k].float()
expected_stats[k] = {} # expected_stats[k] = {}
expected_stats[k]["mean"] = einops.reduce(full_batch[k], pattern, "mean") # expected_stats[k]["mean"] = einops.reduce(full_batch[k], pattern, "mean")
expected_stats[k]["std"] = torch.sqrt( # expected_stats[k]["std"] = torch.sqrt(
einops.reduce((full_batch[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean") # einops.reduce((full_batch[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")
) # )
expected_stats[k]["min"] = einops.reduce(full_batch[k], pattern, "min") # expected_stats[k]["min"] = einops.reduce(full_batch[k], pattern, "min")
expected_stats[k]["max"] = einops.reduce(full_batch[k], pattern, "max") # expected_stats[k]["max"] = einops.reduce(full_batch[k], pattern, "max")
# test computed stats match expected stats # # test computed stats match expected stats
for k in stats_patterns: # for k in stats_patterns:
assert torch.allclose(computed_stats[k]["mean"], expected_stats[k]["mean"]) # assert torch.allclose(computed_stats[k]["mean"], expected_stats[k]["mean"])
assert torch.allclose(computed_stats[k]["std"], expected_stats[k]["std"]) # assert torch.allclose(computed_stats[k]["std"], expected_stats[k]["std"])
assert torch.allclose(computed_stats[k]["min"], expected_stats[k]["min"]) # assert torch.allclose(computed_stats[k]["min"], expected_stats[k]["min"])
assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"]) # assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"])
# load stats used during training which are expected to match the ones returned by computed_stats # # load stats used during training which are expected to match the ones returned by computed_stats
loaded_stats = dataset.stats # noqa: F841 # loaded_stats = dataset.stats # noqa: F841
# TODO(rcadene): we can't test this because expected_stats is computed on a subset # # TODO(rcadene): we can't test this because expected_stats is computed on a subset
# # test loaded stats match expected stats # # # test loaded stats match expected stats
# for k in stats_patterns: # # for k in stats_patterns:
# assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"]) # # assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"])
# assert torch.allclose(loaded_stats[k]["std"], expected_stats[k]["std"]) # # assert torch.allclose(loaded_stats[k]["std"], expected_stats[k]["std"])
# assert torch.allclose(loaded_stats[k]["min"], expected_stats[k]["min"]) # # assert torch.allclose(loaded_stats[k]["min"], expected_stats[k]["min"])
# assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"]) # # assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"])
def test_load_previous_and_future_frames_within_tolerance():
hf_dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_index": [0, 0, 0, 0, 0],
}
)
hf_dataset.set_transform(hf_transform_to_torch)
episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
}
delta_timestamps = {"index": [-0.2, 0, 0.139]}
tol = 0.04
item = hf_dataset[2]
item = load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
data, is_pad = item["index"], item["index_is_pad"]
assert torch.equal(data, torch.tensor([0, 2, 3])), "Data does not match expected values"
assert not is_pad.any(), "Unexpected padding detected"
def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range():
hf_dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_index": [0, 0, 0, 0, 0],
}
)
hf_dataset.set_transform(hf_transform_to_torch)
episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
}
delta_timestamps = {"index": [-0.2, 0, 0.141]}
tol = 0.04
item = hf_dataset[2]
with pytest.raises(AssertionError):
load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range():
hf_dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_index": [0, 0, 0, 0, 0],
}
)
hf_dataset.set_transform(hf_transform_to_torch)
episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
}
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
tol = 0.04
item = hf_dataset[2]
item = load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
data, is_pad = item["index"], item["index_is_pad"]
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
assert torch.equal(
is_pad, torch.tensor([True, False, False, True, True])
), "Padding does not match expected values"
def test_flatten_unflatten_dict(): def test_flatten_unflatten_dict():
@ -324,6 +257,7 @@ def test_flatten_unflatten_dict():
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}" assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"repo_id", "repo_id",
[ [
@ -395,6 +329,7 @@ def test_backward_compatibility(repo_id):
# load_and_compare(i - 1) # load_and_compare(i - 1)
@pytest.mark.skip("TODO after v2 migration / removing hydra")
def test_aggregate_stats(): def test_aggregate_stats():
"""Makes 3 basic datasets and checks that aggregate stats are computed correctly.""" """Makes 3 basic datasets and checks that aggregate stats are computed correctly."""
with seeded_context(0): with seeded_context(0):