Minimal port of SpatialSoftmax

This commit is contained in:
Akshay Kashyap 2024-05-13 15:06:45 -04:00
parent 89c6be84ca
commit 10c75790c1
1 changed files with 48 additions and 1 deletions

View File

@ -16,7 +16,6 @@ import torchvision
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from huggingface_hub import PyTorchModelHubMixin
from robomimic.models.base_nets import SpatialSoftmax
from torch import Tensor, nn
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
@ -140,6 +139,54 @@ def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMSche
raise ValueError(f"Unsupported noise scheduler type {name}")
class SpatialSoftmax(nn.Module):
"""
Implementation of the Spatial Soft Argmax layer described in "Deep Spatial Autoencoders
for Visuomotor Learning" by Finn et al. (https://arxiv.org/pdf/1509.06113)
Meant to be a minimal port of the robomimic implementation.
"""
def __init__(self, input_shape, num_kp=None, temperature=1.0, learnable_temperature=False):
super().__init__()
assert len(input_shape) == 3
self._in_c, self._in_h, self._in_w = input_shape
if num_kp is not None:
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
self._out_c = num_kp
else:
self.nets = None
self._out_c = self._in_c
self.temperature = nn.Parameter(torch.tensor(temperature), requires_grad=learnable_temperature)
_pos_x, _pos_y = torch.meshgrid(
torch.linspace(-1.0, 1.0, self._in_w), torch.linspace(-1.0, 1.0, self._in_h), indexing="xy"
)
# Register as buffers so they are moved to the correct device, etc.
self.register_buffer("pos_x", _pos_x.reshape(1, self._in_h * self._in_w))
self.register_buffer("pos_y", _pos_y.reshape(1, self._in_h * self._in_w))
def forward(self, features: Tensor) -> Tensor:
if self.nets is not None:
features = self.nets(features)
# [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
features = features.reshape(-1, self._in_h * self._in_w)
# 2d softmax normalization
attention = F.softmax(features / self.temperature, dim=-1)
# [1, H * W] x [B * K, H * W] -> [B * K, 1] for spatial coordinate mean in x and y dimensions
expected_x = torch.sum(self.pos_x * attention, dim=1, keepdim=True)
expected_y = torch.sum(self.pos_y * attention, dim=1, keepdim=True)
# stack to [B * K, 2]
expected_xy = torch.cat([expected_x, expected_y], 1)
# reshape to [B, K, 2]
feature_keypoints = expected_xy.view(-1, self._out_c, 2)
return feature_keypoints
class DiffusionModel(nn.Module):
def __init__(self, config: DiffusionConfig):
super().__init__()