From 94cf0da3ba4762ef345d58ef5c5161eeb8752ed7 Mon Sep 17 00:00:00 2001 From: jogima-cyber Date: Sat, 5 Oct 2024 00:47:37 +0000 Subject: [PATCH] Adding onnx inference. --- .devcontainer/Dockerfile | 1 + Go2Py/control/cat.py | 49 +++++++++++++++++++++++++++++++++++----- 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 9b877e2..0b3ac48 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -43,6 +43,7 @@ RUN pip3 install warp-lang scikit-learn casadi mujoco pin RUN pip install jupyter ipykernel RUN pip install cyclonedds pygame RUN pip install pynput pygame +RUN pip install onnx onnxruntime # Set environmental variables required for using ROS RUN echo 'source /opt/ros/humble/setup.bash' >> ~/.bashrc diff --git a/Go2Py/control/cat.py b/Go2Py/control/cat.py index a968e7d..2062c89 100755 --- a/Go2Py/control/cat.py +++ b/Go2Py/control/cat.py @@ -6,6 +6,8 @@ import torch import torch.nn as nn import torch.nn.functional as F +import onnxruntime + class RunningMeanStd(nn.Module): def __init__(self, shape = (), epsilon=1e-08): super(RunningMeanStd, self).__init__() @@ -88,16 +90,44 @@ class Agent(nn.Module): action_mean = self.actor_mean(x) return action_mean + def forward(self, x): + action_mean = self.actor_mean(self.obs_rms(x, update = False)) + return action_mean + class Policy: def __init__(self, checkpoint_path): self.agent = Agent() actor_sd = torch.load(checkpoint_path, map_location="cpu") self.agent.load_state_dict(actor_sd) + onnx_file_name = checkpoint_path.replace(".pt", ".onnx") + + dummy_input = torch.randn(1, 45) + with torch.no_grad(): + torch_out = self.agent(dummy_input) + + torch.onnx.export( + self.agent, # The model being converted + dummy_input, # An example input for the model + onnx_file_name, # Output file name + export_params=True, # Store trained parameter weights inside the model file + opset_version=11, # ONNX version (opset) to export to, adjust as needed + do_constant_folding=True, # Whether to perform constant folding optimization + input_names=['input'], # Name of the input in the ONNX graph (can be customized) + output_names=['action'], # Name of the output (assuming get_action and get_value are key outputs) + ) + + self.ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=["CPUExecutionProvider"]) + ort_inputs = {'input': dummy_input.numpy()} + ort_outs = self.ort_session.run(None, ort_inputs) + np.testing.assert_allclose(torch_out.numpy(), ort_outs[0], rtol=1e-03, atol=1e-05) + print("Exported model has been tested with ONNXRuntime, and the result looks good!") def __call__(self, obs, info): - with torch.no_grad(): - action = self.agent.get_action(self.agent.obs_rms(obs.unsqueeze(0), update = False)) - return action + #with torch.no_grad(): + # action = self.agent.get_action(self.agent.obs_rms(obs.unsqueeze(0), update = False)) + ort_inputs = {'input': obs[np.newaxis].astype(np.float32)} + ort_outs = self.ort_session.run(None, ort_inputs) + return ort_outs[0] class CommandInterface: def __init__(self, limits=None): @@ -211,6 +241,7 @@ class CaTAgent: self.joint_vel_target = np.zeros(12) self.torques = np.zeros(12) self.contact_state = np.ones(4) + self.foot_contact_forces_mag = np.zeros(4) self.test = 0 def wait_for_state(self): @@ -228,6 +259,7 @@ class CaTAgent: self.dof_pos = np.array(joint_state['q'])[self.unitree_to_policy_map[:, 1]] self.dof_vel = np.array(joint_state['dq'])[self.unitree_to_policy_map[:, 1]] self.body_angular_vel = self.robot.getIMU()["gyro"] + self.foot_contact_forces_mag = self.robot.getFootContact() ob = np.concatenate( ( @@ -241,12 +273,16 @@ class CaTAgent: axis=0, ) - return torch.tensor(ob, device=self.device).float() + #return torch.tensor(ob, device=self.device).float() + return ob def publish_action(self, action, hard_reset=False): # command_for_robot = UnitreeLowCommand() + #self.joint_pos_target = ( + # action[0, :12].detach().cpu().numpy() * 0.25 + #).flatten() self.joint_pos_target = ( - action[0, :12].detach().cpu().numpy() * 0.25 + action[0, :12] * 0.25 ).flatten() self.joint_pos_target += self.default_dof_pos self.joint_vel_target = np.zeros(12) @@ -294,7 +330,8 @@ class CaTAgent: "contact_state": self.contact_state[np.newaxis, :], "body_linear_vel_cmd": self.commands[np.newaxis, 0:2], "body_angular_vel_cmd": self.commands[np.newaxis, 2:], - "torques": self.torques + "torques": self.torques, + "foot_contact_forces_mag": self.foot_contact_forces_mag.copy() } self.timestep += 1