Comply with torchvision 0.21 custom transforms (#665)
This commit is contained in:
parent
c4d912a241
commit
1ee1acf8ad
|
@ -14,7 +14,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import collections
|
||||
from typing import Any, Callable, Dict, Sequence
|
||||
from typing import Any, Callable, Sequence
|
||||
|
||||
import torch
|
||||
from torchvision.transforms import v2
|
||||
|
@ -129,11 +129,12 @@ class SharpnessJitter(Transform):
|
|||
|
||||
return float(sharpness[0]), float(sharpness[1])
|
||||
|
||||
def _generate_value(self, left: float, right: float) -> float:
|
||||
return torch.empty(1).uniform_(left, right).item()
|
||||
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
|
||||
sharpness_factor = torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item()
|
||||
return {"sharpness_factor": sharpness_factor}
|
||||
|
||||
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
||||
sharpness_factor = self._generate_value(self.sharpness[0], self.sharpness[1])
|
||||
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
|
||||
sharpness_factor = params["sharpness_factor"]
|
||||
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
|
||||
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -41,7 +41,7 @@ numba = ">=0.59.0"
|
|||
torch = ">=2.2.1"
|
||||
opencv-python = ">=4.9.0"
|
||||
diffusers = ">=0.27.2"
|
||||
torchvision = ">=0.17.1"
|
||||
torchvision = ">=0.21.0"
|
||||
h5py = ">=3.10.0"
|
||||
huggingface-hub = {extras = ["hf-transfer", "cli"], version = ">=0.25.2"}
|
||||
gymnasium = "==0.29.1" # TODO(rcadene, aliberts): Make gym 1.0.0 work
|
||||
|
|
Loading…
Reference in New Issue