style: code format

This commit is contained in:
fan-ziqi 2024-10-08 19:45:59 +08:00
parent 10808335fd
commit 8b6a2d6af1
22 changed files with 317 additions and 270 deletions

View File

@ -16,6 +16,7 @@ class RL_Real : public RL
public:
RL_Real();
~RL_Real();
private:
// rl functions
torch::Tensor Forward() override;
@ -59,4 +60,4 @@ private:
int state_mapping[12] = {3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8};
};
#endif
#endif // RL_REAL_HPP

View File

@ -21,6 +21,7 @@ class RL_Sim : public RL
public:
RL_Sim();
~RL_Sim();
private:
// rl functions
torch::Tensor Forward() override;
@ -72,4 +73,4 @@ private:
void MapData(const std::vector<double> &source_data, std::vector<double> &target_data);
};
#endif
#endif // RL_SIM_HPP

View File

@ -14,16 +14,6 @@
class LoopFunc
{
private:
std::string _name;
double _period;
std::function<void()> _func;
int _bindCPU;
std::atomic<bool> _running;
std::mutex _mutex;
std::condition_variable _cv;
std::thread _thread;
public:
LoopFunc(const std::string &name, double period, std::function<void()> func, int bindCPU = -1)
: _name(name), _period(period), _func(func), _bindCPU(bindCPU), _running(false) {}
@ -57,7 +47,17 @@ class LoopFunc
}
log("[Loop End] named: " + _name);
}
private:
std::string _name;
double _period;
std::function<void()> _func;
int _bindCPU;
std::atomic<bool> _running;
std::mutex _mutex;
std::condition_variable _cv;
std::thread _thread;
void loop()
{
while (_running)
@ -72,7 +72,8 @@ class LoopFunc
if (sleepTime.count() > 0)
{
std::unique_lock<std::mutex> lock(_mutex);
if (_cv.wait_for(lock, sleepTime, [this]{ return !_running; }))
if (_cv.wait_for(lock, sleepTime, [this]
{ return !_running; }))
{
break;
}
@ -108,4 +109,4 @@ class LoopFunc
}
};
#endif
#endif // LOOP_H

View File

@ -16,7 +16,8 @@ ObservationBuffer::ObservationBuffer(int num_envs,
void ObservationBuffer::reset(std::vector<int> reset_idxs, torch::Tensor new_obs)
{
std::vector<torch::indexing::TensorIndex> indices;
for (int idx : reset_idxs) {
for (int idx : reset_idxs)
{
indices.push_back(torch::indexing::Slice(idx));
}
obs_buf.index_put_(indices, new_obs.repeat({1, include_history_steps}));

View File

@ -4,7 +4,8 @@
#include <torch/torch.h>
#include <vector>
class ObservationBuffer {
class ObservationBuffer
{
public:
ObservationBuffer(int num_envs, int num_obs, int include_history_steps);
ObservationBuffer();

View File

@ -310,22 +310,52 @@ void RL::KeyboardInterface()
int c = fgetc(stdin);
switch (c)
{
case '0': this->control.control_state = STATE_POS_GETUP; break;
case 'p': this->control.control_state = STATE_RL_INIT; break;
case '1': this->control.control_state = STATE_POS_GETDOWN; break;
case 'q': break;
case 'w': this->control.x += 0.1; break;
case 's': this->control.x -= 0.1; break;
case 'a': this->control.yaw += 0.1; break;
case 'd': this->control.yaw -= 0.1; break;
case 'i': break;
case 'k': break;
case 'j': this->control.y += 0.1; break;
case 'l': this->control.y -= 0.1; break;
case ' ': this->control.x = 0; this->control.y = 0; this->control.yaw = 0; break;
case 'r': this->control.control_state = STATE_RESET_SIMULATION; break;
case '\n': this->control.control_state = STATE_TOGGLE_SIMULATION; break;
default: break;
case '0':
this->control.control_state = STATE_POS_GETUP;
break;
case 'p':
this->control.control_state = STATE_RL_INIT;
break;
case '1':
this->control.control_state = STATE_POS_GETDOWN;
break;
case 'q':
break;
case 'w':
this->control.x += 0.1;
break;
case 's':
this->control.x -= 0.1;
break;
case 'a':
this->control.yaw += 0.1;
break;
case 'd':
this->control.yaw -= 0.1;
break;
case 'i':
break;
case 'k':
break;
case 'j':
this->control.y += 0.1;
break;
case 'l':
this->control.y -= 0.1;
break;
case ' ':
this->control.x = 0;
this->control.y = 0;
this->control.yaw = 0;
break;
case 'r':
this->control.control_state = STATE_RESET_SIMULATION;
break;
case '\n':
this->control.control_state = STATE_TOGGLE_SIMULATION;
break;
default:
break;
}
}
}
@ -380,7 +410,8 @@ void RL::ReadYaml(std::string robot_name)
try
{
config = YAML::LoadFile(config_path)[robot_name];
} catch(YAML::BadFile &e)
}
catch (YAML::BadFile &e)
{
std::cout << LOGGER::ERROR << "The file '" << config_path << "' does not exist" << std::endl;
return;

View File

@ -8,7 +8,8 @@
#include <yaml-cpp/yaml.h>
namespace LOGGER {
namespace LOGGER
{
const char *const INFO = "\033[0;37m[INFO]\033[0m ";
const char *const WARNING = "\033[0;33m[WARNING]\033[0m ";
const char *const ERROR = "\033[0;31m[ERROR]\033[0m ";
@ -48,7 +49,8 @@ struct RobotState
} motor_state;
};
enum STATE {
enum STATE
{
STATE_WAITING = 0,
STATE_POS_GETUP,
STATE_RL_INIT,
@ -165,4 +167,4 @@ protected:
torch::Tensor output_dof_pos;
};
#endif
#endif // RL_SDK_HPP

View File

@ -234,4 +234,3 @@ class RL_Sim(RL):
if __name__ == "__main__":
rl_sim = RL_Sim()
rospy.spin()

View File

@ -39,7 +39,6 @@ RL_Real::RL_Real() : unitree_safe(UNITREE_LEGGED_SDK::LeggedType::A1), unitree_u
this->loop_control->start();
this->loop_rl->start();
#ifdef PLOT
this->plot_t = std::vector<int>(this->plot_size, 0);
this->plot_real_joint_pos.resize(this->params.num_of_dofs);
@ -225,7 +224,7 @@ int main(int argc, char **argv)
while (1)
{
sleep(10);
};
}
return 0;
}

View File

@ -46,8 +46,8 @@ RL_Sim::RL_Sim()
for (int i = 0; i < this->params.num_of_dofs; ++i)
{
// joint need to rename as xxx_joint
this->joint_publishers[this->params.joint_controller_names[i]] = nh.advertise<robot_msgs::MotorCommand>(
this->ros_namespace + this->params.joint_controller_names[i] + "/command", 10);
this->joint_publishers[this->params.joint_controller_names[i]] =
nh.advertise<robot_msgs::MotorCommand>(this->ros_namespace + this->params.joint_controller_names[i] + "/command", 10);
}
// subscriber

View File

@ -66,7 +66,6 @@ public:
void setGains(const double &p, const double &i, const double &d, const double &i_max, const double &i_min, const bool &antiwindup = false);
void getGains(double &p, double &i, double &d, double &i_max, double &i_min, bool &antiwindup);
void getGains(double &p, double &i, double &d, double &i_max, double &i_min);
};
}

View File

@ -3,10 +3,14 @@
// #define rqtTune // use rqt or not
double clamp(double& value, double min, double max) {
if (value < min) {
double clamp(double &value, double min, double max)
{
if (value < min)
{
value = min;
} else if (value > max) {
}
else if (value > max)
{
value = max;
}
return value;
@ -15,13 +19,15 @@ double clamp(double& value, double min, double max) {
namespace robot_joint_controller
{
RobotJointController::RobotJointController(){
RobotJointController::RobotJointController()
{
memset(&lastCommand, 0, sizeof(robot_msgs::MotorCommand));
memset(&lastState, 0, sizeof(robot_msgs::MotorState));
memset(&servoCommand, 0, sizeof(ServoCommand));
}
RobotJointController::~RobotJointController(){
RobotJointController::~RobotJointController()
{
sub_ft.shutdown();
sub_command.shutdown();
}
@ -43,7 +49,8 @@ namespace robot_joint_controller
bool RobotJointController::init(hardware_interface::EffortJointInterface *robot, ros::NodeHandle &n)
{
name_space = n.getNamespace();
if (!n.getParam("joint", joint_name)){
if (!n.getParam("joint", joint_name))
{
ROS_ERROR("No joint given in namespace: '%s')", n.getNamespace().c_str());
return false;
}
@ -56,12 +63,14 @@ namespace robot_joint_controller
#endif
urdf::Model urdf; // Get URDF info about joint
if (!urdf.initParamWithNodeHandle("robot_description", n)){
if (!urdf.initParamWithNodeHandle("robot_description", n))
{
ROS_ERROR("Failed to parse urdf file");
return false;
}
joint_urdf = urdf.getJoint(joint_name);
if (!joint_urdf){
if (!joint_urdf)
{
ROS_ERROR("Could not find joint '%s' in urdf", joint_name.c_str());
return false;
}
@ -118,13 +127,15 @@ namespace robot_joint_controller
servoCommand.pos = lastCommand.q;
positionLimits(servoCommand.pos);
servoCommand.posStiffness = lastCommand.kp;
if(fabs(lastCommand.q - PosStopF) < 0.00001){
if (fabs(lastCommand.q - PosStopF) < 0.00001)
{
servoCommand.posStiffness = 0;
}
servoCommand.vel = lastCommand.dq;
velocityLimits(servoCommand.vel);
servoCommand.velStiffness = lastCommand.kd;
if(fabs(lastCommand.dq - VelStopF) < 0.00001){
if (fabs(lastCommand.dq - VelStopF) < 0.00001)
{
servoCommand.velStiffness = 0;
}
servoCommand.torque = lastCommand.tau;
@ -151,7 +162,8 @@ namespace robot_joint_controller
lastState.tauEst = joint.getEffort();
// publish state
if (controller_state_publisher_ && controller_state_publisher_->trylock()) {
if (controller_state_publisher_ && controller_state_publisher_->trylock())
{
controller_state_publisher_->msg_.q = lastState.q;
controller_state_publisher_->msg_.dq = lastState.dq;
controller_state_publisher_->msg_.tauEst = lastState.tauEst;