fix vqbet

This commit is contained in:
Remi Cadene 2024-07-17 00:45:34 +02:00
parent 4ea6481390
commit df303d7311
1 changed files with 1 additions and 1 deletions

View File

@ -298,7 +298,7 @@ class VQBeTModel(nn.Module):
# bin prediction head / offset prediction head part of VQ-BeT
self.action_head = VQBeTHead(config)
num_tokens = self.config.n_action_pred_token + self.config.action_chunk_size - 1
num_tokens = self.config.n_action_pred_token
self.register_buffer(
"select_target_actions_indices",
torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]),