feat: add LOGGER

This commit is contained in:
fan-ziqi 2024-05-27 21:33:07 +08:00
parent e84f2d85bc
commit 3a9886e6ea
5 changed files with 36 additions and 31 deletions

View File

@ -69,12 +69,12 @@ set_property(TARGET observation_buffer PROPERTY CXX_STANDARD 14)
add_executable(rl_sim src/rl_sim.cpp ) add_executable(rl_sim src/rl_sim.cpp )
target_link_libraries(rl_sim target_link_libraries(rl_sim
${catkin_LIBRARIES} ${EXTRA_LIBS} "${TORCH_LIBRARIES}" ${catkin_LIBRARIES} ${EXTRA_LIBS}
rl_sdk observation_buffer yaml-cpp rl_sdk observation_buffer yaml-cpp
) )
add_executable(rl_real_a1 src/rl_real_a1.cpp ) add_executable(rl_real_a1 src/rl_real_a1.cpp )
target_link_libraries(rl_real_a1 target_link_libraries(rl_real_a1
${catkin_LIBRARIES} ${EXTRA_LIBS} "${TORCH_LIBRARIES}" ${catkin_LIBRARIES} ${EXTRA_LIBS}
rl_sdk observation_buffer yaml-cpp rl_sdk observation_buffer yaml-cpp
) )

View File

@ -3,34 +3,28 @@
/* You may need to override this ComputeObservation() function /* You may need to override this ComputeObservation() function
torch::Tensor RL::ComputeObservation() torch::Tensor RL::ComputeObservation()
{ {
torch::Tensor obs = torch::cat({(this->QuatRotateInverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale, torch::Tensor obs = torch::cat({this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel) * this->params.ang_vel_scale,
(this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel)) * this->params.ang_vel_scale,
this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec), this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec),
this->obs.commands * this->params.commands_scale, this->obs.commands * this->params.commands_scale,
(this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale, (this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale,
this->obs.dof_vel * this->params.dof_vel_scale, this->obs.dof_vel * this->params.dof_vel_scale,
this->obs.actions}, this->obs.actions
1); },1);
torch::Tensor clamped_obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs);
obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs); return clamped_obs;
printf("observation size: %d, %d\n", obs.sizes()[0], obs.sizes()[1]);
return obs;
} }
*/ */
/* You may need to override this Forward() function /* You may need to override this Forward() function
torch::Tensor RL::Forward() torch::Tensor RL::Forward()
{ {
torch::Tensor obs = this->ComputeObservation(); torch::autograd::GradMode::set_enabled(false);
torch::Tensor actor_input = torch::cat({obs}, 1); torch::Tensor clamped_obs = this->ComputeObservation();
torch::Tensor actions = this->actor.forward({actor_input}).toTensor(); torch::Tensor actions = this->model.forward({clamped_obs}).toTensor();
this->obs.actions = actions; torch::Tensor clamped_actions = torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper);
torch::Tensor clamped_actions = torch::clamp(actions, -this->params.clip_actions, this->params.clip_actions);
return clamped_actions; return clamped_actions;
} }
@ -122,7 +116,7 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
command->motor_command.kd[i] = params.fixed_kd[0][i].item<double>(); command->motor_command.kd[i] = params.fixed_kd[0][i].item<double>();
command->motor_command.tau[i] = 0; command->motor_command.tau[i] = 0;
} }
printf("Getting up %.3f%%\r", getup_percent*100.0); std::cout << LOGGER::INFO << "Getting up " << std::fixed << std::setprecision(2) << getup_percent * 100.0 << "%\r";
} }
if(keyboard.keyboard_state == STATE_RL_INIT) if(keyboard.keyboard_state == STATE_RL_INIT)
{ {
@ -155,8 +149,6 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
// rl loop // rl loop
else if(running_state == STATE_RL_RUNNING) else if(running_state == STATE_RL_RUNNING)
{ {
std::cout << "[RL Controller] x:" << keyboard.x << " y:" << keyboard.y << " yaw:" << keyboard.yaw << " \r";
for(int i = 0; i < params.num_of_dofs; ++i) for(int i = 0; i < params.num_of_dofs; ++i)
{ {
command->motor_command.q[i] = output_dof_pos[0][i].item<double>(); command->motor_command.q[i] = output_dof_pos[0][i].item<double>();
@ -191,7 +183,7 @@ void RL::StateController(const RobotState<double> *state, RobotCommand<double> *
command->motor_command.kd[i] = params.fixed_kd[0][i].item<double>(); command->motor_command.kd[i] = params.fixed_kd[0][i].item<double>();
command->motor_command.tau[i] = 0; command->motor_command.tau[i] = 0;
} }
printf("Getting down %.3f%%\r", getdown_percent*100.0); std::cout << LOGGER::INFO << "Getting down " << std::fixed << std::setprecision(2) << getdown_percent * 100.0 << "%\r";
} }
if(getdown_percent == 1) if(getdown_percent == 1)
{ {
@ -222,16 +214,16 @@ void RL::TorqueProtect(torch::Tensor origin_output_torques)
} }
if(!out_of_range_indices.empty()) if(!out_of_range_indices.empty())
{ {
std::cout << "Error: origin_output_torques is out of range at indices: ";
for(int i = 0; i < out_of_range_indices.size(); ++i) for(int i = 0; i < out_of_range_indices.size(); ++i)
{ {
std::cout << out_of_range_indices[i] << " (value: " << out_of_range_values[i] << ")"; int index = out_of_range_indices[i];
if(i < out_of_range_indices.size() - 1) double value = out_of_range_values[i];
{ double limit_lower = -this->params.torque_limits[0][index].item<double>();
std::cout << ", "; double limit_upper = this->params.torque_limits[0][index].item<double>();
}
std::cout << LOGGER::ERROR << "Torque(" << i+1 << ")=" << value << " out of range(" << limit_lower << ", " << limit_upper << ")" << std::endl;
std::cout << LOGGER::ERROR << "Switching to STATE_POS_GETDOWN"<< std::endl;
} }
std::cout << std::endl;
keyboard.keyboard_state = STATE_POS_GETDOWN; keyboard.keyboard_state = STATE_POS_GETDOWN;
} }
} }
@ -257,6 +249,11 @@ static bool kbhit()
void RL::RunKeyboard() void RL::RunKeyboard()
{ {
if(running_state == STATE_RL_RUNNING)
{
std::cout << LOGGER::INFO << "RL Controller x:" << keyboard.x << " y:" << keyboard.y << " yaw:" << keyboard.yaw << " \r";
}
if(kbhit()) if(kbhit())
{ {
int c = fgetc(stdin); int c = fgetc(stdin);
@ -301,7 +298,7 @@ void RL::ReadYaml(std::string robot_name)
} catch(YAML::BadFile &e) } catch(YAML::BadFile &e)
{ {
std::cout << "The file '" << CONFIG_PATH << "' does not exist" << std::endl; std::cout << LOGGER::ERROR << "The file '" << CONFIG_PATH << "' does not exist" << std::endl;
return; return;
} }

View File

@ -9,6 +9,13 @@
#include <yaml-cpp/yaml.h> #include <yaml-cpp/yaml.h>
#define CONFIG_PATH CMAKE_CURRENT_SOURCE_DIR "/config.yaml" #define CONFIG_PATH CMAKE_CURRENT_SOURCE_DIR "/config.yaml"
namespace LOGGER {
const char* const INFO = "\033[0;37m[INFO]\033[0m ";
const char* const WARNING = "\033[0;33m[WARNING]\033[0m ";
const char* const ERROR = "\033[0;31m[ERROR]\033[0m ";
const char* const DEBUG = "\033[0;32m[DEBUG]\033[0m ";
}
template<typename T> template<typename T>
struct RobotCommand struct RobotCommand
{ {
@ -102,6 +109,7 @@ class RL
{ {
public: public:
RL(){}; RL(){};
~RL(){};
ModelParams params; ModelParams params;
Observations obs; Observations obs;

View File

@ -64,7 +64,7 @@ RL_Real::~RL_Real()
#ifdef PLOT #ifdef PLOT
loop_plot->shutdown(); loop_plot->shutdown();
#endif #endif
printf("exit\n"); std::cout << LOGGER::INFO << "RL_Real exit" << std::endl;
} }
void RL_Real::GetState(RobotState<double> *state) void RL_Real::GetState(RobotState<double> *state)

View File

@ -90,7 +90,7 @@ RL_Sim::~RL_Sim()
#ifdef PLOT #ifdef PLOT
loop_plot->shutdown(); loop_plot->shutdown();
#endif #endif
printf("exit\n"); std::cout << LOGGER::INFO << "RL_Sim exit" << std::endl;
} }
void RL_Sim::GetState(RobotState<double> *state) void RL_Sim::GetState(RobotState<double> *state)