mirror of https://github.com/fan-ziqi/rl_sar.git
fix: fix format
This commit is contained in:
parent
8d2ab423b5
commit
9020bb81ee
|
@ -28,6 +28,7 @@ class RL_Real : public RL
|
|||
public:
|
||||
RL_Real();
|
||||
~RL_Real();
|
||||
|
||||
void runModel();
|
||||
torch::Tensor forward() override;
|
||||
torch::Tensor compute_observation() override;
|
||||
|
|
|
@ -1,21 +1,24 @@
|
|||
#ifndef RL_SIM_HPP
|
||||
#define RL_SIM_HPP
|
||||
|
||||
#include "../library/rl/rl.hpp"
|
||||
#include "../library/observation_buffer/observation_buffer.hpp"
|
||||
#include <ros/ros.h>
|
||||
#include <gazebo_msgs/ModelStates.h>
|
||||
#include <sensor_msgs/JointState.h>
|
||||
#include <geometry_msgs/Twist.h>
|
||||
#include "../library/rl/rl.hpp"
|
||||
#include "../library/observation_buffer/observation_buffer.hpp"
|
||||
#include "unitree_legged_msgs/MotorCmd.h"
|
||||
#include <csignal>
|
||||
|
||||
class RL_Sim : public RL
|
||||
{
|
||||
public:
|
||||
RL_Sim();
|
||||
|
||||
void modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg);
|
||||
void jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg);
|
||||
void cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg);
|
||||
|
||||
void runModel(const ros::TimerEvent &event);
|
||||
torch::Tensor forward() override;
|
||||
torch::Tensor compute_observation() override;
|
||||
|
@ -24,8 +27,6 @@ public:
|
|||
torch::Tensor history_obs;
|
||||
|
||||
private:
|
||||
std::string ros_namespace;
|
||||
|
||||
std::vector<std::string> torque_command_topics;
|
||||
|
||||
ros::Subscriber model_state_subscriber_;
|
||||
|
@ -33,7 +34,7 @@ private:
|
|||
ros::Subscriber cmd_vel_subscriber_;
|
||||
|
||||
std::map<std::string, ros::Publisher> torque_publishers;
|
||||
std::vector<unitree_legged_msgs::MotorCmd> torque_commands;
|
||||
std::vector<unitree_legged_msgs::MotorCmd> motor_commands;
|
||||
|
||||
geometry_msgs::Twist vel;
|
||||
geometry_msgs::Pose pose;
|
||||
|
|
|
@ -5,7 +5,8 @@
|
|||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
struct ModelParams {
|
||||
struct ModelParams
|
||||
{
|
||||
int num_observations;
|
||||
float damping;
|
||||
float stiffness;
|
||||
|
@ -25,7 +26,8 @@ struct ModelParams {
|
|||
torch::Tensor default_dof_pos;
|
||||
};
|
||||
|
||||
struct Observations {
|
||||
struct Observations
|
||||
{
|
||||
torch::Tensor lin_vel;
|
||||
torch::Tensor ang_vel;
|
||||
torch::Tensor gravity_vec;
|
||||
|
@ -36,9 +38,11 @@ struct Observations {
|
|||
torch::Tensor actions;
|
||||
};
|
||||
|
||||
class RL {
|
||||
class RL
|
||||
{
|
||||
public:
|
||||
RL(){};
|
||||
|
||||
ModelParams params;
|
||||
Observations obs;
|
||||
|
||||
|
@ -61,8 +65,8 @@ protected:
|
|||
torch::Tensor dof_pos;
|
||||
torch::Tensor dof_vel;
|
||||
torch::Tensor actions;
|
||||
|
||||
torch::Tensor torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
|
||||
// output buffer
|
||||
torch::Tensor torques;
|
||||
torch::Tensor target_dof_pos;
|
||||
};
|
||||
|
||||
|
|
|
@ -159,6 +159,7 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL)
|
|||
|
||||
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
|
||||
|
||||
torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
|
||||
target_dof_pos = params.default_dof_pos;
|
||||
|
||||
// InitEnvironment();
|
||||
|
@ -185,9 +186,9 @@ void RL_Real::runModel()
|
|||
{
|
||||
if(init_state == STATE_RL_START)
|
||||
{
|
||||
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count();
|
||||
std::cout << "Execution time: " << duration << " microseconds" << std::endl;
|
||||
start_time = std::chrono::high_resolution_clock::now();
|
||||
// auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count();
|
||||
// std::cout << "Execution time: " << duration << " microseconds" << std::endl;
|
||||
// start_time = std::chrono::high_resolution_clock::now();
|
||||
|
||||
// printf("%f, %f, %f\n", state.imu.gyroscope[0], state.imu.gyroscope[1], state.imu.gyroscope[2]);
|
||||
// printf("%f, %f, %f, %f\n", state.imu.quaternion[1], state.imu.quaternion[2], state.imu.quaternion[3], state.imu.quaternion[0]);
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#include "../include/rl_sim.hpp"
|
||||
#include <ros/package.h>
|
||||
|
||||
RL_Sim rl_sar;
|
||||
|
||||
RL_Sim::RL_Sim()
|
||||
{
|
||||
|
@ -8,7 +9,7 @@ RL_Sim::RL_Sim()
|
|||
|
||||
cmd_vel = geometry_msgs::Twist();
|
||||
|
||||
torque_commands.resize(12);
|
||||
motor_commands.resize(12);
|
||||
|
||||
std::string actor_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/actor.pt";
|
||||
std::string encoder_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/encoder.pt";
|
||||
|
@ -49,14 +50,14 @@ RL_Sim::RL_Sim()
|
|||
|
||||
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
|
||||
|
||||
torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
|
||||
target_dof_pos = params.default_dof_pos;
|
||||
|
||||
cmd_vel_subscriber_ = nh.subscribe<geometry_msgs::Twist>(
|
||||
"/cmd_vel", 10, &RL_Sim::cmdvelCallback, this);
|
||||
cmd_vel_subscriber_ = nh.subscribe<geometry_msgs::Twist>("/cmd_vel", 10, &RL_Sim::cmdvelCallback, this);
|
||||
|
||||
timer = nh.createTimer(ros::Duration(0.02), &RL_Sim::runModel, this);
|
||||
|
||||
ros_namespace = "/a1_gazebo/";
|
||||
std::string ros_namespace = "/a1_gazebo/";
|
||||
|
||||
joint_names = {
|
||||
"FL_hip_joint", "FL_thigh_joint", "FL_calf_joint",
|
||||
|
@ -98,9 +99,9 @@ void RL_Sim::jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg)
|
|||
|
||||
void RL_Sim::runModel(const ros::TimerEvent &event)
|
||||
{
|
||||
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count();
|
||||
// auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count();
|
||||
// std::cout << "Execution time: " << duration << " microseconds" << std::endl;
|
||||
start_time = std::chrono::high_resolution_clock::now();
|
||||
// start_time = std::chrono::high_resolution_clock::now();
|
||||
|
||||
this->obs.lin_vel = torch::tensor({{vel.linear.x, vel.linear.y, vel.linear.z}});
|
||||
this->obs.ang_vel = torch::tensor({{vel.angular.x, vel.angular.y, vel.angular.z}});
|
||||
|
@ -121,15 +122,15 @@ void RL_Sim::runModel(const ros::TimerEvent &event)
|
|||
|
||||
for (int i = 0; i < 12; ++i)
|
||||
{
|
||||
torque_commands[i].mode = 0x0A;
|
||||
// torque_commands[i].tau = torques[0][i].item<double>();
|
||||
torque_commands[i].tau = 0;
|
||||
torque_commands[i].q = target_dof_pos[0][i].item<double>();
|
||||
torque_commands[i].dq = 0;
|
||||
torque_commands[i].Kp = params.stiffness;
|
||||
torque_commands[i].Kd = params.damping;
|
||||
motor_commands[i].mode = 0x0A;
|
||||
// motor_commands[i].tau = torques[0][i].item<double>();
|
||||
motor_commands[i].tau = 0;
|
||||
motor_commands[i].q = target_dof_pos[0][i].item<double>();
|
||||
motor_commands[i].dq = 0;
|
||||
motor_commands[i].Kp = params.stiffness;
|
||||
motor_commands[i].Kd = params.damping;
|
||||
|
||||
torque_publishers[joint_names[i]].publish(torque_commands[i]);
|
||||
torque_publishers[joint_names[i]].publish(motor_commands[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -168,10 +169,19 @@ torch::Tensor RL_Sim::forward()
|
|||
return clamped;
|
||||
}
|
||||
|
||||
void signalHandler(int signum)
|
||||
{
|
||||
ros::shutdown();
|
||||
exit(0);
|
||||
}
|
||||
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
signal(SIGINT, signalHandler);
|
||||
|
||||
ros::init(argc, argv, "rl_sar");
|
||||
RL_Sim rl_sar;
|
||||
|
||||
ros::spin();
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue