diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py index 64804d8f..012efddd 100644 --- a/examples/3_train_policy.py +++ b/examples/3_train_policy.py @@ -11,54 +11,54 @@ import torch from omegaconf import OmegaConf from lerobot.common.datasets.factory import make_dataset +from lerobot.common.datasets.utils import cycle +from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.utils import init_hydra_config output_directory = Path("outputs/train/example_pusht_diffusion") os.makedirs(output_directory, exist_ok=True) -overrides = [ - "env=pusht", - "policy=diffusion", - # Adjust as you prefer. 5000 steps are needed to get something worth evaluating. - "offline_steps=5000", - "log_freq=250", - "device=cuda", -] - -cfg = init_hydra_config("lerobot/configs/default.yaml", overrides) - -policy = DiffusionPolicy( - cfg=cfg.policy, - cfg_device=cfg.device, - cfg_noise_scheduler=cfg.noise_scheduler, - cfg_optimizer=cfg.optimizer, - cfg_ema=cfg.ema, - **cfg.policy, -) -policy.train() +# Number of offline training steps (we'll only do offline training for this example. +# Adjust as you prefer. 5000 steps are needed to get something worth evaluating. +training_steps = 5000 +device = torch.device("cuda") +log_freq = 250 +# Set up the dataset. +cfg = init_hydra_config("lerobot/configs/default.yaml", overrides=["env=pusht"]) dataset = make_dataset(cfg) -# create dataloader for offline training +# Set up the the policy. +# Policies are initialized with a configuration class, in this case `DiffusionConfig`. +# For this example, no arguments need to be passed because the defaults are set up for PushT. +# If you're doing something different, you will likely need to change at least some of the defaults. +cfg = DiffusionConfig() +# TODO(alexander-soare): Remove LR scheduler from the policy. +policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps) +policy.train() +policy.to(device) + +# Create dataloader for offline training. dataloader = torch.utils.data.DataLoader( dataset, num_workers=4, - batch_size=cfg.policy.batch_size, + batch_size=cfg.batch_size, shuffle=True, - pin_memory=cfg.device != "cpu", + pin_memory=device != torch.device("cpu"), drop_last=True, ) -for step, batch in enumerate(dataloader): - info = policy(batch, step) - - if step % cfg.log_freq == 0: - num_samples = (step + 1) * cfg.policy.batch_size +# Run training loop. +dataloader = cycle(dataloader) +for step in range(training_steps): + batch = {k: v.to(device, non_blocking=True) for k, v in next(dataloader).items()} + info = policy(batch) + if step % log_freq == 0: + num_samples = (step + 1) * cfg.batch_size loss = info["loss"] update_s = info["update_s"] - print(f"step:{step} samples:{num_samples} loss:{loss:.3f} update_time:{update_s:.3f}(seconds)") - + print(f"step: {step} samples: {num_samples} loss: {loss:.3f} update_time: {update_s:.3f} (seconds)") # Save the policy, configuration, and normalization stats for later use. policy.save(output_directory / "model.pt") diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index e67d8a04..34430ff1 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -208,6 +208,10 @@ def compute_stats(dataset, batch_size=32, max_num_samples=None): def cycle(iterable): + """The equivalent of itertools.cycle, but safe for Pytorch dataloaders. + + See https://github.com/pytorch/pytorch/issues/23900 for information on why itertools.cycle is not safe. + """ iterator = iter(iterable) while True: try: diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index 72d35eb3..1b438f2d 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -26,8 +26,8 @@ class ActionChunkingTransformerConfig: image_normalization_std: Value by which to divide the input image pixels (after the mean has been subtracted). vision_backbone: Name of the torchvision resnet backbone to use for encoding images. - use_pretrained_backbone: Whether the backbone should be initialized with ImageNet, pretrained weights - from torchvision. + use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from + torchvision. replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated convolution. pre_norm: Whether to use "pre-norm" in the transformer blocks. diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index 272c80ea..d8820a0b 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -13,9 +13,49 @@ class DiffusionConfig: Args: state_dim: Dimensionality of the observation state space (excluding images). action_dim: Dimensionality of the action space. + image_size: (H, W) size of the input images. n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the current step and additional steps going back). - horizon: Diffusion model action prediction horizon as detailed in the main policy documentation. + horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`. + n_action_steps: The number of action steps to run in the environment for one invocation of the policy. + See `DiffusionPolicy.select_action` for more details. + image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in + [0, 1]) for normalization. + image_normalization_std: Value by which to divide the input image pixels (after the mean has been + subtracted). + vision_backbone: Name of the torchvision resnet backbone to use for encoding images. + crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit + within the image size. If None, no cropping is done. + crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval + mode). + use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from + torchvision. + use_group_norm: Whether to replace batch normalization with group normalization in the backbone. + The group sizes are set to be about 16 (to be precise, feature_dim // 16). + spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax. + down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet. + You may provide a variable number of dimensions, therefore also controlling the degree of + downsampling. + kernel_size: The convolutional kernel size of the diffusion modeling Unet. + n_groups: Number of groups used in the group norm of the Unet's convolutional blocks. + diffusion_step_embed_dim: The Unet is conditioned on the diffusion timestep via a small non-linear + network. This is the output dimension of that network, i.e., the embedding dimension. + use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning. + Bias modulation is used be default, while this parameter indicates whether to also use scale + modulation. + num_train_timesteps: Number of diffusion steps for the forward diffusion schedule. + beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers. + beta_start: Beta value for the first forward-diffusion step. + beta_end: Beta value for the last forward-diffusion step. + prediction_type: The type of prediction that the diffusion modeling Unet makes. Choose from "epsilon" + or "sample". These have equivalent outcomes from a latent variable modeling perspective, but + "epsilon" has been shown to work better in many deep neural network settings. + clip_sample: Whether to clip the sample to [-`clip_sample_range`, +`clip_sample_range`] for each + denoising step at inference time. WARNING: you will need to make sure your action-space is + normalized to fit within this range. + clip_sample_range: The magnitude of the clipping range as described above. + num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly + spaced). If not provided, this defaults to be the same as `num_train_timesteps`. """ # Environment. @@ -36,7 +76,7 @@ class DiffusionConfig: # Architecture / modeling. # Vision backbone. vision_backbone: str = "resnet18" - crop_shape: tuple[int, int] = (84, 84) + crop_shape: tuple[int, int] | None = (84, 84) crop_is_random: bool = True use_pretrained_backbone: bool = False use_group_norm: bool = True @@ -46,18 +86,18 @@ class DiffusionConfig: kernel_size: int = 5 n_groups: int = 8 diffusion_step_embed_dim: int = 128 - film_scale_modulation: bool = True + use_film_scale_modulation: bool = True # Noise scheduler. num_train_timesteps: int = 100 beta_schedule: str = "squaredcos_cap_v2" beta_start: float = 0.0001 beta_end: float = 0.02 - variance_type: str = "fixed_small" prediction_type: str = "epsilon" - clip_sample: True + clip_sample: bool = True + clip_sample_range: float = 1.0 # Inference - num_inference_steps: int = 100 + num_inference_steps: int | None = None # --- # TODO(alexander-soare): Remove these from the policy config. @@ -72,12 +112,24 @@ class DiffusionConfig: utd: int = 1 use_ema: bool = True ema_update_after_step: int = 0 - ema_min_rate: float = 0.0 - ema_max_rate: float = 0.9999 + ema_min_alpha: float = 0.0 + ema_max_alpha: float = 0.9999 ema_inv_gamma: float = 1.0 ema_power: float = 0.75 def __post_init__(self): """Input validation (not exhaustive).""" if not self.vision_backbone.startswith("resnet"): - raise ValueError("`vision_backbone` must be one of the ResNet variants.") + raise ValueError( + f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." + ) + if self.crop_shape[0] > self.image_size[0] or self.crop_shape[1] > self.image_size[1]: + raise ValueError( + f"`crop_shape` should fit within `image_size`. Got {self.crop_shape} for `crop_shape` and " + f"{self.image_size} for `image_size`." + ) + supported_prediction_types = ["epsilon", "sample"] + if self.prediction_type not in supported_prediction_types: + raise ValueError( + f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}." + ) diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 4853dbcf..a95effb2 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -1,8 +1,10 @@ """ TODO(alexander-soare): - Remove reliance on Robomimic for SpatialSoftmax. - - Remove reliance on diffusers for DDPMScheduler. + - Remove reliance on diffusers for DDPMScheduler and LR scheduler. - Move EMA out of policy. + - Consolidate _DiffusionUnetImagePolicy into DiffusionPolicy. + - One more pass on comments and documentation. """ import copy @@ -10,10 +12,10 @@ import logging import math import time from collections import deque +from itertools import chain from typing import Callable import einops -import hydra import torch import torch.nn.functional as F # noqa: N812 import torchvision @@ -23,12 +25,12 @@ from robomimic.models.base_nets import SpatialSoftmax from torch import Tensor, nn from torch.nn.modules.batchnorm import _BatchNorm +from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.utils import ( get_device_from_parameters, get_dtype_from_parameters, populate_queues, ) -from lerobot.common.utils import get_safe_torch_device logger = logging.getLogger(__name__) @@ -41,69 +43,29 @@ class DiffusionPolicy(nn.Module): name = "diffusion" - def __init__( - self, - cfg, - cfg_device, - cfg_noise_scheduler, - cfg_optimizer, - cfg_ema, - shape_meta: dict, - horizon, - n_action_steps, - n_obs_steps, - num_inference_steps=None, - diffusion_step_embed_dim=256, - down_dims=(256, 512, 1024), - kernel_size=5, - n_groups=8, - film_scale_modulation=True, - **_, - ): + def __init__(self, cfg: DiffusionConfig, lr_scheduler_num_training_steps: int): super().__init__() self.cfg = cfg - self.n_obs_steps = n_obs_steps - self.n_action_steps = n_action_steps # queues are populated during rollout of the policy, they contain the n latest observations and actions self._queues = None - noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler) - - self.diffusion = _DiffusionUnetImagePolicy( - cfg, - shape_meta=shape_meta, - noise_scheduler=noise_scheduler, - horizon=horizon, - n_action_steps=n_action_steps, - n_obs_steps=n_obs_steps, - num_inference_steps=num_inference_steps, - diffusion_step_embed_dim=diffusion_step_embed_dim, - down_dims=down_dims, - kernel_size=kernel_size, - n_groups=n_groups, - film_scale_modulation=film_scale_modulation, - ) - - self.device = get_safe_torch_device(cfg_device) - self.diffusion.to(self.device) + self.diffusion = _DiffusionUnetImagePolicy(cfg) # TODO(alexander-soare): This should probably be managed outside of the policy class. self.ema_diffusion = None self.ema = None if self.cfg.use_ema: self.ema_diffusion = copy.deepcopy(self.diffusion) - self.ema = hydra.utils.instantiate( - cfg_ema, - model=self.ema_diffusion, - ) + self.ema = _EMA(cfg, model=self.ema_diffusion) - self.optimizer = hydra.utils.instantiate( - cfg_optimizer, - params=self.diffusion.parameters(), + # TODO(alexander-soare): Move optimizer out of policy. + self.optimizer = torch.optim.Adam( + self.diffusion.parameters(), cfg.lr, cfg.adam_betas, cfg.adam_eps, cfg.adam_weight_decay ) - # TODO(rcadene): modify lr scheduler so that it doesnt depend on epochs but steps + # TODO(alexander-soare): Move LR scheduler out of policy. + # TODO(rcadene): modify lr scheduler so that it doesn't depend on epochs but steps self.global_step = 0 # configure lr scheduler @@ -111,7 +73,7 @@ class DiffusionPolicy(nn.Module): cfg.lr_scheduler, optimizer=self.optimizer, num_warmup_steps=cfg.lr_warmup_steps, - num_training_steps=cfg.offline_steps, + num_training_steps=lr_scheduler_num_training_steps, # pytorch assumes stepping LRScheduler every epoch # however huggingface diffusers steps it every batch last_epoch=self.global_step - 1, @@ -122,9 +84,9 @@ class DiffusionPolicy(nn.Module): Clear observation and action queues. Should be called on `env.reset()` """ self._queues = { - "observation.image": deque(maxlen=self.n_obs_steps), - "observation.state": deque(maxlen=self.n_obs_steps), - "action": deque(maxlen=self.n_action_steps), + "observation.image": deque(maxlen=self.cfg.n_obs_steps), + "observation.state": deque(maxlen=self.cfg.n_obs_steps), + "action": deque(maxlen=self.cfg.n_action_steps), } @torch.no_grad @@ -138,11 +100,13 @@ class DiffusionPolicy(nn.Module): - The diffusion model generates `horizon` steps worth of actions. - `n_action_steps` worth of actions are actually kept for execution, starting from the current step. Schematically this looks like: + ---------------------------------------------------------------------------------------------- (legend: o = n_obs_steps, h = horizon, a = n_action_steps) |timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... |n-o+1+h| - |observation is used | YES | YES | ..... | NO | NO | NO | NO | NO | NO | + |observation is used | YES | YES | YES | NO | NO | NO | NO | NO | NO | |action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES | |action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO | + ---------------------------------------------------------------------------------------------- Note that this means we require: `n_action_steps < horizon - n_obs_steps + 1`. Also, note that "horizon" may not the best name to describe what the variable actually means, because this period is actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. @@ -213,57 +177,41 @@ class DiffusionPolicy(nn.Module): class _DiffusionUnetImagePolicy(nn.Module): - def __init__( - self, - cfg, - shape_meta: dict, - noise_scheduler: DDPMScheduler, - horizon, - n_action_steps, - n_obs_steps, - num_inference_steps=None, - diffusion_step_embed_dim=256, - down_dims=(256, 512, 1024), - kernel_size=5, - n_groups=8, - film_scale_modulation=True, - ): + def __init__(self, cfg: DiffusionConfig): super().__init__() - action_shape = shape_meta["action"]["shape"] - assert len(action_shape) == 1 - action_dim = action_shape[0] - - self.rgb_encoder = _RgbEncoder(input_shape=shape_meta.obs.image.shape, **cfg.rgb_encoder) + self.cfg = cfg + self.rgb_encoder = _RgbEncoder(cfg) self.unet = _ConditionalUnet1D( - input_dim=action_dim, - global_cond_dim=(action_dim + self.rgb_encoder.feature_dim) * n_obs_steps, - diffusion_step_embed_dim=diffusion_step_embed_dim, - down_dims=down_dims, - kernel_size=kernel_size, - n_groups=n_groups, - film_scale_modulation=film_scale_modulation, + cfg, global_cond_dim=(cfg.action_dim + self.rgb_encoder.feature_dim) * cfg.n_obs_steps ) - self.noise_scheduler = noise_scheduler - self.horizon = horizon - self.action_dim = action_dim - self.n_action_steps = n_action_steps - self.n_obs_steps = n_obs_steps + self.noise_scheduler = DDPMScheduler( + num_train_timesteps=cfg.num_train_timesteps, + beta_start=cfg.beta_start, + beta_end=cfg.beta_end, + beta_schedule=cfg.beta_schedule, + variance_type="fixed_small", + clip_sample=cfg.clip_sample, + clip_sample_range=cfg.clip_sample_range, + prediction_type=cfg.prediction_type, + ) - if num_inference_steps is None: - num_inference_steps = noise_scheduler.config.num_train_timesteps - - self.num_inference_steps = num_inference_steps + if cfg.num_inference_steps is None: + self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps + else: + self.num_inference_steps = cfg.num_inference_steps # ========= inference ============ - def conditional_sample(self, batch_size, global_cond=None, generator=None): + def conditional_sample( + self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None + ) -> Tensor: device = get_device_from_parameters(self) dtype = get_dtype_from_parameters(self) # Sample prior. sample = torch.randn( - size=(batch_size, self.horizon, self.action_dim), + size=(batch_size, self.cfg.horizon, self.cfg.action_dim), dtype=dtype, device=device, generator=generator, @@ -283,7 +231,7 @@ class _DiffusionUnetImagePolicy(nn.Module): return sample - def generate_actions(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + def generate_actions(self, batch: dict[str, Tensor]) -> Tensor: """ This function expects `batch` to have (at least): { @@ -293,8 +241,7 @@ class _DiffusionUnetImagePolicy(nn.Module): """ assert set(batch).issuperset({"observation.state", "observation.image"}) batch_size, n_obs_steps = batch["observation.state"].shape[:2] - assert n_obs_steps == self.n_obs_steps - assert self.n_obs_steps == n_obs_steps + assert n_obs_steps == self.cfg.n_obs_steps # Extract image feature (first combine batch and sequence dims). img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ...")) @@ -307,13 +254,13 @@ class _DiffusionUnetImagePolicy(nn.Module): sample = self.conditional_sample(batch_size, global_cond=global_cond) # `horizon` steps worth of actions (from the first observation). - action = sample[..., : self.action_dim] + actions = sample[..., : self.cfg.action_dim] # Extract `n_action_steps` steps worth of actions (from the current observation). start = n_obs_steps - 1 - end = start + self.n_action_steps - action = action[:, start:end] + end = start + self.cfg.n_action_steps + actions = actions[:, start:end] - return action + return actions def compute_loss(self, batch: dict[str, Tensor]) -> Tensor: """ @@ -329,9 +276,8 @@ class _DiffusionUnetImagePolicy(nn.Module): assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"}) batch_size, n_obs_steps = batch["observation.state"].shape[:2] horizon = batch["action"].shape[1] - assert horizon == self.horizon - assert n_obs_steps == self.n_obs_steps - assert self.n_obs_steps == n_obs_steps + assert horizon == self.cfg.horizon + assert n_obs_steps == self.cfg.n_obs_steps # Extract image feature (first combine batch and sequence dims). img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ...")) @@ -359,14 +305,13 @@ class _DiffusionUnetImagePolicy(nn.Module): pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond) # Compute the loss. - # The targe is either the original trajectory, or the noise. - pred_type = self.noise_scheduler.config.prediction_type - if pred_type == "epsilon": + # The target is either the original trajectory, or the noise. + if self.cfg.prediction_type == "epsilon": target = eps - elif pred_type == "sample": + elif self.cfg.prediction_type == "sample": target = batch["action"] else: - raise ValueError(f"Unsupported prediction type {pred_type}") + raise ValueError(f"Unsupported prediction type {self.cfg.prediction_type}") loss = F.mse_loss(pred, target, reduction="none") @@ -384,64 +329,35 @@ class _RgbEncoder(nn.Module): Includes the ability to normalize and crop the image first. """ - def __init__( - self, - input_shape: tuple[int, int, int], - norm_mean_std: tuple[float, float] = [1.0, 1.0], - crop_shape: tuple[int, int] | None = None, - random_crop: bool = False, - backbone_name: str = "resnet18", - pretrained_backbone: bool = False, - use_group_norm: bool = False, - num_keypoints: int = 32, - ): - """ - Args: - input_shape: channel-first input shape (C, H, W) - norm_mean_std: mean and standard deviation used for image normalization. Images are normalized as - (image - mean) / std. - crop_shape: (H, W) shape to crop to (must fit within the input shape). If not provided, no - cropping is done. - random_crop: Whether the crop should be random at training time (it's always a center crop in - eval mode). - backbone_name: The name of one of the available resnet models from torchvision (eg resnet18). - pretrained_backbone: whether to use timm pretrained weights. - use_group_norm: Whether to replace batch normalization with group normalization in the backbone. - The group sizes are set to be about 16 (to be precise, feature_dim // 16). - num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image). - """ + def __init__(self, cfg: DiffusionConfig): super().__init__() - if input_shape[0] != 3: - raise ValueError("Only RGB images are handled") - if not backbone_name.startswith("resnet"): - raise ValueError( - "Only resnet is supported for now (because of the assumption that 'layer4' is the output layer)" - ) - # Set up optional preprocessing. - if norm_mean_std == [1.0, 1.0]: + if all(v == 1.0 for v in chain(cfg.image_normalization_mean, cfg.image_normalization_std)): self.normalizer = nn.Identity() else: - self.normalizer = torchvision.transforms.Normalize(mean=norm_mean_std[0], std=norm_mean_std[1]) - - if crop_shape is not None: + self.normalizer = torchvision.transforms.Normalize( + mean=cfg.image_normalization_mean, std=cfg.image_normalization_std + ) + if cfg.crop_shape is not None: self.do_crop = True # Always use center crop for eval - self.center_crop = torchvision.transforms.CenterCrop(crop_shape) - if random_crop: - self.maybe_random_crop = torchvision.transforms.RandomCrop(crop_shape) + self.center_crop = torchvision.transforms.CenterCrop(cfg.crop_shape) + if cfg.crop_is_random: + self.maybe_random_crop = torchvision.transforms.RandomCrop(cfg.crop_shape) else: self.maybe_random_crop = self.center_crop else: self.do_crop = False # Set up backbone. - backbone_model = getattr(torchvision.models, backbone_name)(pretrained=pretrained_backbone) + backbone_model = getattr(torchvision.models, cfg.vision_backbone)( + pretrained=cfg.use_pretrained_backbone + ) # Note: This assumes that the layer4 feature map is children()[-3] # TODO(alexander-soare): Use a safer alternative. self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2])) - if use_group_norm: - if pretrained_backbone: + if cfg.use_group_norm: + if cfg.use_pretrained_backbone: raise ValueError( "You can't replace BatchNorm in a pretrained model without ruining the weights!" ) @@ -454,10 +370,10 @@ class _RgbEncoder(nn.Module): # Set up pooling and final layers. # Use a dry run to get the feature map shape. with torch.inference_mode(): - feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:]) - self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints) - self.feature_dim = num_keypoints * 2 - self.out = nn.Linear(num_keypoints * 2, self.feature_dim) + feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, 3, *cfg.image_size))).shape[1:]) + self.pool = SpatialSoftmax(feat_map_shape, num_kp=cfg.spatial_softmax_num_keypoints) + self.feature_dim = cfg.spatial_softmax_num_keypoints * 2 + self.out = nn.Linear(cfg.spatial_softmax_num_keypoints * 2, self.feature_dim) self.relu = nn.ReLU() def forward(self, x: Tensor) -> Tensor: @@ -516,16 +432,18 @@ def _replace_submodules( class _SinusoidalPosEmb(nn.Module): - def __init__(self, dim): + """1D sinusoidal positional embeddings as in Attention is All You Need.""" + + def __init__(self, dim: int): super().__init__() self.dim = dim - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: device = x.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) - emb = x[:, None] * emb[None, :] + emb = x.unsqueeze(-1) * emb.unsqueeze(0) emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb @@ -549,92 +467,46 @@ class _Conv1dBlock(nn.Module): class _ConditionalUnet1D(nn.Module): """A 1D convolutional UNet with FiLM modulation for conditioning. - Two types of conditioning can be applied: - - Global: Conditioning information that is aggregated over the whole observation window. This is - incorporated via the FiLM technique in the residual convolution blocks of the Unet's encoder/decoder. - - Local: Conditioning information for each timestep in the observation window. This is incorporated - by encoding the information via 1D convolutions and adding the resulting embeddings to the inputs and - outputs of the Unet's encoder/decoder. + Note: this removes local conditioning as compared to the original diffusion policy code. """ - def __init__( - self, - input_dim: int, - local_cond_dim: int | None = None, - global_cond_dim: int | None = None, - diffusion_step_embed_dim: int = 256, - down_dims: int | None = None, - kernel_size: int = 3, - n_groups: int = 8, - film_scale_modulation: bool = False, - ): + def __init__(self, cfg: DiffusionConfig, global_cond_dim: int): super().__init__() - if down_dims is None: - down_dims = [256, 512, 1024] + self.cfg = cfg # Encoder for the diffusion timestep. self.diffusion_step_encoder = nn.Sequential( - _SinusoidalPosEmb(diffusion_step_embed_dim), - nn.Linear(diffusion_step_embed_dim, diffusion_step_embed_dim * 4), + _SinusoidalPosEmb(cfg.diffusion_step_embed_dim), + nn.Linear(cfg.diffusion_step_embed_dim, cfg.diffusion_step_embed_dim * 4), nn.Mish(), - nn.Linear(diffusion_step_embed_dim * 4, diffusion_step_embed_dim), + nn.Linear(cfg.diffusion_step_embed_dim * 4, cfg.diffusion_step_embed_dim), ) # The FiLM conditioning dimension. - cond_dim = diffusion_step_embed_dim - if global_cond_dim is not None: - cond_dim += global_cond_dim - - self.local_cond_down_encoder = None - self.local_cond_up_encoder = None - if local_cond_dim is not None: - # Encoder for the local conditioning. The output gets added to the Unet encoder input. - self.local_cond_down_encoder = _ConditionalResidualBlock1D( - local_cond_dim, - down_dims[0], - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - film_scale_modulation=film_scale_modulation, - ) - # Encoder for the local conditioning. The output gets added to the Unet encoder output. - self.local_cond_up_encoder = _ConditionalResidualBlock1D( - local_cond_dim, - down_dims[0], - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - film_scale_modulation=film_scale_modulation, - ) + cond_dim = cfg.diffusion_step_embed_dim + global_cond_dim # In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we # just reverse these. - in_out = [(input_dim, down_dims[0])] + list(zip(down_dims[:-1], down_dims[1:], strict=True)) + in_out = [(cfg.action_dim, cfg.down_dims[0])] + list( + zip(cfg.down_dims[:-1], cfg.down_dims[1:], strict=True) + ) # Unet encoder. + common_res_block_kwargs = { + "cond_dim": cond_dim, + "kernel_size": cfg.kernel_size, + "n_groups": cfg.n_groups, + "use_film_scale_modulation": cfg.use_film_scale_modulation, + } self.down_modules = nn.ModuleList([]) for ind, (dim_in, dim_out) in enumerate(in_out): is_last = ind >= (len(in_out) - 1) self.down_modules.append( nn.ModuleList( [ - _ConditionalResidualBlock1D( - dim_in, - dim_out, - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - film_scale_modulation=film_scale_modulation, - ), - _ConditionalResidualBlock1D( - dim_out, - dim_out, - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - film_scale_modulation=film_scale_modulation, - ), + _ConditionalResidualBlock1D(dim_in, dim_out, **common_res_block_kwargs), + _ConditionalResidualBlock1D(dim_out, dim_out, **common_res_block_kwargs), # Downsample as long as it is not the last block. nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(), ] @@ -644,22 +516,8 @@ class _ConditionalUnet1D(nn.Module): # Processing in the middle of the auto-encoder. self.mid_modules = nn.ModuleList( [ - _ConditionalResidualBlock1D( - down_dims[-1], - down_dims[-1], - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - film_scale_modulation=film_scale_modulation, - ), - _ConditionalResidualBlock1D( - down_dims[-1], - down_dims[-1], - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - film_scale_modulation=film_scale_modulation, - ), + _ConditionalResidualBlock1D(cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs), + _ConditionalResidualBlock1D(cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs), ] ) @@ -670,22 +528,9 @@ class _ConditionalUnet1D(nn.Module): self.up_modules.append( nn.ModuleList( [ - _ConditionalResidualBlock1D( - dim_in * 2, # x2 as it takes the encoder's skip connection as well - dim_out, - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - film_scale_modulation=film_scale_modulation, - ), - _ConditionalResidualBlock1D( - dim_out, - dim_out, - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - film_scale_modulation=film_scale_modulation, - ), + # dim_in * 2, because it takes the encoder's skip connection as well + _ConditionalResidualBlock1D(dim_in * 2, dim_out, **common_res_block_kwargs), + _ConditionalResidualBlock1D(dim_out, dim_out, **common_res_block_kwargs), # Upsample as long as it is not the last block. nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(), ] @@ -693,29 +538,22 @@ class _ConditionalUnet1D(nn.Module): ) self.final_conv = nn.Sequential( - _Conv1dBlock(down_dims[0], down_dims[0], kernel_size=kernel_size), - nn.Conv1d(down_dims[0], input_dim, 1), + _Conv1dBlock(cfg.down_dims[0], cfg.down_dims[0], kernel_size=cfg.kernel_size), + nn.Conv1d(cfg.down_dims[0], cfg.action_dim, 1), ) - def forward(self, x: Tensor, timestep: Tensor | int, local_cond=None, global_cond=None) -> Tensor: + def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor: """ Args: x: (B, T, input_dim) tensor for input to the Unet. timestep: (B,) tensor of (timestep_we_are_denoising_from - 1). - local_cond: (B, T, local_cond_dim) global_cond: (B, global_cond_dim) output: (B, T, input_dim) Returns: - (B, T, input_dim) + (B, T, input_dim) diffusion model prediction. """ # For 1D convolutions we'll need feature dimension first. x = einops.rearrange(x, "b t d -> b d t") - if local_cond is not None: - if self.local_cond_down_encoder is None or self.local_cond_up_encoder is None: - raise ValueError( - "`local_cond` was provided but the relevant encoders weren't built at initialization." - ) - local_cond = einops.rearrange(local_cond, "b t d -> b d t") timesteps_embed = self.diffusion_step_encoder(timestep) @@ -725,11 +563,10 @@ class _ConditionalUnet1D(nn.Module): else: global_feature = timesteps_embed + # Run encoder, keeping track of skip features to pass to the decoder. encoder_skip_features: list[Tensor] = [] - for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules): + for resnet, resnet2, downsample in self.down_modules: x = resnet(x, global_feature) - if idx == 0 and local_cond is not None: - x = x + self.local_cond_down_encoder(local_cond, global_feature) x = resnet2(x, global_feature) encoder_skip_features.append(x) x = downsample(x) @@ -737,14 +574,10 @@ class _ConditionalUnet1D(nn.Module): for mid_module in self.mid_modules: x = mid_module(x, global_feature) - for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules): + # Run decoder, using the skip features from the encoder. + for resnet, resnet2, upsample in self.up_modules: x = torch.cat((x, encoder_skip_features.pop()), dim=1) x = resnet(x, global_feature) - # Note: The condition in the original implementation is: - # if idx == len(self.up_modules) and local_cond is not None: - # But as they mention in their comments, this is incorrect. We use the correct condition here. - if idx == (len(self.up_modules) - 1) and local_cond is not None: - x = x + self.local_cond_up_encoder(local_cond, global_feature) x = resnet2(x, global_feature) x = upsample(x) @@ -766,17 +599,17 @@ class _ConditionalResidualBlock1D(nn.Module): n_groups: int = 8, # Set to True to do scale modulation with FiLM as well as bias modulation (defaults to False meaning # FiLM just modulates bias). - film_scale_modulation: bool = False, + use_film_scale_modulation: bool = False, ): super().__init__() - self.film_scale_modulation = film_scale_modulation + self.use_film_scale_modulation = use_film_scale_modulation self.out_channels = out_channels self.conv1 = _Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups) # FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale. - cond_channels = out_channels * 2 if film_scale_modulation else out_channels + cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels)) self.conv2 = _Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups) @@ -798,7 +631,7 @@ class _ConditionalResidualBlock1D(nn.Module): # Get condition embedding. Unsqueeze for broadcasting to `out`, resulting in (B, out_channels, 1). cond_embed = self.cond_encoder(cond).unsqueeze(-1) - if self.film_scale_modulation: + if self.use_film_scale_modulation: # Treat the embedding as a list of scales and biases. scale = cond_embed[:, : self.out_channels] bias = cond_embed[:, self.out_channels :] @@ -817,9 +650,7 @@ class _EMA: Exponential Moving Average of models weights """ - def __init__( - self, model, update_after_step=0, inv_gamma=1.0, power=2 / 3, min_value=0.0, max_value=0.9999 - ): + def __init__(self, cfg: DiffusionConfig, model: nn.Module): """ @crowsonkb's notes on EMA Warmup: If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan @@ -829,18 +660,18 @@ class _EMA: Args: inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. power (float): Exponential factor of EMA warmup. Default: 2/3. - min_value (float): The minimum EMA decay rate. Default: 0. + min_alpha (float): The minimum EMA decay rate. Default: 0. """ self.averaged_model = model self.averaged_model.eval() self.averaged_model.requires_grad_(False) - self.update_after_step = update_after_step - self.inv_gamma = inv_gamma - self.power = power - self.min_value = min_value - self.max_value = max_value + self.update_after_step = cfg.ema_update_after_step + self.inv_gamma = cfg.ema_inv_gamma + self.power = cfg.ema_power + self.min_alpha = cfg.ema_min_alpha + self.max_alpha = cfg.ema_max_alpha self.alpha = 0.0 self.optimization_step = 0 @@ -855,7 +686,7 @@ class _EMA: if step <= 0: return 0.0 - return max(self.min_value, min(value, self.max_value)) + return max(self.min_alpha, min(value, self.max_alpha)) @torch.no_grad() def step(self, new_model): diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 0b185e86..b5b5f861 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -9,7 +9,7 @@ def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg): expected_kwargs = set(inspect.signature(policy_cfg_class).parameters) assert set(hydra_cfg.policy).issuperset( expected_kwargs - ), f"Hydra config is missing arguments: {set(hydra_cfg.policy).difference(expected_kwargs)}" + ), f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}" policy_cfg = policy_cfg_class( **{ k: v @@ -35,7 +35,7 @@ def make_policy(hydra_cfg: DictConfig): from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg) - policy = DiffusionPolicy(policy_cfg) + policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps) policy.to(get_safe_torch_device(hydra_cfg.device)) elif hydra_cfg.policy.name == "act": from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index c8bdb0c4..44746dfc 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -44,7 +44,7 @@ policy: # Vision backbone. vision_backbone: resnet18 crop_shape: [84, 84] - random_crop: True + crop_is_random: True use_pretrained_backbone: false use_group_norm: True spatial_softmax_num_keypoints: 32 @@ -53,15 +53,15 @@ policy: kernel_size: 5 n_groups: 8 diffusion_step_embed_dim: 128 - film_scale_modulation: True + use_film_scale_modulation: True # Noise scheduler. num_train_timesteps: 100 beta_schedule: squaredcos_cap_v2 beta_start: 0.0001 beta_end: 0.02 - variance_type: fixed_small prediction_type: epsilon # epsilon / sample clip_sample: True + clip_sample_range: 1.0 # Inference num_inference_steps: 100 @@ -79,12 +79,12 @@ policy: utd: 1 use_ema: true ema_update_after_step: 0 - ema_min_rate: 0.0 - ema_max_rate: 0.9999 + ema_min_alpha: 0.0 + ema_max_alpha: 0.9999 ema_inv_gamma: 1.0 ema_power: 0.75 delta_timestamps: - observation.images: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]" + observation.image: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]" observation.state: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]" action: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1 - ${n_obs_steps} + ${policy.horizon})]" diff --git a/tests/test_examples.py b/tests/test_examples.py index 4263e452..83fdad5e 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1,8 +1,8 @@ from pathlib import Path -def _find_and_replace(text: str, finds: list[str], replaces: list[str]) -> str: - for f, r in zip(finds, replaces): +def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> str: + for f, r in finds_and_replaces: assert f in text text = text.replace(f, r) return text @@ -32,8 +32,10 @@ def test_examples_3_and_2(): # Do less steps and use CPU. file_contents = _find_and_replace( file_contents, - ['"offline_steps=5000"', '"device=cuda"'], - ['"offline_steps=1"', '"device=cpu"'], + [ + ("offline_steps = 5000", "offline_steps = 1"), + ('device = torch.device("cuda")', 'device = torch.device("cpu")'), + ], ) exec(file_contents) @@ -50,20 +52,15 @@ def test_examples_3_and_2(): file_contents = _find_and_replace( file_contents, [ - '"eval_episodes=10"', - '"rollout_batch_size=10"', - '"device=cuda"', - '# folder = Path("outputs/train/example_pusht_diffusion")', - 'hub_id = "lerobot/diffusion_policy_pusht_image"', - "folder = Path(snapshot_download(hub_id)", - ], - [ - '"eval_episodes=1"', - '"rollout_batch_size=1"', - '"device=cpu"', - 'folder = Path("outputs/train/example_pusht_diffusion")', - "", - "", + ('"eval_episodes=10"', '"eval_episodes=1"'), + ('"rollout_batch_size=10"', '"rollout_batch_size=1"'), + ('"device=cuda"', '"device=cpu"'), + ( + '# folder = Path("outputs/train/example_pusht_diffusion")', + 'folder = Path("outputs/train/example_pusht_diffusion")', + ), + ('hub_id = "lerobot/diffusion_policy_pusht_image"', ""), + ("folder = Path(snapshot_download(hub_id)", ""), ], )