remove gpt_num_obs_mode
This commit is contained in:
parent
b3fc3b7e21
commit
31245325bd
|
@ -51,7 +51,6 @@ class VQBeTConfig:
|
|||
gpt_n_layer: Number of layers of GPT
|
||||
gpt_n_head: Number of headers of GPT
|
||||
gpt_hidden_dim: Size of hidden dimensions of GPT
|
||||
gpt_num_obs_mode: Number of different observation modes. (e.g., PushT env: {state, image observation}, thus 2.)
|
||||
dropout: Dropout rate for GPT
|
||||
mlp_hidden_dim: Size of hidden dimensions of offset header / bin prediction headers parts of VQ-BeT
|
||||
offset_loss_weight: A constant that is multiplied to the offset loss
|
||||
|
@ -109,7 +108,6 @@ class VQBeTConfig:
|
|||
gpt_n_layer: int = 8
|
||||
gpt_n_head: int = 8
|
||||
gpt_hidden_dim: int = 512
|
||||
gpt_num_obs_mode: int = 2
|
||||
dropout: float = 0.1
|
||||
mlp_hidden_dim: int = 1024
|
||||
offset_loss_weight: float = 10000.
|
||||
|
|
|
@ -314,7 +314,8 @@ class VQBeTModel(nn.Module):
|
|||
|
||||
# get action features (pass through GPT)
|
||||
features = self.policy(input_tokens)
|
||||
historical_act_pred_index = np.arange(0, n_obs_steps) * (self.config.gpt_num_obs_mode+1) + self.config.gpt_num_obs_mode
|
||||
# len(self.config.input_shapes) is the number of different observation modes. this line gets the index of action prompt tokens.
|
||||
historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_shapes)+1) + len(self.config.input_shapes)
|
||||
|
||||
# only extract the output tokens at the position of action query
|
||||
features = torch.cat([
|
||||
|
|
|
@ -96,7 +96,6 @@ policy:
|
|||
gpt_n_layer: 8
|
||||
gpt_n_head: 8
|
||||
gpt_hidden_dim: 512
|
||||
gpt_num_obs_mode: 2
|
||||
dropout: 0.1
|
||||
mlp_hidden_dim: 1024
|
||||
offset_loss_weight: 10000.
|
||||
|
|
Loading…
Reference in New Issue