From ab3cd3a7ba318e5eefea67a47b63b2f9688923dd Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Fri, 5 Apr 2024 15:35:20 +0200 Subject: [PATCH] (WIP) Add gym-xarm --- lerobot/common/envs/factory.py | 12 ++++++++++-- poetry.lock | 25 ++++++++++++++++++++++-- pyproject.toml | 4 ++++ tests/test_envs.py | 35 +++++++++++++++++----------------- 4 files changed, 54 insertions(+), 22 deletions(-) diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 788af3cb..ab3e9294 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -9,9 +9,17 @@ def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv: kwargs = {} if cfg.env.name == "simxarm": - kwargs["task"] = cfg.env.task + import gym_xarm # noqa: F401 + + assert cfg.env.task == "lift" + env_fn = lambda: gym.make( + "gym_xarm/XarmLift-v0", + render_mode="rgb_array", + max_episode_steps=cfg.env.episode_length, + **kwargs, + ) elif cfg.env.name == "pusht": - import gym_pusht # noqa + import gym_pusht # noqa: F401 # assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range." kwargs.update( diff --git a/poetry.lock b/poetry.lock index 8fb6b7a7..b9e31930 100644 --- a/poetry.lock +++ b/poetry.lock @@ -900,7 +900,27 @@ shapely = "^2.0.3" type = "git" url = "git@github.com:huggingface/gym-pusht.git" reference = "HEAD" -resolved_reference = "d7e1a39a31b1368741e9674791007d7cccf046a3" +resolved_reference = "0fe4449cca5a2b08f529f7a07fbf5b9df24962ec" + +[[package]] +name = "gym-xarm" +version = "0.1.0" +description = "A gym environment for xArm" +optional = true +python-versions = "^3.10" +files = [] +develop = false + +[package.dependencies] +gymnasium = "^0.29.1" +gymnasium-robotics = "^1.2.4" +mujoco = "^2.3.7" + +[package.source] +type = "git" +url = "git@github.com:huggingface/gym-xarm.git" +reference = "HEAD" +resolved_reference = "2eb83fc4fc871b9d271c946d169e42f226ac3a7c" [[package]] name = "gymnasium" @@ -3611,8 +3631,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [extras] pusht = ["gym_pusht"] +xarm = ["gym_xarm"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "3eee17e4bf2b7a570f41ef9c400ec5a24a3113f62a13162229cf43504ca0d005" +content-hash = "c9524cdf000eaa755a2ab3be669118222b4f8b1c262013f103f6874cbd54eeb6" diff --git a/pyproject.toml b/pyproject.toml index d0fc7c0d..b7e1b9fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,9 +53,13 @@ gymnasium-robotics = "^1.2.4" gymnasium = "^0.29.1" cmake = "^3.29.0.1" gym_pusht = { git = "git@github.com:huggingface/gym-pusht.git", optional = true} +gym_xarm = { git = "git@github.com:huggingface/gym-xarm.git", optional = true} +# gym_pusht = { path = "../gym-pusht", develop = true, optional = true} +# gym_xarm = { path = "../gym-xarm", develop = true, optional = true} [tool.poetry.extras] pusht = ["gym_pusht"] +xarm = ["gym_xarm"] [tool.poetry.group.dev.dependencies] pre-commit = "^3.6.2" diff --git a/tests/test_envs.py b/tests/test_envs.py index 0c56f4fc..665c1fba 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -3,6 +3,8 @@ from tensordict import TensorDict import torch from torchrl.envs.utils import check_env_specs, step_mdp from lerobot.common.datasets.factory import make_dataset +import gymnasium as gym +from gymnasium.utils.env_checker import check_env from lerobot.common.envs.aloha.env import AlohaEnv from lerobot.common.envs.factory import make_env @@ -61,29 +63,26 @@ def test_aloha(task, from_pixels, pixels_only): @pytest.mark.parametrize( - "task,from_pixels,pixels_only", + "task, obs_type", [ - ("lift", False, False), - ("lift", True, False), - ("lift", True, True), + ("XarmLift-v0", "state"), + ("XarmLift-v0", "pixels"), + ("XarmLift-v0", "pixels_agent_pos"), # TODO(aliberts): Add simxarm other tasks - # ("reach", False, False), - # ("reach", True, False), - # ("push", False, False), - # ("push", True, False), - # ("peg_in_box", False, False), - # ("peg_in_box", True, False), ], ) -def test_simxarm(task, from_pixels, pixels_only): - env = SimxarmEnv( - task, - from_pixels=from_pixels, - pixels_only=pixels_only, - image_size=84 if from_pixels else None, - ) +def test_xarm(env_task, obs_type): + import gym_xarm + env = gym.make(f"gym_xarm/{env_task}", obs_type=obs_type) + # env = SimxarmEnv( + # task, + # from_pixels=from_pixels, + # pixels_only=pixels_only, + # image_size=84 if from_pixels else None, + # ) # print_spec_rollout(env) - check_env_specs(env) + # check_env_specs(env) + check_env(env) @pytest.mark.parametrize(