fix_tests
This commit is contained in:
parent
607bea1cb3
commit
b8837756fd
17
Makefile
17
Makefile
|
@ -49,8 +49,21 @@ test-act-ete-train:
|
||||||
|
|
||||||
test-act-ete-train-resume:
|
test-act-ete-train-resume:
|
||||||
python lerobot/scripts/train.py \
|
python lerobot/scripts/train.py \
|
||||||
hydra.run.dir=tests/outputs/act/ \
|
policy=act \
|
||||||
training.offline_steps=4 \
|
policy.dim_model=64 \
|
||||||
|
env=aloha \
|
||||||
|
wandb.enable=False \
|
||||||
|
training.offline_steps=2 \
|
||||||
|
training.online_steps=0 \
|
||||||
|
eval.n_episodes=1 \
|
||||||
|
eval.batch_size=1 \
|
||||||
|
device=cpu \
|
||||||
|
training.save_checkpoint=true \
|
||||||
|
training.save_freq=2 \
|
||||||
|
policy.n_action_steps=20 \
|
||||||
|
policy.chunk_size=20 \
|
||||||
|
training.batch_size=2 \
|
||||||
|
hydra.run.dir=tests/outputs/act/
|
||||||
resume=true
|
resume=true
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -50,12 +50,14 @@ def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
|
||||||
|
|
||||||
def get_global_random_state() -> dict[str, Any]:
|
def get_global_random_state() -> dict[str, Any]:
|
||||||
"""Get the random state for `random`, `numpy`, and `torch`."""
|
"""Get the random state for `random`, `numpy`, and `torch`."""
|
||||||
return {
|
random_state_dict = {
|
||||||
"random_state": random.getstate(),
|
"random_state": random.getstate(),
|
||||||
"numpy_random_state": np.random.get_state(),
|
"numpy_random_state": np.random.get_state(),
|
||||||
"torch_random_state": torch.random.get_rng_state(),
|
"torch_random_state": torch.random.get_rng_state(),
|
||||||
"torch_cuda_random_state": torch.cuda.random.get_rng_state(),
|
|
||||||
}
|
}
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
random_state_dict["torch_cuda_random_state"] = torch.cuda.random.get_rng_state()
|
||||||
|
return random_state_dict
|
||||||
|
|
||||||
|
|
||||||
def set_global_random_state(random_state_dict: dict[str, Any]):
|
def set_global_random_state(random_state_dict: dict[str, Any]):
|
||||||
|
@ -67,7 +69,8 @@ def set_global_random_state(random_state_dict: dict[str, Any]):
|
||||||
random.setstate(random_state_dict["random_state"])
|
random.setstate(random_state_dict["random_state"])
|
||||||
np.random.set_state(random_state_dict["numpy_random_state"])
|
np.random.set_state(random_state_dict["numpy_random_state"])
|
||||||
torch.random.set_rng_state(random_state_dict["torch_random_state"])
|
torch.random.set_rng_state(random_state_dict["torch_random_state"])
|
||||||
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"])
|
||||||
|
|
||||||
|
|
||||||
def set_global_seed(seed):
|
def set_global_seed(seed):
|
||||||
|
@ -75,7 +78,8 @@ def set_global_seed(seed):
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
torch.cuda.manual_seed_all(seed)
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
|
Loading…
Reference in New Issue