mirror of https://github.com/fan-ziqi/rl_sar.git
fix: rename
This commit is contained in:
parent
2794537724
commit
cbaef30296
|
@ -52,7 +52,7 @@ catkin build
|
||||||
|
|
||||||
## Running
|
## Running
|
||||||
|
|
||||||
Before running, copy the trained pt model file to `rl_sar/src/unitree_rl/models`
|
Before running, copy the trained pt model file to `rl_sar/src/rl_sar/models`
|
||||||
|
|
||||||
### Simulation
|
### Simulation
|
||||||
|
|
||||||
|
@ -60,14 +60,14 @@ Open a new terminal, launch the gazebo simulation environment
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
source devel/setup.bash
|
source devel/setup.bash
|
||||||
roslaunch unitree_rl start_env.launch
|
roslaunch rl_sar start_env.launch
|
||||||
```
|
```
|
||||||
|
|
||||||
Open a new terminal, run the control program
|
Open a new terminal, run the control program
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
source devel/setup.bash
|
source devel/setup.bash
|
||||||
rosrun unitree_rl unitree_rl
|
rosrun rl_sar unitree_rl
|
||||||
```
|
```
|
||||||
|
|
||||||
Open a new terminal, run the keyboard control program
|
Open a new terminal, run the keyboard control program
|
||||||
|
@ -82,7 +82,7 @@ Open a new terminal, run the control program
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
source devel/setup.bash
|
source devel/setup.bash
|
||||||
rosrun unitree_rl unitree_rl_real
|
rosrun rl_sar unitree_rl_real
|
||||||
```
|
```
|
||||||
|
|
||||||
> Some code references: https://github.com/mertgungor/unitree_model_control
|
> Some code references: https://github.com/mertgungor/unitree_model_control
|
|
@ -52,7 +52,7 @@ catkin build
|
||||||
|
|
||||||
## 运行
|
## 运行
|
||||||
|
|
||||||
运行前请将训练好的pt模型文件拷贝到`rl_sar/src/unitree_rl/models`中
|
运行前请将训练好的pt模型文件拷贝到`rl_sar/src/rl_sar/models`中
|
||||||
|
|
||||||
### 仿真
|
### 仿真
|
||||||
|
|
||||||
|
@ -60,14 +60,14 @@ catkin build
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
source devel/setup.bash
|
source devel/setup.bash
|
||||||
roslaunch unitree_rl start_env.launch
|
roslaunch rl_sar start_env.launch
|
||||||
```
|
```
|
||||||
|
|
||||||
新建终端,启动控制程序
|
新建终端,启动控制程序
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
source devel/setup.bash
|
source devel/setup.bash
|
||||||
rosrun unitree_rl unitree_rl
|
rosrun rl_sar unitree_rl
|
||||||
```
|
```
|
||||||
|
|
||||||
新建终端,键盘控制程序
|
新建终端,键盘控制程序
|
||||||
|
@ -82,7 +82,7 @@ rosrun teleop_twist_keyboard teleop_twist_keyboard.py
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
source devel/setup.bash
|
source devel/setup.bash
|
||||||
rosrun unitree_rl unitree_rl_real
|
rosrun rl_sar unitree_rl_real
|
||||||
```
|
```
|
||||||
|
|
||||||
> 部分代码参考https://github.com/mertgungor/unitree_model_control
|
> 部分代码参考https://github.com/mertgungor/unitree_model_control
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
cmake_minimum_required(VERSION 3.0.2)
|
cmake_minimum_required(VERSION 3.0.2)
|
||||||
project(unitree_rl)
|
project(rl_sar)
|
||||||
|
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
|
||||||
|
|
||||||
|
add_definitions(-DCMAKE_CURRENT_SOURCE_DIR="${CMAKE_CURRENT_SOURCE_DIR}")
|
||||||
|
|
||||||
find_package(Torch REQUIRED)
|
find_package(Torch REQUIRED)
|
||||||
|
|
||||||
find_package(catkin REQUIRED COMPONENTS
|
find_package(catkin REQUIRED COMPONENTS
|
||||||
|
@ -39,22 +41,22 @@ include_directories(
|
||||||
|
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${GAZEBO_CXX_FLAGS}")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${GAZEBO_CXX_FLAGS}")
|
||||||
|
|
||||||
add_library(model library/model/model.cpp library/model/model.hpp)
|
add_library(rl library/rl/rl.cpp library/rl/rl.hpp)
|
||||||
target_link_libraries(model "${TORCH_LIBRARIES}")
|
target_link_libraries(rl "${TORCH_LIBRARIES}")
|
||||||
set_property(TARGET model PROPERTY CXX_STANDARD 14)
|
set_property(TARGET rl PROPERTY CXX_STANDARD 14)
|
||||||
|
|
||||||
add_library(observation_buffer library/observation_buffer/observation_buffer.cpp library/observation_buffer/observation_buffer.hpp)
|
add_library(observation_buffer library/observation_buffer/observation_buffer.cpp library/observation_buffer/observation_buffer.hpp)
|
||||||
target_link_libraries(observation_buffer "${TORCH_LIBRARIES}")
|
target_link_libraries(observation_buffer "${TORCH_LIBRARIES}")
|
||||||
set_property(TARGET observation_buffer PROPERTY CXX_STANDARD 14)
|
set_property(TARGET observation_buffer PROPERTY CXX_STANDARD 14)
|
||||||
|
|
||||||
add_executable(unitree_rl src/unitree_rl.cpp )
|
add_executable(rl_sim src/rl_sim.cpp )
|
||||||
target_link_libraries(unitree_rl
|
target_link_libraries(rl_sim
|
||||||
${catkin_LIBRARIES} ${EXTRA_LIBS} "${TORCH_LIBRARIES}"
|
${catkin_LIBRARIES} ${EXTRA_LIBS} "${TORCH_LIBRARIES}"
|
||||||
model observation_buffer
|
rl observation_buffer
|
||||||
)
|
)
|
||||||
|
|
||||||
add_executable(unitree_rl_real src/unitree_rl_real.cpp )
|
add_executable(rl_real src/rl_real.cpp )
|
||||||
target_link_libraries(unitree_rl_real
|
target_link_libraries(rl_real
|
||||||
${catkin_LIBRARIES} ${EXTRA_LIBS} "${TORCH_LIBRARIES}"
|
${catkin_LIBRARIES} ${EXTRA_LIBS} "${TORCH_LIBRARIES}"
|
||||||
model observation_buffer
|
rl observation_buffer
|
||||||
)
|
)
|
|
@ -1,11 +1,7 @@
|
||||||
#ifndef UNITREE_RL
|
#ifndef RL_REAL_HPP
|
||||||
#define UNITREE_RL
|
#define RL_REAL_HPP
|
||||||
|
|
||||||
#include <ros/ros.h>
|
#include "../library/rl/rl.hpp"
|
||||||
#include <gazebo_msgs/ModelStates.h>
|
|
||||||
#include <sensor_msgs/JointState.h>
|
|
||||||
#include <geometry_msgs/Twist.h>
|
|
||||||
#include "../library/model/model.hpp"
|
|
||||||
#include "../library/observation_buffer/observation_buffer.hpp"
|
#include "../library/observation_buffer/observation_buffer.hpp"
|
||||||
#include <unitree_legged_msgs/LowCmd.h>
|
#include <unitree_legged_msgs/LowCmd.h>
|
||||||
#include "unitree_legged_msgs/LowState.h"
|
#include "unitree_legged_msgs/LowState.h"
|
||||||
|
@ -16,13 +12,10 @@
|
||||||
|
|
||||||
using namespace UNITREE_LEGGED_SDK;
|
using namespace UNITREE_LEGGED_SDK;
|
||||||
|
|
||||||
class Unitree_RL : public Model
|
class RL_Real : public RL
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
Unitree_RL();
|
RL_Real();
|
||||||
void modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg);
|
|
||||||
void jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg);
|
|
||||||
void cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg);
|
|
||||||
void runModel();
|
void runModel();
|
||||||
torch::Tensor forward() override;
|
torch::Tensor forward() override;
|
||||||
torch::Tensor compute_observation() override;
|
torch::Tensor compute_observation() override;
|
||||||
|
@ -60,35 +53,17 @@ private:
|
||||||
|
|
||||||
std::vector<std::string> torque_command_topics;
|
std::vector<std::string> torque_command_topics;
|
||||||
|
|
||||||
ros::Subscriber model_state_subscriber_;
|
|
||||||
ros::Subscriber joint_state_subscriber_;
|
|
||||||
ros::Subscriber cmd_vel_subscriber_;
|
|
||||||
|
|
||||||
std::map<std::string, ros::Publisher> torque_publishers;
|
|
||||||
std::vector<unitree_legged_msgs::MotorCmd> torque_commands;
|
std::vector<unitree_legged_msgs::MotorCmd> torque_commands;
|
||||||
|
|
||||||
geometry_msgs::Twist vel;
|
|
||||||
geometry_msgs::Pose pose;
|
|
||||||
geometry_msgs::Twist cmd_vel;
|
|
||||||
|
|
||||||
std::vector<std::string> joint_names;
|
std::vector<std::string> joint_names;
|
||||||
std::vector<double> joint_positions;
|
std::vector<double> joint_positions;
|
||||||
std::vector<double> joint_velocities;
|
std::vector<double> joint_velocities;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
ros::Timer timer;
|
|
||||||
|
|
||||||
std::chrono::high_resolution_clock::time_point start_time;
|
std::chrono::high_resolution_clock::time_point start_time;
|
||||||
|
|
||||||
// other rl module
|
// other rl module
|
||||||
torch::jit::script::Module encoder;
|
torch::jit::script::Module encoder;
|
||||||
torch::jit::script::Module vq;
|
torch::jit::script::Module vq;
|
||||||
|
|
||||||
UNITREE_LEGGED_SDK::LowCmd SendLowLCM = {0};
|
|
||||||
UNITREE_LEGGED_SDK::LowState RecvLowLCM = {0};
|
|
||||||
unitree_legged_msgs::LowCmd SendLowROS;
|
|
||||||
unitree_legged_msgs::LowState RecvLowROS;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif
|
#endif
|
|
@ -1,18 +1,18 @@
|
||||||
#ifndef UNITREE_RL
|
#ifndef RL_SIM_HPP
|
||||||
#define UNITREE_RL
|
#define RL_SIM_HPP
|
||||||
|
|
||||||
#include <ros/ros.h>
|
#include <ros/ros.h>
|
||||||
#include <gazebo_msgs/ModelStates.h>
|
#include <gazebo_msgs/ModelStates.h>
|
||||||
#include <sensor_msgs/JointState.h>
|
#include <sensor_msgs/JointState.h>
|
||||||
#include <geometry_msgs/Twist.h>
|
#include <geometry_msgs/Twist.h>
|
||||||
#include "../library/model/model.hpp"
|
#include "../library/rl/rl.hpp"
|
||||||
#include "../library/observation_buffer/observation_buffer.hpp"
|
#include "../library/observation_buffer/observation_buffer.hpp"
|
||||||
#include "unitree_legged_msgs/MotorCmd.h"
|
#include "unitree_legged_msgs/MotorCmd.h"
|
||||||
|
|
||||||
class Unitree_RL : public Model
|
class RL_Sim : public RL
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
Unitree_RL();
|
RL_Sim();
|
||||||
void modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg);
|
void modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg);
|
||||||
void jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg);
|
void jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg);
|
||||||
void cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg);
|
void cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg);
|
|
@ -13,7 +13,7 @@
|
||||||
<arg name="user_debug" default="false"/>
|
<arg name="user_debug" default="false"/>
|
||||||
|
|
||||||
<include file="$(find gazebo_ros)/launch/empty_world.launch">
|
<include file="$(find gazebo_ros)/launch/empty_world.launch">
|
||||||
<arg name="world_name" value="$(find unitree_rl)/worlds/$(arg wname).world"/>
|
<arg name="world_name" value="$(find rl_sar)/worlds/$(arg wname).world"/>
|
||||||
<arg name="debug" value="$(arg debug)"/>
|
<arg name="debug" value="$(arg debug)"/>
|
||||||
<arg name="gui" value="$(arg gui)"/>
|
<arg name="gui" value="$(arg gui)"/>
|
||||||
<arg name="paused" value="$(arg paused)"/>
|
<arg name="paused" value="$(arg paused)"/>
|
|
@ -1,6 +1,6 @@
|
||||||
#include "model.hpp"
|
#include "rl.hpp"
|
||||||
|
|
||||||
torch::Tensor Model::quat_rotate_inverse(torch::Tensor q, torch::Tensor v)
|
torch::Tensor RL::quat_rotate_inverse(torch::Tensor q, torch::Tensor v)
|
||||||
{
|
{
|
||||||
c10::IntArrayRef shape = q.sizes();
|
c10::IntArrayRef shape = q.sizes();
|
||||||
torch::Tensor q_w = q.index({torch::indexing::Slice(), -1});
|
torch::Tensor q_w = q.index({torch::indexing::Slice(), -1});
|
||||||
|
@ -11,7 +11,19 @@ torch::Tensor Model::quat_rotate_inverse(torch::Tensor q, torch::Tensor v)
|
||||||
return a - b + c;
|
return a - b + c;
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor Model::compute_torques(torch::Tensor actions)
|
void RL::init_observations()
|
||||||
|
{
|
||||||
|
this->obs.lin_vel = torch::tensor({{0.0, 0.0, 0.0}});
|
||||||
|
this->obs.ang_vel = torch::tensor({{0.0, 0.0, 0.0}});
|
||||||
|
this->obs.gravity_vec = torch::tensor({{0.0, 0.0, -1.0}});
|
||||||
|
this->obs.commands = torch::tensor({{0.0, 0.0, 0.0}});
|
||||||
|
this->obs.base_quat = torch::tensor({{0.0, 0.0, 0.0, 1.0}});
|
||||||
|
this->obs.dof_pos = this->params.default_dof_pos;
|
||||||
|
this->obs.dof_vel = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
|
||||||
|
this->obs.actions = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::Tensor RL::compute_torques(torch::Tensor actions)
|
||||||
{
|
{
|
||||||
torch::Tensor actions_scaled = actions * this->params.action_scale;
|
torch::Tensor actions_scaled = actions * this->params.action_scale;
|
||||||
int indices[] = {0, 3, 6, 9};
|
int indices[] = {0, 3, 6, 9};
|
||||||
|
@ -25,10 +37,9 @@ torch::Tensor Model::compute_torques(torch::Tensor actions)
|
||||||
return clamped;
|
return clamped;
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor Model::compute_observation()
|
/* You may need to override this compute_observation() function
|
||||||
|
torch::Tensor RL::compute_observation()
|
||||||
{
|
{
|
||||||
std::cout << "You may need to override this compute_observation() function" << std::endl;
|
|
||||||
|
|
||||||
torch::Tensor obs = torch::cat({(this->quat_rotate_inverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale,
|
torch::Tensor obs = torch::cat({(this->quat_rotate_inverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale,
|
||||||
(this->quat_rotate_inverse(this->obs.base_quat, this->obs.ang_vel)) * this->params.ang_vel_scale,
|
(this->quat_rotate_inverse(this->obs.base_quat, this->obs.ang_vel)) * this->params.ang_vel_scale,
|
||||||
this->quat_rotate_inverse(this->obs.base_quat, this->obs.gravity_vec),
|
this->quat_rotate_inverse(this->obs.base_quat, this->obs.gravity_vec),
|
||||||
|
@ -40,27 +51,15 @@ torch::Tensor Model::compute_observation()
|
||||||
|
|
||||||
obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
|
obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
|
||||||
|
|
||||||
// printf("observation size: %d, %d\n", obs.sizes()[0], obs.sizes()[1]);
|
printf("observation size: %d, %d\n", obs.sizes()[0], obs.sizes()[1]);
|
||||||
|
|
||||||
return obs;
|
return obs;
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
void Model::init_observations()
|
/* You may need to override this forward() function
|
||||||
|
torch::Tensor RL::forward()
|
||||||
{
|
{
|
||||||
this->obs.lin_vel = torch::tensor({{0.0, 0.0, 0.0}});
|
|
||||||
this->obs.ang_vel = torch::tensor({{0.0, 0.0, 0.0}});
|
|
||||||
this->obs.gravity_vec = torch::tensor({{0.0, 0.0, -1.0}});
|
|
||||||
this->obs.commands = torch::tensor({{0.0, 0.0, 0.0}});
|
|
||||||
this->obs.base_quat = torch::tensor({{0.0, 0.0, 0.0, 1.0}});
|
|
||||||
this->obs.dof_pos = this->params.default_dof_pos;
|
|
||||||
this->obs.dof_vel = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
|
|
||||||
this->obs.actions = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
|
|
||||||
}
|
|
||||||
|
|
||||||
torch::Tensor Model::forward()
|
|
||||||
{
|
|
||||||
std::cout << "You may need to override this forward() function" << std::endl;
|
|
||||||
|
|
||||||
torch::Tensor obs = this->compute_observation();
|
torch::Tensor obs = this->compute_observation();
|
||||||
|
|
||||||
torch::Tensor actor_input = torch::cat({obs}, 1);
|
torch::Tensor actor_input = torch::cat({obs}, 1);
|
||||||
|
@ -72,3 +71,4 @@ torch::Tensor Model::forward()
|
||||||
|
|
||||||
return clamped;
|
return clamped;
|
||||||
}
|
}
|
||||||
|
*/
|
|
@ -1,5 +1,5 @@
|
||||||
#ifndef MODEL_HPP
|
#ifndef RL_HPP
|
||||||
#define MODEL_HPP
|
#define RL_HPP
|
||||||
|
|
||||||
#include <torch/script.h>
|
#include <torch/script.h>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
@ -36,14 +36,14 @@ struct Observations {
|
||||||
torch::Tensor actions;
|
torch::Tensor actions;
|
||||||
};
|
};
|
||||||
|
|
||||||
class Model {
|
class RL {
|
||||||
public:
|
public:
|
||||||
Model(){};
|
RL(){};
|
||||||
ModelParams params;
|
ModelParams params;
|
||||||
Observations obs;
|
Observations obs;
|
||||||
|
|
||||||
virtual torch::Tensor forward();
|
virtual torch::Tensor forward() = 0;
|
||||||
virtual torch::Tensor compute_observation();
|
virtual torch::Tensor compute_observation() = 0;
|
||||||
|
|
||||||
torch::Tensor compute_torques(torch::Tensor actions);
|
torch::Tensor compute_torques(torch::Tensor actions);
|
||||||
torch::Tensor quat_rotate_inverse(torch::Tensor q, torch::Tensor v);
|
torch::Tensor quat_rotate_inverse(torch::Tensor q, torch::Tensor v);
|
||||||
|
@ -63,4 +63,4 @@ protected:
|
||||||
torch::Tensor actions;
|
torch::Tensor actions;
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // MODEL_HPP
|
#endif // RL_HPP
|
|
@ -1,8 +1,8 @@
|
||||||
<?xml version="1.0"?>
|
<?xml version="1.0"?>
|
||||||
<package format="2">
|
<package format="2">
|
||||||
<name>unitree_rl</name>
|
<name>rl_sar</name>
|
||||||
<version>0.0.0</version>
|
<version>2.0.0</version>
|
||||||
<description>The unitree_rl package</description>
|
<description>The rl_sar package</description>
|
||||||
|
|
||||||
<maintainer email="fanziqi614@gmail.com">Ziqi Fan</maintainer>
|
<maintainer email="fanziqi614@gmail.com">Ziqi Fan</maintainer>
|
||||||
|
|
|
@ -1,17 +1,16 @@
|
||||||
#include "../include/unitree_rl_real.hpp"
|
#include "../include/rl_real.hpp"
|
||||||
#include <ros/package.h>
|
|
||||||
|
|
||||||
void Unitree_RL::UDPRecv()
|
void RL_Real::UDPRecv()
|
||||||
{
|
{
|
||||||
udp.Recv();
|
udp.Recv();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Unitree_RL::UDPSend()
|
void RL_Real::UDPSend()
|
||||||
{
|
{
|
||||||
udp.Send();
|
udp.Send();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Unitree_RL::RobotControl()
|
void RL_Real::RobotControl()
|
||||||
{
|
{
|
||||||
motiontime++;
|
motiontime++;
|
||||||
udp.GetRecv(state);
|
udp.GetRecv(state);
|
||||||
|
@ -82,20 +81,17 @@ void Unitree_RL::RobotControl()
|
||||||
udp.SetSend(cmd);
|
udp.SetSend(cmd);
|
||||||
}
|
}
|
||||||
|
|
||||||
Unitree_RL::Unitree_RL() : safe(LeggedType::A1), udp(LOWLEVEL)
|
RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL)
|
||||||
{
|
{
|
||||||
udp.InitCmdData(cmd);
|
udp.InitCmdData(cmd);
|
||||||
|
|
||||||
start_time = std::chrono::high_resolution_clock::now();
|
start_time = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
cmd_vel = geometry_msgs::Twist();
|
|
||||||
|
|
||||||
torque_commands.resize(12);
|
torque_commands.resize(12);
|
||||||
|
|
||||||
std::string package_name = "unitree_rl";
|
std::string actor_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/actor.pt";
|
||||||
std::string actor_path = ros::package::getPath(package_name) + "/models/actor.pt";
|
std::string encoder_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/encoder.pt";
|
||||||
std::string encoder_path = ros::package::getPath(package_name) + "/models/encoder.pt";
|
std::string vq_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/vq_layer.pt";
|
||||||
std::string vq_path = ros::package::getPath(package_name) + "/models/vq_layer.pt";
|
|
||||||
|
|
||||||
this->actor = torch::jit::load(actor_path);
|
this->actor = torch::jit::load(actor_path);
|
||||||
this->encoder = torch::jit::load(encoder_path);
|
this->encoder = torch::jit::load(encoder_path);
|
||||||
|
@ -133,23 +129,18 @@ Unitree_RL::Unitree_RL() : safe(LeggedType::A1), udp(LOWLEVEL)
|
||||||
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
|
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
|
||||||
|
|
||||||
// InitEnvironment();
|
// InitEnvironment();
|
||||||
loop_control = std::make_shared<LoopFunc>("control_loop", 0.02 , boost::bind(&Unitree_RL::RobotControl, this));
|
loop_control = std::make_shared<LoopFunc>("control_loop", 0.02 , boost::bind(&RL_Real::RobotControl, this));
|
||||||
loop_udpSend = std::make_shared<LoopFunc>("udp_send" , 0.002, 3, boost::bind(&Unitree_RL::UDPSend, this));
|
loop_udpSend = std::make_shared<LoopFunc>("udp_send" , 0.002, 3, boost::bind(&RL_Real::UDPSend, this));
|
||||||
loop_udpRecv = std::make_shared<LoopFunc>("udp_recv" , 0.002, 3, boost::bind(&Unitree_RL::UDPRecv, this));
|
loop_udpRecv = std::make_shared<LoopFunc>("udp_recv" , 0.002, 3, boost::bind(&RL_Real::UDPRecv, this));
|
||||||
loop_rl = std::make_shared<LoopFunc>("rl_loop" , 0.02 , boost::bind(&Unitree_RL::runModel, this));
|
loop_rl = std::make_shared<LoopFunc>("rl_loop" , 0.02 , boost::bind(&RL_Real::runModel, this));
|
||||||
|
|
||||||
loop_udpSend->start();
|
loop_udpSend->start();
|
||||||
loop_udpRecv->start();
|
loop_udpRecv->start();
|
||||||
loop_control->start();
|
loop_control->start();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Unitree_RL::cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg)
|
// void RL_Real::runModel(const ros::TimerEvent &event)
|
||||||
{
|
void RL_Real::runModel()
|
||||||
cmd_vel = *msg;
|
|
||||||
}
|
|
||||||
|
|
||||||
// void Unitree_RL::runModel(const ros::TimerEvent &event)
|
|
||||||
void Unitree_RL::runModel()
|
|
||||||
{
|
{
|
||||||
if(init_done)
|
if(init_done)
|
||||||
{
|
{
|
||||||
|
@ -163,7 +154,7 @@ void Unitree_RL::runModel()
|
||||||
// printf("%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f\n", state.motorState[FL_0].dq, state.motorState[FL_1].dq, state.motorState[FL_2].dq, state.motorState[FR_0].dq, state.motorState[FR_1].dq, state.motorState[FR_2].dq, state.motorState[RL_0].dq, state.motorState[RL_1].dq, state.motorState[RL_2].dq, state.motorState[RR_0].dq, state.motorState[RR_1].dq, state.motorState[RR_2].dq);
|
// printf("%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f\n", state.motorState[FL_0].dq, state.motorState[FL_1].dq, state.motorState[FL_2].dq, state.motorState[FR_0].dq, state.motorState[FR_1].dq, state.motorState[FR_2].dq, state.motorState[RL_0].dq, state.motorState[RL_1].dq, state.motorState[RL_2].dq, state.motorState[RR_0].dq, state.motorState[RR_1].dq, state.motorState[RR_2].dq);
|
||||||
|
|
||||||
this->obs.ang_vel = torch::tensor({{state.imu.gyroscope[0], state.imu.gyroscope[1], state.imu.gyroscope[2]}});
|
this->obs.ang_vel = torch::tensor({{state.imu.gyroscope[0], state.imu.gyroscope[1], state.imu.gyroscope[2]}});
|
||||||
this->obs.commands = torch::tensor({{cmd_vel.linear.x, cmd_vel.linear.y, cmd_vel.angular.z}});
|
// this->obs.commands = torch::tensor({{cmd_vel.linear.x, cmd_vel.linear.y, cmd_vel.angular.z}});
|
||||||
this->obs.base_quat = torch::tensor({{state.imu.quaternion[1], state.imu.quaternion[2], state.imu.quaternion[3], state.imu.quaternion[0]}});
|
this->obs.base_quat = torch::tensor({{state.imu.quaternion[1], state.imu.quaternion[2], state.imu.quaternion[3], state.imu.quaternion[0]}});
|
||||||
this->obs.dof_pos = torch::tensor({{state.motorState[FR_0].q, state.motorState[FR_1].q, state.motorState[FR_2].q,
|
this->obs.dof_pos = torch::tensor({{state.motorState[FR_0].q, state.motorState[FR_1].q, state.motorState[FR_2].q,
|
||||||
state.motorState[FL_0].q, state.motorState[FL_1].q, state.motorState[FL_2].q,
|
state.motorState[FL_0].q, state.motorState[FL_1].q, state.motorState[FL_2].q,
|
||||||
|
@ -180,7 +171,7 @@ void Unitree_RL::runModel()
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor Unitree_RL::compute_observation()
|
torch::Tensor RL_Real::compute_observation()
|
||||||
{
|
{
|
||||||
torch::Tensor ang_vel = this->quat_rotate_inverse(this->obs.base_quat, this->obs.ang_vel);
|
torch::Tensor ang_vel = this->quat_rotate_inverse(this->obs.base_quat, this->obs.ang_vel);
|
||||||
// float ang_vel_temp = ang_vel[0][0].item<double>();
|
// float ang_vel_temp = ang_vel[0][0].item<double>();
|
||||||
|
@ -204,7 +195,7 @@ torch::Tensor Unitree_RL::compute_observation()
|
||||||
return obs;
|
return obs;
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor Unitree_RL::forward()
|
torch::Tensor RL_Real::forward()
|
||||||
{
|
{
|
||||||
torch::Tensor obs = this->compute_observation();
|
torch::Tensor obs = this->compute_observation();
|
||||||
|
|
||||||
|
@ -229,7 +220,7 @@ torch::Tensor Unitree_RL::forward()
|
||||||
|
|
||||||
int main(int argc, char **argv)
|
int main(int argc, char **argv)
|
||||||
{
|
{
|
||||||
Unitree_RL unitree_rl;
|
RL_Real rl_sar;
|
||||||
|
|
||||||
while(1){
|
while(1){
|
||||||
sleep(10);
|
sleep(10);
|
|
@ -1,7 +1,7 @@
|
||||||
#include "../include/unitree_rl.hpp"
|
#include "../include/rl_sim.hpp"
|
||||||
#include <ros/package.h>
|
#include <ros/package.h>
|
||||||
|
|
||||||
Unitree_RL::Unitree_RL()
|
RL_Sim::RL_Sim()
|
||||||
{
|
{
|
||||||
ros::NodeHandle nh;
|
ros::NodeHandle nh;
|
||||||
start_time = std::chrono::high_resolution_clock::now();
|
start_time = std::chrono::high_resolution_clock::now();
|
||||||
|
@ -10,10 +10,9 @@ Unitree_RL::Unitree_RL()
|
||||||
|
|
||||||
torque_commands.resize(12);
|
torque_commands.resize(12);
|
||||||
|
|
||||||
std::string package_name = "unitree_rl";
|
std::string actor_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/actor.pt";
|
||||||
std::string actor_path = ros::package::getPath(package_name) + "/models/actor.pt";
|
std::string encoder_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/encoder.pt";
|
||||||
std::string encoder_path = ros::package::getPath(package_name) + "/models/encoder.pt";
|
std::string vq_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/vq_layer.pt";
|
||||||
std::string vq_path = ros::package::getPath(package_name) + "/models/vq_layer.pt";
|
|
||||||
|
|
||||||
this->actor = torch::jit::load(actor_path);
|
this->actor = torch::jit::load(actor_path);
|
||||||
this->encoder = torch::jit::load(encoder_path);
|
this->encoder = torch::jit::load(encoder_path);
|
||||||
|
@ -51,9 +50,9 @@ Unitree_RL::Unitree_RL()
|
||||||
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
|
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
|
||||||
|
|
||||||
cmd_vel_subscriber_ = nh.subscribe<geometry_msgs::Twist>(
|
cmd_vel_subscriber_ = nh.subscribe<geometry_msgs::Twist>(
|
||||||
"/cmd_vel", 10, &Unitree_RL::cmdvelCallback, this);
|
"/cmd_vel", 10, &RL_Sim::cmdvelCallback, this);
|
||||||
|
|
||||||
timer = nh.createTimer(ros::Duration(0.005), &Unitree_RL::runModel, this);
|
timer = nh.createTimer(ros::Duration(0.005), &RL_Sim::runModel, this);
|
||||||
|
|
||||||
ros_namespace = "/a1_gazebo/";
|
ros_namespace = "/a1_gazebo/";
|
||||||
|
|
||||||
|
@ -71,31 +70,31 @@ Unitree_RL::Unitree_RL()
|
||||||
}
|
}
|
||||||
|
|
||||||
model_state_subscriber_ = nh.subscribe<gazebo_msgs::ModelStates>(
|
model_state_subscriber_ = nh.subscribe<gazebo_msgs::ModelStates>(
|
||||||
"/gazebo/model_states", 10, &Unitree_RL::modelStatesCallback, this);
|
"/gazebo/model_states", 10, &RL_Sim::modelStatesCallback, this);
|
||||||
|
|
||||||
joint_state_subscriber_ = nh.subscribe<sensor_msgs::JointState>(
|
joint_state_subscriber_ = nh.subscribe<sensor_msgs::JointState>(
|
||||||
"/a1_gazebo/joint_states", 10, &Unitree_RL::jointStatesCallback, this);
|
"/a1_gazebo/joint_states", 10, &RL_Sim::jointStatesCallback, this);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Unitree_RL::modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg)
|
void RL_Sim::modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg)
|
||||||
{
|
{
|
||||||
|
|
||||||
vel = msg->twist[2];
|
vel = msg->twist[2];
|
||||||
pose = msg->pose[2];
|
pose = msg->pose[2];
|
||||||
}
|
}
|
||||||
|
|
||||||
void Unitree_RL::cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg)
|
void RL_Sim::cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg)
|
||||||
{
|
{
|
||||||
cmd_vel = *msg;
|
cmd_vel = *msg;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Unitree_RL::jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg)
|
void RL_Sim::jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg)
|
||||||
{
|
{
|
||||||
joint_positions = msg->position;
|
joint_positions = msg->position;
|
||||||
joint_velocities = msg->velocity;
|
joint_velocities = msg->velocity;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Unitree_RL::runModel(const ros::TimerEvent &event)
|
void RL_Sim::runModel(const ros::TimerEvent &event)
|
||||||
{
|
{
|
||||||
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count();
|
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count();
|
||||||
// std::cout << "Execution time: " << duration << " microseconds" << std::endl;
|
// std::cout << "Execution time: " << duration << " microseconds" << std::endl;
|
||||||
|
@ -126,7 +125,7 @@ void Unitree_RL::runModel(const ros::TimerEvent &event)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor Unitree_RL::compute_observation()
|
torch::Tensor RL_Sim::compute_observation()
|
||||||
{
|
{
|
||||||
torch::Tensor obs = torch::cat({// (this->quat_rotate_inverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale,
|
torch::Tensor obs = torch::cat({// (this->quat_rotate_inverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale,
|
||||||
(this->quat_rotate_inverse(this->obs.base_quat, this->obs.ang_vel)) * this->params.ang_vel_scale,
|
(this->quat_rotate_inverse(this->obs.base_quat, this->obs.ang_vel)) * this->params.ang_vel_scale,
|
||||||
|
@ -140,7 +139,7 @@ torch::Tensor Unitree_RL::compute_observation()
|
||||||
return obs;
|
return obs;
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor Unitree_RL::forward()
|
torch::Tensor RL_Sim::forward()
|
||||||
{
|
{
|
||||||
torch::Tensor obs = this->compute_observation();
|
torch::Tensor obs = this->compute_observation();
|
||||||
|
|
||||||
|
@ -163,8 +162,8 @@ torch::Tensor Unitree_RL::forward()
|
||||||
|
|
||||||
int main(int argc, char **argv)
|
int main(int argc, char **argv)
|
||||||
{
|
{
|
||||||
ros::init(argc, argv, "unitree_rl");
|
ros::init(argc, argv, "rl_sar");
|
||||||
Unitree_RL unitree_rl;
|
RL_Sim rl_sar;
|
||||||
ros::spin();
|
ros::spin();
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
Loading…
Reference in New Issue