Merge remote-tracking branch 'origin/main' into user/aliberts/2024_05_07_remove_upper_bounds

This commit is contained in:
Simon Alibert 2024-05-08 19:09:34 +02:00
commit 974e6cf93b
3 changed files with 31 additions and 6 deletions

View File

@ -51,6 +51,7 @@ class DiffusionConfig:
use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning.
Bias modulation is used be default, while this parameter indicates whether to also use scale
modulation.
noise_scheduler_type: Name of the noise scheduler to use. Supported options: ["DDPM", "DDIM"].
num_train_timesteps: Number of diffusion steps for the forward diffusion schedule.
beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers.
beta_start: Beta value for the first forward-diffusion step.
@ -110,6 +111,7 @@ class DiffusionConfig:
diffusion_step_embed_dim: int = 128
use_film_scale_modulation: bool = True
# Noise scheduler.
noise_scheduler_type: str = "DDPM"
num_train_timesteps: int = 100
beta_schedule: str = "squaredcos_cap_v2"
beta_start: float = 0.0001
@ -144,3 +146,9 @@ class DiffusionConfig:
raise ValueError(
f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}."
)
supported_noise_schedulers = ["DDPM", "DDIM"]
if self.noise_scheduler_type not in supported_noise_schedulers:
raise ValueError(
f"`noise_scheduler_type` must be one of {supported_noise_schedulers}. "
f"Got {self.noise_scheduler_type}."
)

View File

@ -13,6 +13,7 @@ import einops
import torch
import torch.nn.functional as F # noqa: N812
import torchvision
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from huggingface_hub import PyTorchModelHubMixin
from robomimic.models.base_nets import SpatialSoftmax
@ -126,6 +127,19 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
return {"loss": loss}
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
"""
Factory for noise scheduler instances of the requested type. All kwargs are passed
to the scheduler.
"""
if name == "DDPM":
return DDPMScheduler(**kwargs)
elif name == "DDIM":
return DDIMScheduler(**kwargs)
else:
raise ValueError(f"Unsupported noise scheduler type {name}")
class DiffusionModel(nn.Module):
def __init__(self, config: DiffusionConfig):
super().__init__()
@ -138,12 +152,12 @@ class DiffusionModel(nn.Module):
* config.n_obs_steps,
)
self.noise_scheduler = DDPMScheduler(
self.noise_scheduler = _make_noise_scheduler(
config.noise_scheduler_type,
num_train_timesteps=config.num_train_timesteps,
beta_start=config.beta_start,
beta_end=config.beta_end,
beta_schedule=config.beta_schedule,
variance_type="fixed_small",
clip_sample=config.clip_sample,
clip_sample_range=config.clip_sample_range,
prediction_type=config.prediction_type,
@ -315,11 +329,13 @@ class DiffusionRgbEncoder(nn.Module):
# Set up pooling and final layers.
# Use a dry run to get the feature map shape.
# The dummy input should take the number of image channels from `config.input_shapes` and it should use the
# height and width from `config.crop_shape`.
dummy_input = torch.zeros(size=(1, config.input_shapes["observation.image"][0], *config.crop_shape))
with torch.inference_mode():
feat_map_shape = tuple(
self.backbone(torch.zeros(size=(1, *config.input_shapes["observation.image"]))).shape[1:]
)
self.pool = SpatialSoftmax(feat_map_shape, num_kp=config.spatial_softmax_num_keypoints)
dummy_feature_map = self.backbone(dummy_input)
feature_map_shape = tuple(dummy_feature_map.shape[1:])
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.feature_dim = config.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
self.relu = nn.ReLU()

View File

@ -85,6 +85,7 @@ policy:
diffusion_step_embed_dim: 128
use_film_scale_modulation: True
# Noise scheduler.
noise_scheduler_type: DDPM
num_train_timesteps: 100
beta_schedule: squaredcos_cap_v2
beta_start: 0.0001