From 547d3c3ea1ff5e0f4258791328253abbae5ccecb Mon Sep 17 00:00:00 2001 From: jayLEE0301 Date: Fri, 24 May 2024 18:06:12 -0400 Subject: [PATCH] add comments, add primary_code_loss_weight --- .../policies/vqbet/configuration_vqbet.py | 4 +- .../common/policies/vqbet/modeling_vqbet.py | 198 +++++++++++------- lerobot/configs/policy/vqbet.yaml | 3 +- 3 files changed, 126 insertions(+), 79 deletions(-) diff --git a/lerobot/common/policies/vqbet/configuration_vqbet.py b/lerobot/common/policies/vqbet/configuration_vqbet.py index 477bb789..34002911 100644 --- a/lerobot/common/policies/vqbet/configuration_vqbet.py +++ b/lerobot/common/policies/vqbet/configuration_vqbet.py @@ -55,7 +55,8 @@ class VQBeTConfig: dropout: Dropout rate for GPT mlp_hidden_dim: Size of hidden dimensions of offset header / bin prediction headers parts of VQ-BeT offset_loss_weight: A constant that is multiplied to the offset loss - secondary_code_loss_weight: A constant that is multiplied to the secondary loss + primary_code_loss_weight: A constant that is multiplied to the primary code prediction loss + secondary_code_loss_weight: A constant that is multiplied to the secondary code prediction loss bet_softmax_temperature: Sampling temperature of code for rollout with VQ-BeT """ @@ -110,6 +111,7 @@ class VQBeTConfig: dropout: float = 0.1 mlp_hidden_dim: int = 1024 offset_loss_weight: float = 10000. + primary_code_loss_weight: float = 5.0 secondary_code_loss_weight: float = 0.5 bet_softmax_temperature: float = 0.1 diff --git a/lerobot/common/policies/vqbet/modeling_vqbet.py b/lerobot/common/policies/vqbet/modeling_vqbet.py index 8ea6353c..c5ab1797 100644 --- a/lerobot/common/policies/vqbet/modeling_vqbet.py +++ b/lerobot/common/policies/vqbet/modeling_vqbet.py @@ -1,20 +1,16 @@ -import os -from pathlib import Path from collections import deque -from typing import Callable, List, Optional +from typing import Callable, List from functools import partial -from itertools import zip_longest from random import randrange import math from math import ceil -from dataclasses import dataclass import warnings import einops from einops import rearrange, repeat, reduce, pack, unpack import numpy as np import torch -import torch.nn.functional as F # noqa: N812 +import torch.nn.functional as F import torchvision from torch import Tensor, nn, einsum import torch.distributed as distributed @@ -123,7 +119,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin): if not self.check_discretized(): loss, n_different_codes, n_different_combinations = self.vqbet.discretize(self.config.discretize_step, batch['action']) return {"loss": loss, "n_different_codes": n_different_codes, "n_different_combinations": n_different_combinations} - + # if Residual VQ is already trained, VQ-BeT trains its GPT and bin ped header / offset header parts. _, loss = self.vqbet(batch, rollout=False) return loss @@ -159,6 +155,8 @@ class VQBeTModel(nn.Module): Phase 2. + timestep {t-n+1} timestep {t-n+2} timestep {t} + ┌─────┴─────┐ ┌─────┴─────┐ ┌─────┴─────┐ o_{t-n+1} o_{t-n+2} ... o_{t} │ │ │ @@ -191,10 +189,10 @@ class VQBeTModel(nn.Module): self.rgb_encoder = VQBeTRgbEncoder(config) + # This action query token is used as a prompt for querying action chunks. Please refer to "A_Q" in the image above. + self._action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim)) - # action token and EOS token - self._action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim)) # Batch, Timestep, Data type, GPT input dim - + # To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT. self.state_projector = MLP( config.output_shapes["action"][0], hidden_channels=[self.config.gpt_input_dim] @@ -203,7 +201,10 @@ class VQBeTModel(nn.Module): self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim] ) + + # GPT part of VQ-BeT self.policy = GPT(config) + # bin prediction header / offset prediction header part of VQ-BeT self.action_head = VQBeTHead(config) def discretize(self, discretize_step, actions): @@ -220,7 +221,7 @@ class VQBeTModel(nn.Module): # Separate batch and sequence dims. img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size) - + # image observation feature, state feature, and action query token are grouped together with the same timestpe to form a group, which is listed in order to be entered into GPT sequentially. observation_feature = torch.cat([ torch.unsqueeze(self.obs_projector(img_features), dim=2), torch.unsqueeze(self.state_projector(batch["observation.state"]), dim=2), @@ -230,24 +231,28 @@ class VQBeTModel(nn.Module): raise NotImplementedError len_additional_action_token = self.config.n_action_pred_token-1 action_token = self._action_token.repeat(batch_size, len_additional_action_token, 1) - + + # add additional action query tokens for predicting future action chunks observation_feature = torch.cat([observation_feature, action_token], dim=1) - # get action features + # get action features (pass through GPT) features = self.policy(observation_feature) historical_act_pred_index = np.arange(0, n_obs_steps) * (self.config.gpt_num_obs_mode+1) + self.config.gpt_num_obs_mode + + # only extract the output tokens at the position of action query features = torch.cat([ features[:, historical_act_pred_index], features[:, -len_additional_action_token:] ], dim=1) - # action head + # pass through action head pred_action = self.action_head( features, ) - + # if rollout, VQ-BeT don't calculate loss if rollout: return pred_action["predicted_action"][:, n_obs_steps-1, :].reshape(batch_size, self.config.n_action_pred_chunk, -1) + # else, it calculate overall loss (bin prediction loss, and offset loss) else: action = batch["action"] n, total_w, act_dim = action.shape @@ -270,7 +275,21 @@ class VQBeTModel(nn.Module): class VQBeTHead(nn.Module): def __init__(self, config: VQBeTConfig): """ - TODO: add explanation for each value. + VQBeTHead takes output of GPT layers, and pass the feature through bin prediction head (`self.map_to_cbet_preds_bin`), and offset prediction head (`self.map_to_cbet_preds_offset`) + + self.map_to_cbet_preds_bin: outputs probability of each code (for each layer). + The input dimension of `self.map_to_cbet_preds_bin` is same with the output of GPT, + and the output dimension of `self.map_to_cbet_preds_bin` is `self.config.vqvae_groups * self.config.vqvae_n_embed`, where + `self.config.vqvae_groups` is number of RVQ layers, and + `self.config.vqvae_n_embed` is codebook size of RVQ. + + self.map_to_cbet_preds_offset: output the predicted offsets for all the codes in all the layers. + The input dimension of ` self.map_to_cbet_preds_offset` is same with the output of GPT, + and the output dimension of ` self.map_to_cbet_preds_offset` is `self.config.vqvae_groups * self.config.vqvae_n_embed * config.n_action_pred_chunk * config.output_shapes["action"][0]`, where + `self.config.vqvae_groups` is number of RVQ layers, + `self.config.vqvae_n_embed` is codebook size of RVQ, + `config.n_action_pred_chunk is action chunk size of each token, and + `config.output_shapes["action"][0]` is the dimension of action """ super().__init__() @@ -301,6 +320,7 @@ class VQBeTHead(nn.Module): self.vqvae_model.device = get_device_from_parameters(self) loss, n_different_codes, n_different_combinations = pretrain_vqvae(self.vqvae_model, discretize_step, actions) + # if we updated RVQ more than `discretize_step` steps, if self.vqvae_model.discretized: print("Finished discretizing action data!") self.vqvae_model.eval() @@ -309,7 +329,10 @@ class VQBeTHead(nn.Module): return loss, n_different_codes, n_different_combinations def forward(self, x, **kwargs): + # N is the batch size, and T is number of action query tokens, which are process through same GPT N, T, _ = x.shape + # we calculate N and T side parallely. Thus, the dimensions would be + # (batch size * number of action query tokens, action chunk size, action dimension) x = einops.rearrange(x, "N T WA -> (N T) WA") cbet_logits = self.map_to_cbet_preds_bin(x) @@ -334,23 +357,27 @@ class VQBeTHead(nn.Module): sampled_centers, ) # Use advanced indexing to sample the values - sampled_offsets = cbet_offsets[indices] # NT, G, W, A(?) or NT, G, A - + sampled_offsets = cbet_offsets[indices] + # Extract the only offsets corresponding to the sampled codes. sampled_offsets = sampled_offsets.sum(dim=1) + # Get the centroids of each layer to pass it through RVQ decoder centers = self.vqvae_model.draw_code_forward(sampled_centers).view( NT, -1, self.config.vqvae_embedding_dim ) return_decoder_input = einops.rearrange( centers.clone().detach(), "NT 1 D -> NT D" ) + # pass the centroids through decoder to get actions. decoded_action = ( self.vqvae_model.get_action_from_latent(return_decoder_input) .clone() .detach() - ) # NT, A + ) + # reshaped extracted offset to match with decoded centroids sampled_offsets = einops.rearrange( sampled_offsets, "NT (W A) -> NT W A", W=self.config.n_action_pred_chunk ) + # add offset and decoded centroids predicted_action = decoded_action + sampled_offsets predicted_action = einops.rearrange( predicted_action, @@ -368,7 +395,16 @@ class VQBeTHead(nn.Module): } def loss_fn(self, pred, target, **kwargs): - # Rename the inputs for clarity. + """ + for given ground truth action values (target), and prediction (pred) this function calculates the overall loss. + + predicted_action: predicted action chunk (offset + decoded centroids) + sampled_centers: sampled centroids (code of RVQ) + decoded_action: decoded action, which is produced by passing sampled_centers through RVQ decoder + NT: batch size * T + T: number of action query tokens, which are process through same GPT + cbet_logits: probability of all codes in each layer + """ action_seq = target predicted_action = pred["predicted_action"] sampled_centers = pred["sampled_centers"] @@ -383,7 +419,7 @@ class VQBeTHead(nn.Module): action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A") # Figure out the loss for the actions. - # First, we need to find the closest cluster center for each action. + # First, we need to find the closest cluster center for each ground truth action. state_vq, action_bins = self.vqvae_model.get_code( action_seq ) # action_bins: NT, G @@ -392,18 +428,21 @@ class VQBeTHead(nn.Module): if action_seq.ndim == 2: action_seq = action_seq.unsqueeze(0) + # offset loss is L1 distance between the predicted action and ground truth action offset_loss = torch.nn.L1Loss()(action_seq, predicted_action) - + # calculate primary code prediction loss cbet_loss1 = self._criterion( # F.cross_entropy cbet_logits[:, 0, :], action_bins[:, 0], ) + # calculate secondary code prediction loss cbet_loss2 = self._criterion( # F.cross_entropy cbet_logits[:, 1, :], action_bins[:, 1], ) - cbet_loss = cbet_loss1 * 5 + cbet_loss2 * self.config.secondary_code_loss_weight + # add all the prediction loss + cbet_loss = cbet_loss1 * self.config.primary_code_loss_weight + cbet_loss2 * self.config.secondary_code_loss_weight equal_primary_code_rate = torch.sum( (action_bins[:, 0] == sampled_centers[:, 0]).int() @@ -416,21 +455,9 @@ class VQBeTHead(nn.Module): einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T), einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T), ) - vq_action_error = ( - abs( - einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(decoded_action, "(N T) W A -> N T W A", T=T) - ) - ).mean() - offset_action_error = ( - abs( - einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T) - ) - ).mean() - action_error_max = ( - abs( - einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T) - ) - ).max() + vq_action_error = (abs(einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(decoded_action, "(N T) W A -> N T W A", T=T))).mean() + offset_action_error = (abs(einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T))).mean() + action_error_max = (abs(einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T))).max() loss = cbet_loss + self.config.offset_loss_weight * offset_loss @@ -506,35 +533,30 @@ class VQBeTOptimizer: def step(self): self.optimizing_step +=1 + # pretraining VQ-VAE (phase 1) if self.optimizing_step < self.discretize_step: - # pretraining VQ-VAE self.vqvae_optimizer.step() + # training BeT (phase 2) else: - # training BeT - if self.optimizing_step < 0.6 * self.offline_steps: - self.encoder_optimizer.step() - self.bet_optimizer1.step() - self.bet_optimizer2.step() - self.bet_optimizer3.step() - else: - self.bet_optimizer3.step() + self.encoder_optimizer.step() + self.bet_optimizer1.step() + self.bet_optimizer2.step() + self.bet_optimizer3.step() def zero_grad(self): + # pretraining VQ-VAE (phase 1) if self.optimizing_step < self.discretize_step: - # pretraining VQ-VAE self.vqvae_optimizer.zero_grad() + # training BeT (phase 2) else: - # training BeT - if self.optimizing_step < 0.6 * self.offline_steps: - self.encoder_optimizer.zero_grad() - self.bet_optimizer1.zero_grad() - self.bet_optimizer2.zero_grad() - self.bet_optimizer3.zero_grad() - else: - self.bet_optimizer3.zero_grad() + self.encoder_optimizer.zero_grad() + self.bet_optimizer1.zero_grad() + self.bet_optimizer2.zero_grad() + self.bet_optimizer3.zero_grad() class VQBeTScheduler: def __init__(self, optimizer, cfg): + # VQ-BeT use scheduler only for rgb encoder. Since we took rgb encoder part from diffusion policy, we also follow the same scheduler from it. from diffusers.optimization import get_scheduler self.discretize_step = cfg.training.discretize_step self.optimizing_step = 0 @@ -663,6 +685,11 @@ class VqVae(nn.Module): def __init__( self, config: VQBeTConfig, ): + """ + VQ-VAE is composed of three parts: encoder, vq_layer, and decoder. + Encoder and decoder are MLPs consisting of an input, output layer, and hidden layer, respectively. + The vq_layer uses residual VQs. + """ super(VqVae, self).__init__() self.config = config @@ -676,24 +703,14 @@ class VqVae(nn.Module): codebook_size=config.vqvae_n_embed, ) - if self.config.n_action_pred_chunk == 1: - self.encoder = MLP( - in_channels=self.config.output_shapes["action"][0], - hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, config.vqvae_embedding_dim], - ) - self.decoder = MLP( - in_channels=config.vqvae_embedding_dim, - hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.config.output_shapes["action"][0]], - ) - else: - self.encoder = MLP( - in_channels=self.config.output_shapes["action"][0] * self.config.n_action_pred_chunk, - hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, config.vqvae_embedding_dim], - ) - self.decoder = MLP( - in_channels=config.vqvae_embedding_dim, - hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.config.output_shapes["action"][0] * self.config.n_action_pred_chunk], - ) + self.encoder = MLP( + in_channels=self.config.output_shapes["action"][0] * self.config.n_action_pred_chunk, + hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, config.vqvae_embedding_dim], + ) + self.decoder = MLP( + in_channels=config.vqvae_embedding_dim, + hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.config.output_shapes["action"][0] * self.config.n_action_pred_chunk], + ) self.train() @@ -704,6 +721,11 @@ class VqVae(nn.Module): self.decoder.eval() def train(self, mode=True): + """ + This function forces the RVQ to no longer update when action discretization is complete. + Since VQs are partly updated via the EMA method, simply passing data through them can cause unintended modifications. + Therefore, we use function overriding to prevent RVQs from being updated during the training of VQ-BeT after discretization completes. + """ if mode: if self.discretized: pass @@ -736,7 +758,7 @@ class VqVae(nn.Module): if not torch.is_tensor(state): state = torch.FloatTensor(state.copy()) if self.config.n_action_pred_chunk == 1: - state = state.squeeze(-2) # state.squeeze(-1) + state = state.squeeze(-2) else: state = einops.rearrange(state, "N T A -> N (T A)") return state @@ -764,7 +786,6 @@ class VqVae(nn.Module): torch.swapaxes(recon_state_ae, -2, -1), ) else: - # econ_from_code = self.draw_code_forward(vq_code) return state_vq, vq_code def vqvae_forward(self, state): @@ -815,7 +836,7 @@ def pretrain_vqvae(vqvae_model, discretize_step, actions): loss, metric = vqvae_model.vqvae_forward( actions - ) # N T D + ) n_different_codes = len(torch.unique(metric[2])) n_different_combinations = len(torch.unique(metric[2], dim=0)) vqvae_model.optimized_steps += 1 @@ -837,6 +858,29 @@ def round_up_multiple(num, mult): class ResidualVQ(nn.Module): + """ + MIT License + + Copyright (c) 2020 Phil Wang + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + """ """Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf""" def __init__( diff --git a/lerobot/configs/policy/vqbet.yaml b/lerobot/configs/policy/vqbet.yaml index 524f1a21..8956ad4f 100644 --- a/lerobot/configs/policy/vqbet.yaml +++ b/lerobot/configs/policy/vqbet.yaml @@ -1,6 +1,6 @@ # @package _global_ -# Defaults for training for the PushT dataset as per https://github.com/real-stanford/diffusion_policy. +# Defaults for training for the PushT dataset. seed: 100000 dataset_repo_id: lerobot/pusht @@ -100,5 +100,6 @@ policy: dropout: 0.1 mlp_hidden_dim: 1024 offset_loss_weight: 10000. + primary_code_loss_weight: 5.0 secondary_code_loss_weight: 0.5 bet_softmax_temperature: 0.1 \ No newline at end of file