mirror of https://github.com/fan-ziqi/rl_sar.git
style: code format
This commit is contained in:
parent
10808335fd
commit
8b6a2d6af1
|
@ -16,6 +16,7 @@ class RL_Real : public RL
|
|||
public:
|
||||
RL_Real();
|
||||
~RL_Real();
|
||||
|
||||
private:
|
||||
// rl functions
|
||||
torch::Tensor Forward() override;
|
||||
|
@ -59,4 +60,4 @@ private:
|
|||
int state_mapping[12] = {3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8};
|
||||
};
|
||||
|
||||
#endif
|
||||
#endif // RL_REAL_HPP
|
||||
|
|
|
@ -21,6 +21,7 @@ class RL_Sim : public RL
|
|||
public:
|
||||
RL_Sim();
|
||||
~RL_Sim();
|
||||
|
||||
private:
|
||||
// rl functions
|
||||
torch::Tensor Forward() override;
|
||||
|
@ -72,4 +73,4 @@ private:
|
|||
void MapData(const std::vector<double> &source_data, std::vector<double> &target_data);
|
||||
};
|
||||
|
||||
#endif
|
||||
#endif // RL_SIM_HPP
|
||||
|
|
|
@ -14,16 +14,6 @@
|
|||
|
||||
class LoopFunc
|
||||
{
|
||||
private:
|
||||
std::string _name;
|
||||
double _period;
|
||||
std::function<void()> _func;
|
||||
int _bindCPU;
|
||||
std::atomic<bool> _running;
|
||||
std::mutex _mutex;
|
||||
std::condition_variable _cv;
|
||||
std::thread _thread;
|
||||
|
||||
public:
|
||||
LoopFunc(const std::string &name, double period, std::function<void()> func, int bindCPU = -1)
|
||||
: _name(name), _period(period), _func(func), _bindCPU(bindCPU), _running(false) {}
|
||||
|
@ -57,7 +47,17 @@ class LoopFunc
|
|||
}
|
||||
log("[Loop End] named: " + _name);
|
||||
}
|
||||
|
||||
private:
|
||||
std::string _name;
|
||||
double _period;
|
||||
std::function<void()> _func;
|
||||
int _bindCPU;
|
||||
std::atomic<bool> _running;
|
||||
std::mutex _mutex;
|
||||
std::condition_variable _cv;
|
||||
std::thread _thread;
|
||||
|
||||
void loop()
|
||||
{
|
||||
while (_running)
|
||||
|
@ -72,7 +72,8 @@ class LoopFunc
|
|||
if (sleepTime.count() > 0)
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(_mutex);
|
||||
if (_cv.wait_for(lock, sleepTime, [this]{ return !_running; }))
|
||||
if (_cv.wait_for(lock, sleepTime, [this]
|
||||
{ return !_running; }))
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
@ -108,4 +109,4 @@ class LoopFunc
|
|||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
#endif // LOOP_H
|
||||
|
|
|
@ -16,7 +16,8 @@ ObservationBuffer::ObservationBuffer(int num_envs,
|
|||
void ObservationBuffer::reset(std::vector<int> reset_idxs, torch::Tensor new_obs)
|
||||
{
|
||||
std::vector<torch::indexing::TensorIndex> indices;
|
||||
for (int idx : reset_idxs) {
|
||||
for (int idx : reset_idxs)
|
||||
{
|
||||
indices.push_back(torch::indexing::Slice(idx));
|
||||
}
|
||||
obs_buf.index_put_(indices, new_obs.repeat({1, include_history_steps}));
|
||||
|
|
|
@ -4,7 +4,8 @@
|
|||
#include <torch/torch.h>
|
||||
#include <vector>
|
||||
|
||||
class ObservationBuffer {
|
||||
class ObservationBuffer
|
||||
{
|
||||
public:
|
||||
ObservationBuffer(int num_envs, int num_obs, int include_history_steps);
|
||||
ObservationBuffer();
|
||||
|
|
|
@ -310,22 +310,52 @@ void RL::KeyboardInterface()
|
|||
int c = fgetc(stdin);
|
||||
switch (c)
|
||||
{
|
||||
case '0': this->control.control_state = STATE_POS_GETUP; break;
|
||||
case 'p': this->control.control_state = STATE_RL_INIT; break;
|
||||
case '1': this->control.control_state = STATE_POS_GETDOWN; break;
|
||||
case 'q': break;
|
||||
case 'w': this->control.x += 0.1; break;
|
||||
case 's': this->control.x -= 0.1; break;
|
||||
case 'a': this->control.yaw += 0.1; break;
|
||||
case 'd': this->control.yaw -= 0.1; break;
|
||||
case 'i': break;
|
||||
case 'k': break;
|
||||
case 'j': this->control.y += 0.1; break;
|
||||
case 'l': this->control.y -= 0.1; break;
|
||||
case ' ': this->control.x = 0; this->control.y = 0; this->control.yaw = 0; break;
|
||||
case 'r': this->control.control_state = STATE_RESET_SIMULATION; break;
|
||||
case '\n': this->control.control_state = STATE_TOGGLE_SIMULATION; break;
|
||||
default: break;
|
||||
case '0':
|
||||
this->control.control_state = STATE_POS_GETUP;
|
||||
break;
|
||||
case 'p':
|
||||
this->control.control_state = STATE_RL_INIT;
|
||||
break;
|
||||
case '1':
|
||||
this->control.control_state = STATE_POS_GETDOWN;
|
||||
break;
|
||||
case 'q':
|
||||
break;
|
||||
case 'w':
|
||||
this->control.x += 0.1;
|
||||
break;
|
||||
case 's':
|
||||
this->control.x -= 0.1;
|
||||
break;
|
||||
case 'a':
|
||||
this->control.yaw += 0.1;
|
||||
break;
|
||||
case 'd':
|
||||
this->control.yaw -= 0.1;
|
||||
break;
|
||||
case 'i':
|
||||
break;
|
||||
case 'k':
|
||||
break;
|
||||
case 'j':
|
||||
this->control.y += 0.1;
|
||||
break;
|
||||
case 'l':
|
||||
this->control.y -= 0.1;
|
||||
break;
|
||||
case ' ':
|
||||
this->control.x = 0;
|
||||
this->control.y = 0;
|
||||
this->control.yaw = 0;
|
||||
break;
|
||||
case 'r':
|
||||
this->control.control_state = STATE_RESET_SIMULATION;
|
||||
break;
|
||||
case '\n':
|
||||
this->control.control_state = STATE_TOGGLE_SIMULATION;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -380,7 +410,8 @@ void RL::ReadYaml(std::string robot_name)
|
|||
try
|
||||
{
|
||||
config = YAML::LoadFile(config_path)[robot_name];
|
||||
} catch(YAML::BadFile &e)
|
||||
}
|
||||
catch (YAML::BadFile &e)
|
||||
{
|
||||
std::cout << LOGGER::ERROR << "The file '" << config_path << "' does not exist" << std::endl;
|
||||
return;
|
||||
|
|
|
@ -8,7 +8,8 @@
|
|||
|
||||
#include <yaml-cpp/yaml.h>
|
||||
|
||||
namespace LOGGER {
|
||||
namespace LOGGER
|
||||
{
|
||||
const char *const INFO = "\033[0;37m[INFO]\033[0m ";
|
||||
const char *const WARNING = "\033[0;33m[WARNING]\033[0m ";
|
||||
const char *const ERROR = "\033[0;31m[ERROR]\033[0m ";
|
||||
|
@ -48,7 +49,8 @@ struct RobotState
|
|||
} motor_state;
|
||||
};
|
||||
|
||||
enum STATE {
|
||||
enum STATE
|
||||
{
|
||||
STATE_WAITING = 0,
|
||||
STATE_POS_GETUP,
|
||||
STATE_RL_INIT,
|
||||
|
@ -165,4 +167,4 @@ protected:
|
|||
torch::Tensor output_dof_pos;
|
||||
};
|
||||
|
||||
#endif
|
||||
#endif // RL_SDK_HPP
|
||||
|
|
|
@ -234,4 +234,3 @@ class RL_Sim(RL):
|
|||
if __name__ == "__main__":
|
||||
rl_sim = RL_Sim()
|
||||
rospy.spin()
|
||||
|
|
@ -39,7 +39,6 @@ RL_Real::RL_Real() : unitree_safe(UNITREE_LEGGED_SDK::LeggedType::A1), unitree_u
|
|||
this->loop_control->start();
|
||||
this->loop_rl->start();
|
||||
|
||||
|
||||
#ifdef PLOT
|
||||
this->plot_t = std::vector<int>(this->plot_size, 0);
|
||||
this->plot_real_joint_pos.resize(this->params.num_of_dofs);
|
||||
|
@ -225,7 +224,7 @@ int main(int argc, char **argv)
|
|||
while (1)
|
||||
{
|
||||
sleep(10);
|
||||
};
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -46,8 +46,8 @@ RL_Sim::RL_Sim()
|
|||
for (int i = 0; i < this->params.num_of_dofs; ++i)
|
||||
{
|
||||
// joint need to rename as xxx_joint
|
||||
this->joint_publishers[this->params.joint_controller_names[i]] = nh.advertise<robot_msgs::MotorCommand>(
|
||||
this->ros_namespace + this->params.joint_controller_names[i] + "/command", 10);
|
||||
this->joint_publishers[this->params.joint_controller_names[i]] =
|
||||
nh.advertise<robot_msgs::MotorCommand>(this->ros_namespace + this->params.joint_controller_names[i] + "/command", 10);
|
||||
}
|
||||
|
||||
// subscriber
|
||||
|
|
|
@ -66,7 +66,6 @@ public:
|
|||
void setGains(const double &p, const double &i, const double &d, const double &i_max, const double &i_min, const bool &antiwindup = false);
|
||||
void getGains(double &p, double &i, double &d, double &i_max, double &i_min, bool &antiwindup);
|
||||
void getGains(double &p, double &i, double &d, double &i_max, double &i_min);
|
||||
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -3,10 +3,14 @@
|
|||
|
||||
// #define rqtTune // use rqt or not
|
||||
|
||||
double clamp(double& value, double min, double max) {
|
||||
if (value < min) {
|
||||
double clamp(double &value, double min, double max)
|
||||
{
|
||||
if (value < min)
|
||||
{
|
||||
value = min;
|
||||
} else if (value > max) {
|
||||
}
|
||||
else if (value > max)
|
||||
{
|
||||
value = max;
|
||||
}
|
||||
return value;
|
||||
|
@ -15,13 +19,15 @@ double clamp(double& value, double min, double max) {
|
|||
namespace robot_joint_controller
|
||||
{
|
||||
|
||||
RobotJointController::RobotJointController(){
|
||||
RobotJointController::RobotJointController()
|
||||
{
|
||||
memset(&lastCommand, 0, sizeof(robot_msgs::MotorCommand));
|
||||
memset(&lastState, 0, sizeof(robot_msgs::MotorState));
|
||||
memset(&servoCommand, 0, sizeof(ServoCommand));
|
||||
}
|
||||
|
||||
RobotJointController::~RobotJointController(){
|
||||
RobotJointController::~RobotJointController()
|
||||
{
|
||||
sub_ft.shutdown();
|
||||
sub_command.shutdown();
|
||||
}
|
||||
|
@ -43,7 +49,8 @@ namespace robot_joint_controller
|
|||
bool RobotJointController::init(hardware_interface::EffortJointInterface *robot, ros::NodeHandle &n)
|
||||
{
|
||||
name_space = n.getNamespace();
|
||||
if (!n.getParam("joint", joint_name)){
|
||||
if (!n.getParam("joint", joint_name))
|
||||
{
|
||||
ROS_ERROR("No joint given in namespace: '%s')", n.getNamespace().c_str());
|
||||
return false;
|
||||
}
|
||||
|
@ -56,12 +63,14 @@ namespace robot_joint_controller
|
|||
#endif
|
||||
|
||||
urdf::Model urdf; // Get URDF info about joint
|
||||
if (!urdf.initParamWithNodeHandle("robot_description", n)){
|
||||
if (!urdf.initParamWithNodeHandle("robot_description", n))
|
||||
{
|
||||
ROS_ERROR("Failed to parse urdf file");
|
||||
return false;
|
||||
}
|
||||
joint_urdf = urdf.getJoint(joint_name);
|
||||
if (!joint_urdf){
|
||||
if (!joint_urdf)
|
||||
{
|
||||
ROS_ERROR("Could not find joint '%s' in urdf", joint_name.c_str());
|
||||
return false;
|
||||
}
|
||||
|
@ -118,13 +127,15 @@ namespace robot_joint_controller
|
|||
servoCommand.pos = lastCommand.q;
|
||||
positionLimits(servoCommand.pos);
|
||||
servoCommand.posStiffness = lastCommand.kp;
|
||||
if(fabs(lastCommand.q - PosStopF) < 0.00001){
|
||||
if (fabs(lastCommand.q - PosStopF) < 0.00001)
|
||||
{
|
||||
servoCommand.posStiffness = 0;
|
||||
}
|
||||
servoCommand.vel = lastCommand.dq;
|
||||
velocityLimits(servoCommand.vel);
|
||||
servoCommand.velStiffness = lastCommand.kd;
|
||||
if(fabs(lastCommand.dq - VelStopF) < 0.00001){
|
||||
if (fabs(lastCommand.dq - VelStopF) < 0.00001)
|
||||
{
|
||||
servoCommand.velStiffness = 0;
|
||||
}
|
||||
servoCommand.torque = lastCommand.tau;
|
||||
|
@ -151,7 +162,8 @@ namespace robot_joint_controller
|
|||
lastState.tauEst = joint.getEffort();
|
||||
|
||||
// publish state
|
||||
if (controller_state_publisher_ && controller_state_publisher_->trylock()) {
|
||||
if (controller_state_publisher_ && controller_state_publisher_->trylock())
|
||||
{
|
||||
controller_state_publisher_->msg_.q = lastState.q;
|
||||
controller_state_publisher_->msg_.dq = lastState.dq;
|
||||
controller_state_publisher_->msg_.tauEst = lastState.tauEst;
|
||||
|
|
Loading…
Reference in New Issue