Adding onnx inference.

This commit is contained in:
jogima-cyber 2024-10-05 00:47:37 +00:00
parent 124617ec54
commit 94cf0da3ba
2 changed files with 44 additions and 6 deletions

View File

@ -43,6 +43,7 @@ RUN pip3 install warp-lang scikit-learn casadi mujoco pin
RUN pip install jupyter ipykernel RUN pip install jupyter ipykernel
RUN pip install cyclonedds pygame RUN pip install cyclonedds pygame
RUN pip install pynput pygame RUN pip install pynput pygame
RUN pip install onnx onnxruntime
# Set environmental variables required for using ROS # Set environmental variables required for using ROS
RUN echo 'source /opt/ros/humble/setup.bash' >> ~/.bashrc RUN echo 'source /opt/ros/humble/setup.bash' >> ~/.bashrc

View File

@ -6,6 +6,8 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import onnxruntime
class RunningMeanStd(nn.Module): class RunningMeanStd(nn.Module):
def __init__(self, shape = (), epsilon=1e-08): def __init__(self, shape = (), epsilon=1e-08):
super(RunningMeanStd, self).__init__() super(RunningMeanStd, self).__init__()
@ -88,16 +90,44 @@ class Agent(nn.Module):
action_mean = self.actor_mean(x) action_mean = self.actor_mean(x)
return action_mean return action_mean
def forward(self, x):
action_mean = self.actor_mean(self.obs_rms(x, update = False))
return action_mean
class Policy: class Policy:
def __init__(self, checkpoint_path): def __init__(self, checkpoint_path):
self.agent = Agent() self.agent = Agent()
actor_sd = torch.load(checkpoint_path, map_location="cpu") actor_sd = torch.load(checkpoint_path, map_location="cpu")
self.agent.load_state_dict(actor_sd) self.agent.load_state_dict(actor_sd)
onnx_file_name = checkpoint_path.replace(".pt", ".onnx")
dummy_input = torch.randn(1, 45)
with torch.no_grad():
torch_out = self.agent(dummy_input)
torch.onnx.export(
self.agent, # The model being converted
dummy_input, # An example input for the model
onnx_file_name, # Output file name
export_params=True, # Store trained parameter weights inside the model file
opset_version=11, # ONNX version (opset) to export to, adjust as needed
do_constant_folding=True, # Whether to perform constant folding optimization
input_names=['input'], # Name of the input in the ONNX graph (can be customized)
output_names=['action'], # Name of the output (assuming get_action and get_value are key outputs)
)
self.ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=["CPUExecutionProvider"])
ort_inputs = {'input': dummy_input.numpy()}
ort_outs = self.ort_session.run(None, ort_inputs)
np.testing.assert_allclose(torch_out.numpy(), ort_outs[0], rtol=1e-03, atol=1e-05)
print("Exported model has been tested with ONNXRuntime, and the result looks good!")
def __call__(self, obs, info): def __call__(self, obs, info):
with torch.no_grad(): #with torch.no_grad():
action = self.agent.get_action(self.agent.obs_rms(obs.unsqueeze(0), update = False)) # action = self.agent.get_action(self.agent.obs_rms(obs.unsqueeze(0), update = False))
return action ort_inputs = {'input': obs[np.newaxis].astype(np.float32)}
ort_outs = self.ort_session.run(None, ort_inputs)
return ort_outs[0]
class CommandInterface: class CommandInterface:
def __init__(self, limits=None): def __init__(self, limits=None):
@ -211,6 +241,7 @@ class CaTAgent:
self.joint_vel_target = np.zeros(12) self.joint_vel_target = np.zeros(12)
self.torques = np.zeros(12) self.torques = np.zeros(12)
self.contact_state = np.ones(4) self.contact_state = np.ones(4)
self.foot_contact_forces_mag = np.zeros(4)
self.test = 0 self.test = 0
def wait_for_state(self): def wait_for_state(self):
@ -228,6 +259,7 @@ class CaTAgent:
self.dof_pos = np.array(joint_state['q'])[self.unitree_to_policy_map[:, 1]] self.dof_pos = np.array(joint_state['q'])[self.unitree_to_policy_map[:, 1]]
self.dof_vel = np.array(joint_state['dq'])[self.unitree_to_policy_map[:, 1]] self.dof_vel = np.array(joint_state['dq'])[self.unitree_to_policy_map[:, 1]]
self.body_angular_vel = self.robot.getIMU()["gyro"] self.body_angular_vel = self.robot.getIMU()["gyro"]
self.foot_contact_forces_mag = self.robot.getFootContact()
ob = np.concatenate( ob = np.concatenate(
( (
@ -241,12 +273,16 @@ class CaTAgent:
axis=0, axis=0,
) )
return torch.tensor(ob, device=self.device).float() #return torch.tensor(ob, device=self.device).float()
return ob
def publish_action(self, action, hard_reset=False): def publish_action(self, action, hard_reset=False):
# command_for_robot = UnitreeLowCommand() # command_for_robot = UnitreeLowCommand()
#self.joint_pos_target = (
# action[0, :12].detach().cpu().numpy() * 0.25
#).flatten()
self.joint_pos_target = ( self.joint_pos_target = (
action[0, :12].detach().cpu().numpy() * 0.25 action[0, :12] * 0.25
).flatten() ).flatten()
self.joint_pos_target += self.default_dof_pos self.joint_pos_target += self.default_dof_pos
self.joint_vel_target = np.zeros(12) self.joint_vel_target = np.zeros(12)
@ -294,7 +330,8 @@ class CaTAgent:
"contact_state": self.contact_state[np.newaxis, :], "contact_state": self.contact_state[np.newaxis, :],
"body_linear_vel_cmd": self.commands[np.newaxis, 0:2], "body_linear_vel_cmd": self.commands[np.newaxis, 0:2],
"body_angular_vel_cmd": self.commands[np.newaxis, 2:], "body_angular_vel_cmd": self.commands[np.newaxis, 2:],
"torques": self.torques "torques": self.torques,
"foot_contact_forces_mag": self.foot_contact_forces_mag.copy()
} }
self.timestep += 1 self.timestep += 1