improve code

This commit is contained in:
lesjie-wen 2025-04-03 13:47:47 +08:00
parent 41ebc1bfb3
commit 31788f65dd
5 changed files with 26 additions and 16 deletions

View File

@ -85,14 +85,12 @@ python lerobot/scripts/train.py \
--policy.pretrain_scaledp_path /path/to/pretrained/scale_dp_h/open_scale_dp_l_backbone.safetensors \
--policy.policy_head_size 'scaledp_h' \
--policy.training_stage 2 \
--dataset.repo_i folding_blue_tshirt \
--batch_size 2 \
--dataset.repo_i lerobot/aloha_mobile_chair \
--policy.using_film true \
--output_dir /path/to/output \
--steps 10000 \
--save_freq 1000 \
--optimizer_lr 2e-5 \
--policy.device=cuda
--optimizer_lr 2e-5
~~~
### Training Stage 3
@ -104,14 +102,13 @@ python lerobot/scripts/train.py \
--.pretrained_path /path/to/pretrained/stage2/weights \
--policy.policy_head_size 'scaledp_h' \
--policy.training_stage 3 \
--dataset.repo_i folding_blue_tshirt \
--dataset.repo_i lerobot/aloha_mobile_chair \
--batch_size 2 \
--policy.using_film true \
--output_dir /path/to/output \
--steps 10000 \
--save_freq 1000 \
--optimizer_lr 2e-5 \
--policy.device=cuda
--optimizer_lr 2e-5
~~~
### Training Time
@ -136,8 +133,7 @@ python lerobot/scripts/eval.py \
--policy.qwen2_vl_path /path/to/official/Qwen2-VL-2B-Instruct \
--env.task AlohaInsertion-v0 \
--eval.n_episodes 1 \
--eval.batch_size 1 \
--device cuda
--eval.batch_size 1
~~~
### Inference Speed

View File

@ -50,6 +50,8 @@ class DexVLAConfig(PreTrainedConfig):
n_action_steps: int = 50
n_obs_steps: int = 1
device: str = "cuda"
hidden_size: int = 1536
qwen2_vl_path: str = (
None # '/media/rl/HDD/data/weights/Qwen2-VL-2B-Instruct', official weights of qwen2vl

View File

@ -50,9 +50,9 @@ class FiLM(nn.Module):
nn.init.zeros_(self.shift_fc.bias)
def forward(self, x, condition):
# 计算缩放和偏移参数
# calculate scale and shift
scale = self.scale_fc(condition)
shift = self.shift_fc(condition)
# 应用 FiLM 调制
# film
return x * (1 + scale) + shift

View File

@ -125,7 +125,7 @@ class DexVLAPolicy(PreTrainedPolicy):
try:
reasonings = batch["reasoning"]
except KeyError:
reasonings = ["no reasoning"] * len(task_descs)
reasonings = ["None."] * len(task_descs)
pass
is_pad = batch["action_is_pad"]
@ -208,10 +208,13 @@ class DexVLAPolicy(PreTrainedPolicy):
all_hidden_states = torch.cat(last_hidden_states, dim=1)
action_hidden_states = None
labels_input = torch.ones((1, input_token_len)) * -100
labels_output = torch.ones((1, output_ids.shape[1] - input_token_len))
labels = torch.cat([labels_input, labels_output], dim=1)
if self.model.using_film:
action_hidden_states = self.model.film_forward(
labels=torch.ones_like(output_ids),
labels=labels,
input_ids=output_ids,
hidden_states=torch.cat(last_hidden_states, dim=1),
)

View File

@ -69,10 +69,16 @@ class Qwen2VLAProcess:
)
if eval:
return model_inputs
new_dict = {}
for k, v in model_inputs.items():
if "image_grid" in k:
new_dict["image_grid_spatiotemporal"] = v
else:
new_dict[k] = v
return new_dict
input_labels = torch.ones_like(model_inputs["input_ids"]) * -100
answer = reasoning + "Next action:" + "<|im_end|>" if use_reasoning else "" + "<|im_end|>"
answer = reasoning + " Next action:" + "<|im_end|>" if use_reasoning else "" + "<|im_end|>"
output_text = self.tokenizer(answer, padding=True, return_tensors="pt")
output_labels = output_text["input_ids"]
@ -84,7 +90,10 @@ class Qwen2VLAProcess:
data_dict["labels"] = labels
for k, v in model_inputs.items():
data_dict[k] = v
if "image_grid" in k:
data_dict["image_grid_spatiotemporal"] = v
else:
data_dict[k] = v
return data_dict
def forward(self, batch, use_reasoning=True):