[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-03-03 06:33:38 +00:00
parent 4e2dc91e59
commit e8126dc3d6
2 changed files with 16 additions and 15 deletions

View File

@ -707,9 +707,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
item = {} item = {}
for vid_key, query_ts in query_timestamps.items(): for vid_key, query_ts in query_timestamps.items():
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key) video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
frames = decode_video_frames_torchcodec( frames = decode_video_frames_torchcodec(video_path, query_ts, self.tolerance_s)
video_path, query_ts, self.tolerance_s
)
item[vid_key] = frames.squeeze(0) item[vid_key] = frames.squeeze(0)
return item return item

View File

@ -29,6 +29,7 @@ from datasets.features.features import register_feature
from PIL import Image from PIL import Image
from torchcodec.decoders import VideoDecoder from torchcodec.decoders import VideoDecoder
def decode_video_frames_torchvision( def decode_video_frames_torchvision(
video_path: Path | str, video_path: Path | str,
timestamps: list[float], timestamps: list[float],
@ -125,7 +126,8 @@ def decode_video_frames_torchvision(
assert len(timestamps) == len(closest_frames) assert len(timestamps) == len(closest_frames)
return closest_frames return closest_frames
def decode_video_frames_torchcodec( def decode_video_frames_torchcodec(
video_path: Path | str, video_path: Path | str,
timestamps: list[float], timestamps: list[float],
@ -142,26 +144,26 @@ def decode_video_frames_torchcodec(
# get metadata for frame information # get metadata for frame information
metadata = decoder.metadata metadata = decoder.metadata
average_fps = metadata.average_fps average_fps = metadata.average_fps
# convert timestamps to frame indices # convert timestamps to frame indices
frame_indices = [round(ts * average_fps) for ts in timestamps] frame_indices = [round(ts * average_fps) for ts in timestamps]
# retrieve frames based on indices # retrieve frames based on indices
frames_batch = decoder.get_frames_at(indices=frame_indices) frames_batch = decoder.get_frames_at(indices=frame_indices)
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds): for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=False):
loaded_frames.append(frame) loaded_frames.append(frame)
loaded_ts.append(pts.item()) loaded_ts.append(pts.item())
if log_loaded_timestamps: if log_loaded_timestamps:
logging.info(f"Frame loaded at timestamp={pts:.4f}") logging.info(f"Frame loaded at timestamp={pts:.4f}")
query_ts = torch.tensor(timestamps) query_ts = torch.tensor(timestamps)
loaded_ts = torch.tensor(loaded_ts) loaded_ts = torch.tensor(loaded_ts)
# compute distances between each query timestamp and loaded timestamps # compute distances between each query timestamp and loaded timestamps
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1) dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
min_, argmin_ = dist.min(1) min_, argmin_ = dist.min(1)
is_within_tol = min_ < tolerance_s is_within_tol = min_ < tolerance_s
assert is_within_tol.all(), ( assert is_within_tol.all(), (
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})." f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
@ -172,20 +174,21 @@ def decode_video_frames_torchcodec(
f"\nloaded timestamps: {loaded_ts}" f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}" f"\nvideo: {video_path}"
) )
# get closest frames to the query timestamps # get closest frames to the query timestamps
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_]) closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
closest_ts = loaded_ts[argmin_] closest_ts = loaded_ts[argmin_]
if log_loaded_timestamps: if log_loaded_timestamps:
logging.info(f"{closest_ts=}") logging.info(f"{closest_ts=}")
# convert to float32 in [0,1] range (channel first) # convert to float32 in [0,1] range (channel first)
closest_frames = closest_frames.type(torch.float32) / 255 closest_frames = closest_frames.type(torch.float32) / 255
assert len(timestamps) == len(closest_frames) assert len(timestamps) == len(closest_frames)
return closest_frames return closest_frames
def encode_video_frames( def encode_video_frames(
imgs_dir: Path | str, imgs_dir: Path | str,
video_path: Path | str, video_path: Path | str,