add new video decoder method

This commit is contained in:
Jade Choghari 2025-02-20 21:13:49 +01:00
parent c6bcfb3539
commit cae49528ee
2 changed files with 96 additions and 5 deletions

View File

@ -652,6 +652,45 @@ class LeRobotDataset(torch.utils.data.Dataset):
item = self.hf_dataset[idx]
ep_idx = item["episode_index"].item()
query_indices = None
if self.delta_indices is not None:
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
query_indices, padding = self._get_query_indices(idx, current_ep_idx)
query_result = self._query_hf_dataset(query_indices)
item = {**item, **padding}
for key, val in query_result.items():
item[key] = val
if len(self.meta.video_keys) > 0:
current_ts = item["timestamp"].item()
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
# if what is returned is all the info that i used query_timestamps, episode
# percentage of chance, 30% cpu, gpu
# video_frames = self._query_videos(query_timestamps, ep_idx)
# item = {**video_frames, **item}
# jade - instead of decoding video, return video path & timestamps
# hack only add metadata
item["video_paths"] = {
vid_key: self.root / self.meta.get_video_file_path(ep_idx, vid_key)
for vid_key in query_timestamps.keys()
}
item["query_timestamps"] = query_timestamps
if self.image_transforms is not None:
breakpoint()
image_keys = self.meta.camera_keys
for cam in image_keys:
item[cam] = self.image_transforms(item[cam])
# Add task as a string
task_idx = item["task_index"].item()
item["task"] = self.meta.tasks[task_idx]
return item
def __getitem2__(self, idx) -> dict:
item = self.hf_dataset[idx]
ep_idx = item["episode_index"].item()
query_indices = None
if self.delta_indices is not None:
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
@ -677,7 +716,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
item["task"] = self.meta.tasks[task_idx]
return item
def __repr__(self):
feature_keys = list(self.features)
return (

View File

@ -23,7 +23,7 @@ import torch
from termcolor import colored
from torch.amp import GradScaler
from torch.optim import Optimizer
from pathlib import Path
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.utils import cycle
@ -51,6 +51,60 @@ from lerobot.common.utils.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.scripts.eval import eval_policy
from lerobot.common.datasets.video_utils import (
decode_video_frames_torchvision
)
# let's define a custom fn
from torchcodec.decoders import VideoDecoder
def custom_collate_fn(batch):
"""
Custom collate function that decodes videos on CPU.
Ensures batch format remains unchanged.
"""
batched_frames = {} # Dictionary to hold video tensors
final_batch = {} # Dictionary to hold the rest of the batch
# Initialize final_batch with all original keys (except video paths)
for key in batch[0].keys():
if key not in ["video_paths", "query_timestamps"]: # Skip video-related fields
final_batch[key] = [item[key] for item in batch]
# Process video decoding
for item in batch:
if "video_paths" in item and "query_timestamps" in item:
for vid_key, video_path in item["video_paths"].items():
decoder = VideoDecoder(str(video_path), device="cpu") # CPU decoding
# frames = decoder.get_frames_played_at(item["query_timestamps"][vid_key]).data.float() / 255
timestamps = item["query_timestamps"][vid_key]
frames = decode_video_frames_torchvision(
video_path=Path(video_path),
timestamps=timestamps,
tolerance_s=0.02, # Adjust tolerance if needed
backend="pyav", # Default backend (modify if needed)
log_loaded_timestamps=False,
)
if vid_key not in batched_frames:
batched_frames[vid_key] = []
batched_frames[vid_key].append(frames)
# Convert lists to tensors where possible
for key in batched_frames:
batched_frames[key] = torch.stack(batched_frames[key]) # Stack tensors
for key in final_batch:
if isinstance(final_batch[key][0], torch.Tensor):
final_batch[key] = torch.stack(final_batch[key])
# **Fix: Ensure video_frames is a single tensor instead of a dictionary**
# hard coded this must change
if len(batched_frames) == 1:
final_batch["observation.images.top"] = list(batched_frames.values())[0] # Direct tensor
else:
final_batch["observation.images.top"] = batched_frames # Keep dict if multiple
return final_batch
def update_policy(
@ -182,12 +236,11 @@ def train(cfg: TrainPipelineConfig):
shuffle=shuffle,
sampler=sampler,
pin_memory=device.type != "cpu",
collate_fn=custom_collate_fn,
drop_last=False,
)
dl_iter = cycle(dataloader)
policy.train()
train_metrics = {
"loss": AverageMeter("loss", ":.3f"),
"grad_norm": AverageMeter("grdn", ":.3f"),
@ -205,7 +258,6 @@ def train(cfg: TrainPipelineConfig):
start_time = time.perf_counter()
batch = next(dl_iter)
train_tracker.dataloading_s = time.perf_counter() - start_time
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(device, non_blocking=True)
@ -231,6 +283,7 @@ def train(cfg: TrainPipelineConfig):
if is_log_step:
logging.info(train_tracker)
breakpoint()
if wandb_logger:
wandb_log_dict = train_tracker.to_dict()
if output_dict: