喝醉了

This commit is contained in:
fan-ziqi 2024-03-15 17:43:55 +08:00
parent 63d28c2bff
commit 6a23e62a47
3 changed files with 155 additions and 88 deletions

View File

@ -8,14 +8,26 @@
#include <unitree_legged_msgs/MotorCmd.h> #include <unitree_legged_msgs/MotorCmd.h>
#include <unitree_legged_msgs/MotorState.h> #include <unitree_legged_msgs/MotorState.h>
#include "unitree_legged_sdk/unitree_legged_sdk.h" #include "unitree_legged_sdk/unitree_legged_sdk.h"
#include "unitree_legged_sdk/unitree_joystick.h"
#include <pthread.h> #include <pthread.h>
#include <csignal>
// #include <signal.h>
using namespace UNITREE_LEGGED_SDK; using namespace UNITREE_LEGGED_SDK;
enum InitState {
STATE_WAITING = 0,
STATE_POS_INIT,
STATE_RL_INIT,
STATE_RL_START,
STATE_POS_STOP,
};
class RL_Real : public RL class RL_Real : public RL
{ {
public: public:
RL_Real(); RL_Real();
~RL_Real();
void runModel(); void runModel();
torch::Tensor forward() override; torch::Tensor forward() override;
torch::Tensor compute_observation() override; torch::Tensor compute_observation() override;
@ -23,8 +35,6 @@ public:
ObservationBuffer history_obs_buf; ObservationBuffer history_obs_buf;
torch::Tensor history_obs; torch::Tensor history_obs;
torch::Tensor torques;
//udp //udp
void UDPSend(); void UDPSend();
void UDPRecv(); void UDPRecv();
@ -33,6 +43,7 @@ public:
UDP udp; UDP udp;
LowCmd cmd = {0}; LowCmd cmd = {0};
LowState state = {0}; LowState state = {0};
xRockerBtnDataStruct _keyData;
int motiontime = 0; int motiontime = 0;
std::shared_ptr<LoopFunc> loop_control; std::shared_ptr<LoopFunc> loop_control;
@ -40,25 +51,34 @@ public:
std::shared_ptr<LoopFunc> loop_udpRecv; std::shared_ptr<LoopFunc> loop_udpRecv;
std::shared_ptr<LoopFunc> loop_rl; std::shared_ptr<LoopFunc> loop_rl;
float _percent; float _percent;
float _targetPos[12] = {0.0, 0.8, -1.6, 0.0, 0.8, -1.6, // float _targetPos[12] = {0.0, 0.8, -1.6, 0.0, 0.8, -1.6,
0.0, 0.8, -1.6, 0.0, 0.8, -1.6}; //0.0, 0.67, -1.3 // 0.0, 0.8, -1.6, 0.0, 0.8, -1.6};
float _startPos[12]; float _startPos[12];
bool init_done = false; int init_state = STATE_WAITING;
private: private:
std::string ros_namespace;
std::vector<std::string> torque_command_topics;
std::vector<unitree_legged_msgs::MotorCmd> torque_commands;
std::vector<std::string> joint_names; std::vector<std::string> joint_names;
std::vector<double> joint_positions; std::vector<double> joint_positions;
std::vector<double> joint_velocities; std::vector<double> joint_velocities;
torch::Tensor torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
int dof_mapping[13] = {3, 4, 5,
0, 1, 2,
9, 10, 11,
6, 7, 8};
float Kp[13] = {20, 10, 10,
20, 10, 10,
20, 10, 10,
20, 10, 10};
float Kd[13] = {1.0, 0.5, 0.5,
1.0, 0.5, 0.5,
1.0, 0.5, 0.5,
1.0, 0.5, 0.5};
torch::Tensor target_dof_pos;
torch::Tensor compute_pos(torch::Tensor actions);
std::chrono::high_resolution_clock::time_point start_time; std::chrono::high_resolution_clock::time_point start_time;
// other rl module // other rl module

View File

@ -5,40 +5,40 @@ Copyright (c) 2020, Unitree Robotics.Co.Ltd. All rights reserved.
#define UNITREE_JOYSTICK_H #define UNITREE_JOYSTICK_H
#include <stdint.h> #include <stdint.h>
// // 16b // 16b
// typedef union { typedef union {
// struct { struct {
// uint8_t R1 :1; uint8_t R1 :1;
// uint8_t L1 :1; uint8_t L1 :1;
// uint8_t start :1; uint8_t start :1;
// uint8_t select :1; uint8_t select :1;
// uint8_t R2 :1; uint8_t R2 :1;
// uint8_t L2 :1; uint8_t L2 :1;
// uint8_t F1 :1; uint8_t F1 :1;
// uint8_t F2 :1; uint8_t F2 :1;
// uint8_t A :1; uint8_t A :1;
// uint8_t B :1; uint8_t B :1;
// uint8_t X :1; uint8_t X :1;
// uint8_t Y :1; uint8_t Y :1;
// uint8_t up :1; uint8_t up :1;
// uint8_t right :1; uint8_t right :1;
// uint8_t down :1; uint8_t down :1;
// uint8_t left :1; uint8_t left :1;
// } components; } components;
// uint16_t value; uint16_t value;
// } xKeySwitchUnion; } xKeySwitchUnion;
// // 40 Byte (now used 24B) // 40 Byte (now used 24B)
// typedef struct { typedef struct {
// uint8_t head[2]; uint8_t head[2];
// xKeySwitchUnion btn; xKeySwitchUnion btn;
// float lx; float lx;
// float rx; float rx;
// float ry; float ry;
// float L2; float L2;
// float ly; float ly;
// uint8_t idle[16]; uint8_t idle[16];
// } xRockerBtnDataStruct; } xRockerBtnDataStruct;
#endif // UNITREE_JOYSTICK_H #endif // UNITREE_JOYSTICK_H

View File

@ -1,5 +1,9 @@
#include "../include/rl_real.hpp" #include "../include/rl_real.hpp"
// #define CONTROL_BY_TORQUE
RL_Real rl_sar;
void RL_Real::UDPRecv() void RL_Real::UDPRecv()
{ {
udp.Recv(); udp.Recv();
@ -15,7 +19,20 @@ void RL_Real::RobotControl()
motiontime++; motiontime++;
udp.GetRecv(state); udp.GetRecv(state);
if( motiontime < 500) memcpy(&_keyData, state.wirelessRemote, 40);
// get joy button
if(init_state < STATE_POS_INIT && (int)_keyData.btn.components.R2 == 1)
{
init_state = STATE_POS_INIT;
}
else if(init_state < STATE_RL_INIT && (int)_keyData.btn.components.R1 == 1)
{
init_state = STATE_RL_INIT;
}
// wait for standup
if(init_state == STATE_WAITING)
{ {
for(int i = 0; i < 12; ++i) for(int i = 0; i < 12; ++i)
{ {
@ -23,36 +40,39 @@ void RL_Real::RobotControl()
_startPos[i] = state.motorState[i].q; _startPos[i] = state.motorState[i].q;
} }
} }
// standup (position control)
if( motiontime >= 500 && _percent != 1) else if(init_state == STATE_POS_INIT && _percent != 1)
{ {
printf("initing %d%%\r", (int)(_percent*100));
_percent += (float) 1 / 1000; _percent += (float) 1 / 1000;
_percent = _percent > 1 ? 1 : _percent; _percent = _percent > 1 ? 1 : _percent;
for(int i = 0; i < 12; ++i) for(int i = 0; i < 12; ++i)
{ {
cmd.motorCmd[i].q = (1 - _percent) * _startPos[i] + _percent * _targetPos[i]; cmd.motorCmd[i].q = (1 - _percent) * _startPos[i] + _percent * params.default_dof_pos[0][dof_mapping[i]].item<double>();
cmd.motorCmd[i].dq = 0; cmd.motorCmd[i].dq = 0;
cmd.motorCmd[i].Kp = 50; cmd.motorCmd[i].Kp = 50;
cmd.motorCmd[i].Kd = 3; cmd.motorCmd[i].Kd = 3;
cmd.motorCmd[i].tau = 0; cmd.motorCmd[i].tau = 0;
} }
} }
if(_percent == 1 && !init_done) // init obs and start rl loop
else if(init_state == STATE_RL_INIT && _percent == 1)
{ {
init_done = true; init_state = STATE_RL_START;
loop_rl->start();
this->init_observations();
std::cout << "init done" << std::endl;
motiontime = 0; motiontime = 0;
this->init_observations();
printf("\nstart rl loop\n");
loop_rl->start();
} }
// rl loop
if(init_done) else if(init_state == STATE_RL_START)
{ {
// wait for 500 times
if( motiontime < 500) if( motiontime < 500)
{ {
for(int i = 0; i < 12; ++i) for(int i = 0; i < 12; ++i)
{ {
cmd.motorCmd[i].q = _targetPos[i]; cmd.motorCmd[i].q = params.default_dof_pos[0][dof_mapping[i]].item<double>();
cmd.motorCmd[i].dq = 0; cmd.motorCmd[i].dq = 0;
cmd.motorCmd[i].Kp = 50; cmd.motorCmd[i].Kp = 50;
cmd.motorCmd[i].Kd = 3; cmd.motorCmd[i].Kd = 3;
@ -62,24 +82,33 @@ void RL_Real::RobotControl()
} }
if( motiontime >= 500) if( motiontime >= 500)
{ {
int mapping[13] = {3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8}; #ifdef CONTROL_BY_TORQUE
for (int i = 0; i < 12; ++i) for (int i = 0; i < 12; ++i)
{ {
float torque = torques[0][mapping[i]].item<double>(); float torque = torques[0][dof_mapping[i]].item<double>();
// if(torque > 5.0f) torque = 5.0f; // if(torque > 5.0f) torque = 5.0f;
// if(torque < -5.0f) torque = -5.0f; // if(torque < -5.0f) torque = -5.0f;
cmd.motorCmd[i].q = PosStopF; cmd.motorCmd[i].q = 0;
cmd.motorCmd[i].dq = VelStopF; cmd.motorCmd[i].dq = 0;
cmd.motorCmd[i].Kp = 0; cmd.motorCmd[i].Kp = 0;
cmd.motorCmd[i].Kd = 0; cmd.motorCmd[i].Kd = 0;
cmd.motorCmd[i].tau = torque; cmd.motorCmd[i].tau = torque;
} }
#else
for (int i = 0; i < 12; ++i)
{
cmd.motorCmd[i].q = target_dof_pos[0][dof_mapping[i]].item<double>();
cmd.motorCmd[i].dq = 0;
cmd.motorCmd[i].Kp = 15;
cmd.motorCmd[i].Kd = 1.5;
cmd.motorCmd[i].tau = 0;
}
#endif
} }
} }
safe.PowerProtect(cmd, state, 1); safe.PowerProtect(cmd, state, 7);
udp.SetSend(cmd); udp.SetSend(cmd);
} }
@ -89,8 +118,6 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL)
start_time = std::chrono::high_resolution_clock::now(); start_time = std::chrono::high_resolution_clock::now();
torque_commands.resize(12);
std::string actor_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/actor.pt"; std::string actor_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/actor.pt";
std::string encoder_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/encoder.pt"; std::string encoder_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/encoder.pt";
std::string vq_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/vq_layer.pt"; std::string vq_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/vq_layer.pt";
@ -103,8 +130,8 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL)
this->params.num_observations = 45; this->params.num_observations = 45;
this->params.clip_obs = 100.0; this->params.clip_obs = 100.0;
this->params.clip_actions = 100.0; this->params.clip_actions = 100.0;
this->params.damping = 0.5; this->params.damping = 1.0; // TODO
this->params.stiffness = 20; this->params.stiffness = 5; // TODO
this->params.d_gains = torch::ones(12) * this->params.damping; this->params.d_gains = torch::ones(12) * this->params.damping;
this->params.p_gains = torch::ones(12) * this->params.stiffness; this->params.p_gains = torch::ones(12) * this->params.stiffness;
this->params.action_scale = 0.25; this->params.action_scale = 0.25;
@ -116,7 +143,6 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL)
this->params.dof_vel_scale = 0.05; this->params.dof_vel_scale = 0.05;
this->params.commands_scale = torch::tensor({this->params.lin_vel_scale, this->params.lin_vel_scale, this->params.ang_vel_scale}); this->params.commands_scale = torch::tensor({this->params.lin_vel_scale, this->params.lin_vel_scale, this->params.ang_vel_scale});
this->params.torque_limits = torch::tensor({{20.0, 55.0, 55.0, this->params.torque_limits = torch::tensor({{20.0, 55.0, 55.0,
20.0, 55.0, 55.0, 20.0, 55.0, 55.0,
20.0, 55.0, 55.0, 20.0, 55.0, 55.0,
@ -130,6 +156,8 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL)
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6); this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
target_dof_pos = params.default_dof_pos;
// InitEnvironment(); // InitEnvironment();
loop_control = std::make_shared<LoopFunc>("control_loop", 0.002, boost::bind(&RL_Real::RobotControl, this)); loop_control = std::make_shared<LoopFunc>("control_loop", 0.002, boost::bind(&RL_Real::RobotControl, this));
loop_udpSend = std::make_shared<LoopFunc>("udp_send" , 0.002, 3, boost::bind(&RL_Real::UDPSend, this)); loop_udpSend = std::make_shared<LoopFunc>("udp_send" , 0.002, 3, boost::bind(&RL_Real::UDPSend, this));
@ -141,9 +169,30 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL)
loop_control->start(); loop_control->start();
} }
RL_Real::~RL_Real()
{
loop_udpSend->shutdown();
loop_udpRecv->shutdown();
loop_control->shutdown();
loop_rl->shutdown();
printf("shutdown\n");
}
torch::Tensor RL_Real::compute_pos(torch::Tensor actions)
{
torch::Tensor actions_scaled = actions * this->params.action_scale;
int indices[] = {0, 3, 6, 9};
for (int i : indices)
{
actions_scaled[0][i] *= this->params.hip_scale_reduction;
}
return actions_scaled + this->params.default_dof_pos;
}
void RL_Real::runModel() void RL_Real::runModel()
{ {
if(init_done) if(init_state == STATE_RL_START)
{ {
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count(); auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count();
// std::cout << "Execution time: " << duration << " microseconds" << std::endl; // std::cout << "Execution time: " << duration << " microseconds" << std::endl;
@ -155,7 +204,7 @@ void RL_Real::runModel()
// printf("%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f\n", state.motorState[FL_0].dq, state.motorState[FL_1].dq, state.motorState[FL_2].dq, state.motorState[FR_0].dq, state.motorState[FR_1].dq, state.motorState[FR_2].dq, state.motorState[RL_0].dq, state.motorState[RL_1].dq, state.motorState[RL_2].dq, state.motorState[RR_0].dq, state.motorState[RR_1].dq, state.motorState[RR_2].dq); // printf("%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f\n", state.motorState[FL_0].dq, state.motorState[FL_1].dq, state.motorState[FL_2].dq, state.motorState[FR_0].dq, state.motorState[FR_1].dq, state.motorState[FR_2].dq, state.motorState[RL_0].dq, state.motorState[RL_1].dq, state.motorState[RL_2].dq, state.motorState[RR_0].dq, state.motorState[RR_1].dq, state.motorState[RR_2].dq);
this->obs.ang_vel = torch::tensor({{state.imu.gyroscope[0], state.imu.gyroscope[1], state.imu.gyroscope[2]}}); this->obs.ang_vel = torch::tensor({{state.imu.gyroscope[0], state.imu.gyroscope[1], state.imu.gyroscope[2]}});
// this->obs.commands = torch::tensor({{cmd_vel.linear.x, cmd_vel.linear.y, cmd_vel.angular.z}}); this->obs.commands = torch::tensor({{_keyData.ly, -_keyData.rx, -_keyData.lx}});
this->obs.base_quat = torch::tensor({{state.imu.quaternion[1], state.imu.quaternion[2], state.imu.quaternion[3], state.imu.quaternion[0]}}); this->obs.base_quat = torch::tensor({{state.imu.quaternion[1], state.imu.quaternion[2], state.imu.quaternion[3], state.imu.quaternion[0]}});
this->obs.dof_pos = torch::tensor({{state.motorState[FL_0].q, state.motorState[FL_1].q, state.motorState[FL_2].q, this->obs.dof_pos = torch::tensor({{state.motorState[FL_0].q, state.motorState[FL_1].q, state.motorState[FL_2].q,
state.motorState[FR_0].q, state.motorState[FR_1].q, state.motorState[FR_2].q, state.motorState[FR_0].q, state.motorState[FR_1].q, state.motorState[FR_2].q,
@ -167,31 +216,25 @@ void RL_Real::runModel()
state.motorState[RR_0].dq, state.motorState[RR_1].dq, state.motorState[RR_2].dq}}); state.motorState[RR_0].dq, state.motorState[RR_1].dq, state.motorState[RR_2].dq}});
torch::Tensor actions = this->forward(); torch::Tensor actions = this->forward();
#ifdef CONTROL_BY_TORQUE
torques = this->compute_torques(actions); torques = this->compute_torques(actions);
#else
target_dof_pos = this->compute_pos(actions);
#endif
} }
} }
torch::Tensor RL_Real::compute_observation() torch::Tensor RL_Real::compute_observation()
{ {
torch::Tensor ang_vel = this->quat_rotate_inverse(this->obs.base_quat, this->obs.ang_vel);
// float ang_vel_temp = ang_vel[0][0].item<double>();
// ang_vel[0][0] = ang_vel[0][1];
// ang_vel[0][1] = ang_vel_temp;
torch::Tensor grav = this->quat_rotate_inverse(this->obs.base_quat, this->obs.gravity_vec);
// float grav_temp = grav[0][0].item<double>();
// grav[0][0] = grav[0][1];
// grav[0][1] = grav_temp;
torch::Tensor obs = torch::cat({// (this->quat_rotate_inverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale, torch::Tensor obs = torch::cat({// (this->quat_rotate_inverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale,
ang_vel * this->params.ang_vel_scale, this->quat_rotate_inverse(this->obs.base_quat, this->obs.ang_vel) * this->params.ang_vel_scale,
grav, this->quat_rotate_inverse(this->obs.base_quat, this->obs.gravity_vec),
this->obs.commands * this->params.commands_scale, this->obs.commands * this->params.commands_scale,
(this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale, (this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale,
this->obs.dof_vel * this->params.dof_vel_scale, this->obs.dof_vel * this->params.dof_vel_scale,
this->obs.actions}, this->obs.actions
1); },1);
obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs); obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
return obs; return obs;
} }
@ -217,13 +260,17 @@ torch::Tensor RL_Real::forward()
return clamped; return clamped;
} }
void signalHandler(int signum)
{
exit(0);
}
int main(int argc, char **argv) int main(int argc, char **argv)
{ {
RL_Real rl_sar; signal(SIGINT, signalHandler);
while(1){ while(1)
{
sleep(10); sleep(10);
}; };