mirror of https://github.com/fan-ziqi/rl_sar.git
feat: add control interface to adapt to Joy
This commit is contained in:
parent
3a9886e6ea
commit
0740bd6e42
|
@ -48,12 +48,12 @@ void RL::InitOutputs()
|
|||
this->output_dof_pos = params.default_dof_pos;
|
||||
}
|
||||
|
||||
void RL::InitKeyboard()
|
||||
void RL::InitControl()
|
||||
{
|
||||
this->keyboard.keyboard_state = STATE_WAITING;
|
||||
this->keyboard.x = 0.0;
|
||||
this->keyboard.y = 0.0;
|
||||
this->keyboard.yaw = 0.0;
|
||||
this->control.control_state = STATE_WAITING;
|
||||
this->control.x = 0.0;
|
||||
this->control.y = 0.0;
|
||||
this->control.yaw = 0.0;
|
||||
}
|
||||
|
||||
torch::Tensor RL::ComputeTorques(torch::Tensor actions)
|
||||
|
@ -89,9 +89,9 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
|
|||
{
|
||||
command->motor_command.q[i] = state->motor_state.q[i];
|
||||
}
|
||||
if(keyboard.keyboard_state == STATE_POS_GETUP)
|
||||
if(control.control_state == STATE_POS_GETUP)
|
||||
{
|
||||
keyboard.keyboard_state = STATE_WAITING;
|
||||
control.control_state = STATE_WAITING;
|
||||
getup_percent = 0.0;
|
||||
for(int i = 0; i < params.num_of_dofs; ++i)
|
||||
{
|
||||
|
@ -118,15 +118,15 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
|
|||
}
|
||||
std::cout << LOGGER::INFO << "Getting up " << std::fixed << std::setprecision(2) << getup_percent * 100.0 << "%\r";
|
||||
}
|
||||
if(keyboard.keyboard_state == STATE_RL_INIT)
|
||||
if(control.control_state == STATE_RL_INIT)
|
||||
{
|
||||
std::cout << std::endl;
|
||||
keyboard.keyboard_state = STATE_WAITING;
|
||||
control.control_state = STATE_WAITING;
|
||||
running_state = STATE_RL_INIT;
|
||||
}
|
||||
else if(keyboard.keyboard_state == STATE_POS_GETDOWN)
|
||||
else if(control.control_state == STATE_POS_GETDOWN)
|
||||
{
|
||||
keyboard.keyboard_state = STATE_WAITING;
|
||||
control.control_state = STATE_WAITING;
|
||||
getdown_percent = 0.0;
|
||||
for(int i = 0; i < params.num_of_dofs; ++i)
|
||||
{
|
||||
|
@ -143,7 +143,7 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
|
|||
running_state = STATE_RL_RUNNING;
|
||||
this->InitObservations();
|
||||
this->InitOutputs();
|
||||
this->InitKeyboard();
|
||||
this->InitControl();
|
||||
}
|
||||
}
|
||||
// rl loop
|
||||
|
@ -157,9 +157,9 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
|
|||
command->motor_command.kd[i] = params.rl_kd[0][i].item<double>();
|
||||
command->motor_command.tau[i] = 0;
|
||||
}
|
||||
if(keyboard.keyboard_state == STATE_POS_GETDOWN)
|
||||
if(control.control_state == STATE_POS_GETDOWN)
|
||||
{
|
||||
keyboard.keyboard_state = STATE_WAITING;
|
||||
control.control_state = STATE_WAITING;
|
||||
getdown_percent = 0.0;
|
||||
for(int i = 0; i < params.num_of_dofs; ++i)
|
||||
{
|
||||
|
@ -191,7 +191,7 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
|
|||
running_state = STATE_WAITING;
|
||||
this->InitObservations();
|
||||
this->InitOutputs();
|
||||
this->InitKeyboard();
|
||||
this->InitControl();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -221,10 +221,10 @@ void RL::TorqueProtect(torch::Tensor origin_output_torques)
|
|||
double limit_lower = -this->params.torque_limits[0][index].item<double>();
|
||||
double limit_upper = this->params.torque_limits[0][index].item<double>();
|
||||
|
||||
std::cout << LOGGER::ERROR << "Torque(" << i+1 << ")=" << value << " out of range(" << limit_lower << ", " << limit_upper << ")" << std::endl;
|
||||
std::cout << LOGGER::ERROR << "Torque(" << index+1 << ")=" << value << " out of range(" << limit_lower << ", " << limit_upper << ")" << std::endl;
|
||||
std::cout << LOGGER::ERROR << "Switching to STATE_POS_GETDOWN"<< std::endl;
|
||||
}
|
||||
keyboard.keyboard_state = STATE_POS_GETDOWN;
|
||||
control.control_state = STATE_POS_GETDOWN;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -247,11 +247,11 @@ static bool kbhit()
|
|||
return byteswaiting > 0;
|
||||
}
|
||||
|
||||
void RL::RunKeyboard()
|
||||
void RL::KeyboardInterface()
|
||||
{
|
||||
if(running_state == STATE_RL_RUNNING)
|
||||
{
|
||||
std::cout << LOGGER::INFO << "RL Controller x:" << keyboard.x << " y:" << keyboard.y << " yaw:" << keyboard.yaw << " \r";
|
||||
std::cout << LOGGER::INFO << "RL Controller x:" << control.x << " y:" << control.y << " yaw:" << control.yaw << " \r";
|
||||
}
|
||||
|
||||
if(kbhit())
|
||||
|
@ -259,20 +259,20 @@ void RL::RunKeyboard()
|
|||
int c = fgetc(stdin);
|
||||
switch(c)
|
||||
{
|
||||
case '0': keyboard.keyboard_state = STATE_POS_GETUP; break;
|
||||
case 'p': keyboard.keyboard_state = STATE_RL_INIT; break;
|
||||
case '1': keyboard.keyboard_state = STATE_POS_GETDOWN; break;
|
||||
case '0': control.control_state = STATE_POS_GETUP; break;
|
||||
case 'p': control.control_state = STATE_RL_INIT; break;
|
||||
case '1': control.control_state = STATE_POS_GETDOWN; break;
|
||||
case 'q': break;
|
||||
case 'w': keyboard.x += 0.1; break;
|
||||
case 's': keyboard.x -= 0.1; break;
|
||||
case 'a': keyboard.yaw += 0.1; break;
|
||||
case 'd': keyboard.yaw -= 0.1; break;
|
||||
case 'w': control.x += 0.1; break;
|
||||
case 's': control.x -= 0.1; break;
|
||||
case 'a': control.yaw += 0.1; break;
|
||||
case 'd': control.yaw -= 0.1; break;
|
||||
case 'i': break;
|
||||
case 'k': break;
|
||||
case 'j': keyboard.y += 0.1; break;
|
||||
case 'l': keyboard.y -= 0.1; break;
|
||||
case ' ': keyboard.x = 0; keyboard.y = 0; keyboard.yaw = 0; break;
|
||||
case 'r': keyboard.keyboard_state = STATE_RESET_SIMULATION; break;
|
||||
case 'j': control.y += 0.1; break;
|
||||
case 'l': control.y -= 0.1; break;
|
||||
case ' ': control.x = 0; control.y = 0; control.yaw = 0; break;
|
||||
case 'r': control.control_state = STATE_RESET_SIMULATION; break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -58,9 +58,9 @@ enum STATE {
|
|||
STATE_RESET_SIMULATION,
|
||||
};
|
||||
|
||||
struct KeyBoard
|
||||
struct Control
|
||||
{
|
||||
STATE keyboard_state;
|
||||
STATE control_state;
|
||||
double x = 0.0;
|
||||
double y = 0.0;
|
||||
double yaw = 0.0;
|
||||
|
@ -120,7 +120,7 @@ public:
|
|||
// init
|
||||
void InitObservations();
|
||||
void InitOutputs();
|
||||
void InitKeyboard();
|
||||
void InitControl();
|
||||
|
||||
// rl functions
|
||||
virtual torch::Tensor Forward() = 0;
|
||||
|
@ -140,9 +140,9 @@ public:
|
|||
void CSVInit(std::string robot_name);
|
||||
void CSVLogger(torch::Tensor torque, torch::Tensor tau_est, torch::Tensor joint_pos, torch::Tensor joint_pos_target, torch::Tensor joint_vel);
|
||||
|
||||
// keyboard
|
||||
KeyBoard keyboard;
|
||||
void RunKeyboard();
|
||||
// control
|
||||
Control control;
|
||||
void KeyboardInterface();
|
||||
|
||||
// others
|
||||
std::string robot_name;
|
||||
|
|
|
@ -22,18 +22,18 @@ RL_Real::RL_Real() : unitree_safe(UNITREE_LEGGED_SDK::LeggedType::A1), unitree_u
|
|||
now_pos.resize(params.num_of_dofs);
|
||||
this->InitObservations();
|
||||
this->InitOutputs();
|
||||
this->InitKeyboard();
|
||||
this->InitControl();
|
||||
|
||||
// model
|
||||
std::string model_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + robot_name + "/" + this->params.model_name;
|
||||
this->model = torch::jit::load(model_path);
|
||||
|
||||
// loop
|
||||
loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05 , boost::bind(&RL_Real::RunKeyboard, this));
|
||||
loop_control = std::make_shared<LoopFunc>("loop_control" , 0.002, boost::bind(&RL_Real::RobotControl, this));
|
||||
loop_udpSend = std::make_shared<LoopFunc>("loop_udpSend" , 0.002, 3, boost::bind(&RL_Real::UDPSend, this));
|
||||
loop_udpRecv = std::make_shared<LoopFunc>("loop_udpRecv" , 0.002, 3, boost::bind(&RL_Real::UDPRecv, this));
|
||||
loop_rl = std::make_shared<LoopFunc>("loop_rl" , 0.02 , boost::bind(&RL_Real::RunModel, this));
|
||||
loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05 , boost::bind(&RL_Real::KeyboardInterface, this));
|
||||
loop_control = std::make_shared<LoopFunc>("loop_control" , 0.002, boost::bind(&RL_Real::RobotControl , this));
|
||||
loop_udpSend = std::make_shared<LoopFunc>("loop_udpSend" , 0.002, 3, boost::bind(&RL_Real::UDPSend , this));
|
||||
loop_udpRecv = std::make_shared<LoopFunc>("loop_udpRecv" , 0.002, 3, boost::bind(&RL_Real::UDPRecv , this));
|
||||
loop_rl = std::make_shared<LoopFunc>("loop_rl" , 0.02 , boost::bind(&RL_Real::RunModel , this));
|
||||
loop_keyboard->start();
|
||||
loop_udpSend->start();
|
||||
loop_udpRecv->start();
|
||||
|
@ -74,15 +74,15 @@ void RL_Real::GetState(RobotState<double> *state)
|
|||
|
||||
if((int)unitree_joy.btn.components.R2 == 1)
|
||||
{
|
||||
keyboard.keyboard_state = STATE_POS_GETUP;
|
||||
control.control_state = STATE_POS_GETUP;
|
||||
}
|
||||
else if((int)unitree_joy.btn.components.R1 == 1)
|
||||
{
|
||||
keyboard.keyboard_state = STATE_RL_INIT;
|
||||
control.control_state = STATE_RL_INIT;
|
||||
}
|
||||
else if((int)unitree_joy.btn.components.L2 == 1)
|
||||
{
|
||||
keyboard.keyboard_state = STATE_POS_GETDOWN;
|
||||
control.control_state = STATE_POS_GETDOWN;
|
||||
}
|
||||
|
||||
state->imu.quaternion[3] = unitree_low_state.imu.quaternion[0]; // w
|
||||
|
|
|
@ -37,7 +37,7 @@ RL_Sim::RL_Sim()
|
|||
now_pos.resize(params.num_of_dofs);
|
||||
this->InitObservations();
|
||||
this->InitOutputs();
|
||||
this->InitKeyboard();
|
||||
this->InitControl();
|
||||
|
||||
// model
|
||||
std::string model_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + robot_name + "/" + this->params.model_name;
|
||||
|
@ -61,9 +61,9 @@ RL_Sim::RL_Sim()
|
|||
gazebo_reset_client = nh.serviceClient<std_srvs::Empty>("/gazebo/reset_simulation");
|
||||
|
||||
// loop
|
||||
loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05 , boost::bind(&RL_Sim::RunKeyboard, this));
|
||||
loop_control = std::make_shared<LoopFunc>("loop_control" , 0.002, boost::bind(&RL_Sim::RobotControl, this));
|
||||
loop_rl = std::make_shared<LoopFunc>("loop_rl" , 0.02 , boost::bind(&RL_Sim::RunModel, this));
|
||||
loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05 , boost::bind(&RL_Sim::KeyboardInterface, this));
|
||||
loop_control = std::make_shared<LoopFunc>("loop_control" , 0.002, boost::bind(&RL_Sim::RobotControl , this));
|
||||
loop_rl = std::make_shared<LoopFunc>("loop_rl" , 0.02 , boost::bind(&RL_Sim::RunModel , this));
|
||||
loop_keyboard->start();
|
||||
loop_control->start();
|
||||
loop_rl->start();
|
||||
|
@ -135,9 +135,9 @@ void RL_Sim::RobotControl()
|
|||
{
|
||||
motiontime++;
|
||||
|
||||
if(keyboard.keyboard_state == STATE_RESET_SIMULATION)
|
||||
if(control.control_state == STATE_RESET_SIMULATION)
|
||||
{
|
||||
keyboard.keyboard_state = STATE_WAITING;
|
||||
control.control_state = STATE_WAITING;
|
||||
std_srvs::Empty srv;
|
||||
gazebo_reset_client.call(srv);
|
||||
}
|
||||
|
@ -180,7 +180,7 @@ void RL_Sim::RunModel()
|
|||
// this->obs.lin_vel = torch::tensor({{vel.linear.x, vel.linear.y, vel.linear.z}});
|
||||
this->obs.ang_vel = torch::tensor(robot_state.imu.gyroscope).unsqueeze(0);
|
||||
// this->obs.commands = torch::tensor({{cmd_vel.linear.x, cmd_vel.linear.y, cmd_vel.angular.z}});
|
||||
this->obs.commands = torch::tensor({{keyboard.x, keyboard.y, keyboard.yaw}});
|
||||
this->obs.commands = torch::tensor({{control.x, control.y, control.yaw}});
|
||||
this->obs.base_quat = torch::tensor(robot_state.imu.quaternion).unsqueeze(0);
|
||||
this->obs.dof_pos = torch::tensor(robot_state.motor_state.q).narrow(0, 0, params.num_of_dofs).unsqueeze(0);
|
||||
this->obs.dof_vel = torch::tensor(robot_state.motor_state.dq).narrow(0, 0, params.num_of_dofs).unsqueeze(0);
|
||||
|
|
Loading…
Reference in New Issue