Move stats_dataset init into else statement -> faster init
This commit is contained in:
parent
c7a8218620
commit
5edd9a89a0
|
@ -48,18 +48,17 @@ def make_dataset(
|
||||||
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
||||||
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
||||||
elif stats_path is None:
|
elif stats_path is None:
|
||||||
# instantiate a one frame dataset with light transform
|
|
||||||
stats_dataset = clsfunc(
|
|
||||||
dataset_id=cfg.dataset_id,
|
|
||||||
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
|
||||||
)
|
|
||||||
|
|
||||||
# load stats if the file exists already or compute stats and save it
|
# load stats if the file exists already or compute stats and save it
|
||||||
precomputed_stats_path = Path("data") / cfg.dataset_id / "stats.pth"
|
precomputed_stats_path = Path("data") / cfg.dataset_id / "stats.pth"
|
||||||
if precomputed_stats_path.exists():
|
if precomputed_stats_path.exists():
|
||||||
stats = torch.load(precomputed_stats_path)
|
stats = torch.load(precomputed_stats_path)
|
||||||
else:
|
else:
|
||||||
logging.info(f"compute_stats and save to {precomputed_stats_path}")
|
logging.info(f"compute_stats and save to {precomputed_stats_path}")
|
||||||
|
# instantiate a one frame dataset with light transform
|
||||||
|
stats_dataset = clsfunc(
|
||||||
|
dataset_id=cfg.dataset_id,
|
||||||
|
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
||||||
|
)
|
||||||
stats = compute_stats(stats_dataset)
|
stats = compute_stats(stats_dataset)
|
||||||
torch.save(stats, stats_path)
|
torch.save(stats, stats_path)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue