style: code format

This commit is contained in:
fan-ziqi 2024-10-08 19:45:59 +08:00
parent 10808335fd
commit 8b6a2d6af1
22 changed files with 317 additions and 270 deletions

View File

@ -16,6 +16,7 @@ class RL_Real : public RL
public: public:
RL_Real(); RL_Real();
~RL_Real(); ~RL_Real();
private: private:
// rl functions // rl functions
torch::Tensor Forward() override; torch::Tensor Forward() override;
@ -59,4 +60,4 @@ private:
int state_mapping[12] = {3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8}; int state_mapping[12] = {3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8};
}; };
#endif #endif // RL_REAL_HPP

View File

@ -21,6 +21,7 @@ class RL_Sim : public RL
public: public:
RL_Sim(); RL_Sim();
~RL_Sim(); ~RL_Sim();
private: private:
// rl functions // rl functions
torch::Tensor Forward() override; torch::Tensor Forward() override;
@ -72,4 +73,4 @@ private:
void MapData(const std::vector<double> &source_data, std::vector<double> &target_data); void MapData(const std::vector<double> &source_data, std::vector<double> &target_data);
}; };
#endif #endif // RL_SIM_HPP

View File

@ -14,16 +14,6 @@
class LoopFunc class LoopFunc
{ {
private:
std::string _name;
double _period;
std::function<void()> _func;
int _bindCPU;
std::atomic<bool> _running;
std::mutex _mutex;
std::condition_variable _cv;
std::thread _thread;
public: public:
LoopFunc(const std::string &name, double period, std::function<void()> func, int bindCPU = -1) LoopFunc(const std::string &name, double period, std::function<void()> func, int bindCPU = -1)
: _name(name), _period(period), _func(func), _bindCPU(bindCPU), _running(false) {} : _name(name), _period(period), _func(func), _bindCPU(bindCPU), _running(false) {}
@ -57,7 +47,17 @@ class LoopFunc
} }
log("[Loop End] named: " + _name); log("[Loop End] named: " + _name);
} }
private: private:
std::string _name;
double _period;
std::function<void()> _func;
int _bindCPU;
std::atomic<bool> _running;
std::mutex _mutex;
std::condition_variable _cv;
std::thread _thread;
void loop() void loop()
{ {
while (_running) while (_running)
@ -72,7 +72,8 @@ class LoopFunc
if (sleepTime.count() > 0) if (sleepTime.count() > 0)
{ {
std::unique_lock<std::mutex> lock(_mutex); std::unique_lock<std::mutex> lock(_mutex);
if (_cv.wait_for(lock, sleepTime, [this]{ return !_running; })) if (_cv.wait_for(lock, sleepTime, [this]
{ return !_running; }))
{ {
break; break;
} }
@ -108,4 +109,4 @@ class LoopFunc
} }
}; };
#endif #endif // LOOP_H

View File

