This commit is contained in:
Akshay Kashyap 2024-05-13 13:42:09 -07:00
parent 91f77e278b
commit 59aa172c17
1 changed files with 24 additions and 15 deletions

View File

@ -290,14 +290,14 @@ class DiffusionModel(nn.Module):
class SpatialSoftmax(nn.Module):
"""
Spatial Soft Argmax layer described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
(https://arxiv.org/pdf/1509.06113)
A minimal port of the robomimic implementation.
At a high level, this operation takes 2D feature maps (from a convnet/ViT/etc.) and returns
the "center of mass" of activations of each channel, i.e., spatial keypoints for your policy to focus on.
At a high level, this takes 2D feature maps (from a convnet/ViT/etc.) and returns the "center of mass"
of activations of each channel, i.e., spatial keypoints for the policy to focus on.
Example: take feature maps of size (512x10x12). We will generate a grid of normalized coordinates (10x12x2):
Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2):
-----------------------------------------------------
| (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) |
| (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) |
@ -307,11 +307,18 @@ class SpatialSoftmax(nn.Module):
We apply channel-wise softmax over the activations (512x120) and compute dot product with the coordinates (120x2)
to get expected points of maximal activation (512x2).
Optionally, can also learn a mapping from the feature maps to a lower dimensional space (num_kp < in_c) before computing the argmax
if we'd like to focus on a smaller number of keypoints.
Optionally, can also learn a linear mapping from the feature maps to a lower/higher dimensional space using a conv1x1
before computing the softmax.
"""
def __init__(self, input_shape, num_kp=None, temperature=1.0, learnable_temperature=False):
"""
Args:
input_shape (list): (C, H, W) input feature map shape.
num_kp (int): number of keypoints to output. If None, output will have the same number of channels as input.
temperature (float): temperature for softmax normalization.
learnable_temperature (bool): whether to learn the temperature parameter.
"""
super().__init__()
assert len(input_shape) == 3
@ -328,16 +335,18 @@ class SpatialSoftmax(nn.Module):
_pos_x, _pos_y = torch.meshgrid(
torch.linspace(-1.0, 1.0, self._in_w), torch.linspace(-1.0, 1.0, self._in_h), indexing="xy"
)
# Register as buffers so they are moved to the correct device, etc.
self.register_buffer(
"pos_grid",
torch.cat(
[_pos_x.reshape(self._in_h * self._in_w, 1), _pos_y.reshape(self._in_h * self._in_w, 1)],
dim=1,
).float(),
)
_pos_x = _pos_x.reshape(self._in_h * self._in_w, 1)
_pos_y = _pos_y.reshape(self._in_h * self._in_w, 1)
# Register as buffer so it's moved to the correct device, etc.
self.register_buffer("pos_grid", torch.cat([_pos_x, _pos_y], dim=1))
def forward(self, features: Tensor) -> Tensor:
"""
Args:
features: (B, C, H, W) input feature maps.
Returns:
(B, K, 2) image-space coordinates of keypoints.
"""
if self.nets is not None:
features = self.nets(features)
@ -346,7 +355,7 @@ class SpatialSoftmax(nn.Module):
# 2d softmax normalization
attention = F.softmax(features / self.temperature, dim=-1)
# [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions
expected_xy = torch.matmul(attention, self.pos_grid)
expected_xy = attention @ self.pos_grid
# reshape to [B, K, 2]
feature_keypoints = expected_xy.view(-1, self._out_c, 2)