fix_tests

This commit is contained in:
Alexander Soare 2024-05-21 15:05:59 +01:00
parent 607bea1cb3
commit b8837756fd
2 changed files with 23 additions and 6 deletions

View File

@ -49,8 +49,21 @@ test-act-ete-train:
test-act-ete-train-resume: test-act-ete-train-resume:
python lerobot/scripts/train.py \ python lerobot/scripts/train.py \
hydra.run.dir=tests/outputs/act/ \ policy=act \
training.offline_steps=4 \ policy.dim_model=64 \
env=aloha \
wandb.enable=False \
training.offline_steps=2 \
training.online_steps=0 \
eval.n_episodes=1 \
eval.batch_size=1 \
device=cpu \
training.save_checkpoint=true \
training.save_freq=2 \
policy.n_action_steps=20 \
policy.chunk_size=20 \
training.batch_size=2 \
hydra.run.dir=tests/outputs/act/
resume=true resume=true

View File

@ -50,12 +50,14 @@ def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
def get_global_random_state() -> dict[str, Any]: def get_global_random_state() -> dict[str, Any]:
"""Get the random state for `random`, `numpy`, and `torch`.""" """Get the random state for `random`, `numpy`, and `torch`."""
return { random_state_dict = {
"random_state": random.getstate(), "random_state": random.getstate(),
"numpy_random_state": np.random.get_state(), "numpy_random_state": np.random.get_state(),
"torch_random_state": torch.random.get_rng_state(), "torch_random_state": torch.random.get_rng_state(),
"torch_cuda_random_state": torch.cuda.random.get_rng_state(),
} }
if torch.cuda.is_available():
random_state_dict["torch_cuda_random_state"] = torch.cuda.random.get_rng_state()
return random_state_dict
def set_global_random_state(random_state_dict: dict[str, Any]): def set_global_random_state(random_state_dict: dict[str, Any]):
@ -67,6 +69,7 @@ def set_global_random_state(random_state_dict: dict[str, Any]):
random.setstate(random_state_dict["random_state"]) random.setstate(random_state_dict["random_state"])
np.random.set_state(random_state_dict["numpy_random_state"]) np.random.set_state(random_state_dict["numpy_random_state"])
torch.random.set_rng_state(random_state_dict["torch_random_state"]) torch.random.set_rng_state(random_state_dict["torch_random_state"])
if torch.cuda.is_available():
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"]) torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
@ -75,6 +78,7 @@ def set_global_seed(seed):
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)