mirror of https://github.com/fan-ziqi/rl_sar.git
喝醉了
This commit is contained in:
parent
63d28c2bff
commit
6a23e62a47
|
@ -8,14 +8,26 @@
|
||||||
#include <unitree_legged_msgs/MotorCmd.h>
|
#include <unitree_legged_msgs/MotorCmd.h>
|
||||||
#include <unitree_legged_msgs/MotorState.h>
|
#include <unitree_legged_msgs/MotorState.h>
|
||||||
#include "unitree_legged_sdk/unitree_legged_sdk.h"
|
#include "unitree_legged_sdk/unitree_legged_sdk.h"
|
||||||
|
#include "unitree_legged_sdk/unitree_joystick.h"
|
||||||
#include <pthread.h>
|
#include <pthread.h>
|
||||||
|
#include <csignal>
|
||||||
|
// #include <signal.h>
|
||||||
|
|
||||||
using namespace UNITREE_LEGGED_SDK;
|
using namespace UNITREE_LEGGED_SDK;
|
||||||
|
|
||||||
|
enum InitState {
|
||||||
|
STATE_WAITING = 0,
|
||||||
|
STATE_POS_INIT,
|
||||||
|
STATE_RL_INIT,
|
||||||
|
STATE_RL_START,
|
||||||
|
STATE_POS_STOP,
|
||||||
|
};
|
||||||
|
|
||||||
class RL_Real : public RL
|
class RL_Real : public RL
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
RL_Real();
|
RL_Real();
|
||||||
|
~RL_Real();
|
||||||
void runModel();
|
void runModel();
|
||||||
torch::Tensor forward() override;
|
torch::Tensor forward() override;
|
||||||
torch::Tensor compute_observation() override;
|
torch::Tensor compute_observation() override;
|
||||||
|
@ -23,8 +35,6 @@ public:
|
||||||
ObservationBuffer history_obs_buf;
|
ObservationBuffer history_obs_buf;
|
||||||
torch::Tensor history_obs;
|
torch::Tensor history_obs;
|
||||||
|
|
||||||
torch::Tensor torques;
|
|
||||||
|
|
||||||
//udp
|
//udp
|
||||||
void UDPSend();
|
void UDPSend();
|
||||||
void UDPRecv();
|
void UDPRecv();
|
||||||
|
@ -33,6 +43,7 @@ public:
|
||||||
UDP udp;
|
UDP udp;
|
||||||
LowCmd cmd = {0};
|
LowCmd cmd = {0};
|
||||||
LowState state = {0};
|
LowState state = {0};
|
||||||
|
xRockerBtnDataStruct _keyData;
|
||||||
int motiontime = 0;
|
int motiontime = 0;
|
||||||
|
|
||||||
std::shared_ptr<LoopFunc> loop_control;
|
std::shared_ptr<LoopFunc> loop_control;
|
||||||
|
@ -40,25 +51,34 @@ public:
|
||||||
std::shared_ptr<LoopFunc> loop_udpRecv;
|
std::shared_ptr<LoopFunc> loop_udpRecv;
|
||||||
std::shared_ptr<LoopFunc> loop_rl;
|
std::shared_ptr<LoopFunc> loop_rl;
|
||||||
|
|
||||||
|
|
||||||
float _percent;
|
float _percent;
|
||||||
float _targetPos[12] = {0.0, 0.8, -1.6, 0.0, 0.8, -1.6,
|
// float _targetPos[12] = {0.0, 0.8, -1.6, 0.0, 0.8, -1.6,
|
||||||
0.0, 0.8, -1.6, 0.0, 0.8, -1.6}; //0.0, 0.67, -1.3
|
// 0.0, 0.8, -1.6, 0.0, 0.8, -1.6};
|
||||||
float _startPos[12];
|
float _startPos[12];
|
||||||
|
|
||||||
bool init_done = false;
|
int init_state = STATE_WAITING;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string ros_namespace;
|
|
||||||
|
|
||||||
std::vector<std::string> torque_command_topics;
|
|
||||||
|
|
||||||
std::vector<unitree_legged_msgs::MotorCmd> torque_commands;
|
|
||||||
|
|
||||||
std::vector<std::string> joint_names;
|
std::vector<std::string> joint_names;
|
||||||
std::vector<double> joint_positions;
|
std::vector<double> joint_positions;
|
||||||
std::vector<double> joint_velocities;
|
std::vector<double> joint_velocities;
|
||||||
|
|
||||||
|
torch::Tensor torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
|
||||||
|
int dof_mapping[13] = {3, 4, 5,
|
||||||
|
0, 1, 2,
|
||||||
|
9, 10, 11,
|
||||||
|
6, 7, 8};
|
||||||
|
float Kp[13] = {20, 10, 10,
|
||||||
|
20, 10, 10,
|
||||||
|
20, 10, 10,
|
||||||
|
20, 10, 10};
|
||||||
|
float Kd[13] = {1.0, 0.5, 0.5,
|
||||||
|
1.0, 0.5, 0.5,
|
||||||
|
1.0, 0.5, 0.5,
|
||||||
|
1.0, 0.5, 0.5};
|
||||||
|
torch::Tensor target_dof_pos;
|
||||||
|
torch::Tensor compute_pos(torch::Tensor actions);
|
||||||
|
|
||||||
std::chrono::high_resolution_clock::time_point start_time;
|
std::chrono::high_resolution_clock::time_point start_time;
|
||||||
|
|
||||||
// other rl module
|
// other rl module
|
||||||
|
|
|
@ -5,40 +5,40 @@ Copyright (c) 2020, Unitree Robotics.Co.Ltd. All rights reserved.
|
||||||
#define UNITREE_JOYSTICK_H
|
#define UNITREE_JOYSTICK_H
|
||||||
|
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
// // 16b
|
// 16b
|
||||||
// typedef union {
|
typedef union {
|
||||||
// struct {
|
struct {
|
||||||
// uint8_t R1 :1;
|
uint8_t R1 :1;
|
||||||
// uint8_t L1 :1;
|
uint8_t L1 :1;
|
||||||
// uint8_t start :1;
|
uint8_t start :1;
|
||||||
// uint8_t select :1;
|
uint8_t select :1;
|
||||||
// uint8_t R2 :1;
|
uint8_t R2 :1;
|
||||||
// uint8_t L2 :1;
|
uint8_t L2 :1;
|
||||||
// uint8_t F1 :1;
|
uint8_t F1 :1;
|
||||||
// uint8_t F2 :1;
|
uint8_t F2 :1;
|
||||||
// uint8_t A :1;
|
uint8_t A :1;
|
||||||
// uint8_t B :1;
|
uint8_t B :1;
|
||||||
// uint8_t X :1;
|
uint8_t X :1;
|
||||||
// uint8_t Y :1;
|
uint8_t Y :1;
|
||||||
// uint8_t up :1;
|
uint8_t up :1;
|
||||||
// uint8_t right :1;
|
uint8_t right :1;
|
||||||
// uint8_t down :1;
|
uint8_t down :1;
|
||||||
// uint8_t left :1;
|
uint8_t left :1;
|
||||||
// } components;
|
} components;
|
||||||
// uint16_t value;
|
uint16_t value;
|
||||||
// } xKeySwitchUnion;
|
} xKeySwitchUnion;
|
||||||
|
|
||||||
// // 40 Byte (now used 24B)
|
// 40 Byte (now used 24B)
|
||||||
// typedef struct {
|
typedef struct {
|
||||||
// uint8_t head[2];
|
uint8_t head[2];
|
||||||
// xKeySwitchUnion btn;
|
xKeySwitchUnion btn;
|
||||||
// float lx;
|
float lx;
|
||||||
// float rx;
|
float rx;
|
||||||
// float ry;
|
float ry;
|
||||||
// float L2;
|
float L2;
|
||||||
// float ly;
|
float ly;
|
||||||
|
|
||||||
// uint8_t idle[16];
|
uint8_t idle[16];
|
||||||
// } xRockerBtnDataStruct;
|
} xRockerBtnDataStruct;
|
||||||
|
|
||||||
#endif // UNITREE_JOYSTICK_H
|
#endif // UNITREE_JOYSTICK_H
|
|
@ -1,5 +1,9 @@
|
||||||
#include "../include/rl_real.hpp"
|
#include "../include/rl_real.hpp"
|
||||||
|
|
||||||
|
// #define CONTROL_BY_TORQUE
|
||||||
|
|
||||||
|
RL_Real rl_sar;
|
||||||
|
|
||||||
void RL_Real::UDPRecv()
|
void RL_Real::UDPRecv()
|
||||||
{
|
{
|
||||||
udp.Recv();
|
udp.Recv();
|
||||||
|
@ -15,7 +19,20 @@ void RL_Real::RobotControl()
|
||||||
motiontime++;
|
motiontime++;
|
||||||
udp.GetRecv(state);
|
udp.GetRecv(state);
|
||||||
|
|
||||||
if( motiontime < 500)
|
memcpy(&_keyData, state.wirelessRemote, 40);
|
||||||
|
|
||||||
|
// get joy button
|
||||||
|
if(init_state < STATE_POS_INIT && (int)_keyData.btn.components.R2 == 1)
|
||||||
|
{
|
||||||
|
init_state = STATE_POS_INIT;
|
||||||
|
}
|
||||||
|
else if(init_state < STATE_RL_INIT && (int)_keyData.btn.components.R1 == 1)
|
||||||
|
{
|
||||||
|
init_state = STATE_RL_INIT;
|
||||||
|
}
|
||||||
|
|
||||||
|
// wait for standup
|
||||||
|
if(init_state == STATE_WAITING)
|
||||||
{
|
{
|
||||||
for(int i = 0; i < 12; ++i)
|
for(int i = 0; i < 12; ++i)
|
||||||
{
|
{
|
||||||
|
@ -23,36 +40,39 @@ void RL_Real::RobotControl()
|
||||||
_startPos[i] = state.motorState[i].q;
|
_startPos[i] = state.motorState[i].q;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// standup (position control)
|
||||||
if( motiontime >= 500 && _percent != 1)
|
else if(init_state == STATE_POS_INIT && _percent != 1)
|
||||||
{
|
{
|
||||||
|
printf("initing %d%%\r", (int)(_percent*100));
|
||||||
_percent += (float) 1 / 1000;
|
_percent += (float) 1 / 1000;
|
||||||
_percent = _percent > 1 ? 1 : _percent;
|
_percent = _percent > 1 ? 1 : _percent;
|
||||||
for(int i = 0; i < 12; ++i)
|
for(int i = 0; i < 12; ++i)
|
||||||
{
|
{
|
||||||
cmd.motorCmd[i].q = (1 - _percent) * _startPos[i] + _percent * _targetPos[i];
|
cmd.motorCmd[i].q = (1 - _percent) * _startPos[i] + _percent * params.default_dof_pos[0][dof_mapping[i]].item<double>();
|
||||||
cmd.motorCmd[i].dq = 0;
|
cmd.motorCmd[i].dq = 0;
|
||||||
cmd.motorCmd[i].Kp = 50;
|
cmd.motorCmd[i].Kp = 50;
|
||||||
cmd.motorCmd[i].Kd = 3;
|
cmd.motorCmd[i].Kd = 3;
|
||||||
cmd.motorCmd[i].tau = 0;
|
cmd.motorCmd[i].tau = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if(_percent == 1 && !init_done)
|
// init obs and start rl loop
|
||||||
|
else if(init_state == STATE_RL_INIT && _percent == 1)
|
||||||
{
|
{
|
||||||
init_done = true;
|
init_state = STATE_RL_START;
|
||||||
loop_rl->start();
|
|
||||||
this->init_observations();
|
|
||||||
std::cout << "init done" << std::endl;
|
|
||||||
motiontime = 0;
|
motiontime = 0;
|
||||||
|
this->init_observations();
|
||||||
|
printf("\nstart rl loop\n");
|
||||||
|
loop_rl->start();
|
||||||
}
|
}
|
||||||
|
// rl loop
|
||||||
if(init_done)
|
else if(init_state == STATE_RL_START)
|
||||||
{
|
{
|
||||||
|
// wait for 500 times
|
||||||
if( motiontime < 500)
|
if( motiontime < 500)
|
||||||
{
|
{
|
||||||
for(int i = 0; i < 12; ++i)
|
for(int i = 0; i < 12; ++i)
|
||||||
{
|
{
|
||||||
cmd.motorCmd[i].q = _targetPos[i];
|
cmd.motorCmd[i].q = params.default_dof_pos[0][dof_mapping[i]].item<double>();
|
||||||
cmd.motorCmd[i].dq = 0;
|
cmd.motorCmd[i].dq = 0;
|
||||||
cmd.motorCmd[i].Kp = 50;
|
cmd.motorCmd[i].Kp = 50;
|
||||||
cmd.motorCmd[i].Kd = 3;
|
cmd.motorCmd[i].Kd = 3;
|
||||||
|
@ -62,24 +82,33 @@ void RL_Real::RobotControl()
|
||||||
}
|
}
|
||||||
if( motiontime >= 500)
|
if( motiontime >= 500)
|
||||||
{
|
{
|
||||||
int mapping[13] = {3, 4, 5, 0, 1, 2, 9, 10, 11, 6, 7, 8};
|
#ifdef CONTROL_BY_TORQUE
|
||||||
|
|
||||||
for (int i = 0; i < 12; ++i)
|
for (int i = 0; i < 12; ++i)
|
||||||
{
|
{
|
||||||
float torque = torques[0][mapping[i]].item<double>();
|
float torque = torques[0][dof_mapping[i]].item<double>();
|
||||||
// if(torque > 5.0f) torque = 5.0f;
|
// if(torque > 5.0f) torque = 5.0f;
|
||||||
// if(torque < -5.0f) torque = -5.0f;
|
// if(torque < -5.0f) torque = -5.0f;
|
||||||
|
|
||||||
cmd.motorCmd[i].q = PosStopF;
|
cmd.motorCmd[i].q = 0;
|
||||||
cmd.motorCmd[i].dq = VelStopF;
|
cmd.motorCmd[i].dq = 0;
|
||||||
cmd.motorCmd[i].Kp = 0;
|
cmd.motorCmd[i].Kp = 0;
|
||||||
cmd.motorCmd[i].Kd = 0;
|
cmd.motorCmd[i].Kd = 0;
|
||||||
cmd.motorCmd[i].tau = torque;
|
cmd.motorCmd[i].tau = torque;
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
for (int i = 0; i < 12; ++i)
|
||||||
|
{
|
||||||
|
cmd.motorCmd[i].q = target_dof_pos[0][dof_mapping[i]].item<double>();
|
||||||
|
cmd.motorCmd[i].dq = 0;
|
||||||
|
cmd.motorCmd[i].Kp = 15;
|
||||||
|
cmd.motorCmd[i].Kd = 1.5;
|
||||||
|
cmd.motorCmd[i].tau = 0;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
safe.PowerProtect(cmd, state, 1);
|
safe.PowerProtect(cmd, state, 7);
|
||||||
udp.SetSend(cmd);
|
udp.SetSend(cmd);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,8 +118,6 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL)
|
||||||
|
|
||||||
start_time = std::chrono::high_resolution_clock::now();
|
start_time = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
torque_commands.resize(12);
|
|
||||||
|
|
||||||
std::string actor_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/actor.pt";
|
std::string actor_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/actor.pt";
|
||||||
std::string encoder_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/encoder.pt";
|
std::string encoder_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/encoder.pt";
|
||||||
std::string vq_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/vq_layer.pt";
|
std::string vq_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/vq_layer.pt";
|
||||||
|
@ -103,8 +130,8 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL)
|
||||||
this->params.num_observations = 45;
|
this->params.num_observations = 45;
|
||||||
this->params.clip_obs = 100.0;
|
this->params.clip_obs = 100.0;
|
||||||
this->params.clip_actions = 100.0;
|
this->params.clip_actions = 100.0;
|
||||||
this->params.damping = 0.5;
|
this->params.damping = 1.0; // TODO
|
||||||
this->params.stiffness = 20;
|
this->params.stiffness = 5; // TODO
|
||||||
this->params.d_gains = torch::ones(12) * this->params.damping;
|
this->params.d_gains = torch::ones(12) * this->params.damping;
|
||||||
this->params.p_gains = torch::ones(12) * this->params.stiffness;
|
this->params.p_gains = torch::ones(12) * this->params.stiffness;
|
||||||
this->params.action_scale = 0.25;
|
this->params.action_scale = 0.25;
|
||||||
|
@ -116,13 +143,12 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL)
|
||||||
this->params.dof_vel_scale = 0.05;
|
this->params.dof_vel_scale = 0.05;
|
||||||
this->params.commands_scale = torch::tensor({this->params.lin_vel_scale, this->params.lin_vel_scale, this->params.ang_vel_scale});
|
this->params.commands_scale = torch::tensor({this->params.lin_vel_scale, this->params.lin_vel_scale, this->params.ang_vel_scale});
|
||||||
|
|
||||||
|
|
||||||
this->params.torque_limits = torch::tensor({{20.0, 55.0, 55.0,
|
this->params.torque_limits = torch::tensor({{20.0, 55.0, 55.0,
|
||||||
20.0, 55.0, 55.0,
|
20.0, 55.0, 55.0,
|
||||||
20.0, 55.0, 55.0,
|
20.0, 55.0, 55.0,
|
||||||
20.0, 55.0, 55.0}});
|
20.0, 55.0, 55.0}});
|
||||||
|
|
||||||
// hip, thigh, calf
|
// hip, thigh, calf
|
||||||
this->params.default_dof_pos = torch::tensor({{-0.1000, 0.8000, -1.5000, // front right
|
this->params.default_dof_pos = torch::tensor({{-0.1000, 0.8000, -1.5000, // front right
|
||||||
0.1000, 0.8000, -1.5000, // front left
|
0.1000, 0.8000, -1.5000, // front left
|
||||||
-0.1000, 1.0000, -1.5000, // rear right
|
-0.1000, 1.0000, -1.5000, // rear right
|
||||||
|
@ -130,6 +156,8 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL)
|
||||||
|
|
||||||
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
|
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
|
||||||
|
|
||||||
|
target_dof_pos = params.default_dof_pos;
|
||||||
|
|
||||||
// InitEnvironment();
|
// InitEnvironment();
|
||||||
loop_control = std::make_shared<LoopFunc>("control_loop", 0.002, boost::bind(&RL_Real::RobotControl, this));
|
loop_control = std::make_shared<LoopFunc>("control_loop", 0.002, boost::bind(&RL_Real::RobotControl, this));
|
||||||
loop_udpSend = std::make_shared<LoopFunc>("udp_send" , 0.002, 3, boost::bind(&RL_Real::UDPSend, this));
|
loop_udpSend = std::make_shared<LoopFunc>("udp_send" , 0.002, 3, boost::bind(&RL_Real::UDPSend, this));
|
||||||
|
@ -141,9 +169,30 @@ RL_Real::RL_Real() : safe(LeggedType::A1), udp(LOWLEVEL)
|
||||||
loop_control->start();
|
loop_control->start();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
RL_Real::~RL_Real()
|
||||||
|
{
|
||||||
|
loop_udpSend->shutdown();
|
||||||
|
loop_udpRecv->shutdown();
|
||||||
|
loop_control->shutdown();
|
||||||
|
loop_rl->shutdown();
|
||||||
|
printf("shutdown\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::Tensor RL_Real::compute_pos(torch::Tensor actions)
|
||||||
|
{
|
||||||
|
torch::Tensor actions_scaled = actions * this->params.action_scale;
|
||||||
|
int indices[] = {0, 3, 6, 9};
|
||||||
|
for (int i : indices)
|
||||||
|
{
|
||||||
|
actions_scaled[0][i] *= this->params.hip_scale_reduction;
|
||||||
|
}
|
||||||
|
|
||||||
|
return actions_scaled + this->params.default_dof_pos;
|
||||||
|
}
|
||||||
|
|
||||||
void RL_Real::runModel()
|
void RL_Real::runModel()
|
||||||
{
|
{
|
||||||
if(init_done)
|
if(init_state == STATE_RL_START)
|
||||||
{
|
{
|
||||||
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count();
|
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now() - start_time).count();
|
||||||
// std::cout << "Execution time: " << duration << " microseconds" << std::endl;
|
// std::cout << "Execution time: " << duration << " microseconds" << std::endl;
|
||||||
|
@ -155,7 +204,7 @@ void RL_Real::runModel()
|
||||||
// printf("%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f\n", state.motorState[FL_0].dq, state.motorState[FL_1].dq, state.motorState[FL_2].dq, state.motorState[FR_0].dq, state.motorState[FR_1].dq, state.motorState[FR_2].dq, state.motorState[RL_0].dq, state.motorState[RL_1].dq, state.motorState[RL_2].dq, state.motorState[RR_0].dq, state.motorState[RR_1].dq, state.motorState[RR_2].dq);
|
// printf("%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f\n", state.motorState[FL_0].dq, state.motorState[FL_1].dq, state.motorState[FL_2].dq, state.motorState[FR_0].dq, state.motorState[FR_1].dq, state.motorState[FR_2].dq, state.motorState[RL_0].dq, state.motorState[RL_1].dq, state.motorState[RL_2].dq, state.motorState[RR_0].dq, state.motorState[RR_1].dq, state.motorState[RR_2].dq);
|
||||||
|
|
||||||
this->obs.ang_vel = torch::tensor({{state.imu.gyroscope[0], state.imu.gyroscope[1], state.imu.gyroscope[2]}});
|
this->obs.ang_vel = torch::tensor({{state.imu.gyroscope[0], state.imu.gyroscope[1], state.imu.gyroscope[2]}});
|
||||||
// this->obs.commands = torch::tensor({{cmd_vel.linear.x, cmd_vel.linear.y, cmd_vel.angular.z}});
|
this->obs.commands = torch::tensor({{_keyData.ly, -_keyData.rx, -_keyData.lx}});
|
||||||
this->obs.base_quat = torch::tensor({{state.imu.quaternion[1], state.imu.quaternion[2], state.imu.quaternion[3], state.imu.quaternion[0]}});
|
this->obs.base_quat = torch::tensor({{state.imu.quaternion[1], state.imu.quaternion[2], state.imu.quaternion[3], state.imu.quaternion[0]}});
|
||||||
this->obs.dof_pos = torch::tensor({{state.motorState[FL_0].q, state.motorState[FL_1].q, state.motorState[FL_2].q,
|
this->obs.dof_pos = torch::tensor({{state.motorState[FL_0].q, state.motorState[FL_1].q, state.motorState[FL_2].q,
|
||||||
state.motorState[FR_0].q, state.motorState[FR_1].q, state.motorState[FR_2].q,
|
state.motorState[FR_0].q, state.motorState[FR_1].q, state.motorState[FR_2].q,
|
||||||
|
@ -167,31 +216,25 @@ void RL_Real::runModel()
|
||||||
state.motorState[RR_0].dq, state.motorState[RR_1].dq, state.motorState[RR_2].dq}});
|
state.motorState[RR_0].dq, state.motorState[RR_1].dq, state.motorState[RR_2].dq}});
|
||||||
|
|
||||||
torch::Tensor actions = this->forward();
|
torch::Tensor actions = this->forward();
|
||||||
|
#ifdef CONTROL_BY_TORQUE
|
||||||
torques = this->compute_torques(actions);
|
torques = this->compute_torques(actions);
|
||||||
|
#else
|
||||||
|
target_dof_pos = this->compute_pos(actions);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor RL_Real::compute_observation()
|
torch::Tensor RL_Real::compute_observation()
|
||||||
{
|
{
|
||||||
torch::Tensor ang_vel = this->quat_rotate_inverse(this->obs.base_quat, this->obs.ang_vel);
|
|
||||||
// float ang_vel_temp = ang_vel[0][0].item<double>();
|
|
||||||
// ang_vel[0][0] = ang_vel[0][1];
|
|
||||||
// ang_vel[0][1] = ang_vel_temp;
|
|
||||||
|
|
||||||
torch::Tensor grav = this->quat_rotate_inverse(this->obs.base_quat, this->obs.gravity_vec);
|
|
||||||
// float grav_temp = grav[0][0].item<double>();
|
|
||||||
// grav[0][0] = grav[0][1];
|
|
||||||
// grav[0][1] = grav_temp;
|
|
||||||
|
|
||||||
torch::Tensor obs = torch::cat({// (this->quat_rotate_inverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale,
|
torch::Tensor obs = torch::cat({// (this->quat_rotate_inverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale,
|
||||||
ang_vel * this->params.ang_vel_scale,
|
this->quat_rotate_inverse(this->obs.base_quat, this->obs.ang_vel) * this->params.ang_vel_scale,
|
||||||
grav,
|
this->quat_rotate_inverse(this->obs.base_quat, this->obs.gravity_vec),
|
||||||
this->obs.commands * this->params.commands_scale,
|
this->obs.commands * this->params.commands_scale,
|
||||||
(this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale,
|
(this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale,
|
||||||
this->obs.dof_vel * this->params.dof_vel_scale,
|
this->obs.dof_vel * this->params.dof_vel_scale,
|
||||||
this->obs.actions},
|
this->obs.actions
|
||||||
1);
|
},1);
|
||||||
obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
|
obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
|
||||||
return obs;
|
return obs;
|
||||||
}
|
}
|
||||||
|
@ -217,13 +260,17 @@ torch::Tensor RL_Real::forward()
|
||||||
return clamped;
|
return clamped;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void signalHandler(int signum)
|
||||||
|
{
|
||||||
|
exit(0);
|
||||||
|
}
|
||||||
|
|
||||||
int main(int argc, char **argv)
|
int main(int argc, char **argv)
|
||||||
{
|
{
|
||||||
RL_Real rl_sar;
|
signal(SIGINT, signalHandler);
|
||||||
|
|
||||||
while(1){
|
while(1)
|
||||||
|
{
|
||||||
sleep(10);
|
sleep(10);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue