Add pixel channels
This commit is contained in:
parent
028c17fd48
commit
21ba4b5263
|
@ -261,10 +261,13 @@ def _get_video_info(video_path: Path | str) -> dict:
|
||||||
num, denom = map(int, r_frame_rate.split("/"))
|
num, denom = map(int, r_frame_rate.split("/"))
|
||||||
fps = num / denom
|
fps = num / denom
|
||||||
|
|
||||||
|
pixel_channels = get_video_pixel_channels(video_stream_info["pix_fmt"])
|
||||||
|
|
||||||
video_info = {
|
video_info = {
|
||||||
"video.fps": fps,
|
"video.fps": fps,
|
||||||
"video.width": video_stream_info["width"],
|
"video.width": video_stream_info["width"],
|
||||||
"video.height": video_stream_info["height"],
|
"video.height": video_stream_info["height"],
|
||||||
|
"video.channels": pixel_channels,
|
||||||
"video.codec": video_stream_info["codec_name"],
|
"video.codec": video_stream_info["codec_name"],
|
||||||
"video.pix_fmt": video_stream_info["pix_fmt"],
|
"video.pix_fmt": video_stream_info["pix_fmt"],
|
||||||
**_get_audio_info(video_path),
|
**_get_audio_info(video_path),
|
||||||
|
@ -293,12 +296,38 @@ def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str]) -> dic
|
||||||
return videos_info_dict
|
return videos_info_dict
|
||||||
|
|
||||||
|
|
||||||
|
def get_video_pixel_channels(pix_fmt: str) -> int:
|
||||||
|
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
|
||||||
|
return 1
|
||||||
|
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
|
||||||
|
return 4
|
||||||
|
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
|
||||||
|
return 3
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown format")
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_pixel_channels(image: Image):
|
||||||
|
if image.mode == "L":
|
||||||
|
return 1 # Grayscale
|
||||||
|
elif image.mode == "LA":
|
||||||
|
return 2 # Grayscale + Alpha
|
||||||
|
elif image.mode == "RGB":
|
||||||
|
return 3 # RGB
|
||||||
|
elif image.mode == "RGBA":
|
||||||
|
return 4 # RGBA
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown format")
|
||||||
|
|
||||||
|
|
||||||
def get_video_shapes(videos_info: dict, video_keys: list) -> dict:
|
def get_video_shapes(videos_info: dict, video_keys: list) -> dict:
|
||||||
video_shapes = {}
|
video_shapes = {}
|
||||||
for img_key in video_keys:
|
for img_key in video_keys:
|
||||||
|
channels = get_video_pixel_channels(videos_info[img_key]["video.pix_fmt"])
|
||||||
video_shapes[img_key] = {
|
video_shapes[img_key] = {
|
||||||
"width": videos_info[img_key]["video.width"],
|
"width": videos_info[img_key]["video.width"],
|
||||||
"height": videos_info[img_key]["video.height"],
|
"height": videos_info[img_key]["video.height"],
|
||||||
|
"channels": channels,
|
||||||
}
|
}
|
||||||
|
|
||||||
return video_shapes
|
return video_shapes
|
||||||
|
@ -309,9 +338,11 @@ def get_image_shapes(table: pa.Table, image_keys: list) -> dict:
|
||||||
for img_key in image_keys:
|
for img_key in image_keys:
|
||||||
image_bytes = table[img_key][0].as_py() # Assuming first row
|
image_bytes = table[img_key][0].as_py() # Assuming first row
|
||||||
image = Image.open(BytesIO(image_bytes["bytes"]))
|
image = Image.open(BytesIO(image_bytes["bytes"]))
|
||||||
|
channels = get_image_pixel_channels(image)
|
||||||
image_shapes[img_key] = {
|
image_shapes[img_key] = {
|
||||||
"width": image.width,
|
"width": image.width,
|
||||||
"height": image.height,
|
"height": image.height,
|
||||||
|
"channels": channels,
|
||||||
}
|
}
|
||||||
|
|
||||||
return image_shapes
|
return image_shapes
|
||||||
|
|
Loading…
Reference in New Issue