replace hardcoded part with gpt_num_obs_mode, delete unused eos token
This commit is contained in:
parent
b3e0ec1afd
commit
d71db341bc
|
@ -104,6 +104,7 @@ class VQBeTConfig:
|
||||||
gpt_n_layer: int = 8
|
gpt_n_layer: int = 8
|
||||||
gpt_n_head: int = 8
|
gpt_n_head: int = 8
|
||||||
gpt_hidden_dim: int = 512
|
gpt_hidden_dim: int = 512
|
||||||
|
gpt_num_obs_mode: int = 2
|
||||||
dropout: float = 0.1
|
dropout: float = 0.1
|
||||||
mlp_hidden_dim: int = 1024
|
mlp_hidden_dim: int = 1024
|
||||||
offset_loss_weight: float = 10000.
|
offset_loss_weight: float = 10000.
|
||||||
|
|
|
@ -142,7 +142,6 @@ class VQBeTModel(nn.Module):
|
||||||
|
|
||||||
# action token and EOS token
|
# action token and EOS token
|
||||||
self._action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim)) # Batch, Timestep, Data type, GPT input dim
|
self._action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim)) # Batch, Timestep, Data type, GPT input dim
|
||||||
self._eos_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim))
|
|
||||||
|
|
||||||
self.state_projector = MLP(
|
self.state_projector = MLP(
|
||||||
config.output_shapes["action"][0],
|
config.output_shapes["action"][0],
|
||||||
|
@ -177,7 +176,6 @@ class VQBeTModel(nn.Module):
|
||||||
], dim=-2).view(batch_size, -1, self.config.gpt_input_dim)
|
], dim=-2).view(batch_size, -1, self.config.gpt_input_dim)
|
||||||
if img_features.shape[1] != n_obs_steps:
|
if img_features.shape[1] != n_obs_steps:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
# eos_token = self._eos_token.repeat(batch_size, 1, 1) # TODO(jayLEE0301) remove EOS token
|
|
||||||
len_additional_action_token = self.config.n_action_pred_token-1
|
len_additional_action_token = self.config.n_action_pred_token-1
|
||||||
action_token = self._action_token.repeat(batch_size, len_additional_action_token, 1)
|
action_token = self._action_token.repeat(batch_size, len_additional_action_token, 1)
|
||||||
|
|
||||||
|
@ -186,7 +184,7 @@ class VQBeTModel(nn.Module):
|
||||||
|
|
||||||
# get action features
|
# get action features
|
||||||
features = self.policy(observation_feature)
|
features = self.policy(observation_feature)
|
||||||
historical_act_pred_index = np.arange(0, n_obs_steps) * 3 + 2 # TODO(jayLEE0301) make it compatible with other values
|
historical_act_pred_index = np.arange(0, n_obs_steps) * (self.config.gpt_num_obs_mode+1) + self.config.gpt_num_obs_mode
|
||||||
features = torch.cat([
|
features = torch.cat([
|
||||||
features[:, historical_act_pred_index],
|
features[:, historical_act_pred_index],
|
||||||
features[:, -len_additional_action_token:]
|
features[:, -len_additional_action_token:]
|
||||||
|
@ -431,9 +429,6 @@ class VQBeTOptimizer:
|
||||||
self.bet_optimizer1.add_param_group(
|
self.bet_optimizer1.add_param_group(
|
||||||
{"params": policy.vqbet._action_token}
|
{"params": policy.vqbet._action_token}
|
||||||
)
|
)
|
||||||
self.bet_optimizer1.add_param_group(
|
|
||||||
{"params": policy.vqbet._eos_token}
|
|
||||||
)
|
|
||||||
self.bet_optimizer1.add_param_group(
|
self.bet_optimizer1.add_param_group(
|
||||||
{"params": policy.vqbet.state_projector.parameters()}
|
{"params": policy.vqbet.state_projector.parameters()}
|
||||||
)
|
)
|
||||||
|
|
|
@ -96,6 +96,7 @@ policy:
|
||||||
gpt_n_layer: 8
|
gpt_n_layer: 8
|
||||||
gpt_n_head: 8
|
gpt_n_head: 8
|
||||||
gpt_hidden_dim: 512
|
gpt_hidden_dim: 512
|
||||||
|
gpt_num_obs_mode: 2
|
||||||
dropout: 0.1
|
dropout: 0.1
|
||||||
mlp_hidden_dim: 1024
|
mlp_hidden_dim: 1024
|
||||||
offset_loss_weight: 10000.
|
offset_loss_weight: 10000.
|
||||||
|
|
Loading…
Reference in New Issue