feat: humanoid works

This commit is contained in:
fan-ziqi 2024-05-30 12:01:32 +08:00
parent 53713805b5
commit e3dfb685a7
7 changed files with 157 additions and 102 deletions

View File

@ -1,5 +1,7 @@
a1: a1:
model_name: "model_0526.pt" model_name: "model_0526.pt"
dt: 0.005
decimation: 4
num_observations: 45 num_observations: 45
clip_obs: 100.0 clip_obs: 100.0
clip_actions_lower: [-100, -100, -100, clip_actions_lower: [-100, -100, -100,
@ -50,6 +52,8 @@ a1:
gr1t1: gr1t1:
model_name: "model_4000_jit.pt" model_name: "model_4000_jit.pt"
dt: 0.001
decimation: 20
num_observations: 39 num_observations: 39
clip_obs: 100.0 clip_obs: 100.0
clip_actions_lower: [-0.4391, -1.0491, -2.0991, -0.4391, -1.3991, clip_actions_lower: [-0.4391, -1.0491, -2.0991, -0.4391, -1.3991,
@ -64,8 +68,8 @@ gr1t1:
57.0, 43.0, 114.0, 114.0, 15.3] 57.0, 43.0, 114.0, 114.0, 15.3]
fixed_kd: [5.7, 4.3, 11.4, 11.4, 1.5, fixed_kd: [5.7, 4.3, 11.4, 11.4, 1.5,
5.7, 4.3, 11.4, 11.4, 1.5] 5.7, 4.3, 11.4, 11.4, 1.5]
hip_scale_reduction: 0.5 hip_scale_reduction: 1.0
hip_scale_reduction_indices: [0, 3, 6, 9] hip_scale_reduction_indices: []
num_of_dofs: 10 num_of_dofs: 10
action_scale: 1.0 action_scale: 1.0
lin_vel_scale: 1.0 lin_vel_scale: 1.0

View File

@ -11,6 +11,7 @@
#include <geometry_msgs/Twist.h> #include <geometry_msgs/Twist.h>
#include "robot_msgs/MotorCommand.h" #include "robot_msgs/MotorCommand.h"
#include <csignal> #include <csignal>
#include <gazebo_msgs/SetModelState.h>
#include "matplotlibcpp.h" #include "matplotlibcpp.h"
namespace plt = matplotlibcpp; namespace plt = matplotlibcpp;
@ -54,7 +55,7 @@ private:
ros::Subscriber model_state_subscriber; ros::Subscriber model_state_subscriber;
ros::Subscriber joint_state_subscriber; ros::Subscriber joint_state_subscriber;
ros::Subscriber cmd_vel_subscriber; ros::Subscriber cmd_vel_subscriber;
ros::ServiceClient gazebo_reset_client; ros::ServiceClient gazebo_set_model_state_client;
std::map<std::string, ros::Publisher> joint_publishers; std::map<std::string, ros::Publisher> joint_publishers;
std::vector<robot_msgs::MotorCommand> joint_publishers_commands; std::vector<robot_msgs::MotorCommand> joint_publishers_commands;
void ModelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg); void ModelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg);

View File

@ -3,82 +3,109 @@
#include <iostream> #include <iostream>
#include <thread> #include <thread>
#include <mutex>
#include <chrono> #include <chrono>
#include <functional> #include <functional>
#include <mutex>
#include <condition_variable>
#include <atomic>
#include <vector>
#include <sstream>
#include <iomanip>
typedef std::function<void()> Callback; class LoopFunc
class Loop {
public:
Loop(std::string name, float period, int bindCPU = -1)
: _name(name), _period(period), _bindCPU(bindCPU) {}
~Loop() {
if (_isrunning) {
shutdown(); // Ensure the loop is stopped when the object is destroyed
}
}
void start() {
_isrunning = true;
_thread = std::thread([this]() {
if (_bindCPU >= 0) {
std::lock_guard<std::mutex> lock(_printMutex);
std::cout << "[Loop Start] named: " << _name << ", period: " << _period * 1000 << " (ms), run at cpu: " << _bindCPU << std::endl;
} else {
std::lock_guard<std::mutex> lock(_printMutex);
std::cout << "[Loop Start] named: " << _name << ", period: " << _period * 1000 << " (ms), cpu unspecified" << std::endl;
}
entryFunc();
}); // Start the loop in a new thread
}
void shutdown() {
_isrunning = false;
if (_thread.joinable()) {
_thread.join(); // Wait for the loop thread to finish
std::lock_guard<std::mutex> lock(_printMutex);
std::cout << "[Loop End] named: " << _name << std::endl;
}
}
virtual void functionCB() = 0;
protected:
void entryFunc() {
while (_isrunning) {
functionCB(); // Call the overridden functionCB in a loop
std::this_thread::sleep_for(std::chrono::duration<float>(_period)); // Wait for the specified period
}
}
std::string _name;
float _period;
int _bindCPU;
bool _isrunning = false;
std::thread _thread;
static std::mutex _printMutex;
};
std::mutex Loop::_printMutex;
class LoopFunc : public Loop {
public:
LoopFunc(std::string name, float period, const Callback& cb)
: Loop(name, period), _fp(cb) {}
LoopFunc(std::string name, float period, int bindCPU, const Callback& cb)
: Loop(name, period, bindCPU), _fp(cb) {}
void functionCB() override {
{ {
std::lock_guard<std::mutex> lock(_printMutex); private:
(_fp)(); // Call the provided callback function std::string _name;
double _period;
std::function<void()> _func;
int _bindCPU;
std::atomic<bool> _running;
std::mutex _mutex;
std::condition_variable _cv;
std::thread _thread;
public:
LoopFunc(const std::string &name, double period, std::function<void()> func, int bindCPU = -1)
: _name(name), _period(period), _func(func), _bindCPU(bindCPU), _running(false) {}
void start()
{
_running = true;
log("[Loop Start] named: " + _name + ", period: " + formatPeriod() + "(ms)" + (_bindCPU != -1 ? ", run at cpu: " + std::to_string(_bindCPU) : ", cpu unspecified"));
if (_bindCPU != -1)
{
_thread = std::thread(&LoopFunc::loop, this);
setThreadAffinity(_thread.native_handle(), _bindCPU);
}
else
{
_thread = std::thread(&LoopFunc::loop, this);
}
_thread.detach();
}
void shutdown()
{
{
std::unique_lock<std::mutex> lock(_mutex);
_running = false;
_cv.notify_one();
}
if (_thread.joinable())
{
_thread.join();
}
log("[Loop End] named: " + _name);
}
private:
void loop()
{
while (_running)
{
auto start = std::chrono::steady_clock::now();
_func();
auto end = std::chrono::steady_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
auto sleepTime = std::chrono::milliseconds(static_cast<int>((_period * 1000) - elapsed.count()));
if (sleepTime.count() > 0)
{
std::unique_lock<std::mutex> lock(_mutex);
if (_cv.wait_for(lock, sleepTime, [this]{ return !_running; }))
{
break;
}
}
} }
} }
private: std::string formatPeriod() const
Callback _fp; {
std::ostringstream stream;
stream << std::fixed << std::setprecision(0) << _period * 1000;
return stream.str();
}
void log(const std::string& message)
{
static std::mutex logMutex;
std::lock_guard<std::mutex> lock(logMutex);
std::cout << message << std::endl;
}
void setThreadAffinity(std::thread::native_handle_type threadHandle, int cpuId)
{
cpu_set_t cpuset;
CPU_ZERO(&cpuset);
CPU_SET(cpuId, &cpuset);
if (pthread_setaffinity_np(threadHandle, sizeof(cpu_set_t), &cpuset) != 0)
{
std::ostringstream oss;
oss << "Error setting thread affinity: CPU " << cpuId << " may not be valid or accessible.";
throw std::runtime_error(oss.str());
}
}
}; };
#endif #endif

View File

@ -104,6 +104,7 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
start_state.motor_state.q[i] = now_state.motor_state.q[i]; start_state.motor_state.q[i] = now_state.motor_state.q[i];
} }
this->running_state = STATE_POS_GETUP; this->running_state = STATE_POS_GETUP;
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_POS_GETUP" << std::endl;
} }
} }
// stand up (position control) // stand up (position control)
@ -111,7 +112,7 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
{ {
if(getup_percent < 1.0) if(getup_percent < 1.0)
{ {
getup_percent += 1 / 1000.0; getup_percent += 1 / 500.0;
getup_percent = getup_percent > 1.0 ? 1.0 : getup_percent; getup_percent = getup_percent > 1.0 ? 1.0 : getup_percent;
for(int i = 0; i < this->params.num_of_dofs; ++i) for(int i = 0; i < this->params.num_of_dofs; ++i)
{ {
@ -125,9 +126,9 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
} }
if(this->control.control_state == STATE_RL_INIT) if(this->control.control_state == STATE_RL_INIT)
{ {
std::cout << std::endl;
this->control.control_state = STATE_WAITING; this->control.control_state = STATE_WAITING;
this->running_state = STATE_RL_INIT; this->running_state = STATE_RL_INIT;
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_RL_INIT" << std::endl;
} }
else if(this->control.control_state == STATE_POS_GETDOWN) else if(this->control.control_state == STATE_POS_GETDOWN)
{ {
@ -138,6 +139,7 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
now_state.motor_state.q[i] = state->motor_state.q[i]; now_state.motor_state.q[i] = state->motor_state.q[i];
} }
this->running_state = STATE_POS_GETDOWN; this->running_state = STATE_POS_GETDOWN;
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_POS_GETDOWN" << std::endl;
} }
} }
// init obs and start rl loop // init obs and start rl loop
@ -145,15 +147,18 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
{ {
if(getup_percent == 1) if(getup_percent == 1)
{ {
this->running_state = STATE_RL_RUNNING;
this->InitObservations(); this->InitObservations();
this->InitOutputs(); this->InitOutputs();
this->InitControl(); this->InitControl();
this->running_state = STATE_RL_RUNNING;
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_RL_RUNNING" << std::endl;
} }
} }
// rl loop // rl loop
else if(this->running_state == STATE_RL_RUNNING) else if(this->running_state == STATE_RL_RUNNING)
{ {
std::cout << LOGGER::INFO << "RL Controller x:" << this->control.x << " y:" << this->control.y << " yaw:" << this->control.yaw << " \r";
for(int i = 0; i < this->params.num_of_dofs; ++i) for(int i = 0; i < this->params.num_of_dofs; ++i)
{ {
command->motor_command.q[i] = this->output_dof_pos[0][i].item<double>(); command->motor_command.q[i] = this->output_dof_pos[0][i].item<double>();
@ -171,6 +176,18 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
now_state.motor_state.q[i] = state->motor_state.q[i]; now_state.motor_state.q[i] = state->motor_state.q[i];
} }
this->running_state = STATE_POS_GETDOWN; this->running_state = STATE_POS_GETDOWN;
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_POS_GETDOWN" << std::endl;
}
else if(this->control.control_state == STATE_POS_GETUP)
{
this->control.control_state = STATE_WAITING;
getup_percent = 0.0;
for(int i = 0; i < this->params.num_of_dofs; ++i)
{
now_state.motor_state.q[i] = state->motor_state.q[i];
}
this->running_state = STATE_POS_GETUP;
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_POS_GETUP" << std::endl;
} }
} }
// get down (position control) // get down (position control)
@ -178,7 +195,7 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
{ {
if(getdown_percent < 1.0) if(getdown_percent < 1.0)
{ {
getdown_percent += 1 / 1000.0; getdown_percent += 1 / 500.0;
getdown_percent = getdown_percent > 1.0 ? 1.0 : getdown_percent; getdown_percent = getdown_percent > 1.0 ? 1.0 : getdown_percent;
for(int i = 0; i < this->params.num_of_dofs; ++i) for(int i = 0; i < this->params.num_of_dofs; ++i)
{ {
@ -192,11 +209,11 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
} }
if(getdown_percent == 1) if(getdown_percent == 1)
{ {
std::cout << std::endl;
this->running_state = STATE_WAITING;
this->InitObservations(); this->InitObservations();
this->InitOutputs(); this->InitOutputs();
this->InitControl(); this->InitControl();
this->running_state = STATE_WAITING;
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_WAITING" << std::endl;
} }
} }
} }
@ -226,10 +243,11 @@ void RL::TorqueProtect(torch::Tensor origin_output_torques)
double limit_lower = -this->params.torque_limits[0][index].item<double>(); double limit_lower = -this->params.torque_limits[0][index].item<double>();
double limit_upper = this->params.torque_limits[0][index].item<double>(); double limit_upper = this->params.torque_limits[0][index].item<double>();
std::cout << LOGGER::ERROR << "Torque(" << index+1 << ")=" << value << " out of range(" << limit_lower << ", " << limit_upper << ")" << std::endl; std::cout << LOGGER::WARNING << "Torque(" << index+1 << ")=" << value << " out of range(" << limit_lower << ", " << limit_upper << ")" << std::endl;
std::cout << LOGGER::ERROR << "Switching to STATE_POS_GETDOWN"<< std::endl;
} }
this->control.control_state = STATE_POS_GETDOWN; // Just a reminder, no protection
// this->control.control_state = STATE_POS_GETDOWN;
// std::cout << LOGGER::INFO << "Switching to STATE_POS_GETDOWN"<< std::endl;
} }
} }
@ -254,11 +272,6 @@ static bool kbhit()
void RL::KeyboardInterface() void RL::KeyboardInterface()
{ {
if(this->running_state == STATE_RL_RUNNING)
{
std::cout << LOGGER::INFO << "RL Controller x:" << this->control.x << " y:" << this->control.y << " yaw:" << this->control.yaw << " \r";
}
if(kbhit()) if(kbhit())
{ {
int c = fgetc(stdin); int c = fgetc(stdin);
@ -308,6 +321,8 @@ void RL::ReadYaml(std::string robot_name)
} }
this->params.model_name = config["model_name"].as<std::string>(); this->params.model_name = config["model_name"].as<std::string>();
this->params.dt = config["dt"].as<double>();
this->params.decimation = config["decimation"].as<int>();
this->params.num_observations = config["num_observations"].as<int>(); this->params.num_observations = config["num_observations"].as<int>();
this->params.clip_obs = config["clip_obs"].as<double>(); this->params.clip_obs = config["clip_obs"].as<double>();
this->params.clip_actions_upper = torch::tensor(ReadVectorFromYaml<double>(config["clip_actions_upper"])).view({1, -1}); this->params.clip_actions_upper = torch::tensor(ReadVectorFromYaml<double>(config["clip_actions_upper"])).view({1, -1});

View File

@ -69,6 +69,8 @@ struct Control
struct ModelParams struct ModelParams
{ {
std::string model_name; std::string model_name;
double dt;
int decimation;
int num_observations; int num_observations;
double damping; double damping;
double stiffness; double stiffness;

View File

@ -27,17 +27,18 @@ RL_Real::RL_Real() : unitree_safe(UNITREE_LEGGED_SDK::LeggedType::A1), unitree_u
this->model = torch::jit::load(model_path); this->model = torch::jit::load(model_path);
// loop // loop
this->loop_udpSend = std::make_shared<LoopFunc>("loop_udpSend" , 0.002, std::bind(&RL_Real::UDPSend, this), 3);
this->loop_udpRecv = std::make_shared<LoopFunc>("loop_udpRecv" , 0.002, std::bind(&RL_Real::UDPRecv, this), 3);
this->loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05, std::bind(&RL_Real::KeyboardInterface, this)); this->loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05, std::bind(&RL_Real::KeyboardInterface, this));
this->loop_control = std::make_shared<LoopFunc>("loop_control" , 0.002, std::bind(&RL_Real::RobotControl , this)); this->loop_control = std::make_shared<LoopFunc>("loop_control", this->params.dt, std::bind(&RL_Real::RobotControl, this));
this->loop_udpSend = std::make_shared<LoopFunc>("loop_udpSend" , 0.002, 3, std::bind(&RL_Real::UDPSend , this)); this->loop_rl = std::make_shared<LoopFunc>("loop_rl", this->params.dt * this->params.decimation, std::bind(&RL_Real::RunModel, this));
this->loop_udpRecv = std::make_shared<LoopFunc>("loop_udpRecv" , 0.002, 3, std::bind(&RL_Real::UDPRecv , this));
this->loop_rl = std::make_shared<LoopFunc>("loop_rl" , 0.02 , std::bind(&RL_Real::RunModel , this));
this->loop_keyboard->start();
this->loop_udpSend->start(); this->loop_udpSend->start();
this->loop_udpRecv->start(); this->loop_udpRecv->start();
this->loop_keyboard->start();
this->loop_control->start(); this->loop_control->start();
this->loop_rl->start(); this->loop_rl->start();
#ifdef PLOT #ifdef PLOT
this->plot_t = std::vector<int>(this->plot_size, 0); this->plot_t = std::vector<int>(this->plot_size, 0);
this->plot_real_joint_pos.resize(this->params.num_of_dofs); this->plot_real_joint_pos.resize(this->params.num_of_dofs);
@ -54,9 +55,9 @@ RL_Real::RL_Real() : unitree_safe(UNITREE_LEGGED_SDK::LeggedType::A1), unitree_u
RL_Real::~RL_Real() RL_Real::~RL_Real()
{ {
this->loop_keyboard->shutdown();
this->loop_udpSend->shutdown(); this->loop_udpSend->shutdown();
this->loop_udpRecv->shutdown(); this->loop_udpRecv->shutdown();
this->loop_keyboard->shutdown();
this->loop_control->shutdown(); this->loop_control->shutdown();
this->loop_rl->shutdown(); this->loop_rl->shutdown();
#ifdef PLOT #ifdef PLOT

View File

@ -56,23 +56,23 @@ RL_Sim::RL_Sim()
this->joint_state_subscriber = nh.subscribe<sensor_msgs::JointState>(this->ros_namespace + "joint_states", 10, &RL_Sim::JointStatesCallback, this); this->joint_state_subscriber = nh.subscribe<sensor_msgs::JointState>(this->ros_namespace + "joint_states", 10, &RL_Sim::JointStatesCallback, this);
// service // service
this->gazebo_reset_client = nh.serviceClient<std_srvs::Empty>("/gazebo/reset_simulation"); this->gazebo_set_model_state_client = nh.serviceClient<gazebo_msgs::SetModelState>("/gazebo/set_model_state");
// loop // loop
this->loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05, std::bind(&RL_Sim::KeyboardInterface, this)); this->loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05, std::bind(&RL_Sim::KeyboardInterface, this));
this->loop_control = std::make_shared<LoopFunc>("loop_control" , 0.002, std::bind(&RL_Sim::RobotControl , this)); this->loop_control = std::make_shared<LoopFunc>("loop_control", this->params.dt, std::bind(&RL_Sim::RobotControl, this));
this->loop_rl = std::make_shared<LoopFunc>("loop_rl" , 0.02 , std::bind(&RL_Sim::RunModel , this)); this->loop_rl = std::make_shared<LoopFunc>("loop_rl", this->params.dt * this->params.decimation, std::bind(&RL_Sim::RunModel, this));
this->loop_keyboard->start(); this->loop_keyboard->start();
this->loop_control->start(); this->loop_control->start();
this->loop_rl->start(); this->loop_rl->start();
#ifdef PLOT #ifdef PLOT
plot_t = std::vector<int>(this->plot_size, 0); this->plot_t = std::vector<int>(this->plot_size, 0);
this->plot_real_joint_pos.resize(this->params.num_of_dofs); this->plot_real_joint_pos.resize(this->params.num_of_dofs);
this->plot_target_joint_pos.resize(this->params.num_of_dofs); this->plot_target_joint_pos.resize(this->params.num_of_dofs);
for(auto& vector : this->plot_real_joint_pos) { vector = std::vector<double>(this->plot_size, 0); } for(auto& vector : this->plot_real_joint_pos) { vector = std::vector<double>(this->plot_size, 0); }
for(auto& vector : this->plot_target_joint_pos) { vector = std::vector<double>(this->plot_size, 0); } for(auto& vector : this->plot_target_joint_pos) { vector = std::vector<double>(this->plot_size, 0); }
this->loop_plot = std::make_shared<LoopFunc>("loop_plot" , 0.002, std::bind(&RL_Sim::Plot, this)); this->loop_plot = std::make_shared<LoopFunc>("loop_plot" , 0.001 , std::bind(&RL_Sim::Plot , this));
this->loop_plot->start(); this->loop_plot->start();
#endif #endif
#ifdef CSV_LOGGER #ifdef CSV_LOGGER
@ -135,9 +135,14 @@ void RL_Sim::RobotControl()
if(this->control.control_state == STATE_RESET_SIMULATION) if(this->control.control_state == STATE_RESET_SIMULATION)
{ {
gazebo_msgs::SetModelState set_model_state;
std::string gazebo_model_name = this->robot_name + "_gazebo";
set_model_state.request.model_state.model_name = gazebo_model_name;
set_model_state.request.model_state.pose.position.z = 1.0;
set_model_state.request.model_state.reference_frame = "world";
this->gazebo_set_model_state_client.call(set_model_state);
this->control.control_state = STATE_WAITING; this->control.control_state = STATE_WAITING;
std_srvs::Empty srv;
this->gazebo_reset_client.call(srv);
} }
this->GetState(&this->robot_state); this->GetState(&this->robot_state);
@ -263,7 +268,7 @@ void RL_Sim::Plot()
plt::xlim(this->plot_t.front(), this->plot_t.back()); plt::xlim(this->plot_t.front(), this->plot_t.back());
} }
// plt::legend(); // plt::legend();
plt::pause(0.0001); plt::pause(0.01);
} }
void signalHandler(int signum) void signalHandler(int signum)