From 657b27cc8f6c2a3db54b8e643e647b17f27eb25a Mon Sep 17 00:00:00 2001 From: Cadene Date: Thu, 11 Apr 2024 12:59:09 +0000 Subject: [PATCH] fix load_data_with_delta_timestamps and add tests --- lerobot/common/datasets/utils.py | 80 +++++++++++++++++--------------- tests/test_datasets.py | 48 ++++++++++++++++++- 2 files changed, 89 insertions(+), 39 deletions(-) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index cf8caa46..e67d8a04 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -34,52 +34,56 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool: return False -def euclidean_distance_matrix(mat0, mat1): - # Compute the square of the distance matrix - sq0 = torch.sum(mat0**2, dim=1, keepdim=True) - sq1 = torch.sum(mat1**2, dim=1, keepdim=True) - distance_sq = sq0 + sq1.transpose(0, 1) - 2 * mat0 @ mat1.transpose(0, 1) - - # Taking the square root to get the euclidean distance - distance = torch.sqrt(torch.clamp(distance_sq, min=0)) - return distance - - -def is_contiguously_true_or_false(bool_vector): - assert bool_vector.ndim == 1 - assert bool_vector.dtype == torch.bool - - # Compare each element with its neighbor to find changes - changes = bool_vector[1:] != bool_vector[:-1] - - # Count the number of changes - num_changes = changes.sum().item() - - # If there's more than one change, the list is not contiguous - return num_changes <= 1 - - # examples = [ - # ([True, False, True, False, False, False], False), - # ([True, True, True, False, False, False], True), - # ([False, False, False, False, False, False], True) - # ] - # for bool_list, expected in examples: - # result = is_contiguously_true_or_false(bool_list) - - def load_data_with_delta_timestamps( - data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode + data_dict: dict[torch.Tensor], + data_ids_per_episode: dict[torch.Tensor], + delta_timestamps: list[float], + key: str, + current_ts: float, + episode: int, + tol: float = 0.04, ): + """ + Given a current timestamp (e.g. current_ts=0.6) and a list of timestamps differences (e.g. delta_timestamps=[-0.8, -0.2, 0, 0.2]), + this function compute the query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames of the specified modality (e.g. key="observation.image"). + + Importantly, when no frame can be found around a query timestamp within a specified tolerance window (e.g. tol=0.04), this function raises an AssertionError. + When a timestamp is queried before the first available timestamp of the episode or after the last available timestamp, + the violation of the tolerance doesnt raise an AssertionError, and the function populates a boolean array indicating which frames are outside of the episode range. + For instance, this boolean array is useful during batched training to not supervise actions associated to timestamps coming after the end of the episode, + or to pad the observations in a specific way. Note that by default the observation frames before the start of the episode are the same as the first frame of the episode. + + Parameters: + - data_dict (dict): A dictionary containing the data, where each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action"). + - data_ids_per_episode (dict): A dictionary where keys are episode identifiers and values are lists of indices corresponding to frames associated with each episode. + - delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible key to be retrieved. These deltas are added to the current_ts to form the query timestamps. + - key (str): The key specifying which data modality is to be retrieved from the data_dict. + - current_ts (float): The current timestamp to which the delta timestamps are added to form the query timestamps. + - episode (int): The identifier of the episode from which frames are to be retrieved. + - tol (float, optional): The tolerance level used to determine if a data point is close enough to the query timestamp. Defaults to 0.04. + + Returns: + - tuple: A tuple containing two elements: + - The first element is the data retrieved from the specified modality based on the closest match to the query timestamps. + - The second element is a boolean array indicating which frames were considered as padding (True if the distance to the closest timestamp was greater than the tolerance level). + + Raises: + - AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization issues with timestamps during data collection. + """ # get indices of the frames associated to the episode, and their timestamps ep_data_ids = data_ids_per_episode[episode] ep_timestamps = data_dict["timestamp"][ep_data_ids] + # we make the assumption that the timestamps are sorted + ep_first_ts = ep_timestamps[0] + ep_last_ts = ep_timestamps[-1] + # get timestamps used as query to retrieve data of previous/future frames delta_ts = delta_timestamps[key] query_ts = current_ts + torch.tensor(delta_ts) # compute distances between each query timestamp and all timestamps of all the frames belonging to the episode - dist = euclidean_distance_matrix(query_ts[:, None], ep_timestamps[:, None]) + dist = torch.cdist(query_ts[:, None], ep_timestamps[:, None], p=1) min_, argmin_ = dist.min(1) # get the indices of the data that are closest to the query timestamps @@ -91,11 +95,11 @@ def load_data_with_delta_timestamps( # TODO(rcadene): synchronize timestamps + interpolation if needed - tol = 0.04 is_pad = min_ > tol - assert is_contiguously_true_or_false(is_pad), ( - f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=})." + # check violated query timestamps are all outside the episode range + assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), ( + f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=}) inside episode range." "This might be due to synchronization issues with timestamps during data collection." ) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 9b32ea25..d56e8252 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -4,7 +4,7 @@ import einops import pytest import torch -from lerobot.common.datasets.utils import compute_stats, get_stats_einops_patterns +from lerobot.common.datasets.utils import compute_stats, get_stats_einops_patterns, is_contiguously_true_or_false, load_data_with_delta_timestamps from lerobot.common.datasets.xarm import XarmDataset from lerobot.common.transforms import Prod from lerobot.common.utils import init_hydra_config @@ -142,3 +142,49 @@ def test_compute_stats(): # assert torch.allclose(loaded_stats[k]["std"], expected_stats[k]["std"]) # assert torch.allclose(loaded_stats[k]["min"], expected_stats[k]["min"]) # assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"]) + + +def test_load_data_with_delta_timestamps_within_tolerance(): + data_dict = { + "timestamp": torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]), + "index": torch.tensor([0, 1, 2, 3, 4]), + } + data_ids_per_episode = {0: torch.tensor([0, 1, 2, 3, 4])} + delta_timestamps = {"index": [-0.2, 0, 0.24]} + key = "index" + current_ts = 0.3 + episode = 0 + tol = 0.04 + data, is_pad = load_data_with_delta_timestamps(data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode, tol) + assert not is_pad.any(), "Unexpected padding detected" + assert torch.equal(data, torch.tensor([0, 2, 4])), "Data does not match expected values" + +def test_load_data_with_delta_timestamps_outside_tolerance_inside_episode_range(): + data_dict = { + "timestamp": torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]), + "index": torch.tensor([0, 1, 2, 3, 4]), + } + data_ids_per_episode = {0: torch.tensor([0, 1, 2, 3, 4])} + delta_timestamps = {"index": [-0.2, 0, 0.14, 0.2]} + key = "index" + current_ts = 0.3 + episode = 0 + tol = 0.03 + with pytest.raises(AssertionError): + load_data_with_delta_timestamps(data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode, tol) + +def test_load_data_with_delta_timestamps_outside_tolerance_outside_episode_range(): + data_dict = { + "timestamp": torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]), + "index": torch.tensor([0, 1, 2, 3, 4]), + } + data_ids_per_episode = {0: torch.tensor([0, 1, 2, 3, 4])} + delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]} + key = "index" + current_ts = 0.3 + episode = 0 + tol = 0.04 + data, is_pad = load_data_with_delta_timestamps(data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode, tol) + assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), "Padding does not match expected values" + assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values" +