{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "RL policy based on the [SoloParkour: Constrained Reinforcement Learning for Visual Locomotion from Privileged Experience](https://arxiv.org/abs/2409.13678). " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Flat Ground" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test In Simulation" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from Go2Py.robot.fsm import FSM\n", "from Go2Py.robot.remote import KeyboardRemote, XBoxRemote\n", "from Go2Py.robot.safety import SafetyHypervisor\n", "from Go2Py.sim.mujoco import Go2Sim\n", "from Go2Py.control.cat import *\n", "import torch" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from Go2Py.robot.model import FrictionModel\n", "friction_model = FrictionModel(Fs=3, mu_v=0.05)\n", "robot = Go2Sim(dt = 0.001, friction_model=friction_model)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "remote = XBoxRemote() # KeyboardRemote()\n", "robot.sitDownReset()\n", "safety_hypervisor = SafetyHypervisor(robot)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class CaTController:\n", " def __init__(self, robot, remote, checkpoint):\n", " self.remote = remote\n", " self.robot = robot\n", " self.policy = Policy(checkpoint)\n", " self.command_profile = CommandInterface()\n", " self.agent = CaTAgent(self.command_profile, self.robot)\n", " self.hist_data = {}\n", "\n", " def init(self):\n", " self.obs = self.agent.reset()\n", " self.policy_info = {}\n", " self.command_profile.yaw_vel_cmd = 0.0\n", " self.command_profile.x_vel_cmd = 0.0\n", " self.command_profile.y_vel_cmd = 0.0\n", "\n", " def update(self, robot, remote):\n", " if not hasattr(self, \"obs\"):\n", " self.init()\n", " commands = remote.getCommands()\n", " self.command_profile.yaw_vel_cmd = -commands[2]\n", " self.command_profile.x_vel_cmd = commands[1] * 0.6\n", " self.command_profile.y_vel_cmd = -commands[0] * 0.6\n", "\n", " action = self.policy(self.obs, self.policy_info)\n", " self.obs, self.ret, self.done, self.info = self.agent.step(action)\n", " for key, value in self.info.items():\n", " if key in self.hist_data:\n", " self.hist_data[key].append(value)\n", " else:\n", " self.hist_data[key] = [value]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "robot.getJointStates()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from Go2Py import ASSETS_PATH \n", "import os\n", "# what we tested\n", "#checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/trainparamsconfigmax_epochs1500_taskenvlearnlimitsfoot_contact_force_rate60_soft_07-20-22-43.pt')\n", "# new one\n", "checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/dof_vel_3_10-00-05-00.pt')\n", "controller = CaTController(robot, remote, checkpoint_path)\n", "decimation = 20\n", "fsm = FSM(robot, remote, safety_hypervisor, control_dT=decimation * robot.dt, user_controller_callback=controller.update)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "remote.x_vel_cmd=0.6\n", "remote.y_vel_cmd=0.0\n", "remote.yaw_vel_cmd = 0.0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Pressing `u` on the keyboard will make the robot stand up. This is equivalent to the `L2+A` combo of the Go2 builtin state machine. After the the robot is on its feet, pressing `s` will hand over the control the RL policy. This action is equivalent to the `start` key of the builtin controller. When you want to stop, pressing `u` again will act similarly to the real robot and locks it in standing mode. Finally, pressing `u` again will command the robot to sit down." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "fsm.close()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "# Assuming 'controller.hist_data[\"torques\"]' is a dictionary with torque profiles\n", "torques = np.array(controller.hist_data[\"body_linear_vel\"])[:, 0, :, 0]\n", "\n", "# Number of torque profiles\n", "torque_nb = torques.shape[1]\n", "\n", "# Number of rows needed for the grid, with 3 columns per row\n", "n_cols = 3\n", "n_rows = int(np.ceil(torque_nb / n_cols))\n", "\n", "# Create the figure and axes for subplots\n", "fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))\n", "\n", "# Flatten the axes array for easy indexing (in case of multiple rows)\n", "axes = axes.flatten()\n", "\n", "# Plot each torque profile\n", "for i in range(torque_nb):\n", " axes[i].plot(np.arange(torques.shape[0]) * robot.dt * decimation, torques[:, i])\n", " axes[i].set_title(f'Torque {i+1}')\n", " axes[i].set_xlabel('Time')\n", " axes[i].set_ylabel('Torque Value')\n", " axes[i].grid(True)\n", "\n", "# Remove any empty subplots if torque_nb is not a multiple of 3\n", "for j in range(torque_nb, len(axes)):\n", " fig.delaxes(axes[j])\n", "\n", "# Adjust layout\n", "plt.tight_layout()\n", "plt.savefig(\"torque_profile.png\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "# Assuming 'controller.hist_data[\"torques\"]' is a dictionary with torque profiles\n", "torques = np.array(controller.hist_data[\"torques\"])\n", "\n", "# Number of torque profiles\n", "torque_nb = torques.shape[1]\n", "\n", "# Number of rows needed for the grid, with 3 columns per row\n", "n_cols = 3\n", "n_rows = int(np.ceil(torque_nb / n_cols))\n", "\n", "# Create the figure and axes for subplots\n", "fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))\n", "\n", "# Flatten the axes array for easy indexing (in case of multiple rows)\n", "axes = axes.flatten()\n", "\n", "# Plot each torque profile\n", "for i in range(torque_nb):\n", " axes[i].plot(np.arange(torques.shape[0]) * robot.dt * decimation, torques[:, i])\n", " axes[i].set_title(f'Torque {i+1}')\n", " axes[i].set_xlabel('Time')\n", " axes[i].set_ylabel('Torque Value')\n", " axes[i].grid(True)\n", "\n", "# Remove any empty subplots if torque_nb is not a multiple of 3\n", "for j in range(torque_nb, len(axes)):\n", " fig.delaxes(axes[j])\n", "\n", "# Adjust layout\n", "plt.tight_layout()\n", "plt.savefig(\"torque_profile.png\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Extract the joint position data for the first joint over time\n", "joint_pos = np.array(controller.hist_data[\"joint_pos\"])[:, 0]\n", "\n", "# Number of data points in joint_pos\n", "n_data_points = len(joint_pos)\n", "\n", "# Since you're plotting only one joint, no need for multiple subplots in this case.\n", "# But to follow the grid requirement, we'll replicate the data across multiple subplots.\n", "# For example, let's assume you want to visualize this data 9 times in a 3x3 grid.\n", "\n", "n_cols = 3\n", "n_rows = int(np.ceil(torque_nb / n_cols))\n", "\n", "# Create the figure and axes for subplots\n", "fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))\n", "\n", "# Flatten the axes array for easy indexing (in case of multiple rows)\n", "axes = axes.flatten()\n", "\n", "# Plot the same joint position data in every subplot (as per grid requirement)\n", "for i in range(n_rows * n_cols):\n", " axes[i].plot(joint_pos[:, i])\n", " axes[i].set_title(f'Joint Position {i+1}')\n", " axes[i].set_xlabel('Time')\n", " axes[i].set_ylabel('Position Value')\n", "\n", "# Adjust layout\n", "plt.tight_layout()\n", "plt.savefig(\"joint_position_profile.png\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "# Assuming 'controller.hist_data[\"foot_contact_forces_mag\"]' is a dictionary with foot contact force magnitudes\n", "foot_contact_forces_mag = np.array(controller.hist_data[\"foot_contact_forces_mag\"])\n", "\n", "# Number of feet (foot_nb)\n", "foot_nb = foot_contact_forces_mag.shape[1]\n", "\n", "# Number of rows needed for the grid, with 3 columns per row\n", "n_cols = 3\n", "n_rows = int(np.ceil(foot_nb / n_cols))\n", "\n", "# Create the figure and axes for subplots\n", "fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))\n", "\n", "# Flatten the axes array for easy indexing (in case of multiple rows)\n", "axes = axes.flatten()\n", "\n", "# Plot each foot's contact force magnitude\n", "for i in range(foot_nb):\n", " axes[i].plot(foot_contact_forces_mag[:, i])\n", " axes[i].set_title(f'Foot {i+1} Contact Force Magnitude')\n", " axes[i].set_xlabel('Time')\n", " axes[i].set_ylabel('Force Magnitude')\n", "\n", "# Remove any empty subplots if foot_nb is not a multiple of 3\n", "for j in range(foot_nb, len(axes)):\n", " fig.delaxes(axes[j])\n", "\n", "# Adjust layout\n", "plt.tight_layout()\n", "plt.savefig(\"foot_contact_profile.png\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Extract the joint acceleration data for the first joint over time\n", "joint_acc = np.array(controller.hist_data[\"joint_acc\"])[:, 0]\n", "\n", "# Number of data points in joint_acc\n", "n_data_points = len(joint_acc)\n", "\n", "# Number of feet (foot_nb)\n", "foot_nb = joint_acc.shape[1]\n", "\n", "# Number of rows needed for the grid, with 3 columns per row\n", "n_cols = 3\n", "n_rows = int(np.ceil(foot_nb / n_cols))\n", "\n", "# Create the figure and axes for subplots\n", "fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))\n", "\n", "# Flatten the axes array for easy indexing\n", "axes = axes.flatten()\n", "\n", "# Plot the same joint acceleration data in every subplot (as per grid requirement)\n", "for i in range(n_rows * n_cols):\n", " axes[i].plot(joint_acc[:, i])\n", " axes[i].set_title(f'Joint Acceleration {i+1}')\n", " axes[i].set_xlabel('Time')\n", " axes[i].set_ylabel('Acceleration Value')\n", "\n", "# Adjust layout to prevent overlap\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Extract the joint jerk data over time\n", "joint_jerk = np.array(controller.hist_data[\"joint_jerk\"])[:, 0]\n", "\n", "# Number of data points in joint_jerk\n", "n_data_points = len(joint_jerk)\n", "\n", "# Number of joints (assuming the second dimension corresponds to joints)\n", "num_joints = joint_jerk.shape[1]\n", "\n", "# Number of columns per row in the subplot grid\n", "n_cols = 3\n", "# Number of rows needed for the grid\n", "n_rows = int(np.ceil(num_joints / n_cols))\n", "\n", "# Create the figure and axes for subplots\n", "fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))\n", "\n", "# Flatten the axes array for easy indexing\n", "axes = axes.flatten()\n", "\n", "# Plot the joint jerk data for each joint\n", "for i in range(num_joints):\n", " axes[i].plot(joint_jerk[:, i])\n", " axes[i].set_title(f'Joint Jerk {i+1}')\n", " axes[i].set_xlabel('Time')\n", " axes[i].set_ylabel('Jerk Value')\n", "\n", "# Hide any unused subplots\n", "for i in range(num_joints, len(axes)):\n", " fig.delaxes(axes[i])\n", "\n", "# Adjust layout to prevent overlap\n", "plt.tight_layout()\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Extract the foot contact rate data over time\n", "foot_contact_rate = np.array(controller.hist_data[\"foot_contact_rate\"])[:, 0]\n", "\n", "# Number of data points in foot_contact_rate\n", "n_data_points = foot_contact_rate.shape[0]\n", "\n", "# Number of feet (assuming the second dimension corresponds to feet)\n", "num_feet = foot_contact_rate.shape[1]\n", "\n", "# Number of columns per row in the subplot grid\n", "n_cols = 3\n", "# Number of rows needed for the grid\n", "n_rows = int(np.ceil(num_feet / n_cols))\n", "\n", "# Create the figure and axes for subplots\n", "fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))\n", "\n", "# Flatten the axes array for easy indexing\n", "axes = axes.flatten()\n", "\n", "# Plot the foot contact rate data for each foot\n", "for i in range(num_feet):\n", " axes[i].plot(foot_contact_rate[:, i])\n", " axes[i].set_title(f'Foot Contact Rate {i+1}')\n", " axes[i].set_xlabel('Time')\n", " axes[i].set_ylabel('Contact Rate')\n", "\n", "# Hide any unused subplots\n", "for i in range(num_feet, len(axes)):\n", " fig.delaxes(axes[i])\n", "\n", "# Adjust layout to prevent overlap\n", "plt.tight_layout()\n", "plt.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test on Real Robot (ToDo)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from Go2Py.robot.fsm import FSM\n", "from Go2Py.robot.remote import XBoxRemote\n", "from Go2Py.robot.safety import SafetyHypervisor\n", "from Go2Py.control.cat import *" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "from Go2Py.robot.interface import GO2Real\n", "import numpy as np\n", "robot = GO2Real(mode='lowlevel')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "robot.getJointStates()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Make sure the robot can take commands from python. The next cell should make the joints free to move (no damping)." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import time\n", "start_time = time.time()\n", "\n", "while time.time()-start_time < 10:\n", " q = np.zeros(12) \n", " dq = np.zeros(12)\n", " kp = np.ones(12)*0.0\n", " kd = np.ones(12)*0.0\n", " tau = np.zeros(12)\n", " tau[0] = 0.0\n", " robot.setCommands(q, dq, kp, kd, tau)\n", " time.sleep(0.02)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "remote = XBoxRemote() # KeyboardRemote()\n", "safety_hypervisor = SafetyHypervisor(robot)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class CaTController:\n", " def __init__(self, robot, remote, checkpoint):\n", " self.remote = remote\n", " self.robot = robot\n", " self.policy = Policy(checkpoint)\n", " self.command_profile = CommandInterface()\n", " self.agent = CaTAgent(self.command_profile, self.robot)\n", " self.init()\n", " self.hist_data = {}\n", "\n", " def init(self):\n", " self.obs = self.agent.reset()\n", " self.policy_info = {}\n", " self.command_profile.yaw_vel_cmd = 0.0\n", " self.command_profile.x_vel_cmd = 0.0\n", " self.command_profile.y_vel_cmd = 0.0\n", "\n", " def update(self, robot, remote):\n", " commands = remote.getCommands()\n", " self.command_profile.yaw_vel_cmd = -commands[2]\n", " self.command_profile.x_vel_cmd = max(commands[1] * 0.5, -0.3)\n", " self.command_profile.y_vel_cmd = -commands[0]\n", "\n", " action = self.policy(self.obs, self.policy_info)\n", " self.obs, self.ret, self.done, self.info = self.agent.step(action)\n", " for key, value in self.info.items():\n", " if key in self.hist_data:\n", " self.hist_data[key].append(value)\n", " else:\n", " self.hist_data[key] = [value]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from Go2Py import ASSETS_PATH \n", "import os\n", "checkpoint_path = os.path.join(ASSETS_PATH, 'checkpoints/SoloParkour/trainparamsconfigmax_epochs1500_taskenvlearnlimitsfoot_contact_force_rate60_soft_07-20-22-43.pt')\n", "controller = CaTController(robot, remote, checkpoint_path)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "fsm = FSM(robot, remote, safety_hypervisor, control_dT=1./50., user_controller_callback=controller.update)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "fsm.close()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "b1-env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }