add comment, change param names, delete kwargs

This commit is contained in:
jayLEE0301 2024-05-23 19:08:37 -04:00
parent e301caf182
commit 340f7cfd6e
1 changed files with 40 additions and 47 deletions

View File

@ -96,7 +96,9 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
if not self.check_discretized():
self.vqbet._action_head._vqvae_model.discretized = True
warnings.warn('To evaluate in the environment, the model was forced to stop learning the Residual VQ. If you are not evaluating with a pre-trained model, this can degrade overall performance.')
# VQ-BeT can predict action only after finishing action discretization.
# We added a logit to force self.vqbet._action_head._vqvae_model.discretized to be True if not self.check_discretized() to account for the case of predicting with a pretrained model, but this shouldn't happen if you're learning from scratch, so set eval_freq greater than discretize_step.
warnings.warn('To evaluate in the environment, the model was forced to stop learning the Residual VQ. If you are not evaluating with a pre-trained model, this can degrade overall performance. To avoid this warning, please set "eval_freq" greater than "discretize_step".')
assert "observation.image" in batch
assert "observation.state" in batch
@ -105,8 +107,9 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
actions = self.vqbet(batch, rollout=True)[:, : self.config.n_action_pred_chunk]
# the dimension of returned action is (batch_size, n_action_pred_chunk, action_dim)
actions = self.unnormalize_outputs({"action": actions})["action"]
# since the data in the action queue's dimension is (n_action_pred_chunk, batch_size, action_dim, we transpose the action and fill the queue
self._queues["action"].extend(actions.transpose(0, 1))
action = self._queues["action"].popleft()
@ -116,6 +119,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
if not self.check_discretized():
loss, n_different_codes, n_different_combinations = self.vqbet.discretize(self.config.discretize_step, batch['action'])
return {"loss": loss, "n_different_codes": n_different_codes, "n_different_combinations": n_different_combinations}
@ -126,13 +130,15 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
class VQBeTModel(nn.Module):
"""
TODO(jayLEE0301)
"""
def __init__(self, config: VQBeTConfig):
super().__init__()
self.config = config
self.rgb_encoder = VQBeTRgbEncoder(config)
self.global_cond_dim = self.rgb_encoder.feature_dim
# action token and EOS token
self._action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim)) # Batch, Timestep, Data type, GPT input dim
@ -143,7 +149,7 @@ class VQBeTModel(nn.Module):
hidden_channels=[self.config.gpt_input_dim]
)
self.obs_projector = MLP(
self.global_cond_dim,
self.rgb_encoder.feature_dim,
hidden_channels=[self.config.gpt_input_dim]
)
self._policy = GPT(config)
@ -152,7 +158,6 @@ class VQBeTModel(nn.Module):
def discretize(self, discretize_step, actions):
return self._action_head.discretize(discretize_step, actions)
# ========= inference ============
def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor:
# Input validation.
assert set(batch).issuperset({"observation.state", "observation.image"})
@ -163,10 +168,9 @@ class VQBeTModel(nn.Module):
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
# Separate batch and sequence dims.
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
# Concatenate state and image features then flatten to (B, global_cond_dim).
global_cond = torch.cat([
observation_feature = torch.cat([
torch.unsqueeze(self.obs_projector(img_features), dim=2),
torch.unsqueeze(self.state_projector(batch["observation.state"]), dim=2),
self._action_token.repeat(batch_size, n_obs_steps, 1, 1)
@ -177,12 +181,11 @@ class VQBeTModel(nn.Module):
len_additional_action_token = self.config.n_action_pred_token-1
action_token = self._action_token.repeat(batch_size, len_additional_action_token, 1)
# prompt_length = global_cond.shape[1]+1
global_cond = torch.cat([global_cond, action_token], dim=1)
observation_feature = torch.cat([observation_feature, action_token], dim=1)
# get action features
features = self._policy(global_cond)
features = self._policy(observation_feature)
historical_act_pred_index = np.arange(0, n_obs_steps) * 3 + 2 # TODO(jayLEE0301) make it compatible with other values
features = torch.cat([
features[:, historical_act_pred_index],
@ -221,26 +224,18 @@ class VQBeTHead(nn.Module):
"""
super().__init__()
self.input_size = config.gpt_output_dim
self.output_size = config.output_shapes["action"][0]
self.hidden_size = config.mlp_hidden_dim
self.offset_loss_weight = config.offset_loss_weight
self.secondary_code_loss_weight = config.secondary_code_loss_weight
self.config = config
self.vqvae_groups = config.vqvae_groups
self.vqvae_n_embed = config.vqvae_n_embed # C(number of code integers)
self.vqvae_embedding_dim = config.vqvae_embedding_dim # D(embedding dims)
self.n_action_pred_chunk = config.n_action_pred_chunk # action chunk size
self._map_to_cbet_preds_bin = MLP(
in_channels=self.input_size,
hidden_channels=[self.vqvae_groups * self.vqvae_n_embed],
in_channels=config.gpt_output_dim,
hidden_channels=[self.config.vqvae_groups * self.config.vqvae_n_embed],
)
self._map_to_cbet_preds_offset = MLP(
in_channels=self.input_size,
in_channels=config.gpt_output_dim,
hidden_channels=[
self.vqvae_groups * self.vqvae_n_embed * config.n_action_pred_chunk * self.output_size,
self.config.vqvae_groups * self.config.vqvae_n_embed * config.n_action_pred_chunk * config.output_shapes["action"][0],
],
)
# init vqvae
@ -270,10 +265,10 @@ class VQBeTHead(nn.Module):
cbet_logits = self._map_to_cbet_preds_bin(x)
cbet_offsets = self._map_to_cbet_preds_offset(x)
cbet_logits = einops.rearrange(
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_groups
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.config.vqvae_groups
)
cbet_offsets = einops.rearrange(
cbet_offsets, "(NT) (G C WA) -> (NT) G C WA", G=self.vqvae_groups, C=self.vqvae_n_embed
cbet_offsets, "(NT) (G C WA) -> (NT) G C WA", G=self.config.vqvae_groups, C=self.config.vqvae_n_embed
)
cbet_probs = torch.softmax(cbet_logits, dim=-1)
NT, G, choices = cbet_probs.shape
@ -285,7 +280,7 @@ class VQBeTHead(nn.Module):
indices = (
torch.arange(NT).unsqueeze(1).cuda(),
torch.arange(self.vqvae_groups).unsqueeze(0).cuda(),
torch.arange(self.config.vqvae_groups).unsqueeze(0).cuda(),
sampled_centers,
)
# Use advanced indexing to sample the values
@ -293,7 +288,7 @@ class VQBeTHead(nn.Module):
sampled_offsets = sampled_offsets.sum(dim=1)
centers = self._vqvae_model.draw_code_forward(sampled_centers).view(
NT, -1, self.vqvae_embedding_dim
NT, -1, self.config.vqvae_embedding_dim
)
return_decoder_input = einops.rearrange(
centers.clone().detach(), "NT 1 D -> NT D"
@ -304,7 +299,7 @@ class VQBeTHead(nn.Module):
.detach()
) # NT, A
sampled_offsets = einops.rearrange(
sampled_offsets, "NT (W A) -> NT W A", W=self.n_action_pred_chunk
sampled_offsets, "NT (W A) -> NT W A", W=self.config.n_action_pred_chunk
)
predicted_action = decoded_action + sampled_offsets
predicted_action = einops.rearrange(
@ -312,7 +307,7 @@ class VQBeTHead(nn.Module):
"(N T) W A -> N T (W A)",
N=N,
T=T,
W=self.n_action_pred_chunk,
W=self.config.n_action_pred_chunk,
)
return {
@ -333,7 +328,7 @@ class VQBeTHead(nn.Module):
cbet_logits = pred["cbet_logits"]
predicted_action = einops.rearrange(
predicted_action, "N T (W A) -> (N T) W A", W=self.n_action_pred_chunk
predicted_action, "N T (W A) -> (N T) W A", W=self.config.n_action_pred_chunk
)
action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A")
@ -358,7 +353,7 @@ class VQBeTHead(nn.Module):
cbet_logits[:, 1, :],
action_bins[:, 1],
)
cbet_loss = cbet_loss1 * 5 + cbet_loss2 * self.secondary_code_loss_weight
cbet_loss = cbet_loss1 * 5 + cbet_loss2 * self.config.secondary_code_loss_weight
equal_primary_code_rate = torch.sum(
(action_bins[:, 0] == sampled_centers[:, 0]).int()
@ -387,7 +382,7 @@ class VQBeTHead(nn.Module):
)
).max()
loss = cbet_loss + self.offset_loss_weight * offset_loss
loss = cbet_loss + self.config.offset_loss_weight * offset_loss
loss_dict = {
"loss": loss,
@ -626,8 +621,7 @@ class VqVae(nn.Module):
):
super(VqVae, self).__init__()
self.n_action_pred_chunk = config.n_action_pred_chunk
self.action_dim = config.output_shapes["action"][0]
self.config = config
self.discretized = False
self.optimized_steps = 0
@ -637,25 +631,24 @@ class VqVae(nn.Module):
num_quantizers=config.vqvae_groups,
codebook_size=config.vqvae_n_embed,
)
self.embedding_dim = config.vqvae_embedding_dim
if self.n_action_pred_chunk == 1:
if self.config.n_action_pred_chunk == 1:
self.encoder = MLP(
in_channels=self.action_dim,
in_channels=self.config.output_shapes["action"][0],
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, config.vqvae_embedding_dim],
)
self.decoder = MLP(
in_channels=config.vqvae_embedding_dim,
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.action_dim],
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.config.output_shapes["action"][0]],
)
else:
self.encoder = MLP(
in_channels=self.action_dim * self.n_action_pred_chunk,
in_channels=self.config.output_shapes["action"][0] * self.config.n_action_pred_chunk,
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, config.vqvae_embedding_dim],
)
self.decoder = MLP(
in_channels=config.vqvae_embedding_dim,
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.action_dim * self.n_action_pred_chunk],
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.config.output_shapes["action"][0] * self.config.n_action_pred_chunk],
)
self.train()
@ -690,15 +683,15 @@ class VqVae(nn.Module):
def get_action_from_latent(self, latent):
output = self.decoder(latent)
if self.n_action_pred_chunk == 1:
return einops.rearrange(output, "N (T A) -> N T A", A=self.action_dim)
if self.config.n_action_pred_chunk == 1:
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
else:
return einops.rearrange(output, "N (T A) -> N T A", A=self.action_dim)
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
def preprocess(self, state):
if not torch.is_tensor(state):
state = torch.FloatTensor(state.copy())
if self.n_action_pred_chunk == 1:
if self.config.n_action_pred_chunk == 1:
state = state.squeeze(-2) # state.squeeze(-1)
else:
state = einops.rearrange(state, "N T A -> N (T A)")
@ -717,7 +710,7 @@ class VqVae(nn.Module):
if required_recon:
recon_state = self.decoder(state_vq)
recon_state_ae = self.decoder(state_rep)
if self.n_action_pred_chunk == 1:
if self.config.n_action_pred_chunk == 1:
return state_vq, vq_code, recon_state, recon_state_ae
else:
return (
@ -764,13 +757,13 @@ class VqVae(nn.Module):
def pretrain_vqvae(vqvae_model, discretize_step, actions):
if vqvae_model.n_action_pred_chunk == 1:
if vqvae_model.config.n_action_pred_chunk == 1:
# not using action chunk
actions = actions.reshape(-1, 1, actions.shape[-1])
else:
# using action chunk
slices = []
slices.extend([actions[:, j:j+vqvae_model.n_action_pred_chunk, :] for j in range(actions.shape[1]+1-vqvae_model.n_action_pred_chunk)])
slices.extend([actions[:, j:j+vqvae_model.config.n_action_pred_chunk, :] for j in range(actions.shape[1]+1-vqvae_model.config.n_action_pred_chunk)])
actions = torch.cat(slices, dim=0)