diff --git a/.gitignore b/.gitignore index 9aecb2f..680e994 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,4 @@ logs *fldlar* .cache *.json -# *gr1t1* \ No newline at end of file +__pycache__ \ No newline at end of file diff --git a/README.md b/README.md index 20371d8..fc9bb57 100644 --- a/README.md +++ b/README.md @@ -75,14 +75,22 @@ Before running, copy the trained pt model file to `rl_sar/src/rl_sar/models/YOUR ### Simulation -Open a new terminal, launch the gazebo simulation environment +Open a terminal, launch the gazebo simulation environment ```bash source devel/setup.bash roslaunch rl_sar gazebo_.launch ``` -Where \ can be `a1` or `gr1t1`. +Open a new terminal, launch the control program + +```bash +source devel/setup.bash +(for cpp version) rosrun rl_sar rl_sim +(for python version) rosrun rl_sar rl_sim.py +``` + +Where \ can be `a1` or `gr1t1` or `gr1t2`. Control: * Press **\** to toggle simulation start/stop. diff --git a/README_CN.md b/README_CN.md index d76362c..73667b9 100644 --- a/README_CN.md +++ b/README_CN.md @@ -75,14 +75,22 @@ catkin build ### 仿真 -新建终端,启动gazebo仿真环境 +打开一个终端,启动gazebo仿真环境 ```bash source devel/setup.bash roslaunch rl_sar gazebo_.launch ``` -其中 \ 可以是 `a1` 或 `gr1t1`. +打开一个新终端,启动控制程序 + +```bash +source devel/setup.bash +(for cpp version) rosrun rl_sar rl_sim +(for python version) rosrun rl_sar rl_sim.py +``` + +其中 \ 可以是 `a1` 或 `gr1t1` 或 `gr1t2`. 控制: diff --git a/src/rl_sar/CMakeLists.txt b/src/rl_sar/CMakeLists.txt index f8d01cb..5fe52c9 100644 --- a/src/rl_sar/CMakeLists.txt +++ b/src/rl_sar/CMakeLists.txt @@ -26,6 +26,7 @@ find_package(catkin REQUIRED COMPONENTS geometry_msgs robot_msgs robot_joint_controller + rospy ) find_package(Python3 COMPONENTS Interpreter Development REQUIRED) @@ -37,6 +38,7 @@ include_directories(${YAML_CPP_INCLUDE_DIR}) catkin_package( CATKIN_DEPENDS robot_joint_controller + rospy ) include_directories(library/unitree_legged_sdk_3.2/include) @@ -78,3 +80,9 @@ target_link_libraries(rl_real_a1 ${catkin_LIBRARIES} ${EXTRA_LIBS} rl_sdk observation_buffer yaml-cpp ) + +catkin_install_python(PROGRAMS + scripts/rl_sim.py + scripts/rl_sdk.py + DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} +) \ No newline at end of file diff --git a/src/rl_sar/launch/gazebo_a1.launch b/src/rl_sar/launch/gazebo_a1.launch index 32886be..a660845 100644 --- a/src/rl_sar/launch/gazebo_a1.launch +++ b/src/rl_sar/launch/gazebo_a1.launch @@ -1,7 +1,7 @@ - + @@ -37,8 +37,6 @@ - - - - diff --git a/src/rl_sar/launch/gazebo_gr1t1.launch b/src/rl_sar/launch/gazebo_gr1t1.launch index abf525c..1ad42d7 100644 --- a/src/rl_sar/launch/gazebo_gr1t1.launch +++ b/src/rl_sar/launch/gazebo_gr1t1.launch @@ -1,7 +1,7 @@ - + @@ -33,8 +33,6 @@ - - - - diff --git a/src/rl_sar/library/rl_sdk/rl_sdk.cpp b/src/rl_sar/library/rl_sdk/rl_sdk.cpp index d14ebb9..54f4489 100644 --- a/src/rl_sar/library/rl_sdk/rl_sdk.cpp +++ b/src/rl_sar/library/rl_sdk/rl_sdk.cpp @@ -1,31 +1,28 @@ #include "rl_sdk.hpp" /* You may need to override this ComputeObservation() function -torch::Tensor RL::ComputeObservation() +torch::Tensor RL_XXX::ComputeObservation() { - torch::Tensor obs = torch::cat({this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel) * this->params.ang_vel_scale, - this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec), - this->obs.commands * this->params.commands_scale, - (this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale, - this->obs.dof_vel * this->params.dof_vel_scale, - this->obs.actions - },1); + torch::Tensor obs = torch::cat({ + this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel) * this->params.ang_vel_scale, + this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec), + this->obs.commands * this->params.commands_scale, + (this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale, + this->obs.dof_vel * this->params.dof_vel_scale, + this->obs.actions + },1); torch::Tensor clamped_obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs); return clamped_obs; } */ /* You may need to override this Forward() function -torch::Tensor RL::Forward() +torch::Tensor RL_XXX::Forward() { torch::autograd::GradMode::set_enabled(false); - torch::Tensor clamped_obs = this->ComputeObservation(); - torch::Tensor actions = this->model.forward({clamped_obs}).toTensor(); - torch::Tensor clamped_actions = torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper); - return clamped_actions; } */ diff --git a/src/rl_sar/package.xml b/src/rl_sar/package.xml index 8cbfbcf..d665e58 100644 --- a/src/rl_sar/package.xml +++ b/src/rl_sar/package.xml @@ -25,6 +25,8 @@ robot_state_publisher roscpp std_msgs + rospy + rospy robot_msgs robot_joint_controller diff --git a/src/rl_sar/scripts/observation_buffer.py b/src/rl_sar/scripts/observation_buffer.py new file mode 100644 index 0000000..ddb53b5 --- /dev/null +++ b/src/rl_sar/scripts/observation_buffer.py @@ -0,0 +1,37 @@ +import torch + +class ObservationBuffer: + def __init__(self, num_envs, num_obs, include_history_steps): + + self.num_envs = num_envs + self.num_obs = num_obs + self.include_history_steps = include_history_steps + + self.num_obs_total = num_obs * include_history_steps + + self.obs_buf = torch.zeros(self.num_envs, self.num_obs_total, dtype=torch.float) + + def reset(self, reset_idxs, new_obs): + self.obs_buf[reset_idxs] = new_obs.repeat(1, self.include_history_steps) + + def insert(self, new_obs): + # Shift observations back. + self.obs_buf[:, : self.num_obs * (self.include_history_steps - 1)] = self.obs_buf[:,self.num_obs : self.num_obs * self.include_history_steps].clone() + + # Add new observation. + self.obs_buf[:, -self.num_obs:] = new_obs + + def get_obs_vec(self, obs_ids): + """Gets history of observations indexed by obs_ids. + + Arguments: + obs_ids: An array of integers with which to index the desired + observations, where 0 is the latest observation and + include_history_steps - 1 is the oldest observation. + """ + + obs = [] + for obs_id in reversed(sorted(obs_ids)): + slice_idx = self.include_history_steps - obs_id - 1 + obs.append(self.obs_buf[:, slice_idx * self.num_obs : (slice_idx + 1) * self.num_obs]) + return torch.cat(obs, dim=-1) diff --git a/src/rl_sar/scripts/rl_sdk.py b/src/rl_sar/scripts/rl_sdk.py new file mode 100644 index 0000000..642c563 --- /dev/null +++ b/src/rl_sar/scripts/rl_sdk.py @@ -0,0 +1,356 @@ +import torch +import yaml +import os +from pynput import keyboard +from enum import Enum, auto + +CONFIG_PATH = os.path.join(os.path.dirname(__file__), "../config.yaml") + +class LOGGER: + INFO = "\033[0;37m[INFO]\033[0m " + WARNING = "\033[0;33m[WARNING]\033[0m " + ERROR = "\033[0;31m[ERROR]\033[0m " + DEBUG = "\033[0;32m[DEBUG]\033[0m " + +class RobotCommand: + def __init__(self): + self.motor_command = self.MotorCommand() + + class MotorCommand: + def __init__(self): + self.q = [0.0] * 32 + self.dq = [0.0] * 32 + self.tau = [0.0] * 32 + self.kp = [0.0] * 32 + self.kd = [0.0] * 32 + +class RobotState: + def __init__(self): + self.imu = self.IMU() + self.motor_state = self.MotorState() + + class IMU: + def __init__(self): + self.quaternion = [1.0, 0.0, 0.0, 0.0] # w, x, y, z + self.gyroscope = [0.0, 0.0, 0.0] + self.accelerometer = [0.0, 0.0, 0.0] + + class MotorState: + def __init__(self): + self.q = [0.0] * 32 + self.dq = [0.0] * 32 + self.ddq = [0.0] * 32 + self.tauEst = [0.0] * 32 + self.cur = [0.0] * 32 + +class STATE(Enum): + STATE_WAITING = 0 + STATE_POS_GETUP = auto() + STATE_RL_INIT = auto() + STATE_RL_RUNNING = auto() + STATE_POS_GETDOWN = auto() + STATE_RESET_SIMULATION = auto() + STATE_TOGGLE_SIMULATION = auto() + +class Control: + def __init__(self): + self.control_state = STATE.STATE_WAITING + self.x = 0.0 + self.y = 0.0 + self.yaw = 0.0 + +class ModelParams: + def __init__(self): + self.model_name = None + self.dt = None + self.decimation = None + self.num_observations = None + self.damping = None + self.stiffness = None + self.action_scale = None + self.hip_scale_reduction = None + self.hip_scale_reduction_indices = None + self.clip_actions_upper = None + self.clip_actions_lower = None + self.num_of_dofs = None + self.lin_vel_scale = None + self.ang_vel_scale = None + self.dof_pos_scale = None + self.dof_vel_scale = None + self.clip_obs = None + self.torque_limits = None + self.rl_kd = None + self.rl_kp = None + self.fixed_kp = None + self.fixed_kd = None + self.commands_scale = None + self.default_dof_pos = None + self.joint_controller_names = None + +class Observations: + def __init__(self): + self.lin_vel = None + self.ang_vel = None + self.gravity_vec = None + self.commands = None + self.base_quat = None + self.dof_pos = None + self.dof_vel = None + self.actions = None + +class RL: + # Static variables + start_state = RobotState() + now_state = RobotState() + getup_percent = 0.0 + getdown_percent = 0.0 + + def __init__(self): + ### public in cpp ### + self.params = ModelParams() + self.obs = Observations() + + self.robot_state = RobotState() + self.robot_command = RobotCommand() + + # control + self.control = Control() + + # others + self.robot_name = "" + self.running_state = STATE.STATE_RL_RUNNING # default running_state set to STATE_RL_RUNNING + self.simulation_running = False + + ### protected in cpp ### + # rl module + self.model = None + self.walk_model = None + self.stand_model = None + + # output buffer + self.output_torques = torch.zeros(1, 32) + self.output_dof_pos = torch.zeros(1, 32) + + def InitObservations(self): + self.obs.lin_vel = torch.zeros(1, 3, dtype=torch.float) + self.obs.ang_vel = torch.zeros(1, 3, dtype=torch.float) + self.obs.gravity_vec = torch.tensor([[0.0, 0.0, -1.0]]) + self.obs.commands = torch.zeros(1, 3, dtype=torch.float) + self.obs.base_quat = torch.zeros(1, 4, dtype=torch.float) + self.obs.dof_pos = self.params.default_dof_pos + self.obs.dof_vel = torch.zeros(1, self.params.num_of_dofs, dtype=torch.float) + self.obs.actions = torch.zeros(1, self.params.num_of_dofs, dtype=torch.float) + + def InitOutputs(self): + self.output_torques = torch.zeros(1, self.params.num_of_dofs, dtype=torch.float) + self.output_dof_pos = self.params.default_dof_pos + + def InitControl(self): + self.control.control_state = STATE.STATE_WAITING + self.control.x = 0.0 + self.control.y = 0.0 + self.control.yaw = 0.0 + + def ComputeTorques(self, actions): + actions_scaled = actions * self.params.action_scale + output_torques = self.params.rl_kp * (actions_scaled + self.params.default_dof_pos - self.obs.dof_pos) - self.params.rl_kd * self.obs.dof_vel + return output_torques + + def ComputePosition(self, actions): + actions_scaled = actions * self.params.action_scale + return actions_scaled + self.params.default_dof_pos + + def QuatRotateInverse(self, q, v): + shape = q.shape + q_w = q[:, -1] + q_vec = q[:, :3] + a = v * (2.0 * q_w ** 2 - 1.0).unsqueeze(-1) + b = torch.cross(q_vec, v, dim=-1) * q_w.unsqueeze(-1) * 2.0 + c = q_vec * torch.bmm(q_vec.view(shape[0], 1, 3), v.view(shape[0], 3, 1)).squeeze(-1) * 2.0 + return a - b + c + + def StateController(self, state, command): + # waiting + if self.running_state == STATE.STATE_WAITING: + for i in range(self.params.num_of_dofs): + command.motor_command.q[i] = state.motor_state.q[i] + if self.control.control_state == STATE.STATE_POS_GETUP: + self.control.control_state = STATE.STATE_WAITING + self.getup_percent = 0.0 + for i in range(self.params.num_of_dofs): + self.now_state.motor_state.q[i] = state.motor_state.q[i] + self.start_state.motor_state.q[i] = self.now_state.motor_state.q[i] + self.running_state = STATE.STATE_POS_GETUP + print("\r\n" + LOGGER.INFO + "Switching to STATE_POS_GETUP") + + # stand up (position control) + elif self.running_state == STATE.STATE_POS_GETUP: + if self.getup_percent < 1.0: + self.getup_percent += 1 / 500.0 + self.getup_percent = min(self.getup_percent, 1.0) + for i in range(self.params.num_of_dofs): + command.motor_command.q[i] = (1 - self.getup_percent) * self.now_state.motor_state.q[i] + self.getup_percent * self.params.default_dof_pos[0][i].item() + command.motor_command.dq[i] = 0 + command.motor_command.kp[i] = self.params.fixed_kp[0][i].item() + command.motor_command.kd[i] = self.params.fixed_kd[0][i].item() + command.motor_command.tau[i] = 0 + print("\r" + LOGGER.INFO + f"Getting up {self.getup_percent * 100.0:.1f}", end='', flush=True) + + if self.control.control_state == STATE.STATE_RL_INIT: + self.control.control_state = STATE.STATE_WAITING + self.running_state = STATE.STATE_RL_INIT + print("\r\n" + LOGGER.INFO + "Switching to STATE_RL_INIT") + + elif self.control.control_state == STATE.STATE_POS_GETDOWN: + self.control.control_state = STATE.STATE_WAITING + self.getdown_percent = 0.0 + for i in range(self.params.num_of_dofs): + self.now_state.motor_state.q[i] = state.motor_state.q[i] + self.running_state = STATE.STATE_POS_GETDOWN + print("\r\n" + LOGGER.INFO + "Switching to STATE_POS_GETDOWN") + + # init obs and start rl loop + elif self.running_state == STATE.STATE_RL_INIT: + if self.getup_percent == 1: + self.InitObservations() + self.InitOutputs() + self.InitControl() + self.running_state = STATE.STATE_RL_RUNNING + print("\r\n" + LOGGER.INFO + "Switching to STATE_RL_RUNNING") + + # rl loop + if self.running_state == STATE.STATE_RL_RUNNING: + print("\r" + LOGGER.INFO + f"RL Controller x: {self.control.x:.1f} y: {self.control.y:.1f} yaw: {self.control.yaw:.1f}", end='', flush=True) + for i in range(self.params.num_of_dofs): + command.motor_command.q[i] = self.output_dof_pos[0][i].item() + command.motor_command.dq[i] = 0 + command.motor_command.kp[i] = self.params.rl_kp[0][i].item() + command.motor_command.kd[i] = self.params.rl_kd[0][i].item() + command.motor_command.tau[i] = 0 + + if self.control.control_state == STATE.STATE_POS_GETDOWN: + self.control.control_state = STATE.STATE_WAITING + self.getdown_percent = 0.0 + for i in range(self.params.num_of_dofs): + self.now_state.motor_state.q[i] = state.motor_state.q[i] + self.running_state = STATE.STATE_POS_GETDOWN + print("\r\n" + LOGGER.INFO + "Switching to STATE_POS_GETDOWN") + + elif self.control.control_state == STATE.STATE_POS_GETUP: + self.control.control_state = STATE.STATE_WAITING + self.getup_percent = 0.0 + for i in range(self.params.num_of_dofs): + self.now_state.motor_state.q[i] = state.motor_state.q[i] + self.running_state = STATE.STATE_POS_GETUP + print("\r\n" + LOGGER.INFO + "Switching to STATE_POS_GETUP") + + # get down (position control) + elif self.running_state == STATE.STATE_POS_GETDOWN: + if self.getdown_percent < 1.0: + self.getdown_percent += 1 / 500.0 + self.getdown_percent = min(1.0, self.getdown_percent) + for i in range(self.params.num_of_dofs): + command.motor_command.q[i] = (1 - self.getdown_percent) * self.now_state.motor_state.q[i] + self.getdown_percent * self.start_state.motor_state.q[i] + command.motor_command.dq[i] = 0 + command.motor_command.kp[i] = self.params.fixed_kp[0][i].item() + command.motor_command.kd[i] = self.params.fixed_kd[0][i].item() + command.motor_command.tau[i] = 0 + print("\r" + LOGGER.INFO + f"Getting down {self.getdown_percent * 100.0:.1f}", end='', flush=True) + + if self.getdown_percent == 1: + self.InitObservations() + self.InitOutputs() + self.InitControl() + self.running_state = STATE.STATE_WAITING + print("\r\n" + LOGGER.INFO + "Switching to STATE_WAITING") + + def TorqueProtect(self, origin_output_torques): + out_of_range_indices = [] + out_of_range_values = [] + + for i in range(origin_output_torques.size(1)): + torque_value = origin_output_torques[0][i].item() + limit_lower = -self.params.torque_limits[0][i].item() + limit_upper = self.params.torque_limits[0][i].item() + + if torque_value < limit_lower or torque_value > limit_upper: + out_of_range_indices.append(i) + out_of_range_values.append(torque_value) + + if out_of_range_indices: + for i, index in enumerate(out_of_range_indices): + value = out_of_range_values[i] + limit_lower = -self.params.torque_limits[0][index].item() + limit_upper = self.params.torque_limits[0][index].item() + + print(LOGGER.WARNING + f"Torque({index + 1})={value} out of range({limit_lower}, {limit_upper})") + + # Just a reminder, no protection + self.control.control_state = STATE.STATE_POS_GETDOWN + print(LOGGER.INFO + "Switching to STATE_POS_GETDOWN") + + def KeyboardInterface(self, key): + try: + if hasattr(key, 'char'): + if key.char == '0': + self.control.control_state = STATE.STATE_POS_GETUP + elif key.char == 'p': + self.control.control_state = STATE.STATE_RL_INIT + elif key.char == '1': + self.control.control_state = STATE.STATE_POS_GETDOWN + elif key.char == 'w': + self.control.x += 0.1 + elif key.char == 's': + self.control.x -= 0.1 + elif key.char == 'a': + self.control.yaw += 0.1 + elif key.char == 'd': + self.control.yaw -= 0.1 + elif key.char == 'j': + self.control.y += 0.1 + elif key.char == 'l': + self.control.y -= 0.1 + elif key.char == 'r': + self.control.control_state = STATE.STATE_RESET_SIMULATION + else: + if key == keyboard.Key.enter: + self.control.control_state = STATE.STATE_TOGGLE_SIMULATION + elif key == keyboard.Key.space: + self.control.x = 0 + self.control.y = 0 + self.control.yaw = 0 + except AttributeError: + pass + + def ReadYaml(self, robot_name): + try: + with open(CONFIG_PATH, 'r') as f: + config = yaml.safe_load(f)[robot_name] + except FileNotFoundError as e: + print(LOGGER.ERROR + "The file '{CONFIG_PATH}' does not exist") + return + + self.params.model_name = config["model_name"] + self.params.dt = config["dt"] + self.params.decimation = config["decimation"] + self.params.num_observations = config["num_observations"] + self.params.clip_obs = config["clip_obs"] + self.params.action_scale = config["action_scale"] + self.params.hip_scale_reduction = config["hip_scale_reduction"] + self.params.hip_scale_reduction_indices = config["hip_scale_reduction_indices"] + self.params.clip_actions_upper = torch.tensor(config["clip_actions_upper"]).view(1, -1) + self.params.clip_actions_lower = torch.tensor(config["clip_actions_lower"]).view(1, -1) + self.params.num_of_dofs = config["num_of_dofs"] + self.params.lin_vel_scale = config["lin_vel_scale"] + self.params.ang_vel_scale = config["ang_vel_scale"] + self.params.dof_pos_scale = config["dof_pos_scale"] + self.params.dof_vel_scale = config["dof_vel_scale"] + self.params.commands_scale = torch.tensor([self.params.lin_vel_scale, self.params.lin_vel_scale, self.params.ang_vel_scale]) + self.params.rl_kp = torch.tensor(config["rl_kp"]).view(1, -1) + self.params.rl_kd = torch.tensor(config["rl_kd"]).view(1, -1) + self.params.fixed_kp = torch.tensor(config["fixed_kp"]).view(1, -1) + self.params.fixed_kd = torch.tensor(config["fixed_kd"]).view(1, -1) + self.params.torque_limits = torch.tensor(config["torque_limits"]).view(1, -1) + self.params.default_dof_pos = torch.tensor(config["default_dof_pos"]).view(1, -1) + self.params.joint_controller_names = config["joint_controller_names"] + diff --git a/src/rl_sar/scripts/rl_sim.py b/src/rl_sar/scripts/rl_sim.py new file mode 100644 index 0000000..c67419e --- /dev/null +++ b/src/rl_sar/scripts/rl_sim.py @@ -0,0 +1,234 @@ +import sys +import os +import torch +import threading +import time +import rospy +import numpy as np +from gazebo_msgs.msg import ModelStates +from sensor_msgs.msg import JointState +from geometry_msgs.msg import Twist, Pose +from robot_msgs.msg import MotorCommand +from gazebo_msgs.srv import SetModelState, SetModelStateRequest +from std_srvs.srv import Empty + +path = os.path.abspath(".") +sys.path.insert(0, path + "/src/rl_sar/scripts") +from rl_sdk import * +from observation_buffer import * + +class RL_Sim(RL): + def __init__(self): + super().__init__() + + # member variables for RL_Sim + self.vel = Twist() + self.pose = Pose() + self.cmd_vel = Twist() + + # start ros node + rospy.init_node("rl_sim") + + # read params from yaml + self.robot_name = rospy.get_param("robot_name", "") + self.ReadYaml(self.robot_name) + + # history + self.use_history = rospy.get_param("use_history", "") + if self.use_history: + self.history_obs_buf = ObservationBuffer(1, self.params.num_observations, 6) + + # Due to the fact that the robot_state_publisher sorts the joint names alphabetically, + # the mapping table is established according to the order defined in the YAML file + sorted_joint_controller_names = sorted(self.params.joint_controller_names) + self.sorted_to_original_index = {} + for i in range(len(self.params.joint_controller_names)): + self.sorted_to_original_index[sorted_joint_controller_names[i]] = i + self.mapped_joint_positions = [0.0] * self.params.num_of_dofs + self.mapped_joint_velocities = [0.0] * self.params.num_of_dofs + self.mapped_joint_efforts = [0.0] * self.params.num_of_dofs + + # init + torch.set_grad_enabled(False) + self.joint_publishers_commands = [MotorCommand() for _ in range(self.params.num_of_dofs)] + self.InitObservations() + self.InitOutputs() + self.InitControl() + + # model + model_path = os.path.join(os.path.dirname(__file__), f"../models/{self.robot_name}/{self.params.model_name}") + self.model = torch.jit.load(model_path) + + # publisher + self.ros_namespace = rospy.get_param("ros_namespace", "") + self.joint_publishers = {} + for i in range(self.params.num_of_dofs): + topic_name = f"{self.ros_namespace}{self.params.joint_controller_names[i]}/command" + self.joint_publishers[self.params.joint_controller_names[i]] = rospy.Publisher(topic_name, MotorCommand, queue_size=10) + + # subscriber + self.cmd_vel_subscriber = rospy.Subscriber("/cmd_vel", Twist, self.CmdvelCallback, queue_size=10) + self.model_state_subscriber = rospy.Subscriber("/gazebo/model_states", ModelStates, self.ModelStatesCallback, queue_size=10) + joint_states_topic = f"{self.ros_namespace}joint_states" + self.joint_state_subscriber = rospy.Subscriber(joint_states_topic, JointState, self.JointStatesCallback, queue_size=10) + + # service + self.gazebo_set_model_state_client = rospy.ServiceProxy("/gazebo/set_model_state", SetModelState) + self.gazebo_pause_physics_client = rospy.ServiceProxy("/gazebo/pause_physics", Empty) + self.gazebo_unpause_physics_client = rospy.ServiceProxy("/gazebo/unpause_physics", Empty) + + # loops + self.thread_control = threading.Thread(target=self.ThreadControl) + self.thread_rl = threading.Thread(target=self.ThreadRL) + self.thread_control.start() + self.thread_rl.start() + + # keyboard + self.listener_keyboard = keyboard.Listener(on_press=self.KeyboardInterface) + self.listener_keyboard.start() + + print(LOGGER.INFO + "RL_Sim start") + + def __del__(self): + print(LOGGER.INFO + "RL_Sim exit") + + def GetState(self, state): + state.imu.quaternion[3] = self.pose.orientation.w + state.imu.quaternion[0] = self.pose.orientation.x + state.imu.quaternion[1] = self.pose.orientation.y + state.imu.quaternion[2] = self.pose.orientation.z + + state.imu.gyroscope[0] = self.vel.angular.x + state.imu.gyroscope[1] = self.vel.angular.y + state.imu.gyroscope[2] = self.vel.angular.z + + # state.imu.accelerometer + + for i in range(self.params.num_of_dofs): + state.motor_state.q[i] = self.mapped_joint_positions[i] + state.motor_state.dq[i] = self.mapped_joint_velocities[i] + state.motor_state.tauEst[i] = self.mapped_joint_efforts[i] + + def SetCommand(self, command): + for i in range(self.params.num_of_dofs): + self.joint_publishers_commands[i].q = command.motor_command.q[i] + self.joint_publishers_commands[i].dq = command.motor_command.dq[i] + self.joint_publishers_commands[i].kp = command.motor_command.kp[i] + self.joint_publishers_commands[i].kd = command.motor_command.kd[i] + self.joint_publishers_commands[i].tau = command.motor_command.tau[i] + + for i in range(self.params.num_of_dofs): + self.joint_publishers[self.params.joint_controller_names[i]].publish(self.joint_publishers_commands[i]) + + def RobotControl(self): + if self.control.control_state == STATE.STATE_RESET_SIMULATION: + set_model_state = SetModelStateRequest().model_state + gazebo_model_name = f"{self.robot_name}_gazebo" + set_model_state.model_name = gazebo_model_name + set_model_state.pose.position.z = 1.0 + set_model_state.reference_frame = "world" + self.gazebo_set_model_state_client(set_model_state) + self.control.control_state = STATE.STATE_WAITING + if self.control.control_state == STATE.STATE_TOGGLE_SIMULATION: + if self.simulation_running: + self.gazebo_pause_physics_client() + print("\r\n" + LOGGER.INFO + "Simulation Stop") + else: + self.gazebo_unpause_physics_client() + print("\r\n" + LOGGER.INFO + "Simulation Start") + self.simulation_running = not self.simulation_running + self.control.control_state = STATE.STATE_WAITING + + if self.simulation_running: + self.GetState(self.robot_state) + self.StateController(self.robot_state, self.robot_command) + self.SetCommand(self.robot_command) + + def ModelStatesCallback(self, msg): + self.vel = msg.twist[2] + self.pose = msg.pose[2] + + def CmdvelCallback(self, msg): + self.cmd_vel = msg + + def MapData(self, source_data, target_data): + for i in range(len(source_data)): + target_data[i] = source_data[self.sorted_to_original_index[self.params.joint_controller_names[i]]] + + def JointStatesCallback(self, msg): + self.MapData(msg.position, self.mapped_joint_positions) + self.MapData(msg.velocity, self.mapped_joint_velocities) + self.MapData(msg.effort, self.mapped_joint_efforts) + + def RunModel(self): + if self.running_state == STATE.STATE_RL_RUNNING and self.simulation_running: + # self.obs.lin_vel = torch.tensor([[self.vel.linear.x, self.vel.linear.y, self.vel.linear.z]]) + self.obs.ang_vel = torch.tensor(self.robot_state.imu.gyroscope).unsqueeze(0) + # self.obs.commands = torch.tensor([[self.cmd_vel.linear.x, self.cmd_vel.linear.y, self.cmd_vel.angular.z]]) + self.obs.commands = torch.tensor([[self.control.x, self.control.y, self.control.yaw]]) + self.obs.base_quat = torch.tensor(self.robot_state.imu.quaternion).unsqueeze(0) + self.obs.dof_pos = torch.tensor(self.robot_state.motor_state.q).narrow(0, 0, self.params.num_of_dofs).unsqueeze(0) + self.obs.dof_vel = torch.tensor(self.robot_state.motor_state.dq).narrow(0, 0, self.params.num_of_dofs).unsqueeze(0) + + clamped_actions = self.Forward() + + for i in self.params.hip_scale_reduction_indices: + clamped_actions[0][i] *= self.params.hip_scale_reduction + + self.obs.actions = clamped_actions + + origin_output_torques = self.ComputeTorques(self.obs.actions) + + # self.TorqueProtect(origin_output_torques) + + self.output_torques = torch.clamp(origin_output_torques, -(self.params.torque_limits), self.params.torque_limits) + self.output_dof_pos = self.ComputePosition(self.obs.actions) + + def ComputeObservation(self): + obs = torch.cat([ + # self.obs.lin_vel * self.params.lin_vel_scale, + # self.obs.ang_vel * self.params.ang_vel_scale, # TODO is QuatRotateInverse necessery? + self.QuatRotateInverse(self.obs.base_quat, self.obs.ang_vel) * self.params.ang_vel_scale, + self.QuatRotateInverse(self.obs.base_quat, self.obs.gravity_vec), + self.obs.commands * self.params.commands_scale, + (self.obs.dof_pos - self.params.default_dof_pos) * self.params.dof_pos_scale, + self.obs.dof_vel * self.params.dof_vel_scale, + self.obs.actions + ], dim = -1) + clamped_obs = torch.clamp(obs, -self.params.clip_obs, self.params.clip_obs) + return clamped_obs + + def Forward(self): + torch.set_grad_enabled(False) + clamped_obs = self.ComputeObservation() + if self.use_history: + self.history_obs_buf.insert(clamped_obs) + history_obs = self.history_obs_buf.get_obs_vec(np.arange(6)) + actions = self.model.forward(history_obs) + else: + actions = self.model.forward(clamped_obs) + clamped_actions = torch.clamp(actions, self.params.clip_actions_lower, self.params.clip_actions_upper) + return clamped_actions + + def ThreadControl(self): + thread_period = self.params.dt + thread_name = "thread_control" + print(f"[Thread Start] named: {thread_name}, period: {thread_period * 1000:.0f}(ms), cpu unspecified") + while not rospy.is_shutdown(): + self.RobotControl() + time.sleep(thread_period) + print("[Thread End] named: " + thread_name) + + def ThreadRL(self): + thread_period = self.params.dt * self.params.decimation + thread_name = "thread_rl" + print(f"[Thread Start] named: {thread_name}, period: {thread_period * 1000:.0f}(ms), cpu unspecified") + while not rospy.is_shutdown(): + self.RunModel() + time.sleep(thread_period) + print("[Thread End] named: " + thread_name) + +if __name__ == "__main__": + rl_sim = RL_Sim() + rospy.spin() + \ No newline at end of file diff --git a/src/rl_sar/src/rl_sim.cpp b/src/rl_sar/src/rl_sim.cpp index fa9ad0f..417bcf2 100644 --- a/src/rl_sar/src/rl_sim.cpp +++ b/src/rl_sar/src/rl_sim.cpp @@ -61,13 +61,15 @@ RL_Sim::RL_Sim() this->gazebo_unpause_physics_client = nh.serviceClient("/gazebo/unpause_physics"); // loop - this->loop_keyboard = std::make_shared("loop_keyboard", 0.05, std::bind(&RL_Sim::KeyboardInterface, this)); this->loop_control = std::make_shared("loop_control", this->params.dt, std::bind(&RL_Sim::RobotControl, this)); this->loop_rl = std::make_shared("loop_rl", this->params.dt * this->params.decimation, std::bind(&RL_Sim::RunModel, this)); - this->loop_keyboard->start(); this->loop_control->start(); this->loop_rl->start(); + // keyboard + this->loop_keyboard = std::make_shared("loop_keyboard", 0.05, std::bind(&RL_Sim::KeyboardInterface, this)); + this->loop_keyboard->start(); + #ifdef PLOT this->plot_t = std::vector(this->plot_size, 0); this->plot_real_joint_pos.resize(this->params.num_of_dofs); @@ -80,6 +82,8 @@ RL_Sim::RL_Sim() #ifdef CSV_LOGGER this->CSVInit(this->robot_name); #endif + + std::cout << LOGGER::INFO << "RL_Sim start" << std::endl; } RL_Sim::~RL_Sim() @@ -150,10 +154,12 @@ void RL_Sim::RobotControl() if(simulation_running) { this->gazebo_pause_physics_client.call(empty); + std::cout << std::endl << LOGGER::INFO << "Simulation Stop" << std::endl; } else { this->gazebo_unpause_physics_client.call(empty); + std::cout << std::endl << LOGGER::INFO << "Simulation Start" << std::endl; } simulation_running = !simulation_running; this->control.control_state = STATE_WAITING; @@ -230,15 +236,16 @@ void RL_Sim::RunModel() torch::Tensor RL_Sim::ComputeObservation() { - torch::Tensor obs = torch::cat({// this->obs.lin_vel * this->params.lin_vel_scale, - this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel) * this->params.ang_vel_scale, - // this->obs.ang_vel * this->params.ang_vel_scale, // TODO - this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec), - this->obs.commands * this->params.commands_scale, - (this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale, - this->obs.dof_vel * this->params.dof_vel_scale, - this->obs.actions - },1); + torch::Tensor obs = torch::cat({ + // this->obs.lin_vel * this->params.lin_vel_scale, + // this->obs.ang_vel * this->params.ang_vel_scale, // TODO is QuatRotateInverse necessery? + this->QuatRotateInverse(this->obs.base_quat, this->obs.ang_vel) * this->params.ang_vel_scale, + this->QuatRotateInverse(this->obs.base_quat, this->obs.gravity_vec), + this->obs.commands * this->params.commands_scale, + (this->obs.dof_pos - this->params.default_dof_pos) * this->params.dof_pos_scale, + this->obs.dof_vel * this->params.dof_vel_scale, + this->obs.actions + },1); torch::Tensor clamped_obs = torch::clamp(obs, -this->params.clip_obs, this->params.clip_obs); return clamped_obs; } @@ -246,11 +253,8 @@ torch::Tensor RL_Sim::ComputeObservation() torch::Tensor RL_Sim::Forward() { torch::autograd::GradMode::set_enabled(false); - torch::Tensor clamped_obs = this->ComputeObservation(); - torch::Tensor actions; - if(this->use_history) { this->history_obs_buf.insert(clamped_obs); @@ -260,10 +264,9 @@ torch::Tensor RL_Sim::Forward() else { actions = this->model.forward({clamped_obs}).toTensor(); - } + } torch::Tensor clamped_actions = torch::clamp(actions, this->params.clip_actions_lower, this->params.clip_actions_upper); - return clamped_actions; } @@ -297,12 +300,8 @@ void signalHandler(int signum) int main(int argc, char **argv) { signal(SIGINT, signalHandler); - ros::init(argc, argv, "rl_sar"); - RL_Sim rl_sar; - ros::spin(); - return 0; }