1463 lines
53 KiB
Python
1463 lines
53 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 Seungjae Lee and Yibin Wang and Haritheja Etukuru
|
|
# and H. Jin Kim and Nur Muhammad Mahi Shafiullah and Lerrel Pinto
|
|
# and The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import math
|
|
from functools import partial
|
|
from math import ceil
|
|
from random import randrange
|
|
from typing import Callable
|
|
|
|
import torch
|
|
import torch.distributed as distributed
|
|
import torch.nn.functional as F # noqa: N812
|
|
from einops import pack, rearrange, reduce, repeat, unpack
|
|
from torch import einsum, nn
|
|
from torch.cuda.amp import autocast
|
|
from torch.optim import Optimizer
|
|
|
|
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
|
|
|
# ruff: noqa: N806
|
|
|
|
"""
|
|
This file is part of a VQ-BeT that utilizes code from the following repositories:
|
|
|
|
- Vector Quantize PyTorch code is licensed under the MIT License:
|
|
Original source: https://github.com/lucidrains/vector-quantize-pytorch
|
|
|
|
- nanoGPT part is an adaptation of Andrej Karpathy's nanoGPT implementation in PyTorch.
|
|
Original source: https://github.com/karpathy/nanoGPT
|
|
|
|
We also made some changes to the original code to adapt it to our needs. The changes are described in the code below.
|
|
"""
|
|
|
|
"""
|
|
This is a part for nanoGPT that utilizes code from the following repository:
|
|
|
|
- Andrej Karpathy's nanoGPT implementation in PyTorch.
|
|
Original source: https://github.com/karpathy/nanoGPT
|
|
|
|
- The nanoGPT code is licensed under the MIT License:
|
|
|
|
MIT License
|
|
|
|
Copyright (c) 2022 Andrej Karpathy
|
|
|
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
of this software and associated documentation files (the "Software"), to deal
|
|
in the Software without restriction, including without limitation the rights
|
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
copies of the Software, and to permit persons to whom the Software is
|
|
furnished to do so, subject to the following conditions:
|
|
|
|
The above copyright notice and this permission notice shall be included in all
|
|
copies or substantial portions of the Software.
|
|
|
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
SOFTWARE.
|
|
|
|
- We've made some changes to the original code to adapt it to our needs.
|
|
|
|
Changed variable names:
|
|
- n_head -> gpt_n_head
|
|
- n_embd -> gpt_hidden_dim
|
|
- block_size -> gpt_block_size
|
|
- n_layer -> gpt_n_layer
|
|
|
|
|
|
class GPT(nn.Module):
|
|
- removed unused functions `def generate`, `def estimate_mfu`, and `def from_pretrained`
|
|
- changed the `configure_optimizers` to `def configure_parameters` and made it to return only the parameters of the model: we use an external optimizer in our training loop.
|
|
- in the function `forward`, we removed target loss calculation parts, since it will be calculated in the training loop (after passing through bin prediction and offset prediction heads).
|
|
|
|
"""
|
|
|
|
|
|
class CausalSelfAttention(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
assert config.gpt_hidden_dim % config.gpt_n_head == 0
|
|
# key, query, value projections for all heads, but in a batch
|
|
self.c_attn = nn.Linear(config.gpt_hidden_dim, 3 * config.gpt_hidden_dim)
|
|
# output projection
|
|
self.c_proj = nn.Linear(config.gpt_hidden_dim, config.gpt_hidden_dim)
|
|
# regularization
|
|
self.attn_dropout = nn.Dropout(config.dropout)
|
|
self.resid_dropout = nn.Dropout(config.dropout)
|
|
# causal mask to ensure that attention is only applied to the left in the input sequence
|
|
self.register_buffer(
|
|
"bias",
|
|
torch.tril(torch.ones(config.gpt_block_size, config.gpt_block_size)).view(
|
|
1, 1, config.gpt_block_size, config.gpt_block_size
|
|
),
|
|
)
|
|
self.gpt_n_head = config.gpt_n_head
|
|
self.gpt_hidden_dim = config.gpt_hidden_dim
|
|
|
|
def forward(self, x):
|
|
(
|
|
B,
|
|
T,
|
|
C,
|
|
) = x.size() # batch size, sequence length, embedding dimensionality (gpt_hidden_dim)
|
|
|
|
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
|
q, k, v = self.c_attn(x).split(self.gpt_hidden_dim, dim=2)
|
|
k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
|
|
q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
|
|
v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs)
|
|
|
|
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
|
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
|
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
|
|
att = F.softmax(att, dim=-1)
|
|
att = self.attn_dropout(att)
|
|
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
|
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
|
|
|
# output projection
|
|
y = self.resid_dropout(self.c_proj(y))
|
|
return y
|
|
|
|
|
|
class Block(nn.Module):
|
|
# causual self-attention block for GPT
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.ln_1 = nn.LayerNorm(config.gpt_hidden_dim)
|
|
self.attn = CausalSelfAttention(config)
|
|
self.ln_2 = nn.LayerNorm(config.gpt_hidden_dim)
|
|
self.mlp = nn.Sequential(
|
|
nn.Linear(config.gpt_hidden_dim, 4 * config.gpt_hidden_dim),
|
|
nn.GELU(),
|
|
nn.Linear(4 * config.gpt_hidden_dim, config.gpt_hidden_dim),
|
|
nn.Dropout(config.dropout),
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = x + self.attn(self.ln_1(x))
|
|
x = x + self.mlp(self.ln_2(x))
|
|
return x
|
|
|
|
|
|
class GPT(nn.Module):
|
|
"""
|
|
Original comments:
|
|
Full definition of a GPT Language Model, all of it in this single file.
|
|
References:
|
|
1) the official GPT-2 TensorFlow implementation released by OpenAI:
|
|
https://github.com/openai/gpt-2/blob/master/src/model.py
|
|
2) huggingface/transformers PyTorch implementation:
|
|
https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
|
|
"""
|
|
|
|
def __init__(self, config: VQBeTConfig):
|
|
"""
|
|
GPT model gets hyperparameters from a config object. Please refer configuration_vqbet.py for more details.
|
|
"""
|
|
super().__init__()
|
|
assert config.gpt_output_dim is not None
|
|
assert config.gpt_block_size is not None
|
|
self.config = config
|
|
|
|
self.transformer = nn.ModuleDict(
|
|
{
|
|
"wte": nn.Linear(config.gpt_input_dim, config.gpt_hidden_dim),
|
|
"wpe": nn.Embedding(config.gpt_block_size, config.gpt_hidden_dim),
|
|
"drop": nn.Dropout(config.dropout),
|
|
"h": nn.ModuleList([Block(config) for _ in range(config.gpt_n_layer)]),
|
|
"ln_f": nn.LayerNorm(config.gpt_hidden_dim),
|
|
}
|
|
)
|
|
self.lm_head = nn.Linear(config.gpt_hidden_dim, config.gpt_output_dim, bias=False)
|
|
# init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper
|
|
self.apply(self._init_weights)
|
|
for pn, p in self.named_parameters():
|
|
if pn.endswith("c_proj.weight"):
|
|
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer))
|
|
|
|
# report number of parameters
|
|
n_params = sum(p.numel() for p in self.parameters())
|
|
print("number of parameters: {:.2f}M".format(n_params / 1e6))
|
|
|
|
def forward(self, input, targets=None):
|
|
device = input.device
|
|
b, t, d = input.size()
|
|
assert t <= self.config.gpt_block_size, (
|
|
f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}"
|
|
)
|
|
|
|
# positional encodings that are added to the input embeddings
|
|
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
|
|
|
|
# forward the GPT model itself
|
|
tok_emb = self.transformer.wte(input) # token embeddings of shape (b, t, gpt_hidden_dim)
|
|
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, gpt_hidden_dim)
|
|
x = self.transformer.drop(tok_emb + pos_emb)
|
|
for block in self.transformer.h:
|
|
x = block(x)
|
|
x = self.transformer.ln_f(x)
|
|
logits = self.lm_head(x)
|
|
return logits
|
|
|
|
def _init_weights(self, module):
|
|
if isinstance(module, nn.Linear):
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
if module.bias is not None:
|
|
torch.nn.init.zeros_(module.bias)
|
|
elif isinstance(module, nn.Embedding):
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
elif isinstance(module, nn.LayerNorm):
|
|
torch.nn.init.zeros_(module.bias)
|
|
torch.nn.init.ones_(module.weight)
|
|
|
|
def crop_block_size(self, gpt_block_size):
|
|
# model surgery to decrease the block size if necessary
|
|
# e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
|
|
# but want to use a smaller block size for some smaller, simpler model
|
|
assert gpt_block_size <= self.config.gpt_block_size
|
|
self.config.gpt_block_size = gpt_block_size
|
|
self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:gpt_block_size])
|
|
for block in self.transformer.h:
|
|
block.attn.bias = block.attn.bias[:, :, :gpt_block_size, :gpt_block_size]
|
|
|
|
def configure_parameters(self):
|
|
"""
|
|
This long function is unfortunately doing something very simple and is being very defensive:
|
|
We are separating out all parameters of the model into two buckets: those that will experience
|
|
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
|
|
"""
|
|
|
|
# separate out all parameters to those that will and won't experience regularizing weight decay
|
|
decay = set()
|
|
no_decay = set()
|
|
whitelist_weight_modules = (torch.nn.Linear,)
|
|
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
|
for mn, m in self.named_modules():
|
|
for pn, _p in m.named_parameters():
|
|
fpn = "{}.{}".format(mn, pn) if mn else pn # full param name
|
|
if pn.endswith("bias"):
|
|
# all biases will not be decayed
|
|
no_decay.add(fpn)
|
|
elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
|
|
# weights of whitelist modules will be weight decayed
|
|
decay.add(fpn)
|
|
elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules):
|
|
# weights of blacklist modules will NOT be weight decayed
|
|
no_decay.add(fpn)
|
|
|
|
# validate that we considered every parameter
|
|
param_dict = dict(self.named_parameters())
|
|
inter_params = decay & no_decay
|
|
union_params = decay | no_decay
|
|
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
|
|
str(inter_params)
|
|
)
|
|
assert len(param_dict.keys() - union_params) == 0, (
|
|
"parameters {} were not separated into either decay/no_decay set!".format(
|
|
str(param_dict.keys() - union_params),
|
|
)
|
|
)
|
|
|
|
decay = [param_dict[pn] for pn in sorted(decay)]
|
|
no_decay = [param_dict[pn] for pn in sorted(no_decay)]
|
|
# return the parameters that require weight decay, and the parameters that don't separately.
|
|
return decay, no_decay
|
|
|
|
|
|
"""
|
|
This file is a part for Residual Vector Quantization that utilizes code from the following repository:
|
|
|
|
- Phil Wang's vector-quantize-pytorch implementation in PyTorch.
|
|
Original source: https://github.com/lucidrains/vector-quantize-pytorch
|
|
|
|
- The vector-quantize-pytorch code is licensed under the MIT License:
|
|
|
|
MIT License
|
|
|
|
Copyright (c) 2020 Phil Wang
|
|
|
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
of this software and associated documentation files (the "Software"), to deal
|
|
in the Software without restriction, including without limitation the rights
|
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
copies of the Software, and to permit persons to whom the Software is
|
|
furnished to do so, subject to the following conditions:
|
|
|
|
The above copyright notice and this permission notice shall be included in all
|
|
copies or substantial portions of the Software.
|
|
|
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
SOFTWARE.
|
|
|
|
- We've made some changes to the original code to adapt it to our needs.
|
|
|
|
class ResidualVQ(nn.Module):
|
|
- added `self.register_buffer('freeze_codebook', torch.tensor(False))` to the __init__ method:
|
|
This enables the user to save an indicator whether the codebook is frozen or not.
|
|
- changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`:
|
|
This is to make the function name more descriptive.
|
|
|
|
class VectorQuantize(nn.Module):
|
|
- removed the `use_cosine_sim` and `layernorm_after_project_in` parameters from the __init__ method:
|
|
These parameters are not used in the code.
|
|
- changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`:
|
|
This is to make the function name more descriptive.
|
|
|
|
"""
|
|
|
|
|
|
class ResidualVQ(nn.Module):
|
|
"""
|
|
Residual VQ is composed of multiple VectorQuantize layers.
|
|
|
|
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
|
"Residual Vector Quantizer (a.k.a. multi-stage vector quantizer [36]) cascades Nq layers of VQ as follows. The unquantized input vector is
|
|
passed through a first VQ and quantization residuals are computed. The residuals are then iteratively quantized by a sequence of additional
|
|
Nq -1 vector quantizers, as described in Algorithm 1."
|
|
|
|
|
|
self.project_in: function for projecting input to codebook dimension
|
|
self.project_out: function for projecting codebook dimension to output dimension
|
|
self.layers: nn.ModuleList of VectorQuantize layers that contains Nq layers of VQ as described in the paper.
|
|
self.freeze_codebook: buffer to save an indicator whether the codebook is frozen or not. VQ-BeT will check this to determine whether to update the codebook or not.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
dim,
|
|
num_quantizers,
|
|
codebook_dim=None,
|
|
shared_codebook=False,
|
|
heads=1,
|
|
quantize_dropout=False,
|
|
quantize_dropout_cutoff_index=0,
|
|
quantize_dropout_multiple_of=1,
|
|
accept_image_fmap=False,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
assert heads == 1, "residual vq is not compatible with multi-headed codes"
|
|
codebook_dim = codebook_dim if (codebook_dim is not None) else dim
|
|
codebook_input_dim = codebook_dim * heads
|
|
|
|
requires_projection = codebook_input_dim != dim
|
|
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
|
|
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
|
|
|
|
self.num_quantizers = num_quantizers
|
|
|
|
self.accept_image_fmap = accept_image_fmap
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
VectorQuantize(
|
|
dim=codebook_dim, codebook_dim=codebook_dim, accept_image_fmap=accept_image_fmap, **kwargs
|
|
)
|
|
for _ in range(num_quantizers)
|
|
]
|
|
)
|
|
|
|
self.quantize_dropout = quantize_dropout and num_quantizers > 1
|
|
|
|
assert quantize_dropout_cutoff_index >= 0
|
|
|
|
self.register_buffer("freeze_codebook", torch.tensor(False))
|
|
self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index
|
|
self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4
|
|
|
|
if not shared_codebook:
|
|
return
|
|
|
|
first_vq, *rest_vq = self.layers
|
|
codebook = first_vq._codebook
|
|
|
|
for vq in rest_vq:
|
|
vq._codebook = codebook
|
|
|
|
@property
|
|
def codebooks(self):
|
|
codebooks = [layer._codebook.embed for layer in self.layers]
|
|
codebooks = torch.stack(codebooks, dim=0)
|
|
codebooks = rearrange(codebooks, "q 1 c d -> q c d")
|
|
return codebooks
|
|
|
|
def get_codebook_vector_from_indices(self, indices):
|
|
# this function will return the codes from all codebooks across layers corresponding to the indices
|
|
batch, quantize_dim = indices.shape[0], indices.shape[-1]
|
|
|
|
# may also receive indices in the shape of 'b h w q' (accept_image_fmap)
|
|
|
|
indices, ps = pack([indices], "b * q")
|
|
|
|
# because of quantize dropout, one can pass in indices that are coarse
|
|
# and the network should be able to reconstruct
|
|
|
|
if quantize_dim < self.num_quantizers:
|
|
assert self.quantize_dropout > 0.0, (
|
|
"quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations"
|
|
)
|
|
indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1)
|
|
|
|
# get ready for gathering
|
|
|
|
codebooks = repeat(self.codebooks, "q c d -> q b c d", b=batch)
|
|
gather_indices = repeat(indices, "b n q -> q b n d", d=codebooks.shape[-1])
|
|
|
|
# take care of quantizer dropout
|
|
|
|
mask = gather_indices == -1.0
|
|
gather_indices = gather_indices.masked_fill(
|
|
mask, 0
|
|
) # have it fetch a dummy code to be masked out later
|
|
|
|
all_codes = codebooks.gather(2, gather_indices) # gather all codes
|
|
|
|
# mask out any codes that were dropout-ed
|
|
|
|
all_codes = all_codes.masked_fill(mask, 0.0)
|
|
|
|
# if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension)
|
|
|
|
(all_codes,) = unpack(all_codes, ps, "q b * d")
|
|
|
|
return all_codes
|
|
|
|
def forward(self, x, indices=None, return_all_codes=False, sample_codebook_temp=None):
|
|
"""
|
|
For given input tensor x, this function will return the quantized output, the indices of the quantized output, and the loss.
|
|
First, the input tensor x is projected to the codebook dimension. Then, the input tensor x is passed through Nq layers of VectorQuantize.
|
|
The residual value of each layer is fed to the next layer.
|
|
"""
|
|
num_quant, quant_dropout_multiple_of, return_loss, device = (
|
|
self.num_quantizers,
|
|
self.quantize_dropout_multiple_of,
|
|
(indices is not None),
|
|
x.device,
|
|
)
|
|
|
|
x = self.project_in(x)
|
|
|
|
assert not (self.accept_image_fmap and (indices is not None))
|
|
|
|
quantized_out = 0.0
|
|
residual = x
|
|
|
|
all_losses = []
|
|
all_indices = []
|
|
|
|
if return_loss:
|
|
assert not torch.any(indices == -1), (
|
|
"some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss"
|
|
)
|
|
ce_losses = []
|
|
|
|
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
|
|
|
|
# sample a layer index at which to dropout further residual quantization
|
|
# also prepare null indices and loss
|
|
|
|
if should_quantize_dropout:
|
|
rand_quantize_dropout_index = randrange(self.quantize_dropout_cutoff_index, num_quant)
|
|
|
|
if quant_dropout_multiple_of != 1:
|
|
rand_quantize_dropout_index = (
|
|
ceil((rand_quantize_dropout_index + 1) / quant_dropout_multiple_of)
|
|
* quant_dropout_multiple_of
|
|
- 1
|
|
)
|
|
|
|
null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2])
|
|
null_indices = torch.full(null_indices_shape, -1.0, device=device, dtype=torch.long)
|
|
null_loss = torch.full((1,), 0.0, device=device, dtype=x.dtype)
|
|
|
|
# go through the layers
|
|
|
|
for quantizer_index, layer in enumerate(self.layers):
|
|
if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index:
|
|
all_indices.append(null_indices)
|
|
all_losses.append(null_loss)
|
|
continue
|
|
|
|
layer_indices = None
|
|
if return_loss:
|
|
layer_indices = indices[..., quantizer_index]
|
|
|
|
quantized, *rest = layer(
|
|
residual,
|
|
indices=layer_indices,
|
|
sample_codebook_temp=sample_codebook_temp,
|
|
freeze_codebook=self.freeze_codebook,
|
|
)
|
|
|
|
residual = residual - quantized.detach()
|
|
quantized_out = quantized_out + quantized
|
|
|
|
if return_loss:
|
|
ce_loss = rest[0]
|
|
ce_losses.append(ce_loss)
|
|
continue
|
|
|
|
embed_indices, loss = rest
|
|
|
|
all_indices.append(embed_indices)
|
|
all_losses.append(loss)
|
|
|
|
# project out, if needed
|
|
|
|
quantized_out = self.project_out(quantized_out)
|
|
|
|
# whether to early return the cross entropy loss
|
|
|
|
if return_loss:
|
|
return quantized_out, sum(ce_losses)
|
|
|
|
# stack all losses and indices
|
|
|
|
all_losses, all_indices = map(partial(torch.stack, dim=-1), (all_losses, all_indices))
|
|
|
|
ret = (quantized_out, all_indices, all_losses)
|
|
|
|
if return_all_codes:
|
|
# whether to return all codes from all codebooks across layers
|
|
all_codes = self.get_codebook_vector_from_indices(all_indices)
|
|
|
|
# will return all codes in shape (quantizer, batch, sequence length, codebook dimension)
|
|
ret = (*ret, all_codes)
|
|
|
|
return ret
|
|
|
|
|
|
class VectorQuantize(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
codebook_size,
|
|
codebook_dim=None,
|
|
heads=1,
|
|
separate_codebook_per_head=False,
|
|
decay=0.8,
|
|
eps=1e-5,
|
|
kmeans_init=False,
|
|
kmeans_iters=10,
|
|
sync_kmeans=True,
|
|
threshold_ema_dead_code=0,
|
|
channel_last=True,
|
|
accept_image_fmap=False,
|
|
commitment_weight=1.0,
|
|
commitment_use_cross_entropy_loss=False,
|
|
orthogonal_reg_weight=0.0,
|
|
orthogonal_reg_active_codes_only=False,
|
|
orthogonal_reg_max_codes=None,
|
|
stochastic_sample_codes=False,
|
|
sample_codebook_temp=1.0,
|
|
straight_through=False,
|
|
reinmax=False, # using reinmax for improved straight-through, assuming straight through helps at all
|
|
sync_codebook=None,
|
|
sync_affine_param=False,
|
|
ema_update=True,
|
|
learnable_codebook=False,
|
|
in_place_codebook_optimizer: Callable[
|
|
..., Optimizer
|
|
] = None, # Optimizer used to update the codebook embedding if using learnable_codebook
|
|
affine_param=False,
|
|
affine_param_batch_decay=0.99,
|
|
affine_param_codebook_decay=0.9,
|
|
sync_update_v=0.0, # the v that controls optimistic vs pessimistic update for synchronous update rule (21) https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf
|
|
):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.heads = heads
|
|
self.separate_codebook_per_head = separate_codebook_per_head
|
|
|
|
codebook_dim = codebook_dim if (codebook_dim is not None) else dim
|
|
codebook_input_dim = codebook_dim * heads
|
|
|
|
requires_projection = codebook_input_dim != dim
|
|
self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity()
|
|
self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity()
|
|
|
|
self.eps = eps
|
|
self.commitment_weight = commitment_weight
|
|
self.commitment_use_cross_entropy_loss = commitment_use_cross_entropy_loss # whether to use cross entropy loss to codebook as commitment loss
|
|
|
|
self.learnable_codebook = learnable_codebook
|
|
|
|
has_codebook_orthogonal_loss = orthogonal_reg_weight > 0
|
|
self.has_codebook_orthogonal_loss = has_codebook_orthogonal_loss
|
|
self.orthogonal_reg_weight = orthogonal_reg_weight
|
|
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
|
|
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
|
|
|
|
assert not (ema_update and learnable_codebook), "learnable codebook not compatible with EMA update"
|
|
|
|
assert 0 <= sync_update_v <= 1.0
|
|
assert not (sync_update_v > 0.0 and not learnable_codebook), "learnable codebook must be turned on"
|
|
|
|
self.sync_update_v = sync_update_v
|
|
|
|
gumbel_sample_fn = partial(
|
|
gumbel_sample,
|
|
stochastic=stochastic_sample_codes,
|
|
reinmax=reinmax,
|
|
straight_through=straight_through,
|
|
)
|
|
|
|
if sync_codebook is None:
|
|
sync_codebook = distributed.is_initialized() and distributed.get_world_size() > 1
|
|
|
|
codebook_kwargs = {
|
|
"dim": codebook_dim,
|
|
"num_codebooks": heads if separate_codebook_per_head else 1,
|
|
"codebook_size": codebook_size,
|
|
"kmeans_init": kmeans_init,
|
|
"kmeans_iters": kmeans_iters,
|
|
"sync_kmeans": sync_kmeans,
|
|
"decay": decay,
|
|
"eps": eps,
|
|
"threshold_ema_dead_code": threshold_ema_dead_code,
|
|
"use_ddp": sync_codebook,
|
|
"learnable_codebook": has_codebook_orthogonal_loss or learnable_codebook,
|
|
"sample_codebook_temp": sample_codebook_temp,
|
|
"gumbel_sample": gumbel_sample_fn,
|
|
"ema_update": ema_update,
|
|
}
|
|
|
|
if affine_param:
|
|
codebook_kwargs = dict(
|
|
**codebook_kwargs,
|
|
affine_param=True,
|
|
sync_affine_param=sync_affine_param,
|
|
affine_param_batch_decay=affine_param_batch_decay,
|
|
affine_param_codebook_decay=affine_param_codebook_decay,
|
|
)
|
|
|
|
self._codebook = EuclideanCodebook(**codebook_kwargs)
|
|
|
|
self.in_place_codebook_optimizer = (
|
|
in_place_codebook_optimizer(self._codebook.parameters())
|
|
if (in_place_codebook_optimizer is not None)
|
|
else None
|
|
)
|
|
|
|
self.codebook_size = codebook_size
|
|
|
|
self.accept_image_fmap = accept_image_fmap
|
|
self.channel_last = channel_last
|
|
|
|
@property
|
|
def codebook(self):
|
|
codebook = self._codebook.embed
|
|
|
|
if self.separate_codebook_per_head:
|
|
return codebook
|
|
|
|
return rearrange(codebook, "1 ... -> ...")
|
|
|
|
@codebook.setter
|
|
def codebook(self, codes):
|
|
if not self.separate_codebook_per_head:
|
|
codes = rearrange(codes, "... -> 1 ...")
|
|
|
|
self._codebook.embed.copy_(codes)
|
|
|
|
def get_codebook_vector_from_indices(self, indices):
|
|
codebook = self.codebook
|
|
is_multiheaded = codebook.ndim > 2
|
|
|
|
if not is_multiheaded:
|
|
codes = codebook[indices]
|
|
return rearrange(codes, "... h d -> ... (h d)")
|
|
|
|
indices, ps = pack_one(indices, "b * h")
|
|
indices = rearrange(indices, "b n h -> b h n")
|
|
|
|
indices = repeat(indices, "b h n -> b h n d", d=codebook.shape[-1])
|
|
codebook = repeat(codebook, "h n d -> b h n d", b=indices.shape[0])
|
|
|
|
codes = codebook.gather(2, indices)
|
|
codes = rearrange(codes, "b h n d -> b n (h d)")
|
|
codes = unpack_one(codes, ps, "b * d")
|
|
return codes
|
|
|
|
def forward(
|
|
self,
|
|
x,
|
|
indices=None,
|
|
mask=None,
|
|
sample_codebook_temp=None,
|
|
freeze_codebook=False,
|
|
):
|
|
orig_input = x
|
|
|
|
only_one = x.ndim == 2
|
|
|
|
if only_one:
|
|
assert mask is None
|
|
x = rearrange(x, "b d -> b 1 d")
|
|
|
|
shape, device, heads, is_multiheaded, _codebook_size, return_loss = (
|
|
x.shape,
|
|
x.device,
|
|
self.heads,
|
|
self.heads > 1,
|
|
self.codebook_size,
|
|
(indices is not None),
|
|
)
|
|
|
|
need_transpose = not self.channel_last and not self.accept_image_fmap
|
|
should_inplace_optimize = self.in_place_codebook_optimizer is not None
|
|
|
|
# rearrange inputs
|
|
|
|
if self.accept_image_fmap:
|
|
height, width = x.shape[-2:]
|
|
x = rearrange(x, "b c h w -> b (h w) c")
|
|
|
|
if need_transpose:
|
|
x = rearrange(x, "b d n -> b n d")
|
|
|
|
# project input
|
|
|
|
x = self.project_in(x)
|
|
|
|
# handle multi-headed separate codebooks
|
|
|
|
if is_multiheaded:
|
|
ein_rhs_eq = "h b n d" if self.separate_codebook_per_head else "1 (b h) n d"
|
|
x = rearrange(x, f"b n (h d) -> {ein_rhs_eq}", h=heads)
|
|
|
|
# l2norm for cosine sim, otherwise identity
|
|
|
|
x = self._codebook.transform_input(x)
|
|
|
|
# codebook forward kwargs
|
|
|
|
codebook_forward_kwargs = {
|
|
"sample_codebook_temp": sample_codebook_temp,
|
|
"mask": mask,
|
|
"freeze_codebook": freeze_codebook,
|
|
}
|
|
|
|
# quantize
|
|
|
|
quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs)
|
|
|
|
# one step in-place update
|
|
|
|
if should_inplace_optimize and self.training and not freeze_codebook:
|
|
if mask is not None:
|
|
loss = F.mse_loss(quantize, x.detach(), reduction="none")
|
|
|
|
loss_mask = mask
|
|
if is_multiheaded:
|
|
loss_mask = repeat(
|
|
mask,
|
|
"b n -> c (b h) n",
|
|
c=loss.shape[0],
|
|
h=loss.shape[1] // mask.shape[0],
|
|
)
|
|
|
|
loss = loss[loss_mask].mean()
|
|
|
|
else:
|
|
loss = F.mse_loss(quantize, x.detach())
|
|
|
|
loss.backward()
|
|
self.in_place_codebook_optimizer.step()
|
|
self.in_place_codebook_optimizer.zero_grad()
|
|
|
|
# quantize again
|
|
|
|
quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs)
|
|
|
|
if self.training:
|
|
# determine code to use for commitment loss
|
|
maybe_detach = torch.detach if not self.learnable_codebook or freeze_codebook else identity
|
|
|
|
commit_quantize = maybe_detach(quantize)
|
|
|
|
# straight through
|
|
|
|
quantize = x + (quantize - x).detach()
|
|
|
|
if self.sync_update_v > 0.0:
|
|
# (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf
|
|
quantize = quantize + self.sync_update_v * (quantize - quantize.detach())
|
|
|
|
# function for calculating cross entropy loss to distance matrix
|
|
# used for (1) naturalspeech2 training residual vq latents to be close to the correct codes and (2) cross-entropy based commitment loss
|
|
|
|
def calculate_ce_loss(codes):
|
|
if not is_multiheaded:
|
|
dist_einops_eq = "1 b n l -> b l n"
|
|
elif self.separate_codebook_per_head:
|
|
dist_einops_eq = "c b n l -> b l n c"
|
|
else:
|
|
dist_einops_eq = "1 (b h) n l -> b l n h"
|
|
|
|
ce_loss = F.cross_entropy(
|
|
rearrange(distances, dist_einops_eq, b=shape[0]), codes, ignore_index=-1
|
|
)
|
|
|
|
return ce_loss
|
|
|
|
# if returning cross entropy loss on codes that were passed in
|
|
|
|
if return_loss:
|
|
return quantize, calculate_ce_loss(indices)
|
|
|
|
# transform embedding indices
|
|
|
|
if is_multiheaded:
|
|
if self.separate_codebook_per_head:
|
|
embed_ind = rearrange(embed_ind, "h b n -> b n h", h=heads)
|
|
else:
|
|
embed_ind = rearrange(embed_ind, "1 (b h) n -> b n h", h=heads)
|
|
|
|
if self.accept_image_fmap:
|
|
embed_ind = rearrange(embed_ind, "b (h w) ... -> b h w ...", h=height, w=width)
|
|
|
|
if only_one:
|
|
embed_ind = rearrange(embed_ind, "b 1 -> b")
|
|
|
|
# aggregate loss
|
|
|
|
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
|
|
|
|
if self.training:
|
|
if self.commitment_weight > 0:
|
|
if self.commitment_use_cross_entropy_loss:
|
|
if mask is not None:
|
|
ce_loss_mask = mask
|
|
if is_multiheaded:
|
|
ce_loss_mask = repeat(ce_loss_mask, "b n -> b n h", h=heads)
|
|
|
|
embed_ind.masked_fill_(~ce_loss_mask, -1)
|
|
|
|
commit_loss = calculate_ce_loss(embed_ind)
|
|
else:
|
|
if mask is not None:
|
|
# with variable lengthed sequences
|
|
commit_loss = F.mse_loss(commit_quantize, x, reduction="none")
|
|
|
|
loss_mask = mask
|
|
if is_multiheaded:
|
|
loss_mask = repeat(
|
|
loss_mask,
|
|
"b n -> c (b h) n",
|
|
c=commit_loss.shape[0],
|
|
h=commit_loss.shape[1] // mask.shape[0],
|
|
)
|
|
|
|
commit_loss = commit_loss[loss_mask].mean()
|
|
else:
|
|
commit_loss = F.mse_loss(commit_quantize, x)
|
|
|
|
loss = loss + commit_loss * self.commitment_weight
|
|
|
|
if self.has_codebook_orthogonal_loss:
|
|
codebook = self._codebook.embed
|
|
|
|
# only calculate orthogonal loss for the activated codes for this batch
|
|
|
|
if self.orthogonal_reg_active_codes_only:
|
|
assert not (is_multiheaded and self.separate_codebook_per_head), (
|
|
"orthogonal regularization for only active codes not compatible with multi-headed with separate codebooks yet"
|
|
)
|
|
unique_code_ids = torch.unique(embed_ind)
|
|
codebook = codebook[:, unique_code_ids]
|
|
|
|
num_codes = codebook.shape[-2]
|
|
|
|
if (self.orthogonal_reg_max_codes is not None) and num_codes > self.orthogonal_reg_max_codes:
|
|
rand_ids = torch.randperm(num_codes, device=device)[: self.orthogonal_reg_max_codes]
|
|
codebook = codebook[:, rand_ids]
|
|
|
|
orthogonal_reg_loss = orthogonal_loss_fn(codebook)
|
|
loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
|
|
|
|
# handle multi-headed quantized embeddings
|
|
|
|
if is_multiheaded:
|
|
if self.separate_codebook_per_head:
|
|
quantize = rearrange(quantize, "h b n d -> b n (h d)", h=heads)
|
|
else:
|
|
quantize = rearrange(quantize, "1 (b h) n d -> b n (h d)", h=heads)
|
|
|
|
# project out
|
|
|
|
quantize = self.project_out(quantize)
|
|
|
|
# rearrange quantized embeddings
|
|
|
|
if need_transpose:
|
|
quantize = rearrange(quantize, "b n d -> b d n")
|
|
|
|
if self.accept_image_fmap:
|
|
quantize = rearrange(quantize, "b (h w) c -> b c h w", h=height, w=width)
|
|
|
|
if only_one:
|
|
quantize = rearrange(quantize, "b 1 d -> b d")
|
|
|
|
# if masking, only return quantized for where mask has True
|
|
|
|
if mask is not None:
|
|
quantize = torch.where(rearrange(mask, "... -> ... 1"), quantize, orig_input)
|
|
|
|
return quantize, embed_ind, loss
|
|
|
|
|
|
def noop(*args, **kwargs):
|
|
pass
|
|
|
|
|
|
def identity(t):
|
|
return t
|
|
|
|
|
|
def cdist(x, y):
|
|
x2 = reduce(x**2, "b n d -> b n", "sum")
|
|
y2 = reduce(y**2, "b n d -> b n", "sum")
|
|
xy = einsum("b i d, b j d -> b i j", x, y) * -2
|
|
return (rearrange(x2, "b i -> b i 1") + rearrange(y2, "b j -> b 1 j") + xy).sqrt()
|
|
|
|
|
|
def log(t, eps=1e-20):
|
|
return torch.log(t.clamp(min=eps))
|
|
|
|
|
|
def ema_inplace(old, new, decay):
|
|
is_mps = str(old.device).startswith("mps:")
|
|
|
|
if not is_mps:
|
|
old.lerp_(new, 1 - decay)
|
|
else:
|
|
old.mul_(decay).add_(new * (1 - decay))
|
|
|
|
|
|
def pack_one(t, pattern):
|
|
return pack([t], pattern)
|
|
|
|
|
|
def unpack_one(t, ps, pattern):
|
|
return unpack(t, ps, pattern)[0]
|
|
|
|
|
|
def uniform_init(*shape):
|
|
t = torch.empty(shape)
|
|
nn.init.kaiming_uniform_(t)
|
|
return t
|
|
|
|
|
|
def gumbel_noise(t):
|
|
noise = torch.zeros_like(t).uniform_(0, 1)
|
|
return -log(-log(noise))
|
|
|
|
|
|
def gumbel_sample(
|
|
logits,
|
|
temperature=1.0,
|
|
stochastic=False,
|
|
straight_through=False,
|
|
reinmax=False,
|
|
dim=-1,
|
|
training=True,
|
|
):
|
|
dtype, size = logits.dtype, logits.shape[dim]
|
|
|
|
if training and stochastic and temperature > 0:
|
|
sampling_logits = (logits / temperature) + gumbel_noise(logits)
|
|
else:
|
|
sampling_logits = logits
|
|
|
|
ind = sampling_logits.argmax(dim=dim)
|
|
one_hot = F.one_hot(ind, size).type(dtype)
|
|
|
|
assert not (reinmax and not straight_through), (
|
|
"reinmax can only be turned on if using straight through gumbel softmax"
|
|
)
|
|
|
|
if not straight_through or temperature <= 0.0 or not training:
|
|
return ind, one_hot
|
|
|
|
# use reinmax for better second-order accuracy - https://arxiv.org/abs/2304.08612
|
|
# algorithm 2
|
|
|
|
if reinmax:
|
|
π0 = logits.softmax(dim=dim)
|
|
π1 = (one_hot + (logits / temperature).softmax(dim=dim)) / 2
|
|
π1 = ((log(π1) - logits).detach() + logits).softmax(dim=1)
|
|
π2 = 2 * π1 - 0.5 * π0
|
|
one_hot = π2 - π2.detach() + one_hot
|
|
else:
|
|
π1 = (logits / temperature).softmax(dim=dim)
|
|
one_hot = one_hot + π1 - π1.detach()
|
|
|
|
return ind, one_hot
|
|
|
|
|
|
def laplace_smoothing(x, n_categories, eps=1e-5, dim=-1):
|
|
denom = x.sum(dim=dim, keepdim=True)
|
|
return (x + eps) / (denom + n_categories * eps)
|
|
|
|
|
|
def sample_vectors(samples, num):
|
|
num_samples, device = samples.shape[0], samples.device
|
|
if num_samples >= num:
|
|
indices = torch.randperm(num_samples, device=device)[:num]
|
|
else:
|
|
indices = torch.randint(0, num_samples, (num,), device=device)
|
|
|
|
return samples[indices]
|
|
|
|
|
|
def batched_sample_vectors(samples, num):
|
|
return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0)
|
|
|
|
|
|
def pad_shape(shape, size, dim=0):
|
|
return [size if i == dim else s for i, s in enumerate(shape)]
|
|
|
|
|
|
def sample_multinomial(total_count, probs):
|
|
device = probs.device
|
|
probs = probs.cpu()
|
|
|
|
total_count = probs.new_full((), total_count)
|
|
remainder = probs.new_ones(())
|
|
sample = torch.empty_like(probs, dtype=torch.long)
|
|
|
|
for i, p in enumerate(probs):
|
|
s = torch.binomial(total_count, p / remainder)
|
|
sample[i] = s
|
|
total_count -= s
|
|
remainder -= p
|
|
|
|
return sample.to(device)
|
|
|
|
|
|
def all_gather_sizes(x, dim):
|
|
size = torch.tensor(x.shape[dim], dtype=torch.long, device=x.device)
|
|
all_sizes = [torch.empty_like(size) for _ in range(distributed.get_world_size())]
|
|
distributed.all_gather(all_sizes, size)
|
|
return torch.stack(all_sizes)
|
|
|
|
|
|
def all_gather_variably_sized(x, sizes, dim=0):
|
|
rank = distributed.get_rank()
|
|
all_x = []
|
|
|
|
for i, size in enumerate(sizes):
|
|
t = x if i == rank else x.new_empty(pad_shape(x.shape, size, dim))
|
|
distributed.broadcast(t, src=i, async_op=True)
|
|
all_x.append(t)
|
|
|
|
distributed.barrier()
|
|
return all_x
|
|
|
|
|
|
def sample_vectors_distributed(local_samples, num):
|
|
local_samples = rearrange(local_samples, "1 ... -> ...")
|
|
|
|
rank = distributed.get_rank()
|
|
all_num_samples = all_gather_sizes(local_samples, dim=0)
|
|
|
|
if rank == 0:
|
|
samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum())
|
|
else:
|
|
samples_per_rank = torch.empty_like(all_num_samples)
|
|
|
|
distributed.broadcast(samples_per_rank, src=0)
|
|
samples_per_rank = samples_per_rank.tolist()
|
|
|
|
local_samples = sample_vectors(local_samples, samples_per_rank[rank])
|
|
all_samples = all_gather_variably_sized(local_samples, samples_per_rank, dim=0)
|
|
out = torch.cat(all_samples, dim=0)
|
|
|
|
return rearrange(out, "... -> 1 ...")
|
|
|
|
|
|
def batched_bincount(x, *, minlength):
|
|
batch, dtype, device = x.shape[0], x.dtype, x.device
|
|
target = torch.zeros(batch, minlength, dtype=dtype, device=device)
|
|
values = torch.ones_like(x)
|
|
target.scatter_add_(-1, x, values)
|
|
return target
|
|
|
|
|
|
def kmeans(
|
|
samples,
|
|
num_clusters,
|
|
num_iters=10,
|
|
sample_fn=batched_sample_vectors,
|
|
all_reduce_fn=noop,
|
|
):
|
|
num_codebooks, dim, dtype, _device = (
|
|
samples.shape[0],
|
|
samples.shape[-1],
|
|
samples.dtype,
|
|
samples.device,
|
|
)
|
|
|
|
means = sample_fn(samples, num_clusters)
|
|
|
|
for _ in range(num_iters):
|
|
dists = -torch.cdist(samples, means, p=2)
|
|
|
|
buckets = torch.argmax(dists, dim=-1)
|
|
bins = batched_bincount(buckets, minlength=num_clusters)
|
|
all_reduce_fn(bins)
|
|
|
|
zero_mask = bins == 0
|
|
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
|
|
|
new_means = buckets.new_zeros(num_codebooks, num_clusters, dim, dtype=dtype)
|
|
|
|
new_means.scatter_add_(1, repeat(buckets, "h n -> h n d", d=dim), samples)
|
|
new_means = new_means / rearrange(bins_min_clamped, "... -> ... 1")
|
|
all_reduce_fn(new_means)
|
|
|
|
means = torch.where(rearrange(zero_mask, "... -> ... 1"), means, new_means)
|
|
|
|
return means, bins
|
|
|
|
|
|
def batched_embedding(indices, embeds):
|
|
batch, dim = indices.shape[1], embeds.shape[-1]
|
|
indices = repeat(indices, "h b n -> h b n d", d=dim)
|
|
embeds = repeat(embeds, "h c d -> h b c d", b=batch)
|
|
return embeds.gather(2, indices)
|
|
|
|
|
|
def orthogonal_loss_fn(t):
|
|
# eq (2) from https://arxiv.org/abs/2112.00384
|
|
h, n = t.shape[:2]
|
|
normed_codes = F.normalize(t, p=2, dim=-1)
|
|
cosine_sim = einsum("h i d, h j d -> h i j", normed_codes, normed_codes)
|
|
return (cosine_sim**2).sum() / (h * n**2) - (1 / n)
|
|
|
|
|
|
class EuclideanCodebook(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim,
|
|
codebook_size,
|
|
num_codebooks=1,
|
|
kmeans_init=False,
|
|
kmeans_iters=10,
|
|
sync_kmeans=True,
|
|
decay=0.8,
|
|
eps=1e-5,
|
|
threshold_ema_dead_code=2,
|
|
reset_cluster_size=None,
|
|
use_ddp=False,
|
|
learnable_codebook=False,
|
|
gumbel_sample=gumbel_sample,
|
|
sample_codebook_temp=1.0,
|
|
ema_update=True,
|
|
affine_param=False,
|
|
sync_affine_param=False,
|
|
affine_param_batch_decay=0.99,
|
|
affine_param_codebook_decay=0.9,
|
|
):
|
|
super().__init__()
|
|
self.transform_input = identity
|
|
|
|
self.decay = decay
|
|
self.ema_update = ema_update
|
|
|
|
init_fn = uniform_init if not kmeans_init else torch.zeros
|
|
embed = init_fn(num_codebooks, codebook_size, dim)
|
|
|
|
self.codebook_size = codebook_size
|
|
self.num_codebooks = num_codebooks
|
|
|
|
self.kmeans_iters = kmeans_iters
|
|
self.eps = eps
|
|
self.threshold_ema_dead_code = threshold_ema_dead_code
|
|
self.reset_cluster_size = (
|
|
reset_cluster_size if (reset_cluster_size is not None) else threshold_ema_dead_code
|
|
)
|
|
|
|
assert callable(gumbel_sample)
|
|
self.gumbel_sample = gumbel_sample
|
|
self.sample_codebook_temp = sample_codebook_temp
|
|
|
|
assert not (use_ddp and num_codebooks > 1 and kmeans_init), (
|
|
"kmeans init is not compatible with multiple codebooks in distributed environment for now"
|
|
)
|
|
|
|
self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors
|
|
self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop
|
|
self.all_reduce_fn = distributed.all_reduce if use_ddp else noop
|
|
|
|
self.register_buffer("initted", torch.Tensor([not kmeans_init]))
|
|
self.register_buffer("cluster_size", torch.zeros(num_codebooks, codebook_size))
|
|
self.register_buffer("embed_avg", embed.clone())
|
|
|
|
self.learnable_codebook = learnable_codebook
|
|
if learnable_codebook:
|
|
self.embed = nn.Parameter(embed)
|
|
else:
|
|
self.register_buffer("embed", embed)
|
|
|
|
# affine related params
|
|
|
|
self.affine_param = affine_param
|
|
self.sync_affine_param = sync_affine_param
|
|
|
|
if not affine_param:
|
|
return
|
|
|
|
self.affine_param_batch_decay = affine_param_batch_decay
|
|
self.affine_param_codebook_decay = affine_param_codebook_decay
|
|
|
|
self.register_buffer("batch_mean", None)
|
|
self.register_buffer("batch_variance", None)
|
|
|
|
self.register_buffer("codebook_mean_needs_init", torch.Tensor([True]))
|
|
self.register_buffer("codebook_mean", torch.empty(num_codebooks, 1, dim))
|
|
self.register_buffer("codebook_variance_needs_init", torch.Tensor([True]))
|
|
self.register_buffer("codebook_variance", torch.empty(num_codebooks, 1, dim))
|
|
|
|
@torch.jit.ignore
|
|
def init_embed_(self, data, mask=None):
|
|
if self.initted:
|
|
return
|
|
|
|
if mask is not None:
|
|
c = data.shape[0]
|
|
data = rearrange(data[mask], "(c n) d -> c n d", c=c)
|
|
|
|
embed, cluster_size = kmeans(
|
|
data,
|
|
self.codebook_size,
|
|
self.kmeans_iters,
|
|
sample_fn=self.sample_fn,
|
|
all_reduce_fn=self.kmeans_all_reduce_fn,
|
|
)
|
|
|
|
embed_sum = embed * rearrange(cluster_size, "... -> ... 1")
|
|
|
|
self.embed.data.copy_(embed)
|
|
self.embed_avg.data.copy_(embed_sum)
|
|
self.cluster_size.data.copy_(cluster_size)
|
|
self.initted.data.copy_(torch.Tensor([True]))
|
|
|
|
@torch.jit.ignore
|
|
def update_with_decay(self, buffer_name, new_value, decay):
|
|
old_value = getattr(self, buffer_name)
|
|
|
|
needs_init = getattr(self, buffer_name + "_needs_init", False)
|
|
|
|
if needs_init:
|
|
self.register_buffer(buffer_name + "_needs_init", torch.Tensor([False]))
|
|
|
|
if not (old_value is not None) or needs_init:
|
|
self.register_buffer(buffer_name, new_value.detach())
|
|
|
|
return
|
|
|
|
value = old_value * decay + new_value.detach() * (1 - decay)
|
|
self.register_buffer(buffer_name, value)
|
|
|
|
@torch.jit.ignore
|
|
def update_affine(self, data, embed, mask=None):
|
|
assert self.affine_param
|
|
|
|
var_fn = partial(torch.var, unbiased=False)
|
|
|
|
# calculate codebook mean and variance
|
|
|
|
embed = rearrange(embed, "h ... d -> h (...) d")
|
|
|
|
if self.training:
|
|
self.update_with_decay(
|
|
"codebook_mean",
|
|
reduce(embed, "h n d -> h 1 d", "mean"),
|
|
self.affine_param_codebook_decay,
|
|
)
|
|
self.update_with_decay(
|
|
"codebook_variance",
|
|
reduce(embed, "h n d -> h 1 d", var_fn),
|
|
self.affine_param_codebook_decay,
|
|
)
|
|
|
|
# prepare batch data, which depends on whether it has masking
|
|
|
|
data = rearrange(data, "h ... d -> h (...) d")
|
|
|
|
if mask is not None:
|
|
c = data.shape[0]
|
|
data = rearrange(data[mask], "(c n) d -> c n d", c=c)
|
|
|
|
# calculate batch mean and variance
|
|
|
|
if not self.sync_affine_param:
|
|
self.update_with_decay(
|
|
"batch_mean",
|
|
reduce(data, "h n d -> h 1 d", "mean"),
|
|
self.affine_param_batch_decay,
|
|
)
|
|
self.update_with_decay(
|
|
"batch_variance",
|
|
reduce(data, "h n d -> h 1 d", var_fn),
|
|
self.affine_param_batch_decay,
|
|
)
|
|
return
|
|
|
|
num_vectors, device, dtype = data.shape[-2], data.device, data.dtype
|
|
|
|
# number of vectors, for denominator
|
|
|
|
num_vectors = torch.tensor([num_vectors], device=device, dtype=dtype)
|
|
distributed.all_reduce(num_vectors)
|
|
|
|
# calculate distributed mean
|
|
|
|
batch_sum = reduce(data, "h n d -> h 1 d", "sum")
|
|
distributed.all_reduce(batch_sum)
|
|
batch_mean = batch_sum / num_vectors
|
|
|
|
self.update_with_decay("batch_mean", batch_mean, self.affine_param_batch_decay)
|
|
|
|
# calculate distributed variance
|
|
|
|
variance_number = reduce((data - batch_mean) ** 2, "h n d -> h 1 d", "sum")
|
|
distributed.all_reduce(variance_number)
|
|
batch_variance = variance_number / num_vectors
|
|
|
|
self.update_with_decay("batch_variance", batch_variance, self.affine_param_batch_decay)
|
|
|
|
def replace(self, batch_samples, batch_mask):
|
|
for ind, (samples, mask) in enumerate(
|
|
zip(batch_samples.unbind(dim=0), batch_mask.unbind(dim=0), strict=False)
|
|
):
|
|
if not torch.any(mask):
|
|
continue
|
|
|
|
sampled = self.sample_fn(rearrange(samples, "... -> 1 ..."), mask.sum().item())
|
|
sampled = rearrange(sampled, "1 ... -> ...")
|
|
|
|
self.embed.data[ind][mask] = sampled
|
|
|
|
self.cluster_size.data[ind][mask] = self.reset_cluster_size
|
|
self.embed_avg.data[ind][mask] = sampled * self.reset_cluster_size
|
|
|
|
def expire_codes_(self, batch_samples):
|
|
if self.threshold_ema_dead_code == 0:
|
|
return
|
|
|
|
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
|
|
|
if not torch.any(expired_codes):
|
|
return
|
|
|
|
batch_samples = rearrange(batch_samples, "h ... d -> h (...) d")
|
|
self.replace(batch_samples, batch_mask=expired_codes)
|
|
|
|
@autocast(enabled=False)
|
|
def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False):
|
|
needs_codebook_dim = x.ndim < 4
|
|
sample_codebook_temp = (
|
|
sample_codebook_temp if (sample_codebook_temp is not None) else self.sample_codebook_temp
|
|
)
|
|
|
|
x = x.float()
|
|
|
|
if needs_codebook_dim:
|
|
x = rearrange(x, "... -> 1 ...")
|
|
|
|
flatten, ps = pack_one(x, "h * d")
|
|
|
|
if mask is not None:
|
|
mask = repeat(
|
|
mask,
|
|
"b n -> c (b h n)",
|
|
c=flatten.shape[0],
|
|
h=flatten.shape[-2] // (mask.shape[0] * mask.shape[1]),
|
|
)
|
|
|
|
self.init_embed_(flatten, mask=mask)
|
|
|
|
if self.affine_param:
|
|
self.update_affine(flatten, self.embed, mask=mask)
|
|
|
|
embed = self.embed if self.learnable_codebook else self.embed.detach()
|
|
|
|
if self.affine_param:
|
|
codebook_std = self.codebook_variance.clamp(min=1e-5).sqrt()
|
|
batch_std = self.batch_variance.clamp(min=1e-5).sqrt()
|
|
embed = (embed - self.codebook_mean) * (batch_std / codebook_std) + self.batch_mean
|
|
|
|
dist = -cdist(flatten, embed)
|
|
|
|
embed_ind, embed_onehot = self.gumbel_sample(
|
|
dist, dim=-1, temperature=sample_codebook_temp, training=self.training
|
|
)
|
|
|
|
embed_ind = unpack_one(embed_ind, ps, "h *")
|
|
|
|
if self.training:
|
|
unpacked_onehot = unpack_one(embed_onehot, ps, "h * c")
|
|
quantize = einsum("h b n c, h c d -> h b n d", unpacked_onehot, embed)
|
|
else:
|
|
quantize = batched_embedding(embed_ind, embed)
|
|
|
|
if self.training and self.ema_update and not freeze_codebook:
|
|
if self.affine_param:
|
|
flatten = (flatten - self.batch_mean) * (codebook_std / batch_std) + self.codebook_mean
|
|
|
|
if mask is not None:
|
|
embed_onehot[~mask] = 0.0
|
|
|
|
cluster_size = embed_onehot.sum(dim=1)
|
|
|
|
self.all_reduce_fn(cluster_size)
|
|
ema_inplace(self.cluster_size.data, cluster_size, self.decay)
|
|
|
|
embed_sum = einsum("h n d, h n c -> h c d", flatten, embed_onehot)
|
|
self.all_reduce_fn(embed_sum.contiguous())
|
|
ema_inplace(self.embed_avg.data, embed_sum, self.decay)
|
|
|
|
cluster_size = laplace_smoothing(
|
|
self.cluster_size, self.codebook_size, self.eps
|
|
) * self.cluster_size.sum(dim=-1, keepdim=True)
|
|
|
|
embed_normalized = self.embed_avg / rearrange(cluster_size, "... -> ... 1")
|
|
self.embed.data.copy_(embed_normalized)
|
|
self.expire_codes_(x)
|
|
|
|
if needs_codebook_dim:
|
|
quantize, embed_ind = tuple(rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind))
|
|
|
|
dist = unpack_one(dist, ps, "h * d")
|
|
|
|
return quantize, embed_ind, dist
|