mirror of https://github.com/fan-ziqi/rl_sar.git
feat: add csv func for actuator_net
This commit is contained in:
parent
6df0191152
commit
909ab15728
|
@ -48,6 +48,48 @@ void RL::ReadYaml(std::string robot_name)
|
||||||
config["default_dof_pos"][9].as<float>(), config["default_dof_pos"][10].as<float>(), config["default_dof_pos"][11].as<float>()}});
|
config["default_dof_pos"][9].as<float>(), config["default_dof_pos"][10].as<float>(), config["default_dof_pos"][11].as<float>()}});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void RL::CSVInit(std::string robot_name)
|
||||||
|
{
|
||||||
|
csv_filename = std::string(CMAKE_CURRENT_SOURCE_DIR) + "/models/" + robot_name + "/motor";
|
||||||
|
|
||||||
|
// // Uncomment these lines if need timestamp for file name
|
||||||
|
// auto now = std::chrono::system_clock::now();
|
||||||
|
// std::time_t now_c = std::chrono::system_clock::to_time_t(now);
|
||||||
|
// std::stringstream ss;
|
||||||
|
// ss << std::put_time(std::localtime(&now_c), "%Y%m%d%H%M%S");
|
||||||
|
// std::string timestamp = ss.str();
|
||||||
|
// csv_filename += "_" + timestamp;
|
||||||
|
|
||||||
|
csv_filename += ".csv";
|
||||||
|
std::ofstream file(csv_filename.c_str()); // 创建新文件
|
||||||
|
|
||||||
|
for(int i = 0; i < 12; ++i) {file << "torque_" << i << ",";}
|
||||||
|
for(int i = 0; i < 12; ++i) {file << "tau_est_" << i << ",";}
|
||||||
|
for(int i = 0; i < 12; ++i) {file << "joint_pos_" << i << ",";}
|
||||||
|
for(int i = 0; i < 12; ++i) {file << "joint_pos_target_" << i << ",";}
|
||||||
|
for(int i = 0; i < 12; ++i) {file << "joint_vel_" << i << ",";}
|
||||||
|
|
||||||
|
file << std::endl;
|
||||||
|
|
||||||
|
file.close(); // 关闭文件
|
||||||
|
}
|
||||||
|
|
||||||
|
void RL::CSVLogger(torch::Tensor torque, torch::Tensor tau_est, torch::Tensor joint_pos, torch::Tensor joint_pos_target, torch::Tensor joint_vel)
|
||||||
|
{
|
||||||
|
std::ofstream file(csv_filename.c_str(), std::ios_base::app); // 打开文件以追加模式写入数据
|
||||||
|
|
||||||
|
// 写入数据
|
||||||
|
for(int i = 0; i < 12; ++i) {file << torque[0][i].item<double>() << ",";}
|
||||||
|
for(int i = 0; i < 12; ++i) {file << tau_est[0][i].item<double>() << ",";}
|
||||||
|
for(int i = 0; i < 12; ++i) {file << joint_pos[0][i].item<double>() << ",";}
|
||||||
|
for(int i = 0; i < 12; ++i) {file << joint_pos_target[0][i].item<double>() << ",";}
|
||||||
|
for(int i = 0; i < 12; ++i) {file << joint_vel[0][i].item<double>() << ",";}
|
||||||
|
|
||||||
|
file << std::endl;
|
||||||
|
|
||||||
|
file.close(); // 关闭文件
|
||||||
|
}
|
||||||
|
|
||||||
torch::Tensor RL::QuatRotateInverse(torch::Tensor q, torch::Tensor v)
|
torch::Tensor RL::QuatRotateInverse(torch::Tensor q, torch::Tensor v)
|
||||||
{
|
{
|
||||||
c10::IntArrayRef shape = q.sizes();
|
c10::IntArrayRef shape = q.sizes();
|
||||||
|
|
|
@ -60,6 +60,9 @@ public:
|
||||||
void InitObservations();
|
void InitObservations();
|
||||||
void InitOutputs();
|
void InitOutputs();
|
||||||
void ReadYaml(std::string robot_name);
|
void ReadYaml(std::string robot_name);
|
||||||
|
std::string csv_filename;
|
||||||
|
void CSVInit(std::string robot_name);
|
||||||
|
void CSVLogger(torch::Tensor torque, torch::Tensor tau_est, torch::Tensor joint_pos, torch::Tensor joint_pos_target, torch::Tensor joint_vel);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// rl module
|
// rl module
|
||||||
|
|
Loading…
Reference in New Issue