fix: rename

This commit is contained in:
fan-ziqi 2024-03-14 13:11:01 +08:00
parent 2794537724
commit cbaef30296
31 changed files with 98 additions and 131 deletions

View File

@ -52,7 +52,7 @@ catkin build
## Running
Before running, copy the trained pt model file to `rl_sar/src/unitree_rl/models`
Before running, copy the trained pt model file to `rl_sar/src/rl_sar/models`
### Simulation
@ -60,14 +60,14 @@ Open a new terminal, launch the gazebo simulation environment
```bash
source devel/setup.bash
roslaunch unitree_rl start_env.launch
roslaunch rl_sar start_env.launch
```
Open a new terminal, run the control program
```bash
source devel/setup.bash
rosrun unitree_rl unitree_rl
rosrun rl_sar unitree_rl
```
Open a new terminal, run the keyboard control program
@ -82,7 +82,7 @@ Open a new terminal, run the control program
```bash
source devel/setup.bash
rosrun unitree_rl unitree_rl_real
rosrun rl_sar unitree_rl_real
```
> Some code references: https://github.com/mertgungor/unitree_model_control

View File

@ -52,7 +52,7 @@ catkin build
## 运行
运行前请将训练好的pt模型文件拷贝到`rl_sar/src/unitree_rl/models`中
运行前请将训练好的pt模型文件拷贝到`rl_sar/src/rl_sar/models`中
### 仿真
@ -60,14 +60,14 @@ catkin build
```bash
source devel/setup.bash
roslaunch unitree_rl start_env.launch
roslaunch rl_sar start_env.launch
```
新建终端,启动控制程序
```bash
source devel/setup.bash
rosrun unitree_rl unitree_rl
rosrun rl_sar unitree_rl
```
新建终端,键盘控制程序
@ -82,7 +82,7 @@ rosrun teleop_twist_keyboard teleop_twist_keyboard.py
```bash
source devel/setup.bash
rosrun unitree_rl unitree_rl_real
rosrun rl_sar unitree_rl_real
```
> 部分代码参考https://github.com/mertgungor/unitree_model_control

View File

