mirror of https://github.com/fan-ziqi/rl_sar.git
feat: add more state && format
This commit is contained in:
parent
6d3ac75843
commit
74a98ea429
|
@ -1,15 +1,17 @@
|
|||
cmake_minimum_required(VERSION 3.0.2)
|
||||
project(rl_sar)
|
||||
|
||||
add_definitions(-DCMAKE_CURRENT_SOURCE_DIR="${CMAKE_CURRENT_SOURCE_DIR}")
|
||||
|
||||
set(CMAKE_BUILD_TYPE Debug)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g")
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
|
||||
|
||||
add_definitions(-DCMAKE_CURRENT_SOURCE_DIR="${CMAKE_CURRENT_SOURCE_DIR}")
|
||||
|
||||
find_package(Torch REQUIRED)
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${GAZEBO_CXX_FLAGS}")
|
||||
find_package(gazebo REQUIRED)
|
||||
|
||||
find_package(catkin REQUIRED COMPONENTS
|
||||
controller_manager
|
||||
genmsg
|
||||
|
@ -23,7 +25,6 @@ find_package(catkin REQUIRED COMPONENTS
|
|||
unitree_legged_msgs
|
||||
)
|
||||
|
||||
find_package(gazebo REQUIRED)
|
||||
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
|
||||
|
||||
catkin_package(
|
||||
|
@ -43,8 +44,6 @@ include_directories(
|
|||
library/matplotlibcpp
|
||||
)
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${GAZEBO_CXX_FLAGS}")
|
||||
|
||||
add_library(rl library/rl/rl.cpp library/rl/rl.hpp)
|
||||
target_link_libraries(rl "${TORCH_LIBRARIES}" Python3::Python Python3::Module)
|
||||
set_property(TARGET rl PROPERTY CXX_STANDARD 14)
|
||||
|
|
|
@ -14,9 +14,9 @@
|
|||
|
||||
using namespace UNITREE_LEGGED_SDK;
|
||||
|
||||
enum InitState {
|
||||
enum RobotState {
|
||||
STATE_WAITING = 0,
|
||||
STATE_POS_INIT,
|
||||
STATE_POS_START,
|
||||
STATE_RL_INIT,
|
||||
STATE_RL_START,
|
||||
STATE_POS_STOP,
|
||||
|
@ -28,23 +28,23 @@ public:
|
|||
RL_Real();
|
||||
~RL_Real();
|
||||
|
||||
void runModel();
|
||||
torch::Tensor forward() override;
|
||||
torch::Tensor compute_observation() override;
|
||||
void RunModel();
|
||||
torch::Tensor Forward() override;
|
||||
torch::Tensor ComputeObservation() override;
|
||||
|
||||
ObservationBuffer history_obs_buf;
|
||||
torch::Tensor history_obs;
|
||||
int motiontime = 0;
|
||||
|
||||
//udp
|
||||
void UDPSend();
|
||||
void UDPRecv();
|
||||
void UDPSend(){udp.Send();}
|
||||
void UDPRecv(){udp.Recv();}
|
||||
void RobotControl();
|
||||
Safety safe;
|
||||
UDP udp;
|
||||
LowCmd cmd = {0};
|
||||
LowState state = {0};
|
||||
xRockerBtnDataStruct _keyData;
|
||||
int motiontime = 0;
|
||||
|
||||
std::shared_ptr<LoopFunc> loop_control;
|
||||
std::shared_ptr<LoopFunc> loop_udpSend;
|
||||
|
@ -52,14 +52,16 @@ public:
|
|||
std::shared_ptr<LoopFunc> loop_rl;
|
||||
std::shared_ptr<LoopFunc> loop_plot;
|
||||
|
||||
float _percent;
|
||||
float _startPos[12];
|
||||
float start_percent = 0.0;
|
||||
float stop_percent = 0.0;
|
||||
float start_pos[12];
|
||||
float stop_pos[12];
|
||||
|
||||
int robot_state = STATE_WAITING;
|
||||
|
||||
std::vector<double> _t;
|
||||
std::vector<std::vector<double>> _real_joint_pos, _target_joint_pos;
|
||||
void plot();
|
||||
std::vector<double> plot_t;
|
||||
std::vector<std::vector<double>> plot_real_joint_pos, plot_target_joint_pos;
|
||||
void Plot();
|
||||
private:
|
||||
std::vector<std::string> joint_names;
|
||||
std::vector<double> joint_positions;
|
||||
|
|
|
@ -19,14 +19,14 @@ public:
|
|||
RL_Sim();
|
||||
~RL_Sim();
|
||||
|
||||
void modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg);
|
||||
void jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg);
|
||||
void cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg);
|
||||
void ModelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg);
|
||||
void JointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg);
|
||||
void CmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg);
|
||||
|
||||
void runModel();
|
||||
void RunModel();
|
||||
void RobotControl();
|
||||
torch::Tensor forward() override;
|
||||
torch::Tensor compute_observation() override;
|
||||
torch::Tensor Forward() override;
|
||||
torch::Tensor ComputeObservation() override;
|
||||
|
||||
ObservationBuffer history_obs_buf;
|
||||
torch::Tensor history_obs;
|
||||
|
@ -37,9 +37,9 @@ public:
|
|||
std::shared_ptr<LoopFunc> loop_rl;
|
||||
std::shared_ptr<LoopFunc> loop_plot;
|
||||
|
||||
std::vector<double> _t;
|
||||
std::vector<std::vector<double>> _real_joint_pos, _target_joint_pos;
|
||||
void plot();
|
||||
std::vector<double> plot_t;
|
||||
std::vector<std::vector<double>> plot_real_joint_pos, plot_target_joint_pos;
|
||||
void Plot();
|
||||
private:
|
||||
std::vector<std::string> torque_command_topics;
|
||||
|
||||
|
@ -58,8 +58,6 @@ private:
|
|||
std::vector<double> joint_positions;
|
||||
std::vector<double> joint_velocities;
|
||||
|
||||
torch::Tensor torques;
|
||||
|
||||
std::chrono::high_resolution_clock::time_point start_time;
|
||||
|
||||
// other rl module
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#include "rl.hpp"
|
||||
|
||||
torch::Tensor RL::quat_rotate_inverse(torch::Tensor q, torch::Tensor v)
|
||||
torch::Tensor RL::QuatRotateInverse(torch::Tensor q, torch::Tensor v)
|
||||
{
|
||||
c10::IntArrayRef shape = q.sizes();
|
||||
torch::Tensor q_w = q.index({torch::indexing::Slice(), -1});
|
||||
|
@ -11,7 +11,7 @@ torch::Tensor RL::quat_rotate_inverse(torch::Tensor q, torch::Tensor v)
|
|||
return a - b + c;
|
||||
}
|
||||
|
||||
void RL::init_observations()
|
||||
void RL::InitObservations()
|
||||
{
|
||||
this->obs.lin_vel = torch::tensor({{0.0, 0.0, 0.0}});
|
||||
this->obs.ang_vel = torch::tensor({{0.0, 0.0, 0.0}});
|
||||
|
@ -23,7 +23,7 @@ void RL::init_observations()
|
|||
this->obs.actions = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
|
||||
}
|
||||
|
||||
torch::Tensor RL::compute_torques(torch::Tensor actions)
|
||||
torch::Tensor RL::ComputeTorques(torch::Tensor actions)
|
||||
{
|
||||
torch::Tensor actions_scaled = actions * this->params.action_scale;
|
||||
int indices[] = {0, 3, 6, 9};
|
||||
|
@ -32,12 +32,12 @@ torch::Tensor RL::compute_torques(torch::Tensor actions)
|
|||
actions_scaled[0][i] *= this->params.hip_scale_reduction;
|
||||
}
|
||||
|
||||
torch::Tensor torques = this->params.p_gains * (actions_scaled + this->params.default_dof_pos - this->obs.dof_pos) - this->params.d_gains * this->obs.dof_vel;
|
||||
torch::Tensor clamped = torch::clamp(torques, -(this->params.torque_limits), this->params.torque_limits);
|
||||
torch::Tensor output_torques = this->params.p_gains * (actions_scaled + this->params.default_dof_pos - this->obs.dof_pos) - this->params.d_gains * this->obs.dof_vel;
|
||||
torch::Tensor clamped = torch::clamp(output_torques, -(this->params.torque_limits), this->params.torque_limits);
|
||||
return clamped;
|
||||
}
|
||||
|
||||
torch::Tensor RL::compute_pos(torch::Tensor actions)
|
||||
torch::Tensor RL::ComputePosition(torch::Tensor actions)
|
||||
{
|
||||
torch::Tensor actions_scaled = actions * this->params.action_scale;
|
||||
int indices[] = {0, 3, 6, 9};
|
||||
|
@ -49,12 +49,12 @@ torch::Tensor RL::compute_pos(torch::Tensor actions)
|
|||
return actions_scaled + this->params.default_dof_pos;
|
||||
}
|
||||
|
||||
/* You may need to override this compute_observation() function
|
||||
torch::Tensor RL::compute_observation()
|
||||
/* You may need to override this ComputeObservation() function
|
||||
torch::Tensor RL::ComputeObservation()
|
||||
{
|
||||
torch::Tensor obs = torch::cat({(this->quat_rotate_inverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale,
|
||||
(this->quat_rotate_inverse(this->obs.base_quat, this->obs.ang_vel)) * this->params.ang_vel_scale,
|
||||
this->quat_rotate_inverse(this->obs.base_quat, this->obs.gravity_vec),
|
||||
torch::Tensor obs = torch::cat({(this->QuatRotateInverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale,
|
||||
(this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel)) * this->params.ang_vel_scale,
|
||||
this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec),
|
||||
this->obs.commands * this->params.commands_scale,
|
||||
(this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale,
|
||||
this->obs.dof_vel * this->params.dof_vel_scale,
|
||||
|
@ -69,10 +69,10 @@ torch::Tensor RL::compute_observation()
|
|||
}
|
||||
*/
|
||||
|
||||
/* You may need to override this forward() function
|
||||
torch::Tensor RL::forward()
|
||||
/* You may need to override this Forward() function
|
||||
torch::Tensor RL::Forward()
|
||||
{
|
||||
torch::Tensor obs = this->compute_observation();
|
||||
torch::Tensor obs = this->ComputeObservation();
|
||||
|
||||
torch::Tensor actor_input = torch::cat({obs}, 1);
|
||||
|
||||
|
|
|
@ -49,12 +49,12 @@ public:
|
|||
ModelParams params;
|
||||
Observations obs;
|
||||
|
||||
virtual torch::Tensor forward() = 0;
|
||||
virtual torch::Tensor compute_observation() = 0;
|
||||
torch::Tensor compute_torques(torch::Tensor actions);
|
||||
torch::Tensor compute_pos(torch::Tensor actions);
|
||||
torch::Tensor quat_rotate_inverse(torch::Tensor q, torch::Tensor v);
|
||||
void init_observations();
|
||||
virtual torch::Tensor Forward() = 0;
|
||||
virtual torch::Tensor ComputeObservation() = 0;
|
||||
torch::Tensor ComputeTorques(torch::Tensor actions);
|
||||
torch::Tensor ComputePosition(torch::Tensor actions);
|
||||
torch::Tensor QuatRotateInverse(torch::Tensor q, torch::Tensor v);
|
||||
void InitObservations();
|
||||
|
||||
protected:
|
||||
// rl module
|
||||
|
@ -69,8 +69,8 @@ protected:
|
|||
torch::Tensor dof_vel;
|
||||
torch::Tensor actions;
|
||||
// output buffer
|
||||
torch::Tensor torques;
|
||||
torch::Tensor target_dof_pos;
|
||||
torch::Tensor output_torques;
|
||||
torch::Tensor output_dof_pos;
|
||||
};
|
||||
|
||||
#endif // RL_HPP
|
|
@ -131,7 +131,7 @@ namespace UNITREE_LEGGED_SDK
|
|||
uint32_t SN;
|
||||
uint8_t bandWidth;
|
||||
uint8_t mode; // 0:idle, default stand 1:forced stand 2:walk continuously
|
||||
float forwardSpeed; // speed of move forward or backward, scale: -1~1
|
||||
float forwardSpeed; // speed of move Forward or backward, scale: -1~1
|
||||
float sideSpeed; // speed of move left or right, scale: -1~1
|
||||
float rotateSpeed; // speed of spin left or right, scale: -1~1
|
||||
float bodyHeight; // body height, scale: -1~1
|
||||
|
|
|
@ -1,20 +1,9 @@
|
|||
#include "../include/rl_real.hpp"
|
||||
|
||||
// #define CONTROL_BY_TORQUE
|
||||
// #define PLOT
|
||||
|
||||
RL_Real rl_sar;
|
||||
|
||||
void RL_Real::UDPRecv()
|
||||
{
|
||||
udp.Recv();
|
||||
}
|
||||
|
||||
void RL_Real::UDPSend()
|
||||
{
|
||||
udp.Send();
|
||||
}
|
||||
|
||||
void RL_Real::RobotControl()
|
||||
{
|
||||
motiontime++;
|
||||
|
@ -23,14 +12,28 @@ void RL_Real::RobotControl()
|
|||
memcpy(&_keyData, state.wirelessRemote, 40);
|
||||
|
||||
// get joy button
|
||||
if(robot_state < STATE_POS_INIT && (int)_keyData.btn.components.R2 == 1)
|
||||
if(robot_state < STATE_POS_START && (int)_keyData.btn.components.R2 == 1)
|
||||
{
|
||||
robot_state = STATE_POS_INIT;
|
||||
start_percent = 0.0;
|
||||
for(int i = 0; i < 12; ++i)
|
||||
{
|
||||
start_pos[i] = state.motorState[i].q;
|
||||
}
|
||||
robot_state = STATE_POS_START;
|
||||
}
|
||||
else if(robot_state < STATE_RL_INIT && (int)_keyData.btn.components.R1 == 1)
|
||||
{
|
||||
robot_state = STATE_RL_INIT;
|
||||
}
|
||||
else if(robot_state == STATE_RL_START && (int)_keyData.btn.components.L2 == 1)
|
||||
{
|
||||
stop_percent = 0.0;
|
||||
for(int i = 0; i < 12; ++i)
|
||||
{
|
||||
stop_pos[i] = state.motorState[i].q;
|
||||
}
|
||||
robot_state = STATE_POS_STOP;
|
||||
}
|
||||
|
||||
// wait for standup
|
||||
if(robot_state == STATE_WAITING)
|
||||
|
@ -38,57 +41,68 @@ void RL_Real::RobotControl()
|
|||
for(int i = 0; i < 12; ++i)
|
||||
{
|
||||
cmd.motorCmd[i].q = state.motorState[i].q;
|
||||
_startPos[i] = state.motorState[i].q;
|
||||
}
|
||||
}
|
||||
// standup (position control)
|
||||
else if(robot_state == STATE_POS_INIT && _percent != 1)
|
||||
else if(robot_state == STATE_POS_START && start_percent != 1)
|
||||
{
|
||||
_percent += 1 / 1000.0;
|
||||
_percent = _percent > 1 ? 1 : _percent;
|
||||
start_percent += 1 / 1000.0;
|
||||
start_percent = start_percent > 1 ? 1 : start_percent;
|
||||
for(int i = 0; i < 12; ++i)
|
||||
{
|
||||
cmd.motorCmd[i].mode = 0x0A;
|
||||
cmd.motorCmd[i].q = (1 - _percent) * _startPos[i] + _percent * params.default_dof_pos[0][dof_mapping[i]].item<double>();
|
||||
cmd.motorCmd[i].q = (1 - start_percent) * start_pos[i] + start_percent * params.default_dof_pos[0][dof_mapping[i]].item<double>();
|
||||
cmd.motorCmd[i].dq = 0;
|
||||
cmd.motorCmd[i].Kp = 50;
|
||||
cmd.motorCmd[i].Kd = 3;
|
||||
cmd.motorCmd[i].tau = 0;
|
||||
}
|
||||
printf("initing %.3f%%\r", _percent*100.0);
|
||||
printf("starting %.3f%%\r", start_percent*100.0);
|
||||
}
|
||||
// init obs and start rl loop
|
||||
else if(robot_state == STATE_RL_INIT && _percent == 1)
|
||||
else if(robot_state == STATE_RL_INIT && start_percent == 1)
|
||||
{
|
||||
robot_state = STATE_RL_START;
|
||||
this->init_observations();
|
||||
this->InitObservations();
|
||||
printf("\nstart rl loop\n");
|
||||
loop_rl->start();
|
||||
}
|
||||
// rl loop
|
||||
else if(robot_state == STATE_RL_START)
|
||||
{
|
||||
#ifdef CONTROL_BY_TORQUE
|
||||
for (int i = 0; i < 12; ++i)
|
||||
{
|
||||
cmd.motorCmd[i].mode = 0x0A;
|
||||
cmd.motorCmd[i].q = 0;
|
||||
cmd.motorCmd[i].dq = 0;
|
||||
cmd.motorCmd[i].Kp = 0;
|
||||
cmd.motorCmd[i].Kd = 0;
|
||||
cmd.motorCmd[i].tau = torques[0][dof_mapping[i]].item<double>();
|
||||
}
|
||||
#else
|
||||
for (int i = 0; i < 12; ++i)
|
||||
{
|
||||
cmd.motorCmd[i].mode = 0x0A;
|
||||
cmd.motorCmd[i].q = target_dof_pos[0][dof_mapping[i]].item<double>();
|
||||
cmd.motorCmd[i].dq = 0;
|
||||
cmd.motorCmd[i].Kp = params.stiffness;
|
||||
cmd.motorCmd[i].Kd = params.damping;
|
||||
cmd.motorCmd[i].tau = 0;
|
||||
}
|
||||
#endif
|
||||
for (int i = 0; i < 12; ++i)
|
||||
{
|
||||
cmd.motorCmd[i].mode = 0x0A;
|
||||
cmd.motorCmd[i].q = output_dof_pos[0][dof_mapping[i]].item<double>();
|
||||
cmd.motorCmd[i].dq = 0;
|
||||
cmd.motorCmd[i].Kp = params.stiffness;
|
||||
cmd.motorCmd[i].Kd = params.damping;
|
||||
// cmd.motorCmd[i].tau = output_torques[0][dof_mapping[i]].item<double>();
|
||||
cmd.motorCmd[i].tau = 0;
|
||||
}
|
||||
}
|
||||
// move to start pos
|
||||
else if(robot_state == STATE_POS_STOP && stop_percent != 1)
|
||||
{
|
||||
stop_percent += 1 / 1000.0;
|
||||
stop_percent = stop_percent > 1 ? 1 : stop_percent;
|
||||
for(int i = 0; i < 12; ++i)
|
||||
{
|
||||
cmd.motorCmd[i].mode = 0x0A;
|
||||
cmd.motorCmd[i].q = (1 - stop_percent) * stop_pos[i] + stop_percent * start_pos[i];
|
||||
cmd.motorCmd[i].dq = 0;
|
||||
cmd.motorCmd[i].Kp = 50;
|
||||
cmd.motorCmd[i].Kd = 3;
|
||||
cmd.motorCmd[i].tau = 0;
|
||||
}
|
||||
printf("stopping %.3f%%\r", stop_percent*100.0);
|
||||
}
|
||||
else if(robot_state == STATE_POS_STOP && stop_percent == 1)
|
||||
{
|
||||
robot_state = STATE_WAITING;
|
||||
this->InitObservations();
|
||||
printf("\nstop rl loop\n");
|
||||
loop_rl->shutdown();
|
||||
}
|
||||
|
||||
safe.PowerProtect(cmd, state, 7);
|
||||
|
@ -108,7 +122,7 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL)
|
|||
this->actor = torch::jit::load(actor_path);
|
||||
this->encoder = torch::jit::load(encoder_path);
|
||||
this->vq = torch::jit::load(vq_path);
|
||||
this->init_observations();
|
||||
this->InitObservations();
|
||||
|
||||
this->params.num_observations = 45;
|
||||
this->params.clip_obs = 100.0;
|
||||
|
@ -139,22 +153,22 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL)
|
|||
|
||||
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
|
||||
|
||||
torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
|
||||
target_dof_pos = params.default_dof_pos;
|
||||
_real_joint_pos.resize(12);
|
||||
_target_joint_pos.resize(12);
|
||||
output_torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
|
||||
output_dof_pos = params.default_dof_pos;
|
||||
plot_real_joint_pos.resize(12);
|
||||
plot_target_joint_pos.resize(12);
|
||||
|
||||
loop_control = std::make_shared<LoopFunc>("loop_control", 0.002, boost::bind(&RL_Real::RobotControl, this));
|
||||
loop_udpSend = std::make_shared<LoopFunc>("loop_udpSend", 0.002, 3, boost::bind(&RL_Real::UDPSend, this));
|
||||
loop_udpRecv = std::make_shared<LoopFunc>("loop_udpRecv", 0.002, 3, boost::bind(&RL_Real::UDPRecv, this));
|
||||
loop_rl = std::make_shared<LoopFunc>("loop_rl" , 0.02 , boost::bind(&RL_Real::runModel, this));
|
||||
loop_rl = std::make_shared<LoopFunc>("loop_rl" , 0.02 , boost::bind(&RL_Real::RunModel, this));
|
||||
|
||||
loop_udpSend->start();
|
||||
loop_udpRecv->start();
|
||||
loop_control->start();
|
||||
|
||||
#ifdef PLOT
|
||||
loop_plot = std::make_shared<LoopFunc>("loop_plot" , 0.002, boost::bind(&RL_Real::plot, this));
|
||||
loop_plot = std::make_shared<LoopFunc>("loop_plot" , 0.002, boost::bind(&RL_Real::Plot, this));
|
||||
loop_plot->start();
|
||||
#endif
|
||||
}
|
||||
|
@ -171,25 +185,25 @@ RL_Real::~RL_Real()
|
|||
printf("exit\n");
|
||||
}
|
||||
|
||||
void RL_Real::plot()
|
||||
void RL_Real::Plot()
|
||||
{
|
||||
_t.push_back(motiontime);
|
||||
plot_t.push_back(motiontime);
|
||||
plt::cla();
|
||||
plt::clf();
|
||||
for(int i = 0; i < 12; ++i)
|
||||
{
|
||||
_real_joint_pos[i].push_back(state.motorState[i].q);
|
||||
_target_joint_pos[i].push_back(cmd.motorCmd[i].q);
|
||||
plot_real_joint_pos[i].push_back(state.motorState[i].q);
|
||||
plot_target_joint_pos[i].push_back(cmd.motorCmd[i].q);
|
||||
plt::subplot(4, 3, i+1);
|
||||
plt::named_plot("_real_joint_pos", _t, _real_joint_pos[i], "r");
|
||||
plt::named_plot("_target_joint_pos", _t, _target_joint_pos[i], "b");
|
||||
plt::named_plot("_real_joint_pos", plot_t, plot_real_joint_pos[i], "r");
|
||||
plt::named_plot("_target_joint_pos", plot_t, plot_target_joint_pos[i], "b");
|
||||
plt::xlim(motiontime-10000, motiontime);
|
||||
}
|
||||
// plt::legend();
|
||||
plt::pause(0.0001);
|
||||
}
|
||||
|
||||
void RL_Real::runModel()
|
||||
void RL_Real::RunModel()
|
||||
{
|
||||
if(robot_state == STATE_RL_START)
|
||||
{
|
||||
|
@ -224,21 +238,19 @@ void RL_Real::runModel()
|
|||
state.motorState[RL_0].dq, state.motorState[RL_1].dq, state.motorState[RL_2].dq,
|
||||
state.motorState[RR_0].dq, state.motorState[RR_1].dq, state.motorState[RR_2].dq}});
|
||||
|
||||
torch::Tensor actions = this->forward();
|
||||
#ifdef CONTROL_BY_TORQUE
|
||||
torques = this->compute_torques(actions);
|
||||
#else
|
||||
target_dof_pos = this->compute_pos(actions);
|
||||
#endif
|
||||
torch::Tensor actions = this->Forward();
|
||||
|
||||
output_torques = this->ComputeTorques(actions);
|
||||
output_dof_pos = this->ComputePosition(actions);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
torch::Tensor RL_Real::compute_observation()
|
||||
torch::Tensor RL_Real::ComputeObservation()
|
||||
{
|
||||
torch::Tensor obs = torch::cat({// (this->quat_rotate_inverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale,
|
||||
this->quat_rotate_inverse(this->obs.base_quat, this->obs.ang_vel) * this->params.ang_vel_scale,
|
||||
this->quat_rotate_inverse(this->obs.base_quat, this->obs.gravity_vec),
|
||||
torch::Tensor obs = torch::cat({// (this->QuatRotateInverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale,
|
||||
this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel) * this->params.ang_vel_scale,
|
||||
this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec),
|
||||
this->obs.commands * this->params.commands_scale,
|
||||
(this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale,
|
||||
this->obs.dof_vel * this->params.dof_vel_scale,
|
||||
|
@ -248,9 +260,9 @@ torch::Tensor RL_Real::compute_observation()
|
|||
return obs;
|
||||
}
|
||||
|
||||
torch::Tensor RL_Real::forward()
|
||||
torch::Tensor RL_Real::Forward()
|
||||
{
|
||||
torch::Tensor obs = this->compute_observation();
|
||||
torch::Tensor obs = this->ComputeObservation();
|
||||
|
||||
history_obs_buf.insert(obs);
|
||||
history_obs = history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5});
|
||||
|
|
|
@ -8,31 +8,30 @@ void RL_Sim::RobotControl()
|
|||
for (int i = 0; i < 12; ++i)
|
||||
{
|
||||
motor_commands[i].mode = 0x0A;
|
||||
// motor_commands[i].tau = torques[0][i].item<double>();
|
||||
motor_commands[i].tau = 0;
|
||||
motor_commands[i].q = target_dof_pos[0][i].item<double>();
|
||||
motor_commands[i].q = output_dof_pos[0][i].item<double>();
|
||||
motor_commands[i].dq = 0;
|
||||
motor_commands[i].Kp = params.stiffness;
|
||||
motor_commands[i].Kd = params.damping;
|
||||
// motor_commands[i].tau = output_torques[0][i].item<double>();
|
||||
motor_commands[i].tau = 0;
|
||||
|
||||
torque_publishers[joint_names[i]].publish(motor_commands[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void RL_Sim::plot()
|
||||
void RL_Sim::Plot()
|
||||
{
|
||||
int dof_mapping[13] = {1, 2, 0, 4, 5, 3, 7, 8, 6, 10, 11, 9};
|
||||
_t.push_back(motiontime);
|
||||
plot_t.push_back(motiontime);
|
||||
plt::cla();
|
||||
plt::clf();
|
||||
for(int i = 0; i < 12; ++i)
|
||||
{
|
||||
_real_joint_pos[i].push_back(joint_positions[dof_mapping[i]]);
|
||||
_target_joint_pos[i].push_back(motor_commands[i].q);
|
||||
plot_real_joint_pos[i].push_back(joint_positions[dof_mapping[i]]);
|
||||
plot_target_joint_pos[i].push_back(motor_commands[i].q);
|
||||
plt::subplot(4, 3, i+1);
|
||||
plt::named_plot("_real_joint_pos", _t, _real_joint_pos[i], "r");
|
||||
plt::named_plot("_target_joint_pos", _t, _target_joint_pos[i], "b");
|
||||
plt::named_plot("_real_joint_pos", plot_t, plot_real_joint_pos[i], "r");
|
||||
plt::named_plot("_target_joint_pos", plot_t, plot_target_joint_pos[i], "b");
|
||||
plt::xlim(motiontime-10000, motiontime);
|
||||
}
|
||||
// plt::legend();
|
||||
|
@ -55,7 +54,7 @@ RL_Sim::RL_Sim()
|
|||
this->actor = torch::jit::load(actor_path);
|
||||
this->encoder = torch::jit::load(encoder_path);
|
||||
this->vq = torch::jit::load(vq_path);
|
||||
this->init_observations();
|
||||
this->InitObservations();
|
||||
|
||||
this->params.num_observations = 45;
|
||||
this->params.clip_obs = 100.0;
|
||||
|
@ -87,14 +86,14 @@ RL_Sim::RL_Sim()
|
|||
|
||||
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
|
||||
|
||||
torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
|
||||
target_dof_pos = params.default_dof_pos;
|
||||
output_torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
|
||||
output_dof_pos = params.default_dof_pos;
|
||||
joint_positions = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0};
|
||||
joint_velocities = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0};
|
||||
_real_joint_pos.resize(12);
|
||||
_target_joint_pos.resize(12);
|
||||
plot_real_joint_pos.resize(12);
|
||||
plot_target_joint_pos.resize(12);
|
||||
|
||||
cmd_vel_subscriber_ = nh.subscribe<geometry_msgs::Twist>("/cmd_vel", 10, &RL_Sim::cmdvelCallback, this);
|
||||
cmd_vel_subscriber_ = nh.subscribe<geometry_msgs::Twist>("/cmd_vel", 10, &RL_Sim::CmdvelCallback, this);
|
||||
|
||||
std::string ros_namespace = "/a1_gazebo/";
|
||||
|
||||
|
@ -112,18 +111,18 @@ RL_Sim::RL_Sim()
|
|||
}
|
||||
|
||||
model_state_subscriber_ = nh.subscribe<gazebo_msgs::ModelStates>(
|
||||
"/gazebo/model_states", 10, &RL_Sim::modelStatesCallback, this);
|
||||
"/gazebo/model_states", 10, &RL_Sim::ModelStatesCallback, this);
|
||||
|
||||
joint_state_subscriber_ = nh.subscribe<sensor_msgs::JointState>(
|
||||
"/a1_gazebo/joint_states", 10, &RL_Sim::jointStatesCallback, this);
|
||||
"/a1_gazebo/joint_states", 10, &RL_Sim::JointStatesCallback, this);
|
||||
|
||||
loop_control = std::make_shared<LoopFunc>("loop_control", 0.002, boost::bind(&RL_Sim::RobotControl, this));
|
||||
loop_rl = std::make_shared<LoopFunc>("loop_rl" , 0.02 , boost::bind(&RL_Sim::runModel, this));
|
||||
loop_rl = std::make_shared<LoopFunc>("loop_rl" , 0.02 , boost::bind(&RL_Sim::RunModel, this));
|
||||
|
||||
loop_control->start();
|
||||
loop_rl->start();
|
||||
#ifdef PLOT
|
||||
loop_plot = std::make_shared<LoopFunc>("loop_plot" , 0.002, boost::bind(&RL_Sim::plot, this));
|
||||
loop_plot = std::make_shared<LoopFunc>("loop_plot" , 0.002, boost::bind(&RL_Sim::Plot, this));
|
||||
loop_plot->start();
|
||||
#endif
|
||||
}
|
||||
|
@ -138,25 +137,25 @@ RL_Sim::~RL_Sim()
|
|||
printf("exit\n");
|
||||
}
|
||||
|
||||
void RL_Sim::modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg)
|
||||
void RL_Sim::ModelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg)
|
||||
{
|
||||
|
||||
vel = msg->twist[2];
|
||||
pose = msg->pose[2];
|
||||
}
|
||||
|
||||
void RL_Sim::cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg)
|
||||
void RL_Sim::CmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg)
|
||||
{
|
||||
cmd_vel = *msg;
|
||||
}
|
||||
|
||||
void RL_Sim::jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg)
|
||||
void RL_Sim::JointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg)
|
||||
{
|
||||
joint_positions = msg->position;
|
||||
joint_velocities = msg->velocity;
|
||||
}
|
||||
|
||||
void RL_Sim::runModel()
|
||||
void RL_Sim::RunModel()
|
||||
{
|
||||
// auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count();
|
||||
// std::cout << "Execution time: " << duration << " microseconds" << std::endl;
|
||||
|
@ -190,16 +189,17 @@ void RL_Sim::runModel()
|
|||
joint_velocities[7], joint_velocities[8], joint_velocities[6],
|
||||
joint_velocities[10], joint_velocities[11], joint_velocities[9]}});
|
||||
|
||||
torch::Tensor actions = this->forward();
|
||||
torques = this->compute_torques(actions);
|
||||
target_dof_pos = this->compute_pos(actions);
|
||||
torch::Tensor actions = this->Forward();
|
||||
|
||||
output_torques = this->ComputeTorques(actions);
|
||||
output_dof_pos = this->ComputePosition(actions);
|
||||
}
|
||||
|
||||
torch::Tensor RL_Sim::compute_observation()
|
||||
torch::Tensor RL_Sim::ComputeObservation()
|
||||
{
|
||||
torch::Tensor obs = torch::cat({// (this->quat_rotate_inverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale,
|
||||
(this->quat_rotate_inverse(this->obs.base_quat, this->obs.ang_vel)) * this->params.ang_vel_scale,
|
||||
this->quat_rotate_inverse(this->obs.base_quat, this->obs.gravity_vec),
|
||||
torch::Tensor obs = torch::cat({// (this->QuatRotateInverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale,
|
||||
(this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel)) * this->params.ang_vel_scale,
|
||||
this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec),
|
||||
this->obs.commands * this->params.commands_scale,
|
||||
(this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale,
|
||||
this->obs.dof_vel * this->params.dof_vel_scale,
|
||||
|
@ -209,9 +209,9 @@ torch::Tensor RL_Sim::compute_observation()
|
|||
return obs;
|
||||
}
|
||||
|
||||
torch::Tensor RL_Sim::forward()
|
||||
torch::Tensor RL_Sim::Forward()
|
||||
{
|
||||
torch::Tensor obs = this->compute_observation();
|
||||
torch::Tensor obs = this->ComputeObservation();
|
||||
|
||||
history_obs_buf.insert(obs);
|
||||
history_obs = history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5});
|
||||
|
|
Loading…
Reference in New Issue