Go2Py/examples/08-ActuatorNet.ipynb

530 lines
164 KiB
Plaintext
Raw Normal View History

2024-11-05 06:27:32 +08:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"pygame 2.6.1 (SDL 2.28.4, Python 3.10.12)\n",
"Hello from the pygame community. https://www.pygame.org/contribute.html\n"
]
}
],
"source": [
"from Go2Py.robot.interface import GO2Real"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from Go2Py.robot.model import Go2Model"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"model = Go2Model()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"robot=GO2Real()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'q': [0.3250521421432495,\n",
" 0.40270543098449707,\n",
" -1.4649406671524048,\n",
" -0.44037926197052,\n",
" 0.35583579540252686,\n",
" -1.469537615776062,\n",
" 0.41242796182632446,\n",
" 0.29579007625579834,\n",
" -1.4764724969863892,\n",
" -0.38193845748901367,\n",
" 0.25427955389022827,\n",
" -1.4649723768234253],\n",
" 'dq': [0.007751047611236572,\n",
" 0.01937761902809143,\n",
" -0.014154086820781231,\n",
" 0.034879714250564575,\n",
" 0.04263076186180115,\n",
" 0.008088049478828907,\n",
" -0.034879714250564575,\n",
" -0.06200838088989258,\n",
" 0.024264149367809296,\n",
" 0.034879714250564575,\n",
" -0.003875523805618286,\n",
" -0.03033018670976162],\n",
" 'tau_est': [-0.14842969179153442,\n",
" 0.12369140982627869,\n",
" -0.1422451138496399,\n",
" -0.22264453768730164,\n",
" -0.12369140982627869,\n",
" 0.3319052755832672,\n",
" 0.049476563930511475,\n",
" 0.0,\n",
" -0.09483008086681366,\n",
" 0.049476563930511475,\n",
" -0.024738281965255737,\n",
" 0.0],\n",
" 'temperature': [27.0,\n",
" 24.0,\n",
" 26.0,\n",
" 26.0,\n",
" 25.0,\n",
" 26.0,\n",
" 27.0,\n",
" 25.0,\n",
" 26.0,\n",
" 27.0,\n",
" 24.0,\n",
" 25.0]}"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"state = robot.getJointStates()\n",
"state"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"from scipy.spatial.transform import Rotation as R\n",
"import numpy as np\n",
"imu = robot.getIMU()\n",
"guat = imu['quat']\n",
"Rot_mat = R.from_quat(guat).as_matrix()\n",
"T = np.eye(4)\n",
"T[:3,:3]=Rot_mat\n",
"\n",
"q = np.array(state['q'])\n",
"dq = np.array(state['dq'])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"dict_keys(['q', 'dq', 'tau_est', 'temperature'])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"state.keys()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"model.updateAllPose(q, dq, T, np.zeros(6))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([-0.68048197, 0.24158602, -0.19297757, 0.58341235, 0.2037472 ,\n",
" -0.19295897, -0.58055523, 0.15918967, -0.20132737, 0.64800928,\n",
" 0.12985214, -0.21048951])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.getInfo()['g'][6:]"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"def getGravComp(model, state, imu):\n",
" imu = robot.getIMU()\n",
" guat = imu['quat']\n",
" Rot_mat = R.from_quat(guat).as_matrix()\n",
" T = np.eye(4)\n",
" T[:3,:3]=Rot_mat\n",
"\n",
" q = np.array(state['q'])\n",
" dq = np.array(state['dq'])\n",
" model.updateAllPose(q, dq, T, np.zeros(6))\n",
" return model.getInfo()['g'][6:]"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[46], line 10\u001b[0m\n\u001b[1;32m 8\u001b[0m state \u001b[38;5;241m=\u001b[39m robot\u001b[38;5;241m.\u001b[39mgetJointStates()\n\u001b[1;32m 9\u001b[0m imu \u001b[38;5;241m=\u001b[39m robot\u001b[38;5;241m.\u001b[39mgetIMU()\n\u001b[0;32m---> 10\u001b[0m grav \u001b[38;5;241m=\u001b[39m \u001b[43mgetGravComp\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mimu\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 11\u001b[0m q \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mzeros(\u001b[38;5;241m12\u001b[39m) \n\u001b[1;32m 12\u001b[0m dq \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mzeros(\u001b[38;5;241m12\u001b[39m)\n",
"Cell \u001b[0;32mIn[36], line 10\u001b[0m, in \u001b[0;36mgetGravComp\u001b[0;34m(model, state, imu)\u001b[0m\n\u001b[1;32m 8\u001b[0m q \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39marray(state[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mq\u001b[39m\u001b[38;5;124m'\u001b[39m])\n\u001b[1;32m 9\u001b[0m dq \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39marray(state[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdq\u001b[39m\u001b[38;5;124m'\u001b[39m])\n\u001b[0;32m---> 10\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mupdateAllPose\u001b[49m\u001b[43m(\u001b[49m\u001b[43mq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mT\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mzeros\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m6\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 11\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m model\u001b[38;5;241m.\u001b[39mgetInfo()[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mg\u001b[39m\u001b[38;5;124m'\u001b[39m][\u001b[38;5;241m6\u001b[39m:]\n",
"File \u001b[0;32m/home/Go2py/Go2Py/robot/model.py:251\u001b[0m, in \u001b[0;36mGo2Model.updateAllPose\u001b[0;34m(self, q, dq, T, v)\u001b[0m\n\u001b[1;32m 249\u001b[0m dq_ \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mhstack([v, dq[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mq_reordering_idx]])\n\u001b[1;32m 250\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mupdateKinematics(q_)\n\u001b[0;32m--> 251\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mupdateDynamics\u001b[49m\u001b[43m(\u001b[49m\u001b[43mq_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdq_\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"import numpy as np\n",
"import time\n",
"start_time = time.time()\n",
"freq = 3\n",
"omega = np.pi*2*freq\n",
"A = 1.5\n",
"while time.time()-start_time < 30:\n",
" state = robot.getJointStates()\n",
" imu = robot.getIMU()\n",
" grav = getGravComp(model, state, imu)\n",
" q = np.zeros(12) \n",
" dq = np.zeros(12)\n",
" kp = np.ones(12)*0.0\n",
" kd = np.ones(12)*0.0\n",
" tau = np.zeros(12)\n",
" # tau[2] = A*np.sin(omega*(time.time()-start_time))\n",
" tau = grav\n",
"\n",
" robot.setCommands(q, dq, kp, kd, tau)\n",
" time.sleep(0.001)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import time\n",
"start_time = time.time()\n",
"freq = 2.\n",
"omega = np.pi*2*freq\n",
"q_nominal = robot.standing_q\n",
"As = [0.5, 1., 1.5, 2, 2.5]\n",
"freqs = [1.5, 2., 2.5, 3., 3.5]\n",
"A = 0 \n",
"\n",
"# Go home before anything\n",
"while time.time()-start_time < 3:\n",
" state = robot.getJointStates()\n",
" imu = robot.getIMU()\n",
" grav = getGravComp(model, state, imu)\n",
" q = q_nominal \n",
" dq = np.zeros(12)\n",
" kp = np.ones(12)*3.0\n",
" kd = np.ones(12)*0.1\n",
" tau = np.zeros(12)\n",
" robot.setCommands(q, dq, kp, kd, tau)\n",
"# Start recording a dataset for the knee joints\n",
"dataset = []\n",
"for freq in freqs:\n",
" omega = np.pi*2*freq\n",
" for A in As:\n",
" start_time=time.time()\n",
" while time.time()-start_time < 5:\n",
" state = robot.getJointStates()\n",
" imu = robot.getIMU()\n",
" grav = getGravComp(model, state, imu)\n",
" q = q_nominal \n",
" dq = np.zeros(12)\n",
" kp = np.ones(12)*3.0\n",
" kd = np.ones(12)*0.1\n",
" tau = np.zeros(12)\n",
" tau = grav\n",
" # The internal controller of the knee joints should not be active \n",
" kp[2]=kp[5]=kp[8]=kp[11]=0.\n",
" kd[2]=kd[5]=kd[8]=kd[11]=0.\n",
" pd_law = 6.3*(q_nominal-state['q']) + 0.5*(np.array([0])-state['dq'])\n",
" tau[2] = A*np.sin(omega*(time.time()-start_time)) + grav[2] + pd_law[2]\n",
" tau[5] = A*np.sin(omega*(time.time()-start_time)) + grav[5] + pd_law[5]\n",
" tau[8] = A*np.sin(omega*(time.time()-start_time)) + grav[8] + pd_law[8]\n",
" tau[11] = A*np.sin(omega*(time.time()-start_time)) + grav[11] + pd_law[11]\n",
" robot.setCommands(q, dq, kp, kd, tau)\n",
" q = np.array(state['q'])\n",
" dq = np.array(state['dq'])\n",
" tau_recorded = np.array(state['tau_est'])\n",
" dataset.append([tau[[2,5,8,11]], q[[2,5,8,11]], dq[[2,5,8,11]], tau_recorded[[2,5,8,11]]])\n",
" time.sleep(0.01)\n",
"\n",
"dataset = np.array(dataset)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x70f9a272c550>,\n",
" <matplotlib.lines.Line2D at 0x70f9a272c580>,\n",
" <matplotlib.lines.Line2D at 0x70f9a2770550>,\n",
" <matplotlib.lines.Line2D at 0x70f9a2770670>]"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAiIAAAGdCAYAAAAvwBgXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAB7HUlEQVR4nO3dd5gTVdsH4N9Msr0XlmXZpSy9d5CO0kVQLCAigl1EhU9svL52fRHsgqKiAjZAQFCpIlWQLlV6X8pSt/dkzvdH2Gyym+xmN2cyJ5Pnvi4usjOTmZOTyZlnThuJMcZACCGEEKIBWesEEEIIIcR3USBCCCGEEM1QIEIIIYQQzVAgQgghhBDNUCBCCCGEEM1QIEIIIYQQzVAgQgghhBDNUCBCCCGEEM0YtU5AeRRFwYULFxAWFgZJkrRODiGEEEJcwBhDVlYWEhISIMvl13kIHYhcuHABSUlJWieDEEIIIVWQkpKCxMTEcrcROhAJCwsDYPkg4eHhGqeGEEIIIa7IzMxEUlKS9TpeHqEDkeLmmPDwcApECCGEEC/jSrcK6qxKCCGEEM1QIEIIIYQQzVAgQgghhBDNUCBCCCGEEM1QIEIIIYQQzVAgQgghhBDNUCBCCCGEEM1QIEIIIYQQzVAgQgghhBDNUCBCCCGEEM1QIEIIIYQQzVAgQgghhBDNUCBCCCGE6NjpjNOYfWA28k35WifFIaGfvksIIYQQ9wxeMhgAcL3gOp5t96zGqSmLakQIIYQQH7D38l6tk+AQBSKEEEKIjuy7sg/zD88HY0zrpLiEmmYIIaSUvVf2YkfqDoxpNgZGmYpJ4l1GLh8JAKgeUh29knppmxgXUI0IIYSUcv/y+/HJP5/gl2O/aJ0Ur3cl9woe+eMR/Hnmz3K3u5h9EQuOLkCBucBDKfNuBeYCfPLPJ9hzeY/TbU5mnHS6TqT8plCfEEKcKK8gJ66ZsmMKtl3chm0Xt2H/6P1Ot7vj1zuQa8rFxeyLeKbtMx5MoXf67t/v8PX+r/H1/q/LzVdnRMpvqhEhhPgkxhgUpmidDF2ad3gePv3nUwBAen56udtmFmaiyFyEXFMuAGDrxa1qJ08XTmWcqnCb8vqIFOf3lgtbuKWpqqhGhBDiM3ak7sCZzDO4u+HdeHz140grSMO8QfOgQIFRMkKSJK2TqAvvbHsHANC/Tn/ASZZezbuK+5ffj/PZ51EztKZ1ueTsDcSOQTZw2Q+D9h1aKRAhhPiMh1Y9BACoH1kfWy5a7gT3XNmDp9c8jRbVWuDLvl/abU8XRfdkFWbB9jpXXAtlkA2Yvns6zmefBwDr/wCQZ87D53s+R+9avdEoupGnk+w1DFLFgYgrQYYIgQg1zRBCfM43+7+xvv7j9B/IKsrC3xf+LrNdWkEaGGMoNBd6Mnm6Ufoi99jqx3Db4tuw5/IeLDq2yOF7jqUdw4y9M3D373d7Iom6d/DawTLLbJtsRBjiSzUihBCfs/7ceuvrnw7/ZH29+sxqzD883/r3spPLsOzkMgDAX8P/QmRgpKeSqB82lUrF/T9GrRilUWJ8z/Clw8ssm7x9sgYpcY5qRAgh5IZn1z+LbanbHK6bf2S+w+WEiCCrMMvlZ8nMPTxX5dRUDtWIEEKIC3JMOVonwStRPxv15RTloMvcLgCArjW7Wpe70uwiwsgxqhEhhBBXmKifSGWJ0P/AFxxLO2Z9vfn8Zuvr0kHG8fTjQn4nFIgQQogrrh7ROgVeR4H2d9u+wNmw89LLMwszsfj4YrtlNGqGEEKIbv13039hUkxaJ0P3MgoyXN72tb9fs/ubAhFCiJ2TGSdpqKigtC+uvc+l3EvYeWmn1snQvXFrxjlcPm33tArfeyztGI6nHeedpEqhQIQQQaw7uw63L7kdD656UOukEAdEbFsnhIdJmyZpenwKRAgRRPEET/uu7HP5PevOrsOU7VNgVsxqJYvcwABg21fA5k+1TgohXOWZ8jQ9Pg3fJUQQUlGu/YLjfwL7FwEDpwCB4Q7f88w6y1MzG0Y1xNAGQ9VOom9jCrDiecvrlsOBsOrapkdQVHPkGTyfi6T1EF6qESFEFJcP2//9w13A3p+ADVMqfuvZv1RKFCmWUpheMgZE4ztIQniiQIToQ/Zl4PBygJoI3ODkTjKz5IFgqTmpeHz149h4bqP9NtlXVEwXAYB12acwpkYcMmUJTh8pS4QYhUEqR+taLApECB+fdQLmjQB2fqt1SnRjcWgIJlaLQaHN3cprf7+Gvy/8XaaXPBX9nrE7MBDdayVqnQyhXc69rHUSSCVpPd8LBSKEj7zrlv+PrNA2HV7qat5V5JcKJ16tFoM/QkOw0HTVuuxs5lkne6BQxFMUSQI4ts/rSXp+Ovou7Kt1MkglKeYiTY9PnVUJ0di13Ku4ecHNTtdnoqS5y8RocigxUCDiyLH0YxVvRISj5F7X9PhUI0L4ojvFStt3eJHb+9C6jZcQ4r0UjW9wKBAhfNEFsfJMBWUWOQssmO2D1zSuTvVlZnqGCtGaop9z0GOByLvvvgtJkjBhwgRPHZIQ7+CgFqnn/J4lqwFLobP4CSDHpiPgW7E276AA0JNMNDrMIaqZ86Br2k7LzpNHApEdO3bgyy+/RMuWLT1xOEK8i4PCO60gzX7B6Y3A3rnUM0EQRfQgN6IxyZSvdRK4UT0Qyc7OxsiRIzFz5kxERUWpfThCdCeHKcjIvYJsSbJrEHiiejXra7oT9SwKRByjOURIVageiIwbNw6DBg1Cnz59Kty2oKAAmZmZdv8I0b0KOvjOMqWi26430blOEi4bSwa6bQ4Osr6m4t+zihg1zRD90Lr8UHX47rx58/DPP/9gx44dLm0/efJkvPHGG2omiRBC3Kb1BFCiohoRb6Vto69qNSIpKSkYP348fvzxRwQGBrr0nkmTJiEjI8P6LyUlRa3kEV7Obi37jBRSOVyaVegC4FHUFEYIN6rViOzatQuXL19G27ZtrcvMZjM2btyI6dOno6CgAAaDwe49AQEBCAgIUCtJhLeMc8C3/bVOBQGFIYQQ76VaINK7d2/s37/fbtmDDz6Ixo0b48UXXywThBAvdP2kg4U2l8RfxwEF2cA9s2miM5VRlbhnUW47Rp2mPUfS0Rg61QKRsLAwNG/e3G5ZSEgIYmJiyiwnXqq8QsdcBOz+wfI67RQQneyZNHkhCiIIIb6MnjVDKu/6KWD7TCC+hfNtbIMU29lAiZ3D1w9ja/pRrZNBKomCR8coX7yT1t+aRwOR9evXe/JwRC2zbwMyzzlcdTXtBL7/siXubno/kooX0lBHp+75/R4u+6Eqcc+i/CaEH3rWDKk8myDkmizDdmqn5/zz8G0gw5ijszyfLkKItig+I1VAgQipsqN+fuhVOxGPxcdZl+0KsgzVtp14ixC9oRoRQvihQIRU2dLQEADAjqCSeWIkKqAJIUR1ehqISLetpEJ5pjwEGUumE98eGIA5EeGINLvY94OCE9XpqEzyCoxmVnWIOquSqqBAhJRr7dm1GL9uPF7o8ALqRdSDJEl4rEb1Mtv9LzoKuwIDwGzCdAa6QHoKFf+EEG9FgQgp10t/vQQAmLpjarnbzY0IK7NsUGINtM8vwJuqpIwQDflYLd8n/3yCzec3Y87AOXa1o6VRjQipCuojQsplkKo+A26Knx8Wh4VyTA0hYmDMt5pmvt7/NQ5dP4TfT/xe/oY+FqARPigQIeXS0zTCWssqzFJttAV9S8QTzBXMCcSuHfdQSvSLMYbJ2ybjh4M/eO6YGhcgFIiQ8pl5zIpKd0l7r+xFl7ld8J9N/1Fl/5TDnuVrNSK28kx5mLFnBo5cP1J2Jc2iXClbLmxBak6q3bIDVw/gp8M/YcqOKRW8Wz+3HxSIkPIpRVqnQBe+3vc1AGDpyaUap8T7MMbw/cHvsefyHq2T4vMkSPhy75f4fO/nuPv3u63Li2v6tL6z9iZbLmzBY6sfQ9+Ffe2W55hyymybW5SLfFO+p5LmcRSIENUpCk3xrjZJx3Uiq8+sxtQdUzFqxSjkFJUtpEtv++HOD6GoXGPhy50yD10/ZH19MuMkZh+YjZ7ze+Jkxkno6S5dbTs
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"plt.plot(dataset[:,-2])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"with open('knee_friction_identification_dataset.pkl', 'wb') as f:\n",
" pickle.dump(\n",
" {'q': dataset[:,1],\n",
" 'dq': dataset[:,2],\n",
" 'tau': dataset[:,0],\n",
" 'tau_recorded': dataset[:,-1],},\n",
" f\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Curve fitting"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"dataset = pickle.load(open('knee_friction_identification_dataset.pkl', 'rb'))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"class FrictionSSID(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.mu_v = nn.Parameter(data=torch.zeros(1), requires_grad=True)\n",
" self.Fs = nn.Parameter(data=torch.zeros(1), requires_grad=True)\n",
" self.register_buffer(\"temperature\", 0.1 * torch.ones(()))\n",
"\n",
" def forward(self, dq):\n",
" tau_sticktion = self.Fs*self.softSign(dq, temperature=self.temperature)\n",
" tau_viscose = self.mu_v*dq\n",
" return tau_sticktion+tau_viscose\n",
"\n",
" def softSign(self, u, temperature=0.1):\n",
" return torch.tanh(u/temperature)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"knee nb.: 0, epgetJointStatesoch 1/100, loss: 3.2810028187543234\n",
"knee nb.: 0, epgetJointStatesoch 26/100, loss: 1.5961371864755898\n",
"knee nb.: 0, epgetJointStatesoch 51/100, loss: 1.5878992784766226\n",
"knee nb.: 0, epgetJointStatesoch 76/100, loss: 1.5875244340091612\n",
"mu_v: Parameter containing:\n",
"tensor([0.2167], requires_grad=True), Fs: Parameter containing:\n",
"tensor([1.5259], requires_grad=True)\n",
"knee nb.: 1, epgetJointStatesoch 1/100, loss: 2.5653999508271745\n",
"knee nb.: 1, epgetJointStatesoch 26/100, loss: 1.620573471383013\n",
"knee nb.: 1, epgetJointStatesoch 51/100, loss: 1.6184235904911686\n",
"knee nb.: 1, epgetJointStatesoch 76/100, loss: 1.6184084277930868\n",
"mu_v: Parameter containing:\n",
"tensor([-0.0647], requires_grad=True), Fs: Parameter containing:\n",
"tensor([1.2380], requires_grad=True)\n",
"knee nb.: 2, epgetJointStatesoch 1/100, loss: 2.281400742101678\n",
"knee nb.: 2, epgetJointStatesoch 26/100, loss: 1.7584790933593724\n",
"knee nb.: 2, epgetJointStatesoch 51/100, loss: 1.7573205771828757\n",
"knee nb.: 2, epgetJointStatesoch 76/100, loss: 1.7573122681074191\n",
"mu_v: Parameter containing:\n",
"tensor([-0.0420], requires_grad=True), Fs: Parameter containing:\n",
"tensor([0.8917], requires_grad=True)\n",
"knee nb.: 3, epgetJointStatesoch 1/100, loss: 4.078545475140164\n",
"knee nb.: 3, epgetJointStatesoch 26/100, loss: 2.136013706473821\n",
"knee nb.: 3, epgetJointStatesoch 51/100, loss: 2.1031023316309967\n",
"knee nb.: 3, epgetJointStatesoch 76/100, loss: 2.100360134679885\n",
"mu_v: Parameter containing:\n",
"tensor([-0.0834], requires_grad=True), Fs: Parameter containing:\n",
"tensor([2.2461], requires_grad=True)\n"
]
}
],
"source": [
"for knee_nb in range(4):\n",
" epoch_nb = 100\n",
" friction_model = FrictionSSID()\n",
" optim = torch.optim.SGD(params=friction_model.parameters(), lr=0.2)\n",
" for e in range(epoch_nb):\n",
" dq = torch.tensor(dataset['dq'][:, knee_nb])\n",
" tau_applied = torch.tensor(dataset['tau'][:, knee_nb])\n",
" tau_recorded = torch.tensor(dataset['tau_recorded'][:, knee_nb])\n",
" loss = ((tau_applied + tau_recorded - friction_model(dq)) ** 2).mean()\n",
" optim.zero_grad()\n",
" loss.backward()\n",
" optim.step()\n",
" if (e) % 25 == 0:\n",
" print(f\"knee nb.: {knee_nb}, epgetJointStatesoch {e+1}/{epoch_nb}, loss: {loss.item()}\")\n",
" print(f\"mu_v: {friction_model.mu_v}, Fs: {friction_model.Fs}\")\n",
" torch.save(friction_model, f\"friction_model_knee_{knee_nb}\")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0, 0.5, 'friction')"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAkIAAAGwCAYAAABFFQqPAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOydd3QUVRuHn9mS3kMqCQm996YUAZWmgIhIU0Gs2BULgooFFHvHTxQVBVG69N5776GH9N57tsx8f0yymyUJCRBIgPuck8PMnTt37iy7O79971skRVEUBAKBQCAQCG5DNNU9AYFAIBAIBILqQgghgUAgEAgEty1CCAkEAoFAILhtEUJIIBAIBALBbYsQQgKBQCAQCG5bhBASCAQCgUBw2yKEkEAgEAgEgtsWXXVPoKYjyzJxcXG4uroiSVJ1T0cgEAgEAkElUBSF7OxsAgMD0WjKt/sIIVQBcXFxBAcHV/c0BAKBQCAQXAXR0dEEBQWVe1wIoQpwdXUF1BfSzc2tmmcjEAgEAoGgMmRlZREcHGx5jpeHEEIVULwc5ubmJoSQQCAQCAQ3GRW5tQhnaYFAIBAIBLctQggJBAKBQCC4bRFCSCAQCAQCwW2LEEICgUAgEAhuW4QQEggEAoFAcNsihJBAIBAIBILblptKCG3bto2BAwcSGBiIJEn8999/l+2/ZcsWJEkq9ZeQkHBjJiwQCAQCgaBGc1MJodzcXFq3bs306dOv6LwzZ84QHx9v+fP19b1OMxQIBAKBQHAzcVMlVOzfvz/9+/e/4vN8fX3x8PCo+gkJBAKBQCC4qbmpLEJXS5s2bQgICKB3797s3Lnzsn0LCwvJysqy+RMIBAKBQHBrcksLoYCAAH7++WcWLVrEokWLCA4OpmfPnhw6dKjcc6ZNm4a7u7vlTxRcFQgEAoHg1kVSFEWp7klcDZIksWTJEgYPHnxF5/Xo0YM6deowe/bsMo8XFhZSWFho2S8u2paZmSlqjQkEAoFAcJOQlZWFu7t7hc/vm8pHqCro1KkTO3bsKPe4vb099vb2N3BGNReTbEJBQa/RV/dUBAKBQCC4LtzSS2NlceTIEQICAqp7GjWeXGMufRb2YczqMRjMhuqejkAgEAgE14WbyiKUk5PD+fPnLfsXL17kyJEjeHl5UadOHSZOnEhsbCx//fUXAN9++y1169alefPmFBQUMHPmTDZt2sS6deuq6xZuGk6mnCQ5P5nk/GR+O/Ebz7V+rrqnJBAIBAJBlXNTCaEDBw7Qq1cvy/748eMBGDNmDLNmzSI+Pp6oqCjLcYPBwOuvv05sbCxOTk60atWKDRs22IwhKJvwzHDL9q/HfqVvaF/quderxhkJBAKBQFD13LTO0jeKyjpb3Wp8vOdj/j3zLxpJg6zItPdrz+99f0cj3XarqQKBQCC4Cans81s81QRlcjHzIgDjWo/DUefIwcSD/Hf+v+qdlEAgEAgEVYwQQoIyuZB5AYDutbvzQpsXAPjywJek5KdU57QEAoFAIKhShBASlCKzMNMieOq61+WRpo/Q1Ksp2YZsPt/3eTXPTiAQCASCqkMIIUEpipfF/J39cdY7o9PoeL/L+2gkDasjVrM9Zns1z1AgEAgEgqpBCCFBKYojxkpGiTX3bs6jTR8FYOqeqeQZ86plbgKBQCAQVCVCCAlKcSFD9Q+6NFy+lmMtAOJy4/jpyE83fF4CgUAgEFQ1QggJSmGxCHnYCqFdcbss23+G/UlYatgNnVelkWUwm6p7FgKBQCC4CRBCSFCK8IzSS2MAMdkxNvsf7v4Qk1zDBEdhNnzfBn7tBYbc6p6NQCAQCGo4QggJbMgz5hGXGwdAfff6lnaTbCIhN8Gmb1hqGHNPzb2h86uQsKWQEQkJx2D9+9U9G4FAIBDUcIQQEthwMUuNGPNy8MLDwcPSnpiXiEkxodfobeqOfXHgC2JzYm/0NMvnyD/W7f2/woVN1TcXgUAgENR4hBAS2FDRslhtl9o83eppmno1tRybumcqNaJSS3oERO4AJGg6SG377wXIz6jGSQkEAoGgJiOEkMCGskLnoYQQcq2NXqPn424fW47tiN3B2oi1N26S5XF0nvpv3bvgwZ/Bqx5kx8HqCdU7L4FAIBDUWIQQEthgsQhdEjFWvPwV5BIEQEPPhrzS7hXL8Te3vUlmYeYNmmUZKAocLfJXajMK7JzhwRkgaeDYvxC2rPrmJhAIBIIaixBCAhsqsggFuwZb2h5v/jhNvJpY9r85+M0NmGE5RO1Wl8bsXKDpQLUtuBN0fVXdXvEq5CRV0+QEAoFAUFMRQkhgwWA2EJ0dDUB9j/o2x2JyrD5Cxeg0Oj676zPL/qJziziQcOAGzLQMjhRZg5oNVq1BxfR8G/xaQF4qLH9VtRwJBAKBQFCEEEICC5FZkZgVMy56F3wcfWyOFVuEglyDbNrrudfjjQ5vWPbHbRiHwWy4/pMtiSEPTv6nbrcZaXtMZ68ukWn0cGYlHP2n1OkCgUAguH0RQkhgoWRGaUmSLO25xlzSC9MBW4tQMY82fZQGHg0AKDQX8uvxX2/AbIvY8S18EgCGbPAIgTpdSvfxbwG9JqnbqydARvSNm59AIBAIajRCCAksVBQ672Hvgauda6nztBot3/X6zrL/89GfLWNdV2QZdv1g3W86EDTlvKW7vgJBnaAwC5Y+r54rEAgEgtseIYQEFootQiUzSoPVP6g4Yqws6rjV4e1Ob1v2n173NLJyncVG8inIS7HuR+0u3wdIo1VD6vVOcHGbmmxRIBAIBLc9QggJLFzILKo671F+DqHLMbLJSItYSspPYuHZhddhliW4uN12P/Yg7J9Zfn/v+tD7I3V7/fuQcu76zU0gEAgENwVCCAkAtZZYZGYkUHpp7GTqSQB8nXwvO4ZG0jCzr1WITNkzhZT8lMuccY1c3Grd9m2m/rtmIsQcLP+cjk9BvV5gyoclz4oq9QKBQHCbI4SQAFATJhpkAw5aBwJdAm2Orb64GoDZYbOJzrq8o3Ftl9q80/kdy/5zG567TO9rQDbDmVXW/SfWQpMBIBthweOQl1b2eZIED0wHe3fVgrSzGnMfCQQCgaDaEUJIAFgdpeu610Ujlf+2GLZiGOsj1192rOGNh+Pl4AXA6bTTbI7aXHUTLSb+qHW7xUPg4AaDfwLPupAZBUvGle8Q7V4b7vtc3d7yqe1YAoFAILitEEJIAFj9g+q617VpT81PtWx7OXiRY8xh/JbxTNs7rex8QbKMJEnMGzDP0vTy5pfJM+ZV7YTPlRBjbR9T/3Vwh2F/gdYezq29vLWn1XA1ykw2qaLJVFi18xMIBALBTYEQQgIALmZeBEpnlN4SvcWyvfahtTzR4gkA5p6ey2OrH7Nkogbg4Cz42B92foe/sz+TOk+yHBq/ZXzVTnjLJ9btundZtwNawX1fqNubppZ2qC5GkmDAt+DsA0lhsPnjsvsJBAKB4JZGCCEBABcyiiLGSjhK55vy+WD3B5Z9B50Dr7V/jen3TMfd3p2w1DCGLS9aKjPkwsYpYC6E9ZNh01RGNBqOVtICsDNuJ8eSj1XNZE0lLFGN+qmh8SVpNxpajwJFhoVPQHZC2eM414KBRfmPdn4PUXuqZn4CgUAguGkQQkiAoig2WaWL+e34b2X2vyvoLhYOXEgbnzbWpbJVT2LISwF7N7XTti+Q1r3LuofWWs57ZNUjmOQqiNI6u9q63XtK6eOSBPd/pUaS5SbBwifLjw5rcj+0eQRQ1CWywpxrn59AIBAIbhqEELoNUBQFo2ws93hCbgL5pnx0ks5SXV5RFNZGrLXpp5RIVujv7M/v/X5nbIuxAMzNPMljgX5E93ob+hc5Iu+Zju+mT3ij/euW897f9f6139DSl6zbPo3K7mPnpPoL2blA5A7YPLX88fpNA/dgSL8I69+79vkJBAKB4KZBCKHbgKl7ptL1n678e/pfGzFTTLE1KMQtBL1Gb2mLyIqw6Tfj2Aybfb1Gz/j245ke2A93s5kwe3uGhc9hvX89GPQjIMHBWYw+s9NyzrILy4jMirz6m1EUKMxUt/1bXr5vrYYwqKgEx45v4Myasvs5uKs
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"plt.plot(dataset['dq'], dataset['tau']-dataset['tau_recorded'])\n",
"plt.xlabel(\"dq\")\n",
"plt.ylabel(\"friction\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}