@ -1,8 +1,10 @@
cmake_minimum_required(VERSION 3.0.2)
project(unitree_rl)
project(rl_sar)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
add_definitions(-DCMAKE_CURRENT_SOURCE_DIR="${CMAKE_CURRENT_SOURCE_DIR}")
find_package(Torch REQUIRED)
find_package(catkin REQUIRED COMPONENTS
@ -39,22 +41,22 @@ include_directories(
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${GAZEBO_CXX_FLAGS}")
add_library(model library/model/model.cpp library/model/model.hpp)
target_link_libraries(model "${TORCH_LIBRARIES}")
set_property(TARGET model PROPERTY CXX_STANDARD 14)
add_library(rl library/rl/rl.cpp library/rl/rl.hpp)
target_link_libraries(rl "${TORCH_LIBRARIES}")
set_property(TARGET rl PROPERTY CXX_STANDARD 14)
add_library(observation_buffer library/observation_buffer/observation_buffer.cpp library/observation_buffer/observation_buffer.hpp)
target_link_libraries(observation_buffer "${TORCH_LIBRARIES}")
set_property(TARGET observation_buffer PROPERTY CXX_STANDARD 14)
add_executable(unitree_rl src/unitree_rl.cpp )
target_link_libraries(unitree_rl
add_executable(rl_sim src/rl_sim.cpp )
target_link_libraries(rl_sim
${catkin_LIBRARIES} ${EXTRA_LIBS} "${TORCH_LIBRARIES}"
model observation_buffer
rl observation_buffer
)
add_executable(unitree_rl_real src/unitree_rl_real.cpp )
target_link_libraries(unitree_rl_real
add_executable(rl_real src/rl_real.cpp )
target_link_libraries(rl_real
${catkin_LIBRARIES} ${EXTRA_LIBS} "${TORCH_LIBRARIES}"
model observation_buffer
rl observation_buffer
)

View File

@ -1,11 +1,7 @@
#ifndef UNITREE_RL
#define UNITREE_RL
#ifndef RL_REAL_HPP
#define RL_REAL_HPP
#include <ros/ros.h>
#include <gazebo_msgs/ModelStates.h>
#include <sensor_msgs/JointState.h>
#include <geometry_msgs/Twist.h>
#include "../library/model/model.hpp"
#include "../library/rl/rl.hpp"
#include "../library/observation_buffer/observation_buffer.hpp"
#include <unitree_legged_msgs/LowCmd.h>
#include "unitree_legged_msgs/LowState.h"
@ -16,13 +12,10 @@
using namespace UNITREE_LEGGED_SDK;
class Unitree_RL : public Model
class RL_Real : public RL
{
public:
Unitree_RL();
void modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg);
void jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg);
void cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg);
RL_Real();
void runModel();
torch::Tensor forward() override;
torch::Tensor compute_observation() override;
@ -60,35 +53,17 @@ private:
std::vector<std::string> torque_command_topics;
ros::Subscriber model_state_subscriber_;
ros::Subscriber joint_state_subscriber_;
ros::Subscriber cmd_vel_subscriber_;
std::map<std::string, ros::Publisher> torque_publishers;
std::vector<unitree_legged_msgs::MotorCmd> torque_commands;
geometry_msgs::Twist vel;
geometry_msgs::Pose pose;
geometry_msgs::Twist cmd_vel;
std::vector<std::string> joint_names;
std::vector<double> joint_positions;
std::vector<double> joint_velocities;
ros::Timer timer;
std::chrono::high_resolution_clock::time_point start_time;
// other rl module
torch::jit::script::Module encoder;
torch::jit::script::Module vq;
UNITREE_LEGGED_SDK::LowCmd SendLowLCM = {0};
UNITREE_LEGGED_SDK::LowState RecvLowLCM = {0};
unitree_legged_msgs::LowCmd SendLowROS;
unitree_legged_msgs::LowState RecvLowROS;
};
#endif

View File

@ -1,18 +1,18 @@
#ifndef UNITREE_RL
#define UNITREE_RL
#ifndef RL_SIM_HPP
#define RL_SIM_HPP
#include <ros/ros.h>
#include <gazebo_msgs/ModelStates.h>
#include <sensor_msgs/JointState.h>
#include <geometry_msgs/Twist.h>
#include "../library/model/model.hpp"
#include "../library/rl/rl.hpp"
#include "../library/observation_buffer/observation_buffer.hpp"
#include "unitree_legged_msgs/MotorCmd.h"
class Unitree_RL : public Model
class RL_Sim : public RL
{
public:
Unitree_RL();
RL_Sim();
void modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg);
void jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg);
void cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg);

View File

@ -13,7 +13,7 @@
<arg name="user_debug" default="false"/>
<include file="$(find gazebo_ros)/launch/empty_world.launch">
<arg name="world_name" value="$(find unitree_rl)/worlds/$(arg wname).world"/>
<arg name="world_name" value="$(find rl_sar)/worlds/$(arg wname).world"/>
<arg name="debug" value="$(arg debug)"/>
<arg name="gui" value="$(arg gui)"/>
<arg name="paused" value="$(arg paused)"/>

View File

