Go2Py/examples/06-CaT-RL-controller.ipynb

769 lines
25 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"RL policy based on the [SoloParkour: Constrained Reinforcement Learning for Visual Locomotion from Privileged Experience](https://arxiv.org/abs/2409.13678). "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Flat Ground"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Test In Simulation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from Go2Py.robot.fsm import FSM\n",
"from Go2Py.robot.remote import KeyboardRemote, XBoxRemote\n",
"from Go2Py.robot.safety import SafetyHypervisor\n",
"from Go2Py.sim.mujoco import Go2Sim\n",
"from Go2Py.control.cat import *\n",
"#from Go2Py.control.cat_rnn import *\n",
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from Go2Py.robot.model import FrictionModel\n",
"friction_model = None\n",
"Fs = np.zeros(12)\n",
"mu_v = np.zeros(12)\n",
"#mu_v[[2,5,8,11]] = np.array([0.2167, -0.0647, -0.0420, -0.0834])\n",
"#Fs[[2,5,8,11]] = np.array([1.5259, 1.2380, 0.8917, 2.2461])\n",
"\n",
"#mu_v[[0,3,6,9]] = np.array([0., 0., 0., 0.])\n",
"#Fs[[0,3,6,9]] = np.array([1.5, 1.5, 1.5, 1.5])\n",
"#mu_v[[2,5,8,11]] = np.array([0., 0., 0., 0.])\n",
"#Fs[[2,5,8,11]] = np.array([1.5, 1.5, 1.5, 1.5])\n",
"\n",
"friction_model = FrictionModel(Fs=1.5, mu_v=0.3)\n",
"#friction_model = FrictionModel(Fs=0., mu_v=0.)\n",
"#friction_model = FrictionModel(Fs=Fs, mu_v=mu_v)\n",
"robot = Go2Sim(dt = 0.001, friction_model=friction_model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"remote = XBoxRemote() # KeyboardRemote()\n",
"robot.sitDownReset()\n",
"safety_hypervisor = SafetyHypervisor(robot)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def getRemote(remote):\n",
" commands = remote.getCommands()\n",
" commands[0] *= 0.6\n",
" commands[1] *= 0.6\n",
" zero_commands_xy = np.linalg.norm(commands[:2]) <= 0.2\n",
" zero_commands_yaw = np.abs(commands[2]) <= 0.2\n",
" if zero_commands_xy:\n",
" commands[:2] = np.zeros_like(commands[:2])\n",
" if zero_commands_yaw:\n",
" commands[2] = 0\n",
" return commands"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"class CaTController:\n",
" def __init__(self, robot, remote, checkpoint):\n",
" self.remote = remote\n",
" self.robot = robot\n",
" self.policy = Policy(checkpoint)\n",
" self.command_profile = CommandInterface()\n",
" self.agent = CaTAgent(self.command_profile, self.robot)\n",
" self.hist_data = {}\n",
"\n",
" def init(self):\n",
" self.obs = self.agent.reset()\n",
" self.policy_info = {}\n",
" self.command_profile.yaw_vel_cmd = 0.0\n",
" self.command_profile.x_vel_cmd = 0.0\n",
" self.command_profile.y_vel_cmd = 0.0\n",
"\n",
" def update(self, robot, remote):\n",
" if not hasattr(self, \"obs\"):\n",
" self.init()\n",
" commands = getRemote(remote)\n",
" self.command_profile.yaw_vel_cmd = -commands[2]\n",
" self.command_profile.x_vel_cmd = commands[1]\n",
" self.command_profile.y_vel_cmd = -commands[0]\n",
"\n",
" self.obs = self.agent.get_obs()\n",
" action = self.policy(self.obs, self.policy_info)\n",
" _, self.ret, self.done, self.info = self.agent.step(action)\n",
" for key, value in self.info.items():\n",
" if key in self.hist_data:\n",
" self.hist_data[key].append(value)\n",
" else:\n",
" self.hist_data[key] = [value]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"robot.getJointStates()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from Go2Py import ASSETS_PATH\n",
"import os\n",
"# what we tested\n",
"#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/trainparamsconfigmax_epochs1500_taskenvlearnlimitsfoot_contact_force_rate60_soft_07-20-22-43.pt')\n",
"# new one\n",
"#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/dof_vel_3_10-00-05-00.pt')\n",
"#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/actuator_net_17-21-28-47.pt')\n",
"#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/motor_friction_randomization_18-16-13-06.pt')\n",
"#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/motor_friction_randomization_cornomove_18-16-41-52.pt')\n",
"#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/motor_friction_randomization_morenoise_18-20-04-09.pt')\n",
"\n",
"#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/dof_pos_nmove_20-23-40-47.pt')\n",
"#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/dof_pos_nomove_acrate120_21-16-02-48.pt')\n",
"#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/dof_pos_nomove_friction05125_21-19-46-53.pt')\n",
"#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/nomove_4footcontact_21-22-30-50.pt')\n",
"##checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/HFE_1_21-23-55-16.pt')\n",
"#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/faster_gait_22-16-56-56.pt')\n",
"#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/no_max_airtime_22-19-29-20.pt')\n",
"#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/stand_still_rew_4_RNN_LOG_19-17-28-53.pt')\n",
"\n",
"# Best policy, friction from 0.5 to 1.25\n",
"#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/TEST1_no_max_airtime_airtime025_22-20-13-45.pt')\n",
"\n",
"# Friction from 0.15 to 0.8, friction motor model\n",
"checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/frictionRange015to08_23-22-28-13.pt')\n",
"\n",
"# Friction from 0.15 to 0.8, NO friction motor model\n",
"#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/frictionRange015to08_noRandomizeMotorFriction_23-22-38-07.pt')\n",
"\n",
"controller = CaTController(robot, remote, checkpoint_path)\n",
"decimation = 20\n",
"fsm = FSM(robot, remote, safety_hypervisor, control_dT=decimation * robot.dt, user_controller_callback=controller.update)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Slippage Analysis"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"contacts = []\n",
"feet_vels = []\n",
"\n",
"while True:\n",
" if remote.xbox_controller.digital_cmd[1]:\n",
" break\n",
" contact_state = robot.getFootContact()>15\n",
" sites = ['FR_foot', 'FL_foot', 'RR_foot', 'RL_foot']\n",
" feet_vel = [np.linalg.norm(robot.getFootVelInWorld(s)) for s in sites]\n",
" contacts.append(contact_state)\n",
" feet_vels.append(feet_vel)\n",
" time.sleep(0.01)\n",
"\n",
"feet_vels = np.stack(feet_vels)\n",
"contacts = np.stack(contacts)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"start = 300\n",
"end = 1200\n",
"plt.plot(contacts[start:end,0])\n",
"plt.plot(feet_vels[start:end,0])\n",
"plt.legend(['contact state', 'foot velocity'])\n",
"plt.grid(True)\n",
"plt.tight_layout()\n",
"plt.savefig('foot_slipping_fric0.2.png')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**To Do**\n",
"- Train a policy without any actuator friction and check the plots for friction 0.2 and 0.6 \n",
"- Do the same experiment for the walk-these-ways policy\n",
"- While testing the walk these ways, check the output of the adaptation module for various friction numbers, any correlation?"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"fsm.close()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Foot Contanct Analysis"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"np.array(controller.hist_data[\"body_pos\"])[0, 0, -1]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.plot(np.array(controller.hist_data[\"body_pos\"])[:, 0, -1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"# Assuming 'controller.hist_data[\"torques\"]' is a dictionary with torque profiles\n",
"torques = np.array(controller.hist_data[\"body_linear_vel\"])[:, 0, :, 0]\n",
"\n",
"# Number of torque profiles\n",
"torque_nb = torques.shape[1]\n",
"\n",
"# Number of rows needed for the grid, with 3 columns per row\n",
"n_cols = 3\n",
"n_rows = int(np.ceil(torque_nb / n_cols))\n",
"\n",
"# Create the figure and axes for subplots\n",
"fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))\n",
"\n",
"# Flatten the axes array for easy indexing (in case of multiple rows)\n",
"axes = axes.flatten()\n",
"\n",
"# Plot each torque profile\n",
"for i in range(torque_nb):\n",
" axes[i].plot(np.arange(torques.shape[0]) * robot.dt * decimation, torques[:, i])\n",
" axes[i].set_title(f'Torque {i+1}')\n",
" axes[i].set_xlabel('Time')\n",
" axes[i].set_ylabel('Torque Value')\n",
" axes[i].grid(True)\n",
"\n",
"# Remove any empty subplots if torque_nb is not a multiple of 3\n",
"for j in range(torque_nb, len(axes)):\n",
" fig.delaxes(axes[j])\n",
"\n",
"# Adjust layout\n",
"plt.tight_layout()\n",
"plt.savefig(\"torque_profile.png\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"# Assuming 'controller.hist_data[\"torques\"]' is a dictionary with torque profiles\n",
"torques = np.array(controller.hist_data[\"torques\"])\n",
"\n",
"# Number of torque profiles\n",
"torque_nb = torques.shape[1]\n",
"\n",
"# Number of rows needed for the grid, with 3 columns per row\n",
"n_cols = 3\n",
"n_rows = int(np.ceil(torque_nb / n_cols))\n",
"\n",
"# Create the figure and axes for subplots\n",
"fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))\n",
"\n",
"# Flatten the axes array for easy indexing (in case of multiple rows)\n",
"axes = axes.flatten()\n",
"\n",
"# Plot each torque profile\n",
"for i in range(torque_nb):\n",
" axes[i].plot(np.arange(torques.shape[0]) * robot.dt * decimation, torques[:, i])\n",
" axes[i].set_title(f'Torque {i+1}')\n",
" axes[i].set_xlabel('Time')\n",
" axes[i].set_ylabel('Torque Value')\n",
" axes[i].grid(True)\n",
"\n",
"# Remove any empty subplots if torque_nb is not a multiple of 3\n",
"for j in range(torque_nb, len(axes)):\n",
" fig.delaxes(axes[j])\n",
"\n",
"# Adjust layout\n",
"plt.tight_layout()\n",
"plt.savefig(\"torque_profile.png\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Extract the joint position data for the first joint over time\n",
"joint_pos = np.array(controller.hist_data[\"joint_vel\"])[:, 0]\n",
"\n",
"# Number of data points in joint_pos\n",
"n_data_points = len(joint_pos)\n",
"\n",
"# Since you're plotting only one joint, no need for multiple subplots in this case.\n",
"# But to follow the grid requirement, we'll replicate the data across multiple subplots.\n",
"# For example, let's assume you want to visualize this data 9 times in a 3x3 grid.\n",
"\n",
"n_cols = 3\n",
"n_rows = int(np.ceil(torque_nb / n_cols))\n",
"\n",
"# Create the figure and axes for subplots\n",
"fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))\n",
"\n",
"# Flatten the axes array for easy indexing (in case of multiple rows)\n",
"axes = axes.flatten()\n",
"\n",
"# Plot the same joint position data in every subplot (as per grid requirement)\n",
"for i in range(n_rows * n_cols):\n",
" axes[i].plot(joint_pos[:, i])\n",
" axes[i].set_title(f'Joint Position {i+1}')\n",
" axes[i].set_xlabel('Time')\n",
" axes[i].set_ylabel('Position Value')\n",
"\n",
"# Adjust layout\n",
"plt.tight_layout()\n",
"plt.savefig(\"joint_position_profile.png\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"# Assuming 'controller.hist_data[\"foot_contact_forces_mag\"]' is a dictionary with foot contact force magnitudes\n",
"foot_contact_forces_mag = np.array(controller.hist_data[\"foot_contact_forces_mag\"])\n",
"\n",
"# Number of feet (foot_nb)\n",
"foot_nb = foot_contact_forces_mag.shape[1]\n",
"\n",
"# Number of rows needed for the grid, with 3 columns per row\n",
"n_cols = 3\n",
"n_rows = int(np.ceil(foot_nb / n_cols))\n",
"\n",
"# Create the figure and axes for subplots\n",
"fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))\n",
"\n",
"# Flatten the axes array for easy indexing (in case of multiple rows)\n",
"axes = axes.flatten()\n",
"\n",
"# Plot each foot's contact force magnitude\n",
"for i in range(foot_nb):\n",
" axes[i].plot(foot_contact_forces_mag[:, i])\n",
" axes[i].set_title(f'Foot {i+1} Contact Force Magnitude')\n",
" axes[i].set_xlabel('Time')\n",
" axes[i].set_ylabel('Force Magnitude')\n",
"\n",
"# Remove any empty subplots if foot_nb is not a multiple of 3\n",
"for j in range(foot_nb, len(axes)):\n",
" fig.delaxes(axes[j])\n",
"\n",
"# Adjust layout\n",
"plt.tight_layout()\n",
"plt.savefig(\"foot_contact_profile.png\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Extract the joint acceleration data for the first joint over time\n",
"joint_acc = np.array(controller.hist_data[\"joint_acc\"])[:, 0]\n",
"\n",
"# Number of data points in joint_acc\n",
"n_data_points = len(joint_acc)\n",
"\n",
"# Number of feet (foot_nb)\n",
"foot_nb = joint_acc.shape[1]\n",
"\n",
"# Number of rows needed for the grid, with 3 columns per row\n",
"n_cols = 3\n",
"n_rows = int(np.ceil(foot_nb / n_cols))\n",
"\n",
"# Create the figure and axes for subplots\n",
"fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))\n",
"\n",
"# Flatten the axes array for easy indexing\n",
"axes = axes.flatten()\n",
"\n",
"# Plot the same joint acceleration data in every subplot (as per grid requirement)\n",
"for i in range(n_rows * n_cols):\n",
" axes[i].plot(joint_acc[:, i])\n",
" axes[i].set_title(f'Joint Acceleration {i+1}')\n",
" axes[i].set_xlabel('Time')\n",
" axes[i].set_ylabel('Acceleration Value')\n",
"\n",
"# Adjust layout to prevent overlap\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Extract the joint jerk data over time\n",
"joint_jerk = np.array(controller.hist_data[\"joint_jerk\"])[:, 0]\n",
"\n",
"# Number of data points in joint_jerk\n",
"n_data_points = len(joint_jerk)\n",
"\n",
"# Number of joints (assuming the second dimension corresponds to joints)\n",
"num_joints = joint_jerk.shape[1]\n",
"\n",
"# Number of columns per row in the subplot grid\n",
"n_cols = 3\n",
"# Number of rows needed for the grid\n",
"n_rows = int(np.ceil(num_joints / n_cols))\n",
"\n",
"# Create the figure and axes for subplots\n",
"fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))\n",
"\n",
"# Flatten the axes array for easy indexing\n",
"axes = axes.flatten()\n",
"\n",
"# Plot the joint jerk data for each joint\n",
"for i in range(num_joints):\n",
" axes[i].plot(joint_jerk[:, i])\n",
" axes[i].set_title(f'Joint Jerk {i+1}')\n",
" axes[i].set_xlabel('Time')\n",
" axes[i].set_ylabel('Jerk Value')\n",
"\n",
"# Hide any unused subplots\n",
"for i in range(num_joints, len(axes)):\n",
" fig.delaxes(axes[i])\n",
"\n",
"# Adjust layout to prevent overlap\n",
"plt.tight_layout()\n",
"plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Extract the foot contact rate data over time\n",
"foot_contact_rate = np.array(controller.hist_data[\"foot_contact_rate\"])[:, 0]\n",
"\n",
"# Number of data points in foot_contact_rate\n",
"n_data_points = foot_contact_rate.shape[0]\n",
"\n",
"# Number of feet (assuming the second dimension corresponds to feet)\n",
"num_feet = foot_contact_rate.shape[1]\n",
"\n",
"# Number of columns per row in the subplot grid\n",
"n_cols = 3\n",
"# Number of rows needed for the grid\n",
"n_rows = int(np.ceil(num_feet / n_cols))\n",
"\n",
"# Create the figure and axes for subplots\n",
"fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))\n",
"\n",
"# Flatten the axes array for easy indexing\n",
"axes = axes.flatten()\n",
"\n",
"# Plot the foot contact rate data for each foot\n",
"for i in range(num_feet):\n",
" axes[i].plot(foot_contact_rate[:, i])\n",
" axes[i].set_title(f'Foot Contact Rate {i+1}')\n",
" axes[i].set_xlabel('Time')\n",
" axes[i].set_ylabel('Contact Rate')\n",
"\n",
"# Hide any unused subplots\n",
"for i in range(num_feet, len(axes)):\n",
" fig.delaxes(axes[i])\n",
"\n",
"# Adjust layout to prevent overlap\n",
"plt.tight_layout()\n",
"plt.show()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Test on Real Robot (ToDo)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"from Go2Py.robot.fsm import FSM\n",
"from Go2Py.robot.remote import XBoxRemote\n",
"from Go2Py.robot.safety import SafetyHypervisor\n",
"from Go2Py.control.cat import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from Go2Py.robot.interface import GO2Real\n",
"import numpy as np\n",
"robot = GO2Real(mode='lowlevel')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"remote = XBoxRemote() # KeyboardRemote()\n",
"safety_hypervisor = SafetyHypervisor(robot)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"robot.getJointStates()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Make sure the robot can take commands from python. The next cell should make the joints free to move (no damping)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import time\n",
"start_time = time.time()\n",
"\n",
"while time.time()-start_time < 30:\n",
" q = np.zeros(12) \n",
" dq = np.zeros(12)\n",
" kp = np.ones(12)*0.0\n",
" kd = np.ones(12)*0.0\n",
" tau = np.zeros(12)\n",
" tau[0] = 0.0\n",
" robot.setCommands(q, dq, kp, kd, tau)\n",
" time.sleep(0.02)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"def getRemote(remote):\n",
" commands = remote.getCommands()\n",
" commands[0] *= 0.6\n",
" commands[1] *= 0.6\n",
" zero_commands = np.logical_and(\n",
" np.linalg.norm(commands[:2]) <= 0.2,\n",
" np.abs(commands[2]) <= 0.2\n",
" )\n",
" if zero_commands:\n",
" commands = np.zeros_like(commands)\n",
" return commands"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"class CaTController:\n",
" def __init__(self, robot, remote, checkpoint):\n",
" self.remote = remote\n",
" self.robot = robot\n",
" self.policy = Policy(checkpoint)\n",
" self.command_profile = CommandInterface()\n",
" self.agent = CaTAgent(self.command_profile, self.robot)\n",
" self.hist_data = {}\n",
"\n",
" def init(self):\n",
" self.obs = self.agent.reset()\n",
" self.policy_info = {}\n",
" self.command_profile.yaw_vel_cmd = 0.0\n",
" self.command_profile.x_vel_cmd = 0.0\n",
" self.command_profile.y_vel_cmd = 0.0\n",
"\n",
" def update(self, robot, remote):\n",
" if not hasattr(self, \"obs\"):\n",
" self.init()\n",
" commands = getRemote(remote)\n",
" self.command_profile.yaw_vel_cmd = -commands[2]\n",
" self.command_profile.x_vel_cmd = commands[1]\n",
" self.command_profile.y_vel_cmd = -commands[0]\n",
"\n",
" self.obs = self.agent.get_obs()\n",
" action = self.policy(self.obs, self.policy_info)\n",
" _, self.ret, self.done, self.info = self.agent.step(action)\n",
" for key, value in self.info.items():\n",
" if key in self.hist_data:\n",
" self.hist_data[key].append(value)\n",
" else:\n",
" self.hist_data[key] = [value]\n",
"\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from Go2Py import ASSETS_PATH \n",
"import os\n",
"#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/trainparamsconfigmax_epochs1500_taskenvlearnlimitsfoot_contact_force_rate60_soft_07-20-22-43.pt')\n",
"#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/actuator_net_17-21-28-47.pt')\n",
"#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/motor_friction_randomization_cornomove_18-16-41-52.pt')\n",
"\n",
"#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/dof_pos_nmove_20-23-40-47.pt')\n",
"#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/dof_pos_nomove_friction05125_21-19-46-53.pt')\n",
"\n",
"#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/HAA_01_21-20-57-43.pt')\n",
"#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/nomove_4footcontact_21-22-30-50.pt')\n",
"###checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/HFE_1_21-23-55-16.pt')\n",
"#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/faster_gait_22-16-56-56.pt')\n",
"#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/no_max_airtime_22-19-29-20.pt')\n",
"\n",
"# Best policy\n",
"checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/TEST1_no_max_airtime_airtime025_22-20-13-45.pt')\n",
"\n",
"#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/frictionRange015to08_23-22-28-13.pt')\n",
"\n",
"controller = CaTController(robot, remote, checkpoint_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fsm = FSM(robot, remote, safety_hypervisor, control_dT=1./50., user_controller_callback=controller.update)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fsm.close()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "b1-env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}