add new video decoder method
This commit is contained in:
parent
c6bcfb3539
commit
cae49528ee
|
@ -652,6 +652,45 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
item = self.hf_dataset[idx]
|
||||
ep_idx = item["episode_index"].item()
|
||||
|
||||
query_indices = None
|
||||
if self.delta_indices is not None:
|
||||
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
|
||||
query_indices, padding = self._get_query_indices(idx, current_ep_idx)
|
||||
query_result = self._query_hf_dataset(query_indices)
|
||||
item = {**item, **padding}
|
||||
for key, val in query_result.items():
|
||||
item[key] = val
|
||||
if len(self.meta.video_keys) > 0:
|
||||
current_ts = item["timestamp"].item()
|
||||
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
||||
# if what is returned is all the info that i used query_timestamps, episode
|
||||
# percentage of chance, 30% cpu, gpu
|
||||
# video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
# item = {**video_frames, **item}
|
||||
|
||||
# jade - instead of decoding video, return video path & timestamps
|
||||
# hack only add metadata
|
||||
item["video_paths"] = {
|
||||
vid_key: self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
||||
for vid_key in query_timestamps.keys()
|
||||
}
|
||||
item["query_timestamps"] = query_timestamps
|
||||
|
||||
if self.image_transforms is not None:
|
||||
breakpoint()
|
||||
image_keys = self.meta.camera_keys
|
||||
for cam in image_keys:
|
||||
item[cam] = self.image_transforms(item[cam])
|
||||
|
||||
# Add task as a string
|
||||
task_idx = item["task_index"].item()
|
||||
item["task"] = self.meta.tasks[task_idx]
|
||||
|
||||
return item
|
||||
def __getitem2__(self, idx) -> dict:
|
||||
item = self.hf_dataset[idx]
|
||||
ep_idx = item["episode_index"].item()
|
||||
|
||||
query_indices = None
|
||||
if self.delta_indices is not None:
|
||||
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
|
||||
|
@ -677,7 +716,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
item["task"] = self.meta.tasks[task_idx]
|
||||
|
||||
return item
|
||||
|
||||
def __repr__(self):
|
||||
feature_keys = list(self.features)
|
||||
return (
|
||||
|
|
|
@ -23,7 +23,7 @@ import torch
|
|||
from termcolor import colored
|
||||
from torch.amp import GradScaler
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from pathlib import Path
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
|
@ -51,6 +51,60 @@ from lerobot.common.utils.wandb_utils import WandBLogger
|
|||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
from lerobot.common.datasets.video_utils import (
|
||||
decode_video_frames_torchvision
|
||||
)
|
||||
# let's define a custom fn
|
||||
from torchcodec.decoders import VideoDecoder
|
||||
|
||||
def custom_collate_fn(batch):
|
||||
"""
|
||||
Custom collate function that decodes videos on CPU.
|
||||
Ensures batch format remains unchanged.
|
||||
"""
|
||||
batched_frames = {} # Dictionary to hold video tensors
|
||||
final_batch = {} # Dictionary to hold the rest of the batch
|
||||
|
||||
# Initialize final_batch with all original keys (except video paths)
|
||||
for key in batch[0].keys():
|
||||
if key not in ["video_paths", "query_timestamps"]: # Skip video-related fields
|
||||
final_batch[key] = [item[key] for item in batch]
|
||||
|
||||
# Process video decoding
|
||||
for item in batch:
|
||||
if "video_paths" in item and "query_timestamps" in item:
|
||||
for vid_key, video_path in item["video_paths"].items():
|
||||
decoder = VideoDecoder(str(video_path), device="cpu") # CPU decoding
|
||||
# frames = decoder.get_frames_played_at(item["query_timestamps"][vid_key]).data.float() / 255
|
||||
timestamps = item["query_timestamps"][vid_key]
|
||||
frames = decode_video_frames_torchvision(
|
||||
video_path=Path(video_path),
|
||||
timestamps=timestamps,
|
||||
tolerance_s=0.02, # Adjust tolerance if needed
|
||||
backend="pyav", # Default backend (modify if needed)
|
||||
log_loaded_timestamps=False,
|
||||
)
|
||||
|
||||
if vid_key not in batched_frames:
|
||||
batched_frames[vid_key] = []
|
||||
batched_frames[vid_key].append(frames)
|
||||
|
||||
# Convert lists to tensors where possible
|
||||
for key in batched_frames:
|
||||
batched_frames[key] = torch.stack(batched_frames[key]) # Stack tensors
|
||||
|
||||
for key in final_batch:
|
||||
if isinstance(final_batch[key][0], torch.Tensor):
|
||||
final_batch[key] = torch.stack(final_batch[key])
|
||||
|
||||
# **Fix: Ensure video_frames is a single tensor instead of a dictionary**
|
||||
# hard coded this must change
|
||||
if len(batched_frames) == 1:
|
||||
final_batch["observation.images.top"] = list(batched_frames.values())[0] # Direct tensor
|
||||
else:
|
||||
final_batch["observation.images.top"] = batched_frames # Keep dict if multiple
|
||||
|
||||
return final_batch
|
||||
|
||||
|
||||
def update_policy(
|
||||
|
@ -182,12 +236,11 @@ def train(cfg: TrainPipelineConfig):
|
|||
shuffle=shuffle,
|
||||
sampler=sampler,
|
||||
pin_memory=device.type != "cpu",
|
||||
collate_fn=custom_collate_fn,
|
||||
drop_last=False,
|
||||
)
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
policy.train()
|
||||
|
||||
train_metrics = {
|
||||
"loss": AverageMeter("loss", ":.3f"),
|
||||
"grad_norm": AverageMeter("grdn", ":.3f"),
|
||||
|
@ -205,7 +258,6 @@ def train(cfg: TrainPipelineConfig):
|
|||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
train_tracker.dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
batch[key] = batch[key].to(device, non_blocking=True)
|
||||
|
@ -231,6 +283,7 @@ def train(cfg: TrainPipelineConfig):
|
|||
|
||||
if is_log_step:
|
||||
logging.info(train_tracker)
|
||||
breakpoint()
|
||||
if wandb_logger:
|
||||
wandb_log_dict = train_tracker.to_dict()
|
||||
if output_dict:
|
||||
|
|
Loading…
Reference in New Issue