fix: fix format

This commit is contained in:
fan-ziqi 2024-03-15 21:14:52 +08:00
parent 8d2ab423b5
commit 9020bb81ee
5 changed files with 47 additions and 30 deletions

View File

@ -28,6 +28,7 @@ class RL_Real : public RL
public: public:
RL_Real(); RL_Real();
~RL_Real(); ~RL_Real();
void runModel(); void runModel();
torch::Tensor forward() override; torch::Tensor forward() override;
torch::Tensor compute_observation() override; torch::Tensor compute_observation() override;

View File

@ -1,21 +1,24 @@
#ifndef RL_SIM_HPP #ifndef RL_SIM_HPP
#define RL_SIM_HPP #define RL_SIM_HPP
#include "../library/rl/rl.hpp"
#include "../library/observation_buffer/observation_buffer.hpp"
#include <ros/ros.h> #include <ros/ros.h>
#include <gazebo_msgs/ModelStates.h> #include <gazebo_msgs/ModelStates.h>
#include <sensor_msgs/JointState.h> #include <sensor_msgs/JointState.h>
#include <geometry_msgs/Twist.h> #include <geometry_msgs/Twist.h>
#include "../library/rl/rl.hpp"
#include "../library/observation_buffer/observation_buffer.hpp"
#include "unitree_legged_msgs/MotorCmd.h" #include "unitree_legged_msgs/MotorCmd.h"
#include <csignal>
class RL_Sim : public RL class RL_Sim : public RL
{ {
public: public:
RL_Sim(); RL_Sim();
void modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg); void modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg);
void jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg); void jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg);
void cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg); void cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg);
void runModel(const ros::TimerEvent &event); void runModel(const ros::TimerEvent &event);
torch::Tensor forward() override; torch::Tensor forward() override;
torch::Tensor compute_observation() override; torch::Tensor compute_observation() override;
@ -24,8 +27,6 @@ public:
torch::Tensor history_obs; torch::Tensor history_obs;
private: private:
std::string ros_namespace;
std::vector<std::string> torque_command_topics; std::vector<std::string> torque_command_topics;
ros::Subscriber model_state_subscriber_; ros::Subscriber model_state_subscriber_;
@ -33,7 +34,7 @@ private:
ros::Subscriber cmd_vel_subscriber_; ros::Subscriber cmd_vel_subscriber_;
std::map<std::string, ros::Publisher> torque_publishers; std::map<std::string, ros::Publisher> torque_publishers;
std::vector<unitree_legged_msgs::MotorCmd> torque_commands; std::vector<unitree_legged_msgs::MotorCmd> motor_commands;
geometry_msgs::Twist vel; geometry_msgs::Twist vel;
geometry_msgs::Pose pose; geometry_msgs::Pose pose;

View File

@ -5,7 +5,8 @@
#include <iostream> #include <iostream>
#include <string> #include <string>
struct ModelParams { struct ModelParams
{
int num_observations; int num_observations;
float damping; float damping;
float stiffness; float stiffness;
@ -25,7 +26,8 @@ struct ModelParams {
torch::Tensor default_dof_pos; torch::Tensor default_dof_pos;
}; };
struct Observations { struct Observations
{
torch::Tensor lin_vel; torch::Tensor lin_vel;
torch::Tensor ang_vel; torch::Tensor ang_vel;
torch::Tensor gravity_vec; torch::Tensor gravity_vec;
@ -36,9 +38,11 @@ struct Observations {
torch::Tensor actions; torch::Tensor actions;
}; };
class RL { class RL
{
public: public:
RL(){}; RL(){};
ModelParams params; ModelParams params;
Observations obs; Observations obs;
@ -61,8 +65,8 @@ protected:
torch::Tensor dof_pos; torch::Tensor dof_pos;
torch::Tensor dof_vel; torch::Tensor dof_vel;
torch::Tensor actions; torch::Tensor actions;
// output buffer
torch::Tensor torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}}); torch::Tensor torques;
torch::Tensor target_dof_pos; torch::Tensor target_dof_pos;
}; };

View File

@ -159,6 +159,7 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL)
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6); this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
target_dof_pos = params.default_dof_pos; target_dof_pos = params.default_dof_pos;
// InitEnvironment(); // InitEnvironment();
@ -185,9 +186,9 @@ void RL_Real::runModel()
{ {
if(init_state == STATE_RL_START) if(init_state == STATE_RL_START)
{ {
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count(); // auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count();
std::cout << "Execution time: " << duration << " microseconds" << std::endl; // std::cout << "Execution time: " << duration << " microseconds" << std::endl;
start_time = std::chrono::high_resolution_clock::now(); // start_time = std::chrono::high_resolution_clock::now();
// printf("%f, %f, %f\n", state.imu.gyroscope[0], state.imu.gyroscope[1], state.imu.gyroscope[2]); // printf("%f, %f, %f\n", state.imu.gyroscope[0], state.imu.gyroscope[1], state.imu.gyroscope[2]);
// printf("%f, %f, %f, %f\n", state.imu.quaternion[1], state.imu.quaternion[2], state.imu.quaternion[3], state.imu.quaternion[0]); // printf("%f, %f, %f, %f\n", state.imu.quaternion[1], state.imu.quaternion[2], state.imu.quaternion[3], state.imu.quaternion[0]);

View File

@ -1,5 +1,6 @@
#include "../include/rl_sim.hpp" #include "../include/rl_sim.hpp"
#include <ros/package.h>
RL_Sim rl_sar;
RL_Sim::RL_Sim() RL_Sim::RL_Sim()
{ {
@ -8,7 +9,7 @@ RL_Sim::RL_Sim()
cmd_vel = geometry_msgs::Twist(); cmd_vel = geometry_msgs::Twist();
torque_commands.resize(12); motor_commands.resize(12);
std::string actor_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/actor.pt"; std::string actor_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/actor.pt";
std::string encoder_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/encoder.pt"; std::string encoder_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/encoder.pt";
@ -49,14 +50,14 @@ RL_Sim::RL_Sim()
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6); this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
target_dof_pos = params.default_dof_pos; target_dof_pos = params.default_dof_pos;
cmd_vel_subscriber_ = nh.subscribe<geometry_msgs::Twist>( cmd_vel_subscriber_ = nh.subscribe<geometry_msgs::Twist>("/cmd_vel", 10, &RL_Sim::cmdvelCallback, this);
"/cmd_vel", 10, &RL_Sim::cmdvelCallback, this);
timer = nh.createTimer(ros::Duration(0.02), &RL_Sim::runModel, this); timer = nh.createTimer(ros::Duration(0.02), &RL_Sim::runModel, this);
ros_namespace = "/a1_gazebo/"; std::string ros_namespace = "/a1_gazebo/";
joint_names = { joint_names = {
"FL_hip_joint", "FL_thigh_joint", "FL_calf_joint", "FL_hip_joint", "FL_thigh_joint", "FL_calf_joint",
@ -98,9 +99,9 @@ void RL_Sim::jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg)
void RL_Sim::runModel(const ros::TimerEvent &event) void RL_Sim::runModel(const ros::TimerEvent &event)
{ {
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count(); // auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count();
// std::cout << "Execution time: " << duration << " microseconds" << std::endl; // std::cout << "Execution time: " << duration << " microseconds" << std::endl;
start_time = std::chrono::high_resolution_clock::now(); // start_time = std::chrono::high_resolution_clock::now();
this->obs.lin_vel = torch::tensor({{vel.linear.x, vel.linear.y, vel.linear.z}}); this->obs.lin_vel = torch::tensor({{vel.linear.x, vel.linear.y, vel.linear.z}});
this->obs.ang_vel = torch::tensor({{vel.angular.x, vel.angular.y, vel.angular.z}}); this->obs.ang_vel = torch::tensor({{vel.angular.x, vel.angular.y, vel.angular.z}});
@ -121,15 +122,15 @@ void RL_Sim::runModel(const ros::TimerEvent &event)
for (int i = 0; i < 12; ++i) for (int i = 0; i < 12; ++i)
{ {
torque_commands[i].mode = 0x0A; motor_commands[i].mode = 0x0A;
// torque_commands[i].tau = torques[0][i].item<double>(); // motor_commands[i].tau = torques[0][i].item<double>();
torque_commands[i].tau = 0; motor_commands[i].tau = 0;
torque_commands[i].q = target_dof_pos[0][i].item<double>(); motor_commands[i].q = target_dof_pos[0][i].item<double>();
torque_commands[i].dq = 0; motor_commands[i].dq = 0;
torque_commands[i].Kp = params.stiffness; motor_commands[i].Kp = params.stiffness;
torque_commands[i].Kd = params.damping; motor_commands[i].Kd = params.damping;
torque_publishers[joint_names[i]].publish(torque_commands[i]); torque_publishers[joint_names[i]].publish(motor_commands[i]);
} }
} }
@ -168,10 +169,19 @@ torch::Tensor RL_Sim::forward()
return clamped; return clamped;
} }
void signalHandler(int signum)
{
ros::shutdown();
exit(0);
}
int main(int argc, char **argv) int main(int argc, char **argv)
{ {
signal(SIGINT, signalHandler);
ros::init(argc, argv, "rl_sar"); ros::init(argc, argv, "rl_sar");
RL_Sim rl_sar;
ros::spin(); ros::spin();
return 0; return 0;
} }