from typing import Dict, Union import numpy as np import torch import torch.nn as nn import zarr from lerobot.common.policies.diffusion.model.dict_of_tensor_mixin import DictOfTensorMixin from lerobot.common.policies.diffusion.pytorch_utils import dict_apply class LinearNormalizer(DictOfTensorMixin): avaliable_modes = ["limits", "gaussian"] @torch.no_grad() def fit( self, data: Union[Dict, torch.Tensor, np.ndarray, zarr.Array], last_n_dims=1, dtype=torch.float32, mode="limits", output_max=1.0, output_min=-1.0, range_eps=1e-4, fit_offset=True, ): if isinstance(data, dict): for key, value in data.items(): self.params_dict[key] = _fit( value, last_n_dims=last_n_dims, dtype=dtype, mode=mode, output_max=output_max, output_min=output_min, range_eps=range_eps, fit_offset=fit_offset, ) else: self.params_dict["_default"] = _fit( data, last_n_dims=last_n_dims, dtype=dtype, mode=mode, output_max=output_max, output_min=output_min, range_eps=range_eps, fit_offset=fit_offset, ) def __call__(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor: return self.normalize(x) def __getitem__(self, key: str): return SingleFieldLinearNormalizer(self.params_dict[key]) def __setitem__(self, key: str, value: "SingleFieldLinearNormalizer"): self.params_dict[key] = value.params_dict def _normalize_impl(self, x, forward=True): if isinstance(x, dict): result = {} for key, value in x.items(): params = self.params_dict[key] result[key] = _normalize(value, params, forward=forward) return result else: if "_default" not in self.params_dict: raise RuntimeError("Not initialized") params = self.params_dict["_default"] return _normalize(x, params, forward=forward) def normalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor: return self._normalize_impl(x, forward=True) def unnormalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor: return self._normalize_impl(x, forward=False) def get_input_stats(self) -> Dict: if len(self.params_dict) == 0: raise RuntimeError("Not initialized") if len(self.params_dict) == 1 and "_default" in self.params_dict: return self.params_dict["_default"]["input_stats"] result = {} for key, value in self.params_dict.items(): if key != "_default": result[key] = value["input_stats"] return result def get_output_stats(self, key="_default"): input_stats = self.get_input_stats() if "min" in input_stats: # no dict return dict_apply(input_stats, self.normalize) result = {} for key, group in input_stats.items(): this_dict = {} for name, value in group.items(): this_dict[name] = self.normalize({key: value})[key] result[key] = this_dict return result class SingleFieldLinearNormalizer(DictOfTensorMixin): avaliable_modes = ["limits", "gaussian"] @torch.no_grad() def fit( self, data: Union[torch.Tensor, np.ndarray, zarr.Array], last_n_dims=1, dtype=torch.float32, mode="limits", output_max=1.0, output_min=-1.0, range_eps=1e-4, fit_offset=True, ): self.params_dict = _fit( data, last_n_dims=last_n_dims, dtype=dtype, mode=mode, output_max=output_max, output_min=output_min, range_eps=range_eps, fit_offset=fit_offset, ) @classmethod def create_fit(cls, data: Union[torch.Tensor, np.ndarray, zarr.Array], **kwargs): obj = cls() obj.fit(data, **kwargs) return obj @classmethod def create_manual( cls, scale: Union[torch.Tensor, np.ndarray], offset: Union[torch.Tensor, np.ndarray], input_stats_dict: Dict[str, Union[torch.Tensor, np.ndarray]], ): def to_tensor(x): if not isinstance(x, torch.Tensor): x = torch.from_numpy(x) x = x.flatten() return x # check for x in [offset] + list(input_stats_dict.values()): assert x.shape == scale.shape assert x.dtype == scale.dtype params_dict = nn.ParameterDict( { "scale": to_tensor(scale), "offset": to_tensor(offset), "input_stats": nn.ParameterDict(dict_apply(input_stats_dict, to_tensor)), } ) return cls(params_dict) @classmethod def create_identity(cls, dtype=torch.float32): scale = torch.tensor([1], dtype=dtype) offset = torch.tensor([0], dtype=dtype) input_stats_dict = { "min": torch.tensor([-1], dtype=dtype), "max": torch.tensor([1], dtype=dtype), "mean": torch.tensor([0], dtype=dtype), "std": torch.tensor([1], dtype=dtype), } return cls.create_manual(scale, offset, input_stats_dict) def normalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: return _normalize(x, self.params_dict, forward=True) def unnormalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: return _normalize(x, self.params_dict, forward=False) def get_input_stats(self): return self.params_dict["input_stats"] def get_output_stats(self): return dict_apply(self.params_dict["input_stats"], self.normalize) def __call__(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor: return self.normalize(x) def _fit( data: Union[torch.Tensor, np.ndarray, zarr.Array], last_n_dims=1, dtype=torch.float32, mode="limits", output_max=1.0, output_min=-1.0, range_eps=1e-4, fit_offset=True, ): assert mode in ["limits", "gaussian"] assert last_n_dims >= 0 assert output_max > output_min # convert data to torch and type if isinstance(data, zarr.Array): data = data[:] if isinstance(data, np.ndarray): data = torch.from_numpy(data) if dtype is not None: data = data.type(dtype) # convert shape dim = 1 if last_n_dims > 0: dim = np.prod(data.shape[-last_n_dims:]) data = data.reshape(-1, dim) # compute input stats min max mean std input_min, _ = data.min(axis=0) input_max, _ = data.max(axis=0) input_mean = data.mean(axis=0) input_std = data.std(axis=0) # compute scale and offset if mode == "limits": if fit_offset: # unit scale input_range = input_max - input_min ignore_dim = input_range < range_eps input_range[ignore_dim] = output_max - output_min scale = (output_max - output_min) / input_range offset = output_min - scale * input_min offset[ignore_dim] = (output_max + output_min) / 2 - input_min[ignore_dim] # ignore dims scaled to mean of output max and min else: # use this when data is pre-zero-centered. assert output_max > 0 assert output_min < 0 # unit abs output_abs = min(abs(output_min), abs(output_max)) input_abs = torch.maximum(torch.abs(input_min), torch.abs(input_max)) ignore_dim = input_abs < range_eps input_abs[ignore_dim] = output_abs # don't scale constant channels scale = output_abs / input_abs offset = torch.zeros_like(input_mean) elif mode == "gaussian": ignore_dim = input_std < range_eps scale = input_std.clone() scale[ignore_dim] = 1 scale = 1 / scale offset = -input_mean * scale if fit_offset else torch.zeros_like(input_mean) # save this_params = nn.ParameterDict( { "scale": scale, "offset": offset, "input_stats": nn.ParameterDict( {"min": input_min, "max": input_max, "mean": input_mean, "std": input_std} ), } ) for p in this_params.parameters(): p.requires_grad_(False) return this_params def _normalize(x, params, forward=True): assert "scale" in params if isinstance(x, np.ndarray): x = torch.from_numpy(x) scale = params["scale"] offset = params["offset"] x = x.to(device=scale.device, dtype=scale.dtype) src_shape = x.shape x = x.reshape(-1, scale.shape[0]) x = x * scale + offset if forward else (x - offset) / scale x = x.reshape(src_shape) return x def test(): data = torch.zeros((100, 10, 9, 2)).uniform_() data[..., 0, 0] = 0 normalizer = SingleFieldLinearNormalizer() normalizer.fit(data, mode="limits", last_n_dims=2) datan = normalizer.normalize(data) assert datan.shape == data.shape assert np.allclose(datan.max(), 1.0) assert np.allclose(datan.min(), -1.0) dataun = normalizer.unnormalize(datan) assert torch.allclose(data, dataun, atol=1e-7) _ = normalizer.get_input_stats() _ = normalizer.get_output_stats() normalizer = SingleFieldLinearNormalizer() normalizer.fit(data, mode="limits", last_n_dims=1, fit_offset=False) datan = normalizer.normalize(data) assert datan.shape == data.shape assert np.allclose(datan.max(), 1.0, atol=1e-3) assert np.allclose(datan.min(), 0.0, atol=1e-3) dataun = normalizer.unnormalize(datan) assert torch.allclose(data, dataun, atol=1e-7) data = torch.zeros((100, 10, 9, 2)).uniform_() normalizer = SingleFieldLinearNormalizer() normalizer.fit(data, mode="gaussian", last_n_dims=0) datan = normalizer.normalize(data) assert datan.shape == data.shape assert np.allclose(datan.mean(), 0.0, atol=1e-3) assert np.allclose(datan.std(), 1.0, atol=1e-3) dataun = normalizer.unnormalize(datan) assert torch.allclose(data, dataun, atol=1e-7) # dict data = torch.zeros((100, 10, 9, 2)).uniform_() data[..., 0, 0] = 0 normalizer = LinearNormalizer() normalizer.fit(data, mode="limits", last_n_dims=2) datan = normalizer.normalize(data) assert datan.shape == data.shape assert np.allclose(datan.max(), 1.0) assert np.allclose(datan.min(), -1.0) dataun = normalizer.unnormalize(datan) assert torch.allclose(data, dataun, atol=1e-7) _ = normalizer.get_input_stats() _ = normalizer.get_output_stats() data = { "obs": torch.zeros((1000, 128, 9, 2)).uniform_() * 512, "action": torch.zeros((1000, 128, 2)).uniform_() * 512, } normalizer = LinearNormalizer() normalizer.fit(data) datan = normalizer.normalize(data) dataun = normalizer.unnormalize(datan) for key in data: assert torch.allclose(data[key], dataun[key], atol=1e-4) _ = normalizer.get_input_stats() _ = normalizer.get_output_stats() state_dict = normalizer.state_dict() n = LinearNormalizer() n.load_state_dict(state_dict) datan = n.normalize(data) dataun = n.unnormalize(datan) for key in data: assert torch.allclose(data[key], dataun[key], atol=1e-4)