simplified action stacking part

This commit is contained in:
jayLEE0301 2024-06-05 11:20:13 -04:00
parent c4c5977f37
commit 8c775c94fc
1 changed files with 11 additions and 16 deletions

View File

@ -272,6 +272,14 @@ class VQBeTModel(nn.Module):
# bin prediction head / offset prediction head part of VQ-BeT
self.action_head = VQBeTHead(config)
num_tokens = self.config.n_action_pred_token + self.config.action_chunk_size - 1
self.register_buffer(
"select_target_actions_indices",
torch.row_stack(
[torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]
),
)
def discretize(self, discretize_step, actions):
return self.action_head.discretize(discretize_step, actions)
@ -320,22 +328,9 @@ class VQBeTModel(nn.Module):
return pred_action["predicted_action"][:, n_obs_steps-1, :].reshape(batch_size, self.config.action_chunk_size, -1)
# else, it calculate overall loss (bin prediction loss, and offset loss)
else:
action = batch["action"]
n, total_w, act_dim = action.shape
act_w = self.config.action_chunk_size
num_token = total_w + 1 - act_w
output_shape = (n, num_token, act_w, act_dim)
output = torch.empty(output_shape).to(action.device)
for i in range(num_token):
output[:, i, :, :] = action[:, i : i + act_w, :]
action = output
loss_dict = self.action_head.loss_fn(
pred_action,
action,
reduction="mean",
)
return pred_action, loss_dict
output = batch["action"][:, self.select_target_actions_indices]
loss = self.action_head.loss_fn(pred_action, output, reduction="mean")
return pred_action, loss
class VQBeTHead(nn.Module):