feat: add csv func for actuator_net

This commit is contained in:
fan-ziqi 2024-04-08 19:51:38 +08:00
parent 6df0191152
commit 909ab15728
2 changed files with 45 additions and 0 deletions

View File

@ -48,6 +48,48 @@ void RL::ReadYaml(std::string robot_name)
config["default_dof_pos"][9].as<float>(), config["default_dof_pos"][10].as<float>(), config["default_dof_pos"][11].as<float>()}});
}
void RL::CSVInit(std::string robot_name)
{
csv_filename = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + robot_name + "/motor";
// // Uncomment these lines if need timestamp for file name
// auto now = std::chrono::system_clock::now();
// std::time_t now_c = std::chrono::system_clock::to_time_t(now);
// std::stringstream ss;
// ss << std::put_time(std::localtime(&now_c), "%Y%m%d%H%M%S");
// std::string timestamp = ss.str();
// csv_filename += "_" + timestamp;
csv_filename += ".csv";
std::ofstream file(csv_filename.c_str()); // 创建新文件
for(int i = 0; i < 12; ++i) {file << "torque_" << i << ",";}
for(int i = 0; i < 12; ++i) {file << "tau_est_" << i << ",";}
for(int i = 0; i < 12; ++i) {file << "joint_pos_" << i << ",";}
for(int i = 0; i < 12; ++i) {file << "joint_pos_target_" << i << ",";}
for(int i = 0; i < 12; ++i) {file << "joint_vel_" << i << ",";}
file << std::endl;
file.close(); // 关闭文件
}
void RL::CSVLogger(torch::Tensor torque, torch::Tensor tau_est, torch::Tensor joint_pos, torch::Tensor joint_pos_target, torch::Tensor joint_vel)
{
std::ofstream file(csv_filename.c_str(), std::ios_base::app); // 打开文件以追加模式写入数据
// 写入数据
for(int i = 0; i < 12; ++i) {file << torque[0][i].item<double>() << ",";}
for(int i = 0; i < 12; ++i) {file << tau_est[0][i].item<double>() << ",";}
for(int i = 0; i < 12; ++i) {file << joint_pos[0][i].item<double>() << ",";}
for(int i = 0; i < 12; ++i) {file << joint_pos_target[0][i].item<double>() << ",";}
for(int i = 0; i < 12; ++i) {file << joint_vel[0][i].item<double>() << ",";}
file << std::endl;
file.close(); // 关闭文件
}
torch::Tensor RL::QuatRotateInverse(torch::Tensor q, torch::Tensor v)
{
c10::IntArrayRef shape = q.sizes();

View File

@ -60,6 +60,9 @@ public:
void InitObservations();
void InitOutputs();
void ReadYaml(std::string robot_name);
std::string csv_filename;
void CSVInit(std::string robot_name);
void CSVLogger(torch::Tensor torque, torch::Tensor tau_est, torch::Tensor joint_pos, torch::Tensor joint_pos_target, torch::Tensor joint_vel);
protected:
// rl module