This commit is contained in:
Simon Alibert 2024-03-24 19:31:47 +01:00
parent b905111895
commit 127de1258d
2 changed files with 29 additions and 23 deletions
lerobot/common/envs/simxarm/simxarm/task

View File

@ -63,7 +63,8 @@ class Base(robot_env.MujocoRobotEnv):
return self._get_obs()
def _step_callback(self):
self.sim.forward()
# self.sim.forward()
self._mujoco.mj_forward(self.model, self.data)
def _limit_gripper(self, gripper_pos, pos_ctrl):
if gripper_pos[0] > self.center_of_table[0] - 0.105 + 0.15:
@ -88,7 +89,12 @@ class Base(robot_env.MujocoRobotEnv):
self._utils.get_site_xpos(self.model, self.data, "grasp"), pos_ctrl
) * (1 / self.n_substeps)
gripper_ctrl = np.array([gripper_ctrl, gripper_ctrl])
mocap.apply_action(self.sim, np.concatenate([pos_ctrl, self.gripper_rotation, gripper_ctrl]))
mocap.apply_action(
self.model,
self._model_names,
self.data,
np.concatenate([pos_ctrl, self.gripper_rotation, gripper_ctrl]),
)
def _viewer_setup(self):
body_id = self.sim.model.body_name2id("link7")
@ -144,8 +150,9 @@ class Base(robot_env.MujocoRobotEnv):
assert action.shape == (4,)
assert self.action_space.contains(action), "{!r} ({}) invalid".format(action, type(action))
self._apply_action(action)
for _ in range(2):
self.sim.step()
# for _ in range(2):
# self.sim.step()
self._mujoco.mj_step(self.model, self.data, nstep=2)
self._step_callback()
obs = self._get_obs()
reward = self.get_reward()

View File

@ -3,17 +3,17 @@ import mujoco
import numpy as np
def apply_action(sim, action):
if sim.model.nmocap > 0:
pos_action, gripper_action = np.split(action, (sim.model.nmocap * 7,))
if sim.data.ctrl is not None:
def apply_action(model, model_names, data, action):
if model.nmocap > 0:
pos_action, gripper_action = np.split(action, (model.nmocap * 7,))
if data.ctrl is not None:
for i in range(gripper_action.shape[0]):
sim.data.ctrl[i] = gripper_action[i]
pos_action = pos_action.reshape(sim.model.nmocap, 7)
data.ctrl[i] = gripper_action[i]
pos_action = pos_action.reshape(model.nmocap, 7)
pos_delta, quat_delta = pos_action[:, :3], pos_action[:, 3:]
reset_mocap2body_xpos(sim)
sim.data.mocap_pos[:] = sim.data.mocap_pos + pos_delta
sim.data.mocap_quat[:] = sim.data.mocap_quat + quat_delta
reset_mocap2body_xpos(model, model_names, data)
data.mocap_pos[:] = data.mocap_pos + pos_delta
data.mocap_quat[:] = data.mocap_quat + quat_delta
def reset(model, data):
@ -41,28 +41,27 @@ def reset(model, data):
mujoco.mj_forward(model, data)
def reset_mocap2body_xpos(sim):
if sim.model.eq_type is None or sim.model.eq_obj1id is None or sim.model.eq_obj2id is None:
def reset_mocap2body_xpos(model, model_names, data):
if model.eq_type is None or model.eq_obj1id is None or model.eq_obj2id is None:
return
# For all weld constraints
for eq_type, obj1_id, obj2_id in zip(
sim.model.eq_type, sim.model.eq_obj1id, sim.model.eq_obj2id, strict=False
):
for eq_type, obj1_id, obj2_id in zip(model.eq_type, model.eq_obj1id, model.eq_obj2id, strict=False):
# if eq_type != mujoco_py.const.EQ_WELD:
if eq_type != mujoco.mjtEq.mjEQ_WELD:
continue
body2 = sim.model.body_id2name(obj2_id)
# body2 = model.body_id2name(obj2_id)
body2 = model_names.body_id2name[obj2_id]
if body2 == "B0" or body2 == "B9" or body2 == "B1":
continue
mocap_id = sim.model.body_mocapid[obj1_id]
mocap_id = model.body_mocapid[obj1_id]
if mocap_id != -1:
# obj1 is the mocap, obj2 is the welded body
body_idx = obj2_id
else:
# obj2 is the mocap, obj1 is the welded body
mocap_id = sim.model.body_mocapid[obj2_id]
mocap_id = model.body_mocapid[obj2_id]
body_idx = obj1_id
assert mocap_id != -1
sim.data.mocap_pos[mocap_id][:] = sim.data.body_xpos[body_idx]
sim.data.mocap_quat[mocap_id][:] = sim.data.body_xquat[body_idx]
data.mocap_pos[mocap_id][:] = data.xpos[body_idx]
data.mocap_quat[mocap_id][:] = data.xquat[body_idx]