From 976a197f9851be45906bbbd158cb0ca058e223d2 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Thu, 11 Apr 2024 17:51:35 +0100 Subject: [PATCH] backup wip --- .../_diffusion_policy_replay_buffer.py} | 5 + lerobot/common/datasets/pusht.py | 4 +- lerobot/common/policies/act/policy.py | 11 +- .../diffusion/diffusion_unet_image_policy.py | 315 ------ .../diffusion/model/conditional_unet1d.py | 401 ++++---- .../diffusion/model/conv1d_components.py | 47 - .../diffusion/model/crop_randomizer.py | 294 ------ .../diffusion/model/dict_of_tensor_mixin.py | 41 - .../model/diffusion_unet_image_policy.py | 220 ++++ .../policies/diffusion/model/ema_model.py | 13 - .../policies/diffusion/model/lr_scheduler.py | 46 - .../diffusion/model/mask_generator.py | 65 -- .../diffusion/model/module_attr_mixin.py | 15 - .../model/multi_image_obs_encoder.py | 214 ---- .../policies/diffusion/model/normalizer.py | 358 ------- .../diffusion/model/positional_embedding.py | 19 - .../policies/diffusion/model/rgb_encoder.py | 147 +++ .../policies/diffusion/model/tensor_utils.py | 972 ------------------ lerobot/common/policies/diffusion/policy.py | 78 +- .../policies/diffusion/pytorch_utils.py | 76 -- lerobot/common/policies/factory.py | 2 - lerobot/common/policies/utils.py | 22 + lerobot/common/utils.py | 1 + lerobot/configs/policy/diffusion.yaml | 24 +- lerobot/scripts/eval.py | 2 +- lerobot/scripts/train.py | 2 +- 26 files changed, 661 insertions(+), 2733 deletions(-) rename lerobot/common/{policies/diffusion/replay_buffer.py => datasets/_diffusion_policy_replay_buffer.py} (99%) delete mode 100644 lerobot/common/policies/diffusion/diffusion_unet_image_policy.py delete mode 100644 lerobot/common/policies/diffusion/model/conv1d_components.py delete mode 100644 lerobot/common/policies/diffusion/model/crop_randomizer.py delete mode 100644 lerobot/common/policies/diffusion/model/dict_of_tensor_mixin.py create mode 100644 lerobot/common/policies/diffusion/model/diffusion_unet_image_policy.py delete mode 100644 lerobot/common/policies/diffusion/model/lr_scheduler.py delete mode 100644 lerobot/common/policies/diffusion/model/mask_generator.py delete mode 100644 lerobot/common/policies/diffusion/model/module_attr_mixin.py delete mode 100644 lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py delete mode 100644 lerobot/common/policies/diffusion/model/normalizer.py delete mode 100644 lerobot/common/policies/diffusion/model/positional_embedding.py create mode 100644 lerobot/common/policies/diffusion/model/rgb_encoder.py delete mode 100644 lerobot/common/policies/diffusion/model/tensor_utils.py delete mode 100644 lerobot/common/policies/diffusion/pytorch_utils.py diff --git a/lerobot/common/policies/diffusion/replay_buffer.py b/lerobot/common/datasets/_diffusion_policy_replay_buffer.py similarity index 99% rename from lerobot/common/policies/diffusion/replay_buffer.py rename to lerobot/common/datasets/_diffusion_policy_replay_buffer.py index 7fccf74d..1697f9fc 100644 --- a/lerobot/common/policies/diffusion/replay_buffer.py +++ b/lerobot/common/datasets/_diffusion_policy_replay_buffer.py @@ -1,3 +1,8 @@ +"""Helper code for loading PushT dataset from Diffusion Policy (https://diffusion-policy.cs.columbia.edu/) + +Copied from the original Diffusion Policy repository. +""" + from __future__ import annotations import math diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index b468637e..9af6f3a1 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -8,8 +8,10 @@ import torch import tqdm from gym_pusht.envs.pusht import pymunk_to_shapely +from lerobot.common.datasets._diffusion_policy_replay_buffer import ( + ReplayBuffer as DiffusionPolicyReplayBuffer, +) from lerobot.common.datasets.utils import download_and_extract_zip, load_data_with_delta_timestamps -from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer # as define in env SUCCESS_THRESHOLD = 0.95 # 95% coverage, diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/policy.py index 25b814ed..821b0196 100644 --- a/lerobot/common/policies/act/policy.py +++ b/lerobot/common/policies/act/policy.py @@ -176,7 +176,8 @@ class ActionChunkingTransformerPolicy(nn.Module): if self.n_action_steps is not None: self._action_queue = deque([], maxlen=self.n_action_steps) - def select_action(self, batch: dict[str, Tensor], *_, **__) -> Tensor: + @torch.no_grad + def select_action(self, batch: dict[str, Tensor], **_) -> Tensor: """ This method wraps `select_actions` in order to return one action at a time for execution in the environment. It works by managing the actions in a queue and only calling `select_actions` when the @@ -188,7 +189,7 @@ class ActionChunkingTransformerPolicy(nn.Module): self._action_queue.extend(self.select_actions(batch).transpose(0, 1)) return self._action_queue.popleft() - @torch.no_grad() + @torch.no_grad def select_actions(self, batch: dict[str, Tensor]) -> Tensor: """Use the action chunking transformer to generate a sequence of actions.""" self.eval() @@ -223,8 +224,6 @@ class ActionChunkingTransformerPolicy(nn.Module): { "observation.state": (B, 1, J) OR (B, J) tensor of robot states (joint configuration). "observation.images.top": (B, 1, C, H, W) OR (B, C, H, W) tensor of images. - "action": (B, H, J) tensor of actions (positional target for robot joint configuration) - "action_is_pad": (B, H) mask for whether the actions are padding outside of the episode bounds. } """ if add_obs_steps_dim: @@ -244,7 +243,8 @@ class ActionChunkingTransformerPolicy(nn.Module): # Note: no squeeze is required for "observation.images.top" because then we'd have to unsqueeze to get # the image index dimension. - def update(self, batch, *_, **__) -> dict: + def update(self, batch, **_) -> dict: + """Run the model in train mode, compute the loss, and do an optimization step.""" start_time = time.time() self._preprocess_batch(batch) @@ -278,6 +278,7 @@ class ActionChunkingTransformerPolicy(nn.Module): return info def forward(self, batch: dict[str, Tensor], return_loss: bool = False) -> dict | Tensor: + """A forward pass through the DNN part of this policy with optional loss computation.""" images = self.image_normalizer(batch["observation.images.top"]) if return_loss: # training time diff --git a/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py b/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py deleted file mode 100644 index f7432db3..00000000 --- a/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py +++ /dev/null @@ -1,315 +0,0 @@ -"""Code from the original diffusion policy project. - -Notes on how to load a checkpoint from the original repository: - -In the original repository, run the eval and use a breakpoint to extract the policy weights. - -``` -torch.save(policy.state_dict(), "weights.pt") -``` - -In this repository, add a breakpoint somewhere after creating an equivalent policy and load in the weights: - -``` -loaded = torch.load("weights.pt") -aligned = {} -their_prefix = "obs_encoder.obs_nets.image.backbone" -our_prefix = "obs_encoder.key_model_map.image.backbone" -aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)}) -their_prefix = "obs_encoder.obs_nets.image.pool" -our_prefix = "obs_encoder.key_model_map.image.pool" -aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)}) -their_prefix = "obs_encoder.obs_nets.image.nets.3" -our_prefix = "obs_encoder.key_model_map.image.out" -aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)}) -aligned.update({k: v for k, v in loaded.items() if k.startswith('model.')}) -# Note: here you are loading into the ema model. -missing_keys, unexpected_keys = policy.ema_diffusion.load_state_dict(aligned, strict=False) -assert all('_dummy_variable' in k for k in missing_keys) -assert len(unexpected_keys) == 0 -``` - -Then in that same runtime you can also save the weights with the new aligned state_dict: - -``` -policy.save("weights.pt") -``` - -Now you can remove the breakpoint and extra code and load in the weights just like with any other lerobot checkpoint. - -""" - -from typing import Dict - -import torch -import torch.nn.functional as F # noqa: N812 -from diffusers.schedulers.scheduling_ddpm import DDPMScheduler -from einops import reduce - -from lerobot.common.policies.diffusion.model.conditional_unet1d import ConditionalUnet1D -from lerobot.common.policies.diffusion.model.mask_generator import LowdimMaskGenerator -from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin -from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder -from lerobot.common.policies.diffusion.model.normalizer import LinearNormalizer -from lerobot.common.policies.diffusion.pytorch_utils import dict_apply - - -class BaseImagePolicy(ModuleAttrMixin): - # init accepts keyword argument shape_meta, see config/task/*_image.yaml - - def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - """ - obs_dict: - str: B,To,* - return: B,Ta,Da - """ - raise NotImplementedError() - - # reset state for stateful policies - def reset(self): - pass - - # ========== training =========== - # no standard training interface except setting normalizer - def set_normalizer(self, normalizer: LinearNormalizer): - raise NotImplementedError() - - -class DiffusionUnetImagePolicy(BaseImagePolicy): - def __init__( - self, - shape_meta: dict, - noise_scheduler: DDPMScheduler, - obs_encoder: MultiImageObsEncoder, - horizon, - n_action_steps, - n_obs_steps, - num_inference_steps=None, - obs_as_global_cond=True, - diffusion_step_embed_dim=256, - down_dims=(256, 512, 1024), - kernel_size=5, - n_groups=8, - cond_predict_scale=True, - # parameters passed to step - **kwargs, - ): - super().__init__() - - # parse shapes - action_shape = shape_meta["action"]["shape"] - assert len(action_shape) == 1 - action_dim = action_shape[0] - # get feature dim - obs_feature_dim = obs_encoder.output_shape()[0] - - # create diffusion model - input_dim = action_dim + obs_feature_dim - global_cond_dim = None - if obs_as_global_cond: - input_dim = action_dim - global_cond_dim = obs_feature_dim * n_obs_steps - - model = ConditionalUnet1D( - input_dim=input_dim, - local_cond_dim=None, - global_cond_dim=global_cond_dim, - diffusion_step_embed_dim=diffusion_step_embed_dim, - down_dims=down_dims, - kernel_size=kernel_size, - n_groups=n_groups, - cond_predict_scale=cond_predict_scale, - ) - - self.obs_encoder = obs_encoder - self.model = model - self.noise_scheduler = noise_scheduler - self.mask_generator = LowdimMaskGenerator( - action_dim=action_dim, - obs_dim=0 if obs_as_global_cond else obs_feature_dim, - max_n_obs_steps=n_obs_steps, - fix_obs_steps=True, - action_visible=False, - ) - self.horizon = horizon - self.obs_feature_dim = obs_feature_dim - self.action_dim = action_dim - self.n_action_steps = n_action_steps - self.n_obs_steps = n_obs_steps - self.obs_as_global_cond = obs_as_global_cond - self.kwargs = kwargs - - if num_inference_steps is None: - num_inference_steps = noise_scheduler.config.num_train_timesteps - self.num_inference_steps = num_inference_steps - - # ========= inference ============ - def conditional_sample( - self, - condition_data, - condition_mask, - local_cond=None, - global_cond=None, - generator=None, - # keyword arguments to scheduler.step - **kwargs, - ): - model = self.model - scheduler = self.noise_scheduler - - trajectory = torch.randn( - size=condition_data.shape, - dtype=condition_data.dtype, - device=condition_data.device, - generator=generator, - ) - - # set step values - scheduler.set_timesteps(self.num_inference_steps) - - for t in scheduler.timesteps: - # 1. apply conditioning - trajectory[condition_mask] = condition_data[condition_mask] - - # 2. predict model output - model_output = model(trajectory, t, local_cond=local_cond, global_cond=global_cond) - - # 3. compute previous image: x_t -> x_t-1 - trajectory = scheduler.step( - model_output, - t, - trajectory, - generator=generator, - # **kwargs # TODO(rcadene): in diffusion_policy, expected to be {} - ).prev_sample - - # finally make sure conditioning is enforced - trajectory[condition_mask] = condition_data[condition_mask] - - return trajectory - - def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - """ - obs_dict: must include "obs" key - result: must include "action" key - """ - assert "past_action" not in obs_dict # not implemented yet - nobs = obs_dict - value = next(iter(nobs.values())) - bsize, n_obs_steps = value.shape[:2] - horizon = self.horizon - action_dim = self.action_dim - obs_dim = self.obs_feature_dim - assert self.n_obs_steps == n_obs_steps - - # build input - device = self.device - dtype = self.dtype - - # handle different ways of passing observation - local_cond = None - global_cond = None - if self.obs_as_global_cond: - # condition through global feature - this_nobs = dict_apply(nobs, lambda x: x[:, :n_obs_steps, ...].reshape(-1, *x.shape[2:])) - nobs_features = self.obs_encoder(this_nobs) - # reshape back to B, Do - global_cond = nobs_features.reshape(bsize, -1) - # empty data for action - cond_data = torch.zeros(size=(bsize, horizon, action_dim), device=device, dtype=dtype) - cond_mask = torch.zeros_like(cond_data, dtype=torch.bool) - else: - # condition through impainting - this_nobs = dict_apply(nobs, lambda x: x[:, :n_obs_steps, ...].reshape(-1, *x.shape[2:])) - nobs_features = self.obs_encoder(this_nobs) - # reshape back to B, T, Do - nobs_features = nobs_features.reshape(bsize, n_obs_steps, -1) - cond_data = torch.zeros(size=(bsize, horizon, action_dim + obs_dim), device=device, dtype=dtype) - cond_mask = torch.zeros_like(cond_data, dtype=torch.bool) - cond_data[:, :n_obs_steps, action_dim:] = nobs_features - cond_mask[:, :n_obs_steps, action_dim:] = True - - # run sampling - nsample = self.conditional_sample( - cond_data, cond_mask, local_cond=local_cond, global_cond=global_cond - ) - - action_pred = nsample[..., :action_dim] - # get action - start = n_obs_steps - 1 - end = start + self.n_action_steps - action = action_pred[:, start:end] - - result = {"action": action, "action_pred": action_pred} - return result - - def compute_loss(self, batch): - nobs = { - "image": batch["observation.image"], - "agent_pos": batch["observation.state"], - } - nactions = batch["action"] - batch_size = nactions.shape[0] - horizon = nactions.shape[1] - - # handle different ways of passing observation - local_cond = None - global_cond = None - trajectory = nactions - cond_data = trajectory - if self.obs_as_global_cond: - # reshape B, T, ... to B*T - this_nobs = dict_apply(nobs, lambda x: x[:, : self.n_obs_steps, ...].reshape(-1, *x.shape[2:])) - nobs_features = self.obs_encoder(this_nobs) - # reshape back to B, Do - global_cond = nobs_features.reshape(batch_size, -1) - else: - # reshape B, T, ... to B*T - this_nobs = dict_apply(nobs, lambda x: x.reshape(-1, *x.shape[2:])) - nobs_features = self.obs_encoder(this_nobs) - # reshape back to B, T, Do - nobs_features = nobs_features.reshape(batch_size, horizon, -1) - cond_data = torch.cat([nactions, nobs_features], dim=-1) - trajectory = cond_data.detach() - - # generate impainting mask - condition_mask = self.mask_generator(trajectory.shape) - - # Sample noise that we'll add to the images - noise = torch.randn(trajectory.shape, device=trajectory.device) - bsz = trajectory.shape[0] - # Sample a random timestep for each image - timesteps = torch.randint( - 0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=trajectory.device - ).long() - # Add noise to the clean images according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_trajectory = self.noise_scheduler.add_noise(trajectory, noise, timesteps) - - # compute loss mask - loss_mask = ~condition_mask - - # apply conditioning - noisy_trajectory[condition_mask] = cond_data[condition_mask] - - # Predict the noise residual - pred = self.model(noisy_trajectory, timesteps, local_cond=local_cond, global_cond=global_cond) - - pred_type = self.noise_scheduler.config.prediction_type - if pred_type == "epsilon": - target = noise - elif pred_type == "sample": - target = trajectory - else: - raise ValueError(f"Unsupported prediction type {pred_type}") - - loss = F.mse_loss(pred, target, reduction="none") - loss = loss * loss_mask.type(loss.dtype) - - if "action_is_pad" in batch: - in_episode_bound = ~batch["action_is_pad"] - loss = loss * in_episode_bound[:, :, None].type(loss.dtype) - - loss = reduce(loss, "b t c -> b", "mean", b=batch_size) - loss = loss.mean() - return loss diff --git a/lerobot/common/policies/diffusion/model/conditional_unet1d.py b/lerobot/common/policies/diffusion/model/conditional_unet1d.py index d2971d38..5c43d488 100644 --- a/lerobot/common/policies/diffusion/model/conditional_unet1d.py +++ b/lerobot/common/policies/diffusion/model/conditional_unet1d.py @@ -1,286 +1,307 @@ import logging -from typing import Union +import math import einops import torch import torch.nn as nn -from einops.layers.torch import Rearrange - -from lerobot.common.policies.diffusion.model.conv1d_components import Conv1dBlock, Downsample1d, Upsample1d -from lerobot.common.policies.diffusion.model.positional_embedding import SinusoidalPosEmb +from torch import Tensor logger = logging.getLogger(__name__) -class ConditionalResidualBlock1D(nn.Module): +class _SinusoidalPosEmb(nn.Module): + # TODO(now): consolidate? + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class _Conv1dBlock(nn.Module): + """Conv1d --> GroupNorm --> Mish""" + + def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): + super().__init__() + + self.block = nn.Sequential( + nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), + nn.GroupNorm(n_groups, out_channels), + nn.Mish(), + ) + + def forward(self, x): + return self.block(x) + + +class _ConditionalResidualBlock1D(nn.Module): + """ResNet style 1D convolutional block with FiLM modulation for conditioning.""" + def __init__( - self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8, cond_predict_scale=False + self, + in_channels: int, + out_channels: int, + cond_dim: int, + kernel_size: int = 3, + n_groups: int = 8, + # Set to True to do scale modulation with FiLM as well as bias modulation (defaults to False meaning + # FiLM just modulates bias). + film_scale_modulation: bool = False, ): super().__init__() - self.blocks = nn.ModuleList( - [ - Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups), - Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups), - ] - ) - - # FiLM modulation https://arxiv.org/abs/1709.07871 - # predicts per-channel scale and bias - cond_channels = out_channels - if cond_predict_scale: - cond_channels = out_channels * 2 - self.cond_predict_scale = cond_predict_scale + self.film_scale_modulation = film_scale_modulation self.out_channels = out_channels - self.cond_encoder = nn.Sequential( - nn.Mish(), - nn.Linear(cond_dim, cond_channels), - Rearrange("batch t -> batch t 1"), - ) - # make sure dimensions compatible + self.conv1 = _Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups) + + # FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale. + cond_channels = out_channels * 2 if film_scale_modulation else out_channels + self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels)) + + self.conv2 = _Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups) + + # A final convolution for dimension matching the residual (if needed). self.residual_conv = ( nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() ) - def forward(self, x, cond): + def forward(self, x: Tensor, cond: Tensor) -> Tensor: """ - x : [ batch_size x in_channels x horizon ] - cond : [ batch_size x cond_dim] + Args: + x: (B, in_channels, T) + cond: (B, cond_dim) + Returns: + (B, out_channels, T) + """ + out = self.conv1(x) - returns: - out : [ batch_size x out_channels x horizon ] - """ - out = self.blocks[0](x) - embed = self.cond_encoder(cond) - if self.cond_predict_scale: - embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1) - scale = embed[:, 0, ...] - bias = embed[:, 1, ...] + # Get condition embedding. Unsqueeze for broadcasting to `out`, resulting in (B, out_channels, 1). + cond_embed = self.cond_encoder(cond).unsqueeze(-1) + if self.film_scale_modulation: + # Treat the embedding as a list of scales and biases. + scale = cond_embed[:, : self.out_channels] + bias = cond_embed[:, self.out_channels :] out = scale * out + bias else: - out = out + embed - out = self.blocks[1](out) + # Treat the embedding as biases. + out = out + cond_embed + + out = self.conv2(out) out = out + self.residual_conv(x) return out class ConditionalUnet1D(nn.Module): + """A 1D convolutional UNet with FiLM modulation for conditioning. + + Two types of conditioning can be applied: + - Global: Conditioning information that is aggregated over the whole observation window. This is + incorporated via the FiLM technique in the residual convolution blocks of the Unet's encoder/decoder. + - Local: Conditioning information for each timestep in the observation window. This is incorporated + by encoding the information via 1D convolutions and adding the resulting embeddings to the inputs and + outputs of the Unet's encoder/decoder. + """ + def __init__( self, - input_dim, - local_cond_dim=None, - global_cond_dim=None, - diffusion_step_embed_dim=256, - down_dims=None, - kernel_size=3, - n_groups=8, - cond_predict_scale=False, + input_dim: int, + local_cond_dim: int | None = None, + global_cond_dim: int | None = None, + diffusion_step_embed_dim: int = 256, + down_dims: int | None = None, + kernel_size: int = 3, + n_groups: int = 8, + film_scale_modulation: bool = False, ): super().__init__() + if down_dims is None: down_dims = [256, 512, 1024] - all_dims = [input_dim] + list(down_dims) - start_dim = down_dims[0] - - dsed = diffusion_step_embed_dim - diffusion_step_encoder = nn.Sequential( - SinusoidalPosEmb(dsed), - nn.Linear(dsed, dsed * 4), + # Encoder for the diffusion timestep. + self.diffusion_step_encoder = nn.Sequential( + _SinusoidalPosEmb(diffusion_step_embed_dim), + nn.Linear(diffusion_step_embed_dim, diffusion_step_embed_dim * 4), nn.Mish(), - nn.Linear(dsed * 4, dsed), + nn.Linear(diffusion_step_embed_dim * 4, diffusion_step_embed_dim), ) - cond_dim = dsed + + # The FiLM conditioning dimension. + cond_dim = diffusion_step_embed_dim if global_cond_dim is not None: cond_dim += global_cond_dim - in_out = list(zip(all_dims[:-1], all_dims[1:], strict=False)) - - local_cond_encoder = None + self.local_cond_down_encoder = None + self.local_cond_up_encoder = None if local_cond_dim is not None: - _, dim_out = in_out[0] - dim_in = local_cond_dim - local_cond_encoder = nn.ModuleList( - [ - # down encoder - ConditionalResidualBlock1D( - dim_in, - dim_out, - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - cond_predict_scale=cond_predict_scale, - ), - # up encoder - ConditionalResidualBlock1D( - dim_in, - dim_out, - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - cond_predict_scale=cond_predict_scale, - ), - ] + # Encoder for the local conditioning. The output gets added to the Unet encoder input. + self.local_cond_down_encoder = _ConditionalResidualBlock1D( + local_cond_dim, + down_dims[0], + cond_dim=cond_dim, + kernel_size=kernel_size, + n_groups=n_groups, + film_scale_modulation=film_scale_modulation, + ) + # Encoder for the local conditioning. The output gets added to the Unet encoder output. + self.local_cond_up_encoder = _ConditionalResidualBlock1D( + local_cond_dim, + down_dims[0], + cond_dim=cond_dim, + kernel_size=kernel_size, + n_groups=n_groups, + film_scale_modulation=film_scale_modulation, ) - mid_dim = all_dims[-1] + # In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we + # just reverse these. + in_out = [(input_dim, down_dims[0])] + list(zip(down_dims[:-1], down_dims[1:], strict=True)) + + # Unet encoder. + self.down_modules = nn.ModuleList([]) + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (len(in_out) - 1) + self.down_modules.append( + nn.ModuleList( + [ + _ConditionalResidualBlock1D( + dim_in, + dim_out, + cond_dim=cond_dim, + kernel_size=kernel_size, + n_groups=n_groups, + film_scale_modulation=film_scale_modulation, + ), + _ConditionalResidualBlock1D( + dim_out, + dim_out, + cond_dim=cond_dim, + kernel_size=kernel_size, + n_groups=n_groups, + film_scale_modulation=film_scale_modulation, + ), + # Downsample as long as it is not the last block. + nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(), + ] + ) + ) + + # Processing in the middle of the auto-encoder. self.mid_modules = nn.ModuleList( [ - ConditionalResidualBlock1D( - mid_dim, - mid_dim, + _ConditionalResidualBlock1D( + down_dims[-1], + down_dims[-1], cond_dim=cond_dim, kernel_size=kernel_size, n_groups=n_groups, - cond_predict_scale=cond_predict_scale, + film_scale_modulation=film_scale_modulation, ), - ConditionalResidualBlock1D( - mid_dim, - mid_dim, + _ConditionalResidualBlock1D( + down_dims[-1], + down_dims[-1], cond_dim=cond_dim, kernel_size=kernel_size, n_groups=n_groups, - cond_predict_scale=cond_predict_scale, + film_scale_modulation=film_scale_modulation, ), ] ) - down_modules = nn.ModuleList([]) - for ind, (dim_in, dim_out) in enumerate(in_out): + # Unet decoder. + self.up_modules = nn.ModuleList([]) + for ind, (dim_out, dim_in) in enumerate(reversed(in_out[1:])): is_last = ind >= (len(in_out) - 1) - down_modules.append( + self.up_modules.append( nn.ModuleList( [ - ConditionalResidualBlock1D( - dim_in, + _ConditionalResidualBlock1D( + dim_in * 2, # x2 as it takes the encoder's skip connection as well dim_out, cond_dim=cond_dim, kernel_size=kernel_size, n_groups=n_groups, - cond_predict_scale=cond_predict_scale, + film_scale_modulation=film_scale_modulation, ), - ConditionalResidualBlock1D( + _ConditionalResidualBlock1D( dim_out, dim_out, cond_dim=cond_dim, kernel_size=kernel_size, n_groups=n_groups, - cond_predict_scale=cond_predict_scale, + film_scale_modulation=film_scale_modulation, ), - Downsample1d(dim_out) if not is_last else nn.Identity(), + # Upsample as long as it is not the last block. + nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(), ] ) ) - up_modules = nn.ModuleList([]) - for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): - is_last = ind >= (len(in_out) - 1) - up_modules.append( - nn.ModuleList( - [ - ConditionalResidualBlock1D( - dim_out * 2, - dim_in, - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - cond_predict_scale=cond_predict_scale, - ), - ConditionalResidualBlock1D( - dim_in, - dim_in, - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - cond_predict_scale=cond_predict_scale, - ), - Upsample1d(dim_in) if not is_last else nn.Identity(), - ] - ) - ) - - final_conv = nn.Sequential( - Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size), - nn.Conv1d(start_dim, input_dim, 1), + self.final_conv = nn.Sequential( + _Conv1dBlock(down_dims[0], down_dims[0], kernel_size=kernel_size), + nn.Conv1d(down_dims[0], input_dim, 1), ) - self.diffusion_step_encoder = diffusion_step_encoder - self.local_cond_encoder = local_cond_encoder - self.up_modules = up_modules - self.down_modules = down_modules - self.final_conv = final_conv - - logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) - - def forward( - self, - sample: torch.Tensor, - timestep: Union[torch.Tensor, float, int], - local_cond=None, - global_cond=None, - **kwargs, - ): + def forward(self, x: Tensor, timestep: Tensor | int, local_cond=None, global_cond=None) -> Tensor: """ - x: (B,T,input_dim) - timestep: (B,) or int, diffusion step - local_cond: (B,T,local_cond_dim) - global_cond: (B,global_cond_dim) - output: (B,T,input_dim) + Args: + x: (B, T, input_dim) tensor for input to the Unet. + timestep: (B,) tensor of (timestep_we_are_denoising_from - 1). + local_cond: (B, T, local_cond_dim) + global_cond: (B, global_cond_dim) + output: (B, T, input_dim) + Returns: + (B, T, input_dim) """ - sample = einops.rearrange(sample, "b h t -> b t h") - - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) - elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) - - global_feature = self.diffusion_step_encoder(timesteps) - - if global_cond is not None: - global_feature = torch.cat([global_feature, global_cond], axis=-1) - - # encode local features - h_local = [] + # For 1D convolutions we'll need feature dimension first. + x = einops.rearrange(x, "b t d -> b d t") if local_cond is not None: - local_cond = einops.rearrange(local_cond, "b h t -> b t h") - resnet, resnet2 = self.local_cond_encoder - x = resnet(local_cond, global_feature) - h_local.append(x) - x = resnet2(local_cond, global_feature) - h_local.append(x) + if self.local_cond_down_encoder is None or self.local_cond_up_encoder is None: + raise ValueError( + "`local_cond` was provided but the relevant encoders weren't built at initialization." + ) + local_cond = einops.rearrange(local_cond, "b t d -> b d t") - x = sample - h = [] + timesteps_embed = self.diffusion_step_encoder(timestep) + + # If there is a global conditioning feature, concatenate it to the timestep embedding. + if global_cond is not None: + global_feature = torch.cat([timesteps_embed, global_cond], axis=-1) + else: + global_feature = timesteps_embed + + encoder_skip_features: list[Tensor] = [] for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules): x = resnet(x, global_feature) - if idx == 0 and len(h_local) > 0: - x = x + h_local[0] + if idx == 0 and local_cond is not None: + x = x + self.local_cond_down_encoder(local_cond, global_feature) x = resnet2(x, global_feature) - h.append(x) + encoder_skip_features.append(x) x = downsample(x) for mid_module in self.mid_modules: x = mid_module(x, global_feature) for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules): - x = torch.cat((x, h.pop()), dim=1) + x = torch.cat((x, encoder_skip_features.pop()), dim=1) x = resnet(x, global_feature) - # The correct condition should be: - # if idx == (len(self.up_modules)-1) and len(h_local) > 0: - # However this change will break compatibility with published checkpoints. - # Therefore it is left as a comment. - if idx == len(self.up_modules) and len(h_local) > 0: - x = x + h_local[1] + # Note: The condition in the original implementation is: + # if idx == len(self.up_modules) and local_cond is not None: + # But as they mention in their comments, this is incorrect. We use the correct condition here. + if idx == (len(self.up_modules) - 1) and local_cond is not None: + x = x + self.local_cond_up_encoder(local_cond, global_feature) x = resnet2(x, global_feature) x = upsample(x) x = self.final_conv(x) - x = einops.rearrange(x, "b t h -> b h t") + x = einops.rearrange(x, "b d t -> b t d") return x diff --git a/lerobot/common/policies/diffusion/model/conv1d_components.py b/lerobot/common/policies/diffusion/model/conv1d_components.py deleted file mode 100644 index 3c21eaf6..00000000 --- a/lerobot/common/policies/diffusion/model/conv1d_components.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch.nn as nn - -# from einops.layers.torch import Rearrange - - -class Downsample1d(nn.Module): - def __init__(self, dim): - super().__init__() - self.conv = nn.Conv1d(dim, dim, 3, 2, 1) - - def forward(self, x): - return self.conv(x) - - -class Upsample1d(nn.Module): - def __init__(self, dim): - super().__init__() - self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) - - def forward(self, x): - return self.conv(x) - - -class Conv1dBlock(nn.Module): - """ - Conv1d --> GroupNorm --> Mish - """ - - def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): - super().__init__() - - self.block = nn.Sequential( - nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), - # Rearrange('batch channels horizon -> batch channels 1 horizon'), - nn.GroupNorm(n_groups, out_channels), - # Rearrange('batch channels 1 horizon -> batch channels horizon'), - nn.Mish(), - ) - - def forward(self, x): - return self.block(x) - - -# def test(): -# cb = Conv1dBlock(256, 128, kernel_size=3) -# x = torch.zeros((1,256,16)) -# o = cb(x) diff --git a/lerobot/common/policies/diffusion/model/crop_randomizer.py b/lerobot/common/policies/diffusion/model/crop_randomizer.py deleted file mode 100644 index 2e60f353..00000000 --- a/lerobot/common/policies/diffusion/model/crop_randomizer.py +++ /dev/null @@ -1,294 +0,0 @@ -import torch -import torch.nn as nn -import torchvision.transforms.functional as ttf - -import lerobot.common.policies.diffusion.model.tensor_utils as tu - - -class CropRandomizer(nn.Module): - """ - Randomly sample crops at input, and then average across crop features at output. - """ - - def __init__( - self, - input_shape, - crop_height, - crop_width, - num_crops=1, - pos_enc=False, - ): - """ - Args: - input_shape (tuple, list): shape of input (not including batch dimension) - crop_height (int): crop height - crop_width (int): crop width - num_crops (int): number of random crops to take - pos_enc (bool): if True, add 2 channels to the output to encode the spatial - location of the cropped pixels in the source image - """ - super().__init__() - - assert len(input_shape) == 3 # (C, H, W) - assert crop_height < input_shape[1] - assert crop_width < input_shape[2] - - self.input_shape = input_shape - self.crop_height = crop_height - self.crop_width = crop_width - self.num_crops = num_crops - self.pos_enc = pos_enc - - def output_shape_in(self, input_shape=None): - """ - Function to compute output shape from inputs to this module. Corresponds to - the @forward_in operation, where raw inputs (usually observation modalities) - are passed in. - - Args: - input_shape (iterable of int): shape of input. Does not include batch dimension. - Some modules may not need this argument, if their output does not depend - on the size of the input, or if they assume fixed size input. - - Returns: - out_shape ([int]): list of integers corresponding to output shape - """ - - # outputs are shape (C, CH, CW), or maybe C + 2 if using position encoding, because - # the number of crops are reshaped into the batch dimension, increasing the batch - # size from B to B * N - out_c = self.input_shape[0] + 2 if self.pos_enc else self.input_shape[0] - return [out_c, self.crop_height, self.crop_width] - - def output_shape_out(self, input_shape=None): - """ - Function to compute output shape from inputs to this module. Corresponds to - the @forward_out operation, where processed inputs (usually encoded observation - modalities) are passed in. - - Args: - input_shape (iterable of int): shape of input. Does not include batch dimension. - Some modules may not need this argument, if their output does not depend - on the size of the input, or if they assume fixed size input. - - Returns: - out_shape ([int]): list of integers corresponding to output shape - """ - - # since the forward_out operation splits [B * N, ...] -> [B, N, ...] - # and then pools to result in [B, ...], only the batch dimension changes, - # and so the other dimensions retain their shape. - return list(input_shape) - - def forward_in(self, inputs): - """ - Samples N random crops for each input in the batch, and then reshapes - inputs to [B * N, ...]. - """ - assert len(inputs.shape) >= 3 # must have at least (C, H, W) dimensions - if self.training: - # generate random crops - out, _ = sample_random_image_crops( - images=inputs, - crop_height=self.crop_height, - crop_width=self.crop_width, - num_crops=self.num_crops, - pos_enc=self.pos_enc, - ) - # [B, N, ...] -> [B * N, ...] - return tu.join_dimensions(out, 0, 1) - else: - # take center crop during eval - out = ttf.center_crop(img=inputs, output_size=(self.crop_height, self.crop_width)) - if self.num_crops > 1: - B, C, H, W = out.shape # noqa: N806 - out = out.unsqueeze(1).expand(B, self.num_crops, C, H, W).reshape(-1, C, H, W) - # [B * N, ...] - return out - - def forward_out(self, inputs): - """ - Splits the outputs from shape [B * N, ...] -> [B, N, ...] and then average across N - to result in shape [B, ...] to make sure the network output is consistent with - what would have happened if there were no randomization. - """ - if self.num_crops <= 1: - return inputs - else: - batch_size = inputs.shape[0] // self.num_crops - out = tu.reshape_dimensions( - inputs, begin_axis=0, end_axis=0, target_dims=(batch_size, self.num_crops) - ) - return out.mean(dim=1) - - def forward(self, inputs): - return self.forward_in(inputs) - - def __repr__(self): - """Pretty print network.""" - header = "{}".format(str(self.__class__.__name__)) - msg = header + "(input_shape={}, crop_size=[{}, {}], num_crops={})".format( - self.input_shape, self.crop_height, self.crop_width, self.num_crops - ) - return msg - - -def crop_image_from_indices(images, crop_indices, crop_height, crop_width): - """ - Crops images at the locations specified by @crop_indices. Crops will be - taken across all channels. - - Args: - images (torch.Tensor): batch of images of shape [..., C, H, W] - - crop_indices (torch.Tensor): batch of indices of shape [..., N, 2] where - N is the number of crops to take per image and each entry corresponds - to the pixel height and width of where to take the crop. Note that - the indices can also be of shape [..., 2] if only 1 crop should - be taken per image. Leading dimensions must be consistent with - @images argument. Each index specifies the top left of the crop. - Values must be in range [0, H - CH - 1] x [0, W - CW - 1] where - H and W are the height and width of @images and CH and CW are - @crop_height and @crop_width. - - crop_height (int): height of crop to take - - crop_width (int): width of crop to take - - Returns: - crops (torch.Tesnor): cropped images of shape [..., C, @crop_height, @crop_width] - """ - - # make sure length of input shapes is consistent - assert crop_indices.shape[-1] == 2 - ndim_im_shape = len(images.shape) - ndim_indices_shape = len(crop_indices.shape) - assert (ndim_im_shape == ndim_indices_shape + 1) or (ndim_im_shape == ndim_indices_shape + 2) - - # maybe pad so that @crop_indices is shape [..., N, 2] - is_padded = False - if ndim_im_shape == ndim_indices_shape + 2: - crop_indices = crop_indices.unsqueeze(-2) - is_padded = True - - # make sure leading dimensions between images and indices are consistent - assert images.shape[:-3] == crop_indices.shape[:-2] - - device = images.device - image_c, image_h, image_w = images.shape[-3:] - num_crops = crop_indices.shape[-2] - - # make sure @crop_indices are in valid range - assert (crop_indices[..., 0] >= 0).all().item() - assert (crop_indices[..., 0] < (image_h - crop_height)).all().item() - assert (crop_indices[..., 1] >= 0).all().item() - assert (crop_indices[..., 1] < (image_w - crop_width)).all().item() - - # convert each crop index (ch, cw) into a list of pixel indices that correspond to the entire window. - - # 2D index array with columns [0, 1, ..., CH - 1] and shape [CH, CW] - crop_ind_grid_h = torch.arange(crop_height).to(device) - crop_ind_grid_h = tu.unsqueeze_expand_at(crop_ind_grid_h, size=crop_width, dim=-1) - # 2D index array with rows [0, 1, ..., CW - 1] and shape [CH, CW] - crop_ind_grid_w = torch.arange(crop_width).to(device) - crop_ind_grid_w = tu.unsqueeze_expand_at(crop_ind_grid_w, size=crop_height, dim=0) - # combine into shape [CH, CW, 2] - crop_in_grid = torch.cat((crop_ind_grid_h.unsqueeze(-1), crop_ind_grid_w.unsqueeze(-1)), dim=-1) - - # Add above grid with the offset index of each sampled crop to get 2d indices for each crop. - # After broadcasting, this will be shape [..., N, CH, CW, 2] and each crop has a [CH, CW, 2] - # shape array that tells us which pixels from the corresponding source image to grab. - grid_reshape = [1] * len(crop_indices.shape[:-1]) + [crop_height, crop_width, 2] - all_crop_inds = crop_indices.unsqueeze(-2).unsqueeze(-2) + crop_in_grid.reshape(grid_reshape) - - # For using @torch.gather, convert to flat indices from 2D indices, and also - # repeat across the channel dimension. To get flat index of each pixel to grab for - # each sampled crop, we just use the mapping: ind = h_ind * @image_w + w_ind - all_crop_inds = all_crop_inds[..., 0] * image_w + all_crop_inds[..., 1] # shape [..., N, CH, CW] - all_crop_inds = tu.unsqueeze_expand_at(all_crop_inds, size=image_c, dim=-3) # shape [..., N, C, CH, CW] - all_crop_inds = tu.flatten(all_crop_inds, begin_axis=-2) # shape [..., N, C, CH * CW] - - # Repeat and flatten the source images -> [..., N, C, H * W] and then use gather to index with crop pixel inds - images_to_crop = tu.unsqueeze_expand_at(images, size=num_crops, dim=-4) - images_to_crop = tu.flatten(images_to_crop, begin_axis=-2) - crops = torch.gather(images_to_crop, dim=-1, index=all_crop_inds) - # [..., N, C, CH * CW] -> [..., N, C, CH, CW] - reshape_axis = len(crops.shape) - 1 - crops = tu.reshape_dimensions( - crops, begin_axis=reshape_axis, end_axis=reshape_axis, target_dims=(crop_height, crop_width) - ) - - if is_padded: - # undo padding -> [..., C, CH, CW] - crops = crops.squeeze(-4) - return crops - - -def sample_random_image_crops(images, crop_height, crop_width, num_crops, pos_enc=False): - """ - For each image, randomly sample @num_crops crops of size (@crop_height, @crop_width), from - @images. - - Args: - images (torch.Tensor): batch of images of shape [..., C, H, W] - - crop_height (int): height of crop to take - - crop_width (int): width of crop to take - - num_crops (n): number of crops to sample - - pos_enc (bool): if True, also add 2 channels to the outputs that gives a spatial - encoding of the original source pixel locations. This means that the - output crops will contain information about where in the source image - it was sampled from. - - Returns: - crops (torch.Tensor): crops of shape (..., @num_crops, C, @crop_height, @crop_width) - if @pos_enc is False, otherwise (..., @num_crops, C + 2, @crop_height, @crop_width) - - crop_inds (torch.Tensor): sampled crop indices of shape (..., N, 2) - """ - device = images.device - - # maybe add 2 channels of spatial encoding to the source image - source_im = images - if pos_enc: - # spatial encoding [y, x] in [0, 1] - h, w = source_im.shape[-2:] - pos_y, pos_x = torch.meshgrid(torch.arange(h), torch.arange(w)) - pos_y = pos_y.float().to(device) / float(h) - pos_x = pos_x.float().to(device) / float(w) - position_enc = torch.stack((pos_y, pos_x)) # shape [C, H, W] - - # unsqueeze and expand to match leading dimensions -> shape [..., C, H, W] - leading_shape = source_im.shape[:-3] - position_enc = position_enc[(None,) * len(leading_shape)] - position_enc = position_enc.expand(*leading_shape, -1, -1, -1) - - # concat across channel dimension with input - source_im = torch.cat((source_im, position_enc), dim=-3) - - # make sure sample boundaries ensure crops are fully within the images - image_c, image_h, image_w = source_im.shape[-3:] - max_sample_h = image_h - crop_height - max_sample_w = image_w - crop_width - - # Sample crop locations for all tensor dimensions up to the last 3, which are [C, H, W]. - # Each gets @num_crops samples - typically this will just be the batch dimension (B), so - # we will sample [B, N] indices, but this supports having more than one leading dimension, - # or possibly no leading dimension. - # - # Trick: sample in [0, 1) with rand, then re-scale to [0, M) and convert to long to get sampled ints - crop_inds_h = (max_sample_h * torch.rand(*source_im.shape[:-3], num_crops).to(device)).long() - crop_inds_w = (max_sample_w * torch.rand(*source_im.shape[:-3], num_crops).to(device)).long() - crop_inds = torch.cat((crop_inds_h.unsqueeze(-1), crop_inds_w.unsqueeze(-1)), dim=-1) # shape [..., N, 2] - - crops = crop_image_from_indices( - images=source_im, - crop_indices=crop_inds, - crop_height=crop_height, - crop_width=crop_width, - ) - - return crops, crop_inds diff --git a/lerobot/common/policies/diffusion/model/dict_of_tensor_mixin.py b/lerobot/common/policies/diffusion/model/dict_of_tensor_mixin.py deleted file mode 100644 index d1356006..00000000 --- a/lerobot/common/policies/diffusion/model/dict_of_tensor_mixin.py +++ /dev/null @@ -1,41 +0,0 @@ -import torch -import torch.nn as nn - - -class DictOfTensorMixin(nn.Module): - def __init__(self, params_dict=None): - super().__init__() - if params_dict is None: - params_dict = nn.ParameterDict() - self.params_dict = params_dict - - @property - def device(self): - return next(iter(self.parameters())).device - - def _load_from_state_dict( - self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ): - def dfs_add(dest, keys, value: torch.Tensor): - if len(keys) == 1: - dest[keys[0]] = value - return - - if keys[0] not in dest: - dest[keys[0]] = nn.ParameterDict() - dfs_add(dest[keys[0]], keys[1:], value) - - def load_dict(state_dict, prefix): - out_dict = nn.ParameterDict() - for key, value in state_dict.items(): - value: torch.Tensor - if key.startswith(prefix): - param_keys = key[len(prefix) :].split(".")[1:] - # if len(param_keys) == 0: - # import pdb; pdb.set_trace() - dfs_add(out_dict, param_keys, value.clone()) - return out_dict - - self.params_dict = load_dict(state_dict, prefix + "params_dict") - self.params_dict.requires_grad_(False) - return diff --git a/lerobot/common/policies/diffusion/model/diffusion_unet_image_policy.py b/lerobot/common/policies/diffusion/model/diffusion_unet_image_policy.py new file mode 100644 index 00000000..b6b78925 --- /dev/null +++ b/lerobot/common/policies/diffusion/model/diffusion_unet_image_policy.py @@ -0,0 +1,220 @@ +import einops +import torch +import torch.nn.functional as F # noqa: N812 +from diffusers.schedulers.scheduling_ddpm import DDPMScheduler +from torch import Tensor, nn + +from lerobot.common.policies.diffusion.model.conditional_unet1d import ConditionalUnet1D +from lerobot.common.policies.diffusion.model.rgb_encoder import RgbEncoder +from lerobot.common.policies.utils import get_device_from_parameters, get_dtype_from_parameters + + +class DiffusionUnetImagePolicy(nn.Module): + """ + TODO(now): Add DDIM scheduler. + + Changes: TODO(now) + - Use single image encoder for now instead of generic obs_encoder. We may generalize again when/if + needed. Code for a general observation encoder can be found at: + https://github.com/huggingface/lerobot/blob/920e0d118b493e4cc3058a9b1b764f38ae145d8e/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py + - Uses the observation as global conditioning for the Unet by default. + - Does not do any inpainting (which would be applicable if the observation were not used to condition + the Unet). + """ + + def __init__( + self, + cfg, + shape_meta: dict, + noise_scheduler: DDPMScheduler, + horizon, + n_action_steps, + n_obs_steps, + num_inference_steps=None, + diffusion_step_embed_dim=256, + down_dims=(256, 512, 1024), + kernel_size=5, + n_groups=8, + film_scale_modulation=True, + ): + super().__init__() + action_shape = shape_meta["action"]["shape"] + assert len(action_shape) == 1 + action_dim = action_shape[0] + + self.rgb_encoder = RgbEncoder(input_shape=shape_meta.obs.image.shape, **cfg.rgb_encoder) + + self.unet = ConditionalUnet1D( + input_dim=action_dim, + global_cond_dim=(action_dim + self.rgb_encoder.feature_dim) * n_obs_steps, + diffusion_step_embed_dim=diffusion_step_embed_dim, + down_dims=down_dims, + kernel_size=kernel_size, + n_groups=n_groups, + film_scale_modulation=film_scale_modulation, + ) + + self.noise_scheduler = noise_scheduler + self.horizon = horizon + self.action_dim = action_dim + self.n_action_steps = n_action_steps + self.n_obs_steps = n_obs_steps + + if num_inference_steps is None: + num_inference_steps = noise_scheduler.config.num_train_timesteps + + self.num_inference_steps = num_inference_steps + + # ========= inference ============ + def conditional_sample( + self, + condition_data, + inpainting_mask, + local_cond=None, + global_cond=None, + generator=None, + ): + model = self.unet + scheduler = self.noise_scheduler + + trajectory = torch.randn( + size=condition_data.shape, + dtype=condition_data.dtype, + device=condition_data.device, + generator=generator, + ) + + # set step values + scheduler.set_timesteps(self.num_inference_steps) + + for t in scheduler.timesteps: + # 1. apply conditioning + trajectory[inpainting_mask] = condition_data[inpainting_mask] + + # 2. predict model output + model_output = model( + trajectory, + torch.full(trajectory.shape[:1], t, dtype=torch.long, device=trajectory.device), + local_cond=local_cond, + global_cond=global_cond, + ) + + # 3. compute previous image: x_t -> x_t-1 + trajectory = scheduler.step( + model_output, + t, + trajectory, + generator=generator, + ).prev_sample + + # finally make sure conditioning is enforced + trajectory[inpainting_mask] = condition_data[inpainting_mask] + + return trajectory + + def predict_action(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + """ + This function expects `batch` to have (at least): + { + "observation.state": (B, n_obs_steps, state_dim) + "observation.image": (B, n_obs_steps, C, H, W) + } + """ + assert set(batch).issuperset({"observation.state", "observation.image"}) + batch_size, n_obs_steps = batch["observation.state"].shape[:2] + assert n_obs_steps == self.n_obs_steps + assert self.n_obs_steps == n_obs_steps + + # build input + device = get_device_from_parameters(self) + dtype = get_dtype_from_parameters(self) + + # Extract image feature (first combine batch and sequence dims). + img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ...")) + # Separate batch and sequence dims. + img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size) + # Concatenate state and image features then flatten to (B, global_cond_dim). + global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1) + # reshape back to B, Do + # empty data for action + cond_data = torch.zeros(size=(batch_size, self.horizon, self.action_dim), device=device, dtype=dtype) + cond_mask = torch.zeros_like(cond_data, dtype=torch.bool) + + # run sampling + nsample = self.conditional_sample(cond_data, cond_mask, global_cond=global_cond) + + # `horizon` steps worth of actions (from the first observation). + action = nsample[..., : self.action_dim] + # Extract `n_action_steps` steps worth of action (from the current observation). + start = n_obs_steps - 1 + end = start + self.n_action_steps + action = action[:, start:end] + + return action + + def compute_loss(self, batch: dict[str, Tensor]) -> Tensor: + """ + This function expects `batch` to have (at least): + { + "observation.state": (B, n_obs_steps, state_dim) + "observation.image": (B, n_obs_steps, C, H, W) + "action": (B, horizon, action_dim) + "action_is_pad": (B, horizon) # TODO(now) maybe this is (B, horizon, 1) + } + """ + assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"}) + batch_size, n_obs_steps = batch["observation.state"].shape[:2] + horizon = batch["action"].shape[1] + assert horizon == self.horizon + assert n_obs_steps == self.n_obs_steps + assert self.n_obs_steps == n_obs_steps + + # handle different ways of passing observation + local_cond = None + global_cond = None + trajectory = batch["action"] + cond_data = trajectory + + # Extract image feature (first combine batch and sequence dims). + img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ...")) + # Separate batch and sequence dims. + img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size) + # Concatenate state and image features then flatten to (B, global_cond_dim). + global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1) + + # Sample noise that we'll add to the images + noise = torch.randn(trajectory.shape, device=trajectory.device) + # Sample a random timestep for each image + timesteps = torch.randint( + 0, + self.noise_scheduler.config.num_train_timesteps, + (trajectory.shape[0],), + device=trajectory.device, + ).long() + # Add noise to the clean images according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_trajectory = self.noise_scheduler.add_noise(trajectory, noise, timesteps) + + # Apply inpainting. TODO(now): implement? + inpainting_mask = torch.zeros_like(trajectory, dtype=bool) + noisy_trajectory[inpainting_mask] = cond_data[inpainting_mask] + + # Predict the noise residual + pred = self.unet(noisy_trajectory, timesteps, local_cond=local_cond, global_cond=global_cond) + + pred_type = self.noise_scheduler.config.prediction_type + if pred_type == "epsilon": + target = noise + elif pred_type == "sample": + target = trajectory + else: + raise ValueError(f"Unsupported prediction type {pred_type}") + + loss = F.mse_loss(pred, target, reduction="none") + loss = loss * (~inpainting_mask) + + if "action_is_pad" in batch: + in_episode_bound = ~batch["action_is_pad"] + loss = loss * in_episode_bound[:, :, None].type(loss.dtype) + + return loss.mean() diff --git a/lerobot/common/policies/diffusion/model/ema_model.py b/lerobot/common/policies/diffusion/model/ema_model.py index 6dc128de..3cb1dfbd 100644 --- a/lerobot/common/policies/diffusion/model/ema_model.py +++ b/lerobot/common/policies/diffusion/model/ema_model.py @@ -51,13 +51,6 @@ class EMAModel: def step(self, new_model): self.decay = self.get_decay(self.optimization_step) - # old_all_dataptrs = set() - # for param in new_model.parameters(): - # data_ptr = param.data_ptr() - # if data_ptr != 0: - # old_all_dataptrs.add(data_ptr) - - # all_dataptrs = set() for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=False): for param, ema_param in zip( module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=False @@ -66,10 +59,6 @@ class EMAModel: if isinstance(param, dict): raise RuntimeError("Dict parameter not supported") - # data_ptr = param.data_ptr() - # if data_ptr != 0: - # all_dataptrs.add(data_ptr) - if isinstance(module, _BatchNorm): # skip batchnorms ema_param.copy_(param.to(dtype=ema_param.dtype).data) @@ -79,6 +68,4 @@ class EMAModel: ema_param.mul_(self.decay) ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay) - # verify that iterating over module and then parameters is identical to parameters recursively. - # assert old_all_dataptrs == all_dataptrs self.optimization_step += 1 diff --git a/lerobot/common/policies/diffusion/model/lr_scheduler.py b/lerobot/common/policies/diffusion/model/lr_scheduler.py deleted file mode 100644 index 084b3a36..00000000 --- a/lerobot/common/policies/diffusion/model/lr_scheduler.py +++ /dev/null @@ -1,46 +0,0 @@ -from diffusers.optimization import TYPE_TO_SCHEDULER_FUNCTION, Optimizer, Optional, SchedulerType, Union - - -def get_scheduler( - name: Union[str, SchedulerType], - optimizer: Optimizer, - num_warmup_steps: Optional[int] = None, - num_training_steps: Optional[int] = None, - **kwargs, -): - """ - Added kwargs vs diffuser's original implementation - - Unified API to get any scheduler from its name. - - Args: - name (`str` or `SchedulerType`): - The name of the scheduler to use. - optimizer (`torch.optim.Optimizer`): - The optimizer that will be used during training. - num_warmup_steps (`int`, *optional*): - The number of warmup steps to do. This is not required by all schedulers (hence the argument being - optional), the function will raise an error if it's unset and the scheduler type requires it. - num_training_steps (`int``, *optional*): - The number of training steps to do. This is not required by all schedulers (hence the argument being - optional), the function will raise an error if it's unset and the scheduler type requires it. - """ - name = SchedulerType(name) - schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] - if name == SchedulerType.CONSTANT: - return schedule_func(optimizer, **kwargs) - - # All other schedulers require `num_warmup_steps` - if num_warmup_steps is None: - raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") - - if name == SchedulerType.CONSTANT_WITH_WARMUP: - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **kwargs) - - # All other schedulers require `num_training_steps` - if num_training_steps is None: - raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") - - return schedule_func( - optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **kwargs - ) diff --git a/lerobot/common/policies/diffusion/model/mask_generator.py b/lerobot/common/policies/diffusion/model/mask_generator.py deleted file mode 100644 index 63306dea..00000000 --- a/lerobot/common/policies/diffusion/model/mask_generator.py +++ /dev/null @@ -1,65 +0,0 @@ -import torch - -from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin - - -class LowdimMaskGenerator(ModuleAttrMixin): - def __init__( - self, - action_dim, - obs_dim, - # obs mask setup - max_n_obs_steps=2, - fix_obs_steps=True, - # action mask - action_visible=False, - ): - super().__init__() - self.action_dim = action_dim - self.obs_dim = obs_dim - self.max_n_obs_steps = max_n_obs_steps - self.fix_obs_steps = fix_obs_steps - self.action_visible = action_visible - - @torch.no_grad() - def forward(self, shape, seed=None): - device = self.device - B, T, D = shape # noqa: N806 - assert (self.action_dim + self.obs_dim) == D - - # create all tensors on this device - rng = torch.Generator(device=device) - if seed is not None: - rng = rng.manual_seed(seed) - - # generate dim mask - dim_mask = torch.zeros(size=shape, dtype=torch.bool, device=device) - is_action_dim = dim_mask.clone() - is_action_dim[..., : self.action_dim] = True - is_obs_dim = ~is_action_dim - - # generate obs mask - if self.fix_obs_steps: - obs_steps = torch.full((B,), fill_value=self.max_n_obs_steps, device=device) - else: - obs_steps = torch.randint( - low=1, high=self.max_n_obs_steps + 1, size=(B,), generator=rng, device=device - ) - - steps = torch.arange(0, T, device=device).reshape(1, T).expand(B, T) - obs_mask = (obs_steps > steps.T).T.reshape(B, T, 1).expand(B, T, D) - obs_mask = obs_mask & is_obs_dim - - # generate action mask - if self.action_visible: - action_steps = torch.maximum( - obs_steps - 1, torch.tensor(0, dtype=obs_steps.dtype, device=obs_steps.device) - ) - action_mask = (action_steps > steps.T).T.reshape(B, T, 1).expand(B, T, D) - action_mask = action_mask & is_action_dim - - mask = obs_mask - if self.action_visible: - mask = mask | action_mask - - return mask diff --git a/lerobot/common/policies/diffusion/model/module_attr_mixin.py b/lerobot/common/policies/diffusion/model/module_attr_mixin.py deleted file mode 100644 index 5d2cf4ea..00000000 --- a/lerobot/common/policies/diffusion/model/module_attr_mixin.py +++ /dev/null @@ -1,15 +0,0 @@ -import torch.nn as nn - - -class ModuleAttrMixin(nn.Module): - def __init__(self): - super().__init__() - self._dummy_variable = nn.Parameter() - - @property - def device(self): - return next(iter(self.parameters())).device - - @property - def dtype(self): - return next(iter(self.parameters())).dtype diff --git a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py b/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py deleted file mode 100644 index d724cd49..00000000 --- a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py +++ /dev/null @@ -1,214 +0,0 @@ -import copy -from typing import Dict, Optional, Tuple, Union - -import torch -import torch.nn as nn -import torchvision -from robomimic.models.base_nets import ResNet18Conv, SpatialSoftmax - -from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer -from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin -from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules - - -class RgbEncoder(nn.Module): - """Following `VisualCore` from Robomimic 0.2.0.""" - - def __init__(self, input_shape, relu=True, pretrained=False, num_keypoints=32): - """ - input_shape: channel-first input shape (C, H, W) - resnet_name: a timm model name. - pretrained: whether to use timm pretrained weights. - relu: whether to use relu as a final step. - num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image). - """ - super().__init__() - self.backbone = ResNet18Conv(input_channel=input_shape[0], pretrained=pretrained) - # Figure out the feature map shape. - with torch.inference_mode(): - feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:]) - self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints) - self.out = nn.Linear(num_keypoints * 2, num_keypoints * 2) - self.relu = nn.ReLU() if relu else nn.Identity() - - def forward(self, x): - return self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1))) - - -class MultiImageObsEncoder(ModuleAttrMixin): - def __init__( - self, - shape_meta: dict, - rgb_model: Union[nn.Module, Dict[str, nn.Module]], - resize_shape: Union[Tuple[int, int], Dict[str, tuple], None] = None, - crop_shape: Union[Tuple[int, int], Dict[str, tuple], None] = None, - random_crop: bool = True, - # replace BatchNorm with GroupNorm - use_group_norm: bool = False, - # use single rgb model for all rgb inputs - share_rgb_model: bool = False, - # renormalize rgb input with imagenet normalization - # assuming input in [0,1] - norm_mean_std: Optional[tuple[float, float]] = None, - ): - """ - Assumes rgb input: B,C,H,W - Assumes low_dim input: B,D - """ - super().__init__() - - rgb_keys = [] - low_dim_keys = [] - key_model_map = nn.ModuleDict() - key_transform_map = nn.ModuleDict() - key_shape_map = {} - - # handle sharing vision backbone - if share_rgb_model: - assert isinstance(rgb_model, nn.Module) - key_model_map["rgb"] = rgb_model - - obs_shape_meta = shape_meta["obs"] - for key, attr in obs_shape_meta.items(): - shape = tuple(attr["shape"]) - type = attr.get("type", "low_dim") - key_shape_map[key] = shape - if type == "rgb": - rgb_keys.append(key) - # configure model for this key - this_model = None - if not share_rgb_model: - if isinstance(rgb_model, dict): - # have provided model for each key - this_model = rgb_model[key] - else: - assert isinstance(rgb_model, nn.Module) - # have a copy of the rgb model - this_model = copy.deepcopy(rgb_model) - - if this_model is not None: - if use_group_norm: - this_model = replace_submodules( - root_module=this_model, - predicate=lambda x: isinstance(x, nn.BatchNorm2d), - func=lambda x: nn.GroupNorm( - num_groups=x.num_features // 16, num_channels=x.num_features - ), - ) - key_model_map[key] = this_model - - # configure resize - input_shape = shape - this_resizer = nn.Identity() - if resize_shape is not None: - if isinstance(resize_shape, dict): - h, w = resize_shape[key] - else: - h, w = resize_shape - this_resizer = torchvision.transforms.Resize(size=(h, w)) - input_shape = (shape[0], h, w) - - # configure randomizer - this_randomizer = nn.Identity() - if crop_shape is not None: - if isinstance(crop_shape, dict): - h, w = crop_shape[key] - else: - h, w = crop_shape - if random_crop: - this_randomizer = CropRandomizer( - input_shape=input_shape, crop_height=h, crop_width=w, num_crops=1, pos_enc=False - ) - else: - this_normalizer = torchvision.transforms.CenterCrop(size=(h, w)) - # configure normalizer - this_normalizer = nn.Identity() - if norm_mean_std is not None: - this_normalizer = torchvision.transforms.Normalize( - mean=norm_mean_std[0], std=norm_mean_std[1] - ) - - this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer) - key_transform_map[key] = this_transform - elif type == "low_dim": - low_dim_keys.append(key) - else: - raise RuntimeError(f"Unsupported obs type: {type}") - rgb_keys = sorted(rgb_keys) - low_dim_keys = sorted(low_dim_keys) - - self.shape_meta = shape_meta - self.key_model_map = key_model_map - self.key_transform_map = key_transform_map - self.share_rgb_model = share_rgb_model - self.rgb_keys = rgb_keys - self.low_dim_keys = low_dim_keys - self.key_shape_map = key_shape_map - - def forward(self, obs_dict): - batch_size = None - features = [] - - # process lowdim input - for key in self.low_dim_keys: - data = obs_dict[key] - if batch_size is None: - batch_size = data.shape[0] - else: - assert batch_size == data.shape[0] - assert data.shape[1:] == self.key_shape_map[key] - features.append(data) - - # process rgb input - if self.share_rgb_model: - # pass all rgb obs to rgb model - imgs = [] - for key in self.rgb_keys: - img = obs_dict[key] - if batch_size is None: - batch_size = img.shape[0] - else: - assert batch_size == img.shape[0] - assert img.shape[1:] == self.key_shape_map[key] - img = self.key_transform_map[key](img) - imgs.append(img) - # (N*B,C,H,W) - imgs = torch.cat(imgs, dim=0) - # (N*B,D) - feature = self.key_model_map["rgb"](imgs) - # (N,B,D) - feature = feature.reshape(-1, batch_size, *feature.shape[1:]) - # (B,N,D) - feature = torch.moveaxis(feature, 0, 1) - # (B,N*D) - feature = feature.reshape(batch_size, -1) - features.append(feature) - else: - # run each rgb obs to independent models - for key in self.rgb_keys: - img = obs_dict[key] - if batch_size is None: - batch_size = img.shape[0] - else: - assert batch_size == img.shape[0] - assert img.shape[1:] == self.key_shape_map[key] - img = self.key_transform_map[key](img) - feature = self.key_model_map[key](img) - features.append(feature) - - # concatenate all features - result = torch.cat(features, dim=-1) - return result - - @torch.no_grad() - def output_shape(self): - example_obs_dict = {} - obs_shape_meta = self.shape_meta["obs"] - batch_size = 1 - for key, attr in obs_shape_meta.items(): - shape = tuple(attr["shape"]) - this_obs = torch.zeros((batch_size,) + shape, dtype=self.dtype, device=self.device) - example_obs_dict[key] = this_obs - example_output = self.forward(example_obs_dict) - output_shape = example_output.shape[1:] - return output_shape diff --git a/lerobot/common/policies/diffusion/model/normalizer.py b/lerobot/common/policies/diffusion/model/normalizer.py deleted file mode 100644 index 0e4d79ab..00000000 --- a/lerobot/common/policies/diffusion/model/normalizer.py +++ /dev/null @@ -1,358 +0,0 @@ -from typing import Dict, Union - -import numpy as np -import torch -import torch.nn as nn -import zarr - -from lerobot.common.policies.diffusion.model.dict_of_tensor_mixin import DictOfTensorMixin -from lerobot.common.policies.diffusion.pytorch_utils import dict_apply - - -class LinearNormalizer(DictOfTensorMixin): - avaliable_modes = ["limits", "gaussian"] - - @torch.no_grad() - def fit( - self, - data: Union[Dict, torch.Tensor, np.ndarray, zarr.Array], - last_n_dims=1, - dtype=torch.float32, - mode="limits", - output_max=1.0, - output_min=-1.0, - range_eps=1e-4, - fit_offset=True, - ): - if isinstance(data, dict): - for key, value in data.items(): - self.params_dict[key] = _fit( - value, - last_n_dims=last_n_dims, - dtype=dtype, - mode=mode, - output_max=output_max, - output_min=output_min, - range_eps=range_eps, - fit_offset=fit_offset, - ) - else: - self.params_dict["_default"] = _fit( - data, - last_n_dims=last_n_dims, - dtype=dtype, - mode=mode, - output_max=output_max, - output_min=output_min, - range_eps=range_eps, - fit_offset=fit_offset, - ) - - def __call__(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor: - return self.normalize(x) - - def __getitem__(self, key: str): - return SingleFieldLinearNormalizer(self.params_dict[key]) - - def __setitem__(self, key: str, value: "SingleFieldLinearNormalizer"): - self.params_dict[key] = value.params_dict - - def _normalize_impl(self, x, forward=True): - if isinstance(x, dict): - result = {} - for key, value in x.items(): - params = self.params_dict[key] - result[key] = _normalize(value, params, forward=forward) - return result - else: - if "_default" not in self.params_dict: - raise RuntimeError("Not initialized") - params = self.params_dict["_default"] - return _normalize(x, params, forward=forward) - - def normalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor: - return self._normalize_impl(x, forward=True) - - def unnormalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor: - return self._normalize_impl(x, forward=False) - - def get_input_stats(self) -> Dict: - if len(self.params_dict) == 0: - raise RuntimeError("Not initialized") - if len(self.params_dict) == 1 and "_default" in self.params_dict: - return self.params_dict["_default"]["input_stats"] - - result = {} - for key, value in self.params_dict.items(): - if key != "_default": - result[key] = value["input_stats"] - return result - - def get_output_stats(self, key="_default"): - input_stats = self.get_input_stats() - if "min" in input_stats: - # no dict - return dict_apply(input_stats, self.normalize) - - result = {} - for key, group in input_stats.items(): - this_dict = {} - for name, value in group.items(): - this_dict[name] = self.normalize({key: value})[key] - result[key] = this_dict - return result - - -class SingleFieldLinearNormalizer(DictOfTensorMixin): - avaliable_modes = ["limits", "gaussian"] - - @torch.no_grad() - def fit( - self, - data: Union[torch.Tensor, np.ndarray, zarr.Array], - last_n_dims=1, - dtype=torch.float32, - mode="limits", - output_max=1.0, - output_min=-1.0, - range_eps=1e-4, - fit_offset=True, - ): - self.params_dict = _fit( - data, - last_n_dims=last_n_dims, - dtype=dtype, - mode=mode, - output_max=output_max, - output_min=output_min, - range_eps=range_eps, - fit_offset=fit_offset, - ) - - @classmethod - def create_fit(cls, data: Union[torch.Tensor, np.ndarray, zarr.Array], **kwargs): - obj = cls() - obj.fit(data, **kwargs) - return obj - - @classmethod - def create_manual( - cls, - scale: Union[torch.Tensor, np.ndarray], - offset: Union[torch.Tensor, np.ndarray], - input_stats_dict: Dict[str, Union[torch.Tensor, np.ndarray]], - ): - def to_tensor(x): - if not isinstance(x, torch.Tensor): - x = torch.from_numpy(x) - x = x.flatten() - return x - - # check - for x in [offset] + list(input_stats_dict.values()): - assert x.shape == scale.shape - assert x.dtype == scale.dtype - - params_dict = nn.ParameterDict( - { - "scale": to_tensor(scale), - "offset": to_tensor(offset), - "input_stats": nn.ParameterDict(dict_apply(input_stats_dict, to_tensor)), - } - ) - return cls(params_dict) - - @classmethod - def create_identity(cls, dtype=torch.float32): - scale = torch.tensor([1], dtype=dtype) - offset = torch.tensor([0], dtype=dtype) - input_stats_dict = { - "min": torch.tensor([-1], dtype=dtype), - "max": torch.tensor([1], dtype=dtype), - "mean": torch.tensor([0], dtype=dtype), - "std": torch.tensor([1], dtype=dtype), - } - return cls.create_manual(scale, offset, input_stats_dict) - - def normalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: - return _normalize(x, self.params_dict, forward=True) - - def unnormalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: - return _normalize(x, self.params_dict, forward=False) - - def get_input_stats(self): - return self.params_dict["input_stats"] - - def get_output_stats(self): - return dict_apply(self.params_dict["input_stats"], self.normalize) - - def __call__(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: - return self.normalize(x) - - -def _fit( - data: Union[torch.Tensor, np.ndarray, zarr.Array], - last_n_dims=1, - dtype=torch.float32, - mode="limits", - output_max=1.0, - output_min=-1.0, - range_eps=1e-4, - fit_offset=True, -): - assert mode in ["limits", "gaussian"] - assert last_n_dims >= 0 - assert output_max > output_min - - # convert data to torch and type - if isinstance(data, zarr.Array): - data = data[:] - if isinstance(data, np.ndarray): - data = torch.from_numpy(data) - if dtype is not None: - data = data.type(dtype) - - # convert shape - dim = 1 - if last_n_dims > 0: - dim = np.prod(data.shape[-last_n_dims:]) - data = data.reshape(-1, dim) - - # compute input stats min max mean std - input_min, _ = data.min(axis=0) - input_max, _ = data.max(axis=0) - input_mean = data.mean(axis=0) - input_std = data.std(axis=0) - - # compute scale and offset - if mode == "limits": - if fit_offset: - # unit scale - input_range = input_max - input_min - ignore_dim = input_range < range_eps - input_range[ignore_dim] = output_max - output_min - scale = (output_max - output_min) / input_range - offset = output_min - scale * input_min - offset[ignore_dim] = (output_max + output_min) / 2 - input_min[ignore_dim] - # ignore dims scaled to mean of output max and min - else: - # use this when data is pre-zero-centered. - assert output_max > 0 - assert output_min < 0 - # unit abs - output_abs = min(abs(output_min), abs(output_max)) - input_abs = torch.maximum(torch.abs(input_min), torch.abs(input_max)) - ignore_dim = input_abs < range_eps - input_abs[ignore_dim] = output_abs - # don't scale constant channels - scale = output_abs / input_abs - offset = torch.zeros_like(input_mean) - elif mode == "gaussian": - ignore_dim = input_std < range_eps - scale = input_std.clone() - scale[ignore_dim] = 1 - scale = 1 / scale - - offset = -input_mean * scale if fit_offset else torch.zeros_like(input_mean) - - # save - this_params = nn.ParameterDict( - { - "scale": scale, - "offset": offset, - "input_stats": nn.ParameterDict( - {"min": input_min, "max": input_max, "mean": input_mean, "std": input_std} - ), - } - ) - for p in this_params.parameters(): - p.requires_grad_(False) - return this_params - - -def _normalize(x, params, forward=True): - assert "scale" in params - if isinstance(x, np.ndarray): - x = torch.from_numpy(x) - scale = params["scale"] - offset = params["offset"] - x = x.to(device=scale.device, dtype=scale.dtype) - src_shape = x.shape - x = x.reshape(-1, scale.shape[0]) - x = x * scale + offset if forward else (x - offset) / scale - x = x.reshape(src_shape) - return x - - -def test(): - data = torch.zeros((100, 10, 9, 2)).uniform_() - data[..., 0, 0] = 0 - - normalizer = SingleFieldLinearNormalizer() - normalizer.fit(data, mode="limits", last_n_dims=2) - datan = normalizer.normalize(data) - assert datan.shape == data.shape - assert np.allclose(datan.max(), 1.0) - assert np.allclose(datan.min(), -1.0) - dataun = normalizer.unnormalize(datan) - assert torch.allclose(data, dataun, atol=1e-7) - - _ = normalizer.get_input_stats() - _ = normalizer.get_output_stats() - - normalizer = SingleFieldLinearNormalizer() - normalizer.fit(data, mode="limits", last_n_dims=1, fit_offset=False) - datan = normalizer.normalize(data) - assert datan.shape == data.shape - assert np.allclose(datan.max(), 1.0, atol=1e-3) - assert np.allclose(datan.min(), 0.0, atol=1e-3) - dataun = normalizer.unnormalize(datan) - assert torch.allclose(data, dataun, atol=1e-7) - - data = torch.zeros((100, 10, 9, 2)).uniform_() - normalizer = SingleFieldLinearNormalizer() - normalizer.fit(data, mode="gaussian", last_n_dims=0) - datan = normalizer.normalize(data) - assert datan.shape == data.shape - assert np.allclose(datan.mean(), 0.0, atol=1e-3) - assert np.allclose(datan.std(), 1.0, atol=1e-3) - dataun = normalizer.unnormalize(datan) - assert torch.allclose(data, dataun, atol=1e-7) - - # dict - data = torch.zeros((100, 10, 9, 2)).uniform_() - data[..., 0, 0] = 0 - - normalizer = LinearNormalizer() - normalizer.fit(data, mode="limits", last_n_dims=2) - datan = normalizer.normalize(data) - assert datan.shape == data.shape - assert np.allclose(datan.max(), 1.0) - assert np.allclose(datan.min(), -1.0) - dataun = normalizer.unnormalize(datan) - assert torch.allclose(data, dataun, atol=1e-7) - - _ = normalizer.get_input_stats() - _ = normalizer.get_output_stats() - - data = { - "obs": torch.zeros((1000, 128, 9, 2)).uniform_() * 512, - "action": torch.zeros((1000, 128, 2)).uniform_() * 512, - } - normalizer = LinearNormalizer() - normalizer.fit(data) - datan = normalizer.normalize(data) - dataun = normalizer.unnormalize(datan) - for key in data: - assert torch.allclose(data[key], dataun[key], atol=1e-4) - - _ = normalizer.get_input_stats() - _ = normalizer.get_output_stats() - - state_dict = normalizer.state_dict() - n = LinearNormalizer() - n.load_state_dict(state_dict) - datan = n.normalize(data) - dataun = n.unnormalize(datan) - for key in data: - assert torch.allclose(data[key], dataun[key], atol=1e-4) diff --git a/lerobot/common/policies/diffusion/model/positional_embedding.py b/lerobot/common/policies/diffusion/model/positional_embedding.py deleted file mode 100644 index 65fc97bd..00000000 --- a/lerobot/common/policies/diffusion/model/positional_embedding.py +++ /dev/null @@ -1,19 +0,0 @@ -import math - -import torch -import torch.nn as nn - - -class SinusoidalPosEmb(nn.Module): - def __init__(self, dim): - super().__init__() - self.dim = dim - - def forward(self, x): - device = x.device - half_dim = self.dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, device=device) * -emb) - emb = x[:, None] * emb[None, :] - emb = torch.cat((emb.sin(), emb.cos()), dim=-1) - return emb diff --git a/lerobot/common/policies/diffusion/model/rgb_encoder.py b/lerobot/common/policies/diffusion/model/rgb_encoder.py new file mode 100644 index 00000000..2a5edf45 --- /dev/null +++ b/lerobot/common/policies/diffusion/model/rgb_encoder.py @@ -0,0 +1,147 @@ +from typing import Callable + +import torch +import torchvision +from robomimic.models.base_nets import SpatialSoftmax +from torch import Tensor, nn +from torchvision.transforms import CenterCrop, RandomCrop + + +class RgbEncoder(nn.Module): + """Encoder an RGB image into a 1D feature vector. + + Includes the ability to normalize and crop the image first. + """ + + def __init__( + self, + input_shape: tuple[int, int, int], + norm_mean_std: tuple[float, float] = [1.0, 1.0], + crop_shape: tuple[int, int] | None = None, + random_crop: bool = False, + backbone_name: str = "resnet18", + pretrained_backbone: bool = False, + use_group_norm: bool = False, + relu: bool = True, + num_keypoints: int = 32, + ): + """ + Args: + input_shape: channel-first input shape (C, H, W) + norm_mean_std: mean and standard deviation used for image normalization. Images are normalized as + (image - mean) / std. + crop_shape: (H, W) shape to crop to (must fit within the input shape). If not provided, no + cropping is done. + random_crop: Whether the crop should be random at training time (it's always a center crop in + eval mode). + backbone_name: The name of one of the available resnet models from torchvision (eg resnet18). + pretrained_backbone: whether to use timm pretrained weights. + use_group_norm: Whether to replace batch normalization with group normalization in the backbone. + The group sizes are set to be about 16 (to be precise, feature_dim // 16). + relu: whether to use relu as a final step. + num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image). + """ + super().__init__() + if input_shape[0] != 3: + raise ValueError("Only RGB images are handled") + if not backbone_name.startswith("resnet"): + raise ValueError( + "Only resnet is supported for now (because of the assumption that 'layer4' is the output layer)" + ) + + # Set up optional preprocessing. + if norm_mean_std == [1.0, 1.0]: + self.normalizer = nn.Identity() + else: + self.normalizer = torchvision.transforms.Normalize(mean=norm_mean_std[0], std=norm_mean_std[1]) + + if crop_shape is not None: + self.do_crop = True + self.center_crop = CenterCrop(crop_shape) # always use center crop for eval + if random_crop: + self.maybe_random_crop = RandomCrop(crop_shape) + else: + self.maybe_random_crop = self.center_crop + else: + self.do_crop = False + + # Set up backbone. + backbone_model = getattr(torchvision.models, backbone_name)(pretrained=pretrained_backbone) + # Note: This assumes that the layer4 feature map is children()[-3] + # TODO(alexander-soare): Use a safer alternative. + self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2])) + if use_group_norm: + if pretrained_backbone: + raise ValueError( + "You can't replace BatchNorm in a pretrained model without ruining the weights!" + ) + self.backbone = _replace_submodules( + root_module=self.backbone, + predicate=lambda x: isinstance(x, nn.BatchNorm2d), + func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features), + ) + + # Set up pooling and final layers. + # Use a dry run to get the feature map shape. + with torch.inference_mode(): + feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:]) + self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints) + self.feature_dim = num_keypoints * 2 + self.out = nn.Linear(num_keypoints * 2, self.feature_dim) + self.maybe_relu = nn.ReLU() if relu else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: (B, C, H, W) image tensor with pixel values in [0, 1]. + Returns: + (B, D) image feature. + """ + # Preprocess: normalize and maybe crop (if it was set up in the __init__). + x = self.normalizer(x) + if self.do_crop: + if self.training: # noqa: SIM108 + x = self.maybe_random_crop(x) + else: + # Always use center crop for eval. + x = self.center_crop(x) + # Extract backbone feature. + x = torch.flatten(self.pool(self.backbone(x)), start_dim=1) + # Final linear layer. + x = self.out(x) + # Maybe a final non-linearity. + x = self.maybe_relu(x) + return x + + +def _replace_submodules( + root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module] +) -> nn.Module: + """ + Args: + root_module: The module for which the submodules need to be replaced + predicate: Takes a module as an argument and must return True if the that module is to be replaced. + func: Takes a module as an argument and returns a new module to replace it with. + Returns: + The root module with its submodules replaced. + """ + if predicate(root_module): + return func(root_module) + + replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)] + for *parents, k in replace_list: + parent_module = root_module + if len(parents) > 0: + parent_module = root_module.get_submodule(".".join(parents)) + if isinstance(parent_module, nn.Sequential): + src_module = parent_module[int(k)] + else: + src_module = getattr(parent_module, k) + tgt_module = func(src_module) + if isinstance(parent_module, nn.Sequential): + parent_module[int(k)] = tgt_module + else: + setattr(parent_module, k, tgt_module) + # verify that all BN are replaced + assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)) + return root_module diff --git a/lerobot/common/policies/diffusion/model/tensor_utils.py b/lerobot/common/policies/diffusion/model/tensor_utils.py deleted file mode 100644 index df9a568a..00000000 --- a/lerobot/common/policies/diffusion/model/tensor_utils.py +++ /dev/null @@ -1,972 +0,0 @@ -""" -A collection of utilities for working with nested tensor structures consisting -of numpy arrays and torch tensors. -""" - -import collections - -import numpy as np -import torch - - -def recursive_dict_list_tuple_apply(x, type_func_dict): - """ - Recursively apply functions to a nested dictionary or list or tuple, given a dictionary of - {data_type: function_to_apply}. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - type_func_dict (dict): a mapping from data types to the functions to be - applied for each data type. - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - assert list not in type_func_dict - assert tuple not in type_func_dict - assert dict not in type_func_dict - - if isinstance(x, (dict, collections.OrderedDict)): - new_x = collections.OrderedDict() if isinstance(x, collections.OrderedDict) else {} - for k, v in x.items(): - new_x[k] = recursive_dict_list_tuple_apply(v, type_func_dict) - return new_x - elif isinstance(x, (list, tuple)): - ret = [recursive_dict_list_tuple_apply(v, type_func_dict) for v in x] - if isinstance(x, tuple): - ret = tuple(ret) - return ret - else: - for t, f in type_func_dict.items(): - if isinstance(x, t): - return f(x) - else: - raise NotImplementedError("Cannot handle data type %s" % str(type(x))) - - -def map_tensor(x, func): - """ - Apply function @func to torch.Tensor objects in a nested dictionary or - list or tuple. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - func (function): function to apply to each tensor - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: func, - type(None): lambda x: x, - }, - ) - - -def map_ndarray(x, func): - """ - Apply function @func to np.ndarray objects in a nested dictionary or - list or tuple. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - func (function): function to apply to each array - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - np.ndarray: func, - type(None): lambda x: x, - }, - ) - - -def map_tensor_ndarray(x, tensor_func, ndarray_func): - """ - Apply function @tensor_func to torch.Tensor objects and @ndarray_func to - np.ndarray objects in a nested dictionary or list or tuple. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - tensor_func (function): function to apply to each tensor - ndarray_Func (function): function to apply to each array - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: tensor_func, - np.ndarray: ndarray_func, - type(None): lambda x: x, - }, - ) - - -def clone(x): - """ - Clones all torch tensors and numpy arrays in nested dictionary or list - or tuple and returns a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x: x.clone(), - np.ndarray: lambda x: x.copy(), - type(None): lambda x: x, - }, - ) - - -def detach(x): - """ - Detaches all torch tensors in nested dictionary or list - or tuple and returns a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x: x.detach(), - }, - ) - - -def to_batch(x): - """ - Introduces a leading batch dimension of 1 for all torch tensors and numpy - arrays in nested dictionary or list or tuple and returns a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x: x[None, ...], - np.ndarray: lambda x: x[None, ...], - type(None): lambda x: x, - }, - ) - - -def to_sequence(x): - """ - Introduces a time dimension of 1 at dimension 1 for all torch tensors and numpy - arrays in nested dictionary or list or tuple and returns a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x: x[:, None, ...], - np.ndarray: lambda x: x[:, None, ...], - type(None): lambda x: x, - }, - ) - - -def index_at_time(x, ind): - """ - Indexes all torch tensors and numpy arrays in dimension 1 with index @ind in - nested dictionary or list or tuple and returns a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - ind (int): index - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x: x[:, ind, ...], - np.ndarray: lambda x: x[:, ind, ...], - type(None): lambda x: x, - }, - ) - - -def unsqueeze(x, dim): - """ - Adds dimension of size 1 at dimension @dim in all torch tensors and numpy arrays - in nested dictionary or list or tuple and returns a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - dim (int): dimension - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x: x.unsqueeze(dim=dim), - np.ndarray: lambda x: np.expand_dims(x, axis=dim), - type(None): lambda x: x, - }, - ) - - -def contiguous(x): - """ - Makes all torch tensors and numpy arrays contiguous in nested dictionary or - list or tuple and returns a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x: x.contiguous(), - np.ndarray: lambda x: np.ascontiguousarray(x), - type(None): lambda x: x, - }, - ) - - -def to_device(x, device): - """ - Sends all torch tensors in nested dictionary or list or tuple to device - @device, and returns a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - device (torch.Device): device to send tensors to - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x, d=device: x.to(d), - type(None): lambda x: x, - }, - ) - - -def to_tensor(x): - """ - Converts all numpy arrays in nested dictionary or list or tuple to - torch tensors (and leaves existing torch Tensors as-is), and returns - a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x: x, - np.ndarray: lambda x: torch.from_numpy(x), - type(None): lambda x: x, - }, - ) - - -def to_numpy(x): - """ - Converts all torch tensors in nested dictionary or list or tuple to - numpy (and leaves existing numpy arrays as-is), and returns - a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - - def f(tensor): - if tensor.is_cuda: - return tensor.detach().cpu().numpy() - else: - return tensor.detach().numpy() - - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: f, - np.ndarray: lambda x: x, - type(None): lambda x: x, - }, - ) - - -def to_list(x): - """ - Converts all torch tensors and numpy arrays in nested dictionary or list - or tuple to a list, and returns a new nested structure. Useful for - json encoding. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - - def f(tensor): - if tensor.is_cuda: - return tensor.detach().cpu().numpy().tolist() - else: - return tensor.detach().numpy().tolist() - - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: f, - np.ndarray: lambda x: x.tolist(), - type(None): lambda x: x, - }, - ) - - -def to_float(x): - """ - Converts all torch tensors and numpy arrays in nested dictionary or list - or tuple to float type entries, and returns a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x: x.float(), - np.ndarray: lambda x: x.astype(np.float32), - type(None): lambda x: x, - }, - ) - - -def to_uint8(x): - """ - Converts all torch tensors and numpy arrays in nested dictionary or list - or tuple to uint8 type entries, and returns a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x: x.byte(), - np.ndarray: lambda x: x.astype(np.uint8), - type(None): lambda x: x, - }, - ) - - -def to_torch(x, device): - """ - Converts all numpy arrays and torch tensors in nested dictionary or list or tuple to - torch tensors on device @device and returns a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - device (torch.Device): device to send tensors to - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return to_device(to_float(to_tensor(x)), device) - - -def to_one_hot_single(tensor, num_class): - """ - Convert tensor to one-hot representation, assuming a certain number of total class labels. - - Args: - tensor (torch.Tensor): tensor containing integer labels - num_class (int): number of classes - - Returns: - x (torch.Tensor): tensor containing one-hot representation of labels - """ - x = torch.zeros(tensor.size() + (num_class,)).to(tensor.device) - x.scatter_(-1, tensor.unsqueeze(-1), 1) - return x - - -def to_one_hot(tensor, num_class): - """ - Convert all tensors in nested dictionary or list or tuple to one-hot representation, - assuming a certain number of total class labels. - - Args: - tensor (dict or list or tuple): a possibly nested dictionary or list or tuple - num_class (int): number of classes - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return map_tensor(tensor, func=lambda x, nc=num_class: to_one_hot_single(x, nc)) - - -def flatten_single(x, begin_axis=1): - """ - Flatten a tensor in all dimensions from @begin_axis onwards. - - Args: - x (torch.Tensor): tensor to flatten - begin_axis (int): which axis to flatten from - - Returns: - y (torch.Tensor): flattened tensor - """ - fixed_size = x.size()[:begin_axis] - _s = list(fixed_size) + [-1] - return x.reshape(*_s) - - -def flatten(x, begin_axis=1): - """ - Flatten all tensors in nested dictionary or list or tuple, from @begin_axis onwards. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - begin_axis (int): which axis to flatten from - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x, b=begin_axis: flatten_single(x, begin_axis=b), - }, - ) - - -def reshape_dimensions_single(x, begin_axis, end_axis, target_dims): - """ - Reshape selected dimensions in a tensor to a target dimension. - - Args: - x (torch.Tensor): tensor to reshape - begin_axis (int): begin dimension - end_axis (int): end dimension - target_dims (tuple or list): target shape for the range of dimensions - (@begin_axis, @end_axis) - - Returns: - y (torch.Tensor): reshaped tensor - """ - assert begin_axis <= end_axis - assert begin_axis >= 0 - assert end_axis < len(x.shape) - assert isinstance(target_dims, (tuple, list)) - s = x.shape - final_s = [] - for i in range(len(s)): - if i == begin_axis: - final_s.extend(target_dims) - elif i < begin_axis or i > end_axis: - final_s.append(s[i]) - return x.reshape(*final_s) - - -def reshape_dimensions(x, begin_axis, end_axis, target_dims): - """ - Reshape selected dimensions for all tensors in nested dictionary or list or tuple - to a target dimension. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - begin_axis (int): begin dimension - end_axis (int): end dimension - target_dims (tuple or list): target shape for the range of dimensions - (@begin_axis, @end_axis) - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single( - x, begin_axis=b, end_axis=e, target_dims=t - ), - np.ndarray: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single( - x, begin_axis=b, end_axis=e, target_dims=t - ), - type(None): lambda x: x, - }, - ) - - -def join_dimensions(x, begin_axis, end_axis): - """ - Joins all dimensions between dimensions (@begin_axis, @end_axis) into a flat dimension, for - all tensors in nested dictionary or list or tuple. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - begin_axis (int): begin dimension - end_axis (int): end dimension - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single( - x, begin_axis=b, end_axis=e, target_dims=[-1] - ), - np.ndarray: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single( - x, begin_axis=b, end_axis=e, target_dims=[-1] - ), - type(None): lambda x: x, - }, - ) - - -def expand_at_single(x, size, dim): - """ - Expand a tensor at a single dimension @dim by @size - - Args: - x (torch.Tensor): input tensor - size (int): size to expand - dim (int): dimension to expand - - Returns: - y (torch.Tensor): expanded tensor - """ - assert dim < x.ndimension() - assert x.shape[dim] == 1 - expand_dims = [-1] * x.ndimension() - expand_dims[dim] = size - return x.expand(*expand_dims) - - -def expand_at(x, size, dim): - """ - Expand all tensors in nested dictionary or list or tuple at a single - dimension @dim by @size. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - size (int): size to expand - dim (int): dimension to expand - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return map_tensor(x, lambda t, s=size, d=dim: expand_at_single(t, s, d)) - - -def unsqueeze_expand_at(x, size, dim): - """ - Unsqueeze and expand a tensor at a dimension @dim by @size. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - size (int): size to expand - dim (int): dimension to unsqueeze and expand - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - x = unsqueeze(x, dim) - return expand_at(x, size, dim) - - -def repeat_by_expand_at(x, repeats, dim): - """ - Repeat a dimension by combining expand and reshape operations. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - repeats (int): number of times to repeat the target dimension - dim (int): dimension to repeat on - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - x = unsqueeze_expand_at(x, repeats, dim + 1) - return join_dimensions(x, dim, dim + 1) - - -def named_reduce_single(x, reduction, dim): - """ - Reduce tensor at a dimension by named reduction functions. - - Args: - x (torch.Tensor): tensor to be reduced - reduction (str): one of ["sum", "max", "mean", "flatten"] - dim (int): dimension to be reduced (or begin axis for flatten) - - Returns: - y (torch.Tensor): reduced tensor - """ - assert x.ndimension() > dim - assert reduction in ["sum", "max", "mean", "flatten"] - if reduction == "flatten": - x = flatten(x, begin_axis=dim) - elif reduction == "max": - x = torch.max(x, dim=dim)[0] # [B, D] - elif reduction == "sum": - x = torch.sum(x, dim=dim) - else: - x = torch.mean(x, dim=dim) - return x - - -def named_reduce(x, reduction, dim): - """ - Reduces all tensors in nested dictionary or list or tuple at a dimension - using a named reduction function. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - reduction (str): one of ["sum", "max", "mean", "flatten"] - dim (int): dimension to be reduced (or begin axis for flatten) - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return map_tensor(x, func=lambda t, r=reduction, d=dim: named_reduce_single(t, r, d)) - - -def gather_along_dim_with_dim_single(x, target_dim, source_dim, indices): - """ - This function indexes out a target dimension of a tensor in a structured way, - by allowing a different value to be selected for each member of a flat index - tensor (@indices) corresponding to a source dimension. This can be interpreted - as moving along the source dimension, using the corresponding index value - in @indices to select values for all other dimensions outside of the - source and target dimensions. A common use case is to gather values - in target dimension 1 for each batch member (target dimension 0). - - Args: - x (torch.Tensor): tensor to gather values for - target_dim (int): dimension to gather values along - source_dim (int): dimension to hold constant and use for gathering values - from the other dimensions - indices (torch.Tensor): flat index tensor with same shape as tensor @x along - @source_dim - - Returns: - y (torch.Tensor): gathered tensor, with dimension @target_dim indexed out - """ - assert len(indices.shape) == 1 - assert x.shape[source_dim] == indices.shape[0] - - # unsqueeze in all dimensions except the source dimension - new_shape = [1] * x.ndimension() - new_shape[source_dim] = -1 - indices = indices.reshape(*new_shape) - - # repeat in all dimensions - but preserve shape of source dimension, - # and make sure target_dimension has singleton dimension - expand_shape = list(x.shape) - expand_shape[source_dim] = -1 - expand_shape[target_dim] = 1 - indices = indices.expand(*expand_shape) - - out = x.gather(dim=target_dim, index=indices) - return out.squeeze(target_dim) - - -def gather_along_dim_with_dim(x, target_dim, source_dim, indices): - """ - Apply @gather_along_dim_with_dim_single to all tensors in a nested - dictionary or list or tuple. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - target_dim (int): dimension to gather values along - source_dim (int): dimension to hold constant and use for gathering values - from the other dimensions - indices (torch.Tensor): flat index tensor with same shape as tensor @x along - @source_dim - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return map_tensor( - x, lambda y, t=target_dim, s=source_dim, i=indices: gather_along_dim_with_dim_single(y, t, s, i) - ) - - -def gather_sequence_single(seq, indices): - """ - Given a tensor with leading dimensions [B, T, ...], gather an element from each sequence in - the batch given an index for each sequence. - - Args: - seq (torch.Tensor): tensor with leading dimensions [B, T, ...] - indices (torch.Tensor): tensor indices of shape [B] - - Return: - y (torch.Tensor): indexed tensor of shape [B, ....] - """ - return gather_along_dim_with_dim_single(seq, target_dim=1, source_dim=0, indices=indices) - - -def gather_sequence(seq, indices): - """ - Given a nested dictionary or list or tuple, gathers an element from each sequence of the batch - for tensors with leading dimensions [B, T, ...]. - - Args: - seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors - of leading dimensions [B, T, ...] - indices (torch.Tensor): tensor indices of shape [B] - - Returns: - y (dict or list or tuple): new nested dict-list-tuple with tensors of shape [B, ...] - """ - return gather_along_dim_with_dim(seq, target_dim=1, source_dim=0, indices=indices) - - -def pad_sequence_single(seq, padding, batched=False, pad_same=True, pad_values=None): - """ - Pad input tensor or array @seq in the time dimension (dimension 1). - - Args: - seq (np.ndarray or torch.Tensor): sequence to be padded - padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1 - batched (bool): if sequence has the batch dimension - pad_same (bool): if pad by duplicating - pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same - - Returns: - padded sequence (np.ndarray or torch.Tensor) - """ - assert isinstance(seq, (np.ndarray, torch.Tensor)) - assert pad_same or pad_values is not None - if pad_values is not None: - assert isinstance(pad_values, float) - repeat_func = np.repeat if isinstance(seq, np.ndarray) else torch.repeat_interleave - concat_func = np.concatenate if isinstance(seq, np.ndarray) else torch.cat - ones_like_func = np.ones_like if isinstance(seq, np.ndarray) else torch.ones_like - seq_dim = 1 if batched else 0 - - begin_pad = [] - end_pad = [] - - if padding[0] > 0: - pad = seq[[0]] if pad_same else ones_like_func(seq[[0]]) * pad_values - begin_pad.append(repeat_func(pad, padding[0], seq_dim)) - if padding[1] > 0: - pad = seq[[-1]] if pad_same else ones_like_func(seq[[-1]]) * pad_values - end_pad.append(repeat_func(pad, padding[1], seq_dim)) - - return concat_func(begin_pad + [seq] + end_pad, seq_dim) - - -def pad_sequence(seq, padding, batched=False, pad_same=True, pad_values=None): - """ - Pad a nested dictionary or list or tuple of sequence tensors in the time dimension (dimension 1). - - Args: - seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors - of leading dimensions [B, T, ...] - padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1 - batched (bool): if sequence has the batch dimension - pad_same (bool): if pad by duplicating - pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same - - Returns: - padded sequence (dict or list or tuple) - """ - return recursive_dict_list_tuple_apply( - seq, - { - torch.Tensor: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values: pad_sequence_single( - x, p, b, ps, pv - ), - np.ndarray: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values: pad_sequence_single( - x, p, b, ps, pv - ), - type(None): lambda x: x, - }, - ) - - -def assert_size_at_dim_single(x, size, dim, msg): - """ - Ensure that array or tensor @x has size @size in dim @dim. - - Args: - x (np.ndarray or torch.Tensor): input array or tensor - size (int): size that tensors should have at @dim - dim (int): dimension to check - msg (str): text to display if assertion fails - """ - assert x.shape[dim] == size, msg - - -def assert_size_at_dim(x, size, dim, msg): - """ - Ensure that arrays and tensors in nested dictionary or list or tuple have - size @size in dim @dim. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - size (int): size that tensors should have at @dim - dim (int): dimension to check - """ - map_tensor(x, lambda t, s=size, d=dim, m=msg: assert_size_at_dim_single(t, s, d, m)) - - -def get_shape(x): - """ - Get all shapes of arrays and tensors in nested dictionary or list or tuple. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - - Returns: - y (dict or list or tuple): new nested dict-list-tuple that contains each array or - tensor's shape - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x: x.shape, - np.ndarray: lambda x: x.shape, - type(None): lambda x: x, - }, - ) - - -def list_of_flat_dict_to_dict_of_list(list_of_dict): - """ - Helper function to go from a list of flat dictionaries to a dictionary of lists. - By "flat" we mean that none of the values are dictionaries, but are numpy arrays, - floats, etc. - - Args: - list_of_dict (list): list of flat dictionaries - - Returns: - dict_of_list (dict): dictionary of lists - """ - assert isinstance(list_of_dict, list) - dic = collections.OrderedDict() - for i in range(len(list_of_dict)): - for k in list_of_dict[i]: - if k not in dic: - dic[k] = [] - dic[k].append(list_of_dict[i][k]) - return dic - - -def flatten_nested_dict_list(d, parent_key="", sep="_", item_key=""): - """ - Flatten a nested dict or list to a list. - - For example, given a dict - { - a: 1 - b: { - c: 2 - } - c: 3 - } - - the function would return [(a, 1), (b_c, 2), (c, 3)] - - Args: - d (dict, list): a nested dict or list to be flattened - parent_key (str): recursion helper - sep (str): separator for nesting keys - item_key (str): recursion helper - Returns: - list: a list of (key, value) tuples - """ - items = [] - if isinstance(d, (tuple, list)): - new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key - for i, v in enumerate(d): - items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=str(i))) - return items - elif isinstance(d, dict): - new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key - for k, v in d.items(): - assert isinstance(k, str) - items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=k)) - return items - else: - new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key - return [(new_key, d)] - - -def time_distributed(inputs, op, activation=None, inputs_as_kwargs=False, inputs_as_args=False, **kwargs): - """ - Apply function @op to all tensors in nested dictionary or list or tuple @inputs in both the - batch (B) and time (T) dimension, where the tensors are expected to have shape [B, T, ...]. - Will do this by reshaping tensors to [B * T, ...], passing through the op, and then reshaping - outputs to [B, T, ...]. - - Args: - inputs (list or tuple or dict): a possibly nested dictionary or list or tuple with tensors - of leading dimensions [B, T, ...] - op: a layer op that accepts inputs - activation: activation to apply at the output - inputs_as_kwargs (bool): whether to feed input as a kwargs dict to the op - inputs_as_args (bool) whether to feed input as a args list to the op - kwargs (dict): other kwargs to supply to the op - - Returns: - outputs (dict or list or tuple): new nested dict-list-tuple with tensors of leading dimension [B, T]. - """ - batch_size, seq_len = flatten_nested_dict_list(inputs)[0][1].shape[:2] - inputs = join_dimensions(inputs, 0, 1) - if inputs_as_kwargs: - outputs = op(**inputs, **kwargs) - elif inputs_as_args: - outputs = op(*inputs, **kwargs) - else: - outputs = op(inputs, **kwargs) - - if activation is not None: - outputs = map_tensor(outputs, activation) - outputs = reshape_dimensions(outputs, begin_axis=0, end_axis=0, target_dims=(batch_size, seq_len)) - return outputs diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index 9785358b..fca89d46 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -5,11 +5,10 @@ from collections import deque import hydra import torch -from torch import nn +from diffusers.optimization import get_scheduler +from torch import Tensor, nn -from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy -from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler -from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder +from lerobot.common.policies.diffusion.model.diffusion_unet_image_policy import DiffusionUnetImagePolicy from lerobot.common.policies.utils import populate_queues from lerobot.common.utils import get_safe_torch_device @@ -22,8 +21,6 @@ class DiffusionPolicy(nn.Module): cfg, cfg_device, cfg_noise_scheduler, - cfg_rgb_model, - cfg_obs_encoder, cfg_optimizer, cfg_ema, shape_meta: dict, @@ -31,53 +28,43 @@ class DiffusionPolicy(nn.Module): n_action_steps, n_obs_steps, num_inference_steps=None, - obs_as_global_cond=True, diffusion_step_embed_dim=256, down_dims=(256, 512, 1024), kernel_size=5, n_groups=8, - cond_predict_scale=True, - # parameters passed to step - **kwargs, + film_scale_modulation=True, + **_, ): super().__init__() self.cfg = cfg self.n_obs_steps = n_obs_steps self.n_action_steps = n_action_steps + # queues are populated during rollout of the policy, they contain the n latest observations and actions self._queues = None + # TODO(now): In-house this. noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler) - rgb_model_input_shape = copy.deepcopy(shape_meta.obs.image.shape) - if cfg_obs_encoder.crop_shape is not None: - rgb_model_input_shape[1:] = cfg_obs_encoder.crop_shape - rgb_model = RgbEncoder(input_shape=rgb_model_input_shape, **cfg_rgb_model) - obs_encoder = MultiImageObsEncoder( - rgb_model=rgb_model, - **cfg_obs_encoder, - ) self.diffusion = DiffusionUnetImagePolicy( + cfg, shape_meta=shape_meta, noise_scheduler=noise_scheduler, - obs_encoder=obs_encoder, horizon=horizon, n_action_steps=n_action_steps, n_obs_steps=n_obs_steps, num_inference_steps=num_inference_steps, - obs_as_global_cond=obs_as_global_cond, diffusion_step_embed_dim=diffusion_step_embed_dim, down_dims=down_dims, kernel_size=kernel_size, n_groups=n_groups, - cond_predict_scale=cond_predict_scale, - # parameters passed to step - **kwargs, + film_scale_modulation=film_scale_modulation, ) self.device = get_safe_torch_device(cfg_device) self.diffusion.to(self.device) + # TODO(alexander-soare): This should probably be managed outside of the policy class. self.ema_diffusion = None self.ema = None if self.cfg.use_ema: @@ -116,42 +103,45 @@ class DiffusionPolicy(nn.Module): "action": deque(maxlen=self.n_action_steps), } - @torch.no_grad() - def select_action(self, batch, step): + def forward(self, batch: dict[str, Tensor], **_) -> Tensor: + """A forward pass through the DNN part of this policy with optional loss computation.""" + return self.select_action(batch) + + @torch.no_grad + def select_action(self, batch, **_): """ Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights. + # TODO(now): Handle a batch """ - # TODO(rcadene): remove unused step - del step assert "observation.image" in batch assert "observation.state" in batch - assert len(batch) == 2 + assert len(batch) == 2 # TODO(now): Does this not have a batch dim? self._queues = populate_queues(self._queues, batch) if len(self._queues["action"]) == 0: # stack n latest observations from the queue batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} - - obs_dict = { - "image": batch["observation.image"], - "agent_pos": batch["observation.state"], - } - if self.training: - out = self.diffusion.predict_action(obs_dict) - else: - out = self.ema_diffusion.predict_action(obs_dict) - self._queues["action"].extend(out["action"].transpose(0, 1)) + actions = self._generate_actions(batch) + self._queues["action"].extend(actions.transpose(0, 1)) action = self._queues["action"].popleft() return action - def forward(self, batch, step): + def _generate_actions(self, batch): + if not self.training and self.ema_diffusion is not None: + return self.ema_diffusion.predict_action(batch) + else: + return self.diffusion.predict_action(batch) + + def update(self, batch, **_): + """Run the model in train mode, compute the loss, and do an optimization step.""" start_time = time.time() self.diffusion.train() - loss = self.diffusion.compute_loss(batch) + loss = self.compute_loss(batch) + loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( @@ -174,13 +164,11 @@ class DiffusionPolicy(nn.Module): "update_s": time.time() - start_time, } - # TODO(rcadene): remove hardcoding - # in diffusion_policy, len(dataloader) is 168 for a batch_size of 64 - if step % 168 == 0: - self.global_step += 1 - return info + def compute_loss(self, batch: dict[str, Tensor]) -> Tensor: + return self.diffusion.compute_loss(batch) + def save(self, fp): torch.save(self.state_dict(), fp) diff --git a/lerobot/common/policies/diffusion/pytorch_utils.py b/lerobot/common/policies/diffusion/pytorch_utils.py deleted file mode 100644 index ed5dc23a..00000000 --- a/lerobot/common/policies/diffusion/pytorch_utils.py +++ /dev/null @@ -1,76 +0,0 @@ -from typing import Callable, Dict - -import torch -import torch.nn as nn -import torchvision - - -def get_resnet(name, weights=None, **kwargs): - """ - name: resnet18, resnet34, resnet50 - weights: "IMAGENET1K_V1", "r3m" - """ - # load r3m weights - if (weights == "r3m") or (weights == "R3M"): - return get_r3m(name=name, **kwargs) - - func = getattr(torchvision.models, name) - resnet = func(weights=weights, **kwargs) - resnet.fc = torch.nn.Identity() - return resnet - - -def get_r3m(name, **kwargs): - """ - name: resnet18, resnet34, resnet50 - """ - import r3m - - r3m.device = "cpu" - model = r3m.load_r3m(name) - r3m_model = model.module - resnet_model = r3m_model.convnet - resnet_model = resnet_model.to("cpu") - return resnet_model - - -def dict_apply( - x: Dict[str, torch.Tensor], func: Callable[[torch.Tensor], torch.Tensor] -) -> Dict[str, torch.Tensor]: - result = {} - for key, value in x.items(): - if isinstance(value, dict): - result[key] = dict_apply(value, func) - else: - result[key] = func(value) - return result - - -def replace_submodules( - root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module] -) -> nn.Module: - """ - predicate: Return true if the module is to be replaced. - func: Return new module to use. - """ - if predicate(root_module): - return func(root_module) - - bn_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)] - for *parent, k in bn_list: - parent_module = root_module - if len(parent) > 0: - parent_module = root_module.get_submodule(".".join(parent)) - if isinstance(parent_module, nn.Sequential): - src_module = parent_module[int(k)] - else: - src_module = getattr(parent_module, k) - tgt_module = func(src_module) - if isinstance(parent_module, nn.Sequential): - parent_module[int(k)] = tgt_module - else: - setattr(parent_module, k, tgt_module) - # verify that all BN are replaced - bn_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)] - assert len(bn_list) == 0 - return root_module diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index a1cbea9a..325b5608 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -12,8 +12,6 @@ def make_policy(cfg): cfg=cfg.policy, cfg_device=cfg.device, cfg_noise_scheduler=cfg.noise_scheduler, - cfg_rgb_model=cfg.rgb_model, - cfg_obs_encoder=cfg.obs_encoder, cfg_optimizer=cfg.optimizer, cfg_ema=cfg.ema, # n_obs_steps=cfg.n_obs_steps, diff --git a/lerobot/common/policies/utils.py b/lerobot/common/policies/utils.py index b0503fe4..9d4b42f0 100644 --- a/lerobot/common/policies/utils.py +++ b/lerobot/common/policies/utils.py @@ -1,3 +1,7 @@ +import torch +from torch import nn + + def populate_queues(queues, batch): for key in batch: if len(queues[key]) != queues[key].maxlen: @@ -8,3 +12,21 @@ def populate_queues(queues, batch): # add latest observation to the queue queues[key].append(batch[key]) return queues + + +def get_device_from_parameters(module: nn.Module) -> torch.device: + """Get a module's device by checking one of its parameters. + + Note: assumes that all parameters have the same device + TODO(now): Add test. + """ + return next(iter(module.parameters())).device + + +def get_dtype_from_parameters(module: nn.Module) -> torch.dtype: + """Get a module's parameter dtype by checking one of its parameters. + + Note: assumes that all parameters have the same dtype. + TODO(now): Add test. + """ + return next(iter(module.parameters())).dtype diff --git a/lerobot/common/utils.py b/lerobot/common/utils.py index e3e22832..28270bac 100644 --- a/lerobot/common/utils.py +++ b/lerobot/common/utils.py @@ -11,6 +11,7 @@ from omegaconf import DictConfig def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device: + """Given a string, return a torch.device with checks on whether the device is available.""" match cfg_device: case "cuda": assert torch.cuda.is_available() diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 811ee824..005b0517 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -19,7 +19,6 @@ n_action_steps: 8 dataset_obs_steps: ${n_obs_steps} past_action_visible: False keypoint_visible_rate: 1.0 -obs_as_global_cond: True eval_episodes: 50 eval_freq: 5000 @@ -40,13 +39,12 @@ policy: n_obs_steps: ${n_obs_steps} n_action_steps: ${n_action_steps} num_inference_steps: 100 - obs_as_global_cond: ${obs_as_global_cond} # crop_shape: null diffusion_step_embed_dim: 128 down_dims: [512, 1024, 2048] kernel_size: 5 n_groups: 8 - cond_predict_scale: True + film_scale_modulation: True pretrained_model_path: @@ -68,6 +66,16 @@ policy: observation.state: [-0.1, 0] action: [-0.1, 0, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1.0, 1.1, 1.2, 1.3, 1.4] + rgb_encoder: + backbone_name: resnet18 + pretrained_backbone: false + use_group_norm: True + num_keypoints: 32 + relu: true + norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs) + crop_shape: [84, 84] + random_crop: True + noise_scheduler: _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler num_train_timesteps: 100 @@ -78,16 +86,6 @@ noise_scheduler: clip_sample: True # required when predict_epsilon=False prediction_type: epsilon # or sample -obs_encoder: - shape_meta: ${shape_meta} - # resize_shape: null - crop_shape: [84, 84] - # constant center crop - random_crop: True - use_group_norm: True - share_rgb_model: False - norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs) - rgb_model: pretrained: false num_keypoints: 32 diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 535ac935..06459a85 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -121,7 +121,7 @@ def eval_policy( # get the next action for the environment with torch.inference_mode(): - action = policy.select_action(observation, step) + action = policy.select_action(observation, step=step) # apply inverse transform to unnormalize the action action = postprocess_action(action, transform) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index caaf5182..dd3da978 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -213,7 +213,7 @@ def train(cfg: dict, out_dir=None, job_name=None): for key in batch: batch[key] = batch[key].to(cfg.device, non_blocking=True) - train_info = policy(batch, step) + train_info = policy.update(batch, step=step) # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? if step % cfg.log_freq == 0: