fix more bugs in normalization
This commit is contained in:
parent
a7ef4a6a33
commit
816b2e9d63
|
@ -54,7 +54,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
|||
return {
|
||||
("observation", "state"): "b c -> 1 c",
|
||||
("observation", "image"): "b c h w -> 1 c 1 1",
|
||||
"action": "b c -> 1 c",
|
||||
("action",): "b c -> 1 c",
|
||||
}
|
||||
|
||||
@property
|
||||
|
@ -73,8 +73,12 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
|||
def num_episodes(self) -> int:
|
||||
return len(self._storage._storage["episode"].unique())
|
||||
|
||||
@property
|
||||
def transform(self):
|
||||
return self._transform
|
||||
|
||||
def set_transform(self, transform):
|
||||
self.transform = transform
|
||||
self._transform = transform
|
||||
|
||||
def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict:
|
||||
stats_path = self.data_dir / "stats.pth"
|
||||
|
|
|
@ -114,7 +114,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
|
|||
def stats_patterns(self) -> dict:
|
||||
d = {
|
||||
("observation", "state"): "b c -> 1 c",
|
||||
"action": "b c -> 1 c",
|
||||
("action",): "b c -> 1 c",
|
||||
}
|
||||
for cam in CAMERAS[self.dataset_id]:
|
||||
d[("observation", "image", cam)] = "b c h w -> 1 c 1 1"
|
||||
|
|
|
@ -87,9 +87,11 @@ def make_offline_buffer(
|
|||
if normalize:
|
||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
|
||||
stats = offline_buffer.compute_or_load_stats()
|
||||
|
||||
# we only normalize the state and action, since the images are usually normalized inside the model for now (except for tdmpc: see the following)
|
||||
in_keys = [("observation", "state"), ("action")]
|
||||
|
||||
if cfg.policy == "tdmpc":
|
||||
if cfg.policy.name == "tdmpc":
|
||||
for key in offline_buffer.image_keys:
|
||||
# TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc
|
||||
in_keys.append(key)
|
||||
|
@ -97,7 +99,7 @@ def make_offline_buffer(
|
|||
in_keys.append(("next", *key))
|
||||
in_keys.append(("next", "observation", "state"))
|
||||
|
||||
if cfg.policy == "diffusion" and cfg.env.name == "pusht":
|
||||
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
||||
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
|
||||
stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
|
||||
stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
|
||||
|
|
|
@ -126,7 +126,7 @@ def eval(cfg: dict, out_dir=None):
|
|||
offline_buffer = make_offline_buffer(cfg)
|
||||
|
||||
logging.info("make_env")
|
||||
env = make_env(cfg, transform=offline_buffer._transform)
|
||||
env = make_env(cfg, transform=offline_buffer.transform)
|
||||
|
||||
if cfg.policy.pretrained_model_path:
|
||||
policy = make_policy(cfg)
|
||||
|
|
|
@ -142,11 +142,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||
online_buffer = TensorDictReplayBuffer(
|
||||
storage=LazyMemmapStorage(100_000),
|
||||
sampler=online_sampler,
|
||||
transform=offline_buffer._transform,
|
||||
transform=offline_buffer.transform,
|
||||
)
|
||||
|
||||
logging.info("make_env")
|
||||
env = make_env(cfg, transform=offline_buffer._transform)
|
||||
env = make_env(cfg, transform=offline_buffer.transform)
|
||||
|
||||
logging.info("make_policy")
|
||||
policy = make_policy(cfg)
|
||||
|
|
Loading…
Reference in New Issue