lerobot/tests/test_examples.py

68 lines
1.9 KiB
Python
Raw Normal View History

from pathlib import Path
2024-03-27 00:13:40 +08:00
2024-04-16 19:51:32 +08:00
def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> str:
for f, r in finds_and_replaces:
2024-03-27 00:13:40 +08:00
assert f in text
text = text.replace(f, r)
return text
def test_example_1():
path = "examples/1_visualize_dataset.py"
with open(path, "r") as file:
file_contents = file.read()
exec(file_contents)
assert Path("outputs/visualize_dataset/example/episode_0.mp4").exists()
def test_examples_3_and_2():
"""
Train a model with example 3, check the outputs.
Evaluate the trained model with example 2, check the outputs.
"""
path = "examples/3_train_policy.py"
with open(path, "r") as file:
file_contents = file.read()
2024-03-27 00:13:40 +08:00
# Do less steps and use CPU.
file_contents = _find_and_replace(
file_contents,
2024-04-16 19:51:32 +08:00
[
("offline_steps = 5000", "offline_steps = 1"),
('device = torch.device("cuda")', 'device = torch.device("cpu")'),
],
2024-03-27 00:13:40 +08:00
)
exec(file_contents)
2024-03-27 00:13:40 +08:00
for file_name in ["model.pt", "stats.pth", "config.yaml"]:
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
path = "examples/2_evaluate_pretrained_policy.py"
with open(path, "r") as file:
file_contents = file.read()
2024-03-27 00:28:16 +08:00
# Do less evals, use CPU, and use the local model.
2024-03-27 00:13:40 +08:00
file_contents = _find_and_replace(
file_contents,
2024-03-27 00:28:16 +08:00
[
2024-04-16 19:51:32 +08:00
('"eval_episodes=10"', '"eval_episodes=1"'),
('"rollout_batch_size=10"', '"rollout_batch_size=1"'),
('"device=cuda"', '"device=cpu"'),
(
'# folder = Path("outputs/train/example_pusht_diffusion")',
'folder = Path("outputs/train/example_pusht_diffusion")',
),
('hub_id = "lerobot/diffusion_policy_pusht_image"', ""),
("folder = Path(snapshot_download(hub_id)", ""),
2024-03-27 00:28:16 +08:00
],
2024-03-27 00:13:40 +08:00
)
2024-03-27 00:28:16 +08:00
assert Path(f"outputs/train/example_pusht_diffusion").exists()