This commit is contained in:
Akshay Kashyap 2024-05-13 13:42:09 -07:00
parent 91f77e278b
commit 59aa172c17
1 changed files with 24 additions and 15 deletions

View File

@ -290,14 +290,14 @@ class DiffusionModel(nn.Module):
class SpatialSoftmax(nn.Module): class SpatialSoftmax(nn.Module):
""" """
Spatial Soft Argmax layer described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al. Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
(https://arxiv.org/pdf/1509.06113) (https://arxiv.org/pdf/1509.06113)
A minimal port of the robomimic implementation. A minimal port of the robomimic implementation.
At a high level, this operation takes 2D feature maps (from a convnet/ViT/etc.) and returns At a high level, this takes 2D feature maps (from a convnet/ViT/etc.) and returns the "center of mass"
the "center of mass" of activations of each channel, i.e., spatial keypoints for your policy to focus on. of activations of each channel, i.e., spatial keypoints for the policy to focus on.
Example: take feature maps of size (512x10x12). We will generate a grid of normalized coordinates (10x12x2): Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2):
----------------------------------------------------- -----------------------------------------------------
| (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) | | (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) |
| (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) | | (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) |
@ -307,11 +307,18 @@ class SpatialSoftmax(nn.Module):
We apply channel-wise softmax over the activations (512x120) and compute dot product with the coordinates (120x2) We apply channel-wise softmax over the activations (512x120) and compute dot product with the coordinates (120x2)
to get expected points of maximal activation (512x2). to get expected points of maximal activation (512x2).
Optionally, can also learn a mapping from the feature maps to a lower dimensional space (num_kp < in_c) before computing the argmax Optionally, can also learn a linear mapping from the feature maps to a lower/higher dimensional space using a conv1x1
if we'd like to focus on a smaller number of keypoints. before computing the softmax.
""" """
def __init__(self, input_shape, num_kp=None, temperature=1.0, learnable_temperature=False): def __init__(self, input_shape, num_kp=None, temperature=1.0, learnable_temperature=False):
"""
Args:
input_shape (list): (C, H, W) input feature map shape.
num_kp (int): number of keypoints to output. If None, output will have the same number of channels as input.
temperature (float): temperature for softmax normalization.
learnable_temperature (bool): whether to learn the temperature parameter.
"""
super().__init__() super().__init__()
assert len(input_shape) == 3 assert len(input_shape) == 3
@ -328,16 +335,18 @@ class SpatialSoftmax(nn.Module):
_pos_x, _pos_y = torch.meshgrid( _pos_x, _pos_y = torch.meshgrid(
torch.linspace(-1.0, 1.0, self._in_w), torch.linspace(-1.0, 1.0, self._in_h), indexing="xy" torch.linspace(-1.0, 1.0, self._in_w), torch.linspace(-1.0, 1.0, self._in_h), indexing="xy"
) )
# Register as buffers so they are moved to the correct device, etc. _pos_x = _pos_x.reshape(self._in_h * self._in_w, 1)
self.register_buffer( _pos_y = _pos_y.reshape(self._in_h * self._in_w, 1)
"pos_grid", # Register as buffer so it's moved to the correct device, etc.
torch.cat( self.register_buffer("pos_grid", torch.cat([_pos_x, _pos_y], dim=1))
[_pos_x.reshape(self._in_h * self._in_w, 1), _pos_y.reshape(self._in_h * self._in_w, 1)],
dim=1,
).float(),
)
def forward(self, features: Tensor) -> Tensor: def forward(self, features: Tensor) -> Tensor:
"""
Args:
features: (B, C, H, W) input feature maps.
Returns:
(B, K, 2) image-space coordinates of keypoints.
"""
if self.nets is not None: if self.nets is not None:
features = self.nets(features) features = self.nets(features)
@ -346,7 +355,7 @@ class SpatialSoftmax(nn.Module):
# 2d softmax normalization # 2d softmax normalization
attention = F.softmax(features / self.temperature, dim=-1) attention = F.softmax(features / self.temperature, dim=-1)
# [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions # [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions
expected_xy = torch.matmul(attention, self.pos_grid) expected_xy = attention @ self.pos_grid
# reshape to [B, K, 2] # reshape to [B, K, 2]
feature_keypoints = expected_xy.view(-1, self._out_c, 2) feature_keypoints = expected_xy.view(-1, self._out_c, 2)