change n_action_pred_chunk -> action_chunk_size

This commit is contained in:
jayLEE0301 2024-06-04 17:59:49 -04:00
parent bc9e6874fc
commit 6d72847bfe
3 changed files with 24 additions and 24 deletions

View File

@ -14,7 +14,7 @@ class VQBeTConfig:
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back).
n_action_pred_token: Number of future tokens that VQ-BeT predicts.
n_action_pred_chunk: Action chunk size of each aciton prediction token.
action_chunk_size: Action chunk size of each action prediction token.
input_shapes: A dictionary defining the shapes of the input data for the policy.
The key represents the input data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "observation.image" refers to an input from
@ -63,7 +63,7 @@ class VQBeTConfig:
# Inputs / output structure.
n_obs_steps: int = 5
n_action_pred_token: int = 3
n_action_pred_chunk: int = 5
action_chunk_size: int = 5
input_shapes: dict[str, list[int]] = field(
default_factory=lambda: {

View File

@ -70,7 +70,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
self._queues = {
"observation.image": deque(maxlen=self.config.n_obs_steps),
"observation.state": deque(maxlen=self.config.n_obs_steps),
"action": deque(maxlen=self.config.n_action_pred_chunk),
"action": deque(maxlen=self.config.action_chunk_size),
}
@torch.no_grad
@ -94,11 +94,11 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
if len(self._queues["action"]) == 0:
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
actions = self.vqbet(batch, rollout=True)[:, : self.config.n_action_pred_chunk]
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
# the dimension of returned action is (batch_size, n_action_pred_chunk, action_dim)
# the dimension of returned action is (batch_size, action_chunk_size, action_dim)
actions = self.unnormalize_outputs({"action": actions})["action"]
# since the data in the action queue's dimension is (n_action_pred_chunk, batch_size, action_dim), we transpose the action and fill the queue
# since the data in the action queue's dimension is (action_chunk_size, batch_size, action_dim), we transpose the action and fill the queue
self._queues["action"].extend(actions.transpose(0, 1))
action = self._queues["action"].popleft()
@ -200,7 +200,7 @@ class VQBeTModel(nn.Module):
and finally generates a prediction for the action chunks.
-------------------------------** legend **-------------------------------
n = n_obs_steps, p = n_action_pred_token, c = n_action_pred_chunk)
n = n_obs_steps, p = n_action_pred_token, c = action_chunk_size)
o_{t} : visual observation at timestep {t}
s_{t} : state observation at timestep {t}
a_{t} : action at timestep {t}
@ -319,12 +319,12 @@ class VQBeTModel(nn.Module):
)
# if rollout, VQ-BeT don't calculate loss
if rollout:
return pred_action["predicted_action"][:, n_obs_steps-1, :].reshape(batch_size, self.config.n_action_pred_chunk, -1)
return pred_action["predicted_action"][:, n_obs_steps-1, :].reshape(batch_size, self.config.action_chunk_size, -1)
# else, it calculate overall loss (bin prediction loss, and offset loss)
else:
action = batch["action"]
n, total_w, act_dim = action.shape
act_w = self.config.n_action_pred_chunk
act_w = self.config.action_chunk_size
num_token = total_w + 1 - act_w
output_shape = (n, num_token, act_w, act_dim)
output = torch.empty(output_shape).to(action.device)
@ -353,10 +353,10 @@ class VQBeTHead(nn.Module):
self.map_to_cbet_preds_offset: output the predicted offsets for all the codes in all the layers.
The input dimension of ` self.map_to_cbet_preds_offset` is same with the output of GPT,
and the output dimension of ` self.map_to_cbet_preds_offset` is `self.config.vqvae_groups * self.config.vqvae_n_embed * config.n_action_pred_chunk * config.output_shapes["action"][0]`, where
and the output dimension of ` self.map_to_cbet_preds_offset` is `self.config.vqvae_groups * self.config.vqvae_n_embed * config.action_chunk_size * config.output_shapes["action"][0]`, where
`self.config.vqvae_groups` is number of RVQ layers,
`self.config.vqvae_n_embed` is codebook size of RVQ,
`config.n_action_pred_chunk is action chunk size of each token, and
`config.action_chunk_size is action chunk size of each token, and
`config.output_shapes["action"][0]` is the dimension of action
"""
@ -372,7 +372,7 @@ class VQBeTHead(nn.Module):
self.map_to_cbet_preds_offset = MLP(
in_channels=config.gpt_output_dim,
hidden_channels=[
self.config.vqvae_groups * self.config.vqvae_n_embed * config.n_action_pred_chunk * config.output_shapes["action"][0],
self.config.vqvae_groups * self.config.vqvae_n_embed * config.action_chunk_size * config.output_shapes["action"][0],
],
)
# init vqvae
@ -443,7 +443,7 @@ class VQBeTHead(nn.Module):
)
# reshaped extracted offset to match with decoded centroids
sampled_offsets = einops.rearrange(
sampled_offsets, "NT (W A) -> NT W A", W=self.config.n_action_pred_chunk
sampled_offsets, "NT (W A) -> NT W A", W=self.config.action_chunk_size
)
# add offset and decoded centroids
predicted_action = decoded_action + sampled_offsets
@ -452,7 +452,7 @@ class VQBeTHead(nn.Module):
"(N T) W A -> N T (W A)",
N=N,
T=T,
W=self.config.n_action_pred_chunk,
W=self.config.action_chunk_size,
)
return {
@ -482,7 +482,7 @@ class VQBeTHead(nn.Module):
cbet_logits = pred["cbet_logits"]
predicted_action = einops.rearrange(
predicted_action, "N T (W A) -> (N T) W A", W=self.config.n_action_pred_chunk
predicted_action, "N T (W A) -> (N T) W A", W=self.config.action_chunk_size
)
action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A")
@ -772,12 +772,12 @@ class VqVae(nn.Module):
)
self.encoder = MLP(
in_channels=self.config.output_shapes["action"][0] * self.config.n_action_pred_chunk,
in_channels=self.config.output_shapes["action"][0] * self.config.action_chunk_size,
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, config.vqvae_embedding_dim],
)
self.decoder = MLP(
in_channels=config.vqvae_embedding_dim,
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.config.output_shapes["action"][0] * self.config.n_action_pred_chunk],
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.config.output_shapes["action"][0] * self.config.action_chunk_size],
)
self.train()
@ -823,7 +823,7 @@ class VqVae(nn.Module):
def get_action_from_latent(self, latent):
output = self.decoder(latent)
if self.config.n_action_pred_chunk == 1:
if self.config.action_chunk_size == 1:
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
else:
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
@ -831,7 +831,7 @@ class VqVae(nn.Module):
def preprocess(self, state):
if not torch.is_tensor(state):
state = torch.FloatTensor(state.copy())
if self.config.n_action_pred_chunk == 1:
if self.config.action_chunk_size == 1:
state = state.squeeze(-2)
else:
state = einops.rearrange(state, "N T A -> N (T A)")
@ -850,7 +850,7 @@ class VqVae(nn.Module):
if required_recon:
recon_state = self.decoder(state_vq)
recon_state_ae = self.decoder(state_rep)
if self.config.n_action_pred_chunk == 1:
if self.config.action_chunk_size == 1:
return state_vq, vq_code, recon_state, recon_state_ae
else:
return (
@ -896,13 +896,13 @@ class VqVae(nn.Module):
def pretrain_vqvae(vqvae_model, discretize_step, actions):
if vqvae_model.config.n_action_pred_chunk == 1:
if vqvae_model.config.action_chunk_size == 1:
# not using action chunk
actions = actions.reshape(-1, 1, actions.shape[-1])
else:
# using action chunk
slices = []
slices.extend([actions[:, j:j+vqvae_model.config.n_action_pred_chunk, :] for j in range(actions.shape[1]+1-vqvae_model.config.n_action_pred_chunk)])
slices.extend([actions[:, j:j+vqvae_model.config.action_chunk_size, :] for j in range(actions.shape[1]+1-vqvae_model.config.action_chunk_size)])
actions = torch.cat(slices, dim=0)

View File

@ -47,7 +47,7 @@ training:
delta_timestamps:
observation.image: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, ${policy.n_action_pred_token} + ${policy.n_action_pred_chunk} - 1)]"
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, ${policy.n_action_pred_token} + ${policy.action_chunk_size} - 1)]"
eval:
n_episodes: 500
@ -59,7 +59,7 @@ policy:
# Input / output structure.
n_obs_steps: 5
n_action_pred_token: 7
n_action_pred_chunk: 5
action_chunk_size: 5
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?