add comments, add primary_code_loss_weight
This commit is contained in:
parent
71ec76fe2a
commit
547d3c3ea1
|
@ -55,7 +55,8 @@ class VQBeTConfig:
|
||||||
dropout: Dropout rate for GPT
|
dropout: Dropout rate for GPT
|
||||||
mlp_hidden_dim: Size of hidden dimensions of offset header / bin prediction headers parts of VQ-BeT
|
mlp_hidden_dim: Size of hidden dimensions of offset header / bin prediction headers parts of VQ-BeT
|
||||||
offset_loss_weight: A constant that is multiplied to the offset loss
|
offset_loss_weight: A constant that is multiplied to the offset loss
|
||||||
secondary_code_loss_weight: A constant that is multiplied to the secondary loss
|
primary_code_loss_weight: A constant that is multiplied to the primary code prediction loss
|
||||||
|
secondary_code_loss_weight: A constant that is multiplied to the secondary code prediction loss
|
||||||
bet_softmax_temperature: Sampling temperature of code for rollout with VQ-BeT
|
bet_softmax_temperature: Sampling temperature of code for rollout with VQ-BeT
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -110,6 +111,7 @@ class VQBeTConfig:
|
||||||
dropout: float = 0.1
|
dropout: float = 0.1
|
||||||
mlp_hidden_dim: int = 1024
|
mlp_hidden_dim: int = 1024
|
||||||
offset_loss_weight: float = 10000.
|
offset_loss_weight: float = 10000.
|
||||||
|
primary_code_loss_weight: float = 5.0
|
||||||
secondary_code_loss_weight: float = 0.5
|
secondary_code_loss_weight: float = 0.5
|
||||||
bet_softmax_temperature: float = 0.1
|
bet_softmax_temperature: float = 0.1
|
||||||
|
|
||||||
|
|
|
@ -1,20 +1,16 @@
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Callable, List, Optional
|
from typing import Callable, List
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from itertools import zip_longest
|
|
||||||
from random import randrange
|
from random import randrange
|
||||||
import math
|
import math
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from dataclasses import dataclass
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
from einops import rearrange, repeat, reduce, pack, unpack
|
from einops import rearrange, repeat, reduce, pack, unpack
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F
|
||||||
import torchvision
|
import torchvision
|
||||||
from torch import Tensor, nn, einsum
|
from torch import Tensor, nn, einsum
|
||||||
import torch.distributed as distributed
|
import torch.distributed as distributed
|
||||||
|
@ -123,7 +119,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
if not self.check_discretized():
|
if not self.check_discretized():
|
||||||
loss, n_different_codes, n_different_combinations = self.vqbet.discretize(self.config.discretize_step, batch['action'])
|
loss, n_different_codes, n_different_combinations = self.vqbet.discretize(self.config.discretize_step, batch['action'])
|
||||||
return {"loss": loss, "n_different_codes": n_different_codes, "n_different_combinations": n_different_combinations}
|
return {"loss": loss, "n_different_codes": n_different_codes, "n_different_combinations": n_different_combinations}
|
||||||
|
# if Residual VQ is already trained, VQ-BeT trains its GPT and bin ped header / offset header parts.
|
||||||
_, loss = self.vqbet(batch, rollout=False)
|
_, loss = self.vqbet(batch, rollout=False)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
@ -159,6 +155,8 @@ class VQBeTModel(nn.Module):
|
||||||
|
|
||||||
Phase 2.
|
Phase 2.
|
||||||
|
|
||||||
|
timestep {t-n+1} timestep {t-n+2} timestep {t}
|
||||||
|
┌─────┴─────┐ ┌─────┴─────┐ ┌─────┴─────┐
|
||||||
|
|
||||||
o_{t-n+1} o_{t-n+2} ... o_{t}
|
o_{t-n+1} o_{t-n+2} ... o_{t}
|
||||||
│ │ │
|
│ │ │
|
||||||
|
@ -191,10 +189,10 @@ class VQBeTModel(nn.Module):
|
||||||
|
|
||||||
self.rgb_encoder = VQBeTRgbEncoder(config)
|
self.rgb_encoder = VQBeTRgbEncoder(config)
|
||||||
|
|
||||||
|
# This action query token is used as a prompt for querying action chunks. Please refer to "A_Q" in the image above.
|
||||||
|
self._action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim))
|
||||||
|
|
||||||
# action token and EOS token
|
# To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT.
|
||||||
self._action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim)) # Batch, Timestep, Data type, GPT input dim
|
|
||||||
|
|
||||||
self.state_projector = MLP(
|
self.state_projector = MLP(
|
||||||
config.output_shapes["action"][0],
|
config.output_shapes["action"][0],
|
||||||
hidden_channels=[self.config.gpt_input_dim]
|
hidden_channels=[self.config.gpt_input_dim]
|
||||||
|
@ -203,7 +201,10 @@ class VQBeTModel(nn.Module):
|
||||||
self.rgb_encoder.feature_dim,
|
self.rgb_encoder.feature_dim,
|
||||||
hidden_channels=[self.config.gpt_input_dim]
|
hidden_channels=[self.config.gpt_input_dim]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# GPT part of VQ-BeT
|
||||||
self.policy = GPT(config)
|
self.policy = GPT(config)
|
||||||
|
# bin prediction header / offset prediction header part of VQ-BeT
|
||||||
self.action_head = VQBeTHead(config)
|
self.action_head = VQBeTHead(config)
|
||||||
|
|
||||||
def discretize(self, discretize_step, actions):
|
def discretize(self, discretize_step, actions):
|
||||||
|
@ -220,7 +221,7 @@ class VQBeTModel(nn.Module):
|
||||||
# Separate batch and sequence dims.
|
# Separate batch and sequence dims.
|
||||||
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
|
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
|
||||||
|
|
||||||
|
# image observation feature, state feature, and action query token are grouped together with the same timestpe to form a group, which is listed in order to be entered into GPT sequentially.
|
||||||
observation_feature = torch.cat([
|
observation_feature = torch.cat([
|
||||||
torch.unsqueeze(self.obs_projector(img_features), dim=2),
|
torch.unsqueeze(self.obs_projector(img_features), dim=2),
|
||||||
torch.unsqueeze(self.state_projector(batch["observation.state"]), dim=2),
|
torch.unsqueeze(self.state_projector(batch["observation.state"]), dim=2),
|
||||||
|
@ -230,24 +231,28 @@ class VQBeTModel(nn.Module):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
len_additional_action_token = self.config.n_action_pred_token-1
|
len_additional_action_token = self.config.n_action_pred_token-1
|
||||||
action_token = self._action_token.repeat(batch_size, len_additional_action_token, 1)
|
action_token = self._action_token.repeat(batch_size, len_additional_action_token, 1)
|
||||||
|
|
||||||
|
# add additional action query tokens for predicting future action chunks
|
||||||
observation_feature = torch.cat([observation_feature, action_token], dim=1)
|
observation_feature = torch.cat([observation_feature, action_token], dim=1)
|
||||||
|
|
||||||
|
|
||||||
# get action features
|
# get action features (pass through GPT)
|
||||||
features = self.policy(observation_feature)
|
features = self.policy(observation_feature)
|
||||||
historical_act_pred_index = np.arange(0, n_obs_steps) * (self.config.gpt_num_obs_mode+1) + self.config.gpt_num_obs_mode
|
historical_act_pred_index = np.arange(0, n_obs_steps) * (self.config.gpt_num_obs_mode+1) + self.config.gpt_num_obs_mode
|
||||||
|
|
||||||
|
# only extract the output tokens at the position of action query
|
||||||
features = torch.cat([
|
features = torch.cat([
|
||||||
features[:, historical_act_pred_index],
|
features[:, historical_act_pred_index],
|
||||||
features[:, -len_additional_action_token:]
|
features[:, -len_additional_action_token:]
|
||||||
], dim=1)
|
], dim=1)
|
||||||
# action head
|
# pass through action head
|
||||||
pred_action = self.action_head(
|
pred_action = self.action_head(
|
||||||
features,
|
features,
|
||||||
)
|
)
|
||||||
|
# if rollout, VQ-BeT don't calculate loss
|
||||||
if rollout:
|
if rollout:
|
||||||
return pred_action["predicted_action"][:, n_obs_steps-1, :].reshape(batch_size, self.config.n_action_pred_chunk, -1)
|
return pred_action["predicted_action"][:, n_obs_steps-1, :].reshape(batch_size, self.config.n_action_pred_chunk, -1)
|
||||||
|
# else, it calculate overall loss (bin prediction loss, and offset loss)
|
||||||
else:
|
else:
|
||||||
action = batch["action"]
|
action = batch["action"]
|
||||||
n, total_w, act_dim = action.shape
|
n, total_w, act_dim = action.shape
|
||||||
|
@ -270,7 +275,21 @@ class VQBeTModel(nn.Module):
|
||||||
class VQBeTHead(nn.Module):
|
class VQBeTHead(nn.Module):
|
||||||
def __init__(self, config: VQBeTConfig):
|
def __init__(self, config: VQBeTConfig):
|
||||||
"""
|
"""
|
||||||
TODO: add explanation for each value.
|
VQBeTHead takes output of GPT layers, and pass the feature through bin prediction head (`self.map_to_cbet_preds_bin`), and offset prediction head (`self.map_to_cbet_preds_offset`)
|
||||||
|
|
||||||
|
self.map_to_cbet_preds_bin: outputs probability of each code (for each layer).
|
||||||
|
The input dimension of `self.map_to_cbet_preds_bin` is same with the output of GPT,
|
||||||
|
and the output dimension of `self.map_to_cbet_preds_bin` is `self.config.vqvae_groups * self.config.vqvae_n_embed`, where
|
||||||
|
`self.config.vqvae_groups` is number of RVQ layers, and
|
||||||
|
`self.config.vqvae_n_embed` is codebook size of RVQ.
|
||||||
|
|
||||||
|
self.map_to_cbet_preds_offset: output the predicted offsets for all the codes in all the layers.
|
||||||
|
The input dimension of ` self.map_to_cbet_preds_offset` is same with the output of GPT,
|
||||||
|
and the output dimension of ` self.map_to_cbet_preds_offset` is `self.config.vqvae_groups * self.config.vqvae_n_embed * config.n_action_pred_chunk * config.output_shapes["action"][0]`, where
|
||||||
|
`self.config.vqvae_groups` is number of RVQ layers,
|
||||||
|
`self.config.vqvae_n_embed` is codebook size of RVQ,
|
||||||
|
`config.n_action_pred_chunk is action chunk size of each token, and
|
||||||
|
`config.output_shapes["action"][0]` is the dimension of action
|
||||||
"""
|
"""
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -301,6 +320,7 @@ class VQBeTHead(nn.Module):
|
||||||
self.vqvae_model.device = get_device_from_parameters(self)
|
self.vqvae_model.device = get_device_from_parameters(self)
|
||||||
|
|
||||||
loss, n_different_codes, n_different_combinations = pretrain_vqvae(self.vqvae_model, discretize_step, actions)
|
loss, n_different_codes, n_different_combinations = pretrain_vqvae(self.vqvae_model, discretize_step, actions)
|
||||||
|
# if we updated RVQ more than `discretize_step` steps,
|
||||||
if self.vqvae_model.discretized:
|
if self.vqvae_model.discretized:
|
||||||
print("Finished discretizing action data!")
|
print("Finished discretizing action data!")
|
||||||
self.vqvae_model.eval()
|
self.vqvae_model.eval()
|
||||||
|
@ -309,7 +329,10 @@ class VQBeTHead(nn.Module):
|
||||||
return loss, n_different_codes, n_different_combinations
|
return loss, n_different_codes, n_different_combinations
|
||||||
|
|
||||||
def forward(self, x, **kwargs):
|
def forward(self, x, **kwargs):
|
||||||
|
# N is the batch size, and T is number of action query tokens, which are process through same GPT
|
||||||
N, T, _ = x.shape
|
N, T, _ = x.shape
|
||||||
|
# we calculate N and T side parallely. Thus, the dimensions would be
|
||||||
|
# (batch size * number of action query tokens, action chunk size, action dimension)
|
||||||
x = einops.rearrange(x, "N T WA -> (N T) WA")
|
x = einops.rearrange(x, "N T WA -> (N T) WA")
|
||||||
|
|
||||||
cbet_logits = self.map_to_cbet_preds_bin(x)
|
cbet_logits = self.map_to_cbet_preds_bin(x)
|
||||||
|
@ -334,23 +357,27 @@ class VQBeTHead(nn.Module):
|
||||||
sampled_centers,
|
sampled_centers,
|
||||||
)
|
)
|
||||||
# Use advanced indexing to sample the values
|
# Use advanced indexing to sample the values
|
||||||
sampled_offsets = cbet_offsets[indices] # NT, G, W, A(?) or NT, G, A
|
sampled_offsets = cbet_offsets[indices]
|
||||||
|
# Extract the only offsets corresponding to the sampled codes.
|
||||||
sampled_offsets = sampled_offsets.sum(dim=1)
|
sampled_offsets = sampled_offsets.sum(dim=1)
|
||||||
|
# Get the centroids of each layer to pass it through RVQ decoder
|
||||||
centers = self.vqvae_model.draw_code_forward(sampled_centers).view(
|
centers = self.vqvae_model.draw_code_forward(sampled_centers).view(
|
||||||
NT, -1, self.config.vqvae_embedding_dim
|
NT, -1, self.config.vqvae_embedding_dim
|
||||||
)
|
)
|
||||||
return_decoder_input = einops.rearrange(
|
return_decoder_input = einops.rearrange(
|
||||||
centers.clone().detach(), "NT 1 D -> NT D"
|
centers.clone().detach(), "NT 1 D -> NT D"
|
||||||
)
|
)
|
||||||
|
# pass the centroids through decoder to get actions.
|
||||||
decoded_action = (
|
decoded_action = (
|
||||||
self.vqvae_model.get_action_from_latent(return_decoder_input)
|
self.vqvae_model.get_action_from_latent(return_decoder_input)
|
||||||
.clone()
|
.clone()
|
||||||
.detach()
|
.detach()
|
||||||
) # NT, A
|
)
|
||||||
|
# reshaped extracted offset to match with decoded centroids
|
||||||
sampled_offsets = einops.rearrange(
|
sampled_offsets = einops.rearrange(
|
||||||
sampled_offsets, "NT (W A) -> NT W A", W=self.config.n_action_pred_chunk
|
sampled_offsets, "NT (W A) -> NT W A", W=self.config.n_action_pred_chunk
|
||||||
)
|
)
|
||||||
|
# add offset and decoded centroids
|
||||||
predicted_action = decoded_action + sampled_offsets
|
predicted_action = decoded_action + sampled_offsets
|
||||||
predicted_action = einops.rearrange(
|
predicted_action = einops.rearrange(
|
||||||
predicted_action,
|
predicted_action,
|
||||||
|
@ -368,7 +395,16 @@ class VQBeTHead(nn.Module):
|
||||||
}
|
}
|
||||||
|
|
||||||
def loss_fn(self, pred, target, **kwargs):
|
def loss_fn(self, pred, target, **kwargs):
|
||||||
# Rename the inputs for clarity.
|
"""
|
||||||
|
for given ground truth action values (target), and prediction (pred) this function calculates the overall loss.
|
||||||
|
|
||||||
|
predicted_action: predicted action chunk (offset + decoded centroids)
|
||||||
|
sampled_centers: sampled centroids (code of RVQ)
|
||||||
|
decoded_action: decoded action, which is produced by passing sampled_centers through RVQ decoder
|
||||||
|
NT: batch size * T
|
||||||
|
T: number of action query tokens, which are process through same GPT
|
||||||
|
cbet_logits: probability of all codes in each layer
|
||||||
|
"""
|
||||||
action_seq = target
|
action_seq = target
|
||||||
predicted_action = pred["predicted_action"]
|
predicted_action = pred["predicted_action"]
|
||||||
sampled_centers = pred["sampled_centers"]
|
sampled_centers = pred["sampled_centers"]
|
||||||
|
@ -383,7 +419,7 @@ class VQBeTHead(nn.Module):
|
||||||
|
|
||||||
action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A")
|
action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A")
|
||||||
# Figure out the loss for the actions.
|
# Figure out the loss for the actions.
|
||||||
# First, we need to find the closest cluster center for each action.
|
# First, we need to find the closest cluster center for each ground truth action.
|
||||||
state_vq, action_bins = self.vqvae_model.get_code(
|
state_vq, action_bins = self.vqvae_model.get_code(
|
||||||
action_seq
|
action_seq
|
||||||
) # action_bins: NT, G
|
) # action_bins: NT, G
|
||||||
|
@ -392,18 +428,21 @@ class VQBeTHead(nn.Module):
|
||||||
if action_seq.ndim == 2:
|
if action_seq.ndim == 2:
|
||||||
action_seq = action_seq.unsqueeze(0)
|
action_seq = action_seq.unsqueeze(0)
|
||||||
|
|
||||||
|
# offset loss is L1 distance between the predicted action and ground truth action
|
||||||
offset_loss = torch.nn.L1Loss()(action_seq, predicted_action)
|
offset_loss = torch.nn.L1Loss()(action_seq, predicted_action)
|
||||||
|
|
||||||
|
# calculate primary code prediction loss
|
||||||
cbet_loss1 = self._criterion( # F.cross_entropy
|
cbet_loss1 = self._criterion( # F.cross_entropy
|
||||||
cbet_logits[:, 0, :],
|
cbet_logits[:, 0, :],
|
||||||
action_bins[:, 0],
|
action_bins[:, 0],
|
||||||
)
|
)
|
||||||
|
# calculate secondary code prediction loss
|
||||||
cbet_loss2 = self._criterion( # F.cross_entropy
|
cbet_loss2 = self._criterion( # F.cross_entropy
|
||||||
cbet_logits[:, 1, :],
|
cbet_logits[:, 1, :],
|
||||||
action_bins[:, 1],
|
action_bins[:, 1],
|
||||||
)
|
)
|
||||||
cbet_loss = cbet_loss1 * 5 + cbet_loss2 * self.config.secondary_code_loss_weight
|
# add all the prediction loss
|
||||||
|
cbet_loss = cbet_loss1 * self.config.primary_code_loss_weight + cbet_loss2 * self.config.secondary_code_loss_weight
|
||||||
|
|
||||||
equal_primary_code_rate = torch.sum(
|
equal_primary_code_rate = torch.sum(
|
||||||
(action_bins[:, 0] == sampled_centers[:, 0]).int()
|
(action_bins[:, 0] == sampled_centers[:, 0]).int()
|
||||||
|
@ -416,21 +455,9 @@ class VQBeTHead(nn.Module):
|
||||||
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T),
|
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T),
|
||||||
einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T),
|
einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T),
|
||||||
)
|
)
|
||||||
vq_action_error = (
|
vq_action_error = (abs(einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(decoded_action, "(N T) W A -> N T W A", T=T))).mean()
|
||||||
abs(
|
offset_action_error = (abs(einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T))).mean()
|
||||||
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(decoded_action, "(N T) W A -> N T W A", T=T)
|
action_error_max = (abs(einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T))).max()
|
||||||
)
|
|
||||||
).mean()
|
|
||||||
offset_action_error = (
|
|
||||||
abs(
|
|
||||||
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)
|
|
||||||
)
|
|
||||||
).mean()
|
|
||||||
action_error_max = (
|
|
||||||
abs(
|
|
||||||
einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)
|
|
||||||
)
|
|
||||||
).max()
|
|
||||||
|
|
||||||
loss = cbet_loss + self.config.offset_loss_weight * offset_loss
|
loss = cbet_loss + self.config.offset_loss_weight * offset_loss
|
||||||
|
|
||||||
|
@ -506,35 +533,30 @@ class VQBeTOptimizer:
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
self.optimizing_step +=1
|
self.optimizing_step +=1
|
||||||
|
# pretraining VQ-VAE (phase 1)
|
||||||
if self.optimizing_step < self.discretize_step:
|
if self.optimizing_step < self.discretize_step:
|
||||||
# pretraining VQ-VAE
|
|
||||||
self.vqvae_optimizer.step()
|
self.vqvae_optimizer.step()
|
||||||
|
# training BeT (phase 2)
|
||||||
else:
|
else:
|
||||||
# training BeT
|
self.encoder_optimizer.step()
|
||||||
if self.optimizing_step < 0.6 * self.offline_steps:
|
self.bet_optimizer1.step()
|
||||||
self.encoder_optimizer.step()
|
self.bet_optimizer2.step()
|
||||||
self.bet_optimizer1.step()
|
self.bet_optimizer3.step()
|
||||||
self.bet_optimizer2.step()
|
|
||||||
self.bet_optimizer3.step()
|
|
||||||
else:
|
|
||||||
self.bet_optimizer3.step()
|
|
||||||
|
|
||||||
def zero_grad(self):
|
def zero_grad(self):
|
||||||
|
# pretraining VQ-VAE (phase 1)
|
||||||
if self.optimizing_step < self.discretize_step:
|
if self.optimizing_step < self.discretize_step:
|
||||||
# pretraining VQ-VAE
|
|
||||||
self.vqvae_optimizer.zero_grad()
|
self.vqvae_optimizer.zero_grad()
|
||||||
|
# training BeT (phase 2)
|
||||||
else:
|
else:
|
||||||
# training BeT
|
self.encoder_optimizer.zero_grad()
|
||||||
if self.optimizing_step < 0.6 * self.offline_steps:
|
self.bet_optimizer1.zero_grad()
|
||||||
self.encoder_optimizer.zero_grad()
|
self.bet_optimizer2.zero_grad()
|
||||||
self.bet_optimizer1.zero_grad()
|
self.bet_optimizer3.zero_grad()
|
||||||
self.bet_optimizer2.zero_grad()
|
|
||||||
self.bet_optimizer3.zero_grad()
|
|
||||||
else:
|
|
||||||
self.bet_optimizer3.zero_grad()
|
|
||||||
|
|
||||||
class VQBeTScheduler:
|
class VQBeTScheduler:
|
||||||
def __init__(self, optimizer, cfg):
|
def __init__(self, optimizer, cfg):
|
||||||
|
# VQ-BeT use scheduler only for rgb encoder. Since we took rgb encoder part from diffusion policy, we also follow the same scheduler from it.
|
||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
self.discretize_step = cfg.training.discretize_step
|
self.discretize_step = cfg.training.discretize_step
|
||||||
self.optimizing_step = 0
|
self.optimizing_step = 0
|
||||||
|
@ -663,6 +685,11 @@ class VqVae(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, config: VQBeTConfig,
|
self, config: VQBeTConfig,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
VQ-VAE is composed of three parts: encoder, vq_layer, and decoder.
|
||||||
|
Encoder and decoder are MLPs consisting of an input, output layer, and hidden layer, respectively.
|
||||||
|
The vq_layer uses residual VQs.
|
||||||
|
"""
|
||||||
|
|
||||||
super(VqVae, self).__init__()
|
super(VqVae, self).__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
@ -676,24 +703,14 @@ class VqVae(nn.Module):
|
||||||
codebook_size=config.vqvae_n_embed,
|
codebook_size=config.vqvae_n_embed,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.config.n_action_pred_chunk == 1:
|
self.encoder = MLP(
|
||||||
self.encoder = MLP(
|
in_channels=self.config.output_shapes["action"][0] * self.config.n_action_pred_chunk,
|
||||||
in_channels=self.config.output_shapes["action"][0],
|
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, config.vqvae_embedding_dim],
|
||||||
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, config.vqvae_embedding_dim],
|
)
|
||||||
)
|
self.decoder = MLP(
|
||||||
self.decoder = MLP(
|
in_channels=config.vqvae_embedding_dim,
|
||||||
in_channels=config.vqvae_embedding_dim,
|
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.config.output_shapes["action"][0] * self.config.n_action_pred_chunk],
|
||||||
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.config.output_shapes["action"][0]],
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.encoder = MLP(
|
|
||||||
in_channels=self.config.output_shapes["action"][0] * self.config.n_action_pred_chunk,
|
|
||||||
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, config.vqvae_embedding_dim],
|
|
||||||
)
|
|
||||||
self.decoder = MLP(
|
|
||||||
in_channels=config.vqvae_embedding_dim,
|
|
||||||
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.config.output_shapes["action"][0] * self.config.n_action_pred_chunk],
|
|
||||||
)
|
|
||||||
|
|
||||||
self.train()
|
self.train()
|
||||||
|
|
||||||
|
@ -704,6 +721,11 @@ class VqVae(nn.Module):
|
||||||
self.decoder.eval()
|
self.decoder.eval()
|
||||||
|
|
||||||
def train(self, mode=True):
|
def train(self, mode=True):
|
||||||
|
"""
|
||||||
|
This function forces the RVQ to no longer update when action discretization is complete.
|
||||||
|
Since VQs are partly updated via the EMA method, simply passing data through them can cause unintended modifications.
|
||||||
|
Therefore, we use function overriding to prevent RVQs from being updated during the training of VQ-BeT after discretization completes.
|
||||||
|
"""
|
||||||
if mode:
|
if mode:
|
||||||
if self.discretized:
|
if self.discretized:
|
||||||
pass
|
pass
|
||||||
|
@ -736,7 +758,7 @@ class VqVae(nn.Module):
|
||||||
if not torch.is_tensor(state):
|
if not torch.is_tensor(state):
|
||||||
state = torch.FloatTensor(state.copy())
|
state = torch.FloatTensor(state.copy())
|
||||||
if self.config.n_action_pred_chunk == 1:
|
if self.config.n_action_pred_chunk == 1:
|
||||||
state = state.squeeze(-2) # state.squeeze(-1)
|
state = state.squeeze(-2)
|
||||||
else:
|
else:
|
||||||
state = einops.rearrange(state, "N T A -> N (T A)")
|
state = einops.rearrange(state, "N T A -> N (T A)")
|
||||||
return state
|
return state
|
||||||
|
@ -764,7 +786,6 @@ class VqVae(nn.Module):
|
||||||
torch.swapaxes(recon_state_ae, -2, -1),
|
torch.swapaxes(recon_state_ae, -2, -1),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# econ_from_code = self.draw_code_forward(vq_code)
|
|
||||||
return state_vq, vq_code
|
return state_vq, vq_code
|
||||||
|
|
||||||
def vqvae_forward(self, state):
|
def vqvae_forward(self, state):
|
||||||
|
@ -815,7 +836,7 @@ def pretrain_vqvae(vqvae_model, discretize_step, actions):
|
||||||
|
|
||||||
loss, metric = vqvae_model.vqvae_forward(
|
loss, metric = vqvae_model.vqvae_forward(
|
||||||
actions
|
actions
|
||||||
) # N T D
|
)
|
||||||
n_different_codes = len(torch.unique(metric[2]))
|
n_different_codes = len(torch.unique(metric[2]))
|
||||||
n_different_combinations = len(torch.unique(metric[2], dim=0))
|
n_different_combinations = len(torch.unique(metric[2], dim=0))
|
||||||
vqvae_model.optimized_steps += 1
|
vqvae_model.optimized_steps += 1
|
||||||
|
@ -837,6 +858,29 @@ def round_up_multiple(num, mult):
|
||||||
|
|
||||||
|
|
||||||
class ResidualVQ(nn.Module):
|
class ResidualVQ(nn.Module):
|
||||||
|
"""
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2020 Phil Wang
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
|
"""
|
||||||
"""Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf"""
|
"""Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# @package _global_
|
# @package _global_
|
||||||
|
|
||||||
# Defaults for training for the PushT dataset as per https://github.com/real-stanford/diffusion_policy.
|
# Defaults for training for the PushT dataset.
|
||||||
|
|
||||||
seed: 100000
|
seed: 100000
|
||||||
dataset_repo_id: lerobot/pusht
|
dataset_repo_id: lerobot/pusht
|
||||||
|
@ -100,5 +100,6 @@ policy:
|
||||||
dropout: 0.1
|
dropout: 0.1
|
||||||
mlp_hidden_dim: 1024
|
mlp_hidden_dim: 1024
|
||||||
offset_loss_weight: 10000.
|
offset_loss_weight: 10000.
|
||||||
|
primary_code_loss_weight: 5.0
|
||||||
secondary_code_loss_weight: 0.5
|
secondary_code_loss_weight: 0.5
|
||||||
bet_softmax_temperature: 0.1
|
bet_softmax_temperature: 0.1
|
Loading…
Reference in New Issue