Adding onnx inference.
This commit is contained in:
parent
124617ec54
commit
94cf0da3ba
|
@ -43,6 +43,7 @@ RUN pip3 install warp-lang scikit-learn casadi mujoco pin
|
||||||
RUN pip install jupyter ipykernel
|
RUN pip install jupyter ipykernel
|
||||||
RUN pip install cyclonedds pygame
|
RUN pip install cyclonedds pygame
|
||||||
RUN pip install pynput pygame
|
RUN pip install pynput pygame
|
||||||
|
RUN pip install onnx onnxruntime
|
||||||
|
|
||||||
# Set environmental variables required for using ROS
|
# Set environmental variables required for using ROS
|
||||||
RUN echo 'source /opt/ros/humble/setup.bash' >> ~/.bashrc
|
RUN echo 'source /opt/ros/humble/setup.bash' >> ~/.bashrc
|
||||||
|
|
|
@ -6,6 +6,8 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import onnxruntime
|
||||||
|
|
||||||
class RunningMeanStd(nn.Module):
|
class RunningMeanStd(nn.Module):
|
||||||
def __init__(self, shape = (), epsilon=1e-08):
|
def __init__(self, shape = (), epsilon=1e-08):
|
||||||
super(RunningMeanStd, self).__init__()
|
super(RunningMeanStd, self).__init__()
|
||||||
|
@ -88,16 +90,44 @@ class Agent(nn.Module):
|
||||||
action_mean = self.actor_mean(x)
|
action_mean = self.actor_mean(x)
|
||||||
return action_mean
|
return action_mean
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
action_mean = self.actor_mean(self.obs_rms(x, update = False))
|
||||||
|
return action_mean
|
||||||
|
|
||||||
class Policy:
|
class Policy:
|
||||||
def __init__(self, checkpoint_path):
|
def __init__(self, checkpoint_path):
|
||||||
self.agent = Agent()
|
self.agent = Agent()
|
||||||
actor_sd = torch.load(checkpoint_path, map_location="cpu")
|
actor_sd = torch.load(checkpoint_path, map_location="cpu")
|
||||||
self.agent.load_state_dict(actor_sd)
|
self.agent.load_state_dict(actor_sd)
|
||||||
|
onnx_file_name = checkpoint_path.replace(".pt", ".onnx")
|
||||||
|
|
||||||
|
dummy_input = torch.randn(1, 45)
|
||||||
|
with torch.no_grad():
|
||||||
|
torch_out = self.agent(dummy_input)
|
||||||
|
|
||||||
|
torch.onnx.export(
|
||||||
|
self.agent, # The model being converted
|
||||||
|
dummy_input, # An example input for the model
|
||||||
|
onnx_file_name, # Output file name
|
||||||
|
export_params=True, # Store trained parameter weights inside the model file
|
||||||
|
opset_version=11, # ONNX version (opset) to export to, adjust as needed
|
||||||
|
do_constant_folding=True, # Whether to perform constant folding optimization
|
||||||
|
input_names=['input'], # Name of the input in the ONNX graph (can be customized)
|
||||||
|
output_names=['action'], # Name of the output (assuming get_action and get_value are key outputs)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=["CPUExecutionProvider"])
|
||||||
|
ort_inputs = {'input': dummy_input.numpy()}
|
||||||
|
ort_outs = self.ort_session.run(None, ort_inputs)
|
||||||
|
np.testing.assert_allclose(torch_out.numpy(), ort_outs[0], rtol=1e-03, atol=1e-05)
|
||||||
|
print("Exported model has been tested with ONNXRuntime, and the result looks good!")
|
||||||
|
|
||||||
def __call__(self, obs, info):
|
def __call__(self, obs, info):
|
||||||
with torch.no_grad():
|
#with torch.no_grad():
|
||||||
action = self.agent.get_action(self.agent.obs_rms(obs.unsqueeze(0), update = False))
|
# action = self.agent.get_action(self.agent.obs_rms(obs.unsqueeze(0), update = False))
|
||||||
return action
|
ort_inputs = {'input': obs[np.newaxis].astype(np.float32)}
|
||||||
|
ort_outs = self.ort_session.run(None, ort_inputs)
|
||||||
|
return ort_outs[0]
|
||||||
|
|
||||||
class CommandInterface:
|
class CommandInterface:
|
||||||
def __init__(self, limits=None):
|
def __init__(self, limits=None):
|
||||||
|
@ -211,6 +241,7 @@ class CaTAgent:
|
||||||
self.joint_vel_target = np.zeros(12)
|
self.joint_vel_target = np.zeros(12)
|
||||||
self.torques = np.zeros(12)
|
self.torques = np.zeros(12)
|
||||||
self.contact_state = np.ones(4)
|
self.contact_state = np.ones(4)
|
||||||
|
self.foot_contact_forces_mag = np.zeros(4)
|
||||||
self.test = 0
|
self.test = 0
|
||||||
|
|
||||||
def wait_for_state(self):
|
def wait_for_state(self):
|
||||||
|
@ -228,6 +259,7 @@ class CaTAgent:
|
||||||
self.dof_pos = np.array(joint_state['q'])[self.unitree_to_policy_map[:, 1]]
|
self.dof_pos = np.array(joint_state['q'])[self.unitree_to_policy_map[:, 1]]
|
||||||
self.dof_vel = np.array(joint_state['dq'])[self.unitree_to_policy_map[:, 1]]
|
self.dof_vel = np.array(joint_state['dq'])[self.unitree_to_policy_map[:, 1]]
|
||||||
self.body_angular_vel = self.robot.getIMU()["gyro"]
|
self.body_angular_vel = self.robot.getIMU()["gyro"]
|
||||||
|
self.foot_contact_forces_mag = self.robot.getFootContact()
|
||||||
|
|
||||||
ob = np.concatenate(
|
ob = np.concatenate(
|
||||||
(
|
(
|
||||||
|
@ -241,12 +273,16 @@ class CaTAgent:
|
||||||
axis=0,
|
axis=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
return torch.tensor(ob, device=self.device).float()
|
#return torch.tensor(ob, device=self.device).float()
|
||||||
|
return ob
|
||||||
|
|
||||||
def publish_action(self, action, hard_reset=False):
|
def publish_action(self, action, hard_reset=False):
|
||||||
# command_for_robot = UnitreeLowCommand()
|
# command_for_robot = UnitreeLowCommand()
|
||||||
|
#self.joint_pos_target = (
|
||||||
|
# action[0, :12].detach().cpu().numpy() * 0.25
|
||||||
|
#).flatten()
|
||||||
self.joint_pos_target = (
|
self.joint_pos_target = (
|
||||||
action[0, :12].detach().cpu().numpy() * 0.25
|
action[0, :12] * 0.25
|
||||||
).flatten()
|
).flatten()
|
||||||
self.joint_pos_target += self.default_dof_pos
|
self.joint_pos_target += self.default_dof_pos
|
||||||
self.joint_vel_target = np.zeros(12)
|
self.joint_vel_target = np.zeros(12)
|
||||||
|
@ -294,7 +330,8 @@ class CaTAgent:
|
||||||
"contact_state": self.contact_state[np.newaxis, :],
|
"contact_state": self.contact_state[np.newaxis, :],
|
||||||
"body_linear_vel_cmd": self.commands[np.newaxis, 0:2],
|
"body_linear_vel_cmd": self.commands[np.newaxis, 0:2],
|
||||||
"body_angular_vel_cmd": self.commands[np.newaxis, 2:],
|
"body_angular_vel_cmd": self.commands[np.newaxis, 2:],
|
||||||
"torques": self.torques
|
"torques": self.torques,
|
||||||
|
"foot_contact_forces_mag": self.foot_contact_forces_mag.copy()
|
||||||
}
|
}
|
||||||
|
|
||||||
self.timestep += 1
|
self.timestep += 1
|
||||||
|
|
Loading…
Reference in New Issue