From 91f77e278b17fe7344b746e0cf92355837925095 Mon Sep 17 00:00:00 2001 From: Akshay Kashyap Date: Mon, 13 May 2024 16:10:39 -0400 Subject: [PATCH] move and ref --- .../policies/diffusion/modeling_diffusion.py | 113 ++++++++++-------- 1 file changed, 65 insertions(+), 48 deletions(-) diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 7c7debf1..4ba7a70f 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -139,54 +139,6 @@ def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMSche raise ValueError(f"Unsupported noise scheduler type {name}") -class SpatialSoftmax(nn.Module): - """ - Implementation of the Spatial Soft Argmax layer described in "Deep Spatial Autoencoders - for Visuomotor Learning" by Finn et al. (https://arxiv.org/pdf/1509.06113) - - Meant to be a minimal port of the robomimic implementation. - """ - - def __init__(self, input_shape, num_kp=None, temperature=1.0, learnable_temperature=False): - super().__init__() - - assert len(input_shape) == 3 - self._in_c, self._in_h, self._in_w = input_shape - - if num_kp is not None: - self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1) - self._out_c = num_kp - else: - self.nets = None - self._out_c = self._in_c - - self.temperature = nn.Parameter(torch.tensor(temperature), requires_grad=learnable_temperature) - _pos_x, _pos_y = torch.meshgrid( - torch.linspace(-1.0, 1.0, self._in_w), torch.linspace(-1.0, 1.0, self._in_h), indexing="xy" - ) - # Register as buffers so they are moved to the correct device, etc. - self.register_buffer("pos_x", _pos_x.reshape(1, self._in_h * self._in_w)) - self.register_buffer("pos_y", _pos_y.reshape(1, self._in_h * self._in_w)) - - def forward(self, features: Tensor) -> Tensor: - if self.nets is not None: - features = self.nets(features) - - # [B, K, H, W] -> [B * K, H * W] where K is number of keypoints - features = features.reshape(-1, self._in_h * self._in_w) - # 2d softmax normalization - attention = F.softmax(features / self.temperature, dim=-1) - # [1, H * W] x [B * K, H * W] -> [B * K, 1] for spatial coordinate mean in x and y dimensions - expected_x = torch.sum(self.pos_x * attention, dim=1, keepdim=True) - expected_y = torch.sum(self.pos_y * attention, dim=1, keepdim=True) - # stack to [B * K, 2] - expected_xy = torch.cat([expected_x, expected_y], 1) - # reshape to [B, K, 2] - feature_keypoints = expected_xy.view(-1, self._out_c, 2) - - return feature_keypoints - - class DiffusionModel(nn.Module): def __init__(self, config: DiffusionConfig): super().__init__() @@ -336,6 +288,71 @@ class DiffusionModel(nn.Module): return loss.mean() +class SpatialSoftmax(nn.Module): + """ + Spatial Soft Argmax layer described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al. + (https://arxiv.org/pdf/1509.06113) + A minimal port of the robomimic implementation. + + At a high level, this operation takes 2D feature maps (from a convnet/ViT/etc.) and returns + the "center of mass" of activations of each channel, i.e., spatial keypoints for your policy to focus on. + + Example: take feature maps of size (512x10x12). We will generate a grid of normalized coordinates (10x12x2): + ----------------------------------------------------- + | (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) | + | (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) | + | ... | ... | ... | ... | + | (-1., 1.) | (-0.82, 1.) | ... | (1., 1.) | + ----------------------------------------------------- + We apply channel-wise softmax over the activations (512x120) and compute dot product with the coordinates (120x2) + to get expected points of maximal activation (512x2). + + Optionally, can also learn a mapping from the feature maps to a lower dimensional space (num_kp < in_c) before computing the argmax + if we'd like to focus on a smaller number of keypoints. + """ + + def __init__(self, input_shape, num_kp=None, temperature=1.0, learnable_temperature=False): + super().__init__() + + assert len(input_shape) == 3 + self._in_c, self._in_h, self._in_w = input_shape + + if num_kp is not None: + self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1) + self._out_c = num_kp + else: + self.nets = None + self._out_c = self._in_c + + self.temperature = nn.Parameter(torch.tensor(temperature), requires_grad=learnable_temperature) + _pos_x, _pos_y = torch.meshgrid( + torch.linspace(-1.0, 1.0, self._in_w), torch.linspace(-1.0, 1.0, self._in_h), indexing="xy" + ) + # Register as buffers so they are moved to the correct device, etc. + self.register_buffer( + "pos_grid", + torch.cat( + [_pos_x.reshape(self._in_h * self._in_w, 1), _pos_y.reshape(self._in_h * self._in_w, 1)], + dim=1, + ).float(), + ) + + def forward(self, features: Tensor) -> Tensor: + if self.nets is not None: + features = self.nets(features) + + # [B, K, H, W] -> [B * K, H * W] where K is number of keypoints + features = features.reshape(-1, self._in_h * self._in_w) + # 2d softmax normalization + attention = F.softmax(features / self.temperature, dim=-1) + # [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions + expected_xy = torch.matmul(attention, self.pos_grid) + # reshape to [B, K, 2] + feature_keypoints = expected_xy.view(-1, self._out_c, 2) + + return feature_keypoints + + class DiffusionRgbEncoder(nn.Module): """Encoder an RGB image into a 1D feature vector.