diff --git a/lerobot/common/policies/vqbet/configuration_vqbet.py b/lerobot/common/policies/vqbet/configuration_vqbet.py index fbe8773d..7533a819 100644 --- a/lerobot/common/policies/vqbet/configuration_vqbet.py +++ b/lerobot/common/policies/vqbet/configuration_vqbet.py @@ -104,6 +104,7 @@ class VQBeTConfig: gpt_n_layer: int = 8 gpt_n_head: int = 8 gpt_hidden_dim: int = 512 + gpt_num_obs_mode: int = 2 dropout: float = 0.1 mlp_hidden_dim: int = 1024 offset_loss_weight: float = 10000. diff --git a/lerobot/common/policies/vqbet/modeling_vqbet.py b/lerobot/common/policies/vqbet/modeling_vqbet.py index e8ca5d8c..817926ec 100644 --- a/lerobot/common/policies/vqbet/modeling_vqbet.py +++ b/lerobot/common/policies/vqbet/modeling_vqbet.py @@ -142,7 +142,6 @@ class VQBeTModel(nn.Module): # action token and EOS token self._action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim)) # Batch, Timestep, Data type, GPT input dim - self._eos_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim)) self.state_projector = MLP( config.output_shapes["action"][0], @@ -177,7 +176,6 @@ class VQBeTModel(nn.Module): ], dim=-2).view(batch_size, -1, self.config.gpt_input_dim) if img_features.shape[1] != n_obs_steps: raise NotImplementedError - # eos_token = self._eos_token.repeat(batch_size, 1, 1) # TODO(jayLEE0301) remove EOS token len_additional_action_token = self.config.n_action_pred_token-1 action_token = self._action_token.repeat(batch_size, len_additional_action_token, 1) @@ -186,7 +184,7 @@ class VQBeTModel(nn.Module): # get action features features = self.policy(observation_feature) - historical_act_pred_index = np.arange(0, n_obs_steps) * 3 + 2 # TODO(jayLEE0301) make it compatible with other values + historical_act_pred_index = np.arange(0, n_obs_steps) * (self.config.gpt_num_obs_mode+1) + self.config.gpt_num_obs_mode features = torch.cat([ features[:, historical_act_pred_index], features[:, -len_additional_action_token:] @@ -431,9 +429,6 @@ class VQBeTOptimizer: self.bet_optimizer1.add_param_group( {"params": policy.vqbet._action_token} ) - self.bet_optimizer1.add_param_group( - {"params": policy.vqbet._eos_token} - ) self.bet_optimizer1.add_param_group( {"params": policy.vqbet.state_projector.parameters()} ) diff --git a/lerobot/configs/policy/vqbet.yaml b/lerobot/configs/policy/vqbet.yaml index 0844bec9..524f1a21 100644 --- a/lerobot/configs/policy/vqbet.yaml +++ b/lerobot/configs/policy/vqbet.yaml @@ -96,6 +96,7 @@ policy: gpt_n_layer: 8 gpt_n_head: 8 gpt_hidden_dim: 512 + gpt_num_obs_mode: 2 dropout: 0.1 mlp_hidden_dim: 1024 offset_loss_weight: 10000.