From 7d1542cae17d4cc66f3787bb9a0865638a6590be Mon Sep 17 00:00:00 2001 From: Seungjae Lee <30570922+jayLEE0301@users.noreply.github.com> Date: Wed, 26 Jun 2024 16:55:02 +0900 Subject: [PATCH] Add VQ-BeT (#166) --- lerobot/__init__.py | 3 +- lerobot/common/policies/factory.py | 5 + .../policies/vqbet/configuration_vqbet.py | 149 ++ .../common/policies/vqbet/modeling_vqbet.py | 932 +++++++++++ lerobot/common/policies/vqbet/vqbet_utils.py | 1444 +++++++++++++++++ lerobot/configs/policy/vqbet.yaml | 104 ++ lerobot/scripts/train.py | 5 + tests/test_available.py | 2 + tests/test_policies.py | 1 + 9 files changed, 2644 insertions(+), 1 deletion(-) create mode 100644 lerobot/common/policies/vqbet/configuration_vqbet.py create mode 100644 lerobot/common/policies/vqbet/modeling_vqbet.py create mode 100644 lerobot/common/policies/vqbet/vqbet_utils.py create mode 100644 lerobot/configs/policy/vqbet.yaml diff --git a/lerobot/__init__.py b/lerobot/__init__.py index a5a90fb4..3963fe61 100644 --- a/lerobot/__init__.py +++ b/lerobot/__init__.py @@ -134,12 +134,13 @@ available_policies = [ "act", "diffusion", "tdmpc", + "vqbet", ] # keys and values refer to yaml files available_policies_per_env = { "aloha": ["act"], - "pusht": ["diffusion"], + "pusht": ["diffusion", "vqbet"], "xarm": ["tdmpc"], "dora_aloha_real": ["act_real"], } diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 4c124b61..124e8c68 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -55,6 +55,11 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]: from lerobot.common.policies.act.modeling_act import ACTPolicy return ACTPolicy, ACTConfig + elif name == "vqbet": + from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig + from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy + + return VQBeTPolicy, VQBeTConfig else: raise NotImplementedError(f"Policy with name {name} is not implemented.") diff --git a/lerobot/common/policies/vqbet/configuration_vqbet.py b/lerobot/common/policies/vqbet/configuration_vqbet.py new file mode 100644 index 00000000..9b2d6a7e --- /dev/null +++ b/lerobot/common/policies/vqbet/configuration_vqbet.py @@ -0,0 +1,149 @@ +from dataclasses import dataclass, field + + +@dataclass +class VQBeTConfig: + """Configuration class for VQ-BeT. + + Defaults are configured for training with PushT providing proprioceptive and single camera observations. + + The parameters you will most likely need to change are the ones which depend on the environment / sensors. + Those are: `input_shapes` and `output_shapes`. + + Notes on the inputs and outputs: + - "observation.state" is required as an input key. + - At least one key starting with "observation.image is required as an input. + - If there are multiple keys beginning with "observation.image" they are treated as multiple camera + views. Right now we only support all images having the same shape. + - "action" is required as an output key. + + Args: + n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the + current step and additional steps going back). + n_action_pred_token: Total number of current token and future tokens that VQ-BeT predicts. + action_chunk_size: Action chunk size of each action prediction token. + input_shapes: A dictionary defining the shapes of the input data for the policy. + The key represents the input data name, and the value is a list indicating the dimensions + of the corresponding data. For example, "observation.image" refers to an input from + a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution. + Importantly, shapes doesnt include batch dimension or temporal dimension. + output_shapes: A dictionary defining the shapes of the output data for the policy. + The key represents the output data name, and the value is a list indicating the dimensions + of the corresponding data. For example, "action" refers to an output shape of [14], indicating + 14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension. + input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), + and the value specifies the normalization mode to apply. The two available modes are "mean_std" + which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a + [-1, 1] range. + output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the + original scale. Note that this is also used for normalizing the training targets. + vision_backbone: Name of the torchvision resnet backbone to use for encoding images. + crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit + within the image size. If None, no cropping is done. + crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval + mode). + pretrained_backbone_weights: Pretrained weights from torchvision to initalize the backbone. + `None` means no pretrained weights. + use_group_norm: Whether to replace batch normalization with group normalization in the backbone. + The group sizes are set to be about 16 (to be precise, feature_dim // 16). + spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax. + n_vqvae_training_steps: Number of optimization steps for training Residual VQ. + vqvae_n_embed: Number of embedding vectors in the RVQ dictionary (each layer). + vqvae_embedding_dim: Dimension of each embedding vector in the RVQ dictionary. + vqvae_enc_hidden_dim: Size of hidden dimensions of Encoder / Decoder part of Residaul VQ-VAE + gpt_block_size: Max block size of minGPT (should be larger than the number of input tokens) + gpt_input_dim: Size of output input of GPT. This is also used as the dimension of observation features. + gpt_output_dim: Size of output dimension of GPT. This is also used as a input dimension of offset / bin prediction headers. + gpt_n_layer: Number of layers of GPT + gpt_n_head: Number of headers of GPT + gpt_hidden_dim: Size of hidden dimensions of GPT + dropout: Dropout rate for GPT + mlp_hidden_dim: Size of hidden dimensions of offset header / bin prediction headers parts of VQ-BeT + offset_loss_weight: A constant that is multiplied to the offset loss + primary_code_loss_weight: A constant that is multiplied to the primary code prediction loss + secondary_code_loss_weight: A constant that is multiplied to the secondary code prediction loss + bet_softmax_temperature: Sampling temperature of code for rollout with VQ-BeT + sequentially_select: Whether select code of primary / secondary as sequentially (pick primary code, + and then select secodnary code), or at the same time. + """ + + # Inputs / output structure. + n_obs_steps: int = 5 + n_action_pred_token: int = 3 + action_chunk_size: int = 5 + + input_shapes: dict[str, list[int]] = field( + default_factory=lambda: { + "observation.image": [3, 96, 96], + "observation.state": [2], + } + ) + output_shapes: dict[str, list[int]] = field( + default_factory=lambda: { + "action": [2], + } + ) + + # Normalization / Unnormalization + input_normalization_modes: dict[str, str] = field( + default_factory=lambda: { + "observation.image": "mean_std", + "observation.state": "min_max", + } + ) + output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"}) + + # Architecture / modeling. + # Vision backbone. + vision_backbone: str = "resnet18" + crop_shape: tuple[int, int] | None = (84, 84) + crop_is_random: bool = True + pretrained_backbone_weights: str | None = None + use_group_norm: bool = True + spatial_softmax_num_keypoints: int = 32 + # VQ-VAE + n_vqvae_training_steps: int = 20000 + vqvae_n_embed: int = 16 + vqvae_embedding_dim: int = 256 + vqvae_enc_hidden_dim: int = 128 + # VQ-BeT + gpt_block_size: int = 500 + gpt_input_dim: int = 512 + gpt_output_dim: int = 512 + gpt_n_layer: int = 8 + gpt_n_head: int = 8 + gpt_hidden_dim: int = 512 + dropout: float = 0.1 + mlp_hidden_dim: int = 1024 + offset_loss_weight: float = 10000.0 + primary_code_loss_weight: float = 5.0 + secondary_code_loss_weight: float = 0.5 + bet_softmax_temperature: float = 0.1 + sequentially_select: bool = False + + def __post_init__(self): + """Input validation (not exhaustive).""" + if not self.vision_backbone.startswith("resnet"): + raise ValueError( + f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}." + ) + image_keys = {k for k in self.input_shapes if k.startswith("observation.image")} + if self.crop_shape is not None: + for image_key in image_keys: + if ( + self.crop_shape[0] > self.input_shapes[image_key][1] + or self.crop_shape[1] > self.input_shapes[image_key][2] + ): + raise ValueError( + f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} " + f"for `crop_shape` and {self.input_shapes[image_key]} for " + "`input_shapes[{image_key}]`." + ) + # Check that all input images have the same shape. + first_image_key = next(iter(image_keys)) + for image_key in image_keys: + if self.input_shapes[image_key] != self.input_shapes[first_image_key]: + raise ValueError( + f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we " + "expect all image shapes to match." + ) diff --git a/lerobot/common/policies/vqbet/modeling_vqbet.py b/lerobot/common/policies/vqbet/modeling_vqbet.py new file mode 100644 index 00000000..ef2b4f2a --- /dev/null +++ b/lerobot/common/policies/vqbet/modeling_vqbet.py @@ -0,0 +1,932 @@ +import math +import warnings +from collections import deque +from typing import Callable, List + +import einops +import numpy as np +import torch +import torch.nn.functional as F # noqa: N812 +import torchvision +from huggingface_hub import PyTorchModelHubMixin +from torch import Tensor, nn +from torch.optim.lr_scheduler import LambdaLR + +from lerobot.common.policies.normalize import Normalize, Unnormalize +from lerobot.common.policies.utils import get_device_from_parameters, populate_queues +from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig +from lerobot.common.policies.vqbet.vqbet_utils import GPT, ResidualVQ + +# ruff: noqa: N806 + + +class VQBeTPolicy(nn.Module, PyTorchModelHubMixin): + """ + VQ-BeT Policy as per "Behavior Generation with Latent Actions" + """ + + name = "vqbet" + + def __init__( + self, + config: VQBeTConfig | None = None, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, + ): + """ + Args: + config: Policy configuration class instance or None, in which case the default instantiation of + the configuration class is used. + dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected + that they will be passed with a call to `load_state_dict` before the policy is used. + """ + super().__init__() + if config is None: + config = VQBeTConfig() + self.config = config + self.normalize_inputs = Normalize( + config.input_shapes, config.input_normalization_modes, dataset_stats + ) + self.normalize_targets = Normalize( + config.output_shapes, config.output_normalization_modes, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_shapes, config.output_normalization_modes, dataset_stats + ) + + self.vqbet = VQBeTModel(config) + + self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + + self.reset() + + def reset(self): + """ + Clear observation and action queues. Should be called on `env.reset()` + queues are populated during rollout of the policy, they contain the n latest observations and actions + """ + self._queues = { + "observation.images": deque(maxlen=self.config.n_obs_steps), + "observation.state": deque(maxlen=self.config.n_obs_steps), + "action": deque(maxlen=self.config.action_chunk_size), + } + + @torch.no_grad + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """Select a single action given environment observations. + + This method wraps `select_actions` in order to return one action at a time for execution in the + environment. It works by managing the actions in a queue and only calling `select_actions` when the + queue is empty. + """ + + batch = self.normalize_inputs(batch) + batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) + # Note: It's important that this happens after stacking the images into a single key. + self._queues = populate_queues(self._queues, batch) + + if not self.vqbet.action_head.vqvae_model.discretized.item(): + warnings.warn( + "To evaluate in the environment, your VQ-BeT model should contain a pretrained Residual VQ.", + stacklevel=1, + ) + + if len(self._queues["action"]) == 0: + batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} + actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size] + + # the dimension of returned action is (batch_size, action_chunk_size, action_dim) + actions = self.unnormalize_outputs({"action": actions})["action"] + # since the data in the action queue's dimension is (action_chunk_size, batch_size, action_dim), we transpose the action and fill the queue + self._queues["action"].extend(actions.transpose(0, 1)) + + action = self._queues["action"].popleft() + return action + + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + """Run the batch through the model and compute the loss for training or validation.""" + batch = self.normalize_inputs(batch) + batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) + batch = self.normalize_targets(batch) + # VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181) + if not self.vqbet.action_head.vqvae_model.discretized.item(): + # loss: total loss of training RVQ + # n_different_codes: how many of the total possible VQ codes are being used in single batch (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed * number of layers of RVQ (=2)`. + # n_different_combinations: how many different code combinations are being used out of all possible combinations in single batch. This can be at most `vqvae_n_embed ^ number of layers of RVQ (=2)` (hint consider the RVQ as a decision tree). + loss, n_different_codes, n_different_combinations, recon_l1_error = ( + self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"]) + ) + return { + "loss": loss, + "n_different_codes": n_different_codes, + "n_different_combinations": n_different_combinations, + "recon_l1_error": recon_l1_error, + } + # if Residual VQ is already trained, VQ-BeT trains its GPT and bin prediction head / offset prediction head parts. + _, loss_dict = self.vqbet(batch, rollout=False) + + return loss_dict + + +class SpatialSoftmax(nn.Module): + """ + Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al. + (https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation. + + At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass" + of activations of each channel, i.e., keypoints in the image space for the policy to focus on. + + Example: take feature maps of size (512x10x12). We generate a grid of normalized coordinates (10x12x2): + ----------------------------------------------------- + | (-1., -1.) | (-0.82, -1.) | ... | (1., -1.) | + | (-1., -0.78) | (-0.82, -0.78) | ... | (1., -0.78) | + | ... | ... | ... | ... | + | (-1., 1.) | (-0.82, 1.) | ... | (1., 1.) | + ----------------------------------------------------- + This is achieved by applying channel-wise softmax over the activations (512x120) and computing the dot + product with the coordinates (120x2) to get expected points of maximal activation (512x2). + + The example above results in 512 keypoints (corresponding to the 512 input channels). We can optionally + provide num_kp != None to control the number of keypoints. This is achieved by a first applying a learnable + linear mapping (in_channels, H, W) -> (num_kp, H, W). + """ + + def __init__(self, input_shape, num_kp=None): + """ + Args: + input_shape (list): (C, H, W) input feature map shape. + num_kp (int): number of keypoints in output. If None, output will have the same number of channels as input. + """ + super().__init__() + + assert len(input_shape) == 3 + self._in_c, self._in_h, self._in_w = input_shape + + if num_kp is not None: + self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1) + self._out_c = num_kp + else: + self.nets = None + self._out_c = self._in_c + + # we could use torch.linspace directly but that seems to behave slightly differently than numpy + # and causes a small degradation in pc_success of pre-trained models. + pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)) + pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float() + pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float() + # register as buffer so it's moved to the correct device. + self.register_buffer("pos_grid", torch.cat([pos_x, pos_y], dim=1)) + + def forward(self, features: Tensor) -> Tensor: + """ + Args: + features: (B, C, H, W) input feature maps. + Returns: + (B, K, 2) image-space coordinates of keypoints. + """ + if self.nets is not None: + features = self.nets(features) + + # [B, K, H, W] -> [B * K, H * W] where K is number of keypoints + features = features.reshape(-1, self._in_h * self._in_w) + # 2d softmax normalization + attention = F.softmax(features, dim=-1) + # [B * K, H * W] x [H * W, 2] -> [B * K, 2] for spatial coordinate mean in x and y dimensions + expected_xy = attention @ self.pos_grid + # reshape to [B, K, 2] + feature_keypoints = expected_xy.view(-1, self._out_c, 2) + + return feature_keypoints + + +class VQBeTModel(nn.Module): + """VQ-BeT: The underlying neural network for VQ-BeT + + Note: In this code we use the terms `rgb_encoder`, 'policy', `action_head`. The meanings are as follows. + - The `rgb_encoder` process rgb-style image observations to one-dimensional embedding vectors + - A `policy` is a minGPT architecture, that takes observation sequences and action query tokens to generate `features`. + - These `features` pass through the action head, which passes through the code prediction, offset prediction head, + and finally generates a prediction for the action chunks. + + -------------------------------** legend **------------------------------- + │ n = n_obs_steps, p = n_action_pred_token, c = action_chunk_size) │ + │ o_{t} : visual observation at timestep {t} │ + │ s_{t} : state observation at timestep {t} │ + │ a_{t} : action at timestep {t} │ + │ A_Q : action_query_token │ + -------------------------------------------------------------------------- + + + Training Phase 1. Discretize action using Residual VQ (for config.n_vqvae_training_steps steps) + + + ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ + │ │ │ │ │ │ + │ RVQ encoder │ ─► │ Residual │ ─► │ RVQ Decoder │ + │ (a_{t}~a_{t+p}) │ │ Code Quantizer │ │ │ + │ │ │ │ │ │ + └─────────────────┘ └─────────────────┘ └─────────────────┘ + + Training Phase 2. + + timestep {t-n+1} timestep {t-n+2} timestep {t} + ┌─────┴─────┐ ┌─────┴─────┐ ┌─────┴─────┐ + + o_{t-n+1} o_{t-n+2} ... o_{t} + │ │ │ + │ s_{t-n+1} │ s_{t-n+2} ... │ s_{t} p + │ │ │ │ │ │ ┌───────┴───────┐ + │ │ A_Q │ │ A_Q ... │ │ A_Q ... A_Q + │ │ │ │ │ │ │ │ │ │ + ┌───▼─────▼─────▼─────▼─────▼─────▼─────────────────▼─────▼─────▼───────────────▼───┐ + │ │ + │ GPT │ => policy + │ │ + └───────────────▼─────────────────▼─────────────────────────────▼───────────────▼───┘ + │ │ │ │ + ┌───┴───┐ ┌───┴───┐ ┌───┴───┐ ┌───┴───┐ + code offset code offset code offset code offset + ▼ │ ▼ │ ▼ │ ▼ │ => action_head + RVQ Decoder │ RVQ Decoder │ RVQ Decoder │ RVQ Decoder │ + └── + ──┘ └── + ──┘ └── + ──┘ └── + ──┘ + ▼ ▼ ▼ ▼ + action chunk action chunk action chunk action chunk + a_{t-n+1} ~ a_{t-n+2} ~ a_{t} ~ ... a_{t+p-1} ~ + a_{t-n+c} a_{t-n+c+1} a_{t+c-1} a_{t+p+c-1} + + ▼ + ONLY this chunk is used in rollout! + """ + + def __init__(self, config: VQBeTConfig): + super().__init__() + self.config = config + + self.rgb_encoder = VQBeTRgbEncoder(config) + self.num_images = len([k for k in config.input_shapes if k.startswith("observation.image")]) + # This action query token is used as a prompt for querying action chunks. Please refer to "A_Q" in the image above. + # Note: During the forward pass, this token is repeated as many times as needed. The authors also experimented with initializing the necessary number of tokens independently and observed inferior results. + self.action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim)) + + # To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT. + self.state_projector = MLP( + config.output_shapes["action"][0], hidden_channels=[self.config.gpt_input_dim] + ) + self.rgb_feature_projector = MLP( + self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim] + ) + + # GPT part of VQ-BeT + self.policy = GPT(config) + # bin prediction head / offset prediction head part of VQ-BeT + self.action_head = VQBeTHead(config) + + num_tokens = self.config.n_action_pred_token + self.config.action_chunk_size - 1 + self.register_buffer( + "select_target_actions_indices", + torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]), + ) + + def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor: + # Input validation. + assert set(batch).issuperset({"observation.state", "observation.images"}) + batch_size, n_obs_steps = batch["observation.state"].shape[:2] + assert n_obs_steps == self.config.n_obs_steps + + # Extract image feature (first combine batch and sequence dims). + img_features = self.rgb_encoder( + einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...") + ) + # Separate batch and sequence dims. + img_features = einops.rearrange( + img_features, "(b s n) ... -> b s n ...", b=batch_size, s=n_obs_steps, n=self.num_images + ) + + # Arrange prior and current observation step tokens as shown in the class docstring. + # First project features to token dimension. + rgb_tokens = self.rgb_feature_projector( + img_features + ) # (batch, obs_step, number of different cameras, projection dims) + input_tokens = [rgb_tokens[:, :, i] for i in range(rgb_tokens.size(2))] + input_tokens.append( + self.state_projector(batch["observation.state"]) + ) # (batch, obs_step, projection dims) + input_tokens.append(einops.repeat(self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps)) + # Interleave tokens by stacking and rearranging. + input_tokens = torch.stack(input_tokens, dim=2) + input_tokens = einops.rearrange(input_tokens, "b n t d -> b (n t) d") + + len_additional_action_token = self.config.n_action_pred_token - 1 + future_action_tokens = self.action_token.repeat(batch_size, len_additional_action_token, 1) + + # add additional action query tokens for predicting future action chunks + input_tokens = torch.cat([input_tokens, future_action_tokens], dim=1) + + # get action features (pass through GPT) + features = self.policy(input_tokens) + # len(self.config.input_shapes) is the number of different observation modes. this line gets the index of action prompt tokens. + historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_shapes) + 1) + len( + self.config.input_shapes + ) + + # only extract the output tokens at the position of action query: + # Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models, mapping sequential observation to sequential action (please refer to section 2.2 in BeT paper https://arxiv.org/pdf/2206.11251). + # Thus, it predict historical action sequence, in addition to current and future actions (predicting future actions : optional). + features = torch.cat( + [features[:, historical_act_pred_index], features[:, -len_additional_action_token:]], dim=1 + ) + # pass through action head + action_head_output = self.action_head(features) + # if rollout, VQ-BeT don't calculate loss + if rollout: + return action_head_output["predicted_action"][:, n_obs_steps - 1, :].reshape( + batch_size, self.config.action_chunk_size, -1 + ) + # else, it calculate overall loss (bin prediction loss, and offset loss) + else: + output = batch["action"][:, self.select_target_actions_indices] + loss = self.action_head.loss_fn(action_head_output, output, reduction="mean") + return action_head_output, loss + + +class VQBeTHead(nn.Module): + def __init__(self, config: VQBeTConfig): + """ + VQBeTHead takes output of GPT layers, and pass the feature through bin prediction head (`self.map_to_cbet_preds_bin`), and offset prediction head (`self.map_to_cbet_preds_offset`) + + self.map_to_cbet_preds_bin: outputs probability of each code (for each layer). + The input dimension of `self.map_to_cbet_preds_bin` is same with the output of GPT, + and the output dimension of `self.map_to_cbet_preds_bin` is `self.vqvae_model.vqvae_num_layers (=fixed as 2) * self.config.vqvae_n_embed`. + if the agent select the code sequentially, we use self.map_to_cbet_preds_primary_bin and self.map_to_cbet_preds_secondary_bin instead of self._map_to_cbet_preds_bin. + + self.map_to_cbet_preds_offset: output the predicted offsets for all the codes in all the layers. + The input dimension of ` self.map_to_cbet_preds_offset` is same with the output of GPT, + and the output dimension of ` self.map_to_cbet_preds_offset` is `self.vqvae_model.vqvae_num_layers (=fixed as 2) * self.config.vqvae_n_embed * config.action_chunk_size * config.output_shapes["action"][0]`. + """ + + super().__init__() + self.config = config + # init vqvae + self.vqvae_model = VqVae(config) + if config.sequentially_select: + self.map_to_cbet_preds_primary_bin = MLP( + in_channels=config.gpt_output_dim, + hidden_channels=[self.config.vqvae_n_embed], + ) + self.map_to_cbet_preds_secondary_bin = MLP( + in_channels=config.gpt_output_dim + self.config.vqvae_n_embed, + hidden_channels=[self.config.vqvae_n_embed], + ) + else: + self.map_to_cbet_preds_bin = MLP( + in_channels=config.gpt_output_dim, + hidden_channels=[self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed], + ) + self.map_to_cbet_preds_offset = MLP( + in_channels=config.gpt_output_dim, + hidden_channels=[ + self.vqvae_model.vqvae_num_layers + * self.config.vqvae_n_embed + * config.action_chunk_size + * config.output_shapes["action"][0], + ], + ) + # loss + self._focal_loss_fn = FocalLoss(gamma=2.0) + + def discretize(self, n_vqvae_training_steps, actions): + # Resize the action sequence data to fit the action chunk size using a sliding window approach. + actions = torch.cat( + [ + actions[:, j : j + self.config.action_chunk_size, :] + for j in range(actions.shape[1] + 1 - self.config.action_chunk_size) + ], + dim=0, + ) + # `actions` is a tensor of shape (new_batch, action_chunk_size, action_dim) where new_batch is the number of possible chunks created from the original sequences using the sliding window. + + loss, metric = self.vqvae_model.vqvae_forward(actions) + n_different_codes = sum( + [len(torch.unique(metric[2][:, i])) for i in range(self.vqvae_model.vqvae_num_layers)] + ) + n_different_combinations = len(torch.unique(metric[2], dim=0)) + recon_l1_error = metric[0].detach().cpu().item() + self.vqvae_model.optimized_steps += 1 + # if we updated RVQ more than `n_vqvae_training_steps` steps, we freeze the RVQ part. + if self.vqvae_model.optimized_steps >= n_vqvae_training_steps: + self.vqvae_model.discretized = torch.tensor(True) + self.vqvae_model.vq_layer.freeze_codebook = torch.tensor(True) + print("Finished discretizing action data!") + self.vqvae_model.eval() + for param in self.vqvae_model.vq_layer.parameters(): + param.requires_grad = False + return loss, n_different_codes, n_different_combinations, recon_l1_error + + def forward(self, x, **kwargs): + # N is the batch size, and T is number of action query tokens, which are process through same GPT + N, T, _ = x.shape + # we calculate N and T side parallely. Thus, the dimensions would be + # (batch size * number of action query tokens, action chunk size, action dimension) + x = einops.rearrange(x, "N T WA -> (N T) WA") + + # sample offsets + cbet_offsets = self.map_to_cbet_preds_offset(x) + cbet_offsets = einops.rearrange( + cbet_offsets, + "(NT) (G C WA) -> (NT) G C WA", + G=self.vqvae_model.vqvae_num_layers, + C=self.config.vqvae_n_embed, + ) + # if self.config.sequentially_select is True, bin prediction head first sample the primary code, and then sample secondary code + if self.config.sequentially_select: + cbet_primary_logits = self.map_to_cbet_preds_primary_bin(x) + + # select primary bin first + cbet_primary_probs = torch.softmax( + cbet_primary_logits / self.config.bet_softmax_temperature, dim=-1 + ) + NT, choices = cbet_primary_probs.shape + sampled_primary_centers = einops.rearrange( + torch.multinomial(cbet_primary_probs.view(-1, choices), num_samples=1), + "(NT) 1 -> NT", + NT=NT, + ) + + cbet_secondary_logits = self.map_to_cbet_preds_secondary_bin( + torch.cat( + (x, F.one_hot(sampled_primary_centers, num_classes=self.config.vqvae_n_embed)), + axis=1, + ) + ) + cbet_secondary_probs = torch.softmax( + cbet_secondary_logits / self.config.bet_softmax_temperature, dim=-1 + ) + sampled_secondary_centers = einops.rearrange( + torch.multinomial(cbet_secondary_probs.view(-1, choices), num_samples=1), + "(NT) 1 -> NT", + NT=NT, + ) + sampled_centers = torch.stack((sampled_primary_centers, sampled_secondary_centers), axis=1) + cbet_logits = torch.stack([cbet_primary_logits, cbet_secondary_logits], dim=1) + # if self.config.sequentially_select is False, bin prediction head samples primary and secondary code at once. + else: + cbet_logits = self.map_to_cbet_preds_bin(x) + cbet_logits = einops.rearrange( + cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers + ) + cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1) + NT, G, choices = cbet_probs.shape + sampled_centers = einops.rearrange( + torch.multinomial(cbet_probs.view(-1, choices), num_samples=1), + "(NT G) 1 -> NT G", + NT=NT, + ) + + device = get_device_from_parameters(self) + indices = ( + torch.arange(NT, device=device).unsqueeze(1), + torch.arange(self.vqvae_model.vqvae_num_layers, device=device).unsqueeze(0), + sampled_centers, + ) + # Use advanced indexing to sample the values (Extract the only offsets corresponding to the sampled codes.) + sampled_offsets = cbet_offsets[indices] + # Then, sum the offsets over the RVQ layers to get a net offset for the bin prediction + sampled_offsets = sampled_offsets.sum(dim=1) + with torch.no_grad(): + # Get the centroids (= vectors corresponding to the codes) of each layer to pass it through RVQ decoder + return_decoder_input = self.vqvae_model.get_embeddings_from_code(sampled_centers).clone().detach() + # pass the centroids through decoder to get actions. + decoded_action = self.vqvae_model.get_action_from_latent(return_decoder_input).clone().detach() + # reshaped extracted offset to match with decoded centroids + sampled_offsets = einops.rearrange( + sampled_offsets, "NT (W A) -> NT W A", W=self.config.action_chunk_size + ) + # add offset and decoded centroids + predicted_action = decoded_action + sampled_offsets + predicted_action = einops.rearrange( + predicted_action, + "(N T) W A -> N T (W A)", + N=N, + T=T, + W=self.config.action_chunk_size, + ) + + return { + "cbet_logits": cbet_logits, + "predicted_action": predicted_action, + "sampled_centers": sampled_centers, + "decoded_action": decoded_action, + } + + def loss_fn(self, pred, target, **kwargs): + """ + for given ground truth action values (target), and prediction (pred) this function calculates the overall loss. + + predicted_action: predicted action chunk (offset + decoded centroids) + sampled_centers: sampled centroids (code of RVQ) + decoded_action: decoded action, which is produced by passing sampled_centers through RVQ decoder + NT: batch size * T + T: number of action query tokens, which are process through same GPT + cbet_logits: probability of all codes in each layer + """ + action_seq = target + predicted_action = pred["predicted_action"] + sampled_centers = pred["sampled_centers"] + decoded_action = pred["decoded_action"] + NT = predicted_action.shape[0] * predicted_action.shape[1] + + cbet_logits = pred["cbet_logits"] + + predicted_action = einops.rearrange( + predicted_action, "N T (W A) -> (N T) W A", W=self.config.action_chunk_size + ) + + action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A") + # Figure out the loss for the actions. + # First, we need to find the closest cluster center for each ground truth action. + with torch.no_grad(): + state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G + + # Now we can compute the loss. + + # offset loss is L1 distance between the predicted action and ground truth action + offset_loss = F.l1_loss(action_seq, predicted_action) + + # calculate primary code prediction loss + cbet_loss1 = self._focal_loss_fn( + cbet_logits[:, 0, :], + action_bins[:, 0], + ) + # calculate secondary code prediction loss + cbet_loss2 = self._focal_loss_fn( + cbet_logits[:, 1, :], + action_bins[:, 1], + ) + # add all the prediction loss + cbet_loss = ( + cbet_loss1 * self.config.primary_code_loss_weight + + cbet_loss2 * self.config.secondary_code_loss_weight + ) + + equal_primary_code_rate = torch.sum((action_bins[:, 0] == sampled_centers[:, 0]).int()) / (NT) + equal_secondary_code_rate = torch.sum((action_bins[:, 1] == sampled_centers[:, 1]).int()) / (NT) + + action_mse_error = torch.mean((action_seq - predicted_action) ** 2) + vq_action_error = torch.mean(torch.abs(action_seq - decoded_action)) + offset_action_error = torch.mean(torch.abs(action_seq - predicted_action)) + action_error_max = torch.max(torch.abs(action_seq - predicted_action)) + + loss = cbet_loss + self.config.offset_loss_weight * offset_loss + + loss_dict = { + "loss": loss, + "classification_loss": cbet_loss.detach().cpu().item(), + "offset_loss": offset_loss.detach().cpu().item(), + "equal_primary_code_rate": equal_primary_code_rate.detach().cpu().item(), + "equal_secondary_code_rate": equal_secondary_code_rate.detach().cpu().item(), + "vq_action_error": vq_action_error.detach().cpu().item(), + "offset_action_error": offset_action_error.detach().cpu().item(), + "action_error_max": action_error_max.detach().cpu().item(), + "action_mse_error": action_mse_error.detach().cpu().item(), + } + return loss_dict + + +class VQBeTOptimizer(torch.optim.Adam): + def __init__(self, policy, cfg): + vqvae_params = ( + list(policy.vqbet.action_head.vqvae_model.encoder.parameters()) + + list(policy.vqbet.action_head.vqvae_model.decoder.parameters()) + + list(policy.vqbet.action_head.vqvae_model.vq_layer.parameters()) + ) + decay_params, no_decay_params = policy.vqbet.policy.configure_parameters() + decay_params = ( + decay_params + + list(policy.vqbet.rgb_encoder.parameters()) + + list(policy.vqbet.state_projector.parameters()) + + list(policy.vqbet.rgb_feature_projector.parameters()) + + [policy.vqbet.action_token] + + list(policy.vqbet.action_head.map_to_cbet_preds_offset.parameters()) + ) + + if cfg.policy.sequentially_select: + decay_params = ( + decay_params + + list(policy.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters()) + + list(policy.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters()) + ) + else: + decay_params = decay_params + list(policy.vqbet.action_head.map_to_cbet_preds_bin.parameters()) + + optim_groups = [ + { + "params": decay_params, + "weight_decay": cfg.training.adam_weight_decay, + "lr": cfg.training.lr, + }, + { + "params": vqvae_params, + "weight_decay": 0.0001, + "lr": cfg.training.vqvae_lr, + }, + { + "params": no_decay_params, + "weight_decay": 0.0, + "lr": cfg.training.lr, + }, + ] + super().__init__( + optim_groups, + cfg.training.lr, + cfg.training.adam_betas, + cfg.training.adam_eps, + ) + + +class VQBeTScheduler(nn.Module): + def __init__(self, optimizer, cfg): + super().__init__() + n_vqvae_training_steps = cfg.training.n_vqvae_training_steps + + num_warmup_steps = cfg.training.lr_warmup_steps + num_training_steps = cfg.training.offline_steps + num_cycles = 0.5 + + def lr_lambda(current_step): + if current_step < n_vqvae_training_steps: + return float(1) + else: + current_step = current_step - n_vqvae_training_steps + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float( + max(1, num_training_steps - num_warmup_steps) + ) + return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) + + self.lr_scheduler = LambdaLR(optimizer, lr_lambda, -1) + + def step(self): + self.lr_scheduler.step() + + +class VQBeTRgbEncoder(nn.Module): + """Encode an RGB image into a 1D feature vector. + + Includes the ability to normalize and crop the image first. + + Same with DiffusionRgbEncoder from modeling_diffusion.py + """ + + def __init__(self, config: VQBeTConfig): + super().__init__() + # Set up optional preprocessing. + if config.crop_shape is not None: + self.do_crop = True + # Always use center crop for eval + self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape) + if config.crop_is_random: + self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape) + else: + self.maybe_random_crop = self.center_crop + else: + self.do_crop = False + + # Set up backbone. + backbone_model = getattr(torchvision.models, config.vision_backbone)( + weights=config.pretrained_backbone_weights + ) + # Note: This assumes that the layer4 feature map is children()[-3] + # TODO(alexander-soare): Use a safer alternative. + self.backbone = nn.Sequential(*(list(backbone_model.children())[:-2])) + if config.use_group_norm: + if config.pretrained_backbone_weights: + raise ValueError( + "You can't replace BatchNorm in a pretrained model without ruining the weights!" + ) + self.backbone = _replace_submodules( + root_module=self.backbone, + predicate=lambda x: isinstance(x, nn.BatchNorm2d), + func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features), + ) + + # Set up pooling and final layers. + # Use a dry run to get the feature map shape. + # The dummy input should take the number of image channels from `config.input_shapes` and it should + # use the height and width from `config.crop_shape` if it is provided, otherwise it should use the + # height and width from `config.input_shapes`. + image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + assert len(image_keys) == 1 + image_key = image_keys[0] + dummy_input_h_w = ( + config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:] + ) + dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w)) + with torch.inference_mode(): + dummy_feature_map = self.backbone(dummy_input) + feature_map_shape = tuple(dummy_feature_map.shape[1:]) + self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints) + self.feature_dim = config.spatial_softmax_num_keypoints * 2 + self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim) + self.relu = nn.ReLU() + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: (B, C, H, W) image tensor with pixel values in [0, 1]. + Returns: + (B, D) image feature. + """ + # Preprocess: maybe crop (if it was set up in the __init__). + if self.do_crop: + if self.training: # noqa: SIM108 + x = self.maybe_random_crop(x) + else: + # Always use center crop for eval. + x = self.center_crop(x) + # Extract backbone feature. + x = torch.flatten(self.pool(self.backbone(x)), start_dim=1) + # Final linear layer with non-linearity. + x = self.relu(self.out(x)) + return x + + +def _replace_submodules( + root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module] +) -> nn.Module: + """ + Args: + root_module: The module for which the submodules need to be replaced + predicate: Takes a module as an argument and must return True if the that module is to be replaced. + func: Takes a module as an argument and returns a new module to replace it with. + Returns: + The root module with its submodules replaced. + """ + if predicate(root_module): + return func(root_module) + + replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)] + for *parents, k in replace_list: + parent_module = root_module + if len(parents) > 0: + parent_module = root_module.get_submodule(".".join(parents)) + if isinstance(parent_module, nn.Sequential): + src_module = parent_module[int(k)] + else: + src_module = getattr(parent_module, k) + tgt_module = func(src_module) + if isinstance(parent_module, nn.Sequential): + parent_module[int(k)] = tgt_module + else: + setattr(parent_module, k, tgt_module) + # verify that all BN are replaced + assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)) + return root_module + + +class VqVae(nn.Module): + def __init__( + self, + config: VQBeTConfig, + ): + """ + VQ-VAE is composed of three parts: encoder, vq_layer, and decoder. + Encoder and decoder are MLPs consisting of an input, output layer, and hidden layer, respectively. + The vq_layer uses residual VQs. + + This class contains functions for training the encoder and decoder along with the residual VQ layer (for trainign phase 1), + as well as functions to help BeT training part in training phase 2. + """ + + super().__init__() + self.config = config + # 'discretized' indicates whether the Residual VQ part is trained or not. (After finishing the training, we set discretized=True) + self.register_buffer("discretized", torch.tensor(False)) + self.optimized_steps = 0 + # we use the fixed number of layers for Residual VQ across all environments. + self.vqvae_num_layers = 2 + + self.vq_layer = ResidualVQ( + dim=config.vqvae_embedding_dim, + num_quantizers=self.vqvae_num_layers, + codebook_size=config.vqvae_n_embed, + ) + + self.encoder = MLP( + in_channels=self.config.output_shapes["action"][0] * self.config.action_chunk_size, + hidden_channels=[ + config.vqvae_enc_hidden_dim, + config.vqvae_enc_hidden_dim, + config.vqvae_embedding_dim, + ], + ) + self.decoder = MLP( + in_channels=config.vqvae_embedding_dim, + hidden_channels=[ + config.vqvae_enc_hidden_dim, + config.vqvae_enc_hidden_dim, + self.config.output_shapes["action"][0] * self.config.action_chunk_size, + ], + ) + + def get_embeddings_from_code(self, encoding_indices): + # This function gets code indices as inputs, and outputs embedding vectors corresponding to the code indices. + with torch.no_grad(): + z_embed = self.vq_layer.get_codebook_vector_from_indices(encoding_indices) + # since the RVQ has multiple layers, it adds the vectors in the axis of layers to provide a vector for that code combination. + z_embed = z_embed.sum(dim=0) + return z_embed + + def get_action_from_latent(self, latent): + # given latent vector, this function outputs the decoded action. + output = self.decoder(latent) + if self.config.action_chunk_size == 1: + return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0]) + else: + return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0]) + + def get_code(self, state): + # in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181) + # this function outputs the `GT code` of given action using frozen encoder and quantization layers. (please refer to Figure 2. in the paper https://arxiv.org/pdf/2403.03181) + state = einops.rearrange(state, "N T A -> N (T A)") + with torch.no_grad(): + state_rep = self.encoder(state) + state_rep_shape = state_rep.shape[:-1] + state_rep_flat = state_rep.view(state_rep.size(0), -1, state_rep.size(1)) + state_rep_flat, vq_code, vq_loss_state = self.vq_layer(state_rep_flat) + state_vq = state_rep_flat.view(*state_rep_shape, -1) + vq_code = vq_code.view(*state_rep_shape, -1) + vq_loss_state = torch.sum(vq_loss_state) + return state_vq, vq_code + + def vqvae_forward(self, state): + # This function passes the given data through Residual VQ with Encoder and Decoder. Please refer to section 3.2 in the paper https://arxiv.org/pdf/2403.03181). + state = einops.rearrange(state, "N T A -> N (T A)") + # We start with passing action (or action chunk) at:t+n through the encoder ϕ. + state_rep = self.encoder(state) + state_rep_shape = state_rep.shape[:-1] + state_rep_flat = state_rep.view(state_rep.size(0), -1, state_rep.size(1)) + # The resulting latent embedding vector x = ϕ(at:t+n) is then mapped to an embedding vector in the codebook of the RVQ layers by the nearest neighbor look-up. + state_rep_flat, vq_code, vq_loss_state = self.vq_layer(state_rep_flat) + state_vq = state_rep_flat.view(*state_rep_shape, -1) + vq_code = vq_code.view(*state_rep_shape, -1) + # since the RVQ has multiple layers, it adds the vectors in the axis of layers to provide a vector for that code combination. + vq_loss_state = torch.sum(vq_loss_state) + # Then, the discretized vector zq(x) is reconstructed as ψ(zq(x)) by passing through the decoder ψ. + dec_out = self.decoder(state_vq) + # Calculate L1 reconstruction loss + encoder_loss = (state - dec_out).abs().mean() + # add encoder reconstruction loss and commitment loss + rep_loss = encoder_loss + vq_loss_state * 5 + + metric = ( + encoder_loss.clone().detach(), + vq_loss_state.clone().detach(), + vq_code, + rep_loss.item(), + ) + return rep_loss, metric + + +class FocalLoss(nn.Module): + """ + From https://github.com/notmahi/miniBET/blob/main/behavior_transformer/bet.py + """ + + def __init__(self, gamma: float = 0, size_average: bool = True): + super().__init__() + self.gamma = gamma + self.size_average = size_average + + def forward(self, input, target): + if len(input.shape) == 3: + N, T, _ = input.shape + logpt = F.log_softmax(input, dim=-1) + logpt = logpt.gather(-1, target.view(N, T, 1)).view(N, T) + elif len(input.shape) == 2: + logpt = F.log_softmax(input, dim=-1) + logpt = logpt.gather(-1, target.view(-1, 1)).view(-1) + pt = logpt.exp() + + loss = -1 * (1 - pt) ** self.gamma * logpt + if self.size_average: + return loss.mean() + else: + return loss.sum() + + +class MLP(torch.nn.Sequential): + def __init__( + self, + in_channels: int, + hidden_channels: List[int], + ): + layers = [] + in_dim = in_channels + for hidden_dim in hidden_channels[:-1]: + layers.append(torch.nn.Linear(in_dim, hidden_dim)) + layers.append(torch.nn.ReLU()) + in_dim = hidden_dim + + layers.append(torch.nn.Linear(in_dim, hidden_channels[-1])) + + super().__init__(*layers) diff --git a/lerobot/common/policies/vqbet/vqbet_utils.py b/lerobot/common/policies/vqbet/vqbet_utils.py new file mode 100644 index 00000000..6e96716f --- /dev/null +++ b/lerobot/common/policies/vqbet/vqbet_utils.py @@ -0,0 +1,1444 @@ +import math +from functools import partial +from math import ceil +from random import randrange +from typing import Callable + +import torch +import torch.distributed as distributed +import torch.nn.functional as F # noqa: N812 +from einops import pack, rearrange, reduce, repeat, unpack +from torch import einsum, nn +from torch.cuda.amp import autocast +from torch.optim import Optimizer + +from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig + +# ruff: noqa: N806 + +""" +This file is part of a VQ-BeT that utilizes code from the following repositories: + + - Vector Quantize PyTorch code is licensed under the MIT License: + Origianl source: https://github.com/lucidrains/vector-quantize-pytorch + + - nanoGPT part is an adaptation of Andrej Karpathy's nanoGPT implementation in PyTorch. + Original source: https://github.com/karpathy/nanoGPT + +We also made some changes to the original code to adapt it to our needs. The changes are described in the code below. +""" + +""" +This is a part for nanoGPT that utilizes code from the following repository: + + - Andrej Karpathy's nanoGPT implementation in PyTorch. + Original source: https://github.com/karpathy/nanoGPT + + - The nanoGPT code is licensed under the MIT License: + + MIT License + + Copyright (c) 2022 Andrej Karpathy + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + + - We've made some changes to the original code to adapt it to our needs. + + Changed variable names: + - n_head -> gpt_n_head + - n_embd -> gpt_hidden_dim + - block_size -> gpt_block_size + - n_layer -> gpt_n_layer + + + class GPT(nn.Module): + - removed unused functions `def generate`, `def estimate_mfu`, and `def from_pretrained` + - changed the `configure_optimizers` to `def configure_parameters` and made it to return only the parameters of the model: we use an external optimizer in our training loop. + - in the function `forward`, we removed target loss calculation parts, since it will be calculated in the training loop (after passing through bin prediction and offset prediction heads). + +""" + + +class CausalSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + assert config.gpt_hidden_dim % config.gpt_n_head == 0 + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear(config.gpt_hidden_dim, 3 * config.gpt_hidden_dim) + # output projection + self.c_proj = nn.Linear(config.gpt_hidden_dim, config.gpt_hidden_dim) + # regularization + self.attn_dropout = nn.Dropout(config.dropout) + self.resid_dropout = nn.Dropout(config.dropout) + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer( + "bias", + torch.tril(torch.ones(config.gpt_block_size, config.gpt_block_size)).view( + 1, 1, config.gpt_block_size, config.gpt_block_size + ), + ) + self.gpt_n_head = config.gpt_n_head + self.gpt_hidden_dim = config.gpt_hidden_dim + + def forward(self, x): + ( + B, + T, + C, + ) = x.size() # batch size, sequence length, embedding dimensionality (gpt_hidden_dim) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q, k, v = self.c_attn(x).split(self.gpt_hidden_dim, dim=2) + k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs) + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y + + +class Block(nn.Module): + # causual self-attention block for GPT + def __init__(self, config): + super().__init__() + self.ln_1 = nn.LayerNorm(config.gpt_hidden_dim) + self.attn = CausalSelfAttention(config) + self.ln_2 = nn.LayerNorm(config.gpt_hidden_dim) + self.mlp = nn.Sequential( + nn.Linear(config.gpt_hidden_dim, 4 * config.gpt_hidden_dim), + nn.GELU(), + nn.Linear(4 * config.gpt_hidden_dim, config.gpt_hidden_dim), + nn.Dropout(config.dropout), + ) + + def forward(self, x): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class GPT(nn.Module): + """ + Original comments: + Full definition of a GPT Language Model, all of it in this single file. + References: + 1) the official GPT-2 TensorFlow implementation released by OpenAI: + https://github.com/openai/gpt-2/blob/master/src/model.py + 2) huggingface/transformers PyTorch implementation: + https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py + """ + + def __init__(self, config: VQBeTConfig): + """ + GPT model gets hyperparameters from a config object. Please refer configuration_vqbet.py for more details. + """ + super().__init__() + assert config.gpt_output_dim is not None + assert config.gpt_block_size is not None + self.config = config + + self.transformer = nn.ModuleDict( + { + "wte": nn.Linear(config.gpt_input_dim, config.gpt_hidden_dim), + "wpe": nn.Embedding(config.gpt_block_size, config.gpt_hidden_dim), + "drop": nn.Dropout(config.dropout), + "h": nn.ModuleList([Block(config) for _ in range(config.gpt_n_layer)]), + "ln_f": nn.LayerNorm(config.gpt_hidden_dim), + } + ) + self.lm_head = nn.Linear(config.gpt_hidden_dim, config.gpt_output_dim, bias=False) + # init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper + self.apply(self._init_weights) + for pn, p in self.named_parameters(): + if pn.endswith("c_proj.weight"): + torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer)) + + # report number of parameters + n_params = sum(p.numel() for p in self.parameters()) + print("number of parameters: {:.2f}M".format(n_params / 1e6)) + + def forward(self, input, targets=None): + device = input.device + b, t, d = input.size() + assert ( + t <= self.config.gpt_block_size + ), f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}" + + # positional encodings that are added to the input embeddings + pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) + + # forward the GPT model itself + tok_emb = self.transformer.wte(input) # token embeddings of shape (b, t, gpt_hidden_dim) + pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, gpt_hidden_dim) + x = self.transformer.drop(tok_emb + pos_emb) + for block in self.transformer.h: + x = block(x) + x = self.transformer.ln_f(x) + logits = self.lm_head(x) + return logits + + def _init_weights(self, module): + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + elif isinstance(module, nn.LayerNorm): + torch.nn.init.zeros_(module.bias) + torch.nn.init.ones_(module.weight) + + def crop_block_size(self, gpt_block_size): + # model surgery to decrease the block size if necessary + # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024) + # but want to use a smaller block size for some smaller, simpler model + assert gpt_block_size <= self.config.gpt_block_size + self.config.gpt_block_size = gpt_block_size + self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:gpt_block_size]) + for block in self.transformer.h: + block.attn.bias = block.attn.bias[:, :, :gpt_block_size, :gpt_block_size] + + def configure_parameters(self): + """ + This long function is unfortunately doing something very simple and is being very defensive: + We are separating out all parameters of the model into two buckets: those that will experience + weight decay for regularization and those that won't (biases, and layernorm/embedding weights). + """ + + # separate out all parameters to those that will and won't experience regularizing weight decay + decay = set() + no_decay = set() + whitelist_weight_modules = (torch.nn.Linear,) + blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) + for mn, m in self.named_modules(): + for pn, _p in m.named_parameters(): + fpn = "{}.{}".format(mn, pn) if mn else pn # full param name + if pn.endswith("bias"): + # all biases will not be decayed + no_decay.add(fpn) + elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): + # weights of whitelist modules will be weight decayed + decay.add(fpn) + elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules): + # weights of blacklist modules will NOT be weight decayed + no_decay.add(fpn) + + # validate that we considered every parameter + param_dict = dict(self.named_parameters()) + inter_params = decay & no_decay + union_params = decay | no_decay + assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format( + str(inter_params) + ) + assert ( + len(param_dict.keys() - union_params) == 0 + ), "parameters {} were not separated into either decay/no_decay set!".format( + str(param_dict.keys() - union_params), + ) + + decay = [param_dict[pn] for pn in sorted(decay)] + no_decay = [param_dict[pn] for pn in sorted(no_decay)] + # return the parameters that require weight decay, and the parameters that don't separately. + return decay, no_decay + + +""" +This file is a part for Residual Vector Quantization that utilizes code from the following repository: + + - Phil Wang's vector-quantize-pytorch implementation in PyTorch. + Origianl source: https://github.com/lucidrains/vector-quantize-pytorch + + - The vector-quantize-pytorch code is licensed under the MIT License: + + MIT License + + Copyright (c) 2020 Phil Wang + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + + - We've made some changes to the original code to adapt it to our needs. + + class ResidualVQ(nn.Module): + - added `self.register_buffer('freeze_codebook', torch.tensor(False))` to the __init__ method: + This enables the user to save an indicator whether the codebook is frozen or not. + - changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`: + This is to make the function name more descriptive. + + class VectorQuantize(nn.Module): + - removed the `use_cosine_sim` and `layernorm_after_project_in` parameters from the __init__ method: + These parameters are not used in the code. + - changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`: + This is to make the function name more descriptive. + +""" + + +class ResidualVQ(nn.Module): + """ + Residual VQ is composed of multiple VectorQuantize layers. + + Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + "Residual Vector Quantizer (a.k.a. multi-stage vector quantizer [36]) cascades Nq layers of VQ as follows. The unquantized input vector is + passed through a first VQ and quantization residuals are computed. The residuals are then iteratively quantized by a sequence of additional + Nq -1 vector quantizers, as described in Algorithm 1." + + + self.project_in: function for projecting input to codebook dimension + self.project_out: function for projecting codebook dimension to output dimension + self.layers: nn.ModuleList of VectorQuantize layers that contains Nq layers of VQ as described in the paper. + self.freeze_codebook: buffer to save an indicator whether the codebook is frozen or not. VQ-BeT will check this to determine whether to update the codebook or not. + """ + + def __init__( + self, + *, + dim, + num_quantizers, + codebook_dim=None, + shared_codebook=False, + heads=1, + quantize_dropout=False, + quantize_dropout_cutoff_index=0, + quantize_dropout_multiple_of=1, + accept_image_fmap=False, + **kwargs, + ): + super().__init__() + assert heads == 1, "residual vq is not compatible with multi-headed codes" + codebook_dim = codebook_dim if (codebook_dim is not None) else dim + codebook_input_dim = codebook_dim * heads + + requires_projection = codebook_input_dim != dim + self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity() + self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity() + + self.num_quantizers = num_quantizers + + self.accept_image_fmap = accept_image_fmap + self.layers = nn.ModuleList( + [ + VectorQuantize( + dim=codebook_dim, codebook_dim=codebook_dim, accept_image_fmap=accept_image_fmap, **kwargs + ) + for _ in range(num_quantizers) + ] + ) + + self.quantize_dropout = quantize_dropout and num_quantizers > 1 + + assert quantize_dropout_cutoff_index >= 0 + + self.register_buffer("freeze_codebook", torch.tensor(False)) + self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index + self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4 + + if not shared_codebook: + return + + first_vq, *rest_vq = self.layers + codebook = first_vq._codebook + + for vq in rest_vq: + vq._codebook = codebook + + @property + def codebooks(self): + codebooks = [layer._codebook.embed for layer in self.layers] + codebooks = torch.stack(codebooks, dim=0) + codebooks = rearrange(codebooks, "q 1 c d -> q c d") + return codebooks + + def get_codebook_vector_from_indices(self, indices): + # this function will return the codes from all codebooks across layers corresponding to the indices + batch, quantize_dim = indices.shape[0], indices.shape[-1] + + # may also receive indices in the shape of 'b h w q' (accept_image_fmap) + + indices, ps = pack([indices], "b * q") + + # because of quantize dropout, one can pass in indices that are coarse + # and the network should be able to reconstruct + + if quantize_dim < self.num_quantizers: + assert ( + self.quantize_dropout > 0.0 + ), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations" + indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1) + + # get ready for gathering + + codebooks = repeat(self.codebooks, "q c d -> q b c d", b=batch) + gather_indices = repeat(indices, "b n q -> q b n d", d=codebooks.shape[-1]) + + # take care of quantizer dropout + + mask = gather_indices == -1.0 + gather_indices = gather_indices.masked_fill( + mask, 0 + ) # have it fetch a dummy code to be masked out later + + all_codes = codebooks.gather(2, gather_indices) # gather all codes + + # mask out any codes that were dropout-ed + + all_codes = all_codes.masked_fill(mask, 0.0) + + # if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension) + + (all_codes,) = unpack(all_codes, ps, "q b * d") + + return all_codes + + def forward(self, x, indices=None, return_all_codes=False, sample_codebook_temp=None): + """ + For given input tensor x, this function will return the quantized output, the indices of the quantized output, and the loss. + First, the input tensor x is projected to the codebook dimension. Then, the input tensor x is passed through Nq layers of VectorQuantize. + The residual value of each layer is fed to the next layer. + """ + num_quant, quant_dropout_multiple_of, return_loss, device = ( + self.num_quantizers, + self.quantize_dropout_multiple_of, + (indices is not None), + x.device, + ) + + x = self.project_in(x) + + assert not (self.accept_image_fmap and (indices is not None)) + + quantized_out = 0.0 + residual = x + + all_losses = [] + all_indices = [] + + if return_loss: + assert not torch.any( + indices == -1 + ), "some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss" + ce_losses = [] + + should_quantize_dropout = self.training and self.quantize_dropout and not return_loss + + # sample a layer index at which to dropout further residual quantization + # also prepare null indices and loss + + if should_quantize_dropout: + rand_quantize_dropout_index = randrange(self.quantize_dropout_cutoff_index, num_quant) + + if quant_dropout_multiple_of != 1: + rand_quantize_dropout_index = ( + ceil((rand_quantize_dropout_index + 1) / quant_dropout_multiple_of) + * quant_dropout_multiple_of + - 1 + ) + + null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2]) + null_indices = torch.full(null_indices_shape, -1.0, device=device, dtype=torch.long) + null_loss = torch.full((1,), 0.0, device=device, dtype=x.dtype) + + # go through the layers + + for quantizer_index, layer in enumerate(self.layers): + if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index: + all_indices.append(null_indices) + all_losses.append(null_loss) + continue + + layer_indices = None + if return_loss: + layer_indices = indices[..., quantizer_index] + + quantized, *rest = layer( + residual, + indices=layer_indices, + sample_codebook_temp=sample_codebook_temp, + freeze_codebook=self.freeze_codebook, + ) + + residual = residual - quantized.detach() + quantized_out = quantized_out + quantized + + if return_loss: + ce_loss = rest[0] + ce_losses.append(ce_loss) + continue + + embed_indices, loss = rest + + all_indices.append(embed_indices) + all_losses.append(loss) + + # project out, if needed + + quantized_out = self.project_out(quantized_out) + + # whether to early return the cross entropy loss + + if return_loss: + return quantized_out, sum(ce_losses) + + # stack all losses and indices + + all_losses, all_indices = map(partial(torch.stack, dim=-1), (all_losses, all_indices)) + + ret = (quantized_out, all_indices, all_losses) + + if return_all_codes: + # whether to return all codes from all codebooks across layers + all_codes = self.get_codebook_vector_from_indices(all_indices) + + # will return all codes in shape (quantizer, batch, sequence length, codebook dimension) + ret = (*ret, all_codes) + + return ret + + +class VectorQuantize(nn.Module): + def __init__( + self, + dim, + codebook_size, + codebook_dim=None, + heads=1, + separate_codebook_per_head=False, + decay=0.8, + eps=1e-5, + kmeans_init=False, + kmeans_iters=10, + sync_kmeans=True, + threshold_ema_dead_code=0, + channel_last=True, + accept_image_fmap=False, + commitment_weight=1.0, + commitment_use_cross_entropy_loss=False, + orthogonal_reg_weight=0.0, + orthogonal_reg_active_codes_only=False, + orthogonal_reg_max_codes=None, + stochastic_sample_codes=False, + sample_codebook_temp=1.0, + straight_through=False, + reinmax=False, # using reinmax for improved straight-through, assuming straight through helps at all + sync_codebook=None, + sync_affine_param=False, + ema_update=True, + learnable_codebook=False, + in_place_codebook_optimizer: Callable[ + ..., Optimizer + ] = None, # Optimizer used to update the codebook embedding if using learnable_codebook + affine_param=False, + affine_param_batch_decay=0.99, + affine_param_codebook_decay=0.9, + sync_update_v=0.0, # the v that controls optimistic vs pessimistic update for synchronous update rule (21) https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf + ): + super().__init__() + self.dim = dim + self.heads = heads + self.separate_codebook_per_head = separate_codebook_per_head + + codebook_dim = codebook_dim if (codebook_dim is not None) else dim + codebook_input_dim = codebook_dim * heads + + requires_projection = codebook_input_dim != dim + self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity() + self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity() + + self.eps = eps + self.commitment_weight = commitment_weight + self.commitment_use_cross_entropy_loss = commitment_use_cross_entropy_loss # whether to use cross entropy loss to codebook as commitment loss + + self.learnable_codebook = learnable_codebook + + has_codebook_orthogonal_loss = orthogonal_reg_weight > 0 + self.has_codebook_orthogonal_loss = has_codebook_orthogonal_loss + self.orthogonal_reg_weight = orthogonal_reg_weight + self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only + self.orthogonal_reg_max_codes = orthogonal_reg_max_codes + + assert not (ema_update and learnable_codebook), "learnable codebook not compatible with EMA update" + + assert 0 <= sync_update_v <= 1.0 + assert not (sync_update_v > 0.0 and not learnable_codebook), "learnable codebook must be turned on" + + self.sync_update_v = sync_update_v + + gumbel_sample_fn = partial( + gumbel_sample, + stochastic=stochastic_sample_codes, + reinmax=reinmax, + straight_through=straight_through, + ) + + if sync_codebook is None: + sync_codebook = distributed.is_initialized() and distributed.get_world_size() > 1 + + codebook_kwargs = { + "dim": codebook_dim, + "num_codebooks": heads if separate_codebook_per_head else 1, + "codebook_size": codebook_size, + "kmeans_init": kmeans_init, + "kmeans_iters": kmeans_iters, + "sync_kmeans": sync_kmeans, + "decay": decay, + "eps": eps, + "threshold_ema_dead_code": threshold_ema_dead_code, + "use_ddp": sync_codebook, + "learnable_codebook": has_codebook_orthogonal_loss or learnable_codebook, + "sample_codebook_temp": sample_codebook_temp, + "gumbel_sample": gumbel_sample_fn, + "ema_update": ema_update, + } + + if affine_param: + codebook_kwargs = dict( + **codebook_kwargs, + affine_param=True, + sync_affine_param=sync_affine_param, + affine_param_batch_decay=affine_param_batch_decay, + affine_param_codebook_decay=affine_param_codebook_decay, + ) + + self._codebook = EuclideanCodebook(**codebook_kwargs) + + self.in_place_codebook_optimizer = ( + in_place_codebook_optimizer(self._codebook.parameters()) + if (in_place_codebook_optimizer is not None) + else None + ) + + self.codebook_size = codebook_size + + self.accept_image_fmap = accept_image_fmap + self.channel_last = channel_last + + @property + def codebook(self): + codebook = self._codebook.embed + + if self.separate_codebook_per_head: + return codebook + + return rearrange(codebook, "1 ... -> ...") + + @codebook.setter + def codebook(self, codes): + if not self.separate_codebook_per_head: + codes = rearrange(codes, "... -> 1 ...") + + self._codebook.embed.copy_(codes) + + def get_codebook_vector_from_indices(self, indices): + codebook = self.codebook + is_multiheaded = codebook.ndim > 2 + + if not is_multiheaded: + codes = codebook[indices] + return rearrange(codes, "... h d -> ... (h d)") + + indices, ps = pack_one(indices, "b * h") + indices = rearrange(indices, "b n h -> b h n") + + indices = repeat(indices, "b h n -> b h n d", d=codebook.shape[-1]) + codebook = repeat(codebook, "h n d -> b h n d", b=indices.shape[0]) + + codes = codebook.gather(2, indices) + codes = rearrange(codes, "b h n d -> b n (h d)") + codes = unpack_one(codes, ps, "b * d") + return codes + + def forward( + self, + x, + indices=None, + mask=None, + sample_codebook_temp=None, + freeze_codebook=False, + ): + orig_input = x + + only_one = x.ndim == 2 + + if only_one: + assert mask is None + x = rearrange(x, "b d -> b 1 d") + + shape, device, heads, is_multiheaded, _codebook_size, return_loss = ( + x.shape, + x.device, + self.heads, + self.heads > 1, + self.codebook_size, + (indices is not None), + ) + + need_transpose = not self.channel_last and not self.accept_image_fmap + should_inplace_optimize = self.in_place_codebook_optimizer is not None + + # rearrange inputs + + if self.accept_image_fmap: + height, width = x.shape[-2:] + x = rearrange(x, "b c h w -> b (h w) c") + + if need_transpose: + x = rearrange(x, "b d n -> b n d") + + # project input + + x = self.project_in(x) + + # handle multi-headed separate codebooks + + if is_multiheaded: + ein_rhs_eq = "h b n d" if self.separate_codebook_per_head else "1 (b h) n d" + x = rearrange(x, f"b n (h d) -> {ein_rhs_eq}", h=heads) + + # l2norm for cosine sim, otherwise identity + + x = self._codebook.transform_input(x) + + # codebook forward kwargs + + codebook_forward_kwargs = { + "sample_codebook_temp": sample_codebook_temp, + "mask": mask, + "freeze_codebook": freeze_codebook, + } + + # quantize + + quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs) + + # one step in-place update + + if should_inplace_optimize and self.training and not freeze_codebook: + if mask is not None: + loss = F.mse_loss(quantize, x.detach(), reduction="none") + + loss_mask = mask + if is_multiheaded: + loss_mask = repeat( + mask, + "b n -> c (b h) n", + c=loss.shape[0], + h=loss.shape[1] // mask.shape[0], + ) + + loss = loss[loss_mask].mean() + + else: + loss = F.mse_loss(quantize, x.detach()) + + loss.backward() + self.in_place_codebook_optimizer.step() + self.in_place_codebook_optimizer.zero_grad() + + # quantize again + + quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs) + + if self.training: + # determine code to use for commitment loss + maybe_detach = torch.detach if not self.learnable_codebook or freeze_codebook else identity + + commit_quantize = maybe_detach(quantize) + + # straight through + + quantize = x + (quantize - x).detach() + + if self.sync_update_v > 0.0: + # (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf + quantize = quantize + self.sync_update_v * (quantize - quantize.detach()) + + # function for calculating cross entropy loss to distance matrix + # used for (1) naturalspeech2 training residual vq latents to be close to the correct codes and (2) cross-entropy based commitment loss + + def calculate_ce_loss(codes): + if not is_multiheaded: + dist_einops_eq = "1 b n l -> b l n" + elif self.separate_codebook_per_head: + dist_einops_eq = "c b n l -> b l n c" + else: + dist_einops_eq = "1 (b h) n l -> b l n h" + + ce_loss = F.cross_entropy( + rearrange(distances, dist_einops_eq, b=shape[0]), codes, ignore_index=-1 + ) + + return ce_loss + + # if returning cross entropy loss on codes that were passed in + + if return_loss: + return quantize, calculate_ce_loss(indices) + + # transform embedding indices + + if is_multiheaded: + if self.separate_codebook_per_head: + embed_ind = rearrange(embed_ind, "h b n -> b n h", h=heads) + else: + embed_ind = rearrange(embed_ind, "1 (b h) n -> b n h", h=heads) + + if self.accept_image_fmap: + embed_ind = rearrange(embed_ind, "b (h w) ... -> b h w ...", h=height, w=width) + + if only_one: + embed_ind = rearrange(embed_ind, "b 1 -> b") + + # aggregate loss + + loss = torch.tensor([0.0], device=device, requires_grad=self.training) + + if self.training: + if self.commitment_weight > 0: + if self.commitment_use_cross_entropy_loss: + if mask is not None: + ce_loss_mask = mask + if is_multiheaded: + ce_loss_mask = repeat(ce_loss_mask, "b n -> b n h", h=heads) + + embed_ind.masked_fill_(~ce_loss_mask, -1) + + commit_loss = calculate_ce_loss(embed_ind) + else: + if mask is not None: + # with variable lengthed sequences + commit_loss = F.mse_loss(commit_quantize, x, reduction="none") + + loss_mask = mask + if is_multiheaded: + loss_mask = repeat( + loss_mask, + "b n -> c (b h) n", + c=commit_loss.shape[0], + h=commit_loss.shape[1] // mask.shape[0], + ) + + commit_loss = commit_loss[loss_mask].mean() + else: + commit_loss = F.mse_loss(commit_quantize, x) + + loss = loss + commit_loss * self.commitment_weight + + if self.has_codebook_orthogonal_loss: + codebook = self._codebook.embed + + # only calculate orthogonal loss for the activated codes for this batch + + if self.orthogonal_reg_active_codes_only: + assert not ( + is_multiheaded and self.separate_codebook_per_head + ), "orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet" + unique_code_ids = torch.unique(embed_ind) + codebook = codebook[:, unique_code_ids] + + num_codes = codebook.shape[-2] + + if (self.orthogonal_reg_max_codes is not None) and num_codes > self.orthogonal_reg_max_codes: + rand_ids = torch.randperm(num_codes, device=device)[: self.orthogonal_reg_max_codes] + codebook = codebook[:, rand_ids] + + orthogonal_reg_loss = orthogonal_loss_fn(codebook) + loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight + + # handle multi-headed quantized embeddings + + if is_multiheaded: + if self.separate_codebook_per_head: + quantize = rearrange(quantize, "h b n d -> b n (h d)", h=heads) + else: + quantize = rearrange(quantize, "1 (b h) n d -> b n (h d)", h=heads) + + # project out + + quantize = self.project_out(quantize) + + # rearrange quantized embeddings + + if need_transpose: + quantize = rearrange(quantize, "b n d -> b d n") + + if self.accept_image_fmap: + quantize = rearrange(quantize, "b (h w) c -> b c h w", h=height, w=width) + + if only_one: + quantize = rearrange(quantize, "b 1 d -> b d") + + # if masking, only return quantized for where mask has True + + if mask is not None: + quantize = torch.where(rearrange(mask, "... -> ... 1"), quantize, orig_input) + + return quantize, embed_ind, loss + + +def noop(*args, **kwargs): + pass + + +def identity(t): + return t + + +def cdist(x, y): + x2 = reduce(x**2, "b n d -> b n", "sum") + y2 = reduce(y**2, "b n d -> b n", "sum") + xy = einsum("b i d, b j d -> b i j", x, y) * -2 + return (rearrange(x2, "b i -> b i 1") + rearrange(y2, "b j -> b 1 j") + xy).sqrt() + + +def log(t, eps=1e-20): + return torch.log(t.clamp(min=eps)) + + +def ema_inplace(old, new, decay): + is_mps = str(old.device).startswith("mps:") + + if not is_mps: + old.lerp_(new, 1 - decay) + else: + old.mul_(decay).add_(new * (1 - decay)) + + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + + +def uniform_init(*shape): + t = torch.empty(shape) + nn.init.kaiming_uniform_(t) + return t + + +def gumbel_noise(t): + noise = torch.zeros_like(t).uniform_(0, 1) + return -log(-log(noise)) + + +def gumbel_sample( + logits, + temperature=1.0, + stochastic=False, + straight_through=False, + reinmax=False, + dim=-1, + training=True, +): + dtype, size = logits.dtype, logits.shape[dim] + + if training and stochastic and temperature > 0: + sampling_logits = (logits / temperature) + gumbel_noise(logits) + else: + sampling_logits = logits + + ind = sampling_logits.argmax(dim=dim) + one_hot = F.one_hot(ind, size).type(dtype) + + assert not ( + reinmax and not straight_through + ), "reinmax can only be turned on if using straight through gumbel softmax" + + if not straight_through or temperature <= 0.0 or not training: + return ind, one_hot + + # use reinmax for better second-order accuracy - https://arxiv.org/abs/2304.08612 + # algorithm 2 + + if reinmax: + π0 = logits.softmax(dim=dim) + π1 = (one_hot + (logits / temperature).softmax(dim=dim)) / 2 + π1 = ((log(π1) - logits).detach() + logits).softmax(dim=1) + π2 = 2 * π1 - 0.5 * π0 + one_hot = π2 - π2.detach() + one_hot + else: + π1 = (logits / temperature).softmax(dim=dim) + one_hot = one_hot + π1 - π1.detach() + + return ind, one_hot + + +def laplace_smoothing(x, n_categories, eps=1e-5, dim=-1): + denom = x.sum(dim=dim, keepdim=True) + return (x + eps) / (denom + n_categories * eps) + + +def sample_vectors(samples, num): + num_samples, device = samples.shape[0], samples.device + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device=device) + + return samples[indices] + + +def batched_sample_vectors(samples, num): + return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0) + + +def pad_shape(shape, size, dim=0): + return [size if i == dim else s for i, s in enumerate(shape)] + + +def sample_multinomial(total_count, probs): + device = probs.device + probs = probs.cpu() + + total_count = probs.new_full((), total_count) + remainder = probs.new_ones(()) + sample = torch.empty_like(probs, dtype=torch.long) + + for i, p in enumerate(probs): + s = torch.binomial(total_count, p / remainder) + sample[i] = s + total_count -= s + remainder -= p + + return sample.to(device) + + +def all_gather_sizes(x, dim): + size = torch.tensor(x.shape[dim], dtype=torch.long, device=x.device) + all_sizes = [torch.empty_like(size) for _ in range(distributed.get_world_size())] + distributed.all_gather(all_sizes, size) + return torch.stack(all_sizes) + + +def all_gather_variably_sized(x, sizes, dim=0): + rank = distributed.get_rank() + all_x = [] + + for i, size in enumerate(sizes): + t = x if i == rank else x.new_empty(pad_shape(x.shape, size, dim)) + distributed.broadcast(t, src=i, async_op=True) + all_x.append(t) + + distributed.barrier() + return all_x + + +def sample_vectors_distributed(local_samples, num): + local_samples = rearrange(local_samples, "1 ... -> ...") + + rank = distributed.get_rank() + all_num_samples = all_gather_sizes(local_samples, dim=0) + + if rank == 0: + samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum()) + else: + samples_per_rank = torch.empty_like(all_num_samples) + + distributed.broadcast(samples_per_rank, src=0) + samples_per_rank = samples_per_rank.tolist() + + local_samples = sample_vectors(local_samples, samples_per_rank[rank]) + all_samples = all_gather_variably_sized(local_samples, samples_per_rank, dim=0) + out = torch.cat(all_samples, dim=0) + + return rearrange(out, "... -> 1 ...") + + +def batched_bincount(x, *, minlength): + batch, dtype, device = x.shape[0], x.dtype, x.device + target = torch.zeros(batch, minlength, dtype=dtype, device=device) + values = torch.ones_like(x) + target.scatter_add_(-1, x, values) + return target + + +def kmeans( + samples, + num_clusters, + num_iters=10, + sample_fn=batched_sample_vectors, + all_reduce_fn=noop, +): + num_codebooks, dim, dtype, _device = ( + samples.shape[0], + samples.shape[-1], + samples.dtype, + samples.device, + ) + + means = sample_fn(samples, num_clusters) + + for _ in range(num_iters): + dists = -torch.cdist(samples, means, p=2) + + buckets = torch.argmax(dists, dim=-1) + bins = batched_bincount(buckets, minlength=num_clusters) + all_reduce_fn(bins) + + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_codebooks, num_clusters, dim, dtype=dtype) + + new_means.scatter_add_(1, repeat(buckets, "h n -> h n d", d=dim), samples) + new_means = new_means / rearrange(bins_min_clamped, "... -> ... 1") + all_reduce_fn(new_means) + + means = torch.where(rearrange(zero_mask, "... -> ... 1"), means, new_means) + + return means, bins + + +def batched_embedding(indices, embeds): + batch, dim = indices.shape[1], embeds.shape[-1] + indices = repeat(indices, "h b n -> h b n d", d=dim) + embeds = repeat(embeds, "h c d -> h b c d", b=batch) + return embeds.gather(2, indices) + + +def orthogonal_loss_fn(t): + # eq (2) from https://arxiv.org/abs/2112.00384 + h, n = t.shape[:2] + normed_codes = F.normalize(t, p=2, dim=-1) + cosine_sim = einsum("h i d, h j d -> h i j", normed_codes, normed_codes) + return (cosine_sim**2).sum() / (h * n**2) - (1 / n) + + +class EuclideanCodebook(nn.Module): + def __init__( + self, + dim, + codebook_size, + num_codebooks=1, + kmeans_init=False, + kmeans_iters=10, + sync_kmeans=True, + decay=0.8, + eps=1e-5, + threshold_ema_dead_code=2, + reset_cluster_size=None, + use_ddp=False, + learnable_codebook=False, + gumbel_sample=gumbel_sample, + sample_codebook_temp=1.0, + ema_update=True, + affine_param=False, + sync_affine_param=False, + affine_param_batch_decay=0.99, + affine_param_codebook_decay=0.9, + ): + super().__init__() + self.transform_input = identity + + self.decay = decay + self.ema_update = ema_update + + init_fn = uniform_init if not kmeans_init else torch.zeros + embed = init_fn(num_codebooks, codebook_size, dim) + + self.codebook_size = codebook_size + self.num_codebooks = num_codebooks + + self.kmeans_iters = kmeans_iters + self.eps = eps + self.threshold_ema_dead_code = threshold_ema_dead_code + self.reset_cluster_size = ( + reset_cluster_size if (reset_cluster_size is not None) else threshold_ema_dead_code + ) + + assert callable(gumbel_sample) + self.gumbel_sample = gumbel_sample + self.sample_codebook_temp = sample_codebook_temp + + assert not ( + use_ddp and num_codebooks > 1 and kmeans_init + ), "kmeans init is not compatible with multiple codebooks in distributed environment for now" + + self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors + self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop + self.all_reduce_fn = distributed.all_reduce if use_ddp else noop + + self.register_buffer("initted", torch.Tensor([not kmeans_init])) + self.register_buffer("cluster_size", torch.zeros(num_codebooks, codebook_size)) + self.register_buffer("embed_avg", embed.clone()) + + self.learnable_codebook = learnable_codebook + if learnable_codebook: + self.embed = nn.Parameter(embed) + else: + self.register_buffer("embed", embed) + + # affine related params + + self.affine_param = affine_param + self.sync_affine_param = sync_affine_param + + if not affine_param: + return + + self.affine_param_batch_decay = affine_param_batch_decay + self.affine_param_codebook_decay = affine_param_codebook_decay + + self.register_buffer("batch_mean", None) + self.register_buffer("batch_variance", None) + + self.register_buffer("codebook_mean_needs_init", torch.Tensor([True])) + self.register_buffer("codebook_mean", torch.empty(num_codebooks, 1, dim)) + self.register_buffer("codebook_variance_needs_init", torch.Tensor([True])) + self.register_buffer("codebook_variance", torch.empty(num_codebooks, 1, dim)) + + @torch.jit.ignore + def init_embed_(self, data, mask=None): + if self.initted: + return + + if mask is not None: + c = data.shape[0] + data = rearrange(data[mask], "(c n) d -> c n d", c=c) + + embed, cluster_size = kmeans( + data, + self.codebook_size, + self.kmeans_iters, + sample_fn=self.sample_fn, + all_reduce_fn=self.kmeans_all_reduce_fn, + ) + + embed_sum = embed * rearrange(cluster_size, "... -> ... 1") + + self.embed.data.copy_(embed) + self.embed_avg.data.copy_(embed_sum) + self.cluster_size.data.copy_(cluster_size) + self.initted.data.copy_(torch.Tensor([True])) + + @torch.jit.ignore + def update_with_decay(self, buffer_name, new_value, decay): + old_value = getattr(self, buffer_name) + + needs_init = getattr(self, buffer_name + "_needs_init", False) + + if needs_init: + self.register_buffer(buffer_name + "_needs_init", torch.Tensor([False])) + + if not (old_value is not None) or needs_init: + self.register_buffer(buffer_name, new_value.detach()) + + return + + value = old_value * decay + new_value.detach() * (1 - decay) + self.register_buffer(buffer_name, value) + + @torch.jit.ignore + def update_affine(self, data, embed, mask=None): + assert self.affine_param + + var_fn = partial(torch.var, unbiased=False) + + # calculate codebook mean and variance + + embed = rearrange(embed, "h ... d -> h (...) d") + + if self.training: + self.update_with_decay( + "codebook_mean", + reduce(embed, "h n d -> h 1 d", "mean"), + self.affine_param_codebook_decay, + ) + self.update_with_decay( + "codebook_variance", + reduce(embed, "h n d -> h 1 d", var_fn), + self.affine_param_codebook_decay, + ) + + # prepare batch data, which depends on whether it has masking + + data = rearrange(data, "h ... d -> h (...) d") + + if mask is not None: + c = data.shape[0] + data = rearrange(data[mask], "(c n) d -> c n d", c=c) + + # calculate batch mean and variance + + if not self.sync_affine_param: + self.update_with_decay( + "batch_mean", + reduce(data, "h n d -> h 1 d", "mean"), + self.affine_param_batch_decay, + ) + self.update_with_decay( + "batch_variance", + reduce(data, "h n d -> h 1 d", var_fn), + self.affine_param_batch_decay, + ) + return + + num_vectors, device, dtype = data.shape[-2], data.device, data.dtype + + # number of vectors, for denominator + + num_vectors = torch.tensor([num_vectors], device=device, dtype=dtype) + distributed.all_reduce(num_vectors) + + # calculate distributed mean + + batch_sum = reduce(data, "h n d -> h 1 d", "sum") + distributed.all_reduce(batch_sum) + batch_mean = batch_sum / num_vectors + + self.update_with_decay("batch_mean", batch_mean, self.affine_param_batch_decay) + + # calculate distributed variance + + variance_numer = reduce((data - batch_mean) ** 2, "h n d -> h 1 d", "sum") + distributed.all_reduce(variance_numer) + batch_variance = variance_numer / num_vectors + + self.update_with_decay("batch_variance", batch_variance, self.affine_param_batch_decay) + + def replace(self, batch_samples, batch_mask): + for ind, (samples, mask) in enumerate( + zip(batch_samples.unbind(dim=0), batch_mask.unbind(dim=0), strict=False) + ): + if not torch.any(mask): + continue + + sampled = self.sample_fn(rearrange(samples, "... -> 1 ..."), mask.sum().item()) + sampled = rearrange(sampled, "1 ... -> ...") + + self.embed.data[ind][mask] = sampled + + self.cluster_size.data[ind][mask] = self.reset_cluster_size + self.embed_avg.data[ind][mask] = sampled * self.reset_cluster_size + + def expire_codes_(self, batch_samples): + if self.threshold_ema_dead_code == 0: + return + + expired_codes = self.cluster_size < self.threshold_ema_dead_code + + if not torch.any(expired_codes): + return + + batch_samples = rearrange(batch_samples, "h ... d -> h (...) d") + self.replace(batch_samples, batch_mask=expired_codes) + + @autocast(enabled=False) + def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False): + needs_codebook_dim = x.ndim < 4 + sample_codebook_temp = ( + sample_codebook_temp if (sample_codebook_temp is not None) else self.sample_codebook_temp + ) + + x = x.float() + + if needs_codebook_dim: + x = rearrange(x, "... -> 1 ...") + + flatten, ps = pack_one(x, "h * d") + + if mask is not None: + mask = repeat( + mask, + "b n -> c (b h n)", + c=flatten.shape[0], + h=flatten.shape[-2] // (mask.shape[0] * mask.shape[1]), + ) + + self.init_embed_(flatten, mask=mask) + + if self.affine_param: + self.update_affine(flatten, self.embed, mask=mask) + + embed = self.embed if self.learnable_codebook else self.embed.detach() + + if self.affine_param: + codebook_std = self.codebook_variance.clamp(min=1e-5).sqrt() + batch_std = self.batch_variance.clamp(min=1e-5).sqrt() + embed = (embed - self.codebook_mean) * (batch_std / codebook_std) + self.batch_mean + + dist = -cdist(flatten, embed) + + embed_ind, embed_onehot = self.gumbel_sample( + dist, dim=-1, temperature=sample_codebook_temp, training=self.training + ) + + embed_ind = unpack_one(embed_ind, ps, "h *") + + if self.training: + unpacked_onehot = unpack_one(embed_onehot, ps, "h * c") + quantize = einsum("h b n c, h c d -> h b n d", unpacked_onehot, embed) + else: + quantize = batched_embedding(embed_ind, embed) + + if self.training and self.ema_update and not freeze_codebook: + if self.affine_param: + flatten = (flatten - self.batch_mean) * (codebook_std / batch_std) + self.codebook_mean + + if mask is not None: + embed_onehot[~mask] = 0.0 + + cluster_size = embed_onehot.sum(dim=1) + + self.all_reduce_fn(cluster_size) + ema_inplace(self.cluster_size.data, cluster_size, self.decay) + + embed_sum = einsum("h n d, h n c -> h c d", flatten, embed_onehot) + self.all_reduce_fn(embed_sum.contiguous()) + ema_inplace(self.embed_avg.data, embed_sum, self.decay) + + cluster_size = laplace_smoothing( + self.cluster_size, self.codebook_size, self.eps + ) * self.cluster_size.sum(dim=-1, keepdim=True) + + embed_normalized = self.embed_avg / rearrange(cluster_size, "... -> ... 1") + self.embed.data.copy_(embed_normalized) + self.expire_codes_(x) + + if needs_codebook_dim: + quantize, embed_ind = tuple(rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind)) + + dist = unpack_one(dist, ps, "h * d") + + return quantize, embed_ind, dist diff --git a/lerobot/configs/policy/vqbet.yaml b/lerobot/configs/policy/vqbet.yaml new file mode 100644 index 00000000..a8b530e2 --- /dev/null +++ b/lerobot/configs/policy/vqbet.yaml @@ -0,0 +1,104 @@ +# @package _global_ + +# Defaults for training for the PushT dataset. + +seed: 100000 +dataset_repo_id: lerobot/pusht + +override_dataset_stats: + # TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model? + observation.image: + mean: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1) + std: [[[0.5]], [[0.5]], [[0.5]]] # (c,1,1) + # TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model + # from the original codebase, but we should remove these and train our own pretrained model + observation.state: + min: [13.456424, 32.938293] + max: [496.14618, 510.9579] + action: + min: [12.0, 25.0] + max: [511.0, 511.0] + +training: + offline_steps: 250000 + online_steps: 0 + eval_freq: 20000 + save_freq: 20000 + log_freq: 250 + save_checkpoint: true + + batch_size: 64 + grad_clip_norm: 10 + lr: 1.0e-4 + lr_scheduler: cosine + lr_warmup_steps: 500 + adam_betas: [0.95, 0.999] + adam_eps: 1.0e-8 + adam_weight_decay: 1.0e-6 + online_steps_between_rollouts: 1 + + # VQ-BeT specific + vqvae_lr: 1.0e-3 + n_vqvae_training_steps: 20000 + bet_weight_decay: 2e-4 + bet_learning_rate: 5.5e-5 + bet_betas: [0.9, 0.999] + + delta_timestamps: + observation.image: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]" + observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]" + action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, ${policy.n_action_pred_token} + ${policy.action_chunk_size} - 1)]" + +eval: + n_episodes: 50 + batch_size: 50 + +policy: + name: vqbet + + # Input / output structure. + n_obs_steps: 5 + n_action_pred_token: 7 + action_chunk_size: 5 + + input_shapes: + # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? + observation.image: [3, 96, 96] + observation.state: ["${env.state_dim}"] + output_shapes: + action: ["${env.action_dim}"] + + # Normalization / Unnormalization + input_normalization_modes: + observation.image: mean_std + observation.state: min_max + output_normalization_modes: + action: min_max + + # Architecture / modeling. + # Vision backbone. + vision_backbone: resnet18 + crop_shape: [84, 84] + crop_is_random: True + pretrained_backbone_weights: null + use_group_norm: True + spatial_softmax_num_keypoints: 32 + # VQ-VAE + n_vqvae_training_steps: ${training.n_vqvae_training_steps} + vqvae_n_embed: 16 + vqvae_embedding_dim: 256 + vqvae_enc_hidden_dim: 128 + # VQ-BeT + gpt_block_size: 500 + gpt_input_dim: 512 + gpt_output_dim: 512 + gpt_n_layer: 8 + gpt_n_head: 8 + gpt_hidden_dim: 512 + dropout: 0.1 + mlp_hidden_dim: 1024 + offset_loss_weight: 10000. + primary_code_loss_weight: 5.0 + secondary_code_loss_weight: 0.5 + bet_softmax_temperature: 0.1 + sequentially_select: False diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 796881c4..4e636db8 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -88,6 +88,11 @@ def make_optimizer_and_scheduler(cfg, policy): elif policy.name == "tdmpc": optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr) lr_scheduler = None + elif cfg.policy.name == "vqbet": + from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler + + optimizer = VQBeTOptimizer(policy, cfg) + lr_scheduler = VQBeTScheduler(optimizer, cfg) else: raise NotImplementedError() diff --git a/tests/test_available.py b/tests/test_available.py index db5bd520..f4f9d4de 100644 --- a/tests/test_available.py +++ b/tests/test_available.py @@ -22,6 +22,7 @@ import lerobot from lerobot.common.policies.act.modeling_act import ACTPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy +from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy from tests.utils import require_env @@ -48,6 +49,7 @@ def test_available_policies(): ACTPolicy, DiffusionPolicy, TDMPCPolicy, + VQBeTPolicy, ] policies = [pol_cls.name for pol_cls in policy_classes] assert set(policies) == set(lerobot.available_policies), policies diff --git a/tests/test_policies.py b/tests/test_policies.py index fdc74751..490c25cc 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -49,6 +49,7 @@ def test_get_policy_and_config_classes(policy_name: str): [ ("xarm", "tdmpc", ["policy.use_mpc=true", "dataset_repo_id=lerobot/xarm_lift_medium"]), ("pusht", "diffusion", []), + ("pusht", "vqbet", []), ("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_human"]), ( "aloha",