Remove obsolete code

This commit is contained in:
Simon Alibert 2024-10-31 21:43:57 +01:00
parent 443a9eec88
commit 5ea7c78237
1 changed files with 0 additions and 94 deletions

View File

@ -368,100 +368,6 @@ def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dic
return delta_indices
# TODO(aliberts): remove
def load_previous_and_future_frames(
item: dict[str, torch.Tensor],
hf_dataset: datasets.Dataset,
episode_data_index: dict[str, torch.Tensor],
delta_timestamps: dict[str, list[float]],
tolerance_s: float,
) -> dict[torch.Tensor]:
"""
Given a current item in the dataset containing a timestamp (e.g. 0.6 seconds), and a list of time differences of
some modalities (e.g. delta_timestamps={"observation.image": [-0.8, -0.2, 0, 0.2]}), this function computes for each
given modality (e.g. "observation.image") a list of query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest
frames in the dataset.
Importantly, when no frame can be found around a query timestamp within a specified tolerance window, this function
raises an AssertionError. When a timestamp is queried before the first available timestamp of the episode or after
the last available timestamp, the violation of the tolerance doesnt raise an AssertionError, and the function
populates a boolean array indicating which frames are outside of the episode range. For instance, this boolean array
is useful during batched training to not supervise actions associated to timestamps coming after the end of the
episode, or to pad the observations in a specific way. Note that by default the observation frames before the start
of the episode are the same as the first frame of the episode.
Parameters:
- item (dict): A dictionary containing all the data related to a frame. It is the result of `dataset[idx]`. Each key
corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
- hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different
modality (e.g., "timestamp", "observation.image", "action").
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
They indicate the start index and end index of each episode in the dataset.
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be
retrieved. These deltas are added to the item timestamp to form the query timestamps.
- tolerance_s (float, optional): The tolerance level (in seconds) used to determine if a data point is close enough to the query
timestamp by asserting `tol > difference`. It is suggested to set `tol` to a smaller value than the
smallest expected inter-frame period, but large enough to account for jitter.
Returns:
- The same item with the queried frames for each modality specified in delta_timestamps, with an additional key for
each modality (e.g. "observation.image_is_pad").
Raises:
- AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization
issues with timestamps during data collection.
"""
# get indices of the frames associated to the episode, and their timestamps
ep_id = item["episode_index"].item()
ep_data_id_from = episode_data_index["from"][ep_id].item()
ep_data_id_to = episode_data_index["to"][ep_id].item()
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
# load timestamps
ep_timestamps = hf_dataset.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"]
ep_timestamps = torch.stack(ep_timestamps)
# we make the assumption that the timestamps are sorted
ep_first_ts = ep_timestamps[0]
ep_last_ts = ep_timestamps[-1]
current_ts = item["timestamp"].item()
for key in delta_timestamps:
# get timestamps used as query to retrieve data of previous/future frames
delta_ts = delta_timestamps[key]
query_ts = current_ts + torch.tensor(delta_ts)
# compute distances between each query timestamp and all timestamps of all the frames belonging to the episode
dist = torch.cdist(query_ts[:, None], ep_timestamps[:, None], p=1)
min_, argmin_ = dist.min(1)
# TODO(rcadene): synchronize timestamps + interpolation if needed
is_pad = min_ > tolerance_s
# check violated query timestamps are all outside the episode range
assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), (
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tolerance_s=}) inside episode range."
"This might be due to synchronization issues with timestamps during data collection."
)
# get dataset indices corresponding to frames to be loaded
data_ids = ep_data_ids[argmin_]
# load frames modality
item[key] = hf_dataset.select_columns(key)[data_ids][key]
if isinstance(item[key][0], dict) and "path" in item[key][0]:
# video mode where frame are expressed as dict of path and timestamp
item[key] = item[key]
else:
item[key] = torch.stack(item[key])
item[f"{key}_is_pad"] = is_pad
return item
# TODO(aliberts): remove
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
"""