Redesign config
This commit is contained in:
parent
443b06b412
commit
fdf56e7a62
|
@ -1,8 +1,9 @@
|
||||||
from typing import Any, Callable, Sequence
|
from typing import Any, Callable, Dict, Sequence
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torchvision.transforms import v2
|
from torchvision.transforms import v2
|
||||||
from torchvision.transforms.v2 import Transform
|
from torchvision.transforms.v2 import Transform
|
||||||
|
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||||
|
|
||||||
|
|
||||||
class RandomSubsetApply(Transform):
|
class RandomSubsetApply(Transform):
|
||||||
|
@ -45,17 +46,39 @@ class RandomSubsetApply(Transform):
|
||||||
return "\n".join(format_string)
|
return "\n".join(format_string)
|
||||||
|
|
||||||
|
|
||||||
def make_transforms(cfg):
|
class RangeRandomSharpness(Transform):
|
||||||
image_transforms = []
|
"""Similar to RandomAdjustSharpness but with p=1 and a sharpness_factor sampled randomly
|
||||||
if "colorjitter" in cfg.list:
|
each time in [range_min, range_max].
|
||||||
image_transforms.append(
|
|
||||||
v2.ColorJitter(brightness=cfg.colorjitter_factor, contrast=cfg.colorjitter_factor)
|
|
||||||
)
|
|
||||||
if "sharpness" in cfg.list:
|
|
||||||
image_transforms.append(v2.RandomAdjustSharpness(cfg.sharpness_factor, p=cfg.sharpness_p))
|
|
||||||
if "blur" in cfg.list:
|
|
||||||
image_transforms.append(v2.RandomAdjustSharpness(cfg.blur_factor, p=cfg.blur_p))
|
|
||||||
|
|
||||||
|
If the input is a :class:`torch.Tensor`,
|
||||||
|
it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, range_min: float, range_max) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.range_min, self.range_max = self._check_input(range_min, range_max)
|
||||||
|
|
||||||
|
def _check_input(self, range_min, range_max):
|
||||||
|
if range_min < 0:
|
||||||
|
raise ValueError("range_min must be non negative.")
|
||||||
|
if range_min > range_max:
|
||||||
|
raise ValueError("range_max must greater or equal to range_min")
|
||||||
|
return range_min, range_max
|
||||||
|
|
||||||
|
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
||||||
|
sharpness_factor = self.range_min + (self.range_max - self.range_min) * torch.rand(1)
|
||||||
|
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
|
||||||
|
|
||||||
|
|
||||||
|
def make_transforms(cfg):
|
||||||
|
image_transforms = [
|
||||||
|
v2.ColorJitter(brightness=(cfg.brightness.min, cfg.brightness.max)),
|
||||||
|
v2.ColorJitter(contrast=(cfg.contrast.min, cfg.contrast.max)),
|
||||||
|
v2.ColorJitter(saturation=(cfg.saturation.min, cfg.saturation.max)),
|
||||||
|
v2.ColorJitter(hue=(cfg.hue.min, cfg.hue.max)),
|
||||||
|
RangeRandomSharpness(cfg.sharpness.min, cfg.sharpness.max),
|
||||||
|
]
|
||||||
|
# WIP
|
||||||
return v2.Compose(
|
return v2.Compose(
|
||||||
[RandomSubsetApply(image_transforms, n_subset=cfg.n_subset), v2.ToDtype(torch.float32, scale=True)]
|
[RandomSubsetApply(image_transforms, n_subset=cfg.n_subset), v2.ToDtype(torch.float32, scale=True)]
|
||||||
)
|
)
|
||||||
|
|
|
@ -60,14 +60,26 @@ wandb:
|
||||||
|
|
||||||
image_transform:
|
image_transform:
|
||||||
enable: false
|
enable: false
|
||||||
colorjitter_factor: 0.5
|
|
||||||
colorjitter_p: 1.O
|
|
||||||
sharpness_factor: 3.0
|
|
||||||
# Should be more than 1, setting parameter to 1 does not change the image
|
|
||||||
sharpness_p: 0.5
|
|
||||||
blur_factor: 0.2
|
|
||||||
# Should be less than 1, setting parameter to 1 does not change the image
|
|
||||||
blur_p: 0.5
|
|
||||||
n_subset: 3
|
|
||||||
# Maximum number of transforms to apply
|
# Maximum number of transforms to apply
|
||||||
list: ["colorjitter", "sharpness", "blur"]
|
max_num_transforms: 3
|
||||||
|
random_order: false
|
||||||
|
brightness:
|
||||||
|
weight: 1
|
||||||
|
min: 0.5
|
||||||
|
max: 1.5
|
||||||
|
contrast:
|
||||||
|
weight: 1
|
||||||
|
min: 0.5
|
||||||
|
max: 1.5
|
||||||
|
saturation:
|
||||||
|
weight: 1
|
||||||
|
min: 0.5
|
||||||
|
max: 1.5
|
||||||
|
hue:
|
||||||
|
weight: 1
|
||||||
|
min: -0.1
|
||||||
|
max: 0.1
|
||||||
|
sharpness:
|
||||||
|
weight: 1
|
||||||
|
min: 0.9
|
||||||
|
max: 1.1
|
||||||
|
|
Loading…
Reference in New Issue