diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 478be771..c1b14780 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -204,7 +204,7 @@ jobs:
source .venv/bin/activate
python lerobot/scripts/train.py \
policy=tdmpc \
- env=simxarm \
+ env=xarm \
wandb.enable=False \
offline_steps=1 \
online_steps=1 \
@@ -229,6 +229,6 @@ jobs:
python lerobot/scripts/eval.py \
--config lerobot/configs/default.yaml \
policy=tdmpc \
- env=simxarm \
+ env=xarm \
eval_episodes=1 \
device=cpu
diff --git a/README.md b/README.md
index 31fdde0a..51e03d65 100644
--- a/README.md
+++ b/README.md
@@ -62,21 +62,29 @@
Download our source code:
```bash
-git clone https://github.com/huggingface/lerobot.git
-cd lerobot
+git clone https://github.com/huggingface/lerobot.git && cd lerobot
```
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
```bash
-conda create -y -n lerobot python=3.10
-conda activate lerobot
+conda create -y -n lerobot python=3.10 && conda activate lerobot
```
-Then, install 🤗 LeRobot:
+Install 🤗 LeRobot:
```bash
python -m pip install .
```
+For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras:
+- [aloha](https://github.com/huggingface/gym-aloha)
+- [xarm](https://github.com/huggingface/gym-xarm)
+- [pusht](https://github.com/huggingface/gym-pusht)
+
+For instance, to install 🤗 LeRobot with aloha and pusht, use:
+```bash
+python -m pip install ".[aloha, pusht]"
+```
+
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiments tracking, log in with
```bash
wandb login
@@ -89,11 +97,11 @@ wandb login
├── lerobot
| ├── configs # contains hydra yaml files with all options that you can override in the command line
| | ├── default.yaml # selected by default, it loads pusht environment and diffusion policy
-| | ├── env # various sim environments and their datasets: aloha.yaml, pusht.yaml, simxarm.yaml
+| | ├── env # various sim environments and their datasets: aloha.yaml, pusht.yaml, xarm.yaml
| | └── policy # various policies: act.yaml, diffusion.yaml, tdmpc.yaml
| ├── common # contains classes and utilities
-| | ├── datasets # various datasets of human demonstrations: aloha, pusht, simxarm
-| | ├── envs # various sim environments: aloha, pusht, simxarm
+| | ├── datasets # various datasets of human demonstrations: aloha, pusht, xarm
+| | ├── envs # various sim environments: aloha, pusht, xarm
| | └── policies # various policies: act, diffusion, tdmpc
| └── scripts # contains functions to execute via command line
| ├── visualize_dataset.py # load a dataset and render its demonstrations
@@ -198,21 +206,33 @@ pre-commit install
pre-commit
```
-### Add dependencies
+### Dependencies
Instead of using `pip` directly, we use `poetry` for development purposes to easily track our dependencies.
If you don't have it already, follow the [instructions](https://python-poetry.org/docs/#installation) to install it.
-Install the project with:
+Install the project with dev dependencies and all environments:
```bash
-poetry install
+poetry install --sync --with dev --all-extras
+```
+This command should be run when pulling code with and updated version of `pyproject.toml` and `poetry.lock` in order to synchronize your virtual environment with the dependencies.
+
+To selectively install environments (for example aloha and pusht) use:
+```bash
+poetry install --sync --with dev --extras "aloha pusht"
```
-Then, the equivalent of `pip install some-package`, would just be:
+The equivalent of `pip install some-package`, would just be:
```bash
poetry add some-package
```
+When changes are made to the poetry sections of the `pyproject.toml`, you should run the following command to lock dependencies.
+```bash
+poetry lock --no-update
+```
+
+
**NOTE:** Currently, to ensure the CI works properly, any new package must also be added in the CPU-only environment dedicated to the CI. To do this, you should create a separate environment and add the new package there as well. For example:
```bash
# Add the new package to your main poetry env
diff --git a/lerobot/__init__.py b/lerobot/__init__.py
index 5cf8bdb8..4673aab0 100644
--- a/lerobot/__init__.py
+++ b/lerobot/__init__.py
@@ -27,7 +27,7 @@ from lerobot.__version__ import __version__ # noqa: F401
available_envs = [
"aloha",
"pusht",
- "simxarm",
+ "xarm",
]
available_tasks_per_env = {
@@ -36,7 +36,7 @@ available_tasks_per_env = {
"sim_transfer_cube",
],
"pusht": ["pusht"],
- "simxarm": ["lift"],
+ "xarm": ["lift"],
}
available_datasets_per_env = {
@@ -47,7 +47,7 @@ available_datasets_per_env = {
"aloha_sim_transfer_cube_scripted",
],
"pusht": ["pusht"],
- "simxarm": ["xarm_lift_medium"],
+ "xarm": ["xarm_lift_medium"],
}
available_datasets = [dataset for env in available_envs for dataset in available_datasets_per_env[env]]
diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py
index c22ae698..0dab5d4b 100644
--- a/lerobot/common/datasets/factory.py
+++ b/lerobot/common/datasets/factory.py
@@ -19,10 +19,10 @@ def make_dataset(
normalize=True,
stats_path=None,
):
- if cfg.env.name == "simxarm":
- from lerobot.common.datasets.simxarm import SimxarmDataset
+ if cfg.env.name == "xarm":
+ from lerobot.common.datasets.xarm import XarmDataset
- clsfunc = SimxarmDataset
+ clsfunc = XarmDataset
elif cfg.env.name == "pusht":
from lerobot.common.datasets.pusht import PushtDataset
diff --git a/lerobot/common/datasets/simxarm.py b/lerobot/common/datasets/xarm.py
similarity index 99%
rename from lerobot/common/datasets/simxarm.py
rename to lerobot/common/datasets/xarm.py
index 7bddf608..733267ab 100644
--- a/lerobot/common/datasets/simxarm.py
+++ b/lerobot/common/datasets/xarm.py
@@ -24,7 +24,7 @@ def download(raw_dir):
zip_path.unlink()
-class SimxarmDataset(torch.utils.data.Dataset):
+class XarmDataset(torch.utils.data.Dataset):
available_datasets = [
"xarm_lift_medium",
]
diff --git a/lerobot/common/envs/aloha/__init__.py b/lerobot/common/envs/aloha/__init__.py
deleted file mode 100644
index 48907a4c..00000000
--- a/lerobot/common/envs/aloha/__init__.py
+++ /dev/null
@@ -1,21 +0,0 @@
-from gymnasium.envs.registration import register
-
-register(
- id="gym_aloha/AlohaInsertion-v0",
- entry_point="lerobot.common.envs.aloha.env:AlohaEnv",
- max_episode_steps=300,
- # Even after seeding, the rendered observations are slightly different,
- # so we set `nondeterministic=True` to pass `check_env` tests
- nondeterministic=True,
- kwargs={"obs_type": "state", "task": "insertion"},
-)
-
-register(
- id="gym_aloha/AlohaTransferCube-v0",
- entry_point="lerobot.common.envs.aloha.env:AlohaEnv",
- max_episode_steps=300,
- # Even after seeding, the rendered observations are slightly different,
- # so we set `nondeterministic=True` to pass `check_env` tests
- nondeterministic=True,
- kwargs={"obs_type": "state", "task": "transfer_cube"},
-)
diff --git a/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_insertion.xml b/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_insertion.xml
deleted file mode 100644
index 8002838c..00000000
--- a/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_insertion.xml
+++ /dev/null
@@ -1,59 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_transfer_cube.xml b/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_transfer_cube.xml
deleted file mode 100644
index 05249ad2..00000000
--- a/lerobot/common/envs/aloha/assets/bimanual_viperx_end_effector_transfer_cube.xml
+++ /dev/null
@@ -1,48 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/lerobot/common/envs/aloha/assets/bimanual_viperx_insertion.xml b/lerobot/common/envs/aloha/assets/bimanual_viperx_insertion.xml
deleted file mode 100644
index 511f7947..00000000
--- a/lerobot/common/envs/aloha/assets/bimanual_viperx_insertion.xml
+++ /dev/null
@@ -1,53 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/lerobot/common/envs/aloha/assets/bimanual_viperx_transfer_cube.xml b/lerobot/common/envs/aloha/assets/bimanual_viperx_transfer_cube.xml
deleted file mode 100644
index 2d85a47c..00000000
--- a/lerobot/common/envs/aloha/assets/bimanual_viperx_transfer_cube.xml
+++ /dev/null
@@ -1,42 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/lerobot/common/envs/aloha/assets/scene.xml b/lerobot/common/envs/aloha/assets/scene.xml
deleted file mode 100644
index 0f61b8a5..00000000
--- a/lerobot/common/envs/aloha/assets/scene.xml
+++ /dev/null
@@ -1,38 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/lerobot/common/envs/aloha/assets/tabletop.stl b/lerobot/common/envs/aloha/assets/tabletop.stl
deleted file mode 100644
index 1c17d3f0..00000000
--- a/lerobot/common/envs/aloha/assets/tabletop.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:76a1571d1aa36520f2bd81c268991b99816c2a7819464d718e0fd9976fe30dce
-size 684
diff --git a/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_left.stl b/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_left.stl
deleted file mode 100644
index ef1f3f35..00000000
--- a/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_left.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:df73ae5b9058e5d50a6409ac2ab687dade75053a86591bb5e23ab051dbf2d659
-size 83384
diff --git a/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_right.stl b/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_right.stl
deleted file mode 100644
index 7eb8aefd..00000000
--- a/lerobot/common/envs/aloha/assets/vx300s_10_custom_finger_right.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:56fb3cc1236d4193106038adf8e457c7252ae9e86c7cee6dabf0578c53666358
-size 83384
diff --git a/lerobot/common/envs/aloha/assets/vx300s_10_gripper_finger.stl b/lerobot/common/envs/aloha/assets/vx300s_10_gripper_finger.stl
deleted file mode 100644
index 4c2b3a1f..00000000
--- a/lerobot/common/envs/aloha/assets/vx300s_10_gripper_finger.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:a4baacd9a64df1be60ea5e98f50f3c660e1b7a1fe9684aace6004c5058c09483
-size 42884
diff --git a/lerobot/common/envs/aloha/assets/vx300s_11_ar_tag.stl b/lerobot/common/envs/aloha/assets/vx300s_11_ar_tag.stl
deleted file mode 100644
index 8a30f7cc..00000000
--- a/lerobot/common/envs/aloha/assets/vx300s_11_ar_tag.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:a18a1601074d29ed1d546ead70cd18fbb063f1db7b5b96b9f0365be714f3136a
-size 3884
diff --git a/lerobot/common/envs/aloha/assets/vx300s_1_base.stl b/lerobot/common/envs/aloha/assets/vx300s_1_base.stl
deleted file mode 100644
index 9198e625..00000000
--- a/lerobot/common/envs/aloha/assets/vx300s_1_base.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:d100cafe656671ca8fde98fb6a4cf2d1b746995c51c61c25ad9ea2715635d146
-size 99984
diff --git a/lerobot/common/envs/aloha/assets/vx300s_2_shoulder.stl b/lerobot/common/envs/aloha/assets/vx300s_2_shoulder.stl
deleted file mode 100644
index ab3d9570..00000000
--- a/lerobot/common/envs/aloha/assets/vx300s_2_shoulder.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:139745a74055cb0b23430bb5bc032bf68cf7bea5e4975c8f4c04107ae005f7f0
-size 63884
diff --git a/lerobot/common/envs/aloha/assets/vx300s_3_upper_arm.stl b/lerobot/common/envs/aloha/assets/vx300s_3_upper_arm.stl
deleted file mode 100644
index 3d6f663c..00000000
--- a/lerobot/common/envs/aloha/assets/vx300s_3_upper_arm.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:900f236320dd3d500870c5fde763b2d47502d51e043a5c377875e70237108729
-size 102984
diff --git a/lerobot/common/envs/aloha/assets/vx300s_4_upper_forearm.stl b/lerobot/common/envs/aloha/assets/vx300s_4_upper_forearm.stl
deleted file mode 100644
index 4eb249e7..00000000
--- a/lerobot/common/envs/aloha/assets/vx300s_4_upper_forearm.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:4104fc54bbfb8a9b533029f1e7e3ade3d54d638372b3195daa0c98f57e0295b5
-size 49584
diff --git a/lerobot/common/envs/aloha/assets/vx300s_5_lower_forearm.stl b/lerobot/common/envs/aloha/assets/vx300s_5_lower_forearm.stl
deleted file mode 100644
index 34c76221..00000000
--- a/lerobot/common/envs/aloha/assets/vx300s_5_lower_forearm.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:66814e27fa728056416e25e02e89eb7d34c51d51c51e7c3df873829037ddc6b8
-size 99884
diff --git a/lerobot/common/envs/aloha/assets/vx300s_6_wrist.stl b/lerobot/common/envs/aloha/assets/vx300s_6_wrist.stl
deleted file mode 100644
index 232fabf7..00000000
--- a/lerobot/common/envs/aloha/assets/vx300s_6_wrist.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:90eb145c85627968c3776ae6de23ccff7e112c9dd713c46bc9acdfdaa859a048
-size 70784
diff --git a/lerobot/common/envs/aloha/assets/vx300s_7_gripper.stl b/lerobot/common/envs/aloha/assets/vx300s_7_gripper.stl
deleted file mode 100644
index 946c3c86..00000000
--- a/lerobot/common/envs/aloha/assets/vx300s_7_gripper.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:786c1077bfd226f14219581b11d5f19464ca95b17132e0bb7532503568f5af90
-size 450084
diff --git a/lerobot/common/envs/aloha/assets/vx300s_8_gripper_prop.stl b/lerobot/common/envs/aloha/assets/vx300s_8_gripper_prop.stl
deleted file mode 100644
index 28d5bd76..00000000
--- a/lerobot/common/envs/aloha/assets/vx300s_8_gripper_prop.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:d1275a93fe2157c83dbc095617fb7e672888bdd48ec070a35ef4ab9ebd9755b0
-size 31684
diff --git a/lerobot/common/envs/aloha/assets/vx300s_9_gripper_bar.stl b/lerobot/common/envs/aloha/assets/vx300s_9_gripper_bar.stl
deleted file mode 100644
index 5201d5ea..00000000
--- a/lerobot/common/envs/aloha/assets/vx300s_9_gripper_bar.stl
+++ /dev/null
@@ -1,3 +0,0 @@
-version https://git-lfs.github.com/spec/v1
-oid sha256:a4de62c9a2ed2c78433010e4c05530a1254b1774a7651967f406120c9bf8973e
-size 379484
diff --git a/lerobot/common/envs/aloha/assets/vx300s_dependencies.xml b/lerobot/common/envs/aloha/assets/vx300s_dependencies.xml
deleted file mode 100644
index 93037ab7..00000000
--- a/lerobot/common/envs/aloha/assets/vx300s_dependencies.xml
+++ /dev/null
@@ -1,17 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/lerobot/common/envs/aloha/assets/vx300s_left.xml b/lerobot/common/envs/aloha/assets/vx300s_left.xml
deleted file mode 100644
index 3af6c235..00000000
--- a/lerobot/common/envs/aloha/assets/vx300s_left.xml
+++ /dev/null
@@ -1,59 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/lerobot/common/envs/aloha/assets/vx300s_right.xml b/lerobot/common/envs/aloha/assets/vx300s_right.xml
deleted file mode 100644
index 495df478..00000000
--- a/lerobot/common/envs/aloha/assets/vx300s_right.xml
+++ /dev/null
@@ -1,59 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/lerobot/common/envs/aloha/constants.py b/lerobot/common/envs/aloha/constants.py
deleted file mode 100644
index e582e5f3..00000000
--- a/lerobot/common/envs/aloha/constants.py
+++ /dev/null
@@ -1,163 +0,0 @@
-from pathlib import Path
-
-### Simulation envs fixed constants
-DT = 0.02 # 0.02 ms -> 1/0.2 = 50 hz
-FPS = 50
-
-
-JOINTS = [
- # absolute joint position
- "left_arm_waist",
- "left_arm_shoulder",
- "left_arm_elbow",
- "left_arm_forearm_roll",
- "left_arm_wrist_angle",
- "left_arm_wrist_rotate",
- # normalized gripper position 0: close, 1: open
- "left_arm_gripper",
- # absolute joint position
- "right_arm_waist",
- "right_arm_shoulder",
- "right_arm_elbow",
- "right_arm_forearm_roll",
- "right_arm_wrist_angle",
- "right_arm_wrist_rotate",
- # normalized gripper position 0: close, 1: open
- "right_arm_gripper",
-]
-
-ACTIONS = [
- # position and quaternion for end effector
- "left_arm_waist",
- "left_arm_shoulder",
- "left_arm_elbow",
- "left_arm_forearm_roll",
- "left_arm_wrist_angle",
- "left_arm_wrist_rotate",
- # normalized gripper position (0: close, 1: open)
- "left_arm_gripper",
- "right_arm_waist",
- "right_arm_shoulder",
- "right_arm_elbow",
- "right_arm_forearm_roll",
- "right_arm_wrist_angle",
- "right_arm_wrist_rotate",
- # normalized gripper position (0: close, 1: open)
- "right_arm_gripper",
-]
-
-
-START_ARM_POSE = [
- 0,
- -0.96,
- 1.16,
- 0,
- -0.3,
- 0,
- 0.02239,
- -0.02239,
- 0,
- -0.96,
- 1.16,
- 0,
- -0.3,
- 0,
- 0.02239,
- -0.02239,
-]
-
-ASSETS_DIR = Path(__file__).parent.resolve() / "assets" # note: absolute path
-
-# Left finger position limits (qpos[7]), right_finger = -1 * left_finger
-MASTER_GRIPPER_POSITION_OPEN = 0.02417
-MASTER_GRIPPER_POSITION_CLOSE = 0.01244
-PUPPET_GRIPPER_POSITION_OPEN = 0.05800
-PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
-
-# Gripper joint limits (qpos[6])
-MASTER_GRIPPER_JOINT_OPEN = 0.3083
-MASTER_GRIPPER_JOINT_CLOSE = -0.6842
-PUPPET_GRIPPER_JOINT_OPEN = 1.4910
-PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
-
-MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2
-
-############################ Helper functions ############################
-
-
-def normalize_master_gripper_position(x):
- return (x - MASTER_GRIPPER_POSITION_CLOSE) / (
- MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE
- )
-
-
-def normalize_puppet_gripper_position(x):
- return (x - PUPPET_GRIPPER_POSITION_CLOSE) / (
- PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE
- )
-
-
-def unnormalize_master_gripper_position(x):
- return x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
-
-
-def unnormalize_puppet_gripper_position(x):
- return x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
-
-
-def convert_position_from_master_to_puppet(x):
- return unnormalize_puppet_gripper_position(normalize_master_gripper_position(x))
-
-
-def normalizer_master_gripper_joint(x):
- return (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
-
-
-def normalize_puppet_gripper_joint(x):
- return (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
-
-
-def unnormalize_master_gripper_joint(x):
- return x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
-
-
-def unnormalize_puppet_gripper_joint(x):
- return x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
-
-
-def convert_join_from_master_to_puppet(x):
- return unnormalize_puppet_gripper_joint(normalizer_master_gripper_joint(x))
-
-
-def normalize_master_gripper_velocity(x):
- return x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
-
-
-def normalize_puppet_gripper_velocity(x):
- return x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
-
-
-def convert_master_from_position_to_joint(x):
- return (
- normalize_master_gripper_position(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
- + MASTER_GRIPPER_JOINT_CLOSE
- )
-
-
-def convert_master_from_joint_to_position(x):
- return unnormalize_master_gripper_position(
- (x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
- )
-
-
-def convert_puppet_from_position_to_join(x):
- return (
- normalize_puppet_gripper_position(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
- + PUPPET_GRIPPER_JOINT_CLOSE
- )
-
-
-def convert_puppet_from_joint_to_position(x):
- return unnormalize_puppet_gripper_position(
- (x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
- )
diff --git a/lerobot/common/envs/aloha/env.py b/lerobot/common/envs/aloha/env.py
deleted file mode 100644
index bd14e6d8..00000000
--- a/lerobot/common/envs/aloha/env.py
+++ /dev/null
@@ -1,178 +0,0 @@
-import gymnasium as gym
-import numpy as np
-from dm_control import mujoco
-from dm_control.rl import control
-from gymnasium import spaces
-
-from lerobot.common.envs.aloha.constants import (
- ACTIONS,
- ASSETS_DIR,
- DT,
- JOINTS,
-)
-from lerobot.common.envs.aloha.tasks.sim import BOX_POSE, InsertionTask, TransferCubeTask
-from lerobot.common.envs.aloha.tasks.sim_end_effector import (
- InsertionEndEffectorTask,
- TransferCubeEndEffectorTask,
-)
-from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose
-
-
-class AlohaEnv(gym.Env):
- metadata = {"render_modes": [], "render_fps": 50}
-
- def __init__(
- self,
- task,
- obs_type="state",
- observation_width=640,
- observation_height=480,
- visualization_width=640,
- visualization_height=480,
- ):
- super().__init__()
- self.task = task
- self.obs_type = obs_type
- self.observation_width = observation_width
- self.observation_height = observation_height
- self.visualization_width = visualization_width
- self.visualization_height = visualization_height
-
- self._env = self._make_env_task(self.task)
-
- if self.obs_type == "state":
- raise NotImplementedError()
- self.observation_space = spaces.Box(
- low=np.array([0] * len(JOINTS)), # ???
- high=np.array([255] * len(JOINTS)), # ???
- dtype=np.float64,
- )
- elif self.obs_type == "pixels":
- self.observation_space = spaces.Dict(
- {
- "top": spaces.Box(
- low=0,
- high=255,
- shape=(self.observation_height, self.observation_width, 3),
- dtype=np.uint8,
- )
- }
- )
- elif self.obs_type == "pixels_agent_pos":
- self.observation_space = spaces.Dict(
- {
- "pixels": spaces.Dict(
- {
- "top": spaces.Box(
- low=0,
- high=255,
- shape=(self.observation_height, self.observation_width, 3),
- dtype=np.uint8,
- )
- }
- ),
- "agent_pos": spaces.Box(
- low=-np.inf,
- high=np.inf,
- shape=(len(JOINTS),),
- dtype=np.float64,
- ),
- }
- )
-
- self.action_space = spaces.Box(low=-1, high=1, shape=(len(ACTIONS),), dtype=np.float32)
-
- def render(self, mode="rgb_array"):
- # TODO(rcadene): render and visualizer several cameras (e.g. angle, front_close)
- if mode in ["visualize", "human"]:
- height, width = self.visualize_height, self.visualize_width
- elif mode == "rgb_array":
- height, width = self.observation_height, self.observation_width
- else:
- raise ValueError(mode)
- image = self._env.physics.render(height=height, width=width, camera_id="top")
- return image
-
- def _make_env_task(self, task_name):
- # time limit is controlled by StepCounter in env factory
- time_limit = float("inf")
-
- if "transfer_cube" in task_name:
- xml_path = ASSETS_DIR / "bimanual_viperx_transfer_cube.xml"
- physics = mujoco.Physics.from_xml_path(str(xml_path))
- task = TransferCubeTask()
- elif "insertion" in task_name:
- xml_path = ASSETS_DIR / "bimanual_viperx_insertion.xml"
- physics = mujoco.Physics.from_xml_path(str(xml_path))
- task = InsertionTask()
- elif "end_effector_transfer_cube" in task_name:
- raise NotImplementedError()
- xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_transfer_cube.xml"
- physics = mujoco.Physics.from_xml_path(str(xml_path))
- task = TransferCubeEndEffectorTask()
- elif "end_effector_insertion" in task_name:
- raise NotImplementedError()
- xml_path = ASSETS_DIR / "bimanual_viperx_end_effector_insertion.xml"
- physics = mujoco.Physics.from_xml_path(str(xml_path))
- task = InsertionEndEffectorTask()
- else:
- raise NotImplementedError(task_name)
-
- env = control.Environment(
- physics, task, time_limit, control_timestep=DT, n_sub_steps=None, flat_observation=False
- )
- return env
-
- def _format_raw_obs(self, raw_obs):
- if self.obs_type == "state":
- raise NotImplementedError()
- elif self.obs_type == "pixels":
- obs = {"top": raw_obs["images"]["top"].copy()}
- elif self.obs_type == "pixels_agent_pos":
- obs = {
- "pixels": {"top": raw_obs["images"]["top"].copy()},
- "agent_pos": raw_obs["qpos"],
- }
- return obs
-
- def reset(self, seed=None, options=None):
- super().reset(seed=seed)
-
- # TODO(rcadene): how to seed the env?
- if seed is not None:
- self._env.task.random.seed(seed)
- self._env.task._random = np.random.RandomState(seed)
-
- # TODO(rcadene): do not use global variable for this
- if "transfer_cube" in self.task:
- BOX_POSE[0] = sample_box_pose(seed) # used in sim reset
- elif "insertion" in self.task:
- BOX_POSE[0] = np.concatenate(sample_insertion_pose(seed)) # used in sim reset
- else:
- raise ValueError(self.task)
-
- raw_obs = self._env.reset()
-
- observation = self._format_raw_obs(raw_obs.observation)
-
- info = {"is_success": False}
- return observation, info
-
- def step(self, action):
- assert action.ndim == 1
- # TODO(rcadene): add info["is_success"] and info["success"] ?
-
- _, reward, _, raw_obs = self._env.step(action)
-
- # TODO(rcadene): add an enum
- terminated = is_success = reward == 4
-
- info = {"is_success": is_success}
-
- observation = self._format_raw_obs(raw_obs)
-
- truncated = False
- return observation, reward, terminated, truncated, info
-
- def close(self):
- pass
diff --git a/lerobot/common/envs/aloha/tasks/sim.py b/lerobot/common/envs/aloha/tasks/sim.py
deleted file mode 100644
index ee1d0927..00000000
--- a/lerobot/common/envs/aloha/tasks/sim.py
+++ /dev/null
@@ -1,219 +0,0 @@
-import collections
-
-import numpy as np
-from dm_control.suite import base
-
-from lerobot.common.envs.aloha.constants import (
- START_ARM_POSE,
- normalize_puppet_gripper_position,
- normalize_puppet_gripper_velocity,
- unnormalize_puppet_gripper_position,
-)
-
-BOX_POSE = [None] # to be changed from outside
-
-"""
-Environment for simulated robot bi-manual manipulation, with joint position control
-Action space: [left_arm_qpos (6), # absolute joint position
- left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
- right_arm_qpos (6), # absolute joint position
- right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
-
-Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
- left_gripper_position (1), # normalized gripper position (0: close, 1: open)
- right_arm_qpos (6), # absolute joint position
- right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
- "qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
- left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
- right_arm_qvel (6), # absolute joint velocity (rad)
- right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
- "images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
-"""
-
-
-class BimanualViperXTask(base.Task):
- def __init__(self, random=None):
- super().__init__(random=random)
-
- def before_step(self, action, physics):
- left_arm_action = action[:6]
- right_arm_action = action[7 : 7 + 6]
- normalized_left_gripper_action = action[6]
- normalized_right_gripper_action = action[7 + 6]
-
- left_gripper_action = unnormalize_puppet_gripper_position(normalized_left_gripper_action)
- right_gripper_action = unnormalize_puppet_gripper_position(normalized_right_gripper_action)
-
- full_left_gripper_action = [left_gripper_action, -left_gripper_action]
- full_right_gripper_action = [right_gripper_action, -right_gripper_action]
-
- env_action = np.concatenate(
- [left_arm_action, full_left_gripper_action, right_arm_action, full_right_gripper_action]
- )
- super().before_step(env_action, physics)
- return
-
- def initialize_episode(self, physics):
- """Sets the state of the environment at the start of each episode."""
- super().initialize_episode(physics)
-
- @staticmethod
- def get_qpos(physics):
- qpos_raw = physics.data.qpos.copy()
- left_qpos_raw = qpos_raw[:8]
- right_qpos_raw = qpos_raw[8:16]
- left_arm_qpos = left_qpos_raw[:6]
- right_arm_qpos = right_qpos_raw[:6]
- left_gripper_qpos = [normalize_puppet_gripper_position(left_qpos_raw[6])]
- right_gripper_qpos = [normalize_puppet_gripper_position(right_qpos_raw[6])]
- return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
-
- @staticmethod
- def get_qvel(physics):
- qvel_raw = physics.data.qvel.copy()
- left_qvel_raw = qvel_raw[:8]
- right_qvel_raw = qvel_raw[8:16]
- left_arm_qvel = left_qvel_raw[:6]
- right_arm_qvel = right_qvel_raw[:6]
- left_gripper_qvel = [normalize_puppet_gripper_velocity(left_qvel_raw[6])]
- right_gripper_qvel = [normalize_puppet_gripper_velocity(right_qvel_raw[6])]
- return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
-
- @staticmethod
- def get_env_state(physics):
- raise NotImplementedError
-
- def get_observation(self, physics):
- obs = collections.OrderedDict()
- obs["qpos"] = self.get_qpos(physics)
- obs["qvel"] = self.get_qvel(physics)
- obs["env_state"] = self.get_env_state(physics)
- obs["images"] = {}
- obs["images"]["top"] = physics.render(height=480, width=640, camera_id="top")
- obs["images"]["angle"] = physics.render(height=480, width=640, camera_id="angle")
- obs["images"]["vis"] = physics.render(height=480, width=640, camera_id="front_close")
-
- return obs
-
- def get_reward(self, physics):
- # return whether left gripper is holding the box
- raise NotImplementedError
-
-
-class TransferCubeTask(BimanualViperXTask):
- def __init__(self, random=None):
- super().__init__(random=random)
- self.max_reward = 4
-
- def initialize_episode(self, physics):
- """Sets the state of the environment at the start of each episode."""
- # TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
- # reset qpos, control and box position
- with physics.reset_context():
- physics.named.data.qpos[:16] = START_ARM_POSE
- np.copyto(physics.data.ctrl, START_ARM_POSE)
- assert BOX_POSE[0] is not None
- physics.named.data.qpos[-7:] = BOX_POSE[0]
- # print(f"{BOX_POSE=}")
- super().initialize_episode(physics)
-
- @staticmethod
- def get_env_state(physics):
- env_state = physics.data.qpos.copy()[16:]
- return env_state
-
- def get_reward(self, physics):
- # return whether left gripper is holding the box
- all_contact_pairs = []
- for i_contact in range(physics.data.ncon):
- id_geom_1 = physics.data.contact[i_contact].geom1
- id_geom_2 = physics.data.contact[i_contact].geom2
- name_geom_1 = physics.model.id2name(id_geom_1, "geom")
- name_geom_2 = physics.model.id2name(id_geom_2, "geom")
- contact_pair = (name_geom_1, name_geom_2)
- all_contact_pairs.append(contact_pair)
-
- touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
- touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
- touch_table = ("red_box", "table") in all_contact_pairs
-
- reward = 0
- if touch_right_gripper:
- reward = 1
- if touch_right_gripper and not touch_table: # lifted
- reward = 2
- if touch_left_gripper: # attempted transfer
- reward = 3
- if touch_left_gripper and not touch_table: # successful transfer
- reward = 4
- return reward
-
-
-class InsertionTask(BimanualViperXTask):
- def __init__(self, random=None):
- super().__init__(random=random)
- self.max_reward = 4
-
- def initialize_episode(self, physics):
- """Sets the state of the environment at the start of each episode."""
- # TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
- # reset qpos, control and box position
- with physics.reset_context():
- physics.named.data.qpos[:16] = START_ARM_POSE
- np.copyto(physics.data.ctrl, START_ARM_POSE)
- assert BOX_POSE[0] is not None
- physics.named.data.qpos[-7 * 2 :] = BOX_POSE[0] # two objects
- # print(f"{BOX_POSE=}")
- super().initialize_episode(physics)
-
- @staticmethod
- def get_env_state(physics):
- env_state = physics.data.qpos.copy()[16:]
- return env_state
-
- def get_reward(self, physics):
- # return whether peg touches the pin
- all_contact_pairs = []
- for i_contact in range(physics.data.ncon):
- id_geom_1 = physics.data.contact[i_contact].geom1
- id_geom_2 = physics.data.contact[i_contact].geom2
- name_geom_1 = physics.model.id2name(id_geom_1, "geom")
- name_geom_2 = physics.model.id2name(id_geom_2, "geom")
- contact_pair = (name_geom_1, name_geom_2)
- all_contact_pairs.append(contact_pair)
-
- touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
- touch_left_gripper = (
- ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
- or ("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
- or ("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
- or ("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
- )
-
- peg_touch_table = ("red_peg", "table") in all_contact_pairs
- socket_touch_table = (
- ("socket-1", "table") in all_contact_pairs
- or ("socket-2", "table") in all_contact_pairs
- or ("socket-3", "table") in all_contact_pairs
- or ("socket-4", "table") in all_contact_pairs
- )
- peg_touch_socket = (
- ("red_peg", "socket-1") in all_contact_pairs
- or ("red_peg", "socket-2") in all_contact_pairs
- or ("red_peg", "socket-3") in all_contact_pairs
- or ("red_peg", "socket-4") in all_contact_pairs
- )
- pin_touched = ("red_peg", "pin") in all_contact_pairs
-
- reward = 0
- if touch_left_gripper and touch_right_gripper: # touch both
- reward = 1
- if (
- touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table)
- ): # grasp both
- reward = 2
- if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
- reward = 3
- if pin_touched: # successful insertion
- reward = 4
- return reward
diff --git a/lerobot/common/envs/aloha/tasks/sim_end_effector.py b/lerobot/common/envs/aloha/tasks/sim_end_effector.py
deleted file mode 100644
index d93c8330..00000000
--- a/lerobot/common/envs/aloha/tasks/sim_end_effector.py
+++ /dev/null
@@ -1,263 +0,0 @@
-import collections
-
-import numpy as np
-from dm_control.suite import base
-
-from lerobot.common.envs.aloha.constants import (
- PUPPET_GRIPPER_POSITION_CLOSE,
- START_ARM_POSE,
- normalize_puppet_gripper_position,
- normalize_puppet_gripper_velocity,
- unnormalize_puppet_gripper_position,
-)
-from lerobot.common.envs.aloha.utils import sample_box_pose, sample_insertion_pose
-
-"""
-Environment for simulated robot bi-manual manipulation, with end-effector control.
-Action space: [left_arm_pose (7), # position and quaternion for end effector
- left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
- right_arm_pose (7), # position and quaternion for end effector
- right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
-
-Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
- left_gripper_position (1), # normalized gripper position (0: close, 1: open)
- right_arm_qpos (6), # absolute joint position
- right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
- "qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
- left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
- right_arm_qvel (6), # absolute joint velocity (rad)
- right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
- "images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
-"""
-
-
-class BimanualViperXEndEffectorTask(base.Task):
- def __init__(self, random=None):
- super().__init__(random=random)
-
- def before_step(self, action, physics):
- a_len = len(action) // 2
- action_left = action[:a_len]
- action_right = action[a_len:]
-
- # set mocap position and quat
- # left
- np.copyto(physics.data.mocap_pos[0], action_left[:3])
- np.copyto(physics.data.mocap_quat[0], action_left[3:7])
- # right
- np.copyto(physics.data.mocap_pos[1], action_right[:3])
- np.copyto(physics.data.mocap_quat[1], action_right[3:7])
-
- # set gripper
- g_left_ctrl = unnormalize_puppet_gripper_position(action_left[7])
- g_right_ctrl = unnormalize_puppet_gripper_position(action_right[7])
- np.copyto(physics.data.ctrl, np.array([g_left_ctrl, -g_left_ctrl, g_right_ctrl, -g_right_ctrl]))
-
- def initialize_robots(self, physics):
- # reset joint position
- physics.named.data.qpos[:16] = START_ARM_POSE
-
- # reset mocap to align with end effector
- # to obtain these numbers:
- # (1) make an ee_sim env and reset to the same start_pose
- # (2) get env._physics.named.data.xpos['vx300s_left/gripper_link']
- # get env._physics.named.data.xquat['vx300s_left/gripper_link']
- # repeat the same for right side
- np.copyto(physics.data.mocap_pos[0], [-0.31718881, 0.5, 0.29525084])
- np.copyto(physics.data.mocap_quat[0], [1, 0, 0, 0])
- # right
- np.copyto(physics.data.mocap_pos[1], np.array([0.31718881, 0.49999888, 0.29525084]))
- np.copyto(physics.data.mocap_quat[1], [1, 0, 0, 0])
-
- # reset gripper control
- close_gripper_control = np.array(
- [
- PUPPET_GRIPPER_POSITION_CLOSE,
- -PUPPET_GRIPPER_POSITION_CLOSE,
- PUPPET_GRIPPER_POSITION_CLOSE,
- -PUPPET_GRIPPER_POSITION_CLOSE,
- ]
- )
- np.copyto(physics.data.ctrl, close_gripper_control)
-
- def initialize_episode(self, physics):
- """Sets the state of the environment at the start of each episode."""
- super().initialize_episode(physics)
-
- @staticmethod
- def get_qpos(physics):
- qpos_raw = physics.data.qpos.copy()
- left_qpos_raw = qpos_raw[:8]
- right_qpos_raw = qpos_raw[8:16]
- left_arm_qpos = left_qpos_raw[:6]
- right_arm_qpos = right_qpos_raw[:6]
- left_gripper_qpos = [normalize_puppet_gripper_position(left_qpos_raw[6])]
- right_gripper_qpos = [normalize_puppet_gripper_position(right_qpos_raw[6])]
- return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
-
- @staticmethod
- def get_qvel(physics):
- qvel_raw = physics.data.qvel.copy()
- left_qvel_raw = qvel_raw[:8]
- right_qvel_raw = qvel_raw[8:16]
- left_arm_qvel = left_qvel_raw[:6]
- right_arm_qvel = right_qvel_raw[:6]
- left_gripper_qvel = [normalize_puppet_gripper_velocity(left_qvel_raw[6])]
- right_gripper_qvel = [normalize_puppet_gripper_velocity(right_qvel_raw[6])]
- return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
-
- @staticmethod
- def get_env_state(physics):
- raise NotImplementedError
-
- def get_observation(self, physics):
- # note: it is important to do .copy()
- obs = collections.OrderedDict()
- obs["qpos"] = self.get_qpos(physics)
- obs["qvel"] = self.get_qvel(physics)
- obs["env_state"] = self.get_env_state(physics)
- obs["images"] = {}
- obs["images"]["top"] = physics.render(height=480, width=640, camera_id="top")
- obs["images"]["angle"] = physics.render(height=480, width=640, camera_id="angle")
- obs["images"]["vis"] = physics.render(height=480, width=640, camera_id="front_close")
- # used in scripted policy to obtain starting pose
- obs["mocap_pose_left"] = np.concatenate(
- [physics.data.mocap_pos[0], physics.data.mocap_quat[0]]
- ).copy()
- obs["mocap_pose_right"] = np.concatenate(
- [physics.data.mocap_pos[1], physics.data.mocap_quat[1]]
- ).copy()
-
- # used when replaying joint trajectory
- obs["gripper_ctrl"] = physics.data.ctrl.copy()
- return obs
-
- def get_reward(self, physics):
- raise NotImplementedError
-
-
-class TransferCubeEndEffectorTask(BimanualViperXEndEffectorTask):
- def __init__(self, random=None):
- super().__init__(random=random)
- self.max_reward = 4
-
- def initialize_episode(self, physics):
- """Sets the state of the environment at the start of each episode."""
- self.initialize_robots(physics)
- # randomize box position
- cube_pose = sample_box_pose()
- box_start_idx = physics.model.name2id("red_box_joint", "joint")
- np.copyto(physics.data.qpos[box_start_idx : box_start_idx + 7], cube_pose)
- # print(f"randomized cube position to {cube_position}")
-
- super().initialize_episode(physics)
-
- @staticmethod
- def get_env_state(physics):
- env_state = physics.data.qpos.copy()[16:]
- return env_state
-
- def get_reward(self, physics):
- # return whether left gripper is holding the box
- all_contact_pairs = []
- for i_contact in range(physics.data.ncon):
- id_geom_1 = physics.data.contact[i_contact].geom1
- id_geom_2 = physics.data.contact[i_contact].geom2
- name_geom_1 = physics.model.id2name(id_geom_1, "geom")
- name_geom_2 = physics.model.id2name(id_geom_2, "geom")
- contact_pair = (name_geom_1, name_geom_2)
- all_contact_pairs.append(contact_pair)
-
- touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
- touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
- touch_table = ("red_box", "table") in all_contact_pairs
-
- reward = 0
- if touch_right_gripper:
- reward = 1
- if touch_right_gripper and not touch_table: # lifted
- reward = 2
- if touch_left_gripper: # attempted transfer
- reward = 3
- if touch_left_gripper and not touch_table: # successful transfer
- reward = 4
- return reward
-
-
-class InsertionEndEffectorTask(BimanualViperXEndEffectorTask):
- def __init__(self, random=None):
- super().__init__(random=random)
- self.max_reward = 4
-
- def initialize_episode(self, physics):
- """Sets the state of the environment at the start of each episode."""
- self.initialize_robots(physics)
- # randomize peg and socket position
- peg_pose, socket_pose = sample_insertion_pose()
-
- def id2index(j_id):
- return 16 + (j_id - 16) * 7 # first 16 is robot qpos, 7 is pose dim # hacky
-
- peg_start_id = physics.model.name2id("red_peg_joint", "joint")
- peg_start_idx = id2index(peg_start_id)
- np.copyto(physics.data.qpos[peg_start_idx : peg_start_idx + 7], peg_pose)
- # print(f"randomized cube position to {cube_position}")
-
- socket_start_id = physics.model.name2id("blue_socket_joint", "joint")
- socket_start_idx = id2index(socket_start_id)
- np.copyto(physics.data.qpos[socket_start_idx : socket_start_idx + 7], socket_pose)
- # print(f"randomized cube position to {cube_position}")
-
- super().initialize_episode(physics)
-
- @staticmethod
- def get_env_state(physics):
- env_state = physics.data.qpos.copy()[16:]
- return env_state
-
- def get_reward(self, physics):
- # return whether peg touches the pin
- all_contact_pairs = []
- for i_contact in range(physics.data.ncon):
- id_geom_1 = physics.data.contact[i_contact].geom1
- id_geom_2 = physics.data.contact[i_contact].geom2
- name_geom_1 = physics.model.id2name(id_geom_1, "geom")
- name_geom_2 = physics.model.id2name(id_geom_2, "geom")
- contact_pair = (name_geom_1, name_geom_2)
- all_contact_pairs.append(contact_pair)
-
- touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
- touch_left_gripper = (
- ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
- or ("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
- or ("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
- or ("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
- )
-
- peg_touch_table = ("red_peg", "table") in all_contact_pairs
- socket_touch_table = (
- ("socket-1", "table") in all_contact_pairs
- or ("socket-2", "table") in all_contact_pairs
- or ("socket-3", "table") in all_contact_pairs
- or ("socket-4", "table") in all_contact_pairs
- )
- peg_touch_socket = (
- ("red_peg", "socket-1") in all_contact_pairs
- or ("red_peg", "socket-2") in all_contact_pairs
- or ("red_peg", "socket-3") in all_contact_pairs
- or ("red_peg", "socket-4") in all_contact_pairs
- )
- pin_touched = ("red_peg", "pin") in all_contact_pairs
-
- reward = 0
- if touch_left_gripper and touch_right_gripper: # touch both
- reward = 1
- if (
- touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table)
- ): # grasp both
- reward = 2
- if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
- reward = 3
- if pin_touched: # successful insertion
- reward = 4
- return reward
diff --git a/lerobot/common/envs/aloha/utils.py b/lerobot/common/envs/aloha/utils.py
deleted file mode 100644
index 5b7d8cfe..00000000
--- a/lerobot/common/envs/aloha/utils.py
+++ /dev/null
@@ -1,43 +0,0 @@
-import numpy as np
-
-
-def sample_box_pose(seed=None):
- x_range = [0.0, 0.2]
- y_range = [0.4, 0.6]
- z_range = [0.05, 0.05]
-
- rng = np.random.RandomState(seed)
-
- ranges = np.vstack([x_range, y_range, z_range])
- cube_position = rng.uniform(ranges[:, 0], ranges[:, 1])
-
- cube_quat = np.array([1, 0, 0, 0])
- return np.concatenate([cube_position, cube_quat])
-
-
-def sample_insertion_pose(seed=None):
- # Peg
- x_range = [0.1, 0.2]
- y_range = [0.4, 0.6]
- z_range = [0.05, 0.05]
-
- rng = np.random.RandomState(seed)
-
- ranges = np.vstack([x_range, y_range, z_range])
- peg_position = rng.uniform(ranges[:, 0], ranges[:, 1])
-
- peg_quat = np.array([1, 0, 0, 0])
- peg_pose = np.concatenate([peg_position, peg_quat])
-
- # Socket
- x_range = [-0.2, -0.1]
- y_range = [0.4, 0.6]
- z_range = [0.05, 0.05]
-
- ranges = np.vstack([x_range, y_range, z_range])
- socket_position = rng.uniform(ranges[:, 0], ranges[:, 1])
-
- socket_quat = np.array([1, 0, 0, 0])
- socket_pose = np.concatenate([socket_position, socket_quat])
-
- return peg_pose, socket_pose
diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py
index 9225cbc5..c8d10851 100644
--- a/lerobot/common/envs/factory.py
+++ b/lerobot/common/envs/factory.py
@@ -1,3 +1,5 @@
+import importlib
+
import gymnasium as gym
@@ -8,43 +10,30 @@ def make_env(cfg, num_parallel_envs=0) -> gym.Env | gym.vector.SyncVectorEnv:
"""
kwargs = {
"obs_type": "pixels_agent_pos",
+ "render_mode": "rgb_array",
"max_episode_steps": cfg.env.episode_length,
"visualization_width": 384,
"visualization_height": 384,
}
- if cfg.env.name == "simxarm":
- import gym_xarm # noqa: F401
+ package_name = f"gym_{cfg.env.name}"
- assert cfg.env.task == "lift"
- env_fn = lambda: gym.make( # noqa: E731
- "gym_xarm/XarmLift-v0",
- **kwargs,
+ try:
+ importlib.import_module(package_name)
+ except ModuleNotFoundError as e:
+ print(
+ f"{package_name} is not installed. Please install it with `pip install 'lerobot[{cfg.env.name}]'`"
)
- elif cfg.env.name == "pusht":
- import gym_pusht # noqa: F401
+ raise e
- # assert kwargs["seed"] > 200, "Seed 0-200 are used for the demonstration dataset, so we don't want to seed the eval env with this range."
- env_fn = lambda: gym.make( # noqa: E731
- "gym_pusht/PushTPixels-v0",
- **kwargs,
- )
- elif cfg.env.name == "aloha":
- from lerobot.common.envs import aloha as gym_aloha # noqa: F401
-
- kwargs["task"] = cfg.env.task
-
- env_fn = lambda: gym.make( # noqa: E731
- "gym_aloha/AlohaInsertion-v0",
- **kwargs,
- )
- else:
- raise ValueError(cfg.env.name)
+ gym_handle = f"{package_name}/{cfg.env.task}"
if num_parallel_envs == 0:
# non-batched version of the env that returns an observation of shape (c)
- env = env_fn()
+ env = gym.make(gym_handle, **kwargs)
else:
# batched version of the env that returns an observation of shape (b, c)
- env = gym.vector.SyncVectorEnv([env_fn for _ in range(num_parallel_envs)])
+ env = gym.vector.SyncVectorEnv(
+ [lambda: gym.make(gym_handle, **kwargs) for _ in range(num_parallel_envs)]
+ )
return env
diff --git a/lerobot/configs/env/aloha.yaml b/lerobot/configs/env/aloha.yaml
index 2bfbbaa8..7a8d8b58 100644
--- a/lerobot/configs/env/aloha.yaml
+++ b/lerobot/configs/env/aloha.yaml
@@ -4,7 +4,7 @@ eval_episodes: 50
eval_freq: 7500
save_freq: 75000
log_freq: 250
-# TODO: same as simxarm, need to adjust
+# TODO: same as xarm, need to adjust
offline_steps: 25000
online_steps: 25000
@@ -14,7 +14,7 @@ dataset_id: aloha_sim_insertion_human
env:
name: aloha
- task: insertion
+ task: AlohaInsertion-v0
from_pixels: True
pixels_only: False
image_size: [3, 480, 640]
diff --git a/lerobot/configs/env/pusht.yaml b/lerobot/configs/env/pusht.yaml
index 0050530e..a5fbcc25 100644
--- a/lerobot/configs/env/pusht.yaml
+++ b/lerobot/configs/env/pusht.yaml
@@ -4,7 +4,7 @@ eval_episodes: 50
eval_freq: 7500
save_freq: 75000
log_freq: 250
-# TODO: same as simxarm, need to adjust
+# TODO: same as xarm, need to adjust
offline_steps: 25000
online_steps: 25000
@@ -14,7 +14,7 @@ dataset_id: pusht
env:
name: pusht
- task: pusht
+ task: PushT-v0
from_pixels: True
pixels_only: False
image_size: 96
diff --git a/lerobot/configs/env/simxarm.yaml b/lerobot/configs/env/xarm.yaml
similarity index 91%
rename from lerobot/configs/env/simxarm.yaml
rename to lerobot/configs/env/xarm.yaml
index 843f80c6..8b3c72ef 100644
--- a/lerobot/configs/env/simxarm.yaml
+++ b/lerobot/configs/env/xarm.yaml
@@ -12,8 +12,8 @@ fps: 15
dataset_id: xarm_lift_medium
env:
- name: simxarm
- task: lift
+ name: xarm
+ task: XarmLift-v0
from_pixels: True
pixels_only: False
image_size: 84
diff --git a/poetry.lock b/poetry.lock
index b9e31930..98449df4 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -879,10 +879,30 @@ files = [
[package.extras]
protobuf = ["grpcio-tools (>=1.62.1)"]
+[[package]]
+name = "gym-aloha"
+version = "0.1.0"
+description = "A gym environment for ALOHA"
+optional = true
+python-versions = "^3.10"
+files = []
+develop = false
+
+[package.dependencies]
+dm-control = "1.0.14"
+gymnasium = "^0.29.1"
+mujoco = "^2.3.7"
+
+[package.source]
+type = "git"
+url = "git@github.com:huggingface/gym-aloha.git"
+reference = "HEAD"
+resolved_reference = "ec7200831e36c14e343cf7d275c6b047f2fe9d11"
+
[[package]]
name = "gym-pusht"
version = "0.1.0"
-description = "PushT environment for LeRobot"
+description = "A gymnasium environment for PushT."
optional = true
python-versions = "^3.10"
files = []
@@ -900,7 +920,7 @@ shapely = "^2.0.3"
type = "git"
url = "git@github.com:huggingface/gym-pusht.git"
reference = "HEAD"
-resolved_reference = "0fe4449cca5a2b08f529f7a07fbf5b9df24962ec"
+resolved_reference = "6c9893504f670ff069d0f759a733e971ea1efdbf"
[[package]]
name = "gym-xarm"
@@ -920,7 +940,7 @@ mujoco = "^2.3.7"
type = "git"
url = "git@github.com:huggingface/gym-xarm.git"
reference = "HEAD"
-resolved_reference = "2eb83fc4fc871b9d271c946d169e42f226ac3a7c"
+resolved_reference = "08ddd5a9400783a6898bbf3c3014fc5da3961b9d"
[[package]]
name = "gymnasium"
@@ -3630,10 +3650,11 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link
testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"]
[extras]
-pusht = ["gym_pusht"]
-xarm = ["gym_xarm"]
+aloha = ["gym-aloha"]
+pusht = ["gym-pusht"]
+xarm = ["gym-xarm"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
-content-hash = "c9524cdf000eaa755a2ab3be669118222b4f8b1c262013f103f6874cbd54eeb6"
+content-hash = "cb450ac7186e004536d75409edd42cd96062f7b1fd47822a5460d12eab8762f9"
diff --git a/pyproject.toml b/pyproject.toml
index b7e1b9fb..e78a502d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -52,14 +52,17 @@ robomimic = "0.2.0"
gymnasium-robotics = "^1.2.4"
gymnasium = "^0.29.1"
cmake = "^3.29.0.1"
-gym_pusht = { git = "git@github.com:huggingface/gym-pusht.git", optional = true}
-gym_xarm = { git = "git@github.com:huggingface/gym-xarm.git", optional = true}
-# gym_pusht = { path = "../gym-pusht", develop = true, optional = true}
-# gym_xarm = { path = "../gym-xarm", develop = true, optional = true}
+gym-pusht = { git = "git@github.com:huggingface/gym-pusht.git", optional = true}
+gym-xarm = { git = "git@github.com:huggingface/gym-xarm.git", optional = true}
+gym-aloha = { git = "git@github.com:huggingface/gym-aloha.git", optional = true}
+# gym-pusht = { path = "../gym-pusht", develop = true, optional = true}
+# gym-xarm = { path = "../gym-xarm", develop = true, optional = true}
+# gym-aloha = { path = "../gym-aloha", develop = true, optional = true}
[tool.poetry.extras]
-pusht = ["gym_pusht"]
-xarm = ["gym_xarm"]
+pusht = ["gym-pusht"]
+xarm = ["gym-xarm"]
+aloha = ["gym-aloha"]
[tool.poetry.group.dev.dependencies]
pre-commit = "^3.6.2"
diff --git a/tests/test_available.py b/tests/test_available.py
index 8a2ece38..8df2c945 100644
--- a/tests/test_available.py
+++ b/tests/test_available.py
@@ -19,7 +19,7 @@ import lerobot
# from gym_pusht.envs import PushtEnv
# from gym_xarm.envs import SimxarmEnv
-# from lerobot.common.datasets.simxarm import SimxarmDataset
+# from lerobot.common.datasets.xarm import SimxarmDataset
# from lerobot.common.datasets.aloha import AlohaDataset
# from lerobot.common.datasets.pusht import PushtDataset
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index e7777c16..e24d7b4d 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -11,7 +11,7 @@ from .utils import DEVICE, DEFAULT_CONFIG_PATH
@pytest.mark.parametrize(
"env_name,dataset_id,policy_name",
[
- ("simxarm", "xarm_lift_medium", "tdmpc"),
+ ("xarm", "xarm_lift_medium", "tdmpc"),
("pusht", "pusht", "diffusion"),
("aloha", "aloha_sim_insertion_human", "act"),
("aloha", "aloha_sim_insertion_scripted", "act"),
diff --git a/tests/test_envs.py b/tests/test_envs.py
index effe4032..72bc93c4 100644
--- a/tests/test_envs.py
+++ b/tests/test_envs.py
@@ -1,3 +1,4 @@
+import importlib
import pytest
import torch
from lerobot.common.datasets.factory import make_dataset
@@ -13,49 +14,25 @@ from .utils import DEVICE, DEFAULT_CONFIG_PATH
@pytest.mark.parametrize(
- "env_task, obs_type",
+ "env_name, task, obs_type",
[
# ("AlohaInsertion-v0", "state"),
- ("AlohaInsertion-v0", "pixels"),
- ("AlohaInsertion-v0", "pixels_agent_pos"),
- ("AlohaTransferCube-v0", "pixels"),
- ("AlohaTransferCube-v0", "pixels_agent_pos"),
+ ("aloha", "AlohaInsertion-v0", "pixels"),
+ ("aloha", "AlohaInsertion-v0", "pixels_agent_pos"),
+ ("aloha", "AlohaTransferCube-v0", "pixels"),
+ ("aloha", "AlohaTransferCube-v0", "pixels_agent_pos"),
+ ("xarm", "XarmLift-v0", "state"),
+ ("xarm", "XarmLift-v0", "pixels"),
+ ("xarm", "XarmLift-v0", "pixels_agent_pos"),
+ ("pusht", "PushT-v0", "state"),
+ ("pusht", "PushT-v0", "pixels"),
+ ("pusht", "PushT-v0", "pixels_agent_pos"),
],
)
-def test_aloha(env_task, obs_type):
- from lerobot.common.envs import aloha as gym_aloha # noqa: F401
- env = gym.make(f"gym_aloha/{env_task}", obs_type=obs_type)
- check_env(env.unwrapped)
-
-
-
-@pytest.mark.parametrize(
- "env_task, obs_type",
- [
- ("XarmLift-v0", "state"),
- ("XarmLift-v0", "pixels"),
- ("XarmLift-v0", "pixels_agent_pos"),
- # TODO(aliberts): Add gym_xarm other tasks
- ],
-)
-def test_xarm(env_task, obs_type):
- import gym_xarm # noqa: F401
- env = gym.make(f"gym_xarm/{env_task}", obs_type=obs_type)
- check_env(env.unwrapped)
-
-
-
-@pytest.mark.parametrize(
- "env_task, obs_type",
- [
- ("PushTPixels-v0", "state"),
- ("PushTPixels-v0", "pixels"),
- ("PushTPixels-v0", "pixels_agent_pos"),
- ],
-)
-def test_pusht(env_task, obs_type):
- import gym_pusht # noqa: F401
- env = gym.make(f"gym_pusht/{env_task}", obs_type=obs_type)
+def test_env(env_name, task, obs_type):
+ package_name = f"gym_{env_name}"
+ importlib.import_module(package_name)
+ env = gym.make(f"{package_name}/{task}", obs_type=obs_type)
check_env(env.unwrapped)
@@ -63,7 +40,7 @@ def test_pusht(env_task, obs_type):
"env_name",
[
"pusht",
- "simxarm",
+ "xarm",
"aloha",
],
)
@@ -76,7 +53,7 @@ def test_factory(env_name):
dataset = make_dataset(cfg)
env = make_env(cfg, num_parallel_envs=1)
- obs, info = env.reset()
+ obs, _ = env.reset()
obs = preprocess_observation(obs, transform=dataset.transform)
for key in dataset.image_keys:
img = obs[key]
diff --git a/tests/test_policies.py b/tests/test_policies.py
index c79bff94..82033b78 100644
--- a/tests/test_policies.py
+++ b/tests/test_policies.py
@@ -12,15 +12,15 @@ from .utils import DEVICE, DEFAULT_CONFIG_PATH
@pytest.mark.parametrize(
"env_name,policy_name,extra_overrides",
[
- ("simxarm", "tdmpc", ["policy.mpc=true"]),
+ ("xarm", "tdmpc", ["policy.mpc=true"]),
("pusht", "tdmpc", ["policy.mpc=false"]),
("pusht", "diffusion", []),
# ("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_human"]),
#("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_scripted"]),
#("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_human"]),
#("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_scripted"]),
- # TODO(aliberts): simxarm not working with diffusion
- # ("simxarm", "diffusion", []),
+ # TODO(aliberts): xarm not working with diffusion
+ # ("xarm", "diffusion", []),
],
)
def test_policy(env_name, policy_name, extra_overrides):