Redesign config

This commit is contained in:
Simon Alibert 2024-06-05 09:49:31 +00:00
parent 443b06b412
commit fdf56e7a62
2 changed files with 56 additions and 21 deletions

View File

@ -1,8 +1,9 @@
from typing import Any, Callable, Sequence from typing import Any, Callable, Dict, Sequence
import torch import torch
from torchvision.transforms import v2 from torchvision.transforms import v2
from torchvision.transforms.v2 import Transform from torchvision.transforms.v2 import Transform
from torchvision.transforms.v2 import functional as F # noqa: N812
class RandomSubsetApply(Transform): class RandomSubsetApply(Transform):
@ -45,17 +46,39 @@ class RandomSubsetApply(Transform):
return "\n".join(format_string) return "\n".join(format_string)
def make_transforms(cfg): class RangeRandomSharpness(Transform):
image_transforms = [] """Similar to RandomAdjustSharpness but with p=1 and a sharpness_factor sampled randomly
if "colorjitter" in cfg.list: each time in [range_min, range_max].
image_transforms.append(
v2.ColorJitter(brightness=cfg.colorjitter_factor, contrast=cfg.colorjitter_factor)
)
if "sharpness" in cfg.list:
image_transforms.append(v2.RandomAdjustSharpness(cfg.sharpness_factor, p=cfg.sharpness_p))
if "blur" in cfg.list:
image_transforms.append(v2.RandomAdjustSharpness(cfg.blur_factor, p=cfg.blur_p))
If the input is a :class:`torch.Tensor`,
it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
"""
def __init__(self, range_min: float, range_max) -> None:
super().__init__()
self.range_min, self.range_max = self._check_input(range_min, range_max)
def _check_input(self, range_min, range_max):
if range_min < 0:
raise ValueError("range_min must be non negative.")
if range_min > range_max:
raise ValueError("range_max must greater or equal to range_min")
return range_min, range_max
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
sharpness_factor = self.range_min + (self.range_max - self.range_min) * torch.rand(1)
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
def make_transforms(cfg):
image_transforms = [
v2.ColorJitter(brightness=(cfg.brightness.min, cfg.brightness.max)),
v2.ColorJitter(contrast=(cfg.contrast.min, cfg.contrast.max)),
v2.ColorJitter(saturation=(cfg.saturation.min, cfg.saturation.max)),
v2.ColorJitter(hue=(cfg.hue.min, cfg.hue.max)),
RangeRandomSharpness(cfg.sharpness.min, cfg.sharpness.max),
]
# WIP
return v2.Compose( return v2.Compose(
[RandomSubsetApply(image_transforms, n_subset=cfg.n_subset), v2.ToDtype(torch.float32, scale=True)] [RandomSubsetApply(image_transforms, n_subset=cfg.n_subset), v2.ToDtype(torch.float32, scale=True)]
) )

View File

@ -60,14 +60,26 @@ wandb:
image_transform: image_transform:
enable: false enable: false
colorjitter_factor: 0.5
colorjitter_p: 1.O
sharpness_factor: 3.0
# Should be more than 1, setting parameter to 1 does not change the image
sharpness_p: 0.5
blur_factor: 0.2
# Should be less than 1, setting parameter to 1 does not change the image
blur_p: 0.5
n_subset: 3
# Maximum number of transforms to apply # Maximum number of transforms to apply
list: ["colorjitter", "sharpness", "blur"] max_num_transforms: 3
random_order: false
brightness:
weight: 1
min: 0.5
max: 1.5
contrast:
weight: 1
min: 0.5
max: 1.5
saturation:
weight: 1
min: 0.5
max: 1.5
hue:
weight: 1
min: -0.1
max: 0.1
sharpness:
weight: 1
min: 0.9
max: 1.1