Bug fix: fix error when setting select_target_actions_indices in vqbet (#310)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Seungjae Lee 2024-07-11 01:56:11 +09:00 committed by GitHub
parent e410e5d711
commit 64425d5e00
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 2 additions and 1 deletions

View File

@ -298,7 +298,8 @@ class VQBeTModel(nn.Module):
# bin prediction head / offset prediction head part of VQ-BeT
self.action_head = VQBeTHead(config)
num_tokens = self.config.n_action_pred_token + self.config.action_chunk_size - 1
# Action tokens for: each observation step, the current action token, and all future action tokens.
num_tokens = self.config.n_action_pred_token + self.config.n_obs_steps - 1
self.register_buffer(
"select_target_actions_indices",
torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]),