don't delete observation.image key
This commit is contained in:
parent
8357dae26a
commit
c3a5cbb0b6
|
@ -89,8 +89,9 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
assert "action" in batch
|
||||
assert "action_is_pad" in batch
|
||||
image_key = next(iter(image_keys))
|
||||
batch["observation.image"] = batch[image_key]
|
||||
del batch[image_key]
|
||||
if image_key != "observation.image":
|
||||
batch["observation.image"] = batch[image_key]
|
||||
del batch[image_key]
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
|
|
Loading…
Reference in New Issue