change n_action_pred_chunk -> action_chunk_size
This commit is contained in:
parent
bc9e6874fc
commit
6d72847bfe
|
@ -14,7 +14,7 @@ class VQBeTConfig:
|
|||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||
current step and additional steps going back).
|
||||
n_action_pred_token: Number of future tokens that VQ-BeT predicts.
|
||||
n_action_pred_chunk: Action chunk size of each aciton prediction token.
|
||||
action_chunk_size: Action chunk size of each action prediction token.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
||||
The key represents the input data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "observation.image" refers to an input from
|
||||
|
@ -63,7 +63,7 @@ class VQBeTConfig:
|
|||
# Inputs / output structure.
|
||||
n_obs_steps: int = 5
|
||||
n_action_pred_token: int = 3
|
||||
n_action_pred_chunk: int = 5
|
||||
action_chunk_size: int = 5
|
||||
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
|
|
|
@ -70,7 +70,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
self._queues = {
|
||||
"observation.image": deque(maxlen=self.config.n_obs_steps),
|
||||
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
||||
"action": deque(maxlen=self.config.n_action_pred_chunk),
|
||||
"action": deque(maxlen=self.config.action_chunk_size),
|
||||
}
|
||||
|
||||
@torch.no_grad
|
||||
|
@ -94,11 +94,11 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||
if len(self._queues["action"]) == 0:
|
||||
|
||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||
actions = self.vqbet(batch, rollout=True)[:, : self.config.n_action_pred_chunk]
|
||||
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
|
||||
|
||||
# the dimension of returned action is (batch_size, n_action_pred_chunk, action_dim)
|
||||
# the dimension of returned action is (batch_size, action_chunk_size, action_dim)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
# since the data in the action queue's dimension is (n_action_pred_chunk, batch_size, action_dim), we transpose the action and fill the queue
|
||||
# since the data in the action queue's dimension is (action_chunk_size, batch_size, action_dim), we transpose the action and fill the queue
|
||||
self._queues["action"].extend(actions.transpose(0, 1))
|
||||
|
||||
action = self._queues["action"].popleft()
|
||||
|
@ -200,7 +200,7 @@ class VQBeTModel(nn.Module):
|
|||
and finally generates a prediction for the action chunks.
|
||||
|
||||
-------------------------------** legend **-------------------------------
|
||||
│ n = n_obs_steps, p = n_action_pred_token, c = n_action_pred_chunk) │
|
||||
│ n = n_obs_steps, p = n_action_pred_token, c = action_chunk_size) │
|
||||
│ o_{t} : visual observation at timestep {t} │
|
||||
│ s_{t} : state observation at timestep {t} │
|
||||
│ a_{t} : action at timestep {t} │
|
||||
|
@ -319,12 +319,12 @@ class VQBeTModel(nn.Module):
|
|||
)
|
||||
# if rollout, VQ-BeT don't calculate loss
|
||||
if rollout:
|
||||
return pred_action["predicted_action"][:, n_obs_steps-1, :].reshape(batch_size, self.config.n_action_pred_chunk, -1)
|
||||
return pred_action["predicted_action"][:, n_obs_steps-1, :].reshape(batch_size, self.config.action_chunk_size, -1)
|
||||
# else, it calculate overall loss (bin prediction loss, and offset loss)
|
||||
else:
|
||||
action = batch["action"]
|
||||
n, total_w, act_dim = action.shape
|
||||
act_w = self.config.n_action_pred_chunk
|
||||
act_w = self.config.action_chunk_size
|
||||
num_token = total_w + 1 - act_w
|
||||
output_shape = (n, num_token, act_w, act_dim)
|
||||
output = torch.empty(output_shape).to(action.device)
|
||||
|
@ -353,10 +353,10 @@ class VQBeTHead(nn.Module):
|
|||
|
||||
self.map_to_cbet_preds_offset: output the predicted offsets for all the codes in all the layers.
|
||||
The input dimension of ` self.map_to_cbet_preds_offset` is same with the output of GPT,
|
||||
and the output dimension of ` self.map_to_cbet_preds_offset` is `self.config.vqvae_groups * self.config.vqvae_n_embed * config.n_action_pred_chunk * config.output_shapes["action"][0]`, where
|
||||
and the output dimension of ` self.map_to_cbet_preds_offset` is `self.config.vqvae_groups * self.config.vqvae_n_embed * config.action_chunk_size * config.output_shapes["action"][0]`, where
|
||||
`self.config.vqvae_groups` is number of RVQ layers,
|
||||
`self.config.vqvae_n_embed` is codebook size of RVQ,
|
||||
`config.n_action_pred_chunk is action chunk size of each token, and
|
||||
`config.action_chunk_size is action chunk size of each token, and
|
||||
`config.output_shapes["action"][0]` is the dimension of action
|
||||
"""
|
||||
|
||||
|
@ -372,7 +372,7 @@ class VQBeTHead(nn.Module):
|
|||
self.map_to_cbet_preds_offset = MLP(
|
||||
in_channels=config.gpt_output_dim,
|
||||
hidden_channels=[
|
||||
self.config.vqvae_groups * self.config.vqvae_n_embed * config.n_action_pred_chunk * config.output_shapes["action"][0],
|
||||
self.config.vqvae_groups * self.config.vqvae_n_embed * config.action_chunk_size * config.output_shapes["action"][0],
|
||||
],
|
||||
)
|
||||
# init vqvae
|
||||
|
@ -443,7 +443,7 @@ class VQBeTHead(nn.Module):
|
|||
)
|
||||
# reshaped extracted offset to match with decoded centroids
|
||||
sampled_offsets = einops.rearrange(
|
||||
sampled_offsets, "NT (W A) -> NT W A", W=self.config.n_action_pred_chunk
|
||||
sampled_offsets, "NT (W A) -> NT W A", W=self.config.action_chunk_size
|
||||
)
|
||||
# add offset and decoded centroids
|
||||
predicted_action = decoded_action + sampled_offsets
|
||||
|
@ -452,7 +452,7 @@ class VQBeTHead(nn.Module):
|
|||
"(N T) W A -> N T (W A)",
|
||||
N=N,
|
||||
T=T,
|
||||
W=self.config.n_action_pred_chunk,
|
||||
W=self.config.action_chunk_size,
|
||||
)
|
||||
|
||||
return {
|
||||
|
@ -482,7 +482,7 @@ class VQBeTHead(nn.Module):
|
|||
cbet_logits = pred["cbet_logits"]
|
||||
|
||||
predicted_action = einops.rearrange(
|
||||
predicted_action, "N T (W A) -> (N T) W A", W=self.config.n_action_pred_chunk
|
||||
predicted_action, "N T (W A) -> (N T) W A", W=self.config.action_chunk_size
|
||||
)
|
||||
|
||||
action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A")
|
||||
|
@ -772,12 +772,12 @@ class VqVae(nn.Module):
|
|||
)
|
||||
|
||||
self.encoder = MLP(
|
||||
in_channels=self.config.output_shapes["action"][0] * self.config.n_action_pred_chunk,
|
||||
in_channels=self.config.output_shapes["action"][0] * self.config.action_chunk_size,
|
||||
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, config.vqvae_embedding_dim],
|
||||
)
|
||||
self.decoder = MLP(
|
||||
in_channels=config.vqvae_embedding_dim,
|
||||
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.config.output_shapes["action"][0] * self.config.n_action_pred_chunk],
|
||||
hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.config.output_shapes["action"][0] * self.config.action_chunk_size],
|
||||
)
|
||||
|
||||
self.train()
|
||||
|
@ -823,7 +823,7 @@ class VqVae(nn.Module):
|
|||
|
||||
def get_action_from_latent(self, latent):
|
||||
output = self.decoder(latent)
|
||||
if self.config.n_action_pred_chunk == 1:
|
||||
if self.config.action_chunk_size == 1:
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
|
||||
else:
|
||||
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
|
||||
|
@ -831,7 +831,7 @@ class VqVae(nn.Module):
|
|||
def preprocess(self, state):
|
||||
if not torch.is_tensor(state):
|
||||
state = torch.FloatTensor(state.copy())
|
||||
if self.config.n_action_pred_chunk == 1:
|
||||
if self.config.action_chunk_size == 1:
|
||||
state = state.squeeze(-2)
|
||||
else:
|
||||
state = einops.rearrange(state, "N T A -> N (T A)")
|
||||
|
@ -850,7 +850,7 @@ class VqVae(nn.Module):
|
|||
if required_recon:
|
||||
recon_state = self.decoder(state_vq)
|
||||
recon_state_ae = self.decoder(state_rep)
|
||||
if self.config.n_action_pred_chunk == 1:
|
||||
if self.config.action_chunk_size == 1:
|
||||
return state_vq, vq_code, recon_state, recon_state_ae
|
||||
else:
|
||||
return (
|
||||
|
@ -896,13 +896,13 @@ class VqVae(nn.Module):
|
|||
|
||||
|
||||
def pretrain_vqvae(vqvae_model, discretize_step, actions):
|
||||
if vqvae_model.config.n_action_pred_chunk == 1:
|
||||
if vqvae_model.config.action_chunk_size == 1:
|
||||
# not using action chunk
|
||||
actions = actions.reshape(-1, 1, actions.shape[-1])
|
||||
else:
|
||||
# using action chunk
|
||||
slices = []
|
||||
slices.extend([actions[:, j:j+vqvae_model.config.n_action_pred_chunk, :] for j in range(actions.shape[1]+1-vqvae_model.config.n_action_pred_chunk)])
|
||||
slices.extend([actions[:, j:j+vqvae_model.config.action_chunk_size, :] for j in range(actions.shape[1]+1-vqvae_model.config.action_chunk_size)])
|
||||
actions = torch.cat(slices, dim=0)
|
||||
|
||||
|
||||
|
|
|
@ -47,7 +47,7 @@ training:
|
|||
delta_timestamps:
|
||||
observation.image: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, ${policy.n_action_pred_token} + ${policy.n_action_pred_chunk} - 1)]"
|
||||
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, ${policy.n_action_pred_token} + ${policy.action_chunk_size} - 1)]"
|
||||
|
||||
eval:
|
||||
n_episodes: 500
|
||||
|
@ -59,7 +59,7 @@ policy:
|
|||
# Input / output structure.
|
||||
n_obs_steps: 5
|
||||
n_action_pred_token: 7
|
||||
n_action_pred_chunk: 5
|
||||
action_chunk_size: 5
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
|
|
Loading…
Reference in New Issue