lerobot/tests/test_examples.py

148 lines
4.8 KiB
Python
Raw Normal View History

2024-05-15 18:13:09 +08:00
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(aliberts): Mute logging for these tests
import io
import subprocess
import sys
2024-04-18 20:47:42 +08:00
from pathlib import Path
from tests.utils import require_package
2024-03-27 00:13:40 +08:00
2024-04-16 19:51:32 +08:00
def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> str:
for f, r in finds_and_replaces:
2024-03-27 00:13:40 +08:00
assert f in text
text = text.replace(f, r)
return text
2024-07-03 03:15:48 +08:00
def _run_script(path, args=None):
subprocess.run([sys.executable, path] + args if args is not None else [], check=True)
def _read_file(path):
with open(path) as file:
return file.read()
2024-03-27 00:13:40 +08:00
def test_example_1():
path = "examples/1_load_lerobot_dataset.py"
_run_script(path)
assert Path("outputs/examples/1_load_lerobot_dataset/episode_0.mp4").exists()
2024-03-27 00:13:40 +08:00
@require_package("gym_pusht")
def test_examples_basic2_basic3_advanced1():
2024-03-27 00:13:40 +08:00
"""
Train a model with example 3, check the outputs.
Evaluate the trained model with example 2, check the outputs.
Calculate the validation loss with advanced example 1, check the outputs.
2024-03-27 00:13:40 +08:00
"""
### Test example 3
file_contents = _read_file("examples/3_train_policy.py")
2024-03-27 00:13:40 +08:00
# Do fewer steps, use smaller batch, use CPU, and don't complicate things with dataloader workers.
2024-03-27 00:13:40 +08:00
file_contents = _find_and_replace(
file_contents,
2024-04-16 19:51:32 +08:00
[
2024-04-16 20:43:58 +08:00
("training_steps = 5000", "training_steps = 1"),
("num_workers=4", "num_workers=0"),
2024-04-16 19:51:32 +08:00
('device = torch.device("cuda")', 'device = torch.device("cpu")'),
2024-04-30 23:08:59 +08:00
("batch_size=64", "batch_size=1"),
2024-04-16 19:51:32 +08:00
],
2024-03-27 00:13:40 +08:00
)
2024-04-17 00:15:51 +08:00
# Pass empty globals to allow dictionary comprehension https://stackoverflow.com/a/32897127/4391249.
exec(file_contents, {})
for file_name in ["model.safetensors", "config.json"]:
2024-03-27 00:13:40 +08:00
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
### Test example 2
file_contents = _read_file("examples/2_evaluate_pretrained_policy.py")
2024-03-27 00:13:40 +08:00
# Do fewer evals, use CPU, and use the local model.
2024-03-27 00:13:40 +08:00
file_contents = _find_and_replace(
file_contents,
2024-03-27 00:28:16 +08:00
[
(
'pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))',
"",
),
(
'# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
),
('device = torch.device("cuda")', 'device = torch.device("cpu")'),
("step += 1", "break"),
2024-03-27 00:28:16 +08:00
],
2024-03-27 00:13:40 +08:00
)
exec(file_contents, {})
assert Path("outputs/eval/example_pusht_diffusion/rollout.mp4").exists()
## Test example 4
file_contents = _read_file("examples/advanced/2_calculate_validation_loss.py")
# Run on a single example from the last episode, use CPU, and use the local model.
file_contents = _find_and_replace(
file_contents,
[
(
'pretrained_policy_path = Path(snapshot_download("lerobot/diffusion_pusht"))',
"",
),
(
'# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
),
('split=f"train[{first_val_frame_index}:]"', 'split="train[30:]"'),
("num_workers=4", "num_workers=0"),
('device = torch.device("cuda")', 'device = torch.device("cpu")'),
("batch_size=64", "batch_size=1"),
],
)
# Capture the output of the script
output_buffer = io.StringIO()
sys.stdout = output_buffer
exec(file_contents, {})
printed_output = output_buffer.getvalue()
# Restore stdout to its original state
sys.stdout = sys.__stdout__
assert "Average loss on validation set" in printed_output
2024-07-03 03:15:48 +08:00
def test_real_world_recording():
path = "examples/real_robot_example/record_training_data.py"
_run_script(
path,
[
"--data_dir",
"outputs/examples",
"--repo-id",
"real_world_debug",
"--num-episodes",
"2",
"--num-frames",
"10",
"--mock-robot",
],
)
assert Path("outputs/examples/real_world_debug/video/episode_0.mp4").exists()