don't delete observation.image key
This commit is contained in:
parent
8357dae26a
commit
c3a5cbb0b6
|
@ -89,8 +89,9 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
assert "action" in batch
|
assert "action" in batch
|
||||||
assert "action_is_pad" in batch
|
assert "action_is_pad" in batch
|
||||||
image_key = next(iter(image_keys))
|
image_key = next(iter(image_keys))
|
||||||
batch["observation.image"] = batch[image_key]
|
if image_key != "observation.image":
|
||||||
del batch[image_key]
|
batch["observation.image"] = batch[image_key]
|
||||||
|
del batch[image_key]
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
|
|
Loading…
Reference in New Issue