Make data recording smooth and small fix

This commit is contained in:
Remi Cadene 2024-07-11 17:14:23 +02:00
parent e6aa8cc641
commit 73692e7459
1 changed files with 31 additions and 25 deletions

View File

@ -232,13 +232,13 @@ def record_dataset(
timestamp = time.perf_counter() - start_time timestamp = time.perf_counter() - start_time
# Start recording all episodes
ep_dicts = []
for episode_index in range(num_episodes):
# Save images using threads to reach high fps (30 and more) # Save images using threads to reach high fps (30 and more)
# Using `with` to exist smoothly if an execption is raised. # Using `with` to exist smoothly if an execption is raised.
# Using only 4 worker threads to avoid blocking the main thread. # Using only 4 worker threads to avoid blocking the main thread.
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
# Start recording all episodes
ep_dicts = []
for episode_index in range(num_episodes):
ep_dict = {} ep_dict = {}
frame_index = 0 frame_index = 0
timestamp = 0 timestamp = 0
@ -279,21 +279,12 @@ def record_dataset(
timestamp = time.perf_counter() - start_time timestamp = time.perf_counter() - start_time
logging.info("Encoding images to videos")
num_frames = frame_index num_frames = frame_index
for key in image_keys: for key in image_keys:
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}" tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
fname = f"{key}_episode_{episode_index:06d}.mp4" fname = f"{key}_episode_{episode_index:06d}.mp4"
video_path = local_dir / "videos" / fname # store the reference to the video frame, even tho the videos are not yet encoded
# note: video encoding is not done asynchronously
encode_video_frames(tmp_imgs_dir, video_path, fps)
# clean temporary images directory
shutil.rmtree(tmp_imgs_dir)
# store the reference to the video frame
ep_dict[key] = [] ep_dict[key] = []
for i in range(num_frames): for i in range(num_frames):
ep_dict[key].append({"path": f"videos/{fname}", "timestamp": i / fps}) ep_dict[key].append({"path": f"videos/{fname}", "timestamp": i / fps})
@ -314,11 +305,29 @@ def record_dataset(
ep_dicts.append(ep_dict) ep_dicts.append(ep_dict)
data_dict = concatenate_episodes(ep_dicts) # last episode
if episode_index == num_episodes - 1:
logging.info("Done recording")
os.system('say "Done recording" &')
data_dict = concatenate_episodes(ep_dicts)
total_frames = data_dict["frame_index"].shape[0] total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1) data_dict["index"] = torch.arange(0, total_frames, 1)
logging.info("Encoding images to videos")
os.system('say "Encoding images to videos" &')
for episode_index in range(num_episodes):
for key in image_keys:
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
fname = f"{key}_episode_{episode_index:06d}.mp4"
video_path = local_dir / "videos" / fname
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
# since video encoding with ffmpeg is already using multithreading.
encode_video_frames(tmp_imgs_dir, video_path, fps)
# Clean temporary images directory
shutil.rmtree(tmp_imgs_dir)
hf_dataset = to_hf_dataset(data_dict, video) hf_dataset = to_hf_dataset(data_dict, video)
episode_data_index = calculate_episode_data_index(hf_dataset) episode_data_index = calculate_episode_data_index(hf_dataset)
info = { info = {
@ -326,13 +335,6 @@ def record_dataset(
"video": video, "video": video,
} }
meta_data_dir = local_dir / "meta_data"
for key in image_keys:
time.sleep(10)
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
shutil.rmtree(tmp_imgs_dir)
lerobot_dataset = LeRobotDataset.from_preloaded( lerobot_dataset = LeRobotDataset.from_preloaded(
repo_id=repo_id, repo_id=repo_id,
hf_dataset=hf_dataset, hf_dataset=hf_dataset,
@ -346,10 +348,14 @@ def record_dataset(
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
hf_dataset.save_to_disk(str(local_dir / "train")) hf_dataset.save_to_disk(str(local_dir / "train"))
meta_data_dir = local_dir / "meta_data"
save_meta_data(info, stats, episode_data_index, meta_data_dir) save_meta_data(info, stats, episode_data_index, meta_data_dir)
# TODO(rcadene): push to hub # TODO(rcadene): push to hub
logging.info("Done, exiting")
os.system('say "Done, exiting" &')
return lerobot_dataset return lerobot_dataset