fix obs_projector to rgb_feature_projector, replace notimplementederror to assert

This commit is contained in:
jayLEE0301 2024-06-04 17:20:57 -04:00
parent 9f109538d9
commit 18b19b95f8
1 changed files with 4 additions and 5 deletions

View File

@ -262,7 +262,7 @@ class VQBeTModel(nn.Module):
config.output_shapes["action"][0],
hidden_channels=[self.config.gpt_input_dim]
)
self.obs_projector = MLP(
self.rgb_feature_projector = MLP(
self.rgb_encoder.feature_dim,
hidden_channels=[self.config.gpt_input_dim]
)
@ -288,12 +288,11 @@ class VQBeTModel(nn.Module):
# image observation feature, state feature, and action query token are grouped together with the same timestpe to form a group, which is listed in order to be entered into GPT sequentially.
observation_feature = torch.cat([
torch.unsqueeze(self.obs_projector(img_features), dim=2),
torch.unsqueeze(self.rgb_feature_projector(img_features), dim=2),
torch.unsqueeze(self.state_projector(batch["observation.state"]), dim=2),
self._action_token.repeat(batch_size, n_obs_steps, 1, 1)
], dim=-2).view(batch_size, -1, self.config.gpt_input_dim)
if img_features.shape[1] != n_obs_steps:
raise NotImplementedError
assert img_features.shape[1] == n_obs_steps, "The number of input image feature tokens should be same with n_obs_steps"
len_additional_action_token = self.config.n_action_pred_token-1
action_token = self._action_token.repeat(batch_size, len_additional_action_token, 1)
@ -577,7 +576,7 @@ class VQBeTOptimizer:
{"params": policy.vqbet.state_projector.parameters()}
)
self.bet_optimizer1.add_param_group(
{"params": policy.vqbet.obs_projector.parameters()}
{"params": policy.vqbet.rgb_feature_projector.parameters()}
)
self.bet_optimizer2 = torch.optim.AdamW(