Fix tolerance for delta_timestamps (#84)

Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
Alexander Soare 2024-04-18 18:48:22 +01:00 committed by GitHub
parent 7ad1909641
commit 8d980940a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 27 additions and 14 deletions

View File

@ -67,6 +67,7 @@ class AlohaDataset(torch.utils.data.Dataset):
item, item,
self.hf_dataset, self.hf_dataset,
self.delta_timestamps, self.delta_timestamps,
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
) )
# convert images from channel last (PIL) to channel first (pytorch) # convert images from channel last (PIL) to channel first (pytorch)

View File

@ -65,6 +65,7 @@ class PushtDataset(torch.utils.data.Dataset):
item, item,
self.hf_dataset, self.hf_dataset,
self.delta_timestamps, self.delta_timestamps,
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
) )
# convert images from channel last (PIL) to channel first (pytorch) # convert images from channel last (PIL) to channel first (pytorch)

View File

@ -11,29 +11,39 @@ def load_previous_and_future_frames(
item: dict[str, torch.Tensor], item: dict[str, torch.Tensor],
hf_dataset: datasets.Dataset, hf_dataset: datasets.Dataset,
delta_timestamps: dict[str, list[float]], delta_timestamps: dict[str, list[float]],
tol: float = 0.04, tol: float,
) -> dict[torch.Tensor]: ) -> dict[torch.Tensor]:
""" """
Given a current item in the dataset containing a timestamp (e.g. 0.6 seconds), and a list of time differences of some modalities (e.g. delta_timestamps={"observation.image": [-0.8, -0.2, 0, 0.2]}), Given a current item in the dataset containing a timestamp (e.g. 0.6 seconds), and a list of time differences of
this function computes for each given modality a list of query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames in the dataset. some modalities (e.g. delta_timestamps={"observation.image": [-0.8, -0.2, 0, 0.2]}), this function computes for each
given modality a list of query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames in the dataset.
Importantly, when no frame can be found around a query timestamp within a specified tolerance window (e.g. tol=0.04), this function raises an AssertionError. Importantly, when no frame can be found around a query timestamp within a specified tolerance window, this function
When a timestamp is queried before the first available timestamp of the episode or after the last available timestamp, raises an AssertionError. When a timestamp is queried before the first available timestamp of the episode or after
the violation of the tolerance doesnt raise an AssertionError, and the function populates a boolean array indicating which frames are outside of the episode range. the last available timestamp, the violation of the tolerance doesnt raise an AssertionError, and the function
For instance, this boolean array is useful during batched training to not supervise actions associated to timestamps coming after the end of the episode, populates a boolean array indicating which frames are outside of the episode range. For instance, this boolean array
or to pad the observations in a specific way. Note that by default the observation frames before the start of the episode are the same as the first frame of the episode. is useful during batched training to not supervise actions associated to timestamps coming after the end of the
episode, or to pad the observations in a specific way. Note that by default the observation frames before the start
of the episode are the same as the first frame of the episode.
Parameters: Parameters:
- item (dict): A dictionary containing all the data related to a frame. It is the result of `dataset[idx]`. Each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action"). - item (dict): A dictionary containing all the data related to a frame. It is the result of `dataset[idx]`. Each key
- hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action"). corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be retrieved. These deltas are added to the item timestamp to form the query timestamps. - hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query timestamp. Defaults to 0.04. modality (e.g., "timestamp", "observation.image", "action").
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be
retrieved. These deltas are added to the item timestamp to form the query timestamps.
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query
timestamp by asserting `tol > difference`. It is suggested to set `tol` to a smaller value than the
smallest expected inter-frame period, but large enough to account for jitter.
Returns: Returns:
- The same item with the queried frames for each modality specified in delta_timestamps, with an additional key for each modality (e.g. "observation.image_is_pad"). - The same item with the queried frames for each modality specified in delta_timestamps, with an additional key for
each modality (e.g. "observation.image_is_pad").
Raises: Raises:
- AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization issues with timestamps during data collection. - AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization
issues with timestamps during data collection.
""" """
# get indices of the frames associated to the episode, and their timestamps # get indices of the frames associated to the episode, and their timestamps
ep_data_id_from = item["episode_data_index_from"].item() ep_data_id_from = item["episode_data_index_from"].item()

View File

@ -59,6 +59,7 @@ class XarmDataset(torch.utils.data.Dataset):
item, item,
self.hf_dataset, self.hf_dataset,
self.delta_timestamps, self.delta_timestamps,
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
) )
# convert images from channel last (PIL) to channel first (pytorch) # convert images from channel last (PIL) to channel first (pytorch)