Add mode to NormalizeTransform with mean_std or min_max (Not fully tested)
This commit is contained in:
parent
48ded3dbc7
commit
cbbed590a9
|
@ -1,4 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
@ -134,18 +135,19 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
else:
|
else:
|
||||||
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
|
storage = TensorStorage(TensorDict.load_memmap(self.root / dataset_id))
|
||||||
|
|
||||||
mean_std = self._compute_or_load_mean_std(storage)
|
stats = self._compute_or_load_stats(storage)
|
||||||
mean_std["next", "observation", "image"] = mean_std["observation", "image"]
|
stats["next", "observation", "image"] = stats["observation", "image"]
|
||||||
mean_std["next", "observation", "state"] = mean_std["observation", "state"]
|
stats["next", "observation", "state"] = stats["observation", "state"]
|
||||||
transform = NormalizeTransform(
|
transform = NormalizeTransform(
|
||||||
mean_std,
|
stats,
|
||||||
in_keys=[
|
in_keys=[
|
||||||
("observation", "image"),
|
# ("observation", "image"),
|
||||||
("observation", "state"),
|
("observation", "state"),
|
||||||
("next", "observation", "image"),
|
# ("next", "observation", "image"),
|
||||||
("next", "observation", "state"),
|
("next", "observation", "state"),
|
||||||
("action"),
|
("action"),
|
||||||
],
|
],
|
||||||
|
mode="min_max",
|
||||||
)
|
)
|
||||||
|
|
||||||
if writer is None:
|
if writer is None:
|
||||||
|
@ -282,7 +284,7 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
|
|
||||||
return TensorStorage(td_data.lock_())
|
return TensorStorage(td_data.lock_())
|
||||||
|
|
||||||
def _compute_mean_std(self, storage, num_batch=10, batch_size=32):
|
def _compute_stats(self, storage, num_batch=100, batch_size=32):
|
||||||
rb = TensorDictReplayBuffer(
|
rb = TensorDictReplayBuffer(
|
||||||
storage=storage,
|
storage=storage,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
|
@ -291,15 +293,27 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
batch = rb.sample()
|
batch = rb.sample()
|
||||||
image_mean = torch.zeros(batch["observation", "image"].shape[1])
|
image_mean = torch.zeros(batch["observation", "image"].shape[1])
|
||||||
image_std = torch.zeros(batch["observation", "image"].shape[1])
|
image_std = torch.zeros(batch["observation", "image"].shape[1])
|
||||||
|
image_max = -math.inf
|
||||||
|
image_min = math.inf
|
||||||
state_mean = torch.zeros(batch["observation", "state"].shape[1])
|
state_mean = torch.zeros(batch["observation", "state"].shape[1])
|
||||||
state_std = torch.zeros(batch["observation", "state"].shape[1])
|
state_std = torch.zeros(batch["observation", "state"].shape[1])
|
||||||
|
state_max = -math.inf
|
||||||
|
state_min = math.inf
|
||||||
action_mean = torch.zeros(batch["action"].shape[1])
|
action_mean = torch.zeros(batch["action"].shape[1])
|
||||||
action_std = torch.zeros(batch["action"].shape[1])
|
action_std = torch.zeros(batch["action"].shape[1])
|
||||||
|
action_max = -math.inf
|
||||||
|
action_min = math.inf
|
||||||
|
|
||||||
for _ in tqdm.tqdm(range(num_batch)):
|
for _ in tqdm.tqdm(range(num_batch)):
|
||||||
image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", reduction="mean")
|
image_mean += einops.reduce(batch["observation", "image"], "b c h w -> c", reduction="mean")
|
||||||
state_mean += batch["observation", "state"].mean(dim=0)
|
state_mean += batch["observation", "state"].mean(dim=0)
|
||||||
action_mean += batch["action"].mean(dim=0)
|
action_mean += batch["action"].mean(dim=0)
|
||||||
|
image_max = max(image_max, batch["observation", "image"].max().item())
|
||||||
|
image_min = min(image_min, batch["observation", "image"].min().item())
|
||||||
|
state_max = max(state_max, batch["observation", "state"].max().item())
|
||||||
|
state_min = min(state_min, batch["observation", "state"].min().item())
|
||||||
|
action_max = max(action_max, batch["action"].max().item())
|
||||||
|
action_min = min(action_min, batch["action"].min().item())
|
||||||
batch = rb.sample()
|
batch = rb.sample()
|
||||||
|
|
||||||
image_mean /= num_batch
|
image_mean /= num_batch
|
||||||
|
@ -311,6 +325,12 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
image_std += (image_mean_batch - image_mean) ** 2
|
image_std += (image_mean_batch - image_mean) ** 2
|
||||||
state_std += (batch["observation", "state"].mean(dim=0) - state_mean) ** 2
|
state_std += (batch["observation", "state"].mean(dim=0) - state_mean) ** 2
|
||||||
action_std += (batch["action"].mean(dim=0) - action_mean) ** 2
|
action_std += (batch["action"].mean(dim=0) - action_mean) ** 2
|
||||||
|
image_max = max(image_max, batch["observation", "image"].max().item())
|
||||||
|
image_min = min(image_min, batch["observation", "image"].min().item())
|
||||||
|
state_max = max(state_max, batch["observation", "state"].max().item())
|
||||||
|
state_min = min(state_min, batch["observation", "state"].min().item())
|
||||||
|
action_max = max(action_max, batch["action"].max().item())
|
||||||
|
action_min = min(action_min, batch["action"].min().item())
|
||||||
if i < num_batch - 1:
|
if i < num_batch - 1:
|
||||||
batch = rb.sample()
|
batch = rb.sample()
|
||||||
|
|
||||||
|
@ -318,25 +338,31 @@ class PushtExperienceReplay(TensorDictReplayBuffer):
|
||||||
state_std = torch.sqrt(state_std / num_batch)
|
state_std = torch.sqrt(state_std / num_batch)
|
||||||
action_std = torch.sqrt(action_std / num_batch)
|
action_std = torch.sqrt(action_std / num_batch)
|
||||||
|
|
||||||
mean_std = TensorDict(
|
stats = TensorDict(
|
||||||
{
|
{
|
||||||
("observation", "image", "mean"): image_mean[None, :, None, None],
|
("observation", "image", "mean"): image_mean[None, :, None, None],
|
||||||
("observation", "image", "std"): image_std[None, :, None, None],
|
("observation", "image", "std"): image_std[None, :, None, None],
|
||||||
|
("observation", "image", "max"): torch.tensor(image_max),
|
||||||
|
("observation", "image", "min"): torch.tensor(image_min),
|
||||||
("observation", "state", "mean"): state_mean[None, :],
|
("observation", "state", "mean"): state_mean[None, :],
|
||||||
("observation", "state", "std"): state_std[None, :],
|
("observation", "state", "std"): state_std[None, :],
|
||||||
|
("observation", "state", "max"): torch.tensor(state_max),
|
||||||
|
("observation", "state", "min"): torch.tensor(state_min),
|
||||||
("action", "mean"): action_mean[None, :],
|
("action", "mean"): action_mean[None, :],
|
||||||
("action", "std"): action_std[None, :],
|
("action", "std"): action_std[None, :],
|
||||||
|
("action", "max"): torch.tensor(action_max),
|
||||||
|
("action", "min"): torch.tensor(action_min),
|
||||||
},
|
},
|
||||||
batch_size=[],
|
batch_size=[],
|
||||||
)
|
)
|
||||||
return mean_std
|
return stats
|
||||||
|
|
||||||
def _compute_or_load_mean_std(self, storage) -> TensorDict:
|
def _compute_or_load_stats(self, storage) -> TensorDict:
|
||||||
mean_std_path = self.root / self.dataset_id / "mean_std.pth"
|
stats_path = self.root / self.dataset_id / "stats.pth"
|
||||||
if mean_std_path.exists():
|
if stats_path.exists():
|
||||||
mean_std = torch.load(mean_std_path)
|
stats = torch.load(stats_path)
|
||||||
else:
|
else:
|
||||||
logging.info(f"compute_mean_std and save to {mean_std_path}")
|
logging.info(f"compute_stats and save to {stats_path}")
|
||||||
mean_std = self._compute_mean_std(storage)
|
stats = self._compute_stats(storage)
|
||||||
torch.save(mean_std, mean_std_path)
|
torch.save(stats, stats_path)
|
||||||
return mean_std
|
return stats
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
from torchrl.envs.transforms import StepCounter, TransformedEnv
|
from torchrl.envs.transforms import StepCounter, TransformedEnv
|
||||||
|
|
||||||
from lerobot.common.envs.transforms import Prod
|
|
||||||
|
|
||||||
|
|
||||||
def make_env(cfg, transform=None):
|
def make_env(cfg, transform=None):
|
||||||
kwargs = {
|
kwargs = {
|
||||||
|
@ -28,12 +26,8 @@ def make_env(cfg, transform=None):
|
||||||
# limit rollout to max_steps
|
# limit rollout to max_steps
|
||||||
env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length))
|
env = TransformedEnv(env, StepCounter(max_steps=cfg.env.episode_length))
|
||||||
|
|
||||||
if cfg.env.name == "pusht":
|
|
||||||
# to ensure pusht is in [0,255] like simxarm
|
|
||||||
env.append_transform(Prod(in_keys=[("observation", "image")], prod=255.0))
|
|
||||||
|
|
||||||
if transform is not None:
|
if transform is not None:
|
||||||
# useful to add mean and std normalization
|
# useful to add normalization
|
||||||
env.append_transform(transform)
|
env.append_transform(transform)
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
|
@ -28,11 +28,12 @@ class NormalizeTransform(Transform):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
mean_std: TensorDictBase,
|
stats: TensorDictBase,
|
||||||
in_keys: Sequence[NestedKey] = None,
|
in_keys: Sequence[NestedKey] = None,
|
||||||
out_keys: Sequence[NestedKey] | None = None,
|
out_keys: Sequence[NestedKey] | None = None,
|
||||||
in_keys_inv: Sequence[NestedKey] | None = None,
|
in_keys_inv: Sequence[NestedKey] | None = None,
|
||||||
out_keys_inv: Sequence[NestedKey] | None = None,
|
out_keys_inv: Sequence[NestedKey] | None = None,
|
||||||
|
mode="mean_std",
|
||||||
):
|
):
|
||||||
if out_keys is None:
|
if out_keys is None:
|
||||||
out_keys = in_keys
|
out_keys = in_keys
|
||||||
|
@ -43,7 +44,14 @@ class NormalizeTransform(Transform):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv
|
in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv
|
||||||
)
|
)
|
||||||
self.mean_std = mean_std
|
self.stats = stats
|
||||||
|
assert mode in ["mean_std", "min_max"]
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
|
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
||||||
|
# _reset is called once when the environment reset to normalize the first observation
|
||||||
|
tensordict_reset = self._call(tensordict_reset)
|
||||||
|
return tensordict_reset
|
||||||
|
|
||||||
@dispatch(source="in_keys", dest="out_keys")
|
@dispatch(source="in_keys", dest="out_keys")
|
||||||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
||||||
|
@ -54,9 +62,17 @@ class NormalizeTransform(Transform):
|
||||||
# TODO(rcadene): don't know how to do `inkey not in td`
|
# TODO(rcadene): don't know how to do `inkey not in td`
|
||||||
if td.get(inkey, None) is None:
|
if td.get(inkey, None) is None:
|
||||||
continue
|
continue
|
||||||
mean = self.mean_std[inkey]["mean"]
|
if self.mode == "mean_std":
|
||||||
std = self.mean_std[inkey]["std"]
|
mean = self.stats[inkey]["mean"]
|
||||||
td[outkey] = (td[inkey] - mean) / (std + 1e-8)
|
std = self.stats[inkey]["std"]
|
||||||
|
td[outkey] = (td[inkey] - mean) / (std + 1e-8)
|
||||||
|
else:
|
||||||
|
min = self.stats[inkey]["min"]
|
||||||
|
max = self.stats[inkey]["max"]
|
||||||
|
# normalize to [0,1]
|
||||||
|
td[outkey] = (td[inkey] - min) / (max - min)
|
||||||
|
# normalize to [-1, 1]
|
||||||
|
td[outkey] = td[outkey] * 2 - 1
|
||||||
return td
|
return td
|
||||||
|
|
||||||
def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
|
def _inv_call(self, td: TensorDictBase) -> TensorDictBase:
|
||||||
|
@ -64,7 +80,13 @@ class NormalizeTransform(Transform):
|
||||||
# TODO(rcadene): don't know how to do `inkey not in td`
|
# TODO(rcadene): don't know how to do `inkey not in td`
|
||||||
if td.get(inkey, None) is None:
|
if td.get(inkey, None) is None:
|
||||||
continue
|
continue
|
||||||
mean = self.mean_std[inkey]["mean"]
|
if self.mode == "mean_std":
|
||||||
std = self.mean_std[inkey]["std"]
|
mean = self.stats[inkey]["mean"]
|
||||||
td[outkey] = td[inkey] * std + mean
|
std = self.stats[inkey]["std"]
|
||||||
|
td[outkey] = td[inkey] * std + mean
|
||||||
|
else:
|
||||||
|
min = self.stats[inkey]["min"]
|
||||||
|
max = self.stats[inkey]["max"]
|
||||||
|
td[outkey] = (td[inkey] + 1) / 2
|
||||||
|
td[outkey] = td[outkey] * (max - min) + min
|
||||||
return td
|
return td
|
||||||
|
|
|
@ -118,7 +118,7 @@ def eval(cfg: dict, out_dir=None):
|
||||||
offline_buffer = make_offline_buffer(cfg)
|
offline_buffer = make_offline_buffer(cfg)
|
||||||
|
|
||||||
logging.info("make_env")
|
logging.info("make_env")
|
||||||
env = make_env(cfg, transform=offline_buffer.transform)
|
env = make_env(cfg, transform=offline_buffer._transform)
|
||||||
|
|
||||||
if cfg.policy.pretrained_model_path:
|
if cfg.policy.pretrained_model_path:
|
||||||
policy = make_policy(cfg)
|
policy = make_policy(cfg)
|
||||||
|
|
Loading…
Reference in New Issue