Don't use batch dimension in normalization

This commit is contained in:
Alexander Soare 2024-03-19 18:52:36 +00:00
parent 099a465367
commit 9946203071
4 changed files with 6 additions and 6 deletions

View File

@ -49,9 +49,9 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
@property @property
def stats_patterns(self) -> dict: def stats_patterns(self) -> dict:
return { return {
("observation", "state"): "b c -> 1 c", ("observation", "state"): "b c -> c",
("observation", "image"): "b c h w -> 1 c 1 1", ("observation", "image"): "b c h w -> c 1 1",
("action",): "b c -> 1 c", ("action",): "b c -> c",
} }
@property @property

View File

@ -113,11 +113,11 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
@property @property
def stats_patterns(self) -> dict: def stats_patterns(self) -> dict:
d = { d = {
("observation", "state"): "b c -> 1 c", ("observation", "state"): "b c -> c",
("action",): "b c -> 1 c", ("action",): "b c -> c",
} }
for cam in CAMERAS[self.dataset_id]: for cam in CAMERAS[self.dataset_id]:
d[("observation", "image", cam)] = "b c h w -> 1 c 1 1" d[("observation", "image", cam)] = "b c h w -> c 1 1"
return d return d
@property @property

Binary file not shown.