diff --git a/examples/2_evaluate_pretrained_policy.py b/examples/2_evaluate_pretrained_policy.py index be6abd1b..b3d13f74 100644 --- a/examples/2_evaluate_pretrained_policy.py +++ b/examples/2_evaluate_pretrained_policy.py @@ -11,6 +11,7 @@ from lerobot.common.utils import init_hydra_config from lerobot.scripts.eval import eval # Get a pretrained policy from the hub. +# TODO(alexander-soare): This no longer works until we upload a new model that uses the current configs. hub_id = "lerobot/diffusion_policy_pusht_image" folder = Path(snapshot_download(hub_id)) # OR uncomment the following to evaluate a policy from the local outputs/train folder. diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py index 238f953d..0c8decc4 100644 --- a/examples/3_train_policy.py +++ b/examples/3_train_policy.py @@ -11,58 +11,58 @@ import torch from omegaconf import OmegaConf from lerobot.common.datasets.factory import make_dataset -from lerobot.common.policies.diffusion.policy import DiffusionPolicy +from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig +from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.utils import init_hydra_config output_directory = Path("outputs/train/example_pusht_diffusion") os.makedirs(output_directory, exist_ok=True) -overrides = [ - "env=pusht", - "policy=diffusion", - # Adjust as you prefer. 5000 steps are needed to get something worth evaluating. - "offline_steps=5000", - "log_freq=250", - "device=cuda", -] +# Number of offline training steps (we'll only do offline training for this example. +# Adjust as you prefer. 5000 steps are needed to get something worth evaluating. +training_steps = 5000 +device = torch.device("cuda") +log_freq = 250 -cfg = init_hydra_config("lerobot/configs/default.yaml", overrides) +# Set up the dataset. +hydra_cfg = init_hydra_config("lerobot/configs/default.yaml", overrides=["env=pusht"]) +dataset = make_dataset(hydra_cfg) -policy = DiffusionPolicy( - cfg=cfg.policy, - cfg_device=cfg.device, - cfg_noise_scheduler=cfg.noise_scheduler, - cfg_rgb_model=cfg.rgb_model, - cfg_obs_encoder=cfg.obs_encoder, - cfg_optimizer=cfg.optimizer, - cfg_ema=cfg.ema, - **cfg.policy, -) +# Set up the the policy. +# Policies are initialized with a configuration class, in this case `DiffusionConfig`. +# For this example, no arguments need to be passed because the defaults are set up for PushT. +# If you're doing something different, you will likely need to change at least some of the defaults. +cfg = DiffusionConfig() +# TODO(alexander-soare): Remove LR scheduler from the policy. +policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps) policy.train() +policy.to(device) -dataset = make_dataset(cfg) - -# create dataloader for offline training +# Create dataloader for offline training. dataloader = torch.utils.data.DataLoader( dataset, num_workers=4, - batch_size=cfg.policy.batch_size, + batch_size=cfg.batch_size, shuffle=True, - pin_memory=cfg.device != "cpu", + pin_memory=device != torch.device("cpu"), drop_last=True, ) -for step, batch in enumerate(dataloader): - info = policy(batch, step) - - if step % cfg.log_freq == 0: - num_samples = (step + 1) * cfg.policy.batch_size - loss = info["loss"] - update_s = info["update_s"] - print(f"step:{step} samples:{num_samples} loss:{loss:.3f} update_time:{update_s:.3f}(seconds)") - +# Run training loop. +step = 0 +done = False +while not done: + for batch in dataloader: + batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()} + info = policy(batch) + if step % log_freq == 0: + print(f"step: {step} loss: {info['loss']:.3f} update_time: {info['update_s']:.3f} (seconds)") + step += 1 + if step >= training_steps: + done = True + break # Save the policy, configuration, and normalization stats for later use. policy.save(output_directory / "model.pt") -OmegaConf.save(cfg, output_directory / "config.yaml") +OmegaConf.save(hydra_cfg, output_directory / "config.yaml") torch.save(dataset.transform.transforms[-1].stats, output_directory / "stats.pth") diff --git a/lerobot/common/policies/diffusion/replay_buffer.py b/lerobot/common/datasets/_diffusion_policy_replay_buffer.py similarity index 99% rename from lerobot/common/policies/diffusion/replay_buffer.py rename to lerobot/common/datasets/_diffusion_policy_replay_buffer.py index 7fccf74d..1697f9fc 100644 --- a/lerobot/common/policies/diffusion/replay_buffer.py +++ b/lerobot/common/datasets/_diffusion_policy_replay_buffer.py @@ -1,3 +1,8 @@ +"""Helper code for loading PushT dataset from Diffusion Policy (https://diffusion-policy.cs.columbia.edu/) + +Copied from the original Diffusion Policy repository. +""" + from __future__ import annotations import math diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index 34d92daa..47253b15 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -5,8 +5,10 @@ import numpy as np import torch import tqdm +from lerobot.common.datasets._diffusion_policy_replay_buffer import ( + ReplayBuffer as DiffusionPolicyReplayBuffer, +) from lerobot.common.datasets.utils import download_and_extract_zip, load_data_with_delta_timestamps -from lerobot.common.policies.diffusion.replay_buffer import ReplayBuffer as DiffusionPolicyReplayBuffer # as define in env SUCCESS_THRESHOLD = 0.95 # 95% coverage, diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index e67d8a04..34430ff1 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -208,6 +208,10 @@ def compute_stats(dataset, batch_size=32, max_num_samples=None): def cycle(iterable): + """The equivalent of itertools.cycle, but safe for Pytorch dataloaders. + + See https://github.com/pytorch/pytorch/issues/23900 for information on why itertools.cycle is not safe. + """ iterator = iter(iterable) while True: try: diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index 74ed270e..1b438f2d 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -26,8 +26,8 @@ class ActionChunkingTransformerConfig: image_normalization_std: Value by which to divide the input image pixels (after the mean has been subtracted). vision_backbone: Name of the torchvision resnet backbone to use for encoding images. - use_pretrained_backbone: Whether the backbone should be initialized with ImageNet, pretrained weights - from torchvision. + use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from + torchvision. replace_final_stride_with_dilation: Whether to replace the ResNet's final 2x2 stride with a dilated convolution. pre_norm: Whether to use "pre-norm" in the transformer blocks. @@ -56,7 +56,7 @@ class ActionChunkingTransformerConfig: # Inputs / output structure. n_obs_steps: int = 1 - camera_names: list[str] = field(default_factory=lambda: ["top"]) + camera_names: tuple[str] = ("top",) chunk_size: int = 100 n_action_steps: int = 100 @@ -101,7 +101,7 @@ class ActionChunkingTransformerConfig: utd: int = 1 def __post_init__(self): - """Input validation.""" + """Input validation (not exhaustive).""" if not self.vision_backbone.startswith("resnet"): raise ValueError("`vision_backbone` must be one of the ResNet variants.") if self.use_temporal_aggregation: diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 1361e071..5f2429a6 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -65,12 +65,16 @@ class ActionChunkingTransformerPolicy(nn.Module): "ActionChunkingTransformerPolicy does not handle multiple observation steps." ) - def __init__(self, cfg: ActionChunkingTransformerConfig): + def __init__(self, cfg: ActionChunkingTransformerConfig | None = None): """ - TODO(alexander-soare): Add documentation for all parameters once we have model configs established. + Args: + cfg: Policy configuration class instance or None, in which case the default instantiation of the + configuration class is used. """ super().__init__() - if getattr(cfg, "n_obs_steps", 1) != 1: + if cfg is None: + cfg = ActionChunkingTransformerConfig() + if cfg.n_obs_steps != 1: raise ValueError(self._multiple_obs_steps_not_handled_msg) self.cfg = cfg @@ -163,7 +167,8 @@ class ActionChunkingTransformerPolicy(nn.Module): @torch.no_grad def select_action(self, batch: dict[str, Tensor], **_) -> Tensor: - """ + """Select a single action given environment observations. + This method wraps `select_actions` in order to return one action at a time for execution in the environment. It works by managing the actions in a queue and only calling `select_actions` when the queue is empty. diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py new file mode 100644 index 00000000..d8820a0b --- /dev/null +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -0,0 +1,135 @@ +from dataclasses import dataclass + + +@dataclass +class DiffusionConfig: + """Configuration class for Diffusion Policy. + + Defaults are configured for training with PushT providing proprioceptive and single camera observations. + + The parameters you will most likely need to change are the ones which depend on the environment / sensors. + Those are: `state_dim`, `action_dim` and `image_size`. + + Args: + state_dim: Dimensionality of the observation state space (excluding images). + action_dim: Dimensionality of the action space. + image_size: (H, W) size of the input images. + n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the + current step and additional steps going back). + horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`. + n_action_steps: The number of action steps to run in the environment for one invocation of the policy. + See `DiffusionPolicy.select_action` for more details. + image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in + [0, 1]) for normalization. + image_normalization_std: Value by which to divide the input image pixels (after the mean has been + subtracted). + vision_backbone: Name of the torchvision resnet backbone to use for encoding images. + crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit + within the image size. If None, no cropping is done. + crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval + mode). + use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from + torchvision. + use_group_norm: Whether to replace batch normalization with group normalization in the backbone. + The group sizes are set to be about 16 (to be precise, feature_dim // 16). + spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax. + down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet. + You may provide a variable number of dimensions, therefore also controlling the degree of + downsampling. + kernel_size: The convolutional kernel size of the diffusion modeling Unet. + n_groups: Number of groups used in the group norm of the Unet's convolutional blocks. + diffusion_step_embed_dim: The Unet is conditioned on the diffusion timestep via a small non-linear + network. This is the output dimension of that network, i.e., the embedding dimension. + use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning. + Bias modulation is used be default, while this parameter indicates whether to also use scale + modulation. + num_train_timesteps: Number of diffusion steps for the forward diffusion schedule. + beta_schedule: Name of the diffusion beta schedule as per DDPMScheduler from Hugging Face diffusers. + beta_start: Beta value for the first forward-diffusion step. + beta_end: Beta value for the last forward-diffusion step. + prediction_type: The type of prediction that the diffusion modeling Unet makes. Choose from "epsilon" + or "sample". These have equivalent outcomes from a latent variable modeling perspective, but + "epsilon" has been shown to work better in many deep neural network settings. + clip_sample: Whether to clip the sample to [-`clip_sample_range`, +`clip_sample_range`] for each + denoising step at inference time. WARNING: you will need to make sure your action-space is + normalized to fit within this range. + clip_sample_range: The magnitude of the clipping range as described above. + num_inference_steps: Number of reverse diffusion steps to use at inference time (steps are evenly + spaced). If not provided, this defaults to be the same as `num_train_timesteps`. + """ + + # Environment. + # Inherit these from the environment config. + state_dim: int = 2 + action_dim: int = 2 + image_size: tuple[int, int] = (96, 96) + + # Inputs / output structure. + n_obs_steps: int = 2 + horizon: int = 16 + n_action_steps: int = 8 + + # Vision preprocessing. + image_normalization_mean: tuple[float, float, float] = (0.5, 0.5, 0.5) + image_normalization_std: tuple[float, float, float] = (0.5, 0.5, 0.5) + + # Architecture / modeling. + # Vision backbone. + vision_backbone: str = "resnet18" + crop_shape: tuple[int, int] | None = (84, 84) + crop_is_random: bool = True + use_pretrained_backbone: bool = False + use_group_norm: bool = True + spatial_softmax_num_keypoints: int = 32 + # Unet. + down_dims: tuple[int, ...] = (512, 1024, 2048) + kernel_size: int = 5 + n_groups: int = 8 + diffusion_step_embed_dim: int = 128 + use_film_scale_modulation: bool = True + # Noise scheduler. + num_train_timesteps: int = 100 + beta_schedule: str = "squaredcos_cap_v2" + beta_start: float = 0.0001 + beta_end: float = 0.02 + prediction_type: str = "epsilon" + clip_sample: bool = True + clip_sample_range: float = 1.0 + + # Inference + num_inference_steps: int | None = None + + # --- + # TODO(alexander-soare): Remove these from the policy config. + batch_size: int = 64 + grad_clip_norm: int = 10 + lr: float = 1.0e-4 + lr_scheduler: str = "cosine" + lr_warmup_steps: int = 500 + adam_betas: tuple[float, float] = (0.95, 0.999) + adam_eps: float = 1.0e-8 + adam_weight_decay: float = 1.0e-6 + utd: int = 1 + use_ema: bool = True + ema_update_after_step: int = 0 + ema_min_alpha: float = 0.0 + ema_max_alpha: float = 0.9999 + ema_inv_gamma: float = 1.0 + ema_power: float = 0.75 + + def __post_init__(self): + """Input validation (not exhaustive).""" + if not self.vision_backbone.startswith("resnet"): + raise ValueError( + f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." + ) + if self.crop_shape[0] > self.image_size[0] or self.crop_shape[1] > self.image_size[1]: + raise ValueError( + f"`crop_shape` should fit within `image_size`. Got {self.crop_shape} for `crop_shape` and " + f"{self.image_size} for `image_size`." + ) + supported_prediction_types = ["epsilon", "sample"] + if self.prediction_type not in supported_prediction_types: + raise ValueError( + f"`prediction_type` must be one of {supported_prediction_types}. Got {self.prediction_type}." + ) diff --git a/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py b/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py deleted file mode 100644 index f7432db3..00000000 --- a/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py +++ /dev/null @@ -1,315 +0,0 @@ -"""Code from the original diffusion policy project. - -Notes on how to load a checkpoint from the original repository: - -In the original repository, run the eval and use a breakpoint to extract the policy weights. - -``` -torch.save(policy.state_dict(), "weights.pt") -``` - -In this repository, add a breakpoint somewhere after creating an equivalent policy and load in the weights: - -``` -loaded = torch.load("weights.pt") -aligned = {} -their_prefix = "obs_encoder.obs_nets.image.backbone" -our_prefix = "obs_encoder.key_model_map.image.backbone" -aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)}) -their_prefix = "obs_encoder.obs_nets.image.pool" -our_prefix = "obs_encoder.key_model_map.image.pool" -aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)}) -their_prefix = "obs_encoder.obs_nets.image.nets.3" -our_prefix = "obs_encoder.key_model_map.image.out" -aligned.update({our_prefix + k.removeprefix(their_prefix): v for k, v in loaded.items() if k.startswith(their_prefix)}) -aligned.update({k: v for k, v in loaded.items() if k.startswith('model.')}) -# Note: here you are loading into the ema model. -missing_keys, unexpected_keys = policy.ema_diffusion.load_state_dict(aligned, strict=False) -assert all('_dummy_variable' in k for k in missing_keys) -assert len(unexpected_keys) == 0 -``` - -Then in that same runtime you can also save the weights with the new aligned state_dict: - -``` -policy.save("weights.pt") -``` - -Now you can remove the breakpoint and extra code and load in the weights just like with any other lerobot checkpoint. - -""" - -from typing import Dict - -import torch -import torch.nn.functional as F # noqa: N812 -from diffusers.schedulers.scheduling_ddpm import DDPMScheduler -from einops import reduce - -from lerobot.common.policies.diffusion.model.conditional_unet1d import ConditionalUnet1D -from lerobot.common.policies.diffusion.model.mask_generator import LowdimMaskGenerator -from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin -from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder -from lerobot.common.policies.diffusion.model.normalizer import LinearNormalizer -from lerobot.common.policies.diffusion.pytorch_utils import dict_apply - - -class BaseImagePolicy(ModuleAttrMixin): - # init accepts keyword argument shape_meta, see config/task/*_image.yaml - - def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - """ - obs_dict: - str: B,To,* - return: B,Ta,Da - """ - raise NotImplementedError() - - # reset state for stateful policies - def reset(self): - pass - - # ========== training =========== - # no standard training interface except setting normalizer - def set_normalizer(self, normalizer: LinearNormalizer): - raise NotImplementedError() - - -class DiffusionUnetImagePolicy(BaseImagePolicy): - def __init__( - self, - shape_meta: dict, - noise_scheduler: DDPMScheduler, - obs_encoder: MultiImageObsEncoder, - horizon, - n_action_steps, - n_obs_steps, - num_inference_steps=None, - obs_as_global_cond=True, - diffusion_step_embed_dim=256, - down_dims=(256, 512, 1024), - kernel_size=5, - n_groups=8, - cond_predict_scale=True, - # parameters passed to step - **kwargs, - ): - super().__init__() - - # parse shapes - action_shape = shape_meta["action"]["shape"] - assert len(action_shape) == 1 - action_dim = action_shape[0] - # get feature dim - obs_feature_dim = obs_encoder.output_shape()[0] - - # create diffusion model - input_dim = action_dim + obs_feature_dim - global_cond_dim = None - if obs_as_global_cond: - input_dim = action_dim - global_cond_dim = obs_feature_dim * n_obs_steps - - model = ConditionalUnet1D( - input_dim=input_dim, - local_cond_dim=None, - global_cond_dim=global_cond_dim, - diffusion_step_embed_dim=diffusion_step_embed_dim, - down_dims=down_dims, - kernel_size=kernel_size, - n_groups=n_groups, - cond_predict_scale=cond_predict_scale, - ) - - self.obs_encoder = obs_encoder - self.model = model - self.noise_scheduler = noise_scheduler - self.mask_generator = LowdimMaskGenerator( - action_dim=action_dim, - obs_dim=0 if obs_as_global_cond else obs_feature_dim, - max_n_obs_steps=n_obs_steps, - fix_obs_steps=True, - action_visible=False, - ) - self.horizon = horizon - self.obs_feature_dim = obs_feature_dim - self.action_dim = action_dim - self.n_action_steps = n_action_steps - self.n_obs_steps = n_obs_steps - self.obs_as_global_cond = obs_as_global_cond - self.kwargs = kwargs - - if num_inference_steps is None: - num_inference_steps = noise_scheduler.config.num_train_timesteps - self.num_inference_steps = num_inference_steps - - # ========= inference ============ - def conditional_sample( - self, - condition_data, - condition_mask, - local_cond=None, - global_cond=None, - generator=None, - # keyword arguments to scheduler.step - **kwargs, - ): - model = self.model - scheduler = self.noise_scheduler - - trajectory = torch.randn( - size=condition_data.shape, - dtype=condition_data.dtype, - device=condition_data.device, - generator=generator, - ) - - # set step values - scheduler.set_timesteps(self.num_inference_steps) - - for t in scheduler.timesteps: - # 1. apply conditioning - trajectory[condition_mask] = condition_data[condition_mask] - - # 2. predict model output - model_output = model(trajectory, t, local_cond=local_cond, global_cond=global_cond) - - # 3. compute previous image: x_t -> x_t-1 - trajectory = scheduler.step( - model_output, - t, - trajectory, - generator=generator, - # **kwargs # TODO(rcadene): in diffusion_policy, expected to be {} - ).prev_sample - - # finally make sure conditioning is enforced - trajectory[condition_mask] = condition_data[condition_mask] - - return trajectory - - def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - """ - obs_dict: must include "obs" key - result: must include "action" key - """ - assert "past_action" not in obs_dict # not implemented yet - nobs = obs_dict - value = next(iter(nobs.values())) - bsize, n_obs_steps = value.shape[:2] - horizon = self.horizon - action_dim = self.action_dim - obs_dim = self.obs_feature_dim - assert self.n_obs_steps == n_obs_steps - - # build input - device = self.device - dtype = self.dtype - - # handle different ways of passing observation - local_cond = None - global_cond = None - if self.obs_as_global_cond: - # condition through global feature - this_nobs = dict_apply(nobs, lambda x: x[:, :n_obs_steps, ...].reshape(-1, *x.shape[2:])) - nobs_features = self.obs_encoder(this_nobs) - # reshape back to B, Do - global_cond = nobs_features.reshape(bsize, -1) - # empty data for action - cond_data = torch.zeros(size=(bsize, horizon, action_dim), device=device, dtype=dtype) - cond_mask = torch.zeros_like(cond_data, dtype=torch.bool) - else: - # condition through impainting - this_nobs = dict_apply(nobs, lambda x: x[:, :n_obs_steps, ...].reshape(-1, *x.shape[2:])) - nobs_features = self.obs_encoder(this_nobs) - # reshape back to B, T, Do - nobs_features = nobs_features.reshape(bsize, n_obs_steps, -1) - cond_data = torch.zeros(size=(bsize, horizon, action_dim + obs_dim), device=device, dtype=dtype) - cond_mask = torch.zeros_like(cond_data, dtype=torch.bool) - cond_data[:, :n_obs_steps, action_dim:] = nobs_features - cond_mask[:, :n_obs_steps, action_dim:] = True - - # run sampling - nsample = self.conditional_sample( - cond_data, cond_mask, local_cond=local_cond, global_cond=global_cond - ) - - action_pred = nsample[..., :action_dim] - # get action - start = n_obs_steps - 1 - end = start + self.n_action_steps - action = action_pred[:, start:end] - - result = {"action": action, "action_pred": action_pred} - return result - - def compute_loss(self, batch): - nobs = { - "image": batch["observation.image"], - "agent_pos": batch["observation.state"], - } - nactions = batch["action"] - batch_size = nactions.shape[0] - horizon = nactions.shape[1] - - # handle different ways of passing observation - local_cond = None - global_cond = None - trajectory = nactions - cond_data = trajectory - if self.obs_as_global_cond: - # reshape B, T, ... to B*T - this_nobs = dict_apply(nobs, lambda x: x[:, : self.n_obs_steps, ...].reshape(-1, *x.shape[2:])) - nobs_features = self.obs_encoder(this_nobs) - # reshape back to B, Do - global_cond = nobs_features.reshape(batch_size, -1) - else: - # reshape B, T, ... to B*T - this_nobs = dict_apply(nobs, lambda x: x.reshape(-1, *x.shape[2:])) - nobs_features = self.obs_encoder(this_nobs) - # reshape back to B, T, Do - nobs_features = nobs_features.reshape(batch_size, horizon, -1) - cond_data = torch.cat([nactions, nobs_features], dim=-1) - trajectory = cond_data.detach() - - # generate impainting mask - condition_mask = self.mask_generator(trajectory.shape) - - # Sample noise that we'll add to the images - noise = torch.randn(trajectory.shape, device=trajectory.device) - bsz = trajectory.shape[0] - # Sample a random timestep for each image - timesteps = torch.randint( - 0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=trajectory.device - ).long() - # Add noise to the clean images according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_trajectory = self.noise_scheduler.add_noise(trajectory, noise, timesteps) - - # compute loss mask - loss_mask = ~condition_mask - - # apply conditioning - noisy_trajectory[condition_mask] = cond_data[condition_mask] - - # Predict the noise residual - pred = self.model(noisy_trajectory, timesteps, local_cond=local_cond, global_cond=global_cond) - - pred_type = self.noise_scheduler.config.prediction_type - if pred_type == "epsilon": - target = noise - elif pred_type == "sample": - target = trajectory - else: - raise ValueError(f"Unsupported prediction type {pred_type}") - - loss = F.mse_loss(pred, target, reduction="none") - loss = loss * loss_mask.type(loss.dtype) - - if "action_is_pad" in batch: - in_episode_bound = ~batch["action_is_pad"] - loss = loss * in_episode_bound[:, :, None].type(loss.dtype) - - loss = reduce(loss, "b t c -> b", "mean", b=batch_size) - loss = loss.mean() - return loss diff --git a/lerobot/common/policies/diffusion/model/conditional_unet1d.py b/lerobot/common/policies/diffusion/model/conditional_unet1d.py deleted file mode 100644 index d2971d38..00000000 --- a/lerobot/common/policies/diffusion/model/conditional_unet1d.py +++ /dev/null @@ -1,286 +0,0 @@ -import logging -from typing import Union - -import einops -import torch -import torch.nn as nn -from einops.layers.torch import Rearrange - -from lerobot.common.policies.diffusion.model.conv1d_components import Conv1dBlock, Downsample1d, Upsample1d -from lerobot.common.policies.diffusion.model.positional_embedding import SinusoidalPosEmb - -logger = logging.getLogger(__name__) - - -class ConditionalResidualBlock1D(nn.Module): - def __init__( - self, in_channels, out_channels, cond_dim, kernel_size=3, n_groups=8, cond_predict_scale=False - ): - super().__init__() - - self.blocks = nn.ModuleList( - [ - Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups), - Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups), - ] - ) - - # FiLM modulation https://arxiv.org/abs/1709.07871 - # predicts per-channel scale and bias - cond_channels = out_channels - if cond_predict_scale: - cond_channels = out_channels * 2 - self.cond_predict_scale = cond_predict_scale - self.out_channels = out_channels - self.cond_encoder = nn.Sequential( - nn.Mish(), - nn.Linear(cond_dim, cond_channels), - Rearrange("batch t -> batch t 1"), - ) - - # make sure dimensions compatible - self.residual_conv = ( - nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() - ) - - def forward(self, x, cond): - """ - x : [ batch_size x in_channels x horizon ] - cond : [ batch_size x cond_dim] - - returns: - out : [ batch_size x out_channels x horizon ] - """ - out = self.blocks[0](x) - embed = self.cond_encoder(cond) - if self.cond_predict_scale: - embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1) - scale = embed[:, 0, ...] - bias = embed[:, 1, ...] - out = scale * out + bias - else: - out = out + embed - out = self.blocks[1](out) - out = out + self.residual_conv(x) - return out - - -class ConditionalUnet1D(nn.Module): - def __init__( - self, - input_dim, - local_cond_dim=None, - global_cond_dim=None, - diffusion_step_embed_dim=256, - down_dims=None, - kernel_size=3, - n_groups=8, - cond_predict_scale=False, - ): - super().__init__() - if down_dims is None: - down_dims = [256, 512, 1024] - - all_dims = [input_dim] + list(down_dims) - start_dim = down_dims[0] - - dsed = diffusion_step_embed_dim - diffusion_step_encoder = nn.Sequential( - SinusoidalPosEmb(dsed), - nn.Linear(dsed, dsed * 4), - nn.Mish(), - nn.Linear(dsed * 4, dsed), - ) - cond_dim = dsed - if global_cond_dim is not None: - cond_dim += global_cond_dim - - in_out = list(zip(all_dims[:-1], all_dims[1:], strict=False)) - - local_cond_encoder = None - if local_cond_dim is not None: - _, dim_out = in_out[0] - dim_in = local_cond_dim - local_cond_encoder = nn.ModuleList( - [ - # down encoder - ConditionalResidualBlock1D( - dim_in, - dim_out, - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - cond_predict_scale=cond_predict_scale, - ), - # up encoder - ConditionalResidualBlock1D( - dim_in, - dim_out, - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - cond_predict_scale=cond_predict_scale, - ), - ] - ) - - mid_dim = all_dims[-1] - self.mid_modules = nn.ModuleList( - [ - ConditionalResidualBlock1D( - mid_dim, - mid_dim, - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - cond_predict_scale=cond_predict_scale, - ), - ConditionalResidualBlock1D( - mid_dim, - mid_dim, - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - cond_predict_scale=cond_predict_scale, - ), - ] - ) - - down_modules = nn.ModuleList([]) - for ind, (dim_in, dim_out) in enumerate(in_out): - is_last = ind >= (len(in_out) - 1) - down_modules.append( - nn.ModuleList( - [ - ConditionalResidualBlock1D( - dim_in, - dim_out, - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - cond_predict_scale=cond_predict_scale, - ), - ConditionalResidualBlock1D( - dim_out, - dim_out, - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - cond_predict_scale=cond_predict_scale, - ), - Downsample1d(dim_out) if not is_last else nn.Identity(), - ] - ) - ) - - up_modules = nn.ModuleList([]) - for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): - is_last = ind >= (len(in_out) - 1) - up_modules.append( - nn.ModuleList( - [ - ConditionalResidualBlock1D( - dim_out * 2, - dim_in, - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - cond_predict_scale=cond_predict_scale, - ), - ConditionalResidualBlock1D( - dim_in, - dim_in, - cond_dim=cond_dim, - kernel_size=kernel_size, - n_groups=n_groups, - cond_predict_scale=cond_predict_scale, - ), - Upsample1d(dim_in) if not is_last else nn.Identity(), - ] - ) - ) - - final_conv = nn.Sequential( - Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size), - nn.Conv1d(start_dim, input_dim, 1), - ) - - self.diffusion_step_encoder = diffusion_step_encoder - self.local_cond_encoder = local_cond_encoder - self.up_modules = up_modules - self.down_modules = down_modules - self.final_conv = final_conv - - logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) - - def forward( - self, - sample: torch.Tensor, - timestep: Union[torch.Tensor, float, int], - local_cond=None, - global_cond=None, - **kwargs, - ): - """ - x: (B,T,input_dim) - timestep: (B,) or int, diffusion step - local_cond: (B,T,local_cond_dim) - global_cond: (B,global_cond_dim) - output: (B,T,input_dim) - """ - sample = einops.rearrange(sample, "b h t -> b t h") - - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) - elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) - - global_feature = self.diffusion_step_encoder(timesteps) - - if global_cond is not None: - global_feature = torch.cat([global_feature, global_cond], axis=-1) - - # encode local features - h_local = [] - if local_cond is not None: - local_cond = einops.rearrange(local_cond, "b h t -> b t h") - resnet, resnet2 = self.local_cond_encoder - x = resnet(local_cond, global_feature) - h_local.append(x) - x = resnet2(local_cond, global_feature) - h_local.append(x) - - x = sample - h = [] - for idx, (resnet, resnet2, downsample) in enumerate(self.down_modules): - x = resnet(x, global_feature) - if idx == 0 and len(h_local) > 0: - x = x + h_local[0] - x = resnet2(x, global_feature) - h.append(x) - x = downsample(x) - - for mid_module in self.mid_modules: - x = mid_module(x, global_feature) - - for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules): - x = torch.cat((x, h.pop()), dim=1) - x = resnet(x, global_feature) - # The correct condition should be: - # if idx == (len(self.up_modules)-1) and len(h_local) > 0: - # However this change will break compatibility with published checkpoints. - # Therefore it is left as a comment. - if idx == len(self.up_modules) and len(h_local) > 0: - x = x + h_local[1] - x = resnet2(x, global_feature) - x = upsample(x) - - x = self.final_conv(x) - - x = einops.rearrange(x, "b t h -> b h t") - return x diff --git a/lerobot/common/policies/diffusion/model/conv1d_components.py b/lerobot/common/policies/diffusion/model/conv1d_components.py deleted file mode 100644 index 3c21eaf6..00000000 --- a/lerobot/common/policies/diffusion/model/conv1d_components.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch.nn as nn - -# from einops.layers.torch import Rearrange - - -class Downsample1d(nn.Module): - def __init__(self, dim): - super().__init__() - self.conv = nn.Conv1d(dim, dim, 3, 2, 1) - - def forward(self, x): - return self.conv(x) - - -class Upsample1d(nn.Module): - def __init__(self, dim): - super().__init__() - self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) - - def forward(self, x): - return self.conv(x) - - -class Conv1dBlock(nn.Module): - """ - Conv1d --> GroupNorm --> Mish - """ - - def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): - super().__init__() - - self.block = nn.Sequential( - nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), - # Rearrange('batch channels horizon -> batch channels 1 horizon'), - nn.GroupNorm(n_groups, out_channels), - # Rearrange('batch channels 1 horizon -> batch channels horizon'), - nn.Mish(), - ) - - def forward(self, x): - return self.block(x) - - -# def test(): -# cb = Conv1dBlock(256, 128, kernel_size=3) -# x = torch.zeros((1,256,16)) -# o = cb(x) diff --git a/lerobot/common/policies/diffusion/model/crop_randomizer.py b/lerobot/common/policies/diffusion/model/crop_randomizer.py deleted file mode 100644 index 2e60f353..00000000 --- a/lerobot/common/policies/diffusion/model/crop_randomizer.py +++ /dev/null @@ -1,294 +0,0 @@ -import torch -import torch.nn as nn -import torchvision.transforms.functional as ttf - -import lerobot.common.policies.diffusion.model.tensor_utils as tu - - -class CropRandomizer(nn.Module): - """ - Randomly sample crops at input, and then average across crop features at output. - """ - - def __init__( - self, - input_shape, - crop_height, - crop_width, - num_crops=1, - pos_enc=False, - ): - """ - Args: - input_shape (tuple, list): shape of input (not including batch dimension) - crop_height (int): crop height - crop_width (int): crop width - num_crops (int): number of random crops to take - pos_enc (bool): if True, add 2 channels to the output to encode the spatial - location of the cropped pixels in the source image - """ - super().__init__() - - assert len(input_shape) == 3 # (C, H, W) - assert crop_height < input_shape[1] - assert crop_width < input_shape[2] - - self.input_shape = input_shape - self.crop_height = crop_height - self.crop_width = crop_width - self.num_crops = num_crops - self.pos_enc = pos_enc - - def output_shape_in(self, input_shape=None): - """ - Function to compute output shape from inputs to this module. Corresponds to - the @forward_in operation, where raw inputs (usually observation modalities) - are passed in. - - Args: - input_shape (iterable of int): shape of input. Does not include batch dimension. - Some modules may not need this argument, if their output does not depend - on the size of the input, or if they assume fixed size input. - - Returns: - out_shape ([int]): list of integers corresponding to output shape - """ - - # outputs are shape (C, CH, CW), or maybe C + 2 if using position encoding, because - # the number of crops are reshaped into the batch dimension, increasing the batch - # size from B to B * N - out_c = self.input_shape[0] + 2 if self.pos_enc else self.input_shape[0] - return [out_c, self.crop_height, self.crop_width] - - def output_shape_out(self, input_shape=None): - """ - Function to compute output shape from inputs to this module. Corresponds to - the @forward_out operation, where processed inputs (usually encoded observation - modalities) are passed in. - - Args: - input_shape (iterable of int): shape of input. Does not include batch dimension. - Some modules may not need this argument, if their output does not depend - on the size of the input, or if they assume fixed size input. - - Returns: - out_shape ([int]): list of integers corresponding to output shape - """ - - # since the forward_out operation splits [B * N, ...] -> [B, N, ...] - # and then pools to result in [B, ...], only the batch dimension changes, - # and so the other dimensions retain their shape. - return list(input_shape) - - def forward_in(self, inputs): - """ - Samples N random crops for each input in the batch, and then reshapes - inputs to [B * N, ...]. - """ - assert len(inputs.shape) >= 3 # must have at least (C, H, W) dimensions - if self.training: - # generate random crops - out, _ = sample_random_image_crops( - images=inputs, - crop_height=self.crop_height, - crop_width=self.crop_width, - num_crops=self.num_crops, - pos_enc=self.pos_enc, - ) - # [B, N, ...] -> [B * N, ...] - return tu.join_dimensions(out, 0, 1) - else: - # take center crop during eval - out = ttf.center_crop(img=inputs, output_size=(self.crop_height, self.crop_width)) - if self.num_crops > 1: - B, C, H, W = out.shape # noqa: N806 - out = out.unsqueeze(1).expand(B, self.num_crops, C, H, W).reshape(-1, C, H, W) - # [B * N, ...] - return out - - def forward_out(self, inputs): - """ - Splits the outputs from shape [B * N, ...] -> [B, N, ...] and then average across N - to result in shape [B, ...] to make sure the network output is consistent with - what would have happened if there were no randomization. - """ - if self.num_crops <= 1: - return inputs - else: - batch_size = inputs.shape[0] // self.num_crops - out = tu.reshape_dimensions( - inputs, begin_axis=0, end_axis=0, target_dims=(batch_size, self.num_crops) - ) - return out.mean(dim=1) - - def forward(self, inputs): - return self.forward_in(inputs) - - def __repr__(self): - """Pretty print network.""" - header = "{}".format(str(self.__class__.__name__)) - msg = header + "(input_shape={}, crop_size=[{}, {}], num_crops={})".format( - self.input_shape, self.crop_height, self.crop_width, self.num_crops - ) - return msg - - -def crop_image_from_indices(images, crop_indices, crop_height, crop_width): - """ - Crops images at the locations specified by @crop_indices. Crops will be - taken across all channels. - - Args: - images (torch.Tensor): batch of images of shape [..., C, H, W] - - crop_indices (torch.Tensor): batch of indices of shape [..., N, 2] where - N is the number of crops to take per image and each entry corresponds - to the pixel height and width of where to take the crop. Note that - the indices can also be of shape [..., 2] if only 1 crop should - be taken per image. Leading dimensions must be consistent with - @images argument. Each index specifies the top left of the crop. - Values must be in range [0, H - CH - 1] x [0, W - CW - 1] where - H and W are the height and width of @images and CH and CW are - @crop_height and @crop_width. - - crop_height (int): height of crop to take - - crop_width (int): width of crop to take - - Returns: - crops (torch.Tesnor): cropped images of shape [..., C, @crop_height, @crop_width] - """ - - # make sure length of input shapes is consistent - assert crop_indices.shape[-1] == 2 - ndim_im_shape = len(images.shape) - ndim_indices_shape = len(crop_indices.shape) - assert (ndim_im_shape == ndim_indices_shape + 1) or (ndim_im_shape == ndim_indices_shape + 2) - - # maybe pad so that @crop_indices is shape [..., N, 2] - is_padded = False - if ndim_im_shape == ndim_indices_shape + 2: - crop_indices = crop_indices.unsqueeze(-2) - is_padded = True - - # make sure leading dimensions between images and indices are consistent - assert images.shape[:-3] == crop_indices.shape[:-2] - - device = images.device - image_c, image_h, image_w = images.shape[-3:] - num_crops = crop_indices.shape[-2] - - # make sure @crop_indices are in valid range - assert (crop_indices[..., 0] >= 0).all().item() - assert (crop_indices[..., 0] < (image_h - crop_height)).all().item() - assert (crop_indices[..., 1] >= 0).all().item() - assert (crop_indices[..., 1] < (image_w - crop_width)).all().item() - - # convert each crop index (ch, cw) into a list of pixel indices that correspond to the entire window. - - # 2D index array with columns [0, 1, ..., CH - 1] and shape [CH, CW] - crop_ind_grid_h = torch.arange(crop_height).to(device) - crop_ind_grid_h = tu.unsqueeze_expand_at(crop_ind_grid_h, size=crop_width, dim=-1) - # 2D index array with rows [0, 1, ..., CW - 1] and shape [CH, CW] - crop_ind_grid_w = torch.arange(crop_width).to(device) - crop_ind_grid_w = tu.unsqueeze_expand_at(crop_ind_grid_w, size=crop_height, dim=0) - # combine into shape [CH, CW, 2] - crop_in_grid = torch.cat((crop_ind_grid_h.unsqueeze(-1), crop_ind_grid_w.unsqueeze(-1)), dim=-1) - - # Add above grid with the offset index of each sampled crop to get 2d indices for each crop. - # After broadcasting, this will be shape [..., N, CH, CW, 2] and each crop has a [CH, CW, 2] - # shape array that tells us which pixels from the corresponding source image to grab. - grid_reshape = [1] * len(crop_indices.shape[:-1]) + [crop_height, crop_width, 2] - all_crop_inds = crop_indices.unsqueeze(-2).unsqueeze(-2) + crop_in_grid.reshape(grid_reshape) - - # For using @torch.gather, convert to flat indices from 2D indices, and also - # repeat across the channel dimension. To get flat index of each pixel to grab for - # each sampled crop, we just use the mapping: ind = h_ind * @image_w + w_ind - all_crop_inds = all_crop_inds[..., 0] * image_w + all_crop_inds[..., 1] # shape [..., N, CH, CW] - all_crop_inds = tu.unsqueeze_expand_at(all_crop_inds, size=image_c, dim=-3) # shape [..., N, C, CH, CW] - all_crop_inds = tu.flatten(all_crop_inds, begin_axis=-2) # shape [..., N, C, CH * CW] - - # Repeat and flatten the source images -> [..., N, C, H * W] and then use gather to index with crop pixel inds - images_to_crop = tu.unsqueeze_expand_at(images, size=num_crops, dim=-4) - images_to_crop = tu.flatten(images_to_crop, begin_axis=-2) - crops = torch.gather(images_to_crop, dim=-1, index=all_crop_inds) - # [..., N, C, CH * CW] -> [..., N, C, CH, CW] - reshape_axis = len(crops.shape) - 1 - crops = tu.reshape_dimensions( - crops, begin_axis=reshape_axis, end_axis=reshape_axis, target_dims=(crop_height, crop_width) - ) - - if is_padded: - # undo padding -> [..., C, CH, CW] - crops = crops.squeeze(-4) - return crops - - -def sample_random_image_crops(images, crop_height, crop_width, num_crops, pos_enc=False): - """ - For each image, randomly sample @num_crops crops of size (@crop_height, @crop_width), from - @images. - - Args: - images (torch.Tensor): batch of images of shape [..., C, H, W] - - crop_height (int): height of crop to take - - crop_width (int): width of crop to take - - num_crops (n): number of crops to sample - - pos_enc (bool): if True, also add 2 channels to the outputs that gives a spatial - encoding of the original source pixel locations. This means that the - output crops will contain information about where in the source image - it was sampled from. - - Returns: - crops (torch.Tensor): crops of shape (..., @num_crops, C, @crop_height, @crop_width) - if @pos_enc is False, otherwise (..., @num_crops, C + 2, @crop_height, @crop_width) - - crop_inds (torch.Tensor): sampled crop indices of shape (..., N, 2) - """ - device = images.device - - # maybe add 2 channels of spatial encoding to the source image - source_im = images - if pos_enc: - # spatial encoding [y, x] in [0, 1] - h, w = source_im.shape[-2:] - pos_y, pos_x = torch.meshgrid(torch.arange(h), torch.arange(w)) - pos_y = pos_y.float().to(device) / float(h) - pos_x = pos_x.float().to(device) / float(w) - position_enc = torch.stack((pos_y, pos_x)) # shape [C, H, W] - - # unsqueeze and expand to match leading dimensions -> shape [..., C, H, W] - leading_shape = source_im.shape[:-3] - position_enc = position_enc[(None,) * len(leading_shape)] - position_enc = position_enc.expand(*leading_shape, -1, -1, -1) - - # concat across channel dimension with input - source_im = torch.cat((source_im, position_enc), dim=-3) - - # make sure sample boundaries ensure crops are fully within the images - image_c, image_h, image_w = source_im.shape[-3:] - max_sample_h = image_h - crop_height - max_sample_w = image_w - crop_width - - # Sample crop locations for all tensor dimensions up to the last 3, which are [C, H, W]. - # Each gets @num_crops samples - typically this will just be the batch dimension (B), so - # we will sample [B, N] indices, but this supports having more than one leading dimension, - # or possibly no leading dimension. - # - # Trick: sample in [0, 1) with rand, then re-scale to [0, M) and convert to long to get sampled ints - crop_inds_h = (max_sample_h * torch.rand(*source_im.shape[:-3], num_crops).to(device)).long() - crop_inds_w = (max_sample_w * torch.rand(*source_im.shape[:-3], num_crops).to(device)).long() - crop_inds = torch.cat((crop_inds_h.unsqueeze(-1), crop_inds_w.unsqueeze(-1)), dim=-1) # shape [..., N, 2] - - crops = crop_image_from_indices( - images=source_im, - crop_indices=crop_inds, - crop_height=crop_height, - crop_width=crop_width, - ) - - return crops, crop_inds diff --git a/lerobot/common/policies/diffusion/model/dict_of_tensor_mixin.py b/lerobot/common/policies/diffusion/model/dict_of_tensor_mixin.py deleted file mode 100644 index d1356006..00000000 --- a/lerobot/common/policies/diffusion/model/dict_of_tensor_mixin.py +++ /dev/null @@ -1,41 +0,0 @@ -import torch -import torch.nn as nn - - -class DictOfTensorMixin(nn.Module): - def __init__(self, params_dict=None): - super().__init__() - if params_dict is None: - params_dict = nn.ParameterDict() - self.params_dict = params_dict - - @property - def device(self): - return next(iter(self.parameters())).device - - def _load_from_state_dict( - self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ): - def dfs_add(dest, keys, value: torch.Tensor): - if len(keys) == 1: - dest[keys[0]] = value - return - - if keys[0] not in dest: - dest[keys[0]] = nn.ParameterDict() - dfs_add(dest[keys[0]], keys[1:], value) - - def load_dict(state_dict, prefix): - out_dict = nn.ParameterDict() - for key, value in state_dict.items(): - value: torch.Tensor - if key.startswith(prefix): - param_keys = key[len(prefix) :].split(".")[1:] - # if len(param_keys) == 0: - # import pdb; pdb.set_trace() - dfs_add(out_dict, param_keys, value.clone()) - return out_dict - - self.params_dict = load_dict(state_dict, prefix + "params_dict") - self.params_dict.requires_grad_(False) - return diff --git a/lerobot/common/policies/diffusion/model/ema_model.py b/lerobot/common/policies/diffusion/model/ema_model.py deleted file mode 100644 index 6dc128de..00000000 --- a/lerobot/common/policies/diffusion/model/ema_model.py +++ /dev/null @@ -1,84 +0,0 @@ -import torch -from torch.nn.modules.batchnorm import _BatchNorm - - -class EMAModel: - """ - Exponential Moving Average of models weights - """ - - def __init__( - self, model, update_after_step=0, inv_gamma=1.0, power=2 / 3, min_value=0.0, max_value=0.9999 - ): - """ - @crowsonkb's notes on EMA Warmup: - If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan - to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), - gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 - at 215.4k steps). - Args: - inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. - power (float): Exponential factor of EMA warmup. Default: 2/3. - min_value (float): The minimum EMA decay rate. Default: 0. - """ - - self.averaged_model = model - self.averaged_model.eval() - self.averaged_model.requires_grad_(False) - - self.update_after_step = update_after_step - self.inv_gamma = inv_gamma - self.power = power - self.min_value = min_value - self.max_value = max_value - - self.decay = 0.0 - self.optimization_step = 0 - - def get_decay(self, optimization_step): - """ - Compute the decay factor for the exponential moving average. - """ - step = max(0, optimization_step - self.update_after_step - 1) - value = 1 - (1 + step / self.inv_gamma) ** -self.power - - if step <= 0: - return 0.0 - - return max(self.min_value, min(value, self.max_value)) - - @torch.no_grad() - def step(self, new_model): - self.decay = self.get_decay(self.optimization_step) - - # old_all_dataptrs = set() - # for param in new_model.parameters(): - # data_ptr = param.data_ptr() - # if data_ptr != 0: - # old_all_dataptrs.add(data_ptr) - - # all_dataptrs = set() - for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=False): - for param, ema_param in zip( - module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=False - ): - # iterative over immediate parameters only. - if isinstance(param, dict): - raise RuntimeError("Dict parameter not supported") - - # data_ptr = param.data_ptr() - # if data_ptr != 0: - # all_dataptrs.add(data_ptr) - - if isinstance(module, _BatchNorm): - # skip batchnorms - ema_param.copy_(param.to(dtype=ema_param.dtype).data) - elif not param.requires_grad: - ema_param.copy_(param.to(dtype=ema_param.dtype).data) - else: - ema_param.mul_(self.decay) - ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay) - - # verify that iterating over module and then parameters is identical to parameters recursively. - # assert old_all_dataptrs == all_dataptrs - self.optimization_step += 1 diff --git a/lerobot/common/policies/diffusion/model/lr_scheduler.py b/lerobot/common/policies/diffusion/model/lr_scheduler.py deleted file mode 100644 index 084b3a36..00000000 --- a/lerobot/common/policies/diffusion/model/lr_scheduler.py +++ /dev/null @@ -1,46 +0,0 @@ -from diffusers.optimization import TYPE_TO_SCHEDULER_FUNCTION, Optimizer, Optional, SchedulerType, Union - - -def get_scheduler( - name: Union[str, SchedulerType], - optimizer: Optimizer, - num_warmup_steps: Optional[int] = None, - num_training_steps: Optional[int] = None, - **kwargs, -): - """ - Added kwargs vs diffuser's original implementation - - Unified API to get any scheduler from its name. - - Args: - name (`str` or `SchedulerType`): - The name of the scheduler to use. - optimizer (`torch.optim.Optimizer`): - The optimizer that will be used during training. - num_warmup_steps (`int`, *optional*): - The number of warmup steps to do. This is not required by all schedulers (hence the argument being - optional), the function will raise an error if it's unset and the scheduler type requires it. - num_training_steps (`int``, *optional*): - The number of training steps to do. This is not required by all schedulers (hence the argument being - optional), the function will raise an error if it's unset and the scheduler type requires it. - """ - name = SchedulerType(name) - schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] - if name == SchedulerType.CONSTANT: - return schedule_func(optimizer, **kwargs) - - # All other schedulers require `num_warmup_steps` - if num_warmup_steps is None: - raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") - - if name == SchedulerType.CONSTANT_WITH_WARMUP: - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **kwargs) - - # All other schedulers require `num_training_steps` - if num_training_steps is None: - raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") - - return schedule_func( - optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **kwargs - ) diff --git a/lerobot/common/policies/diffusion/model/mask_generator.py b/lerobot/common/policies/diffusion/model/mask_generator.py deleted file mode 100644 index 63306dea..00000000 --- a/lerobot/common/policies/diffusion/model/mask_generator.py +++ /dev/null @@ -1,65 +0,0 @@ -import torch - -from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin - - -class LowdimMaskGenerator(ModuleAttrMixin): - def __init__( - self, - action_dim, - obs_dim, - # obs mask setup - max_n_obs_steps=2, - fix_obs_steps=True, - # action mask - action_visible=False, - ): - super().__init__() - self.action_dim = action_dim - self.obs_dim = obs_dim - self.max_n_obs_steps = max_n_obs_steps - self.fix_obs_steps = fix_obs_steps - self.action_visible = action_visible - - @torch.no_grad() - def forward(self, shape, seed=None): - device = self.device - B, T, D = shape # noqa: N806 - assert (self.action_dim + self.obs_dim) == D - - # create all tensors on this device - rng = torch.Generator(device=device) - if seed is not None: - rng = rng.manual_seed(seed) - - # generate dim mask - dim_mask = torch.zeros(size=shape, dtype=torch.bool, device=device) - is_action_dim = dim_mask.clone() - is_action_dim[..., : self.action_dim] = True - is_obs_dim = ~is_action_dim - - # generate obs mask - if self.fix_obs_steps: - obs_steps = torch.full((B,), fill_value=self.max_n_obs_steps, device=device) - else: - obs_steps = torch.randint( - low=1, high=self.max_n_obs_steps + 1, size=(B,), generator=rng, device=device - ) - - steps = torch.arange(0, T, device=device).reshape(1, T).expand(B, T) - obs_mask = (obs_steps > steps.T).T.reshape(B, T, 1).expand(B, T, D) - obs_mask = obs_mask & is_obs_dim - - # generate action mask - if self.action_visible: - action_steps = torch.maximum( - obs_steps - 1, torch.tensor(0, dtype=obs_steps.dtype, device=obs_steps.device) - ) - action_mask = (action_steps > steps.T).T.reshape(B, T, 1).expand(B, T, D) - action_mask = action_mask & is_action_dim - - mask = obs_mask - if self.action_visible: - mask = mask | action_mask - - return mask diff --git a/lerobot/common/policies/diffusion/model/module_attr_mixin.py b/lerobot/common/policies/diffusion/model/module_attr_mixin.py deleted file mode 100644 index 5d2cf4ea..00000000 --- a/lerobot/common/policies/diffusion/model/module_attr_mixin.py +++ /dev/null @@ -1,15 +0,0 @@ -import torch.nn as nn - - -class ModuleAttrMixin(nn.Module): - def __init__(self): - super().__init__() - self._dummy_variable = nn.Parameter() - - @property - def device(self): - return next(iter(self.parameters())).device - - @property - def dtype(self): - return next(iter(self.parameters())).dtype diff --git a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py b/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py deleted file mode 100644 index d724cd49..00000000 --- a/lerobot/common/policies/diffusion/model/multi_image_obs_encoder.py +++ /dev/null @@ -1,214 +0,0 @@ -import copy -from typing import Dict, Optional, Tuple, Union - -import torch -import torch.nn as nn -import torchvision -from robomimic.models.base_nets import ResNet18Conv, SpatialSoftmax - -from lerobot.common.policies.diffusion.model.crop_randomizer import CropRandomizer -from lerobot.common.policies.diffusion.model.module_attr_mixin import ModuleAttrMixin -from lerobot.common.policies.diffusion.pytorch_utils import replace_submodules - - -class RgbEncoder(nn.Module): - """Following `VisualCore` from Robomimic 0.2.0.""" - - def __init__(self, input_shape, relu=True, pretrained=False, num_keypoints=32): - """ - input_shape: channel-first input shape (C, H, W) - resnet_name: a timm model name. - pretrained: whether to use timm pretrained weights. - relu: whether to use relu as a final step. - num_keypoints: Number of keypoints for SpatialSoftmax (default value of 32 matches PushT Image). - """ - super().__init__() - self.backbone = ResNet18Conv(input_channel=input_shape[0], pretrained=pretrained) - # Figure out the feature map shape. - with torch.inference_mode(): - feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, *input_shape))).shape[1:]) - self.pool = SpatialSoftmax(feat_map_shape, num_kp=num_keypoints) - self.out = nn.Linear(num_keypoints * 2, num_keypoints * 2) - self.relu = nn.ReLU() if relu else nn.Identity() - - def forward(self, x): - return self.relu(self.out(torch.flatten(self.pool(self.backbone(x)), start_dim=1))) - - -class MultiImageObsEncoder(ModuleAttrMixin): - def __init__( - self, - shape_meta: dict, - rgb_model: Union[nn.Module, Dict[str, nn.Module]], - resize_shape: Union[Tuple[int, int], Dict[str, tuple], None] = None, - crop_shape: Union[Tuple[int, int], Dict[str, tuple], None] = None, - random_crop: bool = True, - # replace BatchNorm with GroupNorm - use_group_norm: bool = False, - # use single rgb model for all rgb inputs - share_rgb_model: bool = False, - # renormalize rgb input with imagenet normalization - # assuming input in [0,1] - norm_mean_std: Optional[tuple[float, float]] = None, - ): - """ - Assumes rgb input: B,C,H,W - Assumes low_dim input: B,D - """ - super().__init__() - - rgb_keys = [] - low_dim_keys = [] - key_model_map = nn.ModuleDict() - key_transform_map = nn.ModuleDict() - key_shape_map = {} - - # handle sharing vision backbone - if share_rgb_model: - assert isinstance(rgb_model, nn.Module) - key_model_map["rgb"] = rgb_model - - obs_shape_meta = shape_meta["obs"] - for key, attr in obs_shape_meta.items(): - shape = tuple(attr["shape"]) - type = attr.get("type", "low_dim") - key_shape_map[key] = shape - if type == "rgb": - rgb_keys.append(key) - # configure model for this key - this_model = None - if not share_rgb_model: - if isinstance(rgb_model, dict): - # have provided model for each key - this_model = rgb_model[key] - else: - assert isinstance(rgb_model, nn.Module) - # have a copy of the rgb model - this_model = copy.deepcopy(rgb_model) - - if this_model is not None: - if use_group_norm: - this_model = replace_submodules( - root_module=this_model, - predicate=lambda x: isinstance(x, nn.BatchNorm2d), - func=lambda x: nn.GroupNorm( - num_groups=x.num_features // 16, num_channels=x.num_features - ), - ) - key_model_map[key] = this_model - - # configure resize - input_shape = shape - this_resizer = nn.Identity() - if resize_shape is not None: - if isinstance(resize_shape, dict): - h, w = resize_shape[key] - else: - h, w = resize_shape - this_resizer = torchvision.transforms.Resize(size=(h, w)) - input_shape = (shape[0], h, w) - - # configure randomizer - this_randomizer = nn.Identity() - if crop_shape is not None: - if isinstance(crop_shape, dict): - h, w = crop_shape[key] - else: - h, w = crop_shape - if random_crop: - this_randomizer = CropRandomizer( - input_shape=input_shape, crop_height=h, crop_width=w, num_crops=1, pos_enc=False - ) - else: - this_normalizer = torchvision.transforms.CenterCrop(size=(h, w)) - # configure normalizer - this_normalizer = nn.Identity() - if norm_mean_std is not None: - this_normalizer = torchvision.transforms.Normalize( - mean=norm_mean_std[0], std=norm_mean_std[1] - ) - - this_transform = nn.Sequential(this_resizer, this_randomizer, this_normalizer) - key_transform_map[key] = this_transform - elif type == "low_dim": - low_dim_keys.append(key) - else: - raise RuntimeError(f"Unsupported obs type: {type}") - rgb_keys = sorted(rgb_keys) - low_dim_keys = sorted(low_dim_keys) - - self.shape_meta = shape_meta - self.key_model_map = key_model_map - self.key_transform_map = key_transform_map - self.share_rgb_model = share_rgb_model - self.rgb_keys = rgb_keys - self.low_dim_keys = low_dim_keys - self.key_shape_map = key_shape_map - - def forward(self, obs_dict): - batch_size = None - features = [] - - # process lowdim input - for key in self.low_dim_keys: - data = obs_dict[key] - if batch_size is None: - batch_size = data.shape[0] - else: - assert batch_size == data.shape[0] - assert data.shape[1:] == self.key_shape_map[key] - features.append(data) - - # process rgb input - if self.share_rgb_model: - # pass all rgb obs to rgb model - imgs = [] - for key in self.rgb_keys: - img = obs_dict[key] - if batch_size is None: - batch_size = img.shape[0] - else: - assert batch_size == img.shape[0] - assert img.shape[1:] == self.key_shape_map[key] - img = self.key_transform_map[key](img) - imgs.append(img) - # (N*B,C,H,W) - imgs = torch.cat(imgs, dim=0) - # (N*B,D) - feature = self.key_model_map["rgb"](imgs) - # (N,B,D) - feature = feature.reshape(-1, batch_size, *feature.shape[1:]) - # (B,N,D) - feature = torch.moveaxis(feature, 0, 1) - # (B,N*D) - feature = feature.reshape(batch_size, -1) - features.append(feature) - else: - # run each rgb obs to independent models - for key in self.rgb_keys: - img = obs_dict[key] - if batch_size is None: - batch_size = img.shape[0] - else: - assert batch_size == img.shape[0] - assert img.shape[1:] == self.key_shape_map[key] - img = self.key_transform_map[key](img) - feature = self.key_model_map[key](img) - features.append(feature) - - # concatenate all features - result = torch.cat(features, dim=-1) - return result - - @torch.no_grad() - def output_shape(self): - example_obs_dict = {} - obs_shape_meta = self.shape_meta["obs"] - batch_size = 1 - for key, attr in obs_shape_meta.items(): - shape = tuple(attr["shape"]) - this_obs = torch.zeros((batch_size,) + shape, dtype=self.dtype, device=self.device) - example_obs_dict[key] = this_obs - example_output = self.forward(example_obs_dict) - output_shape = example_output.shape[1:] - return output_shape diff --git a/lerobot/common/policies/diffusion/model/normalizer.py b/lerobot/common/policies/diffusion/model/normalizer.py deleted file mode 100644 index 0e4d79ab..00000000 --- a/lerobot/common/policies/diffusion/model/normalizer.py +++ /dev/null @@ -1,358 +0,0 @@ -from typing import Dict, Union - -import numpy as np -import torch -import torch.nn as nn -import zarr - -from lerobot.common.policies.diffusion.model.dict_of_tensor_mixin import DictOfTensorMixin -from lerobot.common.policies.diffusion.pytorch_utils import dict_apply - - -class LinearNormalizer(DictOfTensorMixin): - avaliable_modes = ["limits", "gaussian"] - - @torch.no_grad() - def fit( - self, - data: Union[Dict, torch.Tensor, np.ndarray, zarr.Array], - last_n_dims=1, - dtype=torch.float32, - mode="limits", - output_max=1.0, - output_min=-1.0, - range_eps=1e-4, - fit_offset=True, - ): - if isinstance(data, dict): - for key, value in data.items(): - self.params_dict[key] = _fit( - value, - last_n_dims=last_n_dims, - dtype=dtype, - mode=mode, - output_max=output_max, - output_min=output_min, - range_eps=range_eps, - fit_offset=fit_offset, - ) - else: - self.params_dict["_default"] = _fit( - data, - last_n_dims=last_n_dims, - dtype=dtype, - mode=mode, - output_max=output_max, - output_min=output_min, - range_eps=range_eps, - fit_offset=fit_offset, - ) - - def __call__(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor: - return self.normalize(x) - - def __getitem__(self, key: str): - return SingleFieldLinearNormalizer(self.params_dict[key]) - - def __setitem__(self, key: str, value: "SingleFieldLinearNormalizer"): - self.params_dict[key] = value.params_dict - - def _normalize_impl(self, x, forward=True): - if isinstance(x, dict): - result = {} - for key, value in x.items(): - params = self.params_dict[key] - result[key] = _normalize(value, params, forward=forward) - return result - else: - if "_default" not in self.params_dict: - raise RuntimeError("Not initialized") - params = self.params_dict["_default"] - return _normalize(x, params, forward=forward) - - def normalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor: - return self._normalize_impl(x, forward=True) - - def unnormalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor: - return self._normalize_impl(x, forward=False) - - def get_input_stats(self) -> Dict: - if len(self.params_dict) == 0: - raise RuntimeError("Not initialized") - if len(self.params_dict) == 1 and "_default" in self.params_dict: - return self.params_dict["_default"]["input_stats"] - - result = {} - for key, value in self.params_dict.items(): - if key != "_default": - result[key] = value["input_stats"] - return result - - def get_output_stats(self, key="_default"): - input_stats = self.get_input_stats() - if "min" in input_stats: - # no dict - return dict_apply(input_stats, self.normalize) - - result = {} - for key, group in input_stats.items(): - this_dict = {} - for name, value in group.items(): - this_dict[name] = self.normalize({key: value})[key] - result[key] = this_dict - return result - - -class SingleFieldLinearNormalizer(DictOfTensorMixin): - avaliable_modes = ["limits", "gaussian"] - - @torch.no_grad() - def fit( - self, - data: Union[torch.Tensor, np.ndarray, zarr.Array], - last_n_dims=1, - dtype=torch.float32, - mode="limits", - output_max=1.0, - output_min=-1.0, - range_eps=1e-4, - fit_offset=True, - ): - self.params_dict = _fit( - data, - last_n_dims=last_n_dims, - dtype=dtype, - mode=mode, - output_max=output_max, - output_min=output_min, - range_eps=range_eps, - fit_offset=fit_offset, - ) - - @classmethod - def create_fit(cls, data: Union[torch.Tensor, np.ndarray, zarr.Array], **kwargs): - obj = cls() - obj.fit(data, **kwargs) - return obj - - @classmethod - def create_manual( - cls, - scale: Union[torch.Tensor, np.ndarray], - offset: Union[torch.Tensor, np.ndarray], - input_stats_dict: Dict[str, Union[torch.Tensor, np.ndarray]], - ): - def to_tensor(x): - if not isinstance(x, torch.Tensor): - x = torch.from_numpy(x) - x = x.flatten() - return x - - # check - for x in [offset] + list(input_stats_dict.values()): - assert x.shape == scale.shape - assert x.dtype == scale.dtype - - params_dict = nn.ParameterDict( - { - "scale": to_tensor(scale), - "offset": to_tensor(offset), - "input_stats": nn.ParameterDict(dict_apply(input_stats_dict, to_tensor)), - } - ) - return cls(params_dict) - - @classmethod - def create_identity(cls, dtype=torch.float32): - scale = torch.tensor([1], dtype=dtype) - offset = torch.tensor([0], dtype=dtype) - input_stats_dict = { - "min": torch.tensor([-1], dtype=dtype), - "max": torch.tensor([1], dtype=dtype), - "mean": torch.tensor([0], dtype=dtype), - "std": torch.tensor([1], dtype=dtype), - } - return cls.create_manual(scale, offset, input_stats_dict) - - def normalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: - return _normalize(x, self.params_dict, forward=True) - - def unnormalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: - return _normalize(x, self.params_dict, forward=False) - - def get_input_stats(self): - return self.params_dict["input_stats"] - - def get_output_stats(self): - return dict_apply(self.params_dict["input_stats"], self.normalize) - - def __call__(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: - return self.normalize(x) - - -def _fit( - data: Union[torch.Tensor, np.ndarray, zarr.Array], - last_n_dims=1, - dtype=torch.float32, - mode="limits", - output_max=1.0, - output_min=-1.0, - range_eps=1e-4, - fit_offset=True, -): - assert mode in ["limits", "gaussian"] - assert last_n_dims >= 0 - assert output_max > output_min - - # convert data to torch and type - if isinstance(data, zarr.Array): - data = data[:] - if isinstance(data, np.ndarray): - data = torch.from_numpy(data) - if dtype is not None: - data = data.type(dtype) - - # convert shape - dim = 1 - if last_n_dims > 0: - dim = np.prod(data.shape[-last_n_dims:]) - data = data.reshape(-1, dim) - - # compute input stats min max mean std - input_min, _ = data.min(axis=0) - input_max, _ = data.max(axis=0) - input_mean = data.mean(axis=0) - input_std = data.std(axis=0) - - # compute scale and offset - if mode == "limits": - if fit_offset: - # unit scale - input_range = input_max - input_min - ignore_dim = input_range < range_eps - input_range[ignore_dim] = output_max - output_min - scale = (output_max - output_min) / input_range - offset = output_min - scale * input_min - offset[ignore_dim] = (output_max + output_min) / 2 - input_min[ignore_dim] - # ignore dims scaled to mean of output max and min - else: - # use this when data is pre-zero-centered. - assert output_max > 0 - assert output_min < 0 - # unit abs - output_abs = min(abs(output_min), abs(output_max)) - input_abs = torch.maximum(torch.abs(input_min), torch.abs(input_max)) - ignore_dim = input_abs < range_eps - input_abs[ignore_dim] = output_abs - # don't scale constant channels - scale = output_abs / input_abs - offset = torch.zeros_like(input_mean) - elif mode == "gaussian": - ignore_dim = input_std < range_eps - scale = input_std.clone() - scale[ignore_dim] = 1 - scale = 1 / scale - - offset = -input_mean * scale if fit_offset else torch.zeros_like(input_mean) - - # save - this_params = nn.ParameterDict( - { - "scale": scale, - "offset": offset, - "input_stats": nn.ParameterDict( - {"min": input_min, "max": input_max, "mean": input_mean, "std": input_std} - ), - } - ) - for p in this_params.parameters(): - p.requires_grad_(False) - return this_params - - -def _normalize(x, params, forward=True): - assert "scale" in params - if isinstance(x, np.ndarray): - x = torch.from_numpy(x) - scale = params["scale"] - offset = params["offset"] - x = x.to(device=scale.device, dtype=scale.dtype) - src_shape = x.shape - x = x.reshape(-1, scale.shape[0]) - x = x * scale + offset if forward else (x - offset) / scale - x = x.reshape(src_shape) - return x - - -def test(): - data = torch.zeros((100, 10, 9, 2)).uniform_() - data[..., 0, 0] = 0 - - normalizer = SingleFieldLinearNormalizer() - normalizer.fit(data, mode="limits", last_n_dims=2) - datan = normalizer.normalize(data) - assert datan.shape == data.shape - assert np.allclose(datan.max(), 1.0) - assert np.allclose(datan.min(), -1.0) - dataun = normalizer.unnormalize(datan) - assert torch.allclose(data, dataun, atol=1e-7) - - _ = normalizer.get_input_stats() - _ = normalizer.get_output_stats() - - normalizer = SingleFieldLinearNormalizer() - normalizer.fit(data, mode="limits", last_n_dims=1, fit_offset=False) - datan = normalizer.normalize(data) - assert datan.shape == data.shape - assert np.allclose(datan.max(), 1.0, atol=1e-3) - assert np.allclose(datan.min(), 0.0, atol=1e-3) - dataun = normalizer.unnormalize(datan) - assert torch.allclose(data, dataun, atol=1e-7) - - data = torch.zeros((100, 10, 9, 2)).uniform_() - normalizer = SingleFieldLinearNormalizer() - normalizer.fit(data, mode="gaussian", last_n_dims=0) - datan = normalizer.normalize(data) - assert datan.shape == data.shape - assert np.allclose(datan.mean(), 0.0, atol=1e-3) - assert np.allclose(datan.std(), 1.0, atol=1e-3) - dataun = normalizer.unnormalize(datan) - assert torch.allclose(data, dataun, atol=1e-7) - - # dict - data = torch.zeros((100, 10, 9, 2)).uniform_() - data[..., 0, 0] = 0 - - normalizer = LinearNormalizer() - normalizer.fit(data, mode="limits", last_n_dims=2) - datan = normalizer.normalize(data) - assert datan.shape == data.shape - assert np.allclose(datan.max(), 1.0) - assert np.allclose(datan.min(), -1.0) - dataun = normalizer.unnormalize(datan) - assert torch.allclose(data, dataun, atol=1e-7) - - _ = normalizer.get_input_stats() - _ = normalizer.get_output_stats() - - data = { - "obs": torch.zeros((1000, 128, 9, 2)).uniform_() * 512, - "action": torch.zeros((1000, 128, 2)).uniform_() * 512, - } - normalizer = LinearNormalizer() - normalizer.fit(data) - datan = normalizer.normalize(data) - dataun = normalizer.unnormalize(datan) - for key in data: - assert torch.allclose(data[key], dataun[key], atol=1e-4) - - _ = normalizer.get_input_stats() - _ = normalizer.get_output_stats() - - state_dict = normalizer.state_dict() - n = LinearNormalizer() - n.load_state_dict(state_dict) - datan = n.normalize(data) - dataun = n.unnormalize(datan) - for key in data: - assert torch.allclose(data[key], dataun[key], atol=1e-4) diff --git a/lerobot/common/policies/diffusion/model/positional_embedding.py b/lerobot/common/policies/diffusion/model/positional_embedding.py deleted file mode 100644 index 65fc97bd..00000000 --- a/lerobot/common/policies/diffusion/model/positional_embedding.py +++ /dev/null @@ -1,19 +0,0 @@ -import math - -import torch -import torch.nn as nn - - -class SinusoidalPosEmb(nn.Module): - def __init__(self, dim): - super().__init__() - self.dim = dim - - def forward(self, x): - device = x.device - half_dim = self.dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, device=device) * -emb) - emb = x[:, None] * emb[None, :] - emb = torch.cat((emb.sin(), emb.cos()), dim=-1) - return emb diff --git a/lerobot/common/policies/diffusion/model/tensor_utils.py b/lerobot/common/policies/diffusion/model/tensor_utils.py deleted file mode 100644 index df9a568a..00000000 --- a/lerobot/common/policies/diffusion/model/tensor_utils.py +++ /dev/null @@ -1,972 +0,0 @@ -""" -A collection of utilities for working with nested tensor structures consisting -of numpy arrays and torch tensors. -""" - -import collections - -import numpy as np -import torch - - -def recursive_dict_list_tuple_apply(x, type_func_dict): - """ - Recursively apply functions to a nested dictionary or list or tuple, given a dictionary of - {data_type: function_to_apply}. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - type_func_dict (dict): a mapping from data types to the functions to be - applied for each data type. - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - assert list not in type_func_dict - assert tuple not in type_func_dict - assert dict not in type_func_dict - - if isinstance(x, (dict, collections.OrderedDict)): - new_x = collections.OrderedDict() if isinstance(x, collections.OrderedDict) else {} - for k, v in x.items(): - new_x[k] = recursive_dict_list_tuple_apply(v, type_func_dict) - return new_x - elif isinstance(x, (list, tuple)): - ret = [recursive_dict_list_tuple_apply(v, type_func_dict) for v in x] - if isinstance(x, tuple): - ret = tuple(ret) - return ret - else: - for t, f in type_func_dict.items(): - if isinstance(x, t): - return f(x) - else: - raise NotImplementedError("Cannot handle data type %s" % str(type(x))) - - -def map_tensor(x, func): - """ - Apply function @func to torch.Tensor objects in a nested dictionary or - list or tuple. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - func (function): function to apply to each tensor - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: func, - type(None): lambda x: x, - }, - ) - - -def map_ndarray(x, func): - """ - Apply function @func to np.ndarray objects in a nested dictionary or - list or tuple. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - func (function): function to apply to each array - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - np.ndarray: func, - type(None): lambda x: x, - }, - ) - - -def map_tensor_ndarray(x, tensor_func, ndarray_func): - """ - Apply function @tensor_func to torch.Tensor objects and @ndarray_func to - np.ndarray objects in a nested dictionary or list or tuple. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - tensor_func (function): function to apply to each tensor - ndarray_Func (function): function to apply to each array - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: tensor_func, - np.ndarray: ndarray_func, - type(None): lambda x: x, - }, - ) - - -def clone(x): - """ - Clones all torch tensors and numpy arrays in nested dictionary or list - or tuple and returns a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x: x.clone(), - np.ndarray: lambda x: x.copy(), - type(None): lambda x: x, - }, - ) - - -def detach(x): - """ - Detaches all torch tensors in nested dictionary or list - or tuple and returns a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x: x.detach(), - }, - ) - - -def to_batch(x): - """ - Introduces a leading batch dimension of 1 for all torch tensors and numpy - arrays in nested dictionary or list or tuple and returns a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x: x[None, ...], - np.ndarray: lambda x: x[None, ...], - type(None): lambda x: x, - }, - ) - - -def to_sequence(x): - """ - Introduces a time dimension of 1 at dimension 1 for all torch tensors and numpy - arrays in nested dictionary or list or tuple and returns a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x: x[:, None, ...], - np.ndarray: lambda x: x[:, None, ...], - type(None): lambda x: x, - }, - ) - - -def index_at_time(x, ind): - """ - Indexes all torch tensors and numpy arrays in dimension 1 with index @ind in - nested dictionary or list or tuple and returns a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - ind (int): index - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x: x[:, ind, ...], - np.ndarray: lambda x: x[:, ind, ...], - type(None): lambda x: x, - }, - ) - - -def unsqueeze(x, dim): - """ - Adds dimension of size 1 at dimension @dim in all torch tensors and numpy arrays - in nested dictionary or list or tuple and returns a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - dim (int): dimension - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x: x.unsqueeze(dim=dim), - np.ndarray: lambda x: np.expand_dims(x, axis=dim), - type(None): lambda x: x, - }, - ) - - -def contiguous(x): - """ - Makes all torch tensors and numpy arrays contiguous in nested dictionary or - list or tuple and returns a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x: x.contiguous(), - np.ndarray: lambda x: np.ascontiguousarray(x), - type(None): lambda x: x, - }, - ) - - -def to_device(x, device): - """ - Sends all torch tensors in nested dictionary or list or tuple to device - @device, and returns a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - device (torch.Device): device to send tensors to - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x, d=device: x.to(d), - type(None): lambda x: x, - }, - ) - - -def to_tensor(x): - """ - Converts all numpy arrays in nested dictionary or list or tuple to - torch tensors (and leaves existing torch Tensors as-is), and returns - a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x: x, - np.ndarray: lambda x: torch.from_numpy(x), - type(None): lambda x: x, - }, - ) - - -def to_numpy(x): - """ - Converts all torch tensors in nested dictionary or list or tuple to - numpy (and leaves existing numpy arrays as-is), and returns - a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - - def f(tensor): - if tensor.is_cuda: - return tensor.detach().cpu().numpy() - else: - return tensor.detach().numpy() - - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: f, - np.ndarray: lambda x: x, - type(None): lambda x: x, - }, - ) - - -def to_list(x): - """ - Converts all torch tensors and numpy arrays in nested dictionary or list - or tuple to a list, and returns a new nested structure. Useful for - json encoding. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - - def f(tensor): - if tensor.is_cuda: - return tensor.detach().cpu().numpy().tolist() - else: - return tensor.detach().numpy().tolist() - - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: f, - np.ndarray: lambda x: x.tolist(), - type(None): lambda x: x, - }, - ) - - -def to_float(x): - """ - Converts all torch tensors and numpy arrays in nested dictionary or list - or tuple to float type entries, and returns a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x: x.float(), - np.ndarray: lambda x: x.astype(np.float32), - type(None): lambda x: x, - }, - ) - - -def to_uint8(x): - """ - Converts all torch tensors and numpy arrays in nested dictionary or list - or tuple to uint8 type entries, and returns a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x: x.byte(), - np.ndarray: lambda x: x.astype(np.uint8), - type(None): lambda x: x, - }, - ) - - -def to_torch(x, device): - """ - Converts all numpy arrays and torch tensors in nested dictionary or list or tuple to - torch tensors on device @device and returns a new nested structure. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - device (torch.Device): device to send tensors to - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return to_device(to_float(to_tensor(x)), device) - - -def to_one_hot_single(tensor, num_class): - """ - Convert tensor to one-hot representation, assuming a certain number of total class labels. - - Args: - tensor (torch.Tensor): tensor containing integer labels - num_class (int): number of classes - - Returns: - x (torch.Tensor): tensor containing one-hot representation of labels - """ - x = torch.zeros(tensor.size() + (num_class,)).to(tensor.device) - x.scatter_(-1, tensor.unsqueeze(-1), 1) - return x - - -def to_one_hot(tensor, num_class): - """ - Convert all tensors in nested dictionary or list or tuple to one-hot representation, - assuming a certain number of total class labels. - - Args: - tensor (dict or list or tuple): a possibly nested dictionary or list or tuple - num_class (int): number of classes - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return map_tensor(tensor, func=lambda x, nc=num_class: to_one_hot_single(x, nc)) - - -def flatten_single(x, begin_axis=1): - """ - Flatten a tensor in all dimensions from @begin_axis onwards. - - Args: - x (torch.Tensor): tensor to flatten - begin_axis (int): which axis to flatten from - - Returns: - y (torch.Tensor): flattened tensor - """ - fixed_size = x.size()[:begin_axis] - _s = list(fixed_size) + [-1] - return x.reshape(*_s) - - -def flatten(x, begin_axis=1): - """ - Flatten all tensors in nested dictionary or list or tuple, from @begin_axis onwards. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - begin_axis (int): which axis to flatten from - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x, b=begin_axis: flatten_single(x, begin_axis=b), - }, - ) - - -def reshape_dimensions_single(x, begin_axis, end_axis, target_dims): - """ - Reshape selected dimensions in a tensor to a target dimension. - - Args: - x (torch.Tensor): tensor to reshape - begin_axis (int): begin dimension - end_axis (int): end dimension - target_dims (tuple or list): target shape for the range of dimensions - (@begin_axis, @end_axis) - - Returns: - y (torch.Tensor): reshaped tensor - """ - assert begin_axis <= end_axis - assert begin_axis >= 0 - assert end_axis < len(x.shape) - assert isinstance(target_dims, (tuple, list)) - s = x.shape - final_s = [] - for i in range(len(s)): - if i == begin_axis: - final_s.extend(target_dims) - elif i < begin_axis or i > end_axis: - final_s.append(s[i]) - return x.reshape(*final_s) - - -def reshape_dimensions(x, begin_axis, end_axis, target_dims): - """ - Reshape selected dimensions for all tensors in nested dictionary or list or tuple - to a target dimension. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - begin_axis (int): begin dimension - end_axis (int): end dimension - target_dims (tuple or list): target shape for the range of dimensions - (@begin_axis, @end_axis) - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single( - x, begin_axis=b, end_axis=e, target_dims=t - ), - np.ndarray: lambda x, b=begin_axis, e=end_axis, t=target_dims: reshape_dimensions_single( - x, begin_axis=b, end_axis=e, target_dims=t - ), - type(None): lambda x: x, - }, - ) - - -def join_dimensions(x, begin_axis, end_axis): - """ - Joins all dimensions between dimensions (@begin_axis, @end_axis) into a flat dimension, for - all tensors in nested dictionary or list or tuple. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - begin_axis (int): begin dimension - end_axis (int): end dimension - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single( - x, begin_axis=b, end_axis=e, target_dims=[-1] - ), - np.ndarray: lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single( - x, begin_axis=b, end_axis=e, target_dims=[-1] - ), - type(None): lambda x: x, - }, - ) - - -def expand_at_single(x, size, dim): - """ - Expand a tensor at a single dimension @dim by @size - - Args: - x (torch.Tensor): input tensor - size (int): size to expand - dim (int): dimension to expand - - Returns: - y (torch.Tensor): expanded tensor - """ - assert dim < x.ndimension() - assert x.shape[dim] == 1 - expand_dims = [-1] * x.ndimension() - expand_dims[dim] = size - return x.expand(*expand_dims) - - -def expand_at(x, size, dim): - """ - Expand all tensors in nested dictionary or list or tuple at a single - dimension @dim by @size. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - size (int): size to expand - dim (int): dimension to expand - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return map_tensor(x, lambda t, s=size, d=dim: expand_at_single(t, s, d)) - - -def unsqueeze_expand_at(x, size, dim): - """ - Unsqueeze and expand a tensor at a dimension @dim by @size. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - size (int): size to expand - dim (int): dimension to unsqueeze and expand - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - x = unsqueeze(x, dim) - return expand_at(x, size, dim) - - -def repeat_by_expand_at(x, repeats, dim): - """ - Repeat a dimension by combining expand and reshape operations. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - repeats (int): number of times to repeat the target dimension - dim (int): dimension to repeat on - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - x = unsqueeze_expand_at(x, repeats, dim + 1) - return join_dimensions(x, dim, dim + 1) - - -def named_reduce_single(x, reduction, dim): - """ - Reduce tensor at a dimension by named reduction functions. - - Args: - x (torch.Tensor): tensor to be reduced - reduction (str): one of ["sum", "max", "mean", "flatten"] - dim (int): dimension to be reduced (or begin axis for flatten) - - Returns: - y (torch.Tensor): reduced tensor - """ - assert x.ndimension() > dim - assert reduction in ["sum", "max", "mean", "flatten"] - if reduction == "flatten": - x = flatten(x, begin_axis=dim) - elif reduction == "max": - x = torch.max(x, dim=dim)[0] # [B, D] - elif reduction == "sum": - x = torch.sum(x, dim=dim) - else: - x = torch.mean(x, dim=dim) - return x - - -def named_reduce(x, reduction, dim): - """ - Reduces all tensors in nested dictionary or list or tuple at a dimension - using a named reduction function. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - reduction (str): one of ["sum", "max", "mean", "flatten"] - dim (int): dimension to be reduced (or begin axis for flatten) - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return map_tensor(x, func=lambda t, r=reduction, d=dim: named_reduce_single(t, r, d)) - - -def gather_along_dim_with_dim_single(x, target_dim, source_dim, indices): - """ - This function indexes out a target dimension of a tensor in a structured way, - by allowing a different value to be selected for each member of a flat index - tensor (@indices) corresponding to a source dimension. This can be interpreted - as moving along the source dimension, using the corresponding index value - in @indices to select values for all other dimensions outside of the - source and target dimensions. A common use case is to gather values - in target dimension 1 for each batch member (target dimension 0). - - Args: - x (torch.Tensor): tensor to gather values for - target_dim (int): dimension to gather values along - source_dim (int): dimension to hold constant and use for gathering values - from the other dimensions - indices (torch.Tensor): flat index tensor with same shape as tensor @x along - @source_dim - - Returns: - y (torch.Tensor): gathered tensor, with dimension @target_dim indexed out - """ - assert len(indices.shape) == 1 - assert x.shape[source_dim] == indices.shape[0] - - # unsqueeze in all dimensions except the source dimension - new_shape = [1] * x.ndimension() - new_shape[source_dim] = -1 - indices = indices.reshape(*new_shape) - - # repeat in all dimensions - but preserve shape of source dimension, - # and make sure target_dimension has singleton dimension - expand_shape = list(x.shape) - expand_shape[source_dim] = -1 - expand_shape[target_dim] = 1 - indices = indices.expand(*expand_shape) - - out = x.gather(dim=target_dim, index=indices) - return out.squeeze(target_dim) - - -def gather_along_dim_with_dim(x, target_dim, source_dim, indices): - """ - Apply @gather_along_dim_with_dim_single to all tensors in a nested - dictionary or list or tuple. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - target_dim (int): dimension to gather values along - source_dim (int): dimension to hold constant and use for gathering values - from the other dimensions - indices (torch.Tensor): flat index tensor with same shape as tensor @x along - @source_dim - - Returns: - y (dict or list or tuple): new nested dict-list-tuple - """ - return map_tensor( - x, lambda y, t=target_dim, s=source_dim, i=indices: gather_along_dim_with_dim_single(y, t, s, i) - ) - - -def gather_sequence_single(seq, indices): - """ - Given a tensor with leading dimensions [B, T, ...], gather an element from each sequence in - the batch given an index for each sequence. - - Args: - seq (torch.Tensor): tensor with leading dimensions [B, T, ...] - indices (torch.Tensor): tensor indices of shape [B] - - Return: - y (torch.Tensor): indexed tensor of shape [B, ....] - """ - return gather_along_dim_with_dim_single(seq, target_dim=1, source_dim=0, indices=indices) - - -def gather_sequence(seq, indices): - """ - Given a nested dictionary or list or tuple, gathers an element from each sequence of the batch - for tensors with leading dimensions [B, T, ...]. - - Args: - seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors - of leading dimensions [B, T, ...] - indices (torch.Tensor): tensor indices of shape [B] - - Returns: - y (dict or list or tuple): new nested dict-list-tuple with tensors of shape [B, ...] - """ - return gather_along_dim_with_dim(seq, target_dim=1, source_dim=0, indices=indices) - - -def pad_sequence_single(seq, padding, batched=False, pad_same=True, pad_values=None): - """ - Pad input tensor or array @seq in the time dimension (dimension 1). - - Args: - seq (np.ndarray or torch.Tensor): sequence to be padded - padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1 - batched (bool): if sequence has the batch dimension - pad_same (bool): if pad by duplicating - pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same - - Returns: - padded sequence (np.ndarray or torch.Tensor) - """ - assert isinstance(seq, (np.ndarray, torch.Tensor)) - assert pad_same or pad_values is not None - if pad_values is not None: - assert isinstance(pad_values, float) - repeat_func = np.repeat if isinstance(seq, np.ndarray) else torch.repeat_interleave - concat_func = np.concatenate if isinstance(seq, np.ndarray) else torch.cat - ones_like_func = np.ones_like if isinstance(seq, np.ndarray) else torch.ones_like - seq_dim = 1 if batched else 0 - - begin_pad = [] - end_pad = [] - - if padding[0] > 0: - pad = seq[[0]] if pad_same else ones_like_func(seq[[0]]) * pad_values - begin_pad.append(repeat_func(pad, padding[0], seq_dim)) - if padding[1] > 0: - pad = seq[[-1]] if pad_same else ones_like_func(seq[[-1]]) * pad_values - end_pad.append(repeat_func(pad, padding[1], seq_dim)) - - return concat_func(begin_pad + [seq] + end_pad, seq_dim) - - -def pad_sequence(seq, padding, batched=False, pad_same=True, pad_values=None): - """ - Pad a nested dictionary or list or tuple of sequence tensors in the time dimension (dimension 1). - - Args: - seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors - of leading dimensions [B, T, ...] - padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1 - batched (bool): if sequence has the batch dimension - pad_same (bool): if pad by duplicating - pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same - - Returns: - padded sequence (dict or list or tuple) - """ - return recursive_dict_list_tuple_apply( - seq, - { - torch.Tensor: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values: pad_sequence_single( - x, p, b, ps, pv - ), - np.ndarray: lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values: pad_sequence_single( - x, p, b, ps, pv - ), - type(None): lambda x: x, - }, - ) - - -def assert_size_at_dim_single(x, size, dim, msg): - """ - Ensure that array or tensor @x has size @size in dim @dim. - - Args: - x (np.ndarray or torch.Tensor): input array or tensor - size (int): size that tensors should have at @dim - dim (int): dimension to check - msg (str): text to display if assertion fails - """ - assert x.shape[dim] == size, msg - - -def assert_size_at_dim(x, size, dim, msg): - """ - Ensure that arrays and tensors in nested dictionary or list or tuple have - size @size in dim @dim. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - size (int): size that tensors should have at @dim - dim (int): dimension to check - """ - map_tensor(x, lambda t, s=size, d=dim, m=msg: assert_size_at_dim_single(t, s, d, m)) - - -def get_shape(x): - """ - Get all shapes of arrays and tensors in nested dictionary or list or tuple. - - Args: - x (dict or list or tuple): a possibly nested dictionary or list or tuple - - Returns: - y (dict or list or tuple): new nested dict-list-tuple that contains each array or - tensor's shape - """ - return recursive_dict_list_tuple_apply( - x, - { - torch.Tensor: lambda x: x.shape, - np.ndarray: lambda x: x.shape, - type(None): lambda x: x, - }, - ) - - -def list_of_flat_dict_to_dict_of_list(list_of_dict): - """ - Helper function to go from a list of flat dictionaries to a dictionary of lists. - By "flat" we mean that none of the values are dictionaries, but are numpy arrays, - floats, etc. - - Args: - list_of_dict (list): list of flat dictionaries - - Returns: - dict_of_list (dict): dictionary of lists - """ - assert isinstance(list_of_dict, list) - dic = collections.OrderedDict() - for i in range(len(list_of_dict)): - for k in list_of_dict[i]: - if k not in dic: - dic[k] = [] - dic[k].append(list_of_dict[i][k]) - return dic - - -def flatten_nested_dict_list(d, parent_key="", sep="_", item_key=""): - """ - Flatten a nested dict or list to a list. - - For example, given a dict - { - a: 1 - b: { - c: 2 - } - c: 3 - } - - the function would return [(a, 1), (b_c, 2), (c, 3)] - - Args: - d (dict, list): a nested dict or list to be flattened - parent_key (str): recursion helper - sep (str): separator for nesting keys - item_key (str): recursion helper - Returns: - list: a list of (key, value) tuples - """ - items = [] - if isinstance(d, (tuple, list)): - new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key - for i, v in enumerate(d): - items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=str(i))) - return items - elif isinstance(d, dict): - new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key - for k, v in d.items(): - assert isinstance(k, str) - items.extend(flatten_nested_dict_list(v, new_key, sep=sep, item_key=k)) - return items - else: - new_key = parent_key + sep + item_key if len(parent_key) > 0 else item_key - return [(new_key, d)] - - -def time_distributed(inputs, op, activation=None, inputs_as_kwargs=False, inputs_as_args=False, **kwargs): - """ - Apply function @op to all tensors in nested dictionary or list or tuple @inputs in both the - batch (B) and time (T) dimension, where the tensors are expected to have shape [B, T, ...]. - Will do this by reshaping tensors to [B * T, ...], passing through the op, and then reshaping - outputs to [B, T, ...]. - - Args: - inputs (list or tuple or dict): a possibly nested dictionary or list or tuple with tensors - of leading dimensions [B, T, ...] - op: a layer op that accepts inputs - activation: activation to apply at the output - inputs_as_kwargs (bool): whether to feed input as a kwargs dict to the op - inputs_as_args (bool) whether to feed input as a args list to the op - kwargs (dict): other kwargs to supply to the op - - Returns: - outputs (dict or list or tuple): new nested dict-list-tuple with tensors of leading dimension [B, T]. - """ - batch_size, seq_len = flatten_nested_dict_list(inputs)[0][1].shape[:2] - inputs = join_dimensions(inputs, 0, 1) - if inputs_as_kwargs: - outputs = op(**inputs, **kwargs) - elif inputs_as_args: - outputs = op(*inputs, **kwargs) - else: - outputs = op(inputs, **kwargs) - - if activation is not None: - outputs = map_tensor(outputs, activation) - outputs = reshape_dimensions(outputs, begin_axis=0, end_axis=0, target_dims=(batch_size, seq_len)) - return outputs diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py new file mode 100644 index 00000000..dfab9bb7 --- /dev/null +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -0,0 +1,717 @@ +"""Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion" + +TODO(alexander-soare): + - Remove reliance on Robomimic for SpatialSoftmax. + - Remove reliance on diffusers for DDPMScheduler and LR scheduler. + - Move EMA out of policy. + - Consolidate _DiffusionUnetImagePolicy into DiffusionPolicy. + - One more pass on comments and documentation. +""" + +import copy +import logging +import math +import time +from collections import deque +from itertools import chain +from typing import Callable + +import einops +import torch +import torch.nn.functional as F # noqa: N812 +import torchvision +from diffusers.optimization import get_scheduler +from diffusers.schedulers.scheduling_ddpm import DDPMScheduler +from robomimic.models.base_nets import SpatialSoftmax +from torch import Tensor, nn +from torch.nn.modules.batchnorm import _BatchNorm + +from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig +from lerobot.common.policies.utils import ( + get_device_from_parameters, + get_dtype_from_parameters, + populate_queues, +) + + +class DiffusionPolicy(nn.Module): + """ + Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion" + (paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy). + """ + + name = "diffusion" + + def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0): + super().__init__() + """ + Args: + cfg: Policy configuration class instance or None, in which case the default instantiation of the + configuration class is used. + """ + # TODO(alexander-soare): LR scheduler will be removed. + assert lr_scheduler_num_training_steps > 0 + if cfg is None: + cfg = DiffusionConfig() + self.cfg = cfg + + # queues are populated during rollout of the policy, they contain the n latest observations and actions + self._queues = None + + self.diffusion = _DiffusionUnetImagePolicy(cfg) + + # TODO(alexander-soare): This should probably be managed outside of the policy class. + self.ema_diffusion = None + self.ema = None + if self.cfg.use_ema: + self.ema_diffusion = copy.deepcopy(self.diffusion) + self.ema = _EMA(cfg, model=self.ema_diffusion) + + # TODO(alexander-soare): Move optimizer out of policy. + self.optimizer = torch.optim.Adam( + self.diffusion.parameters(), cfg.lr, cfg.adam_betas, cfg.adam_eps, cfg.adam_weight_decay + ) + + # TODO(alexander-soare): Move LR scheduler out of policy. + # TODO(rcadene): modify lr scheduler so that it doesn't depend on epochs but steps + self.global_step = 0 + + # configure lr scheduler + self.lr_scheduler = get_scheduler( + cfg.lr_scheduler, + optimizer=self.optimizer, + num_warmup_steps=cfg.lr_warmup_steps, + num_training_steps=lr_scheduler_num_training_steps, + # pytorch assumes stepping LRScheduler every epoch + # however huggingface diffusers steps it every batch + last_epoch=self.global_step - 1, + ) + + def reset(self): + """ + Clear observation and action queues. Should be called on `env.reset()` + """ + self._queues = { + "observation.image": deque(maxlen=self.cfg.n_obs_steps), + "observation.state": deque(maxlen=self.cfg.n_obs_steps), + "action": deque(maxlen=self.cfg.n_action_steps), + } + + @torch.no_grad + def select_action(self, batch: dict[str, Tensor], **_) -> Tensor: + """Select a single action given environment observations. + + This method handles caching a history of observations and an action trajectory generated by the + underlying diffusion model. Here's how it works: + - `n_obs_steps` steps worth of observations are cached (for the first steps, the observation is + copied `n_obs_steps` times to fill the cache). + - The diffusion model generates `horizon` steps worth of actions. + - `n_action_steps` worth of actions are actually kept for execution, starting from the current step. + Schematically this looks like: + ---------------------------------------------------------------------------------------------- + (legend: o = n_obs_steps, h = horizon, a = n_action_steps) + |timestep | n-o+1 | n-o+2 | ..... | n | ..... | n+a-1 | n+a | ..... |n-o+1+h| + |observation is used | YES | YES | YES | NO | NO | NO | NO | NO | NO | + |action is generated | YES | YES | YES | YES | YES | YES | YES | YES | YES | + |action is used | NO | NO | NO | YES | YES | YES | NO | NO | NO | + ---------------------------------------------------------------------------------------------- + Note that this means we require: `n_action_steps < horizon - n_obs_steps + 1`. Also, note that + "horizon" may not the best name to describe what the variable actually means, because this period is + actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. + + Note: this method uses the ema model weights if self.training == False, otherwise the non-ema model + weights. + """ + assert "observation.image" in batch + assert "observation.state" in batch + assert len(batch) == 2 + + self._queues = populate_queues(self._queues, batch) + + if len(self._queues["action"]) == 0: + # stack n latest observations from the queue + batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} + if not self.training and self.ema_diffusion is not None: + actions = self.ema_diffusion.generate_actions(batch) + else: + actions = self.diffusion.generate_actions(batch) + self._queues["action"].extend(actions.transpose(0, 1)) + + action = self._queues["action"].popleft() + return action + + def forward(self, batch, **_): + start_time = time.time() + + self.diffusion.train() + + loss = self.diffusion.compute_loss(batch) + loss.backward() + + grad_norm = torch.nn.utils.clip_grad_norm_( + self.diffusion.parameters(), + self.cfg.grad_clip_norm, + error_if_nonfinite=False, + ) + + self.optimizer.step() + self.optimizer.zero_grad() + self.lr_scheduler.step() + + if self.ema is not None: + self.ema.step(self.diffusion) + + info = { + "loss": loss.item(), + "grad_norm": float(grad_norm), + "lr": self.lr_scheduler.get_last_lr()[0], + "update_s": time.time() - start_time, + } + + return info + + def save(self, fp): + torch.save(self.state_dict(), fp) + + def load(self, fp): + d = torch.load(fp) + missing_keys, unexpected_keys = self.load_state_dict(d, strict=False) + if len(missing_keys) > 0: + assert all(k.startswith("ema_diffusion.") for k in missing_keys) + logging.warning( + "DiffusionPolicy.load expected ema parameters in loaded state dict but none were found." + ) + assert len(unexpected_keys) == 0 + + +class _DiffusionUnetImagePolicy(nn.Module): + def __init__(self, cfg: DiffusionConfig): + super().__init__() + self.cfg = cfg + + self.rgb_encoder = _RgbEncoder(cfg) + self.unet = _ConditionalUnet1D( + cfg, global_cond_dim=(cfg.action_dim + self.rgb_encoder.feature_dim) * cfg.n_obs_steps + ) + + self.noise_scheduler = DDPMScheduler( + num_train_timesteps=cfg.num_train_timesteps, + beta_start=cfg.beta_start, + beta_end=cfg.beta_end, + beta_schedule=cfg.beta_schedule, + variance_type="fixed_small", + clip_sample=cfg.clip_sample, + clip_sample_range=cfg.clip_sample_range, + prediction_type=cfg.prediction_type, + ) + + if cfg.num_inference_steps is None: + self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps + else: + self.num_inference_steps = cfg.num_inference_steps + + # ========= inference ============ + def conditional_sample( + self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None + ) -> Tensor: + device = get_device_from_parameters(self) + dtype = get_dtype_from_parameters(self) + + # Sample prior. + sample = torch.randn( + size=(batch_size, self.cfg.horizon, self.cfg.action_dim), + dtype=dtype, + device=device, + generator=generator, + ) + + self.noise_scheduler.set_timesteps(self.num_inference_steps) + + for t in self.noise_scheduler.timesteps: + # Predict model output. + model_output = self.unet( + sample, + torch.full(sample.shape[:1], t, dtype=torch.long, device=sample.device), + global_cond=global_cond, + ) + # Compute previous image: x_t -> x_t-1 + sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample + + return sample + + def generate_actions(self, batch: dict[str, Tensor]) -> Tensor: + """ + This function expects `batch` to have (at least): + { + "observation.state": (B, n_obs_steps, state_dim) + "observation.image": (B, n_obs_steps, C, H, W) + } + """ + assert set(batch).issuperset({"observation.state", "observation.image"}) + batch_size, n_obs_steps = batch["observation.state"].shape[:2] + assert n_obs_steps == self.cfg.n_obs_steps + + # Extract image feature (first combine batch and sequence dims). + img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ...")) + # Separate batch and sequence dims. + img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size) + # Concatenate state and image features then flatten to (B, global_cond_dim). + global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1) + + # run sampling + sample = self.conditional_sample(batch_size, global_cond=global_cond) + + # `horizon` steps worth of actions (from the first observation). + actions = sample[..., : self.cfg.action_dim] + # Extract `n_action_steps` steps worth of actions (from the current observation). + start = n_obs_steps - 1 + end = start + self.cfg.n_action_steps + actions = actions[:, start:end] + + return actions + + def compute_loss(self, batch: dict[str, Tensor]) -> Tensor: + """ + This function expects `batch` to have (at least): + { + "observation.state": (B, n_obs_steps, state_dim) + "observation.image": (B, n_obs_steps, C, H, W) + "action": (B, horizon, action_dim) + "action_is_pad": (B, horizon) + } + """ + # Input validation. + assert set(batch).issuperset({"observation.state", "observation.image", "action", "action_is_pad"}) + batch_size, n_obs_steps = batch["observation.state"].shape[:2] + horizon = batch["action"].shape[1] + assert horizon == self.cfg.horizon + assert n_obs_steps == self.cfg.n_obs_steps + + # Extract image feature (first combine batch and sequence dims). + img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ...")) + # Separate batch and sequence dims. + img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size) + # Concatenate state and image features then flatten to (B, global_cond_dim). + global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1) + + trajectory = batch["action"] + + # Forward diffusion. + # Sample noise to add to the trajectory. + eps = torch.randn(trajectory.shape, device=trajectory.device) + # Sample a random noising timestep for each item in the batch. + timesteps = torch.randint( + low=0, + high=self.noise_scheduler.config.num_train_timesteps, + size=(trajectory.shape[0],), + device=trajectory.device, + ).long() + # Add noise to the clean trajectories according to the noise magnitude at each timestep. + noisy_trajectory = self.noise_scheduler.add_noise(trajectory, eps, timesteps) + + # Run the denoising network (that might denoise the trajectory, or attempt to predict the noise). + pred = self.unet(noisy_trajectory, timesteps, global_cond=global_cond) + + # Compute the loss. + # The target is either the original trajectory, or the noise. + if self.cfg.prediction_type == "epsilon": + target = eps + elif self.cfg.prediction_type == "sample": + target = batch["action"] + else: + raise ValueError(f"Unsupported prediction type {self.cfg.prediction_type}") + + loss = F.mse_loss(pred, target, reduction="none") + + # Mask loss wherever the action is padded with copies (edges of the dataset trajectory). + if "action_is_pad" in batch: + in_episode_bound = ~batch["action_is_pad"] + loss = loss * in_episode_bound.unsqueeze(-1) + + return loss.mean() + + +class _RgbEncoder(nn.Module): + """Encoder an RGB image into a 1D feature vector. + + Includes the ability to normalize and crop the image first. + """ + + def __init__(self, cfg: DiffusionConfig): + super().__init__() + # Set up optional preprocessing. + if all(v == 1.0 for v in chain(cfg.image_normalization_mean, cfg.image_normalization_std)): + self.normalizer = nn.Identity() + else: + self.normalizer = torchvision.transforms.Normalize( + mean=cfg.image_normalization_mean, std=cfg.image_normalization_std + ) + if cfg.crop_shape is not None: + self.do_crop = True + # Always use center crop for eval + self.center_crop = torchvision.transforms.CenterCrop(cfg.crop_shape) + if cfg.crop_is_random: + self.maybe_random_crop = torchvision.transforms.RandomCrop(cfg.crop_shape) + else: + self.maybe_random_crop = self.center_crop + else: + self.do_crop = False + + # Set up backbone. + backbone_model = getattr(torchvision.models, cfg.vision_backbone)( + pretrained=cfg.use_pretrained_backbone + ) + # Note: This assumes that the layer4 feature map is children()[-3] + # TODO(alexander-soare): Use a safer alternative. + self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2])) + if cfg.use_group_norm: + if cfg.use_pretrained_backbone: + raise ValueError( + "You can't replace BatchNorm in a pretrained model without ruining the weights!" + ) + self.backbone = _replace_submodules( + root_module=self.backbone, + predicate=lambda x: isinstance(x, nn.BatchNorm2d), + func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features), + ) + + # Set up pooling and final layers. + # Use a dry run to get the feature map shape. + with torch.inference_mode(): + feat_map_shape = tuple(self.backbone(torch.zeros(size=(1, 3, *cfg.image_size))).shape[1:]) + self.pool = SpatialSoftmax(feat_map_shape, num_kp=cfg.spatial_softmax_num_keypoints) + self.feature_dim = cfg.spatial_softmax_num_keypoints * 2 + self.out = nn.Linear(cfg.spatial_softmax_num_keypoints * 2, self.feature_dim) + self.relu = nn.ReLU() + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: (B, C, H, W) image tensor with pixel values in [0, 1]. + Returns: + (B, D) image feature. + """ + # Preprocess: normalize and maybe crop (if it was set up in the __init__). + x = self.normalizer(x) + if self.do_crop: + if self.training: # noqa: SIM108 + x = self.maybe_random_crop(x) + else: + # Always use center crop for eval. + x = self.center_crop(x) + # Extract backbone feature. + x = torch.flatten(self.pool(self.backbone(x)), start_dim=1) + # Final linear layer with non-linearity. + x = self.relu(self.out(x)) + return x + + +def _replace_submodules( + root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module] +) -> nn.Module: + """ + Args: + root_module: The module for which the submodules need to be replaced + predicate: Takes a module as an argument and must return True if the that module is to be replaced. + func: Takes a module as an argument and returns a new module to replace it with. + Returns: + The root module with its submodules replaced. + """ + if predicate(root_module): + return func(root_module) + + replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)] + for *parents, k in replace_list: + parent_module = root_module + if len(parents) > 0: + parent_module = root_module.get_submodule(".".join(parents)) + if isinstance(parent_module, nn.Sequential): + src_module = parent_module[int(k)] + else: + src_module = getattr(parent_module, k) + tgt_module = func(src_module) + if isinstance(parent_module, nn.Sequential): + parent_module[int(k)] = tgt_module + else: + setattr(parent_module, k, tgt_module) + # verify that all BN are replaced + assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)) + return root_module + + +class _SinusoidalPosEmb(nn.Module): + """1D sinusoidal positional embeddings as in Attention is All You Need.""" + + def __init__(self, dim: int): + super().__init__() + self.dim = dim + + def forward(self, x: Tensor) -> Tensor: + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x.unsqueeze(-1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class _Conv1dBlock(nn.Module): + """Conv1d --> GroupNorm --> Mish""" + + def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): + super().__init__() + + self.block = nn.Sequential( + nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), + nn.GroupNorm(n_groups, out_channels), + nn.Mish(), + ) + + def forward(self, x): + return self.block(x) + + +class _ConditionalUnet1D(nn.Module): + """A 1D convolutional UNet with FiLM modulation for conditioning. + + Note: this removes local conditioning as compared to the original diffusion policy code. + """ + + def __init__(self, cfg: DiffusionConfig, global_cond_dim: int): + super().__init__() + + self.cfg = cfg + + # Encoder for the diffusion timestep. + self.diffusion_step_encoder = nn.Sequential( + _SinusoidalPosEmb(cfg.diffusion_step_embed_dim), + nn.Linear(cfg.diffusion_step_embed_dim, cfg.diffusion_step_embed_dim * 4), + nn.Mish(), + nn.Linear(cfg.diffusion_step_embed_dim * 4, cfg.diffusion_step_embed_dim), + ) + + # The FiLM conditioning dimension. + cond_dim = cfg.diffusion_step_embed_dim + global_cond_dim + + # In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we + # just reverse these. + in_out = [(cfg.action_dim, cfg.down_dims[0])] + list( + zip(cfg.down_dims[:-1], cfg.down_dims[1:], strict=True) + ) + + # Unet encoder. + common_res_block_kwargs = { + "cond_dim": cond_dim, + "kernel_size": cfg.kernel_size, + "n_groups": cfg.n_groups, + "use_film_scale_modulation": cfg.use_film_scale_modulation, + } + self.down_modules = nn.ModuleList([]) + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (len(in_out) - 1) + self.down_modules.append( + nn.ModuleList( + [ + _ConditionalResidualBlock1D(dim_in, dim_out, **common_res_block_kwargs), + _ConditionalResidualBlock1D(dim_out, dim_out, **common_res_block_kwargs), + # Downsample as long as it is not the last block. + nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(), + ] + ) + ) + + # Processing in the middle of the auto-encoder. + self.mid_modules = nn.ModuleList( + [ + _ConditionalResidualBlock1D(cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs), + _ConditionalResidualBlock1D(cfg.down_dims[-1], cfg.down_dims[-1], **common_res_block_kwargs), + ] + ) + + # Unet decoder. + self.up_modules = nn.ModuleList([]) + for ind, (dim_out, dim_in) in enumerate(reversed(in_out[1:])): + is_last = ind >= (len(in_out) - 1) + self.up_modules.append( + nn.ModuleList( + [ + # dim_in * 2, because it takes the encoder's skip connection as well + _ConditionalResidualBlock1D(dim_in * 2, dim_out, **common_res_block_kwargs), + _ConditionalResidualBlock1D(dim_out, dim_out, **common_res_block_kwargs), + # Upsample as long as it is not the last block. + nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(), + ] + ) + ) + + self.final_conv = nn.Sequential( + _Conv1dBlock(cfg.down_dims[0], cfg.down_dims[0], kernel_size=cfg.kernel_size), + nn.Conv1d(cfg.down_dims[0], cfg.action_dim, 1), + ) + + def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor: + """ + Args: + x: (B, T, input_dim) tensor for input to the Unet. + timestep: (B,) tensor of (timestep_we_are_denoising_from - 1). + global_cond: (B, global_cond_dim) + output: (B, T, input_dim) + Returns: + (B, T, input_dim) diffusion model prediction. + """ + # For 1D convolutions we'll need feature dimension first. + x = einops.rearrange(x, "b t d -> b d t") + + timesteps_embed = self.diffusion_step_encoder(timestep) + + # If there is a global conditioning feature, concatenate it to the timestep embedding. + if global_cond is not None: + global_feature = torch.cat([timesteps_embed, global_cond], axis=-1) + else: + global_feature = timesteps_embed + + # Run encoder, keeping track of skip features to pass to the decoder. + encoder_skip_features: list[Tensor] = [] + for resnet, resnet2, downsample in self.down_modules: + x = resnet(x, global_feature) + x = resnet2(x, global_feature) + encoder_skip_features.append(x) + x = downsample(x) + + for mid_module in self.mid_modules: + x = mid_module(x, global_feature) + + # Run decoder, using the skip features from the encoder. + for resnet, resnet2, upsample in self.up_modules: + x = torch.cat((x, encoder_skip_features.pop()), dim=1) + x = resnet(x, global_feature) + x = resnet2(x, global_feature) + x = upsample(x) + + x = self.final_conv(x) + + x = einops.rearrange(x, "b d t -> b t d") + return x + + +class _ConditionalResidualBlock1D(nn.Module): + """ResNet style 1D convolutional block with FiLM modulation for conditioning.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + cond_dim: int, + kernel_size: int = 3, + n_groups: int = 8, + # Set to True to do scale modulation with FiLM as well as bias modulation (defaults to False meaning + # FiLM just modulates bias). + use_film_scale_modulation: bool = False, + ): + super().__init__() + + self.use_film_scale_modulation = use_film_scale_modulation + self.out_channels = out_channels + + self.conv1 = _Conv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups) + + # FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale. + cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels + self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels)) + + self.conv2 = _Conv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups) + + # A final convolution for dimension matching the residual (if needed). + self.residual_conv = ( + nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() + ) + + def forward(self, x: Tensor, cond: Tensor) -> Tensor: + """ + Args: + x: (B, in_channels, T) + cond: (B, cond_dim) + Returns: + (B, out_channels, T) + """ + out = self.conv1(x) + + # Get condition embedding. Unsqueeze for broadcasting to `out`, resulting in (B, out_channels, 1). + cond_embed = self.cond_encoder(cond).unsqueeze(-1) + if self.use_film_scale_modulation: + # Treat the embedding as a list of scales and biases. + scale = cond_embed[:, : self.out_channels] + bias = cond_embed[:, self.out_channels :] + out = scale * out + bias + else: + # Treat the embedding as biases. + out = out + cond_embed + + out = self.conv2(out) + out = out + self.residual_conv(x) + return out + + +class _EMA: + """ + Exponential Moving Average of models weights + """ + + def __init__(self, cfg: DiffusionConfig, model: nn.Module): + """ + @crowsonkb's notes on EMA Warmup: + If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan + to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), + gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 + at 215.4k steps). + Args: + inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. + power (float): Exponential factor of EMA warmup. Default: 2/3. + min_alpha (float): The minimum EMA decay rate. Default: 0. + """ + + self.averaged_model = model + self.averaged_model.eval() + self.averaged_model.requires_grad_(False) + + self.update_after_step = cfg.ema_update_after_step + self.inv_gamma = cfg.ema_inv_gamma + self.power = cfg.ema_power + self.min_alpha = cfg.ema_min_alpha + self.max_alpha = cfg.ema_max_alpha + + self.alpha = 0.0 + self.optimization_step = 0 + + def get_decay(self, optimization_step): + """ + Compute the decay factor for the exponential moving average. + """ + step = max(0, optimization_step - self.update_after_step - 1) + value = 1 - (1 + step / self.inv_gamma) ** -self.power + + if step <= 0: + return 0.0 + + return max(self.min_alpha, min(value, self.max_alpha)) + + @torch.no_grad() + def step(self, new_model): + self.alpha = self.get_decay(self.optimization_step) + + for module, ema_module in zip(new_model.modules(), self.averaged_model.modules(), strict=True): + # Iterate over immediate parameters only. + for param, ema_param in zip( + module.parameters(recurse=False), ema_module.parameters(recurse=False), strict=True + ): + if isinstance(param, dict): + raise RuntimeError("Dict parameter not supported") + if isinstance(module, _BatchNorm) or not param.requires_grad: + # Copy BatchNorm parameters, and non-trainable parameters directly. + ema_param.copy_(param.to(dtype=ema_param.dtype).data) + else: + ema_param.mul_(self.alpha) + ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.alpha) + + self.optimization_step += 1 diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py deleted file mode 100644 index 9785358b..00000000 --- a/lerobot/common/policies/diffusion/policy.py +++ /dev/null @@ -1,195 +0,0 @@ -import copy -import logging -import time -from collections import deque - -import hydra -import torch -from torch import nn - -from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy -from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler -from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder, RgbEncoder -from lerobot.common.policies.utils import populate_queues -from lerobot.common.utils import get_safe_torch_device - - -class DiffusionPolicy(nn.Module): - name = "diffusion" - - def __init__( - self, - cfg, - cfg_device, - cfg_noise_scheduler, - cfg_rgb_model, - cfg_obs_encoder, - cfg_optimizer, - cfg_ema, - shape_meta: dict, - horizon, - n_action_steps, - n_obs_steps, - num_inference_steps=None, - obs_as_global_cond=True, - diffusion_step_embed_dim=256, - down_dims=(256, 512, 1024), - kernel_size=5, - n_groups=8, - cond_predict_scale=True, - # parameters passed to step - **kwargs, - ): - super().__init__() - self.cfg = cfg - self.n_obs_steps = n_obs_steps - self.n_action_steps = n_action_steps - # queues are populated during rollout of the policy, they contain the n latest observations and actions - self._queues = None - - noise_scheduler = hydra.utils.instantiate(cfg_noise_scheduler) - rgb_model_input_shape = copy.deepcopy(shape_meta.obs.image.shape) - if cfg_obs_encoder.crop_shape is not None: - rgb_model_input_shape[1:] = cfg_obs_encoder.crop_shape - rgb_model = RgbEncoder(input_shape=rgb_model_input_shape, **cfg_rgb_model) - obs_encoder = MultiImageObsEncoder( - rgb_model=rgb_model, - **cfg_obs_encoder, - ) - - self.diffusion = DiffusionUnetImagePolicy( - shape_meta=shape_meta, - noise_scheduler=noise_scheduler, - obs_encoder=obs_encoder, - horizon=horizon, - n_action_steps=n_action_steps, - n_obs_steps=n_obs_steps, - num_inference_steps=num_inference_steps, - obs_as_global_cond=obs_as_global_cond, - diffusion_step_embed_dim=diffusion_step_embed_dim, - down_dims=down_dims, - kernel_size=kernel_size, - n_groups=n_groups, - cond_predict_scale=cond_predict_scale, - # parameters passed to step - **kwargs, - ) - - self.device = get_safe_torch_device(cfg_device) - self.diffusion.to(self.device) - - self.ema_diffusion = None - self.ema = None - if self.cfg.use_ema: - self.ema_diffusion = copy.deepcopy(self.diffusion) - self.ema = hydra.utils.instantiate( - cfg_ema, - model=self.ema_diffusion, - ) - - self.optimizer = hydra.utils.instantiate( - cfg_optimizer, - params=self.diffusion.parameters(), - ) - - # TODO(rcadene): modify lr scheduler so that it doesnt depend on epochs but steps - self.global_step = 0 - - # configure lr scheduler - self.lr_scheduler = get_scheduler( - cfg.lr_scheduler, - optimizer=self.optimizer, - num_warmup_steps=cfg.lr_warmup_steps, - num_training_steps=cfg.offline_steps, - # pytorch assumes stepping LRScheduler every epoch - # however huggingface diffusers steps it every batch - last_epoch=self.global_step - 1, - ) - - def reset(self): - """ - Clear observation and action queues. Should be called on `env.reset()` - """ - self._queues = { - "observation.image": deque(maxlen=self.n_obs_steps), - "observation.state": deque(maxlen=self.n_obs_steps), - "action": deque(maxlen=self.n_action_steps), - } - - @torch.no_grad() - def select_action(self, batch, step): - """ - Note: this uses the ema model weights if self.training == False, otherwise the non-ema model weights. - """ - # TODO(rcadene): remove unused step - del step - assert "observation.image" in batch - assert "observation.state" in batch - assert len(batch) == 2 - - self._queues = populate_queues(self._queues, batch) - - if len(self._queues["action"]) == 0: - # stack n latest observations from the queue - batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} - - obs_dict = { - "image": batch["observation.image"], - "agent_pos": batch["observation.state"], - } - if self.training: - out = self.diffusion.predict_action(obs_dict) - else: - out = self.ema_diffusion.predict_action(obs_dict) - self._queues["action"].extend(out["action"].transpose(0, 1)) - - action = self._queues["action"].popleft() - return action - - def forward(self, batch, step): - start_time = time.time() - - self.diffusion.train() - - loss = self.diffusion.compute_loss(batch) - loss.backward() - - grad_norm = torch.nn.utils.clip_grad_norm_( - self.diffusion.parameters(), - self.cfg.grad_clip_norm, - error_if_nonfinite=False, - ) - - self.optimizer.step() - self.optimizer.zero_grad() - self.lr_scheduler.step() - - if self.ema is not None: - self.ema.step(self.diffusion) - - info = { - "loss": loss.item(), - "grad_norm": float(grad_norm), - "lr": self.lr_scheduler.get_last_lr()[0], - "update_s": time.time() - start_time, - } - - # TODO(rcadene): remove hardcoding - # in diffusion_policy, len(dataloader) is 168 for a batch_size of 64 - if step % 168 == 0: - self.global_step += 1 - - return info - - def save(self, fp): - torch.save(self.state_dict(), fp) - - def load(self, fp): - d = torch.load(fp) - missing_keys, unexpected_keys = self.load_state_dict(d, strict=False) - if len(missing_keys) > 0: - assert all(k.startswith("ema_diffusion.") for k in missing_keys) - logging.warning( - "DiffusionPolicy.load expected ema parameters in loaded state dict but none were found." - ) - assert len(unexpected_keys) == 0 diff --git a/lerobot/common/policies/diffusion/pytorch_utils.py b/lerobot/common/policies/diffusion/pytorch_utils.py deleted file mode 100644 index ed5dc23a..00000000 --- a/lerobot/common/policies/diffusion/pytorch_utils.py +++ /dev/null @@ -1,76 +0,0 @@ -from typing import Callable, Dict - -import torch -import torch.nn as nn -import torchvision - - -def get_resnet(name, weights=None, **kwargs): - """ - name: resnet18, resnet34, resnet50 - weights: "IMAGENET1K_V1", "r3m" - """ - # load r3m weights - if (weights == "r3m") or (weights == "R3M"): - return get_r3m(name=name, **kwargs) - - func = getattr(torchvision.models, name) - resnet = func(weights=weights, **kwargs) - resnet.fc = torch.nn.Identity() - return resnet - - -def get_r3m(name, **kwargs): - """ - name: resnet18, resnet34, resnet50 - """ - import r3m - - r3m.device = "cpu" - model = r3m.load_r3m(name) - r3m_model = model.module - resnet_model = r3m_model.convnet - resnet_model = resnet_model.to("cpu") - return resnet_model - - -def dict_apply( - x: Dict[str, torch.Tensor], func: Callable[[torch.Tensor], torch.Tensor] -) -> Dict[str, torch.Tensor]: - result = {} - for key, value in x.items(): - if isinstance(value, dict): - result[key] = dict_apply(value, func) - else: - result[key] = func(value) - return result - - -def replace_submodules( - root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module] -) -> nn.Module: - """ - predicate: Return true if the module is to be replaced. - func: Return new module to use. - """ - if predicate(root_module): - return func(root_module) - - bn_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)] - for *parent, k in bn_list: - parent_module = root_module - if len(parent) > 0: - parent_module = root_module.get_submodule(".".join(parent)) - if isinstance(parent_module, nn.Sequential): - src_module = parent_module[int(k)] - else: - src_module = getattr(parent_module, k) - tgt_module = func(src_module) - if isinstance(parent_module, nn.Sequential): - parent_module[int(k)] = tgt_module - else: - setattr(parent_module, k, tgt_module) - # verify that all BN are replaced - bn_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)] - assert len(bn_list) == 0 - return root_module diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 80ae27da..b5b5f861 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -1,61 +1,61 @@ import inspect -from omegaconf import OmegaConf +from omegaconf import DictConfig, OmegaConf from lerobot.common.utils import get_safe_torch_device -def make_policy(cfg): - if cfg.policy.name == "tdmpc": +def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg): + expected_kwargs = set(inspect.signature(policy_cfg_class).parameters) + assert set(hydra_cfg.policy).issuperset( + expected_kwargs + ), f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}" + policy_cfg = policy_cfg_class( + **{ + k: v + for k, v in OmegaConf.to_container(hydra_cfg.policy, resolve=True).items() + if k in expected_kwargs + } + ) + return policy_cfg + + +def make_policy(hydra_cfg: DictConfig): + if hydra_cfg.policy.name == "tdmpc": from lerobot.common.policies.tdmpc.policy import TDMPCPolicy policy = TDMPCPolicy( - cfg.policy, n_obs_steps=cfg.n_obs_steps, n_action_steps=cfg.n_action_steps, device=cfg.device + hydra_cfg.policy, + n_obs_steps=hydra_cfg.n_obs_steps, + n_action_steps=hydra_cfg.n_action_steps, + device=hydra_cfg.device, ) - elif cfg.policy.name == "diffusion": - from lerobot.common.policies.diffusion.policy import DiffusionPolicy + elif hydra_cfg.policy.name == "diffusion": + from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig + from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy - policy = DiffusionPolicy( - cfg=cfg.policy, - cfg_device=cfg.device, - cfg_noise_scheduler=cfg.noise_scheduler, - cfg_rgb_model=cfg.rgb_model, - cfg_obs_encoder=cfg.obs_encoder, - cfg_optimizer=cfg.optimizer, - cfg_ema=cfg.ema, - # n_obs_steps=cfg.n_obs_steps, - # n_action_steps=cfg.n_action_steps, - **cfg.policy, - ) - elif cfg.policy.name == "act": + policy_cfg = _policy_cfg_from_hydra_cfg(DiffusionConfig, hydra_cfg) + policy = DiffusionPolicy(policy_cfg, hydra_cfg.offline_steps) + policy.to(get_safe_torch_device(hydra_cfg.device)) + elif hydra_cfg.policy.name == "act": from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy - expected_kwargs = set(inspect.signature(ActionChunkingTransformerConfig).parameters) - assert set(cfg.policy).issuperset( - expected_kwargs - ), f"Hydra config is missing arguments: {set(cfg.policy).difference(expected_kwargs)}" - policy_cfg = ActionChunkingTransformerConfig( - **{ - k: v - for k, v in OmegaConf.to_container(cfg.policy, resolve=True).items() - if k in expected_kwargs - } - ) + policy_cfg = _policy_cfg_from_hydra_cfg(ActionChunkingTransformerConfig, hydra_cfg) policy = ActionChunkingTransformerPolicy(policy_cfg) - policy.to(get_safe_torch_device(cfg.device)) + policy.to(get_safe_torch_device(hydra_cfg.device)) else: - raise ValueError(cfg.policy.name) + raise ValueError(hydra_cfg.policy.name) - if cfg.policy.pretrained_model_path: + if hydra_cfg.policy.pretrained_model_path: # TODO(rcadene): hack for old pretrained models from fowm - if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path: - if "offline" in cfg.policy.pretrained_model_path: + if hydra_cfg.policy.name == "tdmpc" and "fowm" in hydra_cfg.policy.pretrained_model_path: + if "offline" in hydra_cfg.policy.pretrained_model_path: policy.step[0] = 25000 - elif "final" in cfg.policy.pretrained_model_path: + elif "final" in hydra_cfg.policy.pretrained_model_path: policy.step[0] = 100000 else: raise NotImplementedError() - policy.load(cfg.policy.pretrained_model_path) + policy.load(hydra_cfg.policy.pretrained_model_path) return policy diff --git a/lerobot/common/policies/utils.py b/lerobot/common/policies/utils.py index b0503fe4..b23c1336 100644 --- a/lerobot/common/policies/utils.py +++ b/lerobot/common/policies/utils.py @@ -1,3 +1,7 @@ +import torch +from torch import nn + + def populate_queues(queues, batch): for key in batch: if len(queues[key]) != queues[key].maxlen: @@ -8,3 +12,19 @@ def populate_queues(queues, batch): # add latest observation to the queue queues[key].append(batch[key]) return queues + + +def get_device_from_parameters(module: nn.Module) -> torch.device: + """Get a module's device by checking one of its parameters. + + Note: assumes that all parameters have the same device + """ + return next(iter(module.parameters())).device + + +def get_dtype_from_parameters(module: nn.Module) -> torch.dtype: + """Get a module's parameter dtype by checking one of its parameters. + + Note: assumes that all parameters have the same dtype. + """ + return next(iter(module.parameters())).dtype diff --git a/lerobot/common/utils.py b/lerobot/common/utils.py index 373a3bbc..81b3d986 100644 --- a/lerobot/common/utils.py +++ b/lerobot/common/utils.py @@ -11,6 +11,7 @@ from omegaconf import DictConfig def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device: + """Given a string, return a torch.device with checks on whether the device is available.""" match cfg_device: case "cuda": assert torch.cuda.is_available() diff --git a/lerobot/configs/policy/act.yaml b/lerobot/configs/policy/act.yaml index bd883613..5dd70d71 100644 --- a/lerobot/configs/policy/act.yaml +++ b/lerobot/configs/policy/act.yaml @@ -18,7 +18,7 @@ policy: pretrained_model_path: # Environment. - # Inherit these from the environment. + # Inherit these from the environment config. state_dim: ??? action_dim: ??? diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 811ee824..44746dfc 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -1,17 +1,5 @@ # @package _global_ -shape_meta: - # acceptable types: rgb, low_dim - obs: - image: - shape: [3, 96, 96] - type: rgb - agent_pos: - shape: [2] - type: low_dim - action: - shape: [2] - seed: 100000 horizon: 16 n_obs_steps: 2 @@ -19,7 +7,6 @@ n_action_steps: 8 dataset_obs_steps: ${n_obs_steps} past_action_visible: False keypoint_visible_rate: 1.0 -obs_as_global_cond: True eval_episodes: 50 eval_freq: 5000 @@ -34,76 +21,70 @@ offline_prioritized_sampler: true policy: name: diffusion - shape_meta: ${shape_meta} + pretrained_model_path: - horizon: ${horizon} + # Environment. + # Inherit these from the environment config. + state_dim: ??? + action_dim: ??? + image_size: + - ${env.image_size} # height + - ${env.image_size} # width + + # Inputs / output structure. n_obs_steps: ${n_obs_steps} + horizon: ${horizon} n_action_steps: ${n_action_steps} - num_inference_steps: 100 - obs_as_global_cond: ${obs_as_global_cond} - # crop_shape: null - diffusion_step_embed_dim: 128 + + # Vision preprocessing. + image_normalization_mean: [0.5, 0.5, 0.5] + image_normalization_std: [0.5, 0.5, 0.5] + + # Architecture / modeling. + # Vision backbone. + vision_backbone: resnet18 + crop_shape: [84, 84] + crop_is_random: True + use_pretrained_backbone: false + use_group_norm: True + spatial_softmax_num_keypoints: 32 + # Unet. down_dims: [512, 1024, 2048] kernel_size: 5 n_groups: 8 - cond_predict_scale: True - - pretrained_model_path: - - batch_size: 64 - - per_alpha: 0.6 - per_beta: 0.4 - - balanced_sampling: false - utd: 1 - offline_steps: ${offline_steps} - use_ema: true - lr_scheduler: cosine - lr_warmup_steps: 500 - grad_clip_norm: 10 - - delta_timestamps: - observation.image: [-0.1, 0] - observation.state: [-0.1, 0] - action: [-0.1, 0, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1.0, 1.1, 1.2, 1.3, 1.4] - -noise_scheduler: - _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler + diffusion_step_embed_dim: 128 + use_film_scale_modulation: True + # Noise scheduler. num_train_timesteps: 100 + beta_schedule: squaredcos_cap_v2 beta_start: 0.0001 beta_end: 0.02 - beta_schedule: squaredcos_cap_v2 - variance_type: fixed_small # Yilun's paper uses fixed_small_log instead, but easy to cause Nan - clip_sample: True # required when predict_epsilon=False - prediction_type: epsilon # or sample + prediction_type: epsilon # epsilon / sample + clip_sample: True + clip_sample_range: 1.0 -obs_encoder: - shape_meta: ${shape_meta} - # resize_shape: null - crop_shape: [84, 84] - # constant center crop - random_crop: True - use_group_norm: True - share_rgb_model: False - norm_mean_std: [0.5, 0.5] # for PushT the original impl normalizes to [-1, 1] (maybe not the case for robomimic envs) + # Inference + num_inference_steps: 100 -rgb_model: - pretrained: false - num_keypoints: 32 - relu: true - -ema: - _target_: lerobot.common.policies.diffusion.model.ema_model.EMAModel - update_after_step: 0 - inv_gamma: 1.0 - power: 0.75 - min_value: 0.0 - max_value: 0.9999 - -optimizer: - _target_: torch.optim.AdamW + # --- + # TODO(alexander-soare): Remove these from the policy config. + batch_size: 64 + grad_clip_norm: 10 lr: 1.0e-4 - betas: [0.95, 0.999] - eps: 1.0e-8 - weight_decay: 1.0e-6 + lr_scheduler: cosine + lr_warmup_steps: 500 + adam_betas: [0.95, 0.999] + adam_eps: 1.0e-8 + adam_weight_decay: 1.0e-6 + utd: 1 + use_ema: true + ema_update_after_step: 0 + ema_min_alpha: 0.0 + ema_max_alpha: 0.9999 + ema_inv_gamma: 1.0 + ema_power: 0.75 + + delta_timestamps: + observation.image: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]" + observation.state: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1)]" + action: "[i / ${fps} for i in range(1 - ${n_obs_steps}, 1 - ${n_obs_steps} + ${policy.horizon})]" diff --git a/tests/test_available.py b/tests/test_available.py index b25a921f..373cc1a7 100644 --- a/tests/test_available.py +++ b/tests/test_available.py @@ -19,7 +19,7 @@ from lerobot.common.datasets.aloha import AlohaDataset from lerobot.common.datasets.pusht import PushtDataset from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy -from lerobot.common.policies.diffusion.policy import DiffusionPolicy +from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.tdmpc.policy import TDMPCPolicy diff --git a/tests/test_examples.py b/tests/test_examples.py index 4263e452..c510eb1e 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1,8 +1,8 @@ from pathlib import Path -def _find_and_replace(text: str, finds: list[str], replaces: list[str]) -> str: - for f, r in zip(finds, replaces): +def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> str: + for f, r in finds_and_replaces: assert f in text text = text.replace(f, r) return text @@ -29,14 +29,19 @@ def test_examples_3_and_2(): with open(path, "r") as file: file_contents = file.read() - # Do less steps and use CPU. + # Do less steps, use smaller batch, use CPU, and don't complicate things with dataloader workers. file_contents = _find_and_replace( file_contents, - ['"offline_steps=5000"', '"device=cuda"'], - ['"offline_steps=1"', '"device=cpu"'], + [ + ("training_steps = 5000", "training_steps = 1"), + ("num_workers=4", "num_workers=0"), + ('device = torch.device("cuda")', 'device = torch.device("cpu")'), + ("batch_size=cfg.batch_size", "batch_size=1"), + ], ) - exec(file_contents) + # Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249. + exec(file_contents, {}) for file_name in ["model.pt", "stats.pth", "config.yaml"]: assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists() @@ -50,20 +55,15 @@ def test_examples_3_and_2(): file_contents = _find_and_replace( file_contents, [ - '"eval_episodes=10"', - '"rollout_batch_size=10"', - '"device=cuda"', - '# folder = Path("outputs/train/example_pusht_diffusion")', - 'hub_id = "lerobot/diffusion_policy_pusht_image"', - "folder = Path(snapshot_download(hub_id)", - ], - [ - '"eval_episodes=1"', - '"rollout_batch_size=1"', - '"device=cpu"', - 'folder = Path("outputs/train/example_pusht_diffusion")', - "", - "", + ('"eval_episodes=10"', '"eval_episodes=1"'), + ('"rollout_batch_size=10"', '"rollout_batch_size=1"'), + ('"device=cuda"', '"device=cpu"'), + ( + '# folder = Path("outputs/train/example_pusht_diffusion")', + 'folder = Path("outputs/train/example_pusht_diffusion")', + ), + ('hub_id = "lerobot/diffusion_policy_pusht_image"', ""), + ("folder = Path(snapshot_download(hub_id)", ""), ], )