Comply with torchvision 0.21 custom transforms (#665)

This commit is contained in:
Simon Alibert 2025-01-30 22:06:11 +01:00 committed by GitHub
parent c4d912a241
commit 1ee1acf8ad
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 2739 additions and 2576 deletions

View File

@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections import collections
from typing import Any, Callable, Dict, Sequence from typing import Any, Callable, Sequence
import torch import torch
from torchvision.transforms import v2 from torchvision.transforms import v2
@ -129,11 +129,12 @@ class SharpnessJitter(Transform):
return float(sharpness[0]), float(sharpness[1]) return float(sharpness[0]), float(sharpness[1])
def _generate_value(self, left: float, right: float) -> float: def make_params(self, flat_inputs: list[Any]) -> dict[str, Any]:
return torch.empty(1).uniform_(left, right).item() sharpness_factor = torch.empty(1).uniform_(self.sharpness[0], self.sharpness[1]).item()
return {"sharpness_factor": sharpness_factor}
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
sharpness_factor = self._generate_value(self.sharpness[0], self.sharpness[1]) sharpness_factor = params["sharpness_factor"]
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor) return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)

5302
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -41,7 +41,7 @@ numba = ">=0.59.0"
torch = ">=2.2.1" torch = ">=2.2.1"
opencv-python = ">=4.9.0" opencv-python = ">=4.9.0"
diffusers = ">=0.27.2" diffusers = ">=0.27.2"
torchvision = ">=0.17.1" torchvision = ">=0.21.0"
h5py = ">=3.10.0" h5py = ">=3.10.0"
huggingface-hub = {extras = ["hf-transfer", "cli"], version = ">=0.25.2"} huggingface-hub = {extras = ["hf-transfer", "cli"], version = ">=0.25.2"}
gymnasium = "==0.29.1" # TODO(rcadene, aliberts): Make gym 1.0.0 work gymnasium = "==0.29.1" # TODO(rcadene, aliberts): Make gym 1.0.0 work