From 5a46b8a2a9b6584edc1e69e8829bc416d0c85b01 Mon Sep 17 00:00:00 2001 From: Cadene Date: Tue, 26 Mar 2024 10:24:46 +0000 Subject: [PATCH] fix tests --- lerobot/common/datasets/factory.py | 52 ++++++++++++++---------------- tests/test_available.py | 2 +- 2 files changed, 25 insertions(+), 29 deletions(-) diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 30fc5258..4212e023 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -103,40 +103,36 @@ def make_offline_buffer( else: img_keys = offline_buffer.image_keys - if normalize: - transforms = [Prod(in_keys=img_keys, prod=1 / 255)] + if normalize: + transforms = [Prod(in_keys=img_keys, prod=1 / 255)] - # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, - # min_max_from_spec - stats = offline_buffer.compute_or_load_stats() if stats_path is None else torch.load(stats_path) + # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, + # min_max_from_spec + stats = offline_buffer.compute_or_load_stats() if stats_path is None else torch.load(stats_path) - # we only normalize the state and action, since the images are usually normalized inside the model for - # now (except for tdmpc: see the following) - in_keys = [("observation", "state"), ("action")] + # we only normalize the state and action, since the images are usually normalized inside the model for + # now (except for tdmpc: see the following) + in_keys = [("observation", "state"), ("action")] - if cfg.policy.name == "tdmpc": - # TODO(rcadene): we add img_keys to the keys to normalize for tdmpc only, since diffusion and act policies normalize the image inside the model for now - in_keys += img_keys - # TODO(racdene): since we use next observations in tdmpc, we also add them to the normalization. We are wasting a bit of compute on this for now. - in_keys += [("next", *key) for key in img_keys] - in_keys.append(("next", "observation", "state")) + if cfg.policy.name == "tdmpc": + # TODO(rcadene): we add img_keys to the keys to normalize for tdmpc only, since diffusion and act policies normalize the image inside the model for now + in_keys += img_keys + # TODO(racdene): since we use next observations in tdmpc, we also add them to the normalization. We are wasting a bit of compute on this for now. + in_keys += [("next", *key) for key in img_keys] + in_keys.append(("next", "observation", "state")) - if cfg.policy.name == "diffusion" and cfg.env.name == "pusht": - # TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this - stats["observation", "state", "min"] = torch.tensor( - [13.456424, 32.938293], dtype=torch.float32 - ) - stats["observation", "state", "max"] = torch.tensor( - [496.14618, 510.9579], dtype=torch.float32 - ) - stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) - stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) + if cfg.policy.name == "diffusion" and cfg.env.name == "pusht": + # TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this + stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32) + stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32) + stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) + stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) - # TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std - normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max" - transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode)) + # TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std + normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max" + transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode)) - offline_buffer.set_transform(transforms) + offline_buffer.set_transform(transforms) if not overwrite_sampler: index = torch.arange(0, offline_buffer.num_samples, 1) diff --git a/tests/test_available.py b/tests/test_available.py index 83382633..9cc91efa 100644 --- a/tests/test_available.py +++ b/tests/test_available.py @@ -17,7 +17,7 @@ import lerobot from lerobot.common.envs.aloha.env import AlohaEnv from lerobot.common.envs.pusht.env import PushtEnv -from lerobot.common.envs.simxarm import SimxarmEnv +from lerobot.common.envs.simxarm.env import SimxarmEnv from lerobot.common.datasets.simxarm import SimxarmDataset from lerobot.common.datasets.aloha import AlohaDataset