fix tests

This commit is contained in:
Cadene 2024-03-26 10:24:46 +00:00
parent 4a8c5e238e
commit 5a46b8a2a9
2 changed files with 25 additions and 29 deletions

View File

@ -123,12 +123,8 @@ def make_offline_buffer(
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht": if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this # TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
stats["observation", "state", "min"] = torch.tensor( stats["observation", "state", "min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
[13.456424, 32.938293], dtype=torch.float32 stats["observation", "state", "max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
)
stats["observation", "state", "max"] = torch.tensor(
[496.14618, 510.9579], dtype=torch.float32
)
stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) stats["action", "min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) stats["action", "max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)

View File

@ -17,7 +17,7 @@ import lerobot
from lerobot.common.envs.aloha.env import AlohaEnv from lerobot.common.envs.aloha.env import AlohaEnv
from lerobot.common.envs.pusht.env import PushtEnv from lerobot.common.envs.pusht.env import PushtEnv
from lerobot.common.envs.simxarm import SimxarmEnv from lerobot.common.envs.simxarm.env import SimxarmEnv
from lerobot.common.datasets.simxarm import SimxarmDataset from lerobot.common.datasets.simxarm import SimxarmDataset
from lerobot.common.datasets.aloha import AlohaDataset from lerobot.common.datasets.aloha import AlohaDataset