diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index e614e4d2..bcc29481 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -207,7 +207,8 @@ def encode_video_frames( ffmpeg_args.append("-y") ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)] - subprocess.run(ffmpeg_cmd, check=True) + # redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal + subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL) @dataclass diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index c33a056f..c5227f82 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -53,9 +53,15 @@ python lerobot/scripts/control_robot.py record_dataset \ --reset-time-s 10 ``` -**NOTE**: You can early exit while recording an episode or resetting the environment, -by tapping the right arrow key '->'. This might require a sudo permission -to allow your terminal to monitor keyboard events. +**NOTE**: You can use your keyboard to control data recording flow. +- Tap right arrow key '->' to early exit while recording an episode and go to resseting the environment. +- Tap right arrow key '->' to early exit while resetting the environment and got to recording the next episode. +- Tap left arrow key '<-' to early exit and re-record the current episode. +- Tap escape key 'esc' to stop the data recording. +This might require a sudo permission to allow your terminal to monitor keyboard events. + +**NOTE**: You can resume/continue data recording by running the same data recording command twice. +To avoid resuming by deleting the dataset, use `--force-override 1`. - Train on this dataset with the ACT policy: ```bash @@ -75,6 +81,7 @@ python lerobot/scripts/control_robot.py run_policy \ import argparse import concurrent.futures +import json import logging import os import shutil @@ -83,10 +90,12 @@ from contextlib import nullcontext from pathlib import Path import torch +import tqdm from omegaconf import DictConfig from PIL import Image from pynput import keyboard +# from safetensors.torch import load_file, save_file from lerobot.common.datasets.compute_stats import compute_stats from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset @@ -200,6 +209,8 @@ def record_dataset( num_episodes=50, video=True, run_compute_stats=True, + num_image_writters=4, + force_override=False, ): # TODO(rcadene): Add option to record logs @@ -210,12 +221,24 @@ def record_dataset( robot.connect() local_dir = Path(root) / repo_id - if local_dir.exists(): + if local_dir.exists() and force_override: shutil.rmtree(local_dir) + episodes_dir = local_dir / "episodes" + episodes_dir.mkdir(parents=True, exist_ok=True) + videos_dir = local_dir / "videos" videos_dir.mkdir(parents=True, exist_ok=True) + # Logic to resume data recording + rec_info_path = episodes_dir / "data_recording_info.json" + if rec_info_path.exists(): + with open(rec_info_path) as f: + rec_info = json.load(f) + episode_index = rec_info["last_episode_index"] + 1 + else: + episode_index = 0 + # Execute a few seconds without recording data, to give times # to the robot devices to connect and start synchronizing. timestamp = 0 @@ -242,12 +265,25 @@ def record_dataset( # by tapping the right arrow key '->'. This might require a sudo permission # to allow your terminal to monitor keyboard events. exit_early = False + rerecord_episode = False + stop_recording = False def on_press(key): - nonlocal exit_early - if key == keyboard.Key.right: - print("Right arrow key pressed. Exiting loop...") - exit_early = True + nonlocal exit_early, rerecord_episode, stop_recording + try: + if key == keyboard.Key.right: + print("Right arrow key pressed. Exiting loop...") + exit_early = True + elif key == keyboard.Key.left: + print("Left arrow key pressed. Exiting loop and rerecord the last episode...") + rerecord_episode = True + exit_early = True + elif key == keyboard.Key.esc: + print("Escape key pressed. Stopping data recording...") + stop_recording = True + exit_early = True + except Exception as e: + print(f"Error handling key press: {e}") listener = keyboard.Listener(on_press=on_press) listener.start() @@ -255,10 +291,10 @@ def record_dataset( # Save images using threads to reach high fps (30 and more) # Using `with` to exist smoothly if an execption is raised. # Using only 4 worker threads to avoid blocking the main thread. - with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: + futures = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writters) as executor: # Start recording all episodes - ep_dicts = [] - for episode_index in range(num_episodes): + while episode_index < num_episodes: logging.info(f"Recording episode {episode_index}") os.system(f'say "Recording episode {episode_index}" &') ep_dict = {} @@ -273,7 +309,11 @@ def record_dataset( not_image_keys = [key for key in observation if "image" not in key] for key in image_keys: - executor.submit(save_image, observation[key], key, frame_index, episode_index, videos_dir) + futures += [ + executor.submit( + save_image, observation[key], key, frame_index, episode_index, videos_dir + ) + ] for key in not_image_keys: if key not in ep_dict: @@ -299,70 +339,107 @@ def record_dataset( exit_early = False break - # Skip resetting if 0 second allocated or it is the last episode - if reset_time_s == 0 or episode_index == num_episodes - 1: - continue + if not stop_recording: + # Start resetting env while the executor are finishing + logging.info("Reset the environment") + os.system('say "Reset the environment" &') - logging.info("Resetting environment") - os.system('say "Resetting environment" &') timestamp = 0 start_time = time.perf_counter() - while timestamp < reset_time_s: - time.sleep(1) - timestamp = time.perf_counter() - start_time - if exit_early: - exit_early = False - break + # During env reset we save the data and encode the videos + num_frames = frame_index - num_frames = frame_index + for key in image_keys: + tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}" + fname = f"{key}_episode_{episode_index:06d}.mp4" + video_path = local_dir / "videos" / fname + if video_path.exists(): + video_path.unlink() + # Store the reference to the video frame, even tho the videos are not yet encoded + ep_dict[key] = [] + for i in range(num_frames): + ep_dict[key].append({"path": f"videos/{fname}", "timestamp": i / fps}) - for key in image_keys: - tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}" - fname = f"{key}_episode_{episode_index:06d}.mp4" - # store the reference to the video frame, even tho the videos are not yet encoded - ep_dict[key] = [] - for i in range(num_frames): - ep_dict[key].append({"path": f"videos/{fname}", "timestamp": i / fps}) + for key in not_image_keys: + ep_dict[key] = torch.stack(ep_dict[key]) - for key in not_image_keys: - ep_dict[key] = torch.stack(ep_dict[key]) + for key in action: + ep_dict[key] = torch.stack(ep_dict[key]) - for key in action: - ep_dict[key] = torch.stack(ep_dict[key]) + ep_dict["episode_index"] = torch.tensor([episode_index] * num_frames) + ep_dict["frame_index"] = torch.arange(0, num_frames, 1) + ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps - ep_dict["episode_index"] = torch.tensor([episode_index] * num_frames) - ep_dict["frame_index"] = torch.arange(0, num_frames, 1) - ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps + done = torch.zeros(num_frames, dtype=torch.bool) + done[-1] = True + ep_dict["next.done"] = done - done = torch.zeros(num_frames, dtype=torch.bool) - done[-1] = True - ep_dict["next.done"] = done + ep_path = episodes_dir / f"episode_{episode_index}.safetensors" + print("Saving episode dictionary...") + torch.save(ep_dict, ep_path) - ep_dicts.append(ep_dict) + rec_info = { + "last_episode_index": episode_index, + } + with open(rec_info_path, "w") as f: + json.dump(rec_info, f) - # last episode - if episode_index == num_episodes - 1: - logging.info("Done recording") - os.system('say "Done recording" &') + # Wait if necessary + with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar: + while timestamp < reset_time_s and not stop_recording: + time.sleep(1) + timestamp = time.perf_counter() - start_time + pbar.update(1) + if exit_early: + exit_early = False + break - data_dict = concatenate_episodes(ep_dicts) - total_frames = data_dict["frame_index"].shape[0] - data_dict["index"] = torch.arange(0, total_frames, 1) + # Skip updating episode index which forces re-recording episode + if rerecord_episode: + rerecord_episode = False + continue - logging.info("Encoding images to videos") - os.system('say "Encoding images to videos" &') + episode_index += 1 - for episode_index in range(num_episodes): + # Only for last episode + if stop_recording or episode_index == num_episodes: + logging.info("Done recording") + os.system('say "Done recording"') + logging.info("Waiting for threads writting the images on disk to terminate...") + listener.stop() + for _ in tqdm.tqdm( + concurrent.futures.as_completed(futures), total=len(futures), desc="Writting images" + ): + pass + break + + num_episodes = episode_index + + logging.info("Encoding videos") + os.system('say "Encoding videos" &') + # Use ffmpeg to convert frames stored as png into mp4 videos + for episode_index in tqdm.tqdm(range(num_episodes)): for key in image_keys: tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}" fname = f"{key}_episode_{episode_index:06d}.mp4" video_path = local_dir / "videos" / fname + if video_path.exists(): + continue # note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, # since video encoding with ffmpeg is already using multithreading. - encode_video_frames(tmp_imgs_dir, video_path, fps) - # Clean temporary images directory - shutil.rmtree(tmp_imgs_dir) + encode_video_frames(tmp_imgs_dir, video_path, fps, overwrite=True) + + logging.info("Concatenating episodes") + ep_dicts = [] + for episode_index in tqdm.tqdm(range(num_episodes)): + ep_path = episodes_dir / f"episode_{episode_index}.safetensors" + ep_dict = torch.load(ep_path) + ep_dicts.append(ep_dict) + data_dict = concatenate_episodes(ep_dicts) + + total_frames = data_dict["frame_index"].shape[0] + data_dict["index"] = torch.arange(0, total_frames, 1) hf_dataset = to_hf_dataset(data_dict, video) episode_data_index = calculate_episode_data_index(hf_dataset) @@ -378,8 +455,13 @@ def record_dataset( info=info, videos_dir=videos_dir, ) - stats = compute_stats(lerobot_dataset) if run_compute_stats else {} - lerobot_dataset.stats = stats + if run_compute_stats: + logging.info("Computing dataset statistics") + os.system('say "Computing dataset statistics" &') + stats = compute_stats(lerobot_dataset) + lerobot_dataset.stats = stats + else: + logging.info("Skipping computation of the dataset statistrics") hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved hf_dataset.save_to_disk(str(local_dir / "train")) @@ -389,8 +471,8 @@ def record_dataset( # TODO(rcadene): push to hub - logging.info("Done, exiting") - os.system('say "Done, exiting" &') + logging.info("Exiting") + os.system('say "Exiting" &') return lerobot_dataset @@ -532,6 +614,19 @@ if __name__ == "__main__": help="By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.", ) + parser_record.add_argument( + "--num-image-writters", + type=int, + default=4, + help="Number of threads writting the frames as png images on disk. Don't set too much as you might get unstable fps due to main thread being blocked.", + ) + parser_record.add_argument( + "--force-override", + type=int, + default=0, + help="By default, data recording is resumed. When set to 1, delete the local directory and start data recording from scratch.", + ) + parser_replay = subparsers.add_parser("replay_episode", parents=[base_parser]) parser_replay.add_argument( "--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"