style: code format

This commit is contained in:
fan-ziqi 2024-10-08 19:45:59 +08:00
parent 10808335fd
commit 8b6a2d6af1
22 changed files with 317 additions and 270 deletions

View File

@ -16,6 +16,7 @@ class RL_Real : public RL
public: public:
RL_Real(); RL_Real();
~RL_Real(); ~RL_Real();
private: private:
// rl functions // rl functions
torch::Tensor Forward() override; torch::Tensor Forward() override;
@ -43,8 +44,8 @@ private:
void Plot(); void Plot();
// unitree interface // unitree interface
void UDPSend(){unitree_udp.Send();} void UDPSend() { unitree_udp.Send(); }
void UDPRecv(){unitree_udp.Recv();} void UDPRecv() { unitree_udp.Recv(); }
UNITREE_LEGGED_SDK::Safety unitree_safe; UNITREE_LEGGED_SDK::Safety unitree_safe;
UNITREE_LEGGED_SDK::UDP unitree_udp; UNITREE_LEGGED_SDK::UDP unitree_udp;
UNITREE_LEGGED_SDK::LowCmd unitree_low_command = {0}; UNITREE_LEGGED_SDK::LowCmd unitree_low_command = {0};
@ -59,4 +60,4 @@ private:
int state_mapping[12] = {3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8}; int state_mapping[12] = {3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8};
}; };
#endif #endif // RL_REAL_HPP

View File

@ -21,6 +21,7 @@ class RL_Sim : public RL
public: public:
RL_Sim(); RL_Sim();
~RL_Sim(); ~RL_Sim();
private: private:
// rl functions // rl functions
torch::Tensor Forward() override; torch::Tensor Forward() override;
@ -69,7 +70,7 @@ private:
std::vector<double> mapped_joint_positions; std::vector<double> mapped_joint_positions;
std::vector<double> mapped_joint_velocities; std::vector<double> mapped_joint_velocities;
std::vector<double> mapped_joint_efforts; std::vector<double> mapped_joint_efforts;
void MapData(const std::vector<double>& source_data, std::vector<double>& target_data); void MapData(const std::vector<double> &source_data, std::vector<double> &target_data);
}; };
#endif #endif // RL_SIM_HPP

View File

@ -14,17 +14,7 @@
class LoopFunc class LoopFunc
{ {
private: public:
std::string _name;
double _period;
std::function<void()> _func;
int _bindCPU;
std::atomic<bool> _running;
std::mutex _mutex;
std::condition_variable _cv;
std::thread _thread;
public:
LoopFunc(const std::string &name, double period, std::function<void()> func, int bindCPU = -1) LoopFunc(const std::string &name, double period, std::function<void()> func, int bindCPU = -1)
: _name(name), _period(period), _func(func), _bindCPU(bindCPU), _running(false) {} : _name(name), _period(period), _func(func), _bindCPU(bindCPU), _running(false) {}
@ -57,7 +47,17 @@ class LoopFunc
} }
log("[Loop End] named: " + _name); log("[Loop End] named: " + _name);
} }
private:
private:
std::string _name;
double _period;
std::function<void()> _func;
int _bindCPU;
std::atomic<bool> _running;
std::mutex _mutex;
std::condition_variable _cv;
std::thread _thread;
void loop() void loop()
{ {
while (_running) while (_running)
@ -72,7 +72,8 @@ class LoopFunc
if (sleepTime.count() > 0) if (sleepTime.count() > 0)
{ {
std::unique_lock<std::mutex> lock(_mutex); std::unique_lock<std::mutex> lock(_mutex);
if (_cv.wait_for(lock, sleepTime, [this]{ return !_running; })) if (_cv.wait_for(lock, sleepTime, [this]
{ return !_running; }))
{ {
break; break;
} }
@ -87,7 +88,7 @@ class LoopFunc
return stream.str(); return stream.str();
} }
void log(const std::string& message) void log(const std::string &message)
{ {
static std::mutex logMutex; static std::mutex logMutex;
std::lock_guard<std::mutex> lock(logMutex); std::lock_guard<std::mutex> lock(logMutex);
@ -108,4 +109,4 @@ class LoopFunc
} }
}; };
#endif #endif // LOOP_H

View File

