From 0b663243e39941e7d6d8b1cf068f408b349f79c9 Mon Sep 17 00:00:00 2001 From: jayLEE0301 Date: Fri, 24 May 2024 14:20:58 -0400 Subject: [PATCH] remove _ in the names of models --- .../common/policies/vqbet/modeling_vqbet.py | 66 +++++++++---------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/lerobot/common/policies/vqbet/modeling_vqbet.py b/lerobot/common/policies/vqbet/modeling_vqbet.py index c29a08db..1de6d79c 100644 --- a/lerobot/common/policies/vqbet/modeling_vqbet.py +++ b/lerobot/common/policies/vqbet/modeling_vqbet.py @@ -67,7 +67,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin): self.vqbet = VQBeTModel(config) def check_discretized(self): - return self.vqbet._action_head._vqvae_model.discretized + return self.vqbet.action_head.vqvae_model.discretized def reset(self): @@ -95,9 +95,9 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin): self._queues = populate_queues(self._queues, batch) if not self.check_discretized(): - self.vqbet._action_head._vqvae_model.discretized = True + self.vqbet.action_head.vqvae_model.discretized = True # VQ-BeT can predict action only after finishing action discretization. - # We added a logit to force self.vqbet._action_head._vqvae_model.discretized to be True if not self.check_discretized() to account for the case of predicting with a pretrained model, but this shouldn't happen if you're learning from scratch, so set eval_freq greater than discretize_step. + # We added a logit to force self.vqbet.action_head.vqvae_model.discretized to be True if not self.check_discretized() to account for the case of predicting with a pretrained model, but this shouldn't happen if you're learning from scratch, so set eval_freq greater than discretize_step. warnings.warn('To evaluate in the environment, the model was forced to stop learning the Residual VQ. If you are not evaluating with a pre-trained model, this can degrade overall performance. To avoid this warning, please set "eval_freq" greater than "discretize_step".') assert "observation.image" in batch assert "observation.state" in batch @@ -152,11 +152,11 @@ class VQBeTModel(nn.Module): self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim] ) - self._policy = GPT(config) - self._action_head = VQBeTHead(config) + self.policy = GPT(config) + self.action_head = VQBeTHead(config) def discretize(self, discretize_step, actions): - return self._action_head.discretize(discretize_step, actions) + return self.action_head.discretize(discretize_step, actions) def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor: # Input validation. @@ -185,14 +185,14 @@ class VQBeTModel(nn.Module): # get action features - features = self._policy(observation_feature) + features = self.policy(observation_feature) historical_act_pred_index = np.arange(0, n_obs_steps) * 3 + 2 # TODO(jayLEE0301) make it compatible with other values features = torch.cat([ features[:, historical_act_pred_index], features[:, -len_additional_action_token:] ], dim=1) # action head - pred_action = self._action_head( + pred_action = self.action_head( features, ) @@ -209,7 +209,7 @@ class VQBeTModel(nn.Module): output[:, i, :, :] = action[:, i : i + act_w, :] action = output - loss = self._action_head.loss_fn( + loss = self.action_head.loss_fn( pred_action, action, reduction="mean", @@ -228,33 +228,33 @@ class VQBeTHead(nn.Module): - self._map_to_cbet_preds_bin = MLP( + self.map_to_cbet_preds_bin = MLP( in_channels=config.gpt_output_dim, hidden_channels=[self.config.vqvae_groups * self.config.vqvae_n_embed], ) - self._map_to_cbet_preds_offset = MLP( + self.map_to_cbet_preds_offset = MLP( in_channels=config.gpt_output_dim, hidden_channels=[ self.config.vqvae_groups * self.config.vqvae_n_embed * config.n_action_pred_chunk * config.output_shapes["action"][0], ], ) # init vqvae - self._vqvae_model = VqVae(config) + self.vqvae_model = VqVae(config) # loss self._criterion = FocalLoss(gamma=2.0) def discretize(self, discretize_step, actions): - if next(self._vqvae_model.encoder.parameters()).device != get_device_from_parameters(self): - self._vqvae_model.encoder.to(get_device_from_parameters(self)) - self._vqvae_model.vq_layer.to(get_device_from_parameters(self)) - self._vqvae_model.decoder.to(get_device_from_parameters(self)) - self._vqvae_model.device = get_device_from_parameters(self) + if next(self.vqvae_model.encoder.parameters()).device != get_device_from_parameters(self): + self.vqvae_model.encoder.to(get_device_from_parameters(self)) + self.vqvae_model.vq_layer.to(get_device_from_parameters(self)) + self.vqvae_model.decoder.to(get_device_from_parameters(self)) + self.vqvae_model.device = get_device_from_parameters(self) - loss, n_different_codes, n_different_combinations = pretrain_vqvae(self._vqvae_model, discretize_step, actions) - if self._vqvae_model.discretized: + loss, n_different_codes, n_different_combinations = pretrain_vqvae(self.vqvae_model, discretize_step, actions) + if self.vqvae_model.discretized: print("Finished discretizing action data!") - self._vqvae_model.eval() - for param in self._vqvae_model.vq_layer.parameters(): + self.vqvae_model.eval() + for param in self.vqvae_model.vq_layer.parameters(): param.requires_grad = False return loss, n_different_codes, n_different_combinations @@ -262,8 +262,8 @@ class VQBeTHead(nn.Module): N, T, _ = x.shape x = einops.rearrange(x, "N T WA -> (N T) WA") - cbet_logits = self._map_to_cbet_preds_bin(x) - cbet_offsets = self._map_to_cbet_preds_offset(x) + cbet_logits = self.map_to_cbet_preds_bin(x) + cbet_offsets = self.map_to_cbet_preds_offset(x) cbet_logits = einops.rearrange( cbet_logits, "(NT) (G C) -> (NT) G C", G=self.config.vqvae_groups ) @@ -287,14 +287,14 @@ class VQBeTHead(nn.Module): sampled_offsets = cbet_offsets[indices] # NT, G, W, A(?) or NT, G, A sampled_offsets = sampled_offsets.sum(dim=1) - centers = self._vqvae_model.draw_code_forward(sampled_centers).view( + centers = self.vqvae_model.draw_code_forward(sampled_centers).view( NT, -1, self.config.vqvae_embedding_dim ) return_decoder_input = einops.rearrange( centers.clone().detach(), "NT 1 D -> NT D" ) decoded_action = ( - self._vqvae_model.get_action_from_latent(return_decoder_input) + self.vqvae_model.get_action_from_latent(return_decoder_input) .clone() .detach() ) # NT, A @@ -334,7 +334,7 @@ class VQBeTHead(nn.Module): action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A") # Figure out the loss for the actions. # First, we need to find the closest cluster center for each action. - state_vq, action_bins = self._vqvae_model.get_code( + state_vq, action_bins = self.vqvae_model.get_code( action_seq ) # action_bins: NT, G @@ -406,9 +406,9 @@ class VQBeTOptimizer: vqvae_params = ( - list(policy.vqbet._action_head._vqvae_model.encoder.parameters()) - + list(policy.vqbet._action_head._vqvae_model.decoder.parameters()) - + list(policy.vqbet._action_head._vqvae_model.vq_layer.parameters()) + list(policy.vqbet.action_head.vqvae_model.encoder.parameters()) + + list(policy.vqbet.action_head.vqvae_model.decoder.parameters()) + + list(policy.vqbet.action_head.vqvae_model.vq_layer.parameters()) ) self.vqvae_optimizer = torch.optim.Adam( vqvae_params, lr=cfg.training.vqvae_lr, weight_decay=0.0001 @@ -422,7 +422,7 @@ class VQBeTOptimizer: cfg.training.adam_weight_decay, ) - self.bet_optimizer1 = policy.vqbet._policy.configure_optimizers( + self.bet_optimizer1 = policy.vqbet.policy.configure_optimizers( weight_decay=cfg.training.bet_weight_decay, learning_rate=cfg.training.bet_learning_rate, betas=cfg.training.bet_betas, @@ -442,14 +442,14 @@ class VQBeTOptimizer: ) self.bet_optimizer2 = torch.optim.AdamW( - policy.vqbet._action_head._map_to_cbet_preds_bin.parameters(), + policy.vqbet.action_head.map_to_cbet_preds_bin.parameters(), lr=cfg.training.bet_learning_rate, weight_decay=cfg.training.bet_weight_decay, betas=cfg.training.bet_betas, ) self.bet_optimizer3 = torch.optim.AdamW( - policy.vqbet._action_head._map_to_cbet_preds_offset.parameters(), + policy.vqbet.action_head.map_to_cbet_preds_offset.parameters(), lr=cfg.training.bet_learning_rate, weight_decay=cfg.training.bet_weight_decay, betas=cfg.training.bet_betas, @@ -490,7 +490,7 @@ class VQBeTScheduler: def __init__(self, optimizer, cfg): from diffusers.optimization import get_scheduler self.discretize_step = cfg.training.discretize_step - self.offline_steps = cfg.training.offline_steps + # self.offline_steps = cfg.training.offline_steps self.optimizing_step = 0 self.lr_scheduler1 = get_scheduler(