rsl_rl/rsl_rl/modules/normalizer.py

74 lines
2.3 KiB
Python

# Copyright (c) 2020 Preferred Networks, Inc.
# Copyright 2021 ETH Zurich, NVIDIA CORPORATION
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
import torch
from torch import nn
class EmpiricalNormalization(nn.Module):
"""Normalize mean and variance of values based on empirical values."""
def __init__(self, shape, eps=1e-2, until=None):
"""Initialize EmpiricalNormalization module.
Args:
shape (int or tuple of int): Shape of input values except batch axis.
eps (float): Small value for stability.
until (int or None): If this arg is specified, the link learns input values until the sum of batch sizes
exceeds it.
"""
super().__init__()
self.eps = eps
self.until = until
self.register_buffer("_mean", torch.zeros(shape).unsqueeze(0))
self.register_buffer("_var", torch.ones(shape).unsqueeze(0))
self.register_buffer("_std", torch.ones(shape).unsqueeze(0))
self.count = 0
@property
def mean(self):
return self._mean.squeeze(0).clone()
@property
def std(self):
return self._std.squeeze(0).clone()
def forward(self, x):
"""Normalize mean and variance of values based on empirical values.
Args:
x (ndarray or Variable): Input values
Returns:
ndarray or Variable: Normalized output values
"""
if self.training:
self.update(x)
return (x - self._mean) / (self._std + self.eps)
@torch.jit.unused
def update(self, x):
"""Learn input values without computing the output values of them"""
if self.until is not None and self.count >= self.until:
return
count_x = x.shape[0]
self.count += count_x
rate = count_x / self.count
var_x = torch.var(x, dim=0, unbiased=False, keepdim=True)
mean_x = torch.mean(x, dim=0, keepdim=True)
delta_mean = mean_x - self._mean
self._mean += rate * delta_mean
self._var += rate * (var_x - self._var + delta_mean * (mean_x - self._mean))
self._std = torch.sqrt(self._var)
@torch.jit.unused
def inverse(self, y):
return y * (self._std + self.eps) + self._mean