parkour/onboard_codes/a1_visual_embedding.py

318 lines
12 KiB
Python

import os
import os.path as osp
import numpy as np
import torch
import json
from functools import partial
from collections import OrderedDict
from a1_real import UnitreeA1Real, resize2d
from rsl_rl import modules
import rospy
from unitree_legged_msgs.msg import Float32MultiArrayStamped
from sensor_msgs.msg import Image
import ros_numpy
import pyrealsense2 as rs
def get_encoder_script(logdir):
with open(osp.join(logdir, "config.json"), "r") as f:
config_dict = json.load(f, object_pairs_hook= OrderedDict)
model_device = torch.device("cuda")
unitree_real_env = UnitreeA1Real(
robot_namespace= "a112138",
cfg= config_dict,
forward_depth_topic= "", # this env only computes parameters to build the model
forward_depth_embedding_dims= None,
model_device= model_device,
)
model = getattr(modules, config_dict["runner"]["policy_class_name"])(
num_actor_obs= unitree_real_env.num_obs,
num_critic_obs= unitree_real_env.num_privileged_obs,
num_actions= 12,
obs_segments= unitree_real_env.obs_segments,
privileged_obs_segments= unitree_real_env.privileged_obs_segments,
**config_dict["policy"],
)
model_names = [i for i in os.listdir(logdir) if i.startswith("model_")]
model_names.sort(key= lambda x: int(x.split("_")[-1].split(".")[0]))
state_dict = torch.load(osp.join(args.logdir, model_names[-1]), map_location= "cpu")
model.load_state_dict(state_dict["model_state_dict"])
model.to(model_device)
model.eval()
visual_encoder = model.visual_encoder
script = torch.jit.script(visual_encoder)
return script, model_device
def get_input_filter(args):
""" This is the filter different from the simulator, but try to close the gap. """
with open(osp.join(args.logdir, "config.json"), "r") as f:
config_dict = json.load(f, object_pairs_hook= OrderedDict)
image_resolution = config_dict["sensor"]["forward_camera"].get(
"output_resolution",
config_dict["sensor"]["forward_camera"]["resolution"],
)
depth_range = config_dict["sensor"]["forward_camera"].get(
"depth_range",
[0.0, 3.0],
)
depth_range = (depth_range[0] * 1000, depth_range[1] * 1000) # [m] -> [mm]
crop_top, crop_bottom, crop_left, crop_right = args.crop_top, args.crop_bottom, args.crop_left, args.crop_right
crop_far = args.crop_far * 1000
def input_filter(depth_image: torch.Tensor,
crop_top: int,
crop_bottom: int,
crop_left: int,
crop_right: int,
crop_far: float,
depth_min: int,
depth_max: int,
output_height: int,
output_width: int,
):
""" depth_image must have shape [1, 1, H, W] """
depth_image = depth_image[:, :,
crop_top: -crop_bottom-1,
crop_left: -crop_right-1,
]
depth_image[depth_image > crop_far] = depth_max
depth_image = torch.clip(
depth_image,
depth_min,
depth_max,
) / (depth_max - depth_min)
depth_image = resize2d(depth_image, (output_height, output_width))
return depth_image
# input_filter = torch.jit.script(input_filter)
return partial(input_filter,
crop_top= crop_top,
crop_bottom= crop_bottom,
crop_left= crop_left,
crop_right= crop_right,
crop_far= crop_far,
depth_min= depth_range[0],
depth_max= depth_range[1],
output_height= image_resolution[0],
output_width= image_resolution[1],
), depth_range
def get_started_pipeline(
height= 480,
width= 640,
fps= 30,
enable_rgb= False,
):
# By default, rgb is not used.
pipeline = rs.pipeline()
config = rs.config()
config.enable_stream(rs.stream.depth, width, height, rs.format.z16, fps)
if enable_rgb:
config.enable_stream(rs.stream.color, width, height, rs.format.rgb8, fps)
profile = pipeline.start(config)
# build the sensor filter
hole_filling_filter = rs.hole_filling_filter(2)
spatial_filter = rs.spatial_filter()
spatial_filter.set_option(rs.option.filter_magnitude, 5)
spatial_filter.set_option(rs.option.filter_smooth_alpha, 0.75)
spatial_filter.set_option(rs.option.filter_smooth_delta, 1)
spatial_filter.set_option(rs.option.holes_fill, 4)
temporal_filter = rs.temporal_filter()
temporal_filter.set_option(rs.option.filter_smooth_alpha, 0.75)
temporal_filter.set_option(rs.option.filter_smooth_delta, 1)
# decimation_filter = rs.decimation_filter()
# decimation_filter.set_option(rs.option.filter_magnitude, 2)
def filter_func(frame):
frame = hole_filling_filter.process(frame)
frame = spatial_filter.process(frame)
frame = temporal_filter.process(frame)
# frame = decimation_filter.process(frame)
return frame
return pipeline, filter_func
def main(args):
rospy.init_node("a1_legged_gym_jetson")
input_filter, depth_range = get_input_filter(args)
model_script, model_device = get_encoder_script(args.logdir)
with open(osp.join(args.logdir, "config.json"), "r") as f:
config_dict = json.load(f, object_pairs_hook= OrderedDict)
if config_dict.get("sensor", dict()).get("forward_camera", dict()).get("refresh_duration", None) is not None:
refresh_duration = config_dict["sensor"]["forward_camera"]["refresh_duration"]
ros_rate = rospy.Rate(1.0 / refresh_duration)
rospy.loginfo("Using refresh duration {}s".format(refresh_duration))
else:
ros_rate = rospy.Rate(args.fps)
rs_pipeline, rs_filters = get_started_pipeline(
height= args.height,
width= args.width,
fps= args.fps,
enable_rgb= args.enable_rgb,
)
# gyro_pipeline = rs.pipeline()
# gyro_config = rs.config()
# gyro_config.enable_stream(rs.stream.gyro, rs.format.motion_xyz32f, 200)
# gyro_profile = gyro_pipeline.start(gyro_config)
embedding_publisher = rospy.Publisher(
args.namespace + "/visual_embedding",
Float32MultiArrayStamped,
queue_size= 1,
)
if args.enable_vis:
depth_image_publisher = rospy.Publisher(
args.namespace + "/camera/depth/image_rect_raw",
Image,
queue_size= 1,
)
network_input_publisher = rospy.Publisher(
args.namespace + "/camera/depth/network_input_raw",
Image,
queue_size= 1,
)
if args.enable_rgb:
rgb_image_publisher = rospy.Publisher(
args.namespace + "/camera/color/image_raw",
Image,
queue_size= 1,
)
rospy.loginfo("Depth range is clipped to [{}, {}] and normalized".format(depth_range[0], depth_range[1]))
rospy.loginfo("ROS, model, realsense have been initialized.")
if args.enable_vis:
rospy.loginfo("Visualization enabled, sending depth{} images".format(", rgb" if args.enable_rgb else ""))
try:
embedding_msg = Float32MultiArrayStamped()
embedding_msg.header.frame_id = args.namespace + "/camera_depth_optical_frame"
frame_got = False
while not rospy.is_shutdown():
# Wait for the depth image
frames = rs_pipeline.wait_for_frames(int( \
config_dict["sensor"]["forward_camera"]["latency_range"][1] \
* 1000)) # ms
embedding_msg.header.stamp = rospy.Time.now()
depth_frame = frames.get_depth_frame()
if not depth_frame:
continue
if not frame_got:
frame_got = True
rospy.loginfo("Realsense frame recieved. Sending embeddings...")
if args.enable_rgb:
color_frame = frames.get_color_frame()
# Use this branch to log the time when image is acquired
if args.enable_vis and not color_frame is None:
color_frame = np.asanyarray(color_frame.get_data())
rgb_image_msg = ros_numpy.msgify(Image, color_frame, encoding= "rgb8")
rgb_image_msg.header.stamp = rospy.Time.now()
rgb_image_msg.header.frame_id = args.namespace + "/camera_color_optical_frame"
rgb_image_publisher.publish(rgb_image_msg)
# Process the depth image and publish
depth_frame = rs_filters(depth_frame)
depth_image_ = np.asanyarray(depth_frame.get_data())
depth_image = torch.from_numpy(depth_image_.astype(np.float32)).unsqueeze(0).unsqueeze(0).to(model_device)
depth_image = input_filter(depth_image)
with torch.no_grad():
depth_embedding = model_script(depth_image).reshape(-1).cpu().numpy()
embedding_msg.header.seq += 1
embedding_msg.data = depth_embedding.tolist()
embedding_publisher.publish(embedding_msg)
# Publish the acquired image if needed
if args.enable_vis:
depth_image_msg = ros_numpy.msgify(Image, depth_image_, encoding= "16UC1")
depth_image_msg.header.stamp = rospy.Time.now()
depth_image_msg.header.frame_id = args.namespace + "/camera_depth_optical_frame"
depth_image_publisher.publish(depth_image_msg)
network_input_np = (\
depth_image.detach().cpu().numpy()[0, 0] * (depth_range[1] - depth_range[0]) \
+ depth_range[0]
).astype(np.uint16)
network_input_msg = ros_numpy.msgify(Image, network_input_np, encoding= "16UC1")
network_input_msg.header.stamp = rospy.Time.now()
network_input_msg.header.frame_id = args.namespace + "/camera_depth_optical_frame"
network_input_publisher.publish(network_input_msg)
ros_rate.sleep()
finally:
rs_pipeline.stop()
if __name__ == "__main__":
""" This script is designed to load the model and process the realsense image directly
from realsense SDK without realsense ROS wrapper
"""
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--namespace",
type= str,
default= "/a112138",
)
parser.add_argument("--logdir",
type= str,
help= "The log directory of the trained model",
)
parser.add_argument("--height",
type= int,
default= 240,
help= "The height of the realsense image",
)
parser.add_argument("--width",
type= int,
default= 424,
help= "The width of the realsense image",
)
parser.add_argument("--fps",
type= int,
default= 30,
help= "The fps of the realsense image",
)
parser.add_argument("--crop_left",
type= int,
default= 60,
help= "num of pixel to crop in the original pyrealsense readings."
)
parser.add_argument("--crop_right",
type= int,
default= 46,
help= "num of pixel to crop in the original pyrealsense readings."
)
parser.add_argument("--crop_top",
type= int,
default= 0,
help= "num of pixel to crop in the original pyrealsense readings."
)
parser.add_argument("--crop_bottom",
type= int,
default= 0,
help= "num of pixel to crop in the original pyrealsense readings."
)
parser.add_argument("--crop_far",
type= float,
default= 3.0,
help= "asside from the config far limit, make all depth readings larger than this value to be 3.0 in un-normalized network input."
)
parser.add_argument("--enable_rgb",
action= "store_true",
help= "Whether to enable rgb image",
)
parser.add_argument("--enable_vis",
action= "store_true",
help= "Whether to publish realsense image",
)
args = parser.parse_args()
main(args)