Add save_image_transforms.py & artifacts
This commit is contained in:
parent
e444b0d529
commit
641d349df4
Binary file not shown.
After Width: | Height: | Size: 118 KiB |
Binary file not shown.
After Width: | Height: | Size: 118 KiB |
Binary file not shown.
After Width: | Height: | Size: 187 KiB |
Binary file not shown.
After Width: | Height: | Size: 168 KiB |
Binary file not shown.
After Width: | Height: | Size: 224 KiB |
Binary file not shown.
After Width: | Height: | Size: 185 KiB |
|
@ -0,0 +1,42 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from torchvision.transforms import ToPILImage, v2
|
||||||
|
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.common.datasets.transforms import RangeRandomSharpness
|
||||||
|
from lerobot.common.utils.utils import seeded_context
|
||||||
|
|
||||||
|
DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
|
||||||
|
ARTIFACT_DIR = "tests/data/save_image_transforms"
|
||||||
|
SEED = 1336
|
||||||
|
to_pil = ToPILImage()
|
||||||
|
|
||||||
|
|
||||||
|
def main(repo_id):
|
||||||
|
dataset = LeRobotDataset(repo_id, transform=None)
|
||||||
|
output_dir = Path(ARTIFACT_DIR)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Get first frame of given episode
|
||||||
|
from_idx = dataset.episode_data_index["from"][0].item()
|
||||||
|
original_frame = dataset[from_idx][dataset.camera_keys[0]]
|
||||||
|
to_pil(original_frame).save(output_dir / "original_frame.png", quality=100)
|
||||||
|
|
||||||
|
transforms = {
|
||||||
|
"brightness": v2.ColorJitter(brightness=(0.0, 2.0)),
|
||||||
|
"contrast": v2.ColorJitter(contrast=(0.0, 2.0)),
|
||||||
|
"saturation": v2.ColorJitter(saturation=(0.0, 2.0)),
|
||||||
|
"hue": v2.ColorJitter(hue=(-0.5, 0.5)),
|
||||||
|
"sharpness": RangeRandomSharpness(0.0, 2.0),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Apply each single transformation
|
||||||
|
for name, transform in transforms.items():
|
||||||
|
with seeded_context(SEED):
|
||||||
|
transformed_frame = transform(original_frame)
|
||||||
|
to_pil(transformed_frame).save(output_dir / f"{SEED}_{name}.png", quality=100)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
repo_id = "lerobot/aloha_mobile_shrimp"
|
||||||
|
main(repo_id)
|
Loading…
Reference in New Issue