fix: code format

This commit is contained in:
fan-ziqi 2024-10-12 22:03:26 +08:00
parent bd2975223f
commit 4e4157d2b3
10 changed files with 295 additions and 389 deletions

View File

@ -4,45 +4,34 @@
#include "rl_sdk.hpp" #include "rl_sdk.hpp"
#include "observation_buffer.hpp" #include "observation_buffer.hpp"
#include "loop.hpp" #include "loop.hpp"
#include <unitree/robot/channel/channel_publisher.hpp> #include <unitree/robot/channel/channel_publisher.hpp> // TODO go2的sdk没有传上来
#include <unitree/robot/channel/channel_subscriber.hpp> #include <unitree/robot/channel/channel_subscriber.hpp>
#include <unitree/idl/go2/LowState_.hpp> #include <unitree/idl/go2/LowState_.hpp>
#include <unitree/idl/go2/LowCmd_.hpp> #include <unitree/idl/go2/LowCmd_.hpp>
#include <unitree/common/time/time_tool.hpp> #include <unitree/common/time/time_tool.hpp>
#include <unitree/common/thread/thread.hpp> #include <unitree/common/thread/thread.hpp>
#include <unitree/robot/go2/robot_state/robot_state_client.hpp> #include <unitree/robot/go2/robot_state/robot_state_client.hpp>
#include "unitree_joystick.h" // TODO 这里调用的是a1的go2有自己的键盘接口吗
#include <csignal> #include <csignal>
#include "matplotlibcpp.h" #include "matplotlibcpp.h"
#include "unitree_joystick.h"
namespace plt = matplotlibcpp; namespace plt = matplotlibcpp;
using namespace unitree::common; using namespace unitree::common;
using namespace unitree::robot; using namespace unitree::robot;
using namespace unitree::robot::go2; using namespace unitree::robot::go2;
#define TOPIC_LOWCMD "rt/lowcmd" #define TOPIC_LOWCMD "rt/lowcmd"
#define TOPIC_LOWSTATE "rt/lowstate" #define TOPIC_LOWSTATE "rt/lowstate"
constexpr double PosStopF = (2.146E+9f); constexpr double PosStopF = (2.146E+9f);
constexpr double VelStopF = (16000.0f); constexpr double VelStopF = (16000.0f);
class RL_Real : public RL class RL_Real : public RL
{ {
public: public:
RL_Real(); RL_Real();
~RL_Real(); ~RL_Real();
void Init();
void InitRobotStateClient();
int queryServiceStatus(const std::string& serviceName);
void activateService(const std::string& serviceName,int activate);
private: private:
void LowCmdwriteHandler();
uint32_t crc32_core(uint32_t* ptr, uint32_t len);
void InitLowCmd();
void LowStateMessageHandler(const void* messages);
// rl functions // rl functions
torch::Tensor Forward() override; torch::Tensor Forward() override;
void GetState(RobotState<double> *state) override; void GetState(RobotState<double> *state) override;
@ -57,8 +46,6 @@ private:
// loop // loop
std::shared_ptr<LoopFunc> loop_keyboard; std::shared_ptr<LoopFunc> loop_keyboard;
std::shared_ptr<LoopFunc> loop_control; std::shared_ptr<LoopFunc> loop_control;
std::shared_ptr<LoopFunc> loop_udpSend;
std::shared_ptr<LoopFunc> loop_udpRecv;
std::shared_ptr<LoopFunc> loop_rl; std::shared_ptr<LoopFunc> loop_rl;
std::shared_ptr<LoopFunc> loop_plot; std::shared_ptr<LoopFunc> loop_plot;
@ -66,26 +53,23 @@ private:
const int plot_size = 100; const int plot_size = 100;
std::vector<int> plot_t; std::vector<int> plot_t;
std::vector<std::vector<double>> plot_real_joint_pos, plot_target_joint_pos; std::vector<std::vector<double>> plot_real_joint_pos, plot_target_joint_pos;
void Plot(); void Plot();
// unitree interface // unitree interface
void InitRobotStateClient();
int QueryServiceStatus(const std::string &serviceName);
void ActivateService(const std::string &serviceName, int activate);
void LowCmdwriteHandler();
uint32_t Crc32Core(uint32_t *ptr, uint32_t len);
void InitLowCmd();
void LowStateMessageHandler(const void *messages);
RobotStateClient rsc; RobotStateClient rsc;
unitree_go::msg::dds_::LowCmd_ unitree_low_command{}; // default init unitree_go::msg::dds_::LowCmd_ unitree_low_command{}; // default init
unitree_go::msg::dds_::LowState_ unitree_low_state{}; // default init unitree_go::msg::dds_::LowState_ unitree_low_state{}; // default init
ChannelPublisherPtr<unitree_go::msg::dds_::LowCmd_> lowcmd_publisher; // publisher
/*publisher*/ ChannelSubscriberPtr<unitree_go::msg::dds_::LowState_> lowstate_subscriber; //subscriber
ChannelPublisherPtr<unitree_go::msg::dds_::LowCmd_> lowcmd_publisher;
/*subscriber*/
ChannelSubscriberPtr<unitree_go::msg::dds_::LowState_> lowstate_subscriber;
xRockerBtnDataStruct unitree_joy; xRockerBtnDataStruct unitree_joy;
// others // others
int motiontime = 0; int motiontime = 0;
std::vector<double> mapped_joint_positions; std::vector<double> mapped_joint_positions;
@ -94,4 +78,4 @@ private:
int state_mapping[12] = {3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8}; int state_mapping[12] = {3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8};
}; };
#endif #endif // RL_REAL_HPP

