lerobot/tests/test_examples.py

77 lines
2.3 KiB
Python
Raw Normal View History

# TODO(aliberts): Mute logging for these tests
import subprocess
import sys
2024-04-18 20:47:42 +08:00
from pathlib import Path
from tests.utils import require_package
2024-03-27 00:13:40 +08:00
2024-04-16 19:51:32 +08:00
def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> str:
for f, r in finds_and_replaces:
2024-03-27 00:13:40 +08:00
assert f in text
text = text.replace(f, r)
return text
def _run_script(path):
subprocess.run([sys.executable, path], check=True)
2024-03-27 00:13:40 +08:00
def test_example_1():
path = "examples/1_load_lerobot_dataset.py"
_run_script(path)
assert Path("outputs/examples/1_load_lerobot_dataset/episode_0.mp4").exists()
2024-03-27 00:13:40 +08:00
@require_package("gym_pusht")
def test_examples_3_and_2():
2024-03-27 00:13:40 +08:00
"""
Train a model with example 3, check the outputs.
Evaluate the trained model with example 2, check the outputs.
"""
path = "examples/3_train_policy.py"
2024-03-27 00:13:40 +08:00
2024-04-18 20:47:42 +08:00
with open(path) as file:
file_contents = file.read()
2024-03-27 00:13:40 +08:00
2024-04-16 21:07:16 +08:00
# Do less steps, use smaller batch, use CPU, and don't complicate things with dataloader workers.
2024-03-27 00:13:40 +08:00
file_contents = _find_and_replace(
file_contents,
2024-04-16 19:51:32 +08:00
[
2024-04-16 20:43:58 +08:00
("training_steps = 5000", "training_steps = 1"),
("num_workers=4", "num_workers=0"),
2024-04-16 19:51:32 +08:00
('device = torch.device("cuda")', 'device = torch.device("cpu")'),
2024-04-30 23:08:59 +08:00
("batch_size=64", "batch_size=1"),
2024-04-16 19:51:32 +08:00
],
2024-03-27 00:13:40 +08:00
)
2024-04-17 00:15:51 +08:00
# Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249.
exec(file_contents, {})
for file_name in ["model.safetensors", "config.json"]:
2024-03-27 00:13:40 +08:00
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
path = "examples/2_evaluate_pretrained_policy.py"
2024-03-27 00:13:40 +08:00
2024-04-18 20:47:42 +08:00
with open(path) as file:
2024-03-27 00:13:40 +08:00
file_contents = file.read()
2024-03-27 00:28:16 +08:00
# Do less evals, use CPU, and use the local model.
2024-03-27 00:13:40 +08:00
file_contents = _find_and_replace(
file_contents,
2024-03-27 00:28:16 +08:00
[
('pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))', ""),
(
'# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
),
('device = torch.device("cuda")', 'device = torch.device("cpu")'),
("step += 1", "break"),
2024-03-27 00:28:16 +08:00
],
2024-03-27 00:13:40 +08:00
)
exec(file_contents, {})
assert Path("outputs/eval/example_pusht_diffusion/rollout.mp4").exists()