diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 6ef955dd..f7a2900f 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -736,8 +736,11 @@ class LeRobotDataset(torch.utils.data.Dataset): if len(self.meta.video_keys) > 0: current_ts = item["timestamp"].item() query_timestamps = self._get_query_timestamps(current_ts, query_indices) - video_frames = self._query_videos(query_timestamps, ep_idx) - item = {**video_frames, **item} + if self.video_backend != "torchcodec-gpu": + video_frames = self._query_videos(query_timestamps, ep_idx) + item = {**video_frames, **item} + else: + item["query_timestamps"] = query_timestamps if self.image_transforms is not None: image_keys = self.meta.camera_keys diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index f2b1e29e..a439e811 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -27,6 +27,7 @@ from torch.optim import Optimizer from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.sampler import EpisodeAwareSampler from lerobot.common.datasets.utils import cycle +from lerobot.common.datasets.video_utils import decode_video_frames_torchcodec from lerobot.common.envs.factory import make_env from lerobot.common.optim.factory import make_optimizer_and_scheduler from lerobot.common.policies.factory import make_policy @@ -203,6 +204,27 @@ def train(cfg: TrainPipelineConfig): for _ in range(step, cfg.steps): start_time = time.perf_counter() batch = next(dl_iter) + + if dataset.video_backend == "torchcodec-gpu": + # make sure we have a cuda device + assert torch.cuda.is_available(), ( + "CUDA device not available. Please run on a machine with a GPU " + "to enable CUDA decoding when using `video_backend='torchcodec-gpu'`." + ) + # add cuda decoding + for vid_key, timestamps_list in batch["query_timestamps"].items(): + frames_list = [] + # convert list of scalar tensors to a tensor of shape [T, B] + query_ts = torch.stack(timestamps_list).T # convert to shape: [B, T] + for i in range(query_ts.shape[0]): + ep_idx = batch["episode_index"][i] + timestamps = query_ts[i].tolist() + video_path = dataset.root / dataset.meta.get_video_file_path(ep_idx, vid_key) + frames = decode_video_frames_torchcodec( + video_path, timestamps, dataset.tolerance_s, device="cuda" + ) + frames_list.append(frames.squeeze(0)) + batch[vid_key] = torch.stack(frames_list) train_tracker.dataloading_s = time.perf_counter() - start_time for key in batch: