diff --git a/lerobot/common/policies/dexvla/README.md b/lerobot/common/policies/dexvla/README.md index b34a40bb..4ea17308 100644 --- a/lerobot/common/policies/dexvla/README.md +++ b/lerobot/common/policies/dexvla/README.md @@ -85,14 +85,12 @@ python lerobot/scripts/train.py \ --policy.pretrain_scaledp_path /path/to/pretrained/scale_dp_h/open_scale_dp_l_backbone.safetensors \ --policy.policy_head_size 'scaledp_h' \ --policy.training_stage 2 \ ---dataset.repo_i folding_blue_tshirt \ ---batch_size 2 \ +--dataset.repo_i lerobot/aloha_mobile_chair \ --policy.using_film true \ --output_dir /path/to/output \ --steps 10000 \ --save_freq 1000 \ ---optimizer_lr 2e-5 \ ---policy.device=cuda +--optimizer_lr 2e-5 ~~~ ### Training Stage 3 @@ -104,14 +102,13 @@ python lerobot/scripts/train.py \ --.pretrained_path /path/to/pretrained/stage2/weights \ --policy.policy_head_size 'scaledp_h' \ --policy.training_stage 3 \ ---dataset.repo_i folding_blue_tshirt \ +--dataset.repo_i lerobot/aloha_mobile_chair \ --batch_size 2 \ --policy.using_film true \ --output_dir /path/to/output \ --steps 10000 \ --save_freq 1000 \ ---optimizer_lr 2e-5 \ ---policy.device=cuda +--optimizer_lr 2e-5 ~~~ ### Training Time @@ -136,8 +133,7 @@ python lerobot/scripts/eval.py \ --policy.qwen2_vl_path /path/to/official/Qwen2-VL-2B-Instruct \ --env.task AlohaInsertion-v0 \ --eval.n_episodes 1 \ ---eval.batch_size 1 \ ---device cuda +--eval.batch_size 1 ~~~ ### Inference Speed diff --git a/lerobot/common/policies/dexvla/configuration_dexvla.py b/lerobot/common/policies/dexvla/configuration_dexvla.py index a4743361..96a3944b 100644 --- a/lerobot/common/policies/dexvla/configuration_dexvla.py +++ b/lerobot/common/policies/dexvla/configuration_dexvla.py @@ -50,6 +50,8 @@ class DexVLAConfig(PreTrainedConfig): n_action_steps: int = 50 n_obs_steps: int = 1 + device: str = "cuda" + hidden_size: int = 1536 qwen2_vl_path: str = ( None # '/media/rl/HDD/data/weights/Qwen2-VL-2B-Instruct', official weights of qwen2vl diff --git a/lerobot/common/policies/dexvla/fusion_modules.py b/lerobot/common/policies/dexvla/fusion_modules.py index 701a4ada..39bbc57f 100644 --- a/lerobot/common/policies/dexvla/fusion_modules.py +++ b/lerobot/common/policies/dexvla/fusion_modules.py @@ -50,9 +50,9 @@ class FiLM(nn.Module): nn.init.zeros_(self.shift_fc.bias) def forward(self, x, condition): - # 计算缩放和偏移参数 + # calculate scale and shift scale = self.scale_fc(condition) shift = self.shift_fc(condition) - # 应用 FiLM 调制 + # film return x * (1 + scale) + shift diff --git a/lerobot/common/policies/dexvla/modeling_dexvla.py b/lerobot/common/policies/dexvla/modeling_dexvla.py index c932bdd7..e1133df8 100644 --- a/lerobot/common/policies/dexvla/modeling_dexvla.py +++ b/lerobot/common/policies/dexvla/modeling_dexvla.py @@ -125,7 +125,7 @@ class DexVLAPolicy(PreTrainedPolicy): try: reasonings = batch["reasoning"] except KeyError: - reasonings = ["no reasoning"] * len(task_descs) + reasonings = ["None."] * len(task_descs) pass is_pad = batch["action_is_pad"] @@ -208,10 +208,13 @@ class DexVLAPolicy(PreTrainedPolicy): all_hidden_states = torch.cat(last_hidden_states, dim=1) action_hidden_states = None + labels_input = torch.ones((1, input_token_len)) * -100 + labels_output = torch.ones((1, output_ids.shape[1] - input_token_len)) + labels = torch.cat([labels_input, labels_output], dim=1) if self.model.using_film: action_hidden_states = self.model.film_forward( - labels=torch.ones_like(output_ids), + labels=labels, input_ids=output_ids, hidden_states=torch.cat(last_hidden_states, dim=1), ) diff --git a/lerobot/common/policies/dexvla/robot_data_processor.py b/lerobot/common/policies/dexvla/robot_data_processor.py index eba11890..7af0aa05 100644 --- a/lerobot/common/policies/dexvla/robot_data_processor.py +++ b/lerobot/common/policies/dexvla/robot_data_processor.py @@ -69,10 +69,16 @@ class Qwen2VLAProcess: ) if eval: - return model_inputs + new_dict = {} + for k, v in model_inputs.items(): + if "image_grid" in k: + new_dict["image_grid_spatiotemporal"] = v + else: + new_dict[k] = v + return new_dict input_labels = torch.ones_like(model_inputs["input_ids"]) * -100 - answer = reasoning + "Next action:" + "<|im_end|>" if use_reasoning else "" + "<|im_end|>" + answer = reasoning + " Next action:" + "<|im_end|>" if use_reasoning else "" + "<|im_end|>" output_text = self.tokenizer(answer, padding=True, return_tensors="pt") output_labels = output_text["input_ids"] @@ -84,7 +90,10 @@ class Qwen2VLAProcess: data_dict["labels"] = labels for k, v in model_inputs.items(): - data_dict[k] = v + if "image_grid" in k: + data_dict["image_grid_spatiotemporal"] = v + else: + data_dict[k] = v return data_dict def forward(self, batch, use_reasoning=True):