style: code format

This commit is contained in:
fan-ziqi 2024-10-08 19:45:59 +08:00
parent 10808335fd
commit 8b6a2d6af1
22 changed files with 317 additions and 270 deletions

View File

@ -15,18 +15,18 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${GAZEBO_CXX_FLAGS}")
find_package(gazebo REQUIRED)
find_package(catkin REQUIRED COMPONENTS
controller_manager
genmsg
joint_state_controller
robot_state_publisher
roscpp
gazebo_ros
std_msgs
tf
geometry_msgs
robot_msgs
robot_joint_controller
rospy
controller_manager
genmsg
joint_state_controller
robot_state_publisher
roscpp
gazebo_ros
std_msgs
tf
geometry_msgs
robot_msgs
robot_joint_controller
rospy
)
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
@ -36,9 +36,9 @@ link_directories(/usr/local/lib)
include_directories(${YAML_CPP_INCLUDE_DIR})
catkin_package(
CATKIN_DEPENDS
robot_joint_controller
rospy
CATKIN_DEPENDS
robot_joint_controller
rospy
)
include_directories(library/unitree_legged_sdk_3.2/include)
@ -46,13 +46,13 @@ link_directories(library/unitree_legged_sdk_3.2/lib)
set(EXTRA_LIBS -pthread libunitree_legged_sdk_amd64.so lcm)
include_directories(
include
${catkin_INCLUDE_DIRS}
${unitree_legged_sdk_INCLUDE_DIRS}
library/matplotlibcpp
library/observation_buffer
library/rl_sdk
library/loop
include
${catkin_INCLUDE_DIRS}
${unitree_legged_sdk_INCLUDE_DIRS}
library/matplotlibcpp
library/observation_buffer
library/rl_sdk
library/loop
)
add_library(rl_sdk library/rl_sdk/rl_sdk.cpp)
@ -60,9 +60,9 @@ target_link_libraries(rl_sdk "${TORCH_LIBRARIES}" Python3::Python Python3::Modul
set_property(TARGET rl_sdk PROPERTY CXX_STANDARD 14)
find_package(Python3 COMPONENTS NumPy)
if(Python3_NumPy_FOUND)
target_link_libraries(rl_sdk Python3::NumPy)
target_link_libraries(rl_sdk Python3::NumPy)
else()
target_compile_definitions(rl_sdk WITHOUT_NUMPY)
target_compile_definitions(rl_sdk WITHOUT_NUMPY)
endif()
add_library(observation_buffer library/observation_buffer/observation_buffer.cpp)
@ -71,13 +71,13 @@ set_property(TARGET observation_buffer PROPERTY CXX_STANDARD 14)
add_executable(rl_sim src/rl_sim.cpp )
target_link_libraries(rl_sim
${catkin_LIBRARIES} ${EXTRA_LIBS}
${catkin_LIBRARIES} ${EXTRA_LIBS}
rl_sdk observation_buffer yaml-cpp
)
add_executable(rl_real_a1 src/rl_real_a1.cpp )
target_link_libraries(rl_real_a1
${catkin_LIBRARIES} ${EXTRA_LIBS}
${catkin_LIBRARIES} ${EXTRA_LIBS}
rl_sdk observation_buffer yaml-cpp
)

View File

@ -16,6 +16,7 @@ class RL_Real : public RL
public:
RL_Real();
~RL_Real();
private:
// rl functions
torch::Tensor Forward() override;
@ -43,8 +44,8 @@ private:
void Plot();
// unitree interface
void UDPSend(){unitree_udp.Send();}
void UDPRecv(){unitree_udp.Recv();}
void UDPSend() { unitree_udp.Send(); }
void UDPRecv() { unitree_udp.Recv(); }
UNITREE_LEGGED_SDK::Safety unitree_safe;
UNITREE_LEGGED_SDK::UDP unitree_udp;
UNITREE_LEGGED_SDK::LowCmd unitree_low_command = {0};
@ -59,4 +60,4 @@ private:
int state_mapping[12] = {3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8};
};
#endif
#endif // RL_REAL_HPP

View File

@ -21,6 +21,7 @@ class RL_Sim : public RL
public:
RL_Sim();
~RL_Sim();
private:
// rl functions
torch::Tensor Forward() override;
@ -69,7 +70,7 @@ private:
std::vector<double> mapped_joint_positions;
std::vector<double> mapped_joint_velocities;
std::vector<double> mapped_joint_efforts;
void MapData(const std::vector<double>& source_data, std::vector<double>& target_data);
void MapData(const std::vector<double> &source_data, std::vector<double> &target_data);
};
#endif
#endif // RL_SIM_HPP

View File

@ -14,7 +14,7 @@
<arg name="debug" default="false"/>
<!-- Debug mode will hung up the robot, use "true" or "false" to switch it. -->
<arg name="user_debug" default="false"/>
<include file="$(find gazebo_ros)/launch/empty_world.launch">
<arg name="world_name" value="$(find rl_sar)/worlds/$(arg wname).world"/>
<arg name="debug" value="$(arg debug)"/>
@ -26,7 +26,7 @@
<!-- Load the URDF into the ROS Parameter Server -->
<param name="robot_description"
command="$(find xacro)/xacro --inorder '$(arg dollar)$(arg robot_path)/xacro/robot.xacro'
command="$(find xacro)/xacro --inorder '$(arg dollar)$(arg robot_path)/xacro/robot.xacro'
DEBUG:=$(arg user_debug)"/>
<!-- Run a python script to the send a service call to gazebo_ros to spawn a URDF robot -->

View File

@ -14,7 +14,7 @@
<arg name="debug" default="false"/>
<!-- Debug mode will hung up the robot, use "true" or "false" to switch it. -->
<arg name="user_debug" default="false"/>
<include file="$(find gazebo_ros)/launch/empty_world.launch">
<arg name="world_name" value="$(find rl_sar)/worlds/$(arg wname).world"/>
<arg name="debug" value="$(arg debug)"/>
@ -26,7 +26,7 @@
<!-- Load the URDF into the ROS Parameter Server -->
<param name="robot_description"
command="$(find xacro)/xacro --inorder '$(arg dollar)$(arg robot_path)/xacro/robot.xacro'
command="$(find xacro)/xacro --inorder '$(arg dollar)$(arg robot_path)/xacro/robot.xacro'
DEBUG:=$(arg user_debug)"/>
<!-- Run a python script to the send a service call to gazebo_ros to spawn a URDF robot -->

View File

@ -12,7 +12,7 @@
<arg name="gui" default="true"/>
<arg name="headless" default="false"/>
<arg name="debug" default="false"/>
<include file="$(find gazebo_ros)/launch/empty_world.launch">
<arg name="world_name" value="$(find rl_sar)/worlds/$(arg wname).world"/>
<arg name="debug" value="$(arg debug)"/>

View File

@ -12,7 +12,7 @@
<arg name="gui" default="true"/>
<arg name="headless" default="false"/>
<arg name="debug" default="false"/>
<include file="$(find gazebo_ros)/launch/empty_world.launch">
<arg name="world_name" value="$(find rl_sar)/worlds/$(arg wname).world"/>
<arg name="debug" value="$(arg debug)"/>

View File

