Add video_info, fix image_writer

This commit is contained in:
Simon Alibert 2024-10-25 16:55:33 +02:00
parent 18ffa4248b
commit e210d795de
6 changed files with 241 additions and 180 deletions

View File

@ -19,9 +19,6 @@ from math import ceil
import einops import einops
import torch import torch
import tqdm import tqdm
from datasets import Image
from lerobot.common.datasets.video_utils import VideoFrame
def get_stats_einops_patterns(dataset, num_workers=0): def get_stats_einops_patterns(dataset, num_workers=0):
@ -39,15 +36,13 @@ def get_stats_einops_patterns(dataset, num_workers=0):
batch = next(iter(dataloader)) batch = next(iter(dataloader))
stats_patterns = {} stats_patterns = {}
for key, feats_type in dataset.features.items():
# NOTE: skip language_instruction embedding in stats computation
if key == "language_instruction":
continue
for key in dataset.features:
# sanity check that tensors are not float64 # sanity check that tensors are not float64
assert batch[key].dtype != torch.float64 assert batch[key].dtype != torch.float64
if isinstance(feats_type, (VideoFrame, Image)): # if isinstance(feats_type, (VideoFrame, Image)):
if key in dataset.camera_keys:
# sanity check that images are channel first # sanity check that images are channel first
_, c, h, w = batch[key].shape _, c, h, w = batch[key].shape
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}" assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
@ -63,7 +58,7 @@ def get_stats_einops_patterns(dataset, num_workers=0):
elif batch[key].ndim == 1: elif batch[key].ndim == 1:
stats_patterns[key] = "b -> 1" stats_patterns[key] = "b -> 1"
else: else:
raise ValueError(f"{key}, {feats_type}, {batch[key].shape}") raise ValueError(f"{key}, {batch[key].shape}")
return stats_patterns return stats_patterns

View File

@ -53,45 +53,54 @@ class ImageWriter:
the number of threads. If it is still not stable, try to use 1 subprocess, or more. the number of threads. If it is still not stable, try to use 1 subprocess, or more.
""" """
def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1): def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1, timeout: int = 10):
self.dir = write_dir self.dir = write_dir
self.dir.mkdir(parents=True, exist_ok=True) self.dir.mkdir(parents=True, exist_ok=True)
self.image_path = DEFAULT_IMAGE_PATH self.image_path = DEFAULT_IMAGE_PATH
self.num_processes = num_processes self.num_processes = num_processes
self.num_threads = self.num_threads_per_process = num_threads self.num_threads = num_threads
self.timeout = timeout
if self.num_processes <= 0: if self.num_processes == 0 and self.num_threads == 0:
self.type = "synchronous"
elif self.num_processes == 0 and self.num_threads > 0:
self.type = "threads" self.type = "threads"
self.threads = ThreadPoolExecutor(max_workers=self.num_threads) self.threads = ThreadPoolExecutor(max_workers=self.num_threads)
self.futures = [] self.futures = []
else: else:
self.type = "processes" self.type = "processes"
self.num_threads_per_process = self.num_threads self.main_event = multiprocessing.Event()
self.image_queue = multiprocessing.Queue() self.image_queue = multiprocessing.Queue()
self.processes: list[multiprocessing.Process] = [] self.processes: list[multiprocessing.Process] = []
for _ in range(num_processes): self.events: list[multiprocessing.Event] = []
process = multiprocessing.Process(target=self._loop_to_save_images_in_threads) for _ in range(self.num_processes):
event = multiprocessing.Event()
process = multiprocessing.Process(target=self._loop_to_save_images_in_threads, args=(event,))
process.start() process.start()
self.processes.append(process) self.processes.append(process)
self.events.append(event)
def _loop_to_save_images_in_threads(self) -> None: def _loop_to_save_images_in_threads(self, event: multiprocessing.Event) -> None:
with ThreadPoolExecutor(max_workers=self.num_threads) as executor: with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
futures = [] futures = []
while True: while True:
frame_data = self.image_queue.get() frame_data = self.image_queue.get()
if frame_data is None: if frame_data is None:
break self._wait_threads(self.futures, 10)
return
image, file_path = frame_data image, file_path = frame_data
futures.append(executor.submit(self._save_image, image, file_path)) futures.append(executor.submit(self._save_image, image, file_path))
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar: if self.main_event.is_set():
wait(futures) self._wait_threads(self.futures, 10)
progress_bar.update(len(futures)) event.set()
def async_save_image(self, image: torch.Tensor, file_path: Path) -> None: def async_save_image(self, image: torch.Tensor, file_path: Path) -> None:
"""Save an image asynchronously using threads or processes.""" """Save an image asynchronously using threads or processes."""
if self.type == "threads": if self.type == "synchronous":
self._save_image(image, file_path)
elif self.type == "threads":
self.futures.append(self.threads.submit(self._save_image, image, file_path)) self.futures.append(self.threads.submit(self._save_image, image, file_path))
else: else:
self.image_queue.put((image, file_path)) self.image_queue.put((image, file_path))
@ -111,12 +120,33 @@ class ImageWriter:
episode_index=episode_index, image_key=image_key, frame_index=0 episode_index=episode_index, image_key=image_key, frame_index=0
).parent ).parent
def stop(self, timeout=20) -> None: def wait(self) -> None:
"""Wait for the thread/processes to finish writing."""
if self.type == "synchronous":
return
elif self.type == "threads":
self._wait_threads(self.futures)
else:
self._wait_processes()
def _wait_threads(self, futures) -> None:
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
wait(futures, timeout=self.timeout)
progress_bar.update(len(futures))
def _wait_processes(self) -> None:
self.main_event.set()
for event in self.events:
event.wait()
self.main_event.clear()
def shutdown(self, timeout=20) -> None:
"""Stop the image writer, waiting for all processes or threads to finish.""" """Stop the image writer, waiting for all processes or threads to finish."""
if self.type == "threads": if self.type == "synchronous":
with tqdm.tqdm(total=len(self.futures), desc="Writing images") as progress_bar: return
wait(self.futures, timeout=timeout) elif self.type == "threads":
progress_bar.update(len(self.futures)) self.threads.shutdown(wait=True)
else: else:
self._stop_processes(timeout) self._stop_processes(timeout)
@ -127,8 +157,9 @@ class ImageWriter:
for process in self.processes: for process in self.processes:
process.join(timeout=timeout) process.join(timeout=timeout)
if process.is_alive(): for process in self.processes:
process.terminate() if process.is_alive():
process.terminate()
self.image_queue.close() self.image_queue.close()
self.image_queue.join_thread() self.image_queue.join_thread()

View File

@ -22,10 +22,10 @@ from pathlib import Path
from typing import Callable from typing import Callable
import datasets import datasets
import pyarrow.parquet as pq
import torch import torch
import torch.utils import torch.utils
from datasets import load_dataset from datasets import load_dataset
from datasets.table import embed_table_storage
from huggingface_hub import snapshot_download, upload_folder from huggingface_hub import snapshot_download, upload_folder
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
@ -57,6 +57,7 @@ from lerobot.common.datasets.video_utils import (
VideoFrame, VideoFrame,
decode_video_frames_torchvision, decode_video_frames_torchvision,
encode_video_frames, encode_video_frames,
get_video_info,
) )
from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.robots.utils import Robot
@ -391,7 +392,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
return self.info["shapes"] return self.info["shapes"]
@property @property
def features(self) -> datasets.Features: def features(self) -> list[str]:
return list(self._features) + self.video_keys
@property
def _features(self) -> datasets.Features:
"""Features of the hf_dataset.""" """Features of the hf_dataset."""
if self.hf_dataset is not None: if self.hf_dataset is not None:
return self.hf_dataset.features return self.hf_dataset.features
@ -583,6 +588,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
image=frame[cam_key], image=frame[cam_key],
file_path=img_path, file_path=img_path,
) )
if cam_key in self.image_keys: if cam_key in self.image_keys:
self.episode_buffer[cam_key].append(str(img_path)) self.episode_buffer[cam_key].append(str(img_path))
@ -592,7 +598,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
the hub. the hub.
Use 'encode_videos' if you want to encode videos during the saving of each episode. Otherwise, Use 'encode_videos' if you want to encode videos during the saving of this episode. Otherwise,
you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend
time for video encoding. time for video encoding.
""" """
@ -608,7 +614,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
for key in self.episode_buffer: for key in self.episode_buffer:
if key in self.image_keys: if key in self.image_keys:
continue continue
if key in self.keys: elif key in self.keys:
self.episode_buffer[key] = torch.stack(self.episode_buffer[key]) self.episode_buffer[key] = torch.stack(self.episode_buffer[key])
elif key == "episode_index": elif key == "episode_index":
self.episode_buffer[key] = torch.full((episode_length,), episode_index) self.episode_buffer[key] = torch.full((episode_length,), episode_index)
@ -619,6 +625,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.episode_buffer["index"] = torch.arange(self.total_frames, self.total_frames + episode_length) self.episode_buffer["index"] = torch.arange(self.total_frames, self.total_frames + episode_length)
self._save_episode_to_metadata(episode_index, episode_length, task, task_index) self._save_episode_to_metadata(episode_index, episode_length, task, task_index)
self._wait_image_writer()
self._save_episode_table(episode_index) self._save_episode_table(episode_index)
if encode_videos and len(self.video_keys) > 0: if encode_videos and len(self.video_keys) > 0:
@ -629,11 +637,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.consolidated = False self.consolidated = False
def _save_episode_table(self, episode_index: int) -> None: def _save_episode_table(self, episode_index: int) -> None:
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self.features, split="train") ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self._features, split="train")
ep_table = ep_dataset._data.table
ep_data_path = self.root / self.get_data_file_path(ep_index=episode_index) ep_data_path = self.root / self.get_data_file_path(ep_index=episode_index)
ep_data_path.parent.mkdir(parents=True, exist_ok=True) ep_data_path.parent.mkdir(parents=True, exist_ok=True)
pq.write_table(ep_table, ep_data_path)
# Embed image bytes into the table before saving to parquet
format = ep_dataset.format
ep_dataset = ep_dataset.with_format("arrow")
ep_dataset = ep_dataset.map(embed_table_storage, batched=False)
ep_dataset = ep_dataset.with_format(**format)
ep_dataset.to_parquet(ep_data_path)
def _save_episode_to_metadata( def _save_episode_to_metadata(
self, episode_index: int, episode_length: int, task: str, task_index: int self, episode_index: int, episode_length: int, task: str, task_index: int
@ -677,7 +691,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Reset the buffer # Reset the buffer
self.episode_buffer = self._create_episode_buffer() self.episode_buffer = self._create_episode_buffer()
def start_image_writter(self, num_processes: int = 0, num_threads: int = 1) -> None: def start_image_writer(self, num_processes: int = 0, num_threads: int = 1) -> None:
if isinstance(self.image_writer, ImageWriter): if isinstance(self.image_writer, ImageWriter):
logging.warning( logging.warning(
"You are starting a new ImageWriter that is replacing an already exising one in the dataset." "You are starting a new ImageWriter that is replacing an already exising one in the dataset."
@ -689,18 +703,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
num_threads=num_threads, num_threads=num_threads,
) )
def stop_image_writter(self) -> None: def stop_image_writer(self) -> None:
""" """
Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to
remove the image_write in order for the LeRobotDataset object to be pickleable and parallelized. remove the image_write in order for the LeRobotDataset object to be pickleable and parallelized.
""" """
if self.image_writer is not None: if self.image_writer is not None:
self.image_writer.stop() self.image_writer.shutdown()
self.image_writer = None self.image_writer = None
def _wait_image_writer(self) -> None:
"""Wait for asynchronous image writer to finish."""
if self.image_writer is not None:
self.image_writer.wait()
def encode_videos(self) -> None: def encode_videos(self) -> None:
# Use ffmpeg to convert frames stored as png into mp4 videos # Use ffmpeg to convert frames stored as png into mp4 videos
for episode_index in range(self.num_episodes): for episode_index in range(self.total_episodes):
for key in self.video_keys: for key in self.video_keys:
# TODO: create video_buffer to store the state of encoded/unencoded videos and remove the need # TODO: create video_buffer to store the state of encoded/unencoded videos and remove the need
# to call self.image_writer here # to call self.image_writer here
@ -713,6 +732,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
# since video encoding with ffmpeg is already using multithreading. # since video encoding with ffmpeg is already using multithreading.
encode_video_frames(tmp_imgs_dir, video_path, self.fps, overwrite=True) encode_video_frames(tmp_imgs_dir, video_path, self.fps, overwrite=True)
def _write_video_info(self) -> None:
"""
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
been encoded the same way. Also, this means it assumes the first episode exists.
"""
for key in self.video_keys:
if key not in self.info["videos"]:
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
self.info["videos"][key] = get_video_info(video_path)
write_json(self.info, self.root / INFO_PATH)
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None: def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
self.hf_dataset = self.load_hf_dataset() self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts) self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
@ -720,12 +751,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
if len(self.video_keys) > 0: if len(self.video_keys) > 0:
self.encode_videos() self.encode_videos()
self._write_video_info()
if not keep_image_files and self.image_writer is not None: if not keep_image_files and self.image_writer is not None:
shutil.rmtree(self.image_writer.dir) shutil.rmtree(self.image_writer.dir)
if run_compute_stats: if run_compute_stats:
self.stop_image_writter() self.stop_image_writer()
self.stats = compute_stats(self) self.stats = compute_stats(self)
write_stats(self.stats, self.root / STATS_PATH) write_stats(self.stats, self.root / STATS_PATH)
self.consolidated = True self.consolidated = True
@ -735,7 +767,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
) )
# TODO(aliberts) # TODO(aliberts)
# - [ ] add video info in info.json # - [X] add video info in info.json
# Sanity checks: # Sanity checks:
# - [ ] shapes # - [ ] shapes
# - [ ] ep_lenghts # - [ ] ep_lenghts
@ -775,7 +807,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
"In this case, frames from lower fps cameras will be repeated to fill in the blanks" "In this case, frames from lower fps cameras will be repeated to fill in the blanks"
) )
if len(robot.cameras) > 0 and (image_writer_processes or image_writer_threads_per_camera): if len(robot.cameras) > 0 and (image_writer_processes or image_writer_threads_per_camera):
obj.start_image_writter( obj.start_image_writer(
image_writer_processes, image_writer_threads_per_camera * robot.num_cameras image_writer_processes, image_writer_threads_per_camera * robot.num_cameras
) )
elif ( elif (
@ -791,7 +823,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
) )
if len(video_keys) > 0 and not use_videos: if len(video_keys) > 0 and not use_videos:
raise ValueError raise ValueError()
obj.tasks, obj.stats, obj.episode_dicts = {}, {}, [] obj.tasks, obj.stats, obj.episode_dicts = {}, {}, []
obj.info = create_empty_dataset_info( obj.info = create_empty_dataset_info(
@ -918,7 +950,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
def features(self) -> datasets.Features: def features(self) -> datasets.Features:
features = {} features = {}
for dataset in self._datasets: for dataset in self._datasets:
features.update({k: v for k, v in dataset.features.items() if k not in self.disabled_data_keys}) features.update({k: v for k, v in dataset._features.items() if k not in self.disabled_data_keys})
return features return features
@property @property

View File

@ -116,7 +116,6 @@ import torch
from datasets import Dataset from datasets import Dataset
from huggingface_hub import HfApi from huggingface_hub import HfApi
from huggingface_hub.errors import EntryNotFoundError from huggingface_hub.errors import EntryNotFoundError
from PIL import Image
from safetensors.torch import load_file from safetensors.torch import load_file
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
@ -136,7 +135,12 @@ from lerobot.common.datasets.utils import (
write_json, write_json,
write_jsonlines, write_jsonlines,
) )
from lerobot.common.datasets.video_utils import VideoFrame # noqa: F401 from lerobot.common.datasets.video_utils import (
VideoFrame, # noqa: F401
get_image_shapes,
get_video_info,
get_video_shapes,
)
from lerobot.common.utils.utils import init_hydra_config from lerobot.common.utils.utils import init_hydra_config
V16 = "v1.6" V16 = "v1.6"
@ -391,83 +395,6 @@ def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[st
return [f for f in video_files if f not in lfs_tracked_files] return [f for f in video_files if f not in lfs_tracked_files]
def _get_audio_info(video_path: Path | str) -> dict:
ffprobe_audio_cmd = [
"ffprobe",
"-v",
"error",
"-select_streams",
"a:0",
"-show_entries",
"stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration",
"-of",
"json",
str(video_path),
]
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
if result.returncode != 0:
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
info = json.loads(result.stdout)
audio_stream_info = info["streams"][0] if info.get("streams") else None
if audio_stream_info is None:
return {"has_audio": False}
# Return the information, defaulting to None if no audio stream is present
return {
"has_audio": True,
"audio.channels": audio_stream_info.get("channels", None),
"audio.codec": audio_stream_info.get("codec_name", None),
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
"audio.sample_rate": int(audio_stream_info["sample_rate"])
if audio_stream_info.get("sample_rate")
else None,
"audio.bit_depth": audio_stream_info.get("bit_depth", None),
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
}
def _get_video_info(video_path: Path | str) -> dict:
ffprobe_video_cmd = [
"ffprobe",
"-v",
"error",
"-select_streams",
"v:0",
"-show_entries",
"stream=r_frame_rate,width,height,codec_name,nb_frames,duration,pix_fmt",
"-of",
"json",
str(video_path),
]
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
if result.returncode != 0:
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
info = json.loads(result.stdout)
video_stream_info = info["streams"][0]
# Calculate fps from r_frame_rate
r_frame_rate = video_stream_info["r_frame_rate"]
num, denom = map(int, r_frame_rate.split("/"))
fps = num / denom
pixel_channels = get_video_pixel_channels(video_stream_info["pix_fmt"])
video_info = {
"video.fps": fps,
"video.width": video_stream_info["width"],
"video.height": video_stream_info["height"],
"video.channels": pixel_channels,
"video.codec": video_stream_info["codec_name"],
"video.pix_fmt": video_stream_info["pix_fmt"],
"video.is_depth_map": False,
**_get_audio_info(video_path),
}
return video_info
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict: def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
hub_api = HfApi() hub_api = HfApi()
videos_info_dict = {"videos_path": DEFAULT_VIDEO_PATH} videos_info_dict = {"videos_path": DEFAULT_VIDEO_PATH}
@ -481,62 +408,11 @@ def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
) )
for vid_key, vid_path in zip(video_keys, video_files, strict=True): for vid_key, vid_path in zip(video_keys, video_files, strict=True):
videos_info_dict[vid_key] = _get_video_info(local_dir / vid_path) videos_info_dict[vid_key] = get_video_info(local_dir / vid_path)
return videos_info_dict return videos_info_dict
def get_video_pixel_channels(pix_fmt: str) -> int:
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
return 1
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
return 4
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
return 3
else:
raise ValueError("Unknown format")
def get_image_pixel_channels(image: Image):
if image.mode == "L":
return 1 # Grayscale
elif image.mode == "LA":
return 2 # Grayscale + Alpha
elif image.mode == "RGB":
return 3 # RGB
elif image.mode == "RGBA":
return 4 # RGBA
else:
raise ValueError("Unknown format")
def get_video_shapes(videos_info: dict, video_keys: list) -> dict:
video_shapes = {}
for img_key in video_keys:
channels = get_video_pixel_channels(videos_info[img_key]["video.pix_fmt"])
video_shapes[img_key] = {
"width": videos_info[img_key]["video.width"],
"height": videos_info[img_key]["video.height"],
"channels": channels,
}
return video_shapes
def get_image_shapes(dataset: Dataset, image_keys: list) -> dict:
image_shapes = {}
for img_key in image_keys:
image = dataset[0][img_key] # Assuming first row
channels = get_image_pixel_channels(image)
image_shapes[img_key] = {
"width": image.width,
"height": image.height,
"channels": channels,
}
return image_shapes
def get_generic_motor_names(sequence_shapes: dict) -> dict: def get_generic_motor_names(sequence_shapes: dict) -> dict:
return {key: [f"motor_{i}" for i in range(length)] for key, length in sequence_shapes.items()} return {key: [f"motor_{i}" for i in range(length)] for key, length in sequence_shapes.items()}

View File

@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import logging import logging
import subprocess import subprocess
import warnings import warnings
@ -24,7 +25,9 @@ from typing import Any, ClassVar
import pyarrow as pa import pyarrow as pa
import torch import torch
import torchvision import torchvision
from datasets import Dataset
from datasets.features.features import register_feature from datasets.features.features import register_feature
from PIL import Image
def decode_video_frames_torchvision( def decode_video_frames_torchvision(
@ -210,3 +213,131 @@ with warnings.catch_warnings():
) )
# to make VideoFrame available in HuggingFace `datasets` # to make VideoFrame available in HuggingFace `datasets`
register_feature(VideoFrame, "VideoFrame") register_feature(VideoFrame, "VideoFrame")
def get_audio_info(video_path: Path | str) -> dict:
ffprobe_audio_cmd = [
"ffprobe",
"-v",
"error",
"-select_streams",
"a:0",
"-show_entries",
"stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration",
"-of",
"json",
str(video_path),
]
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
if result.returncode != 0:
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
info = json.loads(result.stdout)
audio_stream_info = info["streams"][0] if info.get("streams") else None
if audio_stream_info is None:
return {"has_audio": False}
# Return the information, defaulting to None if no audio stream is present
return {
"has_audio": True,
"audio.channels": audio_stream_info.get("channels", None),
"audio.codec": audio_stream_info.get("codec_name", None),
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
"audio.sample_rate": int(audio_stream_info["sample_rate"])
if audio_stream_info.get("sample_rate")
else None,
"audio.bit_depth": audio_stream_info.get("bit_depth", None),
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
}
def get_video_info(video_path: Path | str) -> dict:
ffprobe_video_cmd = [
"ffprobe",
"-v",
"error",
"-select_streams",
"v:0",
"-show_entries",
"stream=r_frame_rate,width,height,codec_name,nb_frames,duration,pix_fmt",
"-of",
"json",
str(video_path),
]
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
if result.returncode != 0:
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
info = json.loads(result.stdout)
video_stream_info = info["streams"][0]
# Calculate fps from r_frame_rate
r_frame_rate = video_stream_info["r_frame_rate"]
num, denom = map(int, r_frame_rate.split("/"))
fps = num / denom
pixel_channels = get_video_pixel_channels(video_stream_info["pix_fmt"])
video_info = {
"video.fps": fps,
"video.width": video_stream_info["width"],
"video.height": video_stream_info["height"],
"video.channels": pixel_channels,
"video.codec": video_stream_info["codec_name"],
"video.pix_fmt": video_stream_info["pix_fmt"],
"video.is_depth_map": False,
**get_audio_info(video_path),
}
return video_info
def get_video_shapes(videos_info: dict, video_keys: list) -> dict:
video_shapes = {}
for img_key in video_keys:
channels = get_video_pixel_channels(videos_info[img_key]["video.pix_fmt"])
video_shapes[img_key] = {
"width": videos_info[img_key]["video.width"],
"height": videos_info[img_key]["video.height"],
"channels": channels,
}
return video_shapes
def get_image_shapes(dataset: Dataset, image_keys: list) -> dict:
image_shapes = {}
for img_key in image_keys:
image = dataset[0][img_key] # Assuming first row
channels = get_image_pixel_channels(image)
image_shapes[img_key] = {
"width": image.width,
"height": image.height,
"channels": channels,
}
return image_shapes
def get_video_pixel_channels(pix_fmt: str) -> int:
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
return 1
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
return 4
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
return 3
else:
raise ValueError("Unknown format")
def get_image_pixel_channels(image: Image):
if image.mode == "L":
return 1 # Grayscale
elif image.mode == "LA":
return 2 # Grayscale + Alpha
elif image.mode == "RGB":
return 3 # RGB
elif image.mode == "RGBA":
return 4 # RGBA
else:
raise ValueError("Unknown format")

View File

@ -234,8 +234,8 @@ def record(
dataset = LeRobotDataset.create( dataset = LeRobotDataset.create(
repo_id, repo_id,
fps, fps,
robot,
root=root, root=root,
robot=robot,
image_writer_processes=num_image_writer_processes, image_writer_processes=num_image_writer_processes,
image_writer_threads_per_camera=num_image_writer_threads_per_camera, image_writer_threads_per_camera=num_image_writer_threads_per_camera,
use_videos=video, use_videos=video,
@ -307,10 +307,6 @@ def record(
log_say("Stop recording", play_sounds, blocking=True) log_say("Stop recording", play_sounds, blocking=True)
stop_recording(robot, listener, display_cameras) stop_recording(robot, listener, display_cameras)
if dataset.image_writer is not None:
logging.info("Waiting for image writer to terminate...")
dataset.image_writer.stop()
if run_compute_stats: if run_compute_stats:
logging.info("Computing dataset statistics") logging.info("Computing dataset statistics")