Add video_info, fix image_writer
This commit is contained in:
parent
18ffa4248b
commit
e210d795de
|
@ -19,9 +19,6 @@ from math import ceil
|
||||||
import einops
|
import einops
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
from datasets import Image
|
|
||||||
|
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame
|
|
||||||
|
|
||||||
|
|
||||||
def get_stats_einops_patterns(dataset, num_workers=0):
|
def get_stats_einops_patterns(dataset, num_workers=0):
|
||||||
|
@ -39,15 +36,13 @@ def get_stats_einops_patterns(dataset, num_workers=0):
|
||||||
batch = next(iter(dataloader))
|
batch = next(iter(dataloader))
|
||||||
|
|
||||||
stats_patterns = {}
|
stats_patterns = {}
|
||||||
for key, feats_type in dataset.features.items():
|
|
||||||
# NOTE: skip language_instruction embedding in stats computation
|
|
||||||
if key == "language_instruction":
|
|
||||||
continue
|
|
||||||
|
|
||||||
|
for key in dataset.features:
|
||||||
# sanity check that tensors are not float64
|
# sanity check that tensors are not float64
|
||||||
assert batch[key].dtype != torch.float64
|
assert batch[key].dtype != torch.float64
|
||||||
|
|
||||||
if isinstance(feats_type, (VideoFrame, Image)):
|
# if isinstance(feats_type, (VideoFrame, Image)):
|
||||||
|
if key in dataset.camera_keys:
|
||||||
# sanity check that images are channel first
|
# sanity check that images are channel first
|
||||||
_, c, h, w = batch[key].shape
|
_, c, h, w = batch[key].shape
|
||||||
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
|
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
|
||||||
|
@ -63,7 +58,7 @@ def get_stats_einops_patterns(dataset, num_workers=0):
|
||||||
elif batch[key].ndim == 1:
|
elif batch[key].ndim == 1:
|
||||||
stats_patterns[key] = "b -> 1"
|
stats_patterns[key] = "b -> 1"
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"{key}, {feats_type}, {batch[key].shape}")
|
raise ValueError(f"{key}, {batch[key].shape}")
|
||||||
|
|
||||||
return stats_patterns
|
return stats_patterns
|
||||||
|
|
||||||
|
|
|
@ -53,45 +53,54 @@ class ImageWriter:
|
||||||
the number of threads. If it is still not stable, try to use 1 subprocess, or more.
|
the number of threads. If it is still not stable, try to use 1 subprocess, or more.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1):
|
def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1, timeout: int = 10):
|
||||||
self.dir = write_dir
|
self.dir = write_dir
|
||||||
self.dir.mkdir(parents=True, exist_ok=True)
|
self.dir.mkdir(parents=True, exist_ok=True)
|
||||||
self.image_path = DEFAULT_IMAGE_PATH
|
self.image_path = DEFAULT_IMAGE_PATH
|
||||||
self.num_processes = num_processes
|
self.num_processes = num_processes
|
||||||
self.num_threads = self.num_threads_per_process = num_threads
|
self.num_threads = num_threads
|
||||||
|
self.timeout = timeout
|
||||||
|
|
||||||
if self.num_processes <= 0:
|
if self.num_processes == 0 and self.num_threads == 0:
|
||||||
|
self.type = "synchronous"
|
||||||
|
elif self.num_processes == 0 and self.num_threads > 0:
|
||||||
self.type = "threads"
|
self.type = "threads"
|
||||||
self.threads = ThreadPoolExecutor(max_workers=self.num_threads)
|
self.threads = ThreadPoolExecutor(max_workers=self.num_threads)
|
||||||
self.futures = []
|
self.futures = []
|
||||||
else:
|
else:
|
||||||
self.type = "processes"
|
self.type = "processes"
|
||||||
self.num_threads_per_process = self.num_threads
|
self.main_event = multiprocessing.Event()
|
||||||
self.image_queue = multiprocessing.Queue()
|
self.image_queue = multiprocessing.Queue()
|
||||||
self.processes: list[multiprocessing.Process] = []
|
self.processes: list[multiprocessing.Process] = []
|
||||||
for _ in range(num_processes):
|
self.events: list[multiprocessing.Event] = []
|
||||||
process = multiprocessing.Process(target=self._loop_to_save_images_in_threads)
|
for _ in range(self.num_processes):
|
||||||
|
event = multiprocessing.Event()
|
||||||
|
process = multiprocessing.Process(target=self._loop_to_save_images_in_threads, args=(event,))
|
||||||
process.start()
|
process.start()
|
||||||
self.processes.append(process)
|
self.processes.append(process)
|
||||||
|
self.events.append(event)
|
||||||
|
|
||||||
def _loop_to_save_images_in_threads(self) -> None:
|
def _loop_to_save_images_in_threads(self, event: multiprocessing.Event) -> None:
|
||||||
with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
|
with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
|
||||||
futures = []
|
futures = []
|
||||||
while True:
|
while True:
|
||||||
frame_data = self.image_queue.get()
|
frame_data = self.image_queue.get()
|
||||||
if frame_data is None:
|
if frame_data is None:
|
||||||
break
|
self._wait_threads(self.futures, 10)
|
||||||
|
return
|
||||||
|
|
||||||
image, file_path = frame_data
|
image, file_path = frame_data
|
||||||
futures.append(executor.submit(self._save_image, image, file_path))
|
futures.append(executor.submit(self._save_image, image, file_path))
|
||||||
|
|
||||||
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
|
if self.main_event.is_set():
|
||||||
wait(futures)
|
self._wait_threads(self.futures, 10)
|
||||||
progress_bar.update(len(futures))
|
event.set()
|
||||||
|
|
||||||
def async_save_image(self, image: torch.Tensor, file_path: Path) -> None:
|
def async_save_image(self, image: torch.Tensor, file_path: Path) -> None:
|
||||||
"""Save an image asynchronously using threads or processes."""
|
"""Save an image asynchronously using threads or processes."""
|
||||||
if self.type == "threads":
|
if self.type == "synchronous":
|
||||||
|
self._save_image(image, file_path)
|
||||||
|
elif self.type == "threads":
|
||||||
self.futures.append(self.threads.submit(self._save_image, image, file_path))
|
self.futures.append(self.threads.submit(self._save_image, image, file_path))
|
||||||
else:
|
else:
|
||||||
self.image_queue.put((image, file_path))
|
self.image_queue.put((image, file_path))
|
||||||
|
@ -111,12 +120,33 @@ class ImageWriter:
|
||||||
episode_index=episode_index, image_key=image_key, frame_index=0
|
episode_index=episode_index, image_key=image_key, frame_index=0
|
||||||
).parent
|
).parent
|
||||||
|
|
||||||
def stop(self, timeout=20) -> None:
|
def wait(self) -> None:
|
||||||
|
"""Wait for the thread/processes to finish writing."""
|
||||||
|
if self.type == "synchronous":
|
||||||
|
return
|
||||||
|
elif self.type == "threads":
|
||||||
|
self._wait_threads(self.futures)
|
||||||
|
else:
|
||||||
|
self._wait_processes()
|
||||||
|
|
||||||
|
def _wait_threads(self, futures) -> None:
|
||||||
|
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
|
||||||
|
wait(futures, timeout=self.timeout)
|
||||||
|
progress_bar.update(len(futures))
|
||||||
|
|
||||||
|
def _wait_processes(self) -> None:
|
||||||
|
self.main_event.set()
|
||||||
|
for event in self.events:
|
||||||
|
event.wait()
|
||||||
|
|
||||||
|
self.main_event.clear()
|
||||||
|
|
||||||
|
def shutdown(self, timeout=20) -> None:
|
||||||
"""Stop the image writer, waiting for all processes or threads to finish."""
|
"""Stop the image writer, waiting for all processes or threads to finish."""
|
||||||
if self.type == "threads":
|
if self.type == "synchronous":
|
||||||
with tqdm.tqdm(total=len(self.futures), desc="Writing images") as progress_bar:
|
return
|
||||||
wait(self.futures, timeout=timeout)
|
elif self.type == "threads":
|
||||||
progress_bar.update(len(self.futures))
|
self.threads.shutdown(wait=True)
|
||||||
else:
|
else:
|
||||||
self._stop_processes(timeout)
|
self._stop_processes(timeout)
|
||||||
|
|
||||||
|
@ -127,8 +157,9 @@ class ImageWriter:
|
||||||
for process in self.processes:
|
for process in self.processes:
|
||||||
process.join(timeout=timeout)
|
process.join(timeout=timeout)
|
||||||
|
|
||||||
if process.is_alive():
|
for process in self.processes:
|
||||||
process.terminate()
|
if process.is_alive():
|
||||||
|
process.terminate()
|
||||||
|
|
||||||
self.image_queue.close()
|
self.image_queue.close()
|
||||||
self.image_queue.join_thread()
|
self.image_queue.join_thread()
|
||||||
|
|
|
@ -22,10 +22,10 @@ from pathlib import Path
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import pyarrow.parquet as pq
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils
|
import torch.utils
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
from datasets.table import embed_table_storage
|
||||||
from huggingface_hub import snapshot_download, upload_folder
|
from huggingface_hub import snapshot_download, upload_folder
|
||||||
|
|
||||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
|
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
|
||||||
|
@ -57,6 +57,7 @@ from lerobot.common.datasets.video_utils import (
|
||||||
VideoFrame,
|
VideoFrame,
|
||||||
decode_video_frames_torchvision,
|
decode_video_frames_torchvision,
|
||||||
encode_video_frames,
|
encode_video_frames,
|
||||||
|
get_video_info,
|
||||||
)
|
)
|
||||||
from lerobot.common.robot_devices.robots.utils import Robot
|
from lerobot.common.robot_devices.robots.utils import Robot
|
||||||
|
|
||||||
|
@ -391,7 +392,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
return self.info["shapes"]
|
return self.info["shapes"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def features(self) -> datasets.Features:
|
def features(self) -> list[str]:
|
||||||
|
return list(self._features) + self.video_keys
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _features(self) -> datasets.Features:
|
||||||
"""Features of the hf_dataset."""
|
"""Features of the hf_dataset."""
|
||||||
if self.hf_dataset is not None:
|
if self.hf_dataset is not None:
|
||||||
return self.hf_dataset.features
|
return self.hf_dataset.features
|
||||||
|
@ -583,6 +588,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
image=frame[cam_key],
|
image=frame[cam_key],
|
||||||
file_path=img_path,
|
file_path=img_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cam_key in self.image_keys:
|
if cam_key in self.image_keys:
|
||||||
self.episode_buffer[cam_key].append(str(img_path))
|
self.episode_buffer[cam_key].append(str(img_path))
|
||||||
|
|
||||||
|
@ -592,7 +598,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
|
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
|
||||||
the hub.
|
the hub.
|
||||||
|
|
||||||
Use 'encode_videos' if you want to encode videos during the saving of each episode. Otherwise,
|
Use 'encode_videos' if you want to encode videos during the saving of this episode. Otherwise,
|
||||||
you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend
|
you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend
|
||||||
time for video encoding.
|
time for video encoding.
|
||||||
"""
|
"""
|
||||||
|
@ -608,7 +614,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
for key in self.episode_buffer:
|
for key in self.episode_buffer:
|
||||||
if key in self.image_keys:
|
if key in self.image_keys:
|
||||||
continue
|
continue
|
||||||
if key in self.keys:
|
elif key in self.keys:
|
||||||
self.episode_buffer[key] = torch.stack(self.episode_buffer[key])
|
self.episode_buffer[key] = torch.stack(self.episode_buffer[key])
|
||||||
elif key == "episode_index":
|
elif key == "episode_index":
|
||||||
self.episode_buffer[key] = torch.full((episode_length,), episode_index)
|
self.episode_buffer[key] = torch.full((episode_length,), episode_index)
|
||||||
|
@ -619,6 +625,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
self.episode_buffer["index"] = torch.arange(self.total_frames, self.total_frames + episode_length)
|
self.episode_buffer["index"] = torch.arange(self.total_frames, self.total_frames + episode_length)
|
||||||
self._save_episode_to_metadata(episode_index, episode_length, task, task_index)
|
self._save_episode_to_metadata(episode_index, episode_length, task, task_index)
|
||||||
|
|
||||||
|
self._wait_image_writer()
|
||||||
self._save_episode_table(episode_index)
|
self._save_episode_table(episode_index)
|
||||||
|
|
||||||
if encode_videos and len(self.video_keys) > 0:
|
if encode_videos and len(self.video_keys) > 0:
|
||||||
|
@ -629,11 +637,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.consolidated = False
|
self.consolidated = False
|
||||||
|
|
||||||
def _save_episode_table(self, episode_index: int) -> None:
|
def _save_episode_table(self, episode_index: int) -> None:
|
||||||
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self.features, split="train")
|
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self._features, split="train")
|
||||||
ep_table = ep_dataset._data.table
|
|
||||||
ep_data_path = self.root / self.get_data_file_path(ep_index=episode_index)
|
ep_data_path = self.root / self.get_data_file_path(ep_index=episode_index)
|
||||||
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
pq.write_table(ep_table, ep_data_path)
|
|
||||||
|
# Embed image bytes into the table before saving to parquet
|
||||||
|
format = ep_dataset.format
|
||||||
|
ep_dataset = ep_dataset.with_format("arrow")
|
||||||
|
ep_dataset = ep_dataset.map(embed_table_storage, batched=False)
|
||||||
|
ep_dataset = ep_dataset.with_format(**format)
|
||||||
|
|
||||||
|
ep_dataset.to_parquet(ep_data_path)
|
||||||
|
|
||||||
def _save_episode_to_metadata(
|
def _save_episode_to_metadata(
|
||||||
self, episode_index: int, episode_length: int, task: str, task_index: int
|
self, episode_index: int, episode_length: int, task: str, task_index: int
|
||||||
|
@ -677,7 +691,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
# Reset the buffer
|
# Reset the buffer
|
||||||
self.episode_buffer = self._create_episode_buffer()
|
self.episode_buffer = self._create_episode_buffer()
|
||||||
|
|
||||||
def start_image_writter(self, num_processes: int = 0, num_threads: int = 1) -> None:
|
def start_image_writer(self, num_processes: int = 0, num_threads: int = 1) -> None:
|
||||||
if isinstance(self.image_writer, ImageWriter):
|
if isinstance(self.image_writer, ImageWriter):
|
||||||
logging.warning(
|
logging.warning(
|
||||||
"You are starting a new ImageWriter that is replacing an already exising one in the dataset."
|
"You are starting a new ImageWriter that is replacing an already exising one in the dataset."
|
||||||
|
@ -689,18 +703,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
num_threads=num_threads,
|
num_threads=num_threads,
|
||||||
)
|
)
|
||||||
|
|
||||||
def stop_image_writter(self) -> None:
|
def stop_image_writer(self) -> None:
|
||||||
"""
|
"""
|
||||||
Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to
|
Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to
|
||||||
remove the image_write in order for the LeRobotDataset object to be pickleable and parallelized.
|
remove the image_write in order for the LeRobotDataset object to be pickleable and parallelized.
|
||||||
"""
|
"""
|
||||||
if self.image_writer is not None:
|
if self.image_writer is not None:
|
||||||
self.image_writer.stop()
|
self.image_writer.shutdown()
|
||||||
self.image_writer = None
|
self.image_writer = None
|
||||||
|
|
||||||
|
def _wait_image_writer(self) -> None:
|
||||||
|
"""Wait for asynchronous image writer to finish."""
|
||||||
|
if self.image_writer is not None:
|
||||||
|
self.image_writer.wait()
|
||||||
|
|
||||||
def encode_videos(self) -> None:
|
def encode_videos(self) -> None:
|
||||||
# Use ffmpeg to convert frames stored as png into mp4 videos
|
# Use ffmpeg to convert frames stored as png into mp4 videos
|
||||||
for episode_index in range(self.num_episodes):
|
for episode_index in range(self.total_episodes):
|
||||||
for key in self.video_keys:
|
for key in self.video_keys:
|
||||||
# TODO: create video_buffer to store the state of encoded/unencoded videos and remove the need
|
# TODO: create video_buffer to store the state of encoded/unencoded videos and remove the need
|
||||||
# to call self.image_writer here
|
# to call self.image_writer here
|
||||||
|
@ -713,6 +732,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
# since video encoding with ffmpeg is already using multithreading.
|
# since video encoding with ffmpeg is already using multithreading.
|
||||||
encode_video_frames(tmp_imgs_dir, video_path, self.fps, overwrite=True)
|
encode_video_frames(tmp_imgs_dir, video_path, self.fps, overwrite=True)
|
||||||
|
|
||||||
|
def _write_video_info(self) -> None:
|
||||||
|
"""
|
||||||
|
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
||||||
|
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||||
|
"""
|
||||||
|
for key in self.video_keys:
|
||||||
|
if key not in self.info["videos"]:
|
||||||
|
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
|
||||||
|
self.info["videos"][key] = get_video_info(video_path)
|
||||||
|
|
||||||
|
write_json(self.info, self.root / INFO_PATH)
|
||||||
|
|
||||||
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
|
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
|
||||||
self.hf_dataset = self.load_hf_dataset()
|
self.hf_dataset = self.load_hf_dataset()
|
||||||
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
|
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
|
||||||
|
@ -720,12 +751,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
if len(self.video_keys) > 0:
|
if len(self.video_keys) > 0:
|
||||||
self.encode_videos()
|
self.encode_videos()
|
||||||
|
self._write_video_info()
|
||||||
|
|
||||||
if not keep_image_files and self.image_writer is not None:
|
if not keep_image_files and self.image_writer is not None:
|
||||||
shutil.rmtree(self.image_writer.dir)
|
shutil.rmtree(self.image_writer.dir)
|
||||||
|
|
||||||
if run_compute_stats:
|
if run_compute_stats:
|
||||||
self.stop_image_writter()
|
self.stop_image_writer()
|
||||||
self.stats = compute_stats(self)
|
self.stats = compute_stats(self)
|
||||||
write_stats(self.stats, self.root / STATS_PATH)
|
write_stats(self.stats, self.root / STATS_PATH)
|
||||||
self.consolidated = True
|
self.consolidated = True
|
||||||
|
@ -735,7 +767,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(aliberts)
|
# TODO(aliberts)
|
||||||
# - [ ] add video info in info.json
|
# - [X] add video info in info.json
|
||||||
# Sanity checks:
|
# Sanity checks:
|
||||||
# - [ ] shapes
|
# - [ ] shapes
|
||||||
# - [ ] ep_lenghts
|
# - [ ] ep_lenghts
|
||||||
|
@ -775,7 +807,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
|
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
|
||||||
)
|
)
|
||||||
if len(robot.cameras) > 0 and (image_writer_processes or image_writer_threads_per_camera):
|
if len(robot.cameras) > 0 and (image_writer_processes or image_writer_threads_per_camera):
|
||||||
obj.start_image_writter(
|
obj.start_image_writer(
|
||||||
image_writer_processes, image_writer_threads_per_camera * robot.num_cameras
|
image_writer_processes, image_writer_threads_per_camera * robot.num_cameras
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
|
@ -791,7 +823,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(video_keys) > 0 and not use_videos:
|
if len(video_keys) > 0 and not use_videos:
|
||||||
raise ValueError
|
raise ValueError()
|
||||||
|
|
||||||
obj.tasks, obj.stats, obj.episode_dicts = {}, {}, []
|
obj.tasks, obj.stats, obj.episode_dicts = {}, {}, []
|
||||||
obj.info = create_empty_dataset_info(
|
obj.info = create_empty_dataset_info(
|
||||||
|
@ -918,7 +950,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||||
def features(self) -> datasets.Features:
|
def features(self) -> datasets.Features:
|
||||||
features = {}
|
features = {}
|
||||||
for dataset in self._datasets:
|
for dataset in self._datasets:
|
||||||
features.update({k: v for k, v in dataset.features.items() if k not in self.disabled_data_keys})
|
features.update({k: v for k, v in dataset._features.items() if k not in self.disabled_data_keys})
|
||||||
return features
|
return features
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -116,7 +116,6 @@ import torch
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from huggingface_hub.errors import EntryNotFoundError
|
from huggingface_hub.errors import EntryNotFoundError
|
||||||
from PIL import Image
|
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
|
@ -136,7 +135,12 @@ from lerobot.common.datasets.utils import (
|
||||||
write_json,
|
write_json,
|
||||||
write_jsonlines,
|
write_jsonlines,
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame # noqa: F401
|
from lerobot.common.datasets.video_utils import (
|
||||||
|
VideoFrame, # noqa: F401
|
||||||
|
get_image_shapes,
|
||||||
|
get_video_info,
|
||||||
|
get_video_shapes,
|
||||||
|
)
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
|
|
||||||
V16 = "v1.6"
|
V16 = "v1.6"
|
||||||
|
@ -391,83 +395,6 @@ def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[st
|
||||||
return [f for f in video_files if f not in lfs_tracked_files]
|
return [f for f in video_files if f not in lfs_tracked_files]
|
||||||
|
|
||||||
|
|
||||||
def _get_audio_info(video_path: Path | str) -> dict:
|
|
||||||
ffprobe_audio_cmd = [
|
|
||||||
"ffprobe",
|
|
||||||
"-v",
|
|
||||||
"error",
|
|
||||||
"-select_streams",
|
|
||||||
"a:0",
|
|
||||||
"-show_entries",
|
|
||||||
"stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration",
|
|
||||||
"-of",
|
|
||||||
"json",
|
|
||||||
str(video_path),
|
|
||||||
]
|
|
||||||
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
|
||||||
if result.returncode != 0:
|
|
||||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
|
||||||
|
|
||||||
info = json.loads(result.stdout)
|
|
||||||
audio_stream_info = info["streams"][0] if info.get("streams") else None
|
|
||||||
if audio_stream_info is None:
|
|
||||||
return {"has_audio": False}
|
|
||||||
|
|
||||||
# Return the information, defaulting to None if no audio stream is present
|
|
||||||
return {
|
|
||||||
"has_audio": True,
|
|
||||||
"audio.channels": audio_stream_info.get("channels", None),
|
|
||||||
"audio.codec": audio_stream_info.get("codec_name", None),
|
|
||||||
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
|
|
||||||
"audio.sample_rate": int(audio_stream_info["sample_rate"])
|
|
||||||
if audio_stream_info.get("sample_rate")
|
|
||||||
else None,
|
|
||||||
"audio.bit_depth": audio_stream_info.get("bit_depth", None),
|
|
||||||
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _get_video_info(video_path: Path | str) -> dict:
|
|
||||||
ffprobe_video_cmd = [
|
|
||||||
"ffprobe",
|
|
||||||
"-v",
|
|
||||||
"error",
|
|
||||||
"-select_streams",
|
|
||||||
"v:0",
|
|
||||||
"-show_entries",
|
|
||||||
"stream=r_frame_rate,width,height,codec_name,nb_frames,duration,pix_fmt",
|
|
||||||
"-of",
|
|
||||||
"json",
|
|
||||||
str(video_path),
|
|
||||||
]
|
|
||||||
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
|
||||||
if result.returncode != 0:
|
|
||||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
|
||||||
|
|
||||||
info = json.loads(result.stdout)
|
|
||||||
video_stream_info = info["streams"][0]
|
|
||||||
|
|
||||||
# Calculate fps from r_frame_rate
|
|
||||||
r_frame_rate = video_stream_info["r_frame_rate"]
|
|
||||||
num, denom = map(int, r_frame_rate.split("/"))
|
|
||||||
fps = num / denom
|
|
||||||
|
|
||||||
pixel_channels = get_video_pixel_channels(video_stream_info["pix_fmt"])
|
|
||||||
|
|
||||||
video_info = {
|
|
||||||
"video.fps": fps,
|
|
||||||
"video.width": video_stream_info["width"],
|
|
||||||
"video.height": video_stream_info["height"],
|
|
||||||
"video.channels": pixel_channels,
|
|
||||||
"video.codec": video_stream_info["codec_name"],
|
|
||||||
"video.pix_fmt": video_stream_info["pix_fmt"],
|
|
||||||
"video.is_depth_map": False,
|
|
||||||
**_get_audio_info(video_path),
|
|
||||||
}
|
|
||||||
|
|
||||||
return video_info
|
|
||||||
|
|
||||||
|
|
||||||
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
|
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
|
||||||
hub_api = HfApi()
|
hub_api = HfApi()
|
||||||
videos_info_dict = {"videos_path": DEFAULT_VIDEO_PATH}
|
videos_info_dict = {"videos_path": DEFAULT_VIDEO_PATH}
|
||||||
|
@ -481,62 +408,11 @@ def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch
|
||||||
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
|
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
|
||||||
)
|
)
|
||||||
for vid_key, vid_path in zip(video_keys, video_files, strict=True):
|
for vid_key, vid_path in zip(video_keys, video_files, strict=True):
|
||||||
videos_info_dict[vid_key] = _get_video_info(local_dir / vid_path)
|
videos_info_dict[vid_key] = get_video_info(local_dir / vid_path)
|
||||||
|
|
||||||
return videos_info_dict
|
return videos_info_dict
|
||||||
|
|
||||||
|
|
||||||
def get_video_pixel_channels(pix_fmt: str) -> int:
|
|
||||||
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
|
|
||||||
return 1
|
|
||||||
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
|
|
||||||
return 4
|
|
||||||
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
|
|
||||||
return 3
|
|
||||||
else:
|
|
||||||
raise ValueError("Unknown format")
|
|
||||||
|
|
||||||
|
|
||||||
def get_image_pixel_channels(image: Image):
|
|
||||||
if image.mode == "L":
|
|
||||||
return 1 # Grayscale
|
|
||||||
elif image.mode == "LA":
|
|
||||||
return 2 # Grayscale + Alpha
|
|
||||||
elif image.mode == "RGB":
|
|
||||||
return 3 # RGB
|
|
||||||
elif image.mode == "RGBA":
|
|
||||||
return 4 # RGBA
|
|
||||||
else:
|
|
||||||
raise ValueError("Unknown format")
|
|
||||||
|
|
||||||
|
|
||||||
def get_video_shapes(videos_info: dict, video_keys: list) -> dict:
|
|
||||||
video_shapes = {}
|
|
||||||
for img_key in video_keys:
|
|
||||||
channels = get_video_pixel_channels(videos_info[img_key]["video.pix_fmt"])
|
|
||||||
video_shapes[img_key] = {
|
|
||||||
"width": videos_info[img_key]["video.width"],
|
|
||||||
"height": videos_info[img_key]["video.height"],
|
|
||||||
"channels": channels,
|
|
||||||
}
|
|
||||||
|
|
||||||
return video_shapes
|
|
||||||
|
|
||||||
|
|
||||||
def get_image_shapes(dataset: Dataset, image_keys: list) -> dict:
|
|
||||||
image_shapes = {}
|
|
||||||
for img_key in image_keys:
|
|
||||||
image = dataset[0][img_key] # Assuming first row
|
|
||||||
channels = get_image_pixel_channels(image)
|
|
||||||
image_shapes[img_key] = {
|
|
||||||
"width": image.width,
|
|
||||||
"height": image.height,
|
|
||||||
"channels": channels,
|
|
||||||
}
|
|
||||||
|
|
||||||
return image_shapes
|
|
||||||
|
|
||||||
|
|
||||||
def get_generic_motor_names(sequence_shapes: dict) -> dict:
|
def get_generic_motor_names(sequence_shapes: dict) -> dict:
|
||||||
return {key: [f"motor_{i}" for i in range(length)] for key, length in sequence_shapes.items()}
|
return {key: [f"motor_{i}" for i in range(length)] for key, length in sequence_shapes.items()}
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import subprocess
|
import subprocess
|
||||||
import warnings
|
import warnings
|
||||||
|
@ -24,7 +25,9 @@ from typing import Any, ClassVar
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
|
from datasets import Dataset
|
||||||
from datasets.features.features import register_feature
|
from datasets.features.features import register_feature
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
def decode_video_frames_torchvision(
|
def decode_video_frames_torchvision(
|
||||||
|
@ -210,3 +213,131 @@ with warnings.catch_warnings():
|
||||||
)
|
)
|
||||||
# to make VideoFrame available in HuggingFace `datasets`
|
# to make VideoFrame available in HuggingFace `datasets`
|
||||||
register_feature(VideoFrame, "VideoFrame")
|
register_feature(VideoFrame, "VideoFrame")
|
||||||
|
|
||||||
|
|
||||||
|
def get_audio_info(video_path: Path | str) -> dict:
|
||||||
|
ffprobe_audio_cmd = [
|
||||||
|
"ffprobe",
|
||||||
|
"-v",
|
||||||
|
"error",
|
||||||
|
"-select_streams",
|
||||||
|
"a:0",
|
||||||
|
"-show_entries",
|
||||||
|
"stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration",
|
||||||
|
"-of",
|
||||||
|
"json",
|
||||||
|
str(video_path),
|
||||||
|
]
|
||||||
|
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||||
|
|
||||||
|
info = json.loads(result.stdout)
|
||||||
|
audio_stream_info = info["streams"][0] if info.get("streams") else None
|
||||||
|
if audio_stream_info is None:
|
||||||
|
return {"has_audio": False}
|
||||||
|
|
||||||
|
# Return the information, defaulting to None if no audio stream is present
|
||||||
|
return {
|
||||||
|
"has_audio": True,
|
||||||
|
"audio.channels": audio_stream_info.get("channels", None),
|
||||||
|
"audio.codec": audio_stream_info.get("codec_name", None),
|
||||||
|
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
|
||||||
|
"audio.sample_rate": int(audio_stream_info["sample_rate"])
|
||||||
|
if audio_stream_info.get("sample_rate")
|
||||||
|
else None,
|
||||||
|
"audio.bit_depth": audio_stream_info.get("bit_depth", None),
|
||||||
|
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_video_info(video_path: Path | str) -> dict:
|
||||||
|
ffprobe_video_cmd = [
|
||||||
|
"ffprobe",
|
||||||
|
"-v",
|
||||||
|
"error",
|
||||||
|
"-select_streams",
|
||||||
|
"v:0",
|
||||||
|
"-show_entries",
|
||||||
|
"stream=r_frame_rate,width,height,codec_name,nb_frames,duration,pix_fmt",
|
||||||
|
"-of",
|
||||||
|
"json",
|
||||||
|
str(video_path),
|
||||||
|
]
|
||||||
|
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||||
|
if result.returncode != 0:
|
||||||
|
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||||
|
|
||||||
|
info = json.loads(result.stdout)
|
||||||
|
video_stream_info = info["streams"][0]
|
||||||
|
|
||||||
|
# Calculate fps from r_frame_rate
|
||||||
|
r_frame_rate = video_stream_info["r_frame_rate"]
|
||||||
|
num, denom = map(int, r_frame_rate.split("/"))
|
||||||
|
fps = num / denom
|
||||||
|
|
||||||
|
pixel_channels = get_video_pixel_channels(video_stream_info["pix_fmt"])
|
||||||
|
|
||||||
|
video_info = {
|
||||||
|
"video.fps": fps,
|
||||||
|
"video.width": video_stream_info["width"],
|
||||||
|
"video.height": video_stream_info["height"],
|
||||||
|
"video.channels": pixel_channels,
|
||||||
|
"video.codec": video_stream_info["codec_name"],
|
||||||
|
"video.pix_fmt": video_stream_info["pix_fmt"],
|
||||||
|
"video.is_depth_map": False,
|
||||||
|
**get_audio_info(video_path),
|
||||||
|
}
|
||||||
|
|
||||||
|
return video_info
|
||||||
|
|
||||||
|
|
||||||
|
def get_video_shapes(videos_info: dict, video_keys: list) -> dict:
|
||||||
|
video_shapes = {}
|
||||||
|
for img_key in video_keys:
|
||||||
|
channels = get_video_pixel_channels(videos_info[img_key]["video.pix_fmt"])
|
||||||
|
video_shapes[img_key] = {
|
||||||
|
"width": videos_info[img_key]["video.width"],
|
||||||
|
"height": videos_info[img_key]["video.height"],
|
||||||
|
"channels": channels,
|
||||||
|
}
|
||||||
|
|
||||||
|
return video_shapes
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_shapes(dataset: Dataset, image_keys: list) -> dict:
|
||||||
|
image_shapes = {}
|
||||||
|
for img_key in image_keys:
|
||||||
|
image = dataset[0][img_key] # Assuming first row
|
||||||
|
channels = get_image_pixel_channels(image)
|
||||||
|
image_shapes[img_key] = {
|
||||||
|
"width": image.width,
|
||||||
|
"height": image.height,
|
||||||
|
"channels": channels,
|
||||||
|
}
|
||||||
|
|
||||||
|
return image_shapes
|
||||||
|
|
||||||
|
|
||||||
|
def get_video_pixel_channels(pix_fmt: str) -> int:
|
||||||
|
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
|
||||||
|
return 1
|
||||||
|
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
|
||||||
|
return 4
|
||||||
|
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
|
||||||
|
return 3
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown format")
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_pixel_channels(image: Image):
|
||||||
|
if image.mode == "L":
|
||||||
|
return 1 # Grayscale
|
||||||
|
elif image.mode == "LA":
|
||||||
|
return 2 # Grayscale + Alpha
|
||||||
|
elif image.mode == "RGB":
|
||||||
|
return 3 # RGB
|
||||||
|
elif image.mode == "RGBA":
|
||||||
|
return 4 # RGBA
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown format")
|
||||||
|
|
|
@ -234,8 +234,8 @@ def record(
|
||||||
dataset = LeRobotDataset.create(
|
dataset = LeRobotDataset.create(
|
||||||
repo_id,
|
repo_id,
|
||||||
fps,
|
fps,
|
||||||
robot,
|
|
||||||
root=root,
|
root=root,
|
||||||
|
robot=robot,
|
||||||
image_writer_processes=num_image_writer_processes,
|
image_writer_processes=num_image_writer_processes,
|
||||||
image_writer_threads_per_camera=num_image_writer_threads_per_camera,
|
image_writer_threads_per_camera=num_image_writer_threads_per_camera,
|
||||||
use_videos=video,
|
use_videos=video,
|
||||||
|
@ -307,10 +307,6 @@ def record(
|
||||||
log_say("Stop recording", play_sounds, blocking=True)
|
log_say("Stop recording", play_sounds, blocking=True)
|
||||||
stop_recording(robot, listener, display_cameras)
|
stop_recording(robot, listener, display_cameras)
|
||||||
|
|
||||||
if dataset.image_writer is not None:
|
|
||||||
logging.info("Waiting for image writer to terminate...")
|
|
||||||
dataset.image_writer.stop()
|
|
||||||
|
|
||||||
if run_compute_stats:
|
if run_compute_stats:
|
||||||
logging.info("Computing dataset statistics")
|
logging.info("Computing dataset statistics")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue