359 lines
12 KiB
Python
359 lines
12 KiB
Python
from typing import Dict, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import zarr
|
|
|
|
from lerobot.common.policies.diffusion.model.dict_of_tensor_mixin import DictOfTensorMixin
|
|
from lerobot.common.policies.diffusion.pytorch_utils import dict_apply
|
|
|
|
|
|
class LinearNormalizer(DictOfTensorMixin):
|
|
avaliable_modes = ["limits", "gaussian"]
|
|
|
|
@torch.no_grad()
|
|
def fit(
|
|
self,
|
|
data: Union[Dict, torch.Tensor, np.ndarray, zarr.Array],
|
|
last_n_dims=1,
|
|
dtype=torch.float32,
|
|
mode="limits",
|
|
output_max=1.0,
|
|
output_min=-1.0,
|
|
range_eps=1e-4,
|
|
fit_offset=True,
|
|
):
|
|
if isinstance(data, dict):
|
|
for key, value in data.items():
|
|
self.params_dict[key] = _fit(
|
|
value,
|
|
last_n_dims=last_n_dims,
|
|
dtype=dtype,
|
|
mode=mode,
|
|
output_max=output_max,
|
|
output_min=output_min,
|
|
range_eps=range_eps,
|
|
fit_offset=fit_offset,
|
|
)
|
|
else:
|
|
self.params_dict["_default"] = _fit(
|
|
data,
|
|
last_n_dims=last_n_dims,
|
|
dtype=dtype,
|
|
mode=mode,
|
|
output_max=output_max,
|
|
output_min=output_min,
|
|
range_eps=range_eps,
|
|
fit_offset=fit_offset,
|
|
)
|
|
|
|
def __call__(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor:
|
|
return self.normalize(x)
|
|
|
|
def __getitem__(self, key: str):
|
|
return SingleFieldLinearNormalizer(self.params_dict[key])
|
|
|
|
def __setitem__(self, key: str, value: "SingleFieldLinearNormalizer"):
|
|
self.params_dict[key] = value.params_dict
|
|
|
|
def _normalize_impl(self, x, forward=True):
|
|
if isinstance(x, dict):
|
|
result = {}
|
|
for key, value in x.items():
|
|
params = self.params_dict[key]
|
|
result[key] = _normalize(value, params, forward=forward)
|
|
return result
|
|
else:
|
|
if "_default" not in self.params_dict:
|
|
raise RuntimeError("Not initialized")
|
|
params = self.params_dict["_default"]
|
|
return _normalize(x, params, forward=forward)
|
|
|
|
def normalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor:
|
|
return self._normalize_impl(x, forward=True)
|
|
|
|
def unnormalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor:
|
|
return self._normalize_impl(x, forward=False)
|
|
|
|
def get_input_stats(self) -> Dict:
|
|
if len(self.params_dict) == 0:
|
|
raise RuntimeError("Not initialized")
|
|
if len(self.params_dict) == 1 and "_default" in self.params_dict:
|
|
return self.params_dict["_default"]["input_stats"]
|
|
|
|
result = {}
|
|
for key, value in self.params_dict.items():
|
|
if key != "_default":
|
|
result[key] = value["input_stats"]
|
|
return result
|
|
|
|
def get_output_stats(self, key="_default"):
|
|
input_stats = self.get_input_stats()
|
|
if "min" in input_stats:
|
|
# no dict
|
|
return dict_apply(input_stats, self.normalize)
|
|
|
|
result = {}
|
|
for key, group in input_stats.items():
|
|
this_dict = {}
|
|
for name, value in group.items():
|
|
this_dict[name] = self.normalize({key: value})[key]
|
|
result[key] = this_dict
|
|
return result
|
|
|
|
|
|
class SingleFieldLinearNormalizer(DictOfTensorMixin):
|
|
avaliable_modes = ["limits", "gaussian"]
|
|
|
|
@torch.no_grad()
|
|
def fit(
|
|
self,
|
|
data: Union[torch.Tensor, np.ndarray, zarr.Array],
|
|
last_n_dims=1,
|
|
dtype=torch.float32,
|
|
mode="limits",
|
|
output_max=1.0,
|
|
output_min=-1.0,
|
|
range_eps=1e-4,
|
|
fit_offset=True,
|
|
):
|
|
self.params_dict = _fit(
|
|
data,
|
|
last_n_dims=last_n_dims,
|
|
dtype=dtype,
|
|
mode=mode,
|
|
output_max=output_max,
|
|
output_min=output_min,
|
|
range_eps=range_eps,
|
|
fit_offset=fit_offset,
|
|
)
|
|
|
|
@classmethod
|
|
def create_fit(cls, data: Union[torch.Tensor, np.ndarray, zarr.Array], **kwargs):
|
|
obj = cls()
|
|
obj.fit(data, **kwargs)
|
|
return obj
|
|
|
|
@classmethod
|
|
def create_manual(
|
|
cls,
|
|
scale: Union[torch.Tensor, np.ndarray],
|
|
offset: Union[torch.Tensor, np.ndarray],
|
|
input_stats_dict: Dict[str, Union[torch.Tensor, np.ndarray]],
|
|
):
|
|
def to_tensor(x):
|
|
if not isinstance(x, torch.Tensor):
|
|
x = torch.from_numpy(x)
|
|
x = x.flatten()
|
|
return x
|
|
|
|
# check
|
|
for x in [offset] + list(input_stats_dict.values()):
|
|
assert x.shape == scale.shape
|
|
assert x.dtype == scale.dtype
|
|
|
|
params_dict = nn.ParameterDict(
|
|
{
|
|
"scale": to_tensor(scale),
|
|
"offset": to_tensor(offset),
|
|
"input_stats": nn.ParameterDict(dict_apply(input_stats_dict, to_tensor)),
|
|
}
|
|
)
|
|
return cls(params_dict)
|
|
|
|
@classmethod
|
|
def create_identity(cls, dtype=torch.float32):
|
|
scale = torch.tensor([1], dtype=dtype)
|
|
offset = torch.tensor([0], dtype=dtype)
|
|
input_stats_dict = {
|
|
"min": torch.tensor([-1], dtype=dtype),
|
|
"max": torch.tensor([1], dtype=dtype),
|
|
"mean": torch.tensor([0], dtype=dtype),
|
|
"std": torch.tensor([1], dtype=dtype),
|
|
}
|
|
return cls.create_manual(scale, offset, input_stats_dict)
|
|
|
|
def normalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
|
|
return _normalize(x, self.params_dict, forward=True)
|
|
|
|
def unnormalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
|
|
return _normalize(x, self.params_dict, forward=False)
|
|
|
|
def get_input_stats(self):
|
|
return self.params_dict["input_stats"]
|
|
|
|
def get_output_stats(self):
|
|
return dict_apply(self.params_dict["input_stats"], self.normalize)
|
|
|
|
def __call__(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
|
|
return self.normalize(x)
|
|
|
|
|
|
def _fit(
|
|
data: Union[torch.Tensor, np.ndarray, zarr.Array],
|
|
last_n_dims=1,
|
|
dtype=torch.float32,
|
|
mode="limits",
|
|
output_max=1.0,
|
|
output_min=-1.0,
|
|
range_eps=1e-4,
|
|
fit_offset=True,
|
|
):
|
|
assert mode in ["limits", "gaussian"]
|
|
assert last_n_dims >= 0
|
|
assert output_max > output_min
|
|
|
|
# convert data to torch and type
|
|
if isinstance(data, zarr.Array):
|
|
data = data[:]
|
|
if isinstance(data, np.ndarray):
|
|
data = torch.from_numpy(data)
|
|
if dtype is not None:
|
|
data = data.type(dtype)
|
|
|
|
# convert shape
|
|
dim = 1
|
|
if last_n_dims > 0:
|
|
dim = np.prod(data.shape[-last_n_dims:])
|
|
data = data.reshape(-1, dim)
|
|
|
|
# compute input stats min max mean std
|
|
input_min, _ = data.min(axis=0)
|
|
input_max, _ = data.max(axis=0)
|
|
input_mean = data.mean(axis=0)
|
|
input_std = data.std(axis=0)
|
|
|
|
# compute scale and offset
|
|
if mode == "limits":
|
|
if fit_offset:
|
|
# unit scale
|
|
input_range = input_max - input_min
|
|
ignore_dim = input_range < range_eps
|
|
input_range[ignore_dim] = output_max - output_min
|
|
scale = (output_max - output_min) / input_range
|
|
offset = output_min - scale * input_min
|
|
offset[ignore_dim] = (output_max + output_min) / 2 - input_min[ignore_dim]
|
|
# ignore dims scaled to mean of output max and min
|
|
else:
|
|
# use this when data is pre-zero-centered.
|
|
assert output_max > 0
|
|
assert output_min < 0
|
|
# unit abs
|
|
output_abs = min(abs(output_min), abs(output_max))
|
|
input_abs = torch.maximum(torch.abs(input_min), torch.abs(input_max))
|
|
ignore_dim = input_abs < range_eps
|
|
input_abs[ignore_dim] = output_abs
|
|
# don't scale constant channels
|
|
scale = output_abs / input_abs
|
|
offset = torch.zeros_like(input_mean)
|
|
elif mode == "gaussian":
|
|
ignore_dim = input_std < range_eps
|
|
scale = input_std.clone()
|
|
scale[ignore_dim] = 1
|
|
scale = 1 / scale
|
|
|
|
offset = -input_mean * scale if fit_offset else torch.zeros_like(input_mean)
|
|
|
|
# save
|
|
this_params = nn.ParameterDict(
|
|
{
|
|
"scale": scale,
|
|
"offset": offset,
|
|
"input_stats": nn.ParameterDict(
|
|
{"min": input_min, "max": input_max, "mean": input_mean, "std": input_std}
|
|
),
|
|
}
|
|
)
|
|
for p in this_params.parameters():
|
|
p.requires_grad_(False)
|
|
return this_params
|
|
|
|
|
|
def _normalize(x, params, forward=True):
|
|
assert "scale" in params
|
|
if isinstance(x, np.ndarray):
|
|
x = torch.from_numpy(x)
|
|
scale = params["scale"]
|
|
offset = params["offset"]
|
|
x = x.to(device=scale.device, dtype=scale.dtype)
|
|
src_shape = x.shape
|
|
x = x.reshape(-1, scale.shape[0])
|
|
x = x * scale + offset if forward else (x - offset) / scale
|
|
x = x.reshape(src_shape)
|
|
return x
|
|
|
|
|
|
def test():
|
|
data = torch.zeros((100, 10, 9, 2)).uniform_()
|
|
data[..., 0, 0] = 0
|
|
|
|
normalizer = SingleFieldLinearNormalizer()
|
|
normalizer.fit(data, mode="limits", last_n_dims=2)
|
|
datan = normalizer.normalize(data)
|
|
assert datan.shape == data.shape
|
|
assert np.allclose(datan.max(), 1.0)
|
|
assert np.allclose(datan.min(), -1.0)
|
|
dataun = normalizer.unnormalize(datan)
|
|
assert torch.allclose(data, dataun, atol=1e-7)
|
|
|
|
_ = normalizer.get_input_stats()
|
|
_ = normalizer.get_output_stats()
|
|
|
|
normalizer = SingleFieldLinearNormalizer()
|
|
normalizer.fit(data, mode="limits", last_n_dims=1, fit_offset=False)
|
|
datan = normalizer.normalize(data)
|
|
assert datan.shape == data.shape
|
|
assert np.allclose(datan.max(), 1.0, atol=1e-3)
|
|
assert np.allclose(datan.min(), 0.0, atol=1e-3)
|
|
dataun = normalizer.unnormalize(datan)
|
|
assert torch.allclose(data, dataun, atol=1e-7)
|
|
|
|
data = torch.zeros((100, 10, 9, 2)).uniform_()
|
|
normalizer = SingleFieldLinearNormalizer()
|
|
normalizer.fit(data, mode="gaussian", last_n_dims=0)
|
|
datan = normalizer.normalize(data)
|
|
assert datan.shape == data.shape
|
|
assert np.allclose(datan.mean(), 0.0, atol=1e-3)
|
|
assert np.allclose(datan.std(), 1.0, atol=1e-3)
|
|
dataun = normalizer.unnormalize(datan)
|
|
assert torch.allclose(data, dataun, atol=1e-7)
|
|
|
|
# dict
|
|
data = torch.zeros((100, 10, 9, 2)).uniform_()
|
|
data[..., 0, 0] = 0
|
|
|
|
normalizer = LinearNormalizer()
|
|
normalizer.fit(data, mode="limits", last_n_dims=2)
|
|
datan = normalizer.normalize(data)
|
|
assert datan.shape == data.shape
|
|
assert np.allclose(datan.max(), 1.0)
|
|
assert np.allclose(datan.min(), -1.0)
|
|
dataun = normalizer.unnormalize(datan)
|
|
assert torch.allclose(data, dataun, atol=1e-7)
|
|
|
|
_ = normalizer.get_input_stats()
|
|
_ = normalizer.get_output_stats()
|
|
|
|
data = {
|
|
"obs": torch.zeros((1000, 128, 9, 2)).uniform_() * 512,
|
|
"action": torch.zeros((1000, 128, 2)).uniform_() * 512,
|
|
}
|
|
normalizer = LinearNormalizer()
|
|
normalizer.fit(data)
|
|
datan = normalizer.normalize(data)
|
|
dataun = normalizer.unnormalize(datan)
|
|
for key in data:
|
|
assert torch.allclose(data[key], dataun[key], atol=1e-4)
|
|
|
|
_ = normalizer.get_input_stats()
|
|
_ = normalizer.get_output_stats()
|
|
|
|
state_dict = normalizer.state_dict()
|
|
n = LinearNormalizer()
|
|
n.load_state_dict(state_dict)
|
|
datan = n.normalize(data)
|
|
dataun = n.unnormalize(datan)
|
|
for key in data:
|
|
assert torch.allclose(data[key], dataun[key], atol=1e-4)
|