This commit is contained in:
fan-ziqi 2024-05-24 11:40:03 +08:00
parent 7b681c1834
commit 1d6ecb2771
10 changed files with 397 additions and 158 deletions

View File

@ -75,15 +75,14 @@ target_link_libraries(rl_sim
rl_sdk observation_buffer yaml-cpp
)
add_executable(rl_real_a1 src/rl_real_a1.cpp )
target_link_libraries(rl_real_a1
${catkin_LIBRARIES} ${EXTRA_LIBS} "${TORCH_LIBRARIES}"
rl_sdk observation_buffer yaml-cpp
)
# add_executable(rl_real_a1 src/rl_real_a1.cpp )
# target_link_libraries(rl_real_a1
# ${catkin_LIBRARIES} ${EXTRA_LIBS} "${TORCH_LIBRARIES}"
# rl_sdk observation_buffer yaml-cpp
# )
add_executable(rl_real_cyberdog src/rl_real_cyberdog.cpp )
target_link_libraries(rl_real_cyberdog
${catkin_LIBRARIES} ${EXTRA_LIBS} "${TORCH_LIBRARIES}"
rl_sdk observation_buffer cyberdog_motor_sdk yaml-cpp
)
# add_executable(rl_real_cyberdog src/rl_real_cyberdog.cpp )
# target_link_libraries(rl_real_cyberdog
# ${catkin_LIBRARIES} ${EXTRA_LIBS} "${TORCH_LIBRARIES}"
# rl_sdk observation_buffer cyberdog_motor_sdk yaml-cpp
# )

View File

@ -107,3 +107,30 @@ lite3_wheel:
"FR_hip_joint", "FR_thigh_joint", "FR_calf_joint",
"RL_hip_joint", "RL_thigh_joint", "RL_calf_joint",
"RR_hip_joint", "RR_thigh_joint", "RR_calf_joint"]
gr1t1:
model_name: "model_4000_jit.pt"
num_observations: 39
clip_obs: 100.0
clip_actions: 100.0
# damping: 0.5
# stiffness: 20.0
p_gains: [57.0, 43.0, 114.0, 114.0, 15.3,
57.0, 43.0, 114.0, 114.0, 15.3]
d_gains: [5.7, 4.3, 11.4, 11.4, 1.5,
5.7, 4.3, 11.4, 11.4, 1.5]
action_scale: 1.0
hip_scale_reduction: 1.0
hip_scale_reduction_indices: []
num_of_dofs: 10
lin_vel_scale: 2.0
ang_vel_scale: 0.25
dof_pos_scale: 1.0
dof_vel_scale: 0.05
commands_scale: [2.0, 2.0, 0.25]
torque_limits: [100.0, 100.0, 100.0, 100.0, 100.0,
100.0, 100.0, 100.0, 100.0, 100.0]
default_dof_pos: [0.0, 0.0, -0.2618, 0.5236, -0.2618,
0.0, 0.0, -0.2618, 0.5236, -0.2618]
joint_names: ["l_hip_roll_joint", "l_hip_yaw_joint", "l_hip_pitch_joint", "l_knee_pitch_joint", "l_ankle_pitch_joint",
"r_hip_roll_joint", "r_hip_yaw_joint", "r_hip_pitch_joint", "r_knee_pitch_joint", "r_ankle_pitch_joint"]

View File

