Remove obsolete code
This commit is contained in:
parent
443a9eec88
commit
5ea7c78237
|
@ -368,100 +368,6 @@ def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dic
|
||||||
return delta_indices
|
return delta_indices
|
||||||
|
|
||||||
|
|
||||||
# TODO(aliberts): remove
|
|
||||||
def load_previous_and_future_frames(
|
|
||||||
item: dict[str, torch.Tensor],
|
|
||||||
hf_dataset: datasets.Dataset,
|
|
||||||
episode_data_index: dict[str, torch.Tensor],
|
|
||||||
delta_timestamps: dict[str, list[float]],
|
|
||||||
tolerance_s: float,
|
|
||||||
) -> dict[torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Given a current item in the dataset containing a timestamp (e.g. 0.6 seconds), and a list of time differences of
|
|
||||||
some modalities (e.g. delta_timestamps={"observation.image": [-0.8, -0.2, 0, 0.2]}), this function computes for each
|
|
||||||
given modality (e.g. "observation.image") a list of query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest
|
|
||||||
frames in the dataset.
|
|
||||||
|
|
||||||
Importantly, when no frame can be found around a query timestamp within a specified tolerance window, this function
|
|
||||||
raises an AssertionError. When a timestamp is queried before the first available timestamp of the episode or after
|
|
||||||
the last available timestamp, the violation of the tolerance doesnt raise an AssertionError, and the function
|
|
||||||
populates a boolean array indicating which frames are outside of the episode range. For instance, this boolean array
|
|
||||||
is useful during batched training to not supervise actions associated to timestamps coming after the end of the
|
|
||||||
episode, or to pad the observations in a specific way. Note that by default the observation frames before the start
|
|
||||||
of the episode are the same as the first frame of the episode.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- item (dict): A dictionary containing all the data related to a frame. It is the result of `dataset[idx]`. Each key
|
|
||||||
corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
|
|
||||||
- hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different
|
|
||||||
modality (e.g., "timestamp", "observation.image", "action").
|
|
||||||
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
|
|
||||||
They indicate the start index and end index of each episode in the dataset.
|
|
||||||
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be
|
|
||||||
retrieved. These deltas are added to the item timestamp to form the query timestamps.
|
|
||||||
- tolerance_s (float, optional): The tolerance level (in seconds) used to determine if a data point is close enough to the query
|
|
||||||
timestamp by asserting `tol > difference`. It is suggested to set `tol` to a smaller value than the
|
|
||||||
smallest expected inter-frame period, but large enough to account for jitter.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- The same item with the queried frames for each modality specified in delta_timestamps, with an additional key for
|
|
||||||
each modality (e.g. "observation.image_is_pad").
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
- AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization
|
|
||||||
issues with timestamps during data collection.
|
|
||||||
"""
|
|
||||||
# get indices of the frames associated to the episode, and their timestamps
|
|
||||||
ep_id = item["episode_index"].item()
|
|
||||||
ep_data_id_from = episode_data_index["from"][ep_id].item()
|
|
||||||
ep_data_id_to = episode_data_index["to"][ep_id].item()
|
|
||||||
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
|
|
||||||
|
|
||||||
# load timestamps
|
|
||||||
ep_timestamps = hf_dataset.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"]
|
|
||||||
ep_timestamps = torch.stack(ep_timestamps)
|
|
||||||
|
|
||||||
# we make the assumption that the timestamps are sorted
|
|
||||||
ep_first_ts = ep_timestamps[0]
|
|
||||||
ep_last_ts = ep_timestamps[-1]
|
|
||||||
current_ts = item["timestamp"].item()
|
|
||||||
|
|
||||||
for key in delta_timestamps:
|
|
||||||
# get timestamps used as query to retrieve data of previous/future frames
|
|
||||||
delta_ts = delta_timestamps[key]
|
|
||||||
query_ts = current_ts + torch.tensor(delta_ts)
|
|
||||||
|
|
||||||
# compute distances between each query timestamp and all timestamps of all the frames belonging to the episode
|
|
||||||
dist = torch.cdist(query_ts[:, None], ep_timestamps[:, None], p=1)
|
|
||||||
min_, argmin_ = dist.min(1)
|
|
||||||
|
|
||||||
# TODO(rcadene): synchronize timestamps + interpolation if needed
|
|
||||||
|
|
||||||
is_pad = min_ > tolerance_s
|
|
||||||
|
|
||||||
# check violated query timestamps are all outside the episode range
|
|
||||||
assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), (
|
|
||||||
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tolerance_s=}) inside episode range."
|
|
||||||
"This might be due to synchronization issues with timestamps during data collection."
|
|
||||||
)
|
|
||||||
|
|
||||||
# get dataset indices corresponding to frames to be loaded
|
|
||||||
data_ids = ep_data_ids[argmin_]
|
|
||||||
|
|
||||||
# load frames modality
|
|
||||||
item[key] = hf_dataset.select_columns(key)[data_ids][key]
|
|
||||||
|
|
||||||
if isinstance(item[key][0], dict) and "path" in item[key][0]:
|
|
||||||
# video mode where frame are expressed as dict of path and timestamp
|
|
||||||
item[key] = item[key]
|
|
||||||
else:
|
|
||||||
item[key] = torch.stack(item[key])
|
|
||||||
|
|
||||||
item[f"{key}_is_pad"] = is_pad
|
|
||||||
|
|
||||||
return item
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(aliberts): remove
|
# TODO(aliberts): remove
|
||||||
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
|
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue