diff --git a/src/rl_sar/CMakeLists.txt b/src/rl_sar/CMakeLists.txt index 99443bb..f8d01cb 100644 --- a/src/rl_sar/CMakeLists.txt +++ b/src/rl_sar/CMakeLists.txt @@ -69,12 +69,12 @@ set_property(TARGET observation_buffer PROPERTY CXX_STANDARD 14) add_executable(rl_sim src/rl_sim.cpp ) target_link_libraries(rl_sim - ${catkin_LIBRARIES} ${EXTRA_LIBS} "${TORCH_LIBRARIES}" + ${catkin_LIBRARIES} ${EXTRA_LIBS} rl_sdk observation_buffer yaml-cpp ) add_executable(rl_real_a1 src/rl_real_a1.cpp ) target_link_libraries(rl_real_a1 - ${catkin_LIBRARIES} ${EXTRA_LIBS} "${TORCH_LIBRARIES}" + ${catkin_LIBRARIES} ${EXTRA_LIBS} rl_sdk observation_buffer yaml-cpp ) diff --git a/src/rl_sar/library/rl_sdk/rl_sdk.cpp b/src/rl_sar/library/rl_sdk/rl_sdk.cpp index 5088f1a..9817a05 100644 --- a/src/rl_sar/library/rl_sdk/rl_sdk.cpp +++ b/src/rl_sar/library/rl_sdk/rl_sdk.cpp @@ -3,34 +3,28 @@ /* You may need to override this ComputeObservation() function torch::Tensor RL::ComputeObservation() { - torch::Tensor obs = torch::cat({(this->QuatRotateInverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale, - (this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel)) * this->params.ang_vel_scale, + torch::Tensor obs = torch::cat({this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel) * this->params.ang_vel_scale, this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec), this->obs.commands * this->params.commands_scale, (this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale, this->obs.dof_vel * this->params.dof_vel_scale, - this->obs.actions}, - 1); - - obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs); - - printf("observation size: %d, %d\n", obs.sizes()[0], obs.sizes()[1]); - - return obs; + this->obs.actions + },1); + torch::Tensor clamped_obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs); + return clamped_obs; } */ /* You may need to override this Forward() function torch::Tensor RL::Forward() { - torch::Tensor obs = this->ComputeObservation(); + torch::autograd::GradMode::set_enabled(false); - torch::Tensor actor_input = torch::cat({obs}, 1); + torch::Tensor clamped_obs = this->ComputeObservation(); - torch::Tensor actions = this->actor.forward({actor_input}).toTensor(); + torch::Tensor actions = this->model.forward({clamped_obs}).toTensor(); - this->obs.actions = actions; - torch::Tensor clamped_actions = torch::clamp(actions, -this->params.clip_actions, this->params.clip_actions); + torch::Tensor clamped_actions = torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper); return clamped_actions; } @@ -122,7 +116,7 @@ void RL::StateController(const RobotState *state, RobotCommand * command->motor_command.kd[i] = params.fixed_kd[0][i].item(); command->motor_command.tau[i] = 0; } - printf("Getting up %.3f%%\r", getup_percent*100.0); + std::cout << LOGGER::INFO << "Getting up " << std::fixed << std::setprecision(2) << getup_percent * 100.0 << "%\r"; } if(keyboard.keyboard_state == STATE_RL_INIT) { @@ -155,8 +149,6 @@ void RL::StateController(const RobotState *state, RobotCommand * // rl loop else if(running_state == STATE_RL_RUNNING) { - std::cout << "[RL Controller] x:" << keyboard.x << " y:" << keyboard.y << " yaw:" << keyboard.yaw << " \r"; - for(int i = 0; i < params.num_of_dofs; ++i) { command->motor_command.q[i] = output_dof_pos[0][i].item(); @@ -191,7 +183,7 @@ void RL::StateController(const RobotState *state, RobotCommand * command->motor_command.kd[i] = params.fixed_kd[0][i].item(); command->motor_command.tau[i] = 0; } - printf("Getting down %.3f%%\r", getdown_percent*100.0); + std::cout << LOGGER::INFO << "Getting down " << std::fixed << std::setprecision(2) << getdown_percent * 100.0 << "%\r"; } if(getdown_percent == 1) { @@ -222,16 +214,16 @@ void RL::TorqueProtect(torch::Tensor origin_output_torques) } if(!out_of_range_indices.empty()) { - std::cout << "Error: origin_output_torques is out of range at indices: "; for(int i = 0; i < out_of_range_indices.size(); ++i) { - std::cout << out_of_range_indices[i] << " (value: " << out_of_range_values[i] << ")"; - if(i < out_of_range_indices.size() - 1) - { - std::cout << ", "; - } + int index = out_of_range_indices[i]; + double value = out_of_range_values[i]; + double limit_lower = -this->params.torque_limits[0][index].item(); + double limit_upper = this->params.torque_limits[0][index].item(); + + std::cout << LOGGER::ERROR << "Torque(" << i+1 << ")=" << value << " out of range(" << limit_lower << ", " << limit_upper << ")" << std::endl; + std::cout << LOGGER::ERROR << "Switching to STATE_POS_GETDOWN"<< std::endl; } - std::cout << std::endl; keyboard.keyboard_state = STATE_POS_GETDOWN; } } @@ -257,6 +249,11 @@ static bool kbhit() void RL::RunKeyboard() { + if(running_state == STATE_RL_RUNNING) + { + std::cout << LOGGER::INFO << "RL Controller x:" << keyboard.x << " y:" << keyboard.y << " yaw:" << keyboard.yaw << " \r"; + } + if(kbhit()) { int c = fgetc(stdin); @@ -301,7 +298,7 @@ void RL::ReadYaml(std::string robot_name) } catch(YAML::BadFile &e) { - std::cout << "The file '" << CONFIG_PATH << "' does not exist" << std::endl; + std::cout << LOGGER::ERROR << "The file '" << CONFIG_PATH << "' does not exist" << std::endl; return; } diff --git a/src/rl_sar/library/rl_sdk/rl_sdk.hpp b/src/rl_sar/library/rl_sdk/rl_sdk.hpp index 6e549bf..91e3b39 100644 --- a/src/rl_sar/library/rl_sdk/rl_sdk.hpp +++ b/src/rl_sar/library/rl_sdk/rl_sdk.hpp @@ -9,6 +9,13 @@ #include #define CONFIG_PATH CMAKE_CURRENT_SOURCE_DIR "/config.yaml" +namespace LOGGER { + const char* const INFO = "\033[0;37m[INFO]\033[0m "; + const char* const WARNING = "\033[0;33m[WARNING]\033[0m "; + const char* const ERROR = "\033[0;31m[ERROR]\033[0m "; + const char* const DEBUG = "\033[0;32m[DEBUG]\033[0m "; +} + template struct RobotCommand { @@ -102,6 +109,7 @@ class RL { public: RL(){}; + ~RL(){}; ModelParams params; Observations obs; diff --git a/src/rl_sar/src/rl_real_a1.cpp b/src/rl_sar/src/rl_real_a1.cpp index 42b6bb2..26ee249 100644 --- a/src/rl_sar/src/rl_real_a1.cpp +++ b/src/rl_sar/src/rl_real_a1.cpp @@ -64,7 +64,7 @@ RL_Real::~RL_Real() #ifdef PLOT loop_plot->shutdown(); #endif - printf("exit\n"); + std::cout << LOGGER::INFO << "RL_Real exit" << std::endl; } void RL_Real::GetState(RobotState *state) diff --git a/src/rl_sar/src/rl_sim.cpp b/src/rl_sar/src/rl_sim.cpp index 40c37b1..4b22699 100644 --- a/src/rl_sar/src/rl_sim.cpp +++ b/src/rl_sar/src/rl_sim.cpp @@ -90,7 +90,7 @@ RL_Sim::~RL_Sim() #ifdef PLOT loop_plot->shutdown(); #endif - printf("exit\n"); + std::cout << LOGGER::INFO << "RL_Sim exit" << std::endl; } void RL_Sim::GetState(RobotState *state)