parkour/onboard_codes/go1/a1_visual_embedding.py

318 lines
12 KiB
Python
Raw Normal View History

import os
import os.path as osp
import numpy as np
import torch
import json
from functools import partial
from collections import OrderedDict
from a1_real import UnitreeA1Real, resize2d
from rsl_rl import modules
import rospy
from unitree_legged_msgs.msg import Float32MultiArrayStamped
from sensor_msgs.msg import Image
import ros_numpy
import pyrealsense2 as rs
def get_encoder_script(logdir):
with open(osp.join(logdir, "config.json"), "r") as f:
config_dict = json.load(f, object_pairs_hook= OrderedDict)
model_device = torch.device("cuda")
unitree_real_env = UnitreeA1Real(
robot_namespace= "a112138",
cfg= config_dict,
forward_depth_topic= "", # this env only computes parameters to build the model
forward_depth_embedding_dims= None,
model_device= model_device,
)
model = getattr(modules, config_dict["runner"]["policy_class_name"])(
num_actor_obs= unitree_real_env.num_obs,
num_critic_obs= unitree_real_env.num_privileged_obs,
num_actions= 12,
obs_segments= unitree_real_env.obs_segments,
privileged_obs_segments= unitree_real_env.privileged_obs_segments,
**config_dict["policy"],
)
model_names = [i for i in os.listdir(logdir) if i.startswith("model_")]
model_names.sort(key= lambda x: int(x.split("_")[-1].split(".")[0]))
state_dict = torch.load(osp.join(args.logdir, model_names[-1]), map_location= "cpu")
model.load_state_dict(state_dict["model_state_dict"])
model.to(model_device)
model.eval()
visual_encoder = model.visual_encoder
script = torch.jit.script(visual_encoder)
return script, model_device
def get_input_filter(args):
""" This is the filter different from the simulator, but try to close the gap. """
with open(osp.join(args.logdir, "config.json"), "r") as f:
config_dict = json.load(f, object_pairs_hook= OrderedDict)
image_resolution = config_dict["sensor"]["forward_camera"].get(
"output_resolution",
config_dict["sensor"]["forward_camera"]["resolution"],
)
depth_range = config_dict["sensor"]["forward_camera"].get(
"depth_range",
[0.0, 3.0],
)
depth_range = (depth_range[0] * 1000, depth_range[1] * 1000) # [m] -> [mm]
crop_top, crop_bottom, crop_left, crop_right = args.crop_top, args.crop_bottom, args.crop_left, args.crop_right
crop_far = args.crop_far * 1000
def input_filter(depth_image: torch.Tensor,
crop_top: int,
crop_bottom: int,
crop_left: int,
crop_right: int,
crop_far: float,
depth_min: int,
depth_max: int,
output_height: int,
output_width: int,
):
""" depth_image must have shape [1, 1, H, W] """
depth_image = depth_image[:, :,
crop_top: -crop_bottom-1,
crop_left: -crop_right-1,
]
depth_image[depth_image > crop_far] = depth_max
depth_image = torch.clip(
depth_image,
depth_min,
depth_max,
) / (depth_max - depth_min)
depth_image = resize2d(depth_image, (output_height, output_width))
return depth_image
# input_filter = torch.jit.script(input_filter)
return partial(input_filter,
crop_top= crop_top,
crop_bottom= crop_bottom,
crop_left= crop_left,
crop_right= crop_right,
crop_far= crop_far,
depth_min= depth_range[0],
depth_max= depth_range[1],
output_height= image_resolution[0],
output_width= image_resolution[1],
), depth_range
def get_started_pipeline(
height= 480,
width= 640,
fps= 30,
enable_rgb= False,
):
# By default, rgb is not used.
pipeline = rs.pipeline()
config = rs.config()
config.enable_stream(rs.stream.depth, width, height, rs.format.z16, fps)
if enable_rgb:
config.enable_stream(rs.stream.color, width, height, rs.format.rgb8, fps)
profile = pipeline.start(config)
# build the sensor filter
hole_filling_filter = rs.hole_filling_filter(2)
spatial_filter = rs.spatial_filter()
spatial_filter.set_option(rs.option.filter_magnitude, 5)
spatial_filter.set_option(rs.option.filter_smooth_alpha, 0.75)
spatial_filter.set_option(rs.option.filter_smooth_delta, 1)
spatial_filter.set_option(rs.option.holes_fill, 4)
temporal_filter = rs.temporal_filter()
temporal_filter.set_option(rs.option.filter_smooth_alpha, 0.75)
temporal_filter.set_option(rs.option.filter_smooth_delta, 1)
# decimation_filter = rs.decimation_filter()
# decimation_filter.set_option(rs.option.filter_magnitude, 2)
def filter_func(frame):
frame = hole_filling_filter.process(frame)
frame = spatial_filter.process(frame)
frame = temporal_filter.process(frame)
# frame = decimation_filter.process(frame)
return frame
return pipeline, filter_func
def main(args):
rospy.init_node("a1_legged_gym_jetson")
input_filter, depth_range = get_input_filter(args)
model_script, model_device = get_encoder_script(args.logdir)
with open(osp.join(args.logdir, "config.json"), "r") as f:
config_dict = json.load(f, object_pairs_hook= OrderedDict)
if config_dict.get("sensor", dict()).get("forward_camera", dict()).get("refresh_duration", None) is not None:
refresh_duration = config_dict["sensor"]["forward_camera"]["refresh_duration"]
ros_rate = rospy.Rate(1.0 / refresh_duration)
rospy.loginfo("Using refresh duration {}s".format(refresh_duration))
else:
ros_rate = rospy.Rate(args.fps)
rs_pipeline, rs_filters = get_started_pipeline(
height= args.height,
width= args.width,
fps= args.fps,
enable_rgb= args.enable_rgb,
)
# gyro_pipeline = rs.pipeline()
# gyro_config = rs.config()
# gyro_config.enable_stream(rs.stream.gyro, rs.format.motion_xyz32f, 200)
# gyro_profile = gyro_pipeline.start(gyro_config)
embedding_publisher = rospy.Publisher(
args.namespace + "/visual_embedding",
Float32MultiArrayStamped,
queue_size= 1,
)
if args.enable_vis:
depth_image_publisher = rospy.Publisher(
args.namespace + "/camera/depth/image_rect_raw",
Image,
queue_size= 1,
)
network_input_publisher = rospy.Publisher(
args.namespace + "/camera/depth/network_input_raw",
Image,
queue_size= 1,
)
if args.enable_rgb:
rgb_image_publisher = rospy.Publisher(
args.namespace + "/camera/color/image_raw",
Image,
queue_size= 1,
)
rospy.loginfo("Depth range is clipped to [{}, {}] and normalized".format(depth_range[0], depth_range[1]))
rospy.loginfo("ROS, model, realsense have been initialized.")
if args.enable_vis:
rospy.loginfo("Visualization enabled, sending depth{} images".format(", rgb" if args.enable_rgb else ""))
try:
embedding_msg = Float32MultiArrayStamped()
embedding_msg.header.frame_id = args.namespace + "/camera_depth_optical_frame"
frame_got = False
while not rospy.is_shutdown():
# Wait for the depth image
frames = rs_pipeline.wait_for_frames(int( \
config_dict["sensor"]["forward_camera"]["latency_range"][1] \
* 1000)) # ms
embedding_msg.header.stamp = rospy.Time.now()
depth_frame = frames.get_depth_frame()
if not depth_frame:
continue
if not frame_got:
frame_got = True
rospy.loginfo("Realsense frame recieved. Sending embeddings...")
if args.enable_rgb:
color_frame = frames.get_color_frame()
# Use this branch to log the time when image is acquired
if args.enable_vis and not color_frame is None:
color_frame = np.asanyarray(color_frame.get_data())
rgb_image_msg = ros_numpy.msgify(Image, color_frame, encoding= "rgb8")
rgb_image_msg.header.stamp = rospy.Time.now()
rgb_image_msg.header.frame_id = args.namespace + "/camera_color_optical_frame"
rgb_image_publisher.publish(rgb_image_msg)
# Process the depth image and publish
depth_frame = rs_filters(depth_frame)
depth_image_ = np.asanyarray(depth_frame.get_data())
depth_image = torch.from_numpy(depth_image_.astype(np.float32)).unsqueeze(0).unsqueeze(0).to(model_device)
depth_image = input_filter(depth_image)
with torch.no_grad():
depth_embedding = model_script(depth_image).reshape(-1).cpu().numpy()
embedding_msg.header.seq += 1
embedding_msg.data = depth_embedding.tolist()
embedding_publisher.publish(embedding_msg)
# Publish the acquired image if needed
if args.enable_vis:
depth_image_msg = ros_numpy.msgify(Image, depth_image_, encoding= "16UC1")
depth_image_msg.header.stamp = rospy.Time.now()
depth_image_msg.header.frame_id = args.namespace + "/camera_depth_optical_frame"
depth_image_publisher.publish(depth_image_msg)
network_input_np = (\
depth_image.detach().cpu().numpy()[0, 0] * (depth_range[1] - depth_range[0]) \
+ depth_range[0]
).astype(np.uint16)
network_input_msg = ros_numpy.msgify(Image, network_input_np, encoding= "16UC1")
network_input_msg.header.stamp = rospy.Time.now()
network_input_msg.header.frame_id = args.namespace + "/camera_depth_optical_frame"
network_input_publisher.publish(network_input_msg)
ros_rate.sleep()
finally:
rs_pipeline.stop()
if __name__ == "__main__":
""" This script is designed to load the model and process the realsense image directly
from realsense SDK without realsense ROS wrapper
"""
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--namespace",
type= str,
default= "/a112138",
)
parser.add_argument("--logdir",
type= str,
help= "The log directory of the trained model",
)
parser.add_argument("--height",
type= int,
default= 240,
help= "The height of the realsense image",
)
parser.add_argument("--width",
type= int,
default= 424,
help= "The width of the realsense image",
)
parser.add_argument("--fps",
type= int,
default= 30,
help= "The fps of the realsense image",
)
parser.add_argument("--crop_left",
type= int,
default= 60,
help= "num of pixel to crop in the original pyrealsense readings."
)
parser.add_argument("--crop_right",
type= int,
default= 46,
help= "num of pixel to crop in the original pyrealsense readings."
)
parser.add_argument("--crop_top",
type= int,
default= 0,
help= "num of pixel to crop in the original pyrealsense readings."
)
parser.add_argument("--crop_bottom",
type= int,
default= 0,
help= "num of pixel to crop in the original pyrealsense readings."
)
parser.add_argument("--crop_far",
type= float,
default= 3.0,
help= "asside from the config far limit, make all depth readings larger than this value to be 3.0 in un-normalized network input."
)
parser.add_argument("--enable_rgb",
action= "store_true",
help= "Whether to enable rgb image",
)
parser.add_argument("--enable_vis",
action= "store_true",
help= "Whether to publish realsense image",
)
args = parser.parse_args()
main(args)