(WIP) Add gym-xarm
This commit is contained in:
parent
c17dffe944
commit
ab3cd3a7ba
|
@ -9,9 +9,17 @@ def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
|
||||||
if cfg.env.name == "simxarm":
|
if cfg.env.name == "simxarm":
|
||||||
kwargs["task"] = cfg.env.task
|
import gym_xarm # noqa: F401
|
||||||
|
|
||||||
|
assert cfg.env.task == "lift"
|
||||||
|
env_fn = lambda: gym.make(
|
||||||
|
"gym_xarm/XarmLift-v0",
|
||||||
|
render_mode="rgb_array",
|
||||||
|
max_episode_steps=cfg.env.episode_length,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
elif cfg.env.name == "pusht":
|
elif cfg.env.name == "pusht":
|
||||||
import gym_pusht # noqa
|
import gym_pusht # noqa: F401
|
||||||
|
|
||||||
# assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range."
|
# assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range."
|
||||||
kwargs.update(
|
kwargs.update(
|
||||||
|
|
|
@ -900,7 +900,27 @@ shapely = "^2.0.3"
|
||||||
type = "git"
|
type = "git"
|
||||||
url = "git@github.com:huggingface/gym-pusht.git"
|
url = "git@github.com:huggingface/gym-pusht.git"
|
||||||
reference = "HEAD"
|
reference = "HEAD"
|
||||||
resolved_reference = "d7e1a39a31b1368741e9674791007d7cccf046a3"
|
resolved_reference = "0fe4449cca5a2b08f529f7a07fbf5b9df24962ec"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "gym-xarm"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "A gym environment for xArm"
|
||||||
|
optional = true
|
||||||
|
python-versions = "^3.10"
|
||||||
|
files = []
|
||||||
|
develop = false
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
gymnasium = "^0.29.1"
|
||||||
|
gymnasium-robotics = "^1.2.4"
|
||||||
|
mujoco = "^2.3.7"
|
||||||
|
|
||||||
|
[package.source]
|
||||||
|
type = "git"
|
||||||
|
url = "git@github.com:huggingface/gym-xarm.git"
|
||||||
|
reference = "HEAD"
|
||||||
|
resolved_reference = "2eb83fc4fc871b9d271c946d169e42f226ac3a7c"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "gymnasium"
|
name = "gymnasium"
|
||||||
|
@ -3611,8 +3631,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
||||||
|
|
||||||
[extras]
|
[extras]
|
||||||
pusht = ["gym_pusht"]
|
pusht = ["gym_pusht"]
|
||||||
|
xarm = ["gym_xarm"]
|
||||||
|
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.10"
|
python-versions = "^3.10"
|
||||||
content-hash = "3eee17e4bf2b7a570f41ef9c400ec5a24a3113f62a13162229cf43504ca0d005"
|
content-hash = "c9524cdf000eaa755a2ab3be669118222b4f8b1c262013f103f6874cbd54eeb6"
|
||||||
|
|
|
@ -53,9 +53,13 @@ gymnasium-robotics = "^1.2.4"
|
||||||
gymnasium = "^0.29.1"
|
gymnasium = "^0.29.1"
|
||||||
cmake = "^3.29.0.1"
|
cmake = "^3.29.0.1"
|
||||||
gym_pusht = { git = "git@github.com:huggingface/gym-pusht.git", optional = true}
|
gym_pusht = { git = "git@github.com:huggingface/gym-pusht.git", optional = true}
|
||||||
|
gym_xarm = { git = "git@github.com:huggingface/gym-xarm.git", optional = true}
|
||||||
|
# gym_pusht = { path = "../gym-pusht", develop = true, optional = true}
|
||||||
|
# gym_xarm = { path = "../gym-xarm", develop = true, optional = true}
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
pusht = ["gym_pusht"]
|
pusht = ["gym_pusht"]
|
||||||
|
xarm = ["gym_xarm"]
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
pre-commit = "^3.6.2"
|
pre-commit = "^3.6.2"
|
||||||
|
|
|
@ -3,6 +3,8 @@ from tensordict import TensorDict
|
||||||
import torch
|
import torch
|
||||||
from torchrl.envs.utils import check_env_specs, step_mdp
|
from torchrl.envs.utils import check_env_specs, step_mdp
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
|
import gymnasium as gym
|
||||||
|
from gymnasium.utils.env_checker import check_env
|
||||||
|
|
||||||
from lerobot.common.envs.aloha.env import AlohaEnv
|
from lerobot.common.envs.aloha.env import AlohaEnv
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
|
@ -61,29 +63,26 @@ def test_aloha(task, from_pixels, pixels_only):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"task,from_pixels,pixels_only",
|
"task, obs_type",
|
||||||
[
|
[
|
||||||
("lift", False, False),
|
("XarmLift-v0", "state"),
|
||||||
("lift", True, False),
|
("XarmLift-v0", "pixels"),
|
||||||
("lift", True, True),
|
("XarmLift-v0", "pixels_agent_pos"),
|
||||||
# TODO(aliberts): Add simxarm other tasks
|
# TODO(aliberts): Add simxarm other tasks
|
||||||
# ("reach", False, False),
|
|
||||||
# ("reach", True, False),
|
|
||||||
# ("push", False, False),
|
|
||||||
# ("push", True, False),
|
|
||||||
# ("peg_in_box", False, False),
|
|
||||||
# ("peg_in_box", True, False),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_simxarm(task, from_pixels, pixels_only):
|
def test_xarm(env_task, obs_type):
|
||||||
env = SimxarmEnv(
|
import gym_xarm
|
||||||
task,
|
env = gym.make(f"gym_xarm/{env_task}", obs_type=obs_type)
|
||||||
from_pixels=from_pixels,
|
# env = SimxarmEnv(
|
||||||
pixels_only=pixels_only,
|
# task,
|
||||||
image_size=84 if from_pixels else None,
|
# from_pixels=from_pixels,
|
||||||
)
|
# pixels_only=pixels_only,
|
||||||
|
# image_size=84 if from_pixels else None,
|
||||||
|
# )
|
||||||
# print_spec_rollout(env)
|
# print_spec_rollout(env)
|
||||||
check_env_specs(env)
|
# check_env_specs(env)
|
||||||
|
check_env(env)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
Loading…
Reference in New Issue