replace torch.load with safe_open
This commit is contained in:
parent
a2f80f42fc
commit
8d03cc8ad2
|
@ -12,6 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Qwen2VL model configuration"""
|
||||
from .policy_heads import register_policy_heads
|
||||
from .qwe2_vla import register_qwen2_vla
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Tuple
|
||||
|
@ -28,9 +30,6 @@ from lerobot.configs.policies import PreTrainedConfig
|
|||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
from .policy_heads import register_policy_heads
|
||||
from .qwe2_vla import register_qwen2_vla
|
||||
|
||||
register_policy_heads()
|
||||
register_qwen2_vla()
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ from lerobot.common.policies.pretrained import PreTrainedPolicy
|
|||
from collections import deque
|
||||
from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from safetensors.torch import load_file
|
||||
class DexVLAPolicy(PreTrainedPolicy):
|
||||
"""Wrapper class around Qwen2VLForConditionalGenerationForVLA model to train and run inference within LeRobot."""
|
||||
|
||||
|
@ -61,9 +61,7 @@ class DexVLAPolicy(PreTrainedPolicy):
|
|||
print(
|
||||
"\n>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>Loading pretrained ScaleDP weights...<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<"
|
||||
)
|
||||
pretrain_scaledp_weights = torch.load(self.config.pretrained_scaledp_path, map_location="cpu")
|
||||
|
||||
pretrain_scaledp_weights = pretrain_scaledp_weights["nets"]["nets"]
|
||||
pretrain_scaledp_weights = load_file(self.config.pretrained_scaledp_path)
|
||||
|
||||
keys_to_del_dit = []
|
||||
pretrain_scaledp_weights = {
|
||||
|
|
Loading…
Reference in New Issue