diff --git a/src/rl_sar/library/rl/rl.cpp b/src/rl_sar/library/rl/rl.cpp index f93e9d3..7aef355 100644 --- a/src/rl_sar/library/rl/rl.cpp +++ b/src/rl_sar/library/rl/rl.cpp @@ -48,6 +48,48 @@ void RL::ReadYaml(std::string robot_name) config["default_dof_pos"][9].as(), config["default_dof_pos"][10].as(), config["default_dof_pos"][11].as()}}); } +void RL::CSVInit(std::string robot_name) +{ + csv_filename = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + robot_name + "/motor"; + + // // Uncomment these lines if need timestamp for file name + // auto now = std::chrono::system_clock::now(); + // std::time_t now_c = std::chrono::system_clock::to_time_t(now); + // std::stringstream ss; + // ss << std::put_time(std::localtime(&now_c), "%Y%m%d%H%M%S"); + // std::string timestamp = ss.str(); + // csv_filename += "_" + timestamp; + + csv_filename += ".csv"; + std::ofstream file(csv_filename.c_str()); // 创建新文件 + + for(int i = 0; i < 12; ++i) {file << "torque_" << i << ",";} + for(int i = 0; i < 12; ++i) {file << "tau_est_" << i << ",";} + for(int i = 0; i < 12; ++i) {file << "joint_pos_" << i << ",";} + for(int i = 0; i < 12; ++i) {file << "joint_pos_target_" << i << ",";} + for(int i = 0; i < 12; ++i) {file << "joint_vel_" << i << ",";} + + file << std::endl; + + file.close(); // 关闭文件 +} + +void RL::CSVLogger(torch::Tensor torque, torch::Tensor tau_est, torch::Tensor joint_pos, torch::Tensor joint_pos_target, torch::Tensor joint_vel) +{ + std::ofstream file(csv_filename.c_str(), std::ios_base::app); // 打开文件以追加模式写入数据 + + // 写入数据 + for(int i = 0; i < 12; ++i) {file << torque[0][i].item() << ",";} + for(int i = 0; i < 12; ++i) {file << tau_est[0][i].item() << ",";} + for(int i = 0; i < 12; ++i) {file << joint_pos[0][i].item() << ",";} + for(int i = 0; i < 12; ++i) {file << joint_pos_target[0][i].item() << ",";} + for(int i = 0; i < 12; ++i) {file << joint_vel[0][i].item() << ",";} + + file << std::endl; + + file.close(); // 关闭文件 +} + torch::Tensor RL::QuatRotateInverse(torch::Tensor q, torch::Tensor v) { c10::IntArrayRef shape = q.sizes(); diff --git a/src/rl_sar/library/rl/rl.hpp b/src/rl_sar/library/rl/rl.hpp index f5b3b4b..923bbeb 100644 --- a/src/rl_sar/library/rl/rl.hpp +++ b/src/rl_sar/library/rl/rl.hpp @@ -60,6 +60,9 @@ public: void InitObservations(); void InitOutputs(); void ReadYaml(std::string robot_name); + std::string csv_filename; + void CSVInit(std::string robot_name); + void CSVLogger(torch::Tensor torque, torch::Tensor tau_est, torch::Tensor joint_pos, torch::Tensor joint_pos_target, torch::Tensor joint_vel); protected: // rl module