@ -8,14 +8,6 @@
#include <csignal>
// #include <signal.h>
enum RobotState {
STATE_WAITING = 0,
STATE_POS_GETUP,
STATE_RL_INIT,
STATE_RL_RUNNING,
STATE_POS_GETDOWN,
};
class RL_Real : public RL
{
public:

View File

@ -8,29 +8,10 @@
#include <CustomInterface.h>
#include <csignal>
// #include <signal.h>
#include <termios.h>
#include <sys/ioctl.h>
#include <iostream>
using CyberdogData = Robot_Data;
using CyberdogCmd = Motor_Cmd;
enum RobotState {
STATE_WAITING = 0,
STATE_POS_GETUP,
STATE_RL_INIT,
STATE_RL_RUNNING,
STATE_POS_GETDOWN,
};
struct KeyBoard
{
RobotState robot_state;
float x = 0;
float y = 0;
float yaw = 0;
};
class RL_Real : public RL, public CustomInterface
{
public:
@ -56,20 +37,11 @@ public:
std::shared_ptr<UNITREE_LEGGED_SDK::LoopFunc> loop_rl;
std::shared_ptr<UNITREE_LEGGED_SDK::LoopFunc> loop_plot;
float getup_percent = 0.0;
float getdown_percent = 0.0;
float start_pos[12];
float now_pos[12];
int robot_state = STATE_WAITING;
const int plot_size = 100;
std::vector<int> plot_t;
std::vector<std::vector<double>> plot_real_joint_pos, plot_target_joint_pos;
void Plot();
void run_keyboard();
KeyBoard keyboard;
std::thread _keyboardThread;
private:
std::vector<std::string> joint_names;

View File

@ -29,6 +29,9 @@ public:
torch::Tensor Forward() override;
torch::Tensor ComputeObservation() override;
void GetState(RobotState<double> *state) override;
void SetCommand(const RobotCommand<double> *command) override;
ObservationBuffer history_obs_buf;
torch::Tensor history_obs;
@ -42,6 +45,8 @@ public:
std::vector<int> plot_t;
std::vector<std::vector<double>> plot_real_joint_pos, plot_target_joint_pos;
void Plot();
std::thread _keyboardThread;
private:
std::string ros_namespace;

View File

@ -51,4 +51,6 @@
<remap from="/joint_states" to="/$(arg rname)_gazebo/joint_states"/>
</node>
<node pkg="rl_sar" type="rl_sim" name="rl_sim" output="screen"/>
</launch>

View File

@ -1,5 +1,8 @@
#include "rl_sdk.hpp"
#include <termios.h>
#include <sys/ioctl.h>
template<typename T>
std::vector<T> ReadVectorFromYaml(const YAML::Node& node)
{
@ -26,26 +29,26 @@ void RL::ReadYaml(std::string robot_name)
this->params.model_name = config["model_name"].as<std::string>();
this->params.num_observations = config["num_observations"].as<int>();
this->params.clip_obs = config["clip_obs"].as<float>();
this->params.clip_actions = config["clip_actions"].as<float>();
this->params.action_scale = config["action_scale"].as<float>();
this->params.hip_scale_reduction = config["hip_scale_reduction"].as<float>();
this->params.clip_obs = config["clip_obs"].as<double>();
this->params.clip_actions = config["clip_actions"].as<double>();
this->params.action_scale = config["action_scale"].as<double>();
this->params.hip_scale_reduction = config["hip_scale_reduction"].as<double>();
this->params.hip_scale_reduction_indices = ReadVectorFromYaml<int>(config["hip_scale_reduction_indices"]);
this->params.num_of_dofs = config["num_of_dofs"].as<int>();
this->params.lin_vel_scale = config["lin_vel_scale"].as<float>();
this->params.ang_vel_scale = config["ang_vel_scale"].as<float>();
this->params.dof_pos_scale = config["dof_pos_scale"].as<float>();
this->params.dof_vel_scale = config["dof_vel_scale"].as<float>();
// this->params.commands_scale = torch::tensor(ReadVectorFromYaml<float>(config["commands_scale"])).view({1, -1});
this->params.lin_vel_scale = config["lin_vel_scale"].as<double>();
this->params.ang_vel_scale = config["ang_vel_scale"].as<double>();
this->params.dof_pos_scale = config["dof_pos_scale"].as<double>();
this->params.dof_vel_scale = config["dof_vel_scale"].as<double>();
// this->params.commands_scale = torch::tensor(ReadVectorFromYaml<double>(config["commands_scale"])).view({1, -1});
this->params.commands_scale = torch::tensor({this->params.lin_vel_scale, this->params.lin_vel_scale, this->params.ang_vel_scale});
// this->params.damping = config["damping"].as<float>();
// this->params.stiffness = config["stiffness"].as<float>();
// this->params.damping = config["damping"].as<double>();
// this->params.stiffness = config["stiffness"].as<double>();
// this->params.d_gains = torch::ones(12) * this->params.damping;
// this->params.p_gains = torch::ones(12) * this->params.stiffness;
this->params.p_gains = torch::tensor(ReadVectorFromYaml<float>(config["p_gains"])).view({1, -1});
this->params.d_gains = torch::tensor(ReadVectorFromYaml<float>(config["d_gains"])).view({1, -1});
this->params.torque_limits = torch::tensor(ReadVectorFromYaml<float>(config["torque_limits"])).view({1, -1});
this->params.default_dof_pos = torch::tensor(ReadVectorFromYaml<float>(config["default_dof_pos"])).view({1, -1});
this->params.p_gains = torch::tensor(ReadVectorFromYaml<double>(config["p_gains"])).view({1, -1});
this->params.d_gains = torch::tensor(ReadVectorFromYaml<double>(config["d_gains"])).view({1, -1});
this->params.torque_limits = torch::tensor(ReadVectorFromYaml<double>(config["torque_limits"])).view({1, -1});
this->params.default_dof_pos = torch::tensor(ReadVectorFromYaml<double>(config["default_dof_pos"])).view({1, -1});
this->params.joint_names = ReadVectorFromYaml<std::string>(config["joint_names"]);
}
@ -110,14 +113,22 @@ void RL::InitObservations()
this->obs.commands = torch::tensor({{0.0, 0.0, 0.0}});
this->obs.base_quat = torch::tensor({{0.0, 0.0, 0.0, 1.0}});
this->obs.dof_pos = this->params.default_dof_pos;
this->obs.dof_vel = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
this->obs.actions = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
this->obs.dof_vel = torch::zeros({1, params.num_of_dofs});
this->obs.actions = torch::zeros({1, params.num_of_dofs});
}
void RL::InitOutputs()
{
output_torques = torch::tensor({{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}});
output_dof_pos = params.default_dof_pos;
this->output_torques = torch::zeros({1, params.num_of_dofs});
this->output_dof_pos = params.default_dof_pos;
}
void RL::InitKeyboard()
{
this->keyboard.keyboard_state = STATE_WAITING;
this->keyboard.x = 0.0;
this->keyboard.y = 0.0;
this->keyboard.yaw = 0.0;
}
torch::Tensor RL::ComputeTorques(torch::Tensor actions)
@ -169,3 +180,172 @@ torch::Tensor RL::Forward()
return clamped;
}
*/
static bool kbhit()
{
termios term;
tcgetattr(0, &term);
termios term2 = term;
term2.c_lflag &= ~ICANON;
tcsetattr(0, TCSANOW, &term2);
int byteswaiting;
ioctl(0, FIONREAD, &byteswaiting);
tcsetattr(0, TCSANOW, &term);
return byteswaiting > 0;
}
void RL::run_keyboard()
{
int c;
// Check for keyboard input
while(true)
{
if(kbhit())
{
c = fgetc(stdin);
switch(c)
{
case '0': keyboard.keyboard_state = STATE_POS_GETUP; break;
case 'p': keyboard.keyboard_state = STATE_RL_INIT; break;
case '1': keyboard.keyboard_state = STATE_POS_GETDOWN; break;
case 'q': break;
case 'w': keyboard.x += 0.1; break;
case 's': keyboard.x -= 0.1; break;
case 'a': keyboard.yaw += 0.1; break;
case 'd': keyboard.yaw -= 0.1; break;
case 'i': break;
case 'k': break;
case 'j': keyboard.y += 0.1; break;
case 'l': keyboard.y -= 0.1; break;
case ' ': keyboard.x = 0; keyboard.y = 0; keyboard.yaw = 0; break;
default: break;
}
}
usleep(10000);
}
}
void RL::StateController(const RobotState<double> *state, RobotCommand<double> *command)
{
// waiting
if(running_state == STATE_WAITING)
{
for(int i = 0; i < params.num_of_dofs; ++i)
{
command->motor_command.q[i] = state->motor_state.q[i];
}
if(keyboard.keyboard_state == STATE_POS_GETUP)
{
keyboard.keyboard_state = STATE_WAITING;
getup_percent = 0.0;
for(int i = 0; i < params.num_of_dofs; ++i)
{
now_pos[i] = state->motor_state.q[i];
start_pos[i] = now_pos[i];
}
running_state = STATE_POS_GETUP;
}
}
// stand up (position control)
else if(running_state == STATE_POS_GETUP)
{
if(getup_percent != 1)
{
getup_percent += 1 / 1000.0;
getup_percent = getup_percent > 1 ? 1 : getup_percent;
for(int i = 0; i < params.num_of_dofs; ++i)
{
command->motor_command.q[i] = (1 - getup_percent) * now_pos[i] + getup_percent * params.default_dof_pos[0][i].item<double>();
command->motor_command.dq[i] = 0;
command->motor_command.kp[i] = 200;
command->motor_command.kd[i] = 10;
command->motor_command.tau[i] = 0;
}
printf("getting up %.3f%%\r", getup_percent*100.0);
}
if(keyboard.keyboard_state == STATE_RL_INIT)
{
keyboard.keyboard_state = STATE_WAITING;
running_state = STATE_RL_INIT;
}
else if(keyboard.keyboard_state == STATE_POS_GETDOWN)
{
keyboard.keyboard_state = STATE_WAITING;
getdown_percent = 0.0;
for(int i = 0; i < params.num_of_dofs; ++i)
{
now_pos[i] = state->motor_state.q[i];
}
running_state = STATE_POS_GETDOWN;
}
}
// init obs and start rl loop
else if(running_state == STATE_RL_INIT)
{
if(getup_percent == 1)
{
running_state = STATE_RL_RUNNING;
this->InitObservations();
this->InitOutputs();
this->InitKeyboard();
// printf("\nstart rl loop\n");
// loop_rl->start();
}
}
// rl loop
else if(running_state == STATE_RL_RUNNING)
{
for(int i = 0; i < params.num_of_dofs; ++i)
{
command->motor_command.q[i] = output_dof_pos[0][i].item<double>();
command->motor_command.dq[i] = 0;
// command->motor_command.kp[i] = params.stiffness;
// command->motor_command.kd[i] = params.damping;
command->motor_command.kp[i] = params.p_gains[0][i].item<double>();
command->motor_command.kd[i] = params.d_gains[0][i].item<double>();
// command->motor_command.tau[i] = output_torques[0][i].item<double>();
command->motor_command.tau[i] = 0;
}
if(keyboard.keyboard_state == STATE_POS_GETDOWN)
{
keyboard.keyboard_state = STATE_WAITING;
getdown_percent = 0.0;
for(int i = 0; i < params.num_of_dofs; ++i)
{
now_pos[i] = state->motor_state.q[i];
}
running_state = STATE_POS_GETDOWN;
}
}
// get down (position control)
else if(running_state == STATE_POS_GETDOWN)
{
if(getdown_percent != 1)
{
getdown_percent += 1 / 1000.0;
getdown_percent = getdown_percent > 1 ? 1 : getdown_percent;
for(int i = 0; i < params.num_of_dofs; ++i)
{
command->motor_command.q[i] = (1 - getdown_percent) * now_pos[i] + getdown_percent * start_pos[i];
command->motor_command.dq[i] = 0;
command->motor_command.kp[i] = 200;
command->motor_command.kd[i] = 10;
command->motor_command.tau[i] = 0;
}
printf("getting down %.3f%%\r", getdown_percent*100.0);
}
if(getdown_percent == 1)
{
running_state = STATE_WAITING;
this->InitObservations();
this->InitOutputs();
this->InitKeyboard();
// printf("\nstop rl loop\n");
// loop_rl->shutdown();
}
}
}

View File

@ -11,22 +11,71 @@ namespace plt = matplotlibcpp;
#include <yaml-cpp/yaml.h>
#define CONFIG_PATH CMAKE_CURRENT_SOURCE_DIR "/config.yaml"
template<typename T>
struct RobotCommand
{
struct MotorCommand
{
std::vector<T> q = std::vector<T>(32, 0.0);
std::vector<T> dq = std::vector<T>(32, 0.0);
std::vector<T> tau = std::vector<T>(32, 0.0);
std::vector<T> kp = std::vector<T>(32, 0.0);
std::vector<T> kd = std::vector<T>(32, 0.0);
} motor_command;
};
template<typename T>
struct RobotState
{
struct IMU
{
T quaternion[4] = {1.0, 0.0, 0.0, 0.0}; // w, x, y, z
T gyroscope[3] = {0.0, 0.0, 0.0};
T accelerometer[3] = {0.0, 0.0, 0.0};
} imu;
struct MotorState
{
std::vector<T> q = std::vector<T>(32, 0.0);
std::vector<T> dq = std::vector<T>(32, 0.0);
std::vector<T> ddq = std::vector<T>(32, 0.0);
std::vector<T> tauEst = std::vector<T>(32, 0.0);
std::vector<T> cur = std::vector<T>(32, 0.0);
} motor_state;
};
enum STATE {
STATE_WAITING = 0,
STATE_POS_GETUP,
STATE_RL_INIT,
STATE_RL_RUNNING,
STATE_POS_GETDOWN,
};
struct KeyBoard
{
STATE keyboard_state;
double x = 0.0;
double y = 0.0;
double yaw = 0.0;
};
struct ModelParams
{
std::string model_name;
int num_observations;
float damping;
float stiffness;
float action_scale;
float hip_scale_reduction;
double damping;
double stiffness;
double action_scale;
double hip_scale_reduction;
std::vector<int> hip_scale_reduction_indices;
int num_of_dofs;
float lin_vel_scale;
float ang_vel_scale;
float dof_pos_scale;
float dof_vel_scale;
float clip_obs;
float clip_actions;
double lin_vel_scale;
double ang_vel_scale;
double dof_pos_scale;
double dof_vel_scale;
double clip_obs;
double clip_actions;
torch::Tensor torque_limits;
torch::Tensor d_gains;
torch::Tensor p_gains;
@ -62,10 +111,26 @@ public:
torch::Tensor QuatRotateInverse(torch::Tensor q, torch::Tensor v);
void InitObservations();
void InitOutputs();
void InitKeyboard();
void ReadYaml(std::string robot_name);
std::string csv_filename;
void CSVInit(std::string robot_name);
void CSVLogger(torch::Tensor torque, torch::Tensor tau_est, torch::Tensor joint_pos, torch::Tensor joint_pos_target, torch::Tensor joint_vel);
void run_keyboard();
float getup_percent = 0.0;
float getdown_percent = 0.0;
std::vector<double> start_pos;
std::vector<double> now_pos;
int running_state = STATE_WAITING;
RobotState<double> robot_state;
RobotCommand<double> robot_command;
virtual void GetState(RobotState<double> *state) = 0;
virtual void SetCommand(const RobotCommand<double> *command) = 0;
void StateController(const RobotState<double> *state, RobotCommand<double> *command);
protected:
// rl module
@ -82,6 +147,8 @@ protected:
// output buffer
torch::Tensor output_torques;
torch::Tensor output_dof_pos;
// keyboard
KeyBoard keyboard;
};
#endif // RL_SDK_HPP

View File

@ -3,7 +3,7 @@
#define ROBOT_NAME "cyberdog1"
// #define PLOT
#define CSV_LOGGER
// #define CSV_LOGGER
RL_Real rl_sar;
@ -202,53 +202,6 @@ void RL_Real::UserCode()
motor_cmd = cyberdogCmd;
}
static bool kbhit()
{
termios term;
tcgetattr(0, &term);
termios term2 = term;
term2.c_lflag &= ~ICANON;
tcsetattr(0, TCSANOW, &term2);
int byteswaiting;
ioctl(0, FIONREAD, &byteswaiting);
tcsetattr(0, TCSANOW, &term);
return byteswaiting > 0;
}
void RL_Real::run_keyboard()
{
int c;
// Check for keyboard input
while(true)
{
if(kbhit())
{
c = fgetc(stdin);
switch(c)
{
case '0': keyboard.robot_state = STATE_POS_GETUP; break;
case 'p': keyboard.robot_state = STATE_RL_INIT; break;
case '1': keyboard.robot_state = STATE_POS_GETDOWN; break;
case 'q': break;
case 'w': keyboard.x += 0.5; break;
case 's': keyboard.x -= 0.5; break;
case 'a': keyboard.yaw += 0.5; break;
case 'd': keyboard.yaw -= 0.5; break;
case 'i': break;
case 'k': break;
case 'j': keyboard.y += 0.5; break;
case 'l': keyboard.y -= 0.5; break;
case ' ': keyboard.x = 0; keyboard.y = 0; keyboard.yaw = 0; break;
default: break;
}
}
usleep(10000);
}
}
void RL_Real::RunModel()
{
if(robot_state == STATE_RL_RUNNING)

View File

@ -1,10 +1,11 @@
#include "../include/rl_sim.hpp"
#define ROBOT_NAME "a1"
// #define ROBOT_NAME "a1"
#define ROBOT_NAME "gr1t1"
// #define PLOT
// #define CSV_LOGGER
#define USE_HISTORY
// #define USE_HISTORY
RL_Sim::RL_Sim()
{
@ -27,12 +28,15 @@ RL_Sim::RL_Sim()
cmd_vel = geometry_msgs::Twist();
motor_commands.resize(params.num_of_dofs);
start_pos.resize(params.num_of_dofs);
now_pos.resize(params.num_of_dofs);
std::string model_path = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + ROBOT_NAME + "/" + this->params.model_name;
this->model = torch::jit::load(model_path);
this->InitObservations();
this->InitOutputs();
this->InitKeyboard();
#ifdef USE_HISTORY
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6);
@ -73,6 +77,7 @@ RL_Sim::RL_Sim()
loop_plot = std::make_shared<UNITREE_LEGGED_SDK::LoopFunc>("loop_plot" , 0.002, boost::bind(&RL_Sim::Plot, this));
loop_plot->start();
#endif
_keyboardThread = std::thread(&RL_Sim::run_keyboard, this);
#ifdef CSV_LOGGER
CSVInit(ROBOT_NAME);
@ -89,24 +94,57 @@ RL_Sim::~RL_Sim()
printf("exit\n");
}
void RL_Sim::RobotControl()
void RL_Sim::GetState(RobotState<double> *state)
{
motiontime++;
for (int i = 0; i < params.num_of_dofs; ++i)
{
motor_commands[i].q = output_dof_pos[0][i].item<double>();
motor_commands[i].dq = 0;
// motor_commands[i].Kp = params.stiffness;
// motor_commands[i].Kd = params.damping;
motor_commands[i].kp = params.p_gains[0][i].item<double>();
motor_commands[i].kd = params.d_gains[0][i].item<double>();
// motor_commands[i].tau = output_torques[0][i].item<double>();
motor_commands[i].tau = 0;
state->imu.quaternion[0] = pose.orientation.w;
state->imu.quaternion[1] = pose.orientation.x;
state->imu.quaternion[2] = pose.orientation.y;
state->imu.quaternion[3] = pose.orientation.z;
state->imu.gyroscope[0] = vel.angular.x;
state->imu.gyroscope[1] = vel.angular.y;
state->imu.gyroscope[2] = vel.angular.z;
// state->imu.accelerometer
for(int i = 0; i < params.num_of_dofs; ++i)
{
state->motor_state.q[i] = joint_positions[i];
state->motor_state.dq[i] = joint_velocities[i];
state->motor_state.tauEst[i] = joint_efforts[i];
}
}
void RL_Sim::SetCommand(const RobotCommand<double> *command)
{
for(int i = 0; i < params.num_of_dofs; ++i)
{
motor_commands[i].q = command->motor_command.q[i];
motor_commands[i].dq = command->motor_command.dq[i];
motor_commands[i].kp = command->motor_command.kp[i];
motor_commands[i].kd = command->motor_command.kd[i];
motor_commands[i].tau = command->motor_command.tau[i];
}
for(int i = 0; i < params.num_of_dofs; ++i)
{
torque_publishers[params.joint_names[i]].publish(motor_commands[i]);
}
}
void RL_Sim::RobotControl()
{
std::cout << "running_state " << keyboard.keyboard_state
<< " x" << keyboard.x << " y" << keyboard.y << " yaw" << keyboard.yaw
<< " \r";
motiontime++;
GetState(&robot_state);
StateController(&robot_state, &robot_command);
SetCommand(&robot_command);
}
void RL_Sim::ModelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg)
{
vel = msg->twist[2];
@ -135,27 +173,31 @@ void RL_Sim::JointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg)
void RL_Sim::RunModel()
{
if(running_state == STATE_RL_RUNNING)
{
// this->obs.lin_vel = torch::tensor({{vel.linear.x, vel.linear.y, vel.linear.z}});
this->obs.ang_vel = torch::tensor({{vel.angular.x, vel.angular.y, vel.angular.z}});
this->obs.commands = torch::tensor({{cmd_vel.linear.x, cmd_vel.linear.y, cmd_vel.angular.z}});
// this->obs.commands = torch::tensor({{cmd_vel.linear.x, cmd_vel.linear.y, cmd_vel.angular.z}});
this->obs.commands = torch::tensor({{keyboard.x, keyboard.y, keyboard.yaw}});
this->obs.base_quat = torch::tensor({{pose.orientation.x, pose.orientation.y, pose.orientation.z, pose.orientation.w}});
this->obs.dof_pos = torch::tensor(joint_positions).unsqueeze(0);
this->obs.dof_vel = torch::tensor(joint_velocities).unsqueeze(0);
torch::Tensor actions = this->Forward();
torch::Tensor clamped_actions = this->Forward();
for (int i : this->params.hip_scale_reduction_indices)
{
actions[0][i] *= this->params.hip_scale_reduction;
clamped_actions[0][i] *= this->params.hip_scale_reduction;
}
output_torques = this->ComputeTorques(actions);
output_dof_pos = this->ComputePosition(actions);
output_torques = this->ComputeTorques(clamped_actions);
output_dof_pos = this->ComputePosition(clamped_actions);
#ifdef CSV_LOGGER
torch::Tensor tau_est = torch::tensor(joint_efforts).unsqueeze(0);
CSVLogger(output_torques, tau_est, this->obs.dof_pos, output_dof_pos, this->obs.dof_vel);
#endif
}
}
torch::Tensor RL_Sim::ComputeObservation()
@ -181,13 +223,13 @@ torch::Tensor RL_Sim::Forward()
history_obs = history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5});
torch::Tensor action = this->model.forward({history_obs}).toTensor();
#else
torch::Tensor action = this->model.forward({obs}).toTensor();
torch::Tensor actions = this->model.forward({obs}).toTensor();
#endif
this->obs.actions = action;
torch::Tensor clamped = torch::clamp(action, -this->params.clip_actions, this->params.clip_actions);
this->obs.actions = actions;
torch::Tensor clamped_actions = torch::clamp(actions, -this->params.clip_actions, this->params.clip_actions);
return clamped;
return clamped_actions;
}
void RL_Sim::Plot()