Raise ValueError if horizon is incompatible with downsampling (#422)

This commit is contained in:
Alexander Soare 2024-09-09 17:22:46 +01:00 committed by GitHub
parent 9c463661c1
commit a60d27b132
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 9 additions and 0 deletions

View File

@ -196,3 +196,12 @@ class DiffusionConfig:
f"`noise_scheduler_type` must be one of {supported_noise_schedulers}. " f"`noise_scheduler_type` must be one of {supported_noise_schedulers}. "
f"Got {self.noise_scheduler_type}." f"Got {self.noise_scheduler_type}."
) )
# Check that the horizon size and U-Net downsampling is compatible.
# U-Net downsamples by 2 with each stage.
downsampling_factor = 2 ** len(self.down_dims)
if self.horizon % downsampling_factor != 0:
raise ValueError(
"The horizon should be an integer multiple of the downsampling factor (which is determined "
f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}"
)