Minor formatting

This commit is contained in:
Marina Barannikov 2024-06-05 13:31:40 +00:00
parent 0fb3dd745b
commit 1b1bbb1632
1 changed files with 7 additions and 7 deletions

View File

@ -1,7 +1,6 @@
from pathlib import Path from pathlib import Path
import hydra import hydra
from torchvision.transforms import ToPILImage from torchvision.transforms import ToPILImage
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
@ -21,9 +20,9 @@ def main(cfg, output_dir=Path("outputs/image_transforms")):
Returns: Returns:
None None
""" """
dataset = LeRobotDataset(cfg.dataset_repo_id, transform=None) dataset = LeRobotDataset(cfg.dataset_repo_id, transform=None)
output_dir = Path(output_dir) / Path(cfg.dataset_repo_id.split("/")[-1]) output_dir = Path(output_dir) / Path(cfg.dataset_repo_id.split("/")[-1])
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
@ -36,14 +35,14 @@ def main(cfg, output_dir=Path("outputs/image_transforms")):
# Apply each single transformation # Apply each single transformation
for transform_name in transforms: for transform_name in transforms:
cfg.image_transform.enable=True cfg.image_transform.enable = True
cfg.image_transform.max_num_transforms=1 cfg.image_transform.max_num_transforms = 1
for t in transforms: for t in transforms:
if t == transform_name: if t == transform_name:
cfg.image_transform[t].weight=1 cfg.image_transform[t].weight = 1
else: else:
cfg.image_transform[t].weight=0 cfg.image_transform[t].weight = 0
transform = make_transforms(cfg.image_transform) transform = make_transforms(cfg.image_transform)
img = transform(frame) img = transform(frame)
@ -56,5 +55,6 @@ def visualize_transforms_cli(cfg: dict):
cfg, cfg,
) )
if __name__ == "__main__": if __name__ == "__main__":
visualize_transforms_cli() visualize_transforms_cli()