Adding onnx inference.

This commit is contained in:
jogima-cyber 2024-10-05 00:47:37 +00:00
parent 124617ec54
commit 94cf0da3ba
2 changed files with 44 additions and 6 deletions

View File

@ -43,6 +43,7 @@ RUN pip3 install warp-lang scikit-learn casadi mujoco pin
RUN pip install jupyter ipykernel
RUN pip install cyclonedds pygame
RUN pip install pynput pygame
RUN pip install onnx onnxruntime
# Set environmental variables required for using ROS
RUN echo 'source /opt/ros/humble/setup.bash' >> ~/.bashrc

View File

@ -6,6 +6,8 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import onnxruntime
class RunningMeanStd(nn.Module):
def __init__(self, shape = (), epsilon=1e-08):
super(RunningMeanStd, self).__init__()
@ -88,16 +90,44 @@ class Agent(nn.Module):
action_mean = self.actor_mean(x)
return action_mean
def forward(self, x):
action_mean = self.actor_mean(self.obs_rms(x, update = False))
return action_mean
class Policy:
def __init__(self, checkpoint_path):
self.agent = Agent()
actor_sd = torch.load(checkpoint_path, map_location="cpu")
self.agent.load_state_dict(actor_sd)
onnx_file_name = checkpoint_path.replace(".pt", ".onnx")
dummy_input = torch.randn(1, 45)
with torch.no_grad():
torch_out = self.agent(dummy_input)
torch.onnx.export(
self.agent, # The model being converted
dummy_input, # An example input for the model
onnx_file_name, # Output file name
export_params=True, # Store trained parameter weights inside the model file
opset_version=11, # ONNX version (opset) to export to, adjust as needed
do_constant_folding=True, # Whether to perform constant folding optimization
input_names=['input'], # Name of the input in the ONNX graph (can be customized)
output_names=['action'], # Name of the output (assuming get_action and get_value are key outputs)
)
self.ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=["CPUExecutionProvider"])
ort_inputs = {'input': dummy_input.numpy()}
ort_outs = self.ort_session.run(None, ort_inputs)
np.testing.assert_allclose(torch_out.numpy(), ort_outs[0], rtol=1e-03, atol=1e-05)
print("Exported model has been tested with ONNXRuntime, and the result looks good!")
def __call__(self, obs, info):
with torch.no_grad():
action = self.agent.get_action(self.agent.obs_rms(obs.unsqueeze(0), update = False))
return action
#with torch.no_grad():
# action = self.agent.get_action(self.agent.obs_rms(obs.unsqueeze(0), update = False))
ort_inputs = {'input': obs[np.newaxis].astype(np.float32)}
ort_outs = self.ort_session.run(None, ort_inputs)
return ort_outs[0]
class CommandInterface:
def __init__(self, limits=None):
@ -211,6 +241,7 @@ class CaTAgent:
self.joint_vel_target = np.zeros(12)
self.torques = np.zeros(12)
self.contact_state = np.ones(4)
self.foot_contact_forces_mag = np.zeros(4)
self.test = 0
def wait_for_state(self):
@ -228,6 +259,7 @@ class CaTAgent:
self.dof_pos = np.array(joint_state['q'])[self.unitree_to_policy_map[:, 1]]
self.dof_vel = np.array(joint_state['dq'])[self.unitree_to_policy_map[:, 1]]
self.body_angular_vel = self.robot.getIMU()["gyro"]
self.foot_contact_forces_mag = self.robot.getFootContact()
ob = np.concatenate(
(
@ -241,12 +273,16 @@ class CaTAgent:
axis=0,
)
return torch.tensor(ob, device=self.device).float()
#return torch.tensor(ob, device=self.device).float()
return ob
def publish_action(self, action, hard_reset=False):
# command_for_robot = UnitreeLowCommand()
#self.joint_pos_target = (
# action[0, :12].detach().cpu().numpy() * 0.25
#).flatten()
self.joint_pos_target = (
action[0, :12].detach().cpu().numpy() * 0.25
action[0, :12] * 0.25
).flatten()
self.joint_pos_target += self.default_dof_pos
self.joint_vel_target = np.zeros(12)
@ -294,7 +330,8 @@ class CaTAgent:
"contact_state": self.contact_state[np.newaxis, :],
"body_linear_vel_cmd": self.commands[np.newaxis, 0:2],
"body_angular_vel_cmd": self.commands[np.newaxis, 2:],
"torques": self.torques
"torques": self.torques,
"foot_contact_forces_mag": self.foot_contact_forces_mag.copy()
}
self.timestep += 1