diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index b5f40d11..2ed76898 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -106,6 +106,7 @@ def visualize_dataset( ws_port: int = 9087, save: bool = False, output_dir: Path | None = None, + root: Path | None = None, ) -> Path | None: if save: assert ( @@ -113,7 +114,7 @@ def visualize_dataset( ), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`." logging.info("Loading dataset") - dataset = LeRobotDataset(repo_id) + dataset = LeRobotDataset(repo_id, root=root) logging.info("Loading dataloader") episode_sampler = EpisodeSampler(dataset, episode_index) @@ -256,6 +257,12 @@ def main(): help="Directory path to write a .rrd file when `--save 1` is set.", ) + parser.add_argument( + "--root", + type=str, + help="Root directory for a dataset stored on a local machine.", + ) + args = parser.parse_args() visualize_dataset(**vars(args)) diff --git a/tests/test_visualize_dataset.py b/tests/test_visualize_dataset.py index 99954040..029c59ed 100644 --- a/tests/test_visualize_dataset.py +++ b/tests/test_visualize_dataset.py @@ -13,6 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path + import pytest from lerobot.scripts.visualize_dataset import visualize_dataset @@ -31,3 +33,20 @@ def test_visualize_dataset(tmpdir, repo_id): output_dir=tmpdir, ) assert rrd_path.exists() + + +@pytest.mark.parametrize( + "repo_id", + ["lerobot/pusht"], +) +@pytest.mark.parametrize("root", [Path(__file__).parent / "data"]) +def test_visualize_local_dataset(tmpdir, repo_id, root): + rrd_path = visualize_dataset( + repo_id, + episode_index=0, + batch_size=32, + save=True, + output_dir=tmpdir, + root=root, + ) + assert rrd_path.exists()