change confusing terms regarding action_tokens

This commit is contained in:
jayLEE0301 2024-06-05 11:23:24 -04:00
parent 8c775c94fc
commit dbc029b353
1 changed files with 4 additions and 4 deletions

View File

@ -298,18 +298,18 @@ class VQBeTModel(nn.Module):
# First project features to token dimension.
rgb_tokens = self.rgb_feature_projector(img_features) # (batch, obs_step, d)
state_tokens = self.state_projector(batch["observation.state"]) # (batch, obs_step, d)
action_tokens = einops.repeat(
history_action_tokens = einops.repeat(
self._action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps
)
# Interleave tokens by stacking and rearranging.
input_tokens = torch.stack([rgb_tokens, state_tokens, action_tokens], dim=2)
input_tokens = torch.stack([rgb_tokens, state_tokens, history_action_tokens], dim=2)
input_tokens = einops.rearrange(input_tokens, "b n t d -> b (n t) d")
len_additional_action_token = self.config.n_action_pred_token-1
action_token = self._action_token.repeat(batch_size, len_additional_action_token, 1)
future_action_tokens = self._action_token.repeat(batch_size, len_additional_action_token, 1)
# add additional action query tokens for predicting future action chunks
input_tokens = torch.cat([input_tokens, action_token], dim=1)
input_tokens = torch.cat([input_tokens, future_action_tokens], dim=1)
# get action features (pass through GPT)