diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index ca0226b7..c33a056f 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -39,16 +39,24 @@ python lerobot/scripts/control_robot.py replay_episode \ --episode 0 ``` -- Record a full dataset in order to train a policy: +- Record a full dataset in order to train a policy, with 2 seconds of warmup, +30 seconds of recording for each episode, and 10 seconds to reset the environment in between episodes: ```bash python lerobot/scripts/control_robot.py record_dataset \ --fps 30 \ --root data \ --repo-id $USER/koch_pick_place_lego \ --num-episodes 50 \ - --run-compute-stats 1 + --run-compute-stats 1 \ + --warmup-time-s 2 \ + --episode-time-s 30 \ + --reset-time-s 10 ``` +**NOTE**: You can early exit while recording an episode or resetting the environment, +by tapping the right arrow key '->'. This might require a sudo permission +to allow your terminal to monitor keyboard events. + - Train on this dataset with the ACT policy: ```bash DATA_DIR=data python lerobot/scripts/train.py \ @@ -77,6 +85,7 @@ from pathlib import Path import torch from omegaconf import DictConfig from PIL import Image +from pynput import keyboard from lerobot.common.datasets.compute_stats import compute_stats from lerobot.common.datasets.lerobot_dataset import LeRobotDataset @@ -187,6 +196,7 @@ def record_dataset( repo_id="lerobot/debug", warmup_time_s=2, episode_time_s=10, + reset_time_s=5, num_episodes=50, video=True, run_compute_stats=True, @@ -228,6 +238,20 @@ def record_dataset( timestamp = time.perf_counter() - start_time + # Allow to exit early while recording an episode or resetting the environment, + # by tapping the right arrow key '->'. This might require a sudo permission + # to allow your terminal to monitor keyboard events. + exit_early = False + + def on_press(key): + nonlocal exit_early + if key == keyboard.Key.right: + print("Right arrow key pressed. Exiting loop...") + exit_early = True + + listener = keyboard.Listener(on_press=on_press) + listener.start() + # Save images using threads to reach high fps (30 and more) # Using `with` to exist smoothly if an execption is raised. # Using only 4 worker threads to avoid blocking the main thread. @@ -235,17 +259,13 @@ def record_dataset( # Start recording all episodes ep_dicts = [] for episode_index in range(num_episodes): + logging.info(f"Recording episode {episode_index}") + os.system(f'say "Recording episode {episode_index}" &') ep_dict = {} frame_index = 0 timestamp = 0 start_time = time.perf_counter() - is_record_print = False while timestamp < episode_time_s: - if not is_record_print: - logging.info(f"Recording episode {episode_index}") - os.system(f'say "Recording episode {episode_index}" &') - is_record_print = True - now = time.perf_counter() observation, action = robot.teleop_step(record_data=True) @@ -275,6 +295,26 @@ def record_dataset( timestamp = time.perf_counter() - start_time + if exit_early: + exit_early = False + break + + # Skip resetting if 0 second allocated or it is the last episode + if reset_time_s == 0 or episode_index == num_episodes - 1: + continue + + logging.info("Resetting environment") + os.system('say "Resetting environment" &') + timestamp = 0 + start_time = time.perf_counter() + while timestamp < reset_time_s: + time.sleep(1) + timestamp = time.perf_counter() - start_time + + if exit_early: + exit_early = False + break + num_frames = frame_index for key in image_keys: @@ -454,20 +494,61 @@ if __name__ == "__main__": parser_record.add_argument( "--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)" ) - parser_record.add_argument("--root", type=Path, default="data", help="") - parser_record.add_argument("--repo-id", type=str, default="lerobot/test", help="") - parser_record.add_argument("--warmup-time-s", type=int, default=2, help="") - parser_record.add_argument("--episode-time-s", type=int, default=10, help="") - parser_record.add_argument("--num-episodes", type=int, default=50, help="") - parser_record.add_argument("--run-compute-stats", type=int, default=1, help="") + parser_record.add_argument( + "--root", + type=Path, + default="data", + help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').", + ) + parser_record.add_argument( + "--repo-id", + type=str, + default="lerobot/test", + help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).", + ) + parser_record.add_argument( + "--warmup-time-s", + type=int, + default=2, + help="Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.", + ) + parser_record.add_argument( + "--episode-time-s", + type=int, + default=10, + help="Number of seconds for data recording for each episode.", + ) + parser_record.add_argument( + "--reset-time-s", + type=int, + default=5, + help="Number of seconds for resetting the environment after each episode.", + ) + parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.") + parser_record.add_argument( + "--run-compute-stats", + type=int, + default=1, + help="By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.", + ) parser_replay = subparsers.add_parser("replay_episode", parents=[base_parser]) parser_replay.add_argument( "--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)" ) - parser_replay.add_argument("--root", type=Path, default="data", help="") - parser_replay.add_argument("--repo-id", type=str, default="lerobot/test", help="") - parser_replay.add_argument("--episode", type=int, default=0, help="") + parser_replay.add_argument( + "--root", + type=Path, + default="data", + help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').", + ) + parser_replay.add_argument( + "--repo-id", + type=str, + default="lerobot/test", + help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).", + ) + parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episode to replay.") parser_policy = subparsers.add_parser("run_policy", parents=[base_parser]) parser_policy.add_argument( diff --git a/poetry.lock b/poetry.lock index cc917d70..4afc150e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -831,6 +831,16 @@ files = [ {file = "einops-0.8.0.tar.gz", hash = "sha256:63486517fed345712a8385c100cb279108d9d47e6ae59099b07657e983deae85"}, ] +[[package]] +name = "evdev" +version = "1.7.1" +description = "Bindings to the Linux input handling subsystem" +optional = false +python-versions = ">=3.6" +files = [ + {file = "evdev-1.7.1.tar.gz", hash = "sha256:0c72c370bda29d857e188d931019c32651a9c1ea977c08c8d939b1ced1637fde"}, +] + [[package]] name = "exceptiongroup" version = "1.2.1" @@ -3000,6 +3010,126 @@ cffi = ">=1.15.0" [package.extras] dev = ["aafigure", "matplotlib", "numpy", "pygame", "pyglet (<2.0.0)", "sphinx", "wheel"] +[[package]] +name = "pynput" +version = "1.7.7" +description = "Monitor and control user input devices" +optional = false +python-versions = "*" +files = [ + {file = "pynput-1.7.7-py2.py3-none-any.whl", hash = "sha256:afc43f651684c98818de048abc76adf9f2d3d797083cb07c1f82be764a2d44cb"}, +] + +[package.dependencies] +evdev = {version = ">=1.3", markers = "sys_platform in \"linux\""} +pyobjc-framework-ApplicationServices = {version = ">=8.0", markers = "sys_platform == \"darwin\""} +pyobjc-framework-Quartz = {version = ">=8.0", markers = "sys_platform == \"darwin\""} +python-xlib = {version = ">=0.17", markers = "sys_platform in \"linux\""} +six = "*" + +[[package]] +name = "pyobjc-core" +version = "10.3.1" +description = "Python<->ObjC Interoperability Module" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyobjc_core-10.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ea46d2cda17921e417085ac6286d43ae448113158afcf39e0abe484c58fb3d78"}, + {file = "pyobjc_core-10.3.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:899d3c84d2933d292c808f385dc881a140cf08632907845043a333a9d7c899f9"}, + {file = "pyobjc_core-10.3.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:6ff5823d13d0a534cdc17fa4ad47cf5bee4846ce0fd27fc40012e12b46db571b"}, + {file = "pyobjc_core-10.3.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:2581e8e68885bcb0e11ec619e81ef28e08ee3fac4de20d8cc83bc5af5bcf4a90"}, + {file = "pyobjc_core-10.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ea98d4c2ec39ca29e62e0327db21418696161fb138ee6278daf2acbedf7ce504"}, + {file = "pyobjc_core-10.3.1-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:4c179c26ee2123d0aabffb9dbc60324b62b6f8614fb2c2328b09386ef59ef6d8"}, + {file = "pyobjc_core-10.3.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:cb901fce65c9be420c40d8a6ee6fff5ff27c6945f44fd7191989b982baa66dea"}, + {file = "pyobjc_core-10.3.1.tar.gz", hash = "sha256:b204a80ccc070f9ab3f8af423a3a25a6fd787e228508d00c4c30f8ac538ba720"}, +] + +[[package]] +name = "pyobjc-framework-applicationservices" +version = "10.3.1" +description = "Wrappers for the framework ApplicationServices on macOS" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyobjc_framework_ApplicationServices-10.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:b694260d423c470cb90c3a7009cfde93e332ea6fb4b9b9526ad3acbd33460e3d"}, + {file = "pyobjc_framework_ApplicationServices-10.3.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:d886ba1f65df47b77ff7546f3fc9bc7d08cfb6b3c04433b719f6b0689a2c0d1f"}, + {file = "pyobjc_framework_ApplicationServices-10.3.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:be157f2c3ffb254064ef38249670af8cada5e519a714d2aa5da3740934d89bc8"}, + {file = "pyobjc_framework_ApplicationServices-10.3.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:57737f41731661e4a3b78793ec9173f61242a32fa560c3e4e58484465d049c32"}, + {file = "pyobjc_framework_ApplicationServices-10.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c429eca69ee675e781e4e55f79e939196b47f02560ad865b1ba9ac753b90bd77"}, + {file = "pyobjc_framework_ApplicationServices-10.3.1-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:4f1814a17041a20adca454044080b52e39a4ebc567ad2c6a48866dd4beaa192a"}, + {file = "pyobjc_framework_ApplicationServices-10.3.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:1252f1137f83eb2c6b9968d8c591363e8859dd2484bc9441d8f365bcfb43a0e4"}, + {file = "pyobjc_framework_applicationservices-10.3.1.tar.gz", hash = "sha256:f27cb64aa4d129ce671fd42638c985eb2a56d544214a95fe3214a007eacc4790"}, +] + +[package.dependencies] +pyobjc-core = ">=10.3.1" +pyobjc-framework-Cocoa = ">=10.3.1" +pyobjc-framework-CoreText = ">=10.3.1" +pyobjc-framework-Quartz = ">=10.3.1" + +[[package]] +name = "pyobjc-framework-cocoa" +version = "10.3.1" +description = "Wrappers for the Cocoa frameworks on macOS" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyobjc_framework_Cocoa-10.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:4cb4f8491ab4d9b59f5187e42383f819f7a46306a4fa25b84f126776305291d1"}, + {file = "pyobjc_framework_Cocoa-10.3.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5f31021f4f8fdf873b57a97ee1f3c1620dbe285e0b4eaed73dd0005eb72fd773"}, + {file = "pyobjc_framework_Cocoa-10.3.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:11b4e0bad4bbb44a4edda128612f03cdeab38644bbf174de0c13129715497296"}, + {file = "pyobjc_framework_Cocoa-10.3.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:de5e62e5ccf2871a94acf3bf79646b20ea893cc9db78afa8d1fe1b0d0f7cbdb0"}, + {file = "pyobjc_framework_Cocoa-10.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6c5af24610ab639bd1f521ce4500484b40787f898f691b7a23da3339e6bc8b90"}, + {file = "pyobjc_framework_Cocoa-10.3.1-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:a7151186bb7805deea434fae9a4423335e6371d105f29e73cc2036c6779a9dbc"}, + {file = "pyobjc_framework_Cocoa-10.3.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:743d2a1ac08027fd09eab65814c79002a1d0421d7c0074ffd1217b6560889744"}, + {file = "pyobjc_framework_cocoa-10.3.1.tar.gz", hash = "sha256:1cf20714daaa986b488fb62d69713049f635c9d41a60c8da97d835710445281a"}, +] + +[package.dependencies] +pyobjc-core = ">=10.3.1" + +[[package]] +name = "pyobjc-framework-coretext" +version = "10.3.1" +description = "Wrappers for the framework CoreText on macOS" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyobjc_framework_CoreText-10.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:dd6123cfccc38e32be884d1a13fb62bd636ecb192b9e8ae2b8011c977dec229e"}, + {file = "pyobjc_framework_CoreText-10.3.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:834142a14235bd80edaef8d3a28d1e203ed3c988810a9b78005df7c561390288"}, + {file = "pyobjc_framework_CoreText-10.3.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ae6c09d29eeaf30a67aa70e08a465b1f1e47d12e22b3a34ae8bc8fdb7e2e7342"}, + {file = "pyobjc_framework_CoreText-10.3.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:51ca95df1db9401366f11a7467f64be57f9a0630d31c357237d4062df0216938"}, + {file = "pyobjc_framework_CoreText-10.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8b75bdc267945b3f33c937c108d79405baf9d7c4cd530f922e5df243082a5031"}, + {file = "pyobjc_framework_CoreText-10.3.1-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:029b24c338f58fc32a004256d8559507e4f366dfe4eb09d3144273d536012d90"}, + {file = "pyobjc_framework_CoreText-10.3.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:418a55047dbff999fcd2b78cca167c4105587020b6c51567cfa28993bbfdc8ed"}, + {file = "pyobjc_framework_coretext-10.3.1.tar.gz", hash = "sha256:b8fa2d5078ed774431ae64ba886156e319aec0b8c6cc23dabfd86778265b416f"}, +] + +[package.dependencies] +pyobjc-core = ">=10.3.1" +pyobjc-framework-Cocoa = ">=10.3.1" +pyobjc-framework-Quartz = ">=10.3.1" + +[[package]] +name = "pyobjc-framework-quartz" +version = "10.3.1" +description = "Wrappers for the Quartz frameworks on macOS" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyobjc_framework_Quartz-10.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:5ef4fd315ed2bc42ef77fdeb2bae28a88ec986bd7b8079a87ba3b3475348f96e"}, + {file = "pyobjc_framework_Quartz-10.3.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:96578d4a3e70164efe44ad7dc320ecd4e211758ffcde5dcd694de1bbdfe090a4"}, + {file = "pyobjc_framework_Quartz-10.3.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ca35f92486869a41847a1703bb176aab8a53dbfd8e678d1f4d68d8e6e1581c71"}, + {file = "pyobjc_framework_Quartz-10.3.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:00a0933267e3a46ea4afcc35d117b2efb920f06de797fa66279c52e7057e3590"}, + {file = "pyobjc_framework_Quartz-10.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:a161bedb4c5257a02ad56a910cd7eefb28bdb0ea78607df0d70ed4efe4ea54c1"}, + {file = "pyobjc_framework_Quartz-10.3.1-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:d7a8028e117a94923a511944bfa9daf9744e212f06cf89010c60934a479863a5"}, + {file = "pyobjc_framework_Quartz-10.3.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:de00c983b3267eb26fa42c6ed9f15e2bf006bde8afa7fe2b390646aa21a5d6fc"}, + {file = "pyobjc_framework_quartz-10.3.1.tar.gz", hash = "sha256:b6d7e346d735c9a7f147cd78e6da79eeae416a0b7d3874644c83a23786c6f886"}, +] + +[package.dependencies] +pyobjc-core = ">=10.3.1" +pyobjc-framework-Cocoa = ">=10.3.1" + [[package]] name = "pyopengl" version = "3.1.7" @@ -3122,6 +3252,20 @@ files = [ [package.dependencies] six = ">=1.5" +[[package]] +name = "python-xlib" +version = "0.33" +description = "Python X Library" +optional = false +python-versions = "*" +files = [ + {file = "python-xlib-0.33.tar.gz", hash = "sha256:55af7906a2c75ce6cb280a584776080602444f75815a7aff4d287bb2d7018b32"}, + {file = "python_xlib-0.33-py2.py3-none-any.whl", hash = "sha256:c3534038d42e0df2f1392a1b30a15a4ff5fdc2b86cfa94f072bf11b10a164398"}, +] + +[package.dependencies] +six = ">=1.10.0" + [[package]] name = "pytz" version = "2024.1" @@ -4363,4 +4507,4 @@ xarm = ["gym-xarm"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "223a6496a630da8181f119634f96bed3e0de3aaca714f1f1abd7edd562e3f1c6" +content-hash = "6d82706d5216ce065ba5912ea9f802846dfdf7e7b1665f8560cc782fb0dfa354" diff --git a/pyproject.toml b/pyproject.toml index 0385772d..8ea2431c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,7 @@ scikit-image = {version = "^0.23.2", optional = true} pandas = {version = "^2.2.2", optional = true} pytest-mock = {version = "^3.14.0", optional = true} dynamixel-sdk = {version = "^3.7.31", optional = true} +pynput = "^1.7.7"