Redesign config
This commit is contained in:
parent
443b06b412
commit
fdf56e7a62
|
@ -1,8 +1,9 @@
|
|||
from typing import Any, Callable, Sequence
|
||||
from typing import Any, Callable, Dict, Sequence
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import v2
|
||||
from torchvision.transforms.v2 import Transform
|
||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
|
||||
|
||||
class RandomSubsetApply(Transform):
|
||||
|
@ -45,17 +46,39 @@ class RandomSubsetApply(Transform):
|
|||
return "\n".join(format_string)
|
||||
|
||||
|
||||
def make_transforms(cfg):
|
||||
image_transforms = []
|
||||
if "colorjitter" in cfg.list:
|
||||
image_transforms.append(
|
||||
v2.ColorJitter(brightness=cfg.colorjitter_factor, contrast=cfg.colorjitter_factor)
|
||||
)
|
||||
if "sharpness" in cfg.list:
|
||||
image_transforms.append(v2.RandomAdjustSharpness(cfg.sharpness_factor, p=cfg.sharpness_p))
|
||||
if "blur" in cfg.list:
|
||||
image_transforms.append(v2.RandomAdjustSharpness(cfg.blur_factor, p=cfg.blur_p))
|
||||
class RangeRandomSharpness(Transform):
|
||||
"""Similar to RandomAdjustSharpness but with p=1 and a sharpness_factor sampled randomly
|
||||
each time in [range_min, range_max].
|
||||
|
||||
If the input is a :class:`torch.Tensor`,
|
||||
it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, range_min: float, range_max) -> None:
|
||||
super().__init__()
|
||||
self.range_min, self.range_max = self._check_input(range_min, range_max)
|
||||
|
||||
def _check_input(self, range_min, range_max):
|
||||
if range_min < 0:
|
||||
raise ValueError("range_min must be non negative.")
|
||||
if range_min > range_max:
|
||||
raise ValueError("range_max must greater or equal to range_min")
|
||||
return range_min, range_max
|
||||
|
||||
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
||||
sharpness_factor = self.range_min + (self.range_max - self.range_min) * torch.rand(1)
|
||||
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
|
||||
|
||||
|
||||
def make_transforms(cfg):
|
||||
image_transforms = [
|
||||
v2.ColorJitter(brightness=(cfg.brightness.min, cfg.brightness.max)),
|
||||
v2.ColorJitter(contrast=(cfg.contrast.min, cfg.contrast.max)),
|
||||
v2.ColorJitter(saturation=(cfg.saturation.min, cfg.saturation.max)),
|
||||
v2.ColorJitter(hue=(cfg.hue.min, cfg.hue.max)),
|
||||
RangeRandomSharpness(cfg.sharpness.min, cfg.sharpness.max),
|
||||
]
|
||||
# WIP
|
||||
return v2.Compose(
|
||||
[RandomSubsetApply(image_transforms, n_subset=cfg.n_subset), v2.ToDtype(torch.float32, scale=True)]
|
||||
)
|
||||
|
|
|
@ -60,14 +60,26 @@ wandb:
|
|||
|
||||
image_transform:
|
||||
enable: false
|
||||
colorjitter_factor: 0.5
|
||||
colorjitter_p: 1.O
|
||||
sharpness_factor: 3.0
|
||||
# Should be more than 1, setting parameter to 1 does not change the image
|
||||
sharpness_p: 0.5
|
||||
blur_factor: 0.2
|
||||
# Should be less than 1, setting parameter to 1 does not change the image
|
||||
blur_p: 0.5
|
||||
n_subset: 3
|
||||
# Maximum number of transforms to apply
|
||||
list: ["colorjitter", "sharpness", "blur"]
|
||||
max_num_transforms: 3
|
||||
random_order: false
|
||||
brightness:
|
||||
weight: 1
|
||||
min: 0.5
|
||||
max: 1.5
|
||||
contrast:
|
||||
weight: 1
|
||||
min: 0.5
|
||||
max: 1.5
|
||||
saturation:
|
||||
weight: 1
|
||||
min: 0.5
|
||||
max: 1.5
|
||||
hue:
|
||||
weight: 1
|
||||
min: -0.1
|
||||
max: 0.1
|
||||
sharpness:
|
||||
weight: 1
|
||||
min: 0.9
|
||||
max: 1.1
|
||||
|
|
Loading…
Reference in New Issue