diff --git a/src/rl_sar/library/core/rl_sdk/rl_sdk.cpp b/src/rl_sar/library/core/rl_sdk/rl_sdk.cpp index 0793677..ed47adc 100644 --- a/src/rl_sar/library/core/rl_sdk/rl_sdk.cpp +++ b/src/rl_sar/library/core/rl_sdk/rl_sdk.cpp @@ -103,12 +103,12 @@ void RL::InitControl() void RL::ComputeOutput(const torch::Tensor &actions, torch::Tensor &output_dof_pos, torch::Tensor &output_dof_vel, torch::Tensor &output_dof_tau) { torch::Tensor actions_scaled = actions * this->params.action_scale; - torch::Tensor pos_actions_scaled = actions_scaled; + torch::Tensor pos_actions_scaled = actions_scaled.clone(); torch::Tensor vel_actions_scaled = torch::zeros_like(actions); for (int i : this->params.wheel_indices) { pos_actions_scaled[0][i] = 0.0; - vel_actions_scaled[0][i] = actions[0][i]; + vel_actions_scaled[0][i] = actions_scaled[0][i]; } torch::Tensor all_actions_scaled = pos_actions_scaled + vel_actions_scaled; output_dof_pos = pos_actions_scaled + this->params.default_dof_pos;