remove _ in the names of models
This commit is contained in:
parent
340f7cfd6e
commit
0b663243e3
|
@ -67,7 +67,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
self.vqbet = VQBeTModel(config)
|
||||
|
||||
def check_discretized(self):
|
||||
return self.vqbet._action_head._vqvae_model.discretized
|
||||
return self.vqbet.action_head.vqvae_model.discretized
|
||||
|
||||
|
||||
def reset(self):
|
||||
|
@ -95,9 +95,9 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
if not self.check_discretized():
|
||||
self.vqbet._action_head._vqvae_model.discretized = True
|
||||
self.vqbet.action_head.vqvae_model.discretized = True
|
||||
# VQ-BeT can predict action only after finishing action discretization.
|
||||
# We added a logit to force self.vqbet._action_head._vqvae_model.discretized to be True if not self.check_discretized() to account for the case of predicting with a pretrained model, but this shouldn't happen if you're learning from scratch, so set eval_freq greater than discretize_step.
|
||||
# We added a logit to force self.vqbet.action_head.vqvae_model.discretized to be True if not self.check_discretized() to account for the case of predicting with a pretrained model, but this shouldn't happen if you're learning from scratch, so set eval_freq greater than discretize_step.
|
||||
warnings.warn('To evaluate in the environment, the model was forced to stop learning the Residual VQ. If you are not evaluating with a pre-trained model, this can degrade overall performance. To avoid this warning, please set "eval_freq" greater than "discretize_step".')
|
||||
assert "observation.image" in batch
|
||||
assert "observation.state" in batch
|
||||
|
@ -152,11 +152,11 @@ class VQBeTModel(nn.Module):
|
|||
self.rgb_encoder.feature_dim,
|
||||
hidden_channels=[self.config.gpt_input_dim]
|
||||
)
|
||||
self._policy = GPT(config)
|
||||
self._action_head = VQBeTHead(config)
|
||||
self.policy = GPT(config)
|
||||
self.action_head = VQBeTHead(config)
|
||||
|
||||
def discretize(self, discretize_step, actions):
|
||||
return self._action_head.discretize(discretize_step, actions)
|
||||
return self.action_head.discretize(discretize_step, actions)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor:
|
||||
# Input validation.
|
||||
|
@ -185,14 +185,14 @@ class VQBeTModel(nn.Module):
|
|||
|
||||
|
||||
# get action features
|
||||
features = self._policy(observation_feature)
|
||||
features = self.policy(observation_feature)
|
||||
historical_act_pred_index = np.arange(0, n_obs_steps) * 3 + 2 # TODO(jayLEE0301) make it compatible with other values
|
||||
features = torch.cat([
|
||||
features[:, historical_act_pred_index],
|
||||
features[:, -len_additional_action_token:]
|
||||
], dim=1)
|
||||
# action head
|
||||
pred_action = self._action_head(
|
||||
pred_action = self.action_head(
|
||||
features,
|
||||
)
|
||||
|
||||
|
@ -209,7 +209,7 @@ class VQBeTModel(nn.Module):
|
|||
output[:, i, :, :] = action[:, i : i + act_w, :]
|
||||
action = output
|
||||
|
||||
loss = self._action_head.loss_fn(
|
||||
loss = self.action_head.loss_fn(
|
||||
pred_action,
|
||||
action,
|
||||
reduction="mean",
|
||||
|
@ -228,33 +228,33 @@ class VQBeTHead(nn.Module):
|
|||
|
||||
|
||||
|
||||
self._map_to_cbet_preds_bin = MLP(
|
||||
self.map_to_cbet_preds_bin = MLP(
|
||||
in_channels=config.gpt_output_dim,
|
||||
hidden_channels=[self.config.vqvae_groups * self.config.vqvae_n_embed],
|
||||
)
|
||||
self._map_to_cbet_preds_offset = MLP(
|
||||
self.map_to_cbet_preds_offset = MLP(
|
||||
in_channels=config.gpt_output_dim,
|
||||
hidden_channels=[
|
||||
self.config.vqvae_groups * self.config.vqvae_n_embed * config.n_action_pred_chunk * config.output_shapes["action"][0],
|
||||
],
|
||||
)
|
||||
# init vqvae
|
||||
self._vqvae_model = VqVae(config)
|
||||
self.vqvae_model = VqVae(config)
|
||||
# loss
|
||||
self._criterion = FocalLoss(gamma=2.0)
|
||||
|
||||
def discretize(self, discretize_step, actions):
|
||||
if next(self._vqvae_model.encoder.parameters()).device != get_device_from_parameters(self):
|
||||
self._vqvae_model.encoder.to(get_device_from_parameters(self))
|
||||
self._vqvae_model.vq_layer.to(get_device_from_parameters(self))
|
||||
self._vqvae_model.decoder.to(get_device_from_parameters(self))
|
||||
self._vqvae_model.device = get_device_from_parameters(self)
|
||||
if next(self.vqvae_model.encoder.parameters()).device != get_device_from_parameters(self):
|
||||
self.vqvae_model.encoder.to(get_device_from_parameters(self))
|
||||
self.vqvae_model.vq_layer.to(get_device_from_parameters(self))
|
||||
self.vqvae_model.decoder.to(get_device_from_parameters(self))
|
||||
self.vqvae_model.device = get_device_from_parameters(self)
|
||||
|
||||
loss, n_different_codes, n_different_combinations = pretrain_vqvae(self._vqvae_model, discretize_step, actions)
|
||||
if self._vqvae_model.discretized:
|
||||
loss, n_different_codes, n_different_combinations = pretrain_vqvae(self.vqvae_model, discretize_step, actions)
|
||||
if self.vqvae_model.discretized:
|
||||
print("Finished discretizing action data!")
|
||||
self._vqvae_model.eval()
|
||||
for param in self._vqvae_model.vq_layer.parameters():
|
||||
self.vqvae_model.eval()
|
||||
for param in self.vqvae_model.vq_layer.parameters():
|
||||
param.requires_grad = False
|
||||
return loss, n_different_codes, n_different_combinations
|
||||
|
||||
|
@ -262,8 +262,8 @@ class VQBeTHead(nn.Module):
|
|||
N, T, _ = x.shape
|
||||
x = einops.rearrange(x, "N T WA -> (N T) WA")
|
||||
|
||||
cbet_logits = self._map_to_cbet_preds_bin(x)
|
||||
cbet_offsets = self._map_to_cbet_preds_offset(x)
|
||||
cbet_logits = self.map_to_cbet_preds_bin(x)
|
||||
cbet_offsets = self.map_to_cbet_preds_offset(x)
|
||||
cbet_logits = einops.rearrange(
|
||||
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.config.vqvae_groups
|
||||
)
|
||||
|
@ -287,14 +287,14 @@ class VQBeTHead(nn.Module):
|
|||
sampled_offsets = cbet_offsets[indices] # NT, G, W, A(?) or NT, G, A
|
||||
|
||||
sampled_offsets = sampled_offsets.sum(dim=1)
|
||||
centers = self._vqvae_model.draw_code_forward(sampled_centers).view(
|
||||
centers = self.vqvae_model.draw_code_forward(sampled_centers).view(
|
||||
NT, -1, self.config.vqvae_embedding_dim
|
||||
)
|
||||
return_decoder_input = einops.rearrange(
|
||||
centers.clone().detach(), "NT 1 D -> NT D"
|
||||
)
|
||||
decoded_action = (
|
||||
self._vqvae_model.get_action_from_latent(return_decoder_input)
|
||||
self.vqvae_model.get_action_from_latent(return_decoder_input)
|
||||
.clone()
|
||||
.detach()
|
||||
) # NT, A
|
||||
|
@ -334,7 +334,7 @@ class VQBeTHead(nn.Module):
|
|||
action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A")
|
||||
# Figure out the loss for the actions.
|
||||
# First, we need to find the closest cluster center for each action.
|
||||
state_vq, action_bins = self._vqvae_model.get_code(
|
||||
state_vq, action_bins = self.vqvae_model.get_code(
|
||||
action_seq
|
||||
) # action_bins: NT, G
|
||||
|
||||
|
@ -406,9 +406,9 @@ class VQBeTOptimizer:
|
|||
|
||||
|
||||
vqvae_params = (
|
||||
list(policy.vqbet._action_head._vqvae_model.encoder.parameters())
|
||||
+ list(policy.vqbet._action_head._vqvae_model.decoder.parameters())
|
||||
+ list(policy.vqbet._action_head._vqvae_model.vq_layer.parameters())
|
||||
list(policy.vqbet.action_head.vqvae_model.encoder.parameters())
|
||||
+ list(policy.vqbet.action_head.vqvae_model.decoder.parameters())
|
||||
+ list(policy.vqbet.action_head.vqvae_model.vq_layer.parameters())
|
||||
)
|
||||
self.vqvae_optimizer = torch.optim.Adam(
|
||||
vqvae_params, lr=cfg.training.vqvae_lr, weight_decay=0.0001
|
||||
|
@ -422,7 +422,7 @@ class VQBeTOptimizer:
|
|||
cfg.training.adam_weight_decay,
|
||||
)
|
||||
|
||||
self.bet_optimizer1 = policy.vqbet._policy.configure_optimizers(
|
||||
self.bet_optimizer1 = policy.vqbet.policy.configure_optimizers(
|
||||
weight_decay=cfg.training.bet_weight_decay,
|
||||
learning_rate=cfg.training.bet_learning_rate,
|
||||
betas=cfg.training.bet_betas,
|
||||
|
@ -442,14 +442,14 @@ class VQBeTOptimizer:
|
|||
)
|
||||
|
||||
self.bet_optimizer2 = torch.optim.AdamW(
|
||||
policy.vqbet._action_head._map_to_cbet_preds_bin.parameters(),
|
||||
policy.vqbet.action_head.map_to_cbet_preds_bin.parameters(),
|
||||
lr=cfg.training.bet_learning_rate,
|
||||
weight_decay=cfg.training.bet_weight_decay,
|
||||
betas=cfg.training.bet_betas,
|
||||
)
|
||||
|
||||
self.bet_optimizer3 = torch.optim.AdamW(
|
||||
policy.vqbet._action_head._map_to_cbet_preds_offset.parameters(),
|
||||
policy.vqbet.action_head.map_to_cbet_preds_offset.parameters(),
|
||||
lr=cfg.training.bet_learning_rate,
|
||||
weight_decay=cfg.training.bet_weight_decay,
|
||||
betas=cfg.training.bet_betas,
|
||||
|
@ -490,7 +490,7 @@ class VQBeTScheduler:
|
|||
def __init__(self, optimizer, cfg):
|
||||
from diffusers.optimization import get_scheduler
|
||||
self.discretize_step = cfg.training.discretize_step
|
||||
self.offline_steps = cfg.training.offline_steps
|
||||
# self.offline_steps = cfg.training.offline_steps
|
||||
self.optimizing_step = 0
|
||||
|
||||
self.lr_scheduler1 = get_scheduler(
|
||||
|
|
Loading…
Reference in New Issue