save test_artifacts

This commit is contained in:
Simon Alibert 2024-06-10 14:47:22 +02:00
parent bb8af34701
commit 625ea91adc
4 changed files with 63 additions and 34 deletions

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:649a5fbba619f76e9d36267a43b738bdee97834b78d46cbd7eb1f2a33d3ebd60
size 3686488

Binary file not shown.

Before

Width:  |  Height:  |  Size: 185 KiB

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a1de3eec04538a92cb762fe9f2011b8e51069684f8e8991bf609ff2f18cb7b14
size 22118896

View File

@ -1,46 +1,69 @@
from pathlib import Path
from torchvision.transforms import v2
import torch
from safetensors.torch import save_file
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.transforms import SharpnessJitter
from lerobot.common.utils.utils import seeded_context
DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
ARTIFACT_DIR = "tests/data/save_image_transforms"
SEED = 1336
to_pil = v2.ToPILImage()
from lerobot.common.datasets.transforms import get_image_transforms
from lerobot.common.utils.utils import init_hydra_config, seeded_context
from tests.test_transforms import ARTIFACT_DIR, REPO_ID
from tests.utils import DEFAULT_CONFIG_PATH
def main(repo_id):
dataset = LeRobotDataset(repo_id, image_transforms=None)
def save_default_config_transform(original_frame: torch.Tensor, output_dir: Path):
cfg = init_hydra_config(DEFAULT_CONFIG_PATH)
cfg_tf = cfg.image_transforms
default_tf = get_image_transforms(
brightness_weight=cfg_tf.brightness.weight,
brightness_min_max=cfg_tf.brightness.min_max,
contrast_weight=cfg_tf.contrast.weight,
contrast_min_max=cfg_tf.contrast.min_max,
saturation_weight=cfg_tf.saturation.weight,
saturation_min_max=cfg_tf.saturation.min_max,
hue_weight=cfg_tf.hue.weight,
hue_min_max=cfg_tf.hue.min_max,
sharpness_weight=cfg_tf.sharpness.weight,
sharpness_min_max=cfg_tf.sharpness.min_max,
max_num_transforms=cfg_tf.max_num_transforms,
random_order=cfg_tf.random_order,
)
with seeded_context(1337):
img_tf = default_tf(original_frame)
save_file({"default": img_tf}, output_dir / "default_transforms.safetensors")
def save_single_transforms(original_frame: torch.Tensor, output_dir: Path):
transforms = [
"brightness",
"contrast",
"saturation",
"hue",
"sharpness",
]
frames = {"original_frame": original_frame}
for transform in transforms:
kwargs = {
f"{transform}_weight": 1.0,
f"{transform}_min_max": (0.5, 0.5),
}
tf = get_image_transforms(**kwargs)
frames[transform] = tf(original_frame)
save_file(frames, output_dir / "single_transforms.safetensors")
def main():
dataset = LeRobotDataset(REPO_ID, image_transforms=None)
output_dir = Path(ARTIFACT_DIR)
output_dir.mkdir(parents=True, exist_ok=True)
original_frame = dataset[0][dataset.camera_keys[0]]
# Get first frame of given episode
from_idx = dataset.episode_data_index["from"][0].item()
original_frame = dataset[from_idx][dataset.camera_keys[0]]
to_pil(original_frame).save(output_dir / "original_frame.png", quality=100)
transforms = {
"brightness": v2.ColorJitter(brightness=(0.0, 2.0)),
"contrast": v2.ColorJitter(contrast=(0.0, 2.0)),
"saturation": v2.ColorJitter(saturation=(0.0, 2.0)),
"hue": v2.ColorJitter(hue=(-0.5, 0.5)),
"sharpness": SharpnessJitter(0.0, 2.0),
}
# frames = {"original_frame": original_frame}
for name, transform in transforms.items():
with seeded_context(SEED):
# transform = v2.Compose([transform, v2.ToDtype(torch.float32, scale=True)])
transformed_frame = transform(original_frame)
# frames[name] = transform(original_frame)
to_pil(transformed_frame).save(output_dir / f"{SEED}_{name}.png", quality=100)
# save_file(frames, output_dir / f"transformed_frames_{SEED}.safetensors")
save_single_transforms(original_frame, output_dir)
save_default_config_transform(original_frame, output_dir)
if __name__ == "__main__":
repo_id = "lerobot/aloha_mobile_shrimp"
main(repo_id)
main()