fix more bugs in normalization

This commit is contained in:
Cadene 2024-03-11 11:03:13 +00:00
parent a7ef4a6a33
commit 816b2e9d63
5 changed files with 14 additions and 8 deletions

View File

@ -54,7 +54,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
return {
("observation", "state"): "b c -> 1 c",
("observation", "image"): "b c h w -> 1 c 1 1",
"action": "b c -> 1 c",
("action",): "b c -> 1 c",
}
@property
@ -73,8 +73,12 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
def num_episodes(self) -> int:
return len(self._storage._storage["episode"].unique())
@property
def transform(self):
return self._transform
def set_transform(self, transform):
self.transform = transform
self._transform = transform
def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict:
stats_path = self.data_dir / "stats.pth"

View File

@ -114,7 +114,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
def stats_patterns(self) -> dict:
d = {
("observation", "state"): "b c -> 1 c",
"action": "b c -> 1 c",
("action",): "b c -> 1 c",
}
for cam in CAMERAS[self.dataset_id]:
d[("observation", "image", cam)] = "b c h w -> 1 c 1 1"

View File

@ -87,9 +87,11 @@ def make_offline_buffer(
if normalize:
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
stats = offline_buffer.compute_or_load_stats()
# we only normalize the state and action, since the images are usually normalized inside the model for now (except for tdmpc: see the following)
in_keys = [("observation", "state"), ("action")]
if cfg.policy == "tdmpc":
if cfg.policy.name == "tdmpc":
for key in offline_buffer.image_keys:
# TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc
in_keys.append(key)
@ -97,7 +99,7 @@ def make_offline_buffer(
in_keys.append(("next", *key))
in_keys.append(("next", "observation", "state"))
if cfg.policy == "diffusion" and cfg.env.name == "pusht":
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)

View File

@ -126,7 +126,7 @@ def eval(cfg: dict, out_dir=None):
offline_buffer = make_offline_buffer(cfg)
logging.info("make_env")
env = make_env(cfg, transform=offline_buffer._transform)
env = make_env(cfg, transform=offline_buffer.transform)
if cfg.policy.pretrained_model_path:
policy = make_policy(cfg)

View File

@ -142,11 +142,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
online_buffer = TensorDictReplayBuffer(
storage=LazyMemmapStorage(100_000),
sampler=online_sampler,
transform=offline_buffer._transform,
transform=offline_buffer.transform,
)
logging.info("make_env")
env = make_env(cfg, transform=offline_buffer._transform)
env = make_env(cfg, transform=offline_buffer.transform)
logging.info("make_policy")
policy = make_policy(cfg)