add new video decoder method
This commit is contained in:
parent
c6bcfb3539
commit
cae49528ee
|
@ -652,6 +652,45 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
item = self.hf_dataset[idx]
|
item = self.hf_dataset[idx]
|
||||||
ep_idx = item["episode_index"].item()
|
ep_idx = item["episode_index"].item()
|
||||||
|
|
||||||
|
query_indices = None
|
||||||
|
if self.delta_indices is not None:
|
||||||
|
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
|
||||||
|
query_indices, padding = self._get_query_indices(idx, current_ep_idx)
|
||||||
|
query_result = self._query_hf_dataset(query_indices)
|
||||||
|
item = {**item, **padding}
|
||||||
|
for key, val in query_result.items():
|
||||||
|
item[key] = val
|
||||||
|
if len(self.meta.video_keys) > 0:
|
||||||
|
current_ts = item["timestamp"].item()
|
||||||
|
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
||||||
|
# if what is returned is all the info that i used query_timestamps, episode
|
||||||
|
# percentage of chance, 30% cpu, gpu
|
||||||
|
# video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||||
|
# item = {**video_frames, **item}
|
||||||
|
|
||||||
|
# jade - instead of decoding video, return video path & timestamps
|
||||||
|
# hack only add metadata
|
||||||
|
item["video_paths"] = {
|
||||||
|
vid_key: self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
||||||
|
for vid_key in query_timestamps.keys()
|
||||||
|
}
|
||||||
|
item["query_timestamps"] = query_timestamps
|
||||||
|
|
||||||
|
if self.image_transforms is not None:
|
||||||
|
breakpoint()
|
||||||
|
image_keys = self.meta.camera_keys
|
||||||
|
for cam in image_keys:
|
||||||
|
item[cam] = self.image_transforms(item[cam])
|
||||||
|
|
||||||
|
# Add task as a string
|
||||||
|
task_idx = item["task_index"].item()
|
||||||
|
item["task"] = self.meta.tasks[task_idx]
|
||||||
|
|
||||||
|
return item
|
||||||
|
def __getitem2__(self, idx) -> dict:
|
||||||
|
item = self.hf_dataset[idx]
|
||||||
|
ep_idx = item["episode_index"].item()
|
||||||
|
|
||||||
query_indices = None
|
query_indices = None
|
||||||
if self.delta_indices is not None:
|
if self.delta_indices is not None:
|
||||||
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
|
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
|
||||||
|
@ -677,7 +716,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
item["task"] = self.meta.tasks[task_idx]
|
item["task"] = self.meta.tasks[task_idx]
|
||||||
|
|
||||||
return item
|
return item
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
feature_keys = list(self.features)
|
feature_keys = list(self.features)
|
||||||
return (
|
return (
|
||||||
|
|
|
@ -23,7 +23,7 @@ import torch
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
from torch.amp import GradScaler
|
from torch.amp import GradScaler
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
from pathlib import Path
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
||||||
from lerobot.common.datasets.utils import cycle
|
from lerobot.common.datasets.utils import cycle
|
||||||
|
@ -51,6 +51,60 @@ from lerobot.common.utils.wandb_utils import WandBLogger
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from lerobot.scripts.eval import eval_policy
|
from lerobot.scripts.eval import eval_policy
|
||||||
|
from lerobot.common.datasets.video_utils import (
|
||||||
|
decode_video_frames_torchvision
|
||||||
|
)
|
||||||
|
# let's define a custom fn
|
||||||
|
from torchcodec.decoders import VideoDecoder
|
||||||
|
|
||||||
|
def custom_collate_fn(batch):
|
||||||
|
"""
|
||||||
|
Custom collate function that decodes videos on CPU.
|
||||||
|
Ensures batch format remains unchanged.
|
||||||
|
"""
|
||||||
|
batched_frames = {} # Dictionary to hold video tensors
|
||||||
|
final_batch = {} # Dictionary to hold the rest of the batch
|
||||||
|
|
||||||
|
# Initialize final_batch with all original keys (except video paths)
|
||||||
|
for key in batch[0].keys():
|
||||||
|
if key not in ["video_paths", "query_timestamps"]: # Skip video-related fields
|
||||||
|
final_batch[key] = [item[key] for item in batch]
|
||||||
|
|
||||||
|
# Process video decoding
|
||||||
|
for item in batch:
|
||||||
|
if "video_paths" in item and "query_timestamps" in item:
|
||||||
|
for vid_key, video_path in item["video_paths"].items():
|
||||||
|
decoder = VideoDecoder(str(video_path), device="cpu") # CPU decoding
|
||||||
|
# frames = decoder.get_frames_played_at(item["query_timestamps"][vid_key]).data.float() / 255
|
||||||
|
timestamps = item["query_timestamps"][vid_key]
|
||||||
|
frames = decode_video_frames_torchvision(
|
||||||
|
video_path=Path(video_path),
|
||||||
|
timestamps=timestamps,
|
||||||
|
tolerance_s=0.02, # Adjust tolerance if needed
|
||||||
|
backend="pyav", # Default backend (modify if needed)
|
||||||
|
log_loaded_timestamps=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if vid_key not in batched_frames:
|
||||||
|
batched_frames[vid_key] = []
|
||||||
|
batched_frames[vid_key].append(frames)
|
||||||
|
|
||||||
|
# Convert lists to tensors where possible
|
||||||
|
for key in batched_frames:
|
||||||
|
batched_frames[key] = torch.stack(batched_frames[key]) # Stack tensors
|
||||||
|
|
||||||
|
for key in final_batch:
|
||||||
|
if isinstance(final_batch[key][0], torch.Tensor):
|
||||||
|
final_batch[key] = torch.stack(final_batch[key])
|
||||||
|
|
||||||
|
# **Fix: Ensure video_frames is a single tensor instead of a dictionary**
|
||||||
|
# hard coded this must change
|
||||||
|
if len(batched_frames) == 1:
|
||||||
|
final_batch["observation.images.top"] = list(batched_frames.values())[0] # Direct tensor
|
||||||
|
else:
|
||||||
|
final_batch["observation.images.top"] = batched_frames # Keep dict if multiple
|
||||||
|
|
||||||
|
return final_batch
|
||||||
|
|
||||||
|
|
||||||
def update_policy(
|
def update_policy(
|
||||||
|
@ -182,12 +236,11 @@ def train(cfg: TrainPipelineConfig):
|
||||||
shuffle=shuffle,
|
shuffle=shuffle,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
pin_memory=device.type != "cpu",
|
pin_memory=device.type != "cpu",
|
||||||
|
collate_fn=custom_collate_fn,
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
)
|
)
|
||||||
dl_iter = cycle(dataloader)
|
dl_iter = cycle(dataloader)
|
||||||
|
|
||||||
policy.train()
|
policy.train()
|
||||||
|
|
||||||
train_metrics = {
|
train_metrics = {
|
||||||
"loss": AverageMeter("loss", ":.3f"),
|
"loss": AverageMeter("loss", ":.3f"),
|
||||||
"grad_norm": AverageMeter("grdn", ":.3f"),
|
"grad_norm": AverageMeter("grdn", ":.3f"),
|
||||||
|
@ -205,7 +258,6 @@ def train(cfg: TrainPipelineConfig):
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
batch = next(dl_iter)
|
batch = next(dl_iter)
|
||||||
train_tracker.dataloading_s = time.perf_counter() - start_time
|
train_tracker.dataloading_s = time.perf_counter() - start_time
|
||||||
|
|
||||||
for key in batch:
|
for key in batch:
|
||||||
if isinstance(batch[key], torch.Tensor):
|
if isinstance(batch[key], torch.Tensor):
|
||||||
batch[key] = batch[key].to(device, non_blocking=True)
|
batch[key] = batch[key].to(device, non_blocking=True)
|
||||||
|
@ -231,6 +283,7 @@ def train(cfg: TrainPipelineConfig):
|
||||||
|
|
||||||
if is_log_step:
|
if is_log_step:
|
||||||
logging.info(train_tracker)
|
logging.info(train_tracker)
|
||||||
|
breakpoint()
|
||||||
if wandb_logger:
|
if wandb_logger:
|
||||||
wandb_log_dict = train_tracker.to_dict()
|
wandb_log_dict = train_tracker.to_dict()
|
||||||
if output_dict:
|
if output_dict:
|
||||||
|
|
Loading…
Reference in New Issue