diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 3115160f..7c7debf1 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -16,7 +16,6 @@ import torchvision from diffusers.schedulers.scheduling_ddim import DDIMScheduler from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from huggingface_hub import PyTorchModelHubMixin -from robomimic.models.base_nets import SpatialSoftmax from torch import Tensor, nn from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig @@ -140,6 +139,54 @@ def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMSche raise ValueError(f"Unsupported noise scheduler type {name}") +class SpatialSoftmax(nn.Module): + """ + Implementation of the Spatial Soft Argmax layer described in "Deep Spatial Autoencoders + for Visuomotor Learning" by Finn et al. (https://arxiv.org/pdf/1509.06113) + + Meant to be a minimal port of the robomimic implementation. + """ + + def __init__(self, input_shape, num_kp=None, temperature=1.0, learnable_temperature=False): + super().__init__() + + assert len(input_shape) == 3 + self._in_c, self._in_h, self._in_w = input_shape + + if num_kp is not None: + self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1) + self._out_c = num_kp + else: + self.nets = None + self._out_c = self._in_c + + self.temperature = nn.Parameter(torch.tensor(temperature), requires_grad=learnable_temperature) + _pos_x, _pos_y = torch.meshgrid( + torch.linspace(-1.0, 1.0, self._in_w), torch.linspace(-1.0, 1.0, self._in_h), indexing="xy" + ) + # Register as buffers so they are moved to the correct device, etc. + self.register_buffer("pos_x", _pos_x.reshape(1, self._in_h * self._in_w)) + self.register_buffer("pos_y", _pos_y.reshape(1, self._in_h * self._in_w)) + + def forward(self, features: Tensor) -> Tensor: + if self.nets is not None: + features = self.nets(features) + + # [B, K, H, W] -> [B * K, H * W] where K is number of keypoints + features = features.reshape(-1, self._in_h * self._in_w) + # 2d softmax normalization + attention = F.softmax(features / self.temperature, dim=-1) + # [1, H * W] x [B * K, H * W] -> [B * K, 1] for spatial coordinate mean in x and y dimensions + expected_x = torch.sum(self.pos_x * attention, dim=1, keepdim=True) + expected_y = torch.sum(self.pos_y * attention, dim=1, keepdim=True) + # stack to [B * K, 2] + expected_xy = torch.cat([expected_x, expected_y], 1) + # reshape to [B, K, 2] + feature_keypoints = expected_xy.view(-1, self._out_c, 2) + + return feature_keypoints + + class DiffusionModel(nn.Module): def __init__(self, config: DiffusionConfig): super().__init__()