diff --git a/README_CN.md b/README_CN.md index 2e39005..6d3a8ba 100644 --- a/README_CN.md +++ b/README_CN.md @@ -74,4 +74,18 @@ rosrun teleop_twist_keyboard teleop_twist_keyboard.py -> 部分代码参考https://github.com/mertgungor/unitree_model_control \ No newline at end of file +> 部分代码参考https://github.com/mertgungor/unitree_model_control + +0.003563, 0.001512, -0.101311 +-0.026536, 0.062404, 0.000014 +-0.000302, 0.000558, -0.011532, 0.999933 +0.499042, 1.120411, -2.696528, -0.497325, 1.120405, -2.696527, 0.495533, 1.120391, -2.696531, -0.493813, 1.120387, -2.696530 +0.838333, 0.031000, -0.103122, -0.768373, -0.006074, 0.000378, 0.669494, 0.020684, -0.066489, -0.588053, -0.004475, 0.000998 + +-0.014956, -0.002124, 0.007345 +0.005116, 0.010299, -0.021976, 0.999692 +0.484257, 0.969105, -2.556143, -0.369799, 1.061818, -2.635455, 0.314972, 1.066411, -2.661583, -0.353911, 1.047658, -2.649109 +0.000000, 0.000000, 0.012878, 0.000000, -0.012878, 0.127060, 0.006010, 0.000000, -0.050652, 0.042926, -0.024897, -0.091003 + +0.146841, 1.004310, -1.390954, -0.053719, 0.855462, -1.598009, -0.011909, 0.792030, -1.664962, -0.052048, 1.163194, -1.519249 +-0.062282, -0.523800, 0.810542, 0.027546, -0.772218, 1.406845, -0.005030, -0.555534, 1.106206, 0.010507, -0.411735, 0.578516 \ No newline at end of file diff --git a/src/unitree_rl/CMakeLists.txt b/src/unitree_rl/CMakeLists.txt index 17955fb..34be3c1 100644 --- a/src/unitree_rl/CMakeLists.txt +++ b/src/unitree_rl/CMakeLists.txt @@ -1,11 +1,9 @@ cmake_minimum_required(VERSION 3.0.2) project(unitree_rl) -set(EXTRA_LIBS -pthread libunitree_legged_sdk_amd64.so lcm) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") find_package(Torch REQUIRED) -find_package(unitree_legged_sdk REQUIRED) find_package(catkin REQUIRED COMPONENTS controller_manager @@ -27,14 +25,11 @@ catkin_package( unitree_legged_msgs ) -message("-- CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") -if("${CMAKE_SYSTEM_PROCESSOR}" MATCHES "x86_64.*") - set(ARCH amd64) -else() - set(ARCH arm64) -endif() - -set(EXTRA_LIBS -pthread ${unitree_legged_sdk_LIBRARIES}) +include_directories(library/unitree_legged_sdk_3.2/include) +link_directories(library/unitree_legged_sdk_3.2/lib) +set(EXTRA_LIBS -pthread libunitree_legged_sdk_amd64.so lcm) +add_executable(lcm_server $ENV{UNITREE_LEGGED_SDK_PATH}/examples/lcm_server.cpp) +target_link_libraries(lcm_server ${EXTRA_LIBS} ${catkin_LIBRARIES}) include_directories( include @@ -55,8 +50,13 @@ target_link_libraries(observation_buffer "${TORCH_LIBRARIES}") set_property(TARGET observation_buffer PROPERTY CXX_STANDARD 14) add_executable(unitree_rl src/unitree_rl.cpp ) -target_link_libraries(${PROJECT_NAME} +target_link_libraries(unitree_rl ${catkin_LIBRARIES} ${EXTRA_LIBS} "${TORCH_LIBRARIES}" model observation_buffer ) -# add_dependencies(${PROJECT_NAME} unitree_legged_msgs_gencpp) \ No newline at end of file + +add_executable(unitree_rl_real src/unitree_rl_real.cpp ) +target_link_libraries(unitree_rl_real + ${catkin_LIBRARIES} ${EXTRA_LIBS} "${TORCH_LIBRARIES}" + model observation_buffer +) \ No newline at end of file diff --git a/src/unitree_rl/include/convert.h b/src/unitree_rl/include/convert.h new file mode 100644 index 0000000..e2e900d --- /dev/null +++ b/src/unitree_rl/include/convert.h @@ -0,0 +1,187 @@ +/************************************************************************ +Copyright (c) 2018-2019, Unitree Robotics.Co.Ltd. All rights reserved. +Use of this source code is governed by the MPL-2.0 license, see LICENSE. +************************************************************************/ + +#ifndef _CONVERT_H_ +#define _CONVERT_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "unitree_legged_sdk/unitree_legged_sdk.h" + +unitree_legged_msgs::IMU ToRos(UNITREE_LEGGED_SDK::IMU& lcm) +{ + unitree_legged_msgs::IMU ros; + ros.quaternion[0] = lcm.quaternion[0]; + ros.quaternion[1] = lcm.quaternion[1]; + ros.quaternion[2] = lcm.quaternion[2]; + ros.quaternion[3] = lcm.quaternion[3]; + ros.gyroscope[0] = lcm.gyroscope[0]; + ros.gyroscope[1] = lcm.gyroscope[1]; + ros.gyroscope[2] = lcm.gyroscope[2]; + ros.accelerometer[0] = lcm.accelerometer[0]; + ros.accelerometer[1] = lcm.accelerometer[1]; + ros.accelerometer[2] = lcm.accelerometer[2]; + // ros.rpy[0] = lcm.rpy[0]; + // ros.rpy[1] = lcm.rpy[1]; + // ros.rpy[2] = lcm.rpy[2]; + ros.temperature = lcm.temperature; + return ros; +} + +unitree_legged_msgs::MotorState ToRos(UNITREE_LEGGED_SDK::MotorState& lcm) +{ + unitree_legged_msgs::MotorState ros; + ros.mode = lcm.mode; + ros.q = lcm.q; + ros.dq = lcm.dq; + ros.ddq = lcm.ddq; + ros.tauEst = lcm.tauEst; + ros.q_raw = lcm.q_raw; + ros.dq_raw = lcm.dq_raw; + ros.ddq_raw = lcm.ddq_raw; + ros.temperature = lcm.temperature; + ros.reserve[0] = lcm.reserve[0]; + ros.reserve[1] = lcm.reserve[1]; + return ros; +} + +UNITREE_LEGGED_SDK::MotorCmd ToLcm(unitree_legged_msgs::MotorCmd& ros, UNITREE_LEGGED_SDK::MotorCmd lcmType) +{ + UNITREE_LEGGED_SDK::MotorCmd lcm; + lcm.mode = ros.mode; + lcm.q = ros.q; + lcm.dq = ros.dq; + lcm.tau = ros.tau; + lcm.Kp = ros.Kp; + lcm.Kd = ros.Kd; + lcm.reserve[0] = ros.reserve[0]; + lcm.reserve[1] = ros.reserve[1]; + lcm.reserve[2] = ros.reserve[2]; + return lcm; +} + +unitree_legged_msgs::LowState ToRos(UNITREE_LEGGED_SDK::LowState& lcm) +{ + unitree_legged_msgs::LowState ros; + ros.levelFlag = lcm.levelFlag; + ros.commVersion = lcm.commVersion; + ros.robotID = lcm.robotID; + ros.SN = lcm.SN; + ros.bandWidth = lcm.bandWidth; + ros.imu = ToRos(lcm.imu); + for(int i = 0; i<20; i++){ + ros.motorState[i] = ToRos(lcm.motorState[i]); + } + for(int i = 0; i<4; i++){ + ros.footForce[i] = lcm.footForce[i]; + ros.footForceEst[i] = lcm.footForceEst[i]; + } + ros.tick = lcm.tick; + for(int i = 0; i<40; i++){ + ros.wirelessRemote[i] = lcm.wirelessRemote[i]; + } + ros.reserve = lcm.reserve; + ros.crc = lcm.crc; + return ros; +} + +UNITREE_LEGGED_SDK::LowCmd ToLcm(unitree_legged_msgs::LowCmd& ros, UNITREE_LEGGED_SDK::LowCmd lcmType) +{ + UNITREE_LEGGED_SDK::LowCmd lcm; + lcm.levelFlag = ros.levelFlag; + lcm.commVersion = ros.commVersion; + lcm.robotID = ros.robotID; + lcm.SN = ros.SN; + lcm.bandWidth = ros.bandWidth; + for(int i = 0; i<20; i++){ + lcm.motorCmd[i] = ToLcm(ros.motorCmd[i], lcm.motorCmd[i]); + } + for(int i = 0; i<4; i++){ + lcm.led[i].r = ros.led[i].r; + lcm.led[i].g = ros.led[i].g; + lcm.led[i].b = ros.led[i].b; + } + for(int i = 0; i<40; i++){ + lcm.wirelessRemote[i] = ros.wirelessRemote[i]; + } + lcm.reserve = ros.reserve; + lcm.crc = ros.crc; + return lcm; +} + +unitree_legged_msgs::HighState ToRos(UNITREE_LEGGED_SDK::HighState& lcm) +{ + unitree_legged_msgs::HighState ros; + ros.levelFlag = lcm.levelFlag; + ros.commVersion = lcm.commVersion; + ros.robotID = lcm.robotID; + ros.SN = lcm.SN; + ros.bandWidth = lcm.bandWidth; + ros.mode = lcm.mode; + ros.imu = ToRos(lcm.imu); + ros.forwardSpeed = lcm.forwardSpeed; + ros.sideSpeed = lcm.sideSpeed; + ros.rotateSpeed = lcm.rotateSpeed; + ros.bodyHeight = lcm.bodyHeight; + ros.updownSpeed = lcm.updownSpeed; + ros.forwardPosition = lcm.forwardPosition; + ros.sidePosition = lcm.sidePosition; + for(int i = 0; i<4; i++){ + ros.footPosition2Body[i].x = lcm.footPosition2Body[i].x; + ros.footPosition2Body[i].y = lcm.footPosition2Body[i].y; + ros.footPosition2Body[i].z = lcm.footPosition2Body[i].z; + ros.footSpeed2Body[i].x = lcm.footSpeed2Body[i].x; + ros.footSpeed2Body[i].y = lcm.footSpeed2Body[i].y; + ros.footSpeed2Body[i].z = lcm.footSpeed2Body[i].z; + ros.footForce[i] = lcm.footForce[i]; + ros.footForceEst[i] = lcm.footForceEst[i]; + } + ros.tick = lcm.tick; + for(int i = 0; i<40; i++){ + ros.wirelessRemote[i] = lcm.wirelessRemote[i]; + } + ros.reserve = lcm.reserve; + ros.crc = lcm.crc; + return ros; +} + +UNITREE_LEGGED_SDK::HighCmd ToLcm(unitree_legged_msgs::HighCmd& ros, UNITREE_LEGGED_SDK::HighCmd lcmType) +{ + UNITREE_LEGGED_SDK::HighCmd lcm; + lcm.levelFlag = ros.levelFlag; + lcm.commVersion = ros.commVersion; + lcm.robotID = ros.robotID; + lcm.SN = ros.SN; + lcm.bandWidth = ros.bandWidth; + lcm.mode = ros.mode; + lcm.forwardSpeed = ros.forwardSpeed; + lcm.sideSpeed = ros.sideSpeed; + lcm.rotateSpeed = ros.rotateSpeed; + lcm.bodyHeight = ros.bodyHeight; + lcm.footRaiseHeight = ros.footRaiseHeight; + lcm.yaw = ros.yaw; + lcm.pitch = ros.pitch; + lcm.roll = ros.roll; + for(int i = 0; i<4; i++){ + lcm.led[i].r = ros.led[i].r; + lcm.led[i].g = ros.led[i].g; + lcm.led[i].b = ros.led[i].b; + } + for(int i = 0; i<40; i++){ + lcm.wirelessRemote[i] = ros.wirelessRemote[i]; + lcm.AppRemote[i] = ros.AppRemote[i]; + } + lcm.reserve = ros.reserve; + lcm.crc = ros.crc; + return lcm; +} + +#endif // _CONVERT_H_ \ No newline at end of file diff --git a/src/unitree_rl/include/unitree_rl_real.hpp b/src/unitree_rl/include/unitree_rl_real.hpp new file mode 100644 index 0000000..957b8b2 --- /dev/null +++ b/src/unitree_rl/include/unitree_rl_real.hpp @@ -0,0 +1,92 @@ +#ifndef UNITREE_RL +#define UNITREE_RL + +#include +#include +#include +#include +#include "../lib/model.cpp" +#include "../lib/observation_buffer.hpp" +#include +#include "unitree_legged_msgs/LowState.h" +#include "convert.h" +#include + +using namespace UNITREE_LEGGED_SDK; + +class Unitree_RL : public Model +{ +public: + Unitree_RL(); + void modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg); + void jointStatesCallback(const sensor_msgs::JointState::ConstPtr &msg); + void cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg); + void runModel(); + torch::Tensor forward() override; + torch::Tensor compute_observation() override; + + ObservationBuffer history_obs_buf; + torch::Tensor history_obs; + + torch::Tensor torques; + + //udp + void UDPSend(); + void UDPRecv(); + void RobotControl(); + Safety safe; + UDP udp; + LowCmd cmd = {0}; + LowState state = {0}; + int motiontime = 0; + + std::shared_ptr loop_control; + std::shared_ptr loop_udpSend; + std::shared_ptr loop_udpRecv; + std::shared_ptr loop_rl; + + + float _percent; + float _targetPos[12] = {0.0, 0.8, -1.6, 0.0, 0.8, -1.6, + 0.0, 0.8, -1.6, 0.0, 0.8, -1.6}; //0.0, 0.67, -1.3 + float _startPos[12]; + + bool init_done = false; + +private: + std::string ros_namespace; + + std::vector torque_command_topics; + + ros::Subscriber model_state_subscriber_; + ros::Subscriber joint_state_subscriber_; + ros::Subscriber cmd_vel_subscriber_; + + std::map torque_publishers; + std::vector torque_commands; + + geometry_msgs::Twist vel; + geometry_msgs::Pose pose; + geometry_msgs::Twist cmd_vel; + + std::vector joint_names; + std::vector joint_positions; + std::vector joint_velocities; + + + + ros::Timer timer; + + std::chrono::high_resolution_clock::time_point start_time; + + // other rl module + torch::jit::script::Module encoder; + torch::jit::script::Module vq; + + UNITREE_LEGGED_SDK::LowCmd SendLowLCM = {0}; + UNITREE_LEGGED_SDK::LowState RecvLowLCM = {0}; + unitree_legged_msgs::LowCmd SendLowROS; + unitree_legged_msgs::LowState RecvLowROS; +}; + +#endif \ No newline at end of file diff --git a/src/unitree_rl/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/a1_const.h b/src/unitree_rl/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/a1_const.h new file mode 100644 index 0000000..a588d04 --- /dev/null +++ b/src/unitree_rl/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/a1_const.h @@ -0,0 +1,19 @@ +/************************************************************************ +Copyright (c) 2020, Unitree Robotics.Co.Ltd. All rights reserved. +Use of this source code is governed by the MPL-2.0 license, see LICENSE. +************************************************************************/ + +#ifndef _UNITREE_LEGGED_A1_H_ +#define _UNITREE_LEGGED_A1_H_ + +namespace UNITREE_LEGGED_SDK +{ + constexpr double a1_Hip_max = 0.802; // unit:radian ( = 46 degree) + constexpr double a1_Hip_min = -0.802; // unit:radian ( = -46 degree) + constexpr double a1_Thigh_max = 4.19; // unit:radian ( = 240 degree) + constexpr double a1_Thigh_min = -1.05; // unit:radian ( = -60 degree) + constexpr double a1_Calf_max = -0.916; // unit:radian ( = -52.5 degree) + constexpr double a1_Calf_min = -2.7; // unit:radian ( = -154.5 degree) +} + +#endif diff --git a/src/unitree_rl/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/aliengo_const.h b/src/unitree_rl/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/aliengo_const.h new file mode 100644 index 0000000..362c78d --- /dev/null +++ b/src/unitree_rl/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/aliengo_const.h @@ -0,0 +1,19 @@ +/************************************************************************ +Copyright (c) 2020, Unitree Robotics.Co.Ltd. All rights reserved. +Use of this source code is governed by the MPL-2.0 license, see LICENSE. +************************************************************************/ + +#ifndef _UNITREE_LEGGED_ALIENGO_H_ +#define _UNITREE_LEGGED_ALIENGO_H_ + +namespace UNITREE_LEGGED_SDK +{ + constexpr double aliengo_Hip_max = 1.047; // unit:radian ( = 60 degree) + constexpr double aliengo_Hip_min = -0.873; // unit:radian ( = -50 degree) + constexpr double aliengo_Thigh_max = 3.927; // unit:radian ( = 225 degree) + constexpr double aliengo_Thigh_min = -0.524; // unit:radian ( = -30 degree) + constexpr double aliengo_Calf_max = -0.611; // unit:radian ( = -35 degree) + constexpr double aliengo_Calf_min = -2.775; // unit:radian ( = -159 degree) +} + +#endif \ No newline at end of file diff --git a/src/unitree_rl/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/comm.h b/src/unitree_rl/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/comm.h new file mode 100644 index 0000000..975c4ee --- /dev/null +++ b/src/unitree_rl/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/comm.h @@ -0,0 +1,167 @@ +/************************************************************************ +Copyright (c) 2020, Unitree Robotics.Co.Ltd. All rights reserved. +Use of this source code is governed by the MPL-2.0 license, see LICENSE. +************************************************************************/ + +#ifndef _UNITREE_LEGGED_COMM_H_ +#define _UNITREE_LEGGED_COMM_H_ + +#include + +namespace UNITREE_LEGGED_SDK +{ + + constexpr int HIGHLEVEL = 0x00; + constexpr int LOWLEVEL = 0xff; + constexpr double PosStopF = (2.146E+9f); + constexpr double VelStopF = (16000.0f); + +#pragma pack(1) + + typedef struct + { + float x; + float y; + float z; + } Cartesian; + + typedef struct + { + float quaternion[4]; // quaternion, normalized, (w,x,y,z) + float gyroscope[3]; // angular velocity (unit: rad/s) + float accelerometer[3]; // m/(s2) + float rpy[3]; // euler angle(unit: rad) + int8_t temperature; + } IMU; // when under accelerated motion, the attitude of the robot calculated by IMU will drift. + + typedef struct + { + uint8_t r; + uint8_t g; + uint8_t b; + } LED; // foot led brightness: 0~255 + + typedef struct + { + uint8_t mode; // motor working mode + float q; // current angle (unit: radian) + float dq; // current velocity (unit: radian/second) + float ddq; // current acc (unit: radian/second*second) + float tauEst; // current estimated output torque (unit: N.m) + float q_raw; // current angle (unit: radian) + float dq_raw; // current velocity (unit: radian/second) + float ddq_raw; + int8_t temperature; // current temperature (temperature conduction is slow that leads to lag) + uint32_t reserve[2]; + } MotorState; // motor feedback + + typedef struct + { + uint8_t mode; // desired working mode + float q; // desired angle (unit: radian) + float dq; // desired velocity (unit: radian/second) + float tau; // desired output torque (unit: N.m) + float Kp; // desired position stiffness (unit: N.m/rad ) + float Kd; // desired velocity stiffness (unit: N.m/(rad/s) ) + uint32_t reserve[3]; + } MotorCmd; // motor control + + typedef struct + { + uint8_t levelFlag; // flag to distinguish high level or low level + uint16_t commVersion; + uint16_t robotID; + uint32_t SN; + uint8_t bandWidth; + IMU imu; + MotorState motorState[20]; + int16_t footForce[4]; // force sensors + int16_t footForceEst[4]; // force sensors + uint32_t tick; // reference real-time from motion controller (unit: us) + uint8_t wirelessRemote[40]; // wireless commands + uint32_t reserve; + uint32_t crc; + } LowState; // low level feedback + + typedef struct + { + uint8_t levelFlag; + uint16_t commVersion; + uint16_t robotID; + uint32_t SN; + uint8_t bandWidth; + MotorCmd motorCmd[20]; + LED led[4]; + uint8_t wirelessRemote[40]; + uint32_t reserve; + uint32_t crc; + } LowCmd; // low level control + + typedef struct + { + uint8_t levelFlag; + uint16_t commVersion; + uint16_t robotID; + uint32_t SN; + uint8_t bandWidth; + uint8_t mode; + IMU imu; + float forwardSpeed; + float sideSpeed; + float rotateSpeed; + float bodyHeight; + float updownSpeed; // speed of stand up or squat down + float forwardPosition; // front or rear displacement, an integrated number form kinematics function, usually drift + float sidePosition; // left or right displacement, an integrated number form kinematics function, usually drift + Cartesian footPosition2Body[4]; // foot position relative to body + Cartesian footSpeed2Body[4]; // foot speed relative to body + int16_t footForce[4]; + int16_t footForceEst[4]; + uint32_t tick; // reference real-time from motion controller (unit: us) + uint8_t wirelessRemote[40]; + uint32_t reserve; + uint32_t crc; + } HighState; // high level feedback + + typedef struct + { + uint8_t levelFlag; + uint16_t commVersion; + uint16_t robotID; + uint32_t SN; + uint8_t bandWidth; + uint8_t mode; // 0:idle, default stand 1:forced stand 2:walk continuously + float forwardSpeed; // speed of move forward or backward, scale: -1~1 + float sideSpeed; // speed of move left or right, scale: -1~1 + float rotateSpeed; // speed of spin left or right, scale: -1~1 + float bodyHeight; // body height, scale: -1~1 + float footRaiseHeight; // foot up height while walking (unavailable now) + float yaw; // unit: radian, scale: -1~1 + float pitch; // unit: radian, scale: -1~1 + float roll; // unit: radian, scale: -1~1 + LED led[4]; + uint8_t wirelessRemote[40]; + uint8_t AppRemote[40]; + uint32_t reserve; + uint32_t crc; + } HighCmd; // high level control + +#pragma pack() + + typedef struct + { + unsigned long long TotalCount; // total loop count + unsigned long long SendCount; // total send count + unsigned long long RecvCount; // total receive count + unsigned long long SendError; // total send error + unsigned long long FlagError; // total flag error + unsigned long long RecvCRCError; // total reveive CRC error + unsigned long long RecvLoseError; // total lose package count + } UDPState; // UDP communication state + + constexpr int HIGH_CMD_LENGTH = (sizeof(HighCmd)); + constexpr int HIGH_STATE_LENGTH = (sizeof(HighState)); + +} + +#endif diff --git a/src/unitree_rl/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/lcm.h b/src/unitree_rl/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/lcm.h new file mode 100644 index 0000000..957b7e6 --- /dev/null +++ b/src/unitree_rl/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/lcm.h @@ -0,0 +1,82 @@ +/************************************************************************ +Copyright (c) 2020, Unitree Robotics.Co.Ltd. All rights reserved. +Use of this source code is governed by the MPL-2.0 license, see LICENSE. +************************************************************************/ + +#ifndef _UNITREE_LEGGED_LCM_H_ +#define _UNITREE_LEGGED_LCM_H_ + +#include "comm.h" +#include +#include + +namespace UNITREE_LEGGED_SDK +{ + + constexpr char highCmdChannel[] = "LCM_High_Cmd"; + constexpr char highStateChannel[] = "LCM_High_State"; + constexpr char lowCmdChannel[] = "LCM_Low_Cmd"; + constexpr char lowStateChannel[] = "LCM_Low_State"; + + template + class LCMHandler + { + public: + LCMHandler(){ + pthread_mutex_init(&countMut, NULL); + pthread_mutex_init(&recvMut, NULL); + } + + void onMsg(const lcm::ReceiveBuffer* rbuf, const std::string& channel){ + isrunning = true; + + pthread_mutex_lock(&countMut); + counter = 0; + pthread_mutex_unlock(&countMut); + + T *msg = NULL; + msg = (T *)(rbuf->data); + pthread_mutex_lock(&recvMut); + // sourceBuf = *msg; + memcpy(&sourceBuf, msg, sizeof(T)); + pthread_mutex_unlock(&recvMut); + } + + bool isrunning = false; + T sourceBuf = {0}; + pthread_mutex_t countMut; + pthread_mutex_t recvMut; + int counter = 0; + }; + + class LCM { + public: + LCM(uint8_t level); + ~LCM(); + void SubscribeCmd(); + void SubscribeState(); // remember to call this when change control level + int Send(HighCmd&); // lcm send cmd + int Send(LowCmd&); // lcm send cmd + int Send(HighState&); // lcm send state + int Send(LowState&); // lcm send state + int Recv(); // directly save in buffer + void Get(HighCmd&); + void Get(LowCmd&); + void Get(HighState&); + void Get(LowState&); + + LCMHandler highStateLCMHandler; + LCMHandler lowStateLCMHandler; + LCMHandler highCmdLCMHandler; + LCMHandler lowCmdLCMHandler; + private: + uint8_t levelFlag = HIGHLEVEL; // default: high level + lcm::LCM lcm; + lcm::Subscription* subLcm; + int lcmFd; + int LCM_PERIOD = 2000; //default 2ms + }; + +} + +#endif diff --git a/src/unitree_rl/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/lcm_server.h b/src/unitree_rl/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/lcm_server.h new file mode 100644 index 0000000..d9bd662 --- /dev/null +++ b/src/unitree_rl/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/lcm_server.h @@ -0,0 +1,106 @@ +/************************************************************************ +Copyright (c) 2020, Unitree Robotics.Co.Ltd. All rights reserved. +Use of this source code is governed by the MPL-2.0 license, see LICENSE. +************************************************************************/ + +#ifndef _UNITREE_LEGGED_LCM_SERVER_ +#define _UNITREE_LEGGED_LCM_SERVER_ + +#include "comm.h" +#include "unitree_legged_sdk/unitree_legged_sdk.h" + +namespace UNITREE_LEGGED_SDK +{ +// Low command Lcm Server +class Lcm_Server_Low +{ +public: + Lcm_Server_Low(LeggedType rname) : udp(LOWLEVEL), mylcm(LOWLEVEL){ + udp.InitCmdData(cmd); + } + void UDPRecv(){ + udp.Recv(); + } + void UDPSend(){ + udp.Send(); + } + void LCMRecv(); + void RobotControl(); + + UDP udp; + LCM mylcm; + LowCmd cmd = {0}; + LowState state = {0}; +}; +void Lcm_Server_Low::LCMRecv() +{ + if(mylcm.lowCmdLCMHandler.isrunning){ + pthread_mutex_lock(&mylcm.lowCmdLCMHandler.countMut); + mylcm.lowCmdLCMHandler.counter++; + if(mylcm.lowCmdLCMHandler.counter > 10){ + printf("Error! LCM Time out.\n"); + exit(-1); // can be commented out + } + pthread_mutex_unlock(&mylcm.lowCmdLCMHandler.countMut); + } + mylcm.Recv(); +} +void Lcm_Server_Low::RobotControl() +{ + udp.GetRecv(state); + mylcm.Send(state); + mylcm.Get(cmd); + udp.SetSend(cmd); +} + + + +// High command Lcm Server +class Lcm_Server_High +{ +public: + Lcm_Server_High(LeggedType rname): udp(HIGHLEVEL), mylcm(HIGHLEVEL){ + udp.InitCmdData(cmd); + } + void UDPRecv(){ + udp.Recv(); + } + void UDPSend(){ + udp.Send(); + } + void LCMRecv(); + void RobotControl(); + + UDP udp; + LCM mylcm; + HighCmd cmd = {0}; + HighState state = {0}; +}; + +void Lcm_Server_High::LCMRecv() +{ + if(mylcm.highCmdLCMHandler.isrunning){ + pthread_mutex_lock(&mylcm.highCmdLCMHandler.countMut); + mylcm.highCmdLCMHandler.counter++; + if(mylcm.highCmdLCMHandler.counter > 10){ + printf("Error! LCM Time out.\n"); + exit(-1); // can be commented out + } + pthread_mutex_unlock(&mylcm.highCmdLCMHandler.countMut); + } + mylcm.Recv(); +} + +void Lcm_Server_High::RobotControl() +{ + udp.GetRecv(state); + mylcm.Send(state); + mylcm.Get(cmd); + udp.SetSend(cmd); +} + + + + +} +#endif \ No newline at end of file diff --git a/src/unitree_rl/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/loop.h b/src/unitree_rl/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/loop.h new file mode 100644 index 0000000..c58d67b --- /dev/null +++ b/src/unitree_rl/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/loop.h @@ -0,0 +1,58 @@ +/************************************************************************ +Copyright (c) 2020, Unitree Robotics.Co.Ltd. All rights reserved. +Use of this source code is governed by the MPL-2.0 license, see LICENSE. +************************************************************************/ + +#ifndef _UNITREE_LEGGED_LOOP_H_ +#define _UNITREE_LEGGED_LOOP_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace UNITREE_LEGGED_SDK +{ + +constexpr int THREAD_PRIORITY = 99; // real-time priority + +typedef boost::function Callback; + +class Loop { +public: + Loop(std::string name, float period, int bindCPU = -1):_name(name), _period(period), _bindCPU(bindCPU) {} + ~Loop(); + void start(); + void shutdown(); + virtual void functionCB() = 0; + +private: + void entryFunc(); + + std::string _name; + float _period; + int _bindCPU; + bool _bind_cpu_flag = false; + bool _isrunning = false; + std::thread _thread; +}; + +class LoopFunc : public Loop { +public: + LoopFunc(std::string name, float period, const Callback& _cb) + : Loop(name, period), _fp(_cb){} + LoopFunc(std::string name, float period, int bindCPU, const Callback& _cb) + : Loop(name, period, bindCPU), _fp(_cb){} + void functionCB() { (_fp)(); } +private: + boost::function _fp; +}; + + +} + +#endif diff --git a/src/unitree_rl/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/quadruped.h b/src/unitree_rl/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/quadruped.h new file mode 100644 index 0000000..0c23f62 --- /dev/null +++ b/src/unitree_rl/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/quadruped.h @@ -0,0 +1,48 @@ +/************************************************************************ +Copyright (c) 2020, Unitree Robotics.Co.Ltd. All rights reserved. +Use of this source code is governed by the MPL-2.0 license, see LICENSE. +************************************************************************/ + +#ifndef _UNITREE_LEGGED_QUADRUPED_H_ +#define _UNITREE_LEGGED_QUADRUPED_H_ + +namespace UNITREE_LEGGED_SDK +{ + +enum class LeggedType { + Aliengo, + A1 +}; + +enum class HighLevelType { + Basic, + Sport +}; + +void InitEnvironment(); // memory lock + +// definition of each leg and joint +constexpr int FR_ = 0; // leg index +constexpr int FL_ = 1; +constexpr int RR_ = 2; +constexpr int RL_ = 3; + +constexpr int FR_0 = 0; // joint index +constexpr int FR_1 = 1; +constexpr int FR_2 = 2; + +constexpr int FL_0 = 3; +constexpr int FL_1 = 4; +constexpr int FL_2 = 5; + +constexpr int RR_0 = 6; +constexpr int RR_1 = 7; +constexpr int RR_2 = 8; + +constexpr int RL_0 = 9; +constexpr int RL_1 = 10; +constexpr int RL_2 = 11; + +} + +#endif diff --git a/src/unitree_rl/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/safety.h b/src/unitree_rl/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/safety.h new file mode 100644 index 0000000..0013f61 --- /dev/null +++ b/src/unitree_rl/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/safety.h @@ -0,0 +1,32 @@ +/************************************************************************ +Copyright (c) 2020, Unitree Robotics.Co.Ltd. All rights reserved. +Use of this source code is governed by the MPL-2.0 license, see LICENSE. +************************************************************************/ + +#ifndef _UNITREE_LEGGED_SAFETY_H_ +#define _UNITREE_LEGGED_SAFETY_H_ + +#include "comm.h" +#include "quadruped.h" + +namespace UNITREE_LEGGED_SDK +{ + +class Safety{ +public: + Safety(LeggedType type); + ~Safety(); + void PositionLimit(LowCmd&); // only effect under Low Level control in Position mode + void PowerProtect(LowCmd&, LowState&, int); /* only effect under Low Level control, input factor: 1~10, + means 10%~100% power limit. If you are new, then use 1; if you are familiar, + then can try bigger number or even comment this function. */ + void PositionProtect(LowCmd&, LowState&, double limit = 0.087); // default limit is 5 degree +private: + int WattLimit, Wcount; // Watt. When limit to 100, you can triger it with 4 hands shaking. + double Hip_max, Hip_min, Thigh_max, Thigh_min, Calf_max, Calf_min; +}; + + +} + +#endif diff --git a/src/unitree_rl/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/udp.h b/src/unitree_rl/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/udp.h new file mode 100644 index 0000000..35d318c --- /dev/null +++ b/src/unitree_rl/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/udp.h @@ -0,0 +1,64 @@ +/************************************************************************ +Copyright (c) 2020, Unitree Robotics.Co.Ltd. All rights reserved. +Use of this source code is governed by the MPL-2.0 license, see LICENSE. +************************************************************************/ + +#ifndef _UNITREE_LEGGED_UDP_H_ +#define _UNITREE_LEGGED_UDP_H_ + +#include "comm.h" +#include "unitree_legged_sdk/quadruped.h" +#include + +namespace UNITREE_LEGGED_SDK +{ + +constexpr int UDP_CLIENT_PORT = 8080; // local port +constexpr int UDP_SERVER_PORT = 8007; // target port +constexpr char UDP_SERVER_IP_BASIC[] = "192.168.123.10"; // target IP address +constexpr char UDP_SERVER_IP_SPORT[] = "192.168.123.161"; // target IP address + +// Notice: User defined data(like struct) should add crc(4Byte) at the end. +class UDP { +public: + UDP(uint8_t level, HighLevelType highControl = HighLevelType::Basic); // unitree dafault IP and Port + UDP(uint16_t localPort, const char* targetIP, uint16_t targetPort, int sendLength, int recvLength); + UDP(uint16_t localPort, uint16_t targetPort, int sendLength, int recvLength); // as server, client IP can change + ~UDP(); + void InitCmdData(HighCmd& cmd); + void InitCmdData(LowCmd& cmd); + void switchLevel(int level); + + int SetSend(HighCmd&); + int SetSend(LowCmd&); + int SetSend(char* cmd); + void GetRecv(HighState&); + void GetRecv(LowState&); + void GetRecv(char*); + int Send(); + int Recv(); // directly save in buffer + + UDPState udpState; + char* targetIP; + uint16_t targetPort; + char* localIP; + uint16_t localPort; +private: + void init(uint16_t localPort, const char* targetIP, uint16_t targetPort); + + uint8_t levelFlag = HIGHLEVEL; // default: high level + int sockFd; + bool connected; // udp only works when connected + int sendLength; + int recvLength; + char* recvTemp; + char* recvBuf; + char* sendBuf; + int lose_recv; + pthread_mutex_t sendMut; + pthread_mutex_t recvMut; +}; + +} + +#endif diff --git a/src/unitree_rl/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/unitree_joystick.h b/src/unitree_rl/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/unitree_joystick.h new file mode 100644 index 0000000..b9e0e14 --- /dev/null +++ b/src/unitree_rl/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/unitree_joystick.h @@ -0,0 +1,44 @@ +/***************************************************************** +Copyright (c) 2020, Unitree Robotics.Co.Ltd. All rights reserved. +*****************************************************************/ +#ifndef UNITREE_JOYSTICK_H +#define UNITREE_JOYSTICK_H + +#include +// // 16b +// typedef union { +// struct { +// uint8_t R1 :1; +// uint8_t L1 :1; +// uint8_t start :1; +// uint8_t select :1; +// uint8_t R2 :1; +// uint8_t L2 :1; +// uint8_t F1 :1; +// uint8_t F2 :1; +// uint8_t A :1; +// uint8_t B :1; +// uint8_t X :1; +// uint8_t Y :1; +// uint8_t up :1; +// uint8_t right :1; +// uint8_t down :1; +// uint8_t left :1; +// } components; +// uint16_t value; +// } xKeySwitchUnion; + +// // 40 Byte (now used 24B) +// typedef struct { +// uint8_t head[2]; +// xKeySwitchUnion btn; +// float lx; +// float rx; +// float ry; +// float L2; +// float ly; + +// uint8_t idle[16]; +// } xRockerBtnDataStruct; + +#endif // UNITREE_JOYSTICK_H \ No newline at end of file diff --git a/src/unitree_rl/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/unitree_legged_sdk.h b/src/unitree_rl/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/unitree_legged_sdk.h new file mode 100644 index 0000000..a93a6c1 --- /dev/null +++ b/src/unitree_rl/library/unitree_legged_sdk_3.2/include/unitree_legged_sdk/unitree_legged_sdk.h @@ -0,0 +1,19 @@ +/************************************************************************ +Copyright (c) 2020, Unitree Robotics.Co.Ltd. All rights reserved. +Use of this source code is governed by the MPL-2.0 license, see LICENSE. +************************************************************************/ + +#ifndef _UNITREE_LEGGED_SDK_H_ +#define _UNITREE_LEGGED_SDK_H_ + +#include "comm.h" +#include "safety.h" +#include "udp.h" +#include "loop.h" +#include "lcm.h" +#include "quadruped.h" +#include + +#define UT UNITREE_LEGGED_SDK //short name + +#endif diff --git a/src/unitree_rl/library/unitree_legged_sdk_3.2/lib/libunitree_legged_sdk_amd64.so b/src/unitree_rl/library/unitree_legged_sdk_3.2/lib/libunitree_legged_sdk_amd64.so new file mode 100644 index 0000000..eb0ada0 Binary files /dev/null and b/src/unitree_rl/library/unitree_legged_sdk_3.2/lib/libunitree_legged_sdk_amd64.so differ diff --git a/src/unitree_rl/library/unitree_legged_sdk_3.2/lib/libunitree_legged_sdk_arm32.so b/src/unitree_rl/library/unitree_legged_sdk_3.2/lib/libunitree_legged_sdk_arm32.so new file mode 100644 index 0000000..40c5b0d Binary files /dev/null and b/src/unitree_rl/library/unitree_legged_sdk_3.2/lib/libunitree_legged_sdk_arm32.so differ diff --git a/src/unitree_rl/library/unitree_legged_sdk_3.2/lib/libunitree_legged_sdk_arm64.so b/src/unitree_rl/library/unitree_legged_sdk_3.2/lib/libunitree_legged_sdk_arm64.so new file mode 100644 index 0000000..e734e82 Binary files /dev/null and b/src/unitree_rl/library/unitree_legged_sdk_3.2/lib/libunitree_legged_sdk_arm64.so differ diff --git a/src/unitree_rl/src/unitree_rl.cpp b/src/unitree_rl/src/unitree_rl.cpp index b69b5af..2f34a44 100644 --- a/src/unitree_rl/src/unitree_rl.cpp +++ b/src/unitree_rl/src/unitree_rl.cpp @@ -10,21 +10,6 @@ Unitree_RL::Unitree_RL() torque_commands.resize(12); - ros_namespace = "/a1_gazebo/"; - - joint_names = { - "FL_hip_joint", "FL_thigh_joint", "FL_calf_joint", - "FR_hip_joint", "FR_thigh_joint", "FR_calf_joint", - "RL_hip_joint", "RL_thigh_joint", "RL_calf_joint", - "RR_hip_joint", "RR_thigh_joint", "RR_calf_joint", - }; - - for (int i = 0; i < 12; ++i) - { - torque_publishers[joint_names[i]] = nh.advertise( - ros_namespace + joint_names[i].substr(0, joint_names[i].size() - 6) + "_controller/command", 10); - } - std::string package_name = "unitree_rl"; std::string actor_path = ros::package::getPath(package_name) + "/models/actor.pt"; std::string encoder_path = ros::package::getPath(package_name) + "/models/encoder.pt"; @@ -51,30 +36,45 @@ Unitree_RL::Unitree_RL() this->params.dof_vel_scale = 0.05; this->params.commands_scale = torch::tensor({this->params.lin_vel_scale, this->params.lin_vel_scale, this->params.ang_vel_scale}); - // hip, thigh, calf - this->params.torque_limits = torch::tensor({{20.0, 55.0, 55.0, // front left - 20.0, 55.0, 55.0, // front right - 20.0, 55.0, 55.0, // rear left - 20.0, 55.0, 55.0}}); // rear right + + this->params.torque_limits = torch::tensor({{20.0, 55.0, 55.0, + 20.0, 55.0, 55.0, + 20.0, 55.0, 55.0, + 20.0, 55.0, 55.0}}); - this->params.default_dof_pos = torch::tensor({{0.1000, 0.8000, -1.5000, - -0.1000, 0.8000, -1.5000, - 0.1000, 1.0000, -1.5000, - -0.1000, 1.0000, -1.5000}}); + // hip, thigh, calf + this->params.default_dof_pos = torch::tensor({{0.1000, 0.8000, -1.5000, // front left + -0.1000, 0.8000, -1.5000, // front right + 0.1000, 1.0000, -1.5000, // rear left + -0.1000, 1.0000, -1.5000}});// rear right this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6); - // Create a subscriber object - model_state_subscriber_ = nh.subscribe( - "/gazebo/model_states", 10, &Unitree_RL::modelStatesCallback, this); - - joint_state_subscriber_ = nh.subscribe( - "/a1_gazebo/joint_states", 10, &Unitree_RL::jointStatesCallback, this); - cmd_vel_subscriber_ = nh.subscribe( "/cmd_vel", 10, &Unitree_RL::cmdvelCallback, this); timer = nh.createTimer(ros::Duration(0.005), &Unitree_RL::runModel, this); + + ros_namespace = "/a1_gazebo/"; + + joint_names = { + "FL_hip_joint", "FL_thigh_joint", "FL_calf_joint", + "FR_hip_joint", "FR_thigh_joint", "FR_calf_joint", + "RL_hip_joint", "RL_thigh_joint", "RL_calf_joint", + "RR_hip_joint", "RR_thigh_joint", "RR_calf_joint", + }; + + for (int i = 0; i < 12; ++i) + { + torque_publishers[joint_names[i]] = nh.advertise( + ros_namespace + joint_names[i].substr(0, joint_names[i].size() - 6) + "_controller/command", 10); + } + + model_state_subscriber_ = nh.subscribe( + "/gazebo/model_states", 10, &Unitree_RL::modelStatesCallback, this); + + joint_state_subscriber_ = nh.subscribe( + "/a1_gazebo/joint_states", 10, &Unitree_RL::jointStatesCallback, this); } void Unitree_RL::modelStatesCallback(const gazebo_msgs::ModelStates::ConstPtr &msg) @@ -98,6 +98,7 @@ void Unitree_RL::jointStatesCallback(const sensor_msgs::JointState::ConstPtr &ms void Unitree_RL::runModel(const ros::TimerEvent &event) { auto duration = std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start_time).count(); + // std::cout << "Execution time: " << duration << " microseconds" << std::endl; start_time = std::chrono::high_resolution_clock::now(); this->obs.lin_vel = torch::tensor({{vel.linear.x, vel.linear.y, vel.linear.z}}); @@ -113,8 +114,8 @@ void Unitree_RL::runModel(const ros::TimerEvent &event) joint_velocities[7], joint_velocities[8], joint_velocities[6], joint_velocities[10], joint_velocities[11], joint_velocities[9]}}); - torques = this->compute_torques(this->forward()); - + torch::Tensor actions = this->forward(); + torques = this->compute_torques(actions); for (int i = 0; i < 12; ++i) { diff --git a/src/unitree_rl/src/unitree_rl_real.cpp b/src/unitree_rl/src/unitree_rl_real.cpp new file mode 100644 index 0000000..397d1d7 --- /dev/null +++ b/src/unitree_rl/src/unitree_rl_real.cpp @@ -0,0 +1,239 @@ +#include "../include/unitree_rl_real.hpp" +#include + +void Unitree_RL::UDPRecv() +{ + udp.Recv(); +} + +void Unitree_RL::UDPSend() +{ + udp.Send(); +} + +void Unitree_RL::RobotControl() +{ + motiontime++; + udp.GetRecv(state); + + if( motiontime < 50) + { + for(int i = 0; i < 12; ++i) + { + cmd.motorCmd[i].q = state.motorState[i].q; + _startPos[i] = state.motorState[i].q; + } + } + + if( motiontime >= 50 && _percent != 1) + { + _percent += (float) 1 / 100; + _percent = _percent > 1 ? 1 : _percent; + for(int i = 0; i < 12; ++i) + { + cmd.motorCmd[i].q = (1 - _percent) * _startPos[i] + _percent * _targetPos[i]; + cmd.motorCmd[i].dq = 0; + cmd.motorCmd[i].Kp = 50; + cmd.motorCmd[i].Kd = 3; + cmd.motorCmd[i].tau = 0; + } + } + if(_percent == 1 && !init_done) + { + init_done = true; + this->init_observations(); + loop_rl->start(); + std::cout << "init done" << std::endl; + motiontime = 0; + } + + if(init_done) + { + if( motiontime < 50) + { + for(int i = 0; i < 12; ++i) + { + cmd.motorCmd[i].q = _targetPos[i]; + cmd.motorCmd[i].dq = 0; + cmd.motorCmd[i].Kp = 50; + cmd.motorCmd[i].Kd = 3; + cmd.motorCmd[i].tau = 0; + _startPos[i] = state.motorState[i].q; + } + } + if( motiontime >= 50) + { + for (int i = 0; i < 12; ++i) + { + float torque = torques[0][i].item(); + // if(torque > 5.0f) torque = 5.0f; + // if(torque < -5.0f) torque = -5.0f; + + cmd.motorCmd[i].q = PosStopF; + cmd.motorCmd[i].dq = VelStopF; + cmd.motorCmd[i].Kp = 0; + cmd.motorCmd[i].Kd = 0; + cmd.motorCmd[i].tau = torque; + } + } + } + + // safe.PowerProtect(cmd, state, 1); + udp.SetSend(cmd); +} + +Unitree_RL::Unitree_RL() : safe(LeggedType::A1), udp(LOWLEVEL) +{ + udp.InitCmdData(cmd); + + start_time = std::chrono::high_resolution_clock::now(); + + cmd_vel = geometry_msgs::Twist(); + + torque_commands.resize(12); + + std::string package_name = "unitree_rl"; + std::string actor_path = ros::package::getPath(package_name) + "/models/actor.pt"; + std::string encoder_path = ros::package::getPath(package_name) + "/models/encoder.pt"; + std::string vq_path = ros::package::getPath(package_name) + "/models/vq_layer.pt"; + + this->actor = torch::jit::load(actor_path); + this->encoder = torch::jit::load(encoder_path); + this->vq = torch::jit::load(vq_path); + this->init_observations(); + + this->params.num_observations = 45; + this->params.clip_obs = 100.0; + this->params.clip_actions = 100.0; + this->params.damping = 0.5; + this->params.stiffness = 20; + this->params.d_gains = torch::ones(12) * this->params.damping; + this->params.p_gains = torch::ones(12) * this->params.stiffness; + this->params.action_scale = 0.25; + this->params.hip_scale_reduction = 0.5; + this->params.num_of_dofs = 12; + this->params.lin_vel_scale = 2.0; + this->params.ang_vel_scale = 0.25; + this->params.dof_pos_scale = 1.0; + this->params.dof_vel_scale = 0.05; + this->params.commands_scale = torch::tensor({this->params.lin_vel_scale, this->params.lin_vel_scale, this->params.ang_vel_scale}); + + + this->params.torque_limits = torch::tensor({{20.0, 55.0, 55.0, + 20.0, 55.0, 55.0, + 20.0, 55.0, 55.0, + 20.0, 55.0, 55.0}}); + + // hip, thigh, calf + this->params.default_dof_pos = torch::tensor({{-0.1000, 0.8000, -1.5000, // front right + 0.1000, 0.8000, -1.5000, // front left + -0.1000, 1.0000, -1.5000, // rear right + 0.1000, 1.0000, -1.5000}});// rear left + + this->history_obs_buf = ObservationBuffer(1, this->params.num_observations, 6); + + // InitEnvironment(); + loop_control = std::make_shared("control_loop", 0.02 , boost::bind(&Unitree_RL::RobotControl, this)); + loop_udpSend = std::make_shared("udp_send" , 0.002, 3, boost::bind(&Unitree_RL::UDPSend, this)); + loop_udpRecv = std::make_shared("udp_recv" , 0.002, 3, boost::bind(&Unitree_RL::UDPRecv, this)); + loop_rl = std::make_shared("rl_loop" , 0.02 , boost::bind(&Unitree_RL::runModel, this)); + + loop_udpSend->start(); + loop_udpRecv->start(); + loop_control->start(); +} + +void Unitree_RL::cmdvelCallback(const geometry_msgs::Twist::ConstPtr &msg) +{ + cmd_vel = *msg; +} + +// void Unitree_RL::runModel(const ros::TimerEvent &event) +void Unitree_RL::runModel() +{ + if(init_done) + { + auto duration = std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start_time).count(); + // std::cout << "Execution time: " << duration << " microseconds" << std::endl; + start_time = std::chrono::high_resolution_clock::now(); + + // printf("%f, %f, %f\n", state.imu.gyroscope[0], state.imu.gyroscope[1], state.imu.gyroscope[2]); + // printf("%f, %f, %f, %f\n", state.imu.quaternion[1], state.imu.quaternion[2], state.imu.quaternion[3], state.imu.quaternion[0]); + // printf("%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f\n", state.motorState[FL_0].q, state.motorState[FL_1].q, state.motorState[FL_2].q, state.motorState[FR_0].q, state.motorState[FR_1].q, state.motorState[FR_2].q, state.motorState[RL_0].q, state.motorState[RL_1].q, state.motorState[RL_2].q, state.motorState[RR_0].q, state.motorState[RR_1].q, state.motorState[RR_2].q); + // printf("%f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f, %f\n", state.motorState[FL_0].dq, state.motorState[FL_1].dq, state.motorState[FL_2].dq, state.motorState[FR_0].dq, state.motorState[FR_1].dq, state.motorState[FR_2].dq, state.motorState[RL_0].dq, state.motorState[RL_1].dq, state.motorState[RL_2].dq, state.motorState[RR_0].dq, state.motorState[RR_1].dq, state.motorState[RR_2].dq); + + this->obs.ang_vel = torch::tensor({{state.imu.gyroscope[0], state.imu.gyroscope[1], state.imu.gyroscope[2]}}); + this->obs.commands = torch::tensor({{cmd_vel.linear.x, cmd_vel.linear.y, cmd_vel.angular.z}}); + this->obs.base_quat = torch::tensor({{state.imu.quaternion[1], state.imu.quaternion[2], state.imu.quaternion[3], state.imu.quaternion[0]}}); + this->obs.dof_pos = torch::tensor({{state.motorState[FR_0].q, state.motorState[FR_1].q, state.motorState[FR_2].q, + state.motorState[FL_0].q, state.motorState[FL_1].q, state.motorState[FL_2].q, + state.motorState[RR_0].q, state.motorState[RR_1].q, state.motorState[RR_2].q, + state.motorState[RL_0].q, state.motorState[RL_1].q, state.motorState[RL_2].q}}); + this->obs.dof_vel = torch::tensor({{state.motorState[FR_0].dq, state.motorState[FR_1].dq, state.motorState[FR_2].dq, + state.motorState[FL_0].dq, state.motorState[FL_1].dq, state.motorState[FL_2].dq, + state.motorState[RR_0].dq, state.motorState[RR_1].dq, state.motorState[RR_2].dq, + state.motorState[RL_0].dq, state.motorState[RL_1].dq, state.motorState[RL_2].dq}}); + + torch::Tensor actions = this->forward(); + torques = this->compute_torques(actions); + } + +} + +torch::Tensor Unitree_RL::compute_observation() +{ + torch::Tensor ang_vel = this->quat_rotate_inverse(this->obs.base_quat, this->obs.ang_vel); + // float ang_vel_temp = ang_vel[0][0].item(); + // ang_vel[0][0] = ang_vel[0][1]; + // ang_vel[0][1] = ang_vel_temp; + + torch::Tensor grav = this->quat_rotate_inverse(this->obs.base_quat, this->obs.gravity_vec); + // float grav_temp = grav[0][0].item(); + // grav[0][0] = grav[0][1]; + // grav[0][1] = grav_temp; + + torch::Tensor obs = torch::cat({// (this->quat_rotate_inverse(this->base_quat, this->lin_vel)) * this->params.lin_vel_scale, + ang_vel * this->params.ang_vel_scale, + grav, + this->obs.commands * this->params.commands_scale, + (this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale, + this->obs.dof_vel * this->params.dof_vel_scale, + this->obs.actions}, + 1); + obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs); + return obs; +} + +torch::Tensor Unitree_RL::forward() +{ + torch::Tensor obs = this->compute_observation(); + + history_obs_buf.insert(obs); + history_obs = history_obs_buf.get_obs_vec({0, 1, 2, 3, 4, 5}); + + torch::Tensor encoding = this->encoder.forward({history_obs}).toTensor(); + + torch::Tensor z = this->vq.forward({encoding}).toTensor(); + + torch::Tensor actor_input = torch::cat({obs, z}, 1); + + torch::Tensor action = this->actor.forward({actor_input}).toTensor(); + + this->obs.actions = action; + torch::Tensor clamped = torch::clamp(action, -this->params.clip_actions, this->params.clip_actions); + + return clamped; +} + + + +int main(int argc, char **argv) +{ + Unitree_RL unitree_rl; + + while(1){ + sleep(10); + }; + + return 0; +}