fix issue with saving freshly computed stats

This commit is contained in:
Alexander Soare 2024-04-17 08:49:28 +01:00
parent 3f1c322d56
commit 1331d3b4e4
1 changed files with 3 additions and 2 deletions
lerobot/common/datasets

View File

@ -62,7 +62,7 @@ def make_dataset(
stats = torch.load(precomputed_stats_path)
else:
logging.info(f"compute_stats and save to {precomputed_stats_path}")
# instantiate a one frame dataset with light transform
# Create a dataset for stats computation.
stats_dataset = clsfunc(
dataset_id=cfg.dataset_id,
split="train",
@ -70,7 +70,8 @@ def make_dataset(
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
)
stats = compute_stats(stats_dataset)
torch.save(stats, stats_path)
os.makedirs(precomputed_stats_path.parent, exist_ok=True)
torch.save(stats, precomputed_stats_path)
else:
stats = torch.load(stats_path)