(WIP) Add gym-xarm
This commit is contained in:
parent
c17dffe944
commit
ab3cd3a7ba
|
@ -9,9 +9,17 @@ def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
|
|||
kwargs = {}
|
||||
|
||||
if cfg.env.name == "simxarm":
|
||||
kwargs["task"] = cfg.env.task
|
||||
import gym_xarm # noqa: F401
|
||||
|
||||
assert cfg.env.task == "lift"
|
||||
env_fn = lambda: gym.make(
|
||||
"gym_xarm/XarmLift-v0",
|
||||
render_mode="rgb_array",
|
||||
max_episode_steps=cfg.env.episode_length,
|
||||
**kwargs,
|
||||
)
|
||||
elif cfg.env.name == "pusht":
|
||||
import gym_pusht # noqa
|
||||
import gym_pusht # noqa: F401
|
||||
|
||||
# assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range."
|
||||
kwargs.update(
|
||||
|
|
|
@ -900,7 +900,27 @@ shapely = "^2.0.3"
|
|||
type = "git"
|
||||
url = "git@github.com:huggingface/gym-pusht.git"
|
||||
reference = "HEAD"
|
||||
resolved_reference = "d7e1a39a31b1368741e9674791007d7cccf046a3"
|
||||
resolved_reference = "0fe4449cca5a2b08f529f7a07fbf5b9df24962ec"
|
||||
|
||||
[[package]]
|
||||
name = "gym-xarm"
|
||||
version = "0.1.0"
|
||||
description = "A gym environment for xArm"
|
||||
optional = true
|
||||
python-versions = "^3.10"
|
||||
files = []
|
||||
develop = false
|
||||
|
||||
[package.dependencies]
|
||||
gymnasium = "^0.29.1"
|
||||
gymnasium-robotics = "^1.2.4"
|
||||
mujoco = "^2.3.7"
|
||||
|
||||
[package.source]
|
||||
type = "git"
|
||||
url = "git@github.com:huggingface/gym-xarm.git"
|
||||
reference = "HEAD"
|
||||
resolved_reference = "2eb83fc4fc871b9d271c946d169e42f226ac3a7c"
|
||||
|
||||
[[package]]
|
||||
name = "gymnasium"
|
||||
|
@ -3611,8 +3631,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
|||
|
||||
[extras]
|
||||
pusht = ["gym_pusht"]
|
||||
xarm = ["gym_xarm"]
|
||||
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "3eee17e4bf2b7a570f41ef9c400ec5a24a3113f62a13162229cf43504ca0d005"
|
||||
content-hash = "c9524cdf000eaa755a2ab3be669118222b4f8b1c262013f103f6874cbd54eeb6"
|
||||
|
|
|
@ -53,9 +53,13 @@ gymnasium-robotics = "^1.2.4"
|
|||
gymnasium = "^0.29.1"
|
||||
cmake = "^3.29.0.1"
|
||||
gym_pusht = { git = "git@github.com:huggingface/gym-pusht.git", optional = true}
|
||||
gym_xarm = { git = "git@github.com:huggingface/gym-xarm.git", optional = true}
|
||||
# gym_pusht = { path = "../gym-pusht", develop = true, optional = true}
|
||||
# gym_xarm = { path = "../gym-xarm", develop = true, optional = true}
|
||||
|
||||
[tool.poetry.extras]
|
||||
pusht = ["gym_pusht"]
|
||||
xarm = ["gym_xarm"]
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pre-commit = "^3.6.2"
|
||||
|
|
|
@ -3,6 +3,8 @@ from tensordict import TensorDict
|
|||
import torch
|
||||
from torchrl.envs.utils import check_env_specs, step_mdp
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
import gymnasium as gym
|
||||
from gymnasium.utils.env_checker import check_env
|
||||
|
||||
from lerobot.common.envs.aloha.env import AlohaEnv
|
||||
from lerobot.common.envs.factory import make_env
|
||||
|
@ -61,29 +63,26 @@ def test_aloha(task, from_pixels, pixels_only):
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"task,from_pixels,pixels_only",
|
||||
"task, obs_type",
|
||||
[
|
||||
("lift", False, False),
|
||||
("lift", True, False),
|
||||
("lift", True, True),
|
||||
("XarmLift-v0", "state"),
|
||||
("XarmLift-v0", "pixels"),
|
||||
("XarmLift-v0", "pixels_agent_pos"),
|
||||
# TODO(aliberts): Add simxarm other tasks
|
||||
# ("reach", False, False),
|
||||
# ("reach", True, False),
|
||||
# ("push", False, False),
|
||||
# ("push", True, False),
|
||||
# ("peg_in_box", False, False),
|
||||
# ("peg_in_box", True, False),
|
||||
],
|
||||
)
|
||||
def test_simxarm(task, from_pixels, pixels_only):
|
||||
env = SimxarmEnv(
|
||||
task,
|
||||
from_pixels=from_pixels,
|
||||
pixels_only=pixels_only,
|
||||
image_size=84 if from_pixels else None,
|
||||
)
|
||||
def test_xarm(env_task, obs_type):
|
||||
import gym_xarm
|
||||
env = gym.make(f"gym_xarm/{env_task}", obs_type=obs_type)
|
||||
# env = SimxarmEnv(
|
||||
# task,
|
||||
# from_pixels=from_pixels,
|
||||
# pixels_only=pixels_only,
|
||||
# image_size=84 if from_pixels else None,
|
||||
# )
|
||||
# print_spec_rollout(env)
|
||||
check_env_specs(env)
|
||||
# check_env_specs(env)
|
||||
check_env(env)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
Loading…
Reference in New Issue