rm temperature and edit docstrings
This commit is contained in:
parent
09b983f2ff
commit
1c343f4d06
|
@ -1,7 +1,6 @@
|
|||
"""Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
|
||||
|
||||
TODO(alexander-soare):
|
||||
- Remove reliance on Robomimic for SpatialSoftmax.
|
||||
- Remove reliance on diffusers for DDPMScheduler and LR scheduler.
|
||||
"""
|
||||
|
||||
|
@ -307,17 +306,15 @@ class SpatialSoftmax(nn.Module):
|
|||
We apply channel-wise softmax over the activations (512x120) and compute dot product with the coordinates (120x2)
|
||||
to get expected points of maximal activation (512x2).
|
||||
|
||||
Optionally, when num_kp != None, can learn a linear mapping from the feature maps to a lower/higher dimensional
|
||||
space using a conv1x1 before computing the softmax.
|
||||
Optionally, when num_kp != None, can learn a linear mapping (of shape (num_kp, H, W)) from the input feature maps
|
||||
using a conv1x1 before computing softmax.
|
||||
"""
|
||||
|
||||
def __init__(self, input_shape, num_kp=None, temperature=1.0, learnable_temperature=False):
|
||||
def __init__(self, input_shape, num_kp=None):
|
||||
"""
|
||||
Args:
|
||||
input_shape (list): (C, H, W) input feature map shape.
|
||||
num_kp (int): number of keypoints in output. If None, output will have the same number of channels as input.
|
||||
temperature (float): temperature for softmax normalization.
|
||||
learnable_temperature (bool): whether to learn the temperature parameter.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
|
@ -331,7 +328,6 @@ class SpatialSoftmax(nn.Module):
|
|||
self.nets = None
|
||||
self._out_c = self._in_c
|
||||
|
||||
self.temperature = nn.Parameter(torch.tensor(temperature), requires_grad=learnable_temperature)
|
||||
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
|
||||
# and causes a small degradation in pc_success of pre-trained models.
|
||||
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
|
||||
|
@ -353,7 +349,7 @@ class SpatialSoftmax(nn.Module):
|
|||
# [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
|
||||
features = features.reshape(-1, self._in_h * self._in_w)
|
||||
# 2d softmax normalization
|
||||
attention = F.softmax(features / self.temperature, dim=-1)
|
||||
attention = F.softmax(features, dim=-1)
|
||||
# [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions
|
||||
expected_xy = attention @ self.pos_grid
|
||||
# reshape to [B, K, 2]
|
||||
|
|
Loading…
Reference in New Issue