From b8837756fd9bafac8896c58ac2d0f64294fcb784 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 21 May 2024 15:05:59 +0100 Subject: [PATCH] fix_tests --- Makefile | 17 +++++++++++++++-- lerobot/common/utils/utils.py | 12 ++++++++---- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/Makefile b/Makefile index 45eb72a0..3e2e1fbb 100644 --- a/Makefile +++ b/Makefile @@ -49,8 +49,21 @@ test-act-ete-train: test-act-ete-train-resume: python lerobot/scripts/train.py \ - hydra.run.dir=tests/outputs/act/ \ - training.offline_steps=4 \ + policy=act \ + policy.dim_model=64 \ + env=aloha \ + wandb.enable=False \ + training.offline_steps=2 \ + training.online_steps=0 \ + eval.n_episodes=1 \ + eval.batch_size=1 \ + device=cpu \ + training.save_checkpoint=true \ + training.save_freq=2 \ + policy.n_action_steps=20 \ + policy.chunk_size=20 \ + training.batch_size=2 \ + hydra.run.dir=tests/outputs/act/ resume=true diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py index 0eab089f..696999ad 100644 --- a/lerobot/common/utils/utils.py +++ b/lerobot/common/utils/utils.py @@ -50,12 +50,14 @@ def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device: def get_global_random_state() -> dict[str, Any]: """Get the random state for `random`, `numpy`, and `torch`.""" - return { + random_state_dict = { "random_state": random.getstate(), "numpy_random_state": np.random.get_state(), "torch_random_state": torch.random.get_rng_state(), - "torch_cuda_random_state": torch.cuda.random.get_rng_state(), } + if torch.cuda.is_available(): + random_state_dict["torch_cuda_random_state"] = torch.cuda.random.get_rng_state() + return random_state_dict def set_global_random_state(random_state_dict: dict[str, Any]): @@ -67,7 +69,8 @@ def set_global_random_state(random_state_dict: dict[str, Any]): random.setstate(random_state_dict["random_state"]) np.random.set_state(random_state_dict["numpy_random_state"]) torch.random.set_rng_state(random_state_dict["torch_random_state"]) - torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"]) + if torch.cuda.is_available(): + torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"]) def set_global_seed(seed): @@ -75,7 +78,8 @@ def set_global_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) @contextmanager