Merge remote-tracking branch 'upstream/main'

This commit is contained in:
Alexander Soare 2024-05-08 18:10:07 +01:00
commit 3883eba708
4 changed files with 45 additions and 11 deletions

View File

@ -51,6 +51,7 @@ class DiffusionConfig:
use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning. use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning.
Bias modulation is used be default, while this parameter indicates whether to also use scale Bias modulation is used be default, while this parameter indicates whether to also use scale
modulation. modulation.
noise_scheduler_type: Name of the noise scheduler to use. Supported options: ["DDPM", "DDIM"].
num_train_timesteps: Number of diffusion steps for the forward diffusion schedule. num_train_timesteps: Number of diffusion steps for the forward diffusion schedule.
beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers. beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers.
beta_start: Beta value for the first forward-diffusion step. beta_start: Beta value for the first forward-diffusion step.
@ -110,6 +111,7 @@ class DiffusionConfig:
diffusion_step_embed_dim: int = 128 diffusion_step_embed_dim: int = 128
use_film_scale_modulation: bool = True use_film_scale_modulation: bool = True
# Noise scheduler. # Noise scheduler.
noise_scheduler_type: str = "DDPM"
num_train_timesteps: int = 100 num_train_timesteps: int = 100
beta_schedule: str = "squaredcos_cap_v2" beta_schedule: str = "squaredcos_cap_v2"
beta_start: float = 0.0001 beta_start: float = 0.0001
@ -144,3 +146,9 @@ class DiffusionConfig:
raise ValueError( raise ValueError(
f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}." f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}."
) )
supported_noise_schedulers = ["DDPM", "DDIM"]
if self.noise_scheduler_type not in supported_noise_schedulers:
raise ValueError(
f"`noise_scheduler_type` must be one of {supported_noise_schedulers}. "
f"Got {self.noise_scheduler_type}."
)

View File

@ -13,6 +13,7 @@ import einops
import torch import torch
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
import torchvision import torchvision
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from huggingface_hub import PyTorchModelHubMixin from huggingface_hub import PyTorchModelHubMixin
from robomimic.models.base_nets import SpatialSoftmax from robomimic.models.base_nets import SpatialSoftmax
@ -126,6 +127,19 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
return {"loss": loss} return {"loss": loss}
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
"""
Factory for noise scheduler instances of the requested type. All kwargs are passed
to the scheduler.
"""
if name == "DDPM":
return DDPMScheduler(**kwargs)
elif name == "DDIM":
return DDIMScheduler(**kwargs)
else:
raise ValueError(f"Unsupported noise scheduler type {name}")
class DiffusionModel(nn.Module): class DiffusionModel(nn.Module):
def __init__(self, config: DiffusionConfig): def __init__(self, config: DiffusionConfig):
super().__init__() super().__init__()
@ -138,12 +152,12 @@ class DiffusionModel(nn.Module):
* config.n_obs_steps, * config.n_obs_steps,
) )
self.noise_scheduler = DDPMScheduler( self.noise_scheduler = _make_noise_scheduler(
config.noise_scheduler_type,
num_train_timesteps=config.num_train_timesteps, num_train_timesteps=config.num_train_timesteps,
beta_start=config.beta_start, beta_start=config.beta_start,
beta_end=config.beta_end, beta_end=config.beta_end,
beta_schedule=config.beta_schedule, beta_schedule=config.beta_schedule,
variance_type="fixed_small",
clip_sample=config.clip_sample, clip_sample=config.clip_sample,
clip_sample_range=config.clip_sample_range, clip_sample_range=config.clip_sample_range,
prediction_type=config.prediction_type, prediction_type=config.prediction_type,
@ -315,11 +329,13 @@ class DiffusionRgbEncoder(nn.Module):
# Set up pooling and final layers. # Set up pooling and final layers.
# Use a dry run to get the feature map shape. # Use a dry run to get the feature map shape.
# The dummy input should take the number of image channels from `config.input_shapes` and it should use the
# height and width from `config.crop_shape`.
dummy_input = torch.zeros(size=(1, config.input_shapes["observation.image"][0], *config.crop_shape))
with torch.inference_mode(): with torch.inference_mode():
feat_map_shape = tuple( dummy_feature_map = self.backbone(dummy_input)
self.backbone(torch.zeros(size=(1, *config.input_shapes["observation.image"]))).shape[1:] feature_map_shape = tuple(dummy_feature_map.shape[1:])
) self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.pool = SpatialSoftmax(feat_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.feature_dim = config.spatial_softmax_num_keypoints * 2 self.feature_dim = config.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim) self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
self.relu = nn.ReLU() self.relu = nn.ReLU()

View File

@ -1,4 +1,5 @@
import inspect import inspect
import logging
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
@ -8,9 +9,10 @@ from lerobot.common.utils.utils import get_safe_torch_device
def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg): def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
expected_kwargs = set(inspect.signature(policy_cfg_class).parameters) expected_kwargs = set(inspect.signature(policy_cfg_class).parameters)
assert set(hydra_cfg.policy).issuperset( if not set(hydra_cfg.policy).issuperset(expected_kwargs):
expected_kwargs logging.warning(
), f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}" f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}"
)
policy_cfg = policy_cfg_class( policy_cfg = policy_cfg_class(
**{ **{
k: v k: v
@ -62,11 +64,18 @@ def make_policy(
policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name) policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name)
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
if pretrained_policy_name_or_path is None: if pretrained_policy_name_or_path is None:
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg) # Make a fresh policy.
policy = policy_cls(policy_cfg, dataset_stats) policy = policy_cls(policy_cfg, dataset_stats)
else: else:
policy = policy_cls.from_pretrained(pretrained_policy_name_or_path) # Load a pretrained policy and override the config if needed (for example, if there are inference-time
# hyperparameters that we want to vary).
# TODO(alexander-soare): This hack makes use of huggingface_hub's tooling to load the policy with, pretrained
# weights which are then loaded into a fresh policy with the desired config. This PR in huggingface_hub should
# make it possible to avoid the hack: https://github.com/huggingface/huggingface_hub/pull/2274.
policy = policy_cls(policy_cfg)
policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict())
policy.to(get_safe_torch_device(hydra_cfg.device)) policy.to(get_safe_torch_device(hydra_cfg.device))

View File

@ -85,6 +85,7 @@ policy:
diffusion_step_embed_dim: 128 diffusion_step_embed_dim: 128
use_film_scale_modulation: True use_film_scale_modulation: True
# Noise scheduler. # Noise scheduler.
noise_scheduler_type: DDPM
num_train_timesteps: 100 num_train_timesteps: 100
beta_schedule: squaredcos_cap_v2 beta_schedule: squaredcos_cap_v2
beta_start: 0.0001 beta_start: 0.0001