diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 2d3dae60..d284c38c 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -10,6 +10,7 @@ from collections import deque from typing import Callable import einops +import numpy as np import torch import torch.nn.functional as F # noqa: N812 import torchvision @@ -291,8 +292,7 @@ class DiffusionModel(nn.Module): class SpatialSoftmax(nn.Module): """ Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al. - (https://arxiv.org/pdf/1509.06113) - A minimal port of the robomimic implementation. + (https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation. At a high level, this takes 2D feature maps (from a convnet/ViT/etc.) and returns the "center of mass" of activations of each channel, i.e., spatial keypoints for the policy to focus on. @@ -307,7 +307,7 @@ class SpatialSoftmax(nn.Module): We apply channel-wise softmax over the activations (512x120) and compute dot product with the coordinates (120x2) to get expected points of maximal activation (512x2). - Optionally, can also learn a linear mapping from the feature maps to a lower/higher dimensional space using a conv1x1 + Optionally, when num_kp != None, can learn a linear mapping from the feature maps to a lower/higher dimensional space using a conv1x1 before computing the softmax. """ @@ -332,13 +332,17 @@ class SpatialSoftmax(nn.Module): self._out_c = self._in_c self.temperature = nn.Parameter(torch.tensor(temperature), requires_grad=learnable_temperature) - _pos_x, _pos_y = torch.meshgrid( - torch.linspace(-1.0, 1.0, self._in_w), torch.linspace(-1.0, 1.0, self._in_h), indexing="xy" + + # we could use torch.linspace directly but that seems to behave slightly differently than numpy + # and cause a small degradation in pc_success of pre-trained models. + pos_x, pos_y = np.meshgrid( + np.linspace(-1., 1., self._in_w), + np.linspace(-1., 1., self._in_h) ) - _pos_x = _pos_x.reshape(self._in_h * self._in_w, 1) - _pos_y = _pos_y.reshape(self._in_h * self._in_w, 1) - # Register as buffer so it's moved to the correct device, etc. - self.register_buffer("pos_grid", torch.cat([_pos_x, _pos_y], dim=1)) + pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float() + pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float() + # register as buffer so it's moved to the correct device, etc. + self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1)) def forward(self, features: Tensor) -> Tensor: """