replace torch.load with safe_open

This commit is contained in:
lesjie-wen 2025-03-18 18:03:05 +08:00
parent a2f80f42fc
commit 8d03cc8ad2
2 changed files with 4 additions and 7 deletions

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Qwen2VL model configuration"""
from .policy_heads import register_policy_heads
from .qwe2_vla import register_qwen2_vla
from dataclasses import dataclass, field
from typing import Tuple
@ -28,9 +30,6 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
logger = logging.get_logger(__name__)
from .policy_heads import register_policy_heads
from .qwe2_vla import register_qwen2_vla
register_policy_heads()
register_qwen2_vla()

View File

@ -12,7 +12,7 @@ from lerobot.common.policies.pretrained import PreTrainedPolicy
from collections import deque
from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM
import torchvision.transforms as transforms
from safetensors.torch import load_file
class DexVLAPolicy(PreTrainedPolicy):
"""Wrapper class around Qwen2VLForConditionalGenerationForVLA model to train and run inference within LeRobot."""
@ -61,9 +61,7 @@ class DexVLAPolicy(PreTrainedPolicy):
print(
"\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>Loading pretrained ScaleDP weights...<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<"
)
pretrain_scaledp_weights = torch.load(self.config.pretrained_scaledp_path, map_location="cpu")
pretrain_scaledp_weights = pretrain_scaledp_weights["nets"]["nets"]
pretrain_scaledp_weights = load_file(self.config.pretrained_scaledp_path)
keys_to_del_dit = []
pretrain_scaledp_weights = {