remove _ in the names of models

This commit is contained in:
jayLEE0301 2024-05-24 14:20:58 -04:00
parent 340f7cfd6e
commit 0b663243e3
1 changed files with 33 additions and 33 deletions

View File

@ -67,7 +67,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
self.vqbet = VQBeTModel(config)
def check_discretized(self):
return self.vqbet._action_head._vqvae_model.discretized
return self.vqbet.action_head.vqvae_model.discretized
def reset(self):
@ -95,9 +95,9 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
self._queues = populate_queues(self._queues, batch)
if not self.check_discretized():
self.vqbet._action_head._vqvae_model.discretized = True
self.vqbet.action_head.vqvae_model.discretized = True
# VQ-BeT can predict action only after finishing action discretization.
# We added a logit to force self.vqbet._action_head._vqvae_model.discretized to be True if not self.check_discretized() to account for the case of predicting with a pretrained model, but this shouldn't happen if you're learning from scratch, so set eval_freq greater than discretize_step.
# We added a logit to force self.vqbet.action_head.vqvae_model.discretized to be True if not self.check_discretized() to account for the case of predicting with a pretrained model, but this shouldn't happen if you're learning from scratch, so set eval_freq greater than discretize_step.
warnings.warn('To evaluate in the environment, the model was forced to stop learning the Residual VQ. If you are not evaluating with a pre-trained model, this can degrade overall performance. To avoid this warning, please set "eval_freq" greater than "discretize_step".')
assert "observation.image" in batch
assert "observation.state" in batch
@ -152,11 +152,11 @@ class VQBeTModel(nn.Module):
self.rgb_encoder.feature_dim,
hidden_channels=[self.config.gpt_input_dim]
)
self._policy = GPT(config)
self._action_head = VQBeTHead(config)
self.policy = GPT(config)
self.action_head = VQBeTHead(config)
def discretize(self, discretize_step, actions):
return self._action_head.discretize(discretize_step, actions)
return self.action_head.discretize(discretize_step, actions)
def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor:
# Input validation.
@ -185,14 +185,14 @@ class VQBeTModel(nn.Module):
# get action features
features = self._policy(observation_feature)
features = self.policy(observation_feature)
historical_act_pred_index = np.arange(0, n_obs_steps) * 3 + 2 # TODO(jayLEE0301) make it compatible with other values
features = torch.cat([
features[:, historical_act_pred_index],
features[:, -len_additional_action_token:]
], dim=1)
# action head
pred_action = self._action_head(
pred_action = self.action_head(
features,
)
@ -209,7 +209,7 @@ class VQBeTModel(nn.Module):
output[:, i, :, :] = action[:, i : i + act_w, :]
action = output
loss = self._action_head.loss_fn(
loss = self.action_head.loss_fn(
pred_action,
action,
reduction="mean",
@ -228,33 +228,33 @@ class VQBeTHead(nn.Module):
self._map_to_cbet_preds_bin = MLP(
self.map_to_cbet_preds_bin = MLP(
in_channels=config.gpt_output_dim,
hidden_channels=[self.config.vqvae_groups * self.config.vqvae_n_embed],
)
self._map_to_cbet_preds_offset = MLP(
self.map_to_cbet_preds_offset = MLP(
in_channels=config.gpt_output_dim,
hidden_channels=[
self.config.vqvae_groups * self.config.vqvae_n_embed * config.n_action_pred_chunk * config.output_shapes["action"][0],
],
)
# init vqvae
self._vqvae_model = VqVae(config)
self.vqvae_model = VqVae(config)
# loss
self._criterion = FocalLoss(gamma=2.0)
def discretize(self, discretize_step, actions):
if next(self._vqvae_model.encoder.parameters()).device != get_device_from_parameters(self):
self._vqvae_model.encoder.to(get_device_from_parameters(self))
self._vqvae_model.vq_layer.to(get_device_from_parameters(self))
self._vqvae_model.decoder.to(get_device_from_parameters(self))
self._vqvae_model.device = get_device_from_parameters(self)
if next(self.vqvae_model.encoder.parameters()).device != get_device_from_parameters(self):
self.vqvae_model.encoder.to(get_device_from_parameters(self))
self.vqvae_model.vq_layer.to(get_device_from_parameters(self))
self.vqvae_model.decoder.to(get_device_from_parameters(self))
self.vqvae_model.device = get_device_from_parameters(self)
loss, n_different_codes, n_different_combinations = pretrain_vqvae(self._vqvae_model, discretize_step, actions)
if self._vqvae_model.discretized:
loss, n_different_codes, n_different_combinations = pretrain_vqvae(self.vqvae_model, discretize_step, actions)
if self.vqvae_model.discretized:
print("Finished discretizing action data!")
self._vqvae_model.eval()
for param in self._vqvae_model.vq_layer.parameters():
self.vqvae_model.eval()
for param in self.vqvae_model.vq_layer.parameters():
param.requires_grad = False
return loss, n_different_codes, n_different_combinations
@ -262,8 +262,8 @@ class VQBeTHead(nn.Module):
N, T, _ = x.shape
x = einops.rearrange(x, "N T WA -> (N T) WA")
cbet_logits = self._map_to_cbet_preds_bin(x)
cbet_offsets = self._map_to_cbet_preds_offset(x)
cbet_logits = self.map_to_cbet_preds_bin(x)
cbet_offsets = self.map_to_cbet_preds_offset(x)
cbet_logits = einops.rearrange(
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.config.vqvae_groups
)
@ -287,14 +287,14 @@ class VQBeTHead(nn.Module):
sampled_offsets = cbet_offsets[indices] # NT, G, W, A(?) or NT, G, A
sampled_offsets = sampled_offsets.sum(dim=1)
centers = self._vqvae_model.draw_code_forward(sampled_centers).view(
centers = self.vqvae_model.draw_code_forward(sampled_centers).view(
NT, -1, self.config.vqvae_embedding_dim
)
return_decoder_input = einops.rearrange(
centers.clone().detach(), "NT 1 D -> NT D"
)
decoded_action = (
self._vqvae_model.get_action_from_latent(return_decoder_input)
self.vqvae_model.get_action_from_latent(return_decoder_input)
.clone()
.detach()
) # NT, A
@ -334,7 +334,7 @@ class VQBeTHead(nn.Module):
action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A")
# Figure out the loss for the actions.
# First, we need to find the closest cluster center for each action.
state_vq, action_bins = self._vqvae_model.get_code(
state_vq, action_bins = self.vqvae_model.get_code(
action_seq
) # action_bins: NT, G
@ -406,9 +406,9 @@ class VQBeTOptimizer:
vqvae_params = (
list(policy.vqbet._action_head._vqvae_model.encoder.parameters())
+ list(policy.vqbet._action_head._vqvae_model.decoder.parameters())
+ list(policy.vqbet._action_head._vqvae_model.vq_layer.parameters())
list(policy.vqbet.action_head.vqvae_model.encoder.parameters())
+ list(policy.vqbet.action_head.vqvae_model.decoder.parameters())
+ list(policy.vqbet.action_head.vqvae_model.vq_layer.parameters())
)
self.vqvae_optimizer = torch.optim.Adam(
vqvae_params, lr=cfg.training.vqvae_lr, weight_decay=0.0001
@ -422,7 +422,7 @@ class VQBeTOptimizer:
cfg.training.adam_weight_decay,
)
self.bet_optimizer1 = policy.vqbet._policy.configure_optimizers(
self.bet_optimizer1 = policy.vqbet.policy.configure_optimizers(
weight_decay=cfg.training.bet_weight_decay,
learning_rate=cfg.training.bet_learning_rate,
betas=cfg.training.bet_betas,
@ -442,14 +442,14 @@ class VQBeTOptimizer:
)
self.bet_optimizer2 = torch.optim.AdamW(
policy.vqbet._action_head._map_to_cbet_preds_bin.parameters(),
policy.vqbet.action_head.map_to_cbet_preds_bin.parameters(),
lr=cfg.training.bet_learning_rate,
weight_decay=cfg.training.bet_weight_decay,
betas=cfg.training.bet_betas,
)
self.bet_optimizer3 = torch.optim.AdamW(
policy.vqbet._action_head._map_to_cbet_preds_offset.parameters(),
policy.vqbet.action_head.map_to_cbet_preds_offset.parameters(),
lr=cfg.training.bet_learning_rate,
weight_decay=cfg.training.bet_weight_decay,
betas=cfg.training.bet_betas,
@ -490,7 +490,7 @@ class VQBeTScheduler:
def __init__(self, optimizer, cfg):
from diffusers.optimization import get_scheduler
self.discretize_step = cfg.training.discretize_step
self.offline_steps = cfg.training.offline_steps
# self.offline_steps = cfg.training.offline_steps
self.optimizing_step = 0
self.lr_scheduler1 = get_scheduler(