Lidar and RL highlevel controller added to sim

This commit is contained in:
Rooholla-KhorramBakht 2024-05-23 14:06:10 -04:00
parent 933f189952
commit 61ce7244a5
5 changed files with 460 additions and 2 deletions

View File

@ -6,9 +6,19 @@ from Go2Py import ASSETS_PATH
import os import os
from scipy.spatial.transform import Rotation from scipy.spatial.transform import Rotation
pnt = np.array([-0.2, 0, 0.05])
lidar_angles = np.linspace(0.0, 2 * np.pi, 1024).reshape(-1, 1)
x_vec = np.cos(lidar_angles)
y_vec = np.sin(lidar_angles)
z_vec = np.zeros_like(x_vec)
vec = np.concatenate([x_vec, y_vec, z_vec], axis=1)
nray = vec.shape[0]
geomid = np.zeros(nray, np.int32)
dist = np.zeros(nray, np.float64)
class Go2Sim: class Go2Sim:
def __init__(self, render=True, dt=0.002): def __init__(self, mode='lowlevel', render=True, dt=0.002):
self.model = mujoco.MjModel.from_xml_path( self.model = mujoco.MjModel.from_xml_path(
os.path.join(ASSETS_PATH, 'mujoco/go2.xml') os.path.join(ASSETS_PATH, 'mujoco/go2.xml')
@ -68,6 +78,22 @@ class Go2Sim:
self.kv = np.zeros(12) self.kv = np.zeros(12)
self.latest_command_stamp = time.time() self.latest_command_stamp = time.time()
self.actuator_tau = np.zeros(12) self.actuator_tau = np.zeros(12)
self.mode = mode
if self.mode == 'highlevel':
from Go2Py.control.walk_these_ways import CommandInterface, loadParameters, Policy, WalkTheseWaysAgent, HistoryWrapper
checkpoint_path = os.path.join(ASSETS_PATH,'checkpoints/walk_these_ways')
self.cfg = loadParameters(checkpoint_path)
self.policy = Policy(checkpoint_path)
self.command_profile = CommandInterface()
self.agent = WalkTheseWaysAgent(self.cfg, self.command_profile, robot=self)
self.agent = HistoryWrapper(self.agent)
self.control_dt = self.cfg["control"]["decimation"] * self.cfg["sim"]["dt"]
self.obs = self.agent.reset()
self.standUpReset()
self.step_counter = 0
self.step = self.stepHighlevel
else:
self.step = self.stepLowlevel
def reset(self): def reset(self):
self.q_nominal = np.hstack( self.q_nominal = np.hstack(
@ -117,7 +143,7 @@ class Go2Sim:
self.tau_ff = tau_ff self.tau_ff = tau_ff
self.latest_command_stamp = time.time() self.latest_command_stamp = time.time()
def step(self): def stepLowlevel(self):
state = self.getJointStates() state = self.getJointStates()
q, dq = state['q'], state['dq'] q, dq = state['q'], state['dq']
tau = np.diag(self.kp) @ (self.q_des - q).reshape(12, 1) + \ tau = np.diag(self.kp) @ (self.q_des - q).reshape(12, 1) + \
@ -131,6 +157,19 @@ class Go2Sim:
if self.render and (self.step_counter % self.render_ds_ratio) == 0: if self.render and (self.step_counter % self.render_ds_ratio) == 0:
self.viewer.sync() self.viewer.sync()
def stepHighlevel(self, vx, vy, omega_z, body_z_offset=0):
policy_info = {}
if self.step_counter % (self.control_dt // self.dt) == 0:
action = self.policy(self.obs, policy_info)
self.obs, ret, done, info = self.agent.step(action)
self.step_counter+=1
self.stepLowlevel()
self.command_profile.yaw_vel_cmd = omega_z
self.command_profile.x_vel_cmd = vx
self.command_profile.y_vel_cmd = vy
self.command_profile.body_height_cmd = body_z_offset
def getSiteJacobian(self, site_name): def getSiteJacobian(self, site_name):
id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_SITE, site_name) id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_SITE, site_name)
assert id > 0, 'The requested site could not be found' assert id > 0, 'The requested site could not be found'
@ -151,6 +190,28 @@ class Go2Sim:
g_in_body = R.T @ np.array([0.0, 0.0, -1.0]).reshape(3, 1) g_in_body = R.T @ np.array([0.0, 0.0, -1.0]).reshape(3, 1)
return g_in_body return g_in_body
def getLidarData(self):
t, q = self.getPose()
world_R_body = Rotation.from_quat([q[1], q[2], q[3], q[0]]).as_matrix()
pnt = t.copy()
pnt[2]+=0.25
vec_in_w = (world_R_body@vec.T).T
mujoco.mj_multiRay(
m=self.model,
d=self.data,
pnt=pnt,
vec=vec.flatten(),
geomgroup=None,
flg_static=1,
bodyexclude=-1,
geomid=geomid,
dist=dist,
nray=nray,
cutoff=mujoco.mjMAXVAL,
)
pcd = dist.reshape(-1, 1) * vec
return {"pcd": pcd, "geomid": geomid, "dist": dist}
def overheat(self): def overheat(self):
return False return False

0
Go2Py/utils/__init__.py Normal file
View File

90
Go2Py/utils/ros2.py Normal file
View File

@ -0,0 +1,90 @@
import struct
import threading
import time
import numpy as np
import rclpy
import tf2_ros
from rclpy.node import Node
from rclpy.qos import QoSProfile
from rclpy.executors import MultiThreadedExecutor
from geometry_msgs.msg import TransformStamped
from scipy.spatial.transform import Rotation as R
def ros2_init(args=None):
rclpy.init(args=args)
def ros2_close():
rclpy.shutdown()
class ROS2ExecutorManager:
"""A class to manage the ROS2 executor. It allows to add nodes and start the executor in a separate thread."""
def __init__(self):
self.executor = MultiThreadedExecutor()
self.nodes = []
self.executor_thread = None
def add_node(self, node: Node):
"""Add a new node to the executor."""
self.nodes.append(node)
self.executor.add_node(node)
def _run_executor(self):
try:
self.executor.spin()
except KeyboardInterrupt:
pass
finally:
self.terminate()
def start(self):
"""Start spinning the nodes in a separate thread."""
self.executor_thread = threading.Thread(target=self._run_executor)
self.executor_thread.start()
def terminate(self):
"""Terminate all nodes and shutdown rclpy."""
for node in self.nodes:
node.destroy_node()
rclpy.shutdown()
if self.executor_thread:
self.executor_thread.join()
class ROS2TFInterface(Node):
def __init__(self, parent_name, child_name, node_name):
super().__init__(f'{node_name}_tf2_listener')
self.parent_name = parent_name
self.child_name = child_name
self.tfBuffer = tf2_ros.Buffer()
self.listener = tf2_ros.TransformListener(self.tfBuffer, self)
self.T = None
self.stamp = None
self.running = True
self.thread = threading.Thread(target=self.update_loop)
self.thread.start()
self.trans = None
def update_loop(self):
while self.running:
try:
self.trans = self.tfBuffer.lookup_transform(self.parent_name, self.child_name, rclpy.time.Time(), rclpy.time.Duration(seconds=0.1))
except (tf2_ros.LookupException, tf2_ros.ConnectivityException, tf2_ros.ExtrapolationException) as e:
print(e)
time.sleep(0.01)
def get_pose(self):
if self.trans is None:
return None
else:
translation = [self.trans.transform.translation.x, self.trans.transform.translation.y, self.trans.transform.translation.z]
rotation = [self.trans.transform.rotation.x, self.trans.transform.rotation.y, self.trans.transform.rotation.z, self.trans.transform.rotation.w]
self.T = np.eye(4)
self.T[0:3, 0:3] = R.from_quat(rotation).as_matrix()
self.T[:3, 3] = translation
self.stamp = self.trans.header.stamp.nanosec * 1e-9 + self.trans.header.stamp.sec
return self.T
def close(self):
self.running = False
self.thread.join()
self.destroy_node()

View File

@ -0,0 +1,31 @@
from launch import LaunchDescription
from launch.actions import IncludeLaunchDescription
from launch.launch_description_sources import PythonLaunchDescriptionSource
from launch.substitutions import ThisLaunchFileDir
from launch_ros.actions import Node
def generate_launch_description():
return LaunchDescription([
Node(
name='go2_cam',
namespace='go2/cam',
package='realsense2_camera',
executable='realsense2_camera_node',
parameters=[{
'enable_infra1': True,
'enable_infra2': True,
'enable_color': True,
'enable_depth': False,
'depth_module.emitter_enabled': 0,
'rgb_camera.profile':'640x480x30',
'depth_module.profile': '640x480x30',
'enable_gyro': True,
'enable_accel': True,
'gyro_fps': 400,
'accel_fps': 200,
'unite_imu_method': 2,
# 'tf_publish_rate': 0.0
}]
)
])

File diff suppressed because one or more lines are too long