Let vqvae inherit from nn.module
This commit is contained in:
parent
00d2422710
commit
db4e5e7ec1
|
@ -765,7 +765,7 @@ class EncoderMLP(nn.Module):
|
|||
return state
|
||||
|
||||
|
||||
class VqVae:
|
||||
class VqVae(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim_h=10, # length of action chunk
|
||||
|
@ -774,11 +774,11 @@ class VqVae:
|
|||
vqvae_n_embed=32,
|
||||
vqvae_groups=4,
|
||||
eval=True,
|
||||
device="cuda",
|
||||
load_dir=None,
|
||||
encoder_loss_multiplier=1.0,
|
||||
act_scale=1.0,
|
||||
):
|
||||
super(VqVae, self).__init__()
|
||||
self.n_latent_dims = n_latent_dims
|
||||
self.input_dim_h = input_dim_h
|
||||
self.input_dim_w = input_dim_w
|
||||
|
@ -786,7 +786,6 @@ class VqVae:
|
|||
self.vqvae_n_embed = vqvae_n_embed
|
||||
self.vqvae_lr = 1e-3
|
||||
self.vqvae_groups = vqvae_groups
|
||||
self.device = device
|
||||
self.encoder_loss_multiplier = encoder_loss_multiplier
|
||||
self.act_scale = act_scale
|
||||
|
||||
|
@ -799,25 +798,23 @@ class VqVae:
|
|||
dim=self.n_latent_dims,
|
||||
num_quantizers=discrete_cfg["groups"],
|
||||
codebook_size=self.vqvae_n_embed,
|
||||
).to(self.device)
|
||||
)
|
||||
self.embedding_dim = self.n_latent_dims
|
||||
|
||||
self.vq_layer.device = device
|
||||
|
||||
if self.input_dim_h == 1:
|
||||
self.encoder = EncoderMLP(
|
||||
input_dim=input_dim_w, output_dim=n_latent_dims
|
||||
).to(self.device)
|
||||
)
|
||||
self.decoder = EncoderMLP(
|
||||
input_dim=n_latent_dims, output_dim=input_dim_w
|
||||
).to(self.device)
|
||||
)
|
||||
else:
|
||||
self.encoder = EncoderMLP(
|
||||
input_dim=input_dim_w * self.input_dim_h, output_dim=n_latent_dims
|
||||
).to(self.device)
|
||||
)
|
||||
self.decoder = EncoderMLP(
|
||||
input_dim=n_latent_dims, output_dim=input_dim_w * self.input_dim_h
|
||||
).to(self.device)
|
||||
)
|
||||
|
||||
|
||||
if load_dir is not None:
|
||||
|
@ -828,15 +825,28 @@ class VqVae:
|
|||
self.load_state_dict(state_dict)
|
||||
|
||||
if eval:
|
||||
self.vq_layer.eval()
|
||||
self.eval()
|
||||
else:
|
||||
self.vq_layer.train()
|
||||
self.train()
|
||||
|
||||
def eval(self):
|
||||
self.training = False
|
||||
self.vq_layer.eval()
|
||||
self.encoder.eval()
|
||||
self.decoder.eval()
|
||||
|
||||
def train(self, mode=True):
|
||||
if mode:
|
||||
if self.discretized:
|
||||
pass
|
||||
else:
|
||||
self.training = True
|
||||
self.vq_layer.train()
|
||||
self.decoder.train()
|
||||
self.encoder.train()
|
||||
else:
|
||||
self.eval()
|
||||
|
||||
def draw_logits_forward(self, encoding_logits):
|
||||
z_embed = self.vq_layer.draw_logits_forward(encoding_logits)
|
||||
return z_embed
|
||||
|
@ -856,12 +866,12 @@ class VqVae:
|
|||
|
||||
def preprocess(self, state):
|
||||
if not torch.is_tensor(state):
|
||||
state = torch.FloatTensor(state.copy()).to(self.device)
|
||||
state = torch.FloatTensor(state.copy())
|
||||
if self.input_dim_h == 1:
|
||||
state = state.squeeze(-2) # state.squeeze(-1)
|
||||
else:
|
||||
state = einops.rearrange(state, "N T A -> N (T A)")
|
||||
return state.to(self.device)
|
||||
return state
|
||||
|
||||
def get_code(self, state, required_recon=False):
|
||||
state = state / self.act_scale
|
||||
|
@ -914,18 +924,11 @@ class VqVae:
|
|||
)
|
||||
return rep_loss, metric
|
||||
|
||||
def state_dict(self):
|
||||
return {
|
||||
"encoder": self.encoder.state_dict(),
|
||||
"decoder": self.decoder.state_dict(),
|
||||
"vq_embedding": self.vq_layer.state_dict(),
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.encoder.load_state_dict(state_dict["encoder"])
|
||||
self.decoder.load_state_dict(state_dict["decoder"])
|
||||
self.vq_layer.load_state_dict(state_dict["vq_embedding"])
|
||||
self.vq_layer.eval()
|
||||
def load_state_dict(self, *args, **kwargs):
|
||||
super(VqVae, self).state_dict(self, *args, **kwargs)
|
||||
self.eval()
|
||||
self.discretized = True
|
||||
|
||||
|
||||
|
||||
|
@ -938,8 +941,6 @@ def init_vqvae(config):
|
|||
vqvae_n_embed=config["vqvae_n_embed"],
|
||||
vqvae_groups=config["vqvae_groups"],
|
||||
eval=False,
|
||||
device=config["device"],
|
||||
# encoder_loss_multiplier=0.033,
|
||||
)
|
||||
|
||||
return vqvae_model
|
||||
|
@ -956,7 +957,7 @@ def pretrain_vqvae(vqvae_model, discretize_step, actions):
|
|||
actions = torch.cat(slices, dim=0)
|
||||
|
||||
|
||||
actions = actions.to(vqvae_model.device)
|
||||
actions = actions.to(get_device_from_parameters(vqvae_model))
|
||||
|
||||
loss, metric = vqvae_model.vqvae_forward(
|
||||
actions
|
||||
|
@ -1089,13 +1090,11 @@ class ResidualVQ(nn.Module):
|
|||
|
||||
def draw_logits_forward(self, encoding_logits):
|
||||
# encoding_indices : dim1 = batch_size dim2 = 4 (number of groups) dim3 = vq dict size (header)
|
||||
encoding_logits = encoding_logits.to(self.device)
|
||||
encoding_logits = encoding_logits
|
||||
bs = encoding_logits.shape[0]
|
||||
quantized = torch.zeros((bs, self.codebooks.shape[-1])).to(self.device)
|
||||
quantized = torch.zeros((bs, self.codebooks.shape[-1]))
|
||||
for q in range(encoding_logits.shape[1]):
|
||||
quantized += torch.matmul(encoding_logits[:, q], self.codebooks[q]).to(
|
||||
self.device
|
||||
)
|
||||
quantized += torch.matmul(encoding_logits[:, q], self.codebooks[q])
|
||||
return quantized
|
||||
|
||||
def forward(
|
||||
|
|
Loading…
Reference in New Issue