Merge 7e681bcc0f
into b568de35ad
This commit is contained in:
commit
df7d938ad2
|
@ -736,8 +736,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
if len(self.meta.video_keys) > 0:
|
||||
current_ts = item["timestamp"].item()
|
||||
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
item = {**video_frames, **item}
|
||||
if self.video_backend != "torchcodec-gpu":
|
||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
item = {**video_frames, **item}
|
||||
else:
|
||||
item["query_timestamps"] = query_timestamps
|
||||
|
||||
if self.image_transforms is not None:
|
||||
image_keys = self.meta.camera_keys
|
||||
|
|
|
@ -27,6 +27,7 @@ from torch.optim import Optimizer
|
|||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
from lerobot.common.datasets.video_utils import decode_video_frames_torchcodec
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
|
@ -203,6 +204,27 @@ def train(cfg: TrainPipelineConfig):
|
|||
for _ in range(step, cfg.steps):
|
||||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
|
||||
if dataset.video_backend == "torchcodec-gpu":
|
||||
# make sure we have a cuda device
|
||||
assert torch.cuda.is_available(), (
|
||||
"CUDA device not available. Please run on a machine with a GPU "
|
||||
"to enable CUDA decoding when using `video_backend='torchcodec-gpu'`."
|
||||
)
|
||||
# add cuda decoding
|
||||
for vid_key, timestamps_list in batch["query_timestamps"].items():
|
||||
frames_list = []
|
||||
# convert list of scalar tensors to a tensor of shape [T, B]
|
||||
query_ts = torch.stack(timestamps_list).T # convert to shape: [B, T]
|
||||
for i in range(query_ts.shape[0]):
|
||||
ep_idx = batch["episode_index"][i]
|
||||
timestamps = query_ts[i].tolist()
|
||||
video_path = dataset.root / dataset.meta.get_video_file_path(ep_idx, vid_key)
|
||||
frames = decode_video_frames_torchcodec(
|
||||
video_path, timestamps, dataset.tolerance_s, device="cuda"
|
||||
)
|
||||
frames_list.append(frames.squeeze(0))
|
||||
batch[vid_key] = torch.stack(frames_list)
|
||||
train_tracker.dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
for key in batch:
|
||||
|
|
Loading…
Reference in New Issue