diff --git a/lerobot/common/policies/vqbet/configuration_vqbet.py b/lerobot/common/policies/vqbet/configuration_vqbet.py index ea576c1d..48c9be5a 100644 --- a/lerobot/common/policies/vqbet/configuration_vqbet.py +++ b/lerobot/common/policies/vqbet/configuration_vqbet.py @@ -108,15 +108,18 @@ class VQBeTConfig: vqvae_groups: int = 2 vqvae_n_embed: int = 16 vqvae_embedding_dim: int = 256 + vqvae_enc_hidden_dim: int = 128 # VQ-BeT - block_size: int = 50 - output_dim: int = 256 - n_layer: int = 6 - n_head: int = 6 - n_embd: int = 120 + gpt_block_size: int = 500 + gpt_input_dim: int = 512 + gpt_output_dim: int = 512 + gpt_n_layer: int = 8 + gpt_n_head: int = 8 + gpt_n_embed: int = 512 dropout: float = 0.1 mlp_hidden_dim: int = 1024 offset_loss_weight: float = 10000. + secondary_code_multiplier: float = 0.5 def __post_init__(self): """Input validation (not exhaustive).""" diff --git a/lerobot/common/policies/vqbet/modeling_vqbet.py b/lerobot/common/policies/vqbet/modeling_vqbet.py index 70d80e97..8713cb4e 100644 --- a/lerobot/common/policies/vqbet/modeling_vqbet.py +++ b/lerobot/common/policies/vqbet/modeling_vqbet.py @@ -62,7 +62,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin): # queues are populated during rollout of the policy, they contain the n latest observations and actions - self._obs_queues = None + self._queues = None self.vqbet = VQBeTModel(config) @@ -74,12 +74,11 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin): """ Clear observation and action queues. Should be called on `env.reset()` """ - self._obs_queues = { + self._queues = { "observation.image": deque(maxlen=self.config.n_obs_steps), "observation.state": deque(maxlen=self.config.n_obs_steps), + "action": deque(maxlen=self.config.n_action_pred_chunk), } - if self.config.n_action_pred_chunk is not None: - self._action_queue = deque([], maxlen=self.config.n_action_pred_chunk) @torch.no_grad def select_action(self, batch: dict[str, Tensor]) -> Tensor: @@ -93,7 +92,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin): self.eval() batch = self.normalize_inputs(batch) - self._obs_queues = populate_queues(self._obs_queues, batch) + self._queues = populate_queues(self._queues, batch) if not self.check_discretized(): self.vqbet._action_head._vqvae_model.discretized = True @@ -101,16 +100,16 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin): assert "observation.image" in batch assert "observation.state" in batch - if len(self._action_queue) == 0: + if len(self._queues["action"]) == 0: - batch = {key: torch.stack(list(self._obs_queues[key]), dim=1) for key in batch} + batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch} actions = self.vqbet(batch, rollout=True)[:, : self.config.n_action_pred_chunk] actions = self.unnormalize_outputs({"action": actions})["action"] - self._action_queue.extend(actions.transpose(0, 1)) + self._queues["action"].extend(actions.transpose(0, 1)) - action = self._action_queue.popleft() + action = self._queues["action"].popleft() return action def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: @@ -123,9 +122,7 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin): _, loss = self.vqbet(batch, rollout=False) - return {"loss": loss['actor_loss'], 'equal_single_code_rate': loss['equal_single_code_rate'], 'equal_single_code_rate2': loss['equal_single_code_rate2'], "offset_loss_weight": loss['offset_loss_weight'], \ - "action_diff": loss['action_diff'], "action_diff_tot": loss['action_diff_tot'], "action_diff_mean_res1": loss['action_diff_mean_res1'], "action_diff_mean_res2": loss['action_diff_mean_res2'], \ - "action_diff_max": loss['action_diff_max']} + return loss class VQBeTModel(nn.Module): @@ -138,36 +135,19 @@ class VQBeTModel(nn.Module): self.global_cond_dim = self.rgb_encoder.feature_dim # action token and EOS token - self._action_token = nn.Parameter(torch.randn(1, 1, self.config.n_embd)) # Batch, Timestep, Data type, GPT input dim - self._eos_token = nn.Parameter(torch.randn(1, 1, self.config.n_embd)) + self._action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim)) # Batch, Timestep, Data type, GPT input dim + self._eos_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim)) self.state_projector = MLP( - config.output_shapes["action"][0], hidden_channels=[self.config.n_embd] + config.output_shapes["action"][0], + hidden_channels=[self.config.gpt_input_dim] ) self.obs_projector = MLP( - self.global_cond_dim, hidden_channels=[self.config.n_embd] + self.global_cond_dim, + hidden_channels=[self.config.gpt_input_dim] ) - self._policy = GPT( - GPTConfig( - block_size=self.config.block_size, - input_dim=self.config.n_embd, - output_dim=self.config.output_dim, - n_layer=self.config.n_layer, - n_head=self.config.n_head, - n_embd=self.config.n_embd, - dropout=self.config.dropout, - ) - ) - self._action_head = VQBeTHead( - config.output_dim, - config.output_shapes["action"][0], - offset_loss_weight=config.offset_loss_weight, - hidden_size=config.mlp_hidden_dim, - vqvae_groups=config.vqvae_groups, - vqvae_n_embed=config.vqvae_n_embed, - vqvae_embedding_dim=config.vqvae_embedding_dim, - n_action_pred_chunk=config.n_action_pred_chunk - ) + self._policy = GPT(config) + self._action_head = VQBeTHead(config) def discretize(self, discretize_step, actions): return self._action_head.discretize(discretize_step, actions) @@ -190,7 +170,7 @@ class VQBeTModel(nn.Module): torch.unsqueeze(self.obs_projector(img_features), dim=2), torch.unsqueeze(self.state_projector(batch["observation.state"]), dim=2), self._action_token.repeat(batch_size, n_obs_steps, 1, 1) - ], dim=-2).view(batch_size, -1, self.config.n_embd) + ], dim=-2).view(batch_size, -1, self.config.gpt_input_dim) if img_features.shape[1] != n_obs_steps: raise NotImplementedError # eos_token = self._eos_token.repeat(batch_size, 1, 1) # TODO remove EOS token @@ -231,57 +211,40 @@ class VQBeTModel(nn.Module): action, reduction="mean", ) - return pred_action, loss[0] if isinstance(loss, tuple) else loss + return pred_action, loss class VQBeTHead(nn.Module): - def __init__( - self, - # network_kwargs - input_size, - output_size, - hidden_size=1024, - # loss_kwargs - offset_loss_weight=100.0, - secondary_code_multiplier=0.5, - vqvae_groups=2, # G(number of groups) - vqvae_n_embed=16, # C(number of code integers) - vqvae_embedding_dim=512, # D(embedding dims) - n_action_pred_chunk=1, # action chunk size - ): - super().__init__() - self.input_size = input_size - self.output_size = output_size - self.hidden_size = hidden_size - self.offset_loss_weight = offset_loss_weight - self.secondary_code_multiplier = secondary_code_multiplier + def __init__(self, config: VQBeTConfig): + """ + TODO: add explanation for each value. + """ - self._G = vqvae_groups # G(number of groups) - self._C = vqvae_n_embed # C(number of code integers) - self._D = vqvae_embedding_dim # D(embedding dims) - self.n_action_pred_chunk = n_action_pred_chunk # action chunk size + super().__init__() + self.input_size = config.gpt_output_dim + self.output_size = config.output_shapes["action"][0] + self.hidden_size = config.mlp_hidden_dim + self.offset_loss_weight = config.offset_loss_weight + self.secondary_code_multiplier = config.secondary_code_multiplier + + self.vqvae_groups = config.vqvae_groups + self.vqvae_n_embed = config.vqvae_n_embed # C(number of code integers) + self.vqvae_embedding_dim = config.vqvae_embedding_dim # D(embedding dims) + self.n_action_pred_chunk = config.n_action_pred_chunk # action chunk size self._map_to_cbet_preds_bin = MLP( in_channels=self.input_size, - hidden_channels=[self._G * self._C], + hidden_channels=[self.vqvae_groups * self.vqvae_n_embed], ) self._map_to_cbet_preds_offset = MLP( in_channels=self.input_size, hidden_channels=[ - self._G * self._C * n_action_pred_chunk * self.output_size, + self.vqvae_groups * self.vqvae_n_embed * config.n_action_pred_chunk * self.output_size, ], ) # init vqvae - vqvae_config = { - "action_chunk": self.n_action_pred_chunk, - "action_dim": self.output_size, - "vqvae_n_latent_dims": self._D, - "vqvae_n_embed": self._C, - "vqvae_groups": self._G, - "device": get_device_from_parameters(self), - } - self._vqvae_model = init_vqvae(vqvae_config) + self._vqvae_model = VqVae(config) # loss self._criterion = FocalLoss(gamma=2.0) @@ -307,10 +270,10 @@ class VQBeTHead(nn.Module): cbet_logits = self._map_to_cbet_preds_bin(x) cbet_offsets = self._map_to_cbet_preds_offset(x) cbet_logits = einops.rearrange( - cbet_logits, "(NT) (G C) -> (NT) G C", G=self._G + cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_groups ) cbet_offsets = einops.rearrange( - cbet_offsets, "(NT) (G C WA) -> (NT) G C WA", G=self._G, C=self._C + cbet_offsets, "(NT) (G C WA) -> (NT) G C WA", G=self.vqvae_groups, C=self.vqvae_n_embed ) cbet_probs = torch.softmax(cbet_logits, dim=-1) NT, G, choices = cbet_probs.shape @@ -322,7 +285,7 @@ class VQBeTHead(nn.Module): indices = ( torch.arange(NT).unsqueeze(1).cuda(), - torch.arange(self._G).unsqueeze(0).cuda(), + torch.arange(self.vqvae_groups).unsqueeze(0).cuda(), sampled_centers, ) # Use advanced indexing to sample the values @@ -330,7 +293,7 @@ class VQBeTHead(nn.Module): sampled_offsets = sampled_offsets.sum(dim=1) centers = self._vqvae_model.draw_code_forward(sampled_centers).view( - NT, -1, self._D + NT, -1, self.vqvae_embedding_dim ) return_decoder_input = einops.rearrange( centers.clone().detach(), "NT 1 D -> NT D" @@ -341,7 +304,7 @@ class VQBeTHead(nn.Module): .detach() ) # NT, A sampled_offsets = einops.rearrange( - sampled_offsets, "NT (W A) -> NT W A", W=self._vqvae_model.input_dim_h + sampled_offsets, "NT (W A) -> NT W A", W=self.n_action_pred_chunk ) predicted_action = decoded_action + sampled_offsets predicted_action = einops.rearrange( @@ -349,7 +312,7 @@ class VQBeTHead(nn.Module): "(N T) W A -> N T (W A)", N=N, T=T, - W=self._vqvae_model.input_dim_h, + W=self.n_action_pred_chunk, ) return { @@ -357,10 +320,6 @@ class VQBeTHead(nn.Module): "predicted_action": predicted_action, "sampled_centers": sampled_centers, "decoded_action": decoded_action, - "G": G, - "NT": NT, - "N": N, - "T": T, } def loss_fn(self, pred, target, **kwargs): @@ -369,11 +328,12 @@ class VQBeTHead(nn.Module): predicted_action = pred["predicted_action"] sampled_centers = pred["sampled_centers"] decoded_action = pred["decoded_action"] - G, NT, N, T = pred["G"], pred["NT"], pred["N"], pred["T"] + NT = predicted_action.shape[0] + T = predicted_action.shape[1] cbet_logits = pred["cbet_logits"] predicted_action = einops.rearrange( - predicted_action, "N T (W A) -> (N T) W A", W=self._vqvae_model.input_dim_h + predicted_action, "N T (W A) -> (N T) W A", W=self.n_action_pred_chunk ) action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A") @@ -400,74 +360,48 @@ class VQBeTHead(nn.Module): ) cbet_loss = cbet_loss1 * 5 + cbet_loss2 * self.secondary_code_multiplier - equal_total_code_rate = ( - torch.sum( - (torch.sum((action_bins == sampled_centers).int(), axis=1) == G).int() - ) - / NT - ) - equal_single_code_rate = torch.sum( + equal_primary_code_rate = torch.sum( (action_bins[:, 0] == sampled_centers[:, 0]).int() ) / (NT) - equal_single_code_rate2 = torch.sum( + equal_secondary_code_rate = torch.sum( (action_bins[:, 1] == sampled_centers[:, 1]).int() ) / (NT) - action_diff = F.mse_loss( - einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[:, 4, :, :], - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)[ - :, 4, :, : - ], - ) # batch, time, windowsize (t ... t+N), action dim -> [:, -1, 0, :] is for rollout - action_diff_tot = F.mse_loss( - einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[:, :, :, :], - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)[ - :, :, :, : - ], - ) # batch, time, windowsize (t ... t+N), action dim -> [:, -1, 0, :] is for rollout - action_diff_mean_res1 = ( + action_mse_error = F.mse_loss( + einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T), + einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T), + ) + vq_action_error = ( abs( - einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[ - :, 4, :, : - ] - - einops.rearrange(decoded_action, "(N T) W A -> N T W A", T=T)[ - :, 4, :, : - ] + einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(decoded_action, "(N T) W A -> N T W A", T=T) ) ).mean() - action_diff_mean_res2 = ( + offset_action_error = ( abs( - einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[ - :, 4, :, : - ] - - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)[ - :, 4, :, : - ] + einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T) ) ).mean() - action_diff_max = ( + action_error_max = ( abs( - einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[ - :, 4, :, : - ] - - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)[ - :, 4, :, : - ] + einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T) - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T) ) ).max() loss = cbet_loss + self.offset_loss_weight * offset_loss + loss_dict = { + "loss": loss, "classification_loss": cbet_loss.detach().cpu().item(), "offset_loss": offset_loss.detach().cpu().item(), - "total_loss": loss.detach().cpu().item(), - "equal_total_code_rate": equal_total_code_rate, - "equal_single_code_rate": equal_single_code_rate, - "equal_single_code_rate2": equal_single_code_rate2, + "equal_primary_code_rate": equal_primary_code_rate.detach().cpu().item(), + "equal_secondary_code_rate": equal_secondary_code_rate.detach().cpu().item(), + "vq_action_error": vq_action_error.detach().cpu().item(), + "offset_action_error": offset_action_error.detach().cpu().item(), + "action_error_max": action_error_max.detach().cpu().item(), + "action_mse_error": action_mse_error.detach().cpu().item(), + } - return {"actor_loss": loss, "equal_single_code_rate": equal_single_code_rate, "equal_single_code_rate2": equal_single_code_rate2, "offset_loss_weight": self.offset_loss_weight, \ - "action_diff": action_diff, "action_diff_tot": action_diff_tot, "action_diff_mean_res1": action_diff_mean_res1, "action_diff_mean_res2": action_diff_mean_res2, \ - "action_diff_max": action_diff_max}, loss_dict + return loss_dict class VQBeTOptimizer: def __init__(self, policy, cfg): @@ -688,71 +622,43 @@ def _replace_submodules( class VqVae(nn.Module): def __init__( - self, - input_dim_h=10, # length of action chunk - input_dim_w=9, # action dim - n_latent_dims=512, - vqvae_n_embed=32, - vqvae_groups=4, - eval=True, - load_dir=None, - encoder_loss_multiplier=1.0, - act_scale=1.0, + self, config: VQBeTConfig, ): + super(VqVae, self).__init__() - self.n_latent_dims = n_latent_dims - self.input_dim_h = input_dim_h - self.input_dim_w = input_dim_w - self.rep_dim = self.n_latent_dims - self.vqvae_n_embed = vqvae_n_embed - self.vqvae_lr = 1e-3 - self.vqvae_groups = vqvae_groups - self.encoder_loss_multiplier = encoder_loss_multiplier - self.act_scale = act_scale + self.n_action_pred_chunk = config.n_action_pred_chunk + self.action_dim = config.output_shapes["action"][0] self.discretized = False self.optimized_steps = 0 - discrete_cfg = {"groups": self.vqvae_groups, "n_embed": self.vqvae_n_embed} - self.vq_layer = ResidualVQ( - dim=self.n_latent_dims, - num_quantizers=discrete_cfg["groups"], - codebook_size=self.vqvae_n_embed, + dim=config.vqvae_embedding_dim, + num_quantizers=config.vqvae_groups, + codebook_size=config.vqvae_n_embed, ) - self.embedding_dim = self.n_latent_dims + self.embedding_dim = config.vqvae_embedding_dim - if self.input_dim_h == 1: + if self.n_action_pred_chunk == 1: self.encoder = MLP( - in_channels=input_dim_w, - hidden_channels=[128, 128, n_latent_dims], + in_channels=self.action_dim, + hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, config.vqvae_embedding_dim], ) self.decoder = MLP( - in_channels=n_latent_dims, - hidden_channels=[128, 128, input_dim_w], + in_channels=config.vqvae_embedding_dim, + hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.action_dim], ) else: self.encoder = MLP( - in_channels=input_dim_w * self.input_dim_h, - hidden_channels=[128, 128, n_latent_dims], + in_channels=self.action_dim * self.n_action_pred_chunk, + hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, config.vqvae_embedding_dim], ) self.decoder = MLP( - in_channels=n_latent_dims, - hidden_channels=[128, 128, input_dim_w * self.input_dim_h], + in_channels=config.vqvae_embedding_dim, + hidden_channels=[config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, self.action_dim * self.n_action_pred_chunk], ) - - if load_dir is not None: - try: - state_dict = torch.load(load_dir) - except RuntimeError: - state_dict = torch.load(load_dir, map_location=torch.device("cpu")) - self.load_state_dict(state_dict) - - if eval: - self.eval() - else: - self.train() + self.train() def eval(self): self.training = False @@ -783,23 +689,22 @@ class VqVae(nn.Module): return z_embed def get_action_from_latent(self, latent): - output = self.decoder(latent) * self.act_scale - if self.input_dim_h == 1: - return einops.rearrange(output, "N (T A) -> N T A", A=self.input_dim_w) + output = self.decoder(latent) + if self.n_action_pred_chunk == 1: + return einops.rearrange(output, "N (T A) -> N T A", A=self.action_dim) else: - return einops.rearrange(output, "N (T A) -> N T A", A=self.input_dim_w) + return einops.rearrange(output, "N (T A) -> N T A", A=self.action_dim) def preprocess(self, state): if not torch.is_tensor(state): state = torch.FloatTensor(state.copy()) - if self.input_dim_h == 1: + if self.n_action_pred_chunk == 1: state = state.squeeze(-2) # state.squeeze(-1) else: state = einops.rearrange(state, "N T A -> N (T A)") return state def get_code(self, state, required_recon=False): - state = state / self.act_scale state = self.preprocess(state) with torch.no_grad(): state_rep = self.encoder(state) @@ -810,9 +715,9 @@ class VqVae(nn.Module): vq_code = vq_code.view(*state_rep_shape, -1) vq_loss_state = torch.sum(vq_loss_state) if required_recon: - recon_state = self.decoder(state_vq) * self.act_scale - recon_state_ae = self.decoder(state_rep) * self.act_scale - if self.input_dim_h == 1: + recon_state = self.decoder(state_vq) + recon_state_ae = self.decoder(state_rep) + if self.n_action_pred_chunk == 1: return state_vq, vq_code, recon_state, recon_state_ae else: return ( @@ -826,7 +731,6 @@ class VqVae(nn.Module): return state_vq, vq_code def vqvae_forward(self, state): - state = state / self.act_scale state = self.preprocess(state) state_rep = self.encoder(state) state_rep_shape = state_rep.shape[:-1] @@ -839,7 +743,7 @@ class VqVae(nn.Module): dec_out = self.decoder(state_vq) encoder_loss = (state - dec_out).abs().mean() - rep_loss = encoder_loss * self.encoder_loss_multiplier + (vq_loss_state * 5) + rep_loss = encoder_loss * vq_loss_state * 5 metric = ( encoder_loss.clone().detach(), @@ -857,28 +761,16 @@ class VqVae(nn.Module): -def init_vqvae(config): - # model - vqvae_model = VqVae( - input_dim_h=config["action_chunk"], - input_dim_w=config["action_dim"], - n_latent_dims=config["vqvae_n_latent_dims"], - vqvae_n_embed=config["vqvae_n_embed"], - vqvae_groups=config["vqvae_groups"], - eval=False, - ) - - return vqvae_model def pretrain_vqvae(vqvae_model, discretize_step, actions): - if vqvae_model.input_dim_h == 1: + if vqvae_model.n_action_pred_chunk == 1: # not using action chunk actions = actions.reshape(-1, 1, actions.shape[-1]) else: # using action chunk slices = [] - slices.extend([actions[:, j:j+vqvae_model.input_dim_h, :] for j in range(actions.shape[1]+1-vqvae_model.input_dim_h)]) + slices.extend([actions[:, j:j+vqvae_model.n_action_pred_chunk, :] for j in range(actions.shape[1]+1-vqvae_model.n_action_pred_chunk)]) actions = torch.cat(slices, dim=0) @@ -2166,40 +2058,40 @@ class MLP(torch.nn.Sequential): class CausalSelfAttention(nn.Module): def __init__(self, config): super().__init__() - assert config.n_embd % config.n_head == 0 + assert config.gpt_n_embed % config.gpt_n_head == 0 # key, query, value projections for all heads, but in a batch - self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) + self.c_attn = nn.Linear(config.gpt_n_embed, 3 * config.gpt_n_embed) # output projection - self.c_proj = nn.Linear(config.n_embd, config.n_embd) + self.c_proj = nn.Linear(config.gpt_n_embed, config.gpt_n_embed) # regularization self.attn_dropout = nn.Dropout(config.dropout) self.resid_dropout = nn.Dropout(config.dropout) # causal mask to ensure that attention is only applied to the left in the input sequence self.register_buffer( "bias", - torch.tril(torch.ones(config.block_size, config.block_size)).view( - 1, 1, config.block_size, config.block_size + torch.tril(torch.ones(config.gpt_block_size, config.gpt_block_size)).view( + 1, 1, config.gpt_block_size, config.gpt_block_size ), ) - self.n_head = config.n_head - self.n_embd = config.n_embd + self.gpt_n_head = config.gpt_n_head + self.gpt_n_embed = config.gpt_n_embed def forward(self, x): ( B, T, C, - ) = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + ) = x.size() # batch size, sequence length, embedding dimensionality (gpt_n_embed) # calculate query, key, values for all heads in batch and move head forward to be the batch dim - q, k, v = self.c_attn(x).split(self.n_embd, dim=2) - k = k.view(B, T, self.n_head, C // self.n_head).transpose( + q, k, v = self.c_attn(x).split(self.gpt_n_embed, dim=2) + k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose( 1, 2 ) # (B, nh, T, hs) - q = q.view(B, T, self.n_head, C // self.n_head).transpose( + q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose( 1, 2 ) # (B, nh, T, hs) - v = v.view(B, T, self.n_head, C // self.n_head).transpose( + v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose( 1, 2 ) # (B, nh, T, hs) @@ -2222,13 +2114,13 @@ class CausalSelfAttention(nn.Module): class Block(nn.Module): def __init__(self, config): super().__init__() - self.ln_1 = nn.LayerNorm(config.n_embd) + self.ln_1 = nn.LayerNorm(config.gpt_n_embed) self.attn = CausalSelfAttention(config) - self.ln_2 = nn.LayerNorm(config.n_embd) + self.ln_2 = nn.LayerNorm(config.gpt_n_embed) self.mlp = nn.Sequential( - nn.Linear(config.n_embd, 4 * config.n_embd), + nn.Linear(config.gpt_n_embed, 4 * config.gpt_n_embed), nn.GELU(), - nn.Linear(4 * config.n_embd, config.n_embd), + nn.Linear(4 * config.gpt_n_embed, config.gpt_n_embed), nn.Dropout(config.dropout) ) @@ -2238,15 +2130,6 @@ class Block(nn.Module): return x -@dataclass -class GPTConfig: - block_size: int = 1024 - input_dim: int = 256 - output_dim: int = 256 - n_layer: int = 12 - n_head: int = 12 - n_embd: int = 768 - dropout: float = 0.1 class GPT(nn.Module): @@ -2287,29 +2170,28 @@ class GPT(nn.Module): """ - def __init__(self, config): + def __init__(self, config: VQBeTConfig): super().__init__() - assert config.input_dim is not None - assert config.output_dim is not None - assert config.block_size is not None + assert config.gpt_output_dim is not None + assert config.gpt_block_size is not None self.config = config self.transformer = nn.ModuleDict( dict( - wte=nn.Linear(config.input_dim, config.n_embd), - wpe=nn.Embedding(config.block_size, config.n_embd), + wte=nn.Linear(config.gpt_input_dim, config.gpt_n_embed), + wpe=nn.Embedding(config.gpt_block_size, config.gpt_n_embed), drop=nn.Dropout(config.dropout), - h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), - ln_f=nn.LayerNorm(config.n_embd), + h=nn.ModuleList([Block(config) for _ in range(config.gpt_n_layer)]), + ln_f=nn.LayerNorm(config.gpt_n_embed), ) ) - self.lm_head = nn.Linear(config.n_embd, config.output_dim, bias=False) + self.lm_head = nn.Linear(config.gpt_n_embed, config.gpt_output_dim, bias=False) # init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper self.apply(self._init_weights) for pn, p in self.named_parameters(): if pn.endswith("c_proj.weight"): torch.nn.init.normal_( - p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer) + p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer) ) # report number of parameters @@ -2320,8 +2202,8 @@ class GPT(nn.Module): device = input.device b, t, d = input.size() assert ( - t <= self.config.block_size - ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" + t <= self.config.gpt_block_size + ), f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}" pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze( 0 ) # shape (1, t) @@ -2329,10 +2211,10 @@ class GPT(nn.Module): # forward the GPT model itself tok_emb = self.transformer.wte( input - ) # token embeddings of shape (b, t, n_embd) + ) # token embeddings of shape (b, t, gpt_n_embed) pos_emb = self.transformer.wpe( pos - ) # position embeddings of shape (1, t, n_embd) + ) # position embeddings of shape (1, t, gpt_n_embed) x = self.transformer.drop(tok_emb + pos_emb) for block in self.transformer.h: x = block(x) @@ -2351,14 +2233,14 @@ class GPT(nn.Module): torch.nn.init.zeros_(module.bias) torch.nn.init.ones_(module.weight) - def crop_block_size(self, block_size): - assert block_size <= self.config.block_size - self.config.block_size = block_size + def crop_block_size(self, gpt_block_size): + assert gpt_block_size <= self.config.gpt_block_size + self.config.gpt_block_size = gpt_block_size self.transformer.wpe.weight = nn.Parameter( - self.transformer.wpe.weight[:block_size] + self.transformer.wpe.weight[:gpt_block_size] ) for block in self.transformer.h: - block.attn.bias = block.attn.bias[:, :, :block_size, :block_size] + block.attn.bias = block.attn.bias[:, :, :gpt_block_size, :gpt_block_size] def configure_optimizers(self, weight_decay, learning_rate, betas, optimizer="Adamw", eps=None): """ diff --git a/lerobot/configs/policy/vqbet.yaml b/lerobot/configs/policy/vqbet.yaml index 00ad3ac8..dfd0d8d7 100644 --- a/lerobot/configs/policy/vqbet.yaml +++ b/lerobot/configs/policy/vqbet.yaml @@ -22,7 +22,7 @@ override_dataset_stats: max: [511.0, 511.0] training: - offline_steps: 800000 + offline_steps: 250000 online_steps: 0 eval_freq: 20000 save_freq: 20000 @@ -90,12 +90,15 @@ policy: vqvae_groups: 2 vqvae_n_embed: 16 vqvae_embedding_dim: 256 + vqvae_enc_hidden_dim: 128 # VQ-BeT - block_size: 500 - output_dim: 512 - n_layer: 8 # 8 - n_head: 8 # 4 - n_embd: 512 + gpt_block_size: 500 + gpt_input_dim: 512 + gpt_output_dim: 512 + gpt_n_layer: 8 + gpt_n_head: 8 + gpt_n_embed: 512 dropout: 0.1 - mlp_hidden_dim: 1024 # 512 - offset_loss_weight: 10000. \ No newline at end of file + mlp_hidden_dim: 1024 + offset_loss_weight: 10000. + secondary_code_multiplier: 0.5 \ No newline at end of file