remove unused parts (some lines in def get_code and def preprocessin class VqVae)
This commit is contained in:
parent
cde1804bc7
commit
2f0e601d9f
|
@ -793,17 +793,8 @@ class VqVae(nn.Module):
|
|||
else:
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
|
||||
|
||||
def preprocess(self, state):
|
||||
if not torch.is_tensor(state):
|
||||
state = torch.FloatTensor(state.copy())
|
||||
if self.config.action_chunk_size == 1:
|
||||
state = state.squeeze(-2)
|
||||
else:
|
||||
state = einops.rearrange(state, "N T A -> N (T A)")
|
||||
return state
|
||||
|
||||
def get_code(self, state, required_recon=False):
|
||||
state = self.preprocess(state)
|
||||
def get_code(self, state):
|
||||
state = einops.rearrange(state, "N T A -> N (T A)")
|
||||
with torch.no_grad():
|
||||
state_rep = self.encoder(state)
|
||||
state_rep_shape = state_rep.shape[:-1]
|
||||
|
@ -812,23 +803,10 @@ class VqVae(nn.Module):
|
|||
state_vq = state_rep_flat.view(*state_rep_shape, -1)
|
||||
vq_code = vq_code.view(*state_rep_shape, -1)
|
||||
vq_loss_state = torch.sum(vq_loss_state)
|
||||
if required_recon:
|
||||
recon_state = self.decoder(state_vq)
|
||||
recon_state_ae = self.decoder(state_rep)
|
||||
if self.config.action_chunk_size == 1:
|
||||
return state_vq, vq_code, recon_state, recon_state_ae
|
||||
else:
|
||||
return (
|
||||
state_vq,
|
||||
vq_code,
|
||||
torch.swapaxes(recon_state, -2, -1),
|
||||
torch.swapaxes(recon_state_ae, -2, -1),
|
||||
)
|
||||
else:
|
||||
return state_vq, vq_code
|
||||
return state_vq, vq_code
|
||||
|
||||
def vqvae_forward(self, state):
|
||||
state = self.preprocess(state)
|
||||
state = einops.rearrange(state, "N T A -> N (T A)")
|
||||
state_rep = self.encoder(state)
|
||||
state_rep_shape = state_rep.shape[:-1]
|
||||
state_rep_flat = state_rep.view(state_rep.size(0), -1, state_rep.size(1))
|
||||
|
|
Loading…
Reference in New Issue