Merge remote-tracking branch 'upstream/main' into sim_latency

This commit is contained in:
Alexander Soare 2024-06-10 14:06:21 +01:00
commit 42e6a5e9b3
4 changed files with 94 additions and 18 deletions

18
.github/workflows/trufflehog.yml vendored Normal file
View File

@ -0,0 +1,18 @@
on:
push:
name: Secret Leaks
permissions:
contents: read
jobs:
trufflehog:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Secret Scanning
uses: trufflesecurity/trufflehog@main

View File

@ -13,39 +13,71 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Use this script to get a quick summary of your system config.
It should be able to run without any of LeRobot's dependencies or LeRobot itself installed.
"""
import platform
import huggingface_hub
HAS_HF_HUB = True
HAS_HF_DATASETS = True
HAS_NP = True
HAS_TORCH = True
HAS_LEROBOT = True
# import dataset
import numpy as np
import torch
try:
import huggingface_hub
except ImportError:
HAS_HF_HUB = False
from lerobot import __version__ as version
try:
import datasets
except ImportError:
HAS_HF_DATASETS = False
pt_version = torch.__version__
pt_cuda_available = torch.cuda.is_available()
pt_cuda_available = torch.cuda.is_available()
cuda_version = torch._C._cuda_getCompiledVersion() if torch.version.cuda is not None else "N/A"
try:
import numpy as np
except ImportError:
HAS_NP = False
try:
import torch
except ImportError:
HAS_TORCH = False
try:
import lerobot
except ImportError:
HAS_LEROBOT = False
lerobot_version = lerobot.__version__ if HAS_LEROBOT else "N/A"
hf_hub_version = huggingface_hub.__version__ if HAS_HF_HUB else "N/A"
hf_datasets_version = datasets.__version__ if HAS_HF_DATASETS else "N/A"
np_version = np.__version__ if HAS_NP else "N/A"
torch_version = torch.__version__ if HAS_TORCH else "N/A"
torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A"
cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A"
# TODO(aliberts): refactor into an actual command `lerobot env`
def display_sys_info() -> dict:
"""Run this to get basic system info to help for tracking issues & bugs."""
info = {
"`lerobot` version": version,
"`lerobot` version": lerobot_version,
"Platform": platform.platform(),
"Python version": platform.python_version(),
"Huggingface_hub version": huggingface_hub.__version__,
# TODO(aliberts): Add dataset when https://github.com/huggingface/lerobot/pull/73 is merged
# "Dataset version": dataset.__version__,
"Numpy version": np.__version__,
"PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
"Huggingface_hub version": hf_hub_version,
"Dataset version": hf_datasets_version,
"Numpy version": np_version,
"PyTorch version (GPU?)": f"{torch_version} ({torch_cuda_available})",
"Cuda version": cuda_version,
"Using GPU in script?": "<fill in>",
"Using distributed or parallel set-up in script?": "<fill in>",
# "Using distributed or parallel set-up in script?": "<fill in>",
}
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n")
print(format_dict(info))
return info

View File

@ -106,6 +106,7 @@ def visualize_dataset(
ws_port: int = 9087,
save: bool = False,
output_dir: Path | None = None,
root: Path | None = None,
) -> Path | None:
if save:
assert (
@ -113,7 +114,7 @@ def visualize_dataset(
), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
logging.info("Loading dataset")
dataset = LeRobotDataset(repo_id)
dataset = LeRobotDataset(repo_id, root=root)
logging.info("Loading dataloader")
episode_sampler = EpisodeSampler(dataset, episode_index)
@ -256,6 +257,12 @@ def main():
help="Directory path to write a .rrd file when `--save 1` is set.",
)
parser.add_argument(
"--root",
type=str,
help="Root directory for a dataset stored on a local machine.",
)
args = parser.parse_args()
visualize_dataset(**vars(args))

View File

@ -13,6 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
import pytest
from lerobot.scripts.visualize_dataset import visualize_dataset
@ -31,3 +33,20 @@ def test_visualize_dataset(tmpdir, repo_id):
output_dir=tmpdir,
)
assert rrd_path.exists()
@pytest.mark.parametrize(
"repo_id",
["lerobot/pusht"],
)
@pytest.mark.parametrize("root", [Path(__file__).parent / "data"])
def test_visualize_local_dataset(tmpdir, repo_id, root):
rrd_path = visualize_dataset(
repo_id,
episode_index=0,
batch_size=32,
save=True,
output_dir=tmpdir,
root=root,
)
assert rrd_path.exists()