From 59aa172c175a231baddd45b4d872e63439ad2a2b Mon Sep 17 00:00:00 2001 From: Akshay Kashyap Date: Mon, 13 May 2024 13:42:09 -0700 Subject: [PATCH] nit --- .../policies/diffusion/modeling_diffusion.py | 39 ++++++++++++------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 4ba7a70f..2d3dae60 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -290,14 +290,14 @@ class DiffusionModel(nn.Module): class SpatialSoftmax(nn.Module): """ - Spatial Soft Argmax layer described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al. + Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al. (https://arxiv.org/pdf/1509.06113) A minimal port of the robomimic implementation. - At a high level, this operation takes 2D feature maps (from a convnet/ViT/etc.) and returns - the "center of mass" of activations of each channel, i.e., spatial keypoints for your policy to focus on. + At a high level, this takes 2D feature maps (from a convnet/ViT/etc.) and returns the "center of mass" + of activations of each channel, i.e., spatial keypoints for the policy to focus on. - Example: take feature maps of size (512x10x12). We will generate a grid of normalized coordinates (10x12x2): + Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2): ----------------------------------------------------- | (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) | | (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) | @@ -307,11 +307,18 @@ class SpatialSoftmax(nn.Module): We apply channel-wise softmax over the activations (512x120) and compute dot product with the coordinates (120x2) to get expected points of maximal activation (512x2). - Optionally, can also learn a mapping from the feature maps to a lower dimensional space (num_kp < in_c) before computing the argmax - if we'd like to focus on a smaller number of keypoints. + Optionally, can also learn a linear mapping from the feature maps to a lower/higher dimensional space using a conv1x1 + before computing the softmax. """ def __init__(self, input_shape, num_kp=None, temperature=1.0, learnable_temperature=False): + """ + Args: + input_shape (list): (C, H, W) input feature map shape. + num_kp (int): number of keypoints to output. If None, output will have the same number of channels as input. + temperature (float): temperature for softmax normalization. + learnable_temperature (bool): whether to learn the temperature parameter. + """ super().__init__() assert len(input_shape) == 3 @@ -328,16 +335,18 @@ class SpatialSoftmax(nn.Module): _pos_x, _pos_y = torch.meshgrid( torch.linspace(-1.0, 1.0, self._in_w), torch.linspace(-1.0, 1.0, self._in_h), indexing="xy" ) - # Register as buffers so they are moved to the correct device, etc. - self.register_buffer( - "pos_grid", - torch.cat( - [_pos_x.reshape(self._in_h * self._in_w, 1), _pos_y.reshape(self._in_h * self._in_w, 1)], - dim=1, - ).float(), - ) + _pos_x = _pos_x.reshape(self._in_h * self._in_w, 1) + _pos_y = _pos_y.reshape(self._in_h * self._in_w, 1) + # Register as buffer so it's moved to the correct device, etc. + self.register_buffer("pos_grid", torch.cat([_pos_x, _pos_y], dim=1)) def forward(self, features: Tensor) -> Tensor: + """ + Args: + features: (B, C, H, W) input feature maps. + Returns: + (B, K, 2) image-space coordinates of keypoints. + """ if self.nets is not None: features = self.nets(features) @@ -346,7 +355,7 @@ class SpatialSoftmax(nn.Module): # 2d softmax normalization attention = F.softmax(features / self.temperature, dim=-1) # [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions - expected_xy = torch.matmul(attention, self.pos_grid) + expected_xy = attention @ self.pos_grid # reshape to [B, K, 2] feature_keypoints = expected_xy.view(-1, self._out_c, 2)