Add keyboard interaction, Add tqdm, Optimize stuff, Fix, Add resuming

This commit is contained in:
Remi Cadene 2024-07-12 19:56:28 +02:00
parent 7a659dbd6b
commit d525e1b0f8
2 changed files with 156 additions and 60 deletions

View File

@ -207,7 +207,8 @@ def encode_video_frames(
ffmpeg_args.append("-y") ffmpeg_args.append("-y")
ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)] ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)]
subprocess.run(ffmpeg_cmd, check=True) # redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
@dataclass @dataclass

View File

@ -53,9 +53,15 @@ python lerobot/scripts/control_robot.py record_dataset \
--reset-time-s 10 --reset-time-s 10
``` ```
**NOTE**: You can early exit while recording an episode or resetting the environment, **NOTE**: You can use your keyboard to control data recording flow.
by tapping the right arrow key '->'. This might require a sudo permission - Tap right arrow key '->' to early exit while recording an episode and go to resseting the environment.
to allow your terminal to monitor keyboard events. - Tap right arrow key '->' to early exit while resetting the environment and got to recording the next episode.
- Tap left arrow key '<-' to early exit and re-record the current episode.
- Tap escape key 'esc' to stop the data recording.
This might require a sudo permission to allow your terminal to monitor keyboard events.
**NOTE**: You can resume/continue data recording by running the same data recording command twice.
To avoid resuming by deleting the dataset, use `--force-override 1`.
- Train on this dataset with the ACT policy: - Train on this dataset with the ACT policy:
```bash ```bash
@ -75,6 +81,7 @@ python lerobot/scripts/control_robot.py run_policy \
import argparse import argparse
import concurrent.futures import concurrent.futures
import json
import logging import logging
import os import os
import shutil import shutil
@ -83,10 +90,12 @@ from contextlib import nullcontext
from pathlib import Path from pathlib import Path
import torch import torch
import tqdm
from omegaconf import DictConfig from omegaconf import DictConfig
from PIL import Image from PIL import Image
from pynput import keyboard from pynput import keyboard
# from safetensors.torch import load_file, save_file
from lerobot.common.datasets.compute_stats import compute_stats from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset
@ -200,6 +209,8 @@ def record_dataset(
num_episodes=50, num_episodes=50,
video=True, video=True,
run_compute_stats=True, run_compute_stats=True,
num_image_writters=4,
force_override=False,
): ):
# TODO(rcadene): Add option to record logs # TODO(rcadene): Add option to record logs
@ -210,12 +221,24 @@ def record_dataset(
robot.connect() robot.connect()
local_dir = Path(root) / repo_id local_dir = Path(root) / repo_id
if local_dir.exists(): if local_dir.exists() and force_override:
shutil.rmtree(local_dir) shutil.rmtree(local_dir)
episodes_dir = local_dir / "episodes"
episodes_dir.mkdir(parents=True, exist_ok=True)
videos_dir = local_dir / "videos" videos_dir = local_dir / "videos"
videos_dir.mkdir(parents=True, exist_ok=True) videos_dir.mkdir(parents=True, exist_ok=True)
# Logic to resume data recording
rec_info_path = episodes_dir / "data_recording_info.json"
if rec_info_path.exists():
with open(rec_info_path) as f:
rec_info = json.load(f)
episode_index = rec_info["last_episode_index"] + 1
else:
episode_index = 0
# Execute a few seconds without recording data, to give times # Execute a few seconds without recording data, to give times
# to the robot devices to connect and start synchronizing. # to the robot devices to connect and start synchronizing.
timestamp = 0 timestamp = 0
@ -242,12 +265,25 @@ def record_dataset(
# by tapping the right arrow key '->'. This might require a sudo permission # by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events. # to allow your terminal to monitor keyboard events.
exit_early = False exit_early = False
rerecord_episode = False
stop_recording = False
def on_press(key): def on_press(key):
nonlocal exit_early nonlocal exit_early, rerecord_episode, stop_recording
try:
if key == keyboard.Key.right: if key == keyboard.Key.right:
print("Right arrow key pressed. Exiting loop...") print("Right arrow key pressed. Exiting loop...")
exit_early = True exit_early = True
elif key == keyboard.Key.left:
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
rerecord_episode = True
exit_early = True
elif key == keyboard.Key.esc:
print("Escape key pressed. Stopping data recording...")
stop_recording = True
exit_early = True
except Exception as e:
print(f"Error handling key press: {e}")
listener = keyboard.Listener(on_press=on_press) listener = keyboard.Listener(on_press=on_press)
listener.start() listener.start()
@ -255,10 +291,10 @@ def record_dataset(
# Save images using threads to reach high fps (30 and more) # Save images using threads to reach high fps (30 and more)
# Using `with` to exist smoothly if an execption is raised. # Using `with` to exist smoothly if an execption is raised.
# Using only 4 worker threads to avoid blocking the main thread. # Using only 4 worker threads to avoid blocking the main thread.
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: futures = []
with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writters) as executor:
# Start recording all episodes # Start recording all episodes
ep_dicts = [] while episode_index < num_episodes:
for episode_index in range(num_episodes):
logging.info(f"Recording episode {episode_index}") logging.info(f"Recording episode {episode_index}")
os.system(f'say "Recording episode {episode_index}" &') os.system(f'say "Recording episode {episode_index}" &')
ep_dict = {} ep_dict = {}
@ -273,7 +309,11 @@ def record_dataset(
not_image_keys = [key for key in observation if "image" not in key] not_image_keys = [key for key in observation if "image" not in key]
for key in image_keys: for key in image_keys:
executor.submit(save_image, observation[key], key, frame_index, episode_index, videos_dir) futures += [
executor.submit(
save_image, observation[key], key, frame_index, episode_index, videos_dir
)
]
for key in not_image_keys: for key in not_image_keys:
if key not in ep_dict: if key not in ep_dict:
@ -299,28 +339,24 @@ def record_dataset(
exit_early = False exit_early = False
break break
# Skip resetting if 0 second allocated or it is the last episode if not stop_recording:
if reset_time_s == 0 or episode_index == num_episodes - 1: # Start resetting env while the executor are finishing
continue logging.info("Reset the environment")
os.system('say "Reset the environment" &')
logging.info("Resetting environment")
os.system('say "Resetting environment" &')
timestamp = 0 timestamp = 0
start_time = time.perf_counter() start_time = time.perf_counter()
while timestamp < reset_time_s:
time.sleep(1)
timestamp = time.perf_counter() - start_time
if exit_early:
exit_early = False
break
# During env reset we save the data and encode the videos
num_frames = frame_index num_frames = frame_index
for key in image_keys: for key in image_keys:
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}" tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
fname = f"{key}_episode_{episode_index:06d}.mp4" fname = f"{key}_episode_{episode_index:06d}.mp4"
# store the reference to the video frame, even tho the videos are not yet encoded video_path = local_dir / "videos" / fname
if video_path.exists():
video_path.unlink()
# Store the reference to the video frame, even tho the videos are not yet encoded
ep_dict[key] = [] ep_dict[key] = []
for i in range(num_frames): for i in range(num_frames):
ep_dict[key].append({"path": f"videos/{fname}", "timestamp": i / fps}) ep_dict[key].append({"path": f"videos/{fname}", "timestamp": i / fps})
@ -339,30 +375,71 @@ def record_dataset(
done[-1] = True done[-1] = True
ep_dict["next.done"] = done ep_dict["next.done"] = done
ep_dicts.append(ep_dict) ep_path = episodes_dir / f"episode_{episode_index}.safetensors"
print("Saving episode dictionary...")
torch.save(ep_dict, ep_path)
# last episode rec_info = {
if episode_index == num_episodes - 1: "last_episode_index": episode_index,
}
with open(rec_info_path, "w") as f:
json.dump(rec_info, f)
# Wait if necessary
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
while timestamp < reset_time_s and not stop_recording:
time.sleep(1)
timestamp = time.perf_counter() - start_time
pbar.update(1)
if exit_early:
exit_early = False
break
# Skip updating episode index which forces re-recording episode
if rerecord_episode:
rerecord_episode = False
continue
episode_index += 1
# Only for last episode
if stop_recording or episode_index == num_episodes:
logging.info("Done recording") logging.info("Done recording")
os.system('say "Done recording" &') os.system('say "Done recording"')
logging.info("Waiting for threads writting the images on disk to terminate...")
listener.stop()
for _ in tqdm.tqdm(
concurrent.futures.as_completed(futures), total=len(futures), desc="Writting images"
):
pass
break
data_dict = concatenate_episodes(ep_dicts) num_episodes = episode_index
total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1)
logging.info("Encoding images to videos") logging.info("Encoding videos")
os.system('say "Encoding images to videos" &') os.system('say "Encoding videos" &')
# Use ffmpeg to convert frames stored as png into mp4 videos
for episode_index in range(num_episodes): for episode_index in tqdm.tqdm(range(num_episodes)):
for key in image_keys: for key in image_keys:
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}" tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
fname = f"{key}_episode_{episode_index:06d}.mp4" fname = f"{key}_episode_{episode_index:06d}.mp4"
video_path = local_dir / "videos" / fname video_path = local_dir / "videos" / fname
if video_path.exists():
continue
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, # note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
# since video encoding with ffmpeg is already using multithreading. # since video encoding with ffmpeg is already using multithreading.
encode_video_frames(tmp_imgs_dir, video_path, fps) encode_video_frames(tmp_imgs_dir, video_path, fps, overwrite=True)
# Clean temporary images directory
shutil.rmtree(tmp_imgs_dir) logging.info("Concatenating episodes")
ep_dicts = []
for episode_index in tqdm.tqdm(range(num_episodes)):
ep_path = episodes_dir / f"episode_{episode_index}.safetensors"
ep_dict = torch.load(ep_path)
ep_dicts.append(ep_dict)
data_dict = concatenate_episodes(ep_dicts)
total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1)
hf_dataset = to_hf_dataset(data_dict, video) hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset) episode_data_index = calculate_episode_data_index(hf_dataset)
@ -378,8 +455,13 @@ def record_dataset(
info=info, info=info,
videos_dir=videos_dir, videos_dir=videos_dir,
) )
stats = compute_stats(lerobot_dataset) if run_compute_stats else {} if run_compute_stats:
logging.info("Computing dataset statistics")
os.system('say "Computing dataset statistics" &')
stats = compute_stats(lerobot_dataset)
lerobot_dataset.stats = stats lerobot_dataset.stats = stats
else:
logging.info("Skipping computation of the dataset statistrics")
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
hf_dataset.save_to_disk(str(local_dir / "train")) hf_dataset.save_to_disk(str(local_dir / "train"))
@ -389,8 +471,8 @@ def record_dataset(
# TODO(rcadene): push to hub # TODO(rcadene): push to hub
logging.info("Done, exiting") logging.info("Exiting")
os.system('say "Done, exiting" &') os.system('say "Exiting" &')
return lerobot_dataset return lerobot_dataset
@ -532,6 +614,19 @@ if __name__ == "__main__":
help="By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.", help="By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.",
) )
parser_record.add_argument(
"--num-image-writters",
type=int,
default=4,
help="Number of threads writting the frames as png images on disk. Don't set too much as you might get unstable fps due to main thread being blocked.",
)
parser_record.add_argument(
"--force-override",
type=int,
default=0,
help="By default, data recording is resumed. When set to 1, delete the local directory and start data recording from scratch.",
)
parser_replay = subparsers.add_parser("replay_episode", parents=[base_parser]) parser_replay = subparsers.add_parser("replay_episode", parents=[base_parser])
parser_replay.add_argument( parser_replay.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)" "--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"