Minimal port of SpatialSoftmax
This commit is contained in:
parent
89c6be84ca
commit
10c75790c1
|
@ -16,7 +16,6 @@ import torchvision
|
|||
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from robomimic.models.base_nets import SpatialSoftmax
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
|
@ -140,6 +139,54 @@ def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMSche
|
|||
raise ValueError(f"Unsupported noise scheduler type {name}")
|
||||
|
||||
|
||||
class SpatialSoftmax(nn.Module):
|
||||
"""
|
||||
Implementation of the Spatial Soft Argmax layer described in "Deep Spatial Autoencoders
|
||||
for Visuomotor Learning" by Finn et al. (https://arxiv.org/pdf/1509.06113)
|
||||
|
||||
Meant to be a minimal port of the robomimic implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, input_shape, num_kp=None, temperature=1.0, learnable_temperature=False):
|
||||
super().__init__()
|
||||
|
||||
assert len(input_shape) == 3
|
||||
self._in_c, self._in_h, self._in_w = input_shape
|
||||
|
||||
if num_kp is not None:
|
||||
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
|
||||
self._out_c = num_kp
|
||||
else:
|
||||
self.nets = None
|
||||
self._out_c = self._in_c
|
||||
|
||||
self.temperature = nn.Parameter(torch.tensor(temperature), requires_grad=learnable_temperature)
|
||||
_pos_x, _pos_y = torch.meshgrid(
|
||||
torch.linspace(-1.0, 1.0, self._in_w), torch.linspace(-1.0, 1.0, self._in_h), indexing="xy"
|
||||
)
|
||||
# Register as buffers so they are moved to the correct device, etc.
|
||||
self.register_buffer("pos_x", _pos_x.reshape(1, self._in_h * self._in_w))
|
||||
self.register_buffer("pos_y", _pos_y.reshape(1, self._in_h * self._in_w))
|
||||
|
||||
def forward(self, features: Tensor) -> Tensor:
|
||||
if self.nets is not None:
|
||||
features = self.nets(features)
|
||||
|
||||
# [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
|
||||
features = features.reshape(-1, self._in_h * self._in_w)
|
||||
# 2d softmax normalization
|
||||
attention = F.softmax(features / self.temperature, dim=-1)
|
||||
# [1, H * W] x [B * K, H * W] -> [B * K, 1] for spatial coordinate mean in x and y dimensions
|
||||
expected_x = torch.sum(self.pos_x * attention, dim=1, keepdim=True)
|
||||
expected_y = torch.sum(self.pos_y * attention, dim=1, keepdim=True)
|
||||
# stack to [B * K, 2]
|
||||
expected_xy = torch.cat([expected_x, expected_y], 1)
|
||||
# reshape to [B, K, 2]
|
||||
feature_keypoints = expected_xy.view(-1, self._out_c, 2)
|
||||
|
||||
return feature_keypoints
|
||||
|
||||
|
||||
class DiffusionModel(nn.Module):
|
||||
def __init__(self, config: DiffusionConfig):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue