Merge branch 'user/adil-zouitine/2025-1-7-port-hil-serl-new' of github.com:huggingface/lerobot into user/adil-zouitine/2025-1-7-port-hil-serl-new
This commit is contained in:
commit
2449b3cca1
|
@ -94,6 +94,8 @@ python lerobot/scripts/find_motors_bus_port.py
|
||||||
|
|
||||||
#### b. Example outputs
|
#### b. Example outputs
|
||||||
|
|
||||||
|
#### b. Example outputs
|
||||||
|
|
||||||
Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem575E0031751` on Mac, or possibly `/dev/ttyACM0` on Linux):
|
Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem575E0031751` on Mac, or possibly `/dev/ttyACM0` on Linux):
|
||||||
```
|
```
|
||||||
Finding all available ports for the MotorBus.
|
Finding all available ports for the MotorBus.
|
||||||
|
@ -117,6 +119,8 @@ The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0032081
|
||||||
Reconnect the usb cable.
|
Reconnect the usb cable.
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### c. Troubleshooting
|
||||||
|
On Linux, you might need to give access to the USB ports by running:
|
||||||
#### c. Troubleshooting
|
#### c. Troubleshooting
|
||||||
On Linux, you might need to give access to the USB ports by running:
|
On Linux, you might need to give access to the USB ports by running:
|
||||||
```bash
|
```bash
|
||||||
|
@ -233,6 +237,7 @@ Follow the video for removing gears. You need to remove the gear for the motors
|
||||||
Follow the video for adding the motor horn. For SO-100, you need to align the holes on the motor horn to the motor spline to be approximately 1:30, 4:30, 7:30 and 10:30.
|
Follow the video for adding the motor horn. For SO-100, you need to align the holes on the motor horn to the motor spline to be approximately 1:30, 4:30, 7:30 and 10:30.
|
||||||
Try to avoid rotating the motor while doing so to keep position 2048 set during configuration. It is especially tricky for the leader motors as it is more sensible without the gears, but it's ok if it's a bit rotated.
|
Try to avoid rotating the motor while doing so to keep position 2048 set during configuration. It is especially tricky for the leader motors as it is more sensible without the gears, but it's ok if it's a bit rotated.
|
||||||
|
|
||||||
|
## D. Assemble the arms
|
||||||
## D. Assemble the arms
|
## D. Assemble the arms
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
|
@ -244,6 +249,7 @@ Try to avoid rotating the motor while doing so to keep position 2048 set during
|
||||||
|
|
||||||
Follow the video for assembling the arms. It is important to insert the cables into the motor that is being assembled before you assemble the motor into the arm! Inserting the cables beforehand is much easier than doing this afterward. The first arm should take a bit more than 1 hour to assemble, but once you get used to it, you can do it under 1 hour for the second arm.
|
Follow the video for assembling the arms. It is important to insert the cables into the motor that is being assembled before you assemble the motor into the arm! Inserting the cables beforehand is much easier than doing this afterward. The first arm should take a bit more than 1 hour to assemble, but once you get used to it, you can do it under 1 hour for the second arm.
|
||||||
|
|
||||||
|
## E. Calibrate
|
||||||
## E. Calibrate
|
## E. Calibrate
|
||||||
|
|
||||||
Next, you'll need to calibrate your SO-100 robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. This calibration is essential because it allows a neural network trained on one SO-100 robot to work on another.
|
Next, you'll need to calibrate your SO-100 robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. This calibration is essential because it allows a neural network trained on one SO-100 robot to work on another.
|
||||||
|
@ -268,6 +274,8 @@ python lerobot/scripts/control_robot.py \
|
||||||
--control.arms='["main_follower"]'
|
--control.arms='["main_follower"]'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### b. Manual calibration of leader arm
|
||||||
|
Follow step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
|
||||||
#### b. Manual calibration of leader arm
|
#### b. Manual calibration of leader arm
|
||||||
Follow step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
|
Follow step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
|
||||||
|
|
||||||
|
@ -284,6 +292,7 @@ python lerobot/scripts/control_robot.py \
|
||||||
--control.arms='["main_leader"]'
|
--control.arms='["main_leader"]'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## F. Teleoperate
|
||||||
## F. Teleoperate
|
## F. Teleoperate
|
||||||
|
|
||||||
**Simple teleop**
|
**Simple teleop**
|
||||||
|
@ -296,6 +305,7 @@ python lerobot/scripts/control_robot.py \
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
#### a. Teleop with displaying cameras
|
||||||
#### a. Teleop with displaying cameras
|
#### a. Teleop with displaying cameras
|
||||||
Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
|
Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
|
||||||
```bash
|
```bash
|
||||||
|
@ -304,6 +314,7 @@ python lerobot/scripts/control_robot.py \
|
||||||
--control.type=teleoperate
|
--control.type=teleoperate
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## G. Record a dataset
|
||||||
## G. Record a dataset
|
## G. Record a dataset
|
||||||
|
|
||||||
Once you're familiar with teleoperation, you can record your first dataset with SO-100.
|
Once you're familiar with teleoperation, you can record your first dataset with SO-100.
|
||||||
|
@ -337,6 +348,7 @@ python lerobot/scripts/control_robot.py \
|
||||||
|
|
||||||
Note: You can resume recording by adding `--control.resume=true`. Also if you didn't push your dataset yet, add `--control.local_files_only=true`.
|
Note: You can resume recording by adding `--control.resume=true`. Also if you didn't push your dataset yet, add `--control.local_files_only=true`.
|
||||||
|
|
||||||
|
## H. Visualize a dataset
|
||||||
## H. Visualize a dataset
|
## H. Visualize a dataset
|
||||||
|
|
||||||
If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
|
If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
|
||||||
|
@ -351,6 +363,7 @@ python lerobot/scripts/visualize_dataset_html.py \
|
||||||
--local-files-only 1
|
--local-files-only 1
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## I. Replay an episode
|
||||||
## I. Replay an episode
|
## I. Replay an episode
|
||||||
|
|
||||||
Now try to replay the first episode on your robot:
|
Now try to replay the first episode on your robot:
|
||||||
|
@ -365,6 +378,7 @@ python lerobot/scripts/control_robot.py \
|
||||||
|
|
||||||
Note: If you didn't push your dataset yet, add `--control.local_files_only=true`.
|
Note: If you didn't push your dataset yet, add `--control.local_files_only=true`.
|
||||||
|
|
||||||
|
## J. Train a policy
|
||||||
## J. Train a policy
|
## J. Train a policy
|
||||||
|
|
||||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||||
|
@ -388,6 +402,7 @@ Let's explain it:
|
||||||
|
|
||||||
Training should take several hours. You will find checkpoints in `outputs/train/act_so100_test/checkpoints`.
|
Training should take several hours. You will find checkpoints in `outputs/train/act_so100_test/checkpoints`.
|
||||||
|
|
||||||
|
## K. Evaluate your policy
|
||||||
## K. Evaluate your policy
|
## K. Evaluate your policy
|
||||||
|
|
||||||
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
||||||
|
@ -411,6 +426,7 @@ As you can see, it's almost the same command as previously used to record your t
|
||||||
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_so100_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_so100_test`).
|
1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_so100_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_so100_test`).
|
||||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_so100_test`).
|
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_so100_test`).
|
||||||
|
|
||||||
|
## L. More Information
|
||||||
## L. More Information
|
## L. More Information
|
||||||
|
|
||||||
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth tutorial on controlling real robots with LeRobot.
|
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth tutorial on controlling real robots with LeRobot.
|
||||||
|
|
|
@ -12,6 +12,10 @@ import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import (
|
||||||
|
LeRobotDataset,
|
||||||
|
LeRobotDatasetMetadata,
|
||||||
|
)
|
||||||
from lerobot.common.datasets.lerobot_dataset import (
|
from lerobot.common.datasets.lerobot_dataset import (
|
||||||
LeRobotDataset,
|
LeRobotDataset,
|
||||||
LeRobotDatasetMetadata,
|
LeRobotDatasetMetadata,
|
||||||
|
|
|
@ -61,6 +61,9 @@ class RandomSubsetApply(Transform):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"n_subset should be in the interval [1, {len(transforms)}]"
|
f"n_subset should be in the interval [1, {len(transforms)}]"
|
||||||
)
|
)
|
||||||
|
raise ValueError(
|
||||||
|
f"n_subset should be in the interval [1, {len(transforms)}]"
|
||||||
|
)
|
||||||
|
|
||||||
self.transforms = transforms
|
self.transforms = transforms
|
||||||
total = sum(p)
|
total = sum(p)
|
||||||
|
@ -124,6 +127,9 @@ class SharpnessJitter(Transform):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"If sharpness is a single number, it must be non negative."
|
"If sharpness is a single number, it must be non negative."
|
||||||
)
|
)
|
||||||
|
raise ValueError(
|
||||||
|
"If sharpness is a single number, it must be non negative."
|
||||||
|
)
|
||||||
sharpness = [1.0 - sharpness, 1.0 + sharpness]
|
sharpness = [1.0 - sharpness, 1.0 + sharpness]
|
||||||
sharpness[0] = max(sharpness[0], 0.0)
|
sharpness[0] = max(sharpness[0], 0.0)
|
||||||
elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2:
|
elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2:
|
||||||
|
@ -132,11 +138,17 @@ class SharpnessJitter(Transform):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"{sharpness=} should be a single number or a sequence with length 2."
|
f"{sharpness=} should be a single number or a sequence with length 2."
|
||||||
)
|
)
|
||||||
|
raise TypeError(
|
||||||
|
f"{sharpness=} should be a single number or a sequence with length 2."
|
||||||
|
)
|
||||||
|
|
||||||
if not 0.0 <= sharpness[0] <= sharpness[1]:
|
if not 0.0 <= sharpness[0] <= sharpness[1]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"sharpnesss values should be between (0., inf), but got {sharpness}."
|
f"sharpnesss values should be between (0., inf), but got {sharpness}."
|
||||||
)
|
)
|
||||||
|
raise ValueError(
|
||||||
|
f"sharpnesss values should be between (0., inf), but got {sharpness}."
|
||||||
|
)
|
||||||
|
|
||||||
return float(sharpness[0]), float(sharpness[1])
|
return float(sharpness[0]), float(sharpness[1])
|
||||||
|
|
||||||
|
|
|
@ -121,6 +121,10 @@ DATASETS = {
|
||||||
"single_task": "Pick up the candy and unwrap it.",
|
"single_task": "Pick up the candy and unwrap it.",
|
||||||
**ALOHA_STATIC_INFO,
|
**ALOHA_STATIC_INFO,
|
||||||
},
|
},
|
||||||
|
"aloha_static_candy": {
|
||||||
|
"single_task": "Pick up the candy and unwrap it.",
|
||||||
|
**ALOHA_STATIC_INFO,
|
||||||
|
},
|
||||||
"aloha_static_coffee": {
|
"aloha_static_coffee": {
|
||||||
"single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray, then push the 'Hot Water' and 'Travel Mug' buttons.",
|
"single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray, then push the 'Hot Water' and 'Travel Mug' buttons.",
|
||||||
**ALOHA_STATIC_INFO,
|
**ALOHA_STATIC_INFO,
|
||||||
|
@ -162,10 +166,12 @@ DATASETS = {
|
||||||
**ALOHA_STATIC_INFO,
|
**ALOHA_STATIC_INFO,
|
||||||
},
|
},
|
||||||
"aloha_static_vinh_cup": {
|
"aloha_static_vinh_cup": {
|
||||||
|
"single_task": "Pick up the plastic cup with the right arm, then pop its lid open with the left arm.",
|
||||||
"single_task": "Pick up the plastic cup with the right arm, then pop its lid open with the left arm.",
|
"single_task": "Pick up the plastic cup with the right arm, then pop its lid open with the left arm.",
|
||||||
**ALOHA_STATIC_INFO,
|
**ALOHA_STATIC_INFO,
|
||||||
},
|
},
|
||||||
"aloha_static_vinh_cup_left": {
|
"aloha_static_vinh_cup_left": {
|
||||||
|
"single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.",
|
||||||
"single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.",
|
"single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.",
|
||||||
**ALOHA_STATIC_INFO,
|
**ALOHA_STATIC_INFO,
|
||||||
},
|
},
|
||||||
|
@ -177,6 +183,14 @@ DATASETS = {
|
||||||
"single_task": "Insert the peg into the socket.",
|
"single_task": "Insert the peg into the socket.",
|
||||||
**ALOHA_STATIC_INFO,
|
**ALOHA_STATIC_INFO,
|
||||||
},
|
},
|
||||||
|
"aloha_static_ziploc_slide": {
|
||||||
|
"single_task": "Slide open the ziploc bag.",
|
||||||
|
**ALOHA_STATIC_INFO,
|
||||||
|
},
|
||||||
|
"aloha_sim_insertion_scripted": {
|
||||||
|
"single_task": "Insert the peg into the socket.",
|
||||||
|
**ALOHA_STATIC_INFO,
|
||||||
|
},
|
||||||
"aloha_sim_insertion_scripted_image": {
|
"aloha_sim_insertion_scripted_image": {
|
||||||
"single_task": "Insert the peg into the socket.",
|
"single_task": "Insert the peg into the socket.",
|
||||||
**ALOHA_STATIC_INFO,
|
**ALOHA_STATIC_INFO,
|
||||||
|
@ -185,6 +199,10 @@ DATASETS = {
|
||||||
"single_task": "Insert the peg into the socket.",
|
"single_task": "Insert the peg into the socket.",
|
||||||
**ALOHA_STATIC_INFO,
|
**ALOHA_STATIC_INFO,
|
||||||
},
|
},
|
||||||
|
"aloha_sim_insertion_human": {
|
||||||
|
"single_task": "Insert the peg into the socket.",
|
||||||
|
**ALOHA_STATIC_INFO,
|
||||||
|
},
|
||||||
"aloha_sim_insertion_human_image": {
|
"aloha_sim_insertion_human_image": {
|
||||||
"single_task": "Insert the peg into the socket.",
|
"single_task": "Insert the peg into the socket.",
|
||||||
**ALOHA_STATIC_INFO,
|
**ALOHA_STATIC_INFO,
|
||||||
|
@ -213,11 +231,23 @@ DATASETS = {
|
||||||
"single_task": "Push the T-shaped block onto the T-shaped target.",
|
"single_task": "Push the T-shaped block onto the T-shaped target.",
|
||||||
**PUSHT_INFO,
|
**PUSHT_INFO,
|
||||||
},
|
},
|
||||||
|
"pusht": {
|
||||||
|
"single_task": "Push the T-shaped block onto the T-shaped target.",
|
||||||
|
**PUSHT_INFO,
|
||||||
|
},
|
||||||
|
"pusht_image": {
|
||||||
|
"single_task": "Push the T-shaped block onto the T-shaped target.",
|
||||||
|
**PUSHT_INFO,
|
||||||
|
},
|
||||||
"unitreeh1_fold_clothes": {"single_task": "Fold the sweatshirt.", **UNITREEH_INFO},
|
"unitreeh1_fold_clothes": {"single_task": "Fold the sweatshirt.", **UNITREEH_INFO},
|
||||||
"unitreeh1_rearrange_objects": {
|
"unitreeh1_rearrange_objects": {
|
||||||
"single_task": "Put the object into the bin.",
|
"single_task": "Put the object into the bin.",
|
||||||
**UNITREEH_INFO,
|
**UNITREEH_INFO,
|
||||||
},
|
},
|
||||||
|
"unitreeh1_rearrange_objects": {
|
||||||
|
"single_task": "Put the object into the bin.",
|
||||||
|
**UNITREEH_INFO,
|
||||||
|
},
|
||||||
"unitreeh1_two_robot_greeting": {
|
"unitreeh1_two_robot_greeting": {
|
||||||
"single_task": "Greet the other robot with a high five.",
|
"single_task": "Greet the other robot with a high five.",
|
||||||
**UNITREEH_INFO,
|
**UNITREEH_INFO,
|
||||||
|
@ -239,6 +269,18 @@ DATASETS = {
|
||||||
"single_task": "Pick up the cube and lift it.",
|
"single_task": "Pick up the cube and lift it.",
|
||||||
**XARM_INFO,
|
**XARM_INFO,
|
||||||
},
|
},
|
||||||
|
"xarm_lift_medium_image": {
|
||||||
|
"single_task": "Pick up the cube and lift it.",
|
||||||
|
**XARM_INFO,
|
||||||
|
},
|
||||||
|
"xarm_lift_medium_replay": {
|
||||||
|
"single_task": "Pick up the cube and lift it.",
|
||||||
|
**XARM_INFO,
|
||||||
|
},
|
||||||
|
"xarm_lift_medium_replay_image": {
|
||||||
|
"single_task": "Pick up the cube and lift it.",
|
||||||
|
**XARM_INFO,
|
||||||
|
},
|
||||||
"xarm_push_medium": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
"xarm_push_medium": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||||
"xarm_push_medium_image": {
|
"xarm_push_medium_image": {
|
||||||
"single_task": "Push the cube onto the target.",
|
"single_task": "Push the cube onto the target.",
|
||||||
|
@ -252,6 +294,18 @@ DATASETS = {
|
||||||
"single_task": "Push the cube onto the target.",
|
"single_task": "Push the cube onto the target.",
|
||||||
**XARM_INFO,
|
**XARM_INFO,
|
||||||
},
|
},
|
||||||
|
"xarm_push_medium_image": {
|
||||||
|
"single_task": "Push the cube onto the target.",
|
||||||
|
**XARM_INFO,
|
||||||
|
},
|
||||||
|
"xarm_push_medium_replay": {
|
||||||
|
"single_task": "Push the cube onto the target.",
|
||||||
|
**XARM_INFO,
|
||||||
|
},
|
||||||
|
"xarm_push_medium_replay_image": {
|
||||||
|
"single_task": "Push the cube onto the target.",
|
||||||
|
**XARM_INFO,
|
||||||
|
},
|
||||||
"umi_cup_in_the_wild": {
|
"umi_cup_in_the_wild": {
|
||||||
"single_task": "Put the cup on the plate.",
|
"single_task": "Put the cup on the plate.",
|
||||||
"license": "apache-2.0",
|
"license": "apache-2.0",
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import importlib
|
import importlib
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
import gymnasium as gym
|
import gymnasium as gym
|
||||||
|
|
||||||
|
@ -164,3 +165,99 @@ class ConvertToLeRobotEnv(gym.Wrapper):
|
||||||
ret["state"] = observation
|
ret["state"] = observation
|
||||||
ret["pixels"] = images
|
ret["pixels"] = images
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def make_maniskill_env(
|
||||||
|
cfg: DictConfig, n_envs: int | None = None
|
||||||
|
) -> gym.vector.VectorEnv | None:
|
||||||
|
"""Make ManiSkill3 gym environment"""
|
||||||
|
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||||
|
|
||||||
|
env = gym.make(
|
||||||
|
cfg.env.task,
|
||||||
|
obs_mode=cfg.env.obs,
|
||||||
|
control_mode=cfg.env.control_mode,
|
||||||
|
render_mode=cfg.env.render_mode,
|
||||||
|
sensor_configs=dict(width=cfg.env.image_size, height=cfg.env.image_size),
|
||||||
|
num_envs=n_envs,
|
||||||
|
)
|
||||||
|
# cfg.env_cfg.control_mode = cfg.eval_env_cfg.control_mode = env.control_mode
|
||||||
|
env = ManiSkillVectorEnv(env, ignore_terminations=True)
|
||||||
|
# state should have the size of 25
|
||||||
|
# env = ConvertToLeRobotEnv(env, n_envs)
|
||||||
|
# env = PixelWrapper(cfg, env, n_envs)
|
||||||
|
env._max_episode_steps = env.max_episode_steps = (
|
||||||
|
50 # gym_utils.find_max_episode_steps_value(env)
|
||||||
|
)
|
||||||
|
env.unwrapped.metadata["render_fps"] = 20
|
||||||
|
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
class PixelWrapper(gym.Wrapper):
|
||||||
|
"""
|
||||||
|
Wrapper for pixel observations. Works with Maniskill vectorized environments
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg, env, num_envs, num_frames=3):
|
||||||
|
super().__init__(env)
|
||||||
|
self.cfg = cfg
|
||||||
|
self.env = env
|
||||||
|
self.observation_space = gym.spaces.Box(
|
||||||
|
low=0,
|
||||||
|
high=255,
|
||||||
|
shape=(num_envs, num_frames * 3, cfg.env.render_size, cfg.env.render_size),
|
||||||
|
dtype=np.uint8,
|
||||||
|
)
|
||||||
|
self._frames = deque([], maxlen=num_frames)
|
||||||
|
self._render_size = cfg.env.render_size
|
||||||
|
|
||||||
|
def _get_obs(self, obs):
|
||||||
|
frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2)
|
||||||
|
self._frames.append(frame)
|
||||||
|
return {
|
||||||
|
"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(
|
||||||
|
self.env.device
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
def reset(self, seed):
|
||||||
|
obs, info = self.env.reset() # (seed=seed)
|
||||||
|
for _ in range(self._frames.maxlen):
|
||||||
|
obs_frames = self._get_obs(obs)
|
||||||
|
return obs_frames, info
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||||
|
return self._get_obs(obs), reward, terminated, truncated, info
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Remove this
|
||||||
|
class ConvertToLeRobotEnv(gym.Wrapper):
|
||||||
|
def __init__(self, env, num_envs):
|
||||||
|
super().__init__(env)
|
||||||
|
|
||||||
|
def reset(self, seed=None, options=None):
|
||||||
|
obs, info = self.env.reset(seed=seed, options={})
|
||||||
|
return self._get_obs(obs), info
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||||
|
return self._get_obs(obs), reward, terminated, truncated, info
|
||||||
|
|
||||||
|
def _get_obs(self, observation):
|
||||||
|
sensor_data = observation.pop("sensor_data")
|
||||||
|
del observation["sensor_param"]
|
||||||
|
images = []
|
||||||
|
for cam_data in sensor_data.values():
|
||||||
|
images.append(cam_data["rgb"])
|
||||||
|
|
||||||
|
images = torch.concat(images, axis=-1)
|
||||||
|
# flatten the rest of the data which should just be state data
|
||||||
|
observation = common.flatten_state_dict(
|
||||||
|
observation, use_torch=True, device=self.base_env.device
|
||||||
|
)
|
||||||
|
ret = dict()
|
||||||
|
ret["state"] = observation
|
||||||
|
ret["pixels"] = images
|
||||||
|
return ret
|
||||||
|
|
|
@ -36,6 +36,12 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||||
# TODO: You have to merge all tensors from agent key and extra key
|
# TODO: You have to merge all tensors from agent key and extra key
|
||||||
# You don't keep sensor param key in the observation
|
# You don't keep sensor param key in the observation
|
||||||
# And you keep sensor data rgb
|
# And you keep sensor data rgb
|
||||||
|
for key, img in observations.items():
|
||||||
|
if "images" not in key:
|
||||||
|
continue
|
||||||
|
# TODO: You have to merge all tensors from agent key and extra key
|
||||||
|
# You don't keep sensor param key in the observation
|
||||||
|
# And you keep sensor data rgb
|
||||||
for key, img in observations.items():
|
for key, img in observations.items():
|
||||||
if "images" not in key:
|
if "images" not in key:
|
||||||
continue
|
continue
|
||||||
|
@ -50,7 +56,13 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||||
|
|
||||||
# sanity check that images are uint8
|
# sanity check that images are uint8
|
||||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||||
|
# sanity check that images are uint8
|
||||||
|
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||||
|
|
||||||
|
# convert to channel first of type float32 in range [0,1]
|
||||||
|
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||||
|
img = img.type(torch.float32)
|
||||||
|
img /= 255
|
||||||
# convert to channel first of type float32 in range [0,1]
|
# convert to channel first of type float32 in range [0,1]
|
||||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||||
img = img.type(torch.float32)
|
img = img.type(torch.float32)
|
||||||
|
@ -59,6 +71,9 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||||
return_observations[key] = img
|
return_observations[key] = img
|
||||||
# obs state agent qpos and qvel
|
# obs state agent qpos and qvel
|
||||||
# image
|
# image
|
||||||
|
return_observations[key] = img
|
||||||
|
# obs state agent qpos and qvel
|
||||||
|
# image
|
||||||
|
|
||||||
if "environment_state" in observations:
|
if "environment_state" in observations:
|
||||||
return_observations["observation.environment_state"] = torch.from_numpy(
|
return_observations["observation.environment_state"] = torch.from_numpy(
|
||||||
|
|
|
@ -12,6 +12,7 @@ from functools import cache
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
from deepdiff import DeepDiff
|
from deepdiff import DeepDiff
|
||||||
|
@ -25,6 +26,9 @@ from lerobot.common.robot_devices.utils import busy_wait
|
||||||
from lerobot.common.utils.utils import get_safe_torch_device, has_method
|
from lerobot.common.utils.utils import get_safe_torch_device, has_method
|
||||||
|
|
||||||
|
|
||||||
|
def log_control_info(
|
||||||
|
robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None
|
||||||
|
):
|
||||||
def log_control_info(
|
def log_control_info(
|
||||||
robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None
|
robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None
|
||||||
):
|
):
|
||||||
|
@ -37,6 +41,7 @@ def log_control_info(
|
||||||
def log_dt(shortname, dt_val_s):
|
def log_dt(shortname, dt_val_s):
|
||||||
nonlocal log_items, fps
|
nonlocal log_items, fps
|
||||||
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)"
|
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)"
|
||||||
|
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)"
|
||||||
if fps is not None:
|
if fps is not None:
|
||||||
actual_fps = 1 / dt_val_s
|
actual_fps = 1 / dt_val_s
|
||||||
if actual_fps < fps - 1:
|
if actual_fps < fps - 1:
|
||||||
|
@ -97,6 +102,9 @@ def predict_action(observation, policy, device, use_amp):
|
||||||
torch.autocast(device_type=device.type)
|
torch.autocast(device_type=device.type)
|
||||||
if device.type == "cuda" and use_amp
|
if device.type == "cuda" and use_amp
|
||||||
else nullcontext(),
|
else nullcontext(),
|
||||||
|
torch.autocast(device_type=device.type)
|
||||||
|
if device.type == "cuda" and use_amp
|
||||||
|
else nullcontext(),
|
||||||
):
|
):
|
||||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||||
for name in observation:
|
for name in observation:
|
||||||
|
@ -119,6 +127,16 @@ def predict_action(observation, policy, device, use_amp):
|
||||||
return action
|
return action
|
||||||
|
|
||||||
|
|
||||||
|
def init_keyboard_listener(assign_rewards=False):
|
||||||
|
"""
|
||||||
|
Initializes a keyboard listener to enable early termination of an episode
|
||||||
|
or environment reset by pressing the right arrow key ('->'). This may require
|
||||||
|
sudo permissions to allow the terminal to monitor keyboard events.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
assign_rewards (bool): If True, allows annotating the collected trajectory
|
||||||
|
with a binary reward at the end of the episode to indicate success.
|
||||||
|
"""
|
||||||
def init_keyboard_listener(assign_rewards=False):
|
def init_keyboard_listener(assign_rewards=False):
|
||||||
"""
|
"""
|
||||||
Initializes a keyboard listener to enable early termination of an episode
|
Initializes a keyboard listener to enable early termination of an episode
|
||||||
|
@ -135,6 +153,8 @@ def init_keyboard_listener(assign_rewards=False):
|
||||||
events["stop_recording"] = False
|
events["stop_recording"] = False
|
||||||
if assign_rewards:
|
if assign_rewards:
|
||||||
events["next.reward"] = 0
|
events["next.reward"] = 0
|
||||||
|
if assign_rewards:
|
||||||
|
events["next.reward"] = 0
|
||||||
|
|
||||||
if is_headless():
|
if is_headless():
|
||||||
logging.warning(
|
logging.warning(
|
||||||
|
@ -152,6 +172,9 @@ def init_keyboard_listener(assign_rewards=False):
|
||||||
print("Right arrow key pressed. Exiting loop...")
|
print("Right arrow key pressed. Exiting loop...")
|
||||||
events["exit_early"] = True
|
events["exit_early"] = True
|
||||||
elif key == keyboard.Key.left:
|
elif key == keyboard.Key.left:
|
||||||
|
print(
|
||||||
|
"Left arrow key pressed. Exiting loop and rerecord the last episode..."
|
||||||
|
)
|
||||||
print(
|
print(
|
||||||
"Left arrow key pressed. Exiting loop and rerecord the last episode..."
|
"Left arrow key pressed. Exiting loop and rerecord the last episode..."
|
||||||
)
|
)
|
||||||
|
@ -168,6 +191,13 @@ def init_keyboard_listener(assign_rewards=False):
|
||||||
events["next.reward"],
|
events["next.reward"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif assign_rewards and key == keyboard.Key.space:
|
||||||
|
events["next.reward"] = 1 if events["next.reward"] == 0 else 0
|
||||||
|
print(
|
||||||
|
"Space key pressed. Assigning new reward to the subsequent frames. New reward:",
|
||||||
|
events["next.reward"],
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error handling key press: {e}")
|
print(f"Error handling key press: {e}")
|
||||||
|
|
||||||
|
@ -206,6 +236,7 @@ def record_episode(
|
||||||
use_amp,
|
use_amp,
|
||||||
fps,
|
fps,
|
||||||
record_delta_actions,
|
record_delta_actions,
|
||||||
|
record_delta_actions,
|
||||||
):
|
):
|
||||||
control_loop(
|
control_loop(
|
||||||
robot=robot,
|
robot=robot,
|
||||||
|
@ -218,6 +249,7 @@ def record_episode(
|
||||||
use_amp=use_amp,
|
use_amp=use_amp,
|
||||||
fps=fps,
|
fps=fps,
|
||||||
record_delta_actions=record_delta_actions,
|
record_delta_actions=record_delta_actions,
|
||||||
|
record_delta_actions=record_delta_actions,
|
||||||
teleoperate=policy is None,
|
teleoperate=policy is None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -261,10 +293,14 @@ def control_loop(
|
||||||
|
|
||||||
current_joint_positions = robot.follower_arms["main"].read("Present_Position")
|
current_joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||||
|
|
||||||
|
current_joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||||
|
|
||||||
if teleoperate:
|
if teleoperate:
|
||||||
observation, action = robot.teleop_step(record_data=True)
|
observation, action = robot.teleop_step(record_data=True)
|
||||||
if record_delta_actions:
|
if record_delta_actions:
|
||||||
action["action"] = action["action"] - current_joint_positions
|
action["action"] = action["action"] - current_joint_positions
|
||||||
|
if record_delta_actions:
|
||||||
|
action["action"] = action["action"] - current_joint_positions
|
||||||
else:
|
else:
|
||||||
observation = robot.capture_observation()
|
observation = robot.capture_observation()
|
||||||
|
|
||||||
|
@ -277,6 +313,11 @@ def control_loop(
|
||||||
|
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
frame = {**observation, **action}
|
frame = {**observation, **action}
|
||||||
|
if "next.reward" in events:
|
||||||
|
frame["next.reward"] = events["next.reward"]
|
||||||
|
frame["next.done"] = (events["next.reward"] == 1) or (
|
||||||
|
events["exit_early"]
|
||||||
|
)
|
||||||
if "next.reward" in events:
|
if "next.reward" in events:
|
||||||
frame["next.reward"] = events["next.reward"]
|
frame["next.reward"] = events["next.reward"]
|
||||||
frame["next.done"] = (events["next.reward"] == 1) or (
|
frame["next.done"] = (events["next.reward"] == 1) or (
|
||||||
|
@ -287,12 +328,18 @@ def control_loop(
|
||||||
# if frame["next.done"]:
|
# if frame["next.done"]:
|
||||||
# break
|
# break
|
||||||
|
|
||||||
|
# if frame["next.done"]:
|
||||||
|
# break
|
||||||
|
|
||||||
if display_cameras and not is_headless():
|
if display_cameras and not is_headless():
|
||||||
image_keys = [key for key in observation if "image" in key]
|
image_keys = [key for key in observation if "image" in key]
|
||||||
for key in image_keys:
|
for key in image_keys:
|
||||||
cv2.imshow(
|
cv2.imshow(
|
||||||
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
|
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
|
||||||
)
|
)
|
||||||
|
cv2.imshow(
|
||||||
|
key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
|
||||||
|
)
|
||||||
cv2.waitKey(1)
|
cv2.waitKey(1)
|
||||||
|
|
||||||
if fps is not None:
|
if fps is not None:
|
||||||
|
@ -318,6 +365,8 @@ def reset_environment(robot, events, reset_time_s):
|
||||||
start_vencod_t = time.perf_counter()
|
start_vencod_t = time.perf_counter()
|
||||||
if "next.reward" in events:
|
if "next.reward" in events:
|
||||||
events["next.reward"] = 0
|
events["next.reward"] = 0
|
||||||
|
if "next.reward" in events:
|
||||||
|
events["next.reward"] = 0
|
||||||
|
|
||||||
# Wait if necessary
|
# Wait if necessary
|
||||||
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
|
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
|
||||||
|
@ -340,6 +389,16 @@ def reset_follower_position(robot: Robot, target_position):
|
||||||
busy_wait(0.015)
|
busy_wait(0.015)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_follower_position(robot: Robot, target_position):
|
||||||
|
current_position = robot.follower_arms["main"].read("Present_Position")
|
||||||
|
trajectory = torch.from_numpy(
|
||||||
|
np.linspace(current_position, target_position, 50)
|
||||||
|
) # NOTE: 30 is just an aribtrary number
|
||||||
|
for pose in trajectory:
|
||||||
|
robot.send_action(pose)
|
||||||
|
busy_wait(0.015)
|
||||||
|
|
||||||
|
|
||||||
def stop_recording(robot, listener, display_cameras):
|
def stop_recording(robot, listener, display_cameras):
|
||||||
robot.disconnect()
|
robot.disconnect()
|
||||||
|
|
||||||
|
@ -375,19 +434,32 @@ def sanity_check_dataset_robot_compatibility(
|
||||||
fps: int,
|
fps: int,
|
||||||
use_videos: bool,
|
use_videos: bool,
|
||||||
extra_features: dict = None,
|
extra_features: dict = None,
|
||||||
|
dataset: LeRobotDataset,
|
||||||
|
robot: Robot,
|
||||||
|
fps: int,
|
||||||
|
use_videos: bool,
|
||||||
|
extra_features: dict = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
features_from_robot = get_features_from_robot(robot, use_videos)
|
features_from_robot = get_features_from_robot(robot, use_videos)
|
||||||
if extra_features is not None:
|
if extra_features is not None:
|
||||||
features_from_robot.update(extra_features)
|
features_from_robot.update(extra_features)
|
||||||
|
|
||||||
|
features_from_robot = get_features_from_robot(robot, use_videos)
|
||||||
|
if extra_features is not None:
|
||||||
|
features_from_robot.update(extra_features)
|
||||||
|
|
||||||
fields = [
|
fields = [
|
||||||
("robot_type", dataset.meta.robot_type, robot.robot_type),
|
("robot_type", dataset.meta.robot_type, robot.robot_type),
|
||||||
("fps", dataset.fps, fps),
|
("fps", dataset.fps, fps),
|
||||||
("features", dataset.features, features_from_robot),
|
("features", dataset.features, features_from_robot),
|
||||||
|
("features", dataset.features, features_from_robot),
|
||||||
]
|
]
|
||||||
|
|
||||||
mismatches = []
|
mismatches = []
|
||||||
for field, dataset_value, present_value in fields:
|
for field, dataset_value, present_value in fields:
|
||||||
|
diff = DeepDiff(
|
||||||
|
dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"]
|
||||||
|
)
|
||||||
diff = DeepDiff(
|
diff = DeepDiff(
|
||||||
dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"]
|
dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"]
|
||||||
)
|
)
|
||||||
|
@ -398,4 +470,6 @@ def sanity_check_dataset_robot_compatibility(
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Dataset metadata compatibility check failed with mismatches:\n"
|
"Dataset metadata compatibility check failed with mismatches:\n"
|
||||||
+ "\n".join(mismatches)
|
+ "\n".join(mismatches)
|
||||||
|
"Dataset metadata compatibility check failed with mismatches:\n"
|
||||||
|
+ "\n".join(mismatches)
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,53 @@
|
||||||
|
_target_: lerobot.common.robot_devices.robots.manipulator.ManipulatorRobot
|
||||||
|
robot_type: koch
|
||||||
|
calibration_dir: .cache/calibration/koch
|
||||||
|
|
||||||
|
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||||
|
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||||
|
# the number of motors in your follower arms.
|
||||||
|
max_relative_target: null
|
||||||
|
|
||||||
|
leader_arms:
|
||||||
|
main:
|
||||||
|
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||||
|
port: /dev/tty.usbmodem58760430441
|
||||||
|
motors:
|
||||||
|
# name: (index, model)
|
||||||
|
shoulder_pan: [1, "xl330-m077"]
|
||||||
|
shoulder_lift: [2, "xl330-m077"]
|
||||||
|
elbow_flex: [3, "xl330-m077"]
|
||||||
|
wrist_flex: [4, "xl330-m077"]
|
||||||
|
wrist_roll: [5, "xl330-m077"]
|
||||||
|
gripper: [6, "xl330-m077"]
|
||||||
|
|
||||||
|
follower_arms:
|
||||||
|
main:
|
||||||
|
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||||
|
port: /dev/tty.usbmodem585A0083391
|
||||||
|
motors:
|
||||||
|
# name: (index, model)
|
||||||
|
shoulder_pan: [1, "xl430-w250"]
|
||||||
|
shoulder_lift: [2, "xl430-w250"]
|
||||||
|
elbow_flex: [3, "xl330-m288"]
|
||||||
|
wrist_flex: [4, "xl330-m288"]
|
||||||
|
wrist_roll: [5, "xl330-m288"]
|
||||||
|
gripper: [6, "xl330-m288"]
|
||||||
|
|
||||||
|
cameras:
|
||||||
|
laptop:
|
||||||
|
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||||
|
camera_index: 0
|
||||||
|
fps: 30
|
||||||
|
width: 640
|
||||||
|
height: 480
|
||||||
|
phone:
|
||||||
|
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||||
|
camera_index: 1
|
||||||
|
fps: 30
|
||||||
|
width: 640
|
||||||
|
height: 480
|
||||||
|
|
||||||
|
# ~ Koch specific settings ~
|
||||||
|
# Sets the leader arm in torque mode with the gripper motor set to this angle. This makes it possible
|
||||||
|
# to squeeze the gripper and have it spring back to an open position on its own.
|
||||||
|
gripper_open_degree: 35.156
|
|
@ -98,6 +98,9 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"No motors detected. Please ensure you have one motor connected."
|
"No motors detected. Please ensure you have one motor connected."
|
||||||
)
|
)
|
||||||
|
raise ValueError(
|
||||||
|
"No motors detected. Please ensure you have one motor connected."
|
||||||
|
)
|
||||||
|
|
||||||
print(f"Motor index found at: {motor_index}")
|
print(f"Motor index found at: {motor_index}")
|
||||||
|
|
||||||
|
@ -106,6 +109,9 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
|
||||||
motor_bus.write_with_motor_ids(
|
motor_bus.write_with_motor_ids(
|
||||||
motor_bus.motor_models, motor_index, "Lock", 0
|
motor_bus.motor_models, motor_index, "Lock", 0
|
||||||
)
|
)
|
||||||
|
motor_bus.write_with_motor_ids(
|
||||||
|
motor_bus.motor_models, motor_index, "Lock", 0
|
||||||
|
)
|
||||||
|
|
||||||
if baudrate != baudrate_des:
|
if baudrate != baudrate_des:
|
||||||
print(f"Setting its baudrate to {baudrate_des}")
|
print(f"Setting its baudrate to {baudrate_des}")
|
||||||
|
@ -115,6 +121,9 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
|
||||||
motor_bus.write_with_motor_ids(
|
motor_bus.write_with_motor_ids(
|
||||||
motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx
|
motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx
|
||||||
)
|
)
|
||||||
|
motor_bus.write_with_motor_ids(
|
||||||
|
motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx
|
||||||
|
)
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
motor_bus.set_bus_baudrate(baudrate_des)
|
motor_bus.set_bus_baudrate(baudrate_des)
|
||||||
present_baudrate_idx = motor_bus.read_with_motor_ids(
|
present_baudrate_idx = motor_bus.read_with_motor_ids(
|
||||||
|
@ -129,6 +138,9 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
|
||||||
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0)
|
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0)
|
||||||
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "ID", motor_idx_des)
|
motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "ID", motor_idx_des)
|
||||||
|
|
||||||
|
present_idx = motor_bus.read_with_motor_ids(
|
||||||
|
motor_bus.motor_models, motor_idx_des, "ID", num_retry=2
|
||||||
|
)
|
||||||
present_idx = motor_bus.read_with_motor_ids(
|
present_idx = motor_bus.read_with_motor_ids(
|
||||||
motor_bus.motor_models, motor_idx_des, "ID", num_retry=2
|
motor_bus.motor_models, motor_idx_des, "ID", num_retry=2
|
||||||
)
|
)
|
||||||
|
@ -178,6 +190,28 @@ if __name__ == "__main__":
|
||||||
help="Desired ID of the current motor (e.g. 1,2,3)",
|
help="Desired ID of the current motor (e.g. 1,2,3)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
"--port",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Motors bus port (e.g. dynamixel,feetech)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ID",
|
||||||
|
type=int,
|
||||||
|
required=True,
|
||||||
|
help="Desired ID of the current motor (e.g. 1,2,3)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--baudrate",
|
||||||
|
type=int,
|
||||||
|
default=1000000,
|
||||||
|
help="Desired baudrate for the motor (default: 1000000)",
|
||||||
"--baudrate",
|
"--baudrate",
|
||||||
type=int,
|
type=int,
|
||||||
default=1000000,
|
default=1000000,
|
||||||
|
|
|
@ -137,6 +137,7 @@ from lerobot.common.robot_devices.control_utils import (
|
||||||
record_episode,
|
record_episode,
|
||||||
reset_environment,
|
reset_environment,
|
||||||
reset_follower_position,
|
reset_follower_position,
|
||||||
|
reset_follower_position,
|
||||||
sanity_check_dataset_name,
|
sanity_check_dataset_name,
|
||||||
sanity_check_dataset_robot_compatibility,
|
sanity_check_dataset_robot_compatibility,
|
||||||
stop_recording,
|
stop_recording,
|
||||||
|
@ -244,10 +245,19 @@ def record(
|
||||||
# Load pretrained policy
|
# Load pretrained policy
|
||||||
policy = None if cfg.policy is None else make_policy(cfg.policy, cfg.device, ds_meta=dataset.meta)
|
policy = None if cfg.policy is None else make_policy(cfg.policy, cfg.device, ds_meta=dataset.meta)
|
||||||
|
|
||||||
|
# Load pretrained policy
|
||||||
|
policy = None if cfg.policy is None else make_policy(cfg.policy, cfg.device, ds_meta=dataset.meta)
|
||||||
|
|
||||||
|
# Load pretrained policy
|
||||||
|
policy = None if cfg.policy is None else make_policy(cfg.policy, cfg.device, ds_meta=dataset.meta)
|
||||||
|
|
||||||
if not robot.is_connected:
|
if not robot.is_connected:
|
||||||
robot.connect()
|
robot.connect()
|
||||||
listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
|
listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
|
||||||
|
listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
|
||||||
|
|
||||||
|
if reset_follower:
|
||||||
|
initial_position = robot.follower_arms["main"].read("Present_Position")
|
||||||
if reset_follower:
|
if reset_follower:
|
||||||
initial_position = robot.follower_arms["main"].read("Present_Position")
|
initial_position = robot.follower_arms["main"].read("Present_Position")
|
||||||
|
|
||||||
|
@ -335,10 +345,13 @@ def replay(
|
||||||
|
|
||||||
log_say("Replaying episode", cfg.play_sounds, blocking=True)
|
log_say("Replaying episode", cfg.play_sounds, blocking=True)
|
||||||
for idx in range(dataset.num_frames):
|
for idx in range(dataset.num_frames):
|
||||||
|
current_joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||||
current_joint_positions = robot.follower_arms["main"].read("Present_Position")
|
current_joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||||
start_episode_t = time.perf_counter()
|
start_episode_t = time.perf_counter()
|
||||||
|
|
||||||
action = actions[idx]["action"]
|
action = actions[idx]["action"]
|
||||||
|
if replay_delta_actions:
|
||||||
|
action = action + current_joint_positions
|
||||||
if replay_delta_actions:
|
if replay_delta_actions:
|
||||||
action = action + current_joint_positions
|
action = action + current_joint_positions
|
||||||
robot.send_action(action)
|
robot.send_action(action)
|
||||||
|
|
|
@ -76,6 +76,9 @@ def make_optimizer_and_scheduler(cfg, policy):
|
||||||
optimizer_params_dicts,
|
optimizer_params_dicts,
|
||||||
lr=cfg.training.lr,
|
lr=cfg.training.lr,
|
||||||
weight_decay=cfg.training.weight_decay,
|
weight_decay=cfg.training.weight_decay,
|
||||||
|
optimizer_params_dicts,
|
||||||
|
lr=cfg.training.lr,
|
||||||
|
weight_decay=cfg.training.weight_decay,
|
||||||
)
|
)
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
elif cfg.policy.name == "diffusion":
|
elif cfg.policy.name == "diffusion":
|
||||||
|
@ -98,6 +101,23 @@ def make_optimizer_and_scheduler(cfg, policy):
|
||||||
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
|
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
|
|
||||||
|
elif policy.name == "sac":
|
||||||
|
optimizer = torch.optim.Adam(
|
||||||
|
[
|
||||||
|
{"params": policy.actor.parameters(), "lr": policy.config.actor_lr},
|
||||||
|
{
|
||||||
|
"params": policy.critic_ensemble.parameters(),
|
||||||
|
"lr": policy.config.critic_lr,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": policy.temperature.parameters(),
|
||||||
|
"lr": policy.config.temperature_lr,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
lr_scheduler = None
|
||||||
|
|
||||||
|
|
||||||
elif policy.name == "sac":
|
elif policy.name == "sac":
|
||||||
optimizer = torch.optim.Adam(
|
optimizer = torch.optim.Adam(
|
||||||
[
|
[
|
||||||
|
@ -119,6 +139,10 @@ def make_optimizer_and_scheduler(cfg, policy):
|
||||||
VQBeTOptimizer,
|
VQBeTOptimizer,
|
||||||
VQBeTScheduler,
|
VQBeTScheduler,
|
||||||
)
|
)
|
||||||
|
from lerobot.common.policies.vqbet.modeling_vqbet import (
|
||||||
|
VQBeTOptimizer,
|
||||||
|
VQBeTScheduler,
|
||||||
|
)
|
||||||
|
|
||||||
optimizer = VQBeTOptimizer(policy, cfg)
|
optimizer = VQBeTOptimizer(policy, cfg)
|
||||||
lr_scheduler = VQBeTScheduler(optimizer, cfg)
|
lr_scheduler = VQBeTScheduler(optimizer, cfg)
|
||||||
|
@ -226,6 +250,9 @@ def train(cfg: TrainPipelineConfig):
|
||||||
if cfg.resume:
|
if cfg.resume:
|
||||||
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
|
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
|
||||||
|
|
||||||
|
num_learnable_params = sum(
|
||||||
|
p.numel() for p in policy.parameters() if p.requires_grad
|
||||||
|
)
|
||||||
num_learnable_params = sum(
|
num_learnable_params = sum(
|
||||||
p.numel() for p in policy.parameters() if p.requires_grad
|
p.numel() for p in policy.parameters() if p.requires_grad
|
||||||
)
|
)
|
||||||
|
|
|
@ -115,6 +115,8 @@ exclude = [
|
||||||
"venv",
|
"venv",
|
||||||
"*_pb2.py",
|
"*_pb2.py",
|
||||||
"*_pb2_grpc.py",
|
"*_pb2_grpc.py",
|
||||||
|
"*_pb2.py",
|
||||||
|
"*_pb2_grpc.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -335,6 +335,12 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
|
||||||
)
|
)
|
||||||
dataset = record(robot, rec_cfg)
|
dataset = record(robot, rec_cfg)
|
||||||
|
|
||||||
|
assert not mock_events[
|
||||||
|
"rerecord_episode"
|
||||||
|
], "`rerecord_episode` wasn't properly reset to False"
|
||||||
|
assert not mock_events[
|
||||||
|
"exit_early"
|
||||||
|
], "`exit_early` wasn't properly reset to False"
|
||||||
assert not mock_events[
|
assert not mock_events[
|
||||||
"rerecord_episode"
|
"rerecord_episode"
|
||||||
], "`rerecord_episode` wasn't properly reset to False"
|
], "`rerecord_episode` wasn't properly reset to False"
|
||||||
|
@ -398,6 +404,8 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"robot_type, mock, num_image_writer_processes",
|
"robot_type, mock, num_image_writer_processes",
|
||||||
[("koch", True, 0), ("koch", True, 1)],
|
[("koch", True, 0), ("koch", True, 1)],
|
||||||
|
"robot_type, mock, num_image_writer_processes",
|
||||||
|
[("koch", True, 0), ("koch", True, 1)],
|
||||||
)
|
)
|
||||||
@require_robot
|
@require_robot
|
||||||
def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num_image_writer_processes):
|
def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num_image_writer_processes):
|
||||||
|
|
Loading…
Reference in New Issue