add comments, add primary_code_loss_weight
This commit is contained in:
parent
71ec76fe2a
commit
547d3c3ea1
|
@ -55,7 +55,8 @@ class VQBeTConfig:
|
|||
dropout: Dropout rate for GPT
|
||||
mlp_hidden_dim: Size of hidden dimensions of offset header / bin prediction headers parts of VQ-BeT
|
||||
offset_loss_weight: A constant that is multiplied to the offset loss
|
||||
secondary_code_loss_weight: A constant that is multiplied to the secondary loss
|
||||
primary_code_loss_weight: A constant that is multiplied to the primary code prediction loss
|
||||
secondary_code_loss_weight: A constant that is multiplied to the secondary code prediction loss
|
||||
bet_softmax_temperature: Sampling temperature of code for rollout with VQ-BeT
|
||||
"""
|
||||
|
||||
|
@ -110,6 +111,7 @@ class VQBeTConfig:
|
|||
dropout: float = 0.1
|
||||
mlp_hidden_dim: int = 1024
|
||||
offset_loss_weight: float = 10000.
|
||||
primary_code_loss_weight: float = 5.0
|
||||
secondary_code_loss_weight: float = 0.5
|
||||
bet_softmax_temperature: float = 0.1
|
||||
|
||||
|
|
|
@ -1,20 +1,16 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
from collections import deque
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, List
|
||||
from functools import partial
|
||||
from itertools import zip_longest
|
||||
from random import randrange
|
||||
import math
|
||||
from math import ceil
|
||||
from dataclasses import dataclass
|
||||
import warnings
|
||||
|
||||
import einops
|
||||
from einops import rearrange, repeat, reduce, pack, unpack
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
from torch import Tensor, nn, einsum
|
||||
import torch.distributed as distributed
|
||||
|
@ -123,7 +119,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
if not self.check_discretized():
|
||||
loss, n_different_codes, n_different_combinations = self.vqbet.discretize(self.config.discretize_step, batch['action'])
|
||||
return {"loss": loss, "n_different_codes": n_different_codes, "n_different_combinations": n_different_combinations}
|
||||
|
||||
# if Residual VQ is already trained, VQ-BeT trains its GPT and bin ped header / offset header parts.
|
||||
_, loss = self.vqbet(batch, rollout=False)
|
||||
|
||||
return loss
|
||||
|
@ -159,6 +155,8 @@ class VQBeTModel(nn.Module):
|
|||
|
||||
Phase 2.
|
||||
|
||||
timestep {t-n+1} timestep {t-n+2} timestep {t}
|
||||
┌─────┴─────┐ ┌─────┴─────┐ ┌─────┴─────┐
|
||||
|
||||
o_{t-n+1} o_{t-n+2} ... o_{t}
|
||||
│ │ │
|
||||
|
@ -191,10 +189,10 @@ class VQBeTModel(nn.Module):
|
|||
|
||||
self.rgb_encoder = VQBeTRgbEncoder(config)
|
||||
|
||||
# This action query token is used as a prompt for querying action chunks. Please refer to "A_Q" in the image above.
|
||||
self._action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim))
|
||||
|
||||
# action token and EOS token
|
||||
self._action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim)) # Batch, Timestep, Data type, GPT input dim
|
||||
|
||||
# To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT.
|
||||
self.state_projector = MLP(
|
||||
config.output_shapes["action"][0],
|
||||
hidden_channels=[self.config.gpt_input_dim]
|
||||
|
@ -203,7 +201,10 @@ class VQBeTModel(nn.Module):
|
|||
self.rgb_encoder.feature_dim,
|
||||
hidden_channels=[self.config.gpt_input_dim]
|
||||
)
|
||||
|
||||
# GPT part of VQ-BeT
|
||||
self.policy = GPT(config)
|
||||
# bin prediction header / offset prediction header part of VQ-BeT
|
||||
self.action_head = VQBeTHead(config)
|
||||
|
||||
def discretize(self, discretize_step, actions):
|
||||
|
@ -220,7 +221,7 @@ class VQBeTModel(nn.Module):
|
|||
# Separate batch and sequence dims.
|
||||
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
|
||||
|
||||
|
||||
# image observation feature, state feature, and action query token are grouped together with the same timestpe to form a group, which is listed in order to be entered into GPT sequentially.
|
||||
observation_feature = torch.cat([
|
||||
torch.unsqueeze(self.obs_projector(img_features), dim=2),
|
||||
torch.unsqueeze(self.state_projector(batch["observation.state"]), dim=2),
|
||||
|
@ -231,23 +232,27 @@ class VQBeTModel(nn.Module):
|
|||
len_additional_action_token = self.config.n_action_pred_token-1
|
||||
action_token = self._action_token.repeat(batch_size, len_additional_action_token, 1)
|
||||
|
||||
# add additional action query tokens for predicting future action chunks
|
||||
observation_feature = torch.cat([observation_feature, action_token], dim=1)
|
||||
|
||||
|
||||
# get action features
|
||||
# get action features (pass through GPT)
|
||||
features = self.policy(observation_feature)
|
||||
historical_act_pred_index = np.arange(0, n_obs_steps) * (self.config.gpt_num_obs_mode+1) + self.config.gpt_num_obs_mode
|
||||
|
||||
# only extract the output tokens at the position of action query
|
||||
features = torch.cat([
|
||||
features[:, historical_act_pred_index],
|
||||
features[:, -len_additional_action_token:]
|
||||
], dim=1)
|
||||
# action head
|
||||
# pass through action head
|
||||
pred_action = self.action_head(
|
||||
features,
|
||||
)
|
||||
|
||||
# if rollout, VQ-BeT don't calculate loss
|
||||
if rollout:
|
||||
return pred_action["predicted_action"][:, n_obs_steps-1, :].reshape(batch_size, self.config.n_action_pred_chunk, -1)
|
||||
# else, it calculate overall loss (bin prediction loss, and offset loss)
|
||||
else:
|
||||
action = batch["action"]
|
||||
n, total_w, act_dim = action.shape
|
||||
|
@ -270,7 +275,21 @@ class VQBeTModel(nn.Module):
|
|||
class VQBeTHead(nn.Module):
|
||||
def __init__(self, config: VQBeTConfig):
|
||||
"""
|
||||
TODO: add explanation for each value.
|
||||
VQBeTHead takes output of GPT layers, and pass the feature through bin prediction head (`self.map_to_cbet_preds_bin`), and offset prediction head (`self.map_to_cbet_preds_offset`)
|
||||
|
||||
self.map_to_cbet_preds_bin: outputs probability of each code (for each layer).
|
||||
The input dimension of `self.map_to_cbet_preds_bin` is same with the output of GPT,
|
||||
and the output dimension of `self.map_to_cbet_preds_bin` is `self.config.vqvae_groups * self.config.vqvae_n_embed`, where
|
||||
`self.config.vqvae_groups` is number of RVQ layers, and
|
||||
`self.config.vqvae_n_embed` is codebook size of RVQ.
|
||||
|
||||
self.map_to_cbet_preds_offset: output the predicted offsets for all the codes in all the layers.
|
||||
The input dimension of ` self.map_to_cbet_preds_offset` is same with the output of GPT,
|
||||
and the output dimension of ` self.map_to_cbet_preds_offset` is `self.config.vqvae_groups * self.config.vqvae_n_embed * config.n_action_pred_chunk * config.output_shapes["action"][0]`, where
|
||||
`self.config.vqvae_groups` is number of RVQ layers,
|
||||
`self.config.vqvae_n_embed` is codebook size of RVQ,
|
||||
`config.n_action_pred_chunk is action chunk size of each token, and
|
||||
`config.output_shapes["action"][0]` is the dimension of action
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
|
@ -301,6 +320,7 @@ class VQBeTHead(nn.Module):
|
|||
self.vqvae_model.device = get_device_from_parameters(self)
|
||||
|
||||
loss, n_different_codes, n_different_combinations = pretrain_vqvae(self.vqvae_model, discretize_step, actions)
|
||||
# if we updated RVQ more than `discretize_step` steps,
|
||||
if self.vqvae_model.discretized:
|
||||
print("Finished discretizing action data!")
|
||||
self.vqvae_model.eval()
|
||||
|
@ -309,7 +329,10 @@ class VQBeTHead(nn.Module):
|
|||
return loss, n_different_codes, n_different_combinations
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
# N is the batch size, and T is number of action query tokens, which are process through same GPT
|
||||
N, T, _ = x.shape
|
||||
# we calculate N and T side parallely. Thus, the dimensions would be
|
||||
# (batch size * number of action query tokens, action chunk size, action dimension)
|
||||
x = einops.rearrange(x, "N T WA -> (N T) WA")
|
||||
|
||||
cbet_logits = self.map_to_cbet_preds_bin(x)
|
||||
|
@ -334,23 +357,27 @@ class VQBeTHead(nn.Module):
|
|||
sampled_centers,
|
||||
)
|
||||
# Use advanced indexing to sample the values
|
||||
sampled_offsets = cbet_offsets[indices] # NT, G, W, A(?) or NT, G, A
|
||||
|
||||
sampled_offsets = cbet_offsets[indices]
|
||||
# Extract the only offsets corresponding to the sampled codes.
|
||||
sampled_offsets = sampled_offsets.sum(dim=1)
|
||||
# Get the centroids of each layer to pass it through RVQ decoder
|
||||
centers = self.vqvae_model.draw_code_forward(sampled_centers).view(
|
||||
NT, -1, self.config.vqvae_embedding_dim
|
||||
)
|
||||
return_decoder_input = einops.rearrange(
|
||||
centers.clone().detach(), "NT 1 D -> NT D"
|
||||
)
|
||||
# pass the centroids through decoder to get actions.
|
||||
decoded_action = (
|
||||
self.vqvae_model.get_action_from_latent(return_decoder_input)
|
||||
.clone()
|
||||
.detach()
|
||||
) # NT, A
|
||||
)
|
||||
# reshaped extracted offset to match with decoded centroids
|
||||
sampled_offsets = einops.rearrange(
|
||||
sampled_offsets, "NT (W A) -> NT W A", W=self.config.n_action_pred_chunk
|
||||
)
|
||||
# add offset and decoded centroids
|
||||
predicted_action = decoded_action + sampled_offsets
|
||||
predicted_action = einops.rearrange(
|
||||
predicted_action,
|
||||
|
@ -368,7 +395,16 @@ class VQBeTHead(nn.Module):
|
|||
}
|
||||
|
||||
def loss_fn(self, pred, target, **kwargs):
|
||||
# Rename the inputs for clarity.
|
||||
"""
|
||||
for given ground truth action values (target), and prediction (pred) this function calculates the overall loss.
|
||||
|
||||
predicted_action: predicted action chunk (offset + decoded centroids)
|
||||
sampled_centers: sampled centroids (code of RVQ)
|
||||
decoded_action: decoded action, which is produced by passing sampled_centers through RVQ decoder
|
||||
NT: batch size * T
|
||||
T: number of action query tokens, which are process through same GPT
|
||||
cbet_logits: probability of all codes in each layer
|
||||
"""
|
||||
action_seq = target
|
||||
predicted_action = pred["predicted_action"]
|
||||
sampled_centers = pred["sampled_centers"]
|
||||
|
@ -383,7 +419,7 @@ class VQBeTHead(nn.Module):
|
|||
|
||||
action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A")
|
||||
# Figure out the loss for the actions.
|
||||
# First, we need to find the closest cluster center for each action.
|
||||
# First, we need to find the closest cluster center for each ground truth action.
|
||||
state_vq, action_bins = self.vqvae_model.get_code(
|
||||
action_seq
|
||||
) # action_bins: NT, G
|
||||
|
@ -392,18 +428,21 @@ class VQBeTHead(nn.Module):
|
|||
if action_seq.ndim == 2:
|
||||
action_seq = action_seq.unsqueeze(0)
|
||||
|
||||
# offset loss is L1 distance between the predicted action and ground truth action
|
||||
offset_loss = torch.nn.L1Loss()(action_seq, predicted_action)
|
||||
|
||||
|
||||
# calculate primary code prediction loss
|
||||
cbet_loss1 = self._criterion( # F.cross_entropy
|
||||
cbet_logits[:, 0, :],
|
||||
action_bins[:, 0],
|
||||
)
|
||||
# calculate secondary code prediction loss
|
||||
cbet_loss2 = self._criterion( # F.cross_entropy
|
||||
cbet_logits[:, 1, :],
|
||||
action_bins[:, 1],
|
||||
)
|
||||
cbet_loss = cbet_loss1 * 5 + cbet_loss2 * self.config.secondary_code_loss_weight
|
||||
# add all the prediction loss
|
||||
cbet_loss = cbet_loss1 * self.config.primary_code_loss_weight + cbet_loss2 * self.config.secondary_code_loss_weight
|
||||
|
||||
equal_primary_code_rate = torch.sum(
|
||||
(action_bins[:, 0] == sampled_centers[:, 0]).int()
|
||||
|
@ -416,21 +455,9 @@ class VQBeTHead(nn.Module):
|
|||
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T),
|
||||
einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T),
|
||||
)
|
||||
vq_action_error = (
|
||||
abs(
|
||||
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(decoded_action, "(N T) W A -> N T W A", T=T)
|
||||
)
|
||||
).mean()
|
||||
offset_action_error = (
|
||||
abs(
|
||||
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)
|
||||
)
|
||||
).mean()
|
||||
action_error_max = (
|
||||
abs(
|
||||
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)
|
||||
)
|
||||
).max()
|
||||
vq_action_error = (abs(einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(decoded_action, "(N T) W A -> N T W A", T=T))).mean()
|
||||
offset_action_error = (abs(einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T))).mean()
|
||||
action_error_max = (abs(einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T))).max()
|
||||
|
||||
loss = cbet_loss + self.config.offset_loss_weight * offset_loss
|
||||
|
||||
|
@ -506,35 +533,30 @@ class VQBeTOptimizer:
|
|||
|
||||
def step(self):
|
||||
self.optimizing_step +=1
|
||||
# pretraining VQ-VAE (phase 1)
|
||||
if self.optimizing_step < self.discretize_step:
|
||||
# pretraining VQ-VAE
|
||||
self.vqvae_optimizer.step()
|
||||
# training BeT (phase 2)
|
||||
else:
|
||||
# training BeT
|
||||
if self.optimizing_step < 0.6 * self.offline_steps:
|
||||
self.encoder_optimizer.step()
|
||||
self.bet_optimizer1.step()
|
||||
self.bet_optimizer2.step()
|
||||
self.bet_optimizer3.step()
|
||||
else:
|
||||
self.bet_optimizer3.step()
|
||||
|
||||
def zero_grad(self):
|
||||
# pretraining VQ-VAE (phase 1)
|
||||
if self.optimizing_step < self.discretize_step:
|
||||
# pretraining VQ-VAE
|
||||
self.vqvae_optimizer.zero_grad()
|
||||
# training BeT (phase 2)
|
||||
else:
|
||||
# training BeT
|
||||
if self.optimizing_step < 0.6 * self.offline_steps:
|
||||
self.encoder_optimizer.zero_grad()
|
||||
self.bet_optimizer1.zero_grad()
|
||||
self.bet_optimizer2.zero_grad()
|
||||
self.bet_optimizer3.zero_grad()
|
||||
else:
|
||||
self.bet_optimizer3.zero_grad()
|
||||
|
||||
class VQBeTScheduler:
|
||||
def __init__(self, optimizer, cfg):
|
||||
# VQ-BeT use scheduler only for rgb encoder. Since we took rgb encoder part from diffusion policy, we also follow the same scheduler from it.
|
||||
from diffusers.optimization import get_scheduler
|
||||
self.discretize_step = cfg.training.discretize_step
|
||||
self.optimizing_step = 0
|
||||
|
@ -663,6 +685,11 @@ class VqVae(nn.Module):
|
|||
def __init__(
|
||||
self, config: VQBeTConfig,
|
||||
):
|
||||
"""
|
||||
VQ-VAE is composed of three parts: encoder, vq_layer, and decoder.
|
||||
Encoder and decoder are MLPs consisting of an input, output layer, and hidden layer, respectively.
|
||||
The vq_layer uses residual VQs.
|
||||
"""
|
||||
|
||||
super(VqVae, self).__init__()
|
||||
self.config = config
|
||||
|
@ -676,16 +703,6 @@ class VqVae(nn.Module):
|
|||
codebook_size=config.vqvae_n_embed,
|
||||
)
|
||||
|
||||
if self.config.n_action_pred_chunk == 1:
|
||||
self.encoder = MLP(
|
||||
in_channels=self.config.output_shapes["action"][0],
|
||||
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, config.vqvae_embedding_dim],
|
||||
)
|
||||
self.decoder = MLP(
|
||||
in_channels=config.vqvae_embedding_dim,
|
||||
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.config.output_shapes["action"][0]],
|
||||
)
|
||||
else:
|
||||
self.encoder = MLP(
|
||||
in_channels=self.config.output_shapes["action"][0] * self.config.n_action_pred_chunk,
|
||||
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, config.vqvae_embedding_dim],
|
||||
|
@ -704,6 +721,11 @@ class VqVae(nn.Module):
|
|||
self.decoder.eval()
|
||||
|
||||
def train(self, mode=True):
|
||||
"""
|
||||
This function forces the RVQ to no longer update when action discretization is complete.
|
||||
Since VQs are partly updated via the EMA method, simply passing data through them can cause unintended modifications.
|
||||
Therefore, we use function overriding to prevent RVQs from being updated during the training of VQ-BeT after discretization completes.
|
||||
"""
|
||||
if mode:
|
||||
if self.discretized:
|
||||
pass
|
||||
|
@ -736,7 +758,7 @@ class VqVae(nn.Module):
|
|||
if not torch.is_tensor(state):
|
||||
state = torch.FloatTensor(state.copy())
|
||||
if self.config.n_action_pred_chunk == 1:
|
||||
state = state.squeeze(-2) # state.squeeze(-1)
|
||||
state = state.squeeze(-2)
|
||||
else:
|
||||
state = einops.rearrange(state, "N T A -> N (T A)")
|
||||
return state
|
||||
|
@ -764,7 +786,6 @@ class VqVae(nn.Module):
|
|||
torch.swapaxes(recon_state_ae, -2, -1),
|
||||
)
|
||||
else:
|
||||
# econ_from_code = self.draw_code_forward(vq_code)
|
||||
return state_vq, vq_code
|
||||
|
||||
def vqvae_forward(self, state):
|
||||
|
@ -815,7 +836,7 @@ def pretrain_vqvae(vqvae_model, discretize_step, actions):
|
|||
|
||||
loss, metric = vqvae_model.vqvae_forward(
|
||||
actions
|
||||
) # N T D
|
||||
)
|
||||
n_different_codes = len(torch.unique(metric[2]))
|
||||
n_different_combinations = len(torch.unique(metric[2], dim=0))
|
||||
vqvae_model.optimized_steps += 1
|
||||
|
@ -837,6 +858,29 @@ def round_up_multiple(num, mult):
|
|||
|
||||
|
||||
class ResidualVQ(nn.Module):
|
||||
"""
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2020 Phil Wang
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"""
|
||||
"""Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf"""
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# @package _global_
|
||||
|
||||
# Defaults for training for the PushT dataset as per https://github.com/real-stanford/diffusion_policy.
|
||||
# Defaults for training for the PushT dataset.
|
||||
|
||||
seed: 100000
|
||||
dataset_repo_id: lerobot/pusht
|
||||
|
@ -100,5 +100,6 @@ policy:
|
|||
dropout: 0.1
|
||||
mlp_hidden_dim: 1024
|
||||
offset_loss_weight: 10000.
|
||||
primary_code_loss_weight: 5.0
|
||||
secondary_code_loss_weight: 0.5
|
||||
bet_softmax_temperature: 0.1
|
Loading…
Reference in New Issue