SpatialSoftmax from modeling_diffusion.py

This commit is contained in:
jayLEE0301 2024-06-03 17:43:11 -04:00
parent fd95f2e89b
commit a33fbd4d44
1 changed files with 69 additions and 1 deletions

View File

@ -19,7 +19,6 @@ from torch.cuda.amp import autocast
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
from huggingface_hub import PyTorchModelHubMixin
from robomimic.models.base_nets import SpatialSoftmax
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
@ -124,6 +123,75 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
return loss
class SpatialSoftmax(nn.Module):
"""
Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
(https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation.
At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass"
of activations of each channel, i.e., keypoints in the image space for the policy to focus on.
Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2):
-----------------------------------------------------
| (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) |
| (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) |
| ... | ... | ... | ... |
| (-1., 1.) | (-0.82, 1.) | ... | (1., 1.) |
-----------------------------------------------------
This is achieved by applying channel-wise softmax over the activations (512x120) and computing the dot
product with the coordinates (120x2) to get expected points of maximal activation (512x2).
The example above results in 512 keypoints (corresponding to the 512 input channels). We can optionally
provide num_kp != None to control the number of keypoints. This is achieved by a first applying a learnable
linear mapping (in_channels, H, W) -> (num_kp, H, W).
"""
def __init__(self, input_shape, num_kp=None):
"""
Args:
input_shape (list): (C, H, W) input feature map shape.
num_kp (int): number of keypoints in output. If None, output will have the same number of channels as input.
"""
super().__init__()
assert len(input_shape) == 3
self._in_c, self._in_h, self._in_w = input_shape
if num_kp is not None:
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
self._out_c = num_kp
else:
self.nets = None
self._out_c = self._in_c
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
# and causes a small degradation in pc_success of pre-trained models.
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
# register as buffer so it's moved to the correct device.
self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1))
def forward(self, features: Tensor) -> Tensor:
"""
Args:
features: (B, C, H, W) input feature maps.
Returns:
(B, K, 2) image-space coordinates of keypoints.
"""
if self.nets is not None:
features = self.nets(features)
# [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
features = features.reshape(-1, self._in_h * self._in_w)
# 2d softmax normalization
attention = F.softmax(features, dim=-1)
# [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions
expected_xy = attention @ self.pos_grid
# reshape to [B, K, 2]
feature_keypoints = expected_xy.view(-1, self._out_c, 2)
return feature_keypoints
class VQBeTModel(nn.Module):
"""VQ-BeT: The underlying neural network for VQ-BeT