Adding onnx inference.
This commit is contained in:
parent
124617ec54
commit
94cf0da3ba
|
@ -43,6 +43,7 @@ RUN pip3 install warp-lang scikit-learn casadi mujoco pin
|
|||
RUN pip install jupyter ipykernel
|
||||
RUN pip install cyclonedds pygame
|
||||
RUN pip install pynput pygame
|
||||
RUN pip install onnx onnxruntime
|
||||
|
||||
# Set environmental variables required for using ROS
|
||||
RUN echo 'source /opt/ros/humble/setup.bash' >> ~/.bashrc
|
||||
|
|
|
@ -6,6 +6,8 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import onnxruntime
|
||||
|
||||
class RunningMeanStd(nn.Module):
|
||||
def __init__(self, shape = (), epsilon=1e-08):
|
||||
super(RunningMeanStd, self).__init__()
|
||||
|
@ -88,16 +90,44 @@ class Agent(nn.Module):
|
|||
action_mean = self.actor_mean(x)
|
||||
return action_mean
|
||||
|
||||
def forward(self, x):
|
||||
action_mean = self.actor_mean(self.obs_rms(x, update = False))
|
||||
return action_mean
|
||||
|
||||
class Policy:
|
||||
def __init__(self, checkpoint_path):
|
||||
self.agent = Agent()
|
||||
actor_sd = torch.load(checkpoint_path, map_location="cpu")
|
||||
self.agent.load_state_dict(actor_sd)
|
||||
onnx_file_name = checkpoint_path.replace(".pt", ".onnx")
|
||||
|
||||
dummy_input = torch.randn(1, 45)
|
||||
with torch.no_grad():
|
||||
torch_out = self.agent(dummy_input)
|
||||
|
||||
torch.onnx.export(
|
||||
self.agent, # The model being converted
|
||||
dummy_input, # An example input for the model
|
||||
onnx_file_name, # Output file name
|
||||
export_params=True, # Store trained parameter weights inside the model file
|
||||
opset_version=11, # ONNX version (opset) to export to, adjust as needed
|
||||
do_constant_folding=True, # Whether to perform constant folding optimization
|
||||
input_names=['input'], # Name of the input in the ONNX graph (can be customized)
|
||||
output_names=['action'], # Name of the output (assuming get_action and get_value are key outputs)
|
||||
)
|
||||
|
||||
self.ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=["CPUExecutionProvider"])
|
||||
ort_inputs = {'input': dummy_input.numpy()}
|
||||
ort_outs = self.ort_session.run(None, ort_inputs)
|
||||
np.testing.assert_allclose(torch_out.numpy(), ort_outs[0], rtol=1e-03, atol=1e-05)
|
||||
print("Exported model has been tested with ONNXRuntime, and the result looks good!")
|
||||
|
||||
def __call__(self, obs, info):
|
||||
with torch.no_grad():
|
||||
action = self.agent.get_action(self.agent.obs_rms(obs.unsqueeze(0), update = False))
|
||||
return action
|
||||
#with torch.no_grad():
|
||||
# action = self.agent.get_action(self.agent.obs_rms(obs.unsqueeze(0), update = False))
|
||||
ort_inputs = {'input': obs[np.newaxis].astype(np.float32)}
|
||||
ort_outs = self.ort_session.run(None, ort_inputs)
|
||||
return ort_outs[0]
|
||||
|
||||
class CommandInterface:
|
||||
def __init__(self, limits=None):
|
||||
|
@ -211,6 +241,7 @@ class CaTAgent:
|
|||
self.joint_vel_target = np.zeros(12)
|
||||
self.torques = np.zeros(12)
|
||||
self.contact_state = np.ones(4)
|
||||
self.foot_contact_forces_mag = np.zeros(4)
|
||||
self.test = 0
|
||||
|
||||
def wait_for_state(self):
|
||||
|
@ -228,6 +259,7 @@ class CaTAgent:
|
|||
self.dof_pos = np.array(joint_state['q'])[self.unitree_to_policy_map[:, 1]]
|
||||
self.dof_vel = np.array(joint_state['dq'])[self.unitree_to_policy_map[:, 1]]
|
||||
self.body_angular_vel = self.robot.getIMU()["gyro"]
|
||||
self.foot_contact_forces_mag = self.robot.getFootContact()
|
||||
|
||||
ob = np.concatenate(
|
||||
(
|
||||
|
@ -241,12 +273,16 @@ class CaTAgent:
|
|||
axis=0,
|
||||
)
|
||||
|
||||
return torch.tensor(ob, device=self.device).float()
|
||||
#return torch.tensor(ob, device=self.device).float()
|
||||
return ob
|
||||
|
||||
def publish_action(self, action, hard_reset=False):
|
||||
# command_for_robot = UnitreeLowCommand()
|
||||
#self.joint_pos_target = (
|
||||
# action[0, :12].detach().cpu().numpy() * 0.25
|
||||
#).flatten()
|
||||
self.joint_pos_target = (
|
||||
action[0, :12].detach().cpu().numpy() * 0.25
|
||||
action[0, :12] * 0.25
|
||||
).flatten()
|
||||
self.joint_pos_target += self.default_dof_pos
|
||||
self.joint_vel_target = np.zeros(12)
|
||||
|
@ -294,7 +330,8 @@ class CaTAgent:
|
|||
"contact_state": self.contact_state[np.newaxis, :],
|
||||
"body_linear_vel_cmd": self.commands[np.newaxis, 0:2],
|
||||
"body_angular_vel_cmd": self.commands[np.newaxis, 2:],
|
||||
"torques": self.torques
|
||||
"torques": self.torques,
|
||||
"foot_contact_forces_mag": self.foot_contact_forces_mag.copy()
|
||||
}
|
||||
|
||||
self.timestep += 1
|
||||
|
|
Loading…
Reference in New Issue