remove unused parts (some lines in def get_code and def preprocessin class VqVae)

This commit is contained in:
jayLEE0301 2024-06-08 17:18:08 -04:00
parent cde1804bc7
commit 2f0e601d9f
1 changed files with 4 additions and 26 deletions

View File

@ -793,17 +793,8 @@ class VqVae(nn.Module):
else:
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
def preprocess(self, state):
if not torch.is_tensor(state):
state = torch.FloatTensor(state.copy())
if self.config.action_chunk_size == 1:
state = state.squeeze(-2)
else:
state = einops.rearrange(state, "N T A -> N (T A)")
return state
def get_code(self, state, required_recon=False):
state = self.preprocess(state)
def get_code(self, state):
state = einops.rearrange(state, "N T A -> N (T A)")
with torch.no_grad():
state_rep = self.encoder(state)
state_rep_shape = state_rep.shape[:-1]
@ -812,23 +803,10 @@ class VqVae(nn.Module):
state_vq = state_rep_flat.view(*state_rep_shape, -1)
vq_code = vq_code.view(*state_rep_shape, -1)
vq_loss_state = torch.sum(vq_loss_state)
if required_recon:
recon_state = self.decoder(state_vq)
recon_state_ae = self.decoder(state_rep)
if self.config.action_chunk_size == 1:
return state_vq, vq_code, recon_state, recon_state_ae
else:
return (
state_vq,
vq_code,
torch.swapaxes(recon_state, -2, -1),
torch.swapaxes(recon_state_ae, -2, -1),
)
else:
return state_vq, vq_code
return state_vq, vq_code
def vqvae_forward(self, state):
state = self.preprocess(state)
state = einops.rearrange(state, "N T A -> N (T A)")
state_rep = self.encoder(state)
state_rep_shape = state_rep.shape[:-1]
state_rep_flat = state_rep.view(state_rep.size(0), -1, state_rep.size(1))