add comment, change param names, delete kwargs
This commit is contained in:
parent
e301caf182
commit
340f7cfd6e
|
@ -96,7 +96,9 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
|
||||
if not self.check_discretized():
|
||||
self.vqbet._action_head._vqvae_model.discretized = True
|
||||
warnings.warn('To evaluate in the environment, the model was forced to stop learning the Residual VQ. If you are not evaluating with a pre-trained model, this can degrade overall performance.')
|
||||
# VQ-BeT can predict action only after finishing action discretization.
|
||||
# We added a logit to force self.vqbet._action_head._vqvae_model.discretized to be True if not self.check_discretized() to account for the case of predicting with a pretrained model, but this shouldn't happen if you're learning from scratch, so set eval_freq greater than discretize_step.
|
||||
warnings.warn('To evaluate in the environment, the model was forced to stop learning the Residual VQ. If you are not evaluating with a pre-trained model, this can degrade overall performance. To avoid this warning, please set "eval_freq" greater than "discretize_step".')
|
||||
assert "observation.image" in batch
|
||||
assert "observation.state" in batch
|
||||
|
||||
|
@ -105,8 +107,9 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||
actions = self.vqbet(batch, rollout=True)[:, : self.config.n_action_pred_chunk]
|
||||
|
||||
# the dimension of returned action is (batch_size, n_action_pred_chunk, action_dim)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
# since the data in the action queue's dimension is (n_action_pred_chunk, batch_size, action_dim, we transpose the action and fill the queue
|
||||
self._queues["action"].extend(actions.transpose(0, 1))
|
||||
|
||||
action = self._queues["action"].popleft()
|
||||
|
@ -116,6 +119,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
|
||||
if not self.check_discretized():
|
||||
loss, n_different_codes, n_different_combinations = self.vqbet.discretize(self.config.discretize_step, batch['action'])
|
||||
return {"loss": loss, "n_different_codes": n_different_codes, "n_different_combinations": n_different_combinations}
|
||||
|
@ -126,13 +130,15 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
|
||||
|
||||
class VQBeTModel(nn.Module):
|
||||
"""
|
||||
TODO(jayLEE0301)
|
||||
"""
|
||||
def __init__(self, config: VQBeTConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.rgb_encoder = VQBeTRgbEncoder(config)
|
||||
|
||||
self.global_cond_dim = self.rgb_encoder.feature_dim
|
||||
|
||||
# action token and EOS token
|
||||
self._action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim)) # Batch, Timestep, Data type, GPT input dim
|
||||
|
@ -143,7 +149,7 @@ class VQBeTModel(nn.Module):
|
|||
hidden_channels=[self.config.gpt_input_dim]
|
||||
)
|
||||
self.obs_projector = MLP(
|
||||
self.global_cond_dim,
|
||||
self.rgb_encoder.feature_dim,
|
||||
hidden_channels=[self.config.gpt_input_dim]
|
||||
)
|
||||
self._policy = GPT(config)
|
||||
|
@ -152,7 +158,6 @@ class VQBeTModel(nn.Module):
|
|||
def discretize(self, discretize_step, actions):
|
||||
return self._action_head.discretize(discretize_step, actions)
|
||||
|
||||
# ========= inference ============
|
||||
def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor:
|
||||
# Input validation.
|
||||
assert set(batch).issuperset({"observation.state", "observation.image"})
|
||||
|
@ -163,10 +168,9 @@ class VQBeTModel(nn.Module):
|
|||
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
|
||||
# Separate batch and sequence dims.
|
||||
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
|
||||
# Concatenate state and image features then flatten to (B, global_cond_dim).
|
||||
|
||||
|
||||
global_cond = torch.cat([
|
||||
observation_feature = torch.cat([
|
||||
torch.unsqueeze(self.obs_projector(img_features), dim=2),
|
||||
torch.unsqueeze(self.state_projector(batch["observation.state"]), dim=2),
|
||||
self._action_token.repeat(batch_size, n_obs_steps, 1, 1)
|
||||
|
@ -177,12 +181,11 @@ class VQBeTModel(nn.Module):
|
|||
len_additional_action_token = self.config.n_action_pred_token-1
|
||||
action_token = self._action_token.repeat(batch_size, len_additional_action_token, 1)
|
||||
|
||||
# prompt_length = global_cond.shape[1]+1
|
||||
global_cond = torch.cat([global_cond, action_token], dim=1)
|
||||
observation_feature = torch.cat([observation_feature, action_token], dim=1)
|
||||
|
||||
|
||||
# get action features
|
||||
features = self._policy(global_cond)
|
||||
features = self._policy(observation_feature)
|
||||
historical_act_pred_index = np.arange(0, n_obs_steps) * 3 + 2 # TODO(jayLEE0301) make it compatible with other values
|
||||
features = torch.cat([
|
||||
features[:, historical_act_pred_index],
|
||||
|
@ -221,26 +224,18 @@ class VQBeTHead(nn.Module):
|
|||
"""
|
||||
|
||||
super().__init__()
|
||||
self.input_size = config.gpt_output_dim
|
||||
self.output_size = config.output_shapes["action"][0]
|
||||
self.hidden_size = config.mlp_hidden_dim
|
||||
self.offset_loss_weight = config.offset_loss_weight
|
||||
self.secondary_code_loss_weight = config.secondary_code_loss_weight
|
||||
self.config = config
|
||||
|
||||
self.vqvae_groups = config.vqvae_groups
|
||||
self.vqvae_n_embed = config.vqvae_n_embed # C(number of code integers)
|
||||
self.vqvae_embedding_dim = config.vqvae_embedding_dim # D(embedding dims)
|
||||
self.n_action_pred_chunk = config.n_action_pred_chunk # action chunk size
|
||||
|
||||
|
||||
self._map_to_cbet_preds_bin = MLP(
|
||||
in_channels=self.input_size,
|
||||
hidden_channels=[self.vqvae_groups * self.vqvae_n_embed],
|
||||
in_channels=config.gpt_output_dim,
|
||||
hidden_channels=[self.config.vqvae_groups * self.config.vqvae_n_embed],
|
||||
)
|
||||
self._map_to_cbet_preds_offset = MLP(
|
||||
in_channels=self.input_size,
|
||||
in_channels=config.gpt_output_dim,
|
||||
hidden_channels=[
|
||||
self.vqvae_groups * self.vqvae_n_embed * config.n_action_pred_chunk * self.output_size,
|
||||
self.config.vqvae_groups * self.config.vqvae_n_embed * config.n_action_pred_chunk * config.output_shapes["action"][0],
|
||||
],
|
||||
)
|
||||
# init vqvae
|
||||
|
@ -270,10 +265,10 @@ class VQBeTHead(nn.Module):
|
|||
cbet_logits = self._map_to_cbet_preds_bin(x)
|
||||
cbet_offsets = self._map_to_cbet_preds_offset(x)
|
||||
cbet_logits = einops.rearrange(
|
||||
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_groups
|
||||
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.config.vqvae_groups
|
||||
)
|
||||
cbet_offsets = einops.rearrange(
|
||||
cbet_offsets, "(NT) (G C WA) -> (NT) G C WA", G=self.vqvae_groups, C=self.vqvae_n_embed
|
||||
cbet_offsets, "(NT) (G C WA) -> (NT) G C WA", G=self.config.vqvae_groups, C=self.config.vqvae_n_embed
|
||||
)
|
||||
cbet_probs = torch.softmax(cbet_logits, dim=-1)
|
||||
NT, G, choices = cbet_probs.shape
|
||||
|
@ -285,7 +280,7 @@ class VQBeTHead(nn.Module):
|
|||
|
||||
indices = (
|
||||
torch.arange(NT).unsqueeze(1).cuda(),
|
||||
torch.arange(self.vqvae_groups).unsqueeze(0).cuda(),
|
||||
torch.arange(self.config.vqvae_groups).unsqueeze(0).cuda(),
|
||||
sampled_centers,
|
||||
)
|
||||
# Use advanced indexing to sample the values
|
||||
|
@ -293,7 +288,7 @@ class VQBeTHead(nn.Module):
|
|||
|
||||
sampled_offsets = sampled_offsets.sum(dim=1)
|
||||
centers = self._vqvae_model.draw_code_forward(sampled_centers).view(
|
||||
NT, -1, self.vqvae_embedding_dim
|
||||
NT, -1, self.config.vqvae_embedding_dim
|
||||
)
|
||||
return_decoder_input = einops.rearrange(
|
||||
centers.clone().detach(), "NT 1 D -> NT D"
|
||||
|
@ -304,7 +299,7 @@ class VQBeTHead(nn.Module):
|
|||
.detach()
|
||||
) # NT, A
|
||||
sampled_offsets = einops.rearrange(
|
||||
sampled_offsets, "NT (W A) -> NT W A", W=self.n_action_pred_chunk
|
||||
sampled_offsets, "NT (W A) -> NT W A", W=self.config.n_action_pred_chunk
|
||||
)
|
||||
predicted_action = decoded_action + sampled_offsets
|
||||
predicted_action = einops.rearrange(
|
||||
|
@ -312,7 +307,7 @@ class VQBeTHead(nn.Module):
|
|||
"(N T) W A -> N T (W A)",
|
||||
N=N,
|
||||
T=T,
|
||||
W=self.n_action_pred_chunk,
|
||||
W=self.config.n_action_pred_chunk,
|
||||
)
|
||||
|
||||
return {
|
||||
|
@ -333,7 +328,7 @@ class VQBeTHead(nn.Module):
|
|||
cbet_logits = pred["cbet_logits"]
|
||||
|
||||
predicted_action = einops.rearrange(
|
||||
predicted_action, "N T (W A) -> (N T) W A", W=self.n_action_pred_chunk
|
||||
predicted_action, "N T (W A) -> (N T) W A", W=self.config.n_action_pred_chunk
|
||||
)
|
||||
|
||||
action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A")
|
||||
|
@ -358,7 +353,7 @@ class VQBeTHead(nn.Module):
|
|||
cbet_logits[:, 1, :],
|
||||
action_bins[:, 1],
|
||||
)
|
||||
cbet_loss = cbet_loss1 * 5 + cbet_loss2 * self.secondary_code_loss_weight
|
||||
cbet_loss = cbet_loss1 * 5 + cbet_loss2 * self.config.secondary_code_loss_weight
|
||||
|
||||
equal_primary_code_rate = torch.sum(
|
||||
(action_bins[:, 0] == sampled_centers[:, 0]).int()
|
||||
|
@ -387,7 +382,7 @@ class VQBeTHead(nn.Module):
|
|||
)
|
||||
).max()
|
||||
|
||||
loss = cbet_loss + self.offset_loss_weight * offset_loss
|
||||
loss = cbet_loss + self.config.offset_loss_weight * offset_loss
|
||||
|
||||
loss_dict = {
|
||||
"loss": loss,
|
||||
|
@ -626,8 +621,7 @@ class VqVae(nn.Module):
|
|||
):
|
||||
|
||||
super(VqVae, self).__init__()
|
||||
self.n_action_pred_chunk = config.n_action_pred_chunk
|
||||
self.action_dim = config.output_shapes["action"][0]
|
||||
self.config = config
|
||||
|
||||
self.discretized = False
|
||||
self.optimized_steps = 0
|
||||
|
@ -637,25 +631,24 @@ class VqVae(nn.Module):
|
|||
num_quantizers=config.vqvae_groups,
|
||||
codebook_size=config.vqvae_n_embed,
|
||||
)
|
||||
self.embedding_dim = config.vqvae_embedding_dim
|
||||
|
||||
if self.n_action_pred_chunk == 1:
|
||||
if self.config.n_action_pred_chunk == 1:
|
||||
self.encoder = MLP(
|
||||
in_channels=self.action_dim,
|
||||
in_channels=self.config.output_shapes["action"][0],
|
||||
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, config.vqvae_embedding_dim],
|
||||
)
|
||||
self.decoder = MLP(
|
||||
in_channels=config.vqvae_embedding_dim,
|
||||
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.action_dim],
|
||||
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.config.output_shapes["action"][0]],
|
||||
)
|
||||
else:
|
||||
self.encoder = MLP(
|
||||
in_channels=self.action_dim * self.n_action_pred_chunk,
|
||||
in_channels=self.config.output_shapes["action"][0] * self.config.n_action_pred_chunk,
|
||||
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, config.vqvae_embedding_dim],
|
||||
)
|
||||
self.decoder = MLP(
|
||||
in_channels=config.vqvae_embedding_dim,
|
||||
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.action_dim * self.n_action_pred_chunk],
|
||||
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.config.output_shapes["action"][0] * self.config.n_action_pred_chunk],
|
||||
)
|
||||
|
||||
self.train()
|
||||
|
@ -690,15 +683,15 @@ class VqVae(nn.Module):
|
|||
|
||||
def get_action_from_latent(self, latent):
|
||||
output = self.decoder(latent)
|
||||
if self.n_action_pred_chunk == 1:
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.action_dim)
|
||||
if self.config.n_action_pred_chunk == 1:
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
|
||||
else:
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.action_dim)
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
|
||||
|
||||
def preprocess(self, state):
|
||||
if not torch.is_tensor(state):
|
||||
state = torch.FloatTensor(state.copy())
|
||||
if self.n_action_pred_chunk == 1:
|
||||
if self.config.n_action_pred_chunk == 1:
|
||||
state = state.squeeze(-2) # state.squeeze(-1)
|
||||
else:
|
||||
state = einops.rearrange(state, "N T A -> N (T A)")
|
||||
|
@ -717,7 +710,7 @@ class VqVae(nn.Module):
|
|||
if required_recon:
|
||||
recon_state = self.decoder(state_vq)
|
||||
recon_state_ae = self.decoder(state_rep)
|
||||
if self.n_action_pred_chunk == 1:
|
||||
if self.config.n_action_pred_chunk == 1:
|
||||
return state_vq, vq_code, recon_state, recon_state_ae
|
||||
else:
|
||||
return (
|
||||
|
@ -764,13 +757,13 @@ class VqVae(nn.Module):
|
|||
|
||||
|
||||
def pretrain_vqvae(vqvae_model, discretize_step, actions):
|
||||
if vqvae_model.n_action_pred_chunk == 1:
|
||||
if vqvae_model.config.n_action_pred_chunk == 1:
|
||||
# not using action chunk
|
||||
actions = actions.reshape(-1, 1, actions.shape[-1])
|
||||
else:
|
||||
# using action chunk
|
||||
slices = []
|
||||
slices.extend([actions[:, j:j+vqvae_model.n_action_pred_chunk, :] for j in range(actions.shape[1]+1-vqvae_model.n_action_pred_chunk)])
|
||||
slices.extend([actions[:, j:j+vqvae_model.config.n_action_pred_chunk, :] for j in range(actions.shape[1]+1-vqvae_model.config.n_action_pred_chunk)])
|
||||
actions = torch.cat(slices, dim=0)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue