move and ref

This commit is contained in:
Akshay Kashyap 2024-05-13 16:10:39 -04:00
parent 10c75790c1
commit 91f77e278b
1 changed files with 65 additions and 48 deletions

View File

@ -139,54 +139,6 @@ def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMSche
raise ValueError(f"Unsupported noise scheduler type {name}")
class SpatialSoftmax(nn.Module):
"""
Implementation of the Spatial Soft Argmax layer described in "Deep Spatial Autoencoders
for Visuomotor Learning" by Finn et al. (https://arxiv.org/pdf/1509.06113)
Meant to be a minimal port of the robomimic implementation.
"""
def __init__(self, input_shape, num_kp=None, temperature=1.0, learnable_temperature=False):
super().__init__()
assert len(input_shape) == 3
self._in_c, self._in_h, self._in_w = input_shape
if num_kp is not None:
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
self._out_c = num_kp
else:
self.nets = None
self._out_c = self._in_c
self.temperature = nn.Parameter(torch.tensor(temperature), requires_grad=learnable_temperature)
_pos_x, _pos_y = torch.meshgrid(
torch.linspace(-1.0, 1.0, self._in_w), torch.linspace(-1.0, 1.0, self._in_h), indexing="xy"
)
# Register as buffers so they are moved to the correct device, etc.
self.register_buffer("pos_x", _pos_x.reshape(1, self._in_h * self._in_w))
self.register_buffer("pos_y", _pos_y.reshape(1, self._in_h * self._in_w))
def forward(self, features: Tensor) -> Tensor:
if self.nets is not None:
features = self.nets(features)
# [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
features = features.reshape(-1, self._in_h * self._in_w)
# 2d softmax normalization
attention = F.softmax(features / self.temperature, dim=-1)
# [1, H * W] x [B * K, H * W] -> [B * K, 1] for spatial coordinate mean in x and y dimensions
expected_x = torch.sum(self.pos_x * attention, dim=1, keepdim=True)
expected_y = torch.sum(self.pos_y * attention, dim=1, keepdim=True)
# stack to [B * K, 2]
expected_xy = torch.cat([expected_x, expected_y], 1)
# reshape to [B, K, 2]
feature_keypoints = expected_xy.view(-1, self._out_c, 2)
return feature_keypoints
class DiffusionModel(nn.Module):
def __init__(self, config: DiffusionConfig):
super().__init__()
@ -336,6 +288,71 @@ class DiffusionModel(nn.Module):
return loss.mean()
class SpatialSoftmax(nn.Module):
"""
Spatial Soft Argmax layer described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
(https://arxiv.org/pdf/1509.06113)
A minimal port of the robomimic implementation.
At a high level, this operation takes 2D feature maps (from a convnet/ViT/etc.) and returns
the "center of mass" of activations of each channel, i.e., spatial keypoints for your policy to focus on.
Example: take feature maps of size (512x10x12). We will generate a grid of normalized coordinates (10x12x2):
-----------------------------------------------------
| (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) |
| (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) |
| ... | ... | ... | ... |
| (-1., 1.) | (-0.82, 1.) | ... | (1., 1.) |
-----------------------------------------------------
We apply channel-wise softmax over the activations (512x120) and compute dot product with the coordinates (120x2)
to get expected points of maximal activation (512x2).
Optionally, can also learn a mapping from the feature maps to a lower dimensional space (num_kp < in_c) before computing the argmax
if we'd like to focus on a smaller number of keypoints.
"""
def __init__(self, input_shape, num_kp=None, temperature=1.0, learnable_temperature=False):
super().__init__()
assert len(input_shape) == 3
self._in_c, self._in_h, self._in_w = input_shape
if num_kp is not None:
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
self._out_c = num_kp
else:
self.nets = None
self._out_c = self._in_c
self.temperature = nn.Parameter(torch.tensor(temperature), requires_grad=learnable_temperature)
_pos_x, _pos_y = torch.meshgrid(
torch.linspace(-1.0, 1.0, self._in_w), torch.linspace(-1.0, 1.0, self._in_h), indexing="xy"
)
# Register as buffers so they are moved to the correct device, etc.
self.register_buffer(
"pos_grid",
torch.cat(
[_pos_x.reshape(self._in_h * self._in_w, 1), _pos_y.reshape(self._in_h * self._in_w, 1)],
dim=1,
).float(),
)
def forward(self, features: Tensor) -> Tensor:
if self.nets is not None:
features = self.nets(features)
# [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
features = features.reshape(-1, self._in_h * self._in_w)
# 2d softmax normalization
attention = F.softmax(features / self.temperature, dim=-1)
# [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions
expected_xy = torch.matmul(attention, self.pos_grid)
# reshape to [B, K, 2]
feature_keypoints = expected_xy.view(-1, self._out_c, 2)
return feature_keypoints
class DiffusionRgbEncoder(nn.Module):
"""Encoder an RGB image into a 1D feature vector.