喝醉了

This commit is contained in:
fan-ziqi 2024-03-15 17:43:55 +08:00
parent 63d28c2bff
commit 6a23e62a47
3 changed files with 155 additions and 88 deletions

View File

@ -8,14 +8,26 @@
#include <unitree_legged_msgs/MotorCmd.h>
#include <unitree_legged_msgs/MotorState.h>
#include "unitree_legged_sdk/unitree_legged_sdk.h"
#include "unitree_legged_sdk/unitree_joystick.h"
#include <pthread.h>
#include <csignal>
// #include <signal.h>
using namespace UNITREE_LEGGED_SDK;
enum InitState {
STATE_WAITING = 0,
STATE_POS_INIT,
STATE_RL_INIT,
STATE_RL_START,
STATE_POS_STOP,
};
class RL_Real : public RL
{
public:
RL_Real();
~RL_Real();
void runModel();
torch::Tensor forward() override;
torch::Tensor compute_observation() override;
@ -23,8 +35,6 @@ public:
ObservationBuffer history_obs_buf;
torch::Tensor history_obs;
torch::Tensor torques;
//udp
void UDPSend();
void UDPRecv();
@ -33,6 +43,7 @@ public:
UDP udp;
LowCmd cmd = {0};
LowState state = {0};
xRockerBtnDataStruct _keyData;
int motiontime = 0;
std::shared_ptr<LoopFunc> loop_control;
@ -40,25 +51,34 @@ public:
std::shared_ptr<LoopFunc> loop_udpRecv;
std::shared_ptr<LoopFunc> loop_rl;
float _percent;
float _targetPos[12] = {0.0, 0.8, -1.6, 0.0, 0.8, -1.6,
0.0, 0.8, -1.6, 0.0, 0.8, -1.6}; //0.0, 0.67, -1.3
// float _targetPos[12] = {0.0, 0.8, -1.6, 0.0, 0.8, -1.6,
// 0.0, 0.8, -1.6, 0.0, 0.8, -1.6};
float _startPos[12];
bool init_done = false;
int init_state = STATE_WAITING;
private:
std::string ros_namespace;
std::vector<std::string> torque_command_topics;
std::vector<unitree_legged_msgs::MotorCmd> torque_commands;
std::vector<std::string> joint_names;
std::vector<double> joint_positions;
std::vector<double> joint_velocities;
torch::Tensor torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
int dof_mapping[13] = {3, 4, 5,
0, 1, 2,
9, 10, 11,
6, 7, 8};
float Kp[13] = {20, 10, 10,
20, 10, 10,
20, 10, 10,
20, 10, 10};
float Kd[13] = {1.0, 0.5, 0.5,
1.0, 0.5, 0.5,
1.0, 0.5, 0.5,
1.0, 0.5, 0.5};
torch::Tensor target_dof_pos;
torch::Tensor compute_pos(torch::Tensor actions);
std::chrono::high_resolution_clock::time_point start_time;
// other rl module

View File

@ -5,40 +5,40 @@ Copyright (c) 2020, Unitree Robotics.Co.Ltd. All rights reserved.
#define UNITREE_JOYSTICK_H
#include <stdint.h>
// // 16b
// typedef union {
// struct {
// uint8_t R1 :1;
// uint8_t L1 :1;
// uint8_t start :1;
// uint8_t select :1;
// uint8_t R2 :1;
// uint8_t L2 :1;
// uint8_t F1 :1;
// uint8_t F2 :1;
// uint8_t A :1;
// uint8_t B :1;
// uint8_t X :1;
// uint8_t Y :1;
// uint8_t up :1;
// uint8_t right :1;
// uint8_t down :1;
// uint8_t left :1;
// } components;
// uint16_t value;
// } xKeySwitchUnion;
// 16b
typedef union {
struct {
uint8_t R1 :1;
uint8_t L1 :1;
uint8_t start :1;
uint8_t select :1;
uint8_t R2 :1;
uint8_t L2 :1;
uint8_t F1 :1;
uint8_t F2 :1;
uint8_t A :1;
uint8_t B :1;
uint8_t X :1;
uint8_t Y :1;
uint8_t up :1;
uint8_t right :1;
uint8_t down :1;
uint8_t left :1;
} components;
uint16_t value;
} xKeySwitchUnion;
// // 40 Byte (now used 24B)
// typedef struct {
// uint8_t head[2];
// xKeySwitchUnion btn;
// float lx;
// float rx;
// float ry;
// float L2;
// float ly;
// 40 Byte (now used 24B)
typedef struct {
uint8_t head[2];
xKeySwitchUnion btn;
float lx;
float rx;
float ry;
float L2;
float ly;
// uint8_t idle[16];
// } xRockerBtnDataStruct;
uint8_t idle[16];
} xRockerBtnDataStruct;
#endif // UNITREE_JOYSTICK_H

View File

@ -1,5 +1,9 @@
#include "../include/rl_real.hpp"
// #define CONTROL_BY_TORQUE
RL_Real rl_sar;
void RL_Real::UDPRecv()
{
udp.Recv();
@ -15,7 +19,20 @@ void RL_Real::RobotControl()
motiontime++;
udp.GetRecv(state);
if( motiontime < 500)
memcpy(&_keyData, state.wirelessRemote, 40);
// get joy button
if(init_state < STATE_POS_INIT && (int)_keyData.btn.components.R2 == 1)
{
init_state = STATE_POS_INIT;
}
else if(init_state < STATE_RL_INIT && (int)_keyData.btn.components.R1 == 1)
{
init_state = STATE_RL_INIT;
}
// wait for standup
if(init_state == STATE_WAITING)
{
for(int i = 0; i < 12; ++i)
{
@ -23,36 +40,39 @@ void RL_Real::RobotControl()
_startPos[i] = state.motorState[i].q;
}
}
if( motiontime >= 500 && _percent != 1)
// standup (position control)
else if(init_state == STATE_POS_INIT && _percent != 1)
{
printf("initing %d%%\r", (int)(_percent*100));
_percent += (float) 1 / 1000;
_percent = _percent > 1 ? 1 : _percent;
for(int i = 0; i < 12; ++i)
{
cmd.motorCmd[i].q = (1 - _percent) * _startPos[i] + _percent * _targetPos[i];
cmd.motorCmd[i].q = (1 - _percent) * _startPos[i] + _percent * params.default_dof_pos[0][dof_mapping[i]].item<double>();
cmd.motorCmd[i].dq = 0;
cmd.motorCmd[i].Kp = 50;
cmd.motorCmd[i].Kd = 3;
cmd.motorCmd[i].tau = 0;
}
}
if(_percent == 1 && !init_done)
// init obs and start rl loop
else if(init_state == STATE_RL_INIT && _percent == 1)
{
init_done = true;
loop_rl->start();
this->init_observations();
std::cout << "init done" << std::endl;
init_state = STATE_RL_START;
motiontime = 0;
this->init_observations();
printf("\nstart rl loop\n");
loop_rl->start();
}
if(init_done)
// rl loop
else if(init_state == STATE_RL_START)
{
// wait for 500 times
if( motiontime < 500)
{
for(int i = 0; i < 12; ++i)
{
cmd.motorCmd[i].q = _targetPos[i];
cmd.motorCmd[i].q = params.default_dof_pos[0][dof_mapping[i]].item<double>();
cmd.motorCmd[i].dq = 0;
cmd.motorCmd[i].Kp = 50;
cmd.motorCmd[i].Kd = 3;
@ -62,24 +82,33 @@ void RL_Real::RobotControl()
}
if( motiontime >= 500)
{
int mapping[13] = {3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8};
#ifdef CONTROL_BY_TORQUE
for (int i = 0; i < 12; ++i)
{
float torque = torques[0][mapping[i]].item<double>();
float torque = torques[0][dof_mapping[i]].item<double>();
// if(torque > 5.0f) torque = 5.0f;
// if(torque < -5.0f) torque = -5.0f;
cmd.motorCmd[i].q = PosStopF;
cmd.motorCmd[i].dq = VelStopF;
cmd.motorCmd[i].q = 0;
cmd.motorCmd[i].dq = 0;
cmd.motorCmd[i].Kp = 0;
cmd.motorCmd[i].Kd = 0;
cmd.motorCmd[i].tau = torque;
}
#else
for (int i = 0; i < 12; ++i)
{
cmd.motorCmd[i].q = target_dof_pos[0][dof_mapping[i]].item<double>();
cmd.motorCmd[i].dq = 0;
cmd.motorCmd[i].Kp = 15;
cmd.motorCmd[i].Kd = 1.5;
cmd.motorCmd[i].tau = 0;
}
#endif
}
}
safe.PowerProtect(cmd, state, 1);
safe.PowerProtect(cmd, state, 7);
udp.SetSend(cmd);
}
@ -89,8 +118,6 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL)
start_time = std::chrono::high_resolution_clock::now();
torque_commands.resize(12);
std::string actor_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/actor.pt";
std::string encoder_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/encoder.pt";
std::string vq_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/vq_layer.pt";
@ -103,8 +130,8 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL)
this->params.num_observations = 45;
this->params.clip_obs = 100.0;
this->params.clip_actions = 100.0;
this->params.damping = 0.5;
this->params.stiffness = 20;
this->params.damping = 1.0; // TODO
this->params.stiffness = 5; // TODO
this->params.d_gains = torch::ones(12) * this->params.damping;
this->params.p_gains = torch::ones(12) * this->params.stiffness;
this->params.action_scale = 0.25;
@ -115,14 +142,13 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL)
this->params.dof_pos_scale = 1.0;
this->params.dof_vel_scale = 0.05;
this->params.commands_scale = torch::tensor({this->params.lin_vel_scale, this->params.lin_vel_scale, this->params.ang_vel_scale});
this->params.torque_limits = torch::tensor({{20.0, 55.0, 55.0,
20.0, 55.0, 55.0,
20.0, 55.0, 55.0,
20.0, 55.0, 55.0}});
// hip, thigh, calf
// hip, thigh, calf
this->params.default_dof_pos = torch::tensor({{-0.1000, 0.8000, -1.5000, // front right
0.1000, 0.8000, -1.5000, // front left
-0.1000, 1.0000, -1.5000, // rear right
@ -130,6 +156,8 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL)
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
target_dof_pos = params.default_dof_pos;
// InitEnvironment();
loop_control = std::make_shared<LoopFunc>("control_loop", 0.002, boost::bind(&RL_Real::RobotControl, this));
loop_udpSend = std::make_shared<LoopFunc>("udp_send" , 0.002, 3, boost::bind(&RL_Real::UDPSend, this));
@ -141,9 +169,30 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL)
loop_control->start();
}
RL_Real::~RL_Real()
{
loop_udpSend->shutdown();
loop_udpRecv->shutdown();
loop_control->shutdown();
loop_rl->shutdown();
printf("shutdown\n");
}
torch::Tensor RL_Real::compute_pos(torch::Tensor actions)
{
torch::Tensor actions_scaled = actions * this->params.action_scale;
int indices[] = {0, 3, 6, 9};
for (int i : indices)
{
actions_scaled[0][i] *= this->params.hip_scale_reduction;
}
return actions_scaled + this->params.default_dof_pos;
}
void RL_Real::runModel()
{
if(init_done)
if(init_state == STATE_RL_START)
{
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count();
// std::cout << "Execution time: " << duration << " microseconds" << std::endl;
@ -155,7 +204,7 @@ void RL_Real::runModel()
// printf("%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f\n", state.motorState[FL_0].dq, state.motorState[FL_1].dq, state.motorState[FL_2].dq, state.motorState[FR_0].dq, state.motorState[FR_1].dq, state.motorState[FR_2].dq, state.motorState[RL_0].dq, state.motorState[RL_1].dq, state.motorState[RL_2].dq, state.motorState[RR_0].dq, state.motorState[RR_1].dq, state.motorState[RR_2].dq);
this->obs.ang_vel = torch::tensor({{state.imu.gyroscope[0], state.imu.gyroscope[1], state.imu.gyroscope[2]}});
// this->obs.commands = torch::tensor({{cmd_vel.linear.x, cmd_vel.linear.y, cmd_vel.angular.z}});
this->obs.commands = torch::tensor({{_keyData.ly, -_keyData.rx, -_keyData.lx}});
this->obs.base_quat = torch::tensor({{state.imu.quaternion[1], state.imu.quaternion[2], state.imu.quaternion[3], state.imu.quaternion[0]}});
this->obs.dof_pos = torch::tensor({{state.motorState[FL_0].q, state.motorState[FL_1].q, state.motorState[FL_2].q,
state.motorState[FR_0].q, state.motorState[FR_1].q, state.motorState[FR_2].q,
@ -167,31 +216,25 @@ void RL_Real::runModel()
state.motorState[RR_0].dq, state.motorState[RR_1].dq, state.motorState[RR_2].dq}});
torch::Tensor actions = this->forward();
#ifdef CONTROL_BY_TORQUE
torques = this->compute_torques(actions);
#else
target_dof_pos = this->compute_pos(actions);
#endif
}
}
torch::Tensor RL_Real::compute_observation()
{
torch::Tensor ang_vel = this->quat_rotate_inverse(this->obs.base_quat, this->obs.ang_vel);
// float ang_vel_temp = ang_vel[0][0].item<double>();
// ang_vel[0][0] = ang_vel[0][1];
// ang_vel[0][1] = ang_vel_temp;
torch::Tensor grav = this->quat_rotate_inverse(this->obs.base_quat, this->obs.gravity_vec);
// float grav_temp = grav[0][0].item<double>();
// grav[0][0] = grav[0][1];
// grav[0][1] = grav_temp;
torch::Tensor obs = torch::cat({// (this->quat_rotate_inverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale,
ang_vel * this->params.ang_vel_scale,
grav,
this->quat_rotate_inverse(this->obs.base_quat, this->obs.ang_vel) * this->params.ang_vel_scale,
this->quat_rotate_inverse(this->obs.base_quat, this->obs.gravity_vec),
this->obs.commands * this->params.commands_scale,
(this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale,
this->obs.dof_vel * this->params.dof_vel_scale,
this->obs.actions},
1);
this->obs.actions
},1);
obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
return obs;
}
@ -217,13 +260,17 @@ torch::Tensor RL_Real::forward()
return clamped;
}
void signalHandler(int signum)
{
exit(0);
}
int main(int argc, char **argv)
{
RL_Real rl_sar;
signal(SIGINT, signalHandler);
while(1){
while(1)
{
sleep(10);
};