add comment, change param names, delete kwargs
This commit is contained in:
parent
e301caf182
commit
340f7cfd6e
|
@ -96,7 +96,9 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
|
|
||||||
if not self.check_discretized():
|
if not self.check_discretized():
|
||||||
self.vqbet._action_head._vqvae_model.discretized = True
|
self.vqbet._action_head._vqvae_model.discretized = True
|
||||||
warnings.warn('To evaluate in the environment, the model was forced to stop learning the Residual VQ. If you are not evaluating with a pre-trained model, this can degrade overall performance.')
|
# VQ-BeT can predict action only after finishing action discretization.
|
||||||
|
# We added a logit to force self.vqbet._action_head._vqvae_model.discretized to be True if not self.check_discretized() to account for the case of predicting with a pretrained model, but this shouldn't happen if you're learning from scratch, so set eval_freq greater than discretize_step.
|
||||||
|
warnings.warn('To evaluate in the environment, the model was forced to stop learning the Residual VQ. If you are not evaluating with a pre-trained model, this can degrade overall performance. To avoid this warning, please set "eval_freq" greater than "discretize_step".')
|
||||||
assert "observation.image" in batch
|
assert "observation.image" in batch
|
||||||
assert "observation.state" in batch
|
assert "observation.state" in batch
|
||||||
|
|
||||||
|
@ -105,8 +107,9 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||||
actions = self.vqbet(batch, rollout=True)[:, : self.config.n_action_pred_chunk]
|
actions = self.vqbet(batch, rollout=True)[:, : self.config.n_action_pred_chunk]
|
||||||
|
|
||||||
|
# the dimension of returned action is (batch_size, n_action_pred_chunk, action_dim)
|
||||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
|
# since the data in the action queue's dimension is (n_action_pred_chunk, batch_size, action_dim, we transpose the action and fill the queue
|
||||||
self._queues["action"].extend(actions.transpose(0, 1))
|
self._queues["action"].extend(actions.transpose(0, 1))
|
||||||
|
|
||||||
action = self._queues["action"].popleft()
|
action = self._queues["action"].popleft()
|
||||||
|
@ -116,6 +119,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
"""Run the batch through the model and compute the loss for training or validation."""
|
"""Run the batch through the model and compute the loss for training or validation."""
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
batch = self.normalize_targets(batch)
|
batch = self.normalize_targets(batch)
|
||||||
|
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
|
||||||
if not self.check_discretized():
|
if not self.check_discretized():
|
||||||
loss, n_different_codes, n_different_combinations = self.vqbet.discretize(self.config.discretize_step, batch['action'])
|
loss, n_different_codes, n_different_combinations = self.vqbet.discretize(self.config.discretize_step, batch['action'])
|
||||||
return {"loss": loss, "n_different_codes": n_different_codes, "n_different_combinations": n_different_combinations}
|
return {"loss": loss, "n_different_codes": n_different_codes, "n_different_combinations": n_different_combinations}
|
||||||
|
@ -126,13 +130,15 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
|
|
||||||
|
|
||||||
class VQBeTModel(nn.Module):
|
class VQBeTModel(nn.Module):
|
||||||
|
"""
|
||||||
|
TODO(jayLEE0301)
|
||||||
|
"""
|
||||||
def __init__(self, config: VQBeTConfig):
|
def __init__(self, config: VQBeTConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.rgb_encoder = VQBeTRgbEncoder(config)
|
self.rgb_encoder = VQBeTRgbEncoder(config)
|
||||||
|
|
||||||
self.global_cond_dim = self.rgb_encoder.feature_dim
|
|
||||||
|
|
||||||
# action token and EOS token
|
# action token and EOS token
|
||||||
self._action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim)) # Batch, Timestep, Data type, GPT input dim
|
self._action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim)) # Batch, Timestep, Data type, GPT input dim
|
||||||
|
@ -143,7 +149,7 @@ class VQBeTModel(nn.Module):
|
||||||
hidden_channels=[self.config.gpt_input_dim]
|
hidden_channels=[self.config.gpt_input_dim]
|
||||||
)
|
)
|
||||||
self.obs_projector = MLP(
|
self.obs_projector = MLP(
|
||||||
self.global_cond_dim,
|
self.rgb_encoder.feature_dim,
|
||||||
hidden_channels=[self.config.gpt_input_dim]
|
hidden_channels=[self.config.gpt_input_dim]
|
||||||
)
|
)
|
||||||
self._policy = GPT(config)
|
self._policy = GPT(config)
|
||||||
|
@ -152,7 +158,6 @@ class VQBeTModel(nn.Module):
|
||||||
def discretize(self, discretize_step, actions):
|
def discretize(self, discretize_step, actions):
|
||||||
return self._action_head.discretize(discretize_step, actions)
|
return self._action_head.discretize(discretize_step, actions)
|
||||||
|
|
||||||
# ========= inference ============
|
|
||||||
def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor:
|
def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor:
|
||||||
# Input validation.
|
# Input validation.
|
||||||
assert set(batch).issuperset({"observation.state", "observation.image"})
|
assert set(batch).issuperset({"observation.state", "observation.image"})
|
||||||
|
@ -163,10 +168,9 @@ class VQBeTModel(nn.Module):
|
||||||
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
|
img_features = self.rgb_encoder(einops.rearrange(batch["observation.image"], "b n ... -> (b n) ..."))
|
||||||
# Separate batch and sequence dims.
|
# Separate batch and sequence dims.
|
||||||
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
|
img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size)
|
||||||
# Concatenate state and image features then flatten to (B, global_cond_dim).
|
|
||||||
|
|
||||||
|
|
||||||
global_cond = torch.cat([
|
observation_feature = torch.cat([
|
||||||
torch.unsqueeze(self.obs_projector(img_features), dim=2),
|
torch.unsqueeze(self.obs_projector(img_features), dim=2),
|
||||||
torch.unsqueeze(self.state_projector(batch["observation.state"]), dim=2),
|
torch.unsqueeze(self.state_projector(batch["observation.state"]), dim=2),
|
||||||
self._action_token.repeat(batch_size, n_obs_steps, 1, 1)
|
self._action_token.repeat(batch_size, n_obs_steps, 1, 1)
|
||||||
|
@ -177,12 +181,11 @@ class VQBeTModel(nn.Module):
|
||||||
len_additional_action_token = self.config.n_action_pred_token-1
|
len_additional_action_token = self.config.n_action_pred_token-1
|
||||||
action_token = self._action_token.repeat(batch_size, len_additional_action_token, 1)
|
action_token = self._action_token.repeat(batch_size, len_additional_action_token, 1)
|
||||||
|
|
||||||
# prompt_length = global_cond.shape[1]+1
|
observation_feature = torch.cat([observation_feature, action_token], dim=1)
|
||||||
global_cond = torch.cat([global_cond, action_token], dim=1)
|
|
||||||
|
|
||||||
|
|
||||||
# get action features
|
# get action features
|
||||||
features = self._policy(global_cond)
|
features = self._policy(observation_feature)
|
||||||
historical_act_pred_index = np.arange(0, n_obs_steps) * 3 + 2 # TODO(jayLEE0301) make it compatible with other values
|
historical_act_pred_index = np.arange(0, n_obs_steps) * 3 + 2 # TODO(jayLEE0301) make it compatible with other values
|
||||||
features = torch.cat([
|
features = torch.cat([
|
||||||
features[:, historical_act_pred_index],
|
features[:, historical_act_pred_index],
|
||||||
|
@ -221,26 +224,18 @@ class VQBeTHead(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_size = config.gpt_output_dim
|
self.config = config
|
||||||
self.output_size = config.output_shapes["action"][0]
|
|
||||||
self.hidden_size = config.mlp_hidden_dim
|
|
||||||
self.offset_loss_weight = config.offset_loss_weight
|
|
||||||
self.secondary_code_loss_weight = config.secondary_code_loss_weight
|
|
||||||
|
|
||||||
self.vqvae_groups = config.vqvae_groups
|
|
||||||
self.vqvae_n_embed = config.vqvae_n_embed # C(number of code integers)
|
|
||||||
self.vqvae_embedding_dim = config.vqvae_embedding_dim # D(embedding dims)
|
|
||||||
self.n_action_pred_chunk = config.n_action_pred_chunk # action chunk size
|
|
||||||
|
|
||||||
|
|
||||||
self._map_to_cbet_preds_bin = MLP(
|
self._map_to_cbet_preds_bin = MLP(
|
||||||
in_channels=self.input_size,
|
in_channels=config.gpt_output_dim,
|
||||||
hidden_channels=[self.vqvae_groups * self.vqvae_n_embed],
|
hidden_channels=[self.config.vqvae_groups * self.config.vqvae_n_embed],
|
||||||
)
|
)
|
||||||
self._map_to_cbet_preds_offset = MLP(
|
self._map_to_cbet_preds_offset = MLP(
|
||||||
in_channels=self.input_size,
|
in_channels=config.gpt_output_dim,
|
||||||
hidden_channels=[
|
hidden_channels=[
|
||||||
self.vqvae_groups * self.vqvae_n_embed * config.n_action_pred_chunk * self.output_size,
|
self.config.vqvae_groups * self.config.vqvae_n_embed * config.n_action_pred_chunk * config.output_shapes["action"][0],
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
# init vqvae
|
# init vqvae
|
||||||
|
@ -270,10 +265,10 @@ class VQBeTHead(nn.Module):
|
||||||
cbet_logits = self._map_to_cbet_preds_bin(x)
|
cbet_logits = self._map_to_cbet_preds_bin(x)
|
||||||
cbet_offsets = self._map_to_cbet_preds_offset(x)
|
cbet_offsets = self._map_to_cbet_preds_offset(x)
|
||||||
cbet_logits = einops.rearrange(
|
cbet_logits = einops.rearrange(
|
||||||
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_groups
|
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.config.vqvae_groups
|
||||||
)
|
)
|
||||||
cbet_offsets = einops.rearrange(
|
cbet_offsets = einops.rearrange(
|
||||||
cbet_offsets, "(NT) (G C WA) -> (NT) G C WA", G=self.vqvae_groups, C=self.vqvae_n_embed
|
cbet_offsets, "(NT) (G C WA) -> (NT) G C WA", G=self.config.vqvae_groups, C=self.config.vqvae_n_embed
|
||||||
)
|
)
|
||||||
cbet_probs = torch.softmax(cbet_logits, dim=-1)
|
cbet_probs = torch.softmax(cbet_logits, dim=-1)
|
||||||
NT, G, choices = cbet_probs.shape
|
NT, G, choices = cbet_probs.shape
|
||||||
|
@ -285,7 +280,7 @@ class VQBeTHead(nn.Module):
|
||||||
|
|
||||||
indices = (
|
indices = (
|
||||||
torch.arange(NT).unsqueeze(1).cuda(),
|
torch.arange(NT).unsqueeze(1).cuda(),
|
||||||
torch.arange(self.vqvae_groups).unsqueeze(0).cuda(),
|
torch.arange(self.config.vqvae_groups).unsqueeze(0).cuda(),
|
||||||
sampled_centers,
|
sampled_centers,
|
||||||
)
|
)
|
||||||
# Use advanced indexing to sample the values
|
# Use advanced indexing to sample the values
|
||||||
|
@ -293,7 +288,7 @@ class VQBeTHead(nn.Module):
|
||||||
|
|
||||||
sampled_offsets = sampled_offsets.sum(dim=1)
|
sampled_offsets = sampled_offsets.sum(dim=1)
|
||||||
centers = self._vqvae_model.draw_code_forward(sampled_centers).view(
|
centers = self._vqvae_model.draw_code_forward(sampled_centers).view(
|
||||||
NT, -1, self.vqvae_embedding_dim
|
NT, -1, self.config.vqvae_embedding_dim
|
||||||
)
|
)
|
||||||
return_decoder_input = einops.rearrange(
|
return_decoder_input = einops.rearrange(
|
||||||
centers.clone().detach(), "NT 1 D -> NT D"
|
centers.clone().detach(), "NT 1 D -> NT D"
|
||||||
|
@ -304,7 +299,7 @@ class VQBeTHead(nn.Module):
|
||||||
.detach()
|
.detach()
|
||||||
) # NT, A
|
) # NT, A
|
||||||
sampled_offsets = einops.rearrange(
|
sampled_offsets = einops.rearrange(
|
||||||
sampled_offsets, "NT (W A) -> NT W A", W=self.n_action_pred_chunk
|
sampled_offsets, "NT (W A) -> NT W A", W=self.config.n_action_pred_chunk
|
||||||
)
|
)
|
||||||
predicted_action = decoded_action + sampled_offsets
|
predicted_action = decoded_action + sampled_offsets
|
||||||
predicted_action = einops.rearrange(
|
predicted_action = einops.rearrange(
|
||||||
|
@ -312,7 +307,7 @@ class VQBeTHead(nn.Module):
|
||||||
"(N T) W A -> N T (W A)",
|
"(N T) W A -> N T (W A)",
|
||||||
N=N,
|
N=N,
|
||||||
T=T,
|
T=T,
|
||||||
W=self.n_action_pred_chunk,
|
W=self.config.n_action_pred_chunk,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
@ -333,7 +328,7 @@ class VQBeTHead(nn.Module):
|
||||||
cbet_logits = pred["cbet_logits"]
|
cbet_logits = pred["cbet_logits"]
|
||||||
|
|
||||||
predicted_action = einops.rearrange(
|
predicted_action = einops.rearrange(
|
||||||
predicted_action, "N T (W A) -> (N T) W A", W=self.n_action_pred_chunk
|
predicted_action, "N T (W A) -> (N T) W A", W=self.config.n_action_pred_chunk
|
||||||
)
|
)
|
||||||
|
|
||||||
action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A")
|
action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A")
|
||||||
|
@ -358,7 +353,7 @@ class VQBeTHead(nn.Module):
|
||||||
cbet_logits[:, 1, :],
|
cbet_logits[:, 1, :],
|
||||||
action_bins[:, 1],
|
action_bins[:, 1],
|
||||||
)
|
)
|
||||||
cbet_loss = cbet_loss1 * 5 + cbet_loss2 * self.secondary_code_loss_weight
|
cbet_loss = cbet_loss1 * 5 + cbet_loss2 * self.config.secondary_code_loss_weight
|
||||||
|
|
||||||
equal_primary_code_rate = torch.sum(
|
equal_primary_code_rate = torch.sum(
|
||||||
(action_bins[:, 0] == sampled_centers[:, 0]).int()
|
(action_bins[:, 0] == sampled_centers[:, 0]).int()
|
||||||
|
@ -387,7 +382,7 @@ class VQBeTHead(nn.Module):
|
||||||
)
|
)
|
||||||
).max()
|
).max()
|
||||||
|
|
||||||
loss = cbet_loss + self.offset_loss_weight * offset_loss
|
loss = cbet_loss + self.config.offset_loss_weight * offset_loss
|
||||||
|
|
||||||
loss_dict = {
|
loss_dict = {
|
||||||
"loss": loss,
|
"loss": loss,
|
||||||
|
@ -626,8 +621,7 @@ class VqVae(nn.Module):
|
||||||
):
|
):
|
||||||
|
|
||||||
super(VqVae, self).__init__()
|
super(VqVae, self).__init__()
|
||||||
self.n_action_pred_chunk = config.n_action_pred_chunk
|
self.config = config
|
||||||
self.action_dim = config.output_shapes["action"][0]
|
|
||||||
|
|
||||||
self.discretized = False
|
self.discretized = False
|
||||||
self.optimized_steps = 0
|
self.optimized_steps = 0
|
||||||
|
@ -637,25 +631,24 @@ class VqVae(nn.Module):
|
||||||
num_quantizers=config.vqvae_groups,
|
num_quantizers=config.vqvae_groups,
|
||||||
codebook_size=config.vqvae_n_embed,
|
codebook_size=config.vqvae_n_embed,
|
||||||
)
|
)
|
||||||
self.embedding_dim = config.vqvae_embedding_dim
|
|
||||||
|
|
||||||
if self.n_action_pred_chunk == 1:
|
if self.config.n_action_pred_chunk == 1:
|
||||||
self.encoder = MLP(
|
self.encoder = MLP(
|
||||||
in_channels=self.action_dim,
|
in_channels=self.config.output_shapes["action"][0],
|
||||||
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, config.vqvae_embedding_dim],
|
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, config.vqvae_embedding_dim],
|
||||||
)
|
)
|
||||||
self.decoder = MLP(
|
self.decoder = MLP(
|
||||||
in_channels=config.vqvae_embedding_dim,
|
in_channels=config.vqvae_embedding_dim,
|
||||||
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.action_dim],
|
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.config.output_shapes["action"][0]],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.encoder = MLP(
|
self.encoder = MLP(
|
||||||
in_channels=self.action_dim * self.n_action_pred_chunk,
|
in_channels=self.config.output_shapes["action"][0] * self.config.n_action_pred_chunk,
|
||||||
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, config.vqvae_embedding_dim],
|
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, config.vqvae_embedding_dim],
|
||||||
)
|
)
|
||||||
self.decoder = MLP(
|
self.decoder = MLP(
|
||||||
in_channels=config.vqvae_embedding_dim,
|
in_channels=config.vqvae_embedding_dim,
|
||||||
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.action_dim * self.n_action_pred_chunk],
|
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.config.output_shapes["action"][0] * self.config.n_action_pred_chunk],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.train()
|
self.train()
|
||||||
|
@ -690,15 +683,15 @@ class VqVae(nn.Module):
|
||||||
|
|
||||||
def get_action_from_latent(self, latent):
|
def get_action_from_latent(self, latent):
|
||||||
output = self.decoder(latent)
|
output = self.decoder(latent)
|
||||||
if self.n_action_pred_chunk == 1:
|
if self.config.n_action_pred_chunk == 1:
|
||||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.action_dim)
|
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
|
||||||
else:
|
else:
|
||||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.action_dim)
|
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
|
||||||
|
|
||||||
def preprocess(self, state):
|
def preprocess(self, state):
|
||||||
if not torch.is_tensor(state):
|
if not torch.is_tensor(state):
|
||||||
state = torch.FloatTensor(state.copy())
|
state = torch.FloatTensor(state.copy())
|
||||||
if self.n_action_pred_chunk == 1:
|
if self.config.n_action_pred_chunk == 1:
|
||||||
state = state.squeeze(-2) # state.squeeze(-1)
|
state = state.squeeze(-2) # state.squeeze(-1)
|
||||||
else:
|
else:
|
||||||
state = einops.rearrange(state, "N T A -> N (T A)")
|
state = einops.rearrange(state, "N T A -> N (T A)")
|
||||||
|
@ -717,7 +710,7 @@ class VqVae(nn.Module):
|
||||||
if required_recon:
|
if required_recon:
|
||||||
recon_state = self.decoder(state_vq)
|
recon_state = self.decoder(state_vq)
|
||||||
recon_state_ae = self.decoder(state_rep)
|
recon_state_ae = self.decoder(state_rep)
|
||||||
if self.n_action_pred_chunk == 1:
|
if self.config.n_action_pred_chunk == 1:
|
||||||
return state_vq, vq_code, recon_state, recon_state_ae
|
return state_vq, vq_code, recon_state, recon_state_ae
|
||||||
else:
|
else:
|
||||||
return (
|
return (
|
||||||
|
@ -764,13 +757,13 @@ class VqVae(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
def pretrain_vqvae(vqvae_model, discretize_step, actions):
|
def pretrain_vqvae(vqvae_model, discretize_step, actions):
|
||||||
if vqvae_model.n_action_pred_chunk == 1:
|
if vqvae_model.config.n_action_pred_chunk == 1:
|
||||||
# not using action chunk
|
# not using action chunk
|
||||||
actions = actions.reshape(-1, 1, actions.shape[-1])
|
actions = actions.reshape(-1, 1, actions.shape[-1])
|
||||||
else:
|
else:
|
||||||
# using action chunk
|
# using action chunk
|
||||||
slices = []
|
slices = []
|
||||||
slices.extend([actions[:, j:j+vqvae_model.n_action_pred_chunk, :] for j in range(actions.shape[1]+1-vqvae_model.n_action_pred_chunk)])
|
slices.extend([actions[:, j:j+vqvae_model.config.n_action_pred_chunk, :] for j in range(actions.shape[1]+1-vqvae_model.config.n_action_pred_chunk)])
|
||||||
actions = torch.cat(slices, dim=0)
|
actions = torch.cat(slices, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue