nit
This commit is contained in:
parent
91f77e278b
commit
59aa172c17
|
@ -290,14 +290,14 @@ class DiffusionModel(nn.Module):
|
|||
|
||||
class SpatialSoftmax(nn.Module):
|
||||
"""
|
||||
Spatial Soft Argmax layer described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
|
||||
Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
|
||||
(https://arxiv.org/pdf/1509.06113)
|
||||
A minimal port of the robomimic implementation.
|
||||
|
||||
At a high level, this operation takes 2D feature maps (from a convnet/ViT/etc.) and returns
|
||||
the "center of mass" of activations of each channel, i.e., spatial keypoints for your policy to focus on.
|
||||
At a high level, this takes 2D feature maps (from a convnet/ViT/etc.) and returns the "center of mass"
|
||||
of activations of each channel, i.e., spatial keypoints for the policy to focus on.
|
||||
|
||||
Example: take feature maps of size (512x10x12). We will generate a grid of normalized coordinates (10x12x2):
|
||||
Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2):
|
||||
-----------------------------------------------------
|
||||
| (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) |
|
||||
| (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) |
|
||||
|
@ -307,11 +307,18 @@ class SpatialSoftmax(nn.Module):
|
|||
We apply channel-wise softmax over the activations (512x120) and compute dot product with the coordinates (120x2)
|
||||
to get expected points of maximal activation (512x2).
|
||||
|
||||
Optionally, can also learn a mapping from the feature maps to a lower dimensional space (num_kp < in_c) before computing the argmax
|
||||
if we'd like to focus on a smaller number of keypoints.
|
||||
Optionally, can also learn a linear mapping from the feature maps to a lower/higher dimensional space using a conv1x1
|
||||
before computing the softmax.
|
||||
"""
|
||||
|
||||
def __init__(self, input_shape, num_kp=None, temperature=1.0, learnable_temperature=False):
|
||||
"""
|
||||
Args:
|
||||
input_shape (list): (C, H, W) input feature map shape.
|
||||
num_kp (int): number of keypoints to output. If None, output will have the same number of channels as input.
|
||||
temperature (float): temperature for softmax normalization.
|
||||
learnable_temperature (bool): whether to learn the temperature parameter.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
assert len(input_shape) == 3
|
||||
|
@ -328,16 +335,18 @@ class SpatialSoftmax(nn.Module):
|
|||
_pos_x, _pos_y = torch.meshgrid(
|
||||
torch.linspace(-1.0, 1.0, self._in_w), torch.linspace(-1.0, 1.0, self._in_h), indexing="xy"
|
||||
)
|
||||
# Register as buffers so they are moved to the correct device, etc.
|
||||
self.register_buffer(
|
||||
"pos_grid",
|
||||
torch.cat(
|
||||
[_pos_x.reshape(self._in_h * self._in_w, 1), _pos_y.reshape(self._in_h * self._in_w, 1)],
|
||||
dim=1,
|
||||
).float(),
|
||||
)
|
||||
_pos_x = _pos_x.reshape(self._in_h * self._in_w, 1)
|
||||
_pos_y = _pos_y.reshape(self._in_h * self._in_w, 1)
|
||||
# Register as buffer so it's moved to the correct device, etc.
|
||||
self.register_buffer("pos_grid", torch.cat([_pos_x, _pos_y], dim=1))
|
||||
|
||||
def forward(self, features: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
features: (B, C, H, W) input feature maps.
|
||||
Returns:
|
||||
(B, K, 2) image-space coordinates of keypoints.
|
||||
"""
|
||||
if self.nets is not None:
|
||||
features = self.nets(features)
|
||||
|
||||
|
@ -346,7 +355,7 @@ class SpatialSoftmax(nn.Module):
|
|||
# 2d softmax normalization
|
||||
attention = F.softmax(features / self.temperature, dim=-1)
|
||||
# [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions
|
||||
expected_xy = torch.matmul(attention, self.pos_grid)
|
||||
expected_xy = attention @ self.pos_grid
|
||||
# reshape to [B, K, 2]
|
||||
feature_keypoints = expected_xy.view(-1, self._out_c, 2)
|
||||
|
||||
|
|
Loading…
Reference in New Issue