add comment, change param names, delete kwargs

This commit is contained in:
jayLEE0301 2024-05-23 19:08:37 -04:00
parent e301caf182
commit 340f7cfd6e
1 changed files with 40 additions and 47 deletions

View File

@ -96,7 +96,9 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
if not self.check_discretized(): if not self.check_discretized():
self.vqbet._action_head._vqvae_model.discretized = True self.vqbet._action_head._vqvae_model.discretized = True
warnings.warn('To evaluate in the environment, the model was forced to stop learning the Residual VQ. If you are not evaluating with a pre-trained model, this can degrade overall performance.') # VQ-BeT can predict action only after finishing action discretization.
# We added a logit to force self.vqbet._action_head._vqvae_model.discretized to be True if not self.check_discretized() to account for the case of predicting with a pretrained model, but this shouldn't happen if you're learning from scratch, so set eval_freq greater than discretize_step.
warnings.warn('To evaluate in the environment, the model was forced to stop learning the Residual VQ. If you are not evaluating with a pre-trained model, this can degrade overall performance. To avoid this warning, please set "eval_freq" greater than "discretize_step".')
assert "observation.image" in batch assert "observation.image" in batch
assert "observation.state" in batch assert "observation.state" in batch
@ -105,8 +107,9 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
actions = self.vqbet(batch, rollout=True)[:, : self.config.n_action_pred_chunk] actions = self.vqbet(batch, rollout=True)[:, : self.config.n_action_pred_chunk]
# the dimension of returned action is (batch_size, n_action_pred_chunk, action_dim)
actions = self.unnormalize_outputs({"action": actions})["action"] actions = self.unnormalize_outputs({"action": actions})["action"]
# since the data in the action queue's dimension is (n_action_pred_chunk, batch_size, action_dim, we transpose the action and fill the queue
self._queues["action"].extend(actions.transpose(0, 1)) self._queues["action"].extend(actions.transpose(0, 1))
action = self._queues["action"].popleft() action = self._queues["action"].popleft()
@ -116,6 +119,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
"""Run the batch through the model and compute the loss for training or validation.""" """Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch) batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch) batch = self.normalize_targets(batch)
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
if not self.check_discretized(): if not self.check_discretized():
loss, n_different_codes, n_different_combinations = self.vqbet.discretize(self.config.discretize_step, batch['action']) loss, n_different_codes, n_different_combinations = self.vqbet.discretize(self.config.discretize_step, batch['action'])
return {"loss": loss, "n_different_codes": n_different_codes, "n_different_combinations": n_different_combinations} return {"loss": loss, "n_different_codes": n_different_codes, "n_different_combinations": n_different_combinations}
@ -126,13 +130,15 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
class VQBeTModel(nn.Module): class VQBeTModel(nn.Module):
"""
TODO(jayLEE0301)
"""
def __init__(self, config: VQBeTConfig): def __init__(self, config: VQBeTConfig):
super().__init__() super().__init__()
self.config = config self.config = config
self.rgb_encoder = VQBeTRgbEncoder(config) self.rgb_encoder = VQBeTRgbEncoder(config)
self.global_cond_dim = self.rgb_encoder.feature_dim
# action token and EOS token # action token and EOS token
self._action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim)) # Batch, Timestep, Data type, GPT input dim self._action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim)) # Batch, Timestep, Data type, GPT input dim
@ -143,7 +149,7 @@ class VQBeTModel(nn.Module):
hidden_channels=[self.config.gpt_input_dim] hidden_channels=[self.config.gpt_input_dim]
) )
self.obs_projector = MLP( self.obs_projector = MLP(
self.global_cond_dim, self.rgb_encoder.feature_dim,
hidden_channels=[self.config.gpt_input_dim] hidden_channels=[self.config.gpt_input_dim]
) )
self._policy = GPT(config) self._policy = GPT(config)
@ -152,7 +158,6 @@ class VQBeTModel(nn.Module):
def discretize(self, discretize_step, actions): def discretize(self, discretize_step, actions):
return self._action_head.discretize(discretize_step, actions) return self._action_head.discretize(discretize_step, actions)
# ========= inference ============
def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor: def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor:
# Input validation. # Input validation.
assert set(batch).issuperset({"observation.state", "observation.image"}) assert set(batch).issuperset({"observation.state", "observation.image"})
@ -163,10 +168,9 @@ class VQBeTModel(nn.Module):
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ...")) img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
# Separate batch and sequence dims. # Separate batch and sequence dims.
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size) img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
# Concatenate state and image features then flatten to (B, global_cond_dim).
global_cond = torch.cat([ observation_feature = torch.cat([
torch.unsqueeze(self.obs_projector(img_features), dim=2), torch.unsqueeze(self.obs_projector(img_features), dim=2),
torch.unsqueeze(self.state_projector(batch["observation.state"]), dim=2), torch.unsqueeze(self.state_projector(batch["observation.state"]), dim=2),
self._action_token.repeat(batch_size, n_obs_steps, 1, 1) self._action_token.repeat(batch_size, n_obs_steps, 1, 1)
@ -177,12 +181,11 @@ class VQBeTModel(nn.Module):
len_additional_action_token = self.config.n_action_pred_token-1 len_additional_action_token = self.config.n_action_pred_token-1
action_token = self._action_token.repeat(batch_size, len_additional_action_token, 1) action_token = self._action_token.repeat(batch_size, len_additional_action_token, 1)
# prompt_length = global_cond.shape[1]+1 observation_feature = torch.cat([observation_feature, action_token], dim=1)
global_cond = torch.cat([global_cond, action_token], dim=1)
# get action features # get action features
features = self._policy(global_cond) features = self._policy(observation_feature)
historical_act_pred_index = np.arange(0, n_obs_steps) * 3 + 2 # TODO(jayLEE0301) make it compatible with other values historical_act_pred_index = np.arange(0, n_obs_steps) * 3 + 2 # TODO(jayLEE0301) make it compatible with other values
features = torch.cat([ features = torch.cat([
features[:, historical_act_pred_index], features[:, historical_act_pred_index],
@ -221,26 +224,18 @@ class VQBeTHead(nn.Module):
""" """
super().__init__() super().__init__()
self.input_size = config.gpt_output_dim self.config = config
self.output_size = config.output_shapes["action"][0]
self.hidden_size = config.mlp_hidden_dim
self.offset_loss_weight = config.offset_loss_weight
self.secondary_code_loss_weight = config.secondary_code_loss_weight
self.vqvae_groups = config.vqvae_groups
self.vqvae_n_embed = config.vqvae_n_embed # C(number of code integers)
self.vqvae_embedding_dim = config.vqvae_embedding_dim # D(embedding dims)
self.n_action_pred_chunk = config.n_action_pred_chunk # action chunk size
self._map_to_cbet_preds_bin = MLP( self._map_to_cbet_preds_bin = MLP(
in_channels=self.input_size, in_channels=config.gpt_output_dim,
hidden_channels=[self.vqvae_groups * self.vqvae_n_embed], hidden_channels=[self.config.vqvae_groups * self.config.vqvae_n_embed],
) )
self._map_to_cbet_preds_offset = MLP( self._map_to_cbet_preds_offset = MLP(
in_channels=self.input_size, in_channels=config.gpt_output_dim,
hidden_channels=[ hidden_channels=[
self.vqvae_groups * self.vqvae_n_embed * config.n_action_pred_chunk * self.output_size, self.config.vqvae_groups * self.config.vqvae_n_embed * config.n_action_pred_chunk * config.output_shapes["action"][0],
], ],
) )
# init vqvae # init vqvae
@ -270,10 +265,10 @@ class VQBeTHead(nn.Module):
cbet_logits = self._map_to_cbet_preds_bin(x) cbet_logits = self._map_to_cbet_preds_bin(x)
cbet_offsets = self._map_to_cbet_preds_offset(x) cbet_offsets = self._map_to_cbet_preds_offset(x)
cbet_logits = einops.rearrange( cbet_logits = einops.rearrange(
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_groups cbet_logits, "(NT) (G C) -> (NT) G C", G=self.config.vqvae_groups
) )
cbet_offsets = einops.rearrange( cbet_offsets = einops.rearrange(
cbet_offsets, "(NT) (G C WA) -> (NT) G C WA", G=self.vqvae_groups, C=self.vqvae_n_embed cbet_offsets, "(NT) (G C WA) -> (NT) G C WA", G=self.config.vqvae_groups, C=self.config.vqvae_n_embed
) )
cbet_probs = torch.softmax(cbet_logits, dim=-1) cbet_probs = torch.softmax(cbet_logits, dim=-1)
NT, G, choices = cbet_probs.shape NT, G, choices = cbet_probs.shape
@ -285,7 +280,7 @@ class VQBeTHead(nn.Module):
indices = ( indices = (
torch.arange(NT).unsqueeze(1).cuda(), torch.arange(NT).unsqueeze(1).cuda(),
torch.arange(self.vqvae_groups).unsqueeze(0).cuda(), torch.arange(self.config.vqvae_groups).unsqueeze(0).cuda(),
sampled_centers, sampled_centers,
) )
# Use advanced indexing to sample the values # Use advanced indexing to sample the values
@ -293,7 +288,7 @@ class VQBeTHead(nn.Module):
sampled_offsets = sampled_offsets.sum(dim=1) sampled_offsets = sampled_offsets.sum(dim=1)
centers = self._vqvae_model.draw_code_forward(sampled_centers).view( centers = self._vqvae_model.draw_code_forward(sampled_centers).view(
NT, -1, self.vqvae_embedding_dim NT, -1, self.config.vqvae_embedding_dim
) )
return_decoder_input = einops.rearrange( return_decoder_input = einops.rearrange(
centers.clone().detach(), "NT 1 D -> NT D" centers.clone().detach(), "NT 1 D -> NT D"
@ -304,7 +299,7 @@ class VQBeTHead(nn.Module):
.detach() .detach()
) # NT, A ) # NT, A
sampled_offsets = einops.rearrange( sampled_offsets = einops.rearrange(
sampled_offsets, "NT (W A) -> NT W A", W=self.n_action_pred_chunk sampled_offsets, "NT (W A) -> NT W A", W=self.config.n_action_pred_chunk
) )
predicted_action = decoded_action + sampled_offsets predicted_action = decoded_action + sampled_offsets
predicted_action = einops.rearrange( predicted_action = einops.rearrange(
@ -312,7 +307,7 @@ class VQBeTHead(nn.Module):
"(N T) W A -> N T (W A)", "(N T) W A -> N T (W A)",
N=N, N=N,
T=T, T=T,
W=self.n_action_pred_chunk, W=self.config.n_action_pred_chunk,
) )
return { return {
@ -333,7 +328,7 @@ class VQBeTHead(nn.Module):
cbet_logits = pred["cbet_logits"] cbet_logits = pred["cbet_logits"]
predicted_action = einops.rearrange( predicted_action = einops.rearrange(
predicted_action, "N T (W A) -> (N T) W A", W=self.n_action_pred_chunk predicted_action, "N T (W A) -> (N T) W A", W=self.config.n_action_pred_chunk
) )
action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A") action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A")
@ -358,7 +353,7 @@ class VQBeTHead(nn.Module):
cbet_logits[:, 1, :], cbet_logits[:, 1, :],
action_bins[:, 1], action_bins[:, 1],
) )
cbet_loss = cbet_loss1 * 5 + cbet_loss2 * self.secondary_code_loss_weight cbet_loss = cbet_loss1 * 5 + cbet_loss2 * self.config.secondary_code_loss_weight
equal_primary_code_rate = torch.sum( equal_primary_code_rate = torch.sum(
(action_bins[:, 0] == sampled_centers[:, 0]).int() (action_bins[:, 0] == sampled_centers[:, 0]).int()
@ -387,7 +382,7 @@ class VQBeTHead(nn.Module):
) )
).max() ).max()
loss = cbet_loss + self.offset_loss_weight * offset_loss loss = cbet_loss + self.config.offset_loss_weight * offset_loss
loss_dict = { loss_dict = {
"loss": loss, "loss": loss,
@ -626,8 +621,7 @@ class VqVae(nn.Module):
): ):
super(VqVae, self).__init__() super(VqVae, self).__init__()
self.n_action_pred_chunk = config.n_action_pred_chunk self.config = config
self.action_dim = config.output_shapes["action"][0]
self.discretized = False self.discretized = False
self.optimized_steps = 0 self.optimized_steps = 0
@ -637,25 +631,24 @@ class VqVae(nn.Module):
num_quantizers=config.vqvae_groups, num_quantizers=config.vqvae_groups,
codebook_size=config.vqvae_n_embed, codebook_size=config.vqvae_n_embed,
) )
self.embedding_dim = config.vqvae_embedding_dim
if self.n_action_pred_chunk == 1: if self.config.n_action_pred_chunk == 1:
self.encoder = MLP( self.encoder = MLP(
in_channels=self.action_dim, in_channels=self.config.output_shapes["action"][0],
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, config.vqvae_embedding_dim], hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, config.vqvae_embedding_dim],
) )
self.decoder = MLP( self.decoder = MLP(
in_channels=config.vqvae_embedding_dim, in_channels=config.vqvae_embedding_dim,
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.action_dim], hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.config.output_shapes["action"][0]],
) )
else: else:
self.encoder = MLP( self.encoder = MLP(
in_channels=self.action_dim * self.n_action_pred_chunk, in_channels=self.config.output_shapes["action"][0] * self.config.n_action_pred_chunk,
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, config.vqvae_embedding_dim], hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, config.vqvae_embedding_dim],
) )
self.decoder = MLP( self.decoder = MLP(
in_channels=config.vqvae_embedding_dim, in_channels=config.vqvae_embedding_dim,
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.action_dim * self.n_action_pred_chunk], hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.config.output_shapes["action"][0] * self.config.n_action_pred_chunk],
) )
self.train() self.train()
@ -690,15 +683,15 @@ class VqVae(nn.Module):
def get_action_from_latent(self, latent): def get_action_from_latent(self, latent):
output = self.decoder(latent) output = self.decoder(latent)
if self.n_action_pred_chunk == 1: if self.config.n_action_pred_chunk == 1:
return einops.rearrange(output, "N (T A) -> N T A", A=self.action_dim) return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
else: else:
return einops.rearrange(output, "N (T A) -> N T A", A=self.action_dim) return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
def preprocess(self, state): def preprocess(self, state):
if not torch.is_tensor(state): if not torch.is_tensor(state):
state = torch.FloatTensor(state.copy()) state = torch.FloatTensor(state.copy())
if self.n_action_pred_chunk == 1: if self.config.n_action_pred_chunk == 1:
state = state.squeeze(-2) # state.squeeze(-1) state = state.squeeze(-2) # state.squeeze(-1)
else: else:
state = einops.rearrange(state, "N T A -> N (T A)") state = einops.rearrange(state, "N T A -> N (T A)")
@ -717,7 +710,7 @@ class VqVae(nn.Module):
if required_recon: if required_recon:
recon_state = self.decoder(state_vq) recon_state = self.decoder(state_vq)
recon_state_ae = self.decoder(state_rep) recon_state_ae = self.decoder(state_rep)
if self.n_action_pred_chunk == 1: if self.config.n_action_pred_chunk == 1:
return state_vq, vq_code, recon_state, recon_state_ae return state_vq, vq_code, recon_state, recon_state_ae
else: else:
return ( return (
@ -764,13 +757,13 @@ class VqVae(nn.Module):
def pretrain_vqvae(vqvae_model, discretize_step, actions): def pretrain_vqvae(vqvae_model, discretize_step, actions):
if vqvae_model.n_action_pred_chunk == 1: if vqvae_model.config.n_action_pred_chunk == 1:
# not using action chunk # not using action chunk
actions = actions.reshape(-1, 1, actions.shape[-1]) actions = actions.reshape(-1, 1, actions.shape[-1])
else: else:
# using action chunk # using action chunk
slices = [] slices = []
slices.extend([actions[:, j:j+vqvae_model.n_action_pred_chunk, :] for j in range(actions.shape[1]+1-vqvae_model.n_action_pred_chunk)]) slices.extend([actions[:, j:j+vqvae_model.config.n_action_pred_chunk, :] for j in range(actions.shape[1]+1-vqvae_model.config.n_action_pred_chunk)])
actions = torch.cat(slices, dim=0) actions = torch.cat(slices, dim=0)