replace hardcoded part with gpt_num_obs_mode, delete unused eos token

This commit is contained in:
jayLEE0301 2024-05-24 15:46:29 -04:00
parent b3e0ec1afd
commit d71db341bc
3 changed files with 3 additions and 6 deletions

View File

@ -104,6 +104,7 @@ class VQBeTConfig:
gpt_n_layer: int = 8 gpt_n_layer: int = 8
gpt_n_head: int = 8 gpt_n_head: int = 8
gpt_hidden_dim: int = 512 gpt_hidden_dim: int = 512
gpt_num_obs_mode: int = 2
dropout: float = 0.1 dropout: float = 0.1
mlp_hidden_dim: int = 1024 mlp_hidden_dim: int = 1024
offset_loss_weight: float = 10000. offset_loss_weight: float = 10000.

View File

@ -142,7 +142,6 @@ class VQBeTModel(nn.Module):
# action token and EOS token # action token and EOS token
self._action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim)) # Batch, Timestep, Data type, GPT input dim self._action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim)) # Batch, Timestep, Data type, GPT input dim
self._eos_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim))
self.state_projector = MLP( self.state_projector = MLP(
config.output_shapes["action"][0], config.output_shapes["action"][0],
@ -177,7 +176,6 @@ class VQBeTModel(nn.Module):
], dim=-2).view(batch_size, -1, self.config.gpt_input_dim) ], dim=-2).view(batch_size, -1, self.config.gpt_input_dim)
if img_features.shape[1] != n_obs_steps: if img_features.shape[1] != n_obs_steps:
raise NotImplementedError raise NotImplementedError
# eos_token = self._eos_token.repeat(batch_size, 1, 1) # TODO(jayLEE0301) remove EOS token
len_additional_action_token = self.config.n_action_pred_token-1 len_additional_action_token = self.config.n_action_pred_token-1
action_token = self._action_token.repeat(batch_size, len_additional_action_token, 1) action_token = self._action_token.repeat(batch_size, len_additional_action_token, 1)
@ -186,7 +184,7 @@ class VQBeTModel(nn.Module):
# get action features # get action features
features = self.policy(observation_feature) features = self.policy(observation_feature)
historical_act_pred_index = np.arange(0, n_obs_steps) * 3 + 2 # TODO(jayLEE0301) make it compatible with other values historical_act_pred_index = np.arange(0, n_obs_steps) * (self.config.gpt_num_obs_mode+1) + self.config.gpt_num_obs_mode
features = torch.cat([ features = torch.cat([
features[:, historical_act_pred_index], features[:, historical_act_pred_index],
features[:, -len_additional_action_token:] features[:, -len_additional_action_token:]
@ -431,9 +429,6 @@ class VQBeTOptimizer:
self.bet_optimizer1.add_param_group( self.bet_optimizer1.add_param_group(
{"params": policy.vqbet._action_token} {"params": policy.vqbet._action_token}
) )
self.bet_optimizer1.add_param_group(
{"params": policy.vqbet._eos_token}
)
self.bet_optimizer1.add_param_group( self.bet_optimizer1.add_param_group(
{"params": policy.vqbet.state_projector.parameters()} {"params": policy.vqbet.state_projector.parameters()}
) )

View File

@ -96,6 +96,7 @@ policy:
gpt_n_layer: 8 gpt_n_layer: 8
gpt_n_head: 8 gpt_n_head: 8
gpt_hidden_dim: 512 gpt_hidden_dim: 512
gpt_num_obs_mode: 2
dropout: 0.1 dropout: 0.1
mlp_hidden_dim: 1024 mlp_hidden_dim: 1024
offset_loss_weight: 10000. offset_loss_weight: 10000.