replace hardcoded part with gpt_num_obs_mode, delete unused eos token
This commit is contained in:
parent
b3e0ec1afd
commit
d71db341bc
|
@ -104,6 +104,7 @@ class VQBeTConfig:
|
|||
gpt_n_layer: int = 8
|
||||
gpt_n_head: int = 8
|
||||
gpt_hidden_dim: int = 512
|
||||
gpt_num_obs_mode: int = 2
|
||||
dropout: float = 0.1
|
||||
mlp_hidden_dim: int = 1024
|
||||
offset_loss_weight: float = 10000.
|
||||
|
|
|
@ -142,7 +142,6 @@ class VQBeTModel(nn.Module):
|
|||
|
||||
# action token and EOS token
|
||||
self._action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim)) # Batch, Timestep, Data type, GPT input dim
|
||||
self._eos_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim))
|
||||
|
||||
self.state_projector = MLP(
|
||||
config.output_shapes["action"][0],
|
||||
|
@ -177,7 +176,6 @@ class VQBeTModel(nn.Module):
|
|||
], dim=-2).view(batch_size, -1, self.config.gpt_input_dim)
|
||||
if img_features.shape[1] != n_obs_steps:
|
||||
raise NotImplementedError
|
||||
# eos_token = self._eos_token.repeat(batch_size, 1, 1) # TODO(jayLEE0301) remove EOS token
|
||||
len_additional_action_token = self.config.n_action_pred_token-1
|
||||
action_token = self._action_token.repeat(batch_size, len_additional_action_token, 1)
|
||||
|
||||
|
@ -186,7 +184,7 @@ class VQBeTModel(nn.Module):
|
|||
|
||||
# get action features
|
||||
features = self.policy(observation_feature)
|
||||
historical_act_pred_index = np.arange(0, n_obs_steps) * 3 + 2 # TODO(jayLEE0301) make it compatible with other values
|
||||
historical_act_pred_index = np.arange(0, n_obs_steps) * (self.config.gpt_num_obs_mode+1) + self.config.gpt_num_obs_mode
|
||||
features = torch.cat([
|
||||
features[:, historical_act_pred_index],
|
||||
features[:, -len_additional_action_token:]
|
||||
|
@ -431,9 +429,6 @@ class VQBeTOptimizer:
|
|||
self.bet_optimizer1.add_param_group(
|
||||
{"params": policy.vqbet._action_token}
|
||||
)
|
||||
self.bet_optimizer1.add_param_group(
|
||||
{"params": policy.vqbet._eos_token}
|
||||
)
|
||||
self.bet_optimizer1.add_param_group(
|
||||
{"params": policy.vqbet.state_projector.parameters()}
|
||||
)
|
||||
|
|
|
@ -96,6 +96,7 @@ policy:
|
|||
gpt_n_layer: 8
|
||||
gpt_n_head: 8
|
||||
gpt_hidden_dim: 512
|
||||
gpt_num_obs_mode: 2
|
||||
dropout: 0.1
|
||||
mlp_hidden_dim: 1024
|
||||
offset_loss_weight: 10000.
|
||||
|
|
Loading…
Reference in New Issue