fix memory leak due to itertools.cycle
This commit is contained in:
parent
5af00d0c1e
commit
ad3379a73a
|
@ -203,3 +203,12 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
|
||||||
|
|
||||||
torch.save(stats, stats_path)
|
torch.save(stats, stats_path)
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
def cycle(iterable):
|
||||||
|
iterator = iter(iterable)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
yield next(iterator)
|
||||||
|
except StopIteration:
|
||||||
|
iterator = iter(iterable)
|
||||||
|
|
|
@ -95,3 +95,15 @@ def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> D
|
||||||
)
|
)
|
||||||
cfg = hydra.compose(Path(config_path).stem, overrides)
|
cfg = hydra.compose(Path(config_path).stem, overrides)
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
def print_cuda_memory_usage():
|
||||||
|
import gc
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
# Also clear the cache if you want to fully release the memory
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2))
|
||||||
|
print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2))
|
||||||
|
print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2))
|
||||||
|
print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2))
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import logging
|
import logging
|
||||||
from itertools import cycle
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
|
@ -7,10 +6,16 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
|
from lerobot.common.datasets.utils import cycle
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.logger import Logger, log_output_dir
|
from lerobot.common.logger import Logger, log_output_dir
|
||||||
from lerobot.common.policies.factory import make_policy
|
from lerobot.common.policies.factory import make_policy
|
||||||
from lerobot.common.utils import format_big_number, get_safe_torch_device, init_logging, set_global_seed
|
from lerobot.common.utils import (
|
||||||
|
format_big_number,
|
||||||
|
get_safe_torch_device,
|
||||||
|
init_logging,
|
||||||
|
set_global_seed,
|
||||||
|
)
|
||||||
from lerobot.scripts.eval import eval_policy
|
from lerobot.scripts.eval import eval_policy
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue