fix tests

This commit is contained in:
Cadene 2024-03-26 10:24:46 +00:00
parent 4a8c5e238e
commit 5a46b8a2a9
2 changed files with 25 additions and 29 deletions

View File

@ -103,40 +103,36 @@ def make_offline_buffer(
else: else:
img_keys = offline_buffer.image_keys img_keys = offline_buffer.image_keys
if normalize: if normalize:
transforms = [Prod(in_keys=img_keys, prod=1 / 255)] transforms = [Prod(in_keys=img_keys, prod=1 / 255)]
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, # TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
# min_max_from_spec # min_max_from_spec
stats = offline_buffer.compute_or_load_stats() if stats_path is None else torch.load(stats_path) stats = offline_buffer.compute_or_load_stats() if stats_path is None else torch.load(stats_path)
# we only normalize the state and action, since the images are usually normalized inside the model for # we only normalize the state and action, since the images are usually normalized inside the model for
# now (except for tdmpc: see the following) # now (except for tdmpc: see the following)
in_keys = [("observation", "state"), ("action")] in_keys = [("observation", "state"), ("action")]
if cfg.policy.name == "tdmpc": if cfg.policy.name == "tdmpc":
# TODO(rcadene): we add img_keys to the keys to normalize for tdmpc only, since diffusion and act policies normalize the image inside the model for now # TODO(rcadene): we add img_keys to the keys to normalize for tdmpc only, since diffusion and act policies normalize the image inside the model for now
in_keys += img_keys in_keys += img_keys
# TODO(racdene): since we use next observations in tdmpc, we also add them to the normalization. We are wasting a bit of compute on this for now. # TODO(racdene): since we use next observations in tdmpc, we also add them to the normalization. We are wasting a bit of compute on this for now.
in_keys += [("next", *key) for key in img_keys] in_keys += [("next", *key) for key in img_keys]
in_keys.append(("next", "observation", "state")) in_keys.append(("next", "observation", "state"))
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht": if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this # TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
stats["observation", "state", "min"] = torch.tensor( stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
[13.456424, 32.938293], dtype=torch.float32 stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
) stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["observation", "state", "max"] = torch.tensor( stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
[496.14618, 510.9579], dtype=torch.float32
)
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std # TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max" normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode)) transforms.append(NormalizeTransform(stats, in_keys, mode=normalization_mode))
offline_buffer.set_transform(transforms) offline_buffer.set_transform(transforms)
if not overwrite_sampler: if not overwrite_sampler:
index = torch.arange(0, offline_buffer.num_samples, 1) index = torch.arange(0, offline_buffer.num_samples, 1)

View File

@ -17,7 +17,7 @@ import lerobot
from lerobot.common.envs.aloha.env import AlohaEnv from lerobot.common.envs.aloha.env import AlohaEnv
from lerobot.common.envs.pusht.env import PushtEnv from lerobot.common.envs.pusht.env import PushtEnv
from lerobot.common.envs.simxarm import SimxarmEnv from lerobot.common.envs.simxarm.env import SimxarmEnv
from lerobot.common.datasets.simxarm import SimxarmDataset from lerobot.common.datasets.simxarm import SimxarmDataset
from lerobot.common.datasets.aloha import AlohaDataset from lerobot.common.datasets.aloha import AlohaDataset