diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 9483bf0a..371c334a 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -652,6 +652,45 @@ class LeRobotDataset(torch.utils.data.Dataset): item = self.hf_dataset[idx] ep_idx = item["episode_index"].item() + query_indices = None + if self.delta_indices is not None: + current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx + query_indices, padding = self._get_query_indices(idx, current_ep_idx) + query_result = self._query_hf_dataset(query_indices) + item = {**item, **padding} + for key, val in query_result.items(): + item[key] = val + if len(self.meta.video_keys) > 0: + current_ts = item["timestamp"].item() + query_timestamps = self._get_query_timestamps(current_ts, query_indices) + # if what is returned is all the info that i used query_timestamps, episode + # percentage of chance, 30% cpu, gpu + # video_frames = self._query_videos(query_timestamps, ep_idx) + # item = {**video_frames, **item} + + # jade - instead of decoding video, return video path & timestamps + # hack only add metadata + item["video_paths"] = { + vid_key: self.root / self.meta.get_video_file_path(ep_idx, vid_key) + for vid_key in query_timestamps.keys() + } + item["query_timestamps"] = query_timestamps + + if self.image_transforms is not None: + breakpoint() + image_keys = self.meta.camera_keys + for cam in image_keys: + item[cam] = self.image_transforms(item[cam]) + + # Add task as a string + task_idx = item["task_index"].item() + item["task"] = self.meta.tasks[task_idx] + + return item + def __getitem2__(self, idx) -> dict: + item = self.hf_dataset[idx] + ep_idx = item["episode_index"].item() + query_indices = None if self.delta_indices is not None: current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx @@ -677,7 +716,6 @@ class LeRobotDataset(torch.utils.data.Dataset): item["task"] = self.meta.tasks[task_idx] return item - def __repr__(self): feature_keys = list(self.features) return ( diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index f3c57fe2..b52fb10f 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -23,7 +23,7 @@ import torch from termcolor import colored from torch.amp import GradScaler from torch.optim import Optimizer - +from pathlib import Path from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.sampler import EpisodeAwareSampler from lerobot.common.datasets.utils import cycle @@ -51,6 +51,60 @@ from lerobot.common.utils.wandb_utils import WandBLogger from lerobot.configs import parser from lerobot.configs.train import TrainPipelineConfig from lerobot.scripts.eval import eval_policy +from lerobot.common.datasets.video_utils import ( + decode_video_frames_torchvision +) +# let's define a custom fn +from torchcodec.decoders import VideoDecoder + +def custom_collate_fn(batch): + """ + Custom collate function that decodes videos on CPU. + Ensures batch format remains unchanged. + """ + batched_frames = {} # Dictionary to hold video tensors + final_batch = {} # Dictionary to hold the rest of the batch + + # Initialize final_batch with all original keys (except video paths) + for key in batch[0].keys(): + if key not in ["video_paths", "query_timestamps"]: # Skip video-related fields + final_batch[key] = [item[key] for item in batch] + + # Process video decoding + for item in batch: + if "video_paths" in item and "query_timestamps" in item: + for vid_key, video_path in item["video_paths"].items(): + decoder = VideoDecoder(str(video_path), device="cpu") # CPU decoding + # frames = decoder.get_frames_played_at(item["query_timestamps"][vid_key]).data.float() / 255 + timestamps = item["query_timestamps"][vid_key] + frames = decode_video_frames_torchvision( + video_path=Path(video_path), + timestamps=timestamps, + tolerance_s=0.02, # Adjust tolerance if needed + backend="pyav", # Default backend (modify if needed) + log_loaded_timestamps=False, + ) + + if vid_key not in batched_frames: + batched_frames[vid_key] = [] + batched_frames[vid_key].append(frames) + + # Convert lists to tensors where possible + for key in batched_frames: + batched_frames[key] = torch.stack(batched_frames[key]) # Stack tensors + + for key in final_batch: + if isinstance(final_batch[key][0], torch.Tensor): + final_batch[key] = torch.stack(final_batch[key]) + + # **Fix: Ensure video_frames is a single tensor instead of a dictionary** + # hard coded this must change + if len(batched_frames) == 1: + final_batch["observation.images.top"] = list(batched_frames.values())[0] # Direct tensor + else: + final_batch["observation.images.top"] = batched_frames # Keep dict if multiple + + return final_batch def update_policy( @@ -182,12 +236,11 @@ def train(cfg: TrainPipelineConfig): shuffle=shuffle, sampler=sampler, pin_memory=device.type != "cpu", + collate_fn=custom_collate_fn, drop_last=False, ) dl_iter = cycle(dataloader) - policy.train() - train_metrics = { "loss": AverageMeter("loss", ":.3f"), "grad_norm": AverageMeter("grdn", ":.3f"), @@ -205,7 +258,6 @@ def train(cfg: TrainPipelineConfig): start_time = time.perf_counter() batch = next(dl_iter) train_tracker.dataloading_s = time.perf_counter() - start_time - for key in batch: if isinstance(batch[key], torch.Tensor): batch[key] = batch[key].to(device, non_blocking=True) @@ -231,6 +283,7 @@ def train(cfg: TrainPipelineConfig): if is_log_step: logging.info(train_tracker) + breakpoint() if wandb_logger: wandb_log_dict = train_tracker.to_dict() if output_dict: