diff --git a/examples/05-walk-these-ways-RL-controller.ipynb b/examples/05-walk-these-ways-RL-controller.ipynb index 97ed4c4..8a75b9e 100644 --- a/examples/05-walk-these-ways-RL-controller.ipynb +++ b/examples/05-walk-these-ways-RL-controller.ipynb @@ -1,91 +1,5 @@ { "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class walkTheseWaysController:\n", - " def __init__(self, robot, remote, checkpoint):\n", - " self.remote = remote\n", - " self.robot = robot\n", - " self.cfg = loadParameters(checkpoint)\n", - " self.policy = Policy(checkpoint)\n", - " self.command_profile = CommandInterface()\n", - " self.agent = WalkTheseWaysAgent(self.cfg, self.command_profile, self.robot)\n", - " self.agent = HistoryWrapper(self.agent)\n", - " self.init()\n", - "\n", - " def init(self):\n", - " self.obs = self.agent.reset()\n", - " self.policy_info = {}\n", - " self.command_profile.yaw_vel_cmd = 0.0\n", - " self.command_profile.x_vel_cmd = 0.0\n", - " self.command_profile.y_vel_cmd = 0.0\n", - " self.command_profile.stance_width_cmd=0.25\n", - " self.command_profile.footswing_height_cmd=0.08\n", - " self.command_profile.step_frequency_cmd = 3.0\n", - " self.command_profile.bodyHeight = 0.00\n", - "\n", - " def update(self, robot, remote):\n", - " action = self.policy(self.obs, self.policy_info)\n", - " self.obs, self.ret, self.done, self.info = self.agent.step(action)\n", - " vy = -robot.getRemoteState().lx\n", - " vx = robot.getRemoteState().ly\n", - " omega = -robot.getRemoteState().rx*2.2\n", - " self.command_profile.x_vel_cmd = vx*1.5\n", - " self.command_profile.y_vel_cmd = vy*1.5\n", - " self.command_profile.yaw_vel_cmd = omega" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class BaseRemote:\n", - " def __init__(self):\n", - " pass\n", - " def startSeq(self):\n", - " return False\n", - " def standUpDownSeq(self):\n", - " return False\n", - "\n", - " def flushStates(self):\n", - " pass\n", - "\n", - " def getEstop(self):\n", - " return False\n", - "\n", - "class UnitreeRemote(BaseRemote):\n", - " def __init__(self, robot):\n", - " self.robot = robot\n", - "\n", - " def startSeq(self):\n", - " remote = self.robot.getRemoteState()\n", - " if remote.btn.start:\n", - " return True\n", - " else:\n", - " return False\n", - "\n", - " def standUpDownSeq(self):\n", - " remote = self.robot.getRemoteState()\n", - " if remote.btn.L2 and remote.btn.A:\n", - " return True\n", - " else:\n", - " return False\n", - "\n", - " def getEstop(self):\n", - " remote = self.robot.getRemoteState()\n", - " if remote.btn.L2 and remote.btn.R2:\n", - " return True\n", - " else:\n", - " return False\n", - " " - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -132,14 +46,37 @@ "metadata": {}, "outputs": [], "source": [ - "controller.command_profile.pitch_cmd=0.0\n", - "controller.command_profile.body_height_cmd=0.0\n", - "controller.command_profile.footswing_height_cmd=0.08\n", - "controller.command_profile.roll_cmd=0.0\n", - "controller.command_profile.stance_width_cmd=0.2\n", - "controller.command_profile.x_vel_cmd=-0.2\n", - "controller.command_profile.y_vel_cmd=0.01\n", - "controller.command_profile.setGaitType(\"trotting\")" + "class walkTheseWaysController:\n", + " def __init__(self, robot, remote, checkpoint):\n", + " self.remote = remote\n", + " self.robot = robot\n", + " self.cfg = loadParameters(checkpoint)\n", + " self.policy = Policy(checkpoint)\n", + " self.command_profile = CommandInterface()\n", + " self.agent = WalkTheseWaysAgent(self.cfg, self.command_profile, self.robot)\n", + " self.agent = HistoryWrapper(self.agent)\n", + " self.init()\n", + "\n", + " def init(self):\n", + " self.obs = self.agent.reset()\n", + " self.policy_info = {}\n", + " self.command_profile.yaw_vel_cmd = 0.0\n", + " self.command_profile.x_vel_cmd = 0.0\n", + " self.command_profile.y_vel_cmd = 0.0\n", + " self.command_profile.stance_width_cmd=0.25\n", + " self.command_profile.footswing_height_cmd=0.08\n", + " self.command_profile.step_frequency_cmd = 3.0\n", + " self.command_profile.bodyHeight = 0.00\n", + "\n", + " def update(self, robot, remote):\n", + " action = self.policy(self.obs, self.policy_info)\n", + " self.obs, self.ret, self.done, self.info = self.agent.step(action)\n", + " vy = -robot.getRemoteState().lx\n", + " vx = robot.getRemoteState().ly\n", + " omega = -robot.getRemoteState().rx*2.2\n", + " self.command_profile.x_vel_cmd = vx*1.5\n", + " self.command_profile.y_vel_cmd = vy*1.5\n", + " self.command_profile.yaw_vel_cmd = omega" ] }, { @@ -153,6 +90,22 @@ "fsm = FSM(robot, remote, safety_hypervisor, user_controller_callback=controller.update)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "controller.command_profile.pitch_cmd=0.0\n", + "controller.command_profile.body_height_cmd=0.0\n", + "controller.command_profile.footswing_height_cmd=0.08\n", + "controller.command_profile.roll_cmd=0.0\n", + "controller.command_profile.stance_width_cmd=0.2\n", + "controller.command_profile.x_vel_cmd=-0.2\n", + "controller.command_profile.y_vel_cmd=0.01\n", + "controller.command_profile.setGaitType(\"trotting\")" + ] + }, { "cell_type": "code", "execution_count": null, @@ -176,7 +129,7 @@ "outputs": [], "source": [ "from Go2Py.robot.fsm import FSM\n", - "from Go2Py.robot.remote import KeyboardRemote\n", + "from Go2Py.robot.remote import UnitreeRemote\n", "from Go2Py.robot.safety import SafetyHypervisor\n", "from Go2Py.control.walk_these_ways import *" ] @@ -187,7 +140,7 @@ "metadata": {}, "outputs": [], "source": [ - "from Go2Py.robot.interface.dds import GO2Real\n", + "from Go2Py.robot.interface import GO2Real\n", "import numpy as np\n", "robot = GO2Real(mode='lowlevel')" ] @@ -202,6 +155,45 @@ "safety_hypervisor = SafetyHypervisor(robot)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class walkTheseWaysController:\n", + " def __init__(self, robot, remote, checkpoint):\n", + " self.remote = remote\n", + " self.robot = robot\n", + " self.cfg = loadParameters(checkpoint)\n", + " self.policy = Policy(checkpoint)\n", + " self.command_profile = CommandInterface()\n", + " self.agent = WalkTheseWaysAgent(self.cfg, self.command_profile, self.robot)\n", + " self.agent = HistoryWrapper(self.agent)\n", + " self.init()\n", + "\n", + " def init(self):\n", + " self.obs = self.agent.reset()\n", + " self.policy_info = {}\n", + " self.command_profile.yaw_vel_cmd = 0.0\n", + " self.command_profile.x_vel_cmd = 0.0\n", + " self.command_profile.y_vel_cmd = 0.0\n", + " self.command_profile.stance_width_cmd=0.25\n", + " self.command_profile.footswing_height_cmd=0.08\n", + " self.command_profile.step_frequency_cmd = 3.0\n", + " self.command_profile.bodyHeight = 0.00\n", + "\n", + " def update(self, robot, remote):\n", + " action = self.policy(self.obs, self.policy_info)\n", + " self.obs, self.ret, self.done, self.info = self.agent.step(action)\n", + " vy = -robot.getRemoteState().lx\n", + " vx = robot.getRemoteState().ly\n", + " omega = -robot.getRemoteState().rx*2.2\n", + " self.command_profile.x_vel_cmd = vx*1.5\n", + " self.command_profile.y_vel_cmd = vy*1.5\n", + " self.command_profile.yaw_vel_cmd = omega" + ] + }, { "cell_type": "code", "execution_count": null,