fix bug in compute_stats for action normalization

This commit is contained in:
Cadene 2024-03-11 09:37:56 +00:00
parent 4cc7e1539e
commit a7ef4a6a33
2 changed files with 2 additions and 2 deletions

View File

@ -54,7 +54,7 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
return {
("observation", "state"): "b c -> 1 c",
("observation", "image"): "b c h w -> 1 c 1 1",
("action"): "b c -> 1 c",
"action": "b c -> 1 c",
}
@property

View File

@ -114,7 +114,7 @@ class AlohaExperienceReplay(AbstractExperienceReplay):
def stats_patterns(self) -> dict:
d = {
("observation", "state"): "b c -> 1 c",
("action"): "b c -> 1 c",
"action": "b c -> 1 c",
}
for cam in CAMERAS[self.dataset_id]:
d[("observation", "image", cam)] = "b c h w -> 1 c 1 1"