This commit is contained in:
Alexander Soare 2024-05-16 13:47:58 +01:00
parent 380c31390e
commit 461c48a66c
1 changed files with 5 additions and 1 deletions

View File

@ -19,6 +19,7 @@
TODO(alexander-soare):
- Remove reliance on Robomimic for SpatialSoftmax.
- Remove reliance on diffusers for DDPMScheduler and LR scheduler.
- Make compatible with multiple image keys.
"""
import math
@ -85,7 +86,10 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
assert len(image_keys) == 1
if len(image_keys) != 1:
raise NotImplementedError(
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
)
self.input_image_key = image_keys[0]
self.reset()