Let vqvae inherit from nn.module

This commit is contained in:
jayLEE0301 2024-05-09 20:01:48 -04:00
parent 00d2422710
commit db4e5e7ec1
1 changed files with 32 additions and 33 deletions

View File

@ -765,7 +765,7 @@ class EncoderMLP(nn.Module):
return state
class VqVae:
class VqVae(nn.Module):
def __init__(
self,
input_dim_h=10, # length of action chunk
@ -774,11 +774,11 @@ class VqVae:
vqvae_n_embed=32,
vqvae_groups=4,
eval=True,
device="cuda",
load_dir=None,
encoder_loss_multiplier=1.0,
act_scale=1.0,
):
super(VqVae, self).__init__()
self.n_latent_dims = n_latent_dims
self.input_dim_h = input_dim_h
self.input_dim_w = input_dim_w
@ -786,7 +786,6 @@ class VqVae:
self.vqvae_n_embed = vqvae_n_embed
self.vqvae_lr = 1e-3
self.vqvae_groups = vqvae_groups
self.device = device
self.encoder_loss_multiplier = encoder_loss_multiplier
self.act_scale = act_scale
@ -799,25 +798,23 @@ class VqVae:
dim=self.n_latent_dims,
num_quantizers=discrete_cfg["groups"],
codebook_size=self.vqvae_n_embed,
).to(self.device)
)
self.embedding_dim = self.n_latent_dims
self.vq_layer.device = device
if self.input_dim_h == 1:
self.encoder = EncoderMLP(
input_dim=input_dim_w, output_dim=n_latent_dims
).to(self.device)
)
self.decoder = EncoderMLP(
input_dim=n_latent_dims, output_dim=input_dim_w
).to(self.device)
)
else:
self.encoder = EncoderMLP(
input_dim=input_dim_w * self.input_dim_h, output_dim=n_latent_dims
).to(self.device)
)
self.decoder = EncoderMLP(
input_dim=n_latent_dims, output_dim=input_dim_w * self.input_dim_h
).to(self.device)
)
if load_dir is not None:
@ -828,15 +825,28 @@ class VqVae:
self.load_state_dict(state_dict)
if eval:
self.vq_layer.eval()
self.eval()
else:
self.vq_layer.train()
self.train()
def eval(self):
self.training = False
self.vq_layer.eval()
self.encoder.eval()
self.decoder.eval()
def train(self, mode=True):
if mode:
if self.discretized:
pass
else:
self.training = True
self.vq_layer.train()
self.decoder.train()
self.encoder.train()
else:
self.eval()
def draw_logits_forward(self, encoding_logits):
z_embed = self.vq_layer.draw_logits_forward(encoding_logits)
return z_embed
@ -856,12 +866,12 @@ class VqVae:
def preprocess(self, state):
if not torch.is_tensor(state):
state = torch.FloatTensor(state.copy()).to(self.device)
state = torch.FloatTensor(state.copy())
if self.input_dim_h == 1:
state = state.squeeze(-2) # state.squeeze(-1)
else:
state = einops.rearrange(state, "N T A -> N (T A)")
return state.to(self.device)
return state
def get_code(self, state, required_recon=False):
state = state / self.act_scale
@ -914,18 +924,11 @@ class VqVae:
)
return rep_loss, metric
def state_dict(self):
return {
"encoder": self.encoder.state_dict(),
"decoder": self.decoder.state_dict(),
"vq_embedding": self.vq_layer.state_dict(),
}
def load_state_dict(self, state_dict):
self.encoder.load_state_dict(state_dict["encoder"])
self.decoder.load_state_dict(state_dict["decoder"])
self.vq_layer.load_state_dict(state_dict["vq_embedding"])
self.vq_layer.eval()
def load_state_dict(self, *args, **kwargs):
super(VqVae, self).state_dict(self, *args, **kwargs)
self.eval()
self.discretized = True
@ -938,8 +941,6 @@ def init_vqvae(config):
vqvae_n_embed=config["vqvae_n_embed"],
vqvae_groups=config["vqvae_groups"],
eval=False,
device=config["device"],
# encoder_loss_multiplier=0.033,
)
return vqvae_model
@ -956,7 +957,7 @@ def pretrain_vqvae(vqvae_model, discretize_step, actions):
actions = torch.cat(slices, dim=0)
actions = actions.to(vqvae_model.device)
actions = actions.to(get_device_from_parameters(vqvae_model))
loss, metric = vqvae_model.vqvae_forward(
actions
@ -1089,13 +1090,11 @@ class ResidualVQ(nn.Module):
def draw_logits_forward(self, encoding_logits):
# encoding_indices : dim1 = batch_size dim2 = 4 (number of groups) dim3 = vq dict size (header)
encoding_logits = encoding_logits.to(self.device)
encoding_logits = encoding_logits
bs = encoding_logits.shape[0]
quantized = torch.zeros((bs, self.codebooks.shape[-1])).to(self.device)
quantized = torch.zeros((bs, self.codebooks.shape[-1]))
for q in range(encoding_logits.shape[1]):
quantized += torch.matmul(encoding_logits[:, q], self.codebooks[q]).to(
self.device
)
quantized += torch.matmul(encoding_logits[:, q], self.codebooks[q])
return quantized
def forward(