diff --git a/lerobot/common/policies/vqbet/configuration_vqbet.py b/lerobot/common/policies/vqbet/configuration_vqbet.py index dffa2d14..35aa6b6e 100644 --- a/lerobot/common/policies/vqbet/configuration_vqbet.py +++ b/lerobot/common/policies/vqbet/configuration_vqbet.py @@ -71,7 +71,9 @@ class VQBeTConfig: # Inputs / output structure. n_obs_steps: int = 5 - n_action_steps: int = 5 + # n_action_steps: int = 7 + n_action_pred_token: int = 3 + n_action_pred_chunk: int = 5 input_shapes: dict[str, list[int]] = field( default_factory=lambda: { diff --git a/lerobot/common/policies/vqbet/modeling_vqbet.py b/lerobot/common/policies/vqbet/modeling_vqbet.py index 56d308fc..728d759f 100644 --- a/lerobot/common/policies/vqbet/modeling_vqbet.py +++ b/lerobot/common/policies/vqbet/modeling_vqbet.py @@ -77,8 +77,10 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin): "observation.image": deque(maxlen=self.config.n_obs_steps), "observation.state": deque(maxlen=self.config.n_obs_steps), } - if self.config.n_action_steps is not None: - self._action_queue = deque([], maxlen=self.config.n_action_steps) + if self.config.n_action_pred_chunk is not None: + self._action_queue = deque([], maxlen=self.config.n_action_pred_chunk) # original one + # self._action_queue = deque([], maxlen=self.config.n_action_pred_token) # jay temp2 + self._action_history_queue = deque([], maxlen=self.config.n_obs_steps - 1) @torch.no_grad def select_action(self, batch: dict[str, Tensor]) -> Tensor: @@ -96,6 +98,10 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin): batch = self.normalize_inputs(batch) self._obs_queues = populate_queues(self._obs_queues, batch) + if len(self._action_history_queue) == 0: + while len(self._action_history_queue) != self._action_history_queue.maxlen: + self._action_history_queue.append(batch["observation.state"]) + if not self.check_discretized(): self.vqbet._action_head._vqvae_model.discretized = True # raise NotImplementedError( @@ -109,16 +115,35 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin): # for act seq pred, we should provide averaged act over horizon. if len(self._action_queue) == 0: - # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue - # effectively has shape (n_action_steps, batch_size, *), hence the transpose. + + # original one + batch = {key: torch.stack(list(self._obs_queues[key]), dim=1) for key in batch} - actions = self.vqbet(batch)[:, : self.config.n_action_steps] + batch["action"] = torch.stack(list(self._action_history_queue), dim=1) + actions = self.vqbet(batch, rollout=True)[:, : self.config.n_action_pred_chunk] # TODO(rcadene): make _forward return output dictionary? actions = self.unnormalize_outputs({"action": actions})["action"] self._action_queue.extend(actions.transpose(0, 1)) - return self._action_queue.popleft() + + + # jay temp2 + # batch = {key: torch.stack(list(self._obs_queues[key]), dim=1) for key in batch} + # batch["action"] = torch.stack(list(self._action_history_queue), dim=1) + # action_predicted = [] + # for i in range(self.config.n_action_pred_token): + + # actions = self.vqbet(batch, rollout=True, input_predicted = action_predicted)[:, i] + + # action_predicted.append(actions) + # actions = self.unnormalize_outputs({"action": torch.stack(action_predicted)})["action"] + # self._action_queue.extend(actions) + ############# + + action = self._action_queue.popleft() + self._action_history_queue.append(self.normalize_targets({"action": action})["action"]) + return action def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Run the batch through the model and compute the loss for training or validation.""" @@ -127,7 +152,60 @@ class VQBeTPolicy(nn.Module, PyTorchModelHubMixin): if not self.check_discretized(): loss = self.vqbet.discretize(self.config.discretize_step, batch['action']) return {"loss": loss} - _, loss = self.vqbet(batch) + + # original one + _, loss = self.vqbet(batch, rollout=False) + + # jay temp2 + # action_predicted = [] + # pred_values = [] + # for i in range(self.config.n_action_pred_token): + # pred = self.vqbet(batch, rollout=False, input_predicted = action_predicted, recurrent = True) + # actions = pred["predicted_action"][:, i] + # action_predicted.append(actions) + # pred_values.append(pred) + + # # Stack 'a', 'b', and 'c' from each dictionary in the list + # # jay TODO + # predicted_action = torch.stack([d["predicted_action"][:, i] for (i, d) in enumerate(pred_values)]).transpose(0, 1) + # decoded_action = torch.stack([einops.rearrange(d["decoded_action"], "(B T) 1 N->B T N", T=5)[:, i] for (i, d) in enumerate(pred_values)]).transpose(0, 1).reshape(-1, 1, 2) + # sampled_centers = torch.stack([einops.rearrange(d["sampled_centers"], "(B T) N->B T N", T=5)[:, i] for (i, d) in enumerate(pred_values)]).transpose(0, 1).reshape(-1, 2) + # cbet_logits = torch.stack([einops.rearrange(d["cbet_logits"], "(B T) G N->B T G N", T=5)[:, i] for (i, d) in enumerate(pred_values)]).transpose(0, 1).reshape(-1, 2, 16) + + # gpt_output, G, NT, N, T = pred_values[-1]["input"], pred_values[-1]["G"], pred_values[-1]["NT"], pred_values[-1]["N"], pred_values[-1]["T"] + + + # pred_action = { + # "input": gpt_output, + # "cbet_logits1": None, + # "cbet_logits2": None, + # "cbet_logits": cbet_logits if "cbet_logits" in locals() else None, + # "predicted_action": predicted_action, + # "decoded_action": decoded_action, + # "sampled_centers": sampled_centers, + # "G": G, + # "NT": NT, + # "N": N, + # "T": T, + # } + + # action = batch["action"] + # n, total_w, act_dim = action.shape + # act_w = self.config.n_action_pred_chunk + # num_token = total_w + 1 - act_w + # output_shape = (n, num_token, act_w, act_dim) + # output = torch.empty(output_shape).to(action.device) + # for i in range(num_token): + # output[:, i, :, :] = action[:, i : i + act_w, :] + # action = output # batch size, 7, 5, 2 + + # loss = self.vqbet._action_head.loss_fn( + # pred_action, + # action[:, self.config.n_obs_steps-1:], + # reduction="mean", + # )[0] + ############# + return {"loss": loss['actor_loss']} @@ -138,11 +216,22 @@ class VQBeTModel(nn.Module): self.rgb_encoder = DiffusionRgbEncoder(config) - global_cond_dim = (config.output_shapes["action"][0] + self.rgb_encoder.feature_dim) + self.global_cond_dim = self.rgb_encoder.feature_dim + + # action token and EOS token + self._action_token = nn.Parameter(torch.randn(1, 1, self.global_cond_dim)) # Batch, Timestep, Data type, GPT input dim + self._eos_token = nn.Parameter(torch.randn(1, 1, self.global_cond_dim)) + + self.state_projector = MLP( + config.output_shapes["action"][0], hidden_channels=[self.global_cond_dim] + ) + self.action_projector = MLP( + config.output_shapes["action"][0], hidden_channels=[self.global_cond_dim] + ) self._policy = GPT( GPTConfig( block_size=self.config.block_size, - input_dim=global_cond_dim, + input_dim=self.global_cond_dim, output_dim=self.config.output_dim, n_layer=self.config.n_layer, n_head=self.config.n_head, @@ -157,14 +246,14 @@ class VQBeTModel(nn.Module): vqvae_groups=config.vqvae_groups, vqvae_n_embed=config.vqvae_n_embed, vqvae_embedding_dim=config.vqvae_embedding_dim, - n_action_steps=config.n_action_steps + n_action_pred_chunk=config.n_action_pred_chunk ) def discretize(self, discretize_step, actions): return self._action_head.discretize(discretize_step, actions) # ========= inference ============ - def forward(self, batch: dict[str, Tensor]) -> Tensor: + def forward(self, batch: dict[str, Tensor], rollout: bool, input_predicted:list = None, recurrent = False) -> Tensor: # jay temp2 # Input validation. assert set(batch).issuperset({"observation.state", "observation.image"}) batch_size, n_obs_steps = batch["observation.state"].shape[:2] @@ -175,13 +264,8 @@ class VQBeTModel(nn.Module): # Separate batch and sequence dims. img_features = einops.rearrange(img_features, "(b n) ... -> b n ...", b=batch_size) # Concatenate state and image features then flatten to (B, global_cond_dim). - global_cond = torch.cat([batch["observation.state"], img_features], dim=-1) - - obs = global_cond - if 'action' in batch.keys(): - action = batch["action"] - else: - action = None + + if 'goal' in batch.keys(): goal = batch["goal"] num_goal_token = goal.shape[1] @@ -190,21 +274,70 @@ class VQBeTModel(nn.Module): num_goal_token = 0 + if rollout: + gpt_input_action = torch.cat((batch["action"], torch.zeros(batch_size, 1, batch["action"].shape[-1]).to(get_device_from_parameters(self))), dim=1) + else: + gpt_input_action = batch["action"][..., :n_obs_steps, :] + for i, val in enumerate(batch["frame_index"]): + if val < n_obs_steps - 1: + # Update original_data based on the rules + gpt_input_action[i, :-(val+1)] = batch["observation.state"][i, :-(val+1)] + + + gpt_input_action[:, -1] = 0 + + global_cond = torch.cat([ + torch.unsqueeze(img_features, dim=2), + torch.unsqueeze(self.state_projector(batch["observation.state"]), dim=2), + # torch.unsqueeze(self.action_projector(gpt_input_action), dim=2) # jay temp + ], dim=-2).view(batch_size, -1, self.global_cond_dim) # batch size, num tokens, GPT input dim + # global_cond = global_cond[:, :-1] # get rid of final action # jay temp + + + # 앞선 action을 주고 뒤를 예측하게 해도 되고, 그냥 싹 다 예측하게 해도 될듯 -> 일단 action 앞에거까지 주는걸로 + # 토큰이 그냥 single action이어도 되고, multi step (chunk) 여도 될듯 -> 일단 chunk로 + eos_token = self._eos_token.repeat(batch_size, 1, 1) + action_token = self._action_token.repeat(batch_size, self.config.n_action_pred_token, 1) + + # jay temp2 + # if (input_predicted is not None) and (len(input_predicted) > 0): + # len_input = len(input_predicted) + # action_token[:, :len_input] = self.action_projector(torch.stack(input_predicted, dim=1).to(get_device_from_parameters(self))) + + prompt_length = global_cond.shape[1]+1 + global_cond = torch.cat([global_cond, eos_token, action_token], dim=1) + + # get action features - features = self._policy(obs) - features = features[:, num_goal_token:] + features = self._policy(global_cond) + features = features[:, prompt_length:] # action head pred_action = self._action_head( features, - **{"action_seq": action}, + # **{"action_seq": action}, ) - if action is None: - return pred_action["predicted_action"][:, -1, :].reshape(batch_size, self.config.n_action_steps, -1) + if rollout: + return pred_action["predicted_action"][:, -1, :].reshape(batch_size, self.config.n_action_pred_chunk, -1) # original one + # return pred_action["predicted_action"] # jay temp2 + + # should put this part between "if" and "else" + # elif recurrent: # jay temp2 => training + # return pred_action else: + action = batch["action"] + n, total_w, act_dim = action.shape + act_w = self.config.n_action_pred_chunk + num_token = total_w + 1 - act_w + output_shape = (n, num_token, act_w, act_dim) + output = torch.empty(output_shape).to(action.device) + for i in range(num_token): + output[:, i, :, :] = action[:, i : i + act_w, :] + action = output # batch size, 7, 5, 2 + loss = self._action_head.loss_fn( pred_action, - action, + action[:, n_obs_steps-1:], reduction="mean", ) return pred_action, loss[0] if isinstance(loss, tuple) else loss @@ -223,7 +356,7 @@ class VQBeTHead(nn.Module): vqvae_groups=2, # G(number of groups) vqvae_n_embed=16, # C(number of code integers) vqvae_embedding_dim=512, # D(embedding dims) - n_action_steps=1, # action chunk size + n_action_pred_chunk=1, # action chunk size ): super().__init__() self.input_size = input_size @@ -236,7 +369,7 @@ class VQBeTHead(nn.Module): self._G = vqvae_groups # G(number of groups) self._C = vqvae_n_embed # C(number of code integers) self._D = vqvae_embedding_dim # D(embedding dims) - self.n_action_steps = n_action_steps # action chunk size + self.n_action_pred_chunk = n_action_pred_chunk # action chunk size if self.sequentially_select: @@ -259,12 +392,12 @@ class VQBeTHead(nn.Module): hidden_channels=[ self.hidden_size, self.hidden_size, - self._G * self._C * n_action_steps * self.output_size, + self._G * self._C * n_action_pred_chunk * self.output_size, ], ) # init vqvae vqvae_config = { - "action_chunk": self.n_action_steps, + "action_chunk": self.n_action_pred_chunk, "action_dim": self.output_size, "vqvae_n_latent_dims": self._D, "vqvae_n_embed": self._C, @@ -404,14 +537,14 @@ class VQBeTHead(nn.Module): predicted_action, "N T (W A) -> (N T) W A", W=self._vqvae_model.input_dim_h ) - n, total_w, act_dim = action_seq.shape - act_w = self._vqvae_model.input_dim_h - obs_w = total_w + 1 - act_w - output_shape = (n, obs_w, act_w, act_dim) - output = torch.empty(output_shape).to(action_seq.device) - for i in range(obs_w): - output[:, i, :, :] = action_seq[:, i : i + act_w, :] - action_seq = einops.rearrange(output, "N T W A -> (N T) W A") + # n, total_w, act_dim = action_seq.shape + # act_w = self._vqvae_model.input_dim_h + # obs_w = total_w + 1 - act_w + # output_shape = (n, obs_w, act_w, act_dim) + # output = torch.empty(output_shape).to(action_seq.device) + # for i in range(obs_w): + # output[:, i, :, :] = action_seq[:, i : i + act_w, :] + action_seq = einops.rearrange(action_seq, "N T W A -> (N T) W A") # Figure out the loss for the actions. # First, we need to find the closest cluster center for each action. state_vq, action_bins = self._vqvae_model.get_code( @@ -425,43 +558,43 @@ class VQBeTHead(nn.Module): offset_loss = torch.nn.L1Loss()(action_seq, predicted_action) action_diff = F.mse_loss( - einops.rearrange(action_seq, "(N T) W A -> N T W A", T=obs_w)[:, -1, 0, :], - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=obs_w)[ + einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[:, -1, 0, :], + einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)[ :, -1, 0, : ], ) # batch, time, windowsize (t ... t+N), action dim -> [:, -1, 0, :] is for rollout action_diff_tot = F.mse_loss( - einops.rearrange(action_seq, "(N T) W A -> N T W A", T=obs_w)[:, -1, :, :], - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=obs_w)[ + einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[:, -1, :, :], + einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)[ :, -1, :, : ], ) # batch, time, windowsize (t ... t+N), action dim -> [:, -1, 0, :] is for rollout action_diff_mean_res1 = ( abs( - einops.rearrange(action_seq, "(N T) W A -> N T W A", T=obs_w)[ + einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[ :, -1, 0, : ] - - einops.rearrange(decoded_action, "(N T) W A -> N T W A", T=obs_w)[ + - einops.rearrange(decoded_action, "(N T) W A -> N T W A", T=T)[ :, -1, 0, : ] ) ).mean() action_diff_mean_res2 = ( abs( - einops.rearrange(action_seq, "(N T) W A -> N T W A", T=obs_w)[ + einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[ :, -1, 0, : ] - - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=obs_w)[ + - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)[ :, -1, 0, : ] ) ).mean() action_diff_max = ( abs( - einops.rearrange(action_seq, "(N T) W A -> N T W A", T=obs_w)[ + einops.rearrange(action_seq, "(N T) W A -> N T W A", T=T)[ :, -1, 0, : ] - - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=obs_w)[ + - einops.rearrange(predicted_action, "(N T) W A -> N T W A", T=T)[ :, -1, 0, : ] ) @@ -692,6 +825,19 @@ class VQBeTOptimizer: learning_rate=cfg.training.bet_learning_rate, betas=cfg.training.bet_betas, ) + + self.bet_optimizer1.add_param_group( + {"params": policy.vqbet._action_token} + ) + self.bet_optimizer1.add_param_group( + {"params": policy.vqbet._eos_token} + ) + self.bet_optimizer1.add_param_group( + {"params": policy.vqbet.state_projector.parameters()} + ) + self.bet_optimizer1.add_param_group( + {"params": policy.vqbet.action_projector.parameters()} + ) # if policy.vqbet._action_head.sequentially_select: # self.bet_optimizer1.add_param_group( # {"params": policy.vqbet._action_head._map_to_cbet_preds_bin1.parameters()} diff --git a/lerobot/configs/policy/vqbet.yaml b/lerobot/configs/policy/vqbet.yaml index dbc6d920..5fff7657 100644 --- a/lerobot/configs/policy/vqbet.yaml +++ b/lerobot/configs/policy/vqbet.yaml @@ -22,10 +22,10 @@ override_dataset_stats: max: [511.0, 511.0] training: - offline_steps: 200000 + offline_steps: 800000 online_steps: 0 - eval_freq: 10000 - save_freq: 5000 + eval_freq: 20000 # jay + save_freq: 20000 log_freq: 250 save_model: true @@ -41,7 +41,7 @@ training: # VQ-BeT specific vqvae_lr: 1.0e-3 - discretize_step: 3000 + discretize_step: 20000 # jay bet_weight_decay: 2e-4 bet_learning_rate: 5.5e-5 bet_betas: [0.9, 0.999] @@ -49,7 +49,7 @@ training: delta_timestamps: observation.image: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]" observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]" - action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, ${policy.n_action_steps})]" + action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, ${policy.n_action_pred_token} + ${policy.n_action_pred_chunk} - 1)]" eval: n_episodes: 50 @@ -60,7 +60,9 @@ policy: # Input / output structure. n_obs_steps: 5 - n_action_steps: 5 + # n_action_steps: 7 # n_action_pred_token + n_action_pred_window - 1 + n_action_pred_token: 3 # 3 5 + n_action_pred_chunk: 5 # 5 1 jay temp2 input_shapes: # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? @@ -90,7 +92,7 @@ policy: vqvae_n_embed: 16 vqvae_embedding_dim: 256 # VQ-BeT - block_size: 50 + block_size: 500 output_dim: 256 # 512 n_layer: 6 # 8 n_head: 6 # 4