diff --git a/.gitattributes b/.gitattributes index df7d2d5b..4135de8f 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1 +1,2 @@ *.memmap filter=lfs diff=lfs merge=lfs -text +*.stl filter=lfs diff=lfs merge=lfs -text diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index 8295ed48..d0206cc7 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -32,6 +32,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): collate_fn: Callable = None, writer: Writer = None, transform: "torchrl.envs.Transform" = None, + # storage = None, ): self.dataset_id = dataset_id self.version = version @@ -43,7 +44,12 @@ class AbstractExperienceReplay(TensorDictReplayBuffer): f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})." ) - storage = self._download_or_load_dataset() + # HACK + if dataset_id == "xarm_lift_medium": + self.data_dir = self.root / self.dataset_id + storage = self._download_and_preproc_obsolete() + else: + storage = self._download_or_load_dataset() super().__init__( storage=storage, diff --git a/lerobot/common/datasets/simxarm.py b/lerobot/common/datasets/simxarm.py index d7e2e18f..b5cec7e1 100644 --- a/lerobot/common/datasets/simxarm.py +++ b/lerobot/common/datasets/simxarm.py @@ -67,11 +67,11 @@ class SimxarmExperienceReplay(AbstractExperienceReplay): ) def _download_and_preproc_obsolete(self): - assert self.root is not None + # assert self.root is not None # TODO(rcadene): finish download - download() + # download() - dataset_path = self.root / f"{self.dataset_id}_raw" / "buffer.pkl" + dataset_path = self.root / f"{self.dataset_id}" / "buffer.pkl" print(f"Using offline dataset '{dataset_path}'") with open(dataset_path, "rb") as f: dataset_dict = pickle.load(f) @@ -105,8 +105,8 @@ class SimxarmExperienceReplay(AbstractExperienceReplay): "frame_id": torch.arange(0, num_frames, 1), ("next", "observation", "image"): next_image, ("next", "observation", "state"): next_state, - ("next", "observation", "reward"): next_reward, - ("next", "observation", "done"): next_done, + ("next", "reward"): next_reward, + ("next", "done"): next_done, }, batch_size=num_frames, ) diff --git a/lerobot/common/envs/aloha/assets/tabletop.stl b/lerobot/common/envs/aloha/assets/tabletop.stl index ab35cdf7..1c17d3f0 100644 Binary files a/lerobot/common/envs/aloha/assets/tabletop.stl and b/lerobot/common/envs/aloha/assets/tabletop.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_left.stl b/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_left.stl index 534c7af9..ef1f3f35 100644 Binary files a/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_left.stl and b/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_left.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_right.stl b/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_right.stl index d6a492c2..7eb8aefd 100644 Binary files a/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_right.stl and b/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_right.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_10_gripper_finger.stl b/lerobot/common/envs/aloha/assets/vx300s_10_gripper_finger.stl index d6df86be..4c2b3a1f 100644 Binary files a/lerobot/common/envs/aloha/assets/vx300s_10_gripper_finger.stl and b/lerobot/common/envs/aloha/assets/vx300s_10_gripper_finger.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_11_ar_tag.stl b/lerobot/common/envs/aloha/assets/vx300s_11_ar_tag.stl index 193014b6..8a30f7cc 100644 Binary files a/lerobot/common/envs/aloha/assets/vx300s_11_ar_tag.stl and b/lerobot/common/envs/aloha/assets/vx300s_11_ar_tag.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_1_base.stl b/lerobot/common/envs/aloha/assets/vx300s_1_base.stl index 5a7efda2..9198e625 100644 Binary files a/lerobot/common/envs/aloha/assets/vx300s_1_base.stl and b/lerobot/common/envs/aloha/assets/vx300s_1_base.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_2_shoulder.stl b/lerobot/common/envs/aloha/assets/vx300s_2_shoulder.stl index dc22aa7e..ab3d9570 100644 Binary files a/lerobot/common/envs/aloha/assets/vx300s_2_shoulder.stl and b/lerobot/common/envs/aloha/assets/vx300s_2_shoulder.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_3_upper_arm.stl b/lerobot/common/envs/aloha/assets/vx300s_3_upper_arm.stl index 111c586e..3d6f663c 100644 Binary files a/lerobot/common/envs/aloha/assets/vx300s_3_upper_arm.stl and b/lerobot/common/envs/aloha/assets/vx300s_3_upper_arm.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_4_upper_forearm.stl b/lerobot/common/envs/aloha/assets/vx300s_4_upper_forearm.stl index 8170d21c..4eb249e7 100644 Binary files a/lerobot/common/envs/aloha/assets/vx300s_4_upper_forearm.stl and b/lerobot/common/envs/aloha/assets/vx300s_4_upper_forearm.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_5_lower_forearm.stl b/lerobot/common/envs/aloha/assets/vx300s_5_lower_forearm.stl index 39581f83..34c76221 100644 Binary files a/lerobot/common/envs/aloha/assets/vx300s_5_lower_forearm.stl and b/lerobot/common/envs/aloha/assets/vx300s_5_lower_forearm.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_6_wrist.stl b/lerobot/common/envs/aloha/assets/vx300s_6_wrist.stl index ab8423e9..232fabf7 100644 Binary files a/lerobot/common/envs/aloha/assets/vx300s_6_wrist.stl and b/lerobot/common/envs/aloha/assets/vx300s_6_wrist.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_7_gripper.stl b/lerobot/common/envs/aloha/assets/vx300s_7_gripper.stl index 043db9ca..946c3c86 100644 Binary files a/lerobot/common/envs/aloha/assets/vx300s_7_gripper.stl and b/lerobot/common/envs/aloha/assets/vx300s_7_gripper.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_8_gripper_prop.stl b/lerobot/common/envs/aloha/assets/vx300s_8_gripper_prop.stl index 36099b42..28d5bd76 100644 Binary files a/lerobot/common/envs/aloha/assets/vx300s_8_gripper_prop.stl and b/lerobot/common/envs/aloha/assets/vx300s_8_gripper_prop.stl differ diff --git a/lerobot/common/envs/aloha/assets/vx300s_9_gripper_bar.stl b/lerobot/common/envs/aloha/assets/vx300s_9_gripper_bar.stl index eba3caa2..5201d5ea 100644 Binary files a/lerobot/common/envs/aloha/assets/vx300s_9_gripper_bar.stl and b/lerobot/common/envs/aloha/assets/vx300s_9_gripper_bar.stl differ diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 06c7c43f..855e073b 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -17,7 +17,7 @@ def make_env(cfg, transform=None): } if cfg.env.name == "simxarm": - from lerobot.common.envs.simxarm import SimxarmEnv + from lerobot.common.envs.simxarm.env import SimxarmEnv kwargs["task"] = cfg.env.task clsfunc = SimxarmEnv diff --git a/lerobot/common/envs/simxarm.py b/lerobot/common/envs/simxarm/env.py similarity index 93% rename from lerobot/common/envs/simxarm.py rename to lerobot/common/envs/simxarm/env.py index eac3666d..7236e911 100644 --- a/lerobot/common/envs/simxarm.py +++ b/lerobot/common/envs/simxarm/env.py @@ -19,8 +19,8 @@ from lerobot.common.utils import set_seed MAX_NUM_ACTIONS = 4 -_has_gym = importlib.util.find_spec("gym") is not None -_has_simxarm = importlib.util.find_spec("simxarm") is not None and _has_gym +_has_gym = importlib.util.find_spec("gymnasium") is not None +# _has_simxarm = importlib.util.find_spec("simxarm") is not None and _has_gym class SimxarmEnv(AbstractEnv): @@ -49,13 +49,14 @@ class SimxarmEnv(AbstractEnv): ) def _make_env(self): - if not _has_simxarm: - raise ImportError("Cannot import simxarm.") + # if not _has_simxarm: + # raise ImportError("Cannot import simxarm.") if not _has_gym: raise ImportError("Cannot import gym.") - import gym - from simxarm import TASKS + import gymnasium + + from lerobot.common.envs.simxarm.simxarm import TASKS if self.task not in TASKS: raise ValueError(f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}") @@ -63,7 +64,7 @@ class SimxarmEnv(AbstractEnv): self._env = TASKS[self.task]["env"]() num_actions = len(TASKS[self.task]["action_space"]) - self._action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,)) + self._action_space = gymnasium.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,)) self._action_padding = np.zeros((MAX_NUM_ACTIONS - num_actions), dtype=np.float32) if "w" not in TASKS[self.task]["action_space"]: self._action_padding[-1] = 1.0 @@ -230,4 +231,7 @@ class SimxarmEnv(AbstractEnv): def _set_seed(self, seed: Optional[int]): set_seed(seed) - self._env.seed(seed) + # self._env.seed(seed) + # self._env.action_space.seed(seed) + # self.set_seed(seed) + self._seed = seed diff --git a/lerobot/common/envs/simxarm/simxarm/__init__.py b/lerobot/common/envs/simxarm/simxarm/__init__.py new file mode 100644 index 00000000..870eb5a9 --- /dev/null +++ b/lerobot/common/envs/simxarm/simxarm/__init__.py @@ -0,0 +1,165 @@ +from collections import OrderedDict, deque + +import gym +import numpy as np +from gym.wrappers import TimeLimit + +from lerobot.common.envs.simxarm.simxarm.task.lift import Lift +from lerobot.common.envs.simxarm.simxarm.task.peg_in_box import PegInBox +from lerobot.common.envs.simxarm.simxarm.task.push import Push +from lerobot.common.envs.simxarm.simxarm.task.reach import Reach + +TASKS = OrderedDict( + ( + ( + "reach", + { + "env": Reach, + "action_space": "xyz", + "episode_length": 50, + "description": "Reach a target location with the end effector", + }, + ), + ( + "push", + { + "env": Push, + "action_space": "xyz", + "episode_length": 50, + "description": "Push a cube to a target location", + }, + ), + ( + "peg_in_box", + { + "env": PegInBox, + "action_space": "xyz", + "episode_length": 50, + "description": "Insert a peg into a box", + }, + ), + ( + "lift", + { + "env": Lift, + "action_space": "xyzw", + "episode_length": 50, + "description": "Lift a cube above a height threshold", + }, + ), + ) +) + + +class SimXarmWrapper(gym.Wrapper): + """ + A wrapper for the SimXarm environments. This wrapper is used to + convert the action and observation spaces to the correct format. + """ + + def __init__(self, env, task, obs_mode, image_size, action_repeat, frame_stack=1, channel_last=False): + super().__init__(env) + self._env = env + self.obs_mode = obs_mode + self.image_size = image_size + self.action_repeat = action_repeat + self.frame_stack = frame_stack + self._frames = deque([], maxlen=frame_stack) + self.channel_last = channel_last + self._max_episode_steps = task["episode_length"] // action_repeat + + image_shape = ( + (image_size, image_size, 3 * frame_stack) + if channel_last + else (3 * frame_stack, image_size, image_size) + ) + if obs_mode == "state": + self.observation_space = env.observation_space["observation"] + elif obs_mode == "rgb": + self.observation_space = gym.spaces.Box(low=0, high=255, shape=image_shape, dtype=np.uint8) + elif obs_mode == "all": + self.observation_space = gym.spaces.Dict( + state=gym.spaces.Box(low=-np.inf, high=np.inf, shape=(4,), dtype=np.float32), + rgb=gym.spaces.Box(low=0, high=255, shape=image_shape, dtype=np.uint8), + ) + else: + raise ValueError(f"Unknown obs_mode {obs_mode}. Must be one of [rgb, all, state]") + self.action_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(len(task["action_space"]),)) + self.action_padding = np.zeros(4 - len(task["action_space"]), dtype=np.float32) + if "w" not in task["action_space"]: + self.action_padding[-1] = 1.0 + + def _render_obs(self): + obs = self.render(mode="rgb_array", width=self.image_size, height=self.image_size) + if not self.channel_last: + obs = obs.transpose(2, 0, 1) + return obs.copy() + + def _update_frames(self, reset=False): + pixels = self._render_obs() + self._frames.append(pixels) + if reset: + for _ in range(1, self.frame_stack): + self._frames.append(pixels) + assert len(self._frames) == self.frame_stack + + def transform_obs(self, obs, reset=False): + if self.obs_mode == "state": + return obs["observation"] + elif self.obs_mode == "rgb": + self._update_frames(reset=reset) + rgb_obs = np.concatenate(list(self._frames), axis=-1 if self.channel_last else 0) + return rgb_obs + elif self.obs_mode == "all": + self._update_frames(reset=reset) + rgb_obs = np.concatenate(list(self._frames), axis=-1 if self.channel_last else 0) + return OrderedDict((("rgb", rgb_obs), ("state", self.robot_state))) + else: + raise ValueError(f"Unknown obs_mode {self.obs_mode}. Must be one of [rgb, all, state]") + + def reset(self): + return self.transform_obs(self._env.reset(), reset=True) + + def step(self, action): + action = np.concatenate([action, self.action_padding]) + reward = 0.0 + for _ in range(self.action_repeat): + obs, r, done, info = self._env.step(action) + reward += r + return self.transform_obs(obs), reward, done, info + + def render(self, mode="rgb_array", width=384, height=384, **kwargs): + return self._env.render(mode, width=width, height=height) + + @property + def state(self): + return self._env.robot_state + + +def make(task, obs_mode="state", image_size=84, action_repeat=1, frame_stack=1, channel_last=False, seed=0): + """ + Create a new environment. + Args: + task (str): The task to create an environment for. Must be one of: + - 'reach' + - 'push' + - 'peg-in-box' + - 'lift' + obs_mode (str): The observation mode to use. Must be one of: + - 'state': Only state observations + - 'rgb': RGB images + - 'all': RGB images and state observations + image_size (int): The size of the image observations + action_repeat (int): The number of times to repeat the action + seed (int): The random seed to use + Returns: + gym.Env: The environment + """ + if task not in TASKS: + raise ValueError(f"Unknown task {task}. Must be one of {list(TASKS.keys())}") + env = TASKS[task]["env"]() + env = TimeLimit(env, TASKS[task]["episode_length"]) + env = SimXarmWrapper(env, TASKS[task], obs_mode, image_size, action_repeat, frame_stack, channel_last) + env.seed(seed) + + return env diff --git a/lerobot/common/envs/simxarm/simxarm/task/__init__.py b/lerobot/common/envs/simxarm/simxarm/task/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lerobot/common/envs/simxarm/simxarm/task/assets/lift.xml b/lerobot/common/envs/simxarm/simxarm/task/assets/lift.xml new file mode 100644 index 00000000..92231f92 --- /dev/null +++ b/lerobot/common/envs/simxarm/simxarm/task/assets/lift.xml @@ -0,0 +1,53 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/base_link.stl b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/base_link.stl new file mode 100644 index 00000000..f1f52955 --- /dev/null +++ b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/base_link.stl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:21fb81ae7fba19e3c6b2d2ca60c8051712ba273357287eb5a397d92d61c7a736 +size 1211434 diff --git a/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/block_inner.stl b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/block_inner.stl new file mode 100644 index 00000000..6cb88945 --- /dev/null +++ b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/block_inner.stl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:be68ce180d11630a667a5f37f4dffcc3feebe4217d4bb3912c813b6d9ca3ec66 +size 3284 diff --git a/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/block_inner2.stl b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/block_inner2.stl new file mode 100644 index 00000000..dab55ef5 --- /dev/null +++ b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/block_inner2.stl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2c6448552bf6b1c4f17334d686a5320ce051bcdfe31431edf69303d8a570d1de +size 3284 diff --git a/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/block_outer.stl b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/block_outer.stl new file mode 100644 index 00000000..21cf11fa --- /dev/null +++ b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/block_outer.stl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:748b9e197e6521914f18d1f6383a36f211136b3f33f2ad2a8c11b9f921c2cf86 +size 6284 diff --git a/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/left_finger.stl b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/left_finger.stl new file mode 100644 index 00000000..6bf4e502 --- /dev/null +++ b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/left_finger.stl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a44756eb72f9c214cb37e61dc209cd7073fdff3e4271a7423476ef6fd090d2d4 +size 242684 diff --git a/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/left_inner_knuckle.stl b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/left_inner_knuckle.stl new file mode 100644 index 00000000..817c7e1d --- /dev/null +++ b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/left_inner_knuckle.stl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e8e48692ad26837bb3d6a97582c89784d09948fc09bfe4e5a59017859ff04dac +size 366284 diff --git a/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/left_outer_knuckle.stl b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/left_outer_knuckle.stl new file mode 100644 index 00000000..010c0f3b --- /dev/null +++ b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/left_outer_knuckle.stl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:501665812b08d67e764390db781e839adc6896a9540301d60adf606f57648921 +size 22284 diff --git a/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/link1.stl b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/link1.stl new file mode 100644 index 00000000..f2b676f2 --- /dev/null +++ b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/link1.stl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:34b541122df84d2ef5fcb91b715eb19659dc15ad8d44a191dde481f780265636 +size 184184 diff --git a/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/link2.stl b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/link2.stl new file mode 100644 index 00000000..bf93580c --- /dev/null +++ b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/link2.stl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:61e641cd47c169ecef779683332e00e4914db729bf02dfb61bfbe69351827455 +size 225584 diff --git a/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/link3.stl b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/link3.stl new file mode 100644 index 00000000..d316d233 --- /dev/null +++ b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/link3.stl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9e2798e7946dd70046c95455d5ba96392d0b54a6069caba91dc4ca66e1379b42 +size 237084 diff --git a/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/link4.stl b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/link4.stl new file mode 100644 index 00000000..f6d5fe94 --- /dev/null +++ b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/link4.stl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c757fee95f873191a0633c355c07a360032960771cabbd7593a6cdb0f1ffb089 +size 243684 diff --git a/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/link5.stl b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/link5.stl new file mode 100644 index 00000000..e037b8b9 --- /dev/null +++ b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/link5.stl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:715ad5787c5dab57589937fd47289882707b5e1eb997e340d567785b02f4ec90 +size 229084 diff --git a/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/link6.stl b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/link6.stl new file mode 100644 index 00000000..198c5300 --- /dev/null +++ b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/link6.stl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:85b320aa420497827223d16d492bba8de091173374e361396fc7a5dad7bdb0cb +size 399384 diff --git a/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/link7.stl b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/link7.stl new file mode 100644 index 00000000..ce9a39ac --- /dev/null +++ b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/link7.stl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:97115d848fbf802cb770cd9be639ae2af993103b9d9bbb0c50c943c738a36f18 +size 231684 diff --git a/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/link_base.stl b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/link_base.stl new file mode 100644 index 00000000..110b9531 --- /dev/null +++ b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/link_base.stl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f6fcbc18258090eb56c21cfb17baa5ae43abc98b1958cd366f3a73b9898fc7f0 +size 2106184 diff --git a/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/right_finger.stl b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/right_finger.stl new file mode 100644 index 00000000..03f26e9a --- /dev/null +++ b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/right_finger.stl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c5dee87c7f37baf554b8456ebfe0b3e8ed0b22b8938bd1add6505c2ad6d32c7d +size 242684 diff --git a/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/right_inner_knuckle.stl b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/right_inner_knuckle.stl new file mode 100644 index 00000000..8586f344 --- /dev/null +++ b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/right_inner_knuckle.stl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b41dd2c2c550281bf78d7cc6fa117b14786700e5c453560a0cb5fd6dfa0ffb3e +size 366284 diff --git a/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/right_outer_knuckle.stl b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/right_outer_knuckle.stl new file mode 100644 index 00000000..ae7afc25 --- /dev/null +++ b/lerobot/common/envs/simxarm/simxarm/task/assets/mesh/right_outer_knuckle.stl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:75ca1107d0a42a0f03802a9a49cab48419b31851ee8935f8f1ca06be1c1c91e8 +size 22284 diff --git a/lerobot/common/envs/simxarm/simxarm/task/assets/peg_in_box.xml b/lerobot/common/envs/simxarm/simxarm/task/assets/peg_in_box.xml new file mode 100644 index 00000000..0f85459f --- /dev/null +++ b/lerobot/common/envs/simxarm/simxarm/task/assets/peg_in_box.xml @@ -0,0 +1,74 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/lerobot/common/envs/simxarm/simxarm/task/assets/push.xml b/lerobot/common/envs/simxarm/simxarm/task/assets/push.xml new file mode 100644 index 00000000..42a78c8a --- /dev/null +++ b/lerobot/common/envs/simxarm/simxarm/task/assets/push.xml @@ -0,0 +1,54 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/lerobot/common/envs/simxarm/simxarm/task/assets/reach.xml b/lerobot/common/envs/simxarm/simxarm/task/assets/reach.xml new file mode 100644 index 00000000..ded6d209 --- /dev/null +++ b/lerobot/common/envs/simxarm/simxarm/task/assets/reach.xml @@ -0,0 +1,48 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/lerobot/common/envs/simxarm/simxarm/task/assets/shared.xml b/lerobot/common/envs/simxarm/simxarm/task/assets/shared.xml new file mode 100644 index 00000000..ee56f8f0 --- /dev/null +++ b/lerobot/common/envs/simxarm/simxarm/task/assets/shared.xml @@ -0,0 +1,51 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/lerobot/common/envs/simxarm/simxarm/task/assets/xarm.xml b/lerobot/common/envs/simxarm/simxarm/task/assets/xarm.xml new file mode 100644 index 00000000..023474d6 --- /dev/null +++ b/lerobot/common/envs/simxarm/simxarm/task/assets/xarm.xml @@ -0,0 +1,88 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/lerobot/common/envs/simxarm/simxarm/task/base.py b/lerobot/common/envs/simxarm/simxarm/task/base.py new file mode 100644 index 00000000..d5f54f72 --- /dev/null +++ b/lerobot/common/envs/simxarm/simxarm/task/base.py @@ -0,0 +1,170 @@ +import os + +import glfw +import mujoco +import numpy as np + +# import gym +# from gym.envs.robotics import robot_env +from gymnasium_robotics.envs import robot_env + +from lerobot.common.envs.simxarm.simxarm.task import mocap + + +class Base(robot_env.MujocoRobotEnv): + """ + Superclass for all simxarm environments. + Args: + xml_name (str): name of the xml environment file + gripper_rotation (list): initial rotation of the gripper (given as a quaternion) + """ + + def __init__(self, xml_name, gripper_rotation=None): + if gripper_rotation is None: + gripper_rotation = [0, 1, 0, 0] + self.gripper_rotation = np.array(gripper_rotation, dtype=np.float32) + self.center_of_table = np.array([1.655, 0.3, 0.63625]) + self.max_z = 1.2 + self.min_z = 0.2 + super().__init__( + model_path=os.path.join(os.path.dirname(__file__), "assets", xml_name + ".xml"), + n_substeps=20, + n_actions=4, + initial_qpos={}, + ) + + @property + def dt(self): + return self.n_substeps * self.model.opt.timestep + + @property + def eef(self): + return self._utils.get_site_xpos(self.model, self.data, "grasp") + + @property + def obj(self): + return self._utils.get_site_xpos(self.model, self.data, "object_site") + + @property + def robot_state(self): + gripper_angle = self._utils.get_joint_qpos(self.model, self.data, "right_outer_knuckle_joint") + return np.concatenate([self.eef, gripper_angle]) + + def is_success(self): + return NotImplementedError() + + def get_reward(self): + raise NotImplementedError() + + def _sample_goal(self): + raise NotImplementedError() + + def get_obs(self): + return self._get_obs() + + def _step_callback(self): + self.sim.forward() + + def _limit_gripper(self, gripper_pos, pos_ctrl): + if gripper_pos[0] > self.center_of_table[0] - 0.105 + 0.15: + pos_ctrl[0] = min(pos_ctrl[0], 0) + if gripper_pos[0] < self.center_of_table[0] - 0.105 - 0.3: + pos_ctrl[0] = max(pos_ctrl[0], 0) + if gripper_pos[1] > self.center_of_table[1] + 0.3: + pos_ctrl[1] = min(pos_ctrl[1], 0) + if gripper_pos[1] < self.center_of_table[1] - 0.3: + pos_ctrl[1] = max(pos_ctrl[1], 0) + if gripper_pos[2] > self.max_z: + pos_ctrl[2] = min(pos_ctrl[2], 0) + if gripper_pos[2] < self.min_z: + pos_ctrl[2] = max(pos_ctrl[2], 0) + return pos_ctrl + + def _apply_action(self, action): + assert action.shape == (4,) + action = action.copy() + pos_ctrl, gripper_ctrl = action[:3], action[3] + pos_ctrl = self._limit_gripper( + self._utils.get_site_xpos(self.model, self.data, "grasp"), pos_ctrl + ) * (1 / self.n_substeps) + gripper_ctrl = np.array([gripper_ctrl, gripper_ctrl]) + mocap.apply_action(self.sim, np.concatenate([pos_ctrl, self.gripper_rotation, gripper_ctrl])) + + def _viewer_setup(self): + body_id = self.sim.model.body_name2id("link7") + lookat = self.sim.data.body_xpos[body_id] + for idx, value in enumerate(lookat): + self.viewer.cam.lookat[idx] = value + self.viewer.cam.distance = 4.0 + self.viewer.cam.azimuth = 132.0 + self.viewer.cam.elevation = -14.0 + + def _render_callback(self): + # self.sim.forward() + self._mujoco.mj_forward(self.model, self.data) + + def _reset_sim(self): + # self.sim.set_state(self.initial_state) + self.data.time = self.initial_time + self.data.qpos[:] = np.copy(self.initial_qpos) + self.data.qvel[:] = np.copy(self.initial_qvel) + self._sample_goal() + for _ in range(10): + # self.sim.step() + self._mujoco.mj_forward(self.model, self.data) + return True + + def _set_gripper(self, gripper_pos, gripper_rotation): + # self.data.set_mocap_pos('robot0:mocap2', gripper_pos) + # self.data.set_mocap_quat('robot0:mocap2', gripper_rotation) + # self.data.set_joint_qpos('right_outer_knuckle_joint', 0) + self._utils.set_mocap_pos(self.model, self.data, "robot0:mocap", gripper_pos) + # self._utils.set_mocap_pos(self.model, self.data, "robot0:mocap", gripper_rotation) + self._utils.set_mocap_quat(self.model, self.data, "robot0:mocap", gripper_rotation) + self._utils.set_joint_qpos(self.model, self.data, "right_outer_knuckle_joint", 0) + self.data.qpos[10] = 0.0 + self.data.qpos[12] = 0.0 + + def _env_setup(self, initial_qpos): + for name, value in initial_qpos.items(): + # self.sim.data.set_joint_qpos(name, value) + self.data.set_joint_qpos(name, value) + mocap.reset(self.model, self.data) + mujoco.mj_forward(self.model, self.data) + # self.sim.forward() + self._sample_goal() + # self.sim.forward() + mujoco.mj_forward(self.model, self.data) + + def reset(self): + self._reset_sim() + return self._get_obs() + + def step(self, action): + assert action.shape == (4,) + assert self.action_space.contains(action), "{!r} ({}) invalid".format(action, type(action)) + self._apply_action(action) + for _ in range(2): + self.sim.step() + self._step_callback() + obs = self._get_obs() + reward = self.get_reward() + done = False + info = {"is_success": self.is_success(), "success": self.is_success()} + return obs, reward, done, info + + def render(self, mode="rgb_array", width=384, height=384): + self._render_callback() + # if mode == 'rgb_array': + # return self.sim.render(width, height, camera_name='camera0', depth=False)[::-1, :, :] + # elif mode == "human": + # self._get_viewer(mode).render() + return self.mujoco_renderer.render(mode) + + def close(self): + if self.viewer is not None: + # self.viewer.finish() + print("Closing window glfw") + glfw.destroy_window(self.viewer.window) + self.viewer = None + self._viewers = {} diff --git a/lerobot/common/envs/simxarm/simxarm/task/lift.py b/lerobot/common/envs/simxarm/simxarm/task/lift.py new file mode 100644 index 00000000..bd7fc500 --- /dev/null +++ b/lerobot/common/envs/simxarm/simxarm/task/lift.py @@ -0,0 +1,101 @@ +import numpy as np + +from lerobot.common.envs.simxarm.simxarm import Base + + +class Lift(Base): + def __init__(self): + self._z_threshold = 0.15 + super().__init__("lift") + + @property + def z_target(self): + return self._init_z + self._z_threshold + + def is_success(self): + return self.obj[2] >= self.z_target + + def get_reward(self): + reach_dist = np.linalg.norm(self.obj - self.eef) + reach_dist_xy = np.linalg.norm(self.obj[:-1] - self.eef[:-1]) + pick_completed = self.obj[2] >= (self.z_target - 0.01) + obj_dropped = (self.obj[2] < (self._init_z + 0.005)) and (reach_dist > 0.02) + + # Reach + if reach_dist < 0.05: + reach_reward = -reach_dist + max(self._action[-1], 0) / 50 + elif reach_dist_xy < 0.05: + reach_reward = -reach_dist + else: + z_bonus = np.linalg.norm(np.linalg.norm(self.obj[-1] - self.eef[-1])) + reach_reward = -reach_dist - 2 * z_bonus + + # Pick + if pick_completed and not obj_dropped: + pick_reward = self.z_target + elif (reach_dist < 0.1) and (self.obj[2] > (self._init_z + 0.005)): + pick_reward = min(self.z_target, self.obj[2]) + else: + pick_reward = 0 + + return reach_reward / 100 + pick_reward + + def _get_obs(self): + eef_velp = self._utils.get_site_xvelp(self.model, self.data, "grasp") * self.dt + gripper_angle = self._utils.get_joint_qpos(self.model, self.data, "right_outer_knuckle_joint") + eef = self.eef - self.center_of_table + + obj = self.obj - self.center_of_table + obj_rot = self._utils.get_joint_qpos(self.model, self.data, "object_joint0")[-4:] + obj_velp = self._utils.get_site_xvelp(self.model, self.data, "object_site") * self.dt + obj_velr = self._utils.get_site_xvelr(self.model, self.data, "object_site") * self.dt + + obs = np.concatenate( + [ + eef, + eef_velp, + obj, + obj_rot, + obj_velp, + obj_velr, + eef - obj, + np.array( + [ + np.linalg.norm(eef - obj), + np.linalg.norm(eef[:-1] - obj[:-1]), + self.z_target, + self.z_target - obj[-1], + self.z_target - eef[-1], + ] + ), + gripper_angle, + ], + axis=0, + ) + return {"observation": obs, "state": eef, "achieved_goal": eef, "desired_goal": eef} + + def _sample_goal(self): + # Gripper + gripper_pos = np.array([1.280, 0.295, 0.735]) + self.np_random.uniform(-0.05, 0.05, size=3) + super()._set_gripper(gripper_pos, self.gripper_rotation) + + # Object + object_pos = self.center_of_table - np.array([0.15, 0.10, 0.07]) + object_pos[0] += self.np_random.uniform(-0.05, 0.05, size=1) + object_pos[1] += self.np_random.uniform(-0.05, 0.05, size=1) + object_qpos = self._utils.get_joint_qpos(self.model, self.data, "object_joint0") + object_qpos[:3] = object_pos + # self.sim.data.set_joint_qpos('object_joint0', object_qpos) + self._utils.set_joint_qpos(self.model, self.data, "object_joint0", object_qpos) + self._init_z = object_pos[2] + + # Goal + return object_pos + np.array([0, 0, self._z_threshold]) + + def reset(self): + self._action = np.zeros(4) + return super().reset() + + def step(self, action): + self._action = action.copy() + return super().step(action) diff --git a/lerobot/common/envs/simxarm/simxarm/task/mocap.py b/lerobot/common/envs/simxarm/simxarm/task/mocap.py new file mode 100644 index 00000000..45722f13 --- /dev/null +++ b/lerobot/common/envs/simxarm/simxarm/task/mocap.py @@ -0,0 +1,68 @@ +# import mujoco_py +import mujoco +import numpy as np + + +def apply_action(sim, action): + if sim.model.nmocap > 0: + pos_action, gripper_action = np.split(action, (sim.model.nmocap * 7,)) + if sim.data.ctrl is not None: + for i in range(gripper_action.shape[0]): + sim.data.ctrl[i] = gripper_action[i] + pos_action = pos_action.reshape(sim.model.nmocap, 7) + pos_delta, quat_delta = pos_action[:, :3], pos_action[:, 3:] + reset_mocap2body_xpos(sim) + sim.data.mocap_pos[:] = sim.data.mocap_pos + pos_delta + sim.data.mocap_quat[:] = sim.data.mocap_quat + quat_delta + + +def reset(model, data): + if model.nmocap > 0 and model.eq_data is not None: + for i in range(model.eq_data.shape[0]): + # if sim.model.eq_type[i] == mujoco_py.const.EQ_WELD: + if model.eq_type[i] == mujoco.mjtEq.mjEQ_WELD: + # model.eq_data[i, :] = np.array([0., 0., 0., 1., 0., 0., 0.]) + model.eq_data[i, :] = np.array( + [ + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ] + ) + # sim.forward() + mujoco.mj_forward(model, data) + + +def reset_mocap2body_xpos(sim): + if sim.model.eq_type is None or sim.model.eq_obj1id is None or sim.model.eq_obj2id is None: + return + + # For all weld constraints + for eq_type, obj1_id, obj2_id in zip( + sim.model.eq_type, sim.model.eq_obj1id, sim.model.eq_obj2id, strict=False + ): + # if eq_type != mujoco_py.const.EQ_WELD: + if eq_type != mujoco.mjtEq.mjEQ_WELD: + continue + body2 = sim.model.body_id2name(obj2_id) + if body2 == "B0" or body2 == "B9" or body2 == "B1": + continue + mocap_id = sim.model.body_mocapid[obj1_id] + if mocap_id != -1: + # obj1 is the mocap, obj2 is the welded body + body_idx = obj2_id + else: + # obj2 is the mocap, obj1 is the welded body + mocap_id = sim.model.body_mocapid[obj2_id] + body_idx = obj1_id + assert mocap_id != -1 + sim.data.mocap_pos[mocap_id][:] = sim.data.body_xpos[body_idx] + sim.data.mocap_quat[mocap_id][:] = sim.data.body_xquat[body_idx] diff --git a/lerobot/common/envs/simxarm/simxarm/task/peg_in_box.py b/lerobot/common/envs/simxarm/simxarm/task/peg_in_box.py new file mode 100644 index 00000000..42e41520 --- /dev/null +++ b/lerobot/common/envs/simxarm/simxarm/task/peg_in_box.py @@ -0,0 +1,86 @@ +import numpy as np + +from lerobot.common.envs.simxarm.simxarm import Base + + +class PegInBox(Base): + def __init__(self): + super().__init__("peg_in_box") + + def _reset_sim(self): + self._act_magnitude = 0 + super()._reset_sim() + for _ in range(10): + self._apply_action(np.array([0, 0, 0, 1], dtype=np.float32)) + self.sim.step() + + @property + def box(self): + return self.sim.data.get_site_xpos("box_site") + + def is_success(self): + return np.linalg.norm(self.obj - self.box) <= 0.05 + + def get_reward(self): + dist_xy = np.linalg.norm(self.obj[:2] - self.box[:2]) + dist_xyz = np.linalg.norm(self.obj - self.box) + return float(dist_xy <= 0.045) * (2 - 6 * dist_xyz) - 0.2 * np.square(self._act_magnitude) - dist_xy + + def _get_obs(self): + eef_velp = self.sim.data.get_site_xvelp("grasp") * self.dt + gripper_angle = self.sim.data.get_joint_qpos("right_outer_knuckle_joint") + eef, box = self.eef - self.center_of_table, self.box - self.center_of_table + + obj = self.obj - self.center_of_table + obj_rot = self.sim.data.get_joint_qpos("object_joint0")[-4:] + obj_velp = self.sim.data.get_site_xvelp("object_site") * self.dt + obj_velr = self.sim.data.get_site_xvelr("object_site") * self.dt + + obs = np.concatenate( + [ + eef, + eef_velp, + box, + obj, + obj_rot, + obj_velp, + obj_velr, + eef - box, + eef - obj, + obj - box, + np.array( + [ + np.linalg.norm(eef - box), + np.linalg.norm(eef - obj), + np.linalg.norm(obj - box), + gripper_angle, + ] + ), + ], + axis=0, + ) + return {"observation": obs, "state": eef, "achieved_goal": eef, "desired_goal": box} + + def _sample_goal(self): + # Gripper + gripper_pos = np.array([1.280, 0.295, 0.9]) + self.np_random.uniform(-0.05, 0.05, size=3) + super()._set_gripper(gripper_pos, self.gripper_rotation) + + # Object + object_pos = gripper_pos - np.array([0, 0, 0.06]) + self.np_random.uniform(-0.005, 0.005, size=3) + object_qpos = self.sim.data.get_joint_qpos("object_joint0") + object_qpos[:3] = object_pos + self.sim.data.set_joint_qpos("object_joint0", object_qpos) + + # Box + box_pos = np.array([1.61, 0.18, 0.58]) + box_pos[:2] += self.np_random.uniform(-0.11, 0.11, size=2) + box_qpos = self.sim.data.get_joint_qpos("box_joint0") + box_qpos[:3] = box_pos + self.sim.data.set_joint_qpos("box_joint0", box_qpos) + + return self.box + + def step(self, action): + self._act_magnitude = np.linalg.norm(action[:3]) + return super().step(action) diff --git a/lerobot/common/envs/simxarm/simxarm/task/push.py b/lerobot/common/envs/simxarm/simxarm/task/push.py new file mode 100644 index 00000000..36c4a550 --- /dev/null +++ b/lerobot/common/envs/simxarm/simxarm/task/push.py @@ -0,0 +1,78 @@ +import numpy as np + +from lerobot.common.envs.simxarm.simxarm import Base + + +class Push(Base): + def __init__(self): + super().__init__("push") + + def _reset_sim(self): + self._act_magnitude = 0 + super()._reset_sim() + + def is_success(self): + return np.linalg.norm(self.obj - self.goal) <= 0.05 + + def get_reward(self): + dist = np.linalg.norm(self.obj - self.goal) + penalty = self._act_magnitude**2 + return -(dist + 0.15 * penalty) + + def _get_obs(self): + eef_velp = self.sim.data.get_site_xvelp("grasp") * self.dt + gripper_angle = self.sim.data.get_joint_qpos("right_outer_knuckle_joint") + eef, goal = self.eef - self.center_of_table, self.goal - self.center_of_table + + obj = self.obj - self.center_of_table + obj_rot = self.sim.data.get_joint_qpos("object_joint0")[-4:] + obj_velp = self.sim.data.get_site_xvelp("object_site") * self.dt + obj_velr = self.sim.data.get_site_xvelr("object_site") * self.dt + + obs = np.concatenate( + [ + eef, + eef_velp, + goal, + obj, + obj_rot, + obj_velp, + obj_velr, + eef - goal, + eef - obj, + obj - goal, + np.array( + [ + np.linalg.norm(eef - goal), + np.linalg.norm(eef - obj), + np.linalg.norm(obj - goal), + gripper_angle, + ] + ), + ], + axis=0, + ) + return {"observation": obs, "state": eef, "achieved_goal": eef, "desired_goal": goal} + + def _sample_goal(self): + # Gripper + gripper_pos = np.array([1.280, 0.295, 0.735]) + self.np_random.uniform(-0.05, 0.05, size=3) + super()._set_gripper(gripper_pos, self.gripper_rotation) + + # Object + object_pos = self.center_of_table - np.array([0.25, 0, 0.07]) + object_pos[0] += self.np_random.uniform(-0.08, 0.08, size=1) + object_pos[1] += self.np_random.uniform(-0.08, 0.08, size=1) + object_qpos = self.sim.data.get_joint_qpos("object_joint0") + object_qpos[:3] = object_pos + self.sim.data.set_joint_qpos("object_joint0", object_qpos) + + # Goal + self.goal = np.array([1.600, 0.200, 0.545]) + self.goal[:2] += self.np_random.uniform(-0.1, 0.1, size=2) + self.sim.model.site_pos[self.sim.model.site_name2id("target0")] = self.goal + return self.goal + + def step(self, action): + self._act_magnitude = np.linalg.norm(action[:3]) + return super().step(action) diff --git a/lerobot/common/envs/simxarm/simxarm/task/reach.py b/lerobot/common/envs/simxarm/simxarm/task/reach.py new file mode 100644 index 00000000..941a586f --- /dev/null +++ b/lerobot/common/envs/simxarm/simxarm/task/reach.py @@ -0,0 +1,44 @@ +import numpy as np + +from lerobot.common.envs.simxarm.simxarm import Base + + +class Reach(Base): + def __init__(self): + super().__init__("reach") + + def _reset_sim(self): + self._act_magnitude = 0 + super()._reset_sim() + + def is_success(self): + return np.linalg.norm(self.eef - self.goal) <= 0.05 + + def get_reward(self): + dist = np.linalg.norm(self.eef - self.goal) + penalty = self._act_magnitude**2 + return -(dist + 0.15 * penalty) + + def _get_obs(self): + eef_velp = self.sim.data.get_site_xvelp("grasp") * self.dt + gripper_angle = self.sim.data.get_joint_qpos("right_outer_knuckle_joint") + eef, goal = self.eef - self.center_of_table, self.goal - self.center_of_table + obs = np.concatenate( + [eef, eef_velp, goal, eef - goal, np.array([np.linalg.norm(eef - goal), gripper_angle])], axis=0 + ) + return {"observation": obs, "state": eef, "achieved_goal": eef, "desired_goal": goal} + + def _sample_goal(self): + # Gripper + gripper_pos = np.array([1.280, 0.295, 0.735]) + self.np_random.uniform(-0.05, 0.05, size=3) + super()._set_gripper(gripper_pos, self.gripper_rotation) + + # Goal + self.goal = np.array([1.550, 0.287, 0.580]) + self.goal[:2] += self.np_random.uniform(-0.125, 0.125, size=2) + self.sim.model.site_pos[self.sim.model.site_name2id("target0")] = self.goal + return self.goal + + def step(self, action): + self._act_magnitude = np.linalg.norm(action[:3]) + return super().step(action) diff --git a/lerobot/configs/policy/tdmpc.yaml b/lerobot/configs/policy/tdmpc.yaml index 16b7018e..ff0e6b04 100644 --- a/lerobot/configs/policy/tdmpc.yaml +++ b/lerobot/configs/policy/tdmpc.yaml @@ -1,6 +1,7 @@ # @package _global_ n_action_steps: 1 +n_obs_steps: 1 policy: name: tdmpc diff --git a/poetry.lock b/poetry.lock index d2d39e7a..8bfdeb9b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -338,73 +338,6 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] -[[package]] -name = "cython" -version = "3.0.9" -description = "The Cython compiler for writing C extensions in the Python language." -optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -files = [ - {file = "Cython-3.0.9-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:296bd30d4445ac61b66c9d766567f6e81a6e262835d261e903c60c891a6729d3"}, - {file = "Cython-3.0.9-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f496b52845cb45568a69d6359a2c335135233003e708ea02155c10ce3548aa89"}, - {file = "Cython-3.0.9-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:858c3766b9aa3ab8a413392c72bbab1c144a9766b7c7bfdef64e2e414363fa0c"}, - {file = "Cython-3.0.9-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c0eb1e6ef036028a52525fd9a012a556f6dd4788a0e8755fe864ba0e70cde2ff"}, - {file = "Cython-3.0.9-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c8191941073ea5896321de3c8c958fd66e5f304b0cd1f22c59edd0b86c4dd90d"}, - {file = "Cython-3.0.9-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e32b016030bc72a8a22a1f21f470a2f57573761a4f00fbfe8347263f4fbdb9f1"}, - {file = "Cython-3.0.9-cp310-cp310-win32.whl", hash = "sha256:d6f3ff1cd6123973fe03e0fb8ee936622f976c0c41138969975824d08886572b"}, - {file = "Cython-3.0.9-cp310-cp310-win_amd64.whl", hash = "sha256:56f3b643dbe14449248bbeb9a63fe3878a24256664bc8c8ef6efd45d102596d8"}, - {file = "Cython-3.0.9-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:35e6665a20d6b8a152d72b7fd87dbb2af6bb6b18a235b71add68122d594dbd41"}, - {file = "Cython-3.0.9-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f92f4960c40ad027bd8c364c50db11104eadc59ffeb9e5b7f605ca2f05946e20"}, - {file = "Cython-3.0.9-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:38df37d0e732fbd9a2fef898788492e82b770c33d1e4ed12444bbc8a3b3f89c0"}, - {file = "Cython-3.0.9-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ad7fd88ebaeaf2e76fd729a8919fae80dab3d6ac0005e28494261d52ff347a8f"}, - {file = "Cython-3.0.9-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1365d5f76bf4d19df3d19ce932584c9bb76e9fb096185168918ef9b36e06bfa4"}, - {file = "Cython-3.0.9-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c232e7f279388ac9625c3e5a5a9f0078a9334959c5d6458052c65bbbba895e1e"}, - {file = "Cython-3.0.9-cp311-cp311-win32.whl", hash = "sha256:357e2fad46a25030b0c0496487e01a9dc0fdd0c09df0897f554d8ba3c1bc4872"}, - {file = "Cython-3.0.9-cp311-cp311-win_amd64.whl", hash = "sha256:1315aee506506e8d69cf6631d8769e6b10131fdcc0eb66df2698f2a3ddaeeff2"}, - {file = "Cython-3.0.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:157973807c2796addbed5fbc4d9c882ab34bbc60dc297ca729504901479d5df7"}, - {file = "Cython-3.0.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:00b105b5d050645dd59e6767bc0f18b48a4aa11c85f42ec7dd8181606f4059e3"}, - {file = "Cython-3.0.9-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac5536d09bef240cae0416d5a703d298b74c7bbc397da803ac9d344e732d4369"}, - {file = "Cython-3.0.9-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09c44501d476d16aaa4cbc29c87f8c0f54fc20e69b650d59cbfa4863426fc70c"}, - {file = "Cython-3.0.9-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:cc9c3b9f20d8e298618e5ccd32083ca386e785b08f9893fbec4c50b6b85be772"}, - {file = "Cython-3.0.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a30d96938c633e3ec37000ac3796525da71254ef109e66bdfd78f29891af6454"}, - {file = "Cython-3.0.9-cp312-cp312-win32.whl", hash = "sha256:757ca93bdd80702546df4d610d2494ef2e74249cac4d5ba9464589fb464bd8a3"}, - {file = "Cython-3.0.9-cp312-cp312-win_amd64.whl", hash = "sha256:1dc320a9905ab95414013f6de805efbff9e17bb5fb3b90bbac533f017bec8136"}, - {file = "Cython-3.0.9-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:4ae349960ebe0da0d33724eaa7f1eb866688fe5434cc67ce4dbc06d6a719fbfc"}, - {file = "Cython-3.0.9-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63d2537bf688247f76ded6dee28ebd26274f019309aef1eb4f2f9c5c482fde2d"}, - {file = "Cython-3.0.9-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36f5a2dfc724bea1f710b649f02d802d80fc18320c8e6396684ba4a48412445a"}, - {file = "Cython-3.0.9-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:deaf4197d4b0bcd5714a497158ea96a2bd6d0f9636095437448f7e06453cc83d"}, - {file = "Cython-3.0.9-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:000af6deb7412eb7ac0c635ff5e637fb8725dd0a7b88cc58dfc2b3de14e701c4"}, - {file = "Cython-3.0.9-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:15c7f5c2d35bed9aa5f2a51eaac0df23ae72f2dbacf62fc672dd6bfaa75d2d6f"}, - {file = "Cython-3.0.9-cp36-cp36m-win32.whl", hash = "sha256:f49aa4970cd3bec66ac22e701def16dca2a49c59cceba519898dd7526e0be2c0"}, - {file = "Cython-3.0.9-cp36-cp36m-win_amd64.whl", hash = "sha256:4558814fa025b193058d42eeee498a53d6b04b2980d01339fc2444b23fd98e58"}, - {file = "Cython-3.0.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:539cd1d74fd61f6cfc310fa6bbbad5adc144627f2b7486a07075d4e002fd6aad"}, - {file = "Cython-3.0.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c3232926cd406ee02eabb732206f6e882c3aed9d58f0fea764013d9240405bcf"}, - {file = "Cython-3.0.9-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:33b6ac376538a7fc8c567b85d3c71504308a9318702ec0485dd66c059f3165cb"}, - {file = "Cython-3.0.9-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2cc92504b5d22ac66031ffb827bd3a967fc75a5f0f76ab48bce62df19be6fdfd"}, - {file = "Cython-3.0.9-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:22b8fae756c5c0d8968691bed520876de452f216c28ec896a00739a12dba3bd9"}, - {file = "Cython-3.0.9-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:9cda0d92a09f3520f29bd91009f1194ba9600777c02c30c6d2d4ac65fb63e40d"}, - {file = "Cython-3.0.9-cp37-cp37m-win32.whl", hash = "sha256:ec612418490941ed16c50c8d3784c7bdc4c4b2a10c361259871790b02ec8c1db"}, - {file = "Cython-3.0.9-cp37-cp37m-win_amd64.whl", hash = "sha256:976c8d2bedc91ff6493fc973d38b2dc01020324039e2af0e049704a8e1b22936"}, - {file = "Cython-3.0.9-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5055988b007c92256b6e9896441c3055556038c3497fcbf8c921a6c1fce90719"}, - {file = "Cython-3.0.9-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9360606d964c2d0492a866464efcf9d0a92715644eede3f6a2aa696de54a137"}, - {file = "Cython-3.0.9-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02c6e809f060bed073dc7cba1648077fe3b68208863d517c8b39f3920eecf9dd"}, - {file = "Cython-3.0.9-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:95ed792c966f969cea7489c32ff90150b415c1f3567db8d5a9d489c7c1602dac"}, - {file = "Cython-3.0.9-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8edd59d22950b400b03ca78d27dc694d2836a92ef0cac4f64cb4b2ff902f7e25"}, - {file = "Cython-3.0.9-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:4cf0ed273bf60e97922fcbbdd380c39693922a597760160b4b4355e6078ca188"}, - {file = "Cython-3.0.9-cp38-cp38-win32.whl", hash = "sha256:5eb9bd4ae12ebb2bc79a193d95aacf090fbd8d7013e11ed5412711650cb34934"}, - {file = "Cython-3.0.9-cp38-cp38-win_amd64.whl", hash = "sha256:44457279da56e0f829bb1fc5a5dc0836e5d498dbcf9b2324f32f7cc9d2ec6569"}, - {file = "Cython-3.0.9-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c4b419a1adc2af43f4660e2f6eaf1e4fac2dbac59490771eb8ac3d6063f22356"}, - {file = "Cython-3.0.9-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f836192140f033b2319a0128936367c295c2b32e23df05b03b672a6015757ea"}, - {file = "Cython-3.0.9-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2fd198c1a7f8e9382904d622cc0efa3c184605881fd5262c64cbb7168c4c1ec5"}, - {file = "Cython-3.0.9-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a274fe9ca5c53fafbcf5c8f262f8ad6896206a466f0eeb40aaf36a7951e957c0"}, - {file = "Cython-3.0.9-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:158c38360bbc5063341b1e78d3737f1251050f89f58a3df0d10fb171c44262be"}, - {file = "Cython-3.0.9-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8bf30b045f7deda0014b042c1b41c1d272facc762ab657529e3b05505888e878"}, - {file = "Cython-3.0.9-cp39-cp39-win32.whl", hash = "sha256:9a001fd95c140c94d934078544ff60a3c46aca2dc86e75a76e4121d3cd1f4b33"}, - {file = "Cython-3.0.9-cp39-cp39-win_amd64.whl", hash = "sha256:530c01c4aebba709c0ec9c7ecefe07177d0b9fd7ffee29450a118d92192ccbdf"}, - {file = "Cython-3.0.9-py2.py3-none-any.whl", hash = "sha256:bf96417714353c5454c2e3238fca9338599330cf51625cdc1ca698684465646f"}, - {file = "Cython-3.0.9.tar.gz", hash = "sha256:a2d354f059d1f055d34cfaa62c5b68bc78ac2ceab6407148d47fb508cf3ba4f3"}, -] - [[package]] name = "debugpy" version = "1.8.1" @@ -639,6 +572,17 @@ files = [ [package.extras] test = ["pytest (>=6)"] +[[package]] +name = "farama-notifications" +version = "0.0.4" +description = "Notifications for all Farama Foundation maintained libraries." +optional = false +python-versions = "*" +files = [ + {file = "Farama-Notifications-0.0.4.tar.gz", hash = "sha256:13fceff2d14314cf80703c8266462ebf3733c7d165336eee998fc58e545efd18"}, + {file = "Farama_Notifications-0.0.4-py3-none-any.whl", hash = "sha256:14de931035a41961f7c056361dc7f980762a143d05791ef5794a751a2caf05ae"}, +] + [[package]] name = "fasteners" version = "0.19" @@ -877,6 +821,59 @@ files = [ {file = "gym_notices-0.0.8-py3-none-any.whl", hash = "sha256:e5f82e00823a166747b4c2a07de63b6560b1acb880638547e0cabf825a01e463"}, ] +[[package]] +name = "gymnasium" +version = "0.29.1" +description = "A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym)." +optional = false +python-versions = ">=3.8" +files = [ + {file = "gymnasium-0.29.1-py3-none-any.whl", hash = "sha256:61c3384b5575985bb7f85e43213bcb40f36fcdff388cae6bc229304c71f2843e"}, + {file = "gymnasium-0.29.1.tar.gz", hash = "sha256:1a532752efcb7590478b1cc7aa04f608eb7a2fdad5570cd217b66b6a35274bb1"}, +] + +[package.dependencies] +cloudpickle = ">=1.2.0" +farama-notifications = ">=0.0.1" +numpy = ">=1.21.0" +typing-extensions = ">=4.3.0" + +[package.extras] +accept-rom-license = ["autorom[accept-rom-license] (>=0.4.2,<0.5.0)"] +all = ["box2d-py (==2.3.5)", "cython (<3)", "imageio (>=2.14.1)", "jax (>=0.4.0)", "jaxlib (>=0.4.0)", "lz4 (>=3.1.0)", "matplotlib (>=3.0)", "moviepy (>=1.0.0)", "mujoco (>=2.3.3)", "mujoco-py (>=2.1,<2.2)", "opencv-python (>=3.0)", "pygame (>=2.1.3)", "shimmy[atari] (>=0.1.0,<1.0)", "swig (==4.*)", "torch (>=1.0.0)"] +atari = ["shimmy[atari] (>=0.1.0,<1.0)"] +box2d = ["box2d-py (==2.3.5)", "pygame (>=2.1.3)", "swig (==4.*)"] +classic-control = ["pygame (>=2.1.3)", "pygame (>=2.1.3)"] +jax = ["jax (>=0.4.0)", "jaxlib (>=0.4.0)"] +mujoco = ["imageio (>=2.14.1)", "mujoco (>=2.3.3)"] +mujoco-py = ["cython (<3)", "cython (<3)", "mujoco-py (>=2.1,<2.2)", "mujoco-py (>=2.1,<2.2)"] +other = ["lz4 (>=3.1.0)", "matplotlib (>=3.0)", "moviepy (>=1.0.0)", "opencv-python (>=3.0)", "torch (>=1.0.0)"] +testing = ["pytest (==7.1.3)", "scipy (>=1.7.3)"] +toy-text = ["pygame (>=2.1.3)", "pygame (>=2.1.3)"] + +[[package]] +name = "gymnasium-robotics" +version = "1.2.4" +description = "Robotics environments for the Gymnasium repo." +optional = false +python-versions = ">=3.8" +files = [ + {file = "gymnasium-robotics-1.2.4.tar.gz", hash = "sha256:d304192b066f8b800599dfbe3d9d90bba9b761ee884472bdc4d05968a8bc61cb"}, + {file = "gymnasium_robotics-1.2.4-py3-none-any.whl", hash = "sha256:c2cb23e087ca0280ae6802837eb7b3a6d14e5bd24c00803ab09f015fcff3eef5"}, +] + +[package.dependencies] +gymnasium = ">=0.26" +imageio = "*" +Jinja2 = ">=3.0.3" +mujoco = ">=2.3.3,<3.0" +numpy = ">=1.21.0" +PettingZoo = ">=1.23.0" + +[package.extras] +mujoco-py = ["cython (<3)", "mujoco-py (>=2.1,<2.2)"] +testing = ["Jinja2 (>=3.0.3)", "PettingZoo (>=1.23.0)", "cython (<3)", "mujoco-py (>=2.1,<2.2)", "pytest (==7.0.1)"] + [[package]] name = "h5py" version = "3.10.0" @@ -1506,25 +1503,6 @@ glfw = "*" numpy = "*" pyopengl = "*" -[[package]] -name = "mujoco-py" -version = "2.1.2.14" -description = "" -optional = false -python-versions = ">=3.6" -files = [ - {file = "mujoco-py-2.1.2.14.tar.gz", hash = "sha256:eb5b14485acf80a3cf8c15f4b080c6a28a9f79e68869aa696d16cbd51ea7706f"}, - {file = "mujoco_py-2.1.2.14-py3-none-any.whl", hash = "sha256:37c0b41bc0153a8a0eb3663103a67c60f65467753f74e4ff6e68b879f3e3a71f"}, -] - -[package.dependencies] -cffi = ">=1.10" -Cython = ">=0.27.2" -fasteners = ">=0.15,<1.0" -glfw = ">=1.4.0" -imageio = ">=2.1.2" -numpy = ">=1.11" - [[package]] name = "networkx" version = "3.2.1" @@ -1940,6 +1918,31 @@ sql-other = ["SQLAlchemy (>=2.0.0)", "adbc-driver-postgresql (>=0.8.0)", "adbc-d test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-xdist (>=2.2.0)"] xml = ["lxml (>=4.9.2)"] +[[package]] +name = "pettingzoo" +version = "1.24.3" +description = "Gymnasium for multi-agent reinforcement learning." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pettingzoo-1.24.3-py3-none-any.whl", hash = "sha256:23ed90517d2e8a7098bdaf5e31234b3a7f7b73ca578d70d1ca7b9d0cb0e37982"}, + {file = "pettingzoo-1.24.3.tar.gz", hash = "sha256:91f9094f18e06fb74b98f4099cd22e8ae4396125e51719d50b30c9f1c7ab07e6"}, +] + +[package.dependencies] +gymnasium = ">=0.28.0" +numpy = ">=1.21.0" + +[package.extras] +all = ["box2d-py (==2.3.5)", "chess (==1.9.4)", "multi-agent-ale-py (==0.1.11)", "pillow (>=8.0.1)", "pygame (==2.3.0)", "pymunk (==6.2.0)", "rlcard (==1.0.5)", "scipy (>=1.4.1)", "shimmy[openspiel] (>=1.2.0)"] +atari = ["multi-agent-ale-py (==0.1.11)", "pygame (==2.3.0)"] +butterfly = ["pygame (==2.3.0)", "pymunk (==6.2.0)"] +classic = ["chess (==1.9.4)", "pygame (==2.3.0)", "rlcard (==1.0.5)", "shimmy[openspiel] (>=1.2.0)"] +mpe = ["pygame (==2.3.0)"] +other = ["pillow (>=8.0.1)"] +sisl = ["box2d-py (==2.3.5)", "pygame (==2.3.0)", "pymunk (==6.2.0)", "scipy (>=1.4.1)"] +testing = ["AutoROM", "pre-commit", "pynput", "pytest", "pytest-cov", "pytest-markdown-docs", "pytest-xdist"] + [[package]] name = "pillow" version = "10.2.0" @@ -3510,4 +3513,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "1a45c808e1c48bcbf4319d4cf6876771b7d50f40a5a8968a8b7f3af36192bf34" +content-hash = "abe6fc1c5b99d6f51f2efb0adda0e7cd1fcfe7b2d789879dafa441869e555745" diff --git a/pyproject.toml b/pyproject.toml index 7e9996a0..5f6c9456 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,6 @@ packages = [{include = "lerobot"}] [tool.poetry.dependencies] python = "^3.10" -cython = "^3.0.8" termcolor = "^2.4.0" omegaconf = "^2.3.0" dm-env = "^1.6" @@ -43,7 +42,6 @@ torch = "^2.2.1" tensordict = {git = "https://github.com/pytorch/tensordict"} torchrl = {git = "https://github.com/pytorch/rl", rev = "13bef426dcfa5887c6e5034a6e9697993fa92c37"} mujoco = "2.3.7" -mujoco-py = "^2.1.2.14" gym = "^0.26.2" opencv-python = "^4.9.0.80" diffusers = "^0.26.3" @@ -52,6 +50,7 @@ h5py = "^3.10.0" dm-control = "1.0.14" huggingface-hub = {extras = ["hf-transfer"], version = "^0.21.4"} robomimic = "0.2.0" +gymnasium-robotics = "^1.2.4" [tool.poetry.group.dev.dependencies] diff --git a/tests/test_envs.py b/tests/test_envs.py index 7776ba3c..8931cf52 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -7,7 +7,7 @@ from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.envs.factory import make_env from lerobot.common.envs.pusht.env import PushtEnv -from lerobot.common.envs.simxarm import SimxarmEnv +from lerobot.common.envs.simxarm.env import SimxarmEnv from .utils import DEVICE, init_config