remove _ in the names of models
This commit is contained in:
parent
340f7cfd6e
commit
0b663243e3
|
@ -67,7 +67,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
self.vqbet = VQBeTModel(config)
|
self.vqbet = VQBeTModel(config)
|
||||||
|
|
||||||
def check_discretized(self):
|
def check_discretized(self):
|
||||||
return self.vqbet._action_head._vqvae_model.discretized
|
return self.vqbet.action_head.vqvae_model.discretized
|
||||||
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
|
@ -95,9 +95,9 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
self._queues = populate_queues(self._queues, batch)
|
self._queues = populate_queues(self._queues, batch)
|
||||||
|
|
||||||
if not self.check_discretized():
|
if not self.check_discretized():
|
||||||
self.vqbet._action_head._vqvae_model.discretized = True
|
self.vqbet.action_head.vqvae_model.discretized = True
|
||||||
# VQ-BeT can predict action only after finishing action discretization.
|
# VQ-BeT can predict action only after finishing action discretization.
|
||||||
# We added a logit to force self.vqbet._action_head._vqvae_model.discretized to be True if not self.check_discretized() to account for the case of predicting with a pretrained model, but this shouldn't happen if you're learning from scratch, so set eval_freq greater than discretize_step.
|
# We added a logit to force self.vqbet.action_head.vqvae_model.discretized to be True if not self.check_discretized() to account for the case of predicting with a pretrained model, but this shouldn't happen if you're learning from scratch, so set eval_freq greater than discretize_step.
|
||||||
warnings.warn('To evaluate in the environment, the model was forced to stop learning the Residual VQ. If you are not evaluating with a pre-trained model, this can degrade overall performance. To avoid this warning, please set "eval_freq" greater than "discretize_step".')
|
warnings.warn('To evaluate in the environment, the model was forced to stop learning the Residual VQ. If you are not evaluating with a pre-trained model, this can degrade overall performance. To avoid this warning, please set "eval_freq" greater than "discretize_step".')
|
||||||
assert "observation.image" in batch
|
assert "observation.image" in batch
|
||||||
assert "observation.state" in batch
|
assert "observation.state" in batch
|
||||||
|
@ -152,11 +152,11 @@ class VQBeTModel(nn.Module):
|
||||||
self.rgb_encoder.feature_dim,
|
self.rgb_encoder.feature_dim,
|
||||||
hidden_channels=[self.config.gpt_input_dim]
|
hidden_channels=[self.config.gpt_input_dim]
|
||||||
)
|
)
|
||||||
self._policy = GPT(config)
|
self.policy = GPT(config)
|
||||||
self._action_head = VQBeTHead(config)
|
self.action_head = VQBeTHead(config)
|
||||||
|
|
||||||
def discretize(self, discretize_step, actions):
|
def discretize(self, discretize_step, actions):
|
||||||
return self._action_head.discretize(discretize_step, actions)
|
return self.action_head.discretize(discretize_step, actions)
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor:
|
def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor:
|
||||||
# Input validation.
|
# Input validation.
|
||||||
|
@ -185,14 +185,14 @@ class VQBeTModel(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
# get action features
|
# get action features
|
||||||
features = self._policy(observation_feature)
|
features = self.policy(observation_feature)
|
||||||
historical_act_pred_index = np.arange(0, n_obs_steps) * 3 + 2 # TODO(jayLEE0301) make it compatible with other values
|
historical_act_pred_index = np.arange(0, n_obs_steps) * 3 + 2 # TODO(jayLEE0301) make it compatible with other values
|
||||||
features = torch.cat([
|
features = torch.cat([
|
||||||
features[:, historical_act_pred_index],
|
features[:, historical_act_pred_index],
|
||||||
features[:, -len_additional_action_token:]
|
features[:, -len_additional_action_token:]
|
||||||
], dim=1)
|
], dim=1)
|
||||||
# action head
|
# action head
|
||||||
pred_action = self._action_head(
|
pred_action = self.action_head(
|
||||||
features,
|
features,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -209,7 +209,7 @@ class VQBeTModel(nn.Module):
|
||||||
output[:, i, :, :] = action[:, i : i + act_w, :]
|
output[:, i, :, :] = action[:, i : i + act_w, :]
|
||||||
action = output
|
action = output
|
||||||
|
|
||||||
loss = self._action_head.loss_fn(
|
loss = self.action_head.loss_fn(
|
||||||
pred_action,
|
pred_action,
|
||||||
action,
|
action,
|
||||||
reduction="mean",
|
reduction="mean",
|
||||||
|
@ -228,33 +228,33 @@ class VQBeTHead(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
self._map_to_cbet_preds_bin = MLP(
|
self.map_to_cbet_preds_bin = MLP(
|
||||||
in_channels=config.gpt_output_dim,
|
in_channels=config.gpt_output_dim,
|
||||||
hidden_channels=[self.config.vqvae_groups * self.config.vqvae_n_embed],
|
hidden_channels=[self.config.vqvae_groups * self.config.vqvae_n_embed],
|
||||||
)
|
)
|
||||||
self._map_to_cbet_preds_offset = MLP(
|
self.map_to_cbet_preds_offset = MLP(
|
||||||
in_channels=config.gpt_output_dim,
|
in_channels=config.gpt_output_dim,
|
||||||
hidden_channels=[
|
hidden_channels=[
|
||||||
self.config.vqvae_groups * self.config.vqvae_n_embed * config.n_action_pred_chunk * config.output_shapes["action"][0],
|
self.config.vqvae_groups * self.config.vqvae_n_embed * config.n_action_pred_chunk * config.output_shapes["action"][0],
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
# init vqvae
|
# init vqvae
|
||||||
self._vqvae_model = VqVae(config)
|
self.vqvae_model = VqVae(config)
|
||||||
# loss
|
# loss
|
||||||
self._criterion = FocalLoss(gamma=2.0)
|
self._criterion = FocalLoss(gamma=2.0)
|
||||||
|
|
||||||
def discretize(self, discretize_step, actions):
|
def discretize(self, discretize_step, actions):
|
||||||
if next(self._vqvae_model.encoder.parameters()).device != get_device_from_parameters(self):
|
if next(self.vqvae_model.encoder.parameters()).device != get_device_from_parameters(self):
|
||||||
self._vqvae_model.encoder.to(get_device_from_parameters(self))
|
self.vqvae_model.encoder.to(get_device_from_parameters(self))
|
||||||
self._vqvae_model.vq_layer.to(get_device_from_parameters(self))
|
self.vqvae_model.vq_layer.to(get_device_from_parameters(self))
|
||||||
self._vqvae_model.decoder.to(get_device_from_parameters(self))
|
self.vqvae_model.decoder.to(get_device_from_parameters(self))
|
||||||
self._vqvae_model.device = get_device_from_parameters(self)
|
self.vqvae_model.device = get_device_from_parameters(self)
|
||||||
|
|
||||||
loss, n_different_codes, n_different_combinations = pretrain_vqvae(self._vqvae_model, discretize_step, actions)
|
loss, n_different_codes, n_different_combinations = pretrain_vqvae(self.vqvae_model, discretize_step, actions)
|
||||||
if self._vqvae_model.discretized:
|
if self.vqvae_model.discretized:
|
||||||
print("Finished discretizing action data!")
|
print("Finished discretizing action data!")
|
||||||
self._vqvae_model.eval()
|
self.vqvae_model.eval()
|
||||||
for param in self._vqvae_model.vq_layer.parameters():
|
for param in self.vqvae_model.vq_layer.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
return loss, n_different_codes, n_different_combinations
|
return loss, n_different_codes, n_different_combinations
|
||||||
|
|
||||||
|
@ -262,8 +262,8 @@ class VQBeTHead(nn.Module):
|
||||||
N, T, _ = x.shape
|
N, T, _ = x.shape
|
||||||
x = einops.rearrange(x, "N T WA -> (N T) WA")
|
x = einops.rearrange(x, "N T WA -> (N T) WA")
|
||||||
|
|
||||||
cbet_logits = self._map_to_cbet_preds_bin(x)
|
cbet_logits = self.map_to_cbet_preds_bin(x)
|
||||||
cbet_offsets = self._map_to_cbet_preds_offset(x)
|
cbet_offsets = self.map_to_cbet_preds_offset(x)
|
||||||
cbet_logits = einops.rearrange(
|
cbet_logits = einops.rearrange(
|
||||||
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.config.vqvae_groups
|
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.config.vqvae_groups
|
||||||
)
|
)
|
||||||
|
@ -287,14 +287,14 @@ class VQBeTHead(nn.Module):
|
||||||
sampled_offsets = cbet_offsets[indices] # NT, G, W, A(?) or NT, G, A
|
sampled_offsets = cbet_offsets[indices] # NT, G, W, A(?) or NT, G, A
|
||||||
|
|
||||||
sampled_offsets = sampled_offsets.sum(dim=1)
|
sampled_offsets = sampled_offsets.sum(dim=1)
|
||||||
centers = self._vqvae_model.draw_code_forward(sampled_centers).view(
|
centers = self.vqvae_model.draw_code_forward(sampled_centers).view(
|
||||||
NT, -1, self.config.vqvae_embedding_dim
|
NT, -1, self.config.vqvae_embedding_dim
|
||||||
)
|
)
|
||||||
return_decoder_input = einops.rearrange(
|
return_decoder_input = einops.rearrange(
|
||||||
centers.clone().detach(), "NT 1 D -> NT D"
|
centers.clone().detach(), "NT 1 D -> NT D"
|
||||||
)
|
)
|
||||||
decoded_action = (
|
decoded_action = (
|
||||||
self._vqvae_model.get_action_from_latent(return_decoder_input)
|
self.vqvae_model.get_action_from_latent(return_decoder_input)
|
||||||
.clone()
|
.clone()
|
||||||
.detach()
|
.detach()
|
||||||
) # NT, A
|
) # NT, A
|
||||||
|
@ -334,7 +334,7 @@ class VQBeTHead(nn.Module):
|
||||||
action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A")
|
action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A")
|
||||||
# Figure out the loss for the actions.
|
# Figure out the loss for the actions.
|
||||||
# First, we need to find the closest cluster center for each action.
|
# First, we need to find the closest cluster center for each action.
|
||||||
state_vq, action_bins = self._vqvae_model.get_code(
|
state_vq, action_bins = self.vqvae_model.get_code(
|
||||||
action_seq
|
action_seq
|
||||||
) # action_bins: NT, G
|
) # action_bins: NT, G
|
||||||
|
|
||||||
|
@ -406,9 +406,9 @@ class VQBeTOptimizer:
|
||||||
|
|
||||||
|
|
||||||
vqvae_params = (
|
vqvae_params = (
|
||||||
list(policy.vqbet._action_head._vqvae_model.encoder.parameters())
|
list(policy.vqbet.action_head.vqvae_model.encoder.parameters())
|
||||||
+ list(policy.vqbet._action_head._vqvae_model.decoder.parameters())
|
+ list(policy.vqbet.action_head.vqvae_model.decoder.parameters())
|
||||||
+ list(policy.vqbet._action_head._vqvae_model.vq_layer.parameters())
|
+ list(policy.vqbet.action_head.vqvae_model.vq_layer.parameters())
|
||||||
)
|
)
|
||||||
self.vqvae_optimizer = torch.optim.Adam(
|
self.vqvae_optimizer = torch.optim.Adam(
|
||||||
vqvae_params, lr=cfg.training.vqvae_lr, weight_decay=0.0001
|
vqvae_params, lr=cfg.training.vqvae_lr, weight_decay=0.0001
|
||||||
|
@ -422,7 +422,7 @@ class VQBeTOptimizer:
|
||||||
cfg.training.adam_weight_decay,
|
cfg.training.adam_weight_decay,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.bet_optimizer1 = policy.vqbet._policy.configure_optimizers(
|
self.bet_optimizer1 = policy.vqbet.policy.configure_optimizers(
|
||||||
weight_decay=cfg.training.bet_weight_decay,
|
weight_decay=cfg.training.bet_weight_decay,
|
||||||
learning_rate=cfg.training.bet_learning_rate,
|
learning_rate=cfg.training.bet_learning_rate,
|
||||||
betas=cfg.training.bet_betas,
|
betas=cfg.training.bet_betas,
|
||||||
|
@ -442,14 +442,14 @@ class VQBeTOptimizer:
|
||||||
)
|
)
|
||||||
|
|
||||||
self.bet_optimizer2 = torch.optim.AdamW(
|
self.bet_optimizer2 = torch.optim.AdamW(
|
||||||
policy.vqbet._action_head._map_to_cbet_preds_bin.parameters(),
|
policy.vqbet.action_head.map_to_cbet_preds_bin.parameters(),
|
||||||
lr=cfg.training.bet_learning_rate,
|
lr=cfg.training.bet_learning_rate,
|
||||||
weight_decay=cfg.training.bet_weight_decay,
|
weight_decay=cfg.training.bet_weight_decay,
|
||||||
betas=cfg.training.bet_betas,
|
betas=cfg.training.bet_betas,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.bet_optimizer3 = torch.optim.AdamW(
|
self.bet_optimizer3 = torch.optim.AdamW(
|
||||||
policy.vqbet._action_head._map_to_cbet_preds_offset.parameters(),
|
policy.vqbet.action_head.map_to_cbet_preds_offset.parameters(),
|
||||||
lr=cfg.training.bet_learning_rate,
|
lr=cfg.training.bet_learning_rate,
|
||||||
weight_decay=cfg.training.bet_weight_decay,
|
weight_decay=cfg.training.bet_weight_decay,
|
||||||
betas=cfg.training.bet_betas,
|
betas=cfg.training.bet_betas,
|
||||||
|
@ -490,7 +490,7 @@ class VQBeTScheduler:
|
||||||
def __init__(self, optimizer, cfg):
|
def __init__(self, optimizer, cfg):
|
||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
self.discretize_step = cfg.training.discretize_step
|
self.discretize_step = cfg.training.discretize_step
|
||||||
self.offline_steps = cfg.training.offline_steps
|
# self.offline_steps = cfg.training.offline_steps
|
||||||
self.optimizing_step = 0
|
self.optimizing_step = 0
|
||||||
|
|
||||||
self.lr_scheduler1 = get_scheduler(
|
self.lr_scheduler1 = get_scheduler(
|
||||||
|
|
Loading…
Reference in New Issue