rm temperature and edit docstrings

This commit is contained in:
Akshay Kashyap 2024-05-14 10:44:29 -04:00
parent 09b983f2ff
commit 1c343f4d06
1 changed files with 4 additions and 8 deletions

View File

@ -1,7 +1,6 @@
"""Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion" """Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
TODO(alexander-soare): TODO(alexander-soare):
- Remove reliance on Robomimic for SpatialSoftmax.
- Remove reliance on diffusers for DDPMScheduler and LR scheduler. - Remove reliance on diffusers for DDPMScheduler and LR scheduler.
""" """
@ -307,17 +306,15 @@ class SpatialSoftmax(nn.Module):
We apply channel-wise softmax over the activations (512x120) and compute dot product with the coordinates (120x2) We apply channel-wise softmax over the activations (512x120) and compute dot product with the coordinates (120x2)
to get expected points of maximal activation (512x2). to get expected points of maximal activation (512x2).
Optionally, when num_kp != None, can learn a linear mapping from the feature maps to a lower/higher dimensional Optionally, when num_kp != None, can learn a linear mapping (of shape (num_kp, H, W)) from the input feature maps
space using a conv1x1 before computing the softmax. using a conv1x1 before computing softmax.
""" """
def __init__(self, input_shape, num_kp=None, temperature=1.0, learnable_temperature=False): def __init__(self, input_shape, num_kp=None):
""" """
Args: Args:
input_shape (list): (C, H, W) input feature map shape. input_shape (list): (C, H, W) input feature map shape.
num_kp (int): number of keypoints in output. If None, output will have the same number of channels as input. num_kp (int): number of keypoints in output. If None, output will have the same number of channels as input.
temperature (float): temperature for softmax normalization.
learnable_temperature (bool): whether to learn the temperature parameter.
""" """
super().__init__() super().__init__()
@ -331,7 +328,6 @@ class SpatialSoftmax(nn.Module):
self.nets = None self.nets = None
self._out_c = self._in_c self._out_c = self._in_c
self.temperature = nn.Parameter(torch.tensor(temperature), requires_grad=learnable_temperature)
# we could use torch.linspace directly but that seems to behave slightly differently than numpy # we could use torch.linspace directly but that seems to behave slightly differently than numpy
# and causes a small degradation in pc_success of pre-trained models. # and causes a small degradation in pc_success of pre-trained models.
pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)) pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h))
@ -353,7 +349,7 @@ class SpatialSoftmax(nn.Module):
# [B, K, H, W] -> [B * K, H * W] where K is number of keypoints # [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
features = features.reshape(-1, self._in_h * self._in_w) features = features.reshape(-1, self._in_h * self._in_w)
# 2d softmax normalization # 2d softmax normalization
attention = F.softmax(features / self.temperature, dim=-1) attention = F.softmax(features, dim=-1)
# [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions # [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions
expected_xy = attention @ self.pos_grid expected_xy = attention @ self.pos_grid
# reshape to [B, K, 2] # reshape to [B, K, 2]