Add closest timestamp matching
This commit is contained in:
parent
88ff197453
commit
69acb6d266
|
@ -79,6 +79,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
def num_episodes(self) -> int:
|
||||
return len(self.hf_dataset.unique("episode_index"))
|
||||
|
||||
@property
|
||||
def tolerance_s(self) -> float:
|
||||
# to account for possible numerical error
|
||||
return 1e-4
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
|
@ -91,11 +96,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.hf_dataset,
|
||||
self.episode_data_index,
|
||||
self.delta_timestamps,
|
||||
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
|
||||
self.tolerance_s,
|
||||
)
|
||||
|
||||
if self.video:
|
||||
item = load_from_videos(item, self.video_frame_keys, self.videos_dir)
|
||||
item = load_from_videos(
|
||||
item,
|
||||
self.video_frame_keys,
|
||||
self.videos_dir,
|
||||
self.tolerance_s,
|
||||
)
|
||||
|
||||
if self.transform is not None:
|
||||
item = self.transform(item)
|
||||
|
|
|
@ -30,7 +30,7 @@ def get_stats_einops_patterns(dataset: LeRobotDataset | datasets.Dataset):
|
|||
# sanity check that tensors are not float64
|
||||
assert batch[key].dtype != torch.float64
|
||||
|
||||
if isinstance(feats_type, VideoFrame, Image):
|
||||
if isinstance(feats_type, (VideoFrame, Image)):
|
||||
# sanity check that images are channel first
|
||||
_, c, h, w = batch[key].shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
|
||||
|
@ -70,7 +70,7 @@ def compute_stats(dataset: LeRobotDataset | datasets.Dataset, batch_size=32, max
|
|||
generator.manual_seed(seed)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=0,
|
||||
num_workers=8,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
drop_last=False,
|
||||
|
|
|
@ -140,12 +140,13 @@ def load_previous_and_future_frames(
|
|||
hf_dataset: datasets.Dataset,
|
||||
episode_data_index: dict[str, torch.Tensor],
|
||||
delta_timestamps: dict[str, list[float]],
|
||||
tol: float,
|
||||
tolerance_s: float,
|
||||
) -> dict[torch.Tensor]:
|
||||
"""
|
||||
Given a current item in the dataset containing a timestamp (e.g. 0.6 seconds), and a list of time differences of
|
||||
some modalities (e.g. delta_timestamps={"observation.image": [-0.8, -0.2, 0, 0.2]}), this function computes for each
|
||||
given modality a list of query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames in the dataset.
|
||||
given modality (e.g. "observation.image") a list of query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest
|
||||
frames in the dataset.
|
||||
|
||||
Importantly, when no frame can be found around a query timestamp within a specified tolerance window, this function
|
||||
raises an AssertionError. When a timestamp is queried before the first available timestamp of the episode or after
|
||||
|
@ -164,7 +165,7 @@ def load_previous_and_future_frames(
|
|||
They indicate the start index and end index of each episode in the dataset.
|
||||
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be
|
||||
retrieved. These deltas are added to the item timestamp to form the query timestamps.
|
||||
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query
|
||||
- tolerance_s (float, optional): The tolerance level (in seconds) used to determine if a data point is close enough to the query
|
||||
timestamp by asserting `tol > difference`. It is suggested to set `tol` to a smaller value than the
|
||||
smallest expected inter-frame period, but large enough to account for jitter.
|
||||
|
||||
|
@ -202,11 +203,11 @@ def load_previous_and_future_frames(
|
|||
|
||||
# TODO(rcadene): synchronize timestamps + interpolation if needed
|
||||
|
||||
is_pad = min_ > tol
|
||||
is_pad = min_ > tolerance_s
|
||||
|
||||
# check violated query timestamps are all outside the episode range
|
||||
assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), (
|
||||
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=}) inside episode range."
|
||||
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tolerance_s=}) inside episode range."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
)
|
||||
|
||||
|
|
|
@ -10,41 +10,40 @@ import torchvision
|
|||
from datasets.features.features import register_feature
|
||||
|
||||
|
||||
def load_from_videos(item, video_frame_keys, videos_dir):
|
||||
def load_from_videos(
|
||||
item: dict[str, torch.Tensor], video_frame_keys: list[str], videos_dir: Path, tolerance_s: float
|
||||
):
|
||||
# since video path already contains "videos" (e.g. videos_dir="data/videos", path="videos/episode_0.mp4")
|
||||
data_dir = videos_dir.parent
|
||||
|
||||
for key in video_frame_keys:
|
||||
ep_idx = item["episode_index"]
|
||||
video_path = data_dir / key / f"episode_{ep_idx:06d}.mp4"
|
||||
|
||||
if isinstance(item[key], list):
|
||||
# load multiple frames at once
|
||||
# load multiple frames at once (expected when delta_timestamps is not None)
|
||||
timestamps = [frame["timestamp"] for frame in item[key]]
|
||||
paths = [frame["path"] for frame in item[key]]
|
||||
if len(set(paths)) == 1:
|
||||
raise NotImplementedError("All video paths are expected to be the same for now.")
|
||||
video_path = data_dir / paths[0]
|
||||
|
||||
frames = decode_video_frames_torchvision(video_path, timestamps)
|
||||
assert len(frames) == len(timestamps)
|
||||
|
||||
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s)
|
||||
item[key] = frames
|
||||
else:
|
||||
# load one frame
|
||||
timestamps = [item[key]["timestamp"]]
|
||||
video_path = data_dir / item[key]["path"]
|
||||
|
||||
frames = decode_video_frames_torchvision(video_path, timestamps)
|
||||
assert len(frames) == 1
|
||||
|
||||
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s)
|
||||
item[key] = frames[0]
|
||||
|
||||
return item
|
||||
|
||||
|
||||
def decode_video_frames_torchvision(
|
||||
video_path: str, timestamps: list[float], device: str = "cpu", log_loaded_timestamps: bool = False
|
||||
video_path: str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
device: str = "cpu",
|
||||
log_loaded_timestamps: bool = False,
|
||||
):
|
||||
"""Loads frames associated to the requested timestamps of a video
|
||||
|
||||
|
@ -85,40 +84,50 @@ def decode_video_frames_torchvision(
|
|||
first_ts = timestamps[0]
|
||||
last_ts = timestamps[-1]
|
||||
|
||||
# access key frame of first requested frame, and load all frames until last requested frame
|
||||
# access closest key frame of the first requested frame
|
||||
# Note: closest key frame timestamp is usally smaller than `first_ts` (e.g. key frame can be the first frame of the video)
|
||||
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
|
||||
reader.seek(first_ts)
|
||||
frames = []
|
||||
|
||||
# load all frames until last requested frame
|
||||
loaded_frames = []
|
||||
loaded_ts = []
|
||||
for frame in reader:
|
||||
# get timestamp of the loaded frame
|
||||
ts = round_timestamp(frame["pts"])
|
||||
|
||||
# if the loaded frame is not among the requested frames, we dont add it to the list of output frames
|
||||
is_frame_requested = ts in timestamps
|
||||
if is_frame_requested:
|
||||
frames.append(frame["data"])
|
||||
|
||||
current_ts = frame["pts"]
|
||||
if log_loaded_timestamps:
|
||||
log = f"frame loaded at timestamp={ts:.4f}"
|
||||
if is_frame_requested:
|
||||
log += " requested"
|
||||
logging.info(log)
|
||||
|
||||
if len(timestamps) == len(frames):
|
||||
logging.info(f"frame loaded at timestamp={current_ts:.4f}")
|
||||
loaded_frames.append(frame["data"])
|
||||
loaded_ts.append(current_ts)
|
||||
if current_ts >= last_ts:
|
||||
break
|
||||
|
||||
# hard stop
|
||||
assert (
|
||||
frame["pts"] >= last_ts
|
||||
), f"Not enough frames have been loaded in [{first_ts}, {last_ts}]. {len(timestamps)} expected, but only {len(frames)} loaded."
|
||||
query_ts = torch.tensor(timestamps)
|
||||
loaded_ts = torch.tensor(loaded_ts)
|
||||
|
||||
frames = torch.stack(frames)
|
||||
# compute distances between each query timestamp and timestamps of all loaded frames
|
||||
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
is_within_tol = min_ < tolerance_s
|
||||
assert is_within_tol.all(), (
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
"It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
"To be safe, we advise to ignore this item during training."
|
||||
)
|
||||
|
||||
# get closest frames to the query timestamps
|
||||
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||
closest_ts = loaded_ts[argmin_]
|
||||
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"{closest_ts=}")
|
||||
|
||||
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
||||
frames = frames.type(torch.float32) / 255
|
||||
closest_frames = closest_frames.type(torch.float32) / 255
|
||||
|
||||
assert len(timestamps) == frames.shape[0]
|
||||
return frames
|
||||
assert len(timestamps) == len(closest_frames)
|
||||
return closest_frames
|
||||
|
||||
|
||||
def encode_video_frames(imgs_dir: Path, video_path: Path, fps: int):
|
||||
|
|
Loading…
Reference in New Issue