fix obs_projector to rgb_feature_projector, replace notimplementederror to assert
This commit is contained in:
parent
9f109538d9
commit
18b19b95f8
|
@ -262,7 +262,7 @@ class VQBeTModel(nn.Module):
|
|||
config.output_shapes["action"][0],
|
||||
hidden_channels=[self.config.gpt_input_dim]
|
||||
)
|
||||
self.obs_projector = MLP(
|
||||
self.rgb_feature_projector = MLP(
|
||||
self.rgb_encoder.feature_dim,
|
||||
hidden_channels=[self.config.gpt_input_dim]
|
||||
)
|
||||
|
@ -288,12 +288,11 @@ class VQBeTModel(nn.Module):
|
|||
|
||||
# image observation feature, state feature, and action query token are grouped together with the same timestpe to form a group, which is listed in order to be entered into GPT sequentially.
|
||||
observation_feature = torch.cat([
|
||||
torch.unsqueeze(self.obs_projector(img_features), dim=2),
|
||||
torch.unsqueeze(self.rgb_feature_projector(img_features), dim=2),
|
||||
torch.unsqueeze(self.state_projector(batch["observation.state"]), dim=2),
|
||||
self._action_token.repeat(batch_size, n_obs_steps, 1, 1)
|
||||
], dim=-2).view(batch_size, -1, self.config.gpt_input_dim)
|
||||
if img_features.shape[1] != n_obs_steps:
|
||||
raise NotImplementedError
|
||||
assert img_features.shape[1] == n_obs_steps, "The number of input image feature tokens should be same with n_obs_steps"
|
||||
len_additional_action_token = self.config.n_action_pred_token-1
|
||||
action_token = self._action_token.repeat(batch_size, len_additional_action_token, 1)
|
||||
|
||||
|
@ -577,7 +576,7 @@ class VQBeTOptimizer:
|
|||
{"params": policy.vqbet.state_projector.parameters()}
|
||||
)
|
||||
self.bet_optimizer1.add_param_group(
|
||||
{"params": policy.vqbet.obs_projector.parameters()}
|
||||
{"params": policy.vqbet.rgb_feature_projector.parameters()}
|
||||
)
|
||||
|
||||
self.bet_optimizer2 = torch.optim.AdamW(
|
||||
|
|
Loading…
Reference in New Issue