[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
4e2dc91e59
commit
e8126dc3d6
|
@ -707,9 +707,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
item = {}
|
item = {}
|
||||||
for vid_key, query_ts in query_timestamps.items():
|
for vid_key, query_ts in query_timestamps.items():
|
||||||
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
||||||
frames = decode_video_frames_torchcodec(
|
frames = decode_video_frames_torchcodec(video_path, query_ts, self.tolerance_s)
|
||||||
video_path, query_ts, self.tolerance_s
|
|
||||||
)
|
|
||||||
item[vid_key] = frames.squeeze(0)
|
item[vid_key] = frames.squeeze(0)
|
||||||
|
|
||||||
return item
|
return item
|
||||||
|
|
|
@ -29,6 +29,7 @@ from datasets.features.features import register_feature
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchcodec.decoders import VideoDecoder
|
from torchcodec.decoders import VideoDecoder
|
||||||
|
|
||||||
|
|
||||||
def decode_video_frames_torchvision(
|
def decode_video_frames_torchvision(
|
||||||
video_path: Path | str,
|
video_path: Path | str,
|
||||||
timestamps: list[float],
|
timestamps: list[float],
|
||||||
|
@ -125,7 +126,8 @@ def decode_video_frames_torchvision(
|
||||||
|
|
||||||
assert len(timestamps) == len(closest_frames)
|
assert len(timestamps) == len(closest_frames)
|
||||||
return closest_frames
|
return closest_frames
|
||||||
|
|
||||||
|
|
||||||
def decode_video_frames_torchcodec(
|
def decode_video_frames_torchcodec(
|
||||||
video_path: Path | str,
|
video_path: Path | str,
|
||||||
timestamps: list[float],
|
timestamps: list[float],
|
||||||
|
@ -142,26 +144,26 @@ def decode_video_frames_torchcodec(
|
||||||
# get metadata for frame information
|
# get metadata for frame information
|
||||||
metadata = decoder.metadata
|
metadata = decoder.metadata
|
||||||
average_fps = metadata.average_fps
|
average_fps = metadata.average_fps
|
||||||
|
|
||||||
# convert timestamps to frame indices
|
# convert timestamps to frame indices
|
||||||
frame_indices = [round(ts * average_fps) for ts in timestamps]
|
frame_indices = [round(ts * average_fps) for ts in timestamps]
|
||||||
|
|
||||||
# retrieve frames based on indices
|
# retrieve frames based on indices
|
||||||
frames_batch = decoder.get_frames_at(indices=frame_indices)
|
frames_batch = decoder.get_frames_at(indices=frame_indices)
|
||||||
|
|
||||||
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds):
|
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=False):
|
||||||
loaded_frames.append(frame)
|
loaded_frames.append(frame)
|
||||||
loaded_ts.append(pts.item())
|
loaded_ts.append(pts.item())
|
||||||
if log_loaded_timestamps:
|
if log_loaded_timestamps:
|
||||||
logging.info(f"Frame loaded at timestamp={pts:.4f}")
|
logging.info(f"Frame loaded at timestamp={pts:.4f}")
|
||||||
|
|
||||||
query_ts = torch.tensor(timestamps)
|
query_ts = torch.tensor(timestamps)
|
||||||
loaded_ts = torch.tensor(loaded_ts)
|
loaded_ts = torch.tensor(loaded_ts)
|
||||||
|
|
||||||
# compute distances between each query timestamp and loaded timestamps
|
# compute distances between each query timestamp and loaded timestamps
|
||||||
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
|
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
|
||||||
min_, argmin_ = dist.min(1)
|
min_, argmin_ = dist.min(1)
|
||||||
|
|
||||||
is_within_tol = min_ < tolerance_s
|
is_within_tol = min_ < tolerance_s
|
||||||
assert is_within_tol.all(), (
|
assert is_within_tol.all(), (
|
||||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||||
|
@ -172,20 +174,21 @@ def decode_video_frames_torchcodec(
|
||||||
f"\nloaded timestamps: {loaded_ts}"
|
f"\nloaded timestamps: {loaded_ts}"
|
||||||
f"\nvideo: {video_path}"
|
f"\nvideo: {video_path}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# get closest frames to the query timestamps
|
# get closest frames to the query timestamps
|
||||||
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||||
closest_ts = loaded_ts[argmin_]
|
closest_ts = loaded_ts[argmin_]
|
||||||
|
|
||||||
if log_loaded_timestamps:
|
if log_loaded_timestamps:
|
||||||
logging.info(f"{closest_ts=}")
|
logging.info(f"{closest_ts=}")
|
||||||
|
|
||||||
# convert to float32 in [0,1] range (channel first)
|
# convert to float32 in [0,1] range (channel first)
|
||||||
closest_frames = closest_frames.type(torch.float32) / 255
|
closest_frames = closest_frames.type(torch.float32) / 255
|
||||||
|
|
||||||
assert len(timestamps) == len(closest_frames)
|
assert len(timestamps) == len(closest_frames)
|
||||||
return closest_frames
|
return closest_frames
|
||||||
|
|
||||||
|
|
||||||
def encode_video_frames(
|
def encode_video_frames(
|
||||||
imgs_dir: Path | str,
|
imgs_dir: Path | str,
|
||||||
video_path: Path | str,
|
video_path: Path | str,
|
||||||
|
|
Loading…
Reference in New Issue