Move stats_dataset init into else statement -> faster init

This commit is contained in:
Cadene 2024-04-16 07:04:08 +00:00
parent c7a8218620
commit 5edd9a89a0
1 changed files with 5 additions and 6 deletions

View File

@ -48,18 +48,17 @@ def make_dataset(
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
elif stats_path is None: elif stats_path is None:
# instantiate a one frame dataset with light transform
stats_dataset = clsfunc(
dataset_id=cfg.dataset_id,
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
)
# load stats if the file exists already or compute stats and save it # load stats if the file exists already or compute stats and save it
precomputed_stats_path = Path("data") / cfg.dataset_id / "stats.pth" precomputed_stats_path = Path("data") / cfg.dataset_id / "stats.pth"
if precomputed_stats_path.exists(): if precomputed_stats_path.exists():
stats = torch.load(precomputed_stats_path) stats = torch.load(precomputed_stats_path)
else: else:
logging.info(f"compute_stats and save to {precomputed_stats_path}") logging.info(f"compute_stats and save to {precomputed_stats_path}")
# instantiate a one frame dataset with light transform
stats_dataset = clsfunc(
dataset_id=cfg.dataset_id,
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
)
stats = compute_stats(stats_dataset) stats = compute_stats(stats_dataset)
torch.save(stats, stats_path) torch.save(stats, stats_path)
else: else: