nit comments

This commit is contained in:
Akshay Kashyap 2024-05-13 18:47:17 -04:00
parent eaecb5fc57
commit 985cc4f0cf
1 changed files with 3 additions and 4 deletions

View File

@ -294,7 +294,7 @@ class SpatialSoftmax(nn.Module):
Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al. Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
(https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation. (https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation.
At a high level, this takes 2D feature maps (from a convnet/ViT/etc.) and returns the "center of mass" At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass"
of activations of each channel, i.e., spatial keypoints for the policy to focus on. of activations of each channel, i.e., spatial keypoints for the policy to focus on.
Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2): Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2):
@ -332,13 +332,12 @@ class SpatialSoftmax(nn.Module):
self._out_c = self._in_c self._out_c = self._in_c
self.temperature = nn.Parameter(torch.tensor(temperature), requires_grad=learnable_temperature) self.temperature = nn.Parameter(torch.tensor(temperature), requires_grad=learnable_temperature)
# we could use torch.linspace directly but that seems to behave slightly differently than numpy # we could use torch.linspace directly but that seems to behave slightly differently than numpy
# and cause a small degradation in pc_success of pre-trained models. # and causes a small degradation in pc_success of pre-trained models.
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)) pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float() pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float() pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
# register as buffer so it's moved to the correct device, etc. # register as buffer so it's moved to the correct device.
self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1)) self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1))
def forward(self, features: Tensor) -> Tensor: def forward(self, features: Tensor) -> Tensor: