diff --git a/tests/scripts/save_policy_to_safetensor.py b/tests/scripts/save_policy_to_safetensor.py new file mode 100644 index 00000000..81748db1 --- /dev/null +++ b/tests/scripts/save_policy_to_safetensor.py @@ -0,0 +1,10 @@ +import shutil +from pathlib import Path + + +def save_policy_to_safetensors(output_dir, repo_id="lerobot/pusht"): + ... + repo_dir = Path(output_dir) / repo_id + + if repo_dir.exists(): + shutil.rmtree(repo_dir)