change confusing terms regarding action_tokens
This commit is contained in:
parent
8c775c94fc
commit
dbc029b353
|
@ -298,18 +298,18 @@ class VQBeTModel(nn.Module):
|
|||
# First project features to token dimension.
|
||||
rgb_tokens = self.rgb_feature_projector(img_features) # (batch, obs_step, d)
|
||||
state_tokens = self.state_projector(batch["observation.state"]) # (batch, obs_step, d)
|
||||
action_tokens = einops.repeat(
|
||||
history_action_tokens = einops.repeat(
|
||||
self._action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps
|
||||
)
|
||||
# Interleave tokens by stacking and rearranging.
|
||||
input_tokens = torch.stack([rgb_tokens, state_tokens, action_tokens], dim=2)
|
||||
input_tokens = torch.stack([rgb_tokens, state_tokens, history_action_tokens], dim=2)
|
||||
input_tokens = einops.rearrange(input_tokens, "b n t d -> b (n t) d")
|
||||
|
||||
len_additional_action_token = self.config.n_action_pred_token-1
|
||||
action_token = self._action_token.repeat(batch_size, len_additional_action_token, 1)
|
||||
future_action_tokens = self._action_token.repeat(batch_size, len_additional_action_token, 1)
|
||||
|
||||
# add additional action query tokens for predicting future action chunks
|
||||
input_tokens = torch.cat([input_tokens, action_token], dim=1)
|
||||
input_tokens = torch.cat([input_tokens, future_action_tokens], dim=1)
|
||||
|
||||
|
||||
# get action features (pass through GPT)
|
||||
|
|
Loading…
Reference in New Issue