mirror of https://github.com/fan-ziqi/rl_sar.git
feat: add plot && add thread in sim
This commit is contained in:
parent
05c181aeec
commit
89a8c9d814
|
@ -1,6 +1,9 @@
|
|||
cmake_minimum_required(VERSION 3.0.2)
|
||||
project(rl_sar)
|
||||
|
||||
set(CMAKE_BUILD_TYPE Debug)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g")
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
|
||||
|
||||
add_definitions(-DCMAKE_CURRENT_SOURCE_DIR="${CMAKE_CURRENT_SOURCE_DIR}")
|
||||
|
@ -21,6 +24,7 @@ find_package(catkin REQUIRED COMPONENTS
|
|||
)
|
||||
|
||||
find_package(gazebo REQUIRED)
|
||||
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
|
||||
|
||||
catkin_package(
|
||||
CATKIN_DEPENDS
|
||||
|
@ -36,14 +40,20 @@ include_directories(
|
|||
${catkin_INCLUDE_DIRS}
|
||||
${unitree_legged_sdk_INCLUDE_DIRS}
|
||||
../unitree_controller/include
|
||||
|
||||
library/matplotlibcpp
|
||||
)
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${GAZEBO_CXX_FLAGS}")
|
||||
|
||||
add_library(rl library/rl/rl.cpp library/rl/rl.hpp)
|
||||
target_link_libraries(rl "${TORCH_LIBRARIES}")
|
||||
target_link_libraries(rl "${TORCH_LIBRARIES}" Python3::Python Python3::Module)
|
||||
set_property(TARGET rl PROPERTY CXX_STANDARD 14)
|
||||
find_package(Python3 COMPONENTS NumPy)
|
||||
if(Python3_NumPy_FOUND)
|
||||
target_link_libraries(rl Python3::NumPy)
|
||||
else()
|
||||
target_compile_definitions(rl WITHOUT_NUMPY)
|
||||
endif()
|
||||
|
||||
add_library(observation_buffer library/observation_buffer/observation_buffer.cpp library/observation_buffer/observation_buffer.hpp)
|
||||
target_link_libraries(observation_buffer "${TORCH_LIBRARIES}")
|
||||
|
|
|
@ -9,7 +9,6 @@
|
|||
#include <unitree_legged_msgs/MotorState.h>
|
||||
#include "unitree_legged_sdk/unitree_legged_sdk.h"
|
||||
#include "unitree_legged_sdk/unitree_joystick.h"
|
||||
#include <pthread.h>
|
||||
#include <csignal>
|
||||
// #include <signal.h>
|
||||
|
||||
|
@ -51,12 +50,16 @@ public:
|
|||
std::shared_ptr<LoopFunc> loop_udpSend;
|
||||
std::shared_ptr<LoopFunc> loop_udpRecv;
|
||||
std::shared_ptr<LoopFunc> loop_rl;
|
||||
std::shared_ptr<LoopFunc> loop_plot;
|
||||
|
||||
float _percent;
|
||||
float _startPos[12];
|
||||
|
||||
int init_state = STATE_WAITING;
|
||||
int robot_state = STATE_WAITING;
|
||||
|
||||
std::vector<double> _t;
|
||||
std::vector<std::vector<double>> _real_joint_pos, _target_joint_pos;
|
||||
void plot();
|
||||
private:
|
||||
std::vector<std::string> joint_names;
|
||||
std::vector<double> joint_positions;
|
||||
|
|
|
@ -8,24 +8,38 @@
|
|||
#include <sensor_msgs/JointState.h>
|
||||
#include <geometry_msgs/Twist.h>
|
||||
#include "unitree_legged_msgs/MotorCmd.h"
|
||||
#include "unitree_legged_sdk/loop.h"
|
||||
#include <csignal>
|
||||
|
||||
using namespace UNITREE_LEGGED_SDK;
|
||||
|
||||
class RL_Sim : public RL
|
||||
{
|
||||
public:
|
||||
RL_Sim();
|
||||
~RL_Sim();
|
||||
|
||||
void modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg);
|
||||
void jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg);
|
||||
void cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg);
|
||||
|
||||
void runModel(const ros::TimerEvent &event);
|
||||
void runModel();
|
||||
void RobotControl();
|
||||
torch::Tensor forward() override;
|
||||
torch::Tensor compute_observation() override;
|
||||
|
||||
ObservationBuffer history_obs_buf;
|
||||
torch::Tensor history_obs;
|
||||
|
||||
int motiontime = 0;
|
||||
|
||||
std::shared_ptr<LoopFunc> loop_control;
|
||||
std::shared_ptr<LoopFunc> loop_rl;
|
||||
std::shared_ptr<LoopFunc> loop_plot;
|
||||
|
||||
std::vector<double> _t;
|
||||
std::vector<std::vector<double>> _real_joint_pos, _target_joint_pos;
|
||||
void plot();
|
||||
private:
|
||||
std::vector<std::string> torque_command_topics;
|
||||
|
||||
|
@ -46,8 +60,6 @@ private:
|
|||
|
||||
torch::Tensor torques;
|
||||
|
||||
ros::Timer timer;
|
||||
|
||||
std::chrono::high_resolution_clock::time_point start_time;
|
||||
|
||||
// other rl module
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -5,6 +5,9 @@
|
|||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "matplotlibcpp.h"
|
||||
namespace plt = matplotlibcpp;
|
||||
|
||||
struct ModelParams
|
||||
{
|
||||
int num_observations;
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#include "../include/rl_real.hpp"
|
||||
|
||||
// #define CONTROL_BY_TORQUE
|
||||
#define PLOT
|
||||
|
||||
RL_Real rl_sar;
|
||||
|
||||
|
@ -22,17 +23,17 @@ void RL_Real::RobotControl()
|
|||
memcpy(&_keyData, state.wirelessRemote, 40);
|
||||
|
||||
// get joy button
|
||||
if(init_state < STATE_POS_INIT && (int)_keyData.btn.components.R2 == 1)
|
||||
if(robot_state < STATE_POS_INIT && (int)_keyData.btn.components.R2 == 1)
|
||||
{
|
||||
init_state = STATE_POS_INIT;
|
||||
robot_state = STATE_POS_INIT;
|
||||
}
|
||||
else if(init_state < STATE_RL_INIT && (int)_keyData.btn.components.R1 == 1)
|
||||
else if(robot_state < STATE_RL_INIT && (int)_keyData.btn.components.R1 == 1)
|
||||
{
|
||||
init_state = STATE_RL_INIT;
|
||||
robot_state = STATE_RL_INIT;
|
||||
}
|
||||
|
||||
// wait for standup
|
||||
if(init_state == STATE_WAITING)
|
||||
if(robot_state == STATE_WAITING)
|
||||
{
|
||||
for(int i = 0; i < 12; ++i)
|
||||
{
|
||||
|
@ -41,10 +42,9 @@ void RL_Real::RobotControl()
|
|||
}
|
||||
}
|
||||
// standup (position control)
|
||||
else if(init_state == STATE_POS_INIT && _percent != 1)
|
||||
else if(robot_state == STATE_POS_INIT && _percent != 1)
|
||||
{
|
||||
printf("initing %d%%\r", (int)(_percent*100));
|
||||
_percent += (float) 1 / 1000;
|
||||
_percent += 1 / 1000.0;
|
||||
_percent = _percent > 1 ? 1 : _percent;
|
||||
for(int i = 0; i < 12; ++i)
|
||||
{
|
||||
|
@ -55,47 +55,28 @@ void RL_Real::RobotControl()
|
|||
cmd.motorCmd[i].Kd = 3;
|
||||
cmd.motorCmd[i].tau = 0;
|
||||
}
|
||||
printf("initing %.3f%%\r", _percent*100.0);
|
||||
}
|
||||
// init obs and start rl loop
|
||||
else if(init_state == STATE_RL_INIT && _percent == 1)
|
||||
else if(robot_state == STATE_RL_INIT && _percent == 1)
|
||||
{
|
||||
init_state = STATE_RL_START;
|
||||
motiontime = 0;
|
||||
robot_state = STATE_RL_START;
|
||||
this->init_observations();
|
||||
printf("\nstart rl loop\n");
|
||||
loop_rl->start();
|
||||
}
|
||||
// rl loop
|
||||
else if(init_state == STATE_RL_START)
|
||||
else if(robot_state == STATE_RL_START)
|
||||
{
|
||||
// wait for 500 times
|
||||
if( motiontime < 500)
|
||||
{
|
||||
for(int i = 0; i < 12; ++i)
|
||||
{
|
||||
cmd.motorCmd[i].mode = 0x0A;
|
||||
cmd.motorCmd[i].q = params.default_dof_pos[0][dof_mapping[i]].item<double>();
|
||||
cmd.motorCmd[i].dq = 0;
|
||||
cmd.motorCmd[i].Kp = 50;
|
||||
cmd.motorCmd[i].Kd = 3;
|
||||
cmd.motorCmd[i].tau = 0;
|
||||
_startPos[i] = state.motorState[i].q;
|
||||
}
|
||||
}
|
||||
if( motiontime >= 500)
|
||||
{
|
||||
#ifdef CONTROL_BY_TORQUE
|
||||
for (int i = 0; i < 12; ++i)
|
||||
{
|
||||
float torque = torques[0][dof_mapping[i]].item<double>();
|
||||
// if(torque > 5.0f) torque = 5.0f;
|
||||
// if(torque < -5.0f) torque = -5.0f;
|
||||
cmd.motorCmd[i].mode = 0x0A;
|
||||
cmd.motorCmd[i].q = 0;
|
||||
cmd.motorCmd[i].dq = 0;
|
||||
cmd.motorCmd[i].Kp = 0;
|
||||
cmd.motorCmd[i].Kd = 0;
|
||||
cmd.motorCmd[i].tau = torque;
|
||||
cmd.motorCmd[i].tau = torques[0][dof_mapping[i]].item<double>();
|
||||
}
|
||||
#else
|
||||
for (int i = 0; i < 12; ++i)
|
||||
|
@ -108,7 +89,6 @@ void RL_Real::RobotControl()
|
|||
cmd.motorCmd[i].tau = 0;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
safe.PowerProtect(cmd, state, 7);
|
||||
|
@ -161,16 +141,22 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL)
|
|||
|
||||
torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
|
||||
target_dof_pos = params.default_dof_pos;
|
||||
_real_joint_pos.resize(12);
|
||||
_target_joint_pos.resize(12);
|
||||
|
||||
// InitEnvironment();
|
||||
loop_control = std::make_shared<LoopFunc>("control_loop", 0.002, boost::bind(&RL_Real::RobotControl, this));
|
||||
loop_udpSend = std::make_shared<LoopFunc>("udp_send" , 0.002, 3, boost::bind(&RL_Real::UDPSend, this));
|
||||
loop_udpRecv = std::make_shared<LoopFunc>("udp_recv" , 0.002, 3, boost::bind(&RL_Real::UDPRecv, this));
|
||||
loop_rl = std::make_shared<LoopFunc>("rl_loop" , 0.02 , boost::bind(&RL_Real::runModel, this));
|
||||
loop_control = std::make_shared<LoopFunc>("loop_control", 0.002, boost::bind(&RL_Real::RobotControl, this));
|
||||
loop_udpSend = std::make_shared<LoopFunc>("loop_udpSend", 0.002, 3, boost::bind(&RL_Real::UDPSend, this));
|
||||
loop_udpRecv = std::make_shared<LoopFunc>("loop_udpRecv", 0.002, 3, boost::bind(&RL_Real::UDPRecv, this));
|
||||
loop_rl = std::make_shared<LoopFunc>("loop_rl" , 0.02 , boost::bind(&RL_Real::runModel, this));
|
||||
|
||||
loop_udpSend->start();
|
||||
loop_udpRecv->start();
|
||||
loop_control->start();
|
||||
|
||||
#ifdef PLOT
|
||||
loop_plot = std::make_shared<LoopFunc>("loop_plot" , 0.002, boost::bind(&RL_Real::plot, this));
|
||||
loop_plot->start();
|
||||
#endif
|
||||
}
|
||||
|
||||
RL_Real::~RL_Real()
|
||||
|
@ -179,21 +165,52 @@ RL_Real::~RL_Real()
|
|||
loop_udpRecv->shutdown();
|
||||
loop_control->shutdown();
|
||||
loop_rl->shutdown();
|
||||
printf("shutdown\n");
|
||||
#ifdef PLOT
|
||||
loop_plot->shutdown();
|
||||
#endif
|
||||
printf("exit\n");
|
||||
}
|
||||
|
||||
void RL_Real::plot()
|
||||
{
|
||||
_t.push_back(motiontime);
|
||||
plt::cla();
|
||||
plt::clf();
|
||||
for(int i = 0; i < 12; ++i)
|
||||
{
|
||||
_real_joint_pos[i].push_back(state.motorState[i].q);
|
||||
_target_joint_pos[i].push_back(cmd.motorCmd[i].q);
|
||||
plt::subplot(4, 3, i+1);
|
||||
plt::named_plot("_real_joint_pos", _t, _real_joint_pos[i], "r");
|
||||
plt::named_plot("_target_joint_pos", _t, _target_joint_pos[i], "b");
|
||||
plt::xlim(motiontime-10000, motiontime);
|
||||
}
|
||||
// plt::legend();
|
||||
plt::pause(0.0001);
|
||||
}
|
||||
|
||||
void RL_Real::runModel()
|
||||
{
|
||||
if(init_state == STATE_RL_START)
|
||||
if(robot_state == STATE_RL_START)
|
||||
{
|
||||
// auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count();
|
||||
// std::cout << "Execution time: " << duration << " microseconds" << std::endl;
|
||||
// start_time = std::chrono::high_resolution_clock::now();
|
||||
|
||||
// printf("%f, %f, %f\n", state.imu.gyroscope[0], state.imu.gyroscope[1], state.imu.gyroscope[2]);
|
||||
// printf("%f, %f, %f, %f\n", state.imu.quaternion[1], state.imu.quaternion[2], state.imu.quaternion[3], state.imu.quaternion[0]);
|
||||
// printf("%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f\n", state.motorState[FL_0].q, state.motorState[FL_1].q, state.motorState[FL_2].q, state.motorState[FR_0].q, state.motorState[FR_1].q, state.motorState[FR_2].q, state.motorState[RL_0].q, state.motorState[RL_1].q, state.motorState[RL_2].q, state.motorState[RR_0].q, state.motorState[RR_1].q, state.motorState[RR_2].q);
|
||||
// printf("%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f\n", state.motorState[FL_0].dq, state.motorState[FL_1].dq, state.motorState[FL_2].dq, state.motorState[FR_0].dq, state.motorState[FR_1].dq, state.motorState[FR_2].dq, state.motorState[RL_0].dq, state.motorState[RL_1].dq, state.motorState[RL_2].dq, state.motorState[RR_0].dq, state.motorState[RR_1].dq, state.motorState[RR_2].dq);
|
||||
// printf("%f, %f, %f\n",
|
||||
// state.imu.gyroscope[0], state.imu.gyroscope[1], state.imu.gyroscope[2]);
|
||||
// printf("%f, %f, %f, %f\n",
|
||||
// state.imu.quaternion[1], state.imu.quaternion[2], state.imu.quaternion[3], state.imu.quaternion[0]);
|
||||
// printf("%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f\n",
|
||||
// state.motorState[FL_0].q, state.motorState[FL_1].q, state.motorState[FL_2].q,
|
||||
// state.motorState[FR_0].q, state.motorState[FR_1].q, state.motorState[FR_2].q,
|
||||
// state.motorState[RL_0].q, state.motorState[RL_1].q, state.motorState[RL_2].q,
|
||||
// state.motorState[RR_0].q, state.motorState[RR_1].q, state.motorState[RR_2].q);
|
||||
// printf("%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f\n",
|
||||
// state.motorState[FL_0].dq, state.motorState[FL_1].dq, state.motorState[FL_2].dq,
|
||||
// state.motorState[FR_0].dq, state.motorState[FR_1].dq, state.motorState[FR_2].dq,
|
||||
// state.motorState[RL_0].dq, state.motorState[RL_1].dq, state.motorState[RL_2].dq,
|
||||
// state.motorState[RR_0].dq, state.motorState[RR_1].dq, state.motorState[RR_2].dq);
|
||||
|
||||
this->obs.ang_vel = torch::tensor({{state.imu.gyroscope[0], state.imu.gyroscope[1], state.imu.gyroscope[2]}});
|
||||
this->obs.commands = torch::tensor({{_keyData.ly, -_keyData.rx, -_keyData.lx}});
|
||||
|
|
|
@ -1,5 +1,44 @@
|
|||
#include "../include/rl_sim.hpp"
|
||||
|
||||
#define PLOT
|
||||
|
||||
void RL_Sim::RobotControl()
|
||||
{
|
||||
motiontime++;
|
||||
for (int i = 0; i < 12; ++i)
|
||||
{
|
||||
motor_commands[i].mode = 0x0A;
|
||||
// motor_commands[i].tau = torques[0][i].item<double>();
|
||||
motor_commands[i].tau = 0;
|
||||
motor_commands[i].q = target_dof_pos[0][i].item<double>();
|
||||
motor_commands[i].dq = 0;
|
||||
motor_commands[i].Kp = params.stiffness;
|
||||
motor_commands[i].Kd = params.damping;
|
||||
|
||||
torque_publishers[joint_names[i]].publish(motor_commands[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void RL_Sim::plot()
|
||||
{
|
||||
int dof_mapping[13] = {1, 2, 0, 4, 5, 3, 7, 8, 6, 10, 11, 9};
|
||||
_t.push_back(motiontime);
|
||||
plt::cla();
|
||||
plt::clf();
|
||||
for(int i = 0; i < 12; ++i)
|
||||
{
|
||||
_real_joint_pos[i].push_back(joint_positions[dof_mapping[i]]);
|
||||
_target_joint_pos[i].push_back(motor_commands[i].q);
|
||||
plt::subplot(4, 3, i+1);
|
||||
plt::named_plot("_real_joint_pos", _t, _real_joint_pos[i], "r");
|
||||
plt::named_plot("_target_joint_pos", _t, _target_joint_pos[i], "b");
|
||||
plt::xlim(motiontime-10000, motiontime);
|
||||
}
|
||||
// plt::legend();
|
||||
plt::pause(0.0001);
|
||||
}
|
||||
|
||||
RL_Sim::RL_Sim()
|
||||
{
|
||||
ros::NodeHandle nh;
|
||||
|
@ -50,11 +89,13 @@ RL_Sim::RL_Sim()
|
|||
|
||||
torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
|
||||
target_dof_pos = params.default_dof_pos;
|
||||
joint_positions = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0};
|
||||
joint_velocities = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0};
|
||||
_real_joint_pos.resize(12);
|
||||
_target_joint_pos.resize(12);
|
||||
|
||||
cmd_vel_subscriber_ = nh.subscribe<geometry_msgs::Twist>("/cmd_vel", 10, &RL_Sim::cmdvelCallback, this);
|
||||
|
||||
timer = nh.createTimer(ros::Duration(0.02), &RL_Sim::runModel, this);
|
||||
|
||||
std::string ros_namespace = "/a1_gazebo/";
|
||||
|
||||
joint_names = {
|
||||
|
@ -75,6 +116,26 @@ RL_Sim::RL_Sim()
|
|||
|
||||
joint_state_subscriber_ = nh.subscribe<sensor_msgs::JointState>(
|
||||
"/a1_gazebo/joint_states", 10, &RL_Sim::jointStatesCallback, this);
|
||||
|
||||
loop_control = std::make_shared<LoopFunc>("loop_control", 0.002, boost::bind(&RL_Sim::RobotControl, this));
|
||||
loop_rl = std::make_shared<LoopFunc>("loop_rl" , 0.02 , boost::bind(&RL_Sim::runModel, this));
|
||||
|
||||
loop_control->start();
|
||||
loop_rl->start();
|
||||
#ifdef PLOT
|
||||
loop_plot = std::make_shared<LoopFunc>("loop_plot" , 0.002, boost::bind(&RL_Sim::plot, this));
|
||||
loop_plot->start();
|
||||
#endif
|
||||
}
|
||||
|
||||
RL_Sim::~RL_Sim()
|
||||
{
|
||||
loop_control->shutdown();
|
||||
loop_rl->shutdown();
|
||||
#ifdef PLOT
|
||||
loop_plot->shutdown();
|
||||
#endif
|
||||
printf("exit\n");
|
||||
}
|
||||
|
||||
void RL_Sim::modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg)
|
||||
|
@ -95,12 +156,27 @@ void RL_Sim::jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg)
|
|||
joint_velocities = msg->velocity;
|
||||
}
|
||||
|
||||
void RL_Sim::runModel(const ros::TimerEvent &event)
|
||||
void RL_Sim::runModel()
|
||||
{
|
||||
// auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count();
|
||||
// std::cout << "Execution time: " << duration << " microseconds" << std::endl;
|
||||
// start_time = std::chrono::high_resolution_clock::now();
|
||||
|
||||
// printf("%f, %f, %f\n",
|
||||
// vel.angular.x, vel.angular.y, vel.angular.z);
|
||||
// printf("%f, %f, %f, %f\n",
|
||||
// pose.orientation.x, pose.orientation.y, pose.orientation.z, pose.orientation.w);
|
||||
// printf("%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f\n",
|
||||
// joint_positions[1], joint_positions[2], joint_positions[0],
|
||||
// joint_positions[4], joint_positions[5], joint_positions[3],
|
||||
// joint_positions[7], joint_positions[8], joint_positions[6],
|
||||
// joint_positions[10], joint_positions[11], joint_positions[9]);
|
||||
// printf("%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f\n",
|
||||
// joint_velocities[1], joint_velocities[2], joint_velocities[0],
|
||||
// joint_velocities[4], joint_velocities[5], joint_velocities[3],
|
||||
// joint_velocities[7], joint_velocities[8], joint_velocities[6],
|
||||
// joint_velocities[10], joint_velocities[11], joint_velocities[9]);
|
||||
|
||||
this->obs.lin_vel = torch::tensor({{vel.linear.x, vel.linear.y, vel.linear.z}});
|
||||
this->obs.ang_vel = torch::tensor({{vel.angular.x, vel.angular.y, vel.angular.z}});
|
||||
this->obs.commands = torch::tensor({{cmd_vel.linear.x, cmd_vel.linear.y, cmd_vel.angular.z}});
|
||||
|
@ -117,19 +193,6 @@ void RL_Sim::runModel(const ros::TimerEvent &event)
|
|||
torch::Tensor actions = this->forward();
|
||||
torques = this->compute_torques(actions);
|
||||
target_dof_pos = this->compute_pos(actions);
|
||||
|
||||
for (int i = 0; i < 12; ++i)
|
||||
{
|
||||
motor_commands[i].mode = 0x0A;
|
||||
// motor_commands[i].tau = torques[0][i].item<double>();
|
||||
motor_commands[i].tau = 0;
|
||||
motor_commands[i].q = target_dof_pos[0][i].item<double>();
|
||||
motor_commands[i].dq = 0;
|
||||
motor_commands[i].Kp = params.stiffness;
|
||||
motor_commands[i].Kd = params.damping;
|
||||
|
||||
torque_publishers[joint_names[i]].publish(motor_commands[i]);
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor RL_Sim::compute_observation()
|
||||
|
|
Loading…
Reference in New Issue