From 21ba4b5263048eeab3252170274a7de17b843bed Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Sun, 6 Oct 2024 11:16:49 +0200 Subject: [PATCH] Add pixel channels --- convert_dataset_16_to_20.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/convert_dataset_16_to_20.py b/convert_dataset_16_to_20.py index f2878605..abcd8ff0 100644 --- a/convert_dataset_16_to_20.py +++ b/convert_dataset_16_to_20.py @@ -261,10 +261,13 @@ def _get_video_info(video_path: Path | str) -> dict: num, denom = map(int, r_frame_rate.split("/")) fps = num / denom + pixel_channels = get_video_pixel_channels(video_stream_info["pix_fmt"]) + video_info = { "video.fps": fps, "video.width": video_stream_info["width"], "video.height": video_stream_info["height"], + "video.channels": pixel_channels, "video.codec": video_stream_info["codec_name"], "video.pix_fmt": video_stream_info["pix_fmt"], **_get_audio_info(video_path), @@ -293,12 +296,38 @@ def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str]) -> dic return videos_info_dict +def get_video_pixel_channels(pix_fmt: str) -> int: + if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt: + return 1 + elif "rgba" in pix_fmt or "yuva" in pix_fmt: + return 4 + elif "rgb" in pix_fmt or "yuv" in pix_fmt: + return 3 + else: + raise ValueError("Unknown format") + + +def get_image_pixel_channels(image: Image): + if image.mode == "L": + return 1 # Grayscale + elif image.mode == "LA": + return 2 # Grayscale + Alpha + elif image.mode == "RGB": + return 3 # RGB + elif image.mode == "RGBA": + return 4 # RGBA + else: + raise ValueError("Unknown format") + + def get_video_shapes(videos_info: dict, video_keys: list) -> dict: video_shapes = {} for img_key in video_keys: + channels = get_video_pixel_channels(videos_info[img_key]["video.pix_fmt"]) video_shapes[img_key] = { "width": videos_info[img_key]["video.width"], "height": videos_info[img_key]["video.height"], + "channels": channels, } return video_shapes @@ -309,9 +338,11 @@ def get_image_shapes(table: pa.Table, image_keys: list) -> dict: for img_key in image_keys: image_bytes = table[img_key][0].as_py() # Assuming first row image = Image.open(BytesIO(image_bytes["bytes"])) + channels = get_image_pixel_channels(image) image_shapes[img_key] = { "width": image.width, "height": image.height, + "channels": channels, } return image_shapes