From 5ea7c782373006d02ac8842e6b21a041a6159980 Mon Sep 17 00:00:00 2001 From: Simon Alibert <simon.alibert@huggingface.co> Date: Thu, 31 Oct 2024 21:43:57 +0100 Subject: [PATCH] Remove obsolete code --- lerobot/common/datasets/utils.py | 94 -------------------------------- 1 file changed, 94 deletions(-) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index e5cc02f9..0e60af3f 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -368,100 +368,6 @@ def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dic return delta_indices -# TODO(aliberts): remove -def load_previous_and_future_frames( - item: dict[str, torch.Tensor], - hf_dataset: datasets.Dataset, - episode_data_index: dict[str, torch.Tensor], - delta_timestamps: dict[str, list[float]], - tolerance_s: float, -) -> dict[torch.Tensor]: - """ - Given a current item in the dataset containing a timestamp (e.g. 0.6 seconds), and a list of time differences of - some modalities (e.g. delta_timestamps={"observation.image": [-0.8, -0.2, 0, 0.2]}), this function computes for each - given modality (e.g. "observation.image") a list of query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest - frames in the dataset. - - Importantly, when no frame can be found around a query timestamp within a specified tolerance window, this function - raises an AssertionError. When a timestamp is queried before the first available timestamp of the episode or after - the last available timestamp, the violation of the tolerance doesnt raise an AssertionError, and the function - populates a boolean array indicating which frames are outside of the episode range. For instance, this boolean array - is useful during batched training to not supervise actions associated to timestamps coming after the end of the - episode, or to pad the observations in a specific way. Note that by default the observation frames before the start - of the episode are the same as the first frame of the episode. - - Parameters: - - item (dict): A dictionary containing all the data related to a frame. It is the result of `dataset[idx]`. Each key - corresponds to a different modality (e.g., "timestamp", "observation.image", "action"). - - hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different - modality (e.g., "timestamp", "observation.image", "action"). - - episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices. - They indicate the start index and end index of each episode in the dataset. - - delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be - retrieved. These deltas are added to the item timestamp to form the query timestamps. - - tolerance_s (float, optional): The tolerance level (in seconds) used to determine if a data point is close enough to the query - timestamp by asserting `tol > difference`. It is suggested to set `tol` to a smaller value than the - smallest expected inter-frame period, but large enough to account for jitter. - - Returns: - - The same item with the queried frames for each modality specified in delta_timestamps, with an additional key for - each modality (e.g. "observation.image_is_pad"). - - Raises: - - AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization - issues with timestamps during data collection. - """ - # get indices of the frames associated to the episode, and their timestamps - ep_id = item["episode_index"].item() - ep_data_id_from = episode_data_index["from"][ep_id].item() - ep_data_id_to = episode_data_index["to"][ep_id].item() - ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1) - - # load timestamps - ep_timestamps = hf_dataset.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"] - ep_timestamps = torch.stack(ep_timestamps) - - # we make the assumption that the timestamps are sorted - ep_first_ts = ep_timestamps[0] - ep_last_ts = ep_timestamps[-1] - current_ts = item["timestamp"].item() - - for key in delta_timestamps: - # get timestamps used as query to retrieve data of previous/future frames - delta_ts = delta_timestamps[key] - query_ts = current_ts + torch.tensor(delta_ts) - - # compute distances between each query timestamp and all timestamps of all the frames belonging to the episode - dist = torch.cdist(query_ts[:, None], ep_timestamps[:, None], p=1) - min_, argmin_ = dist.min(1) - - # TODO(rcadene): synchronize timestamps + interpolation if needed - - is_pad = min_ > tolerance_s - - # check violated query timestamps are all outside the episode range - assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), ( - f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tolerance_s=}) inside episode range." - "This might be due to synchronization issues with timestamps during data collection." - ) - - # get dataset indices corresponding to frames to be loaded - data_ids = ep_data_ids[argmin_] - - # load frames modality - item[key] = hf_dataset.select_columns(key)[data_ids][key] - - if isinstance(item[key][0], dict) and "path" in item[key][0]: - # video mode where frame are expressed as dict of path and timestamp - item[key] = item[key] - else: - item[key] = torch.stack(item[key]) - - item[f"{key}_is_pad"] = is_pad - - return item - - # TODO(aliberts): remove def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]: """