@ -16,7 +16,8 @@ ObservationBuffer::ObservationBuffer(int num_envs,
void ObservationBuffer::reset(std::vector<int> reset_idxs, torch::Tensor new_obs) void ObservationBuffer::reset(std::vector<int> reset_idxs, torch::Tensor new_obs)
{ {
std::vector<torch::indexing::TensorIndex> indices; std::vector<torch::indexing::TensorIndex> indices;
for (int idx : reset_idxs) { for (int idx : reset_idxs)
{
indices.push_back(torch::indexing::Slice(idx)); indices.push_back(torch::indexing::Slice(idx));
} }
obs_buf.index_put_(indices, new_obs.repeat({1, include_history_steps})); obs_buf.index_put_(indices, new_obs.repeat({1, include_history_steps}));

View File

@ -4,7 +4,8 @@
#include <torch/torch.h> #include <torch/torch.h>
#include <vector> #include <vector>
class ObservationBuffer { class ObservationBuffer
{
public: public:
ObservationBuffer(int num_envs, int num_obs, int include_history_steps); ObservationBuffer(int num_envs, int num_obs, int include_history_steps);
ObservationBuffer(); ObservationBuffer();

View File

@ -15,34 +15,34 @@ torch::Tensor RL::ComputeObservation()
{ {
std::vector<torch::Tensor> obs_list; std::vector<torch::Tensor> obs_list;
for(const std::string& observation : this->params.observations) for (const std::string &observation : this->params.observations)
{ {
if(observation == "lin_vel") if (observation == "lin_vel")
{ {
obs_list.push_back(this->obs.lin_vel * this->params.lin_vel_scale); obs_list.push_back(this->obs.lin_vel * this->params.lin_vel_scale);
} }
else if(observation == "ang_vel") else if (observation == "ang_vel")
{ {
// obs_list.push_back(this->obs.ang_vel * this->params.ang_vel_scale); // TODO is QuatRotateInverse necessery? // obs_list.push_back(this->obs.ang_vel * this->params.ang_vel_scale); // TODO is QuatRotateInverse necessery?
obs_list.push_back(this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel, this->params.framework) * this->params.ang_vel_scale); obs_list.push_back(this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel, this->params.framework) * this->params.ang_vel_scale);
} }
else if(observation == "gravity_vec") else if (observation == "gravity_vec")
{ {
obs_list.push_back(this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec, this->params.framework)); obs_list.push_back(this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec, this->params.framework));
} }
else if(observation == "commands") else if (observation == "commands")
{ {
obs_list.push_back(this->obs.commands * this->params.commands_scale); obs_list.push_back(this->obs.commands * this->params.commands_scale);
} }
else if(observation == "dof_pos") else if (observation == "dof_pos")
{ {
obs_list.push_back((this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale); obs_list.push_back((this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale);
} }
else if(observation == "dof_vel") else if (observation == "dof_vel")
{ {
obs_list.push_back(this->obs.dof_vel * this->params.dof_vel_scale); obs_list.push_back(this->obs.dof_vel * this->params.dof_vel_scale);
} }
else if(observation == "actions") else if (observation == "actions")
{ {
obs_list.push_back(this->obs.actions); obs_list.push_back(this->obs.actions);
} }
@ -92,16 +92,16 @@ torch::Tensor RL::ComputePosition(torch::Tensor actions)
return actions_scaled + this->params.default_dof_pos; return actions_scaled + this->params.default_dof_pos;
} }
torch::Tensor RL::QuatRotateInverse(torch::Tensor q, torch::Tensor v, const std::string& framework) torch::Tensor RL::QuatRotateInverse(torch::Tensor q, torch::Tensor v, const std::string &framework)
{ {
torch::Tensor q_w; torch::Tensor q_w;
torch::Tensor q_vec; torch::Tensor q_vec;
if(framework == "isaacsim") if (framework == "isaacsim")
{ {
q_w = q.index({torch::indexing::Slice(), 0}); q_w = q.index({torch::indexing::Slice(), 0});
q_vec = q.index({torch::indexing::Slice(), torch::indexing::Slice(1, 4)}); q_vec = q.index({torch::indexing::Slice(), torch::indexing::Slice(1, 4)});
} }
else if(framework == "isaacgym") else if (framework == "isaacgym")
{ {
q_w = q.index({torch::indexing::Slice(), 3}); q_w = q.index({torch::indexing::Slice(), 3});
q_vec = q.index({torch::indexing::Slice(), torch::indexing::Slice(0, 3)}); q_vec = q.index({torch::indexing::Slice(), torch::indexing::Slice(0, 3)});
@ -122,17 +122,17 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
static float getdown_percent = 0.0; static float getdown_percent = 0.0;
// waiting // waiting
if(this->running_state == STATE_WAITING) if (this->running_state == STATE_WAITING)
{ {
for(int i = 0; i < this->params.num_of_dofs; ++i) for (int i = 0; i < this->params.num_of_dofs; ++i)
{ {
command->motor_command.q[i] = state->motor_state.q[i]; command->motor_command.q[i] = state->motor_state.q[i];
} }
if(this->control.control_state == STATE_POS_GETUP) if (this->control.control_state == STATE_POS_GETUP)
{ {
this->control.control_state = STATE_WAITING; this->control.control_state = STATE_WAITING;
getup_percent = 0.0; getup_percent = 0.0;
for(int i = 0; i < this->params.num_of_dofs; ++i) for (int i = 0; i < this->params.num_of_dofs; ++i)
{ {
now_state.motor_state.q[i] = state->motor_state.q[i]; now_state.motor_state.q[i] = state->motor_state.q[i];
start_state.motor_state.q[i] = now_state.motor_state.q[i]; start_state.motor_state.q[i] = now_state.motor_state.q[i];
@ -142,13 +142,13 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
} }
} }
// stand up (position control) // stand up (position control)
else if(this->running_state == STATE_POS_GETUP) else if (this->running_state == STATE_POS_GETUP)
{ {
if(getup_percent < 1.0) if (getup_percent < 1.0)
{ {
getup_percent += 1 / 500.0; getup_percent += 1 / 500.0;
getup_percent = getup_percent > 1.0 ? 1.0 : getup_percent; getup_percent = getup_percent > 1.0 ? 1.0 : getup_percent;
for(int i = 0; i < this->params.num_of_dofs; ++i) for (int i = 0; i < this->params.num_of_dofs; ++i)
{ {
command->motor_command.q[i] = (1 - getup_percent) * now_state.motor_state.q[i] + getup_percent * this->params.default_dof_pos[0][i].item<double>(); command->motor_command.q[i] = (1 - getup_percent) * now_state.motor_state.q[i] + getup_percent * this->params.default_dof_pos[0][i].item<double>();
command->motor_command.dq[i] = 0; command->motor_command.dq[i] = 0;
@ -158,17 +158,17 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
} }
std::cout << "\r" << std::flush << LOGGER::INFO << "Getting up " << std::fixed << std::setprecision(2) << getup_percent * 100.0 << std::flush; std::cout << "\r" << std::flush << LOGGER::INFO << "Getting up " << std::fixed << std::setprecision(2) << getup_percent * 100.0 << std::flush;
} }
if(this->control.control_state == STATE_RL_INIT) if (this->control.control_state == STATE_RL_INIT)
{ {
this->control.control_state = STATE_WAITING; this->control.control_state = STATE_WAITING;
this->running_state = STATE_RL_INIT; this->running_state = STATE_RL_INIT;
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_RL_INIT" << std::endl; std::cout << std::endl << LOGGER::INFO << "Switching to STATE_RL_INIT" << std::endl;
} }
else if(this->control.control_state == STATE_POS_GETDOWN) else if (this->control.control_state == STATE_POS_GETDOWN)
{ {
this->control.control_state = STATE_WAITING; this->control.control_state = STATE_WAITING;
getdown_percent = 0.0; getdown_percent = 0.0;
for(int i = 0; i < this->params.num_of_dofs; ++i) for (int i = 0; i < this->params.num_of_dofs; ++i)
{ {
now_state.motor_state.q[i] = state->motor_state.q[i]; now_state.motor_state.q[i] = state->motor_state.q[i];
} }
@ -177,9 +177,9 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
} }
} }
// init obs and start rl loop // init obs and start rl loop
else if(this->running_state == STATE_RL_INIT) else if (this->running_state == STATE_RL_INIT)
{ {
if(getup_percent == 1) if (getup_percent == 1)
{ {
this->InitObservations(); this->InitObservations();
this->InitOutputs(); this->InitOutputs();
@ -189,10 +189,10 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
} }
} }
// rl loop // rl loop
else if(this->running_state == STATE_RL_RUNNING) else if (this->running_state == STATE_RL_RUNNING)
{ {
std::cout << "\r" << std::flush << LOGGER::INFO << "RL Controller x:" << this->control.x << " y:" << this->control.y << " yaw:" << this->control.yaw << std::flush; std::cout << "\r" << std::flush << LOGGER::INFO << "RL Controller x:" << this->control.x << " y:" << this->control.y << " yaw:" << this->control.yaw << std::flush;
for(int i = 0; i < this->params.num_of_dofs; ++i) for (int i = 0; i < this->params.num_of_dofs; ++i)
{ {
command->motor_command.q[i] = this->output_dof_pos[0][i].item<double>(); command->motor_command.q[i] = this->output_dof_pos[0][i].item<double>();
command->motor_command.dq[i] = 0; command->motor_command.dq[i] = 0;
@ -200,22 +200,22 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
command->motor_command.kd[i] = this->params.rl_kd[0][i].item<double>(); command->motor_command.kd[i] = this->params.rl_kd[0][i].item<double>();
command->motor_command.tau[i] = 0; command->motor_command.tau[i] = 0;
} }
if(this->control.control_state == STATE_POS_GETDOWN) if (this->control.control_state == STATE_POS_GETDOWN)
{ {
this->control.control_state = STATE_WAITING; this->control.control_state = STATE_WAITING;
getdown_percent = 0.0; getdown_percent = 0.0;
for(int i = 0; i < this->params.num_of_dofs; ++i) for (int i = 0; i < this->params.num_of_dofs; ++i)
{ {
now_state.motor_state.q[i] = state->motor_state.q[i]; now_state.motor_state.q[i] = state->motor_state.q[i];
} }
this->running_state = STATE_POS_GETDOWN; this->running_state = STATE_POS_GETDOWN;
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_POS_GETDOWN" << std::endl; std::cout << std::endl << LOGGER::INFO << "Switching to STATE_POS_GETDOWN" << std::endl;
} }
else if(this->control.control_state == STATE_POS_GETUP) else if (this->control.control_state == STATE_POS_GETUP)
{ {
this->control.control_state = STATE_WAITING; this->control.control_state = STATE_WAITING;
getup_percent = 0.0; getup_percent = 0.0;
for(int i = 0; i < this->params.num_of_dofs; ++i) for (int i = 0; i < this->params.num_of_dofs; ++i)
{ {
now_state.motor_state.q[i] = state->motor_state.q[i]; now_state.motor_state.q[i] = state->motor_state.q[i];
} }
@ -224,13 +224,13 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
} }
} }
// get down (position control) // get down (position control)
else if(this->running_state == STATE_POS_GETDOWN) else if (this->running_state == STATE_POS_GETDOWN)
{ {
if(getdown_percent < 1.0) if (getdown_percent < 1.0)
{ {
getdown_percent += 1 / 500.0; getdown_percent += 1 / 500.0;
getdown_percent = getdown_percent > 1.0 ? 1.0 : getdown_percent; getdown_percent = getdown_percent > 1.0 ? 1.0 : getdown_percent;
for(int i = 0; i < this->params.num_of_dofs; ++i) for (int i = 0; i < this->params.num_of_dofs; ++i)
{ {
command->motor_command.q[i] = (1 - getdown_percent) * now_state.motor_state.q[i] + getdown_percent * start_state.motor_state.q[i]; command->motor_command.q[i] = (1 - getdown_percent) * now_state.motor_state.q[i] + getdown_percent * start_state.motor_state.q[i];
command->motor_command.dq[i] = 0; command->motor_command.dq[i] = 0;
@ -240,7 +240,7 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
} }
std::cout << "\r" << std::flush << LOGGER::INFO << "Getting down " << std::fixed << std::setprecision(2) << getdown_percent * 100.0 << std::flush; std::cout << "\r" << std::flush << LOGGER::INFO << "Getting down " << std::fixed << std::setprecision(2) << getdown_percent * 100.0 << std::flush;
} }
if(getdown_percent == 1) if (getdown_percent == 1)
{ {
this->InitObservations(); this->InitObservations();
this->InitOutputs(); this->InitOutputs();
@ -255,28 +255,28 @@ void RL::TorqueProtect(torch::Tensor origin_output_torques)
{ {
std::vector<int> out_of_range_indices; std::vector<int> out_of_range_indices;
std::vector<double> out_of_range_values; std::vector<double> out_of_range_values;
for(int i = 0; i < origin_output_torques.size(1); ++i) for (int i = 0; i < origin_output_torques.size(1); ++i)
{ {
double torque_value = origin_output_torques[0][i].item<double>(); double torque_value = origin_output_torques[0][i].item<double>();
double limit_lower = -this->params.torque_limits[0][i].item<double>(); double limit_lower = -this->params.torque_limits[0][i].item<double>();
double limit_upper = this->params.torque_limits[0][i].item<double>(); double limit_upper = this->params.torque_limits[0][i].item<double>();
if(torque_value < limit_lower || torque_value > limit_upper) if (torque_value < limit_lower || torque_value > limit_upper)
{ {
out_of_range_indices.push_back(i); out_of_range_indices.push_back(i);
out_of_range_values.push_back(torque_value); out_of_range_values.push_back(torque_value);
} }
} }
if(!out_of_range_indices.empty()) if (!out_of_range_indices.empty())
{ {
for(int i = 0; i < out_of_range_indices.size(); ++i) for (int i = 0; i < out_of_range_indices.size(); ++i)
{ {
int index = out_of_range_indices[i]; int index = out_of_range_indices[i];
double value = out_of_range_values[i]; double value = out_of_range_values[i];
double limit_lower = -this->params.torque_limits[0][index].item<double>(); double limit_lower = -this->params.torque_limits[0][index].item<double>();
double limit_upper = this->params.torque_limits[0][index].item<double>(); double limit_upper = this->params.torque_limits[0][index].item<double>();
std::cout << LOGGER::WARNING << "Torque(" << index+1 << ")=" << value << " out of range(" << limit_lower << ", " << limit_upper << ")" << std::endl; std::cout << LOGGER::WARNING << "Torque(" << index + 1 << ")=" << value << " out of range(" << limit_lower << ", " << limit_upper << ")" << std::endl;
} }
// Just a reminder, no protection // Just a reminder, no protection
// this->control.control_state = STATE_POS_GETDOWN; // this->control.control_state = STATE_POS_GETDOWN;
@ -305,64 +305,94 @@ static bool kbhit()
void RL::KeyboardInterface() void RL::KeyboardInterface()
{ {
if(kbhit()) if (kbhit())
{ {
int c = fgetc(stdin); int c = fgetc(stdin);
switch(c) switch (c)
{ {
case '0': this->control.control_state = STATE_POS_GETUP; break; case '0':
case 'p': this->control.control_state = STATE_RL_INIT; break; this->control.control_state = STATE_POS_GETUP;
case '1': this->control.control_state = STATE_POS_GETDOWN; break; break;
case 'q': break; case 'p':
case 'w': this->control.x += 0.1; break; this->control.control_state = STATE_RL_INIT;
case 's': this->control.x -= 0.1; break; break;
case 'a': this->control.yaw += 0.1; break; case '1':
case 'd': this->control.yaw -= 0.1; break; this->control.control_state = STATE_POS_GETDOWN;
case 'i': break; break;
case 'k': break; case 'q':
case 'j': this->control.y += 0.1; break; break;
case 'l': this->control.y -= 0.1; break; case 'w':
case ' ': this->control.x = 0; this->control.y = 0; this->control.yaw = 0; break; this->control.x += 0.1;
case 'r': this->control.control_state = STATE_RESET_SIMULATION; break; break;
case '\n': this->control.control_state = STATE_TOGGLE_SIMULATION; break; case 's':
default: break; this->control.x -= 0.1;
break;
case 'a':
this->control.yaw += 0.1;
break;
case 'd':
this->control.yaw -= 0.1;
break;
case 'i':
break;
case 'k':
break;
case 'j':
this->control.y += 0.1;
break;
case 'l':
this->control.y -= 0.1;
break;
case ' ':
this->control.x = 0;
this->control.y = 0;
this->control.yaw = 0;
break;
case 'r':
this->control.control_state = STATE_RESET_SIMULATION;
break;
case '\n':
this->control.control_state = STATE_TOGGLE_SIMULATION;
break;
default:
break;
} }
} }
} }
template<typename T> template <typename T>
std::vector<T> ReadVectorFromYaml(const YAML::Node& node) std::vector<T> ReadVectorFromYaml(const YAML::Node &node)
{ {
std::vector<T> values; std::vector<T> values;
for(const auto& val : node) for (const auto &val : node)
{ {
values.push_back(val.as<T>()); values.push_back(val.as<T>());
} }
return values; return values;
} }
template<typename T> template <typename T>
std::vector<T> ReadVectorFromYaml(const YAML::Node& node, const std::string& framework, const int& rows, const int& cols) std::vector<T> ReadVectorFromYaml(const YAML::Node &node, const std::string &framework, const int &rows, const int &cols)
{ {
std::vector<T> values; std::vector<T> values;
for(const auto& val : node) for (const auto &val : node)
{ {
values.push_back(val.as<T>()); values.push_back(val.as<T>());
} }
if(framework == "isaacsim") if (framework == "isaacsim")
{ {
std::vector<T> transposed_values(cols * rows); std::vector<T> transposed_values(cols * rows);
for(int r = 0; r < rows; ++r) for (int r = 0; r < rows; ++r)
{ {
for(int c = 0; c < cols; ++c) for (int c = 0; c < cols; ++c)
{ {
transposed_values[c * rows + r] = values[r * cols + c]; transposed_values[c * rows + r] = values[r * cols + c];
} }
} }
return transposed_values; return transposed_values;
} }
else if(framework == "isaacgym") else if (framework == "isaacgym")
{ {
return values; return values;
} }
@ -380,7 +410,8 @@ void RL::ReadYaml(std::string robot_name)
try try
{ {
config = YAML::LoadFile(config_path)[robot_name]; config = YAML::LoadFile(config_path)[robot_name];
} catch(YAML::BadFile &e) }
catch (YAML::BadFile &e)
{ {
std::cout << LOGGER::ERROR << "The file '" << config_path << "' does not exist" << std::endl; std::cout << LOGGER::ERROR << "The file '" << config_path << "' does not exist" << std::endl;
return; return;
@ -396,7 +427,7 @@ void RL::ReadYaml(std::string robot_name)
this->params.num_observations = config["num_observations"].as<int>(); this->params.num_observations = config["num_observations"].as<int>();
this->params.observations = ReadVectorFromYaml<std::string>(config["observations"]); this->params.observations = ReadVectorFromYaml<std::string>(config["observations"]);
this->params.clip_obs = config["clip_obs"].as<double>(); this->params.clip_obs = config["clip_obs"].as<double>();
if(config["clip_actions_lower"].IsNull() && config["clip_actions_upper"].IsNull()) if (config["clip_actions_lower"].IsNull() && config["clip_actions_upper"].IsNull())
{ {
this->params.clip_actions_upper = torch::tensor({}).view({1, -1}); this->params.clip_actions_upper = torch::tensor({}).view({1, -1});
this->params.clip_actions_lower = torch::tensor({}).view({1, -1}); this->params.clip_actions_lower = torch::tensor({}).view({1, -1});
@ -440,11 +471,11 @@ void RL::CSVInit(std::string robot_name)
csv_filename += ".csv"; csv_filename += ".csv";
std::ofstream file(csv_filename.c_str()); std::ofstream file(csv_filename.c_str());
for(int i = 0; i < 12; ++i) {file << "tau_cal_" << i << ",";} for(int i = 0; i < 12; ++i) { file << "tau_cal_" << i << ","; }
for(int i = 0; i < 12; ++i) {file << "tau_est_" << i << ",";} for(int i = 0; i < 12; ++i) { file << "tau_est_" << i << ","; }
for(int i = 0; i < 12; ++i) {file << "joint_pos_" << i << ",";} for(int i = 0; i < 12; ++i) { file << "joint_pos_" << i << ","; }
for(int i = 0; i < 12; ++i) {file << "joint_pos_target_" << i << ",";} for(int i = 0; i < 12; ++i) { file << "joint_pos_target_" << i << ","; }
for(int i = 0; i < 12; ++i) {file << "joint_vel_" << i << ",";} for(int i = 0; i < 12; ++i) { file << "joint_vel_" << i << ","; }
file << std::endl; file << std::endl;
@ -455,11 +486,11 @@ void RL::CSVLogger(torch::Tensor torque, torch::Tensor tau_est, torch::Tensor jo
{ {
std::ofstream file(csv_filename.c_str(), std::ios_base::app); std::ofstream file(csv_filename.c_str(), std::ios_base::app);
for(int i = 0; i < 12; ++i) {file << torque[0][i].item<double>() << ",";} for(int i = 0; i < 12; ++i) { file << torque[0][i].item<double>() << ","; }
for(int i = 0; i < 12; ++i) {file << tau_est[0][i].item<double>() << ",";} for(int i = 0; i < 12; ++i) { file << tau_est[0][i].item<double>() << ","; }
for(int i = 0; i < 12; ++i) {file << joint_pos[0][i].item<double>() << ",";} for(int i = 0; i < 12; ++i) { file << joint_pos[0][i].item<double>() << ","; }
for(int i = 0; i < 12; ++i) {file << joint_pos_target[0][i].item<double>() << ",";} for(int i = 0; i < 12; ++i) { file << joint_pos_target[0][i].item<double>() << ","; }
for(int i = 0; i < 12; ++i) {file << joint_vel[0][i].item<double>() << ",";} for(int i = 0; i < 12; ++i) { file << joint_vel[0][i].item<double>() << ","; }
file << std::endl; file << std::endl;

View File

@ -8,14 +8,15 @@
#include <yaml-cpp/yaml.h> #include <yaml-cpp/yaml.h>
namespace LOGGER { namespace LOGGER
const char* const INFO = "\033[0;37m[INFO]\033[0m "; {
const char* const WARNING = "\033[0;33m[WARNING]\033[0m "; const char *const INFO = "\033[0;37m[INFO]\033[0m ";
const char* const ERROR = "\033[0;31m[ERROR]\033[0m "; const char *const WARNING = "\033[0;33m[WARNING]\033[0m ";
const char* const DEBUG = "\033[0;32m[DEBUG]\033[0m "; const char *const ERROR = "\033[0;31m[ERROR]\033[0m ";
const char *const DEBUG = "\033[0;32m[DEBUG]\033[0m ";
} }
template<typename T> template <typename T>
struct RobotCommand struct RobotCommand
{ {
struct MotorCommand struct MotorCommand
@ -28,7 +29,7 @@ struct RobotCommand
} motor_command; } motor_command;
}; };
template<typename T> template <typename T>
struct RobotState struct RobotState
{ {
struct IMU struct IMU
@ -48,7 +49,8 @@ struct RobotState
} motor_state; } motor_state;
}; };
enum STATE { enum STATE
{
STATE_WAITING = 0, STATE_WAITING = 0,
STATE_POS_GETUP, STATE_POS_GETUP,
STATE_RL_INIT, STATE_RL_INIT,
@ -113,8 +115,8 @@ struct Observations
class RL class RL
{ {
public: public:
RL(){}; RL() {};
~RL(){}; ~RL() {};
ModelParams params; ModelParams params;
Observations obs; Observations obs;
@ -135,7 +137,7 @@ public:
void StateController(const RobotState<double> *state, RobotCommand<double> *command); void StateController(const RobotState<double> *state, RobotCommand<double> *command);
torch::Tensor ComputeTorques(torch::Tensor actions); torch::Tensor ComputeTorques(torch::Tensor actions);
torch::Tensor ComputePosition(torch::Tensor actions); torch::Tensor ComputePosition(torch::Tensor actions);
torch::Tensor QuatRotateInverse(torch::Tensor q, torch::Tensor v, const std::string& framework); torch::Tensor QuatRotateInverse(torch::Tensor q, torch::Tensor v, const std::string &framework);
// yaml params // yaml params
void ReadYaml(std::string robot_name); void ReadYaml(std::string robot_name);
@ -165,4 +167,4 @@ protected:
torch::Tensor output_dof_pos; torch::Tensor output_dof_pos;
}; };
#endif #endif // RL_SDK_HPP

View File

@ -234,4 +234,3 @@ class RL_Sim(RL):
if __name__ == "__main__": if __name__ == "__main__":
rl_sim = RL_Sim() rl_sim = RL_Sim()
rospy.spin() rospy.spin()

View File

@ -28,8 +28,8 @@ RL_Real::RL_Real() : unitree_safe(UNITREE_LEGGED_SDK::LeggedType::A1), unitree_u
this->model = torch::jit::load(model_path); this->model = torch::jit::load(model_path);
// loop // loop
this->loop_udpSend = std::make_shared<LoopFunc>("loop_udpSend" , 0.002, std::bind(&RL_Real::UDPSend, this), 3); this->loop_udpSend = std::make_shared<LoopFunc>("loop_udpSend", 0.002, std::bind(&RL_Real::UDPSend, this), 3);
this->loop_udpRecv = std::make_shared<LoopFunc>("loop_udpRecv" , 0.002, std::bind(&RL_Real::UDPRecv, this), 3); this->loop_udpRecv = std::make_shared<LoopFunc>("loop_udpRecv", 0.002, std::bind(&RL_Real::UDPRecv, this), 3);
this->loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05, std::bind(&RL_Real::KeyboardInterface, this)); this->loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05, std::bind(&RL_Real::KeyboardInterface, this));
this->loop_control = std::make_shared<LoopFunc>("loop_control", this->params.dt, std::bind(&RL_Real::RobotControl, this)); this->loop_control = std::make_shared<LoopFunc>("loop_control", this->params.dt, std::bind(&RL_Real::RobotControl, this));
this->loop_rl = std::make_shared<LoopFunc>("loop_rl", this->params.dt * this->params.decimation, std::bind(&RL_Real::RunModel, this)); this->loop_rl = std::make_shared<LoopFunc>("loop_rl", this->params.dt * this->params.decimation, std::bind(&RL_Real::RunModel, this));
@ -39,14 +39,13 @@ RL_Real::RL_Real() : unitree_safe(UNITREE_LEGGED_SDK::LeggedType::A1), unitree_u
this->loop_control->start(); this->loop_control->start();
this->loop_rl->start(); this->loop_rl->start();
#ifdef PLOT #ifdef PLOT
this->plot_t = std::vector<int>(this->plot_size, 0); this->plot_t = std::vector<int>(this->plot_size, 0);
this->plot_real_joint_pos.resize(this->params.num_of_dofs); this->plot_real_joint_pos.resize(this->params.num_of_dofs);
this->plot_target_joint_pos.resize(this->params.num_of_dofs); this->plot_target_joint_pos.resize(this->params.num_of_dofs);
for(auto& vector : this->plot_real_joint_pos) { vector = std::vector<double>(this->plot_size, 0); } for (auto &vector : this->plot_real_joint_pos) { vector = std::vector<double>(this->plot_size, 0); }
for(auto& vector : this->plot_target_joint_pos) { vector = std::vector<double>(this->plot_size, 0); } for (auto &vector : this->plot_target_joint_pos) { vector = std::vector<double>(this->plot_size, 0); }
this->loop_plot = std::make_shared<LoopFunc>("loop_plot" , 0.002, std::bind(&RL_Real::Plot, this)); this->loop_plot = std::make_shared<LoopFunc>("loop_plot", 0.002, std::bind(&RL_Real::Plot, this));
this->loop_plot->start(); this->loop_plot->start();
#endif #endif
#ifdef CSV_LOGGER #ifdef CSV_LOGGER
@ -72,27 +71,27 @@ void RL_Real::GetState(RobotState<double> *state)
this->unitree_udp.GetRecv(this->unitree_low_state); this->unitree_udp.GetRecv(this->unitree_low_state);
memcpy(&this->unitree_joy, this->unitree_low_state.wirelessRemote, 40); memcpy(&this->unitree_joy, this->unitree_low_state.wirelessRemote, 40);
if((int)this->unitree_joy.btn.components.R2 == 1) if ((int)this->unitree_joy.btn.components.R2 == 1)
{ {
this->control.control_state = STATE_POS_GETUP; this->control.control_state = STATE_POS_GETUP;
} }
else if((int)this->unitree_joy.btn.components.R1 == 1) else if ((int)this->unitree_joy.btn.components.R1 == 1)
{ {
this->control.control_state = STATE_RL_INIT; this->control.control_state = STATE_RL_INIT;
} }
else if((int)this->unitree_joy.btn.components.L2 == 1) else if ((int)this->unitree_joy.btn.components.L2 == 1)
{ {
this->control.control_state = STATE_POS_GETDOWN; this->control.control_state = STATE_POS_GETDOWN;
} }
if(this->params.framework == "isaacgym") if (this->params.framework == "isaacgym")
{ {
state->imu.quaternion[3] = this->unitree_low_state.imu.quaternion[0]; // w state->imu.quaternion[3] = this->unitree_low_state.imu.quaternion[0]; // w
state->imu.quaternion[0] = this->unitree_low_state.imu.quaternion[1]; // x state->imu.quaternion[0] = this->unitree_low_state.imu.quaternion[1]; // x
state->imu.quaternion[1] = this->unitree_low_state.imu.quaternion[2]; // y state->imu.quaternion[1] = this->unitree_low_state.imu.quaternion[2]; // y
state->imu.quaternion[2] = this->unitree_low_state.imu.quaternion[3]; // z state->imu.quaternion[2] = this->unitree_low_state.imu.quaternion[3]; // z
} }
else if(this->params.framework == "isaacsim") else if (this->params.framework == "isaacsim")
{ {
state->imu.quaternion[0] = this->unitree_low_state.imu.quaternion[0]; // w state->imu.quaternion[0] = this->unitree_low_state.imu.quaternion[0]; // w
state->imu.quaternion[1] = this->unitree_low_state.imu.quaternion[1]; // x state->imu.quaternion[1] = this->unitree_low_state.imu.quaternion[1]; // x
@ -100,11 +99,11 @@ void RL_Real::GetState(RobotState<double> *state)
state->imu.quaternion[3] = this->unitree_low_state.imu.quaternion[3]; // z state->imu.quaternion[3] = this->unitree_low_state.imu.quaternion[3]; // z
} }
for(int i = 0; i < 3; ++i) for (int i = 0; i < 3; ++i)
{ {
state->imu.gyroscope[i] = this->unitree_low_state.imu.gyroscope[i]; state->imu.gyroscope[i] = this->unitree_low_state.imu.gyroscope[i];
} }
for(int i = 0; i < this->params.num_of_dofs; ++i) for (int i = 0; i < this->params.num_of_dofs; ++i)
{ {
state->motor_state.q[i] = this->unitree_low_state.motorState[state_mapping[i]].q; state->motor_state.q[i] = this->unitree_low_state.motorState[state_mapping[i]].q;
state->motor_state.dq[i] = this->unitree_low_state.motorState[state_mapping[i]].dq; state->motor_state.dq[i] = this->unitree_low_state.motorState[state_mapping[i]].dq;
@ -114,7 +113,7 @@ void RL_Real::GetState(RobotState<double> *state)
void RL_Real::SetCommand(const RobotCommand<double> *command) void RL_Real::SetCommand(const RobotCommand<double> *command)
{ {
for(int i = 0; i < this->params.num_of_dofs; ++i) for (int i = 0; i < this->params.num_of_dofs; ++i)
{ {
this->unitree_low_command.motorCmd[i].mode = 0x0A; this->unitree_low_command.motorCmd[i].mode = 0x0A;
this->unitree_low_command.motorCmd[i].q = command->motor_command.q[command_mapping[i]]; this->unitree_low_command.motorCmd[i].q = command->motor_command.q[command_mapping[i]];
@ -140,7 +139,7 @@ void RL_Real::RobotControl()
void RL_Real::RunModel() void RL_Real::RunModel()
{ {
if(this->running_state == STATE_RL_RUNNING) if (this->running_state == STATE_RL_RUNNING)
{ {
this->obs.ang_vel = torch::tensor(this->robot_state.imu.gyroscope).unsqueeze(0); this->obs.ang_vel = torch::tensor(this->robot_state.imu.gyroscope).unsqueeze(0);
this->obs.commands = torch::tensor({{this->unitree_joy.ly, -this->unitree_joy.rx, -this->unitree_joy.lx}}); this->obs.commands = torch::tensor({{this->unitree_joy.ly, -this->unitree_joy.rx, -this->unitree_joy.lx}});
@ -182,7 +181,7 @@ torch::Tensor RL_Real::Forward()
torch::Tensor actions = this->model.forward({this->history_obs}).toTensor(); torch::Tensor actions = this->model.forward({this->history_obs}).toTensor();
if(this->params.clip_actions_upper.numel() != 0 && this->params.clip_actions_lower.numel() != 0) if (this->params.clip_actions_upper.numel() != 0 && this->params.clip_actions_lower.numel() != 0)
{ {
return torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper); return torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper);
} }
@ -198,7 +197,7 @@ void RL_Real::Plot()
this->plot_t.push_back(this->motiontime); this->plot_t.push_back(this->motiontime);
plt::cla(); plt::cla();
plt::clf(); plt::clf();
for(int i = 0; i < this->params.num_of_dofs; ++i) for (int i = 0; i < this->params.num_of_dofs; ++i)
{ {
this->plot_real_joint_pos[i].erase(this->plot_real_joint_pos[i].begin()); this->plot_real_joint_pos[i].erase(this->plot_real_joint_pos[i].begin());
this->plot_target_joint_pos[i].erase(this->plot_target_joint_pos[i].begin()); this->plot_target_joint_pos[i].erase(this->plot_target_joint_pos[i].begin());
@ -222,10 +221,10 @@ int main(int argc, char **argv)
{ {
signal(SIGINT, signalHandler); signal(SIGINT, signalHandler);
while(1) while (1)
{ {
sleep(10); sleep(10);
}; }
return 0; return 0;
} }

View File

@ -12,7 +12,7 @@ RL_Sim::RL_Sim()
this->ReadYaml(this->robot_name); this->ReadYaml(this->robot_name);
// history // history
if(this->params.use_history) if (this->params.use_history)
{ {
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6); this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
} }
@ -21,7 +21,7 @@ RL_Sim::RL_Sim()
// the mapping table is established according to the order defined in the YAML file // the mapping table is established according to the order defined in the YAML file
std::vector<std::string> sorted_joint_controller_names = this->params.joint_controller_names; std::vector<std::string> sorted_joint_controller_names = this->params.joint_controller_names;
std::sort(sorted_joint_controller_names.begin(), sorted_joint_controller_names.end()); std::sort(sorted_joint_controller_names.begin(), sorted_joint_controller_names.end());
for(size_t i = 0; i < this->params.joint_controller_names.size(); ++i) for (size_t i = 0; i < this->params.joint_controller_names.size(); ++i)
{ {
this->sorted_to_original_index[sorted_joint_controller_names[i]] = i; this->sorted_to_original_index[sorted_joint_controller_names[i]] = i;
} }
@ -46,8 +46,8 @@ RL_Sim::RL_Sim()
for (int i = 0; i < this->params.num_of_dofs; ++i) for (int i = 0; i < this->params.num_of_dofs; ++i)
{ {
// joint need to rename as xxx_joint // joint need to rename as xxx_joint
this->joint_publishers[this->params.joint_controller_names[i]] = nh.advertise<robot_msgs::MotorCommand>( this->joint_publishers[this->params.joint_controller_names[i]] =
this->ros_namespace + this->params.joint_controller_names[i] + "/command", 10); nh.advertise<robot_msgs::MotorCommand>(this->ros_namespace + this->params.joint_controller_names[i] + "/command", 10);
} }
// subscriber // subscriber
@ -75,9 +75,9 @@ RL_Sim::RL_Sim()
this->plot_t = std::vector<int>(this->plot_size, 0); this->plot_t = std::vector<int>(this->plot_size, 0);
this->plot_real_joint_pos.resize(this->params.num_of_dofs); this->plot_real_joint_pos.resize(this->params.num_of_dofs);
this->plot_target_joint_pos.resize(this->params.num_of_dofs); this->plot_target_joint_pos.resize(this->params.num_of_dofs);
for(auto& vector : this->plot_real_joint_pos) { vector = std::vector<double>(this->plot_size, 0); } for (auto &vector : this->plot_real_joint_pos) { vector = std::vector<double>(this->plot_size, 0); }
for(auto& vector : this->plot_target_joint_pos) { vector = std::vector<double>(this->plot_size, 0); } for (auto &vector : this->plot_target_joint_pos) { vector = std::vector<double>(this->plot_size, 0); }
this->loop_plot = std::make_shared<LoopFunc>("loop_plot" , 0.001 , std::bind(&RL_Sim::Plot , this)); this->loop_plot = std::make_shared<LoopFunc>("loop_plot", 0.001, std::bind(&RL_Sim::Plot, this));
this->loop_plot->start(); this->loop_plot->start();
#endif #endif
#ifdef CSV_LOGGER #ifdef CSV_LOGGER
@ -100,14 +100,14 @@ RL_Sim::~RL_Sim()
void RL_Sim::GetState(RobotState<double> *state) void RL_Sim::GetState(RobotState<double> *state)
{ {
if(this->params.framework == "isaacgym") if (this->params.framework == "isaacgym")
{ {
state->imu.quaternion[3] = this->pose.orientation.w; state->imu.quaternion[3] = this->pose.orientation.w;
state->imu.quaternion[0] = this->pose.orientation.x; state->imu.quaternion[0] = this->pose.orientation.x;
state->imu.quaternion[1] = this->pose.orientation.y; state->imu.quaternion[1] = this->pose.orientation.y;
state->imu.quaternion[2] = this->pose.orientation.z; state->imu.quaternion[2] = this->pose.orientation.z;
} }
else if(this->params.framework == "isaacsim") else if (this->params.framework == "isaacsim")
{ {
state->imu.quaternion[0] = this->pose.orientation.w; state->imu.quaternion[0] = this->pose.orientation.w;
state->imu.quaternion[1] = this->pose.orientation.x; state->imu.quaternion[1] = this->pose.orientation.x;
@ -121,7 +121,7 @@ void RL_Sim::GetState(RobotState<double> *state)
// state->imu.accelerometer // state->imu.accelerometer
for(int i = 0; i < this->params.num_of_dofs; ++i) for (int i = 0; i < this->params.num_of_dofs; ++i)
{ {
state->motor_state.q[i] = this->mapped_joint_positions[i]; state->motor_state.q[i] = this->mapped_joint_positions[i];
state->motor_state.dq[i] = this->mapped_joint_velocities[i]; state->motor_state.dq[i] = this->mapped_joint_velocities[i];
@ -131,7 +131,7 @@ void RL_Sim::GetState(RobotState<double> *state)
void RL_Sim::SetCommand(const RobotCommand<double> *command) void RL_Sim::SetCommand(const RobotCommand<double> *command)
{ {
for(int i = 0; i < this->params.num_of_dofs; ++i) for (int i = 0; i < this->params.num_of_dofs; ++i)
{ {
this->joint_publishers_commands[i].q = command->motor_command.q[i]; this->joint_publishers_commands[i].q = command->motor_command.q[i];
this->joint_publishers_commands[i].dq = command->motor_command.dq[i]; this->joint_publishers_commands[i].dq = command->motor_command.dq[i];
@ -140,7 +140,7 @@ void RL_Sim::SetCommand(const RobotCommand<double> *command)
this->joint_publishers_commands[i].tau = command->motor_command.tau[i]; this->joint_publishers_commands[i].tau = command->motor_command.tau[i];
} }
for(int i = 0; i < this->params.num_of_dofs; ++i) for (int i = 0; i < this->params.num_of_dofs; ++i)
{ {
this->joint_publishers[this->params.joint_controller_names[i]].publish(this->joint_publishers_commands[i]); this->joint_publishers[this->params.joint_controller_names[i]].publish(this->joint_publishers_commands[i]);
} }
@ -148,7 +148,7 @@ void RL_Sim::SetCommand(const RobotCommand<double> *command)
void RL_Sim::RobotControl() void RL_Sim::RobotControl()
{ {
if(this->control.control_state == STATE_RESET_SIMULATION) if (this->control.control_state == STATE_RESET_SIMULATION)
{ {
gazebo_msgs::SetModelState set_model_state; gazebo_msgs::SetModelState set_model_state;
set_model_state.request.model_state.model_name = this->gazebo_model_name; set_model_state.request.model_state.model_name = this->gazebo_model_name;
@ -158,10 +158,10 @@ void RL_Sim::RobotControl()
this->control.control_state = STATE_WAITING; this->control.control_state = STATE_WAITING;
} }
if(this->control.control_state == STATE_TOGGLE_SIMULATION) if (this->control.control_state == STATE_TOGGLE_SIMULATION)
{ {
std_srvs::Empty empty; std_srvs::Empty empty;
if(simulation_running) if (simulation_running)
{ {
this->gazebo_pause_physics_client.call(empty); this->gazebo_pause_physics_client.call(empty);
std::cout << std::endl << LOGGER::INFO << "Simulation Stop" << std::endl; std::cout << std::endl << LOGGER::INFO << "Simulation Stop" << std::endl;
@ -174,7 +174,7 @@ void RL_Sim::RobotControl()
simulation_running = !simulation_running; simulation_running = !simulation_running;
this->control.control_state = STATE_WAITING; this->control.control_state = STATE_WAITING;
} }
if(simulation_running) if (simulation_running)
{ {
this->motiontime++; this->motiontime++;
this->GetState(&this->robot_state); this->GetState(&this->robot_state);
@ -194,9 +194,9 @@ void RL_Sim::CmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg)
this->cmd_vel = *msg; this->cmd_vel = *msg;
} }
void RL_Sim::MapData(const std::vector<double>& source_data, std::vector<double>& target_data) void RL_Sim::MapData(const std::vector<double> &source_data, std::vector<double> &target_data)
{ {
for(size_t i = 0; i < source_data.size(); ++i) for (size_t i = 0; i < source_data.size(); ++i)
{ {
target_data[i] = source_data[this->sorted_to_original_index[this->params.joint_controller_names[i]]]; target_data[i] = source_data[this->sorted_to_original_index[this->params.joint_controller_names[i]]];
} }
@ -211,7 +211,7 @@ void RL_Sim::JointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg)
void RL_Sim::RunModel() void RL_Sim::RunModel()
{ {
if(this->running_state == STATE_RL_RUNNING && simulation_running) if (this->running_state == STATE_RL_RUNNING && simulation_running)
{ {
this->obs.lin_vel = torch::tensor({{this->vel.linear.x, this->vel.linear.y, this->vel.linear.z}}); this->obs.lin_vel = torch::tensor({{this->vel.linear.x, this->vel.linear.y, this->vel.linear.z}});
this->obs.ang_vel = torch::tensor(this->robot_state.imu.gyroscope).unsqueeze(0); this->obs.ang_vel = torch::tensor(this->robot_state.imu.gyroscope).unsqueeze(0);
@ -249,7 +249,7 @@ torch::Tensor RL_Sim::Forward()
torch::autograd::GradMode::set_enabled(false); torch::autograd::GradMode::set_enabled(false);
torch::Tensor clamped_obs = this->ComputeObservation(); torch::Tensor clamped_obs = this->ComputeObservation();
torch::Tensor actions; torch::Tensor actions;
if(this->params.use_history) if (this->params.use_history)
{ {
this->history_obs_buf.insert(clamped_obs); this->history_obs_buf.insert(clamped_obs);
this->history_obs = this->history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5}); this->history_obs = this->history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5});
@ -260,7 +260,7 @@ torch::Tensor RL_Sim::Forward()
actions = this->model.forward({clamped_obs}).toTensor(); actions = this->model.forward({clamped_obs}).toTensor();
} }
if(this->params.clip_actions_upper.numel() != 0 && this->params.clip_actions_lower.numel() != 0) if (this->params.clip_actions_upper.numel() != 0 && this->params.clip_actions_lower.numel() != 0)
{ {
return torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper); return torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper);
} }
@ -276,13 +276,13 @@ void RL_Sim::Plot()
this->plot_t.push_back(this->motiontime); this->plot_t.push_back(this->motiontime);
plt::cla(); plt::cla();
plt::clf(); plt::clf();
for(int i = 0; i < this->params.num_of_dofs; ++i) for (int i = 0; i < this->params.num_of_dofs; ++i)
{ {
this->plot_real_joint_pos[i].erase(this->plot_real_joint_pos[i].begin()); this->plot_real_joint_pos[i].erase(this->plot_real_joint_pos[i].begin());
this->plot_target_joint_pos[i].erase(this->plot_target_joint_pos[i].begin()); this->plot_target_joint_pos[i].erase(this->plot_target_joint_pos[i].begin());
this->plot_real_joint_pos[i].push_back(this->mapped_joint_positions[i]); this->plot_real_joint_pos[i].push_back(this->mapped_joint_positions[i]);
this->plot_target_joint_pos[i].push_back(this->joint_publishers_commands[i].q); this->plot_target_joint_pos[i].push_back(this->joint_publishers_commands[i].q);
plt::subplot(4, 3, i+1); plt::subplot(4, 3, i + 1);
plt::named_plot("_real_joint_pos", this->plot_t, this->plot_real_joint_pos[i], "r"); plt::named_plot("_real_joint_pos", this->plot_t, this->plot_real_joint_pos[i], "r");
plt::named_plot("_target_joint_pos", this->plot_t, this->plot_target_joint_pos[i], "b"); plt::named_plot("_target_joint_pos", this->plot_t, this->plot_target_joint_pos[i], "b");
plt::xlim(this->plot_t.front(), this->plot_t.back()); plt::xlim(this->plot_t.front(), this->plot_t.back());

View File

@ -35,15 +35,15 @@ typedef struct
namespace robot_joint_controller namespace robot_joint_controller
{ {
class RobotJointController: public controller_interface::Controller<hardware_interface::EffortJointInterface> class RobotJointController : public controller_interface::Controller<hardware_interface::EffortJointInterface>
{ {
private: private:
hardware_interface::JointHandle joint; hardware_interface::JointHandle joint;
ros::Subscriber sub_command, sub_ft; ros::Subscriber sub_command, sub_ft;
control_toolbox::Pid pid_controller_; control_toolbox::Pid pid_controller_;
std::unique_ptr<realtime_tools::RealtimePublisher<robot_msgs::MotorState> > controller_state_publisher_ ; std::unique_ptr<realtime_tools::RealtimePublisher<robot_msgs::MotorState>> controller_state_publisher_;
public: public:
std::string name_space; std::string name_space;
std::string joint_name; std::string joint_name;
urdf::JointConstSharedPtr joint_urdf; urdf::JointConstSharedPtr joint_urdf;
@ -55,10 +55,10 @@ public:
RobotJointController(); RobotJointController();
~RobotJointController(); ~RobotJointController();
virtual bool init(hardware_interface::EffortJointInterface *robot, ros::NodeHandle &n); virtual bool init(hardware_interface::EffortJointInterface *robot, ros::NodeHandle &n);
virtual void starting(const ros::Time& time); virtual void starting(const ros::Time &time);
virtual void update(const ros::Time& time, const ros::Duration& period); virtual void update(const ros::Time &time, const ros::Duration &period);
virtual void stopping(); virtual void stopping();
void setCommandCB(const robot_msgs::MotorCommandConstPtr& msg); void setCommandCB(const robot_msgs::MotorCommandConstPtr &msg);
void positionLimits(double &position); void positionLimits(double &position);
void velocityLimits(double &velocity); void velocityLimits(double &velocity);
void effortLimits(double &effort); void effortLimits(double &effort);
@ -66,7 +66,6 @@ public:
void setGains(const double &p, const double &i, const double &d, const double &i_max, const double &i_min, const bool &antiwindup = false); void setGains(const double &p, const double &i, const double &d, const double &i_max, const double &i_min, const bool &antiwindup = false);
void getGains(double &p, double &i, double &d, double &i_max, double &i_min, bool &antiwindup); void getGains(double &p, double &i, double &d, double &i_max, double &i_min, bool &antiwindup);
void getGains(double &p, double &i, double &d, double &i_max, double &i_min); void getGains(double &p, double &i, double &d, double &i_max, double &i_min);
}; };
} }

View File

@ -3,10 +3,14 @@
// #define rqtTune // use rqt or not // #define rqtTune // use rqt or not
double clamp(double& value, double min, double max) { double clamp(double &value, double min, double max)
if (value < min) { {
if (value < min)
{
value = min; value = min;
} else if (value > max) { }
else if (value > max)
{
value = max; value = max;
} }
return value; return value;
@ -15,18 +19,20 @@ double clamp(double& value, double min, double max) {
namespace robot_joint_controller namespace robot_joint_controller
{ {
RobotJointController::RobotJointController(){ RobotJointController::RobotJointController()
{
memset(&lastCommand, 0, sizeof(robot_msgs::MotorCommand)); memset(&lastCommand, 0, sizeof(robot_msgs::MotorCommand));
memset(&lastState, 0, sizeof(robot_msgs::MotorState)); memset(&lastState, 0, sizeof(robot_msgs::MotorState));
memset(&servoCommand, 0, sizeof(ServoCommand)); memset(&servoCommand, 0, sizeof(ServoCommand));
} }
RobotJointController::~RobotJointController(){ RobotJointController::~RobotJointController()
{
sub_ft.shutdown(); sub_ft.shutdown();
sub_command.shutdown(); sub_command.shutdown();
} }
void RobotJointController::setCommandCB(const robot_msgs::MotorCommandConstPtr& msg) void RobotJointController::setCommandCB(const robot_msgs::MotorCommandConstPtr &msg)
{ {
lastCommand.q = msg->q; lastCommand.q = msg->q;
lastCommand.kp = msg->kp; lastCommand.kp = msg->kp;
@ -43,7 +49,8 @@ namespace robot_joint_controller
bool RobotJointController::init(hardware_interface::EffortJointInterface *robot, ros::NodeHandle &n) bool RobotJointController::init(hardware_interface::EffortJointInterface *robot, ros::NodeHandle &n)
{ {
name_space = n.getNamespace(); name_space = n.getNamespace();
if (!n.getParam("joint", joint_name)){ if (!n.getParam("joint", joint_name))
{
ROS_ERROR("No joint given in namespace: '%s')", n.getNamespace().c_str()); ROS_ERROR("No joint given in namespace: '%s')", n.getNamespace().c_str());
return false; return false;
} }
@ -56,12 +63,14 @@ namespace robot_joint_controller
#endif #endif
urdf::Model urdf; // Get URDF info about joint urdf::Model urdf; // Get URDF info about joint
if (!urdf.initParamWithNodeHandle("robot_description", n)){ if (!urdf.initParamWithNodeHandle("robot_description", n))
{
ROS_ERROR("Failed to parse urdf file"); ROS_ERROR("Failed to parse urdf file");
return false; return false;
} }
joint_urdf = urdf.getJoint(joint_name); joint_urdf = urdf.getJoint(joint_name);
if (!joint_urdf){ if (!joint_urdf)
{
ROS_ERROR("Could not find joint '%s' in urdf", joint_name.c_str()); ROS_ERROR("Could not find joint '%s' in urdf", joint_name.c_str());
return false; return false;
} }
@ -79,22 +88,22 @@ namespace robot_joint_controller
void RobotJointController::setGains(const double &p, const double &i, const double &d, const double &i_max, const double &i_min, const bool &antiwindup) void RobotJointController::setGains(const double &p, const double &i, const double &d, const double &i_max, const double &i_min, const bool &antiwindup)
{ {
pid_controller_.setGains(p,i,d,i_max,i_min,antiwindup); pid_controller_.setGains(p, i, d, i_max, i_min, antiwindup);
} }
void RobotJointController::getGains(double &p, double &i, double &d, double &i_max, double &i_min, bool &antiwindup) void RobotJointController::getGains(double &p, double &i, double &d, double &i_max, double &i_min, bool &antiwindup)
{ {
pid_controller_.getGains(p,i,d,i_max,i_min,antiwindup); pid_controller_.getGains(p, i, d, i_max, i_min, antiwindup);
} }
void RobotJointController::getGains(double &p, double &i, double &d, double &i_max, double &i_min) void RobotJointController::getGains(double &p, double &i, double &d, double &i_max, double &i_min)
{ {
bool dummy; bool dummy;
pid_controller_.getGains(p,i,d,i_max,i_min,dummy); pid_controller_.getGains(p, i, d, i_max, i_min, dummy);
} }
// Controller startup in realtime // Controller startup in realtime
void RobotJointController::starting(const ros::Time& time) void RobotJointController::starting(const ros::Time &time)
{ {
double init_pos = joint.getPosition(); double init_pos = joint.getPosition();
lastCommand.q = init_pos; lastCommand.q = init_pos;
@ -109,7 +118,7 @@ namespace robot_joint_controller
} }
// Controller update loop in realtime // Controller update loop in realtime
void RobotJointController::update(const ros::Time& time, const ros::Duration& period) void RobotJointController::update(const ros::Time &time, const ros::Duration &period)
{ {
double currentPos, currentVel, calcTorque; double currentPos, currentVel, calcTorque;
lastCommand = *(command.readFromRT()); lastCommand = *(command.readFromRT());
@ -118,13 +127,15 @@ namespace robot_joint_controller
servoCommand.pos = lastCommand.q; servoCommand.pos = lastCommand.q;
positionLimits(servoCommand.pos); positionLimits(servoCommand.pos);
servoCommand.posStiffness = lastCommand.kp; servoCommand.posStiffness = lastCommand.kp;
if(fabs(lastCommand.q - PosStopF) < 0.00001){ if (fabs(lastCommand.q - PosStopF) < 0.00001)
{
servoCommand.posStiffness = 0; servoCommand.posStiffness = 0;
} }
servoCommand.vel = lastCommand.dq; servoCommand.vel = lastCommand.dq;
velocityLimits(servoCommand.vel); velocityLimits(servoCommand.vel);
servoCommand.velStiffness = lastCommand.kd; servoCommand.velStiffness = lastCommand.kd;
if(fabs(lastCommand.dq - VelStopF) < 0.00001){ if (fabs(lastCommand.dq - VelStopF) < 0.00001)
{
servoCommand.velStiffness = 0; servoCommand.velStiffness = 0;
} }
servoCommand.torque = lastCommand.tau; servoCommand.torque = lastCommand.tau;
@ -133,7 +144,7 @@ namespace robot_joint_controller
// rqt set P D gains // rqt set P D gains
#ifdef rqtTune #ifdef rqtTune
double i, i_max, i_min; double i, i_max, i_min;
getGains(servoCommand.posStiffness,i,servoCommand.velStiffness,i_max,i_min); getGains(servoCommand.posStiffness, i, servoCommand.velStiffness, i_max, i_min);
#endif #endif
currentPos = joint.getPosition(); currentPos = joint.getPosition();
@ -151,7 +162,8 @@ namespace robot_joint_controller
lastState.tauEst = joint.getEffort(); lastState.tauEst = joint.getEffort();
// publish state // publish state
if (controller_state_publisher_ && controller_state_publisher_->trylock()) { if (controller_state_publisher_ && controller_state_publisher_->trylock())
{
controller_state_publisher_->msg_.q = lastState.q; controller_state_publisher_->msg_.q = lastState.q;
controller_state_publisher_->msg_.dq = lastState.dq; controller_state_publisher_->msg_.dq = lastState.dq;
controller_state_publisher_->msg_.tauEst = lastState.tauEst; controller_state_publisher_->msg_.tauEst = lastState.tauEst;
@ -160,7 +172,7 @@ namespace robot_joint_controller
} }
// Controller stopping in realtime // Controller stopping in realtime
void RobotJointController::stopping(){} void RobotJointController::stopping() {}
void RobotJointController::positionLimits(double &position) void RobotJointController::positionLimits(double &position)
{ {