diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 4769a2bf..785b68e5 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -67,6 +67,7 @@ class AlohaDataset(torch.utils.data.Dataset): item, self.hf_dataset, self.delta_timestamps, + tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error ) # convert images from channel last (PIL) to channel first (pytorch) diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index c5c06bf6..2879c177 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -65,6 +65,7 @@ class PushtDataset(torch.utils.data.Dataset): item, self.hf_dataset, self.delta_timestamps, + tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error ) # convert images from channel last (PIL) to channel first (pytorch) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index f4699cc5..50c50856 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -11,29 +11,39 @@ def load_previous_and_future_frames( item: dict[str, torch.Tensor], hf_dataset: datasets.Dataset, delta_timestamps: dict[str, list[float]], - tol: float = 0.04, + tol: float, ) -> dict[torch.Tensor]: """ - Given a current item in the dataset containing a timestamp (e.g. 0.6 seconds), and a list of time differences of some modalities (e.g. delta_timestamps={"observation.image": [-0.8, -0.2, 0, 0.2]}), - this function computes for each given modality a list of query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames in the dataset. + Given a current item in the dataset containing a timestamp (e.g. 0.6 seconds), and a list of time differences of + some modalities (e.g. delta_timestamps={"observation.image": [-0.8, -0.2, 0, 0.2]}), this function computes for each + given modality a list of query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames in the dataset. - Importantly, when no frame can be found around a query timestamp within a specified tolerance window (e.g. tol=0.04), this function raises an AssertionError. - When a timestamp is queried before the first available timestamp of the episode or after the last available timestamp, - the violation of the tolerance doesnt raise an AssertionError, and the function populates a boolean array indicating which frames are outside of the episode range. - For instance, this boolean array is useful during batched training to not supervise actions associated to timestamps coming after the end of the episode, - or to pad the observations in a specific way. Note that by default the observation frames before the start of the episode are the same as the first frame of the episode. + Importantly, when no frame can be found around a query timestamp within a specified tolerance window, this function + raises an AssertionError. When a timestamp is queried before the first available timestamp of the episode or after + the last available timestamp, the violation of the tolerance doesnt raise an AssertionError, and the function + populates a boolean array indicating which frames are outside of the episode range. For instance, this boolean array + is useful during batched training to not supervise actions associated to timestamps coming after the end of the + episode, or to pad the observations in a specific way. Note that by default the observation frames before the start + of the episode are the same as the first frame of the episode. Parameters: - - item (dict): A dictionary containing all the data related to a frame. It is the result of `dataset[idx]`. Each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action"). - - hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action"). - - delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be retrieved. These deltas are added to the item timestamp to form the query timestamps. - - tol (float, optional): The tolerance level used to determine if a data point is close enough to the query timestamp. Defaults to 0.04. + - item (dict): A dictionary containing all the data related to a frame. It is the result of `dataset[idx]`. Each key + corresponds to a different modality (e.g., "timestamp", "observation.image", "action"). + - hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different + modality (e.g., "timestamp", "observation.image", "action"). + - delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be + retrieved. These deltas are added to the item timestamp to form the query timestamps. + - tol (float, optional): The tolerance level used to determine if a data point is close enough to the query + timestamp by asserting `tol > difference`. It is suggested to set `tol` to a smaller value than the + smallest expected inter-frame period, but large enough to account for jitter. Returns: - - The same item with the queried frames for each modality specified in delta_timestamps, with an additional key for each modality (e.g. "observation.image_is_pad"). + - The same item with the queried frames for each modality specified in delta_timestamps, with an additional key for + each modality (e.g. "observation.image_is_pad"). Raises: - - AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization issues with timestamps during data collection. + - AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization + issues with timestamps during data collection. """ # get indices of the frames associated to the episode, and their timestamps ep_data_id_from = item["episode_data_index_from"].item() diff --git a/lerobot/common/datasets/xarm.py b/lerobot/common/datasets/xarm.py index 711ff642..385b7d99 100644 --- a/lerobot/common/datasets/xarm.py +++ b/lerobot/common/datasets/xarm.py @@ -59,6 +59,7 @@ class XarmDataset(torch.utils.data.Dataset): item, self.hf_dataset, self.delta_timestamps, + tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error ) # convert images from channel last (PIL) to channel first (pytorch)