This commit is contained in:
Simon Alibert 2024-06-06 15:23:49 +00:00
parent bdc0ebd36a
commit a86f387554
3 changed files with 102 additions and 8 deletions

View File

@ -98,7 +98,7 @@ class RangeRandomSharpness(Transform):
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor) return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
def make_transforms(cfg): def make_transforms(cfg, to_dtype: torch.dtype = torch.float32):
transforms_list = [ transforms_list = [
v2.ColorJitter(brightness=(cfg.brightness.min, cfg.brightness.max)), v2.ColorJitter(brightness=(cfg.brightness.min, cfg.brightness.max)),
v2.ColorJitter(contrast=(cfg.contrast.min, cfg.contrast.max)), v2.ColorJitter(contrast=(cfg.contrast.min, cfg.contrast.max)),
@ -118,4 +118,6 @@ def make_transforms(cfg):
transforms_list, p=transforms_weights, n_subset=cfg.max_num_transforms, random_order=cfg.random_order transforms_list, p=transforms_weights, n_subset=cfg.max_num_transforms, random_order=cfg.random_order
) )
return v2.Compose([transforms, v2.ToDtype(torch.float32, scale=True)]) # return transforms
# return v2.Compose([transforms, v2.ToDtype(to_dtype, scale=True)])
return v2.Compose([transforms, v2.ToDtype(to_dtype, scale=False)])

View File

@ -1,6 +1,8 @@
from pathlib import Path from pathlib import Path
from torchvision.transforms import ToPILImage, v2 import torch
from torchvision.transforms import v2
from safetensors.torch import save_file
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.transforms import RangeRandomSharpness from lerobot.common.datasets.transforms import RangeRandomSharpness
@ -9,7 +11,7 @@ from lerobot.common.utils.utils import seeded_context
DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml" DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
ARTIFACT_DIR = "tests/data/save_image_transforms" ARTIFACT_DIR = "tests/data/save_image_transforms"
SEED = 1336 SEED = 1336
to_pil = ToPILImage() to_pil = v2.ToPILImage()
def main(repo_id): def main(repo_id):
@ -30,12 +32,15 @@ def main(repo_id):
"sharpness": RangeRandomSharpness(0.0, 2.0), "sharpness": RangeRandomSharpness(0.0, 2.0),
} }
# Apply each single transformation # frames = {"original_frame": original_frame}
for name, transform in transforms.items(): for name, transform in transforms.items():
with seeded_context(SEED): with seeded_context(SEED):
# transform = v2.Compose([transform, v2.ToDtype(torch.float32, scale=True)])
transformed_frame = transform(original_frame) transformed_frame = transform(original_frame)
# frames[name] = transform(original_frame)
to_pil(transformed_frame).save(output_dir / f"{SEED}_{name}.png", quality=100) to_pil(transformed_frame).save(output_dir / f"{SEED}_{name}.png", quality=100)
# save_file(frames, output_dir / f"transformed_frames_{SEED}.safetensors")
if __name__ == "__main__": if __name__ == "__main__":
repo_id = "lerobot/aloha_mobile_shrimp" repo_id = "lerobot/aloha_mobile_shrimp"

View File

@ -1,9 +1,17 @@
from pathlib import Path
import numpy as np
from omegaconf import OmegaConf
import pytest import pytest
import torch import torch
from torchvision.transforms import v2 from torchvision.transforms import v2
from torchvision.transforms.v2 import functional as F # noqa: N812 from torchvision.transforms.v2 import functional as F # noqa: N812
from PIL import Image
from safetensors.torch import load_file
from lerobot.common.datasets.transforms import RandomSubsetApply, RangeRandomSharpness from lerobot.common.datasets.transforms import RandomSubsetApply, RangeRandomSharpness, make_transforms
from lerobot.common.datasets.utils import flatten_dict
from lerobot.common.utils.utils import init_hydra_config, seeded_context
from tests.utils import DEFAULT_CONFIG_PATH
class TestRandomSubsetApply: class TestRandomSubsetApply:
@ -76,5 +84,84 @@ class TestRangeRandomSharpness:
class TestMakeTransforms: class TestMakeTransforms:
... @pytest.fixture(autouse=True)
# TODO def setup(self):
"""Seed should be the same as the one that was used to generate artifacts"""
self.config = {
"enable": True,
"max_num_transforms": 1,
"random_order": False,
"brightness": {
"weight": 0,
"min": 0.0,
"max": 2.0
},
"contrast": {
"weight": 0,
"min": 0.0,
"max": 2.0,
},
"saturation": {
"weight": 0,
"min": 0.0,
"max": 2.0,
},
"hue": {
"weight": 0,
"min": -0.5,
"max": 0.5,
},
"sharpness": {
"weight": 0,
"min": 0.0,
"max": 2.0,
},
}
self.path = Path("tests/data/save_image_transforms")
# self.expected_frames = load_file(self.path / f"transformed_frames_1336.safetensors")
self.original_frame = self.load_png_to_tensor(self.path / "original_frame.png")
# self.original_frame = self.expected_frames["original_frame"]
self.transforms = {
"brightness": v2.ColorJitter(brightness=(0.0, 2.0)),
"contrast": v2.ColorJitter(contrast=(0.0, 2.0)),
"saturation": v2.ColorJitter(saturation=(0.0, 2.0)),
"hue": v2.ColorJitter(hue=(-0.5, 0.5)),
"sharpness": RangeRandomSharpness(0.0, 2.0),
}
@staticmethod
def load_png_to_tensor(path: Path):
return torch.from_numpy(np.array(Image.open(path).convert('RGB'))).permute(2, 0, 1)
@pytest.mark.parametrize(
"transform_key, seed",
[
("brightness", 1336),
("contrast", 1336),
("saturation", 1336),
("hue", 1336),
("sharpness", 1336),
]
)
def test_single_transform(self, transform_key, seed):
config = self.config
config[transform_key]["weight"] = 1
cfg = OmegaConf.create(config)
transform = make_transforms(cfg, to_dtype=torch.uint8)
# expected_t = self.transforms[transform_key]
with seeded_context(seed):
actual = transform(self.original_frame)
# torch.manual_seed(42)
# actual = actual_t(self.original_frame)
# torch.manual_seed(42)
# expected = expected_t(self.original_frame)
# with seeded_context(1336):
# expected = expected_t(self.original_frame)
expected = self.load_png_to_tensor(self.path / f"{seed}_{transform_key}.png")
# # expected = self.expected_frames[transform_key]
to_pil = v2.ToPILImage()
to_pil(actual).save(self.path / f"{seed}_{transform_key}_test.png", quality=100)
torch.testing.assert_close(actual, expected)