diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 5db061ec..48f81abd 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -97,11 +97,9 @@ def test_compute_stats_on_xarm(): # TODO(rcadene): Reduce size of dataset sample on which stats compute is tested from lerobot.common.datasets.xarm import XarmDataset - data_dir = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None - dataset = XarmDataset( dataset_id="xarm_lift_medium", - root=data_dir, + root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None, ) # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched @@ -246,9 +244,11 @@ def test_backward_compatibility(): # TODO(rcadene): make it work for all datasets with LeRobotDataset(repo_id) dataset_id = "pusht" data_dir = Path("tests/data/save_dataset_to_safetensors") / dataset_id + dataset = PushtDataset( dataset_id=dataset_id, split="train", + root=Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None, ) def load_and_compare(i):