move and ref
This commit is contained in:
parent
10c75790c1
commit
91f77e278b
|
@ -139,54 +139,6 @@ def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMSche
|
|||
raise ValueError(f"Unsupported noise scheduler type {name}")
|
||||
|
||||
|
||||
class SpatialSoftmax(nn.Module):
|
||||
"""
|
||||
Implementation of the Spatial Soft Argmax layer described in "Deep Spatial Autoencoders
|
||||
for Visuomotor Learning" by Finn et al. (https://arxiv.org/pdf/1509.06113)
|
||||
|
||||
Meant to be a minimal port of the robomimic implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, input_shape, num_kp=None, temperature=1.0, learnable_temperature=False):
|
||||
super().__init__()
|
||||
|
||||
assert len(input_shape) == 3
|
||||
self._in_c, self._in_h, self._in_w = input_shape
|
||||
|
||||
if num_kp is not None:
|
||||
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
|
||||
self._out_c = num_kp
|
||||
else:
|
||||
self.nets = None
|
||||
self._out_c = self._in_c
|
||||
|
||||
self.temperature = nn.Parameter(torch.tensor(temperature), requires_grad=learnable_temperature)
|
||||
_pos_x, _pos_y = torch.meshgrid(
|
||||
torch.linspace(-1.0, 1.0, self._in_w), torch.linspace(-1.0, 1.0, self._in_h), indexing="xy"
|
||||
)
|
||||
# Register as buffers so they are moved to the correct device, etc.
|
||||
self.register_buffer("pos_x", _pos_x.reshape(1, self._in_h * self._in_w))
|
||||
self.register_buffer("pos_y", _pos_y.reshape(1, self._in_h * self._in_w))
|
||||
|
||||
def forward(self, features: Tensor) -> Tensor:
|
||||
if self.nets is not None:
|
||||
features = self.nets(features)
|
||||
|
||||
# [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
|
||||
features = features.reshape(-1, self._in_h * self._in_w)
|
||||
# 2d softmax normalization
|
||||
attention = F.softmax(features / self.temperature, dim=-1)
|
||||
# [1, H * W] x [B * K, H * W] -> [B * K, 1] for spatial coordinate mean in x and y dimensions
|
||||
expected_x = torch.sum(self.pos_x * attention, dim=1, keepdim=True)
|
||||
expected_y = torch.sum(self.pos_y * attention, dim=1, keepdim=True)
|
||||
# stack to [B * K, 2]
|
||||
expected_xy = torch.cat([expected_x, expected_y], 1)
|
||||
# reshape to [B, K, 2]
|
||||
feature_keypoints = expected_xy.view(-1, self._out_c, 2)
|
||||
|
||||
return feature_keypoints
|
||||
|
||||
|
||||
class DiffusionModel(nn.Module):
|
||||
def __init__(self, config: DiffusionConfig):
|
||||
super().__init__()
|
||||
|
@ -336,6 +288,71 @@ class DiffusionModel(nn.Module):
|
|||
return loss.mean()
|
||||
|
||||
|
||||
class SpatialSoftmax(nn.Module):
|
||||
"""
|
||||
Spatial Soft Argmax layer described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
|
||||
(https://arxiv.org/pdf/1509.06113)
|
||||
A minimal port of the robomimic implementation.
|
||||
|
||||
At a high level, this operation takes 2D feature maps (from a convnet/ViT/etc.) and returns
|
||||
the "center of mass" of activations of each channel, i.e., spatial keypoints for your policy to focus on.
|
||||
|
||||
Example: take feature maps of size (512x10x12). We will generate a grid of normalized coordinates (10x12x2):
|
||||
-----------------------------------------------------
|
||||
| (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) |
|
||||
| (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) |
|
||||
| ... | ... | ... | ... |
|
||||
| (-1., 1.) | (-0.82, 1.) | ... | (1., 1.) |
|
||||
-----------------------------------------------------
|
||||
We apply channel-wise softmax over the activations (512x120) and compute dot product with the coordinates (120x2)
|
||||
to get expected points of maximal activation (512x2).
|
||||
|
||||
Optionally, can also learn a mapping from the feature maps to a lower dimensional space (num_kp < in_c) before computing the argmax
|
||||
if we'd like to focus on a smaller number of keypoints.
|
||||
"""
|
||||
|
||||
def __init__(self, input_shape, num_kp=None, temperature=1.0, learnable_temperature=False):
|
||||
super().__init__()
|
||||
|
||||
assert len(input_shape) == 3
|
||||
self._in_c, self._in_h, self._in_w = input_shape
|
||||
|
||||
if num_kp is not None:
|
||||
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
|
||||
self._out_c = num_kp
|
||||
else:
|
||||
self.nets = None
|
||||
self._out_c = self._in_c
|
||||
|
||||
self.temperature = nn.Parameter(torch.tensor(temperature), requires_grad=learnable_temperature)
|
||||
_pos_x, _pos_y = torch.meshgrid(
|
||||
torch.linspace(-1.0, 1.0, self._in_w), torch.linspace(-1.0, 1.0, self._in_h), indexing="xy"
|
||||
)
|
||||
# Register as buffers so they are moved to the correct device, etc.
|
||||
self.register_buffer(
|
||||
"pos_grid",
|
||||
torch.cat(
|
||||
[_pos_x.reshape(self._in_h * self._in_w, 1), _pos_y.reshape(self._in_h * self._in_w, 1)],
|
||||
dim=1,
|
||||
).float(),
|
||||
)
|
||||
|
||||
def forward(self, features: Tensor) -> Tensor:
|
||||
if self.nets is not None:
|
||||
features = self.nets(features)
|
||||
|
||||
# [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
|
||||
features = features.reshape(-1, self._in_h * self._in_w)
|
||||
# 2d softmax normalization
|
||||
attention = F.softmax(features / self.temperature, dim=-1)
|
||||
# [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions
|
||||
expected_xy = torch.matmul(attention, self.pos_grid)
|
||||
# reshape to [B, K, 2]
|
||||
feature_keypoints = expected_xy.view(-1, self._out_c, 2)
|
||||
|
||||
return feature_keypoints
|
||||
|
||||
|
||||
class DiffusionRgbEncoder(nn.Module):
|
||||
"""Encoder an RGB image into a 1D feature vector.
|
||||
|
||||
|
|
Loading…
Reference in New Issue