simplified action stacking part
This commit is contained in:
parent
c4c5977f37
commit
8c775c94fc
|
@ -272,6 +272,14 @@ class VQBeTModel(nn.Module):
|
|||
# bin prediction head / offset prediction head part of VQ-BeT
|
||||
self.action_head = VQBeTHead(config)
|
||||
|
||||
num_tokens = self.config.n_action_pred_token + self.config.action_chunk_size - 1
|
||||
self.register_buffer(
|
||||
"select_target_actions_indices",
|
||||
torch.row_stack(
|
||||
[torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]
|
||||
),
|
||||
)
|
||||
|
||||
def discretize(self, discretize_step, actions):
|
||||
return self.action_head.discretize(discretize_step, actions)
|
||||
|
||||
|
@ -320,22 +328,9 @@ class VQBeTModel(nn.Module):
|
|||
return pred_action["predicted_action"][:, n_obs_steps-1, :].reshape(batch_size, self.config.action_chunk_size, -1)
|
||||
# else, it calculate overall loss (bin prediction loss, and offset loss)
|
||||
else:
|
||||
action = batch["action"]
|
||||
n, total_w, act_dim = action.shape
|
||||
act_w = self.config.action_chunk_size
|
||||
num_token = total_w + 1 - act_w
|
||||
output_shape = (n, num_token, act_w, act_dim)
|
||||
output = torch.empty(output_shape).to(action.device)
|
||||
for i in range(num_token):
|
||||
output[:, i, :, :] = action[:, i : i + act_w, :]
|
||||
action = output
|
||||
|
||||
loss_dict = self.action_head.loss_fn(
|
||||
pred_action,
|
||||
action,
|
||||
reduction="mean",
|
||||
)
|
||||
return pred_action, loss_dict
|
||||
output = batch["action"][:, self.select_target_actions_indices]
|
||||
loss = self.action_head.loss_fn(pred_action, output, reduction="mean")
|
||||
return pred_action, loss
|
||||
|
||||
|
||||
class VQBeTHead(nn.Module):
|
||||
|
|
Loading…
Reference in New Issue