diff --git a/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py b/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py new file mode 100644 index 00000000..3c12d53a --- /dev/null +++ b/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py @@ -0,0 +1,246 @@ +from typing import Dict + +import torch +import torch.nn.functional as F # noqa: N812 +from diffusers.schedulers.scheduling_ddpm import DDPMScheduler +from einops import reduce + +from diffusion_policy.common.pytorch_util import dict_apply +from diffusion_policy.model.diffusion.conditional_unet1d import ConditionalUnet1D +from diffusion_policy.model.diffusion.mask_generator import LowdimMaskGenerator +from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder +from diffusion_policy.policy.base_image_policy import BaseImagePolicy + + +class DiffusionUnetImagePolicy(BaseImagePolicy): + def __init__( + self, + shape_meta: dict, + noise_scheduler: DDPMScheduler, + obs_encoder: MultiImageObsEncoder, + horizon, + n_action_steps, + n_obs_steps, + num_inference_steps=None, + obs_as_global_cond=True, + diffusion_step_embed_dim=256, + down_dims=(256, 512, 1024), + kernel_size=5, + n_groups=8, + cond_predict_scale=True, + # parameters passed to step + **kwargs, + ): + super().__init__() + + # parse shapes + action_shape = shape_meta["action"]["shape"] + assert len(action_shape) == 1 + action_dim = action_shape[0] + # get feature dim + obs_feature_dim = obs_encoder.output_shape()[0] + + # create diffusion model + input_dim = action_dim + obs_feature_dim + global_cond_dim = None + if obs_as_global_cond: + input_dim = action_dim + global_cond_dim = obs_feature_dim * n_obs_steps + + model = ConditionalUnet1D( + input_dim=input_dim, + local_cond_dim=None, + global_cond_dim=global_cond_dim, + diffusion_step_embed_dim=diffusion_step_embed_dim, + down_dims=down_dims, + kernel_size=kernel_size, + n_groups=n_groups, + cond_predict_scale=cond_predict_scale, + ) + + self.obs_encoder = obs_encoder + self.model = model + self.noise_scheduler = noise_scheduler + self.mask_generator = LowdimMaskGenerator( + action_dim=action_dim, + obs_dim=0 if obs_as_global_cond else obs_feature_dim, + max_n_obs_steps=n_obs_steps, + fix_obs_steps=True, + action_visible=False, + ) + self.horizon = horizon + self.obs_feature_dim = obs_feature_dim + self.action_dim = action_dim + self.n_action_steps = n_action_steps + self.n_obs_steps = n_obs_steps + self.obs_as_global_cond = obs_as_global_cond + self.kwargs = kwargs + + if num_inference_steps is None: + num_inference_steps = noise_scheduler.config.num_train_timesteps + self.num_inference_steps = num_inference_steps + + # ========= inference ============ + def conditional_sample( + self, + condition_data, + condition_mask, + local_cond=None, + global_cond=None, + generator=None, + # keyword arguments to scheduler.step + **kwargs, + ): + model = self.model + scheduler = self.noise_scheduler + + trajectory = torch.randn( + size=condition_data.shape, + dtype=condition_data.dtype, + device=condition_data.device, + generator=generator, + ) + + # set step values + scheduler.set_timesteps(self.num_inference_steps) + + for t in scheduler.timesteps: + # 1. apply conditioning + trajectory[condition_mask] = condition_data[condition_mask] + + # 2. predict model output + model_output = model(trajectory, t, local_cond=local_cond, global_cond=global_cond) + + # 3. compute previous image: x_t -> x_t-1 + trajectory = scheduler.step( + model_output, + t, + trajectory, + generator=generator, + # **kwargs # TODO(rcadene): in diffusion_policy, expected to be {} + ).prev_sample + + # finally make sure conditioning is enforced + trajectory[condition_mask] = condition_data[condition_mask] + + return trajectory + + def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + obs_dict: must include "obs" key + result: must include "action" key + """ + assert "past_action" not in obs_dict # not implemented yet + nobs = obs_dict + value = next(iter(nobs.values())) + bsize, n_obs_steps = value.shape[:2] + horizon = self.horizon + action_dim = self.action_dim + obs_dim = self.obs_feature_dim + assert self.n_obs_steps == n_obs_steps + + # build input + device = self.device + dtype = self.dtype + + # handle different ways of passing observation + local_cond = None + global_cond = None + if self.obs_as_global_cond: + # condition through global feature + this_nobs = dict_apply(nobs, lambda x: x[:, :n_obs_steps, ...].reshape(-1, *x.shape[2:])) + nobs_features = self.obs_encoder(this_nobs) + # reshape back to B, Do + global_cond = nobs_features.reshape(bsize, -1) + # empty data for action + cond_data = torch.zeros(size=(bsize, horizon, action_dim), device=device, dtype=dtype) + cond_mask = torch.zeros_like(cond_data, dtype=torch.bool) + else: + # condition through impainting + this_nobs = dict_apply(nobs, lambda x: x[:, :n_obs_steps, ...].reshape(-1, *x.shape[2:])) + nobs_features = self.obs_encoder(this_nobs) + # reshape back to B, T, Do + nobs_features = nobs_features.reshape(bsize, n_obs_steps, -1) + cond_data = torch.zeros(size=(bsize, horizon, action_dim + obs_dim), device=device, dtype=dtype) + cond_mask = torch.zeros_like(cond_data, dtype=torch.bool) + cond_data[:, :n_obs_steps, action_dim:] = nobs_features + cond_mask[:, :n_obs_steps, action_dim:] = True + + # run sampling + nsample = self.conditional_sample( + cond_data, cond_mask, local_cond=local_cond, global_cond=global_cond, **self.kwargs + ) + + action_pred = nsample[..., :action_dim] + + # get action + start = n_obs_steps - 1 + end = start + self.n_action_steps + action = action_pred[:, start:end] + + result = {"action": action, "action_pred": action_pred} + return result + + def compute_loss(self, batch): + assert "valid_mask" not in batch + nobs = batch["obs"] + nactions = batch["action"] + batch_size = nactions.shape[0] + horizon = nactions.shape[1] + + # handle different ways of passing observation + local_cond = None + global_cond = None + trajectory = nactions + cond_data = trajectory + if self.obs_as_global_cond: + # reshape B, T, ... to B*T + this_nobs = dict_apply(nobs, lambda x: x[:, : self.n_obs_steps, ...].reshape(-1, *x.shape[2:])) + nobs_features = self.obs_encoder(this_nobs) + # reshape back to B, Do + global_cond = nobs_features.reshape(batch_size, -1) + else: + # reshape B, T, ... to B*T + this_nobs = dict_apply(nobs, lambda x: x.reshape(-1, *x.shape[2:])) + nobs_features = self.obs_encoder(this_nobs) + # reshape back to B, T, Do + nobs_features = nobs_features.reshape(batch_size, horizon, -1) + cond_data = torch.cat([nactions, nobs_features], dim=-1) + trajectory = cond_data.detach() + + # generate impainting mask + condition_mask = self.mask_generator(trajectory.shape) + + # Sample noise that we'll add to the images + noise = torch.randn(trajectory.shape, device=trajectory.device) + bsz = trajectory.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint( + 0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=trajectory.device + ).long() + # Add noise to the clean images according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_trajectory = self.noise_scheduler.add_noise(trajectory, noise, timesteps) + + # compute loss mask + loss_mask = ~condition_mask + + # apply conditioning + noisy_trajectory[condition_mask] = cond_data[condition_mask] + + # Predict the noise residual + pred = self.model(noisy_trajectory, timesteps, local_cond=local_cond, global_cond=global_cond) + + pred_type = self.noise_scheduler.config.prediction_type + if pred_type == "epsilon": + target = noise + elif pred_type == "sample": + target = trajectory + else: + raise ValueError(f"Unsupported prediction type {pred_type}") + + loss = F.mse_loss(pred, target, reduction="none") + loss = loss * loss_mask.type(loss.dtype) + loss = reduce(loss, "b ... -> b (...)", "mean") + loss = loss.mean() + return loss diff --git a/lerobot/common/policies/diffusion/multi_image_obs_encoder.py b/lerobot/common/policies/diffusion/multi_image_obs_encoder.py new file mode 100644 index 00000000..e52f147f --- /dev/null +++ b/lerobot/common/policies/diffusion/multi_image_obs_encoder.py @@ -0,0 +1,189 @@ +import copy +from typing import Dict, Tuple, Union + +import torch +import torch.nn as nn +import torchvision + +from diffusion_policy.common.pytorch_util import replace_submodules +from diffusion_policy.model.common.module_attr_mixin import ModuleAttrMixin +from diffusion_policy.model.vision.crop_randomizer import CropRandomizer + + +class MultiImageObsEncoder(ModuleAttrMixin): + def __init__( + self, + shape_meta: dict, + rgb_model: Union[nn.Module, Dict[str, nn.Module]], + resize_shape: Union[Tuple[int, int], Dict[str, tuple], None] = None, + crop_shape: Union[Tuple[int, int], Dict[str, tuple], None] = None, + random_crop: bool = True, + # replace BatchNorm with GroupNorm + use_group_norm: bool = False, + # use single rgb model for all rgb inputs + share_rgb_model: bool = False, + # renormalize rgb input with imagenet normalization + # assuming input in [0,1] + imagenet_norm: bool = False, + ): + """ + Assumes rgb input: B,C,H,W + Assumes low_dim input: B,D + """ + super().__init__() + + rgb_keys = [] + low_dim_keys = [] + key_model_map = nn.ModuleDict() + key_transform_map = nn.ModuleDict() + key_shape_map = {} + + # handle sharing vision backbone + if share_rgb_model: + assert isinstance(rgb_model, nn.Module) + key_model_map["rgb"] = rgb_model + + obs_shape_meta = shape_meta["obs"] + for key, attr in obs_shape_meta.items(): + shape = tuple(attr["shape"]) + type = attr.get("type", "low_dim") + key_shape_map[key] = shape + if type == "rgb": + rgb_keys.append(key) + # configure model for this key + this_model = None + if not share_rgb_model: + if isinstance(rgb_model, dict): + # have provided model for each key + this_model = rgb_model[key] + else: + assert isinstance(rgb_model, nn.Module) + # have a copy of the rgb model + this_model = copy.deepcopy(rgb_model) + + if this_model is not None: + if use_group_norm: + this_model = replace_submodules( + root_module=this_model, + predicate=lambda x: isinstance(x, nn.BatchNorm2d), + func=lambda x: nn.GroupNorm( + num_groups=x.num_features // 16, num_channels=x.num_features + ), + ) + key_model_map[key] = this_model + + # configure resize + input_shape = shape + this_resizer = nn.Identity() + if resize_shape is not None: + if isinstance(resize_shape, dict): + h, w = resize_shape[key] + else: + h, w = resize_shape + this_resizer = torchvision.transforms.Resize(size=(h, w)) + input_shape = (shape[0], h, w) + + # configure randomizer + this_randomizer = nn.Identity() + if crop_shape is not None: + if isinstance(crop_shape, dict): + h, w = crop_shape[key] + else: + h, w = crop_shape + if random_crop: + this_randomizer = CropRandomizer( + input_shape=input_shape, crop_height=h, crop_width=w, num_crops=1, pos_enc=False + ) + else: + this_normalizer = torchvision.transforms.CenterCrop(size=(h, w)) + # configure normalizer + this_normalizer = nn.Identity() + if imagenet_norm: + # TODO(rcadene): move normalizer to dataset and env + this_normalizer = torchvision.transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer) + key_transform_map[key] = this_transform + elif type == "low_dim": + low_dim_keys.append(key) + else: + raise RuntimeError(f"Unsupported obs type: {type}") + rgb_keys = sorted(rgb_keys) + low_dim_keys = sorted(low_dim_keys) + + self.shape_meta = shape_meta + self.key_model_map = key_model_map + self.key_transform_map = key_transform_map + self.share_rgb_model = share_rgb_model + self.rgb_keys = rgb_keys + self.low_dim_keys = low_dim_keys + self.key_shape_map = key_shape_map + + def forward(self, obs_dict): + batch_size = None + features = [] + # process rgb input + if self.share_rgb_model: + # pass all rgb obs to rgb model + imgs = [] + for key in self.rgb_keys: + img = obs_dict[key] + if batch_size is None: + batch_size = img.shape[0] + else: + assert batch_size == img.shape[0] + assert img.shape[1:] == self.key_shape_map[key] + img = self.key_transform_map[key](img) + imgs.append(img) + # (N*B,C,H,W) + imgs = torch.cat(imgs, dim=0) + # (N*B,D) + feature = self.key_model_map["rgb"](imgs) + # (N,B,D) + feature = feature.reshape(-1, batch_size, *feature.shape[1:]) + # (B,N,D) + feature = torch.moveaxis(feature, 0, 1) + # (B,N*D) + feature = feature.reshape(batch_size, -1) + features.append(feature) + else: + # run each rgb obs to independent models + for key in self.rgb_keys: + img = obs_dict[key] + if batch_size is None: + batch_size = img.shape[0] + else: + assert batch_size == img.shape[0] + assert img.shape[1:] == self.key_shape_map[key] + img = self.key_transform_map[key](img) + feature = self.key_model_map[key](img) + features.append(feature) + + # process lowdim input + for key in self.low_dim_keys: + data = obs_dict[key] + if batch_size is None: + batch_size = data.shape[0] + else: + assert batch_size == data.shape[0] + assert data.shape[1:] == self.key_shape_map[key] + features.append(data) + + # concatenate all features + result = torch.cat(features, dim=-1) + return result + + @torch.no_grad() + def output_shape(self): + example_obs_dict = {} + obs_shape_meta = self.shape_meta["obs"] + batch_size = 1 + for key, attr in obs_shape_meta.items(): + shape = tuple(attr["shape"]) + this_obs = torch.zeros((batch_size,) + shape, dtype=self.dtype, device=self.device) + example_obs_dict[key] = this_obs + example_output = self.forward(example_obs_dict) + output_shape = example_output.shape[1:] + return output_shape diff --git a/lerobot/common/policies/diffusion.py b/lerobot/common/policies/diffusion/policy.py similarity index 89% rename from lerobot/common/policies/diffusion.py rename to lerobot/common/policies/diffusion/policy.py index 50e72b23..a484c65a 100644 --- a/lerobot/common/policies/diffusion.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -1,6 +1,7 @@ import copy import time +import einops import hydra import torch import torch.nn as nn @@ -8,8 +9,9 @@ from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from diffusion_policy.model.common.lr_scheduler import get_scheduler from diffusion_policy.model.vision.model_getter import get_resnet -from diffusion_policy.model.vision.multi_image_obs_encoder import MultiImageObsEncoder -from diffusion_policy.policy.diffusion_unet_image_policy import DiffusionUnetImagePolicy + +from .diffusion_unet_image_policy import DiffusionUnetImagePolicy +from .multi_image_obs_encoder import MultiImageObsEncoder FIRST_ACTION = 0 @@ -99,10 +101,15 @@ class DiffusionPolicy(nn.Module): # TODO(rcadene): remove unused step_count del step_count + # TODO(rcadene): remove unsqueeze hack... + if observation["image"].ndim == 3: + observation["image"] = observation["image"].unsqueeze(0) + observation["state"] = observation["state"].unsqueeze(0) + obs_dict = { - # c h w -> b t c h w (b=1, t=1) - "image": observation["image"][None, None, ...], - "agent_pos": observation["state"][None, None, ...], + # TODO(rcadene): hack to add temporal dim + "image": einops.rearrange(observation["image"], "b c h w -> b 1 c h w"), + "agent_pos": einops.rearrange(observation["state"], "b c -> b 1 c"), } out = self.diffusion.predict_action(obs_dict) diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 9d5afe35..3ce207f0 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -4,7 +4,7 @@ def make_policy(cfg): policy = TDMPC(cfg.policy) elif cfg.policy.name == "diffusion": - from lerobot.common.policies.diffusion import DiffusionPolicy + from lerobot.common.policies.diffusion.policy import DiffusionPolicy policy = DiffusionPolicy( cfg=cfg.policy, diff --git a/lerobot/common/policies/tdmpc.py b/lerobot/common/policies/tdmpc.py index b56e45df..64908d62 100644 --- a/lerobot/common/policies/tdmpc.py +++ b/lerobot/common/policies/tdmpc.py @@ -138,9 +138,6 @@ class TDMPC(nn.Module): "state": observation["state"].contiguous(), } action = self.act(obs, t0=t0, step=self.step.item()) - - # TODO(rcadene): hack to postprocess action (e.g. unnormalize) - # action = action * self.action_std + self.action_mean return action @torch.no_grad() diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index cd1fe15e..a537835e 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -147,7 +147,7 @@ def train(cfg: dict, out_dir=None, job_name=None): env = make_env(cfg, transform=offline_buffer._transform) logging.info("make_policy") - policy = make_policy(cfg, transform=offline_buffer._transform) + policy = make_policy(cfg) td_policy = TensorDictModule( policy,