delete some short functions

This commit is contained in:
jayLEE0301 2024-06-05 13:47:06 -04:00
parent 975da28461
commit 1778dee9ab
1 changed files with 23 additions and 47 deletions

View File

@ -937,17 +937,6 @@ def pretrain_vqvae(vqvae_model, discretize_step, actions):
return loss, n_different_codes, n_different_combinations
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def round_up_multiple(num, mult):
return ceil(num / mult) * mult
class ResidualVQ(nn.Module):
"""
@ -991,7 +980,7 @@ class ResidualVQ(nn.Module):
):
super().__init__()
assert heads == 1, "residual vq is not compatible with multi-headed codes"
codebook_dim = default(codebook_dim, dim)
codebook_dim = codebook_dim if (codebook_dim is not None) else dim
codebook_input_dim = codebook_dim * heads
requires_projection = codebook_input_dim != dim
@ -1095,13 +1084,13 @@ class ResidualVQ(nn.Module):
num_quant, quant_dropout_multiple_of, return_loss, device = (
self.num_quantizers,
self.quantize_dropout_multiple_of,
exists(indices),
(indices is not None),
x.device,
)
x = self.project_in(x)
assert not (self.accept_image_fmap and exists(indices))
assert not (self.accept_image_fmap and (indices is not None))
quantized_out = 0.0
residual = x
@ -1129,9 +1118,7 @@ class ResidualVQ(nn.Module):
if quant_dropout_multiple_of != 1:
rand_quantize_dropout_index = (
round_up_multiple(
rand_quantize_dropout_index + 1, quant_dropout_multiple_of
)
ceil((rand_quantize_dropout_index + 1) / quant_dropout_multiple_of) * quant_dropout_multiple_of
- 1
)
@ -1251,7 +1238,7 @@ class VectorQuantize(nn.Module):
self.heads = heads
self.separate_codebook_per_head = separate_codebook_per_head
codebook_dim = default(codebook_dim, dim)
codebook_dim = codebook_dim if (codebook_dim is not None) else dim
codebook_input_dim = codebook_dim * heads
requires_projection = codebook_input_dim != dim
@ -1293,7 +1280,7 @@ class VectorQuantize(nn.Module):
straight_through=straight_through,
)
if not exists(sync_codebook):
if sync_codebook is None:
sync_codebook = (
distributed.is_initialized() and distributed.get_world_size() > 1
)
@ -1328,7 +1315,7 @@ class VectorQuantize(nn.Module):
self.in_place_codebook_optimizer = (
in_place_codebook_optimizer(self._codebook.parameters())
if exists(in_place_codebook_optimizer)
if (in_place_codebook_optimizer is not None)
else None
)
@ -1385,7 +1372,7 @@ class VectorQuantize(nn.Module):
only_one = x.ndim == 2
if only_one:
assert not exists(mask)
assert mask is None
x = rearrange(x, "b d -> b 1 d")
shape, device, heads, is_multiheaded, codebook_size, return_loss = (
@ -1394,11 +1381,11 @@ class VectorQuantize(nn.Module):
self.heads,
self.heads > 1,
self.codebook_size,
exists(indices),
(indices is not None),
)
need_transpose = not self.channel_last and not self.accept_image_fmap
should_inplace_optimize = exists(self.in_place_codebook_optimizer)
should_inplace_optimize = (self.in_place_codebook_optimizer is not None)
# rearrange inputs
@ -1438,7 +1425,7 @@ class VectorQuantize(nn.Module):
# one step in-place update
if should_inplace_optimize and self.training and not freeze_codebook:
if exists(mask):
if (mask is not None):
loss = F.mse_loss(quantize, x.detach(), reduction="none")
loss_mask = mask
@ -1530,7 +1517,7 @@ class VectorQuantize(nn.Module):
if self.training:
if self.commitment_weight > 0:
if self.commitment_use_cross_entropy_loss:
if exists(mask):
if (mask is not None):
ce_loss_mask = mask
if is_multiheaded:
ce_loss_mask = repeat(ce_loss_mask, "b n -> b n h", h=heads)
@ -1539,7 +1526,7 @@ class VectorQuantize(nn.Module):
commit_loss = calculate_ce_loss(embed_ind)
else:
if exists(mask):
if (mask is not None):
# with variable lengthed sequences
commit_loss = F.mse_loss(commit_quantize, x, reduction="none")
@ -1573,7 +1560,7 @@ class VectorQuantize(nn.Module):
num_codes = codebook.shape[-2]
if (
exists(self.orthogonal_reg_max_codes)
(self.orthogonal_reg_max_codes is not None)
and num_codes > self.orthogonal_reg_max_codes
):
rand_ids = torch.randperm(num_codes, device=device)[
@ -1609,19 +1596,13 @@ class VectorQuantize(nn.Module):
# if masking, only return quantized for where mask has True
if exists(mask):
if (mask is not None):
quantize = torch.where(
rearrange(mask, "... -> ... 1"), quantize, orig_input
)
return quantize, embed_ind, loss
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def noop(*args, **kwargs):
@ -1631,11 +1612,6 @@ def noop(*args, **kwargs):
def identity(t):
return t
def l2norm(t):
return F.normalize(t, p=2, dim=-1)
def cdist(x, y):
x2 = reduce(x**2, "b n d -> b n", "sum")
y2 = reduce(y**2, "b n d -> b n", "sum")
@ -1860,7 +1836,7 @@ def batched_embedding(indices, embeds):
def orthogonal_loss_fn(t):
# eq (2) from https://arxiv.org/abs/2112.00384
h, n = t.shape[:2]
normed_codes = l2norm(t)
normed_codes = F.normalize(t, p=2, dim=-1)
cosine_sim = einsum("h i d, h j d -> h i j", normed_codes, normed_codes)
return (cosine_sim**2).sum() / (h * n**2) - (1 / n)
@ -1906,7 +1882,7 @@ class EuclideanCodebook(nn.Module):
self.kmeans_iters = kmeans_iters
self.eps = eps
self.threshold_ema_dead_code = threshold_ema_dead_code
self.reset_cluster_size = default(reset_cluster_size, threshold_ema_dead_code)
self.reset_cluster_size = reset_cluster_size if (reset_cluster_size is not None) else threshold_ema_dead_code
assert callable(gumbel_sample)
self.gumbel_sample = gumbel_sample
@ -1960,7 +1936,7 @@ class EuclideanCodebook(nn.Module):
if self.initted:
return
if exists(mask):
if (mask is not None):
c = data.shape[0]
data = rearrange(data[mask], "(c n) d -> c n d", c=c)
@ -1988,7 +1964,7 @@ class EuclideanCodebook(nn.Module):
if needs_init:
self.register_buffer(buffer_name + "_needs_init", torch.Tensor([False]))
if not exists(old_value) or needs_init:
if not (old_value is not None) or needs_init:
self.register_buffer(buffer_name, new_value.detach())
return
@ -2022,7 +1998,7 @@ class EuclideanCodebook(nn.Module):
data = rearrange(data, "h ... d -> h (...) d")
if exists(mask):
if (mask is not None):
c = data.shape[0]
data = rearrange(data[mask], "(c n) d -> c n d", c=c)
@ -2098,7 +2074,7 @@ class EuclideanCodebook(nn.Module):
@autocast(enabled=False)
def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False):
needs_codebook_dim = x.ndim < 4
sample_codebook_temp = default(sample_codebook_temp, self.sample_codebook_temp)
sample_codebook_temp = sample_codebook_temp if (sample_codebook_temp is not None) else self.sample_codebook_temp
x = x.float()
@ -2108,7 +2084,7 @@ class EuclideanCodebook(nn.Module):
dtype = x.dtype
flatten, ps = pack_one(x, "h * d")
if exists(mask):
if (mask is not None):
mask = repeat(
mask,
"b n -> c (b h n)",
@ -2150,7 +2126,7 @@ class EuclideanCodebook(nn.Module):
codebook_std / batch_std
) + self.codebook_mean
if exists(mask):
if (mask is not None):
embed_onehot[~mask] = 0.0
cluster_size = embed_onehot.sum(dim=1)