@ -16,7 +16,8 @@ ObservationBuffer::ObservationBuffer(int num_envs,
void ObservationBuffer::reset(std::vector<int> reset_idxs, torch::Tensor new_obs) void ObservationBuffer::reset(std::vector<int> reset_idxs, torch::Tensor new_obs)
{ {
std::vector<torch::indexing::TensorIndex> indices; std::vector<torch::indexing::TensorIndex> indices;
for (int idx : reset_idxs) { for (int idx : reset_idxs)
{
indices.push_back(torch::indexing::Slice(idx)); indices.push_back(torch::indexing::Slice(idx));
} }
obs_buf.index_put_(indices, new_obs.repeat({1, include_history_steps})); obs_buf.index_put_(indices, new_obs.repeat({1, include_history_steps}));

View File

@ -4,7 +4,8 @@
#include <torch/torch.h> #include <torch/torch.h>
#include <vector> #include <vector>
class ObservationBuffer { class ObservationBuffer
{
public: public:
ObservationBuffer(int num_envs, int num_obs, int include_history_steps); ObservationBuffer(int num_envs, int num_obs, int include_history_steps);
ObservationBuffer(); ObservationBuffer();

View File

@ -310,22 +310,52 @@ void RL::KeyboardInterface()
int c = fgetc(stdin); int c = fgetc(stdin);
switch (c) switch (c)
{ {
case '0': this->control.control_state = STATE_POS_GETUP; break; case '0':
case 'p': this->control.control_state = STATE_RL_INIT; break; this->control.control_state = STATE_POS_GETUP;
case '1': this->control.control_state = STATE_POS_GETDOWN; break; break;
case 'q': break; case 'p':
case 'w': this->control.x += 0.1; break; this->control.control_state = STATE_RL_INIT;
case 's': this->control.x -= 0.1; break; break;
case 'a': this->control.yaw += 0.1; break; case '1':
case 'd': this->control.yaw -= 0.1; break; this->control.control_state = STATE_POS_GETDOWN;
case 'i': break; break;
case 'k': break; case 'q':
case 'j': this->control.y += 0.1; break; break;
case 'l': this->control.y -= 0.1; break; case 'w':
case ' ': this->control.x = 0; this->control.y = 0; this->control.yaw = 0; break; this->control.x += 0.1;
case 'r': this->control.control_state = STATE_RESET_SIMULATION; break; break;
case '\n': this->control.control_state = STATE_TOGGLE_SIMULATION; break; case 's':
default: break; this->control.x -= 0.1;
break;
case 'a':
this->control.yaw += 0.1;
break;
case 'd':
this->control.yaw -= 0.1;
break;
case 'i':
break;
case 'k':
break;
case 'j':
this->control.y += 0.1;
break;
case 'l':
this->control.y -= 0.1;
break;
case ' ':
this->control.x = 0;
this->control.y = 0;
this->control.yaw = 0;
break;
case 'r':
this->control.control_state = STATE_RESET_SIMULATION;
break;
case '\n':
this->control.control_state = STATE_TOGGLE_SIMULATION;
break;
default:
break;
} }
} }
} }
@ -380,7 +410,8 @@ void RL::ReadYaml(std::string robot_name)
try try
{ {
config = YAML::LoadFile(config_path)[robot_name]; config = YAML::LoadFile(config_path)[robot_name];
} catch(YAML::BadFile &e) }
catch (YAML::BadFile &e)
{ {
std::cout << LOGGER::ERROR << "The file '" << config_path << "' does not exist" << std::endl; std::cout << LOGGER::ERROR << "The file '" << config_path << "' does not exist" << std::endl;
return; return;

View File

@ -8,7 +8,8 @@
#include <yaml-cpp/yaml.h> #include <yaml-cpp/yaml.h>
namespace LOGGER { namespace LOGGER
{
const char *const INFO = "\033[0;37m[INFO]\033[0m "; const char *const INFO = "\033[0;37m[INFO]\033[0m ";
const char *const WARNING = "\033[0;33m[WARNING]\033[0m "; const char *const WARNING = "\033[0;33m[WARNING]\033[0m ";
const char *const ERROR = "\033[0;31m[ERROR]\033[0m "; const char *const ERROR = "\033[0;31m[ERROR]\033[0m ";
@ -48,7 +49,8 @@ struct RobotState
} motor_state; } motor_state;
}; };
enum STATE { enum STATE
{
STATE_WAITING = 0, STATE_WAITING = 0,
STATE_POS_GETUP, STATE_POS_GETUP,
STATE_RL_INIT, STATE_RL_INIT,
@ -165,4 +167,4 @@ protected:
torch::Tensor output_dof_pos; torch::Tensor output_dof_pos;
}; };
#endif #endif // RL_SDK_HPP

View File

@ -234,4 +234,3 @@ class RL_Sim(RL):
if __name__ == "__main__": if __name__ == "__main__":
rl_sim = RL_Sim() rl_sim = RL_Sim()
rospy.spin() rospy.spin()

View File

@ -39,7 +39,6 @@ RL_Real::RL_Real() : unitree_safe(UNITREE_LEGGED_SDK::LeggedType::A1), unitree_u
this->loop_control->start(); this->loop_control->start();
this->loop_rl->start(); this->loop_rl->start();
#ifdef PLOT #ifdef PLOT
this->plot_t = std::vector<int>(this->plot_size, 0); this->plot_t = std::vector<int>(this->plot_size, 0);
this->plot_real_joint_pos.resize(this->params.num_of_dofs); this->plot_real_joint_pos.resize(this->params.num_of_dofs);
@ -225,7 +224,7 @@ int main(int argc, char **argv)
while (1) while (1)
{ {
sleep(10); sleep(10);
}; }
return 0; return 0;
} }

View File

@ -46,8 +46,8 @@ RL_Sim::RL_Sim()
for (int i = 0; i < this->params.num_of_dofs; ++i) for (int i = 0; i < this->params.num_of_dofs; ++i)
{ {
// joint need to rename as xxx_joint // joint need to rename as xxx_joint
this->joint_publishers[this->params.joint_controller_names[i]] = nh.advertise<robot_msgs::MotorCommand>( this->joint_publishers[this->params.joint_controller_names[i]] =
this->ros_namespace + this->params.joint_controller_names[i] + "/command", 10); nh.advertise<robot_msgs::MotorCommand>(this->ros_namespace + this->params.joint_controller_names[i] + "/command", 10);
} }
// subscriber // subscriber

View File

@ -66,7 +66,6 @@ public:
void setGains(const double &p, const double &i, const double &d, const double &i_max, const double &i_min, const bool &antiwindup = false); void setGains(const double &p, const double &i, const double &d, const double &i_max, const double &i_min, const bool &antiwindup = false);
void getGains(double &p, double &i, double &d, double &i_max, double &i_min, bool &antiwindup); void getGains(double &p, double &i, double &d, double &i_max, double &i_min, bool &antiwindup);
void getGains(double &p, double &i, double &d, double &i_max, double &i_min); void getGains(double &p, double &i, double &d, double &i_max, double &i_min);
}; };
} }

