SpatialSoftmax from modeling_diffusion.py
This commit is contained in:
parent
fd95f2e89b
commit
a33fbd4d44
|
@ -19,7 +19,6 @@ from torch.cuda.amp import autocast
|
|||
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from robomimic.models.base_nets import SpatialSoftmax
|
||||
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
|
||||
|
@ -124,6 +123,75 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
|
||||
return loss
|
||||
|
||||
class SpatialSoftmax(nn.Module):
|
||||
"""
|
||||
Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
|
||||
(https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation.
|
||||
|
||||
At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass"
|
||||
of activations of each channel, i.e., keypoints in the image space for the policy to focus on.
|
||||
|
||||
Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2):
|
||||
-----------------------------------------------------
|
||||
| (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) |
|
||||
| (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) |
|
||||
| ... | ... | ... | ... |
|
||||
| (-1., 1.) | (-0.82, 1.) | ... | (1., 1.) |
|
||||
-----------------------------------------------------
|
||||
This is achieved by applying channel-wise softmax over the activations (512x120) and computing the dot
|
||||
product with the coordinates (120x2) to get expected points of maximal activation (512x2).
|
||||
|
||||
The example above results in 512 keypoints (corresponding to the 512 input channels). We can optionally
|
||||
provide num_kp != None to control the number of keypoints. This is achieved by a first applying a learnable
|
||||
linear mapping (in_channels, H, W) -> (num_kp, H, W).
|
||||
"""
|
||||
|
||||
def __init__(self, input_shape, num_kp=None):
|
||||
"""
|
||||
Args:
|
||||
input_shape (list): (C, H, W) input feature map shape.
|
||||
num_kp (int): number of keypoints in output. If None, output will have the same number of channels as input.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
assert len(input_shape) == 3
|
||||
self._in_c, self._in_h, self._in_w = input_shape
|
||||
|
||||
if num_kp is not None:
|
||||
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
|
||||
self._out_c = num_kp
|
||||
else:
|
||||
self.nets = None
|
||||
self._out_c = self._in_c
|
||||
|
||||
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
|
||||
# and causes a small degradation in pc_success of pre-trained models.
|
||||
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
|
||||
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
|
||||
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
|
||||
# register as buffer so it's moved to the correct device.
|
||||
self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1))
|
||||
|
||||
def forward(self, features: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
features: (B, C, H, W) input feature maps.
|
||||
Returns:
|
||||
(B, K, 2) image-space coordinates of keypoints.
|
||||
"""
|
||||
if self.nets is not None:
|
||||
features = self.nets(features)
|
||||
|
||||
# [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
|
||||
features = features.reshape(-1, self._in_h * self._in_w)
|
||||
# 2d softmax normalization
|
||||
attention = F.softmax(features, dim=-1)
|
||||
# [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions
|
||||
expected_xy = attention @ self.pos_grid
|
||||
# reshape to [B, K, 2]
|
||||
feature_keypoints = expected_xy.view(-1, self._out_c, 2)
|
||||
|
||||
return feature_keypoints
|
||||
|
||||
class VQBeTModel(nn.Module):
|
||||
"""VQ-BeT: The underlying neural network for VQ-BeT
|
||||
|
|
Loading…
Reference in New Issue