Bug: Fix VQ-Bet not working when n_action_pred_token=1 (#420)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
parent
9ff829a3a1
commit
f17d9a2ba1
|
@ -350,17 +350,22 @@ class VQBeTModel(nn.Module):
|
||||||
|
|
||||||
# get action features (pass through GPT)
|
# get action features (pass through GPT)
|
||||||
features = self.policy(input_tokens)
|
features = self.policy(input_tokens)
|
||||||
# len(self.config.input_shapes) is the number of different observation modes. this line gets the index of action prompt tokens.
|
# len(self.config.input_shapes) is the number of different observation modes.
|
||||||
|
# this line gets the index of action prompt tokens.
|
||||||
historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_shapes) + 1) + len(
|
historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_shapes) + 1) + len(
|
||||||
self.config.input_shapes
|
self.config.input_shapes
|
||||||
)
|
)
|
||||||
|
|
||||||
# only extract the output tokens at the position of action query:
|
# only extract the output tokens at the position of action query:
|
||||||
# Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models, mapping sequential observation to sequential action (please refer to section 2.2 in BeT paper https://arxiv.org/pdf/2206.11251).
|
# Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models,
|
||||||
# Thus, it predict historical action sequence, in addition to current and future actions (predicting future actions : optional).
|
# mapping sequential observation to sequential action (please refer to section 2.2 in BeT paper https://arxiv.org/pdf/2206.11251).
|
||||||
|
# Thus, it predicts a historical action sequence, in addition to current and future actions (predicting future actions : optional).
|
||||||
|
if len_additional_action_token > 0:
|
||||||
features = torch.cat(
|
features = torch.cat(
|
||||||
[features[:, historical_act_pred_index], features[:, -len_additional_action_token:]], dim=1
|
[features[:, historical_act_pred_index], features[:, -len_additional_action_token:]], dim=1
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
features = features[:, historical_act_pred_index]
|
||||||
# pass through action head
|
# pass through action head
|
||||||
action_head_output = self.action_head(features)
|
action_head_output = self.action_head(features)
|
||||||
# if rollout, VQ-BeT don't calculate loss
|
# if rollout, VQ-BeT don't calculate loss
|
||||||
|
|
Loading…
Reference in New Issue