mirror of https://github.com/fan-ziqi/rl_sar.git
fix: add this to class member
This commit is contained in:
parent
be79e7f64a
commit
8327b52d01
|
@ -51,12 +51,12 @@ private:
|
|||
geometry_msgs::Twist vel;
|
||||
geometry_msgs::Pose pose;
|
||||
geometry_msgs::Twist cmd_vel;
|
||||
std::vector<std::string> torque_command_topics;
|
||||
ros::Subscriber model_state_subscriber;
|
||||
ros::Subscriber joint_state_subscriber;
|
||||
ros::Subscriber cmd_vel_subscriber;
|
||||
std::map<std::string, ros::Publisher> torque_publishers;
|
||||
ros::ServiceClient gazebo_reset_client;
|
||||
std::map<std::string, ros::Publisher> joint_publishers;
|
||||
std::vector<robot_msgs::MotorCommand> joint_publishers_commands;
|
||||
void ModelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg);
|
||||
void JointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg);
|
||||
void CmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg);
|
||||
|
@ -64,7 +64,6 @@ private:
|
|||
// others
|
||||
int motiontime = 0;
|
||||
std::map<std::string, size_t> sorted_to_original_index;
|
||||
std::vector<robot_msgs::MotorCommand> motor_commands;
|
||||
std::vector<double> mapped_joint_positions;
|
||||
std::vector<double> mapped_joint_velocities;
|
||||
std::vector<double> mapped_joint_efforts;
|
||||
|
|
|
@ -38,14 +38,14 @@ void RL::InitObservations()
|
|||
this->obs.commands = torch::tensor({{0.0, 0.0, 0.0}});
|
||||
this->obs.base_quat = torch::tensor({{0.0, 0.0, 0.0, 1.0}});
|
||||
this->obs.dof_pos = this->params.default_dof_pos;
|
||||
this->obs.dof_vel = torch::zeros({1, params.num_of_dofs});
|
||||
this->obs.actions = torch::zeros({1, params.num_of_dofs});
|
||||
this->obs.dof_vel = torch::zeros({1, this->params.num_of_dofs});
|
||||
this->obs.actions = torch::zeros({1, this->params.num_of_dofs});
|
||||
}
|
||||
|
||||
void RL::InitOutputs()
|
||||
{
|
||||
this->output_torques = torch::zeros({1, params.num_of_dofs});
|
||||
this->output_dof_pos = params.default_dof_pos;
|
||||
this->output_torques = torch::zeros({1, this->params.num_of_dofs});
|
||||
this->output_dof_pos = this->params.default_dof_pos;
|
||||
}
|
||||
|
||||
void RL::InitControl()
|
||||
|
@ -88,104 +88,104 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
|
|||
static float getdown_percent = 0.0;
|
||||
|
||||
// waiting
|
||||
if(running_state == STATE_WAITING)
|
||||
if(this->running_state == STATE_WAITING)
|
||||
{
|
||||
for(int i = 0; i < params.num_of_dofs; ++i)
|
||||
for(int i = 0; i < this->params.num_of_dofs; ++i)
|
||||
{
|
||||
command->motor_command.q[i] = state->motor_state.q[i];
|
||||
}
|
||||
if(control.control_state == STATE_POS_GETUP)
|
||||
if(this->control.control_state == STATE_POS_GETUP)
|
||||
{
|
||||
control.control_state = STATE_WAITING;
|
||||
this->control.control_state = STATE_WAITING;
|
||||
getup_percent = 0.0;
|
||||
for(int i = 0; i < params.num_of_dofs; ++i)
|
||||
for(int i = 0; i < this->params.num_of_dofs; ++i)
|
||||
{
|
||||
now_state.motor_state.q[i] = state->motor_state.q[i];
|
||||
start_state.motor_state.q[i] = now_state.motor_state.q[i];
|
||||
}
|
||||
running_state = STATE_POS_GETUP;
|
||||
this->running_state = STATE_POS_GETUP;
|
||||
}
|
||||
}
|
||||
// stand up (position control)
|
||||
else if(running_state == STATE_POS_GETUP)
|
||||
else if(this->running_state == STATE_POS_GETUP)
|
||||
{
|
||||
if(getup_percent < 1.0)
|
||||
{
|
||||
getup_percent += 1 / 1000.0;
|
||||
getup_percent = getup_percent > 1.0 ? 1.0 : getup_percent;
|
||||
for(int i = 0; i < params.num_of_dofs; ++i)
|
||||
for(int i = 0; i < this->params.num_of_dofs; ++i)
|
||||
{
|
||||
command->motor_command.q[i] = (1 - getup_percent) * now_state.motor_state.q[i] + getup_percent * params.default_dof_pos[0][i].item<double>();
|
||||
command->motor_command.q[i] = (1 - getup_percent) * now_state.motor_state.q[i] + getup_percent * this->params.default_dof_pos[0][i].item<double>();
|
||||
command->motor_command.dq[i] = 0;
|
||||
command->motor_command.kp[i] = params.fixed_kp[0][i].item<double>();
|
||||
command->motor_command.kd[i] = params.fixed_kd[0][i].item<double>();
|
||||
command->motor_command.kp[i] = this->params.fixed_kp[0][i].item<double>();
|
||||
command->motor_command.kd[i] = this->params.fixed_kd[0][i].item<double>();
|
||||
command->motor_command.tau[i] = 0;
|
||||
}
|
||||
std::cout << LOGGER::INFO << "Getting up " << std::fixed << std::setprecision(2) << getup_percent * 100.0 << "%\r";
|
||||
}
|
||||
if(control.control_state == STATE_RL_INIT)
|
||||
if(this->control.control_state == STATE_RL_INIT)
|
||||
{
|
||||
std::cout << std::endl;
|
||||
control.control_state = STATE_WAITING;
|
||||
running_state = STATE_RL_INIT;
|
||||
this->control.control_state = STATE_WAITING;
|
||||
this->running_state = STATE_RL_INIT;
|
||||
}
|
||||
else if(control.control_state == STATE_POS_GETDOWN)
|
||||
else if(this->control.control_state == STATE_POS_GETDOWN)
|
||||
{
|
||||
control.control_state = STATE_WAITING;
|
||||
this->control.control_state = STATE_WAITING;
|
||||
getdown_percent = 0.0;
|
||||
for(int i = 0; i < params.num_of_dofs; ++i)
|
||||
for(int i = 0; i < this->params.num_of_dofs; ++i)
|
||||
{
|
||||
now_state.motor_state.q[i] = state->motor_state.q[i];
|
||||
}
|
||||
running_state = STATE_POS_GETDOWN;
|
||||
this->running_state = STATE_POS_GETDOWN;
|
||||
}
|
||||
}
|
||||
// init obs and start rl loop
|
||||
else if(running_state == STATE_RL_INIT)
|
||||
else if(this->running_state == STATE_RL_INIT)
|
||||
{
|
||||
if(getup_percent == 1)
|
||||
{
|
||||
running_state = STATE_RL_RUNNING;
|
||||
this->running_state = STATE_RL_RUNNING;
|
||||
this->InitObservations();
|
||||
this->InitOutputs();
|
||||
this->InitControl();
|
||||
}
|
||||
}
|
||||
// rl loop
|
||||
else if(running_state == STATE_RL_RUNNING)
|
||||
else if(this->running_state == STATE_RL_RUNNING)
|
||||
{
|
||||
for(int i = 0; i < params.num_of_dofs; ++i)
|
||||
for(int i = 0; i < this->params.num_of_dofs; ++i)
|
||||
{
|
||||
command->motor_command.q[i] = output_dof_pos[0][i].item<double>();
|
||||
command->motor_command.q[i] = this->output_dof_pos[0][i].item<double>();
|
||||
command->motor_command.dq[i] = 0;
|
||||
command->motor_command.kp[i] = params.rl_kp[0][i].item<double>();
|
||||
command->motor_command.kd[i] = params.rl_kd[0][i].item<double>();
|
||||
command->motor_command.kp[i] = this->params.rl_kp[0][i].item<double>();
|
||||
command->motor_command.kd[i] = this->params.rl_kd[0][i].item<double>();
|
||||
command->motor_command.tau[i] = 0;
|
||||
}
|
||||
if(control.control_state == STATE_POS_GETDOWN)
|
||||
if(this->control.control_state == STATE_POS_GETDOWN)
|
||||
{
|
||||
control.control_state = STATE_WAITING;
|
||||
this->control.control_state = STATE_WAITING;
|
||||
getdown_percent = 0.0;
|
||||
for(int i = 0; i < params.num_of_dofs; ++i)
|
||||
for(int i = 0; i < this->params.num_of_dofs; ++i)
|
||||
{
|
||||
now_state.motor_state.q[i] = state->motor_state.q[i];
|
||||
}
|
||||
running_state = STATE_POS_GETDOWN;
|
||||
this->running_state = STATE_POS_GETDOWN;
|
||||
}
|
||||
}
|
||||
// get down (position control)
|
||||
else if(running_state == STATE_POS_GETDOWN)
|
||||
else if(this->running_state == STATE_POS_GETDOWN)
|
||||
{
|
||||
if(getdown_percent < 1.0)
|
||||
{
|
||||
getdown_percent += 1 / 1000.0;
|
||||
getdown_percent = getdown_percent > 1.0 ? 1.0 : getdown_percent;
|
||||
for(int i = 0; i < params.num_of_dofs; ++i)
|
||||
for(int i = 0; i < this->params.num_of_dofs; ++i)
|
||||
{
|
||||
command->motor_command.q[i] = (1 - getdown_percent) * now_state.motor_state.q[i] + getdown_percent * start_state.motor_state.q[i];
|
||||
command->motor_command.dq[i] = 0;
|
||||
command->motor_command.kp[i] = params.fixed_kp[0][i].item<double>();
|
||||
command->motor_command.kd[i] = params.fixed_kd[0][i].item<double>();
|
||||
command->motor_command.kp[i] = this->params.fixed_kp[0][i].item<double>();
|
||||
command->motor_command.kd[i] = this->params.fixed_kd[0][i].item<double>();
|
||||
command->motor_command.tau[i] = 0;
|
||||
}
|
||||
std::cout << LOGGER::INFO << "Getting down " << std::fixed << std::setprecision(2) << getdown_percent * 100.0 << "%\r";
|
||||
|
@ -193,7 +193,7 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
|
|||
if(getdown_percent == 1)
|
||||
{
|
||||
std::cout << std::endl;
|
||||
running_state = STATE_WAITING;
|
||||
this->running_state = STATE_WAITING;
|
||||
this->InitObservations();
|
||||
this->InitOutputs();
|
||||
this->InitControl();
|
||||
|
@ -229,7 +229,7 @@ void RL::TorqueProtect(torch::Tensor origin_output_torques)
|
|||
std::cout << LOGGER::ERROR << "Torque(" << index+1 << ")=" << value << " out of range(" << limit_lower << ", " << limit_upper << ")" << std::endl;
|
||||
std::cout << LOGGER::ERROR << "Switching to STATE_POS_GETDOWN"<< std::endl;
|
||||
}
|
||||
control.control_state = STATE_POS_GETDOWN;
|
||||
this->control.control_state = STATE_POS_GETDOWN;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -254,9 +254,9 @@ static bool kbhit()
|
|||
|
||||
void RL::KeyboardInterface()
|
||||
{
|
||||
if(running_state == STATE_RL_RUNNING)
|
||||
if(this->running_state == STATE_RL_RUNNING)
|
||||
{
|
||||
std::cout << LOGGER::INFO << "RL Controller x:" << control.x << " y:" << control.y << " yaw:" << control.yaw << " \r";
|
||||
std::cout << LOGGER::INFO << "RL Controller x:" << this->control.x << " y:" << this->control.y << " yaw:" << this->control.yaw << " \r";
|
||||
}
|
||||
|
||||
if(kbhit())
|
||||
|
@ -264,20 +264,20 @@ void RL::KeyboardInterface()
|
|||
int c = fgetc(stdin);
|
||||
switch(c)
|
||||
{
|
||||
case '0': control.control_state = STATE_POS_GETUP; break;
|
||||
case 'p': control.control_state = STATE_RL_INIT; break;
|
||||
case '1': control.control_state = STATE_POS_GETDOWN; break;
|
||||
case '0': this->control.control_state = STATE_POS_GETUP; break;
|
||||
case 'p': this->control.control_state = STATE_RL_INIT; break;
|
||||
case '1': this->control.control_state = STATE_POS_GETDOWN; break;
|
||||
case 'q': break;
|
||||
case 'w': control.x += 0.1; break;
|
||||
case 's': control.x -= 0.1; break;
|
||||
case 'a': control.yaw += 0.1; break;
|
||||
case 'd': control.yaw -= 0.1; break;
|
||||
case 'w': this->control.x += 0.1; break;
|
||||
case 's': this->control.x -= 0.1; break;
|
||||
case 'a': this->control.yaw += 0.1; break;
|
||||
case 'd': this->control.yaw -= 0.1; break;
|
||||
case 'i': break;
|
||||
case 'k': break;
|
||||
case 'j': control.y += 0.1; break;
|
||||
case 'l': control.y -= 0.1; break;
|
||||
case ' ': control.x = 0; control.y = 0; control.yaw = 0; break;
|
||||
case 'r': control.control_state = STATE_RESET_SIMULATION; break;
|
||||
case 'j': this->control.y += 0.1; break;
|
||||
case 'l': this->control.y -= 0.1; break;
|
||||
case ' ': this->control.x = 0; this->control.y = 0; this->control.yaw = 0; break;
|
||||
case 'r': this->control.control_state = STATE_RESET_SIMULATION; break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,13 +8,13 @@ RL_Real rl_sar;
|
|||
RL_Real::RL_Real() : unitree_safe(UNITREE_LEGGED_SDK::LeggedType::A1), unitree_udp(UNITREE_LEGGED_SDK::LOWLEVEL)
|
||||
{
|
||||
// read params from yaml
|
||||
robot_name = "a1";
|
||||
ReadYaml(robot_name);
|
||||
this->robot_name = "a1";
|
||||
this->ReadYaml(this->robot_name);
|
||||
|
||||
// history
|
||||
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
|
||||
|
||||
unitree_udp.InitCmdData(unitree_low_command);
|
||||
this->unitree_udp.InitCmdData(this->unitree_low_command);
|
||||
|
||||
// init
|
||||
torch::autograd::GradMode::set_enabled(false);
|
||||
|
@ -23,117 +23,117 @@ RL_Real::RL_Real() : unitree_safe(UNITREE_LEGGED_SDK::LeggedType::A1), unitree_u
|
|||
this->InitControl();
|
||||
|
||||
// model
|
||||
std::string model_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + robot_name + "/" + this->params.model_name;
|
||||
std::string model_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + this->robot_name + "/" + this->params.model_name;
|
||||
this->model = torch::jit::load(model_path);
|
||||
|
||||
// loop
|
||||
loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05 , boost::bind(&RL_Real::KeyboardInterface, this));
|
||||
loop_control = std::make_shared<LoopFunc>("loop_control" , 0.002, boost::bind(&RL_Real::RobotControl , this));
|
||||
loop_udpSend = std::make_shared<LoopFunc>("loop_udpSend" , 0.002, 3, boost::bind(&RL_Real::UDPSend , this));
|
||||
loop_udpRecv = std::make_shared<LoopFunc>("loop_udpRecv" , 0.002, 3, boost::bind(&RL_Real::UDPRecv , this));
|
||||
loop_rl = std::make_shared<LoopFunc>("loop_rl" , 0.02 , boost::bind(&RL_Real::RunModel , this));
|
||||
loop_keyboard->start();
|
||||
loop_udpSend->start();
|
||||
loop_udpRecv->start();
|
||||
loop_control->start();
|
||||
loop_rl->start();
|
||||
this->loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05 , boost::bind(&RL_Real::KeyboardInterface, this));
|
||||
this->loop_control = std::make_shared<LoopFunc>("loop_control" , 0.002, boost::bind(&RL_Real::RobotControl , this));
|
||||
this->loop_udpSend = std::make_shared<LoopFunc>("loop_udpSend" , 0.002, 3, boost::bind(&RL_Real::UDPSend , this));
|
||||
this->loop_udpRecv = std::make_shared<LoopFunc>("loop_udpRecv" , 0.002, 3, boost::bind(&RL_Real::UDPRecv , this));
|
||||
this->loop_rl = std::make_shared<LoopFunc>("loop_rl" , 0.02 , boost::bind(&RL_Real::RunModel , this));
|
||||
this->loop_keyboard->start();
|
||||
this->loop_udpSend->start();
|
||||
this->loop_udpRecv->start();
|
||||
this->loop_control->start();
|
||||
this->loop_rl->start();
|
||||
|
||||
#ifdef PLOT
|
||||
plot_t = std::vector<int>(plot_size, 0);
|
||||
plot_real_joint_pos.resize(params.num_of_dofs);
|
||||
plot_target_joint_pos.resize(params.num_of_dofs);
|
||||
for(auto& vector : plot_real_joint_pos) { vector = std::vector<double>(plot_size, 0); }
|
||||
for(auto& vector : plot_target_joint_pos) { vector = std::vector<double>(plot_size, 0); }
|
||||
loop_plot = std::make_shared<LoopFunc>("loop_plot" , 0.002, boost::bind(&RL_Real::Plot, this));
|
||||
loop_plot->start();
|
||||
this->plot_t = std::vector<int>(this->plot_size, 0);
|
||||
this->plot_real_joint_pos.resize(this->params.num_of_dofs);
|
||||
this->plot_target_joint_pos.resize(this->params.num_of_dofs);
|
||||
for(auto& vector : this->plot_real_joint_pos) { vector = std::vector<double>(this->plot_size, 0); }
|
||||
for(auto& vector : this->plot_target_joint_pos) { vector = std::vector<double>(this->plot_size, 0); }
|
||||
this->loop_plot = std::make_shared<LoopFunc>("loop_plot" , 0.002, boost::bind(&RL_Real::Plot, this));
|
||||
this->loop_plot->start();
|
||||
#endif
|
||||
#ifdef CSV_LOGGER
|
||||
CSVInit(robot_name);
|
||||
this->CSVInit(this->robot_name);
|
||||
#endif
|
||||
}
|
||||
|
||||
RL_Real::~RL_Real()
|
||||
{
|
||||
loop_keyboard->shutdown();
|
||||
loop_udpSend->shutdown();
|
||||
loop_udpRecv->shutdown();
|
||||
loop_control->shutdown();
|
||||
loop_rl->shutdown();
|
||||
this->loop_keyboard->shutdown();
|
||||
this->loop_udpSend->shutdown();
|
||||
this->loop_udpRecv->shutdown();
|
||||
this->loop_control->shutdown();
|
||||
this->loop_rl->shutdown();
|
||||
#ifdef PLOT
|
||||
loop_plot->shutdown();
|
||||
this->loop_plot->shutdown();
|
||||
#endif
|
||||
std::cout << LOGGER::INFO << "RL_Real exit" << std::endl;
|
||||
}
|
||||
|
||||
void RL_Real::GetState(RobotState<double> *state)
|
||||
{
|
||||
unitree_udp.GetRecv(unitree_low_state);
|
||||
memcpy(&unitree_joy, unitree_low_state.wirelessRemote, 40);
|
||||
this->unitree_udp.GetRecv(this->unitree_low_state);
|
||||
memcpy(&this->unitree_joy, this->unitree_low_state.wirelessRemote, 40);
|
||||
|
||||
if((int)unitree_joy.btn.components.R2 == 1)
|
||||
if((int)this->unitree_joy.btn.components.R2 == 1)
|
||||
{
|
||||
control.control_state = STATE_POS_GETUP;
|
||||
this->control.control_state = STATE_POS_GETUP;
|
||||
}
|
||||
else if((int)unitree_joy.btn.components.R1 == 1)
|
||||
else if((int)this->unitree_joy.btn.components.R1 == 1)
|
||||
{
|
||||
control.control_state = STATE_RL_INIT;
|
||||
this->control.control_state = STATE_RL_INIT;
|
||||
}
|
||||
else if((int)unitree_joy.btn.components.L2 == 1)
|
||||
else if((int)this->unitree_joy.btn.components.L2 == 1)
|
||||
{
|
||||
control.control_state = STATE_POS_GETDOWN;
|
||||
this->control.control_state = STATE_POS_GETDOWN;
|
||||
}
|
||||
|
||||
state->imu.quaternion[3] = unitree_low_state.imu.quaternion[0]; // w
|
||||
state->imu.quaternion[0] = unitree_low_state.imu.quaternion[1]; // x
|
||||
state->imu.quaternion[1] = unitree_low_state.imu.quaternion[2]; // y
|
||||
state->imu.quaternion[2] = unitree_low_state.imu.quaternion[3]; // z
|
||||
state->imu.quaternion[3] = this->unitree_low_state.imu.quaternion[0]; // w
|
||||
state->imu.quaternion[0] = this->unitree_low_state.imu.quaternion[1]; // x
|
||||
state->imu.quaternion[1] = this->unitree_low_state.imu.quaternion[2]; // y
|
||||
state->imu.quaternion[2] = this->unitree_low_state.imu.quaternion[3]; // z
|
||||
for(int i = 0; i < 3; ++i)
|
||||
{
|
||||
state->imu.gyroscope[i] = unitree_low_state.imu.gyroscope[i];
|
||||
state->imu.gyroscope[i] = this->unitree_low_state.imu.gyroscope[i];
|
||||
}
|
||||
for(int i = 0; i < params.num_of_dofs; ++i)
|
||||
for(int i = 0; i < this->params.num_of_dofs; ++i)
|
||||
{
|
||||
state->motor_state.q[i] = unitree_low_state.motorState[state_mapping[i]].q;
|
||||
state->motor_state.dq[i] = unitree_low_state.motorState[state_mapping[i]].dq;
|
||||
state->motor_state.tauEst[i] = unitree_low_state.motorState[state_mapping[i]].tauEst;
|
||||
state->motor_state.q[i] = this->unitree_low_state.motorState[state_mapping[i]].q;
|
||||
state->motor_state.dq[i] = this->unitree_low_state.motorState[state_mapping[i]].dq;
|
||||
state->motor_state.tauEst[i] = this->unitree_low_state.motorState[state_mapping[i]].tauEst;
|
||||
}
|
||||
}
|
||||
|
||||
void RL_Real::SetCommand(const RobotCommand<double> *command)
|
||||
{
|
||||
for(int i = 0; i < params.num_of_dofs; ++i)
|
||||
for(int i = 0; i < this->params.num_of_dofs; ++i)
|
||||
{
|
||||
unitree_low_command.motorCmd[i].mode = 0x0A;
|
||||
unitree_low_command.motorCmd[i].q = command->motor_command.q[command_mapping[i]];
|
||||
unitree_low_command.motorCmd[i].dq = command->motor_command.dq[command_mapping[i]];
|
||||
unitree_low_command.motorCmd[i].Kp = command->motor_command.kp[command_mapping[i]];
|
||||
unitree_low_command.motorCmd[i].Kd = command->motor_command.kd[command_mapping[i]];
|
||||
unitree_low_command.motorCmd[i].tau = command->motor_command.tau[command_mapping[i]];
|
||||
this->unitree_low_command.motorCmd[i].mode = 0x0A;
|
||||
this->unitree_low_command.motorCmd[i].q = command->motor_command.q[command_mapping[i]];
|
||||
this->unitree_low_command.motorCmd[i].dq = command->motor_command.dq[command_mapping[i]];
|
||||
this->unitree_low_command.motorCmd[i].Kp = command->motor_command.kp[command_mapping[i]];
|
||||
this->unitree_low_command.motorCmd[i].Kd = command->motor_command.kd[command_mapping[i]];
|
||||
this->unitree_low_command.motorCmd[i].tau = command->motor_command.tau[command_mapping[i]];
|
||||
}
|
||||
|
||||
unitree_safe.PowerProtect(unitree_low_command, unitree_low_state, 6);
|
||||
// unitree_safe.PositionProtect(unitree_low_command, unitree_low_state);
|
||||
unitree_udp.SetSend(unitree_low_command);
|
||||
this->unitree_safe.PowerProtect(this->unitree_low_command, this->unitree_low_state, 6);
|
||||
// this->unitree_safe.PositionProtect(this->unitree_low_command, this->unitree_low_state);
|
||||
this->unitree_udp.SetSend(this->unitree_low_command);
|
||||
}
|
||||
|
||||
void RL_Real::RobotControl()
|
||||
{
|
||||
motiontime++;
|
||||
this->motiontime++;
|
||||
|
||||
GetState(&robot_state);
|
||||
StateController(&robot_state, &robot_command);
|
||||
SetCommand(&robot_command);
|
||||
this->GetState(&this->robot_state);
|
||||
this->StateController(&this->robot_state, &this->robot_command);
|
||||
this->SetCommand(&this->robot_command);
|
||||
}
|
||||
|
||||
void RL_Real::RunModel()
|
||||
{
|
||||
if(running_state == STATE_RL_RUNNING)
|
||||
if(this->running_state == STATE_RL_RUNNING)
|
||||
{
|
||||
this->obs.ang_vel = torch::tensor(robot_state.imu.gyroscope).unsqueeze(0);
|
||||
this->obs.commands = torch::tensor({{unitree_joy.ly, -unitree_joy.rx, -unitree_joy.lx}});
|
||||
this->obs.base_quat = torch::tensor(robot_state.imu.quaternion).unsqueeze(0);
|
||||
this->obs.dof_pos = torch::tensor(robot_state.motor_state.q).narrow(0, 0, params.num_of_dofs).unsqueeze(0);
|
||||
this->obs.dof_vel = torch::tensor(robot_state.motor_state.dq).narrow(0, 0, params.num_of_dofs).unsqueeze(0);
|
||||
this->obs.ang_vel = torch::tensor(this->robot_state.imu.gyroscope).unsqueeze(0);
|
||||
this->obs.commands = torch::tensor({{this->unitree_joy.ly, -this->unitree_joy.rx, -this->unitree_joy.lx}});
|
||||
this->obs.base_quat = torch::tensor(this->robot_state.imu.quaternion).unsqueeze(0);
|
||||
this->obs.dof_pos = torch::tensor(this->robot_state.motor_state.q).narrow(0, 0, this->params.num_of_dofs).unsqueeze(0);
|
||||
this->obs.dof_vel = torch::tensor(this->robot_state.motor_state.dq).narrow(0, 0, this->params.num_of_dofs).unsqueeze(0);
|
||||
|
||||
torch::Tensor clamped_actions = this->Forward();
|
||||
|
||||
|
@ -146,14 +146,14 @@ void RL_Real::RunModel()
|
|||
|
||||
torch::Tensor origin_output_torques = this->ComputeTorques(this->obs.actions);
|
||||
|
||||
TorqueProtect(origin_output_torques);
|
||||
this->TorqueProtect(origin_output_torques);
|
||||
|
||||
output_torques = torch::clamp(origin_output_torques, -(this->params.torque_limits), this->params.torque_limits);
|
||||
output_dof_pos = this->ComputePosition(this->obs.actions);
|
||||
this->output_torques = torch::clamp(origin_output_torques, -(this->params.torque_limits), this->params.torque_limits);
|
||||
this->output_dof_pos = this->ComputePosition(this->obs.actions);
|
||||
|
||||
#ifdef CSV_LOGGER
|
||||
torch::Tensor tau_est = torch::tensor(robot_state.motor_state.tauEst).unsqueeze(0);
|
||||
CSVLogger(output_torques, tau_est, this->obs.dof_pos, output_dof_pos, this->obs.dof_vel);
|
||||
torch::Tensor tau_est = torch::tensor(this->robot_state.motor_state.tauEst).unsqueeze(0);
|
||||
this->CSVLogger(this->output_torques, tau_est, this->obs.dof_pos, this->output_dof_pos, this->obs.dof_vel);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
@ -178,10 +178,10 @@ torch::Tensor RL_Real::Forward()
|
|||
|
||||
torch::Tensor clamped_obs = this->ComputeObservation();
|
||||
|
||||
history_obs_buf.insert(clamped_obs);
|
||||
history_obs = history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5});
|
||||
this->history_obs_buf.insert(clamped_obs);
|
||||
this->history_obs = this->history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5});
|
||||
|
||||
torch::Tensor actions = this->model.forward({history_obs}).toTensor();
|
||||
torch::Tensor actions = this->model.forward({this->history_obs}).toTensor();
|
||||
|
||||
torch::Tensor clamped_actions = torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper);
|
||||
|
||||
|
@ -190,20 +190,20 @@ torch::Tensor RL_Real::Forward()
|
|||
|
||||
void RL_Real::Plot()
|
||||
{
|
||||
plot_t.erase(plot_t.begin());
|
||||
plot_t.push_back(motiontime);
|
||||
this->plot_t.erase(this->plot_t.begin());
|
||||
this->plot_t.push_back(this->motiontime);
|
||||
plt::cla();
|
||||
plt::clf();
|
||||
for(int i = 0; i < params.num_of_dofs; ++i)
|
||||
for(int i = 0; i < this->params.num_of_dofs; ++i)
|
||||
{
|
||||
plot_real_joint_pos[i].erase(plot_real_joint_pos[i].begin());
|
||||
plot_target_joint_pos[i].erase(plot_target_joint_pos[i].begin());
|
||||
plot_real_joint_pos[i].push_back(unitree_low_state.motorState[i].q);
|
||||
plot_target_joint_pos[i].push_back(unitree_low_command.motorCmd[i].q);
|
||||
this->plot_real_joint_pos[i].erase(this->plot_real_joint_pos[i].begin());
|
||||
this->plot_target_joint_pos[i].erase(this->plot_target_joint_pos[i].begin());
|
||||
this->plot_real_joint_pos[i].push_back(this->unitree_low_state.motorState[i].q);
|
||||
this->plot_target_joint_pos[i].push_back(this->unitree_low_command.motorCmd[i].q);
|
||||
plt::subplot(4, 3, i + 1);
|
||||
plt::named_plot("_real_joint_pos", plot_t, plot_real_joint_pos[i], "r");
|
||||
plt::named_plot("_target_joint_pos", plot_t, plot_target_joint_pos[i], "b");
|
||||
plt::xlim(plot_t.front(), plot_t.back());
|
||||
plt::named_plot("_real_joint_pos", this->plot_t, this->plot_real_joint_pos[i], "r");
|
||||
plt::named_plot("_target_joint_pos", this->plot_t, this->plot_target_joint_pos[i], "b");
|
||||
plt::xlim(this->plot_t.front(), this->plot_t.back());
|
||||
}
|
||||
// plt::legend();
|
||||
plt::pause(0.0001);
|
||||
|
|
|
@ -8,180 +8,180 @@ RL_Sim::RL_Sim()
|
|||
ros::NodeHandle nh;
|
||||
|
||||
// read params from yaml
|
||||
nh.param<std::string>("robot_name", robot_name, "");
|
||||
ReadYaml(robot_name);
|
||||
nh.param<std::string>("robot_name", this->robot_name, "");
|
||||
this->ReadYaml(this->robot_name);
|
||||
|
||||
// history
|
||||
nh.param<bool>("use_history", use_history, "");
|
||||
if(use_history)
|
||||
nh.param<bool>("use_history", this->use_history, "");
|
||||
if(this->use_history)
|
||||
{
|
||||
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
|
||||
}
|
||||
|
||||
// Due to the fact that the robot_state_publisher sorts the joint names alphabetically,
|
||||
// the mapping table is established according to the order defined in the YAML file
|
||||
std::vector<std::string> sorted_joint_controller_names = params.joint_controller_names;
|
||||
std::vector<std::string> sorted_joint_controller_names = this->params.joint_controller_names;
|
||||
std::sort(sorted_joint_controller_names.begin(), sorted_joint_controller_names.end());
|
||||
for(size_t i = 0; i < params.joint_controller_names.size(); ++i)
|
||||
for(size_t i = 0; i < this->params.joint_controller_names.size(); ++i)
|
||||
{
|
||||
sorted_to_original_index[sorted_joint_controller_names[i]] = i;
|
||||
this->sorted_to_original_index[sorted_joint_controller_names[i]] = i;
|
||||
}
|
||||
mapped_joint_positions = std::vector<double>(params.num_of_dofs, 0.0);
|
||||
mapped_joint_velocities = std::vector<double>(params.num_of_dofs, 0.0);
|
||||
mapped_joint_efforts = std::vector<double>(params.num_of_dofs, 0.0);
|
||||
this->mapped_joint_positions = std::vector<double>(this->params.num_of_dofs, 0.0);
|
||||
this->mapped_joint_velocities = std::vector<double>(this->params.num_of_dofs, 0.0);
|
||||
this->mapped_joint_efforts = std::vector<double>(this->params.num_of_dofs, 0.0);
|
||||
|
||||
// init
|
||||
torch::autograd::GradMode::set_enabled(false);
|
||||
motor_commands.resize(params.num_of_dofs);
|
||||
this->joint_publishers_commands.resize(this->params.num_of_dofs);
|
||||
this->InitObservations();
|
||||
this->InitOutputs();
|
||||
this->InitControl();
|
||||
|
||||
// model
|
||||
std::string model_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + robot_name + "/" + this->params.model_name;
|
||||
std::string model_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + this->robot_name + "/" + this->params.model_name;
|
||||
this->model = torch::jit::load(model_path);
|
||||
|
||||
// publisher
|
||||
nh.param<std::string>("ros_namespace", ros_namespace, "");
|
||||
for (int i = 0; i < params.num_of_dofs; ++i)
|
||||
nh.param<std::string>("ros_namespace", this->ros_namespace, "");
|
||||
for (int i = 0; i < this->params.num_of_dofs; ++i)
|
||||
{
|
||||
// joint need to rename as xxx_joint
|
||||
torque_publishers[params.joint_controller_names[i]] = nh.advertise<robot_msgs::MotorCommand>(
|
||||
ros_namespace + params.joint_controller_names[i] + "/command", 10);
|
||||
this->joint_publishers[this->params.joint_controller_names[i]] = nh.advertise<robot_msgs::MotorCommand>(
|
||||
this->ros_namespace + this->params.joint_controller_names[i] + "/command", 10);
|
||||
}
|
||||
|
||||
// subscriber
|
||||
cmd_vel_subscriber = nh.subscribe<geometry_msgs::Twist>("/cmd_vel", 10, &RL_Sim::CmdvelCallback, this);
|
||||
model_state_subscriber = nh.subscribe<gazebo_msgs::ModelStates>("/gazebo/model_states", 10, &RL_Sim::ModelStatesCallback, this);
|
||||
joint_state_subscriber = nh.subscribe<sensor_msgs::JointState>(ros_namespace + "joint_states", 10, &RL_Sim::JointStatesCallback, this);
|
||||
this->cmd_vel_subscriber = nh.subscribe<geometry_msgs::Twist>("/cmd_vel", 10, &RL_Sim::CmdvelCallback, this);
|
||||
this->model_state_subscriber = nh.subscribe<gazebo_msgs::ModelStates>("/gazebo/model_states", 10, &RL_Sim::ModelStatesCallback, this);
|
||||
this->joint_state_subscriber = nh.subscribe<sensor_msgs::JointState>(this->ros_namespace + "joint_states", 10, &RL_Sim::JointStatesCallback, this);
|
||||
|
||||
// service
|
||||
gazebo_reset_client = nh.serviceClient<std_srvs::Empty>("/gazebo/reset_simulation");
|
||||
this->gazebo_reset_client = nh.serviceClient<std_srvs::Empty>("/gazebo/reset_simulation");
|
||||
|
||||
// loop
|
||||
loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05 , boost::bind(&RL_Sim::KeyboardInterface, this));
|
||||
loop_control = std::make_shared<LoopFunc>("loop_control" , 0.002, boost::bind(&RL_Sim::RobotControl , this));
|
||||
loop_rl = std::make_shared<LoopFunc>("loop_rl" , 0.02 , boost::bind(&RL_Sim::RunModel , this));
|
||||
loop_keyboard->start();
|
||||
loop_control->start();
|
||||
loop_rl->start();
|
||||
this->loop_keyboard = std::make_shared<LoopFunc>("loop_keyboard", 0.05 , boost::bind(&RL_Sim::KeyboardInterface, this));
|
||||
this->loop_control = std::make_shared<LoopFunc>("loop_control" , 0.002, boost::bind(&RL_Sim::RobotControl , this));
|
||||
this->loop_rl = std::make_shared<LoopFunc>("loop_rl" , 0.02 , boost::bind(&RL_Sim::RunModel , this));
|
||||
this->loop_keyboard->start();
|
||||
this->loop_control->start();
|
||||
this->loop_rl->start();
|
||||
|
||||
#ifdef PLOT
|
||||
plot_t = std::vector<int>(plot_size, 0);
|
||||
plot_real_joint_pos.resize(params.num_of_dofs);
|
||||
plot_target_joint_pos.resize(params.num_of_dofs);
|
||||
for(auto& vector : plot_real_joint_pos) { vector = std::vector<double>(plot_size, 0); }
|
||||
for(auto& vector : plot_target_joint_pos) { vector = std::vector<double>(plot_size, 0); }
|
||||
loop_plot = std::make_shared<LoopFunc>("loop_plot" , 0.002, boost::bind(&RL_Sim::Plot, this));
|
||||
loop_plot->start();
|
||||
plot_t = std::vector<int>(this->plot_size, 0);
|
||||
this->plot_real_joint_pos.resize(this->params.num_of_dofs);
|
||||
this->plot_target_joint_pos.resize(this->params.num_of_dofs);
|
||||
for(auto& vector : this->plot_real_joint_pos) { vector = std::vector<double>(this->plot_size, 0); }
|
||||
for(auto& vector : this->plot_target_joint_pos) { vector = std::vector<double>(this->plot_size, 0); }
|
||||
this->loop_plot = std::make_shared<LoopFunc>("loop_plot" , 0.002, boost::bind(&RL_Sim::Plot, this));
|
||||
this->loop_plot->start();
|
||||
#endif
|
||||
#ifdef CSV_LOGGER
|
||||
CSVInit(robot_name);
|
||||
this->CSVInit(this->robot_name);
|
||||
#endif
|
||||
}
|
||||
|
||||
RL_Sim::~RL_Sim()
|
||||
{
|
||||
loop_keyboard->shutdown();
|
||||
loop_control->shutdown();
|
||||
loop_rl->shutdown();
|
||||
this->loop_keyboard->shutdown();
|
||||
this->loop_control->shutdown();
|
||||
this->loop_rl->shutdown();
|
||||
#ifdef PLOT
|
||||
loop_plot->shutdown();
|
||||
this->loop_plot->shutdown();
|
||||
#endif
|
||||
std::cout << LOGGER::INFO << "RL_Sim exit" << std::endl;
|
||||
}
|
||||
|
||||
void RL_Sim::GetState(RobotState<double> *state)
|
||||
{
|
||||
state->imu.quaternion[3] = pose.orientation.w;
|
||||
state->imu.quaternion[0] = pose.orientation.x;
|
||||
state->imu.quaternion[1] = pose.orientation.y;
|
||||
state->imu.quaternion[2] = pose.orientation.z;
|
||||
state->imu.quaternion[3] = this->pose.orientation.w;
|
||||
state->imu.quaternion[0] = this->pose.orientation.x;
|
||||
state->imu.quaternion[1] = this->pose.orientation.y;
|
||||
state->imu.quaternion[2] = this->pose.orientation.z;
|
||||
|
||||
state->imu.gyroscope[0] = vel.angular.x;
|
||||
state->imu.gyroscope[1] = vel.angular.y;
|
||||
state->imu.gyroscope[2] = vel.angular.z;
|
||||
state->imu.gyroscope[0] = this->vel.angular.x;
|
||||
state->imu.gyroscope[1] = this->vel.angular.y;
|
||||
state->imu.gyroscope[2] = this->vel.angular.z;
|
||||
|
||||
// state->imu.accelerometer
|
||||
|
||||
for(int i = 0; i < params.num_of_dofs; ++i)
|
||||
for(int i = 0; i < this->params.num_of_dofs; ++i)
|
||||
{
|
||||
state->motor_state.q[i] = mapped_joint_positions[i];
|
||||
state->motor_state.dq[i] = mapped_joint_velocities[i];
|
||||
state->motor_state.tauEst[i] = mapped_joint_efforts[i];
|
||||
state->motor_state.q[i] = this->mapped_joint_positions[i];
|
||||
state->motor_state.dq[i] = this->mapped_joint_velocities[i];
|
||||
state->motor_state.tauEst[i] = this->mapped_joint_efforts[i];
|
||||
}
|
||||
}
|
||||
|
||||
void RL_Sim::SetCommand(const RobotCommand<double> *command)
|
||||
{
|
||||
for(int i = 0; i < params.num_of_dofs; ++i)
|
||||
for(int i = 0; i < this->params.num_of_dofs; ++i)
|
||||
{
|
||||
motor_commands[i].q = command->motor_command.q[i];
|
||||
motor_commands[i].dq = command->motor_command.dq[i];
|
||||
motor_commands[i].kp = command->motor_command.kp[i];
|
||||
motor_commands[i].kd = command->motor_command.kd[i];
|
||||
motor_commands[i].tau = command->motor_command.tau[i];
|
||||
this->joint_publishers_commands[i].q = command->motor_command.q[i];
|
||||
this->joint_publishers_commands[i].dq = command->motor_command.dq[i];
|
||||
this->joint_publishers_commands[i].kp = command->motor_command.kp[i];
|
||||
this->joint_publishers_commands[i].kd = command->motor_command.kd[i];
|
||||
this->joint_publishers_commands[i].tau = command->motor_command.tau[i];
|
||||
}
|
||||
|
||||
for(int i = 0; i < params.num_of_dofs; ++i)
|
||||
for(int i = 0; i < this->params.num_of_dofs; ++i)
|
||||
{
|
||||
torque_publishers[params.joint_controller_names[i]].publish(motor_commands[i]);
|
||||
this->joint_publishers[this->params.joint_controller_names[i]].publish(this->joint_publishers_commands[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void RL_Sim::RobotControl()
|
||||
{
|
||||
motiontime++;
|
||||
this->motiontime++;
|
||||
|
||||
if(control.control_state == STATE_RESET_SIMULATION)
|
||||
if(this->control.control_state == STATE_RESET_SIMULATION)
|
||||
{
|
||||
control.control_state = STATE_WAITING;
|
||||
this->control.control_state = STATE_WAITING;
|
||||
std_srvs::Empty srv;
|
||||
gazebo_reset_client.call(srv);
|
||||
this->gazebo_reset_client.call(srv);
|
||||
}
|
||||
|
||||
GetState(&robot_state);
|
||||
StateController(&robot_state, &robot_command);
|
||||
SetCommand(&robot_command);
|
||||
this->GetState(&this->robot_state);
|
||||
this->StateController(&this->robot_state, &this->robot_command);
|
||||
this->SetCommand(&this->robot_command);
|
||||
}
|
||||
|
||||
void RL_Sim::ModelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg)
|
||||
{
|
||||
vel = msg->twist[2];
|
||||
pose = msg->pose[2];
|
||||
this->vel = msg->twist[2];
|
||||
this->pose = msg->pose[2];
|
||||
}
|
||||
|
||||
void RL_Sim::CmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg)
|
||||
{
|
||||
cmd_vel = *msg;
|
||||
this->cmd_vel = *msg;
|
||||
}
|
||||
|
||||
void RL_Sim::MapData(const std::vector<double>& source_data, std::vector<double>& target_data)
|
||||
{
|
||||
for(size_t i = 0; i < source_data.size(); ++i)
|
||||
{
|
||||
target_data[i] = source_data[sorted_to_original_index[params.joint_controller_names[i]]];
|
||||
target_data[i] = source_data[this->sorted_to_original_index[this->params.joint_controller_names[i]]];
|
||||
}
|
||||
}
|
||||
|
||||
void RL_Sim::JointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg)
|
||||
{
|
||||
MapData(msg->position, mapped_joint_positions);
|
||||
MapData(msg->velocity, mapped_joint_velocities);
|
||||
MapData(msg->effort, mapped_joint_efforts);
|
||||
MapData(msg->position, this->mapped_joint_positions);
|
||||
MapData(msg->velocity, this->mapped_joint_velocities);
|
||||
MapData(msg->effort, this->mapped_joint_efforts);
|
||||
}
|
||||
|
||||
void RL_Sim::RunModel()
|
||||
{
|
||||
if(running_state == STATE_RL_RUNNING)
|
||||
if(this->running_state == STATE_RL_RUNNING)
|
||||
{
|
||||
// this->obs.lin_vel = torch::tensor({{vel.linear.x, vel.linear.y, vel.linear.z}});
|
||||
this->obs.ang_vel = torch::tensor(robot_state.imu.gyroscope).unsqueeze(0);
|
||||
// this->obs.commands = torch::tensor({{cmd_vel.linear.x, cmd_vel.linear.y, cmd_vel.angular.z}});
|
||||
this->obs.commands = torch::tensor({{control.x, control.y, control.yaw}});
|
||||
this->obs.base_quat = torch::tensor(robot_state.imu.quaternion).unsqueeze(0);
|
||||
this->obs.dof_pos = torch::tensor(robot_state.motor_state.q).narrow(0, 0, params.num_of_dofs).unsqueeze(0);
|
||||
this->obs.dof_vel = torch::tensor(robot_state.motor_state.dq).narrow(0, 0, params.num_of_dofs).unsqueeze(0);
|
||||
// this->obs.lin_vel = torch::tensor({{this->vel.linear.x, this->vel.linear.y, this->vel.linear.z}});
|
||||
this->obs.ang_vel = torch::tensor(this->robot_state.imu.gyroscope).unsqueeze(0);
|
||||
// this->obs.commands = torch::tensor({{this->cmd_vel.linear.x, this->cmd_vel.linear.y, this->cmd_vel.angular.z}});
|
||||
this->obs.commands = torch::tensor({{this->control.x, this->control.y, this->control.yaw}});
|
||||
this->obs.base_quat = torch::tensor(this->robot_state.imu.quaternion).unsqueeze(0);
|
||||
this->obs.dof_pos = torch::tensor(this->robot_state.motor_state.q).narrow(0, 0, this->params.num_of_dofs).unsqueeze(0);
|
||||
this->obs.dof_vel = torch::tensor(this->robot_state.motor_state.dq).narrow(0, 0, this->params.num_of_dofs).unsqueeze(0);
|
||||
|
||||
torch::Tensor clamped_actions = this->Forward();
|
||||
|
||||
|
@ -194,14 +194,14 @@ void RL_Sim::RunModel()
|
|||
|
||||
torch::Tensor origin_output_torques = this->ComputeTorques(this->obs.actions);
|
||||
|
||||
TorqueProtect(origin_output_torques);
|
||||
this->TorqueProtect(origin_output_torques);
|
||||
|
||||
output_torques = torch::clamp(origin_output_torques, -(this->params.torque_limits), this->params.torque_limits);
|
||||
output_dof_pos = this->ComputePosition(this->obs.actions);
|
||||
this->output_torques = torch::clamp(origin_output_torques, -(this->params.torque_limits), this->params.torque_limits);
|
||||
this->output_dof_pos = this->ComputePosition(this->obs.actions);
|
||||
|
||||
#ifdef CSV_LOGGER
|
||||
torch::Tensor tau_est = torch::tensor(mapped_joint_efforts).unsqueeze(0);
|
||||
CSVLogger(output_torques, tau_est, this->obs.dof_pos, output_dof_pos, this->obs.dof_vel);
|
||||
torch::Tensor tau_est = torch::tensor(this->mapped_joint_efforts).unsqueeze(0);
|
||||
this->CSVLogger(this->output_torques, tau_est, this->obs.dof_pos, this->output_dof_pos, this->obs.dof_vel);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
@ -229,11 +229,11 @@ torch::Tensor RL_Sim::Forward()
|
|||
|
||||
torch::Tensor actions;
|
||||
|
||||
if(use_history)
|
||||
if(this->use_history)
|
||||
{
|
||||
history_obs_buf.insert(clamped_obs);
|
||||
history_obs = history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5});
|
||||
actions = this->model.forward({history_obs}).toTensor();
|
||||
this->history_obs_buf.insert(clamped_obs);
|
||||
this->history_obs = this->history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5});
|
||||
actions = this->model.forward({this->history_obs}).toTensor();
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -247,20 +247,20 @@ torch::Tensor RL_Sim::Forward()
|
|||
|
||||
void RL_Sim::Plot()
|
||||
{
|
||||
plot_t.erase(plot_t.begin());
|
||||
plot_t.push_back(motiontime);
|
||||
this->plot_t.erase(this->plot_t.begin());
|
||||
this->plot_t.push_back(this->motiontime);
|
||||
plt::cla();
|
||||
plt::clf();
|
||||
for(int i = 0; i < params.num_of_dofs; ++i)
|
||||
for(int i = 0; i < this->params.num_of_dofs; ++i)
|
||||
{
|
||||
plot_real_joint_pos[i].erase(plot_real_joint_pos[i].begin());
|
||||
plot_target_joint_pos[i].erase(plot_target_joint_pos[i].begin());
|
||||
plot_real_joint_pos[i].push_back(mapped_joint_positions[i]);
|
||||
plot_target_joint_pos[i].push_back(motor_commands[i].q);
|
||||
this->plot_real_joint_pos[i].erase(this->plot_real_joint_pos[i].begin());
|
||||
this->plot_target_joint_pos[i].erase(this->plot_target_joint_pos[i].begin());
|
||||
this->plot_real_joint_pos[i].push_back(this->mapped_joint_positions[i]);
|
||||
this->plot_target_joint_pos[i].push_back(this->joint_publishers_commands[i].q);
|
||||
plt::subplot(4, 3, i+1);
|
||||
plt::named_plot("_real_joint_pos", plot_t, plot_real_joint_pos[i], "r");
|
||||
plt::named_plot("_target_joint_pos", plot_t, plot_target_joint_pos[i], "b");
|
||||
plt::xlim(plot_t.front(), plot_t.back());
|
||||
plt::named_plot("_real_joint_pos", this->plot_t, this->plot_real_joint_pos[i], "r");
|
||||
plt::named_plot("_target_joint_pos", this->plot_t, this->plot_target_joint_pos[i], "b");
|
||||
plt::xlim(this->plot_t.front(), this->plot_t.back());
|
||||
}
|
||||
// plt::legend();
|
||||
plt::pause(0.0001);
|
||||
|
|
Loading…
Reference in New Issue