diff --git a/examples/real_robot_example/convert_original_act_checkpoint.ipynb b/examples/real_robot_example/convert_original_act_checkpoint.ipynb new file mode 100644 index 00000000..92306b86 --- /dev/null +++ b/examples/real_robot_example/convert_original_act_checkpoint.ipynb @@ -0,0 +1,840 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from safetensors.torch import load_file, save_file\n", + "from pprint import pprint" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [], + "source": [ + "original_ckpt_path = \"/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort/policy_last.ckpt\"\n", + "converted_ckpt_path = \"/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort/model.safetensors\"\n", + "\n", + "comparison_main_path = \"/home/thomwolf/Documents/Github/lerobot/examples/real_robot_example/outputs/train/blue_red_debug_no_masking/checkpoints/last/pretrained_model/\"\n", + "comparison_safetensor_path = comparison_main_path + \"model.safetensors\"\n", + "comparison_config_json_path = comparison_main_path + \"config.json\"\n", + "comparison_config_yaml_path = comparison_main_path + \"config.yaml\"" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "a = torch.load(original_ckpt_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "b = load_file(comparison_safetensor_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['model.action_head.bias',\n", + " 'model.action_head.weight',\n", + " 'model.backbone.bn1.bias',\n", + " 'model.backbone.bn1.running_mean',\n", + " 'model.backbone.bn1.running_var',\n", + " 'model.backbone.bn1.weight',\n", + " 'model.backbone.conv1.weight',\n", + " 'model.backbone.layer1.0.bn1.bias',\n", + " 'model.backbone.layer1.0.bn1.running_mean',\n", + " 'model.backbone.layer1.0.bn1.running_var',\n", + " 'model.backbone.layer1.0.bn1.weight',\n", + " 'model.backbone.layer1.0.bn2.bias',\n", + " 'model.backbone.layer1.0.bn2.running_mean',\n", + " 'model.backbone.layer1.0.bn2.running_var',\n", + " 'model.backbone.layer1.0.bn2.weight',\n", + " 'model.backbone.layer1.0.conv1.weight',\n", + " 'model.backbone.layer1.0.conv2.weight',\n", + " 'model.backbone.layer1.1.bn1.bias',\n", + " 'model.backbone.layer1.1.bn1.running_mean',\n", + " 'model.backbone.layer1.1.bn1.running_var',\n", + " 'model.backbone.layer1.1.bn1.weight',\n", + " 'model.backbone.layer1.1.bn2.bias',\n", + " 'model.backbone.layer1.1.bn2.running_mean',\n", + " 'model.backbone.layer1.1.bn2.running_var',\n", + " 'model.backbone.layer1.1.bn2.weight',\n", + " 'model.backbone.layer1.1.conv1.weight',\n", + " 'model.backbone.layer1.1.conv2.weight',\n", + " 'model.backbone.layer2.0.bn1.bias',\n", + " 'model.backbone.layer2.0.bn1.running_mean',\n", + " 'model.backbone.layer2.0.bn1.running_var',\n", + " 'model.backbone.layer2.0.bn1.weight',\n", + " 'model.backbone.layer2.0.bn2.bias',\n", + " 'model.backbone.layer2.0.bn2.running_mean',\n", + " 'model.backbone.layer2.0.bn2.running_var',\n", + " 'model.backbone.layer2.0.bn2.weight',\n", + " 'model.backbone.layer2.0.conv1.weight',\n", + " 'model.backbone.layer2.0.conv2.weight',\n", + " 'model.backbone.layer2.0.downsample.0.weight',\n", + " 'model.backbone.layer2.0.downsample.1.bias',\n", + " 'model.backbone.layer2.0.downsample.1.running_mean',\n", + " 'model.backbone.layer2.0.downsample.1.running_var',\n", + " 'model.backbone.layer2.0.downsample.1.weight',\n", + " 'model.backbone.layer2.1.bn1.bias',\n", + " 'model.backbone.layer2.1.bn1.running_mean',\n", + " 'model.backbone.layer2.1.bn1.running_var',\n", + " 'model.backbone.layer2.1.bn1.weight',\n", + " 'model.backbone.layer2.1.bn2.bias',\n", + " 'model.backbone.layer2.1.bn2.running_mean',\n", + " 'model.backbone.layer2.1.bn2.running_var',\n", + " 'model.backbone.layer2.1.bn2.weight',\n", + " 'model.backbone.layer2.1.conv1.weight',\n", + " 'model.backbone.layer2.1.conv2.weight',\n", + " 'model.backbone.layer3.0.bn1.bias',\n", + " 'model.backbone.layer3.0.bn1.running_mean',\n", + " 'model.backbone.layer3.0.bn1.running_var',\n", + " 'model.backbone.layer3.0.bn1.weight',\n", + " 'model.backbone.layer3.0.bn2.bias',\n", + " 'model.backbone.layer3.0.bn2.running_mean',\n", + " 'model.backbone.layer3.0.bn2.running_var',\n", + " 'model.backbone.layer3.0.bn2.weight',\n", + " 'model.backbone.layer3.0.conv1.weight',\n", + " 'model.backbone.layer3.0.conv2.weight',\n", + " 'model.backbone.layer3.0.downsample.0.weight',\n", + " 'model.backbone.layer3.0.downsample.1.bias',\n", + " 'model.backbone.layer3.0.downsample.1.running_mean',\n", + " 'model.backbone.layer3.0.downsample.1.running_var',\n", + " 'model.backbone.layer3.0.downsample.1.weight',\n", + " 'model.backbone.layer3.1.bn1.bias',\n", + " 'model.backbone.layer3.1.bn1.running_mean',\n", + " 'model.backbone.layer3.1.bn1.running_var',\n", + " 'model.backbone.layer3.1.bn1.weight',\n", + " 'model.backbone.layer3.1.bn2.bias',\n", + " 'model.backbone.layer3.1.bn2.running_mean',\n", + " 'model.backbone.layer3.1.bn2.running_var',\n", + " 'model.backbone.layer3.1.bn2.weight',\n", + " 'model.backbone.layer3.1.conv1.weight',\n", + " 'model.backbone.layer3.1.conv2.weight',\n", + " 'model.backbone.layer4.0.bn1.bias',\n", + " 'model.backbone.layer4.0.bn1.running_mean',\n", + " 'model.backbone.layer4.0.bn1.running_var',\n", + " 'model.backbone.layer4.0.bn1.weight',\n", + " 'model.backbone.layer4.0.bn2.bias',\n", + " 'model.backbone.layer4.0.bn2.running_mean',\n", + " 'model.backbone.layer4.0.bn2.running_var',\n", + " 'model.backbone.layer4.0.bn2.weight',\n", + " 'model.backbone.layer4.0.conv1.weight',\n", + " 'model.backbone.layer4.0.conv2.weight',\n", + " 'model.backbone.layer4.0.downsample.0.weight',\n", + " 'model.backbone.layer4.0.downsample.1.bias',\n", + " 'model.backbone.layer4.0.downsample.1.running_mean',\n", + " 'model.backbone.layer4.0.downsample.1.running_var',\n", + " 'model.backbone.layer4.0.downsample.1.weight',\n", + " 'model.backbone.layer4.1.bn1.bias',\n", + " 'model.backbone.layer4.1.bn1.running_mean',\n", + " 'model.backbone.layer4.1.bn1.running_var',\n", + " 'model.backbone.layer4.1.bn1.weight',\n", + " 'model.backbone.layer4.1.bn2.bias',\n", + " 'model.backbone.layer4.1.bn2.running_mean',\n", + " 'model.backbone.layer4.1.bn2.running_var',\n", + " 'model.backbone.layer4.1.bn2.weight',\n", + " 'model.backbone.layer4.1.conv1.weight',\n", + " 'model.backbone.layer4.1.conv2.weight',\n", + " 'model.decoder.layers.0.linear1.bias',\n", + " 'model.decoder.layers.0.linear1.weight',\n", + " 'model.decoder.layers.0.linear2.bias',\n", + " 'model.decoder.layers.0.linear2.weight',\n", + " 'model.decoder.layers.0.multihead_attn.in_proj_bias',\n", + " 'model.decoder.layers.0.multihead_attn.in_proj_weight',\n", + " 'model.decoder.layers.0.multihead_attn.out_proj.bias',\n", + " 'model.decoder.layers.0.multihead_attn.out_proj.weight',\n", + " 'model.decoder.layers.0.norm1.bias',\n", + " 'model.decoder.layers.0.norm1.weight',\n", + " 'model.decoder.layers.0.norm2.bias',\n", + " 'model.decoder.layers.0.norm2.weight',\n", + " 'model.decoder.layers.0.norm3.bias',\n", + " 'model.decoder.layers.0.norm3.weight',\n", + " 'model.decoder.layers.0.self_attn.in_proj_bias',\n", + " 'model.decoder.layers.0.self_attn.in_proj_weight',\n", + " 'model.decoder.layers.0.self_attn.out_proj.bias',\n", + " 'model.decoder.layers.0.self_attn.out_proj.weight',\n", + " 'model.decoder_pos_embed.weight',\n", + " 'model.encoder.layers.0.linear1.bias',\n", + " 'model.encoder.layers.0.linear1.weight',\n", + " 'model.encoder.layers.0.linear2.bias',\n", + " 'model.encoder.layers.0.linear2.weight',\n", + " 'model.encoder.layers.0.norm1.bias',\n", + " 'model.encoder.layers.0.norm1.weight',\n", + " 'model.encoder.layers.0.norm2.bias',\n", + " 'model.encoder.layers.0.norm2.weight',\n", + " 'model.encoder.layers.0.self_attn.in_proj_bias',\n", + " 'model.encoder.layers.0.self_attn.in_proj_weight',\n", + " 'model.encoder.layers.0.self_attn.out_proj.bias',\n", + " 'model.encoder.layers.0.self_attn.out_proj.weight',\n", + " 'model.encoder.layers.1.linear1.bias',\n", + " 'model.encoder.layers.1.linear1.weight',\n", + " 'model.encoder.layers.1.linear2.bias',\n", + " 'model.encoder.layers.1.linear2.weight',\n", + " 'model.encoder.layers.1.norm1.bias',\n", + " 'model.encoder.layers.1.norm1.weight',\n", + " 'model.encoder.layers.1.norm2.bias',\n", + " 'model.encoder.layers.1.norm2.weight',\n", + " 'model.encoder.layers.1.self_attn.in_proj_bias',\n", + " 'model.encoder.layers.1.self_attn.in_proj_weight',\n", + " 'model.encoder.layers.1.self_attn.out_proj.bias',\n", + " 'model.encoder.layers.1.self_attn.out_proj.weight',\n", + " 'model.encoder.layers.2.linear1.bias',\n", + " 'model.encoder.layers.2.linear1.weight',\n", + " 'model.encoder.layers.2.linear2.bias',\n", + " 'model.encoder.layers.2.linear2.weight',\n", + " 'model.encoder.layers.2.norm1.bias',\n", + " 'model.encoder.layers.2.norm1.weight',\n", + " 'model.encoder.layers.2.norm2.bias',\n", + " 'model.encoder.layers.2.norm2.weight',\n", + " 'model.encoder.layers.2.self_attn.in_proj_bias',\n", + " 'model.encoder.layers.2.self_attn.in_proj_weight',\n", + " 'model.encoder.layers.2.self_attn.out_proj.bias',\n", + " 'model.encoder.layers.2.self_attn.out_proj.weight',\n", + " 'model.encoder.layers.3.linear1.bias',\n", + " 'model.encoder.layers.3.linear1.weight',\n", + " 'model.encoder.layers.3.linear2.bias',\n", + " 'model.encoder.layers.3.linear2.weight',\n", + " 'model.encoder.layers.3.norm1.bias',\n", + " 'model.encoder.layers.3.norm1.weight',\n", + " 'model.encoder.layers.3.norm2.bias',\n", + " 'model.encoder.layers.3.norm2.weight',\n", + " 'model.encoder.layers.3.self_attn.in_proj_bias',\n", + " 'model.encoder.layers.3.self_attn.in_proj_weight',\n", + " 'model.encoder.layers.3.self_attn.out_proj.bias',\n", + " 'model.encoder.layers.3.self_attn.out_proj.weight',\n", + " 'model.encoder_img_feat_input_proj.bias',\n", + " 'model.encoder_img_feat_input_proj.weight',\n", + " 'model.encoder_latent_input_proj.bias',\n", + " 'model.encoder_latent_input_proj.weight',\n", + " 'model.encoder_robot_and_latent_pos_embed.weight',\n", + " 'model.encoder_robot_state_input_proj.bias',\n", + " 'model.encoder_robot_state_input_proj.weight',\n", + " 'model.vae_encoder.layers.0.linear1.bias',\n", + " 'model.vae_encoder.layers.0.linear1.weight',\n", + " 'model.vae_encoder.layers.0.linear2.bias',\n", + " 'model.vae_encoder.layers.0.linear2.weight',\n", + " 'model.vae_encoder.layers.0.norm1.bias',\n", + " 'model.vae_encoder.layers.0.norm1.weight',\n", + " 'model.vae_encoder.layers.0.norm2.bias',\n", + " 'model.vae_encoder.layers.0.norm2.weight',\n", + " 'model.vae_encoder.layers.0.self_attn.in_proj_bias',\n", + " 'model.vae_encoder.layers.0.self_attn.in_proj_weight',\n", + " 'model.vae_encoder.layers.0.self_attn.out_proj.bias',\n", + " 'model.vae_encoder.layers.0.self_attn.out_proj.weight',\n", + " 'model.vae_encoder.layers.1.linear1.bias',\n", + " 'model.vae_encoder.layers.1.linear1.weight',\n", + " 'model.vae_encoder.layers.1.linear2.bias',\n", + " 'model.vae_encoder.layers.1.linear2.weight',\n", + " 'model.vae_encoder.layers.1.norm1.bias',\n", + " 'model.vae_encoder.layers.1.norm1.weight',\n", + " 'model.vae_encoder.layers.1.norm2.bias',\n", + " 'model.vae_encoder.layers.1.norm2.weight',\n", + " 'model.vae_encoder.layers.1.self_attn.in_proj_bias',\n", + " 'model.vae_encoder.layers.1.self_attn.in_proj_weight',\n", + " 'model.vae_encoder.layers.1.self_attn.out_proj.bias',\n", + " 'model.vae_encoder.layers.1.self_attn.out_proj.weight',\n", + " 'model.vae_encoder.layers.2.linear1.bias',\n", + " 'model.vae_encoder.layers.2.linear1.weight',\n", + " 'model.vae_encoder.layers.2.linear2.bias',\n", + " 'model.vae_encoder.layers.2.linear2.weight',\n", + " 'model.vae_encoder.layers.2.norm1.bias',\n", + " 'model.vae_encoder.layers.2.norm1.weight',\n", + " 'model.vae_encoder.layers.2.norm2.bias',\n", + " 'model.vae_encoder.layers.2.norm2.weight',\n", + " 'model.vae_encoder.layers.2.self_attn.in_proj_bias',\n", + " 'model.vae_encoder.layers.2.self_attn.in_proj_weight',\n", + " 'model.vae_encoder.layers.2.self_attn.out_proj.bias',\n", + " 'model.vae_encoder.layers.2.self_attn.out_proj.weight',\n", + " 'model.vae_encoder.layers.3.linear1.bias',\n", + " 'model.vae_encoder.layers.3.linear1.weight',\n", + " 'model.vae_encoder.layers.3.linear2.bias',\n", + " 'model.vae_encoder.layers.3.linear2.weight',\n", + " 'model.vae_encoder.layers.3.norm1.bias',\n", + " 'model.vae_encoder.layers.3.norm1.weight',\n", + " 'model.vae_encoder.layers.3.norm2.bias',\n", + " 'model.vae_encoder.layers.3.norm2.weight',\n", + " 'model.vae_encoder.layers.3.self_attn.in_proj_bias',\n", + " 'model.vae_encoder.layers.3.self_attn.in_proj_weight',\n", + " 'model.vae_encoder.layers.3.self_attn.out_proj.bias',\n", + " 'model.vae_encoder.layers.3.self_attn.out_proj.weight',\n", + " 'model.vae_encoder_action_input_proj.bias',\n", + " 'model.vae_encoder_action_input_proj.weight',\n", + " 'model.vae_encoder_cls_embed.weight',\n", + " 'model.vae_encoder_latent_output_proj.bias',\n", + " 'model.vae_encoder_latent_output_proj.weight',\n", + " 'model.vae_encoder_pos_enc',\n", + " 'model.vae_encoder_robot_state_input_proj.bias',\n", + " 'model.vae_encoder_robot_state_input_proj.weight',\n", + " 'normalize_inputs.buffer_observation_images_front.mean',\n", + " 'normalize_inputs.buffer_observation_images_front.std',\n", + " 'normalize_inputs.buffer_observation_images_top.mean',\n", + " 'normalize_inputs.buffer_observation_images_top.std',\n", + " 'normalize_inputs.buffer_observation_state.mean',\n", + " 'normalize_inputs.buffer_observation_state.std',\n", + " 'normalize_targets.buffer_action.mean',\n", + " 'normalize_targets.buffer_action.std',\n", + " 'unnormalize_outputs.buffer_action.mean',\n", + " 'unnormalize_outputs.buffer_action.std']\n" + ] + } + ], + "source": [ + "dest = list(b.keys())\n", + "pprint(dest)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['model.pos_table',\n", + " 'model.transformer.encoder.layers.0.self_attn.in_proj_weight',\n", + " 'model.transformer.encoder.layers.0.self_attn.in_proj_bias',\n", + " 'model.transformer.encoder.layers.0.self_attn.out_proj.weight',\n", + " 'model.transformer.encoder.layers.0.self_attn.out_proj.bias',\n", + " 'model.transformer.encoder.layers.0.linear1.weight',\n", + " 'model.transformer.encoder.layers.0.linear1.bias',\n", + " 'model.transformer.encoder.layers.0.linear2.weight',\n", + " 'model.transformer.encoder.layers.0.linear2.bias',\n", + " 'model.transformer.encoder.layers.0.norm1.weight',\n", + " 'model.transformer.encoder.layers.0.norm1.bias',\n", + " 'model.transformer.encoder.layers.0.norm2.weight',\n", + " 'model.transformer.encoder.layers.0.norm2.bias',\n", + " 'model.transformer.encoder.layers.1.self_attn.in_proj_weight',\n", + " 'model.transformer.encoder.layers.1.self_attn.in_proj_bias',\n", + " 'model.transformer.encoder.layers.1.self_attn.out_proj.weight',\n", + " 'model.transformer.encoder.layers.1.self_attn.out_proj.bias',\n", + " 'model.transformer.encoder.layers.1.linear1.weight',\n", + " 'model.transformer.encoder.layers.1.linear1.bias',\n", + " 'model.transformer.encoder.layers.1.linear2.weight',\n", + " 'model.transformer.encoder.layers.1.linear2.bias',\n", + " 'model.transformer.encoder.layers.1.norm1.weight',\n", + " 'model.transformer.encoder.layers.1.norm1.bias',\n", + " 'model.transformer.encoder.layers.1.norm2.weight',\n", + " 'model.transformer.encoder.layers.1.norm2.bias',\n", + " 'model.transformer.encoder.layers.2.self_attn.in_proj_weight',\n", + " 'model.transformer.encoder.layers.2.self_attn.in_proj_bias',\n", + " 'model.transformer.encoder.layers.2.self_attn.out_proj.weight',\n", + " 'model.transformer.encoder.layers.2.self_attn.out_proj.bias',\n", + " 'model.transformer.encoder.layers.2.linear1.weight',\n", + " 'model.transformer.encoder.layers.2.linear1.bias',\n", + " 'model.transformer.encoder.layers.2.linear2.weight',\n", + " 'model.transformer.encoder.layers.2.linear2.bias',\n", + " 'model.transformer.encoder.layers.2.norm1.weight',\n", + " 'model.transformer.encoder.layers.2.norm1.bias',\n", + " 'model.transformer.encoder.layers.2.norm2.weight',\n", + " 'model.transformer.encoder.layers.2.norm2.bias',\n", + " 'model.transformer.encoder.layers.3.self_attn.in_proj_weight',\n", + " 'model.transformer.encoder.layers.3.self_attn.in_proj_bias',\n", + " 'model.transformer.encoder.layers.3.self_attn.out_proj.weight',\n", + " 'model.transformer.encoder.layers.3.self_attn.out_proj.bias',\n", + " 'model.transformer.encoder.layers.3.linear1.weight',\n", + " 'model.transformer.encoder.layers.3.linear1.bias',\n", + " 'model.transformer.encoder.layers.3.linear2.weight',\n", + " 'model.transformer.encoder.layers.3.linear2.bias',\n", + " 'model.transformer.encoder.layers.3.norm1.weight',\n", + " 'model.transformer.encoder.layers.3.norm1.bias',\n", + " 'model.transformer.encoder.layers.3.norm2.weight',\n", + " 'model.transformer.encoder.layers.3.norm2.bias',\n", + " 'model.transformer.decoder.layers.0.self_attn.in_proj_weight',\n", + " 'model.transformer.decoder.layers.0.self_attn.in_proj_bias',\n", + " 'model.transformer.decoder.layers.0.self_attn.out_proj.weight',\n", + " 'model.transformer.decoder.layers.0.self_attn.out_proj.bias',\n", + " 'model.transformer.decoder.layers.0.multihead_attn.in_proj_weight',\n", + " 'model.transformer.decoder.layers.0.multihead_attn.in_proj_bias',\n", + " 'model.transformer.decoder.layers.0.multihead_attn.out_proj.weight',\n", + " 'model.transformer.decoder.layers.0.multihead_attn.out_proj.bias',\n", + " 'model.transformer.decoder.layers.0.linear1.weight',\n", + " 'model.transformer.decoder.layers.0.linear1.bias',\n", + " 'model.transformer.decoder.layers.0.linear2.weight',\n", + " 'model.transformer.decoder.layers.0.linear2.bias',\n", + " 'model.transformer.decoder.layers.0.norm1.weight',\n", + " 'model.transformer.decoder.layers.0.norm1.bias',\n", + " 'model.transformer.decoder.layers.0.norm2.weight',\n", + " 'model.transformer.decoder.layers.0.norm2.bias',\n", + " 'model.transformer.decoder.layers.0.norm3.weight',\n", + " 'model.transformer.decoder.layers.0.norm3.bias',\n", + " 'model.transformer.decoder.layers.1.self_attn.in_proj_weight',\n", + " 'model.transformer.decoder.layers.1.self_attn.in_proj_bias',\n", + " 'model.transformer.decoder.layers.1.self_attn.out_proj.weight',\n", + " 'model.transformer.decoder.layers.1.self_attn.out_proj.bias',\n", + " 'model.transformer.decoder.layers.1.multihead_attn.in_proj_weight',\n", + " 'model.transformer.decoder.layers.1.multihead_attn.in_proj_bias',\n", + " 'model.transformer.decoder.layers.1.multihead_attn.out_proj.weight',\n", + " 'model.transformer.decoder.layers.1.multihead_attn.out_proj.bias',\n", + " 'model.transformer.decoder.layers.1.linear1.weight',\n", + " 'model.transformer.decoder.layers.1.linear1.bias',\n", + " 'model.transformer.decoder.layers.1.linear2.weight',\n", + " 'model.transformer.decoder.layers.1.linear2.bias',\n", + " 'model.transformer.decoder.layers.1.norm1.weight',\n", + " 'model.transformer.decoder.layers.1.norm1.bias',\n", + " 'model.transformer.decoder.layers.1.norm2.weight',\n", + " 'model.transformer.decoder.layers.1.norm2.bias',\n", + " 'model.transformer.decoder.layers.1.norm3.weight',\n", + " 'model.transformer.decoder.layers.1.norm3.bias',\n", + " 'model.transformer.decoder.layers.2.self_attn.in_proj_weight',\n", + " 'model.transformer.decoder.layers.2.self_attn.in_proj_bias',\n", + " 'model.transformer.decoder.layers.2.self_attn.out_proj.weight',\n", + " 'model.transformer.decoder.layers.2.self_attn.out_proj.bias',\n", + " 'model.transformer.decoder.layers.2.multihead_attn.in_proj_weight',\n", + " 'model.transformer.decoder.layers.2.multihead_attn.in_proj_bias',\n", + " 'model.transformer.decoder.layers.2.multihead_attn.out_proj.weight',\n", + " 'model.transformer.decoder.layers.2.multihead_attn.out_proj.bias',\n", + " 'model.transformer.decoder.layers.2.linear1.weight',\n", + " 'model.transformer.decoder.layers.2.linear1.bias',\n", + " 'model.transformer.decoder.layers.2.linear2.weight',\n", + " 'model.transformer.decoder.layers.2.linear2.bias',\n", + " 'model.transformer.decoder.layers.2.norm1.weight',\n", + " 'model.transformer.decoder.layers.2.norm1.bias',\n", + " 'model.transformer.decoder.layers.2.norm2.weight',\n", + " 'model.transformer.decoder.layers.2.norm2.bias',\n", + " 'model.transformer.decoder.layers.2.norm3.weight',\n", + " 'model.transformer.decoder.layers.2.norm3.bias',\n", + " 'model.transformer.decoder.layers.3.self_attn.in_proj_weight',\n", + " 'model.transformer.decoder.layers.3.self_attn.in_proj_bias',\n", + " 'model.transformer.decoder.layers.3.self_attn.out_proj.weight',\n", + " 'model.transformer.decoder.layers.3.self_attn.out_proj.bias',\n", + " 'model.transformer.decoder.layers.3.multihead_attn.in_proj_weight',\n", + " 'model.transformer.decoder.layers.3.multihead_attn.in_proj_bias',\n", + " 'model.transformer.decoder.layers.3.multihead_attn.out_proj.weight',\n", + " 'model.transformer.decoder.layers.3.multihead_attn.out_proj.bias',\n", + " 'model.transformer.decoder.layers.3.linear1.weight',\n", + " 'model.transformer.decoder.layers.3.linear1.bias',\n", + " 'model.transformer.decoder.layers.3.linear2.weight',\n", + " 'model.transformer.decoder.layers.3.linear2.bias',\n", + " 'model.transformer.decoder.layers.3.norm1.weight',\n", + " 'model.transformer.decoder.layers.3.norm1.bias',\n", + " 'model.transformer.decoder.layers.3.norm2.weight',\n", + " 'model.transformer.decoder.layers.3.norm2.bias',\n", + " 'model.transformer.decoder.layers.3.norm3.weight',\n", + " 'model.transformer.decoder.layers.3.norm3.bias',\n", + " 'model.transformer.decoder.layers.4.self_attn.in_proj_weight',\n", + " 'model.transformer.decoder.layers.4.self_attn.in_proj_bias',\n", + " 'model.transformer.decoder.layers.4.self_attn.out_proj.weight',\n", + " 'model.transformer.decoder.layers.4.self_attn.out_proj.bias',\n", + " 'model.transformer.decoder.layers.4.multihead_attn.in_proj_weight',\n", + " 'model.transformer.decoder.layers.4.multihead_attn.in_proj_bias',\n", + " 'model.transformer.decoder.layers.4.multihead_attn.out_proj.weight',\n", + " 'model.transformer.decoder.layers.4.multihead_attn.out_proj.bias',\n", + " 'model.transformer.decoder.layers.4.linear1.weight',\n", + " 'model.transformer.decoder.layers.4.linear1.bias',\n", + " 'model.transformer.decoder.layers.4.linear2.weight',\n", + " 'model.transformer.decoder.layers.4.linear2.bias',\n", + " 'model.transformer.decoder.layers.4.norm1.weight',\n", + " 'model.transformer.decoder.layers.4.norm1.bias',\n", + " 'model.transformer.decoder.layers.4.norm2.weight',\n", + " 'model.transformer.decoder.layers.4.norm2.bias',\n", + " 'model.transformer.decoder.layers.4.norm3.weight',\n", + " 'model.transformer.decoder.layers.4.norm3.bias',\n", + " 'model.transformer.decoder.layers.5.self_attn.in_proj_weight',\n", + " 'model.transformer.decoder.layers.5.self_attn.in_proj_bias',\n", + " 'model.transformer.decoder.layers.5.self_attn.out_proj.weight',\n", + " 'model.transformer.decoder.layers.5.self_attn.out_proj.bias',\n", + " 'model.transformer.decoder.layers.5.multihead_attn.in_proj_weight',\n", + " 'model.transformer.decoder.layers.5.multihead_attn.in_proj_bias',\n", + " 'model.transformer.decoder.layers.5.multihead_attn.out_proj.weight',\n", + " 'model.transformer.decoder.layers.5.multihead_attn.out_proj.bias',\n", + " 'model.transformer.decoder.layers.5.linear1.weight',\n", + " 'model.transformer.decoder.layers.5.linear1.bias',\n", + " 'model.transformer.decoder.layers.5.linear2.weight',\n", + " 'model.transformer.decoder.layers.5.linear2.bias',\n", + " 'model.transformer.decoder.layers.5.norm1.weight',\n", + " 'model.transformer.decoder.layers.5.norm1.bias',\n", + " 'model.transformer.decoder.layers.5.norm2.weight',\n", + " 'model.transformer.decoder.layers.5.norm2.bias',\n", + " 'model.transformer.decoder.layers.5.norm3.weight',\n", + " 'model.transformer.decoder.layers.5.norm3.bias',\n", + " 'model.transformer.decoder.layers.6.self_attn.in_proj_weight',\n", + " 'model.transformer.decoder.layers.6.self_attn.in_proj_bias',\n", + " 'model.transformer.decoder.layers.6.self_attn.out_proj.weight',\n", + " 'model.transformer.decoder.layers.6.self_attn.out_proj.bias',\n", + " 'model.transformer.decoder.layers.6.multihead_attn.in_proj_weight',\n", + " 'model.transformer.decoder.layers.6.multihead_attn.in_proj_bias',\n", + " 'model.transformer.decoder.layers.6.multihead_attn.out_proj.weight',\n", + " 'model.transformer.decoder.layers.6.multihead_attn.out_proj.bias',\n", + " 'model.transformer.decoder.layers.6.linear1.weight',\n", + " 'model.transformer.decoder.layers.6.linear1.bias',\n", + " 'model.transformer.decoder.layers.6.linear2.weight',\n", + " 'model.transformer.decoder.layers.6.linear2.bias',\n", + " 'model.transformer.decoder.layers.6.norm1.weight',\n", + " 'model.transformer.decoder.layers.6.norm1.bias',\n", + " 'model.transformer.decoder.layers.6.norm2.weight',\n", + " 'model.transformer.decoder.layers.6.norm2.bias',\n", + " 'model.transformer.decoder.layers.6.norm3.weight',\n", + " 'model.transformer.decoder.layers.6.norm3.bias',\n", + " 'model.transformer.decoder.norm.weight',\n", + " 'model.transformer.decoder.norm.bias',\n", + " 'model.encoder.layers.0.self_attn.in_proj_weight',\n", + " 'model.encoder.layers.0.self_attn.in_proj_bias',\n", + " 'model.encoder.layers.0.self_attn.out_proj.weight',\n", + " 'model.encoder.layers.0.self_attn.out_proj.bias',\n", + " 'model.encoder.layers.0.linear1.weight',\n", + " 'model.encoder.layers.0.linear1.bias',\n", + " 'model.encoder.layers.0.linear2.weight',\n", + " 'model.encoder.layers.0.linear2.bias',\n", + " 'model.encoder.layers.0.norm1.weight',\n", + " 'model.encoder.layers.0.norm1.bias',\n", + " 'model.encoder.layers.0.norm2.weight',\n", + " 'model.encoder.layers.0.norm2.bias',\n", + " 'model.encoder.layers.1.self_attn.in_proj_weight',\n", + " 'model.encoder.layers.1.self_attn.in_proj_bias',\n", + " 'model.encoder.layers.1.self_attn.out_proj.weight',\n", + " 'model.encoder.layers.1.self_attn.out_proj.bias',\n", + " 'model.encoder.layers.1.linear1.weight',\n", + " 'model.encoder.layers.1.linear1.bias',\n", + " 'model.encoder.layers.1.linear2.weight',\n", + " 'model.encoder.layers.1.linear2.bias',\n", + " 'model.encoder.layers.1.norm1.weight',\n", + " 'model.encoder.layers.1.norm1.bias',\n", + " 'model.encoder.layers.1.norm2.weight',\n", + " 'model.encoder.layers.1.norm2.bias',\n", + " 'model.encoder.layers.2.self_attn.in_proj_weight',\n", + " 'model.encoder.layers.2.self_attn.in_proj_bias',\n", + " 'model.encoder.layers.2.self_attn.out_proj.weight',\n", + " 'model.encoder.layers.2.self_attn.out_proj.bias',\n", + " 'model.encoder.layers.2.linear1.weight',\n", + " 'model.encoder.layers.2.linear1.bias',\n", + " 'model.encoder.layers.2.linear2.weight',\n", + " 'model.encoder.layers.2.linear2.bias',\n", + " 'model.encoder.layers.2.norm1.weight',\n", + " 'model.encoder.layers.2.norm1.bias',\n", + " 'model.encoder.layers.2.norm2.weight',\n", + " 'model.encoder.layers.2.norm2.bias',\n", + " 'model.encoder.layers.3.self_attn.in_proj_weight',\n", + " 'model.encoder.layers.3.self_attn.in_proj_bias',\n", + " 'model.encoder.layers.3.self_attn.out_proj.weight',\n", + " 'model.encoder.layers.3.self_attn.out_proj.bias',\n", + " 'model.encoder.layers.3.linear1.weight',\n", + " 'model.encoder.layers.3.linear1.bias',\n", + " 'model.encoder.layers.3.linear2.weight',\n", + " 'model.encoder.layers.3.linear2.bias',\n", + " 'model.encoder.layers.3.norm1.weight',\n", + " 'model.encoder.layers.3.norm1.bias',\n", + " 'model.encoder.layers.3.norm2.weight',\n", + " 'model.encoder.layers.3.norm2.bias',\n", + " 'model.action_head.weight',\n", + " 'model.action_head.bias',\n", + " 'model.is_pad_head.weight',\n", + " 'model.is_pad_head.bias',\n", + " 'model.query_embed.weight',\n", + " 'model.input_proj.weight',\n", + " 'model.input_proj.bias',\n", + " 'model.backbones.0.0.body.conv1.weight',\n", + " 'model.backbones.0.0.body.bn1.weight',\n", + " 'model.backbones.0.0.body.bn1.bias',\n", + " 'model.backbones.0.0.body.bn1.running_mean',\n", + " 'model.backbones.0.0.body.bn1.running_var',\n", + " 'model.backbones.0.0.body.bn1.num_batches_tracked',\n", + " 'model.backbones.0.0.body.layer1.0.conv1.weight',\n", + " 'model.backbones.0.0.body.layer1.0.bn1.weight',\n", + " 'model.backbones.0.0.body.layer1.0.bn1.bias',\n", + " 'model.backbones.0.0.body.layer1.0.bn1.running_mean',\n", + " 'model.backbones.0.0.body.layer1.0.bn1.running_var',\n", + " 'model.backbones.0.0.body.layer1.0.bn1.num_batches_tracked',\n", + " 'model.backbones.0.0.body.layer1.0.conv2.weight',\n", + " 'model.backbones.0.0.body.layer1.0.bn2.weight',\n", + " 'model.backbones.0.0.body.layer1.0.bn2.bias',\n", + " 'model.backbones.0.0.body.layer1.0.bn2.running_mean',\n", + " 'model.backbones.0.0.body.layer1.0.bn2.running_var',\n", + " 'model.backbones.0.0.body.layer1.0.bn2.num_batches_tracked',\n", + " 'model.backbones.0.0.body.layer1.1.conv1.weight',\n", + " 'model.backbones.0.0.body.layer1.1.bn1.weight',\n", + " 'model.backbones.0.0.body.layer1.1.bn1.bias',\n", + " 'model.backbones.0.0.body.layer1.1.bn1.running_mean',\n", + " 'model.backbones.0.0.body.layer1.1.bn1.running_var',\n", + " 'model.backbones.0.0.body.layer1.1.bn1.num_batches_tracked',\n", + " 'model.backbones.0.0.body.layer1.1.conv2.weight',\n", + " 'model.backbones.0.0.body.layer1.1.bn2.weight',\n", + " 'model.backbones.0.0.body.layer1.1.bn2.bias',\n", + " 'model.backbones.0.0.body.layer1.1.bn2.running_mean',\n", + " 'model.backbones.0.0.body.layer1.1.bn2.running_var',\n", + " 'model.backbones.0.0.body.layer1.1.bn2.num_batches_tracked',\n", + " 'model.backbones.0.0.body.layer2.0.conv1.weight',\n", + " 'model.backbones.0.0.body.layer2.0.bn1.weight',\n", + " 'model.backbones.0.0.body.layer2.0.bn1.bias',\n", + " 'model.backbones.0.0.body.layer2.0.bn1.running_mean',\n", + " 'model.backbones.0.0.body.layer2.0.bn1.running_var',\n", + " 'model.backbones.0.0.body.layer2.0.bn1.num_batches_tracked',\n", + " 'model.backbones.0.0.body.layer2.0.conv2.weight',\n", + " 'model.backbones.0.0.body.layer2.0.bn2.weight',\n", + " 'model.backbones.0.0.body.layer2.0.bn2.bias',\n", + " 'model.backbones.0.0.body.layer2.0.bn2.running_mean',\n", + " 'model.backbones.0.0.body.layer2.0.bn2.running_var',\n", + " 'model.backbones.0.0.body.layer2.0.bn2.num_batches_tracked',\n", + " 'model.backbones.0.0.body.layer2.0.downsample.0.weight',\n", + " 'model.backbones.0.0.body.layer2.0.downsample.1.weight',\n", + " 'model.backbones.0.0.body.layer2.0.downsample.1.bias',\n", + " 'model.backbones.0.0.body.layer2.0.downsample.1.running_mean',\n", + " 'model.backbones.0.0.body.layer2.0.downsample.1.running_var',\n", + " 'model.backbones.0.0.body.layer2.0.downsample.1.num_batches_tracked',\n", + " 'model.backbones.0.0.body.layer2.1.conv1.weight',\n", + " 'model.backbones.0.0.body.layer2.1.bn1.weight',\n", + " 'model.backbones.0.0.body.layer2.1.bn1.bias',\n", + " 'model.backbones.0.0.body.layer2.1.bn1.running_mean',\n", + " 'model.backbones.0.0.body.layer2.1.bn1.running_var',\n", + " 'model.backbones.0.0.body.layer2.1.bn1.num_batches_tracked',\n", + " 'model.backbones.0.0.body.layer2.1.conv2.weight',\n", + " 'model.backbones.0.0.body.layer2.1.bn2.weight',\n", + " 'model.backbones.0.0.body.layer2.1.bn2.bias',\n", + " 'model.backbones.0.0.body.layer2.1.bn2.running_mean',\n", + " 'model.backbones.0.0.body.layer2.1.bn2.running_var',\n", + " 'model.backbones.0.0.body.layer2.1.bn2.num_batches_tracked',\n", + " 'model.backbones.0.0.body.layer3.0.conv1.weight',\n", + " 'model.backbones.0.0.body.layer3.0.bn1.weight',\n", + " 'model.backbones.0.0.body.layer3.0.bn1.bias',\n", + " 'model.backbones.0.0.body.layer3.0.bn1.running_mean',\n", + " 'model.backbones.0.0.body.layer3.0.bn1.running_var',\n", + " 'model.backbones.0.0.body.layer3.0.bn1.num_batches_tracked',\n", + " 'model.backbones.0.0.body.layer3.0.conv2.weight',\n", + " 'model.backbones.0.0.body.layer3.0.bn2.weight',\n", + " 'model.backbones.0.0.body.layer3.0.bn2.bias',\n", + " 'model.backbones.0.0.body.layer3.0.bn2.running_mean',\n", + " 'model.backbones.0.0.body.layer3.0.bn2.running_var',\n", + " 'model.backbones.0.0.body.layer3.0.bn2.num_batches_tracked',\n", + " 'model.backbones.0.0.body.layer3.0.downsample.0.weight',\n", + " 'model.backbones.0.0.body.layer3.0.downsample.1.weight',\n", + " 'model.backbones.0.0.body.layer3.0.downsample.1.bias',\n", + " 'model.backbones.0.0.body.layer3.0.downsample.1.running_mean',\n", + " 'model.backbones.0.0.body.layer3.0.downsample.1.running_var',\n", + " 'model.backbones.0.0.body.layer3.0.downsample.1.num_batches_tracked',\n", + " 'model.backbones.0.0.body.layer3.1.conv1.weight',\n", + " 'model.backbones.0.0.body.layer3.1.bn1.weight',\n", + " 'model.backbones.0.0.body.layer3.1.bn1.bias',\n", + " 'model.backbones.0.0.body.layer3.1.bn1.running_mean',\n", + " 'model.backbones.0.0.body.layer3.1.bn1.running_var',\n", + " 'model.backbones.0.0.body.layer3.1.bn1.num_batches_tracked',\n", + " 'model.backbones.0.0.body.layer3.1.conv2.weight',\n", + " 'model.backbones.0.0.body.layer3.1.bn2.weight',\n", + " 'model.backbones.0.0.body.layer3.1.bn2.bias',\n", + " 'model.backbones.0.0.body.layer3.1.bn2.running_mean',\n", + " 'model.backbones.0.0.body.layer3.1.bn2.running_var',\n", + " 'model.backbones.0.0.body.layer3.1.bn2.num_batches_tracked',\n", + " 'model.backbones.0.0.body.layer4.0.conv1.weight',\n", + " 'model.backbones.0.0.body.layer4.0.bn1.weight',\n", + " 'model.backbones.0.0.body.layer4.0.bn1.bias',\n", + " 'model.backbones.0.0.body.layer4.0.bn1.running_mean',\n", + " 'model.backbones.0.0.body.layer4.0.bn1.running_var',\n", + " 'model.backbones.0.0.body.layer4.0.bn1.num_batches_tracked',\n", + " 'model.backbones.0.0.body.layer4.0.conv2.weight',\n", + " 'model.backbones.0.0.body.layer4.0.bn2.weight',\n", + " 'model.backbones.0.0.body.layer4.0.bn2.bias',\n", + " 'model.backbones.0.0.body.layer4.0.bn2.running_mean',\n", + " 'model.backbones.0.0.body.layer4.0.bn2.running_var',\n", + " 'model.backbones.0.0.body.layer4.0.bn2.num_batches_tracked',\n", + " 'model.backbones.0.0.body.layer4.0.downsample.0.weight',\n", + " 'model.backbones.0.0.body.layer4.0.downsample.1.weight',\n", + " 'model.backbones.0.0.body.layer4.0.downsample.1.bias',\n", + " 'model.backbones.0.0.body.layer4.0.downsample.1.running_mean',\n", + " 'model.backbones.0.0.body.layer4.0.downsample.1.running_var',\n", + " 'model.backbones.0.0.body.layer4.0.downsample.1.num_batches_tracked',\n", + " 'model.backbones.0.0.body.layer4.1.conv1.weight',\n", + " 'model.backbones.0.0.body.layer4.1.bn1.weight',\n", + " 'model.backbones.0.0.body.layer4.1.bn1.bias',\n", + " 'model.backbones.0.0.body.layer4.1.bn1.running_mean',\n", + " 'model.backbones.0.0.body.layer4.1.bn1.running_var',\n", + " 'model.backbones.0.0.body.layer4.1.bn1.num_batches_tracked',\n", + " 'model.backbones.0.0.body.layer4.1.conv2.weight',\n", + " 'model.backbones.0.0.body.layer4.1.bn2.weight',\n", + " 'model.backbones.0.0.body.layer4.1.bn2.bias',\n", + " 'model.backbones.0.0.body.layer4.1.bn2.running_mean',\n", + " 'model.backbones.0.0.body.layer4.1.bn2.running_var',\n", + " 'model.backbones.0.0.body.layer4.1.bn2.num_batches_tracked',\n", + " 'model.input_proj_robot_state.weight',\n", + " 'model.input_proj_robot_state.bias',\n", + " 'model.cls_embed.weight',\n", + " 'model.encoder_action_proj.weight',\n", + " 'model.encoder_action_proj.bias',\n", + " 'model.encoder_joint_proj.weight',\n", + " 'model.encoder_joint_proj.bias',\n", + " 'model.latent_proj.weight',\n", + " 'model.latent_proj.bias',\n", + " 'model.latent_out_proj.weight',\n", + " 'model.latent_out_proj.bias',\n", + " 'model.additional_pos_embed.weight']\n" + ] + } + ], + "source": [ + "orig = list(a.keys())\n", + "pprint(orig)" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [], + "source": [ + "a = torch.load(original_ckpt_path)\n", + "\n", + "to_remove_startswith = ['model.transformer.decoder.layers.1.',\n", + " 'model.transformer.decoder.layers.2.',\n", + " 'model.transformer.decoder.layers.3.',\n", + " 'model.transformer.decoder.layers.4.',\n", + " 'model.transformer.decoder.layers.5.',\n", + " 'model.transformer.decoder.layers.6.',\n", + " 'model.transformer.decoder.norm.',\n", + " 'model.is_pad_head']\n", + "\n", + "to_remove_in = ['num_batches_tracked',]\n", + "\n", + "conv = {}\n", + "\n", + "keys = list(a.keys())\n", + "for k in keys:\n", + " if any(k.startswith(tr) for tr in to_remove_startswith):\n", + " a.pop(k)\n", + " continue\n", + " if any(tr in k for tr in to_remove_in):\n", + " a.pop(k)\n", + " continue\n", + " if k.startswith('model.transformer.encoder.layers.'):\n", + " conv[k.replace('transformer.', '')] = a.pop(k)\n", + " if k.startswith('model.transformer.decoder.layers.0.'):\n", + " conv[k.replace('transformer.', '')] = a.pop(k)\n", + " if k.startswith('model.encoder.layers.'):\n", + " conv[k.replace('encoder.', 'vae_encoder.')] = a.pop(k)\n", + " if k.startswith('model.action_head.'):\n", + " conv[k] = a.pop(k)\n", + " if k.startswith('model.pos_table'):\n", + " conv[k.replace('pos_table', 'vae_encoder_pos_enc')] = a.pop(k)\n", + " if k.startswith('model.query_embed.'):\n", + " conv[k.replace('query_embed', 'decoder_pos_embed')] = a.pop(k)\n", + " if k.startswith('model.input_proj.'):\n", + " conv[k.replace('input_proj.', 'encoder_img_feat_input_proj.')] = a.pop(k)\n", + " if k.startswith('model.input_proj_robot_state.'):\n", + " conv[k.replace('input_proj_robot_state.', 'encoder_robot_state_input_proj.')] = a.pop(k)\n", + " if k.startswith('model.backbones.0.0.body.'):\n", + " conv[k.replace('backbones.0.0.body', 'backbone')] = a.pop(k)\n", + " if k.startswith('model.cls_embed.'):\n", + " conv[k.replace('cls_embed', 'vae_encoder_cls_embed')] = a.pop(k)\n", + " if k.startswith('model.encoder_action_proj.'):\n", + " conv[k.replace('encoder_action_proj', 'vae_encoder_action_input_proj')] = a.pop(k)\n", + " if k.startswith('model.encoder_joint_proj.'):\n", + " conv[k.replace('encoder_joint_proj', 'vae_encoder_robot_state_input_proj')] = a.pop(k)\n", + " if k.startswith('model.latent_proj.'):\n", + " conv[k.replace('latent_proj', 'vae_encoder_latent_output_proj')] = a.pop(k)\n", + " if k.startswith('model.latent_out_proj.'):\n", + " conv[k.replace('latent_out_proj', 'encoder_latent_input_proj')] = a.pop(k)\n", + " if k.startswith('model.additional_pos_embed.'):\n", + " conv[k.replace('additional_pos_embed', 'encoder_robot_and_latent_pos_embed')] = a.pop(k)" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "OrderedDict()" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "a" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [], + "source": [ + "for k, v in conv.items():\n", + " assert b[k].shape == v.shape\n", + " b[k] = v" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [], + "source": [ + "save_file(b, converted_ckpt_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort/config.yaml'" + ] + }, + "execution_count": 54, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Now also copy the config files\n", + "import shutil\n", + "shutil.copy(comparison_config_json_path, converted_ckpt_path.replace('model.safetensors', 'config.json'))\n", + "shutil.copy(comparison_config_yaml_path, converted_ckpt_path.replace('model.safetensors', 'config.yaml'))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "lerobot", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}