save test_artifacts
This commit is contained in:
parent
bb8af34701
commit
625ea91adc
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:649a5fbba619f76e9d36267a43b738bdee97834b78d46cbd7eb1f2a33d3ebd60
|
||||
size 3686488
|
Binary file not shown.
Before Width: | Height: | Size: 185 KiB |
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a1de3eec04538a92cb762fe9f2011b8e51069684f8e8991bf609ff2f18cb7b14
|
||||
size 22118896
|
|
@ -1,46 +1,69 @@
|
|||
from pathlib import Path
|
||||
|
||||
from torchvision.transforms import v2
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.transforms import SharpnessJitter
|
||||
from lerobot.common.utils.utils import seeded_context
|
||||
|
||||
DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
|
||||
ARTIFACT_DIR = "tests/data/save_image_transforms"
|
||||
SEED = 1336
|
||||
to_pil = v2.ToPILImage()
|
||||
from lerobot.common.datasets.transforms import get_image_transforms
|
||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||
from tests.test_transforms import ARTIFACT_DIR, REPO_ID
|
||||
from tests.utils import DEFAULT_CONFIG_PATH
|
||||
|
||||
|
||||
def main(repo_id):
|
||||
dataset = LeRobotDataset(repo_id, image_transforms=None)
|
||||
def save_default_config_transform(original_frame: torch.Tensor, output_dir: Path):
|
||||
cfg = init_hydra_config(DEFAULT_CONFIG_PATH)
|
||||
cfg_tf = cfg.image_transforms
|
||||
default_tf = get_image_transforms(
|
||||
brightness_weight=cfg_tf.brightness.weight,
|
||||
brightness_min_max=cfg_tf.brightness.min_max,
|
||||
contrast_weight=cfg_tf.contrast.weight,
|
||||
contrast_min_max=cfg_tf.contrast.min_max,
|
||||
saturation_weight=cfg_tf.saturation.weight,
|
||||
saturation_min_max=cfg_tf.saturation.min_max,
|
||||
hue_weight=cfg_tf.hue.weight,
|
||||
hue_min_max=cfg_tf.hue.min_max,
|
||||
sharpness_weight=cfg_tf.sharpness.weight,
|
||||
sharpness_min_max=cfg_tf.sharpness.min_max,
|
||||
max_num_transforms=cfg_tf.max_num_transforms,
|
||||
random_order=cfg_tf.random_order,
|
||||
)
|
||||
|
||||
with seeded_context(1337):
|
||||
img_tf = default_tf(original_frame)
|
||||
|
||||
save_file({"default": img_tf}, output_dir / "default_transforms.safetensors")
|
||||
|
||||
|
||||
def save_single_transforms(original_frame: torch.Tensor, output_dir: Path):
|
||||
transforms = [
|
||||
"brightness",
|
||||
"contrast",
|
||||
"saturation",
|
||||
"hue",
|
||||
"sharpness",
|
||||
]
|
||||
|
||||
frames = {"original_frame": original_frame}
|
||||
for transform in transforms:
|
||||
kwargs = {
|
||||
f"{transform}_weight": 1.0,
|
||||
f"{transform}_min_max": (0.5, 0.5),
|
||||
}
|
||||
tf = get_image_transforms(**kwargs)
|
||||
frames[transform] = tf(original_frame)
|
||||
|
||||
save_file(frames, output_dir / "single_transforms.safetensors")
|
||||
|
||||
|
||||
def main():
|
||||
dataset = LeRobotDataset(REPO_ID, image_transforms=None)
|
||||
output_dir = Path(ARTIFACT_DIR)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
original_frame = dataset[0][dataset.camera_keys[0]]
|
||||
|
||||
# Get first frame of given episode
|
||||
from_idx = dataset.episode_data_index["from"][0].item()
|
||||
original_frame = dataset[from_idx][dataset.camera_keys[0]]
|
||||
to_pil(original_frame).save(output_dir / "original_frame.png", quality=100)
|
||||
|
||||
transforms = {
|
||||
"brightness": v2.ColorJitter(brightness=(0.0, 2.0)),
|
||||
"contrast": v2.ColorJitter(contrast=(0.0, 2.0)),
|
||||
"saturation": v2.ColorJitter(saturation=(0.0, 2.0)),
|
||||
"hue": v2.ColorJitter(hue=(-0.5, 0.5)),
|
||||
"sharpness": SharpnessJitter(0.0, 2.0),
|
||||
}
|
||||
|
||||
# frames = {"original_frame": original_frame}
|
||||
for name, transform in transforms.items():
|
||||
with seeded_context(SEED):
|
||||
# transform = v2.Compose([transform, v2.ToDtype(torch.float32, scale=True)])
|
||||
transformed_frame = transform(original_frame)
|
||||
# frames[name] = transform(original_frame)
|
||||
to_pil(transformed_frame).save(output_dir / f"{SEED}_{name}.png", quality=100)
|
||||
|
||||
# save_file(frames, output_dir / f"transformed_frames_{SEED}.safetensors")
|
||||
save_single_transforms(original_frame, output_dir)
|
||||
save_default_config_transform(original_frame, output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
repo_id = "lerobot/aloha_mobile_shrimp"
|
||||
main(repo_id)
|
||||
main()
|
||||
|
|
Loading…
Reference in New Issue