Redesign config

This commit is contained in:
Simon Alibert 2024-06-05 09:49:31 +00:00
parent 443b06b412
commit fdf56e7a62
2 changed files with 56 additions and 21 deletions

View File

@ -1,8 +1,9 @@
from typing import Any, Callable, Sequence
from typing import Any, Callable, Dict, Sequence
import torch
from torchvision.transforms import v2
from torchvision.transforms.v2 import Transform
from torchvision.transforms.v2 import functional as F # noqa: N812
class RandomSubsetApply(Transform):
@ -45,17 +46,39 @@ class RandomSubsetApply(Transform):
return "\n".join(format_string)
def make_transforms(cfg):
image_transforms = []
if "colorjitter" in cfg.list:
image_transforms.append(
v2.ColorJitter(brightness=cfg.colorjitter_factor, contrast=cfg.colorjitter_factor)
)
if "sharpness" in cfg.list:
image_transforms.append(v2.RandomAdjustSharpness(cfg.sharpness_factor, p=cfg.sharpness_p))
if "blur" in cfg.list:
image_transforms.append(v2.RandomAdjustSharpness(cfg.blur_factor, p=cfg.blur_p))
class RangeRandomSharpness(Transform):
"""Similar to RandomAdjustSharpness but with p=1 and a sharpness_factor sampled randomly
each time in [range_min, range_max].
If the input is a :class:`torch.Tensor`,
it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
"""
def __init__(self, range_min: float, range_max) -> None:
super().__init__()
self.range_min, self.range_max = self._check_input(range_min, range_max)
def _check_input(self, range_min, range_max):
if range_min < 0:
raise ValueError("range_min must be non negative.")
if range_min > range_max:
raise ValueError("range_max must greater or equal to range_min")
return range_min, range_max
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
sharpness_factor = self.range_min + (self.range_max - self.range_min) * torch.rand(1)
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
def make_transforms(cfg):
image_transforms = [
v2.ColorJitter(brightness=(cfg.brightness.min, cfg.brightness.max)),
v2.ColorJitter(contrast=(cfg.contrast.min, cfg.contrast.max)),
v2.ColorJitter(saturation=(cfg.saturation.min, cfg.saturation.max)),
v2.ColorJitter(hue=(cfg.hue.min, cfg.hue.max)),
RangeRandomSharpness(cfg.sharpness.min, cfg.sharpness.max),
]
# WIP
return v2.Compose(
[RandomSubsetApply(image_transforms, n_subset=cfg.n_subset), v2.ToDtype(torch.float32, scale=True)]
)

View File

@ -60,14 +60,26 @@ wandb:
image_transform:
enable: false
colorjitter_factor: 0.5
colorjitter_p: 1.O
sharpness_factor: 3.0
# Should be more than 1, setting parameter to 1 does not change the image
sharpness_p: 0.5
blur_factor: 0.2
# Should be less than 1, setting parameter to 1 does not change the image
blur_p: 0.5
n_subset: 3
# Maximum number of transforms to apply
list: ["colorjitter", "sharpness", "blur"]
max_num_transforms: 3
random_order: false
brightness:
weight: 1
min: 0.5
max: 1.5
contrast:
weight: 1
min: 0.5
max: 1.5
saturation:
weight: 1
min: 0.5
max: 1.5
hue:
weight: 1
min: -0.1
max: 0.1
sharpness:
weight: 1
min: 0.9
max: 1.1