Fix CI & update artifacts

This commit is contained in:
Simon Alibert 2024-06-10 13:36:39 +00:00
parent 9fba52332c
commit a65b67ff23
3 changed files with 8 additions and 6 deletions

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:649a5fbba619f76e9d36267a43b738bdee97834b78d46cbd7eb1f2a33d3ebd60
oid sha256:36f50697dacc82d52d1799dbc53c6c2fb722b9c0bd5bfa90a92dfa336591c74a
size 3686488

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a1de3eec04538a92cb762fe9f2011b8e51069684f8e8991bf609ff2f18cb7b14
oid sha256:c5a9f6d584a665f50d4a89e46b6d128b527b2ea2968f139a3c781a36242b4f47
size 22118896

View File

@ -11,7 +11,7 @@ from torchvision.transforms.v2 import functional as F # noqa: N812
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.transforms import RandomSubsetApply, SharpnessJitter, get_image_transforms
from lerobot.common.utils.utils import init_hydra_config, seeded_context
from tests.utils import DEFAULT_CONFIG_PATH
from tests.utils import DEFAULT_CONFIG_PATH, require_x86_64_kernel
ARTIFACT_DIR = Path("tests/data/save_image_transforms")
REPO_ID = "lerobot/aloha_mobile_shrimp"
@ -112,6 +112,7 @@ def test_get_image_transforms_max_num_transforms(img):
torch.testing.assert_close(tf_actual(img), tf_expected(img))
@require_x86_64_kernel
def test_get_image_transforms_random_order(img):
out_imgs = []
tf = get_image_transforms(
@ -122,8 +123,8 @@ def test_get_image_transforms_random_order(img):
sharpness_min_max=(0.5, 0.5),
random_order=True,
)
with seeded_context(1335):
for _ in range(20):
with seeded_context(1337):
for _ in range(10):
out_imgs.append(tf(img))
for i in range(1, len(out_imgs)):
@ -154,7 +155,8 @@ def test_backward_compatibility_torchvision(transform, img, single_transforms):
torch.testing.assert_close(actual, expected)
def test_backward_compatibility_default_yaml(img, default_transforms):
@require_x86_64_kernel
def test_backward_compatibility_default_config(img, default_transforms):
cfg = init_hydra_config(DEFAULT_CONFIG_PATH)
cfg_tf = cfg.image_transforms
default_tf = get_image_transforms(