From 6a23e62a470256f6ddf38727a84f013159fbbd40 Mon Sep 17 00:00:00 2001 From: fan-ziqi Date: Fri, 15 Mar 2024 17:43:55 +0800 Subject: [PATCH] =?UTF-8?q?=E5=96=9D=E9=86=89=E4=BA=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/rl_sar/include/rl_real.hpp | 44 ++++-- .../unitree_legged_sdk/unitree_joystick.h | 66 ++++----- src/rl_sar/src/rl_real.cpp | 133 ++++++++++++------ 3 files changed, 155 insertions(+), 88 deletions(-) diff --git a/src/rl_sar/include/rl_real.hpp b/src/rl_sar/include/rl_real.hpp index 69a4c03..06469f6 100644 --- a/src/rl_sar/include/rl_real.hpp +++ b/src/rl_sar/include/rl_real.hpp @@ -8,14 +8,26 @@ #include #include #include "unitree_legged_sdk/unitree_legged_sdk.h" +#include "unitree_legged_sdk/unitree_joystick.h" #include +#include +// #include using namespace UNITREE_LEGGED_SDK; +enum InitState { + STATE_WAITING = 0, + STATE_POS_INIT, + STATE_RL_INIT, + STATE_RL_START, + STATE_POS_STOP, +}; + class RL_Real : public RL { public: RL_Real(); + ~RL_Real(); void runModel(); torch::Tensor forward() override; torch::Tensor compute_observation() override; @@ -23,8 +35,6 @@ public: ObservationBuffer history_obs_buf; torch::Tensor history_obs; - torch::Tensor torques; - //udp void UDPSend(); void UDPRecv(); @@ -33,6 +43,7 @@ public: UDP udp; LowCmd cmd = {0}; LowState state = {0}; + xRockerBtnDataStruct _keyData; int motiontime = 0; std::shared_ptr loop_control; @@ -40,25 +51,34 @@ public: std::shared_ptr loop_udpRecv; std::shared_ptr loop_rl; - float _percent; - float _targetPos[12] = {0.0, 0.8, -1.6, 0.0, 0.8, -1.6, - 0.0, 0.8, -1.6, 0.0, 0.8, -1.6}; //0.0, 0.67, -1.3 + // float _targetPos[12] = {0.0, 0.8, -1.6, 0.0, 0.8, -1.6, + // 0.0, 0.8, -1.6, 0.0, 0.8, -1.6}; float _startPos[12]; - bool init_done = false; + int init_state = STATE_WAITING; private: - std::string ros_namespace; - - std::vector torque_command_topics; - - std::vector torque_commands; - std::vector joint_names; std::vector joint_positions; std::vector joint_velocities; + torch::Tensor torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}}); + int dof_mapping[13] = {3, 4, 5, + 0, 1, 2, + 9, 10, 11, + 6, 7, 8}; + float Kp[13] = {20, 10, 10, + 20, 10, 10, + 20, 10, 10, + 20, 10, 10}; + float Kd[13] = {1.0, 0.5, 0.5, + 1.0, 0.5, 0.5, + 1.0, 0.5, 0.5, + 1.0, 0.5, 0.5}; + torch::Tensor target_dof_pos; + torch::Tensor compute_pos(torch::Tensor actions); + std::chrono::high_resolution_clock::time_point start_time; // other rl module diff --git a/src/rl_sar/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/unitree_joystick.h b/src/rl_sar/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/unitree_joystick.h index b9e0e14..8ea87c0 100644 --- a/src/rl_sar/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/unitree_joystick.h +++ b/src/rl_sar/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/unitree_joystick.h @@ -5,40 +5,40 @@ Copyright (c) 2020, Unitree Robotics.Co.Ltd. All rights reserved. #define UNITREE_JOYSTICK_H #include -// // 16b -// typedef union { -// struct { -// uint8_t R1 :1; -// uint8_t L1 :1; -// uint8_t start :1; -// uint8_t select :1; -// uint8_t R2 :1; -// uint8_t L2 :1; -// uint8_t F1 :1; -// uint8_t F2 :1; -// uint8_t A :1; -// uint8_t B :1; -// uint8_t X :1; -// uint8_t Y :1; -// uint8_t up :1; -// uint8_t right :1; -// uint8_t down :1; -// uint8_t left :1; -// } components; -// uint16_t value; -// } xKeySwitchUnion; +// 16b +typedef union { + struct { + uint8_t R1 :1; + uint8_t L1 :1; + uint8_t start :1; + uint8_t select :1; + uint8_t R2 :1; + uint8_t L2 :1; + uint8_t F1 :1; + uint8_t F2 :1; + uint8_t A :1; + uint8_t B :1; + uint8_t X :1; + uint8_t Y :1; + uint8_t up :1; + uint8_t right :1; + uint8_t down :1; + uint8_t left :1; + } components; + uint16_t value; +} xKeySwitchUnion; -// // 40 Byte (now used 24B) -// typedef struct { -// uint8_t head[2]; -// xKeySwitchUnion btn; -// float lx; -// float rx; -// float ry; -// float L2; -// float ly; +// 40 Byte (now used 24B) +typedef struct { + uint8_t head[2]; + xKeySwitchUnion btn; + float lx; + float rx; + float ry; + float L2; + float ly; -// uint8_t idle[16]; -// } xRockerBtnDataStruct; + uint8_t idle[16]; +} xRockerBtnDataStruct; #endif // UNITREE_JOYSTICK_H \ No newline at end of file diff --git a/src/rl_sar/src/rl_real.cpp b/src/rl_sar/src/rl_real.cpp index a81910b..0d3146c 100644 --- a/src/rl_sar/src/rl_real.cpp +++ b/src/rl_sar/src/rl_real.cpp @@ -1,5 +1,9 @@ #include "../include/rl_real.hpp" +// #define CONTROL_BY_TORQUE + +RL_Real rl_sar; + void RL_Real::UDPRecv() { udp.Recv(); @@ -15,7 +19,20 @@ void RL_Real::RobotControl() motiontime++; udp.GetRecv(state); - if( motiontime < 500) + memcpy(&_keyData, state.wirelessRemote, 40); + + // get joy button + if(init_state < STATE_POS_INIT && (int)_keyData.btn.components.R2 == 1) + { + init_state = STATE_POS_INIT; + } + else if(init_state < STATE_RL_INIT && (int)_keyData.btn.components.R1 == 1) + { + init_state = STATE_RL_INIT; + } + + // wait for standup + if(init_state == STATE_WAITING) { for(int i = 0; i < 12; ++i) { @@ -23,36 +40,39 @@ void RL_Real::RobotControl() _startPos[i] = state.motorState[i].q; } } - - if( motiontime >= 500 && _percent != 1) + // standup (position control) + else if(init_state == STATE_POS_INIT && _percent != 1) { + printf("initing %d%%\r", (int)(_percent*100)); _percent += (float) 1 / 1000; _percent = _percent > 1 ? 1 : _percent; for(int i = 0; i < 12; ++i) { - cmd.motorCmd[i].q = (1 - _percent) * _startPos[i] + _percent * _targetPos[i]; + cmd.motorCmd[i].q = (1 - _percent) * _startPos[i] + _percent * params.default_dof_pos[0][dof_mapping[i]].item(); cmd.motorCmd[i].dq = 0; cmd.motorCmd[i].Kp = 50; cmd.motorCmd[i].Kd = 3; cmd.motorCmd[i].tau = 0; } } - if(_percent == 1 && !init_done) + // init obs and start rl loop + else if(init_state == STATE_RL_INIT && _percent == 1) { - init_done = true; - loop_rl->start(); - this->init_observations(); - std::cout << "init done" << std::endl; + init_state = STATE_RL_START; motiontime = 0; + this->init_observations(); + printf("\nstart rl loop\n"); + loop_rl->start(); } - - if(init_done) + // rl loop + else if(init_state == STATE_RL_START) { + // wait for 500 times if( motiontime < 500) { for(int i = 0; i < 12; ++i) { - cmd.motorCmd[i].q = _targetPos[i]; + cmd.motorCmd[i].q = params.default_dof_pos[0][dof_mapping[i]].item(); cmd.motorCmd[i].dq = 0; cmd.motorCmd[i].Kp = 50; cmd.motorCmd[i].Kd = 3; @@ -62,24 +82,33 @@ void RL_Real::RobotControl() } if( motiontime >= 500) { - int mapping[13] = {3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8}; - +#ifdef CONTROL_BY_TORQUE for (int i = 0; i < 12; ++i) { - float torque = torques[0][mapping[i]].item(); + float torque = torques[0][dof_mapping[i]].item(); // if(torque > 5.0f) torque = 5.0f; // if(torque < -5.0f) torque = -5.0f; - cmd.motorCmd[i].q = PosStopF; - cmd.motorCmd[i].dq = VelStopF; + cmd.motorCmd[i].q = 0; + cmd.motorCmd[i].dq = 0; cmd.motorCmd[i].Kp = 0; cmd.motorCmd[i].Kd = 0; cmd.motorCmd[i].tau = torque; } +#else + for (int i = 0; i < 12; ++i) + { + cmd.motorCmd[i].q = target_dof_pos[0][dof_mapping[i]].item(); + cmd.motorCmd[i].dq = 0; + cmd.motorCmd[i].Kp = 15; + cmd.motorCmd[i].Kd = 1.5; + cmd.motorCmd[i].tau = 0; + } +#endif } } - safe.PowerProtect(cmd, state, 1); + safe.PowerProtect(cmd, state, 7); udp.SetSend(cmd); } @@ -89,8 +118,6 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL) start_time = std::chrono::high_resolution_clock::now(); - torque_commands.resize(12); - std::string actor_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/actor.pt"; std::string encoder_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/encoder.pt"; std::string vq_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/vq_layer.pt"; @@ -103,8 +130,8 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL) this->params.num_observations = 45; this->params.clip_obs = 100.0; this->params.clip_actions = 100.0; - this->params.damping = 0.5; - this->params.stiffness = 20; + this->params.damping = 1.0; // TODO + this->params.stiffness = 5; // TODO this->params.d_gains = torch::ones(12) * this->params.damping; this->params.p_gains = torch::ones(12) * this->params.stiffness; this->params.action_scale = 0.25; @@ -115,14 +142,13 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL) this->params.dof_pos_scale = 1.0; this->params.dof_vel_scale = 0.05; this->params.commands_scale = torch::tensor({this->params.lin_vel_scale, this->params.lin_vel_scale, this->params.ang_vel_scale}); - this->params.torque_limits = torch::tensor({{20.0, 55.0, 55.0, 20.0, 55.0, 55.0, 20.0, 55.0, 55.0, 20.0, 55.0, 55.0}}); - // hip, thigh, calf + // hip, thigh, calf this->params.default_dof_pos = torch::tensor({{-0.1000, 0.8000, -1.5000, // front right 0.1000, 0.8000, -1.5000, // front left -0.1000, 1.0000, -1.5000, // rear right @@ -130,6 +156,8 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL) this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6); + target_dof_pos = params.default_dof_pos; + // InitEnvironment(); loop_control = std::make_shared("control_loop", 0.002, boost::bind(&RL_Real::RobotControl, this)); loop_udpSend = std::make_shared("udp_send" , 0.002, 3, boost::bind(&RL_Real::UDPSend, this)); @@ -141,9 +169,30 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL) loop_control->start(); } +RL_Real::~RL_Real() +{ + loop_udpSend->shutdown(); + loop_udpRecv->shutdown(); + loop_control->shutdown(); + loop_rl->shutdown(); + printf("shutdown\n"); +} + +torch::Tensor RL_Real::compute_pos(torch::Tensor actions) +{ + torch::Tensor actions_scaled = actions * this->params.action_scale; + int indices[] = {0, 3, 6, 9}; + for (int i : indices) + { + actions_scaled[0][i] *= this->params.hip_scale_reduction; + } + + return actions_scaled + this->params.default_dof_pos; +} + void RL_Real::runModel() { - if(init_done) + if(init_state == STATE_RL_START) { auto duration = std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start_time).count(); // std::cout << "Execution time: " << duration << " microseconds" << std::endl; @@ -155,7 +204,7 @@ void RL_Real::runModel() // printf("%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f\n", state.motorState[FL_0].dq, state.motorState[FL_1].dq, state.motorState[FL_2].dq, state.motorState[FR_0].dq, state.motorState[FR_1].dq, state.motorState[FR_2].dq, state.motorState[RL_0].dq, state.motorState[RL_1].dq, state.motorState[RL_2].dq, state.motorState[RR_0].dq, state.motorState[RR_1].dq, state.motorState[RR_2].dq); this->obs.ang_vel = torch::tensor({{state.imu.gyroscope[0], state.imu.gyroscope[1], state.imu.gyroscope[2]}}); - // this->obs.commands = torch::tensor({{cmd_vel.linear.x, cmd_vel.linear.y, cmd_vel.angular.z}}); + this->obs.commands = torch::tensor({{_keyData.ly, -_keyData.rx, -_keyData.lx}}); this->obs.base_quat = torch::tensor({{state.imu.quaternion[1], state.imu.quaternion[2], state.imu.quaternion[3], state.imu.quaternion[0]}}); this->obs.dof_pos = torch::tensor({{state.motorState[FL_0].q, state.motorState[FL_1].q, state.motorState[FL_2].q, state.motorState[FR_0].q, state.motorState[FR_1].q, state.motorState[FR_2].q, @@ -167,31 +216,25 @@ void RL_Real::runModel() state.motorState[RR_0].dq, state.motorState[RR_1].dq, state.motorState[RR_2].dq}}); torch::Tensor actions = this->forward(); +#ifdef CONTROL_BY_TORQUE torques = this->compute_torques(actions); +#else + target_dof_pos = this->compute_pos(actions); +#endif } } torch::Tensor RL_Real::compute_observation() { - torch::Tensor ang_vel = this->quat_rotate_inverse(this->obs.base_quat, this->obs.ang_vel); - // float ang_vel_temp = ang_vel[0][0].item(); - // ang_vel[0][0] = ang_vel[0][1]; - // ang_vel[0][1] = ang_vel_temp; - - torch::Tensor grav = this->quat_rotate_inverse(this->obs.base_quat, this->obs.gravity_vec); - // float grav_temp = grav[0][0].item(); - // grav[0][0] = grav[0][1]; - // grav[0][1] = grav_temp; - torch::Tensor obs = torch::cat({// (this->quat_rotate_inverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale, - ang_vel * this->params.ang_vel_scale, - grav, + this->quat_rotate_inverse(this->obs.base_quat, this->obs.ang_vel) * this->params.ang_vel_scale, + this->quat_rotate_inverse(this->obs.base_quat, this->obs.gravity_vec), this->obs.commands * this->params.commands_scale, (this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale, this->obs.dof_vel * this->params.dof_vel_scale, - this->obs.actions}, - 1); + this->obs.actions + },1); obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs); return obs; } @@ -217,13 +260,17 @@ torch::Tensor RL_Real::forward() return clamped; } - +void signalHandler(int signum) +{ + exit(0); +} int main(int argc, char **argv) { - RL_Real rl_sar; + signal(SIGINT, signalHandler); - while(1){ + while(1) + { sleep(10); };