View File

@ -3,10 +3,14 @@
// #define rqtTune // use rqt or not // #define rqtTune // use rqt or not
double clamp(double& value, double min, double max) { double clamp(double &value, double min, double max)
if (value < min) { {
if (value < min)
{
value = min; value = min;
} else if (value > max) { }
else if (value > max)
{
value = max; value = max;
} }
return value; return value;
@ -15,13 +19,15 @@ double clamp(double& value, double min, double max) {
namespace robot_joint_controller namespace robot_joint_controller
{ {
RobotJointController::RobotJointController(){ RobotJointController::RobotJointController()
{
memset(&lastCommand, 0, sizeof(robot_msgs::MotorCommand)); memset(&lastCommand, 0, sizeof(robot_msgs::MotorCommand));
memset(&lastState, 0, sizeof(robot_msgs::MotorState)); memset(&lastState, 0, sizeof(robot_msgs::MotorState));
memset(&servoCommand, 0, sizeof(ServoCommand)); memset(&servoCommand, 0, sizeof(ServoCommand));
} }
RobotJointController::~RobotJointController(){ RobotJointController::~RobotJointController()
{
sub_ft.shutdown(); sub_ft.shutdown();
sub_command.shutdown(); sub_command.shutdown();
} }
@ -43,7 +49,8 @@ namespace robot_joint_controller
bool RobotJointController::init(hardware_interface::EffortJointInterface *robot, ros::NodeHandle &n) bool RobotJointController::init(hardware_interface::EffortJointInterface *robot, ros::NodeHandle &n)
{ {
name_space = n.getNamespace(); name_space = n.getNamespace();
if (!n.getParam("joint", joint_name)){ if (!n.getParam("joint", joint_name))
{
ROS_ERROR("No joint given in namespace: '%s')", n.getNamespace().c_str()); ROS_ERROR("No joint given in namespace: '%s')", n.getNamespace().c_str());
return false; return false;
} }
@ -56,12 +63,14 @@ namespace robot_joint_controller
#endif #endif
urdf::Model urdf; // Get URDF info about joint urdf::Model urdf; // Get URDF info about joint
if (!urdf.initParamWithNodeHandle("robot_description", n)){ if (!urdf.initParamWithNodeHandle("robot_description", n))
{
ROS_ERROR("Failed to parse urdf file"); ROS_ERROR("Failed to parse urdf file");
return false; return false;
} }
joint_urdf = urdf.getJoint(joint_name); joint_urdf = urdf.getJoint(joint_name);
if (!joint_urdf){ if (!joint_urdf)
{
ROS_ERROR("Could not find joint '%s' in urdf", joint_name.c_str()); ROS_ERROR("Could not find joint '%s' in urdf", joint_name.c_str());
return false; return false;
} }
@ -118,13 +127,15 @@ namespace robot_joint_controller
servoCommand.pos = lastCommand.q; servoCommand.pos = lastCommand.q;
positionLimits(servoCommand.pos); positionLimits(servoCommand.pos);
servoCommand.posStiffness = lastCommand.kp; servoCommand.posStiffness = lastCommand.kp;
if(fabs(lastCommand.q - PosStopF) < 0.00001){ if (fabs(lastCommand.q - PosStopF) < 0.00001)
{
servoCommand.posStiffness = 0; servoCommand.posStiffness = 0;
} }
servoCommand.vel = lastCommand.dq; servoCommand.vel = lastCommand.dq;
velocityLimits(servoCommand.vel); velocityLimits(servoCommand.vel);
servoCommand.velStiffness = lastCommand.kd; servoCommand.velStiffness = lastCommand.kd;
if(fabs(lastCommand.dq - VelStopF) < 0.00001){ if (fabs(lastCommand.dq - VelStopF) < 0.00001)
{
servoCommand.velStiffness = 0; servoCommand.velStiffness = 0;
} }
servoCommand.torque = lastCommand.tau; servoCommand.torque = lastCommand.tau;
@ -151,7 +162,8 @@ namespace robot_joint_controller
lastState.tauEst = joint.getEffort(); lastState.tauEst = joint.getEffort();
// publish state // publish state
if (controller_state_publisher_ && controller_state_publisher_->trylock()) { if (controller_state_publisher_ && controller_state_publisher_->trylock())
{
controller_state_publisher_->msg_.q = lastState.q; controller_state_publisher_->msg_.q = lastState.q;
controller_state_publisher_->msg_.dq = lastState.dq; controller_state_publisher_->msg_.dq = lastState.dq;
controller_state_publisher_->msg_.tauEst = lastState.tauEst; controller_state_publisher_->msg_.tauEst = lastState.tauEst;