add comments, add primary_code_loss_weight

This commit is contained in:
jayLEE0301 2024-05-24 18:06:12 -04:00
parent 71ec76fe2a
commit 547d3c3ea1
3 changed files with 126 additions and 79 deletions

View File

@ -55,7 +55,8 @@ class VQBeTConfig:
dropout: Dropout rate for GPT
mlp_hidden_dim: Size of hidden dimensions of offset header / bin prediction headers parts of VQ-BeT
offset_loss_weight: A constant that is multiplied to the offset loss
secondary_code_loss_weight: A constant that is multiplied to the secondary loss
primary_code_loss_weight: A constant that is multiplied to the primary code prediction loss
secondary_code_loss_weight: A constant that is multiplied to the secondary code prediction loss
bet_softmax_temperature: Sampling temperature of code for rollout with VQ-BeT
"""
@ -110,6 +111,7 @@ class VQBeTConfig:
dropout: float = 0.1
mlp_hidden_dim: int = 1024
offset_loss_weight: float = 10000.
primary_code_loss_weight: float = 5.0
secondary_code_loss_weight: float = 0.5
bet_softmax_temperature: float = 0.1

View File

@ -1,20 +1,16 @@
import os
from pathlib import Path
from collections import deque
from typing import Callable, List, Optional
from typing import Callable, List
from functools import partial
from itertools import zip_longest
from random import randrange
import math
from math import ceil
from dataclasses import dataclass
import warnings
import einops
from einops import rearrange, repeat, reduce, pack, unpack
import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
import torch.nn.functional as F
import torchvision
from torch import Tensor, nn, einsum
import torch.distributed as distributed
@ -123,7 +119,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
if not self.check_discretized():
loss, n_different_codes, n_different_combinations = self.vqbet.discretize(self.config.discretize_step, batch['action'])
return {"loss": loss, "n_different_codes": n_different_codes, "n_different_combinations": n_different_combinations}
# if Residual VQ is already trained, VQ-BeT trains its GPT and bin ped header / offset header parts.
_, loss = self.vqbet(batch, rollout=False)
return loss
@ -159,6 +155,8 @@ class VQBeTModel(nn.Module):
Phase 2.
timestep {t-n+1} timestep {t-n+2} timestep {t}
o_{t-n+1} o_{t-n+2} ... o_{t}
@ -191,10 +189,10 @@ class VQBeTModel(nn.Module):
self.rgb_encoder = VQBeTRgbEncoder(config)
# This action query token is used as a prompt for querying action chunks. Please refer to "A_Q" in the image above.
self._action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim))
# action token and EOS token
self._action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim)) # Batch, Timestep, Data type, GPT input dim
# To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT.
self.state_projector = MLP(
config.output_shapes["action"][0],
hidden_channels=[self.config.gpt_input_dim]
@ -203,7 +201,10 @@ class VQBeTModel(nn.Module):
self.rgb_encoder.feature_dim,
hidden_channels=[self.config.gpt_input_dim]
)
# GPT part of VQ-BeT
self.policy = GPT(config)
# bin prediction header / offset prediction header part of VQ-BeT
self.action_head = VQBeTHead(config)
def discretize(self, discretize_step, actions):
@ -220,7 +221,7 @@ class VQBeTModel(nn.Module):
# Separate batch and sequence dims.
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
# image observation feature, state feature, and action query token are grouped together with the same timestpe to form a group, which is listed in order to be entered into GPT sequentially.
observation_feature = torch.cat([
torch.unsqueeze(self.obs_projector(img_features), dim=2),
torch.unsqueeze(self.state_projector(batch["observation.state"]), dim=2),
@ -230,24 +231,28 @@ class VQBeTModel(nn.Module):
raise NotImplementedError
len_additional_action_token = self.config.n_action_pred_token-1
action_token = self._action_token.repeat(batch_size, len_additional_action_token, 1)
# add additional action query tokens for predicting future action chunks
observation_feature = torch.cat([observation_feature, action_token], dim=1)
# get action features
# get action features (pass through GPT)
features = self.policy(observation_feature)
historical_act_pred_index = np.arange(0, n_obs_steps) * (self.config.gpt_num_obs_mode+1) + self.config.gpt_num_obs_mode
# only extract the output tokens at the position of action query
features = torch.cat([
features[:, historical_act_pred_index],
features[:, -len_additional_action_token:]
], dim=1)
# action head
# pass through action head
pred_action = self.action_head(
features,
)
# if rollout, VQ-BeT don't calculate loss
if rollout:
return pred_action["predicted_action"][:, n_obs_steps-1, :].reshape(batch_size, self.config.n_action_pred_chunk, -1)
# else, it calculate overall loss (bin prediction loss, and offset loss)
else:
action = batch["action"]
n, total_w, act_dim = action.shape
@ -270,7 +275,21 @@ class VQBeTModel(nn.Module):
class VQBeTHead(nn.Module):
def __init__(self, config: VQBeTConfig):
"""
TODO: add explanation for each value.
VQBeTHead takes output of GPT layers, and pass the feature through bin prediction head (`self.map_to_cbet_preds_bin`), and offset prediction head (`self.map_to_cbet_preds_offset`)
self.map_to_cbet_preds_bin: outputs probability of each code (for each layer).
The input dimension of `self.map_to_cbet_preds_bin` is same with the output of GPT,
and the output dimension of `self.map_to_cbet_preds_bin` is `self.config.vqvae_groups * self.config.vqvae_n_embed`, where
`self.config.vqvae_groups` is number of RVQ layers, and
`self.config.vqvae_n_embed` is codebook size of RVQ.
self.map_to_cbet_preds_offset: output the predicted offsets for all the codes in all the layers.
The input dimension of ` self.map_to_cbet_preds_offset` is same with the output of GPT,
and the output dimension of ` self.map_to_cbet_preds_offset` is `self.config.vqvae_groups * self.config.vqvae_n_embed * config.n_action_pred_chunk * config.output_shapes["action"][0]`, where
`self.config.vqvae_groups` is number of RVQ layers,
`self.config.vqvae_n_embed` is codebook size of RVQ,
`config.n_action_pred_chunk is action chunk size of each token, and
`config.output_shapes["action"][0]` is the dimension of action
"""
super().__init__()
@ -301,6 +320,7 @@ class VQBeTHead(nn.Module):
self.vqvae_model.device = get_device_from_parameters(self)
loss, n_different_codes, n_different_combinations = pretrain_vqvae(self.vqvae_model, discretize_step, actions)
# if we updated RVQ more than `discretize_step` steps,
if self.vqvae_model.discretized:
print("Finished discretizing action data!")
self.vqvae_model.eval()
@ -309,7 +329,10 @@ class VQBeTHead(nn.Module):
return loss, n_different_codes, n_different_combinations
def forward(self, x, **kwargs):
# N is the batch size, and T is number of action query tokens, which are process through same GPT
N, T, _ = x.shape
# we calculate N and T side parallely. Thus, the dimensions would be
# (batch size * number of action query tokens, action chunk size, action dimension)
x = einops.rearrange(x, "N T WA -> (N T) WA")
cbet_logits = self.map_to_cbet_preds_bin(x)
@ -334,23 +357,27 @@ class VQBeTHead(nn.Module):
sampled_centers,
)
# Use advanced indexing to sample the values
sampled_offsets = cbet_offsets[indices] # NT, G, W, A(?) or NT, G, A
sampled_offsets = cbet_offsets[indices]
# Extract the only offsets corresponding to the sampled codes.
sampled_offsets = sampled_offsets.sum(dim=1)
# Get the centroids of each layer to pass it through RVQ decoder
centers = self.vqvae_model.draw_code_forward(sampled_centers).view(
NT, -1, self.config.vqvae_embedding_dim
)
return_decoder_input = einops.rearrange(
centers.clone().detach(), "NT 1 D -> NT D"
)
# pass the centroids through decoder to get actions.
decoded_action = (
self.vqvae_model.get_action_from_latent(return_decoder_input)
.clone()
.detach()
) # NT, A
)
# reshaped extracted offset to match with decoded centroids
sampled_offsets = einops.rearrange(
sampled_offsets, "NT (W A) -> NT W A", W=self.config.n_action_pred_chunk
)
# add offset and decoded centroids
predicted_action = decoded_action + sampled_offsets
predicted_action = einops.rearrange(
predicted_action,
@ -368,7 +395,16 @@ class VQBeTHead(nn.Module):
}
def loss_fn(self, pred, target, **kwargs):
# Rename the inputs for clarity.
"""
for given ground truth action values (target), and prediction (pred) this function calculates the overall loss.
predicted_action: predicted action chunk (offset + decoded centroids)
sampled_centers: sampled centroids (code of RVQ)
decoded_action: decoded action, which is produced by passing sampled_centers through RVQ decoder
NT: batch size * T
T: number of action query tokens, which are process through same GPT
cbet_logits: probability of all codes in each layer
"""
action_seq = target
predicted_action = pred["predicted_action"]
sampled_centers = pred["sampled_centers"]
@ -383,7 +419,7 @@ class VQBeTHead(nn.Module):
action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A")
# Figure out the loss for the actions.
# First, we need to find the closest cluster center for each action.
# First, we need to find the closest cluster center for each ground truth action.
state_vq, action_bins = self.vqvae_model.get_code(
action_seq
) # action_bins: NT, G
@ -392,18 +428,21 @@ class VQBeTHead(nn.Module):
if action_seq.ndim == 2:
action_seq = action_seq.unsqueeze(0)
# offset loss is L1 distance between the predicted action and ground truth action
offset_loss = torch.nn.L1Loss()(action_seq, predicted_action)
# calculate primary code prediction loss
cbet_loss1 = self._criterion( # F.cross_entropy
cbet_logits[:, 0, :],
action_bins[:, 0],
)
# calculate secondary code prediction loss
cbet_loss2 = self._criterion( # F.cross_entropy
cbet_logits[:, 1, :],
action_bins[:, 1],
)
cbet_loss = cbet_loss1 * 5 + cbet_loss2 * self.config.secondary_code_loss_weight
# add all the prediction loss
cbet_loss = cbet_loss1 * self.config.primary_code_loss_weight + cbet_loss2 * self.config.secondary_code_loss_weight
equal_primary_code_rate = torch.sum(
(action_bins[:, 0] == sampled_centers[:, 0]).int()
@ -416,21 +455,9 @@ class VQBeTHead(nn.Module):
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T),
einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T),
)
vq_action_error = (
abs(
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(decoded_action, "(N T) W A -> N T W A", T=T)
)
).mean()
offset_action_error = (
abs(
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)
)
).mean()
action_error_max = (
abs(
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)
)
).max()
vq_action_error = (abs(einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(decoded_action, "(N T) W A -> N T W A", T=T))).mean()
offset_action_error = (abs(einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T))).mean()
action_error_max = (abs(einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T))).max()
loss = cbet_loss + self.config.offset_loss_weight * offset_loss
@ -506,35 +533,30 @@ class VQBeTOptimizer:
def step(self):
self.optimizing_step +=1
# pretraining VQ-VAE (phase 1)
if self.optimizing_step < self.discretize_step:
# pretraining VQ-VAE
self.vqvae_optimizer.step()
# training BeT (phase 2)
else:
# training BeT
if self.optimizing_step < 0.6 * self.offline_steps:
self.encoder_optimizer.step()
self.bet_optimizer1.step()
self.bet_optimizer2.step()
self.bet_optimizer3.step()
else:
self.bet_optimizer3.step()
self.encoder_optimizer.step()
self.bet_optimizer1.step()
self.bet_optimizer2.step()
self.bet_optimizer3.step()
def zero_grad(self):
# pretraining VQ-VAE (phase 1)
if self.optimizing_step < self.discretize_step:
# pretraining VQ-VAE
self.vqvae_optimizer.zero_grad()
# training BeT (phase 2)
else:
# training BeT
if self.optimizing_step < 0.6 * self.offline_steps:
self.encoder_optimizer.zero_grad()
self.bet_optimizer1.zero_grad()
self.bet_optimizer2.zero_grad()
self.bet_optimizer3.zero_grad()
else:
self.bet_optimizer3.zero_grad()
self.encoder_optimizer.zero_grad()
self.bet_optimizer1.zero_grad()
self.bet_optimizer2.zero_grad()
self.bet_optimizer3.zero_grad()
class VQBeTScheduler:
def __init__(self, optimizer, cfg):
# VQ-BeT use scheduler only for rgb encoder. Since we took rgb encoder part from diffusion policy, we also follow the same scheduler from it.
from diffusers.optimization import get_scheduler
self.discretize_step = cfg.training.discretize_step
self.optimizing_step = 0
@ -663,6 +685,11 @@ class VqVae(nn.Module):
def __init__(
self, config: VQBeTConfig,
):
"""
VQ-VAE is composed of three parts: encoder, vq_layer, and decoder.
Encoder and decoder are MLPs consisting of an input, output layer, and hidden layer, respectively.
The vq_layer uses residual VQs.
"""
super(VqVae, self).__init__()
self.config = config
@ -676,24 +703,14 @@ class VqVae(nn.Module):
codebook_size=config.vqvae_n_embed,
)
if self.config.n_action_pred_chunk == 1:
self.encoder = MLP(
in_channels=self.config.output_shapes["action"][0],
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, config.vqvae_embedding_dim],
)
self.decoder = MLP(
in_channels=config.vqvae_embedding_dim,
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.config.output_shapes["action"][0]],
)
else:
self.encoder = MLP(
in_channels=self.config.output_shapes["action"][0] * self.config.n_action_pred_chunk,
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, config.vqvae_embedding_dim],
)
self.decoder = MLP(
in_channels=config.vqvae_embedding_dim,
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.config.output_shapes["action"][0] * self.config.n_action_pred_chunk],
)
self.encoder = MLP(
in_channels=self.config.output_shapes["action"][0] * self.config.n_action_pred_chunk,
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, config.vqvae_embedding_dim],
)
self.decoder = MLP(
in_channels=config.vqvae_embedding_dim,
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.config.output_shapes["action"][0] * self.config.n_action_pred_chunk],
)
self.train()
@ -704,6 +721,11 @@ class VqVae(nn.Module):
self.decoder.eval()
def train(self, mode=True):
"""
This function forces the RVQ to no longer update when action discretization is complete.
Since VQs are partly updated via the EMA method, simply passing data through them can cause unintended modifications.
Therefore, we use function overriding to prevent RVQs from being updated during the training of VQ-BeT after discretization completes.
"""
if mode:
if self.discretized:
pass
@ -736,7 +758,7 @@ class VqVae(nn.Module):
if not torch.is_tensor(state):
state = torch.FloatTensor(state.copy())
if self.config.n_action_pred_chunk == 1:
state = state.squeeze(-2) # state.squeeze(-1)
state = state.squeeze(-2)
else:
state = einops.rearrange(state, "N T A -> N (T A)")
return state
@ -764,7 +786,6 @@ class VqVae(nn.Module):
torch.swapaxes(recon_state_ae, -2, -1),
)
else:
# econ_from_code = self.draw_code_forward(vq_code)
return state_vq, vq_code
def vqvae_forward(self, state):
@ -815,7 +836,7 @@ def pretrain_vqvae(vqvae_model, discretize_step, actions):
loss, metric = vqvae_model.vqvae_forward(
actions
) # N T D
)
n_different_codes = len(torch.unique(metric[2]))
n_different_combinations = len(torch.unique(metric[2], dim=0))
vqvae_model.optimized_steps += 1
@ -837,6 +858,29 @@ def round_up_multiple(num, mult):
class ResidualVQ(nn.Module):
"""
MIT License
Copyright (c) 2020 Phil Wang
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
"""Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf"""
def __init__(

View File

@ -1,6 +1,6 @@
# @package _global_
# Defaults for training for the PushT dataset as per https://github.com/real-stanford/diffusion_policy.
# Defaults for training for the PushT dataset.
seed: 100000
dataset_repo_id: lerobot/pusht
@ -100,5 +100,6 @@ policy:
dropout: 0.1
mlp_hidden_dim: 1024
offset_loss_weight: 10000.
primary_code_loss_weight: 5.0
secondary_code_loss_weight: 0.5
bet_softmax_temperature: 0.1