add new video decoder method

This commit is contained in:
Jade Choghari 2025-02-20 21:13:49 +01:00
parent c6bcfb3539
commit cae49528ee
2 changed files with 96 additions and 5 deletions

View File

@ -652,6 +652,45 @@ class LeRobotDataset(torch.utils.data.Dataset):
item = self.hf_dataset[idx] item = self.hf_dataset[idx]
ep_idx = item["episode_index"].item() ep_idx = item["episode_index"].item()
query_indices = None
if self.delta_indices is not None:
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
query_indices, padding = self._get_query_indices(idx, current_ep_idx)
query_result = self._query_hf_dataset(query_indices)
item = {**item, **padding}
for key, val in query_result.items():
item[key] = val
if len(self.meta.video_keys) > 0:
current_ts = item["timestamp"].item()
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
# if what is returned is all the info that i used query_timestamps, episode
# percentage of chance, 30% cpu, gpu
# video_frames = self._query_videos(query_timestamps, ep_idx)
# item = {**video_frames, **item}
# jade - instead of decoding video, return video path & timestamps
# hack only add metadata
item["video_paths"] = {
vid_key: self.root / self.meta.get_video_file_path(ep_idx, vid_key)
for vid_key in query_timestamps.keys()
}
item["query_timestamps"] = query_timestamps
if self.image_transforms is not None:
breakpoint()
image_keys = self.meta.camera_keys
for cam in image_keys:
item[cam] = self.image_transforms(item[cam])
# Add task as a string
task_idx = item["task_index"].item()
item["task"] = self.meta.tasks[task_idx]
return item
def __getitem2__(self, idx) -> dict:
item = self.hf_dataset[idx]
ep_idx = item["episode_index"].item()
query_indices = None query_indices = None
if self.delta_indices is not None: if self.delta_indices is not None:
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
@ -677,7 +716,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
item["task"] = self.meta.tasks[task_idx] item["task"] = self.meta.tasks[task_idx]
return item return item
def __repr__(self): def __repr__(self):
feature_keys = list(self.features) feature_keys = list(self.features)
return ( return (

View File

@ -23,7 +23,7 @@ import torch
from termcolor import colored from termcolor import colored
from torch.amp import GradScaler from torch.amp import GradScaler
from torch.optim import Optimizer from torch.optim import Optimizer
from pathlib import Path
from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.sampler import EpisodeAwareSampler from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.utils import cycle from lerobot.common.datasets.utils import cycle
@ -51,6 +51,60 @@ from lerobot.common.utils.wandb_utils import WandBLogger
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.train import TrainPipelineConfig
from lerobot.scripts.eval import eval_policy from lerobot.scripts.eval import eval_policy
from lerobot.common.datasets.video_utils import (
decode_video_frames_torchvision
)
# let's define a custom fn
from torchcodec.decoders import VideoDecoder
def custom_collate_fn(batch):
"""
Custom collate function that decodes videos on CPU.
Ensures batch format remains unchanged.
"""
batched_frames = {} # Dictionary to hold video tensors
final_batch = {} # Dictionary to hold the rest of the batch
# Initialize final_batch with all original keys (except video paths)
for key in batch[0].keys():
if key not in ["video_paths", "query_timestamps"]: # Skip video-related fields
final_batch[key] = [item[key] for item in batch]
# Process video decoding
for item in batch:
if "video_paths" in item and "query_timestamps" in item:
for vid_key, video_path in item["video_paths"].items():
decoder = VideoDecoder(str(video_path), device="cpu") # CPU decoding
# frames = decoder.get_frames_played_at(item["query_timestamps"][vid_key]).data.float() / 255
timestamps = item["query_timestamps"][vid_key]
frames = decode_video_frames_torchvision(
video_path=Path(video_path),
timestamps=timestamps,
tolerance_s=0.02, # Adjust tolerance if needed
backend="pyav", # Default backend (modify if needed)
log_loaded_timestamps=False,
)
if vid_key not in batched_frames:
batched_frames[vid_key] = []
batched_frames[vid_key].append(frames)
# Convert lists to tensors where possible
for key in batched_frames:
batched_frames[key] = torch.stack(batched_frames[key]) # Stack tensors
for key in final_batch:
if isinstance(final_batch[key][0], torch.Tensor):
final_batch[key] = torch.stack(final_batch[key])
# **Fix: Ensure video_frames is a single tensor instead of a dictionary**
# hard coded this must change
if len(batched_frames) == 1:
final_batch["observation.images.top"] = list(batched_frames.values())[0] # Direct tensor
else:
final_batch["observation.images.top"] = batched_frames # Keep dict if multiple
return final_batch
def update_policy( def update_policy(
@ -182,12 +236,11 @@ def train(cfg: TrainPipelineConfig):
shuffle=shuffle, shuffle=shuffle,
sampler=sampler, sampler=sampler,
pin_memory=device.type != "cpu", pin_memory=device.type != "cpu",
collate_fn=custom_collate_fn,
drop_last=False, drop_last=False,
) )
dl_iter = cycle(dataloader) dl_iter = cycle(dataloader)
policy.train() policy.train()
train_metrics = { train_metrics = {
"loss": AverageMeter("loss", ":.3f"), "loss": AverageMeter("loss", ":.3f"),
"grad_norm": AverageMeter("grdn", ":.3f"), "grad_norm": AverageMeter("grdn", ":.3f"),
@ -205,7 +258,6 @@ def train(cfg: TrainPipelineConfig):
start_time = time.perf_counter() start_time = time.perf_counter()
batch = next(dl_iter) batch = next(dl_iter)
train_tracker.dataloading_s = time.perf_counter() - start_time train_tracker.dataloading_s = time.perf_counter() - start_time
for key in batch: for key in batch:
if isinstance(batch[key], torch.Tensor): if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(device, non_blocking=True) batch[key] = batch[key].to(device, non_blocking=True)
@ -231,6 +283,7 @@ def train(cfg: TrainPipelineConfig):
if is_log_step: if is_log_step:
logging.info(train_tracker) logging.info(train_tracker)
breakpoint()
if wandb_logger: if wandb_logger:
wandb_log_dict = train_tracker.to_dict() wandb_log_dict = train_tracker.to_dict()
if output_dict: if output_dict: