Merge pull request #17 from Cadene/user/rcadene/2024_03_11_bugfix_compute_stats

Fix bugs with normalization
This commit is contained in:
Remi 2024-03-11 13:44:07 +01:00 committed by GitHub
commit fab2b3240b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 27 additions and 10 deletions

View File

@ -13,6 +13,7 @@ from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SliceSampler
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
from torchrl.envs.transforms.transforms import Compose
class AbstractExperienceReplay(TensorDictReplayBuffer):
@ -54,7 +55,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
return {
("observation", "state"): "b c -> 1 c",
("observation", "image"): "b c h w -> 1 c 1 1",
("action"): "b c -> 1 c",
("action",): "b c -> 1 c",
}
@property
@ -73,8 +74,16 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
def num_episodes(self) -> int:
return len(self._storage._storage["episode"].unique())
@property
def transform(self):
return self._transform
def set_transform(self, transform):
self.transform = transform
if not isinstance(transform, Compose):
# required since torchrl calls `len(self._transform)` downstream
self._transform = Compose(transform)
else:
self._transform = transform
def compute_or_load_stats(self, num_batch=100, batch_size=32) -> TensorDict:
stats_path = self.data_dir / "stats.pth"

View File

@ -114,7 +114,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
def stats_patterns(self) -> dict:
d = {
("observation", "state"): "b c -> 1 c",
("action"): "b c -> 1 c",
("action",): "b c -> 1 c",
}
for cam in CAMERAS[self.dataset_id]:
d[("observation", "image", cam)] = "b c h w -> 1 c 1 1"

View File

@ -87,9 +87,11 @@ def make_offline_buffer(
if normalize:
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
stats = offline_buffer.compute_or_load_stats()
# we only normalize the state and action, since the images are usually normalized inside the model for now (except for tdmpc: see the following)
in_keys = [("observation", "state"), ("action")]
if cfg.policy == "tdmpc":
if cfg.policy.name == "tdmpc":
for key in offline_buffer.image_keys:
# TODO(rcadene): imagenet normalization is applied inside diffusion policy, but no normalization inside tdmpc
in_keys.append(key)
@ -97,7 +99,7 @@ def make_offline_buffer(
in_keys.append(("next", *key))
in_keys.append(("next", "observation", "state"))
if cfg.policy == "diffusion" and cfg.env.name == "pusht":
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)

View File

@ -1,4 +1,4 @@
from torchrl.envs.transforms import StepCounter, TransformedEnv
from torchrl.envs.transforms import Compose, StepCounter, Transform, TransformedEnv
def make_env(cfg, transform=None):
@ -33,7 +33,13 @@ def make_env(cfg, transform=None):
if transform is not None:
# useful to add normalization
env.append_transform(transform)
if isinstance(transform, Compose):
for tf in transform:
env.append_transform(tf.clone())
elif isinstance(transform, Transform):
env.append_transform(transform.clone())
else:
raise NotImplementedError()
return env

View File

@ -126,7 +126,7 @@ def eval(cfg: dict, out_dir=None):
offline_buffer = make_offline_buffer(cfg)
logging.info("make_env")
env = make_env(cfg, transform=offline_buffer._transform)
env = make_env(cfg, transform=offline_buffer.transform)
if cfg.policy.pretrained_model_path:
policy = make_policy(cfg)

View File

@ -142,11 +142,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
online_buffer = TensorDictReplayBuffer(
storage=LazyMemmapStorage(100_000),
sampler=online_sampler,
transform=offline_buffer._transform,
transform=offline_buffer.transform,
)
logging.info("make_env")
env = make_env(cfg, transform=offline_buffer._transform)
env = make_env(cfg, transform=offline_buffer.transform)
logging.info("make_policy")
policy = make_policy(cfg)