diff --git a/src/rl_sar/include/rl_real_go2.hpp b/src/rl_sar/include/rl_real_go2.hpp new file mode 100644 index 0000000..84e51c7 --- /dev/null +++ b/src/rl_sar/include/rl_real_go2.hpp @@ -0,0 +1,97 @@ +#ifndef RL_REAL_HPP +#define RL_REAL_HPP + +#include "rl_sdk.hpp" +#include "observation_buffer.hpp" +#include "loop.hpp" +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "matplotlibcpp.h" +#include "unitree_joystick.h" +namespace plt = matplotlibcpp; + +using namespace unitree::common; +using namespace unitree::robot; +using namespace unitree::robot::go2; + +#define TOPIC_LOWCMD "rt/lowcmd" +#define TOPIC_LOWSTATE "rt/lowstate" + +constexpr double PosStopF = (2.146E+9f); +constexpr double VelStopF = (16000.0f); + + +class RL_Real : public RL +{ +public: + RL_Real(); + ~RL_Real(); + void Init(); + void InitRobotStateClient(); + int queryServiceStatus(const std::string& serviceName); + void activateService(const std::string& serviceName,int activate); +private: + void LowCmdwriteHandler(); + uint32_t crc32_core(uint32_t* ptr, uint32_t len); + void InitLowCmd(); + void LowStateMessageHandler(const void* messages); + // rl functions + torch::Tensor Forward() override; + void GetState(RobotState *state) override; + void SetCommand(const RobotCommand *command) override; + void RunModel(); + void RobotControl(); + + // history buffer + ObservationBuffer history_obs_buf; + torch::Tensor history_obs; + + // loop + std::shared_ptr loop_keyboard; + std::shared_ptr loop_control; + std::shared_ptr loop_udpSend; + std::shared_ptr loop_udpRecv; + std::shared_ptr loop_rl; + std::shared_ptr loop_plot; + + // plot + const int plot_size = 100; + std::vector plot_t; + std::vector> plot_real_joint_pos, plot_target_joint_pos; + + + void Plot(); + + // unitree interface + + RobotStateClient rsc; + + unitree_go::msg::dds_::LowCmd_ unitree_low_command{}; // default init + unitree_go::msg::dds_::LowState_ unitree_low_state{}; // default init + + /*publisher*/ + ChannelPublisherPtr lowcmd_publisher; + /*subscriber*/ + ChannelSubscriberPtr lowstate_subscriber; + + xRockerBtnDataStruct unitree_joy; + + + + // others + int motiontime = 0; + std::vector mapped_joint_positions; + std::vector mapped_joint_velocities; + int command_mapping[12] = {3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8}; + int state_mapping[12] = {3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8}; +}; + +#endif \ No newline at end of file diff --git a/src/rl_sar/include/unitree_joystick.h b/src/rl_sar/include/unitree_joystick.h new file mode 100644 index 0000000..5ea7b02 --- /dev/null +++ b/src/rl_sar/include/unitree_joystick.h @@ -0,0 +1,41 @@ +#ifndef UNITREE_JOYSTICK_H +#define UNITREE_JOYSTICK_H + +#include +// 16b +typedef union { + struct { + uint8_t R1 :1; + uint8_t L1 :1; + uint8_t start :1; + uint8_t select :1; + uint8_t R2 :1; + uint8_t L2 :1; + uint8_t F1 :1; + uint8_t F2 :1; + uint8_t A :1; + uint8_t B :1; + uint8_t X :1; + uint8_t Y :1; + uint8_t up :1; + uint8_t right :1; + uint8_t down :1; + uint8_t left :1; + } components; + uint16_t value; +} xKeySwitchUnion; + +// 40 Byte (now used 24B) +typedef struct { + uint8_t head[2]; + xKeySwitchUnion btn; + float lx; + float rx; + float ry; + float L2; + float ly; + + uint8_t idle[16]; +} xRockerBtnDataStruct; + +#endif // UNITREE_JOYSTICK_H \ No newline at end of file diff --git a/src/rl_sar/src/rl_real_go2.cpp b/src/rl_sar/src/rl_real_go2.cpp new file mode 100644 index 0000000..6b34e52 --- /dev/null +++ b/src/rl_sar/src/rl_real_go2.cpp @@ -0,0 +1,376 @@ +#include "../include/rl_real_go2.hpp" + +// #define PLOT +// #define CSV_LOGGER + + + +uint32_t RL_Real::crc32_core(uint32_t* ptr, uint32_t len) +{ + unsigned int xbit = 0; + unsigned int data = 0; + unsigned int CRC32 = 0xFFFFFFFF; + const unsigned int dwPolynomial = 0x04c11db7; + + for (unsigned int i = 0; i < len; i++) + { + xbit = 1 << 31; + data = ptr[i]; + for (unsigned int bits = 0; bits < 32; bits++) + { + if (CRC32 & 0x80000000) + { + CRC32 <<= 1; + CRC32 ^= dwPolynomial; + } + else + { + CRC32 <<= 1; + } + + if (data & xbit) + CRC32 ^= dwPolynomial; + xbit >>= 1; + } + } + + return CRC32; +} + +void RL_Real::Init() +{ + InitLowCmd(); + + /*create publisher*/ + lowcmd_publisher.reset(new ChannelPublisher(TOPIC_LOWCMD)); + lowcmd_publisher->InitChannel(); + + /*create subscriber*/ + lowstate_subscriber.reset(new ChannelSubscriber(TOPIC_LOWSTATE)); + lowstate_subscriber->InitChannel(std::bind(&RL_Real::LowStateMessageHandler, this, std::placeholders::_1), 1); + + /*loop publishing thread*/ + //lowCmdWriteThreadPtr = CreateRecurrentThreadEx("writebasiccmd", UT_CPU_ID_NONE, 2000, &RL_Real::LowCmdwriteHandler, this); + + // read params from yaml + this->robot_name = "go2_isaacgym"; + this->ReadYaml(this->robot_name); + + // history + this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6); + + + // init + torch::autograd::GradMode::set_enabled(false); + this->InitObservations(); + this->InitOutputs(); + this->InitControl(); + + // model + std::string model_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + this->robot_name + "/" + this->params.model_name; + this->model = torch::jit::load(model_path); + + // loop + // this->loop_udpSend = std::make_shared("loop_udpSend" , 0.002, std::bind(&RL_Real::UDPSend, this), 3); + // this->loop_udpRecv = std::make_shared("loop_udpRecv" , 0.002, std::bind(&RL_Real::UDPRecv, this), 3); + this->loop_keyboard = std::make_shared("loop_keyboard", 0.05, std::bind(&RL_Real::KeyboardInterface, this)); + this->loop_control = std::make_shared("loop_control", this->params.dt, std::bind(&RL_Real::RobotControl, this)); + this->loop_rl = std::make_shared("loop_rl", this->params.dt * this->params.decimation, std::bind(&RL_Real::RunModel, this)); + // this->loop_udpSend->start(); + // this->loop_udpRecv->start(); + this->loop_keyboard->start(); + this->loop_control->start(); + this->loop_rl->start(); + + +#ifdef PLOT + this->plot_t = std::vector(this->plot_size, 0); + this->plot_real_joint_pos.resize(this->params.num_of_dofs); + this->plot_target_joint_pos.resize(this->params.num_of_dofs); + for(auto& vector : this->plot_real_joint_pos) { vector = std::vector(this->plot_size, 0); } + for(auto& vector : this->plot_target_joint_pos) { vector = std::vector(this->plot_size, 0); } + this->loop_plot = std::make_shared("loop_plot" , 0.002, std::bind(&RL_Real::Plot, this)); + this->loop_plot->start(); +#endif +#ifdef CSV_LOGGER + this->CSVInit(this->robot_name); +#endif + +} + + + +void RL_Real::InitLowCmd() +{ + unitree_low_command.head()[0] = 0xFE; + unitree_low_command.head()[1] = 0xEF; + unitree_low_command.level_flag() = 0xFF; + unitree_low_command.gpio() = 0; + + for(int i=0; i<20; i++) + { + unitree_low_command.motor_cmd()[i].mode() = (0x01); // motor switch to servo (PMSM) mode + unitree_low_command.motor_cmd()[i].q() = (PosStopF); + unitree_low_command.motor_cmd()[i].kp() = (0); + unitree_low_command.motor_cmd()[i].dq() = (VelStopF); + unitree_low_command.motor_cmd()[i].kd() = (0); + unitree_low_command.motor_cmd()[i].tau() = (0); + } +} + +void RL_Real::InitRobotStateClient() +{ + rsc.SetTimeout(10.0f); + rsc.Init(); +} + +int RL_Real::queryServiceStatus(const std::string& serviceName) +{ + std::vector serviceStateList; + int ret,serviceStatus; + ret = rsc.ServiceList(serviceStateList); + size_t i, count=serviceStateList.size(); + for (i=0; iunitree_low_command.crc() = crc32_core((uint32_t *)&unitree_low_command, (sizeof(unitree_go::msg::dds_::LowCmd_)>>2)-1); + // this->lowcmd_publisher->Write(unitree_low_command); +} + + +RL_Real::RL_Real(){ + +} + +RL_Real::~RL_Real() +{ + // this->loop_udpSend->shutdown(); + // this->loop_udpRecv->shutdown(); + this->loop_keyboard->shutdown(); + this->loop_control->shutdown(); + this->loop_rl->shutdown(); +#ifdef PLOT + this->loop_plot->shutdown(); +#endif + std::cout << LOGGER::INFO << "RL_Real exit" << std::endl; +} + +void RL_Real::GetState(RobotState *state) +{ + //! 这里是不是要加锁? + memcpy(&this->unitree_joy, &this->unitree_low_state.wireless_remote()[0], 40); + + if((int)this->unitree_joy.btn.components.R2 == 1) + { + this->control.control_state = STATE_POS_GETUP; + } + else if((int)this->unitree_joy.btn.components.R1 == 1) + { + this->control.control_state = STATE_RL_INIT; + } + else if((int)this->unitree_joy.btn.components.L2 == 1) + { + this->control.control_state = STATE_POS_GETDOWN; + } + + if(this->params.framework == "isaacgym") + { + state->imu.quaternion[3] = this->unitree_low_state.imu_state().quaternion()[0]; // w + state->imu.quaternion[0] = this->unitree_low_state.imu_state().quaternion()[1]; // x + state->imu.quaternion[1] = this->unitree_low_state.imu_state().quaternion()[2]; // y + state->imu.quaternion[2] = this->unitree_low_state.imu_state().quaternion()[3]; // z + } + else if(this->params.framework == "isaacsim") + { + state->imu.quaternion[0] = this->unitree_low_state.imu_state().quaternion()[0]; // w + state->imu.quaternion[1] = this->unitree_low_state.imu_state().quaternion()[1]; // x + state->imu.quaternion[2] = this->unitree_low_state.imu_state().quaternion()[2]; // y + state->imu.quaternion[3] = this->unitree_low_state.imu_state().quaternion()[3]; // z + } + + for(int i = 0; i < 3; ++i) + { + state->imu.gyroscope[i] = this->unitree_low_state.imu_state().gyroscope()[i]; + } + for(int i = 0; i < this->params.num_of_dofs; ++i) + { + state->motor_state.q[i] = this->unitree_low_state.motor_state()[state_mapping[i]].q(); + state->motor_state.dq[i] = this->unitree_low_state.motor_state()[state_mapping[i]].dq(); + state->motor_state.tauEst[i] = this->unitree_low_state.motor_state()[state_mapping[i]].tau_est(); + } + + //std::cout<<"(int)this->unitree_joy.btn.components.R2: "<<(int)this->unitree_joy.btn.components.R2<unitree_low_state.imu_state().quaternion()[0]<<", "<< this->unitree_low_state.imu_state().quaternion()[1]<<", "<< this->unitree_low_state.imu_state().quaternion()[2]<<", "<< this->unitree_low_state.imu_state().quaternion()[3]< *command) +{ + for(int i = 0; i < this->params.num_of_dofs; ++i) + { + this->unitree_low_command.motor_cmd()[i].mode() = 0x01; + this->unitree_low_command.motor_cmd()[i].q() = command->motor_command.q[command_mapping[i]]; + this->unitree_low_command.motor_cmd()[i].dq() = command->motor_command.dq[command_mapping[i]]; + this->unitree_low_command.motor_cmd()[i].kp() = command->motor_command.kp[command_mapping[i]]; + this->unitree_low_command.motor_cmd()[i].kd() = command->motor_command.kd[command_mapping[i]]; + this->unitree_low_command.motor_cmd()[i].tau() = command->motor_command.tau[command_mapping[i]]; + //std::cout<<"q: "<motor_command.q[command_mapping[i]]<unitree_low_command.crc() = crc32_core((uint32_t *)&unitree_low_command, (sizeof(unitree_go::msg::dds_::LowCmd_)>>2)-1); + lowcmd_publisher->Write(unitree_low_command); + + +} + +void RL_Real::RobotControl() +{ + this->motiontime++; + + this->GetState(&this->robot_state); + this->StateController(&this->robot_state, &this->robot_command); + this->SetCommand(&this->robot_command); +} + +void RL_Real::RunModel() +{ + if(this->running_state == STATE_RL_RUNNING) + { + this->obs.ang_vel = torch::tensor(this->robot_state.imu.gyroscope).unsqueeze(0); + // this->obs.commands = torch::tensor({{this->unitree_joy.ly, -this->unitree_joy.rx, -this->unitree_joy.lx}}); + this->obs.commands = torch::tensor({{this->control.x , this->control.y, this->control.yaw}}); + this->obs.base_quat = torch::tensor(this->robot_state.imu.quaternion).unsqueeze(0); + this->obs.dof_pos = torch::tensor(this->robot_state.motor_state.q).narrow(0, 0, this->params.num_of_dofs).unsqueeze(0); + this->obs.dof_vel = torch::tensor(this->robot_state.motor_state.dq).narrow(0, 0, this->params.num_of_dofs).unsqueeze(0); + + torch::Tensor clamped_actions = this->Forward(); + + for (int i : this->params.hip_scale_reduction_indices) + { + clamped_actions[0][i] *= this->params.hip_scale_reduction; + } + + this->obs.actions = clamped_actions; + + torch::Tensor origin_output_torques = this->ComputeTorques(this->obs.actions); + + this->TorqueProtect(origin_output_torques); + + this->output_torques = torch::clamp(origin_output_torques, -(this->params.torque_limits), this->params.torque_limits); + this->output_dof_pos = this->ComputePosition(this->obs.actions); + +#ifdef CSV_LOGGER + torch::Tensor tau_est = torch::tensor(this->robot_state.motor_state.tauEst).unsqueeze(0); + this->CSVLogger(this->output_torques, tau_est, this->obs.dof_pos, this->output_dof_pos, this->obs.dof_vel); +#endif + } +} + + + +torch::Tensor RL_Real::Forward() +{ + torch::autograd::GradMode::set_enabled(false); + + torch::Tensor clamped_obs = this->ComputeObservation(); + + this->history_obs_buf.insert(clamped_obs); + this->history_obs = this->history_obs_buf.get_obs_vec({5, 4, 3, 2, 1, 0}); + + torch::Tensor actions = this->model.forward({this->history_obs}).toTensor(); + + torch::Tensor clamped_actions = torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper); + + return clamped_actions; +} + +void RL_Real::Plot() +{ + this->plot_t.erase(this->plot_t.begin()); + this->plot_t.push_back(this->motiontime); + plt::cla(); + plt::clf(); + for(int i = 0; i < this->params.num_of_dofs; ++i) + { + this->plot_real_joint_pos[i].erase(this->plot_real_joint_pos[i].begin()); + this->plot_target_joint_pos[i].erase(this->plot_target_joint_pos[i].begin()); + this->plot_real_joint_pos[i].push_back(this->unitree_low_state.motor_state()[i].q()); + this->plot_target_joint_pos[i].push_back(this->unitree_low_command.motor_cmd()[i].q()); + plt::subplot(4, 3, i + 1); + plt::named_plot("_real_joint_pos", this->plot_t, this->plot_real_joint_pos[i], "r"); + plt::named_plot("_target_joint_pos", this->plot_t, this->plot_target_joint_pos[i], "b"); + plt::xlim(this->plot_t.front(), this->plot_t.back()); + } + // plt::legend(); + plt::pause(0.0001); +} + +void signalHandler(int signum) +{ + exit(0); +} + +int main(int argc, char **argv) +{ + + if (argc < 2) + { + std::cout << "Usage: " << argv[0] << " networkInterface" << std::endl; + exit(-1); + } + + std::cout << "WARNING: Make sure the robot is hung up or lying on the ground." << std::endl + << "Press Enter to continue..." << std::endl; + std::cin.ignore(); + + ChannelFactory::Instance()->Init(0, argv[1]); + + RL_Real rl_sar; + rl_sar.InitRobotStateClient(); + while(rl_sar.queryServiceStatus("sport_mode")) + { + std::cout<<"Try to deactivate the service: "<<"sport_mode"<