mirror of https://github.com/fan-ziqi/rl_sar.git
fix: fix format
This commit is contained in:
parent
8d2ab423b5
commit
9020bb81ee
|
@ -28,6 +28,7 @@ class RL_Real : public RL
|
||||||
public:
|
public:
|
||||||
RL_Real();
|
RL_Real();
|
||||||
~RL_Real();
|
~RL_Real();
|
||||||
|
|
||||||
void runModel();
|
void runModel();
|
||||||
torch::Tensor forward() override;
|
torch::Tensor forward() override;
|
||||||
torch::Tensor compute_observation() override;
|
torch::Tensor compute_observation() override;
|
||||||
|
|
|
@ -1,21 +1,24 @@
|
||||||
#ifndef RL_SIM_HPP
|
#ifndef RL_SIM_HPP
|
||||||
#define RL_SIM_HPP
|
#define RL_SIM_HPP
|
||||||
|
|
||||||
|
#include "../library/rl/rl.hpp"
|
||||||
|
#include "../library/observation_buffer/observation_buffer.hpp"
|
||||||
#include <ros/ros.h>
|
#include <ros/ros.h>
|
||||||
#include <gazebo_msgs/ModelStates.h>
|
#include <gazebo_msgs/ModelStates.h>
|
||||||
#include <sensor_msgs/JointState.h>
|
#include <sensor_msgs/JointState.h>
|
||||||
#include <geometry_msgs/Twist.h>
|
#include <geometry_msgs/Twist.h>
|
||||||
#include "../library/rl/rl.hpp"
|
|
||||||
#include "../library/observation_buffer/observation_buffer.hpp"
|
|
||||||
#include "unitree_legged_msgs/MotorCmd.h"
|
#include "unitree_legged_msgs/MotorCmd.h"
|
||||||
|
#include <csignal>
|
||||||
|
|
||||||
class RL_Sim : public RL
|
class RL_Sim : public RL
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
RL_Sim();
|
RL_Sim();
|
||||||
|
|
||||||
void modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg);
|
void modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg);
|
||||||
void jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg);
|
void jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg);
|
||||||
void cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg);
|
void cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg);
|
||||||
|
|
||||||
void runModel(const ros::TimerEvent &event);
|
void runModel(const ros::TimerEvent &event);
|
||||||
torch::Tensor forward() override;
|
torch::Tensor forward() override;
|
||||||
torch::Tensor compute_observation() override;
|
torch::Tensor compute_observation() override;
|
||||||
|
@ -24,8 +27,6 @@ public:
|
||||||
torch::Tensor history_obs;
|
torch::Tensor history_obs;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string ros_namespace;
|
|
||||||
|
|
||||||
std::vector<std::string> torque_command_topics;
|
std::vector<std::string> torque_command_topics;
|
||||||
|
|
||||||
ros::Subscriber model_state_subscriber_;
|
ros::Subscriber model_state_subscriber_;
|
||||||
|
@ -33,7 +34,7 @@ private:
|
||||||
ros::Subscriber cmd_vel_subscriber_;
|
ros::Subscriber cmd_vel_subscriber_;
|
||||||
|
|
||||||
std::map<std::string, ros::Publisher> torque_publishers;
|
std::map<std::string, ros::Publisher> torque_publishers;
|
||||||
std::vector<unitree_legged_msgs::MotorCmd> torque_commands;
|
std::vector<unitree_legged_msgs::MotorCmd> motor_commands;
|
||||||
|
|
||||||
geometry_msgs::Twist vel;
|
geometry_msgs::Twist vel;
|
||||||
geometry_msgs::Pose pose;
|
geometry_msgs::Pose pose;
|
||||||
|
|
|
@ -5,7 +5,8 @@
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
struct ModelParams {
|
struct ModelParams
|
||||||
|
{
|
||||||
int num_observations;
|
int num_observations;
|
||||||
float damping;
|
float damping;
|
||||||
float stiffness;
|
float stiffness;
|
||||||
|
@ -25,7 +26,8 @@ struct ModelParams {
|
||||||
torch::Tensor default_dof_pos;
|
torch::Tensor default_dof_pos;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Observations {
|
struct Observations
|
||||||
|
{
|
||||||
torch::Tensor lin_vel;
|
torch::Tensor lin_vel;
|
||||||
torch::Tensor ang_vel;
|
torch::Tensor ang_vel;
|
||||||
torch::Tensor gravity_vec;
|
torch::Tensor gravity_vec;
|
||||||
|
@ -36,9 +38,11 @@ struct Observations {
|
||||||
torch::Tensor actions;
|
torch::Tensor actions;
|
||||||
};
|
};
|
||||||
|
|
||||||
class RL {
|
class RL
|
||||||
|
{
|
||||||
public:
|
public:
|
||||||
RL(){};
|
RL(){};
|
||||||
|
|
||||||
ModelParams params;
|
ModelParams params;
|
||||||
Observations obs;
|
Observations obs;
|
||||||
|
|
||||||
|
@ -61,8 +65,8 @@ protected:
|
||||||
torch::Tensor dof_pos;
|
torch::Tensor dof_pos;
|
||||||
torch::Tensor dof_vel;
|
torch::Tensor dof_vel;
|
||||||
torch::Tensor actions;
|
torch::Tensor actions;
|
||||||
|
// output buffer
|
||||||
torch::Tensor torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
|
torch::Tensor torques;
|
||||||
torch::Tensor target_dof_pos;
|
torch::Tensor target_dof_pos;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -159,6 +159,7 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL)
|
||||||
|
|
||||||
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
|
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
|
||||||
|
|
||||||
|
torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
|
||||||
target_dof_pos = params.default_dof_pos;
|
target_dof_pos = params.default_dof_pos;
|
||||||
|
|
||||||
// InitEnvironment();
|
// InitEnvironment();
|
||||||
|
@ -185,9 +186,9 @@ void RL_Real::runModel()
|
||||||
{
|
{
|
||||||
if(init_state == STATE_RL_START)
|
if(init_state == STATE_RL_START)
|
||||||
{
|
{
|
||||||
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count();
|
// auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count();
|
||||||
std::cout << "Execution time: " << duration << " microseconds" << std::endl;
|
// std::cout << "Execution time: " << duration << " microseconds" << std::endl;
|
||||||
start_time = std::chrono::high_resolution_clock::now();
|
// start_time = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
// printf("%f, %f, %f\n", state.imu.gyroscope[0], state.imu.gyroscope[1], state.imu.gyroscope[2]);
|
// printf("%f, %f, %f\n", state.imu.gyroscope[0], state.imu.gyroscope[1], state.imu.gyroscope[2]);
|
||||||
// printf("%f, %f, %f, %f\n", state.imu.quaternion[1], state.imu.quaternion[2], state.imu.quaternion[3], state.imu.quaternion[0]);
|
// printf("%f, %f, %f, %f\n", state.imu.quaternion[1], state.imu.quaternion[2], state.imu.quaternion[3], state.imu.quaternion[0]);
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
#include "../include/rl_sim.hpp"
|
#include "../include/rl_sim.hpp"
|
||||||
#include <ros/package.h>
|
|
||||||
|
RL_Sim rl_sar;
|
||||||
|
|
||||||
RL_Sim::RL_Sim()
|
RL_Sim::RL_Sim()
|
||||||
{
|
{
|
||||||
|
@ -8,7 +9,7 @@ RL_Sim::RL_Sim()
|
||||||
|
|
||||||
cmd_vel = geometry_msgs::Twist();
|
cmd_vel = geometry_msgs::Twist();
|
||||||
|
|
||||||
torque_commands.resize(12);
|
motor_commands.resize(12);
|
||||||
|
|
||||||
std::string actor_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/actor.pt";
|
std::string actor_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/actor.pt";
|
||||||
std::string encoder_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/encoder.pt";
|
std::string encoder_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/encoder.pt";
|
||||||
|
@ -49,14 +50,14 @@ RL_Sim::RL_Sim()
|
||||||
|
|
||||||
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
|
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
|
||||||
|
|
||||||
|
torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
|
||||||
target_dof_pos = params.default_dof_pos;
|
target_dof_pos = params.default_dof_pos;
|
||||||
|
|
||||||
cmd_vel_subscriber_ = nh.subscribe<geometry_msgs::Twist>(
|
cmd_vel_subscriber_ = nh.subscribe<geometry_msgs::Twist>("/cmd_vel", 10, &RL_Sim::cmdvelCallback, this);
|
||||||
"/cmd_vel", 10, &RL_Sim::cmdvelCallback, this);
|
|
||||||
|
|
||||||
timer = nh.createTimer(ros::Duration(0.02), &RL_Sim::runModel, this);
|
timer = nh.createTimer(ros::Duration(0.02), &RL_Sim::runModel, this);
|
||||||
|
|
||||||
ros_namespace = "/a1_gazebo/";
|
std::string ros_namespace = "/a1_gazebo/";
|
||||||
|
|
||||||
joint_names = {
|
joint_names = {
|
||||||
"FL_hip_joint", "FL_thigh_joint", "FL_calf_joint",
|
"FL_hip_joint", "FL_thigh_joint", "FL_calf_joint",
|
||||||
|
@ -98,9 +99,9 @@ void RL_Sim::jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg)
|
||||||
|
|
||||||
void RL_Sim::runModel(const ros::TimerEvent &event)
|
void RL_Sim::runModel(const ros::TimerEvent &event)
|
||||||
{
|
{
|
||||||
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count();
|
// auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count();
|
||||||
// std::cout << "Execution time: " << duration << " microseconds" << std::endl;
|
// std::cout << "Execution time: " << duration << " microseconds" << std::endl;
|
||||||
start_time = std::chrono::high_resolution_clock::now();
|
// start_time = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
this->obs.lin_vel = torch::tensor({{vel.linear.x, vel.linear.y, vel.linear.z}});
|
this->obs.lin_vel = torch::tensor({{vel.linear.x, vel.linear.y, vel.linear.z}});
|
||||||
this->obs.ang_vel = torch::tensor({{vel.angular.x, vel.angular.y, vel.angular.z}});
|
this->obs.ang_vel = torch::tensor({{vel.angular.x, vel.angular.y, vel.angular.z}});
|
||||||
|
@ -121,15 +122,15 @@ void RL_Sim::runModel(const ros::TimerEvent &event)
|
||||||
|
|
||||||
for (int i = 0; i < 12; ++i)
|
for (int i = 0; i < 12; ++i)
|
||||||
{
|
{
|
||||||
torque_commands[i].mode = 0x0A;
|
motor_commands[i].mode = 0x0A;
|
||||||
// torque_commands[i].tau = torques[0][i].item<double>();
|
// motor_commands[i].tau = torques[0][i].item<double>();
|
||||||
torque_commands[i].tau = 0;
|
motor_commands[i].tau = 0;
|
||||||
torque_commands[i].q = target_dof_pos[0][i].item<double>();
|
motor_commands[i].q = target_dof_pos[0][i].item<double>();
|
||||||
torque_commands[i].dq = 0;
|
motor_commands[i].dq = 0;
|
||||||
torque_commands[i].Kp = params.stiffness;
|
motor_commands[i].Kp = params.stiffness;
|
||||||
torque_commands[i].Kd = params.damping;
|
motor_commands[i].Kd = params.damping;
|
||||||
|
|
||||||
torque_publishers[joint_names[i]].publish(torque_commands[i]);
|
torque_publishers[joint_names[i]].publish(motor_commands[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -168,10 +169,19 @@ torch::Tensor RL_Sim::forward()
|
||||||
return clamped;
|
return clamped;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void signalHandler(int signum)
|
||||||
|
{
|
||||||
|
ros::shutdown();
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
|
||||||
int main(int argc, char **argv)
|
int main(int argc, char **argv)
|
||||||
{
|
{
|
||||||
|
signal(SIGINT, signalHandler);
|
||||||
|
|
||||||
ros::init(argc, argv, "rl_sar");
|
ros::init(argc, argv, "rl_sar");
|
||||||
RL_Sim rl_sar;
|
|
||||||
ros::spin();
|
ros::spin();
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue