From 5edd9a89a045b97de4820508ed8950d49dcc83f3 Mon Sep 17 00:00:00 2001 From: Cadene Date: Tue, 16 Apr 2024 07:04:08 +0000 Subject: [PATCH] Move stats_dataset init into else statement -> faster init --- lerobot/common/datasets/factory.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 59115542..ee9285a4 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -48,18 +48,17 @@ def make_dataset( stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32) stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32) elif stats_path is None: - # instantiate a one frame dataset with light transform - stats_dataset = clsfunc( - dataset_id=cfg.dataset_id, - transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0), - ) - # load stats if the file exists already or compute stats and save it precomputed_stats_path = Path("data") / cfg.dataset_id / "stats.pth" if precomputed_stats_path.exists(): stats = torch.load(precomputed_stats_path) else: logging.info(f"compute_stats and save to {precomputed_stats_path}") + # instantiate a one frame dataset with light transform + stats_dataset = clsfunc( + dataset_id=cfg.dataset_id, + transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0), + ) stats = compute_stats(stats_dataset) torch.save(stats, stats_path) else: