rm temperature and edit docstrings

This commit is contained in:
Akshay Kashyap 2024-05-14 10:44:29 -04:00
parent 09b983f2ff
commit 1c343f4d06
1 changed files with 4 additions and 8 deletions

View File

@ -1,7 +1,6 @@
"""Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
TODO(alexander-soare):
- Remove reliance on Robomimic for SpatialSoftmax.
- Remove reliance on diffusers for DDPMScheduler and LR scheduler.
"""
@ -307,17 +306,15 @@ class SpatialSoftmax(nn.Module):
We apply channel-wise softmax over the activations (512x120) and compute dot product with the coordinates (120x2)
to get expected points of maximal activation (512x2).
Optionally, when num_kp != None, can learn a linear mapping from the feature maps to a lower/higher dimensional
space using a conv1x1 before computing the softmax.
Optionally, when num_kp != None, can learn a linear mapping (of shape (num_kp, H, W)) from the input feature maps
using a conv1x1 before computing softmax.
"""
def __init__(self, input_shape, num_kp=None, temperature=1.0, learnable_temperature=False):
def __init__(self, input_shape, num_kp=None):
"""
Args:
input_shape (list): (C, H, W) input feature map shape.
num_kp (int): number of keypoints in output. If None, output will have the same number of channels as input.
temperature (float): temperature for softmax normalization.
learnable_temperature (bool): whether to learn the temperature parameter.
"""
super().__init__()
@ -331,7 +328,6 @@ class SpatialSoftmax(nn.Module):
self.nets = None
self._out_c = self._in_c
self.temperature = nn.Parameter(torch.tensor(temperature), requires_grad=learnable_temperature)
# we could use torch.linspace directly but that seems to behave slightly differently than numpy
# and causes a small degradation in pc_success of pre-trained models.
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
@ -353,7 +349,7 @@ class SpatialSoftmax(nn.Module):
# [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
features = features.reshape(-1, self._in_h * self._in_w)
# 2d softmax normalization
attention = F.softmax(features / self.temperature, dim=-1)
attention = F.softmax(features, dim=-1)
# [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions
expected_xy = attention @ self.pos_grid
# reshape to [B, K, 2]