Make data recording smooth and small fix
This commit is contained in:
parent
e6aa8cc641
commit
73692e7459
|
@ -232,13 +232,13 @@ def record_dataset(
|
||||||
|
|
||||||
timestamp = time.perf_counter() - start_time
|
timestamp = time.perf_counter() - start_time
|
||||||
|
|
||||||
# Start recording all episodes
|
|
||||||
ep_dicts = []
|
|
||||||
for episode_index in range(num_episodes):
|
|
||||||
# Save images using threads to reach high fps (30 and more)
|
# Save images using threads to reach high fps (30 and more)
|
||||||
# Using `with` to exist smoothly if an execption is raised.
|
# Using `with` to exist smoothly if an execption is raised.
|
||||||
# Using only 4 worker threads to avoid blocking the main thread.
|
# Using only 4 worker threads to avoid blocking the main thread.
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
|
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
|
||||||
|
# Start recording all episodes
|
||||||
|
ep_dicts = []
|
||||||
|
for episode_index in range(num_episodes):
|
||||||
ep_dict = {}
|
ep_dict = {}
|
||||||
frame_index = 0
|
frame_index = 0
|
||||||
timestamp = 0
|
timestamp = 0
|
||||||
|
@ -279,21 +279,12 @@ def record_dataset(
|
||||||
|
|
||||||
timestamp = time.perf_counter() - start_time
|
timestamp = time.perf_counter() - start_time
|
||||||
|
|
||||||
logging.info("Encoding images to videos")
|
|
||||||
|
|
||||||
num_frames = frame_index
|
num_frames = frame_index
|
||||||
|
|
||||||
for key in image_keys:
|
for key in image_keys:
|
||||||
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
|
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
|
||||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||||
video_path = local_dir / "videos" / fname
|
# store the reference to the video frame, even tho the videos are not yet encoded
|
||||||
# note: video encoding is not done asynchronously
|
|
||||||
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
|
||||||
|
|
||||||
# clean temporary images directory
|
|
||||||
shutil.rmtree(tmp_imgs_dir)
|
|
||||||
|
|
||||||
# store the reference to the video frame
|
|
||||||
ep_dict[key] = []
|
ep_dict[key] = []
|
||||||
for i in range(num_frames):
|
for i in range(num_frames):
|
||||||
ep_dict[key].append({"path": f"videos/{fname}", "timestamp": i / fps})
|
ep_dict[key].append({"path": f"videos/{fname}", "timestamp": i / fps})
|
||||||
|
@ -314,11 +305,29 @@ def record_dataset(
|
||||||
|
|
||||||
ep_dicts.append(ep_dict)
|
ep_dicts.append(ep_dict)
|
||||||
|
|
||||||
data_dict = concatenate_episodes(ep_dicts)
|
# last episode
|
||||||
|
if episode_index == num_episodes - 1:
|
||||||
|
logging.info("Done recording")
|
||||||
|
os.system('say "Done recording" &')
|
||||||
|
|
||||||
|
data_dict = concatenate_episodes(ep_dicts)
|
||||||
total_frames = data_dict["frame_index"].shape[0]
|
total_frames = data_dict["frame_index"].shape[0]
|
||||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||||
|
|
||||||
|
logging.info("Encoding images to videos")
|
||||||
|
os.system('say "Encoding images to videos" &')
|
||||||
|
|
||||||
|
for episode_index in range(num_episodes):
|
||||||
|
for key in image_keys:
|
||||||
|
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
|
||||||
|
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||||
|
video_path = local_dir / "videos" / fname
|
||||||
|
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||||
|
# since video encoding with ffmpeg is already using multithreading.
|
||||||
|
encode_video_frames(tmp_imgs_dir, video_path, fps)
|
||||||
|
# Clean temporary images directory
|
||||||
|
shutil.rmtree(tmp_imgs_dir)
|
||||||
|
|
||||||
hf_dataset = to_hf_dataset(data_dict, video)
|
hf_dataset = to_hf_dataset(data_dict, video)
|
||||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||||
info = {
|
info = {
|
||||||
|
@ -326,13 +335,6 @@ def record_dataset(
|
||||||
"video": video,
|
"video": video,
|
||||||
}
|
}
|
||||||
|
|
||||||
meta_data_dir = local_dir / "meta_data"
|
|
||||||
|
|
||||||
for key in image_keys:
|
|
||||||
time.sleep(10)
|
|
||||||
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
|
|
||||||
shutil.rmtree(tmp_imgs_dir)
|
|
||||||
|
|
||||||
lerobot_dataset = LeRobotDataset.from_preloaded(
|
lerobot_dataset = LeRobotDataset.from_preloaded(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
hf_dataset=hf_dataset,
|
hf_dataset=hf_dataset,
|
||||||
|
@ -346,10 +348,14 @@ def record_dataset(
|
||||||
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
||||||
hf_dataset.save_to_disk(str(local_dir / "train"))
|
hf_dataset.save_to_disk(str(local_dir / "train"))
|
||||||
|
|
||||||
|
meta_data_dir = local_dir / "meta_data"
|
||||||
save_meta_data(info, stats, episode_data_index, meta_data_dir)
|
save_meta_data(info, stats, episode_data_index, meta_data_dir)
|
||||||
|
|
||||||
# TODO(rcadene): push to hub
|
# TODO(rcadene): push to hub
|
||||||
|
|
||||||
|
logging.info("Done, exiting")
|
||||||
|
os.system('say "Done, exiting" &')
|
||||||
|
|
||||||
return lerobot_dataset
|
return lerobot_dataset
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue