fix_tests
This commit is contained in:
parent
607bea1cb3
commit
b8837756fd
17
Makefile
17
Makefile
|
@ -49,8 +49,21 @@ test-act-ete-train:
|
|||
|
||||
test-act-ete-train-resume:
|
||||
python lerobot/scripts/train.py \
|
||||
hydra.run.dir=tests/outputs/act/ \
|
||||
training.offline_steps=4 \
|
||||
policy=act \
|
||||
policy.dim_model=64 \
|
||||
env=aloha \
|
||||
wandb.enable=False \
|
||||
training.offline_steps=2 \
|
||||
training.online_steps=0 \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
device=cpu \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=2 \
|
||||
policy.n_action_steps=20 \
|
||||
policy.chunk_size=20 \
|
||||
training.batch_size=2 \
|
||||
hydra.run.dir=tests/outputs/act/
|
||||
resume=true
|
||||
|
||||
|
||||
|
|
|
@ -50,12 +50,14 @@ def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
|
|||
|
||||
def get_global_random_state() -> dict[str, Any]:
|
||||
"""Get the random state for `random`, `numpy`, and `torch`."""
|
||||
return {
|
||||
random_state_dict = {
|
||||
"random_state": random.getstate(),
|
||||
"numpy_random_state": np.random.get_state(),
|
||||
"torch_random_state": torch.random.get_rng_state(),
|
||||
"torch_cuda_random_state": torch.cuda.random.get_rng_state(),
|
||||
}
|
||||
if torch.cuda.is_available():
|
||||
random_state_dict["torch_cuda_random_state"] = torch.cuda.random.get_rng_state()
|
||||
return random_state_dict
|
||||
|
||||
|
||||
def set_global_random_state(random_state_dict: dict[str, Any]):
|
||||
|
@ -67,7 +69,8 @@ def set_global_random_state(random_state_dict: dict[str, Any]):
|
|||
random.setstate(random_state_dict["random_state"])
|
||||
np.random.set_state(random_state_dict["numpy_random_state"])
|
||||
torch.random.set_rng_state(random_state_dict["torch_random_state"])
|
||||
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
|
||||
|
||||
|
||||
def set_global_seed(seed):
|
||||
|
@ -75,7 +78,8 @@ def set_global_seed(seed):
|
|||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
|
Loading…
Reference in New Issue