{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "RL policy based on the [SoloParkour: Constrained Reinforcement Learning for Visual Locomotion from Privileged Experience](https://arxiv.org/abs/2409.13678). " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Flat Ground" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test In Simulation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from Go2Py.robot.fsm import FSM\n", "from Go2Py.robot.remote import KeyboardRemote, XBoxRemote\n", "from Go2Py.robot.safety import SafetyHypervisor\n", "from Go2Py.sim.mujoco import Go2Sim\n", "from Go2Py.control.cat import *\n", "#from Go2Py.control.cat_rnn import *\n", "import torch" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from Go2Py.robot.model import FrictionModel\n", "friction_model = None\n", "Fs = np.zeros(12)\n", "mu_v = np.zeros(12)\n", "#mu_v[[2,5,8,11]] = np.array([0.2167, -0.0647, -0.0420, -0.0834])\n", "#Fs[[2,5,8,11]] = np.array([1.5259, 1.2380, 0.8917, 2.2461])\n", "\n", "#mu_v[[0,3,6,9]] = np.array([0., 0., 0., 0.])\n", "#Fs[[0,3,6,9]] = np.array([1.5, 1.5, 1.5, 1.5])\n", "#mu_v[[2,5,8,11]] = np.array([0., 0., 0., 0.])\n", "#Fs[[2,5,8,11]] = np.array([1.5, 1.5, 1.5, 1.5])\n", "\n", "friction_model = FrictionModel(Fs=1.5, mu_v=0.3)\n", "#friction_model = FrictionModel(Fs=0., mu_v=0.)\n", "#friction_model = FrictionModel(Fs=Fs, mu_v=mu_v)\n", "robot = Go2Sim(dt = 0.001, friction_model=friction_model)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "remote = XBoxRemote() # KeyboardRemote()\n", "robot.sitDownReset()\n", "safety_hypervisor = SafetyHypervisor(robot)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def getRemote(remote):\n", " commands = remote.getCommands()\n", " commands[0] *= 0.6\n", " commands[1] *= 0.6\n", " zero_commands_xy = np.linalg.norm(commands[:2]) <= 0.2\n", " zero_commands_yaw = np.abs(commands[2]) <= 0.2\n", " if zero_commands_xy:\n", " commands[:2] = np.zeros_like(commands[:2])\n", " if zero_commands_yaw:\n", " commands[2] = 0\n", " return commands" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class CaTController:\n", " def __init__(self, robot, remote, checkpoint):\n", " self.remote = remote\n", " self.robot = robot\n", " self.policy = Policy(checkpoint)\n", " self.command_profile = CommandInterface()\n", " self.agent = CaTAgent(self.command_profile, self.robot)\n", " self.hist_data = {}\n", "\n", " def init(self):\n", " self.obs = self.agent.reset()\n", " self.policy_info = {}\n", " self.command_profile.yaw_vel_cmd = 0.0\n", " self.command_profile.x_vel_cmd = 0.0\n", " self.command_profile.y_vel_cmd = 0.0\n", "\n", " def update(self, robot, remote):\n", " if not hasattr(self, \"obs\"):\n", " self.init()\n", " commands = getRemote(remote)\n", " self.command_profile.yaw_vel_cmd = -commands[2]\n", " self.command_profile.x_vel_cmd = commands[1]\n", " self.command_profile.y_vel_cmd = -commands[0]\n", "\n", " self.obs = self.agent.get_obs()\n", " action = self.policy(self.obs, self.policy_info)\n", " _, self.ret, self.done, self.info = self.agent.step(action)\n", " for key, value in self.info.items():\n", " if key in self.hist_data:\n", " self.hist_data[key].append(value)\n", " else:\n", " self.hist_data[key] = [value]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "robot.getJointStates()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from Go2Py import ASSETS_PATH\n", "import os\n", "# what we tested\n", "#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/trainparamsconfigmax_epochs1500_taskenvlearnlimitsfoot_contact_force_rate60_soft_07-20-22-43.pt')\n", "# new one\n", "#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/dof_vel_3_10-00-05-00.pt')\n", "#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/actuator_net_17-21-28-47.pt')\n", "#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/motor_friction_randomization_18-16-13-06.pt')\n", "#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/motor_friction_randomization_cornomove_18-16-41-52.pt')\n", "#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/motor_friction_randomization_morenoise_18-20-04-09.pt')\n", "\n", "#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/dof_pos_nmove_20-23-40-47.pt')\n", "#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/dof_pos_nomove_acrate120_21-16-02-48.pt')\n", "#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/dof_pos_nomove_friction05125_21-19-46-53.pt')\n", "#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/nomove_4footcontact_21-22-30-50.pt')\n", "##checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/HFE_1_21-23-55-16.pt')\n", "#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/faster_gait_22-16-56-56.pt')\n", "#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/no_max_airtime_22-19-29-20.pt')\n", "#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/stand_still_rew_4_RNN_LOG_19-17-28-53.pt')\n", "\n", "# Best policy, friction from 0.5 to 1.25\n", "#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/TEST1_no_max_airtime_airtime025_22-20-13-45.pt')\n", "\n", "# Friction from 0.15 to 0.8, friction motor model\n", "checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/frictionRange015to08_23-22-28-13.pt')\n", "\n", "# Friction from 0.15 to 0.8, NO friction motor model\n", "#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/frictionRange015to08_noRandomizeMotorFriction_23-22-38-07.pt')\n", "\n", "controller = CaTController(robot, remote, checkpoint_path)\n", "decimation = 20\n", "fsm = FSM(robot, remote, safety_hypervisor, control_dT=decimation * robot.dt, user_controller_callback=controller.update)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Slippage Analysis" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "contacts = []\n", "feet_vels = []\n", "\n", "while True:\n", " if remote.xbox_controller.digital_cmd[1]:\n", " break\n", " contact_state = robot.getFootContact()>15\n", " sites = ['FR_foot', 'FL_foot', 'RR_foot', 'RL_foot']\n", " feet_vel = [np.linalg.norm(robot.getFootVelInWorld(s)) for s in sites]\n", " contacts.append(contact_state)\n", " feet_vels.append(feet_vel)\n", " time.sleep(0.01)\n", "\n", "feet_vels = np.stack(feet_vels)\n", "contacts = np.stack(contacts)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "start = 300\n", "end = 1200\n", "plt.plot(contacts[start:end,0])\n", "plt.plot(feet_vels[start:end,0])\n", "plt.legend(['contact state', 'foot velocity'])\n", "plt.grid(True)\n", "plt.tight_layout()\n", "plt.savefig('foot_slipping_fric0.2.png')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**To Do**\n", "- Train a policy without any actuator friction and check the plots for friction 0.2 and 0.6 \n", "- Do the same experiment for the walk-these-ways policy\n", "- While testing the walk these ways, check the output of the adaptation module for various friction numbers, any correlation?" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "fsm.close()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Foot Contanct Analysis" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "np.array(controller.hist_data[\"body_pos\"])[0, 0, -1]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.plot(np.array(controller.hist_data[\"body_pos\"])[:, 0, -1])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "# Assuming 'controller.hist_data[\"torques\"]' is a dictionary with torque profiles\n", "torques = np.array(controller.hist_data[\"body_linear_vel\"])[:, 0, :, 0]\n", "\n", "# Number of torque profiles\n", "torque_nb = torques.shape[1]\n", "\n", "# Number of rows needed for the grid, with 3 columns per row\n", "n_cols = 3\n", "n_rows = int(np.ceil(torque_nb / n_cols))\n", "\n", "# Create the figure and axes for subplots\n", "fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))\n", "\n", "# Flatten the axes array for easy indexing (in case of multiple rows)\n", "axes = axes.flatten()\n", "\n", "# Plot each torque profile\n", "for i in range(torque_nb):\n", " axes[i].plot(np.arange(torques.shape[0]) * robot.dt * decimation, torques[:, i])\n", " axes[i].set_title(f'Torque {i+1}')\n", " axes[i].set_xlabel('Time')\n", " axes[i].set_ylabel('Torque Value')\n", " axes[i].grid(True)\n", "\n", "# Remove any empty subplots if torque_nb is not a multiple of 3\n", "for j in range(torque_nb, len(axes)):\n", " fig.delaxes(axes[j])\n", "\n", "# Adjust layout\n", "plt.tight_layout()\n", "plt.savefig(\"torque_profile.png\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "# Assuming 'controller.hist_data[\"torques\"]' is a dictionary with torque profiles\n", "torques = np.array(controller.hist_data[\"torques\"])\n", "\n", "# Number of torque profiles\n", "torque_nb = torques.shape[1]\n", "\n", "# Number of rows needed for the grid, with 3 columns per row\n", "n_cols = 3\n", "n_rows = int(np.ceil(torque_nb / n_cols))\n", "\n", "# Create the figure and axes for subplots\n", "fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))\n", "\n", "# Flatten the axes array for easy indexing (in case of multiple rows)\n", "axes = axes.flatten()\n", "\n", "# Plot each torque profile\n", "for i in range(torque_nb):\n", " axes[i].plot(np.arange(torques.shape[0]) * robot.dt * decimation, torques[:, i])\n", " axes[i].set_title(f'Torque {i+1}')\n", " axes[i].set_xlabel('Time')\n", " axes[i].set_ylabel('Torque Value')\n", " axes[i].grid(True)\n", "\n", "# Remove any empty subplots if torque_nb is not a multiple of 3\n", "for j in range(torque_nb, len(axes)):\n", " fig.delaxes(axes[j])\n", "\n", "# Adjust layout\n", "plt.tight_layout()\n", "plt.savefig(\"torque_profile.png\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Extract the joint position data for the first joint over time\n", "joint_pos = np.array(controller.hist_data[\"joint_vel\"])[:, 0]\n", "\n", "# Number of data points in joint_pos\n", "n_data_points = len(joint_pos)\n", "\n", "# Since you're plotting only one joint, no need for multiple subplots in this case.\n", "# But to follow the grid requirement, we'll replicate the data across multiple subplots.\n", "# For example, let's assume you want to visualize this data 9 times in a 3x3 grid.\n", "\n", "n_cols = 3\n", "n_rows = int(np.ceil(torque_nb / n_cols))\n", "\n", "# Create the figure and axes for subplots\n", "fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))\n", "\n", "# Flatten the axes array for easy indexing (in case of multiple rows)\n", "axes = axes.flatten()\n", "\n", "# Plot the same joint position data in every subplot (as per grid requirement)\n", "for i in range(n_rows * n_cols):\n", " axes[i].plot(joint_pos[:, i])\n", " axes[i].set_title(f'Joint Position {i+1}')\n", " axes[i].set_xlabel('Time')\n", " axes[i].set_ylabel('Position Value')\n", "\n", "# Adjust layout\n", "plt.tight_layout()\n", "plt.savefig(\"joint_position_profile.png\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "# Assuming 'controller.hist_data[\"foot_contact_forces_mag\"]' is a dictionary with foot contact force magnitudes\n", "foot_contact_forces_mag = np.array(controller.hist_data[\"foot_contact_forces_mag\"])\n", "\n", "# Number of feet (foot_nb)\n", "foot_nb = foot_contact_forces_mag.shape[1]\n", "\n", "# Number of rows needed for the grid, with 3 columns per row\n", "n_cols = 3\n", "n_rows = int(np.ceil(foot_nb / n_cols))\n", "\n", "# Create the figure and axes for subplots\n", "fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))\n", "\n", "# Flatten the axes array for easy indexing (in case of multiple rows)\n", "axes = axes.flatten()\n", "\n", "# Plot each foot's contact force magnitude\n", "for i in range(foot_nb):\n", " axes[i].plot(foot_contact_forces_mag[:, i])\n", " axes[i].set_title(f'Foot {i+1} Contact Force Magnitude')\n", " axes[i].set_xlabel('Time')\n", " axes[i].set_ylabel('Force Magnitude')\n", "\n", "# Remove any empty subplots if foot_nb is not a multiple of 3\n", "for j in range(foot_nb, len(axes)):\n", " fig.delaxes(axes[j])\n", "\n", "# Adjust layout\n", "plt.tight_layout()\n", "plt.savefig(\"foot_contact_profile.png\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Extract the joint acceleration data for the first joint over time\n", "joint_acc = np.array(controller.hist_data[\"joint_acc\"])[:, 0]\n", "\n", "# Number of data points in joint_acc\n", "n_data_points = len(joint_acc)\n", "\n", "# Number of feet (foot_nb)\n", "foot_nb = joint_acc.shape[1]\n", "\n", "# Number of rows needed for the grid, with 3 columns per row\n", "n_cols = 3\n", "n_rows = int(np.ceil(foot_nb / n_cols))\n", "\n", "# Create the figure and axes for subplots\n", "fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))\n", "\n", "# Flatten the axes array for easy indexing\n", "axes = axes.flatten()\n", "\n", "# Plot the same joint acceleration data in every subplot (as per grid requirement)\n", "for i in range(n_rows * n_cols):\n", " axes[i].plot(joint_acc[:, i])\n", " axes[i].set_title(f'Joint Acceleration {i+1}')\n", " axes[i].set_xlabel('Time')\n", " axes[i].set_ylabel('Acceleration Value')\n", "\n", "# Adjust layout to prevent overlap\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Extract the joint jerk data over time\n", "joint_jerk = np.array(controller.hist_data[\"joint_jerk\"])[:, 0]\n", "\n", "# Number of data points in joint_jerk\n", "n_data_points = len(joint_jerk)\n", "\n", "# Number of joints (assuming the second dimension corresponds to joints)\n", "num_joints = joint_jerk.shape[1]\n", "\n", "# Number of columns per row in the subplot grid\n", "n_cols = 3\n", "# Number of rows needed for the grid\n", "n_rows = int(np.ceil(num_joints / n_cols))\n", "\n", "# Create the figure and axes for subplots\n", "fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))\n", "\n", "# Flatten the axes array for easy indexing\n", "axes = axes.flatten()\n", "\n", "# Plot the joint jerk data for each joint\n", "for i in range(num_joints):\n", " axes[i].plot(joint_jerk[:, i])\n", " axes[i].set_title(f'Joint Jerk {i+1}')\n", " axes[i].set_xlabel('Time')\n", " axes[i].set_ylabel('Jerk Value')\n", "\n", "# Hide any unused subplots\n", "for i in range(num_joints, len(axes)):\n", " fig.delaxes(axes[i])\n", "\n", "# Adjust layout to prevent overlap\n", "plt.tight_layout()\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Extract the foot contact rate data over time\n", "foot_contact_rate = np.array(controller.hist_data[\"foot_contact_rate\"])[:, 0]\n", "\n", "# Number of data points in foot_contact_rate\n", "n_data_points = foot_contact_rate.shape[0]\n", "\n", "# Number of feet (assuming the second dimension corresponds to feet)\n", "num_feet = foot_contact_rate.shape[1]\n", "\n", "# Number of columns per row in the subplot grid\n", "n_cols = 3\n", "# Number of rows needed for the grid\n", "n_rows = int(np.ceil(num_feet / n_cols))\n", "\n", "# Create the figure and axes for subplots\n", "fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))\n", "\n", "# Flatten the axes array for easy indexing\n", "axes = axes.flatten()\n", "\n", "# Plot the foot contact rate data for each foot\n", "for i in range(num_feet):\n", " axes[i].plot(foot_contact_rate[:, i])\n", " axes[i].set_title(f'Foot Contact Rate {i+1}')\n", " axes[i].set_xlabel('Time')\n", " axes[i].set_ylabel('Contact Rate')\n", "\n", "# Hide any unused subplots\n", "for i in range(num_feet, len(axes)):\n", " fig.delaxes(axes[i])\n", "\n", "# Adjust layout to prevent overlap\n", "plt.tight_layout()\n", "plt.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test on Real Robot (ToDo)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "from Go2Py.robot.fsm import FSM\n", "from Go2Py.robot.remote import XBoxRemote\n", "from Go2Py.robot.safety import SafetyHypervisor\n", "from Go2Py.control.cat import *" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from Go2Py.robot.interface import GO2Real\n", "import numpy as np\n", "robot = GO2Real(mode='lowlevel')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "remote = XBoxRemote() # KeyboardRemote()\n", "safety_hypervisor = SafetyHypervisor(robot)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "robot.getJointStates()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Make sure the robot can take commands from python. The next cell should make the joints free to move (no damping)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import time\n", "start_time = time.time()\n", "\n", "while time.time()-start_time < 30:\n", " q = np.zeros(12) \n", " dq = np.zeros(12)\n", " kp = np.ones(12)*0.0\n", " kd = np.ones(12)*0.0\n", " tau = np.zeros(12)\n", " tau[0] = 0.0\n", " robot.setCommands(q, dq, kp, kd, tau)\n", " time.sleep(0.02)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "def getRemote(remote):\n", " commands = remote.getCommands()\n", " commands[0] *= 0.6\n", " commands[1] *= 0.6\n", " zero_commands = np.logical_and(\n", " np.linalg.norm(commands[:2]) <= 0.2,\n", " np.abs(commands[2]) <= 0.2\n", " )\n", " if zero_commands:\n", " commands = np.zeros_like(commands)\n", " return commands" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "class CaTController:\n", " def __init__(self, robot, remote, checkpoint):\n", " self.remote = remote\n", " self.robot = robot\n", " self.policy = Policy(checkpoint)\n", " self.command_profile = CommandInterface()\n", " self.agent = CaTAgent(self.command_profile, self.robot)\n", " self.hist_data = {}\n", "\n", " def init(self):\n", " self.obs = self.agent.reset()\n", " self.policy_info = {}\n", " self.command_profile.yaw_vel_cmd = 0.0\n", " self.command_profile.x_vel_cmd = 0.0\n", " self.command_profile.y_vel_cmd = 0.0\n", "\n", " def update(self, robot, remote):\n", " if not hasattr(self, \"obs\"):\n", " self.init()\n", " commands = getRemote(remote)\n", " self.command_profile.yaw_vel_cmd = -commands[2]\n", " self.command_profile.x_vel_cmd = commands[1]\n", " self.command_profile.y_vel_cmd = -commands[0]\n", "\n", " self.obs = self.agent.get_obs()\n", " action = self.policy(self.obs, self.policy_info)\n", " _, self.ret, self.done, self.info = self.agent.step(action)\n", " for key, value in self.info.items():\n", " if key in self.hist_data:\n", " self.hist_data[key].append(value)\n", " else:\n", " self.hist_data[key] = [value]\n", "\n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from Go2Py import ASSETS_PATH \n", "import os\n", "#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/trainparamsconfigmax_epochs1500_taskenvlearnlimitsfoot_contact_force_rate60_soft_07-20-22-43.pt')\n", "#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/actuator_net_17-21-28-47.pt')\n", "#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/motor_friction_randomization_cornomove_18-16-41-52.pt')\n", "\n", "#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/dof_pos_nmove_20-23-40-47.pt')\n", "#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/dof_pos_nomove_friction05125_21-19-46-53.pt')\n", "\n", "#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/HAA_01_21-20-57-43.pt')\n", "#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/nomove_4footcontact_21-22-30-50.pt')\n", "###checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/HFE_1_21-23-55-16.pt')\n", "#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/faster_gait_22-16-56-56.pt')\n", "#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/no_max_airtime_22-19-29-20.pt')\n", "\n", "# Best policy\n", "checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/TEST1_no_max_airtime_airtime025_22-20-13-45.pt')\n", "\n", "#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/frictionRange015to08_23-22-28-13.pt')\n", "\n", "controller = CaTController(robot, remote, checkpoint_path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fsm = FSM(robot, remote, safety_hypervisor, control_dT=1./50., user_controller_callback=controller.update)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fsm.close()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "b1-env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }