71 lines
1.9 KiB
Python
71 lines
1.9 KiB
Python
from pathlib import Path
|
|
|
|
|
|
def _find_and_replace(text: str, finds: list[str], replaces: list[str]) -> str:
|
|
for f, r in zip(finds, replaces):
|
|
assert f in text
|
|
text = text.replace(f, r)
|
|
return text
|
|
|
|
|
|
def test_example_1():
|
|
path = "examples/1_visualize_dataset.py"
|
|
|
|
with open(path, "r") as file:
|
|
file_contents = file.read()
|
|
exec(file_contents)
|
|
|
|
assert Path("outputs/visualize_dataset/example/episode_0.mp4").exists()
|
|
|
|
|
|
def test_examples_3_and_2():
|
|
"""
|
|
Train a model with example 3, check the outputs.
|
|
Evaluate the trained model with example 2, check the outputs.
|
|
"""
|
|
|
|
path = "examples/3_train_policy.py"
|
|
|
|
with open(path, "r") as file:
|
|
file_contents = file.read()
|
|
|
|
# Do less steps and use CPU.
|
|
file_contents = _find_and_replace(
|
|
file_contents,
|
|
['"offline_steps=5000"', '"device=cuda"'],
|
|
['"offline_steps=1"', '"device=cpu"'],
|
|
)
|
|
|
|
exec(file_contents)
|
|
|
|
for file_name in ["model.pt", "stats.pth", "config.yaml"]:
|
|
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
|
|
|
|
path = "examples/2_evaluate_pretrained_policy.py"
|
|
|
|
with open(path, "r") as file:
|
|
file_contents = file.read()
|
|
|
|
# Do less evals, use CPU, and use the local model.
|
|
file_contents = _find_and_replace(
|
|
file_contents,
|
|
[
|
|
'"eval_episodes=10"',
|
|
'"rollout_batch_size=10"',
|
|
'"device=cuda"',
|
|
'# folder = Path("outputs/train/example_pusht_diffusion")',
|
|
'hub_id = "lerobot/diffusion_policy_pusht_image"',
|
|
"folder = Path(snapshot_download(hub_id)",
|
|
],
|
|
[
|
|
'"eval_episodes=1"',
|
|
'"rollout_batch_size=1"',
|
|
'"device=cpu"',
|
|
'folder = Path("outputs/train/example_pusht_diffusion")',
|
|
"",
|
|
"",
|
|
],
|
|
)
|
|
|
|
assert Path(f"outputs/train/example_pusht_diffusion").exists()
|