mirror of https://github.com/fan-ziqi/rl_sar.git
89 lines
2.3 KiB
C++
89 lines
2.3 KiB
C++
#ifndef RL_REAL_HPP
|
|
#define RL_REAL_HPP
|
|
|
|
#include "../library/rl/rl.hpp"
|
|
#include "../library/observation_buffer/observation_buffer.hpp"
|
|
#include <unitree_legged_msgs/LowCmd.h>
|
|
#include "unitree_legged_msgs/LowState.h"
|
|
#include <unitree_legged_msgs/MotorCmd.h>
|
|
#include <unitree_legged_msgs/MotorState.h>
|
|
#include "unitree_legged_sdk/unitree_legged_sdk.h"
|
|
#include "unitree_legged_sdk/unitree_joystick.h"
|
|
#include <pthread.h>
|
|
#include <csignal>
|
|
// #include <signal.h>
|
|
|
|
using namespace UNITREE_LEGGED_SDK;
|
|
|
|
enum InitState {
|
|
STATE_WAITING = 0,
|
|
STATE_POS_INIT,
|
|
STATE_RL_INIT,
|
|
STATE_RL_START,
|
|
STATE_POS_STOP,
|
|
};
|
|
|
|
class RL_Real : public RL
|
|
{
|
|
public:
|
|
RL_Real();
|
|
~RL_Real();
|
|
void runModel();
|
|
torch::Tensor forward() override;
|
|
torch::Tensor compute_observation() override;
|
|
|
|
ObservationBuffer history_obs_buf;
|
|
torch::Tensor history_obs;
|
|
|
|
//udp
|
|
void UDPSend();
|
|
void UDPRecv();
|
|
void RobotControl();
|
|
Safety safe;
|
|
UDP udp;
|
|
LowCmd cmd = {0};
|
|
LowState state = {0};
|
|
xRockerBtnDataStruct _keyData;
|
|
int motiontime = 0;
|
|
|
|
std::shared_ptr<LoopFunc> loop_control;
|
|
std::shared_ptr<LoopFunc> loop_udpSend;
|
|
std::shared_ptr<LoopFunc> loop_udpRecv;
|
|
std::shared_ptr<LoopFunc> loop_rl;
|
|
|
|
float _percent;
|
|
// float _targetPos[12] = {0.0, 0.8, -1.6, 0.0, 0.8, -1.6,
|
|
// 0.0, 0.8, -1.6, 0.0, 0.8, -1.6};
|
|
float _startPos[12];
|
|
|
|
int init_state = STATE_WAITING;
|
|
|
|
private:
|
|
std::vector<std::string> joint_names;
|
|
std::vector<double> joint_positions;
|
|
std::vector<double> joint_velocities;
|
|
|
|
torch::Tensor torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
|
|
int dof_mapping[13] = {3, 4, 5,
|
|
0, 1, 2,
|
|
9, 10, 11,
|
|
6, 7, 8};
|
|
float Kp[13] = {20, 10, 10,
|
|
20, 10, 10,
|
|
20, 10, 10,
|
|
20, 10, 10};
|
|
float Kd[13] = {1.0, 0.5, 0.5,
|
|
1.0, 0.5, 0.5,
|
|
1.0, 0.5, 0.5,
|
|
1.0, 0.5, 0.5};
|
|
torch::Tensor target_dof_pos;
|
|
torch::Tensor compute_pos(torch::Tensor actions);
|
|
|
|
std::chrono::high_resolution_clock::time_point start_time;
|
|
|
|
// other rl module
|
|
torch::jit::script::Module encoder;
|
|
torch::jit::script::Module vq;
|
|
};
|
|
|
|
#endif |