feat: enable to use multiple rgb encoders per camera in diffusion policy (#484)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
parent
172809a502
commit
538455a965
|
@ -67,6 +67,7 @@ class DiffusionConfig:
|
||||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||||
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
||||||
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
|
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
|
||||||
|
use_separate_rgb_encoders_per_camera: Whether to use a separate RGB encoder for each camera view.
|
||||||
down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet.
|
down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet.
|
||||||
You may provide a variable number of dimensions, therefore also controlling the degree of
|
You may provide a variable number of dimensions, therefore also controlling the degree of
|
||||||
downsampling.
|
downsampling.
|
||||||
|
@ -130,6 +131,7 @@ class DiffusionConfig:
|
||||||
pretrained_backbone_weights: str | None = None
|
pretrained_backbone_weights: str | None = None
|
||||||
use_group_norm: bool = True
|
use_group_norm: bool = True
|
||||||
spatial_softmax_num_keypoints: int = 32
|
spatial_softmax_num_keypoints: int = 32
|
||||||
|
use_separate_rgb_encoder_per_camera: bool = False
|
||||||
# Unet.
|
# Unet.
|
||||||
down_dims: tuple[int, ...] = (512, 1024, 2048)
|
down_dims: tuple[int, ...] = (512, 1024, 2048)
|
||||||
kernel_size: int = 5
|
kernel_size: int = 5
|
||||||
|
|
|
@ -182,8 +182,13 @@ class DiffusionModel(nn.Module):
|
||||||
self._use_env_state = False
|
self._use_env_state = False
|
||||||
if num_images > 0:
|
if num_images > 0:
|
||||||
self._use_images = True
|
self._use_images = True
|
||||||
self.rgb_encoder = DiffusionRgbEncoder(config)
|
if self.config.use_separate_rgb_encoder_per_camera:
|
||||||
global_cond_dim += self.rgb_encoder.feature_dim * num_images
|
encoders = [DiffusionRgbEncoder(config) for _ in range(num_images)]
|
||||||
|
self.rgb_encoder = nn.ModuleList(encoders)
|
||||||
|
global_cond_dim += encoders[0].feature_dim * num_images
|
||||||
|
else:
|
||||||
|
self.rgb_encoder = DiffusionRgbEncoder(config)
|
||||||
|
global_cond_dim += self.rgb_encoder.feature_dim * num_images
|
||||||
if "observation.environment_state" in config.input_shapes:
|
if "observation.environment_state" in config.input_shapes:
|
||||||
self._use_env_state = True
|
self._use_env_state = True
|
||||||
global_cond_dim += config.input_shapes["observation.environment_state"][0]
|
global_cond_dim += config.input_shapes["observation.environment_state"][0]
|
||||||
|
@ -239,16 +244,32 @@ class DiffusionModel(nn.Module):
|
||||||
"""Encode image features and concatenate them all together along with the state vector."""
|
"""Encode image features and concatenate them all together along with the state vector."""
|
||||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||||
global_cond_feats = [batch["observation.state"]]
|
global_cond_feats = [batch["observation.state"]]
|
||||||
# Extract image feature (first combine batch, sequence, and camera index dims).
|
# Extract image features.
|
||||||
if self._use_images:
|
if self._use_images:
|
||||||
img_features = self.rgb_encoder(
|
if self.config.use_separate_rgb_encoder_per_camera:
|
||||||
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
|
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
|
||||||
)
|
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
|
||||||
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
|
img_features_list = torch.cat(
|
||||||
# feature dim (effectively concatenating the camera features).
|
[
|
||||||
img_features = einops.rearrange(
|
encoder(images)
|
||||||
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
for encoder, images in zip(self.rgb_encoder, images_per_camera, strict=True)
|
||||||
)
|
]
|
||||||
|
)
|
||||||
|
# Separate batch and sequence dims back out. The camera index dim gets absorbed into the
|
||||||
|
# feature dim (effectively concatenating the camera features).
|
||||||
|
img_features = einops.rearrange(
|
||||||
|
img_features_list, "(n b s) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
|
||||||
|
img_features = self.rgb_encoder(
|
||||||
|
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
|
||||||
|
)
|
||||||
|
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
|
||||||
|
# feature dim (effectively concatenating the camera features).
|
||||||
|
img_features = einops.rearrange(
|
||||||
|
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||||
|
)
|
||||||
global_cond_feats.append(img_features)
|
global_cond_feats.append(img_features)
|
||||||
|
|
||||||
if self._use_env_state:
|
if self._use_env_state:
|
||||||
|
|
Loading…
Reference in New Issue