101 lines
3.5 KiB
Python
101 lines
3.5 KiB
Python
import torch
|
|
from torchvision.transforms.v2 import Compose, Transform
|
|
|
|
|
|
def apply_inverse_transform(item, transform):
|
|
transforms = transform.transforms if isinstance(transform, Compose) else [transform]
|
|
for tf in transforms[::-1]:
|
|
if tf.invertible:
|
|
item = tf.inverse_transform(item)
|
|
else:
|
|
raise ValueError(f"Inverse transform called on a non invertible transform ({tf}).")
|
|
return item
|
|
|
|
|
|
class Prod(Transform):
|
|
invertible = True
|
|
|
|
def __init__(self, in_keys: list[str], prod: float):
|
|
super().__init__()
|
|
self.in_keys = in_keys
|
|
self.prod = prod
|
|
self.original_dtypes = {}
|
|
|
|
def forward(self, item):
|
|
for key in self.in_keys:
|
|
if key not in item:
|
|
continue
|
|
self.original_dtypes[key] = item[key].dtype
|
|
item[key] = item[key].type(torch.float32) * self.prod
|
|
return item
|
|
|
|
def inverse_transform(self, item):
|
|
for key in self.in_keys:
|
|
if key not in item:
|
|
continue
|
|
item[key] = (item[key] / self.prod).type(self.original_dtypes[key])
|
|
return item
|
|
|
|
# def transform_observation_spec(self, obs_spec):
|
|
# for key in self.in_keys:
|
|
# if obs_spec.get(key, None) is None:
|
|
# continue
|
|
# obs_spec[key].space.high = obs_spec[key].space.high.type(torch.float32) * self.prod
|
|
# obs_spec[key].space.low = obs_spec[key].space.low.type(torch.float32) * self.prod
|
|
# obs_spec[key].dtype = torch.float32
|
|
# return obs_spec
|
|
|
|
|
|
class NormalizeTransform(Transform):
|
|
invertible = True
|
|
|
|
def __init__(
|
|
self,
|
|
stats: dict,
|
|
in_keys: list[str] = None,
|
|
out_keys: list[str] | None = None,
|
|
in_keys_inv: list[str] | None = None,
|
|
out_keys_inv: list[str] | None = None,
|
|
mode="mean_std",
|
|
):
|
|
super().__init__()
|
|
self.in_keys = in_keys
|
|
self.out_keys = in_keys if out_keys is None else out_keys
|
|
self.in_keys_inv = self.out_keys if in_keys_inv is None else in_keys_inv
|
|
self.out_keys_inv = self.in_keys if out_keys_inv is None else out_keys_inv
|
|
self.stats = stats
|
|
assert mode in ["mean_std", "min_max"]
|
|
self.mode = mode
|
|
|
|
def forward(self, item):
|
|
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
|
|
if inkey not in item:
|
|
continue
|
|
if self.mode == "mean_std":
|
|
mean = self.stats[inkey]["mean"]
|
|
std = self.stats[inkey]["std"]
|
|
item[outkey] = (item[inkey] - mean) / (std + 1e-8)
|
|
else:
|
|
min = self.stats[inkey]["min"]
|
|
max = self.stats[inkey]["max"]
|
|
# normalize to [0,1]
|
|
item[outkey] = (item[inkey] - min) / (max - min)
|
|
# normalize to [-1, 1]
|
|
item[outkey] = item[outkey] * 2 - 1
|
|
return item
|
|
|
|
def inverse_transform(self, item):
|
|
for inkey, outkey in zip(self.in_keys_inv, self.out_keys_inv, strict=False):
|
|
if inkey not in item:
|
|
continue
|
|
if self.mode == "mean_std":
|
|
mean = self.stats[inkey]["mean"]
|
|
std = self.stats[inkey]["std"]
|
|
item[outkey] = item[inkey] * std + mean
|
|
else:
|
|
min = self.stats[inkey]["min"]
|
|
max = self.stats[inkey]["max"]
|
|
item[outkey] = (item[inkey] + 1) / 2
|
|
item[outkey] = item[outkey] * (max - min) + min
|
|
return item
|