Add closest timestamp matching

This commit is contained in:
Cadene 2024-05-01 13:53:04 +00:00
parent 88ff197453
commit 69acb6d266
4 changed files with 65 additions and 45 deletions

View File

@ -79,6 +79,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
def num_episodes(self) -> int:
return len(self.hf_dataset.unique("episode_index"))
@property
def tolerance_s(self) -> float:
# to account for possible numerical error
return 1e-4
def __len__(self):
return self.num_samples
@ -91,11 +96,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.hf_dataset,
self.episode_data_index,
self.delta_timestamps,
tol=1 / self.fps - 1e-4, # 1e-4 to account for possible numerical error
self.tolerance_s,
)
if self.video:
item = load_from_videos(item, self.video_frame_keys, self.videos_dir)
item = load_from_videos(
item,
self.video_frame_keys,
self.videos_dir,
self.tolerance_s,
)
if self.transform is not None:
item = self.transform(item)

View File

@ -30,7 +30,7 @@ def get_stats_einops_patterns(dataset: LeRobotDataset | datasets.Dataset):
# sanity check that tensors are not float64
assert batch[key].dtype != torch.float64
if isinstance(feats_type, VideoFrame, Image):
if isinstance(feats_type, (VideoFrame, Image)):
# sanity check that images are channel first
_, c, h, w = batch[key].shape
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
@ -70,7 +70,7 @@ def compute_stats(dataset: LeRobotDataset | datasets.Dataset, batch_size=32, max
generator.manual_seed(seed)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
num_workers=8,
batch_size=batch_size,
shuffle=True,
drop_last=False,

View File

@ -140,12 +140,13 @@ def load_previous_and_future_frames(
hf_dataset: datasets.Dataset,
episode_data_index: dict[str, torch.Tensor],
delta_timestamps: dict[str, list[float]],
tol: float,
tolerance_s: float,
) -> dict[torch.Tensor]:
"""
Given a current item in the dataset containing a timestamp (e.g. 0.6 seconds), and a list of time differences of
some modalities (e.g. delta_timestamps={"observation.image": [-0.8, -0.2, 0, 0.2]}), this function computes for each
given modality a list of query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames in the dataset.
given modality (e.g. "observation.image") a list of query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest
frames in the dataset.
Importantly, when no frame can be found around a query timestamp within a specified tolerance window, this function
raises an AssertionError. When a timestamp is queried before the first available timestamp of the episode or after
@ -164,7 +165,7 @@ def load_previous_and_future_frames(
They indicate the start index and end index of each episode in the dataset.
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be
retrieved. These deltas are added to the item timestamp to form the query timestamps.
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query
- tolerance_s (float, optional): The tolerance level (in seconds) used to determine if a data point is close enough to the query
timestamp by asserting `tol > difference`. It is suggested to set `tol` to a smaller value than the
smallest expected inter-frame period, but large enough to account for jitter.
@ -202,11 +203,11 @@ def load_previous_and_future_frames(
# TODO(rcadene): synchronize timestamps + interpolation if needed
is_pad = min_ > tol
is_pad = min_ > tolerance_s
# check violated query timestamps are all outside the episode range
assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), (
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=}) inside episode range."
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tolerance_s=}) inside episode range."
"This might be due to synchronization issues with timestamps during data collection."
)

View File

@ -10,41 +10,40 @@ import torchvision
from datasets.features.features import register_feature
def load_from_videos(item, video_frame_keys, videos_dir):
def load_from_videos(
item: dict[str, torch.Tensor], video_frame_keys: list[str], videos_dir: Path, tolerance_s: float
):
# since video path already contains "videos" (e.g. videos_dir="data/videos", path="videos/episode_0.mp4")
data_dir = videos_dir.parent
for key in video_frame_keys:
ep_idx = item["episode_index"]
video_path = data_dir / key / f"episode_{ep_idx:06d}.mp4"
if isinstance(item[key], list):
# load multiple frames at once
# load multiple frames at once (expected when delta_timestamps is not None)
timestamps = [frame["timestamp"] for frame in item[key]]
paths = [frame["path"] for frame in item[key]]
if len(set(paths)) == 1:
raise NotImplementedError("All video paths are expected to be the same for now.")
video_path = data_dir / paths[0]
frames = decode_video_frames_torchvision(video_path, timestamps)
assert len(frames) == len(timestamps)
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s)
item[key] = frames
else:
# load one frame
timestamps = [item[key]["timestamp"]]
video_path = data_dir / item[key]["path"]
frames = decode_video_frames_torchvision(video_path, timestamps)
assert len(frames) == 1
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s)
item[key] = frames[0]
return item
def decode_video_frames_torchvision(
video_path: str, timestamps: list[float], device: str = "cpu", log_loaded_timestamps: bool = False
video_path: str,
timestamps: list[float],
tolerance_s: float,
device: str = "cpu",
log_loaded_timestamps: bool = False,
):
"""Loads frames associated to the requested timestamps of a video
@ -85,40 +84,50 @@ def decode_video_frames_torchvision(
first_ts = timestamps[0]
last_ts = timestamps[-1]
# access key frame of first requested frame, and load all frames until last requested frame
# access closest key frame of the first requested frame
# Note: closest key frame timestamp is usally smaller than `first_ts` (e.g. key frame can be the first frame of the video)
# for details on what `seek` is doing see: https://pyav.basswood-io.com/docs/stable/api/container.html?highlight=inputcontainer#av.container.InputContainer.seek
reader.seek(first_ts)
frames = []
# load all frames until last requested frame
loaded_frames = []
loaded_ts = []
for frame in reader:
# get timestamp of the loaded frame
ts = round_timestamp(frame["pts"])
# if the loaded frame is not among the requested frames, we dont add it to the list of output frames
is_frame_requested = ts in timestamps
if is_frame_requested:
frames.append(frame["data"])
current_ts = frame["pts"]
if log_loaded_timestamps:
log = f"frame loaded at timestamp={ts:.4f}"
if is_frame_requested:
log += " requested"
logging.info(log)
if len(timestamps) == len(frames):
logging.info(f"frame loaded at timestamp={current_ts:.4f}")
loaded_frames.append(frame["data"])
loaded_ts.append(current_ts)
if current_ts >= last_ts:
break
# hard stop
assert (
frame["pts"] >= last_ts
), f"Not enough frames have been loaded in [{first_ts}, {last_ts}]. {len(timestamps)} expected, but only {len(frames)} loaded."
query_ts = torch.tensor(timestamps)
loaded_ts = torch.tensor(loaded_ts)
frames = torch.stack(frames)
# compute distances between each query timestamp and timestamps of all loaded frames
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
min_, argmin_ = dist.min(1)
is_within_tol = min_ < tolerance_s
assert is_within_tol.all(), (
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
"It means that the closest frame that can be loaded from the video is too far away in time."
"This might be due to synchronization issues with timestamps during data collection."
"To be safe, we advise to ignore this item during training."
)
# get closest frames to the query timestamps
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
closest_ts = loaded_ts[argmin_]
if log_loaded_timestamps:
logging.info(f"{closest_ts=}")
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
frames = frames.type(torch.float32) / 255
closest_frames = closest_frames.type(torch.float32) / 255
assert len(timestamps) == frames.shape[0]
return frames
assert len(timestamps) == len(closest_frames)
return closest_frames
def encode_video_frames(imgs_dir: Path, video_path: Path, fps: int):