View File

@ -1,41 +0,0 @@
#ifndef UNITREE_JOYSTICK_H
#define UNITREE_JOYSTICK_H
#include <stdint.h>
// 16b
typedef union {
struct {
uint8_t R1 :1;
uint8_t L1 :1;
uint8_t start :1;
uint8_t select :1;
uint8_t R2 :1;
uint8_t L2 :1;
uint8_t F1 :1;
uint8_t F2 :1;
uint8_t A :1;
uint8_t B :1;
uint8_t X :1;
uint8_t Y :1;
uint8_t up :1;
uint8_t right :1;
uint8_t down :1;
uint8_t left :1;
} components;
uint16_t value;
} xKeySwitchUnion;
// 40 Byte (now used 24B)
typedef struct {
uint8_t head[2];
xKeySwitchUnion btn;
float lx;
float rx;
float ry;
float L2;
float ly;
uint8_t idle[16];
} xRockerBtnDataStruct;
#endif // UNITREE_JOYSTICK_H

View File

@ -1,5 +1,5 @@
go2_isaacgym: go2_isaacgym:
model_name: "himloco_3.pt" model_name: "himloco.pt"
framework: "isaacgym" framework: "isaacgym"
rows: 4 rows: 4
cols: 3 cols: 3

View File

@ -1,57 +1,12 @@
#include "../include/rl_real_go2.hpp" #include "rl_real_go2.hpp"
// #define PLOT // #define PLOT
// #define CSV_LOGGER // #define CSV_LOGGER
RL_Real rl_sar;
void RL_Real::RL_Real()
uint32_t RL_Real::crc32_core(uint32_t* ptr, uint32_t len)
{ {
unsigned int xbit = 0;
unsigned int data = 0;
unsigned int CRC32 = 0xFFFFFFFF;
const unsigned int dwPolynomial = 0x04c11db7;
for (unsigned int i = 0; i < len; i++)
{
xbit = 1 << 31;
data = ptr[i];
for (unsigned int bits = 0; bits < 32; bits++)
{
if (CRC32 & 0x80000000)
{
CRC32 <<= 1;
CRC32 ^= dwPolynomial;
}
else
{
CRC32 <<= 1;
}
if (data & xbit)
CRC32 ^= dwPolynomial;
xbit >>= 1;
}
}
return CRC32;
}
void RL_Real::Init()
{
InitLowCmd();
/*create publisher*/
lowcmd_publisher.reset(new ChannelPublisher<unitree_go::msg::dds_::LowCmd_>(TOPIC_LOWCMD));
lowcmd_publisher->InitChannel();
/*create subscriber*/
lowstate_subscriber.reset(new ChannelSubscriber<unitree_go::msg::dds_::LowState_>(TOPIC_LOWSTATE));
lowstate_subscriber->InitChannel(std::bind(&RL_Real::LowStateMessageHandler, this, std::placeholders::_1), 1);
/*loop publishing thread*/
//lowCmdWriteThreadPtr = CreateRecurrentThreadEx("writebasiccmd", UT_CPU_ID_NONE, 2000, &RL_Real::LowCmdwriteHandler, this);
// read params from yaml // read params from yaml
this->robot_name = "go2_isaacgym"; this->robot_name = "go2_isaacgym";
this->ReadYaml(this->robot_name); this->ReadYaml(this->robot_name);
@ -59,30 +14,43 @@ void RL_Real::Init()
// history // history
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6); this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
// init robot
this->InitRobotStateClient();
while (rl_sar.QueryServiceStatus("sport_mode"))
{
std::cout << "Try to deactivate the service: " << "sport_mode" << std::endl;
rl_sar.ActivateService("sport_mode", 0);
sleep(1);
}
this->InitLowCmd();
// create publisher
lowcmd_publisher.reset(new ChannelPublisher<unitree_go::msg::dds_::LowCmd_>(TOPIC_LOWCMD));
lowcmd_publisher->InitChannel();
// create subscriber
lowstate_subscriber.reset(new ChannelSubscriber<unitree_go::msg::dds_::LowState_>(TOPIC_LOWSTATE));
lowstate_subscriber->InitChannel(std::bind(&RL_Real::LowStateMessageHandler, this, std::placeholders::_1), 1);
// loop publishing thread TODO why?
// lowCmdWriteThreadPtr = CreateRecurrentThreadEx("writebasiccmd", UT_CPU_ID_NONE, 2000, &RL_Real::LowCmdwriteHandler, this);
// init // init rl
torch::autograd::GradMode::set_enabled(false); torch::autograd::GradMode::set_enabled(false);
this->InitObservations(); this->InitObservations();
this->InitOutputs(); this->InitOutputs();
this->InitControl(); this->InitControl();
running_state = STATE_WAITING;
// model // model
std::string model_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + this->robot_name + "/" + this->params.model_name; std::string model_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + this->robot_name + "/" + this->params.model_name;
this->model = torch::jit::load(model_path); this->model = torch::jit::load(model_path);
// loop // loop
// this->loop_udpSend = std::make_shared<LoopFunc>("loop_udpSend" , 0.002, std::bind(&RL_Real::UDPSend, this), 3);
// this->loop_udpRecv = std::make_shared<LoopFunc>("loop_udpRecv" , 0.002, std::bind(&RL_Real::UDPRecv, this), 3);
this->loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05, std::bind(&RL_Real::KeyboardInterface, this)); this->loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05, std::bind(&RL_Real::KeyboardInterface, this));
this->loop_control = std::make_shared<LoopFunc>("loop_control", this->params.dt, std::bind(&RL_Real::RobotControl, this)); this->loop_control = std::make_shared<LoopFunc>("loop_control", this->params.dt, std::bind(&RL_Real::RobotControl, this));
this->loop_rl = std::make_shared<LoopFunc>("loop_rl", this->params.dt * this->params.decimation, std::bind(&RL_Real::RunModel, this)); this->loop_rl = std::make_shared<LoopFunc>("loop_rl", this->params.dt * this->params.decimation, std::bind(&RL_Real::RunModel, this));
// this->loop_udpSend->start();
// this->loop_udpRecv->start();
this->loop_keyboard->start(); this->loop_keyboard->start();
this->loop_control->start(); this->loop_control->start();
this->loop_rl->start(); this->loop_rl->start();
#ifdef PLOT #ifdef PLOT
this->plot_t = std::vector<int>(this->plot_size, 0); this->plot_t = std::vector<int>(this->plot_size, 0);
this->plot_real_joint_pos.resize(this->params.num_of_dofs); this->plot_real_joint_pos.resize(this->params.num_of_dofs);
@ -95,88 +63,10 @@ void RL_Real::Init()
#ifdef CSV_LOGGER #ifdef CSV_LOGGER
this->CSVInit(this->robot_name); this->CSVInit(this->robot_name);
#endif #endif
}
void RL_Real::InitLowCmd()
{
unitree_low_command.head()[0] = 0xFE;
unitree_low_command.head()[1] = 0xEF;
unitree_low_command.level_flag() = 0xFF;
unitree_low_command.gpio() = 0;
for(int i=0; i<20; i++)
{
unitree_low_command.motor_cmd()[i].mode() = (0x01); // motor switch to servo (PMSM) mode
unitree_low_command.motor_cmd()[i].q() = (PosStopF);
unitree_low_command.motor_cmd()[i].kp() = (0);
unitree_low_command.motor_cmd()[i].dq() = (VelStopF);
unitree_low_command.motor_cmd()[i].kd() = (0);
unitree_low_command.motor_cmd()[i].tau() = (0);
}
}
void RL_Real::InitRobotStateClient()
{
rsc.SetTimeout(10.0f);
rsc.Init();
}
int RL_Real::queryServiceStatus(const std::string& serviceName)
{
std::vector<ServiceState> serviceStateList;
int ret,serviceStatus;
ret = rsc.ServiceList(serviceStateList);
size_t i, count=serviceStateList.size();
for (i=0; i<count; i++)
{
const ServiceState& serviceState = serviceStateList[i];
if(serviceState.name == serviceName)
{
if(serviceState.status == 0)
{
std::cout << "name: " << serviceState.name <<" is activate"<<std::endl;
serviceStatus = 1;
}
else
{
std::cout << "name:" << serviceState.name <<" is deactivate"<<std::endl;
serviceStatus = 0;
}
}
}
return serviceStatus;
}
void RL_Real::activateService(const std::string& serviceName,int activate)
{
rsc.ServiceSwitch(serviceName, activate);
}
void RL_Real::LowStateMessageHandler(const void* message)
{
unitree_low_state = *(unitree_go::msg::dds_::LowState_*)message;
}
void RL_Real::LowCmdwriteHandler(){
// this->unitree_low_command.crc() = crc32_core((uint32_t *)&unitree_low_command, (sizeof(unitree_go::msg::dds_::LowCmd_)>>2)-1);
// this->lowcmd_publisher->Write(unitree_low_command);
}
RL_Real::RL_Real(){
} }
RL_Real::~RL_Real() RL_Real::~RL_Real()
{ {
// this->loop_udpSend->shutdown();
// this->loop_udpRecv->shutdown();
this->loop_keyboard->shutdown(); this->loop_keyboard->shutdown();
this->loop_control->shutdown(); this->loop_control->shutdown();
this->loop_rl->shutdown(); this->loop_rl->shutdown();
@ -188,7 +78,7 @@ RL_Real::~RL_Real()
void RL_Real::GetState(RobotState<double> *state) void RL_Real::GetState(RobotState<double> *state)
{ {
//! 这里是不是要加锁? // TODO 加锁
memcpy(&this->unitree_joy, &this->unitree_low_state.wireless_remote()[0], 40); memcpy(&this->unitree_joy, &this->unitree_low_state.wireless_remote()[0], 40);
if ((int)this->unitree_joy.btn.components.R2 == 1) if ((int)this->unitree_joy.btn.components.R2 == 1)
@ -229,10 +119,6 @@ void RL_Real::GetState(RobotState<double> *state)
state->motor_state.dq[i] = this->unitree_low_state.motor_state()[state_mapping[i]].dq(); state->motor_state.dq[i] = this->unitree_low_state.motor_state()[state_mapping[i]].dq();
state->motor_state.tauEst[i] = this->unitree_low_state.motor_state()[state_mapping[i]].tau_est(); state->motor_state.tauEst[i] = this->unitree_low_state.motor_state()[state_mapping[i]].tau_est();
} }
//std::cout<<"(int)this->unitree_joy.btn.components.R2: "<<(int)this->unitree_joy.btn.components.R2<<std::endl;
// std::cout<<"quat: "<< this->unitree_low_state.imu_state().quaternion()[0]<<", "<< this->unitree_low_state.imu_state().quaternion()[1]<<", "<< this->unitree_low_state.imu_state().quaternion()[2]<<", "<< this->unitree_low_state.imu_state().quaternion()[3]<<std::endl;
} }
void RL_Real::SetCommand(const RobotCommand<double> *command) void RL_Real::SetCommand(const RobotCommand<double> *command)
@ -245,14 +131,11 @@ void RL_Real::SetCommand(const RobotCommand<double> *command)
this->unitree_low_command.motor_cmd()[i].kp() = command->motor_command.kp[command_mapping[i]]; this->unitree_low_command.motor_cmd()[i].kp() = command->motor_command.kp[command_mapping[i]];
this->unitree_low_command.motor_cmd()[i].kd() = command->motor_command.kd[command_mapping[i]]; this->unitree_low_command.motor_cmd()[i].kd() = command->motor_command.kd[command_mapping[i]];
this->unitree_low_command.motor_cmd()[i].tau() = command->motor_command.tau[command_mapping[i]]; this->unitree_low_command.motor_cmd()[i].tau() = command->motor_command.tau[command_mapping[i]];
//std::cout<<"q: "<<command->motor_command.q[command_mapping[i]]<<std::endl;
} }
//暂时不发 // 暂时不发 TODO Why?
this->unitree_low_command.crc() = crc32_core((uint32_t *)&unitree_low_command, (sizeof(unitree_go::msg::dds_::LowCmd_)>>2)-1); this->unitree_low_command.crc() = Crc32Core((uint32_t *)&unitree_low_command, (sizeof(unitree_go::msg::dds_::LowCmd_) >> 2) - 1);
lowcmd_publisher->Write(unitree_low_command); lowcmd_publisher->Write(unitree_low_command);
} }
void RL_Real::RobotControl() void RL_Real::RobotControl()
@ -298,8 +181,6 @@ void RL_Real::RunModel()
} }
} }
torch::Tensor RL_Real::Forward() torch::Tensor RL_Real::Forward()
{ {
torch::autograd::GradMode::set_enabled(false); torch::autograd::GradMode::set_enabled(false);
@ -311,9 +192,14 @@ torch::Tensor RL_Real::Forward()
torch::Tensor actions = this->model.forward({this->history_obs}).toTensor(); torch::Tensor actions = this->model.forward({this->history_obs}).toTensor();
torch::Tensor clamped_actions = torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper); if (this->params.clip_actions_upper.numel() != 0 && this->params.clip_actions_lower.numel() != 0)
{
return clamped_actions; return torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper);
}
else
{
return actions;
}
} }
void RL_Real::Plot() void RL_Real::Plot()
@ -337,6 +223,104 @@ void RL_Real::Plot()
plt::pause(0.0001); plt::pause(0.0001);
} }
uint32_t RL_Real::Crc32Core(uint32_t *ptr, uint32_t len)
{
unsigned int xbit = 0;
unsigned int data = 0;
unsigned int CRC32 = 0xFFFFFFFF;
const unsigned int dwPolynomial = 0x04c11db7;
for (unsigned int i = 0; i < len; i++)
{
xbit = 1 << 31;
data = ptr[i];
for (unsigned int bits = 0; bits < 32; bits++)
{
if (CRC32 & 0x80000000)
{
CRC32 <<= 1;
CRC32 ^= dwPolynomial;
}
else
{
CRC32 <<= 1;
}
if (data & xbit)
CRC32 ^= dwPolynomial;
xbit >>= 1;
}
}
return CRC32;
}
void RL_Real::InitLowCmd()
{
unitree_low_command.head()[0] = 0xFE;
unitree_low_command.head()[1] = 0xEF;
unitree_low_command.level_flag() = 0xFF;
unitree_low_command.gpio() = 0;
for (int i = 0; i < 20; i++)
{
unitree_low_command.motor_cmd()[i].mode() = (0x01); // motor switch to servo (PMSM) mode
unitree_low_command.motor_cmd()[i].q() = (PosStopF);
unitree_low_command.motor_cmd()[i].kp() = (0);
unitree_low_command.motor_cmd()[i].dq() = (VelStopF);
unitree_low_command.motor_cmd()[i].kd() = (0);
unitree_low_command.motor_cmd()[i].tau() = (0);
}
}
void RL_Real::InitRobotStateClient()
{
rsc.SetTimeout(10.0f);
rsc.Init();
}
int RL_Real::QueryServiceStatus(const std::string &serviceName)
{
std::vector<ServiceState> serviceStateList;
int ret, serviceStatus;
ret = rsc.ServiceList(serviceStateList);
size_t i, count = serviceStateList.size();
for (i = 0; i < count; i++)
{
const ServiceState &serviceState = serviceStateList[i];
if (serviceState.name == serviceName)
{
if (serviceState.status == 0)
{
std::cout << "name: " << serviceState.name << " is activate" << std::endl;
serviceStatus = 1;
}
else
{
std::cout << "name:" << serviceState.name << " is deactivate" << std::endl;
serviceStatus = 0;
}
}
}
return serviceStatus;
}
void RL_Real::ActivateService(const std::string &serviceName, int activate)
{
rsc.ServiceSwitch(serviceName, activate);
}
void RL_Real::LowStateMessageHandler(const void *message)
{
unitree_low_state = *(unitree_go::msg::dds_::LowState_ *)message;
}
void RL_Real::LowCmdwriteHandler()
{
// this->unitree_low_command.crc() = Crc32Core((uint32_t *)&unitree_low_command, (sizeof(unitree_go::msg::dds_::LowCmd_)>>2)-1);
// this->lowcmd_publisher->Write(unitree_low_command);
}
void signalHandler(int signum) void signalHandler(int signum)
{ {
exit(0); exit(0);
@ -344,28 +328,7 @@ void signalHandler(int signum)
int main(int argc, char **argv) int main(int argc, char **argv)
{ {
signal(SIGINT, signalHandler);
if (argc < 2)
{
std::cout << "Usage: " << argv[0] << " networkInterface" << std::endl;
exit(-1);
}
std::cout << "WARNING: Make sure the robot is hung up or lying on the ground." << std::endl
<< "Press Enter to continue..." << std::endl;
std::cin.ignore();
ChannelFactory::Instance()->Init(0, argv[1]);
RL_Real rl_sar;
rl_sar.InitRobotStateClient();
while(rl_sar.queryServiceStatus("sport_mode"))
{
std::cout<<"Try to deactivate the service: "<<"sport_mode"<<std::endl;
rl_sar.activateService("sport_mode",0);
sleep(1);
}
rl_sar.Init();
while (1) while (1)
{ {

View File

@ -252,6 +252,7 @@ torch::Tensor RL_Sim::Forward()
if (this->params.use_history) if (this->params.use_history)
{ {
this->history_obs_buf.insert(clamped_obs); this->history_obs_buf.insert(clamped_obs);
// TODO 这里要找一种方法适配不同的顺序不能直接改这里会导致a1的模型不可用
this->history_obs = this->history_obs_buf.get_obs_vec({5, 4, 3, 2, 1, 0}); this->history_obs = this->history_obs_buf.get_obs_vec({5, 4, 3, 2, 1, 0});
actions = this->model.forward({this->history_obs}).toTensor(); actions = this->model.forward({this->history_obs}).toTensor();
} }

View File

@ -1 +0,0 @@
controller_joint_names: ['', 'FL_hip_joint', 'Fl_thigh_joint', 'FL_calf_joint', 'FR_hip_joint', 'FR_thigh_joint', 'FR_calf_joint', 'RL_hip_joint', 'RL_thigh_joint', 'RL_calf_joint', 'RR_hip_joint', 'RR_thigh_joint', 'RR_calf_joint', ]