fix_tests

This commit is contained in:
Alexander Soare 2024-05-21 15:05:59 +01:00
parent 607bea1cb3
commit b8837756fd
2 changed files with 23 additions and 6 deletions

View File

@ -49,8 +49,21 @@ test-act-ete-train:
test-act-ete-train-resume:
python lerobot/scripts/train.py \
hydra.run.dir=tests/outputs/act/ \
training.offline_steps=4 \
policy=act \
policy.dim_model=64 \
env=aloha \
wandb.enable=False \
training.offline_steps=2 \
training.online_steps=0 \
eval.n_episodes=1 \
eval.batch_size=1 \
device=cpu \
training.save_checkpoint=true \
training.save_freq=2 \
policy.n_action_steps=20 \
policy.chunk_size=20 \
training.batch_size=2 \
hydra.run.dir=tests/outputs/act/
resume=true

View File

@ -50,12 +50,14 @@ def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
def get_global_random_state() -> dict[str, Any]:
"""Get the random state for `random`, `numpy`, and `torch`."""
return {
random_state_dict = {
"random_state": random.getstate(),
"numpy_random_state": np.random.get_state(),
"torch_random_state": torch.random.get_rng_state(),
"torch_cuda_random_state": torch.cuda.random.get_rng_state(),
}
if torch.cuda.is_available():
random_state_dict["torch_cuda_random_state"] = torch.cuda.random.get_rng_state()
return random_state_dict
def set_global_random_state(random_state_dict: dict[str, Any]):
@ -67,7 +69,8 @@ def set_global_random_state(random_state_dict: dict[str, Any]):
random.setstate(random_state_dict["random_state"])
np.random.set_state(random_state_dict["numpy_random_state"])
torch.random.set_rng_state(random_state_dict["torch_random_state"])
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
if torch.cuda.is_available():
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
def set_global_seed(seed):
@ -75,7 +78,8 @@ def set_global_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
@contextmanager