delete some short functions
This commit is contained in:
parent
975da28461
commit
1778dee9ab
|
@ -937,17 +937,6 @@ def pretrain_vqvae(vqvae_model, discretize_step, actions):
|
|||
return loss, n_different_codes, n_different_combinations
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
|
||||
def round_up_multiple(num, mult):
|
||||
return ceil(num / mult) * mult
|
||||
|
||||
|
||||
class ResidualVQ(nn.Module):
|
||||
"""
|
||||
|
@ -991,7 +980,7 @@ class ResidualVQ(nn.Module):
|
|||
):
|
||||
super().__init__()
|
||||
assert heads == 1, "residual vq is not compatible with multi-headed codes"
|
||||
codebook_dim = default(codebook_dim, dim)
|
||||
codebook_dim = codebook_dim if (codebook_dim is not None) else dim
|
||||
codebook_input_dim = codebook_dim * heads
|
||||
|
||||
requires_projection = codebook_input_dim != dim
|
||||
|
@ -1095,13 +1084,13 @@ class ResidualVQ(nn.Module):
|
|||
num_quant, quant_dropout_multiple_of, return_loss, device = (
|
||||
self.num_quantizers,
|
||||
self.quantize_dropout_multiple_of,
|
||||
exists(indices),
|
||||
(indices is not None),
|
||||
x.device,
|
||||
)
|
||||
|
||||
x = self.project_in(x)
|
||||
|
||||
assert not (self.accept_image_fmap and exists(indices))
|
||||
assert not (self.accept_image_fmap and (indices is not None))
|
||||
|
||||
quantized_out = 0.0
|
||||
residual = x
|
||||
|
@ -1129,9 +1118,7 @@ class ResidualVQ(nn.Module):
|
|||
|
||||
if quant_dropout_multiple_of != 1:
|
||||
rand_quantize_dropout_index = (
|
||||
round_up_multiple(
|
||||
rand_quantize_dropout_index + 1, quant_dropout_multiple_of
|
||||
)
|
||||
ceil((rand_quantize_dropout_index + 1) / quant_dropout_multiple_of) * quant_dropout_multiple_of
|
||||
- 1
|
||||
)
|
||||
|
||||
|
@ -1251,7 +1238,7 @@ class VectorQuantize(nn.Module):
|
|||
self.heads = heads
|
||||
self.separate_codebook_per_head = separate_codebook_per_head
|
||||
|
||||
codebook_dim = default(codebook_dim, dim)
|
||||
codebook_dim = codebook_dim if (codebook_dim is not None) else dim
|
||||
codebook_input_dim = codebook_dim * heads
|
||||
|
||||
requires_projection = codebook_input_dim != dim
|
||||
|
@ -1293,7 +1280,7 @@ class VectorQuantize(nn.Module):
|
|||
straight_through=straight_through,
|
||||
)
|
||||
|
||||
if not exists(sync_codebook):
|
||||
if sync_codebook is None:
|
||||
sync_codebook = (
|
||||
distributed.is_initialized() and distributed.get_world_size() > 1
|
||||
)
|
||||
|
@ -1328,7 +1315,7 @@ class VectorQuantize(nn.Module):
|
|||
|
||||
self.in_place_codebook_optimizer = (
|
||||
in_place_codebook_optimizer(self._codebook.parameters())
|
||||
if exists(in_place_codebook_optimizer)
|
||||
if (in_place_codebook_optimizer is not None)
|
||||
else None
|
||||
)
|
||||
|
||||
|
@ -1385,7 +1372,7 @@ class VectorQuantize(nn.Module):
|
|||
only_one = x.ndim == 2
|
||||
|
||||
if only_one:
|
||||
assert not exists(mask)
|
||||
assert mask is None
|
||||
x = rearrange(x, "b d -> b 1 d")
|
||||
|
||||
shape, device, heads, is_multiheaded, codebook_size, return_loss = (
|
||||
|
@ -1394,11 +1381,11 @@ class VectorQuantize(nn.Module):
|
|||
self.heads,
|
||||
self.heads > 1,
|
||||
self.codebook_size,
|
||||
exists(indices),
|
||||
(indices is not None),
|
||||
)
|
||||
|
||||
need_transpose = not self.channel_last and not self.accept_image_fmap
|
||||
should_inplace_optimize = exists(self.in_place_codebook_optimizer)
|
||||
should_inplace_optimize = (self.in_place_codebook_optimizer is not None)
|
||||
|
||||
# rearrange inputs
|
||||
|
||||
|
@ -1438,7 +1425,7 @@ class VectorQuantize(nn.Module):
|
|||
# one step in-place update
|
||||
|
||||
if should_inplace_optimize and self.training and not freeze_codebook:
|
||||
if exists(mask):
|
||||
if (mask is not None):
|
||||
loss = F.mse_loss(quantize, x.detach(), reduction="none")
|
||||
|
||||
loss_mask = mask
|
||||
|
@ -1530,7 +1517,7 @@ class VectorQuantize(nn.Module):
|
|||
if self.training:
|
||||
if self.commitment_weight > 0:
|
||||
if self.commitment_use_cross_entropy_loss:
|
||||
if exists(mask):
|
||||
if (mask is not None):
|
||||
ce_loss_mask = mask
|
||||
if is_multiheaded:
|
||||
ce_loss_mask = repeat(ce_loss_mask, "b n -> b n h", h=heads)
|
||||
|
@ -1539,7 +1526,7 @@ class VectorQuantize(nn.Module):
|
|||
|
||||
commit_loss = calculate_ce_loss(embed_ind)
|
||||
else:
|
||||
if exists(mask):
|
||||
if (mask is not None):
|
||||
# with variable lengthed sequences
|
||||
commit_loss = F.mse_loss(commit_quantize, x, reduction="none")
|
||||
|
||||
|
@ -1573,7 +1560,7 @@ class VectorQuantize(nn.Module):
|
|||
num_codes = codebook.shape[-2]
|
||||
|
||||
if (
|
||||
exists(self.orthogonal_reg_max_codes)
|
||||
(self.orthogonal_reg_max_codes is not None)
|
||||
and num_codes > self.orthogonal_reg_max_codes
|
||||
):
|
||||
rand_ids = torch.randperm(num_codes, device=device)[
|
||||
|
@ -1609,19 +1596,13 @@ class VectorQuantize(nn.Module):
|
|||
|
||||
# if masking, only return quantized for where mask has True
|
||||
|
||||
if exists(mask):
|
||||
if (mask is not None):
|
||||
quantize = torch.where(
|
||||
rearrange(mask, "... -> ... 1"), quantize, orig_input
|
||||
)
|
||||
|
||||
return quantize, embed_ind, loss
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
|
||||
def noop(*args, **kwargs):
|
||||
|
@ -1631,11 +1612,6 @@ def noop(*args, **kwargs):
|
|||
def identity(t):
|
||||
return t
|
||||
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, p=2, dim=-1)
|
||||
|
||||
|
||||
def cdist(x, y):
|
||||
x2 = reduce(x**2, "b n d -> b n", "sum")
|
||||
y2 = reduce(y**2, "b n d -> b n", "sum")
|
||||
|
@ -1860,7 +1836,7 @@ def batched_embedding(indices, embeds):
|
|||
def orthogonal_loss_fn(t):
|
||||
# eq (2) from https://arxiv.org/abs/2112.00384
|
||||
h, n = t.shape[:2]
|
||||
normed_codes = l2norm(t)
|
||||
normed_codes = F.normalize(t, p=2, dim=-1)
|
||||
cosine_sim = einsum("h i d, h j d -> h i j", normed_codes, normed_codes)
|
||||
return (cosine_sim**2).sum() / (h * n**2) - (1 / n)
|
||||
|
||||
|
@ -1906,7 +1882,7 @@ class EuclideanCodebook(nn.Module):
|
|||
self.kmeans_iters = kmeans_iters
|
||||
self.eps = eps
|
||||
self.threshold_ema_dead_code = threshold_ema_dead_code
|
||||
self.reset_cluster_size = default(reset_cluster_size, threshold_ema_dead_code)
|
||||
self.reset_cluster_size = reset_cluster_size if (reset_cluster_size is not None) else threshold_ema_dead_code
|
||||
|
||||
assert callable(gumbel_sample)
|
||||
self.gumbel_sample = gumbel_sample
|
||||
|
@ -1960,7 +1936,7 @@ class EuclideanCodebook(nn.Module):
|
|||
if self.initted:
|
||||
return
|
||||
|
||||
if exists(mask):
|
||||
if (mask is not None):
|
||||
c = data.shape[0]
|
||||
data = rearrange(data[mask], "(c n) d -> c n d", c=c)
|
||||
|
||||
|
@ -1988,7 +1964,7 @@ class EuclideanCodebook(nn.Module):
|
|||
if needs_init:
|
||||
self.register_buffer(buffer_name + "_needs_init", torch.Tensor([False]))
|
||||
|
||||
if not exists(old_value) or needs_init:
|
||||
if not (old_value is not None) or needs_init:
|
||||
self.register_buffer(buffer_name, new_value.detach())
|
||||
|
||||
return
|
||||
|
@ -2022,7 +1998,7 @@ class EuclideanCodebook(nn.Module):
|
|||
|
||||
data = rearrange(data, "h ... d -> h (...) d")
|
||||
|
||||
if exists(mask):
|
||||
if (mask is not None):
|
||||
c = data.shape[0]
|
||||
data = rearrange(data[mask], "(c n) d -> c n d", c=c)
|
||||
|
||||
|
@ -2098,7 +2074,7 @@ class EuclideanCodebook(nn.Module):
|
|||
@autocast(enabled=False)
|
||||
def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False):
|
||||
needs_codebook_dim = x.ndim < 4
|
||||
sample_codebook_temp = default(sample_codebook_temp, self.sample_codebook_temp)
|
||||
sample_codebook_temp = sample_codebook_temp if (sample_codebook_temp is not None) else self.sample_codebook_temp
|
||||
|
||||
x = x.float()
|
||||
|
||||
|
@ -2108,7 +2084,7 @@ class EuclideanCodebook(nn.Module):
|
|||
dtype = x.dtype
|
||||
flatten, ps = pack_one(x, "h * d")
|
||||
|
||||
if exists(mask):
|
||||
if (mask is not None):
|
||||
mask = repeat(
|
||||
mask,
|
||||
"b n -> c (b h n)",
|
||||
|
@ -2150,7 +2126,7 @@ class EuclideanCodebook(nn.Module):
|
|||
codebook_std / batch_std
|
||||
) + self.codebook_mean
|
||||
|
||||
if exists(mask):
|
||||
if (mask is not None):
|
||||
embed_onehot[~mask] = 0.0
|
||||
|
||||
cluster_size = embed_onehot.sum(dim=1)
|
||||
|
|
Loading…
Reference in New Issue