(WIP) Add gym-xarm

This commit is contained in:
Simon Alibert 2024-04-05 15:35:20 +02:00
parent c17dffe944
commit ab3cd3a7ba
4 changed files with 54 additions and 22 deletions

View File

@ -9,9 +9,17 @@ def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
kwargs = {}
if cfg.env.name == "simxarm":
kwargs["task"] = cfg.env.task
import gym_xarm # noqa: F401
assert cfg.env.task == "lift"
env_fn = lambda: gym.make(
"gym_xarm/XarmLift-v0",
render_mode="rgb_array",
max_episode_steps=cfg.env.episode_length,
**kwargs,
)
elif cfg.env.name == "pusht":
import gym_pusht # noqa
import gym_pusht # noqa: F401
# assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range."
kwargs.update(

25
poetry.lock generated
View File

@ -900,7 +900,27 @@ shapely = "^2.0.3"
type = "git"
url = "git@github.com:huggingface/gym-pusht.git"
reference = "HEAD"
resolved_reference = "d7e1a39a31b1368741e9674791007d7cccf046a3"
resolved_reference = "0fe4449cca5a2b08f529f7a07fbf5b9df24962ec"
[[package]]
name = "gym-xarm"
version = "0.1.0"
description = "A gym environment for xArm"
optional = true
python-versions = "^3.10"
files = []
develop = false
[package.dependencies]
gymnasium = "^0.29.1"
gymnasium-robotics = "^1.2.4"
mujoco = "^2.3.7"
[package.source]
type = "git"
url = "git@github.com:huggingface/gym-xarm.git"
reference = "HEAD"
resolved_reference = "2eb83fc4fc871b9d271c946d169e42f226ac3a7c"
[[package]]
name = "gymnasium"
@ -3611,8 +3631,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[extras]
pusht = ["gym_pusht"]
xarm = ["gym_xarm"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "3eee17e4bf2b7a570f41ef9c400ec5a24a3113f62a13162229cf43504ca0d005"
content-hash = "c9524cdf000eaa755a2ab3be669118222b4f8b1c262013f103f6874cbd54eeb6"

View File

@ -53,9 +53,13 @@ gymnasium-robotics = "^1.2.4"
gymnasium = "^0.29.1"
cmake = "^3.29.0.1"
gym_pusht = { git = "git@github.com:huggingface/gym-pusht.git", optional = true}
gym_xarm = { git = "git@github.com:huggingface/gym-xarm.git", optional = true}
# gym_pusht = { path = "../gym-pusht", develop = true, optional = true}
# gym_xarm = { path = "../gym-xarm", develop = true, optional = true}
[tool.poetry.extras]
pusht = ["gym_pusht"]
xarm = ["gym_xarm"]
[tool.poetry.group.dev.dependencies]
pre-commit = "^3.6.2"

View File

@ -3,6 +3,8 @@ from tensordict import TensorDict
import torch
from torchrl.envs.utils import check_env_specs, step_mdp
from lerobot.common.datasets.factory import make_dataset
import gymnasium as gym
from gymnasium.utils.env_checker import check_env
from lerobot.common.envs.aloha.env import AlohaEnv
from lerobot.common.envs.factory import make_env
@ -61,29 +63,26 @@ def test_aloha(task, from_pixels, pixels_only):
@pytest.mark.parametrize(
"task,from_pixels,pixels_only",
"task, obs_type",
[
("lift", False, False),
("lift", True, False),
("lift", True, True),
("XarmLift-v0", "state"),
("XarmLift-v0", "pixels"),
("XarmLift-v0", "pixels_agent_pos"),
# TODO(aliberts): Add simxarm other tasks
# ("reach", False, False),
# ("reach", True, False),
# ("push", False, False),
# ("push", True, False),
# ("peg_in_box", False, False),
# ("peg_in_box", True, False),
],
)
def test_simxarm(task, from_pixels, pixels_only):
env = SimxarmEnv(
task,
from_pixels=from_pixels,
pixels_only=pixels_only,
image_size=84 if from_pixels else None,
)
def test_xarm(env_task, obs_type):
import gym_xarm
env = gym.make(f"gym_xarm/{env_task}", obs_type=obs_type)
# env = SimxarmEnv(
# task,
# from_pixels=from_pixels,
# pixels_only=pixels_only,
# image_size=84 if from_pixels else None,
# )
# print_spec_rollout(env)
check_env_specs(env)
# check_env_specs(env)
check_env(env)
@pytest.mark.parametrize(