diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index fe335556..9c691fe4 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -173,6 +173,14 @@ def log_control_info(robot, dt_s, episode_index=None, frame_index=None): logging.info(" ".join(log_items)) +def get_is_headless(): + if platform.system() == "Linux": + display = os.environ.get("DISPLAY") + if display is None or display == "": + return True + return False + + ######################################################################################## # Control modes ######################################################################################## @@ -241,6 +249,8 @@ def record_dataset( else: episode_index = 0 + is_headless = get_is_headless() + # Execute a few seconds without recording data, to give times # to the robot devices to connect and start synchronizing. timestamp = 0 @@ -255,6 +265,9 @@ def record_dataset( now = time.perf_counter() observation, action = robot.teleop_step(record_data=True) + if not is_headless: + image_keys = [key for key in observation if "image" in key] + dt_s = time.perf_counter() - now busy_wait(1 / fps - dt_s) @@ -270,15 +283,8 @@ def record_dataset( rerecord_episode = False stop_recording = False - def is_headless(): - if platform.system() == "Linux": - display = os.environ.get("DISPLAY") - if display is None or display == "": - return True - return False - # Only import pynput if not in a headless environment - if is_headless(): + if is_headless: logging.info("Headless environment detected. Keyboard input will not be available.") else: from pynput import keyboard @@ -400,9 +406,11 @@ def record_dataset( with open(rec_info_path, "w") as f: json.dump(rec_info, f) + is_last_episode = stop_recording or (episode_index == (num_episodes - 1)) + # Wait if necessary with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar: - while timestamp < reset_time_s and not stop_recording: + while timestamp < reset_time_s and not is_last_episode: time.sleep(1) timestamp = time.perf_counter() - start_time pbar.update(1) @@ -417,11 +425,10 @@ def record_dataset( episode_index += 1 - # Only for last episode - if stop_recording or episode_index == num_episodes: + if is_last_episode: logging.info("Done recording") os.system('say "Done recording"') - if not is_headless(): + if not is_headless: listener.stop() logging.info("Waiting for threads writing the images on disk to terminate...") @@ -446,6 +453,7 @@ def record_dataset( # note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, # since video encoding with ffmpeg is already using multithreading. encode_video_frames(tmp_imgs_dir, video_path, fps, overwrite=True) + shutil.rmtree(tmp_imgs_dir) logging.info("Concatenating episodes") ep_dicts = []