Add keyboard interaction, Add tqdm, Optimize stuff, Fix, Add resuming
This commit is contained in:
parent
7a659dbd6b
commit
d525e1b0f8
|
@ -207,7 +207,8 @@ def encode_video_frames(
|
||||||
ffmpeg_args.append("-y")
|
ffmpeg_args.append("-y")
|
||||||
|
|
||||||
ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)]
|
ffmpeg_cmd = ["ffmpeg"] + ffmpeg_args + [str(video_path)]
|
||||||
subprocess.run(ffmpeg_cmd, check=True)
|
# redirect stdin to subprocess.DEVNULL to prevent reading random keyboard inputs from terminal
|
||||||
|
subprocess.run(ffmpeg_cmd, check=True, stdin=subprocess.DEVNULL)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
@ -53,9 +53,15 @@ python lerobot/scripts/control_robot.py record_dataset \
|
||||||
--reset-time-s 10
|
--reset-time-s 10
|
||||||
```
|
```
|
||||||
|
|
||||||
**NOTE**: You can early exit while recording an episode or resetting the environment,
|
**NOTE**: You can use your keyboard to control data recording flow.
|
||||||
by tapping the right arrow key '->'. This might require a sudo permission
|
- Tap right arrow key '->' to early exit while recording an episode and go to resseting the environment.
|
||||||
to allow your terminal to monitor keyboard events.
|
- Tap right arrow key '->' to early exit while resetting the environment and got to recording the next episode.
|
||||||
|
- Tap left arrow key '<-' to early exit and re-record the current episode.
|
||||||
|
- Tap escape key 'esc' to stop the data recording.
|
||||||
|
This might require a sudo permission to allow your terminal to monitor keyboard events.
|
||||||
|
|
||||||
|
**NOTE**: You can resume/continue data recording by running the same data recording command twice.
|
||||||
|
To avoid resuming by deleting the dataset, use `--force-override 1`.
|
||||||
|
|
||||||
- Train on this dataset with the ACT policy:
|
- Train on this dataset with the ACT policy:
|
||||||
```bash
|
```bash
|
||||||
|
@ -75,6 +81,7 @@ python lerobot/scripts/control_robot.py run_policy \
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
@ -83,10 +90,12 @@ from contextlib import nullcontext
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import tqdm
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pynput import keyboard
|
from pynput import keyboard
|
||||||
|
|
||||||
|
# from safetensors.torch import load_file, save_file
|
||||||
from lerobot.common.datasets.compute_stats import compute_stats
|
from lerobot.common.datasets.compute_stats import compute_stats
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset
|
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset
|
||||||
|
@ -200,6 +209,8 @@ def record_dataset(
|
||||||
num_episodes=50,
|
num_episodes=50,
|
||||||
video=True,
|
video=True,
|
||||||
run_compute_stats=True,
|
run_compute_stats=True,
|
||||||
|
num_image_writters=4,
|
||||||
|
force_override=False,
|
||||||
):
|
):
|
||||||
# TODO(rcadene): Add option to record logs
|
# TODO(rcadene): Add option to record logs
|
||||||
|
|
||||||
|
@ -210,12 +221,24 @@ def record_dataset(
|
||||||
robot.connect()
|
robot.connect()
|
||||||
|
|
||||||
local_dir = Path(root) / repo_id
|
local_dir = Path(root) / repo_id
|
||||||
if local_dir.exists():
|
if local_dir.exists() and force_override:
|
||||||
shutil.rmtree(local_dir)
|
shutil.rmtree(local_dir)
|
||||||
|
|
||||||
|
episodes_dir = local_dir / "episodes"
|
||||||
|
episodes_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
videos_dir = local_dir / "videos"
|
videos_dir = local_dir / "videos"
|
||||||
videos_dir.mkdir(parents=True, exist_ok=True)
|
videos_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Logic to resume data recording
|
||||||
|
rec_info_path = episodes_dir / "data_recording_info.json"
|
||||||
|
if rec_info_path.exists():
|
||||||
|
with open(rec_info_path) as f:
|
||||||
|
rec_info = json.load(f)
|
||||||
|
episode_index = rec_info["last_episode_index"] + 1
|
||||||
|
else:
|
||||||
|
episode_index = 0
|
||||||
|
|
||||||
# Execute a few seconds without recording data, to give times
|
# Execute a few seconds without recording data, to give times
|
||||||
# to the robot devices to connect and start synchronizing.
|
# to the robot devices to connect and start synchronizing.
|
||||||
timestamp = 0
|
timestamp = 0
|
||||||
|
@ -242,12 +265,25 @@ def record_dataset(
|
||||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||||
# to allow your terminal to monitor keyboard events.
|
# to allow your terminal to monitor keyboard events.
|
||||||
exit_early = False
|
exit_early = False
|
||||||
|
rerecord_episode = False
|
||||||
|
stop_recording = False
|
||||||
|
|
||||||
def on_press(key):
|
def on_press(key):
|
||||||
nonlocal exit_early
|
nonlocal exit_early, rerecord_episode, stop_recording
|
||||||
|
try:
|
||||||
if key == keyboard.Key.right:
|
if key == keyboard.Key.right:
|
||||||
print("Right arrow key pressed. Exiting loop...")
|
print("Right arrow key pressed. Exiting loop...")
|
||||||
exit_early = True
|
exit_early = True
|
||||||
|
elif key == keyboard.Key.left:
|
||||||
|
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
||||||
|
rerecord_episode = True
|
||||||
|
exit_early = True
|
||||||
|
elif key == keyboard.Key.esc:
|
||||||
|
print("Escape key pressed. Stopping data recording...")
|
||||||
|
stop_recording = True
|
||||||
|
exit_early = True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error handling key press: {e}")
|
||||||
|
|
||||||
listener = keyboard.Listener(on_press=on_press)
|
listener = keyboard.Listener(on_press=on_press)
|
||||||
listener.start()
|
listener.start()
|
||||||
|
@ -255,10 +291,10 @@ def record_dataset(
|
||||||
# Save images using threads to reach high fps (30 and more)
|
# Save images using threads to reach high fps (30 and more)
|
||||||
# Using `with` to exist smoothly if an execption is raised.
|
# Using `with` to exist smoothly if an execption is raised.
|
||||||
# Using only 4 worker threads to avoid blocking the main thread.
|
# Using only 4 worker threads to avoid blocking the main thread.
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
|
futures = []
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=num_image_writters) as executor:
|
||||||
# Start recording all episodes
|
# Start recording all episodes
|
||||||
ep_dicts = []
|
while episode_index < num_episodes:
|
||||||
for episode_index in range(num_episodes):
|
|
||||||
logging.info(f"Recording episode {episode_index}")
|
logging.info(f"Recording episode {episode_index}")
|
||||||
os.system(f'say "Recording episode {episode_index}" &')
|
os.system(f'say "Recording episode {episode_index}" &')
|
||||||
ep_dict = {}
|
ep_dict = {}
|
||||||
|
@ -273,7 +309,11 @@ def record_dataset(
|
||||||
not_image_keys = [key for key in observation if "image" not in key]
|
not_image_keys = [key for key in observation if "image" not in key]
|
||||||
|
|
||||||
for key in image_keys:
|
for key in image_keys:
|
||||||
executor.submit(save_image, observation[key], key, frame_index, episode_index, videos_dir)
|
futures += [
|
||||||
|
executor.submit(
|
||||||
|
save_image, observation[key], key, frame_index, episode_index, videos_dir
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
for key in not_image_keys:
|
for key in not_image_keys:
|
||||||
if key not in ep_dict:
|
if key not in ep_dict:
|
||||||
|
@ -299,28 +339,24 @@ def record_dataset(
|
||||||
exit_early = False
|
exit_early = False
|
||||||
break
|
break
|
||||||
|
|
||||||
# Skip resetting if 0 second allocated or it is the last episode
|
if not stop_recording:
|
||||||
if reset_time_s == 0 or episode_index == num_episodes - 1:
|
# Start resetting env while the executor are finishing
|
||||||
continue
|
logging.info("Reset the environment")
|
||||||
|
os.system('say "Reset the environment" &')
|
||||||
|
|
||||||
logging.info("Resetting environment")
|
|
||||||
os.system('say "Resetting environment" &')
|
|
||||||
timestamp = 0
|
timestamp = 0
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
while timestamp < reset_time_s:
|
|
||||||
time.sleep(1)
|
|
||||||
timestamp = time.perf_counter() - start_time
|
|
||||||
|
|
||||||
if exit_early:
|
|
||||||
exit_early = False
|
|
||||||
break
|
|
||||||
|
|
||||||
|
# During env reset we save the data and encode the videos
|
||||||
num_frames = frame_index
|
num_frames = frame_index
|
||||||
|
|
||||||
for key in image_keys:
|
for key in image_keys:
|
||||||
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
|
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
|
||||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||||
# store the reference to the video frame, even tho the videos are not yet encoded
|
video_path = local_dir / "videos" / fname
|
||||||
|
if video_path.exists():
|
||||||
|
video_path.unlink()
|
||||||
|
# Store the reference to the video frame, even tho the videos are not yet encoded
|
||||||
ep_dict[key] = []
|
ep_dict[key] = []
|
||||||
for i in range(num_frames):
|
for i in range(num_frames):
|
||||||
ep_dict[key].append({"path": f"videos/{fname}", "timestamp": i / fps})
|
ep_dict[key].append({"path": f"videos/{fname}", "timestamp": i / fps})
|
||||||
|
@ -339,30 +375,71 @@ def record_dataset(
|
||||||
done[-1] = True
|
done[-1] = True
|
||||||
ep_dict["next.done"] = done
|
ep_dict["next.done"] = done
|
||||||
|
|
||||||
ep_dicts.append(ep_dict)
|
ep_path = episodes_dir / f"episode_{episode_index}.safetensors"
|
||||||
|
print("Saving episode dictionary...")
|
||||||
|
torch.save(ep_dict, ep_path)
|
||||||
|
|
||||||
# last episode
|
rec_info = {
|
||||||
if episode_index == num_episodes - 1:
|
"last_episode_index": episode_index,
|
||||||
|
}
|
||||||
|
with open(rec_info_path, "w") as f:
|
||||||
|
json.dump(rec_info, f)
|
||||||
|
|
||||||
|
# Wait if necessary
|
||||||
|
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
|
||||||
|
while timestamp < reset_time_s and not stop_recording:
|
||||||
|
time.sleep(1)
|
||||||
|
timestamp = time.perf_counter() - start_time
|
||||||
|
pbar.update(1)
|
||||||
|
if exit_early:
|
||||||
|
exit_early = False
|
||||||
|
break
|
||||||
|
|
||||||
|
# Skip updating episode index which forces re-recording episode
|
||||||
|
if rerecord_episode:
|
||||||
|
rerecord_episode = False
|
||||||
|
continue
|
||||||
|
|
||||||
|
episode_index += 1
|
||||||
|
|
||||||
|
# Only for last episode
|
||||||
|
if stop_recording or episode_index == num_episodes:
|
||||||
logging.info("Done recording")
|
logging.info("Done recording")
|
||||||
os.system('say "Done recording" &')
|
os.system('say "Done recording"')
|
||||||
|
logging.info("Waiting for threads writting the images on disk to terminate...")
|
||||||
|
listener.stop()
|
||||||
|
for _ in tqdm.tqdm(
|
||||||
|
concurrent.futures.as_completed(futures), total=len(futures), desc="Writting images"
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
break
|
||||||
|
|
||||||
data_dict = concatenate_episodes(ep_dicts)
|
num_episodes = episode_index
|
||||||
total_frames = data_dict["frame_index"].shape[0]
|
|
||||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
|
||||||
|
|
||||||
logging.info("Encoding images to videos")
|
logging.info("Encoding videos")
|
||||||
os.system('say "Encoding images to videos" &')
|
os.system('say "Encoding videos" &')
|
||||||
|
# Use ffmpeg to convert frames stored as png into mp4 videos
|
||||||
for episode_index in range(num_episodes):
|
for episode_index in tqdm.tqdm(range(num_episodes)):
|
||||||
for key in image_keys:
|
for key in image_keys:
|
||||||
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
|
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
|
||||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||||
video_path = local_dir / "videos" / fname
|
video_path = local_dir / "videos" / fname
|
||||||
|
if video_path.exists():
|
||||||
|
continue
|
||||||
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||||
# since video encoding with ffmpeg is already using multithreading.
|
# since video encoding with ffmpeg is already using multithreading.
|
||||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
encode_video_frames(tmp_imgs_dir, video_path, fps, overwrite=True)
|
||||||
# Clean temporary images directory
|
|
||||||
shutil.rmtree(tmp_imgs_dir)
|
logging.info("Concatenating episodes")
|
||||||
|
ep_dicts = []
|
||||||
|
for episode_index in tqdm.tqdm(range(num_episodes)):
|
||||||
|
ep_path = episodes_dir / f"episode_{episode_index}.safetensors"
|
||||||
|
ep_dict = torch.load(ep_path)
|
||||||
|
ep_dicts.append(ep_dict)
|
||||||
|
data_dict = concatenate_episodes(ep_dicts)
|
||||||
|
|
||||||
|
total_frames = data_dict["frame_index"].shape[0]
|
||||||
|
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||||
|
|
||||||
hf_dataset = to_hf_dataset(data_dict, video)
|
hf_dataset = to_hf_dataset(data_dict, video)
|
||||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||||
|
@ -378,8 +455,13 @@ def record_dataset(
|
||||||
info=info,
|
info=info,
|
||||||
videos_dir=videos_dir,
|
videos_dir=videos_dir,
|
||||||
)
|
)
|
||||||
stats = compute_stats(lerobot_dataset) if run_compute_stats else {}
|
if run_compute_stats:
|
||||||
|
logging.info("Computing dataset statistics")
|
||||||
|
os.system('say "Computing dataset statistics" &')
|
||||||
|
stats = compute_stats(lerobot_dataset)
|
||||||
lerobot_dataset.stats = stats
|
lerobot_dataset.stats = stats
|
||||||
|
else:
|
||||||
|
logging.info("Skipping computation of the dataset statistrics")
|
||||||
|
|
||||||
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
||||||
hf_dataset.save_to_disk(str(local_dir / "train"))
|
hf_dataset.save_to_disk(str(local_dir / "train"))
|
||||||
|
@ -389,8 +471,8 @@ def record_dataset(
|
||||||
|
|
||||||
# TODO(rcadene): push to hub
|
# TODO(rcadene): push to hub
|
||||||
|
|
||||||
logging.info("Done, exiting")
|
logging.info("Exiting")
|
||||||
os.system('say "Done, exiting" &')
|
os.system('say "Exiting" &')
|
||||||
|
|
||||||
return lerobot_dataset
|
return lerobot_dataset
|
||||||
|
|
||||||
|
@ -532,6 +614,19 @@ if __name__ == "__main__":
|
||||||
help="By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.",
|
help="By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser_record.add_argument(
|
||||||
|
"--num-image-writters",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="Number of threads writting the frames as png images on disk. Don't set too much as you might get unstable fps due to main thread being blocked.",
|
||||||
|
)
|
||||||
|
parser_record.add_argument(
|
||||||
|
"--force-override",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="By default, data recording is resumed. When set to 1, delete the local directory and start data recording from scratch.",
|
||||||
|
)
|
||||||
|
|
||||||
parser_replay = subparsers.add_parser("replay_episode", parents=[base_parser])
|
parser_replay = subparsers.add_parser("replay_episode", parents=[base_parser])
|
||||||
parser_replay.add_argument(
|
parser_replay.add_argument(
|
||||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||||
|
|
Loading…
Reference in New Issue