diff --git a/Go2Py/sim/mujoco.py b/Go2Py/sim/mujoco.py index 3ecc1f8..ae2d202 100644 --- a/Go2Py/sim/mujoco.py +++ b/Go2Py/sim/mujoco.py @@ -18,11 +18,14 @@ dist = np.zeros(nray, np.float64) class Go2Sim: - def __init__(self, mode='lowlevel', render=True, dt=0.002): + def __init__(self, mode='lowlevel', render=True, dt=0.002, xml_path=None): - self.model = mujoco.MjModel.from_xml_path( - os.path.join(ASSETS_PATH, 'mujoco/go2.xml') - ) + if xml_path is None: + self.model = mujoco.MjModel.from_xml_path( + os.path.join(ASSETS_PATH, 'mujoco/go2.xml') + ) + else: + self.model = mujoco.MjModel.from_xml_path(xml_path) self.simulated = True self.data = mujoco.MjData(self.model) self.dt = dt