use np instead of torch

This commit is contained in:
Akshay Kashyap 2024-05-13 15:35:10 -07:00
parent 1bfddb3619
commit c1bc8410b4
1 changed files with 13 additions and 9 deletions

View File

@ -10,6 +10,7 @@ from collections import deque
from typing import Callable from typing import Callable
import einops import einops
import numpy as np
import torch import torch
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
import torchvision import torchvision
@ -291,8 +292,7 @@ class DiffusionModel(nn.Module):
class SpatialSoftmax(nn.Module): class SpatialSoftmax(nn.Module):
""" """
Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al. Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
(https://arxiv.org/pdf/1509.06113) (https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation.
A minimal port of the robomimic implementation.
At a high level, this takes 2D feature maps (from a convnet/ViT/etc.) and returns the "center of mass" At a high level, this takes 2D feature maps (from a convnet/ViT/etc.) and returns the "center of mass"
of activations of each channel, i.e., spatial keypoints for the policy to focus on. of activations of each channel, i.e., spatial keypoints for the policy to focus on.
@ -307,7 +307,7 @@ class SpatialSoftmax(nn.Module):
We apply channel-wise softmax over the activations (512x120) and compute dot product with the coordinates (120x2) We apply channel-wise softmax over the activations (512x120) and compute dot product with the coordinates (120x2)
to get expected points of maximal activation (512x2). to get expected points of maximal activation (512x2).
Optionally, can also learn a linear mapping from the feature maps to a lower/higher dimensional space using a conv1x1 Optionally, when num_kp != None, can learn a linear mapping from the feature maps to a lower/higher dimensional space using a conv1x1
before computing the softmax. before computing the softmax.
""" """
@ -332,13 +332,17 @@ class SpatialSoftmax(nn.Module):
self._out_c = self._in_c self._out_c = self._in_c
self.temperature = nn.Parameter(torch.tensor(temperature), requires_grad=learnable_temperature) self.temperature = nn.Parameter(torch.tensor(temperature), requires_grad=learnable_temperature)
_pos_x, _pos_y = torch.meshgrid(
torch.linspace(-1.0, 1.0, self._in_w), torch.linspace(-1.0, 1.0, self._in_h), indexing="xy" # we could use torch.linspace directly but that seems to behave slightly differently than numpy
# and cause a small degradation in pc_success of pre-trained models.
pos_x, pos_y = np.meshgrid(
np.linspace(-1., 1., self._in_w),
np.linspace(-1., 1., self._in_h)
) )
_pos_x = _pos_x.reshape(self._in_h * self._in_w, 1) pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
_pos_y = _pos_y.reshape(self._in_h * self._in_w, 1) pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
# Register as buffer so it's moved to the correct device, etc. # register as buffer so it's moved to the correct device, etc.
self.register_buffer("pos_grid", torch.cat([_pos_x, _pos_y], dim=1)) self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1))
def forward(self, features: Tensor) -> Tensor: def forward(self, features: Tensor) -> Tensor:
""" """