fix load_data_with_delta_timestamps and add tests
This commit is contained in:
parent
9229226522
commit
657b27cc8f
|
@ -34,52 +34,56 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def euclidean_distance_matrix(mat0, mat1):
|
|
||||||
# Compute the square of the distance matrix
|
|
||||||
sq0 = torch.sum(mat0**2, dim=1, keepdim=True)
|
|
||||||
sq1 = torch.sum(mat1**2, dim=1, keepdim=True)
|
|
||||||
distance_sq = sq0 + sq1.transpose(0, 1) - 2 * mat0 @ mat1.transpose(0, 1)
|
|
||||||
|
|
||||||
# Taking the square root to get the euclidean distance
|
|
||||||
distance = torch.sqrt(torch.clamp(distance_sq, min=0))
|
|
||||||
return distance
|
|
||||||
|
|
||||||
|
|
||||||
def is_contiguously_true_or_false(bool_vector):
|
|
||||||
assert bool_vector.ndim == 1
|
|
||||||
assert bool_vector.dtype == torch.bool
|
|
||||||
|
|
||||||
# Compare each element with its neighbor to find changes
|
|
||||||
changes = bool_vector[1:] != bool_vector[:-1]
|
|
||||||
|
|
||||||
# Count the number of changes
|
|
||||||
num_changes = changes.sum().item()
|
|
||||||
|
|
||||||
# If there's more than one change, the list is not contiguous
|
|
||||||
return num_changes <= 1
|
|
||||||
|
|
||||||
# examples = [
|
|
||||||
# ([True, False, True, False, False, False], False),
|
|
||||||
# ([True, True, True, False, False, False], True),
|
|
||||||
# ([False, False, False, False, False, False], True)
|
|
||||||
# ]
|
|
||||||
# for bool_list, expected in examples:
|
|
||||||
# result = is_contiguously_true_or_false(bool_list)
|
|
||||||
|
|
||||||
|
|
||||||
def load_data_with_delta_timestamps(
|
def load_data_with_delta_timestamps(
|
||||||
data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode
|
data_dict: dict[torch.Tensor],
|
||||||
|
data_ids_per_episode: dict[torch.Tensor],
|
||||||
|
delta_timestamps: list[float],
|
||||||
|
key: str,
|
||||||
|
current_ts: float,
|
||||||
|
episode: int,
|
||||||
|
tol: float = 0.04,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Given a current timestamp (e.g. current_ts=0.6) and a list of timestamps differences (e.g. delta_timestamps=[-0.8, -0.2, 0, 0.2]),
|
||||||
|
this function compute the query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames of the specified modality (e.g. key="observation.image").
|
||||||
|
|
||||||
|
Importantly, when no frame can be found around a query timestamp within a specified tolerance window (e.g. tol=0.04), this function raises an AssertionError.
|
||||||
|
When a timestamp is queried before the first available timestamp of the episode or after the last available timestamp,
|
||||||
|
the violation of the tolerance doesnt raise an AssertionError, and the function populates a boolean array indicating which frames are outside of the episode range.
|
||||||
|
For instance, this boolean array is useful during batched training to not supervise actions associated to timestamps coming after the end of the episode,
|
||||||
|
or to pad the observations in a specific way. Note that by default the observation frames before the start of the episode are the same as the first frame of the episode.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- data_dict (dict): A dictionary containing the data, where each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
|
||||||
|
- data_ids_per_episode (dict): A dictionary where keys are episode identifiers and values are lists of indices corresponding to frames associated with each episode.
|
||||||
|
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible key to be retrieved. These deltas are added to the current_ts to form the query timestamps.
|
||||||
|
- key (str): The key specifying which data modality is to be retrieved from the data_dict.
|
||||||
|
- current_ts (float): The current timestamp to which the delta timestamps are added to form the query timestamps.
|
||||||
|
- episode (int): The identifier of the episode from which frames are to be retrieved.
|
||||||
|
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query timestamp. Defaults to 0.04.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- tuple: A tuple containing two elements:
|
||||||
|
- The first element is the data retrieved from the specified modality based on the closest match to the query timestamps.
|
||||||
|
- The second element is a boolean array indicating which frames were considered as padding (True if the distance to the closest timestamp was greater than the tolerance level).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
- AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization issues with timestamps during data collection.
|
||||||
|
"""
|
||||||
# get indices of the frames associated to the episode, and their timestamps
|
# get indices of the frames associated to the episode, and their timestamps
|
||||||
ep_data_ids = data_ids_per_episode[episode]
|
ep_data_ids = data_ids_per_episode[episode]
|
||||||
ep_timestamps = data_dict["timestamp"][ep_data_ids]
|
ep_timestamps = data_dict["timestamp"][ep_data_ids]
|
||||||
|
|
||||||
|
# we make the assumption that the timestamps are sorted
|
||||||
|
ep_first_ts = ep_timestamps[0]
|
||||||
|
ep_last_ts = ep_timestamps[-1]
|
||||||
|
|
||||||
# get timestamps used as query to retrieve data of previous/future frames
|
# get timestamps used as query to retrieve data of previous/future frames
|
||||||
delta_ts = delta_timestamps[key]
|
delta_ts = delta_timestamps[key]
|
||||||
query_ts = current_ts + torch.tensor(delta_ts)
|
query_ts = current_ts + torch.tensor(delta_ts)
|
||||||
|
|
||||||
# compute distances between each query timestamp and all timestamps of all the frames belonging to the episode
|
# compute distances between each query timestamp and all timestamps of all the frames belonging to the episode
|
||||||
dist = euclidean_distance_matrix(query_ts[:, None], ep_timestamps[:, None])
|
dist = torch.cdist(query_ts[:, None], ep_timestamps[:, None], p=1)
|
||||||
min_, argmin_ = dist.min(1)
|
min_, argmin_ = dist.min(1)
|
||||||
|
|
||||||
# get the indices of the data that are closest to the query timestamps
|
# get the indices of the data that are closest to the query timestamps
|
||||||
|
@ -91,11 +95,11 @@ def load_data_with_delta_timestamps(
|
||||||
|
|
||||||
# TODO(rcadene): synchronize timestamps + interpolation if needed
|
# TODO(rcadene): synchronize timestamps + interpolation if needed
|
||||||
|
|
||||||
tol = 0.04
|
|
||||||
is_pad = min_ > tol
|
is_pad = min_ > tol
|
||||||
|
|
||||||
assert is_contiguously_true_or_false(is_pad), (
|
# check violated query timestamps are all outside the episode range
|
||||||
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=})."
|
assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), (
|
||||||
|
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=}) inside episode range."
|
||||||
"This might be due to synchronization issues with timestamps during data collection."
|
"This might be due to synchronization issues with timestamps during data collection."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ import einops
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.common.datasets.utils import compute_stats, get_stats_einops_patterns
|
from lerobot.common.datasets.utils import compute_stats, get_stats_einops_patterns, is_contiguously_true_or_false, load_data_with_delta_timestamps
|
||||||
from lerobot.common.datasets.xarm import XarmDataset
|
from lerobot.common.datasets.xarm import XarmDataset
|
||||||
from lerobot.common.transforms import Prod
|
from lerobot.common.transforms import Prod
|
||||||
from lerobot.common.utils import init_hydra_config
|
from lerobot.common.utils import init_hydra_config
|
||||||
|
@ -142,3 +142,49 @@ def test_compute_stats():
|
||||||
# assert torch.allclose(loaded_stats[k]["std"], expected_stats[k]["std"])
|
# assert torch.allclose(loaded_stats[k]["std"], expected_stats[k]["std"])
|
||||||
# assert torch.allclose(loaded_stats[k]["min"], expected_stats[k]["min"])
|
# assert torch.allclose(loaded_stats[k]["min"], expected_stats[k]["min"])
|
||||||
# assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"])
|
# assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_data_with_delta_timestamps_within_tolerance():
|
||||||
|
data_dict = {
|
||||||
|
"timestamp": torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]),
|
||||||
|
"index": torch.tensor([0, 1, 2, 3, 4]),
|
||||||
|
}
|
||||||
|
data_ids_per_episode = {0: torch.tensor([0, 1, 2, 3, 4])}
|
||||||
|
delta_timestamps = {"index": [-0.2, 0, 0.24]}
|
||||||
|
key = "index"
|
||||||
|
current_ts = 0.3
|
||||||
|
episode = 0
|
||||||
|
tol = 0.04
|
||||||
|
data, is_pad = load_data_with_delta_timestamps(data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode, tol)
|
||||||
|
assert not is_pad.any(), "Unexpected padding detected"
|
||||||
|
assert torch.equal(data, torch.tensor([0, 2, 4])), "Data does not match expected values"
|
||||||
|
|
||||||
|
def test_load_data_with_delta_timestamps_outside_tolerance_inside_episode_range():
|
||||||
|
data_dict = {
|
||||||
|
"timestamp": torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]),
|
||||||
|
"index": torch.tensor([0, 1, 2, 3, 4]),
|
||||||
|
}
|
||||||
|
data_ids_per_episode = {0: torch.tensor([0, 1, 2, 3, 4])}
|
||||||
|
delta_timestamps = {"index": [-0.2, 0, 0.14, 0.2]}
|
||||||
|
key = "index"
|
||||||
|
current_ts = 0.3
|
||||||
|
episode = 0
|
||||||
|
tol = 0.03
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
load_data_with_delta_timestamps(data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode, tol)
|
||||||
|
|
||||||
|
def test_load_data_with_delta_timestamps_outside_tolerance_outside_episode_range():
|
||||||
|
data_dict = {
|
||||||
|
"timestamp": torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5]),
|
||||||
|
"index": torch.tensor([0, 1, 2, 3, 4]),
|
||||||
|
}
|
||||||
|
data_ids_per_episode = {0: torch.tensor([0, 1, 2, 3, 4])}
|
||||||
|
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
|
||||||
|
key = "index"
|
||||||
|
current_ts = 0.3
|
||||||
|
episode = 0
|
||||||
|
tol = 0.04
|
||||||
|
data, is_pad = load_data_with_delta_timestamps(data_dict, data_ids_per_episode, delta_timestamps, key, current_ts, episode, tol)
|
||||||
|
assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), "Padding does not match expected values"
|
||||||
|
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue