{ "cells": [ { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from safetensors.torch import load_file, save_file\n", "from pprint import pprint" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "original_ckpt_path = \"/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort/policy_last.ckpt\"\n", "converted_ckpt_path = \"/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort/model.safetensors\"\n", "\n", "comparison_main_path = \"/home/thomwolf/Documents/Github/lerobot/examples/real_robot_example/outputs/train/blue_red_debug_no_masking/checkpoints/last/pretrained_model/\"\n", "comparison_safetensor_path = comparison_main_path + \"model.safetensors\"\n", "comparison_config_json_path = comparison_main_path + \"config.json\"\n", "comparison_config_yaml_path = comparison_main_path + \"config.yaml\"" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "a = torch.load(original_ckpt_path)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "b = load_file(comparison_safetensor_path)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['model.action_head.bias',\n", " 'model.action_head.weight',\n", " 'model.backbone.bn1.bias',\n", " 'model.backbone.bn1.running_mean',\n", " 'model.backbone.bn1.running_var',\n", " 'model.backbone.bn1.weight',\n", " 'model.backbone.conv1.weight',\n", " 'model.backbone.layer1.0.bn1.bias',\n", " 'model.backbone.layer1.0.bn1.running_mean',\n", " 'model.backbone.layer1.0.bn1.running_var',\n", " 'model.backbone.layer1.0.bn1.weight',\n", " 'model.backbone.layer1.0.bn2.bias',\n", " 'model.backbone.layer1.0.bn2.running_mean',\n", " 'model.backbone.layer1.0.bn2.running_var',\n", " 'model.backbone.layer1.0.bn2.weight',\n", " 'model.backbone.layer1.0.conv1.weight',\n", " 'model.backbone.layer1.0.conv2.weight',\n", " 'model.backbone.layer1.1.bn1.bias',\n", " 'model.backbone.layer1.1.bn1.running_mean',\n", " 'model.backbone.layer1.1.bn1.running_var',\n", " 'model.backbone.layer1.1.bn1.weight',\n", " 'model.backbone.layer1.1.bn2.bias',\n", " 'model.backbone.layer1.1.bn2.running_mean',\n", " 'model.backbone.layer1.1.bn2.running_var',\n", " 'model.backbone.layer1.1.bn2.weight',\n", " 'model.backbone.layer1.1.conv1.weight',\n", " 'model.backbone.layer1.1.conv2.weight',\n", " 'model.backbone.layer2.0.bn1.bias',\n", " 'model.backbone.layer2.0.bn1.running_mean',\n", " 'model.backbone.layer2.0.bn1.running_var',\n", " 'model.backbone.layer2.0.bn1.weight',\n", " 'model.backbone.layer2.0.bn2.bias',\n", " 'model.backbone.layer2.0.bn2.running_mean',\n", " 'model.backbone.layer2.0.bn2.running_var',\n", " 'model.backbone.layer2.0.bn2.weight',\n", " 'model.backbone.layer2.0.conv1.weight',\n", " 'model.backbone.layer2.0.conv2.weight',\n", " 'model.backbone.layer2.0.downsample.0.weight',\n", " 'model.backbone.layer2.0.downsample.1.bias',\n", " 'model.backbone.layer2.0.downsample.1.running_mean',\n", " 'model.backbone.layer2.0.downsample.1.running_var',\n", " 'model.backbone.layer2.0.downsample.1.weight',\n", " 'model.backbone.layer2.1.bn1.bias',\n", " 'model.backbone.layer2.1.bn1.running_mean',\n", " 'model.backbone.layer2.1.bn1.running_var',\n", " 'model.backbone.layer2.1.bn1.weight',\n", " 'model.backbone.layer2.1.bn2.bias',\n", " 'model.backbone.layer2.1.bn2.running_mean',\n", " 'model.backbone.layer2.1.bn2.running_var',\n", " 'model.backbone.layer2.1.bn2.weight',\n", " 'model.backbone.layer2.1.conv1.weight',\n", " 'model.backbone.layer2.1.conv2.weight',\n", " 'model.backbone.layer3.0.bn1.bias',\n", " 'model.backbone.layer3.0.bn1.running_mean',\n", " 'model.backbone.layer3.0.bn1.running_var',\n", " 'model.backbone.layer3.0.bn1.weight',\n", " 'model.backbone.layer3.0.bn2.bias',\n", " 'model.backbone.layer3.0.bn2.running_mean',\n", " 'model.backbone.layer3.0.bn2.running_var',\n", " 'model.backbone.layer3.0.bn2.weight',\n", " 'model.backbone.layer3.0.conv1.weight',\n", " 'model.backbone.layer3.0.conv2.weight',\n", " 'model.backbone.layer3.0.downsample.0.weight',\n", " 'model.backbone.layer3.0.downsample.1.bias',\n", " 'model.backbone.layer3.0.downsample.1.running_mean',\n", " 'model.backbone.layer3.0.downsample.1.running_var',\n", " 'model.backbone.layer3.0.downsample.1.weight',\n", " 'model.backbone.layer3.1.bn1.bias',\n", " 'model.backbone.layer3.1.bn1.running_mean',\n", " 'model.backbone.layer3.1.bn1.running_var',\n", " 'model.backbone.layer3.1.bn1.weight',\n", " 'model.backbone.layer3.1.bn2.bias',\n", " 'model.backbone.layer3.1.bn2.running_mean',\n", " 'model.backbone.layer3.1.bn2.running_var',\n", " 'model.backbone.layer3.1.bn2.weight',\n", " 'model.backbone.layer3.1.conv1.weight',\n", " 'model.backbone.layer3.1.conv2.weight',\n", " 'model.backbone.layer4.0.bn1.bias',\n", " 'model.backbone.layer4.0.bn1.running_mean',\n", " 'model.backbone.layer4.0.bn1.running_var',\n", " 'model.backbone.layer4.0.bn1.weight',\n", " 'model.backbone.layer4.0.bn2.bias',\n", " 'model.backbone.layer4.0.bn2.running_mean',\n", " 'model.backbone.layer4.0.bn2.running_var',\n", " 'model.backbone.layer4.0.bn2.weight',\n", " 'model.backbone.layer4.0.conv1.weight',\n", " 'model.backbone.layer4.0.conv2.weight',\n", " 'model.backbone.layer4.0.downsample.0.weight',\n", " 'model.backbone.layer4.0.downsample.1.bias',\n", " 'model.backbone.layer4.0.downsample.1.running_mean',\n", " 'model.backbone.layer4.0.downsample.1.running_var',\n", " 'model.backbone.layer4.0.downsample.1.weight',\n", " 'model.backbone.layer4.1.bn1.bias',\n", " 'model.backbone.layer4.1.bn1.running_mean',\n", " 'model.backbone.layer4.1.bn1.running_var',\n", " 'model.backbone.layer4.1.bn1.weight',\n", " 'model.backbone.layer4.1.bn2.bias',\n", " 'model.backbone.layer4.1.bn2.running_mean',\n", " 'model.backbone.layer4.1.bn2.running_var',\n", " 'model.backbone.layer4.1.bn2.weight',\n", " 'model.backbone.layer4.1.conv1.weight',\n", " 'model.backbone.layer4.1.conv2.weight',\n", " 'model.decoder.layers.0.linear1.bias',\n", " 'model.decoder.layers.0.linear1.weight',\n", " 'model.decoder.layers.0.linear2.bias',\n", " 'model.decoder.layers.0.linear2.weight',\n", " 'model.decoder.layers.0.multihead_attn.in_proj_bias',\n", " 'model.decoder.layers.0.multihead_attn.in_proj_weight',\n", " 'model.decoder.layers.0.multihead_attn.out_proj.bias',\n", " 'model.decoder.layers.0.multihead_attn.out_proj.weight',\n", " 'model.decoder.layers.0.norm1.bias',\n", " 'model.decoder.layers.0.norm1.weight',\n", " 'model.decoder.layers.0.norm2.bias',\n", " 'model.decoder.layers.0.norm2.weight',\n", " 'model.decoder.layers.0.norm3.bias',\n", " 'model.decoder.layers.0.norm3.weight',\n", " 'model.decoder.layers.0.self_attn.in_proj_bias',\n", " 'model.decoder.layers.0.self_attn.in_proj_weight',\n", " 'model.decoder.layers.0.self_attn.out_proj.bias',\n", " 'model.decoder.layers.0.self_attn.out_proj.weight',\n", " 'model.decoder_pos_embed.weight',\n", " 'model.encoder.layers.0.linear1.bias',\n", " 'model.encoder.layers.0.linear1.weight',\n", " 'model.encoder.layers.0.linear2.bias',\n", " 'model.encoder.layers.0.linear2.weight',\n", " 'model.encoder.layers.0.norm1.bias',\n", " 'model.encoder.layers.0.norm1.weight',\n", " 'model.encoder.layers.0.norm2.bias',\n", " 'model.encoder.layers.0.norm2.weight',\n", " 'model.encoder.layers.0.self_attn.in_proj_bias',\n", " 'model.encoder.layers.0.self_attn.in_proj_weight',\n", " 'model.encoder.layers.0.self_attn.out_proj.bias',\n", " 'model.encoder.layers.0.self_attn.out_proj.weight',\n", " 'model.encoder.layers.1.linear1.bias',\n", " 'model.encoder.layers.1.linear1.weight',\n", " 'model.encoder.layers.1.linear2.bias',\n", " 'model.encoder.layers.1.linear2.weight',\n", " 'model.encoder.layers.1.norm1.bias',\n", " 'model.encoder.layers.1.norm1.weight',\n", " 'model.encoder.layers.1.norm2.bias',\n", " 'model.encoder.layers.1.norm2.weight',\n", " 'model.encoder.layers.1.self_attn.in_proj_bias',\n", " 'model.encoder.layers.1.self_attn.in_proj_weight',\n", " 'model.encoder.layers.1.self_attn.out_proj.bias',\n", " 'model.encoder.layers.1.self_attn.out_proj.weight',\n", " 'model.encoder.layers.2.linear1.bias',\n", " 'model.encoder.layers.2.linear1.weight',\n", " 'model.encoder.layers.2.linear2.bias',\n", " 'model.encoder.layers.2.linear2.weight',\n", " 'model.encoder.layers.2.norm1.bias',\n", " 'model.encoder.layers.2.norm1.weight',\n", " 'model.encoder.layers.2.norm2.bias',\n", " 'model.encoder.layers.2.norm2.weight',\n", " 'model.encoder.layers.2.self_attn.in_proj_bias',\n", " 'model.encoder.layers.2.self_attn.in_proj_weight',\n", " 'model.encoder.layers.2.self_attn.out_proj.bias',\n", " 'model.encoder.layers.2.self_attn.out_proj.weight',\n", " 'model.encoder.layers.3.linear1.bias',\n", " 'model.encoder.layers.3.linear1.weight',\n", " 'model.encoder.layers.3.linear2.bias',\n", " 'model.encoder.layers.3.linear2.weight',\n", " 'model.encoder.layers.3.norm1.bias',\n", " 'model.encoder.layers.3.norm1.weight',\n", " 'model.encoder.layers.3.norm2.bias',\n", " 'model.encoder.layers.3.norm2.weight',\n", " 'model.encoder.layers.3.self_attn.in_proj_bias',\n", " 'model.encoder.layers.3.self_attn.in_proj_weight',\n", " 'model.encoder.layers.3.self_attn.out_proj.bias',\n", " 'model.encoder.layers.3.self_attn.out_proj.weight',\n", " 'model.encoder_img_feat_input_proj.bias',\n", " 'model.encoder_img_feat_input_proj.weight',\n", " 'model.encoder_latent_input_proj.bias',\n", " 'model.encoder_latent_input_proj.weight',\n", " 'model.encoder_robot_and_latent_pos_embed.weight',\n", " 'model.encoder_robot_state_input_proj.bias',\n", " 'model.encoder_robot_state_input_proj.weight',\n", " 'model.vae_encoder.layers.0.linear1.bias',\n", " 'model.vae_encoder.layers.0.linear1.weight',\n", " 'model.vae_encoder.layers.0.linear2.bias',\n", " 'model.vae_encoder.layers.0.linear2.weight',\n", " 'model.vae_encoder.layers.0.norm1.bias',\n", " 'model.vae_encoder.layers.0.norm1.weight',\n", " 'model.vae_encoder.layers.0.norm2.bias',\n", " 'model.vae_encoder.layers.0.norm2.weight',\n", " 'model.vae_encoder.layers.0.self_attn.in_proj_bias',\n", " 'model.vae_encoder.layers.0.self_attn.in_proj_weight',\n", " 'model.vae_encoder.layers.0.self_attn.out_proj.bias',\n", " 'model.vae_encoder.layers.0.self_attn.out_proj.weight',\n", " 'model.vae_encoder.layers.1.linear1.bias',\n", " 'model.vae_encoder.layers.1.linear1.weight',\n", " 'model.vae_encoder.layers.1.linear2.bias',\n", " 'model.vae_encoder.layers.1.linear2.weight',\n", " 'model.vae_encoder.layers.1.norm1.bias',\n", " 'model.vae_encoder.layers.1.norm1.weight',\n", " 'model.vae_encoder.layers.1.norm2.bias',\n", " 'model.vae_encoder.layers.1.norm2.weight',\n", " 'model.vae_encoder.layers.1.self_attn.in_proj_bias',\n", " 'model.vae_encoder.layers.1.self_attn.in_proj_weight',\n", " 'model.vae_encoder.layers.1.self_attn.out_proj.bias',\n", " 'model.vae_encoder.layers.1.self_attn.out_proj.weight',\n", " 'model.vae_encoder.layers.2.linear1.bias',\n", " 'model.vae_encoder.layers.2.linear1.weight',\n", " 'model.vae_encoder.layers.2.linear2.bias',\n", " 'model.vae_encoder.layers.2.linear2.weight',\n", " 'model.vae_encoder.layers.2.norm1.bias',\n", " 'model.vae_encoder.layers.2.norm1.weight',\n", " 'model.vae_encoder.layers.2.norm2.bias',\n", " 'model.vae_encoder.layers.2.norm2.weight',\n", " 'model.vae_encoder.layers.2.self_attn.in_proj_bias',\n", " 'model.vae_encoder.layers.2.self_attn.in_proj_weight',\n", " 'model.vae_encoder.layers.2.self_attn.out_proj.bias',\n", " 'model.vae_encoder.layers.2.self_attn.out_proj.weight',\n", " 'model.vae_encoder.layers.3.linear1.bias',\n", " 'model.vae_encoder.layers.3.linear1.weight',\n", " 'model.vae_encoder.layers.3.linear2.bias',\n", " 'model.vae_encoder.layers.3.linear2.weight',\n", " 'model.vae_encoder.layers.3.norm1.bias',\n", " 'model.vae_encoder.layers.3.norm1.weight',\n", " 'model.vae_encoder.layers.3.norm2.bias',\n", " 'model.vae_encoder.layers.3.norm2.weight',\n", " 'model.vae_encoder.layers.3.self_attn.in_proj_bias',\n", " 'model.vae_encoder.layers.3.self_attn.in_proj_weight',\n", " 'model.vae_encoder.layers.3.self_attn.out_proj.bias',\n", " 'model.vae_encoder.layers.3.self_attn.out_proj.weight',\n", " 'model.vae_encoder_action_input_proj.bias',\n", " 'model.vae_encoder_action_input_proj.weight',\n", " 'model.vae_encoder_cls_embed.weight',\n", " 'model.vae_encoder_latent_output_proj.bias',\n", " 'model.vae_encoder_latent_output_proj.weight',\n", " 'model.vae_encoder_pos_enc',\n", " 'model.vae_encoder_robot_state_input_proj.bias',\n", " 'model.vae_encoder_robot_state_input_proj.weight',\n", " 'normalize_inputs.buffer_observation_images_front.mean',\n", " 'normalize_inputs.buffer_observation_images_front.std',\n", " 'normalize_inputs.buffer_observation_images_top.mean',\n", " 'normalize_inputs.buffer_observation_images_top.std',\n", " 'normalize_inputs.buffer_observation_state.mean',\n", " 'normalize_inputs.buffer_observation_state.std',\n", " 'normalize_targets.buffer_action.mean',\n", " 'normalize_targets.buffer_action.std',\n", " 'unnormalize_outputs.buffer_action.mean',\n", " 'unnormalize_outputs.buffer_action.std']\n" ] } ], "source": [ "dest = list(b.keys())\n", "pprint(dest)" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['model.pos_table',\n", " 'model.transformer.encoder.layers.0.self_attn.in_proj_weight',\n", " 'model.transformer.encoder.layers.0.self_attn.in_proj_bias',\n", " 'model.transformer.encoder.layers.0.self_attn.out_proj.weight',\n", " 'model.transformer.encoder.layers.0.self_attn.out_proj.bias',\n", " 'model.transformer.encoder.layers.0.linear1.weight',\n", " 'model.transformer.encoder.layers.0.linear1.bias',\n", " 'model.transformer.encoder.layers.0.linear2.weight',\n", " 'model.transformer.encoder.layers.0.linear2.bias',\n", " 'model.transformer.encoder.layers.0.norm1.weight',\n", " 'model.transformer.encoder.layers.0.norm1.bias',\n", " 'model.transformer.encoder.layers.0.norm2.weight',\n", " 'model.transformer.encoder.layers.0.norm2.bias',\n", " 'model.transformer.encoder.layers.1.self_attn.in_proj_weight',\n", " 'model.transformer.encoder.layers.1.self_attn.in_proj_bias',\n", " 'model.transformer.encoder.layers.1.self_attn.out_proj.weight',\n", " 'model.transformer.encoder.layers.1.self_attn.out_proj.bias',\n", " 'model.transformer.encoder.layers.1.linear1.weight',\n", " 'model.transformer.encoder.layers.1.linear1.bias',\n", " 'model.transformer.encoder.layers.1.linear2.weight',\n", " 'model.transformer.encoder.layers.1.linear2.bias',\n", " 'model.transformer.encoder.layers.1.norm1.weight',\n", " 'model.transformer.encoder.layers.1.norm1.bias',\n", " 'model.transformer.encoder.layers.1.norm2.weight',\n", " 'model.transformer.encoder.layers.1.norm2.bias',\n", " 'model.transformer.encoder.layers.2.self_attn.in_proj_weight',\n", " 'model.transformer.encoder.layers.2.self_attn.in_proj_bias',\n", " 'model.transformer.encoder.layers.2.self_attn.out_proj.weight',\n", " 'model.transformer.encoder.layers.2.self_attn.out_proj.bias',\n", " 'model.transformer.encoder.layers.2.linear1.weight',\n", " 'model.transformer.encoder.layers.2.linear1.bias',\n", " 'model.transformer.encoder.layers.2.linear2.weight',\n", " 'model.transformer.encoder.layers.2.linear2.bias',\n", " 'model.transformer.encoder.layers.2.norm1.weight',\n", " 'model.transformer.encoder.layers.2.norm1.bias',\n", " 'model.transformer.encoder.layers.2.norm2.weight',\n", " 'model.transformer.encoder.layers.2.norm2.bias',\n", " 'model.transformer.encoder.layers.3.self_attn.in_proj_weight',\n", " 'model.transformer.encoder.layers.3.self_attn.in_proj_bias',\n", " 'model.transformer.encoder.layers.3.self_attn.out_proj.weight',\n", " 'model.transformer.encoder.layers.3.self_attn.out_proj.bias',\n", " 'model.transformer.encoder.layers.3.linear1.weight',\n", " 'model.transformer.encoder.layers.3.linear1.bias',\n", " 'model.transformer.encoder.layers.3.linear2.weight',\n", " 'model.transformer.encoder.layers.3.linear2.bias',\n", " 'model.transformer.encoder.layers.3.norm1.weight',\n", " 'model.transformer.encoder.layers.3.norm1.bias',\n", " 'model.transformer.encoder.layers.3.norm2.weight',\n", " 'model.transformer.encoder.layers.3.norm2.bias',\n", " 'model.transformer.decoder.layers.0.self_attn.in_proj_weight',\n", " 'model.transformer.decoder.layers.0.self_attn.in_proj_bias',\n", " 'model.transformer.decoder.layers.0.self_attn.out_proj.weight',\n", " 'model.transformer.decoder.layers.0.self_attn.out_proj.bias',\n", " 'model.transformer.decoder.layers.0.multihead_attn.in_proj_weight',\n", " 'model.transformer.decoder.layers.0.multihead_attn.in_proj_bias',\n", " 'model.transformer.decoder.layers.0.multihead_attn.out_proj.weight',\n", " 'model.transformer.decoder.layers.0.multihead_attn.out_proj.bias',\n", " 'model.transformer.decoder.layers.0.linear1.weight',\n", " 'model.transformer.decoder.layers.0.linear1.bias',\n", " 'model.transformer.decoder.layers.0.linear2.weight',\n", " 'model.transformer.decoder.layers.0.linear2.bias',\n", " 'model.transformer.decoder.layers.0.norm1.weight',\n", " 'model.transformer.decoder.layers.0.norm1.bias',\n", " 'model.transformer.decoder.layers.0.norm2.weight',\n", " 'model.transformer.decoder.layers.0.norm2.bias',\n", " 'model.transformer.decoder.layers.0.norm3.weight',\n", " 'model.transformer.decoder.layers.0.norm3.bias',\n", " 'model.transformer.decoder.layers.1.self_attn.in_proj_weight',\n", " 'model.transformer.decoder.layers.1.self_attn.in_proj_bias',\n", " 'model.transformer.decoder.layers.1.self_attn.out_proj.weight',\n", " 'model.transformer.decoder.layers.1.self_attn.out_proj.bias',\n", " 'model.transformer.decoder.layers.1.multihead_attn.in_proj_weight',\n", " 'model.transformer.decoder.layers.1.multihead_attn.in_proj_bias',\n", " 'model.transformer.decoder.layers.1.multihead_attn.out_proj.weight',\n", " 'model.transformer.decoder.layers.1.multihead_attn.out_proj.bias',\n", " 'model.transformer.decoder.layers.1.linear1.weight',\n", " 'model.transformer.decoder.layers.1.linear1.bias',\n", " 'model.transformer.decoder.layers.1.linear2.weight',\n", " 'model.transformer.decoder.layers.1.linear2.bias',\n", " 'model.transformer.decoder.layers.1.norm1.weight',\n", " 'model.transformer.decoder.layers.1.norm1.bias',\n", " 'model.transformer.decoder.layers.1.norm2.weight',\n", " 'model.transformer.decoder.layers.1.norm2.bias',\n", " 'model.transformer.decoder.layers.1.norm3.weight',\n", " 'model.transformer.decoder.layers.1.norm3.bias',\n", " 'model.transformer.decoder.layers.2.self_attn.in_proj_weight',\n", " 'model.transformer.decoder.layers.2.self_attn.in_proj_bias',\n", " 'model.transformer.decoder.layers.2.self_attn.out_proj.weight',\n", " 'model.transformer.decoder.layers.2.self_attn.out_proj.bias',\n", " 'model.transformer.decoder.layers.2.multihead_attn.in_proj_weight',\n", " 'model.transformer.decoder.layers.2.multihead_attn.in_proj_bias',\n", " 'model.transformer.decoder.layers.2.multihead_attn.out_proj.weight',\n", " 'model.transformer.decoder.layers.2.multihead_attn.out_proj.bias',\n", " 'model.transformer.decoder.layers.2.linear1.weight',\n", " 'model.transformer.decoder.layers.2.linear1.bias',\n", " 'model.transformer.decoder.layers.2.linear2.weight',\n", " 'model.transformer.decoder.layers.2.linear2.bias',\n", " 'model.transformer.decoder.layers.2.norm1.weight',\n", " 'model.transformer.decoder.layers.2.norm1.bias',\n", " 'model.transformer.decoder.layers.2.norm2.weight',\n", " 'model.transformer.decoder.layers.2.norm2.bias',\n", " 'model.transformer.decoder.layers.2.norm3.weight',\n", " 'model.transformer.decoder.layers.2.norm3.bias',\n", " 'model.transformer.decoder.layers.3.self_attn.in_proj_weight',\n", " 'model.transformer.decoder.layers.3.self_attn.in_proj_bias',\n", " 'model.transformer.decoder.layers.3.self_attn.out_proj.weight',\n", " 'model.transformer.decoder.layers.3.self_attn.out_proj.bias',\n", " 'model.transformer.decoder.layers.3.multihead_attn.in_proj_weight',\n", " 'model.transformer.decoder.layers.3.multihead_attn.in_proj_bias',\n", " 'model.transformer.decoder.layers.3.multihead_attn.out_proj.weight',\n", " 'model.transformer.decoder.layers.3.multihead_attn.out_proj.bias',\n", " 'model.transformer.decoder.layers.3.linear1.weight',\n", " 'model.transformer.decoder.layers.3.linear1.bias',\n", " 'model.transformer.decoder.layers.3.linear2.weight',\n", " 'model.transformer.decoder.layers.3.linear2.bias',\n", " 'model.transformer.decoder.layers.3.norm1.weight',\n", " 'model.transformer.decoder.layers.3.norm1.bias',\n", " 'model.transformer.decoder.layers.3.norm2.weight',\n", " 'model.transformer.decoder.layers.3.norm2.bias',\n", " 'model.transformer.decoder.layers.3.norm3.weight',\n", " 'model.transformer.decoder.layers.3.norm3.bias',\n", " 'model.transformer.decoder.layers.4.self_attn.in_proj_weight',\n", " 'model.transformer.decoder.layers.4.self_attn.in_proj_bias',\n", " 'model.transformer.decoder.layers.4.self_attn.out_proj.weight',\n", " 'model.transformer.decoder.layers.4.self_attn.out_proj.bias',\n", " 'model.transformer.decoder.layers.4.multihead_attn.in_proj_weight',\n", " 'model.transformer.decoder.layers.4.multihead_attn.in_proj_bias',\n", " 'model.transformer.decoder.layers.4.multihead_attn.out_proj.weight',\n", " 'model.transformer.decoder.layers.4.multihead_attn.out_proj.bias',\n", " 'model.transformer.decoder.layers.4.linear1.weight',\n", " 'model.transformer.decoder.layers.4.linear1.bias',\n", " 'model.transformer.decoder.layers.4.linear2.weight',\n", " 'model.transformer.decoder.layers.4.linear2.bias',\n", " 'model.transformer.decoder.layers.4.norm1.weight',\n", " 'model.transformer.decoder.layers.4.norm1.bias',\n", " 'model.transformer.decoder.layers.4.norm2.weight',\n", " 'model.transformer.decoder.layers.4.norm2.bias',\n", " 'model.transformer.decoder.layers.4.norm3.weight',\n", " 'model.transformer.decoder.layers.4.norm3.bias',\n", " 'model.transformer.decoder.layers.5.self_attn.in_proj_weight',\n", " 'model.transformer.decoder.layers.5.self_attn.in_proj_bias',\n", " 'model.transformer.decoder.layers.5.self_attn.out_proj.weight',\n", " 'model.transformer.decoder.layers.5.self_attn.out_proj.bias',\n", " 'model.transformer.decoder.layers.5.multihead_attn.in_proj_weight',\n", " 'model.transformer.decoder.layers.5.multihead_attn.in_proj_bias',\n", " 'model.transformer.decoder.layers.5.multihead_attn.out_proj.weight',\n", " 'model.transformer.decoder.layers.5.multihead_attn.out_proj.bias',\n", " 'model.transformer.decoder.layers.5.linear1.weight',\n", " 'model.transformer.decoder.layers.5.linear1.bias',\n", " 'model.transformer.decoder.layers.5.linear2.weight',\n", " 'model.transformer.decoder.layers.5.linear2.bias',\n", " 'model.transformer.decoder.layers.5.norm1.weight',\n", " 'model.transformer.decoder.layers.5.norm1.bias',\n", " 'model.transformer.decoder.layers.5.norm2.weight',\n", " 'model.transformer.decoder.layers.5.norm2.bias',\n", " 'model.transformer.decoder.layers.5.norm3.weight',\n", " 'model.transformer.decoder.layers.5.norm3.bias',\n", " 'model.transformer.decoder.layers.6.self_attn.in_proj_weight',\n", " 'model.transformer.decoder.layers.6.self_attn.in_proj_bias',\n", " 'model.transformer.decoder.layers.6.self_attn.out_proj.weight',\n", " 'model.transformer.decoder.layers.6.self_attn.out_proj.bias',\n", " 'model.transformer.decoder.layers.6.multihead_attn.in_proj_weight',\n", " 'model.transformer.decoder.layers.6.multihead_attn.in_proj_bias',\n", " 'model.transformer.decoder.layers.6.multihead_attn.out_proj.weight',\n", " 'model.transformer.decoder.layers.6.multihead_attn.out_proj.bias',\n", " 'model.transformer.decoder.layers.6.linear1.weight',\n", " 'model.transformer.decoder.layers.6.linear1.bias',\n", " 'model.transformer.decoder.layers.6.linear2.weight',\n", " 'model.transformer.decoder.layers.6.linear2.bias',\n", " 'model.transformer.decoder.layers.6.norm1.weight',\n", " 'model.transformer.decoder.layers.6.norm1.bias',\n", " 'model.transformer.decoder.layers.6.norm2.weight',\n", " 'model.transformer.decoder.layers.6.norm2.bias',\n", " 'model.transformer.decoder.layers.6.norm3.weight',\n", " 'model.transformer.decoder.layers.6.norm3.bias',\n", " 'model.transformer.decoder.norm.weight',\n", " 'model.transformer.decoder.norm.bias',\n", " 'model.encoder.layers.0.self_attn.in_proj_weight',\n", " 'model.encoder.layers.0.self_attn.in_proj_bias',\n", " 'model.encoder.layers.0.self_attn.out_proj.weight',\n", " 'model.encoder.layers.0.self_attn.out_proj.bias',\n", " 'model.encoder.layers.0.linear1.weight',\n", " 'model.encoder.layers.0.linear1.bias',\n", " 'model.encoder.layers.0.linear2.weight',\n", " 'model.encoder.layers.0.linear2.bias',\n", " 'model.encoder.layers.0.norm1.weight',\n", " 'model.encoder.layers.0.norm1.bias',\n", " 'model.encoder.layers.0.norm2.weight',\n", " 'model.encoder.layers.0.norm2.bias',\n", " 'model.encoder.layers.1.self_attn.in_proj_weight',\n", " 'model.encoder.layers.1.self_attn.in_proj_bias',\n", " 'model.encoder.layers.1.self_attn.out_proj.weight',\n", " 'model.encoder.layers.1.self_attn.out_proj.bias',\n", " 'model.encoder.layers.1.linear1.weight',\n", " 'model.encoder.layers.1.linear1.bias',\n", " 'model.encoder.layers.1.linear2.weight',\n", " 'model.encoder.layers.1.linear2.bias',\n", " 'model.encoder.layers.1.norm1.weight',\n", " 'model.encoder.layers.1.norm1.bias',\n", " 'model.encoder.layers.1.norm2.weight',\n", " 'model.encoder.layers.1.norm2.bias',\n", " 'model.encoder.layers.2.self_attn.in_proj_weight',\n", " 'model.encoder.layers.2.self_attn.in_proj_bias',\n", " 'model.encoder.layers.2.self_attn.out_proj.weight',\n", " 'model.encoder.layers.2.self_attn.out_proj.bias',\n", " 'model.encoder.layers.2.linear1.weight',\n", " 'model.encoder.layers.2.linear1.bias',\n", " 'model.encoder.layers.2.linear2.weight',\n", " 'model.encoder.layers.2.linear2.bias',\n", " 'model.encoder.layers.2.norm1.weight',\n", " 'model.encoder.layers.2.norm1.bias',\n", " 'model.encoder.layers.2.norm2.weight',\n", " 'model.encoder.layers.2.norm2.bias',\n", " 'model.encoder.layers.3.self_attn.in_proj_weight',\n", " 'model.encoder.layers.3.self_attn.in_proj_bias',\n", " 'model.encoder.layers.3.self_attn.out_proj.weight',\n", " 'model.encoder.layers.3.self_attn.out_proj.bias',\n", " 'model.encoder.layers.3.linear1.weight',\n", " 'model.encoder.layers.3.linear1.bias',\n", " 'model.encoder.layers.3.linear2.weight',\n", " 'model.encoder.layers.3.linear2.bias',\n", " 'model.encoder.layers.3.norm1.weight',\n", " 'model.encoder.layers.3.norm1.bias',\n", " 'model.encoder.layers.3.norm2.weight',\n", " 'model.encoder.layers.3.norm2.bias',\n", " 'model.action_head.weight',\n", " 'model.action_head.bias',\n", " 'model.is_pad_head.weight',\n", " 'model.is_pad_head.bias',\n", " 'model.query_embed.weight',\n", " 'model.input_proj.weight',\n", " 'model.input_proj.bias',\n", " 'model.backbones.0.0.body.conv1.weight',\n", " 'model.backbones.0.0.body.bn1.weight',\n", " 'model.backbones.0.0.body.bn1.bias',\n", " 'model.backbones.0.0.body.bn1.running_mean',\n", " 'model.backbones.0.0.body.bn1.running_var',\n", " 'model.backbones.0.0.body.bn1.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer1.0.conv1.weight',\n", " 'model.backbones.0.0.body.layer1.0.bn1.weight',\n", " 'model.backbones.0.0.body.layer1.0.bn1.bias',\n", " 'model.backbones.0.0.body.layer1.0.bn1.running_mean',\n", " 'model.backbones.0.0.body.layer1.0.bn1.running_var',\n", " 'model.backbones.0.0.body.layer1.0.bn1.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer1.0.conv2.weight',\n", " 'model.backbones.0.0.body.layer1.0.bn2.weight',\n", " 'model.backbones.0.0.body.layer1.0.bn2.bias',\n", " 'model.backbones.0.0.body.layer1.0.bn2.running_mean',\n", " 'model.backbones.0.0.body.layer1.0.bn2.running_var',\n", " 'model.backbones.0.0.body.layer1.0.bn2.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer1.1.conv1.weight',\n", " 'model.backbones.0.0.body.layer1.1.bn1.weight',\n", " 'model.backbones.0.0.body.layer1.1.bn1.bias',\n", " 'model.backbones.0.0.body.layer1.1.bn1.running_mean',\n", " 'model.backbones.0.0.body.layer1.1.bn1.running_var',\n", " 'model.backbones.0.0.body.layer1.1.bn1.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer1.1.conv2.weight',\n", " 'model.backbones.0.0.body.layer1.1.bn2.weight',\n", " 'model.backbones.0.0.body.layer1.1.bn2.bias',\n", " 'model.backbones.0.0.body.layer1.1.bn2.running_mean',\n", " 'model.backbones.0.0.body.layer1.1.bn2.running_var',\n", " 'model.backbones.0.0.body.layer1.1.bn2.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer2.0.conv1.weight',\n", " 'model.backbones.0.0.body.layer2.0.bn1.weight',\n", " 'model.backbones.0.0.body.layer2.0.bn1.bias',\n", " 'model.backbones.0.0.body.layer2.0.bn1.running_mean',\n", " 'model.backbones.0.0.body.layer2.0.bn1.running_var',\n", " 'model.backbones.0.0.body.layer2.0.bn1.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer2.0.conv2.weight',\n", " 'model.backbones.0.0.body.layer2.0.bn2.weight',\n", " 'model.backbones.0.0.body.layer2.0.bn2.bias',\n", " 'model.backbones.0.0.body.layer2.0.bn2.running_mean',\n", " 'model.backbones.0.0.body.layer2.0.bn2.running_var',\n", " 'model.backbones.0.0.body.layer2.0.bn2.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer2.0.downsample.0.weight',\n", " 'model.backbones.0.0.body.layer2.0.downsample.1.weight',\n", " 'model.backbones.0.0.body.layer2.0.downsample.1.bias',\n", " 'model.backbones.0.0.body.layer2.0.downsample.1.running_mean',\n", " 'model.backbones.0.0.body.layer2.0.downsample.1.running_var',\n", " 'model.backbones.0.0.body.layer2.0.downsample.1.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer2.1.conv1.weight',\n", " 'model.backbones.0.0.body.layer2.1.bn1.weight',\n", " 'model.backbones.0.0.body.layer2.1.bn1.bias',\n", " 'model.backbones.0.0.body.layer2.1.bn1.running_mean',\n", " 'model.backbones.0.0.body.layer2.1.bn1.running_var',\n", " 'model.backbones.0.0.body.layer2.1.bn1.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer2.1.conv2.weight',\n", " 'model.backbones.0.0.body.layer2.1.bn2.weight',\n", " 'model.backbones.0.0.body.layer2.1.bn2.bias',\n", " 'model.backbones.0.0.body.layer2.1.bn2.running_mean',\n", " 'model.backbones.0.0.body.layer2.1.bn2.running_var',\n", " 'model.backbones.0.0.body.layer2.1.bn2.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer3.0.conv1.weight',\n", " 'model.backbones.0.0.body.layer3.0.bn1.weight',\n", " 'model.backbones.0.0.body.layer3.0.bn1.bias',\n", " 'model.backbones.0.0.body.layer3.0.bn1.running_mean',\n", " 'model.backbones.0.0.body.layer3.0.bn1.running_var',\n", " 'model.backbones.0.0.body.layer3.0.bn1.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer3.0.conv2.weight',\n", " 'model.backbones.0.0.body.layer3.0.bn2.weight',\n", " 'model.backbones.0.0.body.layer3.0.bn2.bias',\n", " 'model.backbones.0.0.body.layer3.0.bn2.running_mean',\n", " 'model.backbones.0.0.body.layer3.0.bn2.running_var',\n", " 'model.backbones.0.0.body.layer3.0.bn2.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer3.0.downsample.0.weight',\n", " 'model.backbones.0.0.body.layer3.0.downsample.1.weight',\n", " 'model.backbones.0.0.body.layer3.0.downsample.1.bias',\n", " 'model.backbones.0.0.body.layer3.0.downsample.1.running_mean',\n", " 'model.backbones.0.0.body.layer3.0.downsample.1.running_var',\n", " 'model.backbones.0.0.body.layer3.0.downsample.1.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer3.1.conv1.weight',\n", " 'model.backbones.0.0.body.layer3.1.bn1.weight',\n", " 'model.backbones.0.0.body.layer3.1.bn1.bias',\n", " 'model.backbones.0.0.body.layer3.1.bn1.running_mean',\n", " 'model.backbones.0.0.body.layer3.1.bn1.running_var',\n", " 'model.backbones.0.0.body.layer3.1.bn1.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer3.1.conv2.weight',\n", " 'model.backbones.0.0.body.layer3.1.bn2.weight',\n", " 'model.backbones.0.0.body.layer3.1.bn2.bias',\n", " 'model.backbones.0.0.body.layer3.1.bn2.running_mean',\n", " 'model.backbones.0.0.body.layer3.1.bn2.running_var',\n", " 'model.backbones.0.0.body.layer3.1.bn2.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer4.0.conv1.weight',\n", " 'model.backbones.0.0.body.layer4.0.bn1.weight',\n", " 'model.backbones.0.0.body.layer4.0.bn1.bias',\n", " 'model.backbones.0.0.body.layer4.0.bn1.running_mean',\n", " 'model.backbones.0.0.body.layer4.0.bn1.running_var',\n", " 'model.backbones.0.0.body.layer4.0.bn1.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer4.0.conv2.weight',\n", " 'model.backbones.0.0.body.layer4.0.bn2.weight',\n", " 'model.backbones.0.0.body.layer4.0.bn2.bias',\n", " 'model.backbones.0.0.body.layer4.0.bn2.running_mean',\n", " 'model.backbones.0.0.body.layer4.0.bn2.running_var',\n", " 'model.backbones.0.0.body.layer4.0.bn2.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer4.0.downsample.0.weight',\n", " 'model.backbones.0.0.body.layer4.0.downsample.1.weight',\n", " 'model.backbones.0.0.body.layer4.0.downsample.1.bias',\n", " 'model.backbones.0.0.body.layer4.0.downsample.1.running_mean',\n", " 'model.backbones.0.0.body.layer4.0.downsample.1.running_var',\n", " 'model.backbones.0.0.body.layer4.0.downsample.1.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer4.1.conv1.weight',\n", " 'model.backbones.0.0.body.layer4.1.bn1.weight',\n", " 'model.backbones.0.0.body.layer4.1.bn1.bias',\n", " 'model.backbones.0.0.body.layer4.1.bn1.running_mean',\n", " 'model.backbones.0.0.body.layer4.1.bn1.running_var',\n", " 'model.backbones.0.0.body.layer4.1.bn1.num_batches_tracked',\n", " 'model.backbones.0.0.body.layer4.1.conv2.weight',\n", " 'model.backbones.0.0.body.layer4.1.bn2.weight',\n", " 'model.backbones.0.0.body.layer4.1.bn2.bias',\n", " 'model.backbones.0.0.body.layer4.1.bn2.running_mean',\n", " 'model.backbones.0.0.body.layer4.1.bn2.running_var',\n", " 'model.backbones.0.0.body.layer4.1.bn2.num_batches_tracked',\n", " 'model.input_proj_robot_state.weight',\n", " 'model.input_proj_robot_state.bias',\n", " 'model.cls_embed.weight',\n", " 'model.encoder_action_proj.weight',\n", " 'model.encoder_action_proj.bias',\n", " 'model.encoder_joint_proj.weight',\n", " 'model.encoder_joint_proj.bias',\n", " 'model.latent_proj.weight',\n", " 'model.latent_proj.bias',\n", " 'model.latent_out_proj.weight',\n", " 'model.latent_out_proj.bias',\n", " 'model.additional_pos_embed.weight']\n" ] } ], "source": [ "orig = list(a.keys())\n", "pprint(orig)" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [], "source": [ "a = torch.load(original_ckpt_path)\n", "\n", "to_remove_startswith = ['model.transformer.decoder.layers.1.',\n", " 'model.transformer.decoder.layers.2.',\n", " 'model.transformer.decoder.layers.3.',\n", " 'model.transformer.decoder.layers.4.',\n", " 'model.transformer.decoder.layers.5.',\n", " 'model.transformer.decoder.layers.6.',\n", " 'model.transformer.decoder.norm.',\n", " 'model.is_pad_head']\n", "\n", "to_remove_in = ['num_batches_tracked',]\n", "\n", "conv = {}\n", "\n", "keys = list(a.keys())\n", "for k in keys:\n", " if any(k.startswith(tr) for tr in to_remove_startswith):\n", " a.pop(k)\n", " continue\n", " if any(tr in k for tr in to_remove_in):\n", " a.pop(k)\n", " continue\n", " if k.startswith('model.transformer.encoder.layers.'):\n", " conv[k.replace('transformer.', '')] = a.pop(k)\n", " if k.startswith('model.transformer.decoder.layers.0.'):\n", " conv[k.replace('transformer.', '')] = a.pop(k)\n", " if k.startswith('model.encoder.layers.'):\n", " conv[k.replace('encoder.', 'vae_encoder.')] = a.pop(k)\n", " if k.startswith('model.action_head.'):\n", " conv[k] = a.pop(k)\n", " if k.startswith('model.pos_table'):\n", " conv[k.replace('pos_table', 'vae_encoder_pos_enc')] = a.pop(k)\n", " if k.startswith('model.query_embed.'):\n", " conv[k.replace('query_embed', 'decoder_pos_embed')] = a.pop(k)\n", " if k.startswith('model.input_proj.'):\n", " conv[k.replace('input_proj.', 'encoder_img_feat_input_proj.')] = a.pop(k)\n", " if k.startswith('model.input_proj_robot_state.'):\n", " conv[k.replace('input_proj_robot_state.', 'encoder_robot_state_input_proj.')] = a.pop(k)\n", " if k.startswith('model.backbones.0.0.body.'):\n", " conv[k.replace('backbones.0.0.body', 'backbone')] = a.pop(k)\n", " if k.startswith('model.cls_embed.'):\n", " conv[k.replace('cls_embed', 'vae_encoder_cls_embed')] = a.pop(k)\n", " if k.startswith('model.encoder_action_proj.'):\n", " conv[k.replace('encoder_action_proj', 'vae_encoder_action_input_proj')] = a.pop(k)\n", " if k.startswith('model.encoder_joint_proj.'):\n", " conv[k.replace('encoder_joint_proj', 'vae_encoder_robot_state_input_proj')] = a.pop(k)\n", " if k.startswith('model.latent_proj.'):\n", " conv[k.replace('latent_proj', 'vae_encoder_latent_output_proj')] = a.pop(k)\n", " if k.startswith('model.latent_out_proj.'):\n", " conv[k.replace('latent_out_proj', 'encoder_latent_input_proj')] = a.pop(k)\n", " if k.startswith('model.additional_pos_embed.'):\n", " conv[k.replace('additional_pos_embed', 'encoder_robot_and_latent_pos_embed')] = a.pop(k)" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "OrderedDict()" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [], "source": [ "for k, v in conv.items():\n", " assert b[k].shape == v.shape\n", " b[k] = v" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "save_file(b, converted_ckpt_path)" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort/config.yaml'" ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Now also copy the config files\n", "import shutil\n", "shutil.copy(comparison_config_json_path, converted_ckpt_path.replace('model.safetensors', 'config.json'))\n", "shutil.copy(comparison_config_yaml_path, converted_ckpt_path.replace('model.safetensors', 'config.yaml'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "lerobot", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }