Add pixel channels

This commit is contained in:
Simon Alibert 2024-10-06 11:16:49 +02:00
parent 028c17fd48
commit 21ba4b5263
1 changed files with 31 additions and 0 deletions

View File

@ -261,10 +261,13 @@ def _get_video_info(video_path: Path | str) -> dict:
num, denom = map(int, r_frame_rate.split("/")) num, denom = map(int, r_frame_rate.split("/"))
fps = num / denom fps = num / denom
pixel_channels = get_video_pixel_channels(video_stream_info["pix_fmt"])
video_info = { video_info = {
"video.fps": fps, "video.fps": fps,
"video.width": video_stream_info["width"], "video.width": video_stream_info["width"],
"video.height": video_stream_info["height"], "video.height": video_stream_info["height"],
"video.channels": pixel_channels,
"video.codec": video_stream_info["codec_name"], "video.codec": video_stream_info["codec_name"],
"video.pix_fmt": video_stream_info["pix_fmt"], "video.pix_fmt": video_stream_info["pix_fmt"],
**_get_audio_info(video_path), **_get_audio_info(video_path),
@ -293,12 +296,38 @@ def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str]) -> dic
return videos_info_dict return videos_info_dict
def get_video_pixel_channels(pix_fmt: str) -> int:
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
return 1
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
return 4
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
return 3
else:
raise ValueError("Unknown format")
def get_image_pixel_channels(image: Image):
if image.mode == "L":
return 1 # Grayscale
elif image.mode == "LA":
return 2 # Grayscale + Alpha
elif image.mode == "RGB":
return 3 # RGB
elif image.mode == "RGBA":
return 4 # RGBA
else:
raise ValueError("Unknown format")
def get_video_shapes(videos_info: dict, video_keys: list) -> dict: def get_video_shapes(videos_info: dict, video_keys: list) -> dict:
video_shapes = {} video_shapes = {}
for img_key in video_keys: for img_key in video_keys:
channels = get_video_pixel_channels(videos_info[img_key]["video.pix_fmt"])
video_shapes[img_key] = { video_shapes[img_key] = {
"width": videos_info[img_key]["video.width"], "width": videos_info[img_key]["video.width"],
"height": videos_info[img_key]["video.height"], "height": videos_info[img_key]["video.height"],
"channels": channels,
} }
return video_shapes return video_shapes
@ -309,9 +338,11 @@ def get_image_shapes(table: pa.Table, image_keys: list) -> dict:
for img_key in image_keys: for img_key in image_keys:
image_bytes = table[img_key][0].as_py() # Assuming first row image_bytes = table[img_key][0].as_py() # Assuming first row
image = Image.open(BytesIO(image_bytes["bytes"])) image = Image.open(BytesIO(image_bytes["bytes"]))
channels = get_image_pixel_channels(image)
image_shapes[img_key] = { image_shapes[img_key] = {
"width": image.width, "width": image.width,
"height": image.height, "height": image.height,
"channels": channels,
} }
return image_shapes return image_shapes