fix load_data_with_delta_timestamps and add tests

This commit is contained in:
Cadene 2024-04-11 12:59:09 +00:00
parent 9229226522
commit 657b27cc8f
2 changed files with 89 additions and 39 deletions

View File

@ -34,52 +34,56 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
return False
def euclidean_distance_matrix(mat0, mat1):
# Compute the square of the distance matrix
sq0 = torch.sum(mat0**2, dim=1, keepdim=True)
sq1 = torch.sum(mat1**2, dim=1, keepdim=True)
distance_sq = sq0 + sq1.transpose(0, 1) - 2 * mat0 @ mat1.transpose(0, 1)
# Taking the square root to get the euclidean distance
distance = torch.sqrt(torch.clamp(distance_sq, min=0))
return distance
def is_contiguously_true_or_false(bool_vector):
assert bool_vector.ndim == 1
assert bool_vector.dtype == torch.bool
# Compare each element with its neighbor to find changes
changes = bool_vector[1:] != bool_vector[:-1]
# Count the number of changes
num_changes = changes.sum().item()
# If there's more than one change, the list is not contiguous
return num_changes <= 1
# examples = [
# ([True, False, True, False, False, False], False),
# ([True, True, True, False, False, False], True),
# ([False, False, False, False, False, False], True)
# ]
# for bool_list, expected in examples:
# result = is_contiguously_true_or_false(bool_list)
def load_data_with_delta_timestamps(
data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode
data_dict: dict[torch.Tensor],
data_ids_per_episode: dict[torch.Tensor],
delta_timestamps: list[float],
key: str,
current_ts: float,
episode: int,
tol: float = 0.04,
):
"""
Given a current timestamp (e.g. current_ts=0.6) and a list of timestamps differences (e.g. delta_timestamps=[-0.8, -0.2, 0, 0.2]),
this function compute the query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames of the specified modality (e.g. key="observation.image").
Importantly, when no frame can be found around a query timestamp within a specified tolerance window (e.g. tol=0.04), this function raises an AssertionError.
When a timestamp is queried before the first available timestamp of the episode or after the last available timestamp,
the violation of the tolerance doesnt raise an AssertionError, and the function populates a boolean array indicating which frames are outside of the episode range.
For instance, this boolean array is useful during batched training to not supervise actions associated to timestamps coming after the end of the episode,
or to pad the observations in a specific way. Note that by default the observation frames before the start of the episode are the same as the first frame of the episode.
Parameters:
- data_dict (dict): A dictionary containing the data, where each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
- data_ids_per_episode (dict): A dictionary where keys are episode identifiers and values are lists of indices corresponding to frames associated with each episode.
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible key to be retrieved. These deltas are added to the current_ts to form the query timestamps.
- key (str): The key specifying which data modality is to be retrieved from the data_dict.
- current_ts (float): The current timestamp to which the delta timestamps are added to form the query timestamps.
- episode (int): The identifier of the episode from which frames are to be retrieved.
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query timestamp. Defaults to 0.04.
Returns:
- tuple: A tuple containing two elements:
- The first element is the data retrieved from the specified modality based on the closest match to the query timestamps.
- The second element is a boolean array indicating which frames were considered as padding (True if the distance to the closest timestamp was greater than the tolerance level).
Raises:
- AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization issues with timestamps during data collection.
"""
# get indices of the frames associated to the episode, and their timestamps
ep_data_ids = data_ids_per_episode[episode]
ep_timestamps = data_dict["timestamp"][ep_data_ids]
# we make the assumption that the timestamps are sorted
ep_first_ts = ep_timestamps[0]
ep_last_ts = ep_timestamps[-1]
# get timestamps used as query to retrieve data of previous/future frames
delta_ts = delta_timestamps[key]
query_ts = current_ts + torch.tensor(delta_ts)
# compute distances between each query timestamp and all timestamps of all the frames belonging to the episode
dist = euclidean_distance_matrix(query_ts[:, None], ep_timestamps[:, None])
dist = torch.cdist(query_ts[:, None], ep_timestamps[:, None], p=1)
min_, argmin_ = dist.min(1)
# get the indices of the data that are closest to the query timestamps
@ -91,11 +95,11 @@ def load_data_with_delta_timestamps(
# TODO(rcadene): synchronize timestamps + interpolation if needed
tol = 0.04
is_pad = min_ > tol
assert is_contiguously_true_or_false(is_pad), (
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=})."
# check violated query timestamps are all outside the episode range
assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), (
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=}) inside episode range."
"This might be due to synchronization issues with timestamps during data collection."
)

View File

@ -4,7 +4,7 @@ import einops
import pytest
import torch
from lerobot.common.datasets.utils import compute_stats, get_stats_einops_patterns
from lerobot.common.datasets.utils import compute_stats, get_stats_einops_patterns, is_contiguously_true_or_false, load_data_with_delta_timestamps
from lerobot.common.datasets.xarm import XarmDataset
from lerobot.common.transforms import Prod
from lerobot.common.utils import init_hydra_config
@ -142,3 +142,49 @@ def test_compute_stats():
# assert torch.allclose(loaded_stats[k]["std"], expected_stats[k]["std"])
# assert torch.allclose(loaded_stats[k]["min"], expected_stats[k]["min"])
# assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"])
def test_load_data_with_delta_timestamps_within_tolerance():
data_dict = {
"timestamp": torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]),
"index": torch.tensor([0, 1, 2, 3, 4]),
}
data_ids_per_episode = {0: torch.tensor([0, 1, 2, 3, 4])}
delta_timestamps = {"index": [-0.2, 0, 0.24]}
key = "index"
current_ts = 0.3
episode = 0
tol = 0.04
data, is_pad = load_data_with_delta_timestamps(data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode, tol)
assert not is_pad.any(), "Unexpected padding detected"
assert torch.equal(data, torch.tensor([0, 2, 4])), "Data does not match expected values"
def test_load_data_with_delta_timestamps_outside_tolerance_inside_episode_range():
data_dict = {
"timestamp": torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]),
"index": torch.tensor([0, 1, 2, 3, 4]),
}
data_ids_per_episode = {0: torch.tensor([0, 1, 2, 3, 4])}
delta_timestamps = {"index": [-0.2, 0, 0.14, 0.2]}
key = "index"
current_ts = 0.3
episode = 0
tol = 0.03
with pytest.raises(AssertionError):
load_data_with_delta_timestamps(data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode, tol)
def test_load_data_with_delta_timestamps_outside_tolerance_outside_episode_range():
data_dict = {
"timestamp": torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]),
"index": torch.tensor([0, 1, 2, 3, 4]),
}
data_ids_per_episode = {0: torch.tensor([0, 1, 2, 3, 4])}
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
key = "index"
current_ts = 0.3
episode = 0
tol = 0.04
data, is_pad = load_data_with_delta_timestamps(data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode, tol)
assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), "Padding does not match expected values"
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"