Compare commits

...

2 Commits
v2.3 ... main

Author SHA1 Message Date
fan-ziqi d933e72fa8 fix: clone actions_scaled tensor for output calculations 2025-03-14 15:53:11 +08:00
fan-ziqi a7611f9778 fix: update library paths 2025-03-14 15:03:00 +08:00
2 changed files with 3 additions and 3 deletions

View File

@ -22,4 +22,4 @@ repos:
- --comment-style - --comment-style
- /*| *| */ - /*| *| */
exclude: ^(src/rl_sar/library/matplotlibcpp|src/rl_sar/library/unitree_legged_sdk_3.2|src/rl_sar/library/unitree_sdk2)/ exclude: ^(src/rl_sar/library/core/matplotlibcpp|src/rl_sar/library/thirdparty)/

View File

@ -103,12 +103,12 @@ void RL::InitControl()
void RL::ComputeOutput(const torch::Tensor &actions, torch::Tensor &output_dof_pos, torch::Tensor &output_dof_vel, torch::Tensor &output_dof_tau) void RL::ComputeOutput(const torch::Tensor &actions, torch::Tensor &output_dof_pos, torch::Tensor &output_dof_vel, torch::Tensor &output_dof_tau)
{ {
torch::Tensor actions_scaled = actions * this->params.action_scale; torch::Tensor actions_scaled = actions * this->params.action_scale;
torch::Tensor pos_actions_scaled = actions_scaled; torch::Tensor pos_actions_scaled = actions_scaled.clone();
torch::Tensor vel_actions_scaled = torch::zeros_like(actions); torch::Tensor vel_actions_scaled = torch::zeros_like(actions);
for (int i : this->params.wheel_indices) for (int i : this->params.wheel_indices)
{ {
pos_actions_scaled[0][i] = 0.0; pos_actions_scaled[0][i] = 0.0;
vel_actions_scaled[0][i] = actions[0][i]; vel_actions_scaled[0][i] = actions_scaled[0][i];
} }
torch::Tensor all_actions_scaled = pos_actions_scaled + vel_actions_scaled; torch::Tensor all_actions_scaled = pos_actions_scaled + vel_actions_scaled;
output_dof_pos = pos_actions_scaled + this->params.default_dof_pos; output_dof_pos = pos_actions_scaled + this->params.default_dof_pos;