74 lines
2.3 KiB
Python
74 lines
2.3 KiB
Python
# Copyright (c) 2020 Preferred Networks, Inc.
|
|
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
|
|
from __future__ import annotations
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
|
|
class EmpiricalNormalization(nn.Module):
|
|
"""Normalize mean and variance of values based on empirical values."""
|
|
|
|
def __init__(self, shape, eps=1e-2, until=None):
|
|
"""Initialize EmpiricalNormalization module.
|
|
|
|
Args:
|
|
shape (int or tuple of int): Shape of input values except batch axis.
|
|
eps (float): Small value for stability.
|
|
until (int or None): If this arg is specified, the link learns input values until the sum of batch sizes
|
|
exceeds it.
|
|
"""
|
|
super().__init__()
|
|
self.eps = eps
|
|
self.until = until
|
|
self.register_buffer("_mean", torch.zeros(shape).unsqueeze(0))
|
|
self.register_buffer("_var", torch.ones(shape).unsqueeze(0))
|
|
self.register_buffer("_std", torch.ones(shape).unsqueeze(0))
|
|
self.count = 0
|
|
|
|
@property
|
|
def mean(self):
|
|
return self._mean.squeeze(0).clone()
|
|
|
|
@property
|
|
def std(self):
|
|
return self._std.squeeze(0).clone()
|
|
|
|
def forward(self, x):
|
|
"""Normalize mean and variance of values based on empirical values.
|
|
|
|
Args:
|
|
x (ndarray or Variable): Input values
|
|
|
|
Returns:
|
|
ndarray or Variable: Normalized output values
|
|
"""
|
|
|
|
if self.training:
|
|
self.update(x)
|
|
return (x - self._mean) / (self._std + self.eps)
|
|
|
|
@torch.jit.unused
|
|
def update(self, x):
|
|
"""Learn input values without computing the output values of them"""
|
|
|
|
if self.until is not None and self.count >= self.until:
|
|
return
|
|
|
|
count_x = x.shape[0]
|
|
self.count += count_x
|
|
rate = count_x / self.count
|
|
|
|
var_x = torch.var(x, dim=0, unbiased=False, keepdim=True)
|
|
mean_x = torch.mean(x, dim=0, keepdim=True)
|
|
delta_mean = mean_x - self._mean
|
|
self._mean += rate * delta_mean
|
|
self._var += rate * (var_x - self._var + delta_mean * (mean_x - self._mean))
|
|
self._std = torch.sqrt(self._var)
|
|
|
|
@torch.jit.unused
|
|
def inverse(self, y):
|
|
return y * (self._std + self.eps) + self._mean
|