Comply with torchvision 0.21 custom transforms (#665)
This commit is contained in:
parent
c4d912a241
commit
1ee1acf8ad
|
@ -14,7 +14,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import collections
|
import collections
|
||||||
from typing import Any, Callable, Dict, Sequence
|
from typing import Any, Callable, Sequence
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torchvision.transforms import v2
|
from torchvision.transforms import v2
|
||||||
|
@ -129,11 +129,12 @@ class SharpnessJitter(Transform):
|
||||||
|
|
||||||
return float(sharpness[0]), float(sharpness[1])
|
return float(sharpness[0]), float(sharpness[1])
|
||||||
|
|
||||||
def _generate_value(self, left: float, right: float) -> float:
|
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
|
||||||
return torch.empty(1).uniform_(left, right).item()
|
sharpness_factor = torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item()
|
||||||
|
return {"sharpness_factor": sharpness_factor}
|
||||||
|
|
||||||
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
|
||||||
sharpness_factor = self._generate_value(self.sharpness[0], self.sharpness[1])
|
sharpness_factor = params["sharpness_factor"]
|
||||||
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
|
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
|
||||||
|
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -41,7 +41,7 @@ numba = ">=0.59.0"
|
||||||
torch = ">=2.2.1"
|
torch = ">=2.2.1"
|
||||||
opencv-python = ">=4.9.0"
|
opencv-python = ">=4.9.0"
|
||||||
diffusers = ">=0.27.2"
|
diffusers = ">=0.27.2"
|
||||||
torchvision = ">=0.17.1"
|
torchvision = ">=0.21.0"
|
||||||
h5py = ">=3.10.0"
|
h5py = ">=3.10.0"
|
||||||
huggingface-hub = {extras = ["hf-transfer", "cli"], version = ">=0.25.2"}
|
huggingface-hub = {extras = ["hf-transfer", "cli"], version = ">=0.25.2"}
|
||||||
gymnasium = "==0.29.1" # TODO(rcadene, aliberts): Make gym 1.0.0 work
|
gymnasium = "==0.29.1" # TODO(rcadene, aliberts): Make gym 1.0.0 work
|
||||||
|
|
Loading…
Reference in New Issue