From 340f7cfd6e4622fec180c783d88b9d35a9555900 Mon Sep 17 00:00:00 2001 From: jayLEE0301 Date: Thu, 23 May 2024 19:08:37 -0400 Subject: [PATCH] add comment, change param names, delete kwargs --- .../common/policies/vqbet/modeling_vqbet.py | 87 +++++++++---------- 1 file changed, 40 insertions(+), 47 deletions(-) diff --git a/lerobot/common/policies/vqbet/modeling_vqbet.py b/lerobot/common/policies/vqbet/modeling_vqbet.py index 54ff2569..c29a08db 100644 --- a/lerobot/common/policies/vqbet/modeling_vqbet.py +++ b/lerobot/common/policies/vqbet/modeling_vqbet.py @@ -96,7 +96,9 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin): if not self.check_discretized(): self.vqbet._action_head._vqvae_model.discretized = True - warnings.warn('To evaluate in the environment, the model was forced to stop learning the Residual VQ. If you are not evaluating with a pre-trained model, this can degrade overall performance.') + # VQ-BeT can predict action only after finishing action discretization. + # We added a logit to force self.vqbet._action_head._vqvae_model.discretized to be True if not self.check_discretized() to account for the case of predicting with a pretrained model, but this shouldn't happen if you're learning from scratch, so set eval_freq greater than discretize_step. + warnings.warn('To evaluate in the environment, the model was forced to stop learning the Residual VQ. If you are not evaluating with a pre-trained model, this can degrade overall performance. To avoid this warning, please set "eval_freq" greater than "discretize_step".') assert "observation.image" in batch assert "observation.state" in batch @@ -105,8 +107,9 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin): batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} actions = self.vqbet(batch, rollout=True)[:, : self.config.n_action_pred_chunk] + # the dimension of returned action is (batch_size, n_action_pred_chunk, action_dim) actions = self.unnormalize_outputs({"action": actions})["action"] - + # since the data in the action queue's dimension is (n_action_pred_chunk, batch_size, action_dim, we transpose the action and fill the queue self._queues["action"].extend(actions.transpose(0, 1)) action = self._queues["action"].popleft() @@ -116,6 +119,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin): """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) batch = self.normalize_targets(batch) + # VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181) if not self.check_discretized(): loss, n_different_codes, n_different_combinations = self.vqbet.discretize(self.config.discretize_step, batch['action']) return {"loss": loss, "n_different_codes": n_different_codes, "n_different_combinations": n_different_combinations} @@ -126,13 +130,15 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin): class VQBeTModel(nn.Module): + """ + TODO(jayLEE0301) + """ def __init__(self, config: VQBeTConfig): super().__init__() self.config = config self.rgb_encoder = VQBeTRgbEncoder(config) - self.global_cond_dim = self.rgb_encoder.feature_dim # action token and EOS token self._action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim)) # Batch, Timestep, Data type, GPT input dim @@ -143,7 +149,7 @@ class VQBeTModel(nn.Module): hidden_channels=[self.config.gpt_input_dim] ) self.obs_projector = MLP( - self.global_cond_dim, + self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim] ) self._policy = GPT(config) @@ -152,7 +158,6 @@ class VQBeTModel(nn.Module): def discretize(self, discretize_step, actions): return self._action_head.discretize(discretize_step, actions) - # ========= inference ============ def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor: # Input validation. assert set(batch).issuperset({"observation.state", "observation.image"}) @@ -163,10 +168,9 @@ class VQBeTModel(nn.Module): img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ...")) # Separate batch and sequence dims. img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size) - # Concatenate state and image features then flatten to (B, global_cond_dim). - global_cond = torch.cat([ + observation_feature = torch.cat([ torch.unsqueeze(self.obs_projector(img_features), dim=2), torch.unsqueeze(self.state_projector(batch["observation.state"]), dim=2), self._action_token.repeat(batch_size, n_obs_steps, 1, 1) @@ -177,12 +181,11 @@ class VQBeTModel(nn.Module): len_additional_action_token = self.config.n_action_pred_token-1 action_token = self._action_token.repeat(batch_size, len_additional_action_token, 1) - # prompt_length = global_cond.shape[1]+1 - global_cond = torch.cat([global_cond, action_token], dim=1) + observation_feature = torch.cat([observation_feature, action_token], dim=1) # get action features - features = self._policy(global_cond) + features = self._policy(observation_feature) historical_act_pred_index = np.arange(0, n_obs_steps) * 3 + 2 # TODO(jayLEE0301) make it compatible with other values features = torch.cat([ features[:, historical_act_pred_index], @@ -221,26 +224,18 @@ class VQBeTHead(nn.Module): """ super().__init__() - self.input_size = config.gpt_output_dim - self.output_size = config.output_shapes["action"][0] - self.hidden_size = config.mlp_hidden_dim - self.offset_loss_weight = config.offset_loss_weight - self.secondary_code_loss_weight = config.secondary_code_loss_weight + self.config = config - self.vqvae_groups = config.vqvae_groups - self.vqvae_n_embed = config.vqvae_n_embed # C(number of code integers) - self.vqvae_embedding_dim = config.vqvae_embedding_dim # D(embedding dims) - self.n_action_pred_chunk = config.n_action_pred_chunk # action chunk size self._map_to_cbet_preds_bin = MLP( - in_channels=self.input_size, - hidden_channels=[self.vqvae_groups * self.vqvae_n_embed], + in_channels=config.gpt_output_dim, + hidden_channels=[self.config.vqvae_groups * self.config.vqvae_n_embed], ) self._map_to_cbet_preds_offset = MLP( - in_channels=self.input_size, + in_channels=config.gpt_output_dim, hidden_channels=[ - self.vqvae_groups * self.vqvae_n_embed * config.n_action_pred_chunk * self.output_size, + self.config.vqvae_groups * self.config.vqvae_n_embed * config.n_action_pred_chunk * config.output_shapes["action"][0], ], ) # init vqvae @@ -270,10 +265,10 @@ class VQBeTHead(nn.Module): cbet_logits = self._map_to_cbet_preds_bin(x) cbet_offsets = self._map_to_cbet_preds_offset(x) cbet_logits = einops.rearrange( - cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_groups + cbet_logits, "(NT) (G C) -> (NT) G C", G=self.config.vqvae_groups ) cbet_offsets = einops.rearrange( - cbet_offsets, "(NT) (G C WA) -> (NT) G C WA", G=self.vqvae_groups, C=self.vqvae_n_embed + cbet_offsets, "(NT) (G C WA) -> (NT) G C WA", G=self.config.vqvae_groups, C=self.config.vqvae_n_embed ) cbet_probs = torch.softmax(cbet_logits, dim=-1) NT, G, choices = cbet_probs.shape @@ -285,7 +280,7 @@ class VQBeTHead(nn.Module): indices = ( torch.arange(NT).unsqueeze(1).cuda(), - torch.arange(self.vqvae_groups).unsqueeze(0).cuda(), + torch.arange(self.config.vqvae_groups).unsqueeze(0).cuda(), sampled_centers, ) # Use advanced indexing to sample the values @@ -293,7 +288,7 @@ class VQBeTHead(nn.Module): sampled_offsets = sampled_offsets.sum(dim=1) centers = self._vqvae_model.draw_code_forward(sampled_centers).view( - NT, -1, self.vqvae_embedding_dim + NT, -1, self.config.vqvae_embedding_dim ) return_decoder_input = einops.rearrange( centers.clone().detach(), "NT 1 D -> NT D" @@ -304,7 +299,7 @@ class VQBeTHead(nn.Module): .detach() ) # NT, A sampled_offsets = einops.rearrange( - sampled_offsets, "NT (W A) -> NT W A", W=self.n_action_pred_chunk + sampled_offsets, "NT (W A) -> NT W A", W=self.config.n_action_pred_chunk ) predicted_action = decoded_action + sampled_offsets predicted_action = einops.rearrange( @@ -312,7 +307,7 @@ class VQBeTHead(nn.Module): "(N T) W A -> N T (W A)", N=N, T=T, - W=self.n_action_pred_chunk, + W=self.config.n_action_pred_chunk, ) return { @@ -333,7 +328,7 @@ class VQBeTHead(nn.Module): cbet_logits = pred["cbet_logits"] predicted_action = einops.rearrange( - predicted_action, "N T (W A) -> (N T) W A", W=self.n_action_pred_chunk + predicted_action, "N T (W A) -> (N T) W A", W=self.config.n_action_pred_chunk ) action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A") @@ -358,7 +353,7 @@ class VQBeTHead(nn.Module): cbet_logits[:, 1, :], action_bins[:, 1], ) - cbet_loss = cbet_loss1 * 5 + cbet_loss2 * self.secondary_code_loss_weight + cbet_loss = cbet_loss1 * 5 + cbet_loss2 * self.config.secondary_code_loss_weight equal_primary_code_rate = torch.sum( (action_bins[:, 0] == sampled_centers[:, 0]).int() @@ -387,7 +382,7 @@ class VQBeTHead(nn.Module): ) ).max() - loss = cbet_loss + self.offset_loss_weight * offset_loss + loss = cbet_loss + self.config.offset_loss_weight * offset_loss loss_dict = { "loss": loss, @@ -626,8 +621,7 @@ class VqVae(nn.Module): ): super(VqVae, self).__init__() - self.n_action_pred_chunk = config.n_action_pred_chunk - self.action_dim = config.output_shapes["action"][0] + self.config = config self.discretized = False self.optimized_steps = 0 @@ -637,25 +631,24 @@ class VqVae(nn.Module): num_quantizers=config.vqvae_groups, codebook_size=config.vqvae_n_embed, ) - self.embedding_dim = config.vqvae_embedding_dim - if self.n_action_pred_chunk == 1: + if self.config.n_action_pred_chunk == 1: self.encoder = MLP( - in_channels=self.action_dim, + in_channels=self.config.output_shapes["action"][0], hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, config.vqvae_embedding_dim], ) self.decoder = MLP( in_channels=config.vqvae_embedding_dim, - hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.action_dim], + hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.config.output_shapes["action"][0]], ) else: self.encoder = MLP( - in_channels=self.action_dim * self.n_action_pred_chunk, + in_channels=self.config.output_shapes["action"][0] * self.config.n_action_pred_chunk, hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, config.vqvae_embedding_dim], ) self.decoder = MLP( in_channels=config.vqvae_embedding_dim, - hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.action_dim * self.n_action_pred_chunk], + hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.config.output_shapes["action"][0] * self.config.n_action_pred_chunk], ) self.train() @@ -690,15 +683,15 @@ class VqVae(nn.Module): def get_action_from_latent(self, latent): output = self.decoder(latent) - if self.n_action_pred_chunk == 1: - return einops.rearrange(output, "N (T A) -> N T A", A=self.action_dim) + if self.config.n_action_pred_chunk == 1: + return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0]) else: - return einops.rearrange(output, "N (T A) -> N T A", A=self.action_dim) + return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0]) def preprocess(self, state): if not torch.is_tensor(state): state = torch.FloatTensor(state.copy()) - if self.n_action_pred_chunk == 1: + if self.config.n_action_pred_chunk == 1: state = state.squeeze(-2) # state.squeeze(-1) else: state = einops.rearrange(state, "N T A -> N (T A)") @@ -717,7 +710,7 @@ class VqVae(nn.Module): if required_recon: recon_state = self.decoder(state_vq) recon_state_ae = self.decoder(state_rep) - if self.n_action_pred_chunk == 1: + if self.config.n_action_pred_chunk == 1: return state_vq, vq_code, recon_state, recon_state_ae else: return ( @@ -764,13 +757,13 @@ class VqVae(nn.Module): def pretrain_vqvae(vqvae_model, discretize_step, actions): - if vqvae_model.n_action_pred_chunk == 1: + if vqvae_model.config.n_action_pred_chunk == 1: # not using action chunk actions = actions.reshape(-1, 1, actions.shape[-1]) else: # using action chunk slices = [] - slices.extend([actions[:, j:j+vqvae_model.n_action_pred_chunk, :] for j in range(actions.shape[1]+1-vqvae_model.n_action_pred_chunk)]) + slices.extend([actions[:, j:j+vqvae_model.config.n_action_pred_chunk, :] for j in range(actions.shape[1]+1-vqvae_model.config.n_action_pred_chunk)]) actions = torch.cat(slices, dim=0)