From c3a5cbb0b6f4daf87136ac3a7d1c8eb7130bb622 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 8 May 2024 12:58:14 +0100 Subject: [PATCH] don't delete observation.image key --- lerobot/common/policies/diffusion/modeling_diffusion.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index e5d099da..1ed9d391 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -89,8 +89,9 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin): assert "action" in batch assert "action_is_pad" in batch image_key = next(iter(image_keys)) - batch["observation.image"] = batch[image_key] - del batch[image_key] + if image_key != "observation.image": + batch["observation.image"] = batch[image_key] + del batch[image_key] @torch.no_grad def select_action(self, batch: dict[str, Tensor]) -> Tensor: