mirror of https://github.com/fan-ziqi/rl_sar.git
feat: add control interface to adapt to Joy
This commit is contained in:
parent
3a9886e6ea
commit
0740bd6e42
|
@ -48,12 +48,12 @@ void RL::InitOutputs()
|
||||||
this->output_dof_pos = params.default_dof_pos;
|
this->output_dof_pos = params.default_dof_pos;
|
||||||
}
|
}
|
||||||
|
|
||||||
void RL::InitKeyboard()
|
void RL::InitControl()
|
||||||
{
|
{
|
||||||
this->keyboard.keyboard_state = STATE_WAITING;
|
this->control.control_state = STATE_WAITING;
|
||||||
this->keyboard.x = 0.0;
|
this->control.x = 0.0;
|
||||||
this->keyboard.y = 0.0;
|
this->control.y = 0.0;
|
||||||
this->keyboard.yaw = 0.0;
|
this->control.yaw = 0.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
torch::Tensor RL::ComputeTorques(torch::Tensor actions)
|
torch::Tensor RL::ComputeTorques(torch::Tensor actions)
|
||||||
|
@ -89,9 +89,9 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
|
||||||
{
|
{
|
||||||
command->motor_command.q[i] = state->motor_state.q[i];
|
command->motor_command.q[i] = state->motor_state.q[i];
|
||||||
}
|
}
|
||||||
if(keyboard.keyboard_state == STATE_POS_GETUP)
|
if(control.control_state == STATE_POS_GETUP)
|
||||||
{
|
{
|
||||||
keyboard.keyboard_state = STATE_WAITING;
|
control.control_state = STATE_WAITING;
|
||||||
getup_percent = 0.0;
|
getup_percent = 0.0;
|
||||||
for(int i = 0; i < params.num_of_dofs; ++i)
|
for(int i = 0; i < params.num_of_dofs; ++i)
|
||||||
{
|
{
|
||||||
|
@ -118,15 +118,15 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
|
||||||
}
|
}
|
||||||
std::cout << LOGGER::INFO << "Getting up " << std::fixed << std::setprecision(2) << getup_percent * 100.0 << "%\r";
|
std::cout << LOGGER::INFO << "Getting up " << std::fixed << std::setprecision(2) << getup_percent * 100.0 << "%\r";
|
||||||
}
|
}
|
||||||
if(keyboard.keyboard_state == STATE_RL_INIT)
|
if(control.control_state == STATE_RL_INIT)
|
||||||
{
|
{
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
keyboard.keyboard_state = STATE_WAITING;
|
control.control_state = STATE_WAITING;
|
||||||
running_state = STATE_RL_INIT;
|
running_state = STATE_RL_INIT;
|
||||||
}
|
}
|
||||||
else if(keyboard.keyboard_state == STATE_POS_GETDOWN)
|
else if(control.control_state == STATE_POS_GETDOWN)
|
||||||
{
|
{
|
||||||
keyboard.keyboard_state = STATE_WAITING;
|
control.control_state = STATE_WAITING;
|
||||||
getdown_percent = 0.0;
|
getdown_percent = 0.0;
|
||||||
for(int i = 0; i < params.num_of_dofs; ++i)
|
for(int i = 0; i < params.num_of_dofs; ++i)
|
||||||
{
|
{
|
||||||
|
@ -143,7 +143,7 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
|
||||||
running_state = STATE_RL_RUNNING;
|
running_state = STATE_RL_RUNNING;
|
||||||
this->InitObservations();
|
this->InitObservations();
|
||||||
this->InitOutputs();
|
this->InitOutputs();
|
||||||
this->InitKeyboard();
|
this->InitControl();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// rl loop
|
// rl loop
|
||||||
|
@ -157,9 +157,9 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
|
||||||
command->motor_command.kd[i] = params.rl_kd[0][i].item<double>();
|
command->motor_command.kd[i] = params.rl_kd[0][i].item<double>();
|
||||||
command->motor_command.tau[i] = 0;
|
command->motor_command.tau[i] = 0;
|
||||||
}
|
}
|
||||||
if(keyboard.keyboard_state == STATE_POS_GETDOWN)
|
if(control.control_state == STATE_POS_GETDOWN)
|
||||||
{
|
{
|
||||||
keyboard.keyboard_state = STATE_WAITING;
|
control.control_state = STATE_WAITING;
|
||||||
getdown_percent = 0.0;
|
getdown_percent = 0.0;
|
||||||
for(int i = 0; i < params.num_of_dofs; ++i)
|
for(int i = 0; i < params.num_of_dofs; ++i)
|
||||||
{
|
{
|
||||||
|
@ -191,7 +191,7 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
|
||||||
running_state = STATE_WAITING;
|
running_state = STATE_WAITING;
|
||||||
this->InitObservations();
|
this->InitObservations();
|
||||||
this->InitOutputs();
|
this->InitOutputs();
|
||||||
this->InitKeyboard();
|
this->InitControl();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -221,10 +221,10 @@ void RL::TorqueProtect(torch::Tensor origin_output_torques)
|
||||||
double limit_lower = -this->params.torque_limits[0][index].item<double>();
|
double limit_lower = -this->params.torque_limits[0][index].item<double>();
|
||||||
double limit_upper = this->params.torque_limits[0][index].item<double>();
|
double limit_upper = this->params.torque_limits[0][index].item<double>();
|
||||||
|
|
||||||
std::cout << LOGGER::ERROR << "Torque(" << i+1 << ")=" << value << " out of range(" << limit_lower << ", " << limit_upper << ")" << std::endl;
|
std::cout << LOGGER::ERROR << "Torque(" << index+1 << ")=" << value << " out of range(" << limit_lower << ", " << limit_upper << ")" << std::endl;
|
||||||
std::cout << LOGGER::ERROR << "Switching to STATE_POS_GETDOWN"<< std::endl;
|
std::cout << LOGGER::ERROR << "Switching to STATE_POS_GETDOWN"<< std::endl;
|
||||||
}
|
}
|
||||||
keyboard.keyboard_state = STATE_POS_GETDOWN;
|
control.control_state = STATE_POS_GETDOWN;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -247,11 +247,11 @@ static bool kbhit()
|
||||||
return byteswaiting > 0;
|
return byteswaiting > 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void RL::RunKeyboard()
|
void RL::KeyboardInterface()
|
||||||
{
|
{
|
||||||
if(running_state == STATE_RL_RUNNING)
|
if(running_state == STATE_RL_RUNNING)
|
||||||
{
|
{
|
||||||
std::cout << LOGGER::INFO << "RL Controller x:" << keyboard.x << " y:" << keyboard.y << " yaw:" << keyboard.yaw << " \r";
|
std::cout << LOGGER::INFO << "RL Controller x:" << control.x << " y:" << control.y << " yaw:" << control.yaw << " \r";
|
||||||
}
|
}
|
||||||
|
|
||||||
if(kbhit())
|
if(kbhit())
|
||||||
|
@ -259,20 +259,20 @@ void RL::RunKeyboard()
|
||||||
int c = fgetc(stdin);
|
int c = fgetc(stdin);
|
||||||
switch(c)
|
switch(c)
|
||||||
{
|
{
|
||||||
case '0': keyboard.keyboard_state = STATE_POS_GETUP; break;
|
case '0': control.control_state = STATE_POS_GETUP; break;
|
||||||
case 'p': keyboard.keyboard_state = STATE_RL_INIT; break;
|
case 'p': control.control_state = STATE_RL_INIT; break;
|
||||||
case '1': keyboard.keyboard_state = STATE_POS_GETDOWN; break;
|
case '1': control.control_state = STATE_POS_GETDOWN; break;
|
||||||
case 'q': break;
|
case 'q': break;
|
||||||
case 'w': keyboard.x += 0.1; break;
|
case 'w': control.x += 0.1; break;
|
||||||
case 's': keyboard.x -= 0.1; break;
|
case 's': control.x -= 0.1; break;
|
||||||
case 'a': keyboard.yaw += 0.1; break;
|
case 'a': control.yaw += 0.1; break;
|
||||||
case 'd': keyboard.yaw -= 0.1; break;
|
case 'd': control.yaw -= 0.1; break;
|
||||||
case 'i': break;
|
case 'i': break;
|
||||||
case 'k': break;
|
case 'k': break;
|
||||||
case 'j': keyboard.y += 0.1; break;
|
case 'j': control.y += 0.1; break;
|
||||||
case 'l': keyboard.y -= 0.1; break;
|
case 'l': control.y -= 0.1; break;
|
||||||
case ' ': keyboard.x = 0; keyboard.y = 0; keyboard.yaw = 0; break;
|
case ' ': control.x = 0; control.y = 0; control.yaw = 0; break;
|
||||||
case 'r': keyboard.keyboard_state = STATE_RESET_SIMULATION; break;
|
case 'r': control.control_state = STATE_RESET_SIMULATION; break;
|
||||||
default: break;
|
default: break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -58,9 +58,9 @@ enum STATE {
|
||||||
STATE_RESET_SIMULATION,
|
STATE_RESET_SIMULATION,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct KeyBoard
|
struct Control
|
||||||
{
|
{
|
||||||
STATE keyboard_state;
|
STATE control_state;
|
||||||
double x = 0.0;
|
double x = 0.0;
|
||||||
double y = 0.0;
|
double y = 0.0;
|
||||||
double yaw = 0.0;
|
double yaw = 0.0;
|
||||||
|
@ -120,7 +120,7 @@ public:
|
||||||
// init
|
// init
|
||||||
void InitObservations();
|
void InitObservations();
|
||||||
void InitOutputs();
|
void InitOutputs();
|
||||||
void InitKeyboard();
|
void InitControl();
|
||||||
|
|
||||||
// rl functions
|
// rl functions
|
||||||
virtual torch::Tensor Forward() = 0;
|
virtual torch::Tensor Forward() = 0;
|
||||||
|
@ -140,9 +140,9 @@ public:
|
||||||
void CSVInit(std::string robot_name);
|
void CSVInit(std::string robot_name);
|
||||||
void CSVLogger(torch::Tensor torque, torch::Tensor tau_est, torch::Tensor joint_pos, torch::Tensor joint_pos_target, torch::Tensor joint_vel);
|
void CSVLogger(torch::Tensor torque, torch::Tensor tau_est, torch::Tensor joint_pos, torch::Tensor joint_pos_target, torch::Tensor joint_vel);
|
||||||
|
|
||||||
// keyboard
|
// control
|
||||||
KeyBoard keyboard;
|
Control control;
|
||||||
void RunKeyboard();
|
void KeyboardInterface();
|
||||||
|
|
||||||
// others
|
// others
|
||||||
std::string robot_name;
|
std::string robot_name;
|
||||||
|
|
|
@ -22,18 +22,18 @@ RL_Real::RL_Real() : unitree_safe(UNITREE_LEGGED_SDK::LeggedType::A1), unitree_u
|
||||||
now_pos.resize(params.num_of_dofs);
|
now_pos.resize(params.num_of_dofs);
|
||||||
this->InitObservations();
|
this->InitObservations();
|
||||||
this->InitOutputs();
|
this->InitOutputs();
|
||||||
this->InitKeyboard();
|
this->InitControl();
|
||||||
|
|
||||||
// model
|
// model
|
||||||
std::string model_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + robot_name + "/" + this->params.model_name;
|
std::string model_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + robot_name + "/" + this->params.model_name;
|
||||||
this->model = torch::jit::load(model_path);
|
this->model = torch::jit::load(model_path);
|
||||||
|
|
||||||
// loop
|
// loop
|
||||||
loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05 , boost::bind(&RL_Real::RunKeyboard, this));
|
loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05 , boost::bind(&RL_Real::KeyboardInterface, this));
|
||||||
loop_control = std::make_shared<LoopFunc>("loop_control" , 0.002, boost::bind(&RL_Real::RobotControl, this));
|
loop_control = std::make_shared<LoopFunc>("loop_control" , 0.002, boost::bind(&RL_Real::RobotControl , this));
|
||||||
loop_udpSend = std::make_shared<LoopFunc>("loop_udpSend" , 0.002, 3, boost::bind(&RL_Real::UDPSend, this));
|
loop_udpSend = std::make_shared<LoopFunc>("loop_udpSend" , 0.002, 3, boost::bind(&RL_Real::UDPSend , this));
|
||||||
loop_udpRecv = std::make_shared<LoopFunc>("loop_udpRecv" , 0.002, 3, boost::bind(&RL_Real::UDPRecv, this));
|
loop_udpRecv = std::make_shared<LoopFunc>("loop_udpRecv" , 0.002, 3, boost::bind(&RL_Real::UDPRecv , this));
|
||||||
loop_rl = std::make_shared<LoopFunc>("loop_rl" , 0.02 , boost::bind(&RL_Real::RunModel, this));
|
loop_rl = std::make_shared<LoopFunc>("loop_rl" , 0.02 , boost::bind(&RL_Real::RunModel , this));
|
||||||
loop_keyboard->start();
|
loop_keyboard->start();
|
||||||
loop_udpSend->start();
|
loop_udpSend->start();
|
||||||
loop_udpRecv->start();
|
loop_udpRecv->start();
|
||||||
|
@ -74,15 +74,15 @@ void RL_Real::GetState(RobotState<double> *state)
|
||||||
|
|
||||||
if((int)unitree_joy.btn.components.R2 == 1)
|
if((int)unitree_joy.btn.components.R2 == 1)
|
||||||
{
|
{
|
||||||
keyboard.keyboard_state = STATE_POS_GETUP;
|
control.control_state = STATE_POS_GETUP;
|
||||||
}
|
}
|
||||||
else if((int)unitree_joy.btn.components.R1 == 1)
|
else if((int)unitree_joy.btn.components.R1 == 1)
|
||||||
{
|
{
|
||||||
keyboard.keyboard_state = STATE_RL_INIT;
|
control.control_state = STATE_RL_INIT;
|
||||||
}
|
}
|
||||||
else if((int)unitree_joy.btn.components.L2 == 1)
|
else if((int)unitree_joy.btn.components.L2 == 1)
|
||||||
{
|
{
|
||||||
keyboard.keyboard_state = STATE_POS_GETDOWN;
|
control.control_state = STATE_POS_GETDOWN;
|
||||||
}
|
}
|
||||||
|
|
||||||
state->imu.quaternion[3] = unitree_low_state.imu.quaternion[0]; // w
|
state->imu.quaternion[3] = unitree_low_state.imu.quaternion[0]; // w
|
||||||
|
|
|
@ -37,7 +37,7 @@ RL_Sim::RL_Sim()
|
||||||
now_pos.resize(params.num_of_dofs);
|
now_pos.resize(params.num_of_dofs);
|
||||||
this->InitObservations();
|
this->InitObservations();
|
||||||
this->InitOutputs();
|
this->InitOutputs();
|
||||||
this->InitKeyboard();
|
this->InitControl();
|
||||||
|
|
||||||
// model
|
// model
|
||||||
std::string model_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + robot_name + "/" + this->params.model_name;
|
std::string model_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + robot_name + "/" + this->params.model_name;
|
||||||
|
@ -61,9 +61,9 @@ RL_Sim::RL_Sim()
|
||||||
gazebo_reset_client = nh.serviceClient<std_srvs::Empty>("/gazebo/reset_simulation");
|
gazebo_reset_client = nh.serviceClient<std_srvs::Empty>("/gazebo/reset_simulation");
|
||||||
|
|
||||||
// loop
|
// loop
|
||||||
loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05 , boost::bind(&RL_Sim::RunKeyboard, this));
|
loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05 , boost::bind(&RL_Sim::KeyboardInterface, this));
|
||||||
loop_control = std::make_shared<LoopFunc>("loop_control" , 0.002, boost::bind(&RL_Sim::RobotControl, this));
|
loop_control = std::make_shared<LoopFunc>("loop_control" , 0.002, boost::bind(&RL_Sim::RobotControl , this));
|
||||||
loop_rl = std::make_shared<LoopFunc>("loop_rl" , 0.02 , boost::bind(&RL_Sim::RunModel, this));
|
loop_rl = std::make_shared<LoopFunc>("loop_rl" , 0.02 , boost::bind(&RL_Sim::RunModel , this));
|
||||||
loop_keyboard->start();
|
loop_keyboard->start();
|
||||||
loop_control->start();
|
loop_control->start();
|
||||||
loop_rl->start();
|
loop_rl->start();
|
||||||
|
@ -135,9 +135,9 @@ void RL_Sim::RobotControl()
|
||||||
{
|
{
|
||||||
motiontime++;
|
motiontime++;
|
||||||
|
|
||||||
if(keyboard.keyboard_state == STATE_RESET_SIMULATION)
|
if(control.control_state == STATE_RESET_SIMULATION)
|
||||||
{
|
{
|
||||||
keyboard.keyboard_state = STATE_WAITING;
|
control.control_state = STATE_WAITING;
|
||||||
std_srvs::Empty srv;
|
std_srvs::Empty srv;
|
||||||
gazebo_reset_client.call(srv);
|
gazebo_reset_client.call(srv);
|
||||||
}
|
}
|
||||||
|
@ -180,7 +180,7 @@ void RL_Sim::RunModel()
|
||||||
// this->obs.lin_vel = torch::tensor({{vel.linear.x, vel.linear.y, vel.linear.z}});
|
// this->obs.lin_vel = torch::tensor({{vel.linear.x, vel.linear.y, vel.linear.z}});
|
||||||
this->obs.ang_vel = torch::tensor(robot_state.imu.gyroscope).unsqueeze(0);
|
this->obs.ang_vel = torch::tensor(robot_state.imu.gyroscope).unsqueeze(0);
|
||||||
// this->obs.commands = torch::tensor({{cmd_vel.linear.x, cmd_vel.linear.y, cmd_vel.angular.z}});
|
// this->obs.commands = torch::tensor({{cmd_vel.linear.x, cmd_vel.linear.y, cmd_vel.angular.z}});
|
||||||
this->obs.commands = torch::tensor({{keyboard.x, keyboard.y, keyboard.yaw}});
|
this->obs.commands = torch::tensor({{control.x, control.y, control.yaw}});
|
||||||
this->obs.base_quat = torch::tensor(robot_state.imu.quaternion).unsqueeze(0);
|
this->obs.base_quat = torch::tensor(robot_state.imu.quaternion).unsqueeze(0);
|
||||||
this->obs.dof_pos = torch::tensor(robot_state.motor_state.q).narrow(0, 0, params.num_of_dofs).unsqueeze(0);
|
this->obs.dof_pos = torch::tensor(robot_state.motor_state.q).narrow(0, 0, params.num_of_dofs).unsqueeze(0);
|
||||||
this->obs.dof_vel = torch::tensor(robot_state.motor_state.dq).narrow(0, 0, params.num_of_dofs).unsqueeze(0);
|
this->obs.dof_vel = torch::tensor(robot_state.motor_state.dq).narrow(0, 0, params.num_of_dofs).unsqueeze(0);
|
||||||
|
|
Loading…
Reference in New Issue