Add video_info, fix image_writer
This commit is contained in:
parent
18ffa4248b
commit
e210d795de
|
@ -19,9 +19,6 @@ from math import ceil
|
|||
import einops
|
||||
import torch
|
||||
import tqdm
|
||||
from datasets import Image
|
||||
|
||||
from lerobot.common.datasets.video_utils import VideoFrame
|
||||
|
||||
|
||||
def get_stats_einops_patterns(dataset, num_workers=0):
|
||||
|
@ -39,15 +36,13 @@ def get_stats_einops_patterns(dataset, num_workers=0):
|
|||
batch = next(iter(dataloader))
|
||||
|
||||
stats_patterns = {}
|
||||
for key, feats_type in dataset.features.items():
|
||||
# NOTE: skip language_instruction embedding in stats computation
|
||||
if key == "language_instruction":
|
||||
continue
|
||||
|
||||
for key in dataset.features:
|
||||
# sanity check that tensors are not float64
|
||||
assert batch[key].dtype != torch.float64
|
||||
|
||||
if isinstance(feats_type, (VideoFrame, Image)):
|
||||
# if isinstance(feats_type, (VideoFrame, Image)):
|
||||
if key in dataset.camera_keys:
|
||||
# sanity check that images are channel first
|
||||
_, c, h, w = batch[key].shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
|
||||
|
@ -63,7 +58,7 @@ def get_stats_einops_patterns(dataset, num_workers=0):
|
|||
elif batch[key].ndim == 1:
|
||||
stats_patterns[key] = "b -> 1"
|
||||
else:
|
||||
raise ValueError(f"{key}, {feats_type}, {batch[key].shape}")
|
||||
raise ValueError(f"{key}, {batch[key].shape}")
|
||||
|
||||
return stats_patterns
|
||||
|
||||
|
|
|
@ -53,45 +53,54 @@ class ImageWriter:
|
|||
the number of threads. If it is still not stable, try to use 1 subprocess, or more.
|
||||
"""
|
||||
|
||||
def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1):
|
||||
def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1, timeout: int = 10):
|
||||
self.dir = write_dir
|
||||
self.dir.mkdir(parents=True, exist_ok=True)
|
||||
self.image_path = DEFAULT_IMAGE_PATH
|
||||
self.num_processes = num_processes
|
||||
self.num_threads = self.num_threads_per_process = num_threads
|
||||
self.num_threads = num_threads
|
||||
self.timeout = timeout
|
||||
|
||||
if self.num_processes <= 0:
|
||||
if self.num_processes == 0 and self.num_threads == 0:
|
||||
self.type = "synchronous"
|
||||
elif self.num_processes == 0 and self.num_threads > 0:
|
||||
self.type = "threads"
|
||||
self.threads = ThreadPoolExecutor(max_workers=self.num_threads)
|
||||
self.futures = []
|
||||
else:
|
||||
self.type = "processes"
|
||||
self.num_threads_per_process = self.num_threads
|
||||
self.main_event = multiprocessing.Event()
|
||||
self.image_queue = multiprocessing.Queue()
|
||||
self.processes: list[multiprocessing.Process] = []
|
||||
for _ in range(num_processes):
|
||||
process = multiprocessing.Process(target=self._loop_to_save_images_in_threads)
|
||||
self.events: list[multiprocessing.Event] = []
|
||||
for _ in range(self.num_processes):
|
||||
event = multiprocessing.Event()
|
||||
process = multiprocessing.Process(target=self._loop_to_save_images_in_threads, args=(event,))
|
||||
process.start()
|
||||
self.processes.append(process)
|
||||
self.events.append(event)
|
||||
|
||||
def _loop_to_save_images_in_threads(self) -> None:
|
||||
def _loop_to_save_images_in_threads(self, event: multiprocessing.Event) -> None:
|
||||
with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
|
||||
futures = []
|
||||
while True:
|
||||
frame_data = self.image_queue.get()
|
||||
if frame_data is None:
|
||||
break
|
||||
self._wait_threads(self.futures, 10)
|
||||
return
|
||||
|
||||
image, file_path = frame_data
|
||||
futures.append(executor.submit(self._save_image, image, file_path))
|
||||
|
||||
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
|
||||
wait(futures)
|
||||
progress_bar.update(len(futures))
|
||||
if self.main_event.is_set():
|
||||
self._wait_threads(self.futures, 10)
|
||||
event.set()
|
||||
|
||||
def async_save_image(self, image: torch.Tensor, file_path: Path) -> None:
|
||||
"""Save an image asynchronously using threads or processes."""
|
||||
if self.type == "threads":
|
||||
if self.type == "synchronous":
|
||||
self._save_image(image, file_path)
|
||||
elif self.type == "threads":
|
||||
self.futures.append(self.threads.submit(self._save_image, image, file_path))
|
||||
else:
|
||||
self.image_queue.put((image, file_path))
|
||||
|
@ -111,12 +120,33 @@ class ImageWriter:
|
|||
episode_index=episode_index, image_key=image_key, frame_index=0
|
||||
).parent
|
||||
|
||||
def stop(self, timeout=20) -> None:
|
||||
def wait(self) -> None:
|
||||
"""Wait for the thread/processes to finish writing."""
|
||||
if self.type == "synchronous":
|
||||
return
|
||||
elif self.type == "threads":
|
||||
self._wait_threads(self.futures)
|
||||
else:
|
||||
self._wait_processes()
|
||||
|
||||
def _wait_threads(self, futures) -> None:
|
||||
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
|
||||
wait(futures, timeout=self.timeout)
|
||||
progress_bar.update(len(futures))
|
||||
|
||||
def _wait_processes(self) -> None:
|
||||
self.main_event.set()
|
||||
for event in self.events:
|
||||
event.wait()
|
||||
|
||||
self.main_event.clear()
|
||||
|
||||
def shutdown(self, timeout=20) -> None:
|
||||
"""Stop the image writer, waiting for all processes or threads to finish."""
|
||||
if self.type == "threads":
|
||||
with tqdm.tqdm(total=len(self.futures), desc="Writing images") as progress_bar:
|
||||
wait(self.futures, timeout=timeout)
|
||||
progress_bar.update(len(self.futures))
|
||||
if self.type == "synchronous":
|
||||
return
|
||||
elif self.type == "threads":
|
||||
self.threads.shutdown(wait=True)
|
||||
else:
|
||||
self._stop_processes(timeout)
|
||||
|
||||
|
@ -127,6 +157,7 @@ class ImageWriter:
|
|||
for process in self.processes:
|
||||
process.join(timeout=timeout)
|
||||
|
||||
for process in self.processes:
|
||||
if process.is_alive():
|
||||
process.terminate()
|
||||
|
||||
|
|
|
@ -22,10 +22,10 @@ from pathlib import Path
|
|||
from typing import Callable
|
||||
|
||||
import datasets
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
import torch.utils
|
||||
from datasets import load_dataset
|
||||
from datasets.table import embed_table_storage
|
||||
from huggingface_hub import snapshot_download, upload_folder
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
|
||||
|
@ -57,6 +57,7 @@ from lerobot.common.datasets.video_utils import (
|
|||
VideoFrame,
|
||||
decode_video_frames_torchvision,
|
||||
encode_video_frames,
|
||||
get_video_info,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
|
||||
|
@ -391,7 +392,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
return self.info["shapes"]
|
||||
|
||||
@property
|
||||
def features(self) -> datasets.Features:
|
||||
def features(self) -> list[str]:
|
||||
return list(self._features) + self.video_keys
|
||||
|
||||
@property
|
||||
def _features(self) -> datasets.Features:
|
||||
"""Features of the hf_dataset."""
|
||||
if self.hf_dataset is not None:
|
||||
return self.hf_dataset.features
|
||||
|
@ -583,6 +588,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
image=frame[cam_key],
|
||||
file_path=img_path,
|
||||
)
|
||||
|
||||
if cam_key in self.image_keys:
|
||||
self.episode_buffer[cam_key].append(str(img_path))
|
||||
|
||||
|
@ -592,7 +598,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
|
||||
the hub.
|
||||
|
||||
Use 'encode_videos' if you want to encode videos during the saving of each episode. Otherwise,
|
||||
Use 'encode_videos' if you want to encode videos during the saving of this episode. Otherwise,
|
||||
you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend
|
||||
time for video encoding.
|
||||
"""
|
||||
|
@ -608,7 +614,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
for key in self.episode_buffer:
|
||||
if key in self.image_keys:
|
||||
continue
|
||||
if key in self.keys:
|
||||
elif key in self.keys:
|
||||
self.episode_buffer[key] = torch.stack(self.episode_buffer[key])
|
||||
elif key == "episode_index":
|
||||
self.episode_buffer[key] = torch.full((episode_length,), episode_index)
|
||||
|
@ -619,6 +625,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
self.episode_buffer["index"] = torch.arange(self.total_frames, self.total_frames + episode_length)
|
||||
self._save_episode_to_metadata(episode_index, episode_length, task, task_index)
|
||||
|
||||
self._wait_image_writer()
|
||||
self._save_episode_table(episode_index)
|
||||
|
||||
if encode_videos and len(self.video_keys) > 0:
|
||||
|
@ -629,11 +637,17 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.consolidated = False
|
||||
|
||||
def _save_episode_table(self, episode_index: int) -> None:
|
||||
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self.features, split="train")
|
||||
ep_table = ep_dataset._data.table
|
||||
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self._features, split="train")
|
||||
ep_data_path = self.root / self.get_data_file_path(ep_index=episode_index)
|
||||
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
pq.write_table(ep_table, ep_data_path)
|
||||
|
||||
# Embed image bytes into the table before saving to parquet
|
||||
format = ep_dataset.format
|
||||
ep_dataset = ep_dataset.with_format("arrow")
|
||||
ep_dataset = ep_dataset.map(embed_table_storage, batched=False)
|
||||
ep_dataset = ep_dataset.with_format(**format)
|
||||
|
||||
ep_dataset.to_parquet(ep_data_path)
|
||||
|
||||
def _save_episode_to_metadata(
|
||||
self, episode_index: int, episode_length: int, task: str, task_index: int
|
||||
|
@ -677,7 +691,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
# Reset the buffer
|
||||
self.episode_buffer = self._create_episode_buffer()
|
||||
|
||||
def start_image_writter(self, num_processes: int = 0, num_threads: int = 1) -> None:
|
||||
def start_image_writer(self, num_processes: int = 0, num_threads: int = 1) -> None:
|
||||
if isinstance(self.image_writer, ImageWriter):
|
||||
logging.warning(
|
||||
"You are starting a new ImageWriter that is replacing an already exising one in the dataset."
|
||||
|
@ -689,18 +703,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
def stop_image_writter(self) -> None:
|
||||
def stop_image_writer(self) -> None:
|
||||
"""
|
||||
Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to
|
||||
remove the image_write in order for the LeRobotDataset object to be pickleable and parallelized.
|
||||
"""
|
||||
if self.image_writer is not None:
|
||||
self.image_writer.stop()
|
||||
self.image_writer.shutdown()
|
||||
self.image_writer = None
|
||||
|
||||
def _wait_image_writer(self) -> None:
|
||||
"""Wait for asynchronous image writer to finish."""
|
||||
if self.image_writer is not None:
|
||||
self.image_writer.wait()
|
||||
|
||||
def encode_videos(self) -> None:
|
||||
# Use ffmpeg to convert frames stored as png into mp4 videos
|
||||
for episode_index in range(self.num_episodes):
|
||||
for episode_index in range(self.total_episodes):
|
||||
for key in self.video_keys:
|
||||
# TODO: create video_buffer to store the state of encoded/unencoded videos and remove the need
|
||||
# to call self.image_writer here
|
||||
|
@ -713,6 +732,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
# since video encoding with ffmpeg is already using multithreading.
|
||||
encode_video_frames(tmp_imgs_dir, video_path, self.fps, overwrite=True)
|
||||
|
||||
def _write_video_info(self) -> None:
|
||||
"""
|
||||
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||
"""
|
||||
for key in self.video_keys:
|
||||
if key not in self.info["videos"]:
|
||||
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
|
||||
self.info["videos"][key] = get_video_info(video_path)
|
||||
|
||||
write_json(self.info, self.root / INFO_PATH)
|
||||
|
||||
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
|
||||
|
@ -720,12 +751,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
if len(self.video_keys) > 0:
|
||||
self.encode_videos()
|
||||
self._write_video_info()
|
||||
|
||||
if not keep_image_files and self.image_writer is not None:
|
||||
shutil.rmtree(self.image_writer.dir)
|
||||
|
||||
if run_compute_stats:
|
||||
self.stop_image_writter()
|
||||
self.stop_image_writer()
|
||||
self.stats = compute_stats(self)
|
||||
write_stats(self.stats, self.root / STATS_PATH)
|
||||
self.consolidated = True
|
||||
|
@ -735,7 +767,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
)
|
||||
|
||||
# TODO(aliberts)
|
||||
# - [ ] add video info in info.json
|
||||
# - [X] add video info in info.json
|
||||
# Sanity checks:
|
||||
# - [ ] shapes
|
||||
# - [ ] ep_lenghts
|
||||
|
@ -775,7 +807,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
|
||||
)
|
||||
if len(robot.cameras) > 0 and (image_writer_processes or image_writer_threads_per_camera):
|
||||
obj.start_image_writter(
|
||||
obj.start_image_writer(
|
||||
image_writer_processes, image_writer_threads_per_camera * robot.num_cameras
|
||||
)
|
||||
elif (
|
||||
|
@ -791,7 +823,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
)
|
||||
|
||||
if len(video_keys) > 0 and not use_videos:
|
||||
raise ValueError
|
||||
raise ValueError()
|
||||
|
||||
obj.tasks, obj.stats, obj.episode_dicts = {}, {}, []
|
||||
obj.info = create_empty_dataset_info(
|
||||
|
@ -918,7 +950,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||
def features(self) -> datasets.Features:
|
||||
features = {}
|
||||
for dataset in self._datasets:
|
||||
features.update({k: v for k, v in dataset.features.items() if k not in self.disabled_data_keys})
|
||||
features.update({k: v for k, v in dataset._features.items() if k not in self.disabled_data_keys})
|
||||
return features
|
||||
|
||||
@property
|
||||
|
|
|
@ -116,7 +116,6 @@ import torch
|
|||
from datasets import Dataset
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.errors import EntryNotFoundError
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
|
@ -136,7 +135,12 @@ from lerobot.common.datasets.utils import (
|
|||
write_json,
|
||||
write_jsonlines,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame # noqa: F401
|
||||
from lerobot.common.datasets.video_utils import (
|
||||
VideoFrame, # noqa: F401
|
||||
get_image_shapes,
|
||||
get_video_info,
|
||||
get_video_shapes,
|
||||
)
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
|
||||
V16 = "v1.6"
|
||||
|
@ -391,83 +395,6 @@ def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[st
|
|||
return [f for f in video_files if f not in lfs_tracked_files]
|
||||
|
||||
|
||||
def _get_audio_info(video_path: Path | str) -> dict:
|
||||
ffprobe_audio_cmd = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-select_streams",
|
||||
"a:0",
|
||||
"-show_entries",
|
||||
"stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration",
|
||||
"-of",
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
info = json.loads(result.stdout)
|
||||
audio_stream_info = info["streams"][0] if info.get("streams") else None
|
||||
if audio_stream_info is None:
|
||||
return {"has_audio": False}
|
||||
|
||||
# Return the information, defaulting to None if no audio stream is present
|
||||
return {
|
||||
"has_audio": True,
|
||||
"audio.channels": audio_stream_info.get("channels", None),
|
||||
"audio.codec": audio_stream_info.get("codec_name", None),
|
||||
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
|
||||
"audio.sample_rate": int(audio_stream_info["sample_rate"])
|
||||
if audio_stream_info.get("sample_rate")
|
||||
else None,
|
||||
"audio.bit_depth": audio_stream_info.get("bit_depth", None),
|
||||
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
|
||||
}
|
||||
|
||||
|
||||
def _get_video_info(video_path: Path | str) -> dict:
|
||||
ffprobe_video_cmd = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-select_streams",
|
||||
"v:0",
|
||||
"-show_entries",
|
||||
"stream=r_frame_rate,width,height,codec_name,nb_frames,duration,pix_fmt",
|
||||
"-of",
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
info = json.loads(result.stdout)
|
||||
video_stream_info = info["streams"][0]
|
||||
|
||||
# Calculate fps from r_frame_rate
|
||||
r_frame_rate = video_stream_info["r_frame_rate"]
|
||||
num, denom = map(int, r_frame_rate.split("/"))
|
||||
fps = num / denom
|
||||
|
||||
pixel_channels = get_video_pixel_channels(video_stream_info["pix_fmt"])
|
||||
|
||||
video_info = {
|
||||
"video.fps": fps,
|
||||
"video.width": video_stream_info["width"],
|
||||
"video.height": video_stream_info["height"],
|
||||
"video.channels": pixel_channels,
|
||||
"video.codec": video_stream_info["codec_name"],
|
||||
"video.pix_fmt": video_stream_info["pix_fmt"],
|
||||
"video.is_depth_map": False,
|
||||
**_get_audio_info(video_path),
|
||||
}
|
||||
|
||||
return video_info
|
||||
|
||||
|
||||
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
|
||||
hub_api = HfApi()
|
||||
videos_info_dict = {"videos_path": DEFAULT_VIDEO_PATH}
|
||||
|
@ -481,62 +408,11 @@ def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch
|
|||
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
|
||||
)
|
||||
for vid_key, vid_path in zip(video_keys, video_files, strict=True):
|
||||
videos_info_dict[vid_key] = _get_video_info(local_dir / vid_path)
|
||||
videos_info_dict[vid_key] = get_video_info(local_dir / vid_path)
|
||||
|
||||
return videos_info_dict
|
||||
|
||||
|
||||
def get_video_pixel_channels(pix_fmt: str) -> int:
|
||||
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
|
||||
return 1
|
||||
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
|
||||
return 4
|
||||
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
|
||||
return 3
|
||||
else:
|
||||
raise ValueError("Unknown format")
|
||||
|
||||
|
||||
def get_image_pixel_channels(image: Image):
|
||||
if image.mode == "L":
|
||||
return 1 # Grayscale
|
||||
elif image.mode == "LA":
|
||||
return 2 # Grayscale + Alpha
|
||||
elif image.mode == "RGB":
|
||||
return 3 # RGB
|
||||
elif image.mode == "RGBA":
|
||||
return 4 # RGBA
|
||||
else:
|
||||
raise ValueError("Unknown format")
|
||||
|
||||
|
||||
def get_video_shapes(videos_info: dict, video_keys: list) -> dict:
|
||||
video_shapes = {}
|
||||
for img_key in video_keys:
|
||||
channels = get_video_pixel_channels(videos_info[img_key]["video.pix_fmt"])
|
||||
video_shapes[img_key] = {
|
||||
"width": videos_info[img_key]["video.width"],
|
||||
"height": videos_info[img_key]["video.height"],
|
||||
"channels": channels,
|
||||
}
|
||||
|
||||
return video_shapes
|
||||
|
||||
|
||||
def get_image_shapes(dataset: Dataset, image_keys: list) -> dict:
|
||||
image_shapes = {}
|
||||
for img_key in image_keys:
|
||||
image = dataset[0][img_key] # Assuming first row
|
||||
channels = get_image_pixel_channels(image)
|
||||
image_shapes[img_key] = {
|
||||
"width": image.width,
|
||||
"height": image.height,
|
||||
"channels": channels,
|
||||
}
|
||||
|
||||
return image_shapes
|
||||
|
||||
|
||||
def get_generic_motor_names(sequence_shapes: dict) -> dict:
|
||||
return {key: [f"motor_{i}" for i in range(length)] for key, length in sequence_shapes.items()}
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
import warnings
|
||||
|
@ -24,7 +25,9 @@ from typing import Any, ClassVar
|
|||
import pyarrow as pa
|
||||
import torch
|
||||
import torchvision
|
||||
from datasets import Dataset
|
||||
from datasets.features.features import register_feature
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def decode_video_frames_torchvision(
|
||||
|
@ -210,3 +213,131 @@ with warnings.catch_warnings():
|
|||
)
|
||||
# to make VideoFrame available in HuggingFace `datasets`
|
||||
register_feature(VideoFrame, "VideoFrame")
|
||||
|
||||
|
||||
def get_audio_info(video_path: Path | str) -> dict:
|
||||
ffprobe_audio_cmd = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-select_streams",
|
||||
"a:0",
|
||||
"-show_entries",
|
||||
"stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration",
|
||||
"-of",
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
info = json.loads(result.stdout)
|
||||
audio_stream_info = info["streams"][0] if info.get("streams") else None
|
||||
if audio_stream_info is None:
|
||||
return {"has_audio": False}
|
||||
|
||||
# Return the information, defaulting to None if no audio stream is present
|
||||
return {
|
||||
"has_audio": True,
|
||||
"audio.channels": audio_stream_info.get("channels", None),
|
||||
"audio.codec": audio_stream_info.get("codec_name", None),
|
||||
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
|
||||
"audio.sample_rate": int(audio_stream_info["sample_rate"])
|
||||
if audio_stream_info.get("sample_rate")
|
||||
else None,
|
||||
"audio.bit_depth": audio_stream_info.get("bit_depth", None),
|
||||
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
|
||||
}
|
||||
|
||||
|
||||
def get_video_info(video_path: Path | str) -> dict:
|
||||
ffprobe_video_cmd = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-select_streams",
|
||||
"v:0",
|
||||
"-show_entries",
|
||||
"stream=r_frame_rate,width,height,codec_name,nb_frames,duration,pix_fmt",
|
||||
"-of",
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
info = json.loads(result.stdout)
|
||||
video_stream_info = info["streams"][0]
|
||||
|
||||
# Calculate fps from r_frame_rate
|
||||
r_frame_rate = video_stream_info["r_frame_rate"]
|
||||
num, denom = map(int, r_frame_rate.split("/"))
|
||||
fps = num / denom
|
||||
|
||||
pixel_channels = get_video_pixel_channels(video_stream_info["pix_fmt"])
|
||||
|
||||
video_info = {
|
||||
"video.fps": fps,
|
||||
"video.width": video_stream_info["width"],
|
||||
"video.height": video_stream_info["height"],
|
||||
"video.channels": pixel_channels,
|
||||
"video.codec": video_stream_info["codec_name"],
|
||||
"video.pix_fmt": video_stream_info["pix_fmt"],
|
||||
"video.is_depth_map": False,
|
||||
**get_audio_info(video_path),
|
||||
}
|
||||
|
||||
return video_info
|
||||
|
||||
|
||||
def get_video_shapes(videos_info: dict, video_keys: list) -> dict:
|
||||
video_shapes = {}
|
||||
for img_key in video_keys:
|
||||
channels = get_video_pixel_channels(videos_info[img_key]["video.pix_fmt"])
|
||||
video_shapes[img_key] = {
|
||||
"width": videos_info[img_key]["video.width"],
|
||||
"height": videos_info[img_key]["video.height"],
|
||||
"channels": channels,
|
||||
}
|
||||
|
||||
return video_shapes
|
||||
|
||||
|
||||
def get_image_shapes(dataset: Dataset, image_keys: list) -> dict:
|
||||
image_shapes = {}
|
||||
for img_key in image_keys:
|
||||
image = dataset[0][img_key] # Assuming first row
|
||||
channels = get_image_pixel_channels(image)
|
||||
image_shapes[img_key] = {
|
||||
"width": image.width,
|
||||
"height": image.height,
|
||||
"channels": channels,
|
||||
}
|
||||
|
||||
return image_shapes
|
||||
|
||||
|
||||
def get_video_pixel_channels(pix_fmt: str) -> int:
|
||||
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
|
||||
return 1
|
||||
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
|
||||
return 4
|
||||
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
|
||||
return 3
|
||||
else:
|
||||
raise ValueError("Unknown format")
|
||||
|
||||
|
||||
def get_image_pixel_channels(image: Image):
|
||||
if image.mode == "L":
|
||||
return 1 # Grayscale
|
||||
elif image.mode == "LA":
|
||||
return 2 # Grayscale + Alpha
|
||||
elif image.mode == "RGB":
|
||||
return 3 # RGB
|
||||
elif image.mode == "RGBA":
|
||||
return 4 # RGBA
|
||||
else:
|
||||
raise ValueError("Unknown format")
|
||||
|
|
|
@ -234,8 +234,8 @@ def record(
|
|||
dataset = LeRobotDataset.create(
|
||||
repo_id,
|
||||
fps,
|
||||
robot,
|
||||
root=root,
|
||||
robot=robot,
|
||||
image_writer_processes=num_image_writer_processes,
|
||||
image_writer_threads_per_camera=num_image_writer_threads_per_camera,
|
||||
use_videos=video,
|
||||
|
@ -307,10 +307,6 @@ def record(
|
|||
log_say("Stop recording", play_sounds, blocking=True)
|
||||
stop_recording(robot, listener, display_cameras)
|
||||
|
||||
if dataset.image_writer is not None:
|
||||
logging.info("Waiting for image writer to terminate...")
|
||||
dataset.image_writer.stop()
|
||||
|
||||
if run_compute_stats:
|
||||
logging.info("Computing dataset statistics")
|
||||
|
||||
|
|
Loading…
Reference in New Issue