Comply with torchvision 0.21 custom transforms (#665)

This commit is contained in:
Simon Alibert 2025-01-30 22:06:11 +01:00 committed by GitHub
parent c4d912a241
commit 1ee1acf8ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 2739 additions and 2576 deletions

View File

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from typing import Any, Callable, Dict, Sequence
from typing import Any, Callable, Sequence
import torch
from torchvision.transforms import v2
@ -129,11 +129,12 @@ class SharpnessJitter(Transform):
return float(sharpness[0]), float(sharpness[1])
def _generate_value(self, left: float, right: float) -> float:
return torch.empty(1).uniform_(left, right).item()
def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
sharpness_factor = torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item()
return {"sharpness_factor": sharpness_factor}
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
sharpness_factor = self._generate_value(self.sharpness[0], self.sharpness[1])
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
sharpness_factor = params["sharpness_factor"]
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)

5302
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -41,7 +41,7 @@ numba = ">=0.59.0"
torch = ">=2.2.1"
opencv-python = ">=4.9.0"
diffusers = ">=0.27.2"
torchvision = ">=0.17.1"
torchvision = ">=0.21.0"
h5py = ">=3.10.0"
huggingface-hub = {extras = ["hf-transfer", "cli"], version = ">=0.25.2"}
gymnasium = "==0.29.1" # TODO(rcadene, aliberts): Make gym 1.0.0 work