@ -14,19 +14,9 @@
class LoopFunc
{
private:
std::string _name;
double _period;
std::function<void()> _func;
int _bindCPU;
std::atomic<bool> _running;
std::mutex _mutex;
std::condition_variable _cv;
std::thread _thread;
public:
public:
LoopFunc(const std::string &name, double period, std::function<void()> func, int bindCPU = -1)
: _name(name), _period(period), _func(func), _bindCPU(bindCPU), _running(false) {}
: _name(name), _period(period), _func(func), _bindCPU(bindCPU), _running(false) {}
void start()
{
@ -57,12 +47,22 @@ class LoopFunc
}
log("[Loop End] named: " + _name);
}
private:
private:
std::string _name;
double _period;
std::function<void()> _func;
int _bindCPU;
std::atomic<bool> _running;
std::mutex _mutex;
std::condition_variable _cv;
std::thread _thread;
void loop()
{
while (_running)
{
auto start = std::chrono::steady_clock::now();
while (_running)
{
auto start = std::chrono::steady_clock::now();
_func();
@ -72,7 +72,8 @@ class LoopFunc
if (sleepTime.count() > 0)
{
std::unique_lock<std::mutex> lock(_mutex);
if (_cv.wait_for(lock, sleepTime, [this]{ return !_running; }))
if (_cv.wait_for(lock, sleepTime, [this]
{ return !_running; }))
{
break;
}
@ -87,7 +88,7 @@ class LoopFunc
return stream.str();
}
void log(const std::string& message)
void log(const std::string &message)
{
static std::mutex logMutex;
std::lock_guard<std::mutex> lock(logMutex);
@ -108,4 +109,4 @@ class LoopFunc
}
};
#endif
#endif // LOOP_H

View File

@ -2,11 +2,11 @@
ObservationBuffer::ObservationBuffer() {}
ObservationBuffer::ObservationBuffer(int num_envs,
int num_obs,
int include_history_steps)
: num_envs(num_envs),
num_obs(num_obs),
ObservationBuffer::ObservationBuffer(int num_envs,
int num_obs,
int include_history_steps)
: num_envs(num_envs),
num_obs(num_obs),
include_history_steps(include_history_steps)
{
num_obs_total = num_obs * include_history_steps;
@ -16,7 +16,8 @@ ObservationBuffer::ObservationBuffer(int num_envs,
void ObservationBuffer::reset(std::vector<int> reset_idxs, torch::Tensor new_obs)
{
std::vector<torch::indexing::TensorIndex> indices;
for (int idx : reset_idxs) {
for (int idx : reset_idxs)
{
indices.push_back(torch::indexing::Slice(idx));
}
obs_buf.index_put_(indices, new_obs.repeat({1, include_history_steps}));

View File

@ -4,7 +4,8 @@
#include <torch/torch.h>
#include <vector>
class ObservationBuffer {
class ObservationBuffer
{
public:
ObservationBuffer(int num_envs, int num_obs, int include_history_steps);
ObservationBuffer();
@ -21,4 +22,4 @@ private:
torch::Tensor obs_buf;
};
#endif // OBSERVATION_BUFFER_HPP
#endif // OBSERVATION_BUFFER_HPP

View File

@ -15,34 +15,34 @@ torch::Tensor RL::ComputeObservation()
{
std::vector<torch::Tensor> obs_list;
for(const std::string& observation : this->params.observations)
for (const std::string &observation : this->params.observations)
{
if(observation == "lin_vel")
if (observation == "lin_vel")
{
obs_list.push_back(this->obs.lin_vel * this->params.lin_vel_scale);
}
else if(observation == "ang_vel")
else if (observation == "ang_vel")
{
// obs_list.push_back(this->obs.ang_vel * this->params.ang_vel_scale); // TODO is QuatRotateInverse necessery?
obs_list.push_back(this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel, this->params.framework) * this->params.ang_vel_scale);
}
else if(observation == "gravity_vec")
else if (observation == "gravity_vec")
{
obs_list.push_back(this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec, this->params.framework));
}
else if(observation == "commands")
else if (observation == "commands")
{
obs_list.push_back(this->obs.commands * this->params.commands_scale);
}
else if(observation == "dof_pos")
else if (observation == "dof_pos")
{
obs_list.push_back((this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale);
}
else if(observation == "dof_vel")
else if (observation == "dof_vel")
{
obs_list.push_back(this->obs.dof_vel * this->params.dof_vel_scale);
}
else if(observation == "actions")
else if (observation == "actions")
{
obs_list.push_back(this->obs.actions);
}
@ -92,22 +92,22 @@ torch::Tensor RL::ComputePosition(torch::Tensor actions)
return actions_scaled + this->params.default_dof_pos;
}
torch::Tensor RL::QuatRotateInverse(torch::Tensor q, torch::Tensor v, const std::string& framework)
torch::Tensor RL::QuatRotateInverse(torch::Tensor q, torch::Tensor v, const std::string &framework)
{
torch::Tensor q_w;
torch::Tensor q_vec;
if(framework == "isaacsim")
if (framework == "isaacsim")
{
q_w = q.index({torch::indexing::Slice(), 0});
q_vec = q.index({torch::indexing::Slice(), torch::indexing::Slice(1, 4)});
}
else if(framework == "isaacgym")
else if (framework == "isaacgym")
{
q_w = q.index({torch::indexing::Slice(), 3});
q_vec = q.index({torch::indexing::Slice(), torch::indexing::Slice(0, 3)});
}
c10::IntArrayRef shape = q.sizes();
torch::Tensor a = v * (2.0 * torch::pow(q_w, 2) - 1.0).unsqueeze(-1);
torch::Tensor b = torch::cross(q_vec, v, -1) * q_w.unsqueeze(-1) * 2.0;
torch::Tensor c = q_vec * torch::bmm(q_vec.view({shape[0], 1, 3}), v.view({shape[0], 3, 1})).squeeze(-1) * 2.0;
@ -122,17 +122,17 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
static float getdown_percent = 0.0;
// waiting
if(this->running_state == STATE_WAITING)
if (this->running_state == STATE_WAITING)
{
for(int i = 0; i < this->params.num_of_dofs; ++i)
for (int i = 0; i < this->params.num_of_dofs; ++i)
{
command->motor_command.q[i] = state->motor_state.q[i];
}
if(this->control.control_state == STATE_POS_GETUP)
if (this->control.control_state == STATE_POS_GETUP)
{
this->control.control_state = STATE_WAITING;
getup_percent = 0.0;
for(int i = 0; i < this->params.num_of_dofs; ++i)
for (int i = 0; i < this->params.num_of_dofs; ++i)
{
now_state.motor_state.q[i] = state->motor_state.q[i];
start_state.motor_state.q[i] = now_state.motor_state.q[i];
@ -142,13 +142,13 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
}
}
// stand up (position control)
else if(this->running_state == STATE_POS_GETUP)
else if (this->running_state == STATE_POS_GETUP)
{
if(getup_percent < 1.0)
if (getup_percent < 1.0)
{
getup_percent += 1 / 500.0;
getup_percent = getup_percent > 1.0 ? 1.0 : getup_percent;
for(int i = 0; i < this->params.num_of_dofs; ++i)
for (int i = 0; i < this->params.num_of_dofs; ++i)
{
command->motor_command.q[i] = (1 - getup_percent) * now_state.motor_state.q[i] + getup_percent * this->params.default_dof_pos[0][i].item<double>();
command->motor_command.dq[i] = 0;
@ -158,17 +158,17 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
}
std::cout << "\r" << std::flush << LOGGER::INFO << "Getting up " << std::fixed << std::setprecision(2) << getup_percent * 100.0 << std::flush;
}
if(this->control.control_state == STATE_RL_INIT)
if (this->control.control_state == STATE_RL_INIT)
{
this->control.control_state = STATE_WAITING;
this->running_state = STATE_RL_INIT;
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_RL_INIT" << std::endl;
}
else if(this->control.control_state == STATE_POS_GETDOWN)
else if (this->control.control_state == STATE_POS_GETDOWN)
{
this->control.control_state = STATE_WAITING;
getdown_percent = 0.0;
for(int i = 0; i < this->params.num_of_dofs; ++i)
for (int i = 0; i < this->params.num_of_dofs; ++i)
{
now_state.motor_state.q[i] = state->motor_state.q[i];
}
@ -177,9 +177,9 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
}
}
// init obs and start rl loop
else if(this->running_state == STATE_RL_INIT)
else if (this->running_state == STATE_RL_INIT)
{
if(getup_percent == 1)
if (getup_percent == 1)
{
this->InitObservations();
this->InitOutputs();
@ -189,10 +189,10 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
}
}
// rl loop
else if(this->running_state == STATE_RL_RUNNING)
else if (this->running_state == STATE_RL_RUNNING)
{
std::cout << "\r" << std::flush << LOGGER::INFO << "RL Controller x:" << this->control.x << " y:" << this->control.y << " yaw:" << this->control.yaw << std::flush;
for(int i = 0; i < this->params.num_of_dofs; ++i)
for (int i = 0; i < this->params.num_of_dofs; ++i)
{
command->motor_command.q[i] = this->output_dof_pos[0][i].item<double>();
command->motor_command.dq[i] = 0;
@ -200,22 +200,22 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
command->motor_command.kd[i] = this->params.rl_kd[0][i].item<double>();
command->motor_command.tau[i] = 0;
}
if(this->control.control_state == STATE_POS_GETDOWN)
if (this->control.control_state == STATE_POS_GETDOWN)
{
this->control.control_state = STATE_WAITING;
getdown_percent = 0.0;
for(int i = 0; i < this->params.num_of_dofs; ++i)
for (int i = 0; i < this->params.num_of_dofs; ++i)
{
now_state.motor_state.q[i] = state->motor_state.q[i];
}
this->running_state = STATE_POS_GETDOWN;
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_POS_GETDOWN" << std::endl;
}
else if(this->control.control_state == STATE_POS_GETUP)
else if (this->control.control_state == STATE_POS_GETUP)
{
this->control.control_state = STATE_WAITING;
getup_percent = 0.0;
for(int i = 0; i < this->params.num_of_dofs; ++i)
for (int i = 0; i < this->params.num_of_dofs; ++i)
{
now_state.motor_state.q[i] = state->motor_state.q[i];
}
@ -224,13 +224,13 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
}
}
// get down (position control)
else if(this->running_state == STATE_POS_GETDOWN)
else if (this->running_state == STATE_POS_GETDOWN)
{
if(getdown_percent < 1.0)
if (getdown_percent < 1.0)
{
getdown_percent += 1 / 500.0;
getdown_percent = getdown_percent > 1.0 ? 1.0 : getdown_percent;
for(int i = 0; i < this->params.num_of_dofs; ++i)
for (int i = 0; i < this->params.num_of_dofs; ++i)
{
command->motor_command.q[i] = (1 - getdown_percent) * now_state.motor_state.q[i] + getdown_percent * start_state.motor_state.q[i];
command->motor_command.dq[i] = 0;
@ -240,7 +240,7 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
}
std::cout << "\r" << std::flush << LOGGER::INFO << "Getting down " << std::fixed << std::setprecision(2) << getdown_percent * 100.0 << std::flush;
}
if(getdown_percent == 1)
if (getdown_percent == 1)
{
this->InitObservations();
this->InitOutputs();
@ -255,28 +255,28 @@ void RL::TorqueProtect(torch::Tensor origin_output_torques)
{
std::vector<int> out_of_range_indices;
std::vector<double> out_of_range_values;
for(int i = 0; i < origin_output_torques.size(1); ++i)
for (int i = 0; i < origin_output_torques.size(1); ++i)
{
double torque_value = origin_output_torques[0][i].item<double>();
double limit_lower = -this->params.torque_limits[0][i].item<double>();
double limit_upper = this->params.torque_limits[0][i].item<double>();
if(torque_value < limit_lower || torque_value > limit_upper)
if (torque_value < limit_lower || torque_value > limit_upper)
{
out_of_range_indices.push_back(i);
out_of_range_values.push_back(torque_value);
}
}
if(!out_of_range_indices.empty())
if (!out_of_range_indices.empty())
{
for(int i = 0; i < out_of_range_indices.size(); ++i)
for (int i = 0; i < out_of_range_indices.size(); ++i)
{
int index = out_of_range_indices[i];
double value = out_of_range_values[i];
double limit_lower = -this->params.torque_limits[0][index].item<double>();
double limit_upper = this->params.torque_limits[0][index].item<double>();
std::cout << LOGGER::WARNING << "Torque(" << index+1 << ")=" << value << " out of range(" << limit_lower << ", " << limit_upper << ")" << std::endl;
std::cout << LOGGER::WARNING << "Torque(" << index + 1 << ")=" << value << " out of range(" << limit_lower << ", " << limit_upper << ")" << std::endl;
}
// Just a reminder, no protection
// this->control.control_state = STATE_POS_GETDOWN;
@ -290,79 +290,109 @@ static bool kbhit()
{
termios term;
tcgetattr(0, &term);
termios term2 = term;
term2.c_lflag &= ~ICANON;
tcsetattr(0, TCSANOW, &term2);
int byteswaiting;
ioctl(0, FIONREAD, &byteswaiting);
tcsetattr(0, TCSANOW, &term);
return byteswaiting > 0;
}
void RL::KeyboardInterface()
{
if(kbhit())
if (kbhit())
{
int c = fgetc(stdin);
switch(c)
switch (c)
{
case '0': this->control.control_state = STATE_POS_GETUP; break;
case 'p': this->control.control_state = STATE_RL_INIT; break;
case '1': this->control.control_state = STATE_POS_GETDOWN; break;
case 'q': break;
case 'w': this->control.x += 0.1; break;
case 's': this->control.x -= 0.1; break;
case 'a': this->control.yaw += 0.1; break;
case 'd': this->control.yaw -= 0.1; break;
case 'i': break;
case 'k': break;
case 'j': this->control.y += 0.1; break;
case 'l': this->control.y -= 0.1; break;
case ' ': this->control.x = 0; this->control.y = 0; this->control.yaw = 0; break;
case 'r': this->control.control_state = STATE_RESET_SIMULATION; break;
case '\n': this->control.control_state = STATE_TOGGLE_SIMULATION; break;
default: break;
case '0':
this->control.control_state = STATE_POS_GETUP;
break;
case 'p':
this->control.control_state = STATE_RL_INIT;
break;
case '1':
this->control.control_state = STATE_POS_GETDOWN;
break;
case 'q':
break;
case 'w':
this->control.x += 0.1;
break;
case 's':
this->control.x -= 0.1;
break;
case 'a':
this->control.yaw += 0.1;
break;
case 'd':
this->control.yaw -= 0.1;
break;
case 'i':
break;
case 'k':
break;
case 'j':
this->control.y += 0.1;
break;
case 'l':
this->control.y -= 0.1;
break;
case ' ':
this->control.x = 0;
this->control.y = 0;
this->control.yaw = 0;
break;
case 'r':
this->control.control_state = STATE_RESET_SIMULATION;
break;
case '\n':
this->control.control_state = STATE_TOGGLE_SIMULATION;
break;
default:
break;
}
}
}
template<typename T>
std::vector<T> ReadVectorFromYaml(const YAML::Node& node)
template <typename T>
std::vector<T> ReadVectorFromYaml(const YAML::Node &node)
{
std::vector<T> values;
for(const auto& val : node)
for (const auto &val : node)
{
values.push_back(val.as<T>());
}
return values;
}
template<typename T>
std::vector<T> ReadVectorFromYaml(const YAML::Node& node, const std::string& framework, const int& rows, const int& cols)
template <typename T>
std::vector<T> ReadVectorFromYaml(const YAML::Node &node, const std::string &framework, const int &rows, const int &cols)
{
std::vector<T> values;
for(const auto& val : node)
for (const auto &val : node)
{
values.push_back(val.as<T>());
}
if(framework == "isaacsim")
if (framework == "isaacsim")
{
std::vector<T> transposed_values(cols * rows);
for(int r = 0; r < rows; ++r)
for (int r = 0; r < rows; ++r)
{
for(int c = 0; c < cols; ++c)
for (int c = 0; c < cols; ++c)
{
transposed_values[c * rows + r] = values[r * cols + c];
}
}
return transposed_values;
}
else if(framework == "isaacgym")
else if (framework == "isaacgym")
{
return values;
}
@ -380,7 +410,8 @@ void RL::ReadYaml(std::string robot_name)
try
{
config = YAML::LoadFile(config_path)[robot_name];
} catch(YAML::BadFile &e)
}
catch (YAML::BadFile &e)
{
std::cout << LOGGER::ERROR << "The file '" << config_path << "' does not exist" << std::endl;
return;
@ -396,7 +427,7 @@ void RL::ReadYaml(std::string robot_name)
this->params.num_observations = config["num_observations"].as<int>();
this->params.observations = ReadVectorFromYaml<std::string>(config["observations"]);
this->params.clip_obs = config["clip_obs"].as<double>();
if(config["clip_actions_lower"].IsNull() && config["clip_actions_upper"].IsNull())
if (config["clip_actions_lower"].IsNull() && config["clip_actions_upper"].IsNull())
{
this->params.clip_actions_upper = torch::tensor({}).view({1, -1});
this->params.clip_actions_lower = torch::tensor({}).view({1, -1});
@ -440,11 +471,11 @@ void RL::CSVInit(std::string robot_name)
csv_filename += ".csv";
std::ofstream file(csv_filename.c_str());
for(int i = 0; i < 12; ++i) {file << "tau_cal_" << i << ",";}
for(int i = 0; i < 12; ++i) {file << "tau_est_" << i << ",";}
for(int i = 0; i < 12; ++i) {file << "joint_pos_" << i << ",";}
for(int i = 0; i < 12; ++i) {file << "joint_pos_target_" << i << ",";}
for(int i = 0; i < 12; ++i) {file << "joint_vel_" << i << ",";}
for(int i = 0; i < 12; ++i) { file << "tau_cal_" << i << ","; }
for(int i = 0; i < 12; ++i) { file << "tau_est_" << i << ","; }
for(int i = 0; i < 12; ++i) { file << "joint_pos_" << i << ","; }
for(int i = 0; i < 12; ++i) { file << "joint_pos_target_" << i << ","; }
for(int i = 0; i < 12; ++i) { file << "joint_vel_" << i << ","; }
file << std::endl;
@ -455,13 +486,13 @@ void RL::CSVLogger(torch::Tensor torque, torch::Tensor tau_est, torch::Tensor jo
{
std::ofstream file(csv_filename.c_str(), std::ios_base::app);
for(int i = 0; i < 12; ++i) {file << torque[0][i].item<double>() << ",";}
for(int i = 0; i < 12; ++i) {file << tau_est[0][i].item<double>() << ",";}
for(int i = 0; i < 12; ++i) {file << joint_pos[0][i].item<double>() << ",";}
for(int i = 0; i < 12; ++i) {file << joint_pos_target[0][i].item<double>() << ",";}
for(int i = 0; i < 12; ++i) {file << joint_vel[0][i].item<double>() << ",";}
for(int i = 0; i < 12; ++i) { file << torque[0][i].item<double>() << ","; }
for(int i = 0; i < 12; ++i) { file << tau_est[0][i].item<double>() << ","; }
for(int i = 0; i < 12; ++i) { file << joint_pos[0][i].item<double>() << ","; }
for(int i = 0; i < 12; ++i) { file << joint_pos_target[0][i].item<double>() << ","; }
for(int i = 0; i < 12; ++i) { file << joint_vel[0][i].item<double>() << ","; }
file << std::endl;
file.close();
}
}

View File

@ -8,14 +8,15 @@
#include <yaml-cpp/yaml.h>
namespace LOGGER {
const char* const INFO = "\033[0;37m[INFO]\033[0m ";
const char* const WARNING = "\033[0;33m[WARNING]\033[0m ";
const char* const ERROR = "\033[0;31m[ERROR]\033[0m ";
const char* const DEBUG = "\033[0;32m[DEBUG]\033[0m ";
namespace LOGGER
{
const char *const INFO = "\033[0;37m[INFO]\033[0m ";
const char *const WARNING = "\033[0;33m[WARNING]\033[0m ";
const char *const ERROR = "\033[0;31m[ERROR]\033[0m ";
const char *const DEBUG = "\033[0;32m[DEBUG]\033[0m ";
}
template<typename T>
template <typename T>
struct RobotCommand
{
struct MotorCommand
@ -28,7 +29,7 @@ struct RobotCommand
} motor_command;
};
template<typename T>
template <typename T>
struct RobotState
{
struct IMU
@ -48,7 +49,8 @@ struct RobotState
} motor_state;
};
enum STATE {
enum STATE
{
STATE_WAITING = 0,
STATE_POS_GETUP,
STATE_RL_INIT,
@ -100,21 +102,21 @@ struct ModelParams
struct Observations
{
torch::Tensor lin_vel;
torch::Tensor ang_vel;
torch::Tensor gravity_vec;
torch::Tensor commands;
torch::Tensor base_quat;
torch::Tensor dof_pos;
torch::Tensor dof_vel;
torch::Tensor lin_vel;
torch::Tensor ang_vel;
torch::Tensor gravity_vec;
torch::Tensor commands;
torch::Tensor base_quat;
torch::Tensor dof_pos;
torch::Tensor dof_vel;
torch::Tensor actions;
};
class RL
{
public:
RL(){};
~RL(){};
RL() {};
~RL() {};
ModelParams params;
Observations obs;
@ -135,7 +137,7 @@ public:
void StateController(const RobotState<double> *state, RobotCommand<double> *command);
torch::Tensor ComputeTorques(torch::Tensor actions);
torch::Tensor ComputePosition(torch::Tensor actions);
torch::Tensor QuatRotateInverse(torch::Tensor q, torch::Tensor v, const std::string& framework);
torch::Tensor QuatRotateInverse(torch::Tensor q, torch::Tensor v, const std::string &framework);
// yaml params
void ReadYaml(std::string robot_name);
@ -165,4 +167,4 @@ protected:
torch::Tensor output_dof_pos;
};
#endif
#endif // RL_SDK_HPP

View File

@ -8,7 +8,7 @@
<license>TODO</license>
<buildtool_depend>catkin</buildtool_depend>
<buildtool_depend>catkin</buildtool_depend>
<buildtool_depend>genmsg</buildtool_depend>
<build_depend>controller_manager</build_depend>
<build_depend>joint_state_controller</build_depend>

View File

@ -95,7 +95,7 @@ def load_data(data_path):
for key in data_dict.keys():
data_dict[key] = np.array(data_dict[key]).T
return data_dict, num_motors
def process_data(data_dict, num_motors, step):
@ -122,7 +122,7 @@ def process_data(data_dict, num_motors, step):
xs_joint = torch.cat(xs_joint, dim=1)
xs.append(xs_joint)
ys.append(tau_ests_joint)
xs = torch.cat(xs, dim=0)
ys = torch.cat(ys, dim=0)
return xs, ys

View File

@ -23,7 +23,7 @@ class ObservationBuffer:
def get_obs_vec(self, obs_ids):
"""Gets history of observations indexed by obs_ids.
Arguments:
obs_ids: An array of integers with which to index the desired
observations, where 0 is the latest observation and

View File

@ -124,7 +124,7 @@ class RL:
self.robot_name = ""
self.running_state = STATE.STATE_RL_RUNNING # default running_state set to STATE_RL_RUNNING
self.simulation_running = False
### protected in cpp ###
# rl module
self.model = None
@ -156,7 +156,7 @@ class RL:
obs = torch.cat(obs_list, dim=-1)
clamped_obs = torch.clamp(obs, -self.params.clip_obs, self.params.clip_obs)
return clamped_obs
def InitObservations(self):
self.obs.lin_vel = torch.zeros(1, 3, dtype=torch.float)
self.obs.ang_vel = torch.zeros(1, 3, dtype=torch.float)
@ -409,35 +409,35 @@ class RL:
def CSVInit(self, robot_name):
self.csv_filename = os.path.join(BASE_PATH, "models", robot_name, 'motor')
# Uncomment these lines if need timestamp for file name
# now = datetime.now()
# timestamp = now.strftime("%Y%m%d%H%M%S")
# self.csv_filename += f"_{timestamp}"
self.csv_filename += ".csv"
with open(self.csv_filename, 'w', newline='') as file:
writer = csv.writer(file)
header = []
header += [f"tau_cal_{i}" for i in range(12)]
header += [f"tau_est_{i}" for i in range(12)]
header += [f"joint_pos_{i}" for i in range(12)]
header += [f"joint_pos_target_{i}" for i in range(12)]
header += [f"joint_vel_{i}" for i in range(12)]
writer.writerow(header)
def CSVLogger(self, torque, tau_est, joint_pos, joint_pos_target, joint_vel):
with open(self.csv_filename, 'a', newline='') as file:
writer = csv.writer(file)
row = []
row += [torque[0][i].item() for i in range(12)]
row += [tau_est[0][i].item() for i in range(12)]
row += [joint_pos[0][i].item() for i in range(12)]
row += [joint_pos_target[0][i].item() for i in range(12)]
row += [joint_vel[0][i].item() for i in range(12)]
writer.writerow(row)

View File

@ -234,4 +234,3 @@ class RL_Sim(RL):
if __name__ == "__main__":
rl_sim = RL_Sim()
rospy.spin()

View File

@ -28,10 +28,10 @@ RL_Real::RL_Real() : unitree_safe(UNITREE_LEGGED_SDK::LeggedType::A1), unitree_u
this->model = torch::jit::load(model_path);
// loop
this->loop_udpSend = std::make_shared<LoopFunc>("loop_udpSend" , 0.002, std::bind(&RL_Real::UDPSend, this), 3);
this->loop_udpRecv = std::make_shared<LoopFunc>("loop_udpRecv" , 0.002, std::bind(&RL_Real::UDPRecv, this), 3);
this->loop_udpSend = std::make_shared<LoopFunc>("loop_udpSend", 0.002, std::bind(&RL_Real::UDPSend, this), 3);
this->loop_udpRecv = std::make_shared<LoopFunc>("loop_udpRecv", 0.002, std::bind(&RL_Real::UDPRecv, this), 3);
this->loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05, std::bind(&RL_Real::KeyboardInterface, this));
this->loop_control = std::make_shared<LoopFunc>("loop_control", this->params.dt, std::bind(&RL_Real::RobotControl, this));
this->loop_control = std::make_shared<LoopFunc>("loop_control", this->params.dt, std::bind(&RL_Real::RobotControl, this));
this->loop_rl = std::make_shared<LoopFunc>("loop_rl", this->params.dt * this->params.decimation, std::bind(&RL_Real::RunModel, this));
this->loop_udpSend->start();
this->loop_udpRecv->start();
@ -39,14 +39,13 @@ RL_Real::RL_Real() : unitree_safe(UNITREE_LEGGED_SDK::LeggedType::A1), unitree_u
this->loop_control->start();
this->loop_rl->start();
#ifdef PLOT
this->plot_t = std::vector<int>(this->plot_size, 0);
this->plot_real_joint_pos.resize(this->params.num_of_dofs);
this->plot_target_joint_pos.resize(this->params.num_of_dofs);
for(auto& vector : this->plot_real_joint_pos) { vector = std::vector<double>(this->plot_size, 0); }
for(auto& vector : this->plot_target_joint_pos) { vector = std::vector<double>(this->plot_size, 0); }
this->loop_plot = std::make_shared<LoopFunc>("loop_plot" , 0.002, std::bind(&RL_Real::Plot, this));
for (auto &vector : this->plot_real_joint_pos) { vector = std::vector<double>(this->plot_size, 0); }
for (auto &vector : this->plot_target_joint_pos) { vector = std::vector<double>(this->plot_size, 0); }
this->loop_plot = std::make_shared<LoopFunc>("loop_plot", 0.002, std::bind(&RL_Real::Plot, this));
this->loop_plot->start();
#endif
#ifdef CSV_LOGGER
@ -72,27 +71,27 @@ void RL_Real::GetState(RobotState<double> *state)
this->unitree_udp.GetRecv(this->unitree_low_state);
memcpy(&this->unitree_joy, this->unitree_low_state.wirelessRemote, 40);
if((int)this->unitree_joy.btn.components.R2 == 1)
if ((int)this->unitree_joy.btn.components.R2 == 1)
{
this->control.control_state = STATE_POS_GETUP;
}
else if((int)this->unitree_joy.btn.components.R1 == 1)
else if ((int)this->unitree_joy.btn.components.R1 == 1)
{
this->control.control_state = STATE_RL_INIT;
}
else if((int)this->unitree_joy.btn.components.L2 == 1)
else if ((int)this->unitree_joy.btn.components.L2 == 1)
{
this->control.control_state = STATE_POS_GETDOWN;
}
if(this->params.framework == "isaacgym")
if (this->params.framework == "isaacgym")
{
state->imu.quaternion[3] = this->unitree_low_state.imu.quaternion[0]; // w
state->imu.quaternion[0] = this->unitree_low_state.imu.quaternion[1]; // x
state->imu.quaternion[1] = this->unitree_low_state.imu.quaternion[2]; // y
state->imu.quaternion[2] = this->unitree_low_state.imu.quaternion[3]; // z
}
else if(this->params.framework == "isaacsim")
else if (this->params.framework == "isaacsim")
{
state->imu.quaternion[0] = this->unitree_low_state.imu.quaternion[0]; // w
state->imu.quaternion[1] = this->unitree_low_state.imu.quaternion[1]; // x
@ -100,11 +99,11 @@ void RL_Real::GetState(RobotState<double> *state)
state->imu.quaternion[3] = this->unitree_low_state.imu.quaternion[3]; // z
}
for(int i = 0; i < 3; ++i)
for (int i = 0; i < 3; ++i)
{
state->imu.gyroscope[i] = this->unitree_low_state.imu.gyroscope[i];
}
for(int i = 0; i < this->params.num_of_dofs; ++i)
for (int i = 0; i < this->params.num_of_dofs; ++i)
{
state->motor_state.q[i] = this->unitree_low_state.motorState[state_mapping[i]].q;
state->motor_state.dq[i] = this->unitree_low_state.motorState[state_mapping[i]].dq;
@ -114,7 +113,7 @@ void RL_Real::GetState(RobotState<double> *state)
void RL_Real::SetCommand(const RobotCommand<double> *command)
{
for(int i = 0; i < this->params.num_of_dofs; ++i)
for (int i = 0; i < this->params.num_of_dofs; ++i)
{
this->unitree_low_command.motorCmd[i].mode = 0x0A;
this->unitree_low_command.motorCmd[i].q = command->motor_command.q[command_mapping[i]];
@ -140,7 +139,7 @@ void RL_Real::RobotControl()
void RL_Real::RunModel()
{
if(this->running_state == STATE_RL_RUNNING)
if (this->running_state == STATE_RL_RUNNING)
{
this->obs.ang_vel = torch::tensor(this->robot_state.imu.gyroscope).unsqueeze(0);
this->obs.commands = torch::tensor({{this->unitree_joy.ly, -this->unitree_joy.rx, -this->unitree_joy.lx}});
@ -182,7 +181,7 @@ torch::Tensor RL_Real::Forward()
torch::Tensor actions = this->model.forward({this->history_obs}).toTensor();
if(this->params.clip_actions_upper.numel() != 0 && this->params.clip_actions_lower.numel() != 0)
if (this->params.clip_actions_upper.numel() != 0 && this->params.clip_actions_lower.numel() != 0)
{
return torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper);
}
@ -198,7 +197,7 @@ void RL_Real::Plot()
this->plot_t.push_back(this->motiontime);
plt::cla();
plt::clf();
for(int i = 0; i < this->params.num_of_dofs; ++i)
for (int i = 0; i < this->params.num_of_dofs; ++i)
{
this->plot_real_joint_pos[i].erase(this->plot_real_joint_pos[i].begin());
this->plot_target_joint_pos[i].erase(this->plot_target_joint_pos[i].begin());
@ -222,10 +221,10 @@ int main(int argc, char **argv)
{
signal(SIGINT, signalHandler);
while(1)
while (1)
{
sleep(10);
};
}
return 0;
}

View File

@ -12,7 +12,7 @@ RL_Sim::RL_Sim()
this->ReadYaml(this->robot_name);
// history
if(this->params.use_history)
if (this->params.use_history)
{
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
}
@ -21,7 +21,7 @@ RL_Sim::RL_Sim()
// the mapping table is established according to the order defined in the YAML file
std::vector<std::string> sorted_joint_controller_names = this->params.joint_controller_names;
std::sort(sorted_joint_controller_names.begin(), sorted_joint_controller_names.end());
for(size_t i = 0; i < this->params.joint_controller_names.size(); ++i)
for (size_t i = 0; i < this->params.joint_controller_names.size(); ++i)
{
this->sorted_to_original_index[sorted_joint_controller_names[i]] = i;
}
@ -46,8 +46,8 @@ RL_Sim::RL_Sim()
for (int i = 0; i < this->params.num_of_dofs; ++i)
{
// joint need to rename as xxx_joint
this->joint_publishers[this->params.joint_controller_names[i]] = nh.advertise<robot_msgs::MotorCommand>(
this->ros_namespace + this->params.joint_controller_names[i] + "/command", 10);
this->joint_publishers[this->params.joint_controller_names[i]] =
nh.advertise<robot_msgs::MotorCommand>(this->ros_namespace + this->params.joint_controller_names[i] + "/command", 10);
}
// subscriber
@ -62,7 +62,7 @@ RL_Sim::RL_Sim()
this->gazebo_unpause_physics_client = nh.serviceClient<std_srvs::Empty>("/gazebo/unpause_physics");
// loop
this->loop_control = std::make_shared<LoopFunc>("loop_control", this->params.dt, std::bind(&RL_Sim::RobotControl, this));
this->loop_control = std::make_shared<LoopFunc>("loop_control", this->params.dt, std::bind(&RL_Sim::RobotControl, this));
this->loop_rl = std::make_shared<LoopFunc>("loop_rl", this->params.dt * this->params.decimation, std::bind(&RL_Sim::RunModel, this));
this->loop_control->start();
this->loop_rl->start();
@ -75,9 +75,9 @@ RL_Sim::RL_Sim()
this->plot_t = std::vector<int>(this->plot_size, 0);
this->plot_real_joint_pos.resize(this->params.num_of_dofs);
this->plot_target_joint_pos.resize(this->params.num_of_dofs);
for(auto& vector : this->plot_real_joint_pos) { vector = std::vector<double>(this->plot_size, 0); }
for(auto& vector : this->plot_target_joint_pos) { vector = std::vector<double>(this->plot_size, 0); }
this->loop_plot = std::make_shared<LoopFunc>("loop_plot" , 0.001 , std::bind(&RL_Sim::Plot , this));
for (auto &vector : this->plot_real_joint_pos) { vector = std::vector<double>(this->plot_size, 0); }
for (auto &vector : this->plot_target_joint_pos) { vector = std::vector<double>(this->plot_size, 0); }
this->loop_plot = std::make_shared<LoopFunc>("loop_plot", 0.001, std::bind(&RL_Sim::Plot, this));
this->loop_plot->start();
#endif
#ifdef CSV_LOGGER
@ -100,14 +100,14 @@ RL_Sim::~RL_Sim()
void RL_Sim::GetState(RobotState<double> *state)
{
if(this->params.framework == "isaacgym")
if (this->params.framework == "isaacgym")
{
state->imu.quaternion[3] = this->pose.orientation.w;
state->imu.quaternion[0] = this->pose.orientation.x;
state->imu.quaternion[1] = this->pose.orientation.y;
state->imu.quaternion[2] = this->pose.orientation.z;
}
else if(this->params.framework == "isaacsim")
else if (this->params.framework == "isaacsim")
{
state->imu.quaternion[0] = this->pose.orientation.w;
state->imu.quaternion[1] = this->pose.orientation.x;
@ -121,7 +121,7 @@ void RL_Sim::GetState(RobotState<double> *state)
// state->imu.accelerometer
for(int i = 0; i < this->params.num_of_dofs; ++i)
for (int i = 0; i < this->params.num_of_dofs; ++i)
{
state->motor_state.q[i] = this->mapped_joint_positions[i];
state->motor_state.dq[i] = this->mapped_joint_velocities[i];
@ -131,7 +131,7 @@ void RL_Sim::GetState(RobotState<double> *state)
void RL_Sim::SetCommand(const RobotCommand<double> *command)
{
for(int i = 0; i < this->params.num_of_dofs; ++i)
for (int i = 0; i < this->params.num_of_dofs; ++i)
{
this->joint_publishers_commands[i].q = command->motor_command.q[i];
this->joint_publishers_commands[i].dq = command->motor_command.dq[i];
@ -140,7 +140,7 @@ void RL_Sim::SetCommand(const RobotCommand<double> *command)
this->joint_publishers_commands[i].tau = command->motor_command.tau[i];
}
for(int i = 0; i < this->params.num_of_dofs; ++i)
for (int i = 0; i < this->params.num_of_dofs; ++i)
{
this->joint_publishers[this->params.joint_controller_names[i]].publish(this->joint_publishers_commands[i]);
}
@ -148,7 +148,7 @@ void RL_Sim::SetCommand(const RobotCommand<double> *command)
void RL_Sim::RobotControl()
{
if(this->control.control_state == STATE_RESET_SIMULATION)
if (this->control.control_state == STATE_RESET_SIMULATION)
{
gazebo_msgs::SetModelState set_model_state;
set_model_state.request.model_state.model_name = this->gazebo_model_name;
@ -158,10 +158,10 @@ void RL_Sim::RobotControl()
this->control.control_state = STATE_WAITING;
}
if(this->control.control_state == STATE_TOGGLE_SIMULATION)
if (this->control.control_state == STATE_TOGGLE_SIMULATION)
{
std_srvs::Empty empty;
if(simulation_running)
if (simulation_running)
{
this->gazebo_pause_physics_client.call(empty);
std::cout << std::endl << LOGGER::INFO << "Simulation Stop" << std::endl;
@ -174,7 +174,7 @@ void RL_Sim::RobotControl()
simulation_running = !simulation_running;
this->control.control_state = STATE_WAITING;
}
if(simulation_running)
if (simulation_running)
{
this->motiontime++;
this->GetState(&this->robot_state);
@ -194,9 +194,9 @@ void RL_Sim::CmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg)
this->cmd_vel = *msg;
}
void RL_Sim::MapData(const std::vector<double>& source_data, std::vector<double>& target_data)
void RL_Sim::MapData(const std::vector<double> &source_data, std::vector<double> &target_data)
{
for(size_t i = 0; i < source_data.size(); ++i)
for (size_t i = 0; i < source_data.size(); ++i)
{
target_data[i] = source_data[this->sorted_to_original_index[this->params.joint_controller_names[i]]];
}
@ -211,7 +211,7 @@ void RL_Sim::JointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg)
void RL_Sim::RunModel()
{
if(this->running_state == STATE_RL_RUNNING && simulation_running)
if (this->running_state == STATE_RL_RUNNING && simulation_running)
{
this->obs.lin_vel = torch::tensor({{this->vel.linear.x, this->vel.linear.y, this->vel.linear.z}});
this->obs.ang_vel = torch::tensor(this->robot_state.imu.gyroscope).unsqueeze(0);
@ -249,7 +249,7 @@ torch::Tensor RL_Sim::Forward()
torch::autograd::GradMode::set_enabled(false);
torch::Tensor clamped_obs = this->ComputeObservation();
torch::Tensor actions;
if(this->params.use_history)
if (this->params.use_history)
{
this->history_obs_buf.insert(clamped_obs);
this->history_obs = this->history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5});
@ -260,7 +260,7 @@ torch::Tensor RL_Sim::Forward()
actions = this->model.forward({clamped_obs}).toTensor();
}
if(this->params.clip_actions_upper.numel() != 0 && this->params.clip_actions_lower.numel() != 0)
if (this->params.clip_actions_upper.numel() != 0 && this->params.clip_actions_lower.numel() != 0)
{
return torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper);
}
@ -276,13 +276,13 @@ void RL_Sim::Plot()
this->plot_t.push_back(this->motiontime);
plt::cla();
plt::clf();
for(int i = 0; i < this->params.num_of_dofs; ++i)
for (int i = 0; i < this->params.num_of_dofs; ++i)
{
this->plot_real_joint_pos[i].erase(this->plot_real_joint_pos[i].begin());
this->plot_target_joint_pos[i].erase(this->plot_target_joint_pos[i].begin());
this->plot_real_joint_pos[i].push_back(this->mapped_joint_positions[i]);
this->plot_target_joint_pos[i].push_back(this->joint_publishers_commands[i].q);
plt::subplot(4, 3, i+1);
plt::subplot(4, 3, i + 1);
plt::named_plot("_real_joint_pos", this->plot_t, this->plot_real_joint_pos[i], "r");
plt::named_plot("_target_joint_pos", this->plot_t, this->plot_target_joint_pos[i], "b");
plt::xlim(this->plot_t.front(), this->plot_t.back());

View File

@ -9,16 +9,16 @@
<gravity>0 0 -9.81</gravity>
<ode>
<solver>
<type>quick</type>
<iters>50</iters>
<type>quick</type>
<iters>50</iters>
<sor>1.3</sor>
</solver>
</solver>
<constraints>
<cfm>0.0</cfm>
<erp>0.2</erp>
<contact_max_correcting_vel>10.0</contact_max_correcting_vel>
<contact_surface_layer>0.001</contact_surface_layer>
</constraints>
</constraints>
</ode>
</physics>

View File

@ -20,10 +20,10 @@
#include <algorithm>
#include <math.h>
#define PosStopF (2.146E+9f)
#define VelStopF (16000.0f)
#define PosStopF (2.146E+9f)
#define VelStopF (16000.0f)
typedef struct
typedef struct
{
uint8_t mode;
double pos;
@ -35,15 +35,15 @@ typedef struct
namespace robot_joint_controller
{
class RobotJointController: public controller_interface::Controller<hardware_interface::EffortJointInterface>
class RobotJointController : public controller_interface::Controller<hardware_interface::EffortJointInterface>
{
private:
private:
hardware_interface::JointHandle joint;
ros::Subscriber sub_command, sub_ft;
control_toolbox::Pid pid_controller_;
std::unique_ptr<realtime_tools::RealtimePublisher<robot_msgs::MotorState> > controller_state_publisher_ ;
std::unique_ptr<realtime_tools::RealtimePublisher<robot_msgs::MotorState>> controller_state_publisher_;
public:
public:
std::string name_space;
std::string joint_name;
urdf::JointConstSharedPtr joint_urdf;
@ -55,10 +55,10 @@ public:
RobotJointController();
~RobotJointController();
virtual bool init(hardware_interface::EffortJointInterface *robot, ros::NodeHandle &n);
virtual void starting(const ros::Time& time);
virtual void update(const ros::Time& time, const ros::Duration& period);
virtual void starting(const ros::Time &time);
virtual void update(const ros::Time &time, const ros::Duration &period);
virtual void stopping();
void setCommandCB(const robot_msgs::MotorCommandConstPtr& msg);
void setCommandCB(const robot_msgs::MotorCommandConstPtr &msg);
void positionLimits(double &position);
void velocityLimits(double &velocity);
void effortLimits(double &effort);
@ -66,7 +66,6 @@ public:
void setGains(const double &p, const double &i, const double &d, const double &i_max, const double &i_min, const bool &antiwindup = false);
void getGains(double &p, double &i, double &d, double &i_max, double &i_min, bool &antiwindup);
void getGains(double &p, double &i, double &d, double &i_max, double &i_min);
};
}

View File

@ -3,30 +3,36 @@
// #define rqtTune // use rqt or not
double clamp(double& value, double min, double max) {
if (value < min) {
double clamp(double &value, double min, double max)
{
if (value < min)
{
value = min;
} else if (value > max) {
}
else if (value > max)
{
value = max;
}
return value;
}
namespace robot_joint_controller
namespace robot_joint_controller
{
RobotJointController::RobotJointController(){
RobotJointController::RobotJointController()
{
memset(&lastCommand, 0, sizeof(robot_msgs::MotorCommand));
memset(&lastState, 0, sizeof(robot_msgs::MotorState));
memset(&servoCommand, 0, sizeof(ServoCommand));
}
RobotJointController::~RobotJointController(){
RobotJointController::~RobotJointController()
{
sub_ft.shutdown();
sub_command.shutdown();
}
void RobotJointController::setCommandCB(const robot_msgs::MotorCommandConstPtr& msg)
void RobotJointController::setCommandCB(const robot_msgs::MotorCommandConstPtr &msg)
{
lastCommand.q = msg->q;
lastCommand.kp = msg->kp;
@ -43,28 +49,31 @@ namespace robot_joint_controller
bool RobotJointController::init(hardware_interface::EffortJointInterface *robot, ros::NodeHandle &n)
{
name_space = n.getNamespace();
if (!n.getParam("joint", joint_name)){
if (!n.getParam("joint", joint_name))
{
ROS_ERROR("No joint given in namespace: '%s')", n.getNamespace().c_str());
return false;
}
// load pid param from ymal only if rqt need
// load pid param from ymal only if rqt need
#ifdef rqtTune
// Load PID Controller using gains set on parameter server
if (!pid_controller_.init(ros::NodeHandle(n, "pid")))
return false;
// Load PID Controller using gains set on parameter server
if (!pid_controller_.init(ros::NodeHandle(n, "pid")))
return false;
#endif
urdf::Model urdf; // Get URDF info about joint
if (!urdf.initParamWithNodeHandle("robot_description", n)){
if (!urdf.initParamWithNodeHandle("robot_description", n))
{
ROS_ERROR("Failed to parse urdf file");
return false;
}
joint_urdf = urdf.getJoint(joint_name);
if (!joint_urdf){
if (!joint_urdf)
{
ROS_ERROR("Could not find joint '%s' in urdf", joint_name.c_str());
return false;
}
}
joint = robot->getHandle(joint_name);
// Start command subscriber
@ -72,29 +81,29 @@ namespace robot_joint_controller
// Start realtime state publisher
controller_state_publisher_.reset(
new realtime_tools::RealtimePublisher<robot_msgs::MotorState>(n, name_space + "/state", 1));
new realtime_tools::RealtimePublisher<robot_msgs::MotorState>(n, name_space + "/state", 1));
return true;
}
void RobotJointController::setGains(const double &p, const double &i, const double &d, const double &i_max, const double &i_min, const bool &antiwindup)
{
pid_controller_.setGains(p,i,d,i_max,i_min,antiwindup);
pid_controller_.setGains(p, i, d, i_max, i_min, antiwindup);
}
void RobotJointController::getGains(double &p, double &i, double &d, double &i_max, double &i_min, bool &antiwindup)
{
pid_controller_.getGains(p,i,d,i_max,i_min,antiwindup);
pid_controller_.getGains(p, i, d, i_max, i_min, antiwindup);
}
void RobotJointController::getGains(double &p, double &i, double &d, double &i_max, double &i_min)
{
bool dummy;
pid_controller_.getGains(p,i,d,i_max,i_min,dummy);
pid_controller_.getGains(p, i, d, i_max, i_min, dummy);
}
// Controller startup in realtime
void RobotJointController::starting(const ros::Time& time)
void RobotJointController::starting(const ros::Time &time)
{
double init_pos = joint.getPosition();
lastCommand.q = init_pos;
@ -109,7 +118,7 @@ namespace robot_joint_controller
}
// Controller update loop in realtime
void RobotJointController::update(const ros::Time& time, const ros::Duration& period)
void RobotJointController::update(const ros::Time &time, const ros::Duration &period)
{
double currentPos, currentVel, calcTorque;
lastCommand = *(command.readFromRT());
@ -118,27 +127,29 @@ namespace robot_joint_controller
servoCommand.pos = lastCommand.q;
positionLimits(servoCommand.pos);
servoCommand.posStiffness = lastCommand.kp;
if(fabs(lastCommand.q - PosStopF) < 0.00001){
if (fabs(lastCommand.q - PosStopF) < 0.00001)
{
servoCommand.posStiffness = 0;
}
servoCommand.vel = lastCommand.dq;
velocityLimits(servoCommand.vel);
servoCommand.velStiffness = lastCommand.kd;
if(fabs(lastCommand.dq - VelStopF) < 0.00001){
if (fabs(lastCommand.dq - VelStopF) < 0.00001)
{
servoCommand.velStiffness = 0;
}
servoCommand.torque = lastCommand.tau;
effortLimits(servoCommand.torque);
// rqt set P D gains
#ifdef rqtTune
double i, i_max, i_min;
getGains(servoCommand.posStiffness,i,servoCommand.velStiffness,i_max,i_min);
double i, i_max, i_min;
getGains(servoCommand.posStiffness, i, servoCommand.velStiffness, i_max, i_min);
#endif
currentPos = joint.getPosition();
// currentVel = computeVel(currentPos, (double)lastState.q, (double)lastState.dq, period.toSec());
// calcTorque = computeTorque(currentPos, currentVel, servoCommand);
// calcTorque = computeTorque(currentPos, currentVel, servoCommand);
currentVel = (currentPos - (double)lastState.q) / period.toSec();
calcTorque = servoCommand.posStiffness * (servoCommand.pos - currentPos) + servoCommand.velStiffness * (servoCommand.vel - currentVel) + servoCommand.torque;
effortLimits(calcTorque);
@ -151,7 +162,8 @@ namespace robot_joint_controller
lastState.tauEst = joint.getEffort();
// publish state
if (controller_state_publisher_ && controller_state_publisher_->trylock()) {
if (controller_state_publisher_ && controller_state_publisher_->trylock())
{
controller_state_publisher_->msg_.q = lastState.q;
controller_state_publisher_->msg_.dq = lastState.dq;
controller_state_publisher_->msg_.tauEst = lastState.tauEst;
@ -160,7 +172,7 @@ namespace robot_joint_controller
}
// Controller stopping in realtime
void RobotJointController::stopping(){}
void RobotJointController::stopping() {}
void RobotJointController::positionLimits(double &position)
{