@ -1,6 +1,6 @@
#include "model.hpp"
#include "rl.hpp"
torch::Tensor Model::quat_rotate_inverse(torch::Tensor q, torch::Tensor v)
torch::Tensor RL::quat_rotate_inverse(torch::Tensor q, torch::Tensor v)
{
c10::IntArrayRef shape = q.sizes();
torch::Tensor q_w = q.index({torch::indexing::Slice(), -1});
@ -11,7 +11,19 @@ torch::Tensor Model::quat_rotate_inverse(torch::Tensor q, torch::Tensor v)
return a - b + c;
}
torch::Tensor Model::compute_torques(torch::Tensor actions)
void RL::init_observations()
{
this->obs.lin_vel = torch::tensor({{0.0, 0.0, 0.0}});
this->obs.ang_vel = torch::tensor({{0.0, 0.0, 0.0}});
this->obs.gravity_vec = torch::tensor({{0.0, 0.0, -1.0}});
this->obs.commands = torch::tensor({{0.0, 0.0, 0.0}});
this->obs.base_quat = torch::tensor({{0.0, 0.0, 0.0, 1.0}});
this->obs.dof_pos = this->params.default_dof_pos;
this->obs.dof_vel = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
this->obs.actions = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
}
torch::Tensor RL::compute_torques(torch::Tensor actions)
{
torch::Tensor actions_scaled = actions * this->params.action_scale;
int indices[] = {0, 3, 6, 9};
@ -25,10 +37,9 @@ torch::Tensor Model::compute_torques(torch::Tensor actions)
return clamped;
}
torch::Tensor Model::compute_observation()
/* You may need to override this compute_observation() function
torch::Tensor RL::compute_observation()
{
std::cout << "You may need to override this compute_observation() function" << std::endl;
torch::Tensor obs = torch::cat({(this->quat_rotate_inverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale,
(this->quat_rotate_inverse(this->obs.base_quat, this->obs.ang_vel)) * this->params.ang_vel_scale,
this->quat_rotate_inverse(this->obs.base_quat, this->obs.gravity_vec),
@ -40,27 +51,15 @@ torch::Tensor Model::compute_observation()
obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
// printf("observation size: %d, %d\n", obs.sizes()[0], obs.sizes()[1]);
printf("observation size: %d, %d\n", obs.sizes()[0], obs.sizes()[1]);
return obs;
}
*/
void Model::init_observations()
/* You may need to override this forward() function
torch::Tensor RL::forward()
{
this->obs.lin_vel = torch::tensor({{0.0, 0.0, 0.0}});
this->obs.ang_vel = torch::tensor({{0.0, 0.0, 0.0}});
this->obs.gravity_vec = torch::tensor({{0.0, 0.0, -1.0}});
this->obs.commands = torch::tensor({{0.0, 0.0, 0.0}});
this->obs.base_quat = torch::tensor({{0.0, 0.0, 0.0, 1.0}});
this->obs.dof_pos = this->params.default_dof_pos;
this->obs.dof_vel = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
this->obs.actions = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
}
torch::Tensor Model::forward()
{
std::cout << "You may need to override this forward() function" << std::endl;
torch::Tensor obs = this->compute_observation();
torch::Tensor actor_input = torch::cat({obs}, 1);
@ -72,3 +71,4 @@ torch::Tensor Model::forward()
return clamped;
}
*/

View File

@ -1,5 +1,5 @@
#ifndef MODEL_HPP
#define MODEL_HPP
#ifndef RL_HPP
#define RL_HPP
#include <torch/script.h>
#include <iostream>
@ -36,14 +36,14 @@ struct Observations {
torch::Tensor actions;
};
class Model {
class RL {
public:
Model(){};
RL(){};
ModelParams params;
Observations obs;
virtual torch::Tensor forward();
virtual torch::Tensor compute_observation();
virtual torch::Tensor forward() = 0;
virtual torch::Tensor compute_observation() = 0;
torch::Tensor compute_torques(torch::Tensor actions);
torch::Tensor quat_rotate_inverse(torch::Tensor q, torch::Tensor v);
@ -63,4 +63,4 @@ protected:
torch::Tensor actions;
};
#endif // MODEL_HPP
#endif // RL_HPP

View File

@ -1,8 +1,8 @@
<?xml version="1.0"?>
<package format="2">
<name>unitree_rl</name>
<version>0.0.0</version>
<description>The unitree_rl package</description>
<name>rl_sar</name>
<version>2.0.0</version>
<description>The rl_sar package</description>
<maintainer email="fanziqi614@gmail.com">Ziqi Fan</maintainer>

View File

@ -1,17 +1,16 @@
#include "../include/unitree_rl_real.hpp"
#include <ros/package.h>
#include "../include/rl_real.hpp"
void Unitree_RL::UDPRecv()
void RL_Real::UDPRecv()
{
udp.Recv();
}
void Unitree_RL::UDPSend()
void RL_Real::UDPSend()
{
udp.Send();
}
void Unitree_RL::RobotControl()
void RL_Real::RobotControl()
{
motiontime++;
udp.GetRecv(state);
@ -82,20 +81,17 @@ void Unitree_RL::RobotControl()
udp.SetSend(cmd);
}
Unitree_RL::Unitree_RL() : safe(LeggedType::A1), udp(LOWLEVEL)
RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL)
{
udp.InitCmdData(cmd);
start_time = std::chrono::high_resolution_clock::now();
cmd_vel = geometry_msgs::Twist();
torque_commands.resize(12);
std::string package_name = "unitree_rl";
std::string actor_path = ros::package::getPath(package_name) + "/models/actor.pt";
std::string encoder_path = ros::package::getPath(package_name) + "/models/encoder.pt";
std::string vq_path = ros::package::getPath(package_name) + "/models/vq_layer.pt";
std::string actor_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/actor.pt";
std::string encoder_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/encoder.pt";
std::string vq_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/vq_layer.pt";
this->actor = torch::jit::load(actor_path);
this->encoder = torch::jit::load(encoder_path);
@ -133,23 +129,18 @@ Unitree_RL::Unitree_RL() : safe(LeggedType::A1), udp(LOWLEVEL)
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
// InitEnvironment();
loop_control = std::make_shared<LoopFunc>("control_loop", 0.02 , boost::bind(&Unitree_RL::RobotControl, this));
loop_udpSend = std::make_shared<LoopFunc>("udp_send" , 0.002, 3, boost::bind(&Unitree_RL::UDPSend, this));
loop_udpRecv = std::make_shared<LoopFunc>("udp_recv" , 0.002, 3, boost::bind(&Unitree_RL::UDPRecv, this));
loop_rl = std::make_shared<LoopFunc>("rl_loop" , 0.02 , boost::bind(&Unitree_RL::runModel, this));
loop_control = std::make_shared<LoopFunc>("control_loop", 0.02 , boost::bind(&RL_Real::RobotControl, this));
loop_udpSend = std::make_shared<LoopFunc>("udp_send" , 0.002, 3, boost::bind(&RL_Real::UDPSend, this));
loop_udpRecv = std::make_shared<LoopFunc>("udp_recv" , 0.002, 3, boost::bind(&RL_Real::UDPRecv, this));
loop_rl = std::make_shared<LoopFunc>("rl_loop" , 0.02 , boost::bind(&RL_Real::runModel, this));
loop_udpSend->start();
loop_udpRecv->start();
loop_control->start();
}
void Unitree_RL::cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg)
{
cmd_vel = *msg;
}
// void Unitree_RL::runModel(const ros::TimerEvent &event)
void Unitree_RL::runModel()
// void RL_Real::runModel(const ros::TimerEvent &event)
void RL_Real::runModel()
{
if(init_done)
{
@ -163,7 +154,7 @@ void Unitree_RL::runModel()
// printf("%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f\n", state.motorState[FL_0].dq, state.motorState[FL_1].dq, state.motorState[FL_2].dq, state.motorState[FR_0].dq, state.motorState[FR_1].dq, state.motorState[FR_2].dq, state.motorState[RL_0].dq, state.motorState[RL_1].dq, state.motorState[RL_2].dq, state.motorState[RR_0].dq, state.motorState[RR_1].dq, state.motorState[RR_2].dq);
this->obs.ang_vel = torch::tensor({{state.imu.gyroscope[0], state.imu.gyroscope[1], state.imu.gyroscope[2]}});
this->obs.commands = torch::tensor({{cmd_vel.linear.x, cmd_vel.linear.y, cmd_vel.angular.z}});
// this->obs.commands = torch::tensor({{cmd_vel.linear.x, cmd_vel.linear.y, cmd_vel.angular.z}});
this->obs.base_quat = torch::tensor({{state.imu.quaternion[1], state.imu.quaternion[2], state.imu.quaternion[3], state.imu.quaternion[0]}});
this->obs.dof_pos = torch::tensor({{state.motorState[FR_0].q, state.motorState[FR_1].q, state.motorState[FR_2].q,
state.motorState[FL_0].q, state.motorState[FL_1].q, state.motorState[FL_2].q,
@ -180,7 +171,7 @@ void Unitree_RL::runModel()
}
torch::Tensor Unitree_RL::compute_observation()
torch::Tensor RL_Real::compute_observation()
{
torch::Tensor ang_vel = this->quat_rotate_inverse(this->obs.base_quat, this->obs.ang_vel);
// float ang_vel_temp = ang_vel[0][0].item<double>();
@ -204,7 +195,7 @@ torch::Tensor Unitree_RL::compute_observation()
return obs;
}
torch::Tensor Unitree_RL::forward()
torch::Tensor RL_Real::forward()
{
torch::Tensor obs = this->compute_observation();
@ -229,7 +220,7 @@ torch::Tensor Unitree_RL::forward()
int main(int argc, char **argv)
{
Unitree_RL unitree_rl;
RL_Real rl_sar;
while(1){
sleep(10);

View File

@ -1,7 +1,7 @@
#include "../include/unitree_rl.hpp"
#include "../include/rl_sim.hpp"
#include <ros/package.h>
Unitree_RL::Unitree_RL()
RL_Sim::RL_Sim()
{
ros::NodeHandle nh;
start_time = std::chrono::high_resolution_clock::now();
@ -10,10 +10,9 @@ Unitree_RL::Unitree_RL()
torque_commands.resize(12);
std::string package_name = "unitree_rl";
std::string actor_path = ros::package::getPath(package_name) + "/models/actor.pt";
std::string encoder_path = ros::package::getPath(package_name) + "/models/encoder.pt";
std::string vq_path = ros::package::getPath(package_name) + "/models/vq_layer.pt";
std::string actor_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/actor.pt";
std::string encoder_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/encoder.pt";
std::string vq_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/vq_layer.pt";
this->actor = torch::jit::load(actor_path);
this->encoder = torch::jit::load(encoder_path);
@ -51,9 +50,9 @@ Unitree_RL::Unitree_RL()
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
cmd_vel_subscriber_ = nh.subscribe<geometry_msgs::Twist>(
"/cmd_vel", 10, &Unitree_RL::cmdvelCallback, this);
"/cmd_vel", 10, &RL_Sim::cmdvelCallback, this);
timer = nh.createTimer(ros::Duration(0.005), &Unitree_RL::runModel, this);
timer = nh.createTimer(ros::Duration(0.005), &RL_Sim::runModel, this);
ros_namespace = "/a1_gazebo/";
@ -71,31 +70,31 @@ Unitree_RL::Unitree_RL()
}
model_state_subscriber_ = nh.subscribe<gazebo_msgs::ModelStates>(
"/gazebo/model_states", 10, &Unitree_RL::modelStatesCallback, this);
"/gazebo/model_states", 10, &RL_Sim::modelStatesCallback, this);
joint_state_subscriber_ = nh.subscribe<sensor_msgs::JointState>(
"/a1_gazebo/joint_states", 10, &Unitree_RL::jointStatesCallback, this);
"/a1_gazebo/joint_states", 10, &RL_Sim::jointStatesCallback, this);
}
void Unitree_RL::modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg)
void RL_Sim::modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg)
{
vel = msg->twist[2];
pose = msg->pose[2];
}
void Unitree_RL::cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg)
void RL_Sim::cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg)
{
cmd_vel = *msg;
}
void Unitree_RL::jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg)
void RL_Sim::jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg)
{
joint_positions = msg->position;
joint_velocities = msg->velocity;
}
void Unitree_RL::runModel(const ros::TimerEvent &event)
void RL_Sim::runModel(const ros::TimerEvent &event)
{
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count();
// std::cout << "Execution time: " << duration << " microseconds" << std::endl;
@ -126,7 +125,7 @@ void Unitree_RL::runModel(const ros::TimerEvent &event)
}
}
torch::Tensor Unitree_RL::compute_observation()
torch::Tensor RL_Sim::compute_observation()
{
torch::Tensor obs = torch::cat({// (this->quat_rotate_inverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale,
(this->quat_rotate_inverse(this->obs.base_quat, this->obs.ang_vel)) * this->params.ang_vel_scale,
@ -140,7 +139,7 @@ torch::Tensor Unitree_RL::compute_observation()
return obs;
}
torch::Tensor Unitree_RL::forward()
torch::Tensor RL_Sim::forward()
{
torch::Tensor obs = this->compute_observation();
@ -163,8 +162,8 @@ torch::Tensor Unitree_RL::forward()
int main(int argc, char **argv)
{
ros::init(argc, argv, "unitree_rl");
Unitree_RL unitree_rl;
ros::init(argc, argv, "rl_sar");
RL_Sim rl_sar;
ros::spin();
return 0;
}