diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 2210b8fe..435ae2af 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -1,7 +1,6 @@ """Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion" TODO(alexander-soare): - - Remove reliance on Robomimic for SpatialSoftmax. - Remove reliance on diffusers for DDPMScheduler and LR scheduler. """ @@ -307,17 +306,15 @@ class SpatialSoftmax(nn.Module): We apply channel-wise softmax over the activations (512x120) and compute dot product with the coordinates (120x2) to get expected points of maximal activation (512x2). - Optionally, when num_kp != None, can learn a linear mapping from the feature maps to a lower/higher dimensional - space using a conv1x1 before computing the softmax. + Optionally, when num_kp != None, can learn a linear mapping (of shape (num_kp, H, W)) from the input feature maps + using a conv1x1 before computing softmax. """ - def __init__(self, input_shape, num_kp=None, temperature=1.0, learnable_temperature=False): + def __init__(self, input_shape, num_kp=None): """ Args: input_shape (list): (C, H, W) input feature map shape. num_kp (int): number of keypoints in output. If None, output will have the same number of channels as input. - temperature (float): temperature for softmax normalization. - learnable_temperature (bool): whether to learn the temperature parameter. """ super().__init__() @@ -331,7 +328,6 @@ class SpatialSoftmax(nn.Module): self.nets = None self._out_c = self._in_c - self.temperature = nn.Parameter(torch.tensor(temperature), requires_grad=learnable_temperature) # we could use torch.linspace directly but that seems to behave slightly differently than numpy # and causes a small degradation in pc_success of pre-trained models. pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)) @@ -353,7 +349,7 @@ class SpatialSoftmax(nn.Module): # [B, K, H, W] -> [B * K, H * W] where K is number of keypoints features = features.reshape(-1, self._in_h * self._in_w) # 2d softmax normalization - attention = F.softmax(features / self.temperature, dim=-1) + attention = F.softmax(features, dim=-1) # [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions expected_xy = attention @ self.pos_grid # reshape to [B, K, 2]