revision
This commit is contained in:
parent
380c31390e
commit
461c48a66c
|
@ -19,6 +19,7 @@
|
||||||
TODO(alexander-soare):
|
TODO(alexander-soare):
|
||||||
- Remove reliance on Robomimic for SpatialSoftmax.
|
- Remove reliance on Robomimic for SpatialSoftmax.
|
||||||
- Remove reliance on diffusers for DDPMScheduler and LR scheduler.
|
- Remove reliance on diffusers for DDPMScheduler and LR scheduler.
|
||||||
|
- Make compatible with multiple image keys.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
@ -85,7 +86,10 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
|
|
||||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||||
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
|
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
|
||||||
assert len(image_keys) == 1
|
if len(image_keys) != 1:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.__class__.__name__} only handles one image for now. Got image keys {image_keys}."
|
||||||
|
)
|
||||||
self.input_image_key = image_keys[0]
|
self.input_image_key = image_keys[0]
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
Loading…
Reference in New Issue