This commit is contained in:
Jade Choghari 2025-03-28 17:54:28 +00:00 committed by GitHub
commit df7d938ad2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 27 additions and 2 deletions

View File

@ -736,8 +736,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
if len(self.meta.video_keys) > 0:
current_ts = item["timestamp"].item()
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
video_frames = self._query_videos(query_timestamps, ep_idx)
item = {**video_frames, **item}
if self.video_backend != "torchcodec-gpu":
video_frames = self._query_videos(query_timestamps, ep_idx)
item = {**video_frames, **item}
else:
item["query_timestamps"] = query_timestamps
if self.image_transforms is not None:
image_keys = self.meta.camera_keys

View File

@ -27,6 +27,7 @@ from torch.optim import Optimizer
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.utils import cycle
from lerobot.common.datasets.video_utils import decode_video_frames_torchcodec
from lerobot.common.envs.factory import make_env
from lerobot.common.optim.factory import make_optimizer_and_scheduler
from lerobot.common.policies.factory import make_policy
@ -203,6 +204,27 @@ def train(cfg: TrainPipelineConfig):
for _ in range(step, cfg.steps):
start_time = time.perf_counter()
batch = next(dl_iter)
if dataset.video_backend == "torchcodec-gpu":
# make sure we have a cuda device
assert torch.cuda.is_available(), (
"CUDA device not available. Please run on a machine with a GPU "
"to enable CUDA decoding when using `video_backend='torchcodec-gpu'`."
)
# add cuda decoding
for vid_key, timestamps_list in batch["query_timestamps"].items():
frames_list = []
# convert list of scalar tensors to a tensor of shape [T, B]
query_ts = torch.stack(timestamps_list).T # convert to shape: [B, T]
for i in range(query_ts.shape[0]):
ep_idx = batch["episode_index"][i]
timestamps = query_ts[i].tolist()
video_path = dataset.root / dataset.meta.get_video_file_path(ep_idx, vid_key)
frames = decode_video_frames_torchcodec(
video_path, timestamps, dataset.tolerance_s, device="cuda"
)
frames_list.append(frames.squeeze(0))
batch[vid_key] = torch.stack(frames_list)
train_tracker.dataloading_s = time.perf_counter() - start_time
for key in batch: