fix: fix format

This commit is contained in:
fan-ziqi 2024-03-15 21:14:52 +08:00
parent 8d2ab423b5
commit 9020bb81ee
5 changed files with 47 additions and 30 deletions

View File

@ -28,6 +28,7 @@ class RL_Real : public RL
public:
RL_Real();
~RL_Real();
void runModel();
torch::Tensor forward() override;
torch::Tensor compute_observation() override;

View File

@ -1,21 +1,24 @@
#ifndef RL_SIM_HPP
#define RL_SIM_HPP
#include "../library/rl/rl.hpp"
#include "../library/observation_buffer/observation_buffer.hpp"
#include <ros/ros.h>
#include <gazebo_msgs/ModelStates.h>
#include <sensor_msgs/JointState.h>
#include <geometry_msgs/Twist.h>
#include "../library/rl/rl.hpp"
#include "../library/observation_buffer/observation_buffer.hpp"
#include "unitree_legged_msgs/MotorCmd.h"
#include <csignal>
class RL_Sim : public RL
{
public:
RL_Sim();
void modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg);
void jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg);
void cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg);
void runModel(const ros::TimerEvent &event);
torch::Tensor forward() override;
torch::Tensor compute_observation() override;
@ -24,8 +27,6 @@ public:
torch::Tensor history_obs;
private:
std::string ros_namespace;
std::vector<std::string> torque_command_topics;
ros::Subscriber model_state_subscriber_;
@ -33,7 +34,7 @@ private:
ros::Subscriber cmd_vel_subscriber_;
std::map<std::string, ros::Publisher> torque_publishers;
std::vector<unitree_legged_msgs::MotorCmd> torque_commands;
std::vector<unitree_legged_msgs::MotorCmd> motor_commands;
geometry_msgs::Twist vel;
geometry_msgs::Pose pose;

View File

@ -5,7 +5,8 @@
#include <iostream>
#include <string>
struct ModelParams {
struct ModelParams
{
int num_observations;
float damping;
float stiffness;
@ -25,7 +26,8 @@ struct ModelParams {
torch::Tensor default_dof_pos;
};
struct Observations {
struct Observations
{
torch::Tensor lin_vel;
torch::Tensor ang_vel;
torch::Tensor gravity_vec;
@ -36,9 +38,11 @@ struct Observations {
torch::Tensor actions;
};
class RL {
class RL
{
public:
RL(){};
ModelParams params;
Observations obs;
@ -61,8 +65,8 @@ protected:
torch::Tensor dof_pos;
torch::Tensor dof_vel;
torch::Tensor actions;
torch::Tensor torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
// output buffer
torch::Tensor torques;
torch::Tensor target_dof_pos;
};

View File

@ -159,6 +159,7 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL)
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
target_dof_pos = params.default_dof_pos;
// InitEnvironment();
@ -185,9 +186,9 @@ void RL_Real::runModel()
{
if(init_state == STATE_RL_START)
{
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count();
std::cout << "Execution time: " << duration << " microseconds" << std::endl;
start_time = std::chrono::high_resolution_clock::now();
// auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count();
// std::cout << "Execution time: " << duration << " microseconds" << std::endl;
// start_time = std::chrono::high_resolution_clock::now();
// printf("%f, %f, %f\n", state.imu.gyroscope[0], state.imu.gyroscope[1], state.imu.gyroscope[2]);
// printf("%f, %f, %f, %f\n", state.imu.quaternion[1], state.imu.quaternion[2], state.imu.quaternion[3], state.imu.quaternion[0]);

View File

@ -1,5 +1,6 @@
#include "../include/rl_sim.hpp"
#include <ros/package.h>
RL_Sim rl_sar;
RL_Sim::RL_Sim()
{
@ -8,7 +9,7 @@ RL_Sim::RL_Sim()
cmd_vel = geometry_msgs::Twist();
torque_commands.resize(12);
motor_commands.resize(12);
std::string actor_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/actor.pt";
std::string encoder_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/encoder.pt";
@ -49,14 +50,14 @@ RL_Sim::RL_Sim()
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
target_dof_pos = params.default_dof_pos;
cmd_vel_subscriber_ = nh.subscribe<geometry_msgs::Twist>(
"/cmd_vel", 10, &RL_Sim::cmdvelCallback, this);
cmd_vel_subscriber_ = nh.subscribe<geometry_msgs::Twist>("/cmd_vel", 10, &RL_Sim::cmdvelCallback, this);
timer = nh.createTimer(ros::Duration(0.02), &RL_Sim::runModel, this);
ros_namespace = "/a1_gazebo/";
std::string ros_namespace = "/a1_gazebo/";
joint_names = {
"FL_hip_joint", "FL_thigh_joint", "FL_calf_joint",
@ -98,9 +99,9 @@ void RL_Sim::jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg)
void RL_Sim::runModel(const ros::TimerEvent &event)
{
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count();
// auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count();
// std::cout << "Execution time: " << duration << " microseconds" << std::endl;
start_time = std::chrono::high_resolution_clock::now();
// start_time = std::chrono::high_resolution_clock::now();
this->obs.lin_vel = torch::tensor({{vel.linear.x, vel.linear.y, vel.linear.z}});
this->obs.ang_vel = torch::tensor({{vel.angular.x, vel.angular.y, vel.angular.z}});
@ -121,15 +122,15 @@ void RL_Sim::runModel(const ros::TimerEvent &event)
for (int i = 0; i < 12; ++i)
{
torque_commands[i].mode = 0x0A;
// torque_commands[i].tau = torques[0][i].item<double>();
torque_commands[i].tau = 0;
torque_commands[i].q = target_dof_pos[0][i].item<double>();
torque_commands[i].dq = 0;
torque_commands[i].Kp = params.stiffness;
torque_commands[i].Kd = params.damping;
motor_commands[i].mode = 0x0A;
// motor_commands[i].tau = torques[0][i].item<double>();
motor_commands[i].tau = 0;
motor_commands[i].q = target_dof_pos[0][i].item<double>();
motor_commands[i].dq = 0;
motor_commands[i].Kp = params.stiffness;
motor_commands[i].Kd = params.damping;
torque_publishers[joint_names[i]].publish(torque_commands[i]);
torque_publishers[joint_names[i]].publish(motor_commands[i]);
}
}
@ -168,10 +169,19 @@ torch::Tensor RL_Sim::forward()
return clamped;
}
void signalHandler(int signum)
{
ros::shutdown();
exit(0);
}
int main(int argc, char **argv)
{
signal(SIGINT, signalHandler);
ros::init(argc, argv, "rl_sar");
RL_Sim rl_sar;
ros::spin();
return 0;
}