feat: 1. use `tbb::concurrent_queue` instead of `std::mutex`

2. wheel robot support
3. add `mode` in `RobotCommand::MotorCommand`
This commit is contained in:
fan-ziqi 2025-01-11 21:52:33 +08:00
parent 2599d2ba5b
commit 4575a5b62d
15 changed files with 171 additions and 89 deletions

View File

@ -44,6 +44,12 @@ Install `yaml-cpp` and `lcm`. If you are using Ubuntu, you can directly use the
sudo apt install liblcm-dev libyaml-cpp-dev
```
This project uses the Intel TBB (Threading Building Blocks) library to implement data exchange between different threads. If you use Ubuntu, you can directly use the package manager to install it
```bash
sudo apt install libtbb-dev
```
<details>
<summary>You can also use source code installation, click to expand</summary>

View File

@ -44,6 +44,12 @@ echo 'export Torch_DIR=/path/to/your/torchlib' >> ~/.bashrc
sudo apt install liblcm-dev libyaml-cpp-dev
```
本项目使用Intel TBBThreading Building Blocks库进行线程间数据交换若您使用Ubuntu可直接使用包管理器进行安装
```bash
sudo apt install libtbb-dev
```
<details>
<summary>也可以使用源码安装,点击展开</summary>

View File

@ -14,6 +14,8 @@ find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${GAZEBO_CXX_FLAGS}")
find_package(gazebo REQUIRED)
find_package(TBB REQUIRED)
find_package(catkin REQUIRED COMPONENTS
controller_manager
genmsg
@ -63,7 +65,7 @@ include_directories(
)
add_library(rl_sdk library/rl_sdk/rl_sdk.cpp)
target_link_libraries(rl_sdk "${TORCH_LIBRARIES}" Python3::Python Python3::Module)
target_link_libraries(rl_sdk "${TORCH_LIBRARIES}" Python3::Python Python3::Module TBB::tbb)
set_property(TARGET rl_sdk PROPERTY CXX_STANDARD 14)
find_package(Python3 COMPONENTS NumPy)
if(Python3_NumPy_FOUND)

View File

@ -51,7 +51,12 @@ torch::Tensor RL::ComputeObservation()
}
else if (observation == "dof_pos")
{
obs_list.push_back((this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale);
torch::Tensor dof_pos_rel = this->obs.dof_pos - this->params.default_dof_pos;
for (int i : this->params.wheel_indices)
{
dof_pos_rel[0][i] = 0.0;
}
obs_list.push_back(dof_pos_rel * this->params.dof_pos_scale);
}
else if (observation == "dof_vel")
{
@ -82,8 +87,9 @@ void RL::InitObservations()
void RL::InitOutputs()
{
this->output_torques = torch::zeros({1, this->params.num_of_dofs});
this->output_dof_tau = torch::zeros({1, this->params.num_of_dofs});
this->output_dof_pos = this->params.default_dof_pos;
this->output_dof_vel = torch::zeros({1, this->params.num_of_dofs});
}
void RL::InitControl()
@ -94,17 +100,20 @@ void RL::InitControl()
this->control.yaw = 0.0;
}
torch::Tensor RL::ComputeTorques(torch::Tensor actions)
void RL::ComputeOutput(const torch::Tensor &actions, torch::Tensor &output_dof_pos, torch::Tensor &output_dof_vel, torch::Tensor &output_dof_tau)
{
torch::Tensor actions_scaled = actions * this->params.action_scale;
torch::Tensor output_torques = this->params.rl_kp * (actions_scaled + this->params.default_dof_pos - this->obs.dof_pos) - this->params.rl_kd * this->obs.dof_vel;
return output_torques;
}
torch::Tensor RL::ComputePosition(torch::Tensor actions)
{
torch::Tensor actions_scaled = actions * this->params.action_scale;
return actions_scaled + this->params.default_dof_pos;
torch::Tensor joint_actions_scaled = actions * this->params.action_scale;
torch::Tensor wheel_actions_scaled = torch::zeros({1, this->params.num_of_dofs});
for (int i : this->params.wheel_indices)
{
joint_actions_scaled[0][i] = 0.0;
wheel_actions_scaled[0][i] = actions[0][i] * this->params.action_scale_wheel;
}
torch::Tensor actions_scaled = joint_actions_scaled + wheel_actions_scaled;
output_dof_pos = joint_actions_scaled + this->params.default_dof_pos;
output_dof_vel = wheel_actions_scaled;
output_dof_tau = this->params.rl_kp * (actions_scaled + this->params.default_dof_pos - this->obs.dof_pos) - this->params.rl_kd * this->obs.dof_vel;
output_dof_tau = torch::clamp(output_dof_tau, -(this->params.torque_limits), this->params.torque_limits);
}
torch::Tensor RL::QuatRotateInverse(torch::Tensor q, torch::Tensor v, const std::string &framework)
@ -173,22 +182,25 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
}
std::cout << "\r" << std::flush << LOGGER::INFO << "Getting up " << std::fixed << std::setprecision(2) << getup_percent * 100.0 << std::flush;
}
if (this->control.control_state == STATE_RL_INIT)
else
{
this->control.control_state = STATE_WAITING;
this->running_state = STATE_RL_INIT;
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_RL_INIT" << std::endl;
}
else if (this->control.control_state == STATE_POS_GETDOWN)
{
this->control.control_state = STATE_WAITING;
getdown_percent = 0.0;
for (int i = 0; i < this->params.num_of_dofs; ++i)
if (this->control.control_state == STATE_RL_INIT)
{
now_state.motor_state.q[i] = state->motor_state.q[i];
this->control.control_state = STATE_WAITING;
this->running_state = STATE_RL_INIT;
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_RL_INIT" << std::endl;
}
else if (this->control.control_state == STATE_POS_GETDOWN)
{
this->control.control_state = STATE_WAITING;
getdown_percent = 0.0;
for (int i = 0; i < this->params.num_of_dofs; ++i)
{
now_state.motor_state.q[i] = state->motor_state.q[i];
}
this->running_state = STATE_POS_GETDOWN;
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_POS_GETDOWN" << std::endl;
}
this->running_state = STATE_POS_GETDOWN;
std::cout << std::endl << LOGGER::INFO << "Switching to STATE_POS_GETDOWN" << std::endl;
}
}
// init obs and start rl loop
@ -207,13 +219,24 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
else if (this->running_state == STATE_RL_RUNNING)
{
std::cout << "\r" << std::flush << LOGGER::INFO << "RL Controller x:" << this->control.x << " y:" << this->control.y << " yaw:" << this->control.yaw << std::flush;
for (int i = 0; i < this->params.num_of_dofs; ++i)
torch::Tensor _output_dof_pos, _output_dof_vel;
if (this->output_dof_pos_queue.try_pop(_output_dof_pos) && this->output_dof_vel_queue.try_pop(_output_dof_vel))
{
command->motor_command.q[i] = this->output_dof_pos[0][i].item<double>();
command->motor_command.dq[i] = 0;
command->motor_command.kp[i] = this->params.rl_kp[0][i].item<double>();
command->motor_command.kd[i] = this->params.rl_kd[0][i].item<double>();
command->motor_command.tau[i] = 0;
for (int i = 0; i < this->params.num_of_dofs; ++i)
{
if (_output_dof_pos.defined() && _output_dof_pos.numel() > 0)
{
command->motor_command.q[i] = this->output_dof_pos[0][i].item<double>();
}
if (_output_dof_vel.defined() && _output_dof_vel.numel() > 0)
{
command->motor_command.dq[i] = this->output_dof_vel[0][i].item<double>();
}
command->motor_command.kp[i] = this->params.rl_kp[0][i].item<double>();
command->motor_command.kd[i] = this->params.rl_kd[0][i].item<double>();
command->motor_command.tau[i] = 0;
}
}
if (this->control.control_state == STATE_POS_GETDOWN)
{
@ -266,13 +289,13 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
}
}
void RL::TorqueProtect(torch::Tensor origin_output_torques)
void RL::TorqueProtect(torch::Tensor origin_output_dof_tau)
{
std::vector<int> out_of_range_indices;
std::vector<double> out_of_range_values;
for (int i = 0; i < origin_output_torques.size(1); ++i)
for (int i = 0; i < origin_output_dof_tau.size(1); ++i)
{
double torque_value = origin_output_torques[0][i].item<double>();
double torque_value = origin_output_dof_tau[0][i].item<double>();
double limit_lower = -this->params.torque_limits[0][i].item<double>();
double limit_upper = this->params.torque_limits[0][i].item<double>();
@ -511,6 +534,8 @@ void RL::ReadYaml(std::string robot_name)
this->params.action_scale = config["action_scale"].as<double>();
this->params.hip_scale_reduction = config["hip_scale_reduction"].as<double>();
this->params.hip_scale_reduction_indices = ReadVectorFromYaml<int>(config["hip_scale_reduction_indices"]);
this->params.action_scale_wheel = config["action_scale_wheel"].as<double>();
this->params.wheel_indices = ReadVectorFromYaml<int>(config["wheel_indices"]);
this->params.num_of_dofs = config["num_of_dofs"].as<int>();
this->params.lin_vel_scale = config["lin_vel_scale"].as<double>();
this->params.ang_vel_scale = config["ang_vel_scale"].as<double>();

View File

@ -10,7 +10,7 @@
#include <iostream>
#include <string>
#include <unistd.h>
#include <mutex>
#include <tbb/concurrent_queue.h>
#include <yaml-cpp/yaml.h>
@ -27,6 +27,7 @@ struct RobotCommand
{
struct MotorCommand
{
std::vector<int> mode = std::vector<int>(32, 0);
std::vector<T> q = std::vector<T>(32, 0.0);
std::vector<T> dq = std::vector<T>(32, 0.0);
std::vector<T> tau = std::vector<T>(32, 0.0);
@ -72,6 +73,7 @@ struct Control
double x = 0.0;
double y = 0.0;
double yaw = 0.0;
double wheel = 0.0;
};
struct ModelParams
@ -88,6 +90,8 @@ struct ModelParams
double action_scale;
double hip_scale_reduction;
std::vector<int> hip_scale_reduction_indices;
double action_scale_wheel;
std::vector<int> wheel_indices;
int num_of_dofs;
double lin_vel_scale;
double ang_vel_scale;
@ -129,7 +133,9 @@ public:
RobotState<double> robot_state;
RobotCommand<double> robot_command;
std::mutex robot_state_mutex;
tbb::concurrent_queue<torch::Tensor> output_dof_pos_queue;
tbb::concurrent_queue<torch::Tensor> output_dof_vel_queue;
tbb::concurrent_queue<torch::Tensor> output_dof_tau_queue;
// init
void InitObservations();
@ -142,8 +148,7 @@ public:
virtual void GetState(RobotState<double> *state) = 0;
virtual void SetCommand(const RobotCommand<double> *command) = 0;
void StateController(const RobotState<double> *state, RobotCommand<double> *command);
torch::Tensor ComputeTorques(torch::Tensor actions);
torch::Tensor ComputePosition(torch::Tensor actions);
void ComputeOutput(const torch::Tensor &actions, torch::Tensor &output_dof_pos, torch::Tensor &output_dof_vel, torch::Tensor &output_dof_tau);
torch::Tensor QuatRotateInverse(torch::Tensor q, torch::Tensor v, const std::string &framework);
// yaml params
@ -164,15 +169,24 @@ public:
bool simulation_running = false;
// protect func
void TorqueProtect(torch::Tensor origin_output_torques);
void TorqueProtect(torch::Tensor origin_output_dof_tau);
void AttitudeProtect(const std::vector<double> &quaternion, float pitch_threshold, float roll_threshold);
protected:
// rl module
torch::jit::script::Module model;
// output buffer
torch::Tensor output_torques;
torch::Tensor output_dof_tau;
torch::Tensor output_dof_pos;
torch::Tensor output_dof_vel;
};
template <typename T>
T clamp(T value, T min, T max)
{
if (value < min) return min;
if (value > max) return max;
return value;
}
#endif // RL_SDK_HPP

View File

@ -40,6 +40,8 @@ a1_isaacgym:
hip_scale_reduction_indices: [0, 3, 6, 9]
num_of_dofs: 12
action_scale: 0.25
action_scale_wheel: 0.0
wheel_indices: []
lin_vel_scale: 2.0
ang_vel_scale: 0.25
dof_pos_scale: 1.0

View File

@ -40,6 +40,8 @@ a1_isaacsim:
hip_scale_reduction_indices: []
num_of_dofs: 12
action_scale: 0.25
action_scale_wheel: 0.0
wheel_indices: []
lin_vel_scale: 2.0
ang_vel_scale: 0.25
dof_pos_scale: 1.0

View File

@ -40,6 +40,8 @@ go2_isaacgym:
hip_scale_reduction_indices: [0, 3, 6, 9]
num_of_dofs: 12
action_scale: 0.25
action_scale_wheel: 0.0
wheel_indices: []
lin_vel_scale: 2.0
ang_vel_scale: 0.25
dof_pos_scale: 1.0

View File

@ -28,6 +28,8 @@ gr1t1_isaacgym:
hip_scale_reduction_indices: []
num_of_dofs: 10
action_scale: 1.0
action_scale_wheel: 0.0
wheel_indices: []
lin_vel_scale: 1.0
ang_vel_scale: 1.0
dof_pos_scale: 1.0

View File

@ -28,6 +28,8 @@ gr1t2_isaacgym:
hip_scale_reduction_indices: []
num_of_dofs: 10
action_scale: 1.0
action_scale_wheel: 0.0
wheel_indices: []
lin_vel_scale: 1.0
ang_vel_scale: 1.0
dof_pos_scale: 1.0

View File

@ -135,7 +135,7 @@ class RL:
self.stand_model = None
# output buffer
self.output_torques = torch.zeros(1, 32)
self.output_dof_tau = torch.zeros(1, 32)
self.output_dof_pos = torch.zeros(1, 32)
def ComputeObservation(self):
@ -179,7 +179,7 @@ class RL:
self.obs.actions = torch.zeros(1, self.params.num_of_dofs, dtype=torch.float)
def InitOutputs(self):
self.output_torques = torch.zeros(1, self.params.num_of_dofs, dtype=torch.float)
self.output_dof_tau = torch.zeros(1, self.params.num_of_dofs, dtype=torch.float)
self.output_dof_pos = self.params.default_dof_pos
def InitControl(self):
@ -190,8 +190,8 @@ class RL:
def ComputeTorques(self, actions):
actions_scaled = actions * self.params.action_scale
output_torques = self.params.rl_kp * (actions_scaled + self.params.default_dof_pos - self.obs.dof_pos) - self.params.rl_kd * self.obs.dof_vel
return output_torques
output_dof_tau = self.params.rl_kp * (actions_scaled + self.params.default_dof_pos - self.obs.dof_pos) - self.params.rl_kd * self.obs.dof_vel
return output_dof_tau
def ComputePosition(self, actions):
actions_scaled = actions * self.params.action_scale
@ -305,12 +305,12 @@ class RL:
self.running_state = STATE.STATE_WAITING
print("\r\n" + LOGGER.INFO + "Switching to STATE_WAITING")
def TorqueProtect(self, origin_output_torques):
def TorqueProtect(self, origin_output_dof_tau):
out_of_range_indices = []
out_of_range_values = []
for i in range(origin_output_torques.size(1)):
torque_value = origin_output_torques[0][i].item()
for i in range(origin_output_dof_tau.size(1)):
torque_value = origin_output_dof_tau[0][i].item()
limit_lower = -self.params.torque_limits[0][i].item()
limit_upper = self.params.torque_limits[0][i].item()

View File

@ -194,18 +194,18 @@ class RL_Sim(RL):
self.obs.actions = clamped_actions
origin_output_torques = self.ComputeTorques(self.obs.actions)
origin_output_dof_tau = self.ComputeTorques(self.obs.actions)
# self.TorqueProtect(origin_output_torques)
# self.TorqueProtect(origin_output_dof_tau)
self.output_torques = torch.clamp(origin_output_torques, -(self.params.torque_limits), self.params.torque_limits)
self.output_dof_tau = torch.clamp(origin_output_dof_tau, -(self.params.torque_limits), self.params.torque_limits)
self.output_dof_pos = self.ComputePosition(self.obs.actions)
if CSV_LOGGER:
tau_est = torch.zeros((1, self.params.num_of_dofs))
for i in range(self.params.num_of_dofs):
tau_est[0, i] = self.joint_efforts[self.params.joint_controller_names[i]]
self.CSVLogger(self.output_torques, tau_est, self.obs.dof_pos, self.output_dof_pos, self.obs.dof_vel)
self.CSVLogger(self.output_dof_tau, tau_est, self.obs.dof_pos, self.output_dof_pos, self.obs.dof_vel)
def Forward(self):
torch.set_grad_enabled(False)

View File

@ -27,6 +27,7 @@ RL_Real::RL_Real() : unitree_safe(UNITREE_LEGGED_SDK::LeggedType::A1), unitree_u
// init rl
torch::autograd::GradMode::set_enabled(false);
torch::set_num_threads(4);
if (!this->params.observations_history.empty())
{
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, this->params.observations_history.size());
@ -81,7 +82,6 @@ RL_Real::~RL_Real()
void RL_Real::GetState(RobotState<double> *state)
{
// TODO-devel-mutex
this->unitree_udp.GetRecv(this->unitree_low_state);
memcpy(&this->unitree_joy, this->unitree_low_state.wirelessRemote, 40);
@ -144,8 +144,6 @@ void RL_Real::SetCommand(const RobotCommand<double> *command)
void RL_Real::RobotControl()
{
// std::lock_guard<std::mutex> lock(robot_state_mutex); // TODO will cause thread timeout
this->motiontime++;
this->GetState(&this->robot_state);
@ -155,8 +153,6 @@ void RL_Real::RobotControl()
void RL_Real::RunModel()
{
// std::lock_guard<std::mutex> lock(robot_state_mutex); // TODO will cause thread timeout
if (this->running_state == STATE_RL_RUNNING)
{
this->obs.ang_vel = torch::tensor(this->robot_state.imu.gyroscope).unsqueeze(0);
@ -168,24 +164,34 @@ void RL_Real::RunModel()
torch::Tensor clamped_actions = this->Forward();
this->obs.actions = clamped_actions;
for (int i : this->params.hip_scale_reduction_indices)
{
clamped_actions[0][i] *= this->params.hip_scale_reduction;
}
this->obs.actions = clamped_actions;
this->ComputeOutput(this->obs.actions, this->output_dof_pos, this->output_dof_vel, this->output_dof_tau);
torch::Tensor origin_output_torques = this->ComputeTorques(this->obs.actions);
if (this->output_dof_pos.defined() && this->output_dof_pos.numel() > 0)
{
output_dof_pos_queue.push(this->output_dof_pos);
}
if (this->output_dof_vel.defined() && this->output_dof_vel.numel() > 0)
{
output_dof_vel_queue.push(this->output_dof_vel);
}
if (this->output_dof_tau.defined() && this->output_dof_tau.numel() > 0)
{
output_dof_tau_queue.push(this->output_dof_tau);
}
this->TorqueProtect(origin_output_torques);
this->AttitudeProtect(this->robot_state.imu.quaternion, 60.0f, 60.0f);
this->output_torques = torch::clamp(origin_output_torques, -(this->params.torque_limits), this->params.torque_limits);
this->output_dof_pos = this->ComputePosition(this->obs.actions);
this->TorqueProtect(this->output_dof_tau);
this->AttitudeProtect(this->robot_state.imu.quaternion, 75.0f, 75.0f);
#ifdef CSV_LOGGER
torch::Tensor tau_est = torch::tensor(this->robot_state.motor_state.tau_est).unsqueeze(0);
this->CSVLogger(this->output_torques, tau_est, this->obs.dof_pos, this->output_dof_pos, this->obs.dof_vel);
this->CSVLogger(this->output_dof_tau, tau_est, this->obs.dof_pos, this->output_dof_pos, this->obs.dof_vel);
#endif
}
}

View File

@ -43,6 +43,7 @@ RL_Real::RL_Real()
// init rl
torch::autograd::GradMode::set_enabled(false);
torch::set_num_threads(4);
if (!this->params.observations_history.empty())
{
this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, this->params.observations_history.size());
@ -149,8 +150,6 @@ void RL_Real::SetCommand(const RobotCommand<double> *command)
void RL_Real::RobotControl()
{
// std::lock_guard<std::mutex> lock(robot_state_mutex); // TODO will cause thread timeout
this->motiontime++;
this->GetState(&this->robot_state);
@ -160,8 +159,6 @@ void RL_Real::RobotControl()
void RL_Real::RunModel()
{
// std::lock_guard<std::mutex> lock(robot_state_mutex); // TODO will cause thread timeout
if (this->running_state == STATE_RL_RUNNING)
{
this->obs.ang_vel = torch::tensor(this->robot_state.imu.gyroscope).unsqueeze(0);
@ -173,24 +170,34 @@ void RL_Real::RunModel()
torch::Tensor clamped_actions = this->Forward();
this->obs.actions = clamped_actions;
for (int i : this->params.hip_scale_reduction_indices)
{
clamped_actions[0][i] *= this->params.hip_scale_reduction;
}
this->obs.actions = clamped_actions;
this->ComputeOutput(this->obs.actions, this->output_dof_pos, this->output_dof_vel, this->output_dof_tau);
torch::Tensor origin_output_torques = this->ComputeTorques(this->obs.actions);
if (this->output_dof_pos.defined() && this->output_dof_pos.numel() > 0)
{
output_dof_pos_queue.push(this->output_dof_pos);
}
if (this->output_dof_vel.defined() && this->output_dof_vel.numel() > 0)
{
output_dof_vel_queue.push(this->output_dof_vel);
}
if (this->output_dof_tau.defined() && this->output_dof_tau.numel() > 0)
{
output_dof_tau_queue.push(this->output_dof_tau);
}
this->TorqueProtect(origin_output_torques);
this->AttitudeProtect(this->robot_state.imu.quaternion, 60.0f, 60.0f);
this->output_torques = torch::clamp(origin_output_torques, -(this->params.torque_limits), this->params.torque_limits);
this->output_dof_pos = this->ComputePosition(this->obs.actions);
this->TorqueProtect(this->output_dof_tau);
this->AttitudeProtect(this->robot_state.imu.quaternion, 75.0f, 75.0f);
#ifdef CSV_LOGGER
torch::Tensor tau_est = torch::tensor(this->robot_state.motor_state.tau_est).unsqueeze(0);
this->CSVLogger(this->output_torques, tau_est, this->obs.dof_pos, this->output_dof_pos, this->obs.dof_vel);
this->CSVLogger(this->output_dof_tau, tau_est, this->obs.dof_pos, this->output_dof_pos, this->obs.dof_vel);
#endif
}
}
@ -361,7 +368,7 @@ int main(int argc, char **argv)
while (1)
{
sleep(10);
};
}
return 0;
}

View File

@ -166,8 +166,6 @@ void RL_Sim::SetCommand(const RobotCommand<double> *command)
void RL_Sim::RobotControl()
{
// std::lock_guard<std::mutex> lock(robot_state_mutex); // TODO will cause thread timeout
if (this->control.control_state == STATE_RESET_SIMULATION)
{
gazebo_msgs::SetModelState set_model_state;
@ -260,8 +258,6 @@ void RL_Sim::JointStatesCallback(const robot_msgs::MotorState::ConstPtr &msg, co
void RL_Sim::RunModel()
{
// std::lock_guard<std::mutex> lock(robot_state_mutex); // TODO will cause thread timeout
if (this->running_state == STATE_RL_RUNNING && simulation_running)
{
this->obs.lin_vel = torch::tensor({{this->vel.linear.x, this->vel.linear.y, this->vel.linear.z}});
@ -274,19 +270,29 @@ void RL_Sim::RunModel()
torch::Tensor clamped_actions = this->Forward();
this->obs.actions = clamped_actions;
for (int i : this->params.hip_scale_reduction_indices)
{
clamped_actions[0][i] *= this->params.hip_scale_reduction;
}
this->obs.actions = clamped_actions;
this->ComputeOutput(this->obs.actions, this->output_dof_pos, this->output_dof_vel, this->output_dof_tau);
torch::Tensor origin_output_torques = this->ComputeTorques(this->obs.actions);
if (this->output_dof_pos.defined() && this->output_dof_pos.numel() > 0)
{
output_dof_pos_queue.push(this->output_dof_pos);
}
if (this->output_dof_vel.defined() && this->output_dof_vel.numel() > 0)
{
output_dof_vel_queue.push(this->output_dof_vel);
}
if (this->output_dof_tau.defined() && this->output_dof_tau.numel() > 0)
{
output_dof_tau_queue.push(this->output_dof_tau);
}
// this->TorqueProtect(origin_output_torques);
this->output_torques = torch::clamp(origin_output_torques, -(this->params.torque_limits), this->params.torque_limits);
this->output_dof_pos = this->ComputePosition(this->obs.actions);
// this->TorqueProtect(this->output_dof_tau);
#ifdef CSV_LOGGER
torch::Tensor tau_est = torch::zeros({1, this->params.num_of_dofs});
@ -294,7 +300,7 @@ void RL_Sim::RunModel()
{
tau_est[0][i] = this->joint_efforts[this->params.joint_controller_names[i]];
}
this->CSVLogger(this->output_torques, tau_est, this->obs.dof_pos, this->output_dof_pos, this->obs.dof_vel);
this->CSVLogger(this->output_dof_tau, tau_est, this->obs.dof_pos, this->output_dof_pos, this->obs.dof_vel);
#endif
}
}