fix: rename

This commit is contained in:
fan-ziqi 2024-03-14 13:11:01 +08:00
parent 2794537724
commit cbaef30296
31 changed files with 98 additions and 131 deletions

View File

@ -52,7 +52,7 @@ catkin build
## Running ## Running
Before running, copy the trained pt model file to `rl_sar/src/unitree_rl/models` Before running, copy the trained pt model file to `rl_sar/src/rl_sar/models`
### Simulation ### Simulation
@ -60,14 +60,14 @@ Open a new terminal, launch the gazebo simulation environment
```bash ```bash
source devel/setup.bash source devel/setup.bash
roslaunch unitree_rl start_env.launch roslaunch rl_sar start_env.launch
``` ```
Open a new terminal, run the control program Open a new terminal, run the control program
```bash ```bash
source devel/setup.bash source devel/setup.bash
rosrun unitree_rl unitree_rl rosrun rl_sar unitree_rl
``` ```
Open a new terminal, run the keyboard control program Open a new terminal, run the keyboard control program
@ -82,7 +82,7 @@ Open a new terminal, run the control program
```bash ```bash
source devel/setup.bash source devel/setup.bash
rosrun unitree_rl unitree_rl_real rosrun rl_sar unitree_rl_real
``` ```
> Some code references: https://github.com/mertgungor/unitree_model_control > Some code references: https://github.com/mertgungor/unitree_model_control

View File

@ -52,7 +52,7 @@ catkin build
## 运行 ## 运行
运行前请将训练好的pt模型文件拷贝到`rl_sar/src/unitree_rl/models`中 运行前请将训练好的pt模型文件拷贝到`rl_sar/src/rl_sar/models`中
### 仿真 ### 仿真
@ -60,14 +60,14 @@ catkin build
```bash ```bash
source devel/setup.bash source devel/setup.bash
roslaunch unitree_rl start_env.launch roslaunch rl_sar start_env.launch
``` ```
新建终端,启动控制程序 新建终端,启动控制程序
```bash ```bash
source devel/setup.bash source devel/setup.bash
rosrun unitree_rl unitree_rl rosrun rl_sar unitree_rl
``` ```
新建终端,键盘控制程序 新建终端,键盘控制程序
@ -82,7 +82,7 @@ rosrun teleop_twist_keyboard teleop_twist_keyboard.py
```bash ```bash
source devel/setup.bash source devel/setup.bash
rosrun unitree_rl unitree_rl_real rosrun rl_sar unitree_rl_real
``` ```
> 部分代码参考https://github.com/mertgungor/unitree_model_control > 部分代码参考https://github.com/mertgungor/unitree_model_control

View File

@ -1,8 +1,10 @@
cmake_minimum_required(VERSION 3.0.2) cmake_minimum_required(VERSION 3.0.2)
project(unitree_rl) project(rl_sar)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
add_definitions(-DCMAKE_CURRENT_SOURCE_DIR="${CMAKE_CURRENT_SOURCE_DIR}")
find_package(Torch REQUIRED) find_package(Torch REQUIRED)
find_package(catkin REQUIRED COMPONENTS find_package(catkin REQUIRED COMPONENTS
@ -39,22 +41,22 @@ include_directories(
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${GAZEBO_CXX_FLAGS}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${GAZEBO_CXX_FLAGS}")
add_library(model library/model/model.cpp library/model/model.hpp) add_library(rl library/rl/rl.cpp library/rl/rl.hpp)
target_link_libraries(model "${TORCH_LIBRARIES}") target_link_libraries(rl "${TORCH_LIBRARIES}")
set_property(TARGET model PROPERTY CXX_STANDARD 14) set_property(TARGET rl PROPERTY CXX_STANDARD 14)
add_library(observation_buffer library/observation_buffer/observation_buffer.cpp library/observation_buffer/observation_buffer.hpp) add_library(observation_buffer library/observation_buffer/observation_buffer.cpp library/observation_buffer/observation_buffer.hpp)
target_link_libraries(observation_buffer "${TORCH_LIBRARIES}") target_link_libraries(observation_buffer "${TORCH_LIBRARIES}")
set_property(TARGET observation_buffer PROPERTY CXX_STANDARD 14) set_property(TARGET observation_buffer PROPERTY CXX_STANDARD 14)
add_executable(unitree_rl src/unitree_rl.cpp ) add_executable(rl_sim src/rl_sim.cpp )
target_link_libraries(unitree_rl target_link_libraries(rl_sim
${catkin_LIBRARIES} ${EXTRA_LIBS} "${TORCH_LIBRARIES}" ${catkin_LIBRARIES} ${EXTRA_LIBS} "${TORCH_LIBRARIES}"
model observation_buffer rl observation_buffer
) )
add_executable(unitree_rl_real src/unitree_rl_real.cpp ) add_executable(rl_real src/rl_real.cpp )
target_link_libraries(unitree_rl_real target_link_libraries(rl_real
${catkin_LIBRARIES} ${EXTRA_LIBS} "${TORCH_LIBRARIES}" ${catkin_LIBRARIES} ${EXTRA_LIBS} "${TORCH_LIBRARIES}"
model observation_buffer rl observation_buffer
) )

View File

@ -1,11 +1,7 @@
#ifndef UNITREE_RL #ifndef RL_REAL_HPP
#define UNITREE_RL #define RL_REAL_HPP
#include <ros/ros.h> #include "../library/rl/rl.hpp"
#include <gazebo_msgs/ModelStates.h>
#include <sensor_msgs/JointState.h>
#include <geometry_msgs/Twist.h>
#include "../library/model/model.hpp"
#include "../library/observation_buffer/observation_buffer.hpp" #include "../library/observation_buffer/observation_buffer.hpp"
#include <unitree_legged_msgs/LowCmd.h> #include <unitree_legged_msgs/LowCmd.h>
#include "unitree_legged_msgs/LowState.h" #include "unitree_legged_msgs/LowState.h"
@ -16,13 +12,10 @@
using namespace UNITREE_LEGGED_SDK; using namespace UNITREE_LEGGED_SDK;
class Unitree_RL : public Model class RL_Real : public RL
{ {
public: public:
Unitree_RL(); RL_Real();
void modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg);
void jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg);
void cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg);
void runModel(); void runModel();
torch::Tensor forward() override; torch::Tensor forward() override;
torch::Tensor compute_observation() override; torch::Tensor compute_observation() override;
@ -60,35 +53,17 @@ private:
std::vector<std::string> torque_command_topics; std::vector<std::string> torque_command_topics;
ros::Subscriber model_state_subscriber_;
ros::Subscriber joint_state_subscriber_;
ros::Subscriber cmd_vel_subscriber_;
std::map<std::string, ros::Publisher> torque_publishers;
std::vector<unitree_legged_msgs::MotorCmd> torque_commands; std::vector<unitree_legged_msgs::MotorCmd> torque_commands;
geometry_msgs::Twist vel;
geometry_msgs::Pose pose;
geometry_msgs::Twist cmd_vel;
std::vector<std::string> joint_names; std::vector<std::string> joint_names;
std::vector<double> joint_positions; std::vector<double> joint_positions;
std::vector<double> joint_velocities; std::vector<double> joint_velocities;
ros::Timer timer;
std::chrono::high_resolution_clock::time_point start_time; std::chrono::high_resolution_clock::time_point start_time;
// other rl module // other rl module
torch::jit::script::Module encoder; torch::jit::script::Module encoder;
torch::jit::script::Module vq; torch::jit::script::Module vq;
UNITREE_LEGGED_SDK::LowCmd SendLowLCM = {0};
UNITREE_LEGGED_SDK::LowState RecvLowLCM = {0};
unitree_legged_msgs::LowCmd SendLowROS;
unitree_legged_msgs::LowState RecvLowROS;
}; };
#endif #endif

View File

@ -1,18 +1,18 @@
#ifndef UNITREE_RL #ifndef RL_SIM_HPP
#define UNITREE_RL #define RL_SIM_HPP
#include <ros/ros.h> #include <ros/ros.h>
#include <gazebo_msgs/ModelStates.h> #include <gazebo_msgs/ModelStates.h>
#include <sensor_msgs/JointState.h> #include <sensor_msgs/JointState.h>
#include <geometry_msgs/Twist.h> #include <geometry_msgs/Twist.h>
#include "../library/model/model.hpp" #include "../library/rl/rl.hpp"
#include "../library/observation_buffer/observation_buffer.hpp" #include "../library/observation_buffer/observation_buffer.hpp"
#include "unitree_legged_msgs/MotorCmd.h" #include "unitree_legged_msgs/MotorCmd.h"
class Unitree_RL : public Model class RL_Sim : public RL
{ {
public: public:
Unitree_RL(); RL_Sim();
void modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg); void modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg);
void jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg); void jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg);
void cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg); void cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg);

View File

@ -13,7 +13,7 @@
<arg name="user_debug" default="false"/> <arg name="user_debug" default="false"/>
<include file="$(find gazebo_ros)/launch/empty_world.launch"> <include file="$(find gazebo_ros)/launch/empty_world.launch">
<arg name="world_name" value="$(find unitree_rl)/worlds/$(arg wname).world"/> <arg name="world_name" value="$(find rl_sar)/worlds/$(arg wname).world"/>
<arg name="debug" value="$(arg debug)"/> <arg name="debug" value="$(arg debug)"/>
<arg name="gui" value="$(arg gui)"/> <arg name="gui" value="$(arg gui)"/>
<arg name="paused" value="$(arg paused)"/> <arg name="paused" value="$(arg paused)"/>

View File

@ -1,6 +1,6 @@
#include "model.hpp" #include "rl.hpp"
torch::Tensor Model::quat_rotate_inverse(torch::Tensor q, torch::Tensor v) torch::Tensor RL::quat_rotate_inverse(torch::Tensor q, torch::Tensor v)
{ {
c10::IntArrayRef shape = q.sizes(); c10::IntArrayRef shape = q.sizes();
torch::Tensor q_w = q.index({torch::indexing::Slice(), -1}); torch::Tensor q_w = q.index({torch::indexing::Slice(), -1});
@ -11,7 +11,19 @@ torch::Tensor Model::quat_rotate_inverse(torch::Tensor q, torch::Tensor v)
return a - b + c; return a - b + c;
} }
torch::Tensor Model::compute_torques(torch::Tensor actions) void RL::init_observations()
{
this->obs.lin_vel = torch::tensor({{0.0, 0.0, 0.0}});
this->obs.ang_vel = torch::tensor({{0.0, 0.0, 0.0}});
this->obs.gravity_vec = torch::tensor({{0.0, 0.0, -1.0}});
this->obs.commands = torch::tensor({{0.0, 0.0, 0.0}});
this->obs.base_quat = torch::tensor({{0.0, 0.0, 0.0, 1.0}});
this->obs.dof_pos = this->params.default_dof_pos;
this->obs.dof_vel = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
this->obs.actions = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
}
torch::Tensor RL::compute_torques(torch::Tensor actions)
{ {
torch::Tensor actions_scaled = actions * this->params.action_scale; torch::Tensor actions_scaled = actions * this->params.action_scale;
int indices[] = {0, 3, 6, 9}; int indices[] = {0, 3, 6, 9};
@ -25,10 +37,9 @@ torch::Tensor Model::compute_torques(torch::Tensor actions)
return clamped; return clamped;
} }
torch::Tensor Model::compute_observation() /* You may need to override this compute_observation() function
torch::Tensor RL::compute_observation()
{ {
std::cout << "You may need to override this compute_observation() function" << std::endl;
torch::Tensor obs = torch::cat({(this->quat_rotate_inverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale, torch::Tensor obs = torch::cat({(this->quat_rotate_inverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale,
(this->quat_rotate_inverse(this->obs.base_quat, this->obs.ang_vel)) * this->params.ang_vel_scale, (this->quat_rotate_inverse(this->obs.base_quat, this->obs.ang_vel)) * this->params.ang_vel_scale,
this->quat_rotate_inverse(this->obs.base_quat, this->obs.gravity_vec), this->quat_rotate_inverse(this->obs.base_quat, this->obs.gravity_vec),
@ -40,27 +51,15 @@ torch::Tensor Model::compute_observation()
obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs); obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
// printf("observation size: %d, %d\n", obs.sizes()[0], obs.sizes()[1]); printf("observation size: %d, %d\n", obs.sizes()[0], obs.sizes()[1]);
return obs; return obs;
} }
*/
void Model::init_observations() /* You may need to override this forward() function
torch::Tensor RL::forward()
{ {
this->obs.lin_vel = torch::tensor({{0.0, 0.0, 0.0}});
this->obs.ang_vel = torch::tensor({{0.0, 0.0, 0.0}});
this->obs.gravity_vec = torch::tensor({{0.0, 0.0, -1.0}});
this->obs.commands = torch::tensor({{0.0, 0.0, 0.0}});
this->obs.base_quat = torch::tensor({{0.0, 0.0, 0.0, 1.0}});
this->obs.dof_pos = this->params.default_dof_pos;
this->obs.dof_vel = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
this->obs.actions = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
}
torch::Tensor Model::forward()
{
std::cout << "You may need to override this forward() function" << std::endl;
torch::Tensor obs = this->compute_observation(); torch::Tensor obs = this->compute_observation();
torch::Tensor actor_input = torch::cat({obs}, 1); torch::Tensor actor_input = torch::cat({obs}, 1);
@ -72,3 +71,4 @@ torch::Tensor Model::forward()
return clamped; return clamped;
} }
*/

View File

@ -1,5 +1,5 @@
#ifndef MODEL_HPP #ifndef RL_HPP
#define MODEL_HPP #define RL_HPP
#include <torch/script.h> #include <torch/script.h>
#include <iostream> #include <iostream>
@ -36,14 +36,14 @@ struct Observations {
torch::Tensor actions; torch::Tensor actions;
}; };
class Model { class RL {
public: public:
Model(){}; RL(){};
ModelParams params; ModelParams params;
Observations obs; Observations obs;
virtual torch::Tensor forward(); virtual torch::Tensor forward() = 0;
virtual torch::Tensor compute_observation(); virtual torch::Tensor compute_observation() = 0;
torch::Tensor compute_torques(torch::Tensor actions); torch::Tensor compute_torques(torch::Tensor actions);
torch::Tensor quat_rotate_inverse(torch::Tensor q, torch::Tensor v); torch::Tensor quat_rotate_inverse(torch::Tensor q, torch::Tensor v);
@ -63,4 +63,4 @@ protected:
torch::Tensor actions; torch::Tensor actions;
}; };
#endif // MODEL_HPP #endif // RL_HPP

View File

@ -1,8 +1,8 @@
<?xml version="1.0"?> <?xml version="1.0"?>
<package format="2"> <package format="2">
<name>unitree_rl</name> <name>rl_sar</name>
<version>0.0.0</version> <version>2.0.0</version>
<description>The unitree_rl package</description> <description>The rl_sar package</description>
<maintainer email="fanziqi614@gmail.com">Ziqi Fan</maintainer> <maintainer email="fanziqi614@gmail.com">Ziqi Fan</maintainer>

View File

@ -1,17 +1,16 @@
#include "../include/unitree_rl_real.hpp" #include "../include/rl_real.hpp"
#include <ros/package.h>
void Unitree_RL::UDPRecv() void RL_Real::UDPRecv()
{ {
udp.Recv(); udp.Recv();
} }
void Unitree_RL::UDPSend() void RL_Real::UDPSend()
{ {
udp.Send(); udp.Send();
} }
void Unitree_RL::RobotControl() void RL_Real::RobotControl()
{ {
motiontime++; motiontime++;
udp.GetRecv(state); udp.GetRecv(state);
@ -82,20 +81,17 @@ void Unitree_RL::RobotControl()
udp.SetSend(cmd); udp.SetSend(cmd);
} }
Unitree_RL::Unitree_RL() : safe(LeggedType::A1), udp(LOWLEVEL) RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL)
{ {
udp.InitCmdData(cmd); udp.InitCmdData(cmd);
start_time = std::chrono::high_resolution_clock::now(); start_time = std::chrono::high_resolution_clock::now();
cmd_vel = geometry_msgs::Twist();
torque_commands.resize(12); torque_commands.resize(12);
std::string package_name = "unitree_rl"; std::string actor_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/actor.pt";
std::string actor_path = ros::package::getPath(package_name) + "/models/actor.pt"; std::string encoder_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/encoder.pt";
std::string encoder_path = ros::package::getPath(package_name) + "/models/encoder.pt"; std::string vq_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/vq_layer.pt";
std::string vq_path = ros::package::getPath(package_name) + "/models/vq_layer.pt";
this->actor = torch::jit::load(actor_path); this->actor = torch::jit::load(actor_path);
this->encoder = torch::jit::load(encoder_path); this->encoder = torch::jit::load(encoder_path);
@ -133,23 +129,18 @@ Unitree_RL::Unitree_RL() : safe(LeggedType::A1), udp(LOWLEVEL)
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6); this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
// InitEnvironment(); // InitEnvironment();
loop_control = std::make_shared<LoopFunc>("control_loop", 0.02 , boost::bind(&Unitree_RL::RobotControl, this)); loop_control = std::make_shared<LoopFunc>("control_loop", 0.02 , boost::bind(&RL_Real::RobotControl, this));
loop_udpSend = std::make_shared<LoopFunc>("udp_send" , 0.002, 3, boost::bind(&Unitree_RL::UDPSend, this)); loop_udpSend = std::make_shared<LoopFunc>("udp_send" , 0.002, 3, boost::bind(&RL_Real::UDPSend, this));
loop_udpRecv = std::make_shared<LoopFunc>("udp_recv" , 0.002, 3, boost::bind(&Unitree_RL::UDPRecv, this)); loop_udpRecv = std::make_shared<LoopFunc>("udp_recv" , 0.002, 3, boost::bind(&RL_Real::UDPRecv, this));
loop_rl = std::make_shared<LoopFunc>("rl_loop" , 0.02 , boost::bind(&Unitree_RL::runModel, this)); loop_rl = std::make_shared<LoopFunc>("rl_loop" , 0.02 , boost::bind(&RL_Real::runModel, this));
loop_udpSend->start(); loop_udpSend->start();
loop_udpRecv->start(); loop_udpRecv->start();
loop_control->start(); loop_control->start();
} }
void Unitree_RL::cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg) // void RL_Real::runModel(const ros::TimerEvent &event)
{ void RL_Real::runModel()
cmd_vel = *msg;
}
// void Unitree_RL::runModel(const ros::TimerEvent &event)
void Unitree_RL::runModel()
{ {
if(init_done) if(init_done)
{ {
@ -163,7 +154,7 @@ void Unitree_RL::runModel()
// printf("%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f\n", state.motorState[FL_0].dq, state.motorState[FL_1].dq, state.motorState[FL_2].dq, state.motorState[FR_0].dq, state.motorState[FR_1].dq, state.motorState[FR_2].dq, state.motorState[RL_0].dq, state.motorState[RL_1].dq, state.motorState[RL_2].dq, state.motorState[RR_0].dq, state.motorState[RR_1].dq, state.motorState[RR_2].dq); // printf("%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f\n", state.motorState[FL_0].dq, state.motorState[FL_1].dq, state.motorState[FL_2].dq, state.motorState[FR_0].dq, state.motorState[FR_1].dq, state.motorState[FR_2].dq, state.motorState[RL_0].dq, state.motorState[RL_1].dq, state.motorState[RL_2].dq, state.motorState[RR_0].dq, state.motorState[RR_1].dq, state.motorState[RR_2].dq);
this->obs.ang_vel = torch::tensor({{state.imu.gyroscope[0], state.imu.gyroscope[1], state.imu.gyroscope[2]}}); this->obs.ang_vel = torch::tensor({{state.imu.gyroscope[0], state.imu.gyroscope[1], state.imu.gyroscope[2]}});
this->obs.commands = torch::tensor({{cmd_vel.linear.x, cmd_vel.linear.y, cmd_vel.angular.z}}); // this->obs.commands = torch::tensor({{cmd_vel.linear.x, cmd_vel.linear.y, cmd_vel.angular.z}});
this->obs.base_quat = torch::tensor({{state.imu.quaternion[1], state.imu.quaternion[2], state.imu.quaternion[3], state.imu.quaternion[0]}}); this->obs.base_quat = torch::tensor({{state.imu.quaternion[1], state.imu.quaternion[2], state.imu.quaternion[3], state.imu.quaternion[0]}});
this->obs.dof_pos = torch::tensor({{state.motorState[FR_0].q, state.motorState[FR_1].q, state.motorState[FR_2].q, this->obs.dof_pos = torch::tensor({{state.motorState[FR_0].q, state.motorState[FR_1].q, state.motorState[FR_2].q,
state.motorState[FL_0].q, state.motorState[FL_1].q, state.motorState[FL_2].q, state.motorState[FL_0].q, state.motorState[FL_1].q, state.motorState[FL_2].q,
@ -180,7 +171,7 @@ void Unitree_RL::runModel()
} }
torch::Tensor Unitree_RL::compute_observation() torch::Tensor RL_Real::compute_observation()
{ {
torch::Tensor ang_vel = this->quat_rotate_inverse(this->obs.base_quat, this->obs.ang_vel); torch::Tensor ang_vel = this->quat_rotate_inverse(this->obs.base_quat, this->obs.ang_vel);
// float ang_vel_temp = ang_vel[0][0].item<double>(); // float ang_vel_temp = ang_vel[0][0].item<double>();
@ -204,7 +195,7 @@ torch::Tensor Unitree_RL::compute_observation()
return obs; return obs;
} }
torch::Tensor Unitree_RL::forward() torch::Tensor RL_Real::forward()
{ {
torch::Tensor obs = this->compute_observation(); torch::Tensor obs = this->compute_observation();
@ -229,7 +220,7 @@ torch::Tensor Unitree_RL::forward()
int main(int argc, char **argv) int main(int argc, char **argv)
{ {
Unitree_RL unitree_rl; RL_Real rl_sar;
while(1){ while(1){
sleep(10); sleep(10);

View File

@ -1,7 +1,7 @@
#include "../include/unitree_rl.hpp" #include "../include/rl_sim.hpp"
#include <ros/package.h> #include <ros/package.h>
Unitree_RL::Unitree_RL() RL_Sim::RL_Sim()
{ {
ros::NodeHandle nh; ros::NodeHandle nh;
start_time = std::chrono::high_resolution_clock::now(); start_time = std::chrono::high_resolution_clock::now();
@ -10,10 +10,9 @@ Unitree_RL::Unitree_RL()
torque_commands.resize(12); torque_commands.resize(12);
std::string package_name = "unitree_rl"; std::string actor_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/actor.pt";
std::string actor_path = ros::package::getPath(package_name) + "/models/actor.pt"; std::string encoder_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/encoder.pt";
std::string encoder_path = ros::package::getPath(package_name) + "/models/encoder.pt"; std::string vq_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/vq_layer.pt";
std::string vq_path = ros::package::getPath(package_name) + "/models/vq_layer.pt";
this->actor = torch::jit::load(actor_path); this->actor = torch::jit::load(actor_path);
this->encoder = torch::jit::load(encoder_path); this->encoder = torch::jit::load(encoder_path);
@ -51,9 +50,9 @@ Unitree_RL::Unitree_RL()
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6); this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
cmd_vel_subscriber_ = nh.subscribe<geometry_msgs::Twist>( cmd_vel_subscriber_ = nh.subscribe<geometry_msgs::Twist>(
"/cmd_vel", 10, &Unitree_RL::cmdvelCallback, this); "/cmd_vel", 10, &RL_Sim::cmdvelCallback, this);
timer = nh.createTimer(ros::Duration(0.005), &Unitree_RL::runModel, this); timer = nh.createTimer(ros::Duration(0.005), &RL_Sim::runModel, this);
ros_namespace = "/a1_gazebo/"; ros_namespace = "/a1_gazebo/";
@ -71,31 +70,31 @@ Unitree_RL::Unitree_RL()
} }
model_state_subscriber_ = nh.subscribe<gazebo_msgs::ModelStates>( model_state_subscriber_ = nh.subscribe<gazebo_msgs::ModelStates>(
"/gazebo/model_states", 10, &Unitree_RL::modelStatesCallback, this); "/gazebo/model_states", 10, &RL_Sim::modelStatesCallback, this);
joint_state_subscriber_ = nh.subscribe<sensor_msgs::JointState>( joint_state_subscriber_ = nh.subscribe<sensor_msgs::JointState>(
"/a1_gazebo/joint_states", 10, &Unitree_RL::jointStatesCallback, this); "/a1_gazebo/joint_states", 10, &RL_Sim::jointStatesCallback, this);
} }
void Unitree_RL::modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg) void RL_Sim::modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg)
{ {
vel = msg->twist[2]; vel = msg->twist[2];
pose = msg->pose[2]; pose = msg->pose[2];
} }
void Unitree_RL::cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg) void RL_Sim::cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg)
{ {
cmd_vel = *msg; cmd_vel = *msg;
} }
void Unitree_RL::jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg) void RL_Sim::jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg)
{ {
joint_positions = msg->position; joint_positions = msg->position;
joint_velocities = msg->velocity; joint_velocities = msg->velocity;
} }
void Unitree_RL::runModel(const ros::TimerEvent &event) void RL_Sim::runModel(const ros::TimerEvent &event)
{ {
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count(); auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count();
// std::cout << "Execution time: " << duration << " microseconds" << std::endl; // std::cout << "Execution time: " << duration << " microseconds" << std::endl;
@ -126,7 +125,7 @@ void Unitree_RL::runModel(const ros::TimerEvent &event)
} }
} }
torch::Tensor Unitree_RL::compute_observation() torch::Tensor RL_Sim::compute_observation()
{ {
torch::Tensor obs = torch::cat({// (this->quat_rotate_inverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale, torch::Tensor obs = torch::cat({// (this->quat_rotate_inverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale,
(this->quat_rotate_inverse(this->obs.base_quat, this->obs.ang_vel)) * this->params.ang_vel_scale, (this->quat_rotate_inverse(this->obs.base_quat, this->obs.ang_vel)) * this->params.ang_vel_scale,
@ -140,7 +139,7 @@ torch::Tensor Unitree_RL::compute_observation()
return obs; return obs;
} }
torch::Tensor Unitree_RL::forward() torch::Tensor RL_Sim::forward()
{ {
torch::Tensor obs = this->compute_observation(); torch::Tensor obs = this->compute_observation();
@ -163,8 +162,8 @@ torch::Tensor Unitree_RL::forward()
int main(int argc, char **argv) int main(int argc, char **argv)
{ {
ros::init(argc, argv, "unitree_rl"); ros::init(argc, argv, "rl_sar");
Unitree_RL unitree_rl; RL_Sim rl_sar;
ros::spin(); ros::spin();
return 0; return 0;
} }