use np instead of torch
This commit is contained in:
parent
1bfddb3619
commit
c1bc8410b4
|
@ -10,6 +10,7 @@ from collections import deque
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
import torchvision
|
import torchvision
|
||||||
|
@ -291,8 +292,7 @@ class DiffusionModel(nn.Module):
|
||||||
class SpatialSoftmax(nn.Module):
|
class SpatialSoftmax(nn.Module):
|
||||||
"""
|
"""
|
||||||
Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
|
Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al.
|
||||||
(https://arxiv.org/pdf/1509.06113)
|
(https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation.
|
||||||
A minimal port of the robomimic implementation.
|
|
||||||
|
|
||||||
At a high level, this takes 2D feature maps (from a convnet/ViT/etc.) and returns the "center of mass"
|
At a high level, this takes 2D feature maps (from a convnet/ViT/etc.) and returns the "center of mass"
|
||||||
of activations of each channel, i.e., spatial keypoints for the policy to focus on.
|
of activations of each channel, i.e., spatial keypoints for the policy to focus on.
|
||||||
|
@ -307,7 +307,7 @@ class SpatialSoftmax(nn.Module):
|
||||||
We apply channel-wise softmax over the activations (512x120) and compute dot product with the coordinates (120x2)
|
We apply channel-wise softmax over the activations (512x120) and compute dot product with the coordinates (120x2)
|
||||||
to get expected points of maximal activation (512x2).
|
to get expected points of maximal activation (512x2).
|
||||||
|
|
||||||
Optionally, can also learn a linear mapping from the feature maps to a lower/higher dimensional space using a conv1x1
|
Optionally, when num_kp != None, can learn a linear mapping from the feature maps to a lower/higher dimensional space using a conv1x1
|
||||||
before computing the softmax.
|
before computing the softmax.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -332,13 +332,17 @@ class SpatialSoftmax(nn.Module):
|
||||||
self._out_c = self._in_c
|
self._out_c = self._in_c
|
||||||
|
|
||||||
self.temperature = nn.Parameter(torch.tensor(temperature), requires_grad=learnable_temperature)
|
self.temperature = nn.Parameter(torch.tensor(temperature), requires_grad=learnable_temperature)
|
||||||
_pos_x, _pos_y = torch.meshgrid(
|
|
||||||
torch.linspace(-1.0, 1.0, self._in_w), torch.linspace(-1.0, 1.0, self._in_h), indexing="xy"
|
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
|
||||||
|
# and cause a small degradation in pc_success of pre-trained models.
|
||||||
|
pos_x, pos_y = np.meshgrid(
|
||||||
|
np.linspace(-1., 1., self._in_w),
|
||||||
|
np.linspace(-1., 1., self._in_h)
|
||||||
)
|
)
|
||||||
_pos_x = _pos_x.reshape(self._in_h * self._in_w, 1)
|
pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float()
|
||||||
_pos_y = _pos_y.reshape(self._in_h * self._in_w, 1)
|
pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float()
|
||||||
# Register as buffer so it's moved to the correct device, etc.
|
# register as buffer so it's moved to the correct device, etc.
|
||||||
self.register_buffer("pos_grid", torch.cat([_pos_x, _pos_y], dim=1))
|
self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1))
|
||||||
|
|
||||||
def forward(self, features: Tensor) -> Tensor:
|
def forward(self, features: Tensor) -> Tensor:
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue