Fix test_visualize_dataset_html
This commit is contained in:
parent
a91b7c6163
commit
8546358bc5
|
@ -130,16 +130,16 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
|
||||||
from_idx = dataset.episode_data_index["from"][episode_index]
|
from_idx = dataset.episode_data_index["from"][episode_index]
|
||||||
to_idx = dataset.episode_data_index["to"][episode_index]
|
to_idx = dataset.episode_data_index["to"][episode_index]
|
||||||
|
|
||||||
has_state = "observation.state" in dataset.hf_dataset.features
|
has_state = "observation.state" in dataset.features
|
||||||
has_action = "action" in dataset.hf_dataset.features
|
has_action = "action" in dataset.features
|
||||||
|
|
||||||
# init header of csv with state and action names
|
# init header of csv with state and action names
|
||||||
header = ["timestamp"]
|
header = ["timestamp"]
|
||||||
if has_state:
|
if has_state:
|
||||||
dim_state = dataset.meta.shapes["observation.state"]
|
dim_state = dataset.meta.shapes["observation.state"][0]
|
||||||
header += [f"state_{i}" for i in range(dim_state)]
|
header += [f"state_{i}" for i in range(dim_state)]
|
||||||
if has_action:
|
if has_action:
|
||||||
dim_action = dataset.meta.shapes["action"]
|
dim_action = dataset.meta.shapes["action"][0]
|
||||||
header += [f"action_{i}" for i in range(dim_action)]
|
header += [f"action_{i}" for i in range(dim_action)]
|
||||||
|
|
||||||
columns = ["timestamp"]
|
columns = ["timestamp"]
|
||||||
|
@ -175,23 +175,8 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
|
||||||
# check if the dataset has language instructions
|
|
||||||
if "language_instruction" not in dataset.hf_dataset.features:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# get first frame index
|
|
||||||
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
|
|
||||||
|
|
||||||
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
|
|
||||||
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
|
|
||||||
# with the tf.tensor appearing in the string
|
|
||||||
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)")
|
|
||||||
|
|
||||||
|
|
||||||
def visualize_dataset_html(
|
def visualize_dataset_html(
|
||||||
repo_id: str,
|
dataset: LeRobotDataset,
|
||||||
root: Path | None = None,
|
|
||||||
episodes: list[int] = None,
|
episodes: list[int] = None,
|
||||||
output_dir: Path | None = None,
|
output_dir: Path | None = None,
|
||||||
serve: bool = True,
|
serve: bool = True,
|
||||||
|
@ -201,13 +186,11 @@ def visualize_dataset_html(
|
||||||
) -> Path | None:
|
) -> Path | None:
|
||||||
init_logging()
|
init_logging()
|
||||||
|
|
||||||
dataset = LeRobotDataset(repo_id, root=root)
|
|
||||||
|
|
||||||
if len(dataset.meta.image_keys) > 0:
|
if len(dataset.meta.image_keys) > 0:
|
||||||
raise NotImplementedError(f"Image keys ({dataset.meta.image_keys=}) are currently not supported.")
|
raise NotImplementedError(f"Image keys ({dataset.meta.image_keys=}) are currently not supported.")
|
||||||
|
|
||||||
if output_dir is None:
|
if output_dir is None:
|
||||||
output_dir = f"outputs/visualize_dataset_html/{repo_id}"
|
output_dir = f"outputs/visualize_dataset_html/{dataset.repo_id}"
|
||||||
|
|
||||||
output_dir = Path(output_dir)
|
output_dir = Path(output_dir)
|
||||||
if output_dir.exists():
|
if output_dir.exists():
|
||||||
|
@ -296,7 +279,11 @@ def main():
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
visualize_dataset_html(**vars(args))
|
kwargs = vars(args)
|
||||||
|
repo_id = kwargs.pop("repo_id")
|
||||||
|
root = kwargs.pop("root")
|
||||||
|
dataset = LeRobotDataset(repo_id, root=root, local_files_only=True)
|
||||||
|
visualize_dataset_html(dataset, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -14,23 +14,17 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from lerobot.scripts.visualize_dataset_html import visualize_dataset_html
|
from lerobot.scripts.visualize_dataset_html import visualize_dataset_html
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
def test_visualize_dataset_html(tmp_path, lerobot_dataset_factory):
|
||||||
"repo_id",
|
root = tmp_path / "dataset"
|
||||||
["lerobot/pusht"],
|
output_dir = tmp_path / "outputs"
|
||||||
)
|
dataset = lerobot_dataset_factory(root=root)
|
||||||
def test_visualize_dataset_html(tmpdir, repo_id):
|
|
||||||
tmpdir = Path(tmpdir)
|
|
||||||
visualize_dataset_html(
|
visualize_dataset_html(
|
||||||
repo_id,
|
dataset,
|
||||||
episodes=[0],
|
episodes=[0],
|
||||||
output_dir=tmpdir,
|
output_dir=output_dir,
|
||||||
serve=False,
|
serve=False,
|
||||||
)
|
)
|
||||||
assert (tmpdir / "static" / "episode_0.csv").exists()
|
assert (output_dir / "static" / "episode_0.csv").exists()
|
||||||
|
|
Loading…
Reference in New Issue