Fix tolerance for delta_timestamps (#84)
Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
parent
7ad1909641
commit
8d980940a2
|
@ -67,6 +67,7 @@ class AlohaDataset(torch.utils.data.Dataset):
|
||||||
item,
|
item,
|
||||||
self.hf_dataset,
|
self.hf_dataset,
|
||||||
self.delta_timestamps,
|
self.delta_timestamps,
|
||||||
|
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
|
||||||
)
|
)
|
||||||
|
|
||||||
# convert images from channel last (PIL) to channel first (pytorch)
|
# convert images from channel last (PIL) to channel first (pytorch)
|
||||||
|
|
|
@ -65,6 +65,7 @@ class PushtDataset(torch.utils.data.Dataset):
|
||||||
item,
|
item,
|
||||||
self.hf_dataset,
|
self.hf_dataset,
|
||||||
self.delta_timestamps,
|
self.delta_timestamps,
|
||||||
|
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
|
||||||
)
|
)
|
||||||
|
|
||||||
# convert images from channel last (PIL) to channel first (pytorch)
|
# convert images from channel last (PIL) to channel first (pytorch)
|
||||||
|
|
|
@ -11,29 +11,39 @@ def load_previous_and_future_frames(
|
||||||
item: dict[str, torch.Tensor],
|
item: dict[str, torch.Tensor],
|
||||||
hf_dataset: datasets.Dataset,
|
hf_dataset: datasets.Dataset,
|
||||||
delta_timestamps: dict[str, list[float]],
|
delta_timestamps: dict[str, list[float]],
|
||||||
tol: float = 0.04,
|
tol: float,
|
||||||
) -> dict[torch.Tensor]:
|
) -> dict[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Given a current item in the dataset containing a timestamp (e.g. 0.6 seconds), and a list of time differences of some modalities (e.g. delta_timestamps={"observation.image": [-0.8, -0.2, 0, 0.2]}),
|
Given a current item in the dataset containing a timestamp (e.g. 0.6 seconds), and a list of time differences of
|
||||||
this function computes for each given modality a list of query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames in the dataset.
|
some modalities (e.g. delta_timestamps={"observation.image": [-0.8, -0.2, 0, 0.2]}), this function computes for each
|
||||||
|
given modality a list of query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames in the dataset.
|
||||||
|
|
||||||
Importantly, when no frame can be found around a query timestamp within a specified tolerance window (e.g. tol=0.04), this function raises an AssertionError.
|
Importantly, when no frame can be found around a query timestamp within a specified tolerance window, this function
|
||||||
When a timestamp is queried before the first available timestamp of the episode or after the last available timestamp,
|
raises an AssertionError. When a timestamp is queried before the first available timestamp of the episode or after
|
||||||
the violation of the tolerance doesnt raise an AssertionError, and the function populates a boolean array indicating which frames are outside of the episode range.
|
the last available timestamp, the violation of the tolerance doesnt raise an AssertionError, and the function
|
||||||
For instance, this boolean array is useful during batched training to not supervise actions associated to timestamps coming after the end of the episode,
|
populates a boolean array indicating which frames are outside of the episode range. For instance, this boolean array
|
||||||
or to pad the observations in a specific way. Note that by default the observation frames before the start of the episode are the same as the first frame of the episode.
|
is useful during batched training to not supervise actions associated to timestamps coming after the end of the
|
||||||
|
episode, or to pad the observations in a specific way. Note that by default the observation frames before the start
|
||||||
|
of the episode are the same as the first frame of the episode.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
- item (dict): A dictionary containing all the data related to a frame. It is the result of `dataset[idx]`. Each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
|
- item (dict): A dictionary containing all the data related to a frame. It is the result of `dataset[idx]`. Each key
|
||||||
- hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
|
corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
|
||||||
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be retrieved. These deltas are added to the item timestamp to form the query timestamps.
|
- hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different
|
||||||
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query timestamp. Defaults to 0.04.
|
modality (e.g., "timestamp", "observation.image", "action").
|
||||||
|
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be
|
||||||
|
retrieved. These deltas are added to the item timestamp to form the query timestamps.
|
||||||
|
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query
|
||||||
|
timestamp by asserting `tol > difference`. It is suggested to set `tol` to a smaller value than the
|
||||||
|
smallest expected inter-frame period, but large enough to account for jitter.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- The same item with the queried frames for each modality specified in delta_timestamps, with an additional key for each modality (e.g. "observation.image_is_pad").
|
- The same item with the queried frames for each modality specified in delta_timestamps, with an additional key for
|
||||||
|
each modality (e.g. "observation.image_is_pad").
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
- AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization issues with timestamps during data collection.
|
- AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization
|
||||||
|
issues with timestamps during data collection.
|
||||||
"""
|
"""
|
||||||
# get indices of the frames associated to the episode, and their timestamps
|
# get indices of the frames associated to the episode, and their timestamps
|
||||||
ep_data_id_from = item["episode_data_index_from"].item()
|
ep_data_id_from = item["episode_data_index_from"].item()
|
||||||
|
|
|
@ -59,6 +59,7 @@ class XarmDataset(torch.utils.data.Dataset):
|
||||||
item,
|
item,
|
||||||
self.hf_dataset,
|
self.hf_dataset,
|
||||||
self.delta_timestamps,
|
self.delta_timestamps,
|
||||||
|
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
|
||||||
)
|
)
|
||||||
|
|
||||||
# convert images from channel last (PIL) to channel first (pytorch)
|
# convert images from channel last (PIL) to channel first (pytorch)
|
||||||
|
|
Loading…
Reference in New Issue