don't delete observation.image key

This commit is contained in:
Alexander Soare 2024-05-08 12:58:14 +01:00
parent 8357dae26a
commit c3a5cbb0b6
1 changed files with 3 additions and 2 deletions

View File

@ -89,8 +89,9 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
assert "action" in batch assert "action" in batch
assert "action_is_pad" in batch assert "action_is_pad" in batch
image_key = next(iter(image_keys)) image_key = next(iter(image_keys))
batch["observation.image"] = batch[image_key] if image_key != "observation.image":
del batch[image_key] batch["observation.image"] = batch[image_key]
del batch[image_key]
@torch.no_grad @torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor: def select_action(self, batch: dict[str, Tensor]) -> Tensor: