mirror of https://github.com/fan-ziqi/rl_sar.git
feat: humanoid works
This commit is contained in:
parent
53713805b5
commit
e3dfb685a7
|
@ -1,5 +1,7 @@
|
|||
a1:
|
||||
model_name: "model_0526.pt"
|
||||
dt: 0.005
|
||||
decimation: 4
|
||||
num_observations: 45
|
||||
clip_obs: 100.0
|
||||
clip_actions_lower: [-100, -100, -100,
|
||||
|
@ -50,6 +52,8 @@ a1:
|
|||
|
||||
gr1t1:
|
||||
model_name: "model_4000_jit.pt"
|
||||
dt: 0.001
|
||||
decimation: 20
|
||||
num_observations: 39
|
||||
clip_obs: 100.0
|
||||
clip_actions_lower: [-0.4391, -1.0491, -2.0991, -0.4391, -1.3991,
|
||||
|
@ -64,8 +68,8 @@ gr1t1:
|
|||
57.0, 43.0, 114.0, 114.0, 15.3]
|
||||
fixed_kd: [5.7, 4.3, 11.4, 11.4, 1.5,
|
||||
5.7, 4.3, 11.4, 11.4, 1.5]
|
||||
hip_scale_reduction: 0.5
|
||||
hip_scale_reduction_indices: [0, 3, 6, 9]
|
||||
hip_scale_reduction: 1.0
|
||||
hip_scale_reduction_indices: []
|
||||
num_of_dofs: 10
|
||||
action_scale: 1.0
|
||||
lin_vel_scale: 1.0
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
#include <geometry_msgs/Twist.h>
|
||||
#include "robot_msgs/MotorCommand.h"
|
||||
#include <csignal>
|
||||
#include <gazebo_msgs/SetModelState.h>
|
||||
|
||||
#include "matplotlibcpp.h"
|
||||
namespace plt = matplotlibcpp;
|
||||
|
@ -54,7 +55,7 @@ private:
|
|||
ros::Subscriber model_state_subscriber;
|
||||
ros::Subscriber joint_state_subscriber;
|
||||
ros::Subscriber cmd_vel_subscriber;
|
||||
ros::ServiceClient gazebo_reset_client;
|
||||
ros::ServiceClient gazebo_set_model_state_client;
|
||||
std::map<std::string, ros::Publisher> joint_publishers;
|
||||
std::vector<robot_msgs::MotorCommand> joint_publishers_commands;
|
||||
void ModelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg);
|
||||
|
|
|
@ -3,82 +3,109 @@
|
|||
|
||||
#include <iostream>
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
#include <chrono>
|
||||
#include <functional>
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
#include <atomic>
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
#include <iomanip>
|
||||
|
||||
typedef std::function<void()> Callback;
|
||||
class LoopFunc
|
||||
{
|
||||
private:
|
||||
std::string _name;
|
||||
double _period;
|
||||
std::function<void()> _func;
|
||||
int _bindCPU;
|
||||
std::atomic<bool> _running;
|
||||
std::mutex _mutex;
|
||||
std::condition_variable _cv;
|
||||
std::thread _thread;
|
||||
|
||||
class Loop {
|
||||
public:
|
||||
Loop(std::string name, float period, int bindCPU = -1)
|
||||
: _name(name), _period(period), _bindCPU(bindCPU) {}
|
||||
~Loop() {
|
||||
if (_isrunning) {
|
||||
shutdown(); // Ensure the loop is stopped when the object is destroyed
|
||||
}
|
||||
}
|
||||
public:
|
||||
LoopFunc(const std::string &name, double period, std::function<void()> func, int bindCPU = -1)
|
||||
: _name(name), _period(period), _func(func), _bindCPU(bindCPU), _running(false) {}
|
||||
|
||||
void start() {
|
||||
_isrunning = true;
|
||||
_thread = std::thread([this]() {
|
||||
if (_bindCPU >= 0) {
|
||||
std::lock_guard<std::mutex> lock(_printMutex);
|
||||
std::cout << "[Loop Start] named: " << _name << ", period: " << _period * 1000 << " (ms), run at cpu: " << _bindCPU << std::endl;
|
||||
} else {
|
||||
std::lock_guard<std::mutex> lock(_printMutex);
|
||||
std::cout << "[Loop Start] named: " << _name << ", period: " << _period * 1000 << " (ms), cpu unspecified" << std::endl;
|
||||
}
|
||||
entryFunc();
|
||||
}); // Start the loop in a new thread
|
||||
}
|
||||
|
||||
void shutdown() {
|
||||
_isrunning = false;
|
||||
if (_thread.joinable()) {
|
||||
_thread.join(); // Wait for the loop thread to finish
|
||||
std::lock_guard<std::mutex> lock(_printMutex);
|
||||
std::cout << "[Loop End] named: " << _name << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
virtual void functionCB() = 0;
|
||||
|
||||
protected:
|
||||
void entryFunc() {
|
||||
while (_isrunning) {
|
||||
functionCB(); // Call the overridden functionCB in a loop
|
||||
std::this_thread::sleep_for(std::chrono::duration<float>(_period)); // Wait for the specified period
|
||||
}
|
||||
}
|
||||
|
||||
std::string _name;
|
||||
float _period;
|
||||
int _bindCPU;
|
||||
bool _isrunning = false;
|
||||
std::thread _thread;
|
||||
static std::mutex _printMutex;
|
||||
};
|
||||
|
||||
std::mutex Loop::_printMutex;
|
||||
|
||||
class LoopFunc : public Loop {
|
||||
public:
|
||||
LoopFunc(std::string name, float period, const Callback& cb)
|
||||
: Loop(name, period), _fp(cb) {}
|
||||
|
||||
LoopFunc(std::string name, float period, int bindCPU, const Callback& cb)
|
||||
: Loop(name, period, bindCPU), _fp(cb) {}
|
||||
|
||||
void functionCB() override {
|
||||
void start()
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(_printMutex);
|
||||
(_fp)(); // Call the provided callback function
|
||||
_running = true;
|
||||
log("[Loop Start] named: " + _name + ", period: " + formatPeriod() + "(ms)" + (_bindCPU != -1 ? ", run at cpu: " + std::to_string(_bindCPU) : ", cpu unspecified"));
|
||||
if (_bindCPU != -1)
|
||||
{
|
||||
_thread = std::thread(&LoopFunc::loop, this);
|
||||
setThreadAffinity(_thread.native_handle(), _bindCPU);
|
||||
}
|
||||
else
|
||||
{
|
||||
_thread = std::thread(&LoopFunc::loop, this);
|
||||
}
|
||||
_thread.detach();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Callback _fp;
|
||||
void shutdown()
|
||||
{
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(_mutex);
|
||||
_running = false;
|
||||
_cv.notify_one();
|
||||
}
|
||||
if (_thread.joinable())
|
||||
{
|
||||
_thread.join();
|
||||
}
|
||||
log("[Loop End] named: " + _name);
|
||||
}
|
||||
private:
|
||||
void loop()
|
||||
{
|
||||
while (_running)
|
||||
{
|
||||
auto start = std::chrono::steady_clock::now();
|
||||
|
||||
_func();
|
||||
|
||||
auto end = std::chrono::steady_clock::now();
|
||||
auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
|
||||
auto sleepTime = std::chrono::milliseconds(static_cast<int>((_period * 1000) - elapsed.count()));
|
||||
if (sleepTime.count() > 0)
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(_mutex);
|
||||
if (_cv.wait_for(lock, sleepTime, [this]{ return !_running; }))
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string formatPeriod() const
|
||||
{
|
||||
std::ostringstream stream;
|
||||
stream << std::fixed << std::setprecision(0) << _period * 1000;
|
||||
return stream.str();
|
||||
}
|
||||
|
||||
void log(const std::string& message)
|
||||
{
|
||||
static std::mutex logMutex;
|
||||
std::lock_guard<std::mutex> lock(logMutex);
|
||||
std::cout << message << std::endl;
|
||||
}
|
||||
|
||||
void setThreadAffinity(std::thread::native_handle_type threadHandle, int cpuId)
|
||||
{
|
||||
cpu_set_t cpuset;
|
||||
CPU_ZERO(&cpuset);
|
||||
CPU_SET(cpuId, &cpuset);
|
||||
if (pthread_setaffinity_np(threadHandle, sizeof(cpu_set_t), &cpuset) != 0)
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "Error setting thread affinity: CPU " << cpuId << " may not be valid or accessible.";
|
||||
throw std::runtime_error(oss.str());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
|
@ -104,6 +104,7 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
|
|||
start_state.motor_state.q[i] = now_state.motor_state.q[i];
|
||||
}
|
||||
this->running_state = STATE_POS_GETUP;
|
||||
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_POS_GETUP" << std::endl;
|
||||
}
|
||||
}
|
||||
// stand up (position control)
|
||||
|
@ -111,7 +112,7 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
|
|||
{
|
||||
if(getup_percent < 1.0)
|
||||
{
|
||||
getup_percent += 1 / 1000.0;
|
||||
getup_percent += 1 / 500.0;
|
||||
getup_percent = getup_percent > 1.0 ? 1.0 : getup_percent;
|
||||
for(int i = 0; i < this->params.num_of_dofs; ++i)
|
||||
{
|
||||
|
@ -125,9 +126,9 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
|
|||
}
|
||||
if(this->control.control_state == STATE_RL_INIT)
|
||||
{
|
||||
std::cout << std::endl;
|
||||
this->control.control_state = STATE_WAITING;
|
||||
this->running_state = STATE_RL_INIT;
|
||||
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_RL_INIT" << std::endl;
|
||||
}
|
||||
else if(this->control.control_state == STATE_POS_GETDOWN)
|
||||
{
|
||||
|
@ -138,6 +139,7 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
|
|||
now_state.motor_state.q[i] = state->motor_state.q[i];
|
||||
}
|
||||
this->running_state = STATE_POS_GETDOWN;
|
||||
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_POS_GETDOWN" << std::endl;
|
||||
}
|
||||
}
|
||||
// init obs and start rl loop
|
||||
|
@ -145,15 +147,18 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
|
|||
{
|
||||
if(getup_percent == 1)
|
||||
{
|
||||
this->running_state = STATE_RL_RUNNING;
|
||||
this->InitObservations();
|
||||
this->InitOutputs();
|
||||
this->InitControl();
|
||||
this->running_state = STATE_RL_RUNNING;
|
||||
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_RL_RUNNING" << std::endl;
|
||||
}
|
||||
}
|
||||
// rl loop
|
||||
else if(this->running_state == STATE_RL_RUNNING)
|
||||
{
|
||||
std::cout << LOGGER::INFO << "RL Controller x:" << this->control.x << " y:" << this->control.y << " yaw:" << this->control.yaw << " \r";
|
||||
|
||||
for(int i = 0; i < this->params.num_of_dofs; ++i)
|
||||
{
|
||||
command->motor_command.q[i] = this->output_dof_pos[0][i].item<double>();
|
||||
|
@ -171,6 +176,18 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
|
|||
now_state.motor_state.q[i] = state->motor_state.q[i];
|
||||
}
|
||||
this->running_state = STATE_POS_GETDOWN;
|
||||
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_POS_GETDOWN" << std::endl;
|
||||
}
|
||||
else if(this->control.control_state == STATE_POS_GETUP)
|
||||
{
|
||||
this->control.control_state = STATE_WAITING;
|
||||
getup_percent = 0.0;
|
||||
for(int i = 0; i < this->params.num_of_dofs; ++i)
|
||||
{
|
||||
now_state.motor_state.q[i] = state->motor_state.q[i];
|
||||
}
|
||||
this->running_state = STATE_POS_GETUP;
|
||||
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_POS_GETUP" << std::endl;
|
||||
}
|
||||
}
|
||||
// get down (position control)
|
||||
|
@ -178,7 +195,7 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
|
|||
{
|
||||
if(getdown_percent < 1.0)
|
||||
{
|
||||
getdown_percent += 1 / 1000.0;
|
||||
getdown_percent += 1 / 500.0;
|
||||
getdown_percent = getdown_percent > 1.0 ? 1.0 : getdown_percent;
|
||||
for(int i = 0; i < this->params.num_of_dofs; ++i)
|
||||
{
|
||||
|
@ -192,11 +209,11 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
|
|||
}
|
||||
if(getdown_percent == 1)
|
||||
{
|
||||
std::cout << std::endl;
|
||||
this->running_state = STATE_WAITING;
|
||||
this->InitObservations();
|
||||
this->InitOutputs();
|
||||
this->InitControl();
|
||||
this->running_state = STATE_WAITING;
|
||||
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_WAITING" << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -226,10 +243,11 @@ void RL::TorqueProtect(torch::Tensor origin_output_torques)
|
|||
double limit_lower = -this->params.torque_limits[0][index].item<double>();
|
||||
double limit_upper = this->params.torque_limits[0][index].item<double>();
|
||||
|
||||
std::cout << LOGGER::ERROR << "Torque(" << index+1 << ")=" << value << " out of range(" << limit_lower << ", " << limit_upper << ")" << std::endl;
|
||||
std::cout << LOGGER::ERROR << "Switching to STATE_POS_GETDOWN"<< std::endl;
|
||||
std::cout << LOGGER::WARNING << "Torque(" << index+1 << ")=" << value << " out of range(" << limit_lower << ", " << limit_upper << ")" << std::endl;
|
||||
}
|
||||
this->control.control_state = STATE_POS_GETDOWN;
|
||||
// Just a reminder, no protection
|
||||
// this->control.control_state = STATE_POS_GETDOWN;
|
||||
// std::cout << LOGGER::INFO << "Switching to STATE_POS_GETDOWN"<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -254,11 +272,6 @@ static bool kbhit()
|
|||
|
||||
void RL::KeyboardInterface()
|
||||
{
|
||||
if(this->running_state == STATE_RL_RUNNING)
|
||||
{
|
||||
std::cout << LOGGER::INFO << "RL Controller x:" << this->control.x << " y:" << this->control.y << " yaw:" << this->control.yaw << " \r";
|
||||
}
|
||||
|
||||
if(kbhit())
|
||||
{
|
||||
int c = fgetc(stdin);
|
||||
|
@ -308,6 +321,8 @@ void RL::ReadYaml(std::string robot_name)
|
|||
}
|
||||
|
||||
this->params.model_name = config["model_name"].as<std::string>();
|
||||
this->params.dt = config["dt"].as<double>();
|
||||
this->params.decimation = config["decimation"].as<int>();
|
||||
this->params.num_observations = config["num_observations"].as<int>();
|
||||
this->params.clip_obs = config["clip_obs"].as<double>();
|
||||
this->params.clip_actions_upper = torch::tensor(ReadVectorFromYaml<double>(config["clip_actions_upper"])).view({1, -1});
|
||||
|
|
|
@ -69,6 +69,8 @@ struct Control
|
|||
struct ModelParams
|
||||
{
|
||||
std::string model_name;
|
||||
double dt;
|
||||
int decimation;
|
||||
int num_observations;
|
||||
double damping;
|
||||
double stiffness;
|
||||
|
|
|
@ -27,17 +27,18 @@ RL_Real::RL_Real() : unitree_safe(UNITREE_LEGGED_SDK::LeggedType::A1), unitree_u
|
|||
this->model = torch::jit::load(model_path);
|
||||
|
||||
// loop
|
||||
this->loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05 , std::bind(&RL_Real::KeyboardInterface, this));
|
||||
this->loop_control = std::make_shared<LoopFunc>("loop_control" , 0.002, std::bind(&RL_Real::RobotControl , this));
|
||||
this->loop_udpSend = std::make_shared<LoopFunc>("loop_udpSend" , 0.002, 3, std::bind(&RL_Real::UDPSend , this));
|
||||
this->loop_udpRecv = std::make_shared<LoopFunc>("loop_udpRecv" , 0.002, 3, std::bind(&RL_Real::UDPRecv , this));
|
||||
this->loop_rl = std::make_shared<LoopFunc>("loop_rl" , 0.02 , std::bind(&RL_Real::RunModel , this));
|
||||
this->loop_keyboard->start();
|
||||
this->loop_udpSend = std::make_shared<LoopFunc>("loop_udpSend" , 0.002, std::bind(&RL_Real::UDPSend, this), 3);
|
||||
this->loop_udpRecv = std::make_shared<LoopFunc>("loop_udpRecv" , 0.002, std::bind(&RL_Real::UDPRecv, this), 3);
|
||||
this->loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05, std::bind(&RL_Real::KeyboardInterface, this));
|
||||
this->loop_control = std::make_shared<LoopFunc>("loop_control", this->params.dt, std::bind(&RL_Real::RobotControl, this));
|
||||
this->loop_rl = std::make_shared<LoopFunc>("loop_rl", this->params.dt * this->params.decimation, std::bind(&RL_Real::RunModel, this));
|
||||
this->loop_udpSend->start();
|
||||
this->loop_udpRecv->start();
|
||||
this->loop_keyboard->start();
|
||||
this->loop_control->start();
|
||||
this->loop_rl->start();
|
||||
|
||||
|
||||
#ifdef PLOT
|
||||
this->plot_t = std::vector<int>(this->plot_size, 0);
|
||||
this->plot_real_joint_pos.resize(this->params.num_of_dofs);
|
||||
|
@ -54,9 +55,9 @@ RL_Real::RL_Real() : unitree_safe(UNITREE_LEGGED_SDK::LeggedType::A1), unitree_u
|
|||
|
||||
RL_Real::~RL_Real()
|
||||
{
|
||||
this->loop_keyboard->shutdown();
|
||||
this->loop_udpSend->shutdown();
|
||||
this->loop_udpRecv->shutdown();
|
||||
this->loop_keyboard->shutdown();
|
||||
this->loop_control->shutdown();
|
||||
this->loop_rl->shutdown();
|
||||
#ifdef PLOT
|
||||
|
|
|
@ -56,23 +56,23 @@ RL_Sim::RL_Sim()
|
|||
this->joint_state_subscriber = nh.subscribe<sensor_msgs::JointState>(this->ros_namespace + "joint_states", 10, &RL_Sim::JointStatesCallback, this);
|
||||
|
||||
// service
|
||||
this->gazebo_reset_client = nh.serviceClient<std_srvs::Empty>("/gazebo/reset_simulation");
|
||||
this->gazebo_set_model_state_client = nh.serviceClient<gazebo_msgs::SetModelState>("/gazebo/set_model_state");
|
||||
|
||||
// loop
|
||||
this->loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05 , std::bind(&RL_Sim::KeyboardInterface, this));
|
||||
this->loop_control = std::make_shared<LoopFunc>("loop_control" , 0.002, std::bind(&RL_Sim::RobotControl , this));
|
||||
this->loop_rl = std::make_shared<LoopFunc>("loop_rl" , 0.02 , std::bind(&RL_Sim::RunModel , this));
|
||||
this->loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05, std::bind(&RL_Sim::KeyboardInterface, this));
|
||||
this->loop_control = std::make_shared<LoopFunc>("loop_control", this->params.dt, std::bind(&RL_Sim::RobotControl, this));
|
||||
this->loop_rl = std::make_shared<LoopFunc>("loop_rl", this->params.dt * this->params.decimation, std::bind(&RL_Sim::RunModel, this));
|
||||
this->loop_keyboard->start();
|
||||
this->loop_control->start();
|
||||
this->loop_rl->start();
|
||||
|
||||
#ifdef PLOT
|
||||
plot_t = std::vector<int>(this->plot_size, 0);
|
||||
this->plot_t = std::vector<int>(this->plot_size, 0);
|
||||
this->plot_real_joint_pos.resize(this->params.num_of_dofs);
|
||||
this->plot_target_joint_pos.resize(this->params.num_of_dofs);
|
||||
for(auto& vector : this->plot_real_joint_pos) { vector = std::vector<double>(this->plot_size, 0); }
|
||||
for(auto& vector : this->plot_target_joint_pos) { vector = std::vector<double>(this->plot_size, 0); }
|
||||
this->loop_plot = std::make_shared<LoopFunc>("loop_plot" , 0.002, std::bind(&RL_Sim::Plot, this));
|
||||
this->loop_plot = std::make_shared<LoopFunc>("loop_plot" , 0.001 , std::bind(&RL_Sim::Plot , this));
|
||||
this->loop_plot->start();
|
||||
#endif
|
||||
#ifdef CSV_LOGGER
|
||||
|
@ -135,9 +135,14 @@ void RL_Sim::RobotControl()
|
|||
|
||||
if(this->control.control_state == STATE_RESET_SIMULATION)
|
||||
{
|
||||
gazebo_msgs::SetModelState set_model_state;
|
||||
std::string gazebo_model_name = this->robot_name + "_gazebo";
|
||||
set_model_state.request.model_state.model_name = gazebo_model_name;
|
||||
set_model_state.request.model_state.pose.position.z = 1.0;
|
||||
set_model_state.request.model_state.reference_frame = "world";
|
||||
this->gazebo_set_model_state_client.call(set_model_state);
|
||||
|
||||
this->control.control_state = STATE_WAITING;
|
||||
std_srvs::Empty srv;
|
||||
this->gazebo_reset_client.call(srv);
|
||||
}
|
||||
|
||||
this->GetState(&this->robot_state);
|
||||
|
@ -263,7 +268,7 @@ void RL_Sim::Plot()
|
|||
plt::xlim(this->plot_t.front(), this->plot_t.back());
|
||||
}
|
||||
// plt::legend();
|
||||
plt::pause(0.0001);
|
||||
plt::pause(0.01);
|
||||
}
|
||||
|
||||
void signalHandler(int signum)
|
||||
|
|
Loading…
Reference in New Issue