WIP
This commit is contained in:
parent
bdc0ebd36a
commit
a86f387554
|
@ -98,7 +98,7 @@ class RangeRandomSharpness(Transform):
|
||||||
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
|
return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor)
|
||||||
|
|
||||||
|
|
||||||
def make_transforms(cfg):
|
def make_transforms(cfg, to_dtype: torch.dtype = torch.float32):
|
||||||
transforms_list = [
|
transforms_list = [
|
||||||
v2.ColorJitter(brightness=(cfg.brightness.min, cfg.brightness.max)),
|
v2.ColorJitter(brightness=(cfg.brightness.min, cfg.brightness.max)),
|
||||||
v2.ColorJitter(contrast=(cfg.contrast.min, cfg.contrast.max)),
|
v2.ColorJitter(contrast=(cfg.contrast.min, cfg.contrast.max)),
|
||||||
|
@ -118,4 +118,6 @@ def make_transforms(cfg):
|
||||||
transforms_list, p=transforms_weights, n_subset=cfg.max_num_transforms, random_order=cfg.random_order
|
transforms_list, p=transforms_weights, n_subset=cfg.max_num_transforms, random_order=cfg.random_order
|
||||||
)
|
)
|
||||||
|
|
||||||
return v2.Compose([transforms, v2.ToDtype(torch.float32, scale=True)])
|
# return transforms
|
||||||
|
# return v2.Compose([transforms, v2.ToDtype(to_dtype, scale=True)])
|
||||||
|
return v2.Compose([transforms, v2.ToDtype(to_dtype, scale=False)])
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from torchvision.transforms import ToPILImage, v2
|
import torch
|
||||||
|
from torchvision.transforms import v2
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.common.datasets.transforms import RangeRandomSharpness
|
from lerobot.common.datasets.transforms import RangeRandomSharpness
|
||||||
|
@ -9,7 +11,7 @@ from lerobot.common.utils.utils import seeded_context
|
||||||
DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
|
DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
|
||||||
ARTIFACT_DIR = "tests/data/save_image_transforms"
|
ARTIFACT_DIR = "tests/data/save_image_transforms"
|
||||||
SEED = 1336
|
SEED = 1336
|
||||||
to_pil = ToPILImage()
|
to_pil = v2.ToPILImage()
|
||||||
|
|
||||||
|
|
||||||
def main(repo_id):
|
def main(repo_id):
|
||||||
|
@ -30,12 +32,15 @@ def main(repo_id):
|
||||||
"sharpness": RangeRandomSharpness(0.0, 2.0),
|
"sharpness": RangeRandomSharpness(0.0, 2.0),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Apply each single transformation
|
# frames = {"original_frame": original_frame}
|
||||||
for name, transform in transforms.items():
|
for name, transform in transforms.items():
|
||||||
with seeded_context(SEED):
|
with seeded_context(SEED):
|
||||||
|
# transform = v2.Compose([transform, v2.ToDtype(torch.float32, scale=True)])
|
||||||
transformed_frame = transform(original_frame)
|
transformed_frame = transform(original_frame)
|
||||||
|
# frames[name] = transform(original_frame)
|
||||||
to_pil(transformed_frame).save(output_dir / f"{SEED}_{name}.png", quality=100)
|
to_pil(transformed_frame).save(output_dir / f"{SEED}_{name}.png", quality=100)
|
||||||
|
|
||||||
|
# save_file(frames, output_dir / f"transformed_frames_{SEED}.safetensors")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
repo_id = "lerobot/aloha_mobile_shrimp"
|
repo_id = "lerobot/aloha_mobile_shrimp"
|
||||||
|
|
|
@ -1,9 +1,17 @@
|
||||||
|
from pathlib import Path
|
||||||
|
import numpy as np
|
||||||
|
from omegaconf import OmegaConf
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from torchvision.transforms import v2
|
from torchvision.transforms import v2
|
||||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||||
|
from PIL import Image
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
from lerobot.common.datasets.transforms import RandomSubsetApply, RangeRandomSharpness
|
from lerobot.common.datasets.transforms import RandomSubsetApply, RangeRandomSharpness, make_transforms
|
||||||
|
from lerobot.common.datasets.utils import flatten_dict
|
||||||
|
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||||
|
from tests.utils import DEFAULT_CONFIG_PATH
|
||||||
|
|
||||||
|
|
||||||
class TestRandomSubsetApply:
|
class TestRandomSubsetApply:
|
||||||
|
@ -76,5 +84,84 @@ class TestRangeRandomSharpness:
|
||||||
|
|
||||||
|
|
||||||
class TestMakeTransforms:
|
class TestMakeTransforms:
|
||||||
...
|
@pytest.fixture(autouse=True)
|
||||||
# TODO
|
def setup(self):
|
||||||
|
"""Seed should be the same as the one that was used to generate artifacts"""
|
||||||
|
self.config = {
|
||||||
|
"enable": True,
|
||||||
|
"max_num_transforms": 1,
|
||||||
|
"random_order": False,
|
||||||
|
"brightness": {
|
||||||
|
"weight": 0,
|
||||||
|
"min": 0.0,
|
||||||
|
"max": 2.0
|
||||||
|
},
|
||||||
|
"contrast": {
|
||||||
|
"weight": 0,
|
||||||
|
"min": 0.0,
|
||||||
|
"max": 2.0,
|
||||||
|
},
|
||||||
|
"saturation": {
|
||||||
|
"weight": 0,
|
||||||
|
"min": 0.0,
|
||||||
|
"max": 2.0,
|
||||||
|
},
|
||||||
|
"hue": {
|
||||||
|
"weight": 0,
|
||||||
|
"min": -0.5,
|
||||||
|
"max": 0.5,
|
||||||
|
},
|
||||||
|
"sharpness": {
|
||||||
|
"weight": 0,
|
||||||
|
"min": 0.0,
|
||||||
|
"max": 2.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
self.path = Path("tests/data/save_image_transforms")
|
||||||
|
# self.expected_frames = load_file(self.path / f"transformed_frames_1336.safetensors")
|
||||||
|
self.original_frame = self.load_png_to_tensor(self.path / "original_frame.png")
|
||||||
|
# self.original_frame = self.expected_frames["original_frame"]
|
||||||
|
self.transforms = {
|
||||||
|
"brightness": v2.ColorJitter(brightness=(0.0, 2.0)),
|
||||||
|
"contrast": v2.ColorJitter(contrast=(0.0, 2.0)),
|
||||||
|
"saturation": v2.ColorJitter(saturation=(0.0, 2.0)),
|
||||||
|
"hue": v2.ColorJitter(hue=(-0.5, 0.5)),
|
||||||
|
"sharpness": RangeRandomSharpness(0.0, 2.0),
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_png_to_tensor(path: Path):
|
||||||
|
return torch.from_numpy(np.array(Image.open(path).convert('RGB'))).permute(2, 0, 1)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"transform_key, seed",
|
||||||
|
[
|
||||||
|
("brightness", 1336),
|
||||||
|
("contrast", 1336),
|
||||||
|
("saturation", 1336),
|
||||||
|
("hue", 1336),
|
||||||
|
("sharpness", 1336),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_single_transform(self, transform_key, seed):
|
||||||
|
config = self.config
|
||||||
|
config[transform_key]["weight"] = 1
|
||||||
|
cfg = OmegaConf.create(config)
|
||||||
|
transform = make_transforms(cfg, to_dtype=torch.uint8)
|
||||||
|
|
||||||
|
# expected_t = self.transforms[transform_key]
|
||||||
|
with seeded_context(seed):
|
||||||
|
actual = transform(self.original_frame)
|
||||||
|
# torch.manual_seed(42)
|
||||||
|
# actual = actual_t(self.original_frame)
|
||||||
|
# torch.manual_seed(42)
|
||||||
|
# expected = expected_t(self.original_frame)
|
||||||
|
|
||||||
|
# with seeded_context(1336):
|
||||||
|
# expected = expected_t(self.original_frame)
|
||||||
|
|
||||||
|
expected = self.load_png_to_tensor(self.path / f"{seed}_{transform_key}.png")
|
||||||
|
# # expected = self.expected_frames[transform_key]
|
||||||
|
to_pil = v2.ToPILImage()
|
||||||
|
to_pil(actual).save(self.path / f"{seed}_{transform_key}_test.png", quality=100)
|
||||||
|
torch.testing.assert_close(actual, expected)
|
||||||
|
|
Loading…
Reference in New Issue