don't delete observation.image key

This commit is contained in:
Alexander Soare 2024-05-08 12:58:14 +01:00
parent 8357dae26a
commit c3a5cbb0b6
1 changed files with 3 additions and 2 deletions

View File

@ -89,6 +89,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
assert "action" in batch
assert "action_is_pad" in batch
image_key = next(iter(image_keys))
if image_key != "observation.image":
batch["observation.image"] = batch[image_key]
del batch[image_key]