mirror of https://github.com/fan-ziqi/rl_sar.git
style: code format
This commit is contained in:
parent
10808335fd
commit
8b6a2d6af1
|
@ -16,6 +16,7 @@ class RL_Real : public RL
|
||||||
public:
|
public:
|
||||||
RL_Real();
|
RL_Real();
|
||||||
~RL_Real();
|
~RL_Real();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// rl functions
|
// rl functions
|
||||||
torch::Tensor Forward() override;
|
torch::Tensor Forward() override;
|
||||||
|
@ -59,4 +60,4 @@ private:
|
||||||
int state_mapping[12] = {3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8};
|
int state_mapping[12] = {3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8};
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif
|
#endif // RL_REAL_HPP
|
||||||
|
|
|
@ -21,6 +21,7 @@ class RL_Sim : public RL
|
||||||
public:
|
public:
|
||||||
RL_Sim();
|
RL_Sim();
|
||||||
~RL_Sim();
|
~RL_Sim();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// rl functions
|
// rl functions
|
||||||
torch::Tensor Forward() override;
|
torch::Tensor Forward() override;
|
||||||
|
@ -72,4 +73,4 @@ private:
|
||||||
void MapData(const std::vector<double> &source_data, std::vector<double> &target_data);
|
void MapData(const std::vector<double> &source_data, std::vector<double> &target_data);
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif
|
#endif // RL_SIM_HPP
|
||||||
|
|
|
@ -14,16 +14,6 @@
|
||||||
|
|
||||||
class LoopFunc
|
class LoopFunc
|
||||||
{
|
{
|
||||||
private:
|
|
||||||
std::string _name;
|
|
||||||
double _period;
|
|
||||||
std::function<void()> _func;
|
|
||||||
int _bindCPU;
|
|
||||||
std::atomic<bool> _running;
|
|
||||||
std::mutex _mutex;
|
|
||||||
std::condition_variable _cv;
|
|
||||||
std::thread _thread;
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
LoopFunc(const std::string &name, double period, std::function<void()> func, int bindCPU = -1)
|
LoopFunc(const std::string &name, double period, std::function<void()> func, int bindCPU = -1)
|
||||||
: _name(name), _period(period), _func(func), _bindCPU(bindCPU), _running(false) {}
|
: _name(name), _period(period), _func(func), _bindCPU(bindCPU), _running(false) {}
|
||||||
|
@ -57,7 +47,17 @@ class LoopFunc
|
||||||
}
|
}
|
||||||
log("[Loop End] named: " + _name);
|
log("[Loop End] named: " + _name);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
std::string _name;
|
||||||
|
double _period;
|
||||||
|
std::function<void()> _func;
|
||||||
|
int _bindCPU;
|
||||||
|
std::atomic<bool> _running;
|
||||||
|
std::mutex _mutex;
|
||||||
|
std::condition_variable _cv;
|
||||||
|
std::thread _thread;
|
||||||
|
|
||||||
void loop()
|
void loop()
|
||||||
{
|
{
|
||||||
while (_running)
|
while (_running)
|
||||||
|
@ -72,7 +72,8 @@ class LoopFunc
|
||||||
if (sleepTime.count() > 0)
|
if (sleepTime.count() > 0)
|
||||||
{
|
{
|
||||||
std::unique_lock<std::mutex> lock(_mutex);
|
std::unique_lock<std::mutex> lock(_mutex);
|
||||||
if (_cv.wait_for(lock, sleepTime, [this]{ return !_running; }))
|
if (_cv.wait_for(lock, sleepTime, [this]
|
||||||
|
{ return !_running; }))
|
||||||
{
|
{
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -108,4 +109,4 @@ class LoopFunc
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif
|
#endif // LOOP_H
|
||||||
|
|
|
@ -16,7 +16,8 @@ ObservationBuffer::ObservationBuffer(int num_envs,
|
||||||
void ObservationBuffer::reset(std::vector<int> reset_idxs, torch::Tensor new_obs)
|
void ObservationBuffer::reset(std::vector<int> reset_idxs, torch::Tensor new_obs)
|
||||||
{
|
{
|
||||||
std::vector<torch::indexing::TensorIndex> indices;
|
std::vector<torch::indexing::TensorIndex> indices;
|
||||||
for (int idx : reset_idxs) {
|
for (int idx : reset_idxs)
|
||||||
|
{
|
||||||
indices.push_back(torch::indexing::Slice(idx));
|
indices.push_back(torch::indexing::Slice(idx));
|
||||||
}
|
}
|
||||||
obs_buf.index_put_(indices, new_obs.repeat({1, include_history_steps}));
|
obs_buf.index_put_(indices, new_obs.repeat({1, include_history_steps}));
|
||||||
|
|
|
@ -4,7 +4,8 @@
|
||||||
#include <torch/torch.h>
|
#include <torch/torch.h>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
class ObservationBuffer {
|
class ObservationBuffer
|
||||||
|
{
|
||||||
public:
|
public:
|
||||||
ObservationBuffer(int num_envs, int num_obs, int include_history_steps);
|
ObservationBuffer(int num_envs, int num_obs, int include_history_steps);
|
||||||
ObservationBuffer();
|
ObservationBuffer();
|
||||||
|
|
|
@ -310,22 +310,52 @@ void RL::KeyboardInterface()
|
||||||
int c = fgetc(stdin);
|
int c = fgetc(stdin);
|
||||||
switch (c)
|
switch (c)
|
||||||
{
|
{
|
||||||
case '0': this->control.control_state = STATE_POS_GETUP; break;
|
case '0':
|
||||||
case 'p': this->control.control_state = STATE_RL_INIT; break;
|
this->control.control_state = STATE_POS_GETUP;
|
||||||
case '1': this->control.control_state = STATE_POS_GETDOWN; break;
|
break;
|
||||||
case 'q': break;
|
case 'p':
|
||||||
case 'w': this->control.x += 0.1; break;
|
this->control.control_state = STATE_RL_INIT;
|
||||||
case 's': this->control.x -= 0.1; break;
|
break;
|
||||||
case 'a': this->control.yaw += 0.1; break;
|
case '1':
|
||||||
case 'd': this->control.yaw -= 0.1; break;
|
this->control.control_state = STATE_POS_GETDOWN;
|
||||||
case 'i': break;
|
break;
|
||||||
case 'k': break;
|
case 'q':
|
||||||
case 'j': this->control.y += 0.1; break;
|
break;
|
||||||
case 'l': this->control.y -= 0.1; break;
|
case 'w':
|
||||||
case ' ': this->control.x = 0; this->control.y = 0; this->control.yaw = 0; break;
|
this->control.x += 0.1;
|
||||||
case 'r': this->control.control_state = STATE_RESET_SIMULATION; break;
|
break;
|
||||||
case '\n': this->control.control_state = STATE_TOGGLE_SIMULATION; break;
|
case 's':
|
||||||
default: break;
|
this->control.x -= 0.1;
|
||||||
|
break;
|
||||||
|
case 'a':
|
||||||
|
this->control.yaw += 0.1;
|
||||||
|
break;
|
||||||
|
case 'd':
|
||||||
|
this->control.yaw -= 0.1;
|
||||||
|
break;
|
||||||
|
case 'i':
|
||||||
|
break;
|
||||||
|
case 'k':
|
||||||
|
break;
|
||||||
|
case 'j':
|
||||||
|
this->control.y += 0.1;
|
||||||
|
break;
|
||||||
|
case 'l':
|
||||||
|
this->control.y -= 0.1;
|
||||||
|
break;
|
||||||
|
case ' ':
|
||||||
|
this->control.x = 0;
|
||||||
|
this->control.y = 0;
|
||||||
|
this->control.yaw = 0;
|
||||||
|
break;
|
||||||
|
case 'r':
|
||||||
|
this->control.control_state = STATE_RESET_SIMULATION;
|
||||||
|
break;
|
||||||
|
case '\n':
|
||||||
|
this->control.control_state = STATE_TOGGLE_SIMULATION;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -380,7 +410,8 @@ void RL::ReadYaml(std::string robot_name)
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
config = YAML::LoadFile(config_path)[robot_name];
|
config = YAML::LoadFile(config_path)[robot_name];
|
||||||
} catch(YAML::BadFile &e)
|
}
|
||||||
|
catch (YAML::BadFile &e)
|
||||||
{
|
{
|
||||||
std::cout << LOGGER::ERROR << "The file '" << config_path << "' does not exist" << std::endl;
|
std::cout << LOGGER::ERROR << "The file '" << config_path << "' does not exist" << std::endl;
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -8,7 +8,8 @@
|
||||||
|
|
||||||
#include <yaml-cpp/yaml.h>
|
#include <yaml-cpp/yaml.h>
|
||||||
|
|
||||||
namespace LOGGER {
|
namespace LOGGER
|
||||||
|
{
|
||||||
const char *const INFO = "\033[0;37m[INFO]\033[0m ";
|
const char *const INFO = "\033[0;37m[INFO]\033[0m ";
|
||||||
const char *const WARNING = "\033[0;33m[WARNING]\033[0m ";
|
const char *const WARNING = "\033[0;33m[WARNING]\033[0m ";
|
||||||
const char *const ERROR = "\033[0;31m[ERROR]\033[0m ";
|
const char *const ERROR = "\033[0;31m[ERROR]\033[0m ";
|
||||||
|
@ -48,7 +49,8 @@ struct RobotState
|
||||||
} motor_state;
|
} motor_state;
|
||||||
};
|
};
|
||||||
|
|
||||||
enum STATE {
|
enum STATE
|
||||||
|
{
|
||||||
STATE_WAITING = 0,
|
STATE_WAITING = 0,
|
||||||
STATE_POS_GETUP,
|
STATE_POS_GETUP,
|
||||||
STATE_RL_INIT,
|
STATE_RL_INIT,
|
||||||
|
@ -165,4 +167,4 @@ protected:
|
||||||
torch::Tensor output_dof_pos;
|
torch::Tensor output_dof_pos;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif
|
#endif // RL_SDK_HPP
|
||||||
|
|
|
@ -234,4 +234,3 @@ class RL_Sim(RL):
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
rl_sim = RL_Sim()
|
rl_sim = RL_Sim()
|
||||||
rospy.spin()
|
rospy.spin()
|
||||||
|
|
|
@ -39,7 +39,6 @@ RL_Real::RL_Real() : unitree_safe(UNITREE_LEGGED_SDK::LeggedType::A1), unitree_u
|
||||||
this->loop_control->start();
|
this->loop_control->start();
|
||||||
this->loop_rl->start();
|
this->loop_rl->start();
|
||||||
|
|
||||||
|
|
||||||
#ifdef PLOT
|
#ifdef PLOT
|
||||||
this->plot_t = std::vector<int>(this->plot_size, 0);
|
this->plot_t = std::vector<int>(this->plot_size, 0);
|
||||||
this->plot_real_joint_pos.resize(this->params.num_of_dofs);
|
this->plot_real_joint_pos.resize(this->params.num_of_dofs);
|
||||||
|
@ -225,7 +224,7 @@ int main(int argc, char **argv)
|
||||||
while (1)
|
while (1)
|
||||||
{
|
{
|
||||||
sleep(10);
|
sleep(10);
|
||||||
};
|
}
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,8 +46,8 @@ RL_Sim::RL_Sim()
|
||||||
for (int i = 0; i < this->params.num_of_dofs; ++i)
|
for (int i = 0; i < this->params.num_of_dofs; ++i)
|
||||||
{
|
{
|
||||||
// joint need to rename as xxx_joint
|
// joint need to rename as xxx_joint
|
||||||
this->joint_publishers[this->params.joint_controller_names[i]] = nh.advertise<robot_msgs::MotorCommand>(
|
this->joint_publishers[this->params.joint_controller_names[i]] =
|
||||||
this->ros_namespace + this->params.joint_controller_names[i] + "/command", 10);
|
nh.advertise<robot_msgs::MotorCommand>(this->ros_namespace + this->params.joint_controller_names[i] + "/command", 10);
|
||||||
}
|
}
|
||||||
|
|
||||||
// subscriber
|
// subscriber
|
||||||
|
|
|
@ -66,7 +66,6 @@ public:
|
||||||
void setGains(const double &p, const double &i, const double &d, const double &i_max, const double &i_min, const bool &antiwindup = false);
|
void setGains(const double &p, const double &i, const double &d, const double &i_max, const double &i_min, const bool &antiwindup = false);
|
||||||
void getGains(double &p, double &i, double &d, double &i_max, double &i_min, bool &antiwindup);
|
void getGains(double &p, double &i, double &d, double &i_max, double &i_min, bool &antiwindup);
|
||||||
void getGains(double &p, double &i, double &d, double &i_max, double &i_min);
|
void getGains(double &p, double &i, double &d, double &i_max, double &i_min);
|
||||||
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,10 +3,14 @@
|
||||||
|
|
||||||
// #define rqtTune // use rqt or not
|
// #define rqtTune // use rqt or not
|
||||||
|
|
||||||
double clamp(double& value, double min, double max) {
|
double clamp(double &value, double min, double max)
|
||||||
if (value < min) {
|
{
|
||||||
|
if (value < min)
|
||||||
|
{
|
||||||
value = min;
|
value = min;
|
||||||
} else if (value > max) {
|
}
|
||||||
|
else if (value > max)
|
||||||
|
{
|
||||||
value = max;
|
value = max;
|
||||||
}
|
}
|
||||||
return value;
|
return value;
|
||||||
|
@ -15,13 +19,15 @@ double clamp(double& value, double min, double max) {
|
||||||
namespace robot_joint_controller
|
namespace robot_joint_controller
|
||||||
{
|
{
|
||||||
|
|
||||||
RobotJointController::RobotJointController(){
|
RobotJointController::RobotJointController()
|
||||||
|
{
|
||||||
memset(&lastCommand, 0, sizeof(robot_msgs::MotorCommand));
|
memset(&lastCommand, 0, sizeof(robot_msgs::MotorCommand));
|
||||||
memset(&lastState, 0, sizeof(robot_msgs::MotorState));
|
memset(&lastState, 0, sizeof(robot_msgs::MotorState));
|
||||||
memset(&servoCommand, 0, sizeof(ServoCommand));
|
memset(&servoCommand, 0, sizeof(ServoCommand));
|
||||||
}
|
}
|
||||||
|
|
||||||
RobotJointController::~RobotJointController(){
|
RobotJointController::~RobotJointController()
|
||||||
|
{
|
||||||
sub_ft.shutdown();
|
sub_ft.shutdown();
|
||||||
sub_command.shutdown();
|
sub_command.shutdown();
|
||||||
}
|
}
|
||||||
|
@ -43,7 +49,8 @@ namespace robot_joint_controller
|
||||||
bool RobotJointController::init(hardware_interface::EffortJointInterface *robot, ros::NodeHandle &n)
|
bool RobotJointController::init(hardware_interface::EffortJointInterface *robot, ros::NodeHandle &n)
|
||||||
{
|
{
|
||||||
name_space = n.getNamespace();
|
name_space = n.getNamespace();
|
||||||
if (!n.getParam("joint", joint_name)){
|
if (!n.getParam("joint", joint_name))
|
||||||
|
{
|
||||||
ROS_ERROR("No joint given in namespace: '%s')", n.getNamespace().c_str());
|
ROS_ERROR("No joint given in namespace: '%s')", n.getNamespace().c_str());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -56,12 +63,14 @@ namespace robot_joint_controller
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
urdf::Model urdf; // Get URDF info about joint
|
urdf::Model urdf; // Get URDF info about joint
|
||||||
if (!urdf.initParamWithNodeHandle("robot_description", n)){
|
if (!urdf.initParamWithNodeHandle("robot_description", n))
|
||||||
|
{
|
||||||
ROS_ERROR("Failed to parse urdf file");
|
ROS_ERROR("Failed to parse urdf file");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
joint_urdf = urdf.getJoint(joint_name);
|
joint_urdf = urdf.getJoint(joint_name);
|
||||||
if (!joint_urdf){
|
if (!joint_urdf)
|
||||||
|
{
|
||||||
ROS_ERROR("Could not find joint '%s' in urdf", joint_name.c_str());
|
ROS_ERROR("Could not find joint '%s' in urdf", joint_name.c_str());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -118,13 +127,15 @@ namespace robot_joint_controller
|
||||||
servoCommand.pos = lastCommand.q;
|
servoCommand.pos = lastCommand.q;
|
||||||
positionLimits(servoCommand.pos);
|
positionLimits(servoCommand.pos);
|
||||||
servoCommand.posStiffness = lastCommand.kp;
|
servoCommand.posStiffness = lastCommand.kp;
|
||||||
if(fabs(lastCommand.q - PosStopF) < 0.00001){
|
if (fabs(lastCommand.q - PosStopF) < 0.00001)
|
||||||
|
{
|
||||||
servoCommand.posStiffness = 0;
|
servoCommand.posStiffness = 0;
|
||||||
}
|
}
|
||||||
servoCommand.vel = lastCommand.dq;
|
servoCommand.vel = lastCommand.dq;
|
||||||
velocityLimits(servoCommand.vel);
|
velocityLimits(servoCommand.vel);
|
||||||
servoCommand.velStiffness = lastCommand.kd;
|
servoCommand.velStiffness = lastCommand.kd;
|
||||||
if(fabs(lastCommand.dq - VelStopF) < 0.00001){
|
if (fabs(lastCommand.dq - VelStopF) < 0.00001)
|
||||||
|
{
|
||||||
servoCommand.velStiffness = 0;
|
servoCommand.velStiffness = 0;
|
||||||
}
|
}
|
||||||
servoCommand.torque = lastCommand.tau;
|
servoCommand.torque = lastCommand.tau;
|
||||||
|
@ -151,7 +162,8 @@ namespace robot_joint_controller
|
||||||
lastState.tauEst = joint.getEffort();
|
lastState.tauEst = joint.getEffort();
|
||||||
|
|
||||||
// publish state
|
// publish state
|
||||||
if (controller_state_publisher_ && controller_state_publisher_->trylock()) {
|
if (controller_state_publisher_ && controller_state_publisher_->trylock())
|
||||||
|
{
|
||||||
controller_state_publisher_->msg_.q = lastState.q;
|
controller_state_publisher_->msg_.q = lastState.q;
|
||||||
controller_state_publisher_->msg_.dq = lastState.dq;
|
controller_state_publisher_->msg_.dq = lastState.dq;
|
||||||
controller_state_publisher_->msg_.tauEst = lastState.tauEst;
|
controller_state_publisher_->msg_.tauEst = lastState.tauEst;
|
||||||
|
|
Loading…
Reference in New Issue