improve code
This commit is contained in:
parent
41ebc1bfb3
commit
31788f65dd
|
@ -85,14 +85,12 @@ python lerobot/scripts/train.py \
|
|||
--policy.pretrain_scaledp_path /path/to/pretrained/scale_dp_h/open_scale_dp_l_backbone.safetensors \
|
||||
--policy.policy_head_size 'scaledp_h' \
|
||||
--policy.training_stage 2 \
|
||||
--dataset.repo_i folding_blue_tshirt \
|
||||
--batch_size 2 \
|
||||
--dataset.repo_i lerobot/aloha_mobile_chair \
|
||||
--policy.using_film true \
|
||||
--output_dir /path/to/output \
|
||||
--steps 10000 \
|
||||
--save_freq 1000 \
|
||||
--optimizer_lr 2e-5 \
|
||||
--policy.device=cuda
|
||||
--optimizer_lr 2e-5
|
||||
~~~
|
||||
|
||||
### Training Stage 3
|
||||
|
@ -104,14 +102,13 @@ python lerobot/scripts/train.py \
|
|||
--.pretrained_path /path/to/pretrained/stage2/weights \
|
||||
--policy.policy_head_size 'scaledp_h' \
|
||||
--policy.training_stage 3 \
|
||||
--dataset.repo_i folding_blue_tshirt \
|
||||
--dataset.repo_i lerobot/aloha_mobile_chair \
|
||||
--batch_size 2 \
|
||||
--policy.using_film true \
|
||||
--output_dir /path/to/output \
|
||||
--steps 10000 \
|
||||
--save_freq 1000 \
|
||||
--optimizer_lr 2e-5 \
|
||||
--policy.device=cuda
|
||||
--optimizer_lr 2e-5
|
||||
~~~
|
||||
|
||||
### Training Time
|
||||
|
@ -136,8 +133,7 @@ python lerobot/scripts/eval.py \
|
|||
--policy.qwen2_vl_path /path/to/official/Qwen2-VL-2B-Instruct \
|
||||
--env.task AlohaInsertion-v0 \
|
||||
--eval.n_episodes 1 \
|
||||
--eval.batch_size 1 \
|
||||
--device cuda
|
||||
--eval.batch_size 1
|
||||
~~~
|
||||
|
||||
### Inference Speed
|
||||
|
|
|
@ -50,6 +50,8 @@ class DexVLAConfig(PreTrainedConfig):
|
|||
n_action_steps: int = 50
|
||||
n_obs_steps: int = 1
|
||||
|
||||
device: str = "cuda"
|
||||
|
||||
hidden_size: int = 1536
|
||||
qwen2_vl_path: str = (
|
||||
None # '/media/rl/HDD/data/weights/Qwen2-VL-2B-Instruct', official weights of qwen2vl
|
||||
|
|
|
@ -50,9 +50,9 @@ class FiLM(nn.Module):
|
|||
nn.init.zeros_(self.shift_fc.bias)
|
||||
|
||||
def forward(self, x, condition):
|
||||
# 计算缩放和偏移参数
|
||||
# calculate scale and shift
|
||||
scale = self.scale_fc(condition)
|
||||
shift = self.shift_fc(condition)
|
||||
|
||||
# 应用 FiLM 调制
|
||||
# film
|
||||
return x * (1 + scale) + shift
|
||||
|
|
|
@ -125,7 +125,7 @@ class DexVLAPolicy(PreTrainedPolicy):
|
|||
try:
|
||||
reasonings = batch["reasoning"]
|
||||
except KeyError:
|
||||
reasonings = ["no reasoning"] * len(task_descs)
|
||||
reasonings = ["None."] * len(task_descs)
|
||||
|
||||
pass
|
||||
is_pad = batch["action_is_pad"]
|
||||
|
@ -208,10 +208,13 @@ class DexVLAPolicy(PreTrainedPolicy):
|
|||
all_hidden_states = torch.cat(last_hidden_states, dim=1)
|
||||
|
||||
action_hidden_states = None
|
||||
labels_input = torch.ones((1, input_token_len)) * -100
|
||||
labels_output = torch.ones((1, output_ids.shape[1] - input_token_len))
|
||||
labels = torch.cat([labels_input, labels_output], dim=1)
|
||||
|
||||
if self.model.using_film:
|
||||
action_hidden_states = self.model.film_forward(
|
||||
labels=torch.ones_like(output_ids),
|
||||
labels=labels,
|
||||
input_ids=output_ids,
|
||||
hidden_states=torch.cat(last_hidden_states, dim=1),
|
||||
)
|
||||
|
|
|
@ -69,10 +69,16 @@ class Qwen2VLAProcess:
|
|||
)
|
||||
|
||||
if eval:
|
||||
return model_inputs
|
||||
new_dict = {}
|
||||
for k, v in model_inputs.items():
|
||||
if "image_grid" in k:
|
||||
new_dict["image_grid_spatiotemporal"] = v
|
||||
else:
|
||||
new_dict[k] = v
|
||||
return new_dict
|
||||
|
||||
input_labels = torch.ones_like(model_inputs["input_ids"]) * -100
|
||||
answer = reasoning + "Next action:" + "<|im_end|>" if use_reasoning else "" + "<|im_end|>"
|
||||
answer = reasoning + " Next action:" + "<|im_end|>" if use_reasoning else "" + "<|im_end|>"
|
||||
|
||||
output_text = self.tokenizer(answer, padding=True, return_tensors="pt")
|
||||
output_labels = output_text["input_ids"]
|
||||
|
@ -84,7 +90,10 @@ class Qwen2VLAProcess:
|
|||
|
||||
data_dict["labels"] = labels
|
||||
for k, v in model_inputs.items():
|
||||
data_dict[k] = v
|
||||
if "image_grid" in k:
|
||||
data_dict["image_grid_spatiotemporal"] = v
|
||||
else:
|
||||
data_dict[k] = v
|
||||
return data_dict
|
||||
|
||||
def forward(self, batch, use_reasoning=True):
|
||||
|
|
Loading…
Reference in New Issue