Add reset-time-s, Add keyboard early exit, Add comments

This commit is contained in:
Remi Cadene 2024-07-12 12:58:56 +02:00
parent 1993d29296
commit 7a659dbd6b
3 changed files with 244 additions and 18 deletions

View File

@ -39,16 +39,24 @@ python lerobot/scripts/control_robot.py replay_episode \
--episode 0 --episode 0
``` ```
- Record a full dataset in order to train a policy: - Record a full dataset in order to train a policy, with 2 seconds of warmup,
30 seconds of recording for each episode, and 10 seconds to reset the environment in between episodes:
```bash ```bash
python lerobot/scripts/control_robot.py record_dataset \ python lerobot/scripts/control_robot.py record_dataset \
--fps 30 \ --fps 30 \
--root data \ --root data \
--repo-id $USER/koch_pick_place_lego \ --repo-id $USER/koch_pick_place_lego \
--num-episodes 50 \ --num-episodes 50 \
--run-compute-stats 1 --run-compute-stats 1 \
--warmup-time-s 2 \
--episode-time-s 30 \
--reset-time-s 10
``` ```
**NOTE**: You can early exit while recording an episode or resetting the environment,
by tapping the right arrow key '->'. This might require a sudo permission
to allow your terminal to monitor keyboard events.
- Train on this dataset with the ACT policy: - Train on this dataset with the ACT policy:
```bash ```bash
DATA_DIR=data python lerobot/scripts/train.py \ DATA_DIR=data python lerobot/scripts/train.py \
@ -77,6 +85,7 @@ from pathlib import Path
import torch import torch
from omegaconf import DictConfig from omegaconf import DictConfig
from PIL import Image from PIL import Image
from pynput import keyboard
from lerobot.common.datasets.compute_stats import compute_stats from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
@ -187,6 +196,7 @@ def record_dataset(
repo_id="lerobot/debug", repo_id="lerobot/debug",
warmup_time_s=2, warmup_time_s=2,
episode_time_s=10, episode_time_s=10,
reset_time_s=5,
num_episodes=50, num_episodes=50,
video=True, video=True,
run_compute_stats=True, run_compute_stats=True,
@ -228,6 +238,20 @@ def record_dataset(
timestamp = time.perf_counter() - start_time timestamp = time.perf_counter() - start_time
# Allow to exit early while recording an episode or resetting the environment,
# by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events.
exit_early = False
def on_press(key):
nonlocal exit_early
if key == keyboard.Key.right:
print("Right arrow key pressed. Exiting loop...")
exit_early = True
listener = keyboard.Listener(on_press=on_press)
listener.start()
# Save images using threads to reach high fps (30 and more) # Save images using threads to reach high fps (30 and more)
# Using `with` to exist smoothly if an execption is raised. # Using `with` to exist smoothly if an execption is raised.
# Using only 4 worker threads to avoid blocking the main thread. # Using only 4 worker threads to avoid blocking the main thread.
@ -235,17 +259,13 @@ def record_dataset(
# Start recording all episodes # Start recording all episodes
ep_dicts = [] ep_dicts = []
for episode_index in range(num_episodes): for episode_index in range(num_episodes):
logging.info(f"Recording episode {episode_index}")
os.system(f'say "Recording episode {episode_index}" &')
ep_dict = {} ep_dict = {}
frame_index = 0 frame_index = 0
timestamp = 0 timestamp = 0
start_time = time.perf_counter() start_time = time.perf_counter()
is_record_print = False
while timestamp < episode_time_s: while timestamp < episode_time_s:
if not is_record_print:
logging.info(f"Recording episode {episode_index}")
os.system(f'say "Recording episode {episode_index}" &')
is_record_print = True
now = time.perf_counter() now = time.perf_counter()
observation, action = robot.teleop_step(record_data=True) observation, action = robot.teleop_step(record_data=True)
@ -275,6 +295,26 @@ def record_dataset(
timestamp = time.perf_counter() - start_time timestamp = time.perf_counter() - start_time
if exit_early:
exit_early = False
break
# Skip resetting if 0 second allocated or it is the last episode
if reset_time_s == 0 or episode_index == num_episodes - 1:
continue
logging.info("Resetting environment")
os.system('say "Resetting environment" &')
timestamp = 0
start_time = time.perf_counter()
while timestamp < reset_time_s:
time.sleep(1)
timestamp = time.perf_counter() - start_time
if exit_early:
exit_early = False
break
num_frames = frame_index num_frames = frame_index
for key in image_keys: for key in image_keys:
@ -454,20 +494,61 @@ if __name__ == "__main__":
parser_record.add_argument( parser_record.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)" "--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
) )
parser_record.add_argument("--root", type=Path, default="data", help="") parser_record.add_argument(
parser_record.add_argument("--repo-id", type=str, default="lerobot/test", help="") "--root",
parser_record.add_argument("--warmup-time-s", type=int, default=2, help="") type=Path,
parser_record.add_argument("--episode-time-s", type=int, default=10, help="") default="data",
parser_record.add_argument("--num-episodes", type=int, default=50, help="") help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
parser_record.add_argument("--run-compute-stats", type=int, default=1, help="") )
parser_record.add_argument(
"--repo-id",
type=str,
default="lerobot/test",
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
)
parser_record.add_argument(
"--warmup-time-s",
type=int,
default=2,
help="Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.",
)
parser_record.add_argument(
"--episode-time-s",
type=int,
default=10,
help="Number of seconds for data recording for each episode.",
)
parser_record.add_argument(
"--reset-time-s",
type=int,
default=5,
help="Number of seconds for resetting the environment after each episode.",
)
parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.")
parser_record.add_argument(
"--run-compute-stats",
type=int,
default=1,
help="By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.",
)
parser_replay = subparsers.add_parser("replay_episode", parents=[base_parser]) parser_replay = subparsers.add_parser("replay_episode", parents=[base_parser])
parser_replay.add_argument( parser_replay.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)" "--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
) )
parser_replay.add_argument("--root", type=Path, default="data", help="") parser_replay.add_argument(
parser_replay.add_argument("--repo-id", type=str, default="lerobot/test", help="") "--root",
parser_replay.add_argument("--episode", type=int, default=0, help="") type=Path,
default="data",
help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
)
parser_replay.add_argument(
"--repo-id",
type=str,
default="lerobot/test",
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
)
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episode to replay.")
parser_policy = subparsers.add_parser("run_policy", parents=[base_parser]) parser_policy = subparsers.add_parser("run_policy", parents=[base_parser])
parser_policy.add_argument( parser_policy.add_argument(

146
poetry.lock generated
View File

@ -831,6 +831,16 @@ files = [
{file = "einops-0.8.0.tar.gz", hash = "sha256:63486517fed345712a8385c100cb279108d9d47e6ae59099b07657e983deae85"}, {file = "einops-0.8.0.tar.gz", hash = "sha256:63486517fed345712a8385c100cb279108d9d47e6ae59099b07657e983deae85"},
] ]
[[package]]
name = "evdev"
version = "1.7.1"
description = "Bindings to the Linux input handling subsystem"
optional = false
python-versions = ">=3.6"
files = [
{file = "evdev-1.7.1.tar.gz", hash = "sha256:0c72c370bda29d857e188d931019c32651a9c1ea977c08c8d939b1ced1637fde"},
]
[[package]] [[package]]
name = "exceptiongroup" name = "exceptiongroup"
version = "1.2.1" version = "1.2.1"
@ -3000,6 +3010,126 @@ cffi = ">=1.15.0"
[package.extras] [package.extras]
dev = ["aafigure", "matplotlib", "numpy", "pygame", "pyglet (<2.0.0)", "sphinx", "wheel"] dev = ["aafigure", "matplotlib", "numpy", "pygame", "pyglet (<2.0.0)", "sphinx", "wheel"]
[[package]]
name = "pynput"
version = "1.7.7"
description = "Monitor and control user input devices"
optional = false
python-versions = "*"
files = [
{file = "pynput-1.7.7-py2.py3-none-any.whl", hash = "sha256:afc43f651684c98818de048abc76adf9f2d3d797083cb07c1f82be764a2d44cb"},
]
[package.dependencies]
evdev = {version = ">=1.3", markers = "sys_platform in \"linux\""}
pyobjc-framework-ApplicationServices = {version = ">=8.0", markers = "sys_platform == \"darwin\""}
pyobjc-framework-Quartz = {version = ">=8.0", markers = "sys_platform == \"darwin\""}
python-xlib = {version = ">=0.17", markers = "sys_platform in \"linux\""}
six = "*"
[[package]]
name = "pyobjc-core"
version = "10.3.1"
description = "Python<->ObjC Interoperability Module"
optional = false
python-versions = ">=3.8"
files = [
{file = "pyobjc_core-10.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ea46d2cda17921e417085ac6286d43ae448113158afcf39e0abe484c58fb3d78"},
{file = "pyobjc_core-10.3.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:899d3c84d2933d292c808f385dc881a140cf08632907845043a333a9d7c899f9"},
{file = "pyobjc_core-10.3.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:6ff5823d13d0a534cdc17fa4ad47cf5bee4846ce0fd27fc40012e12b46db571b"},
{file = "pyobjc_core-10.3.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:2581e8e68885bcb0e11ec619e81ef28e08ee3fac4de20d8cc83bc5af5bcf4a90"},
{file = "pyobjc_core-10.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ea98d4c2ec39ca29e62e0327db21418696161fb138ee6278daf2acbedf7ce504"},
{file = "pyobjc_core-10.3.1-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:4c179c26ee2123d0aabffb9dbc60324b62b6f8614fb2c2328b09386ef59ef6d8"},
{file = "pyobjc_core-10.3.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:cb901fce65c9be420c40d8a6ee6fff5ff27c6945f44fd7191989b982baa66dea"},
{file = "pyobjc_core-10.3.1.tar.gz", hash = "sha256:b204a80ccc070f9ab3f8af423a3a25a6fd787e228508d00c4c30f8ac538ba720"},
]
[[package]]
name = "pyobjc-framework-applicationservices"
version = "10.3.1"
description = "Wrappers for the framework ApplicationServices on macOS"
optional = false
python-versions = ">=3.8"
files = [
{file = "pyobjc_framework_ApplicationServices-10.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:b694260d423c470cb90c3a7009cfde93e332ea6fb4b9b9526ad3acbd33460e3d"},
{file = "pyobjc_framework_ApplicationServices-10.3.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:d886ba1f65df47b77ff7546f3fc9bc7d08cfb6b3c04433b719f6b0689a2c0d1f"},
{file = "pyobjc_framework_ApplicationServices-10.3.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:be157f2c3ffb254064ef38249670af8cada5e519a714d2aa5da3740934d89bc8"},
{file = "pyobjc_framework_ApplicationServices-10.3.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:57737f41731661e4a3b78793ec9173f61242a32fa560c3e4e58484465d049c32"},
{file = "pyobjc_framework_ApplicationServices-10.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c429eca69ee675e781e4e55f79e939196b47f02560ad865b1ba9ac753b90bd77"},
{file = "pyobjc_framework_ApplicationServices-10.3.1-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:4f1814a17041a20adca454044080b52e39a4ebc567ad2c6a48866dd4beaa192a"},
{file = "pyobjc_framework_ApplicationServices-10.3.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:1252f1137f83eb2c6b9968d8c591363e8859dd2484bc9441d8f365bcfb43a0e4"},
{file = "pyobjc_framework_applicationservices-10.3.1.tar.gz", hash = "sha256:f27cb64aa4d129ce671fd42638c985eb2a56d544214a95fe3214a007eacc4790"},
]
[package.dependencies]
pyobjc-core = ">=10.3.1"
pyobjc-framework-Cocoa = ">=10.3.1"
pyobjc-framework-CoreText = ">=10.3.1"
pyobjc-framework-Quartz = ">=10.3.1"
[[package]]
name = "pyobjc-framework-cocoa"
version = "10.3.1"
description = "Wrappers for the Cocoa frameworks on macOS"
optional = false
python-versions = ">=3.8"
files = [
{file = "pyobjc_framework_Cocoa-10.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:4cb4f8491ab4d9b59f5187e42383f819f7a46306a4fa25b84f126776305291d1"},
{file = "pyobjc_framework_Cocoa-10.3.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5f31021f4f8fdf873b57a97ee1f3c1620dbe285e0b4eaed73dd0005eb72fd773"},
{file = "pyobjc_framework_Cocoa-10.3.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:11b4e0bad4bbb44a4edda128612f03cdeab38644bbf174de0c13129715497296"},
{file = "pyobjc_framework_Cocoa-10.3.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:de5e62e5ccf2871a94acf3bf79646b20ea893cc9db78afa8d1fe1b0d0f7cbdb0"},
{file = "pyobjc_framework_Cocoa-10.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6c5af24610ab639bd1f521ce4500484b40787f898f691b7a23da3339e6bc8b90"},
{file = "pyobjc_framework_Cocoa-10.3.1-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:a7151186bb7805deea434fae9a4423335e6371d105f29e73cc2036c6779a9dbc"},
{file = "pyobjc_framework_Cocoa-10.3.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:743d2a1ac08027fd09eab65814c79002a1d0421d7c0074ffd1217b6560889744"},
{file = "pyobjc_framework_cocoa-10.3.1.tar.gz", hash = "sha256:1cf20714daaa986b488fb62d69713049f635c9d41a60c8da97d835710445281a"},
]
[package.dependencies]
pyobjc-core = ">=10.3.1"
[[package]]
name = "pyobjc-framework-coretext"
version = "10.3.1"
description = "Wrappers for the framework CoreText on macOS"
optional = false
python-versions = ">=3.8"
files = [
{file = "pyobjc_framework_CoreText-10.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:dd6123cfccc38e32be884d1a13fb62bd636ecb192b9e8ae2b8011c977dec229e"},
{file = "pyobjc_framework_CoreText-10.3.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:834142a14235bd80edaef8d3a28d1e203ed3c988810a9b78005df7c561390288"},
{file = "pyobjc_framework_CoreText-10.3.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ae6c09d29eeaf30a67aa70e08a465b1f1e47d12e22b3a34ae8bc8fdb7e2e7342"},
{file = "pyobjc_framework_CoreText-10.3.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:51ca95df1db9401366f11a7467f64be57f9a0630d31c357237d4062df0216938"},
{file = "pyobjc_framework_CoreText-10.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8b75bdc267945b3f33c937c108d79405baf9d7c4cd530f922e5df243082a5031"},
{file = "pyobjc_framework_CoreText-10.3.1-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:029b24c338f58fc32a004256d8559507e4f366dfe4eb09d3144273d536012d90"},
{file = "pyobjc_framework_CoreText-10.3.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:418a55047dbff999fcd2b78cca167c4105587020b6c51567cfa28993bbfdc8ed"},
{file = "pyobjc_framework_coretext-10.3.1.tar.gz", hash = "sha256:b8fa2d5078ed774431ae64ba886156e319aec0b8c6cc23dabfd86778265b416f"},
]
[package.dependencies]
pyobjc-core = ">=10.3.1"
pyobjc-framework-Cocoa = ">=10.3.1"
pyobjc-framework-Quartz = ">=10.3.1"
[[package]]
name = "pyobjc-framework-quartz"
version = "10.3.1"
description = "Wrappers for the Quartz frameworks on macOS"
optional = false
python-versions = ">=3.8"
files = [
{file = "pyobjc_framework_Quartz-10.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:5ef4fd315ed2bc42ef77fdeb2bae28a88ec986bd7b8079a87ba3b3475348f96e"},
{file = "pyobjc_framework_Quartz-10.3.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:96578d4a3e70164efe44ad7dc320ecd4e211758ffcde5dcd694de1bbdfe090a4"},
{file = "pyobjc_framework_Quartz-10.3.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ca35f92486869a41847a1703bb176aab8a53dbfd8e678d1f4d68d8e6e1581c71"},
{file = "pyobjc_framework_Quartz-10.3.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:00a0933267e3a46ea4afcc35d117b2efb920f06de797fa66279c52e7057e3590"},
{file = "pyobjc_framework_Quartz-10.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:a161bedb4c5257a02ad56a910cd7eefb28bdb0ea78607df0d70ed4efe4ea54c1"},
{file = "pyobjc_framework_Quartz-10.3.1-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:d7a8028e117a94923a511944bfa9daf9744e212f06cf89010c60934a479863a5"},
{file = "pyobjc_framework_Quartz-10.3.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:de00c983b3267eb26fa42c6ed9f15e2bf006bde8afa7fe2b390646aa21a5d6fc"},
{file = "pyobjc_framework_quartz-10.3.1.tar.gz", hash = "sha256:b6d7e346d735c9a7f147cd78e6da79eeae416a0b7d3874644c83a23786c6f886"},
]
[package.dependencies]
pyobjc-core = ">=10.3.1"
pyobjc-framework-Cocoa = ">=10.3.1"
[[package]] [[package]]
name = "pyopengl" name = "pyopengl"
version = "3.1.7" version = "3.1.7"
@ -3122,6 +3252,20 @@ files = [
[package.dependencies] [package.dependencies]
six = ">=1.5" six = ">=1.5"
[[package]]
name = "python-xlib"
version = "0.33"
description = "Python X Library"
optional = false
python-versions = "*"
files = [
{file = "python-xlib-0.33.tar.gz", hash = "sha256:55af7906a2c75ce6cb280a584776080602444f75815a7aff4d287bb2d7018b32"},
{file = "python_xlib-0.33-py2.py3-none-any.whl", hash = "sha256:c3534038d42e0df2f1392a1b30a15a4ff5fdc2b86cfa94f072bf11b10a164398"},
]
[package.dependencies]
six = ">=1.10.0"
[[package]] [[package]]
name = "pytz" name = "pytz"
version = "2024.1" version = "2024.1"
@ -4363,4 +4507,4 @@ xarm = ["gym-xarm"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.10,<3.13" python-versions = ">=3.10,<3.13"
content-hash = "223a6496a630da8181f119634f96bed3e0de3aaca714f1f1abd7edd562e3f1c6" content-hash = "6d82706d5216ce065ba5912ea9f802846dfdf7e7b1665f8560cc782fb0dfa354"

View File

@ -64,6 +64,7 @@ scikit-image = {version = "^0.23.2", optional = true}
pandas = {version = "^2.2.2", optional = true} pandas = {version = "^2.2.2", optional = true}
pytest-mock = {version = "^3.14.0", optional = true} pytest-mock = {version = "^3.14.0", optional = true}
dynamixel-sdk = {version = "^3.7.31", optional = true} dynamixel-sdk = {version = "^3.7.31", optional = true}
pynput = "^1.7.7"