Merge remote-tracking branch 'Cadene/user/rcadene/2024_03_31_remove_torchrl' into refactor_act_remove_torchrl

This commit is contained in:
Alexander Soare 2024-04-05 12:00:31 +01:00
commit 4863e54ce9
3 changed files with 28 additions and 2 deletions

View File

@ -203,3 +203,12 @@ def compute_or_load_stats(dataset, batch_size=32, max_num_samples=None):
torch.save(stats, stats_path)
return stats
def cycle(iterable):
iterator = iter(iterable)
while True:
try:
yield next(iterator)
except StopIteration:
iterator = iter(iterable)

View File

@ -95,3 +95,15 @@ def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> D
)
cfg = hydra.compose(Path(config_path).stem, overrides)
return cfg
def print_cuda_memory_usage():
import gc
gc.collect()
# Also clear the cache if you want to fully release the memory
torch.cuda.empty_cache()
print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2))
print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2))
print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2))
print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2))

View File

@ -1,5 +1,4 @@
import logging
from itertools import cycle
from pathlib import Path
import hydra
@ -7,10 +6,16 @@ import numpy as np
import torch
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import format_big_number, get_safe_torch_device, init_logging, set_global_seed
from lerobot.common.utils import (
format_big_number,
get_safe_torch_device,
init_logging,
set_global_seed,
)
from lerobot.scripts.eval import eval_policy