Fix test_visualize_dataset_html
This commit is contained in:
parent
a91b7c6163
commit
8546358bc5
|
@ -130,16 +130,16 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
|
|||
from_idx = dataset.episode_data_index["from"][episode_index]
|
||||
to_idx = dataset.episode_data_index["to"][episode_index]
|
||||
|
||||
has_state = "observation.state" in dataset.hf_dataset.features
|
||||
has_action = "action" in dataset.hf_dataset.features
|
||||
has_state = "observation.state" in dataset.features
|
||||
has_action = "action" in dataset.features
|
||||
|
||||
# init header of csv with state and action names
|
||||
header = ["timestamp"]
|
||||
if has_state:
|
||||
dim_state = dataset.meta.shapes["observation.state"]
|
||||
dim_state = dataset.meta.shapes["observation.state"][0]
|
||||
header += [f"state_{i}" for i in range(dim_state)]
|
||||
if has_action:
|
||||
dim_action = dataset.meta.shapes["action"]
|
||||
dim_action = dataset.meta.shapes["action"][0]
|
||||
header += [f"action_{i}" for i in range(dim_action)]
|
||||
|
||||
columns = ["timestamp"]
|
||||
|
@ -175,23 +175,8 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
|
|||
]
|
||||
|
||||
|
||||
def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
||||
# check if the dataset has language instructions
|
||||
if "language_instruction" not in dataset.hf_dataset.features:
|
||||
return None
|
||||
|
||||
# get first frame index
|
||||
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
|
||||
|
||||
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
|
||||
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
|
||||
# with the tf.tensor appearing in the string
|
||||
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)")
|
||||
|
||||
|
||||
def visualize_dataset_html(
|
||||
repo_id: str,
|
||||
root: Path | None = None,
|
||||
dataset: LeRobotDataset,
|
||||
episodes: list[int] = None,
|
||||
output_dir: Path | None = None,
|
||||
serve: bool = True,
|
||||
|
@ -201,13 +186,11 @@ def visualize_dataset_html(
|
|||
) -> Path | None:
|
||||
init_logging()
|
||||
|
||||
dataset = LeRobotDataset(repo_id, root=root)
|
||||
|
||||
if len(dataset.meta.image_keys) > 0:
|
||||
raise NotImplementedError(f"Image keys ({dataset.meta.image_keys=}) are currently not supported.")
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = f"outputs/visualize_dataset_html/{repo_id}"
|
||||
output_dir = f"outputs/visualize_dataset_html/{dataset.repo_id}"
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
if output_dir.exists():
|
||||
|
@ -296,7 +279,11 @@ def main():
|
|||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
visualize_dataset_html(**vars(args))
|
||||
kwargs = vars(args)
|
||||
repo_id = kwargs.pop("repo_id")
|
||||
root = kwargs.pop("root")
|
||||
dataset = LeRobotDataset(repo_id, root=root, local_files_only=True)
|
||||
visualize_dataset_html(dataset, **kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -14,23 +14,17 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.scripts.visualize_dataset_html import visualize_dataset_html
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"repo_id",
|
||||
["lerobot/pusht"],
|
||||
)
|
||||
def test_visualize_dataset_html(tmpdir, repo_id):
|
||||
tmpdir = Path(tmpdir)
|
||||
def test_visualize_dataset_html(tmp_path, lerobot_dataset_factory):
|
||||
root = tmp_path / "dataset"
|
||||
output_dir = tmp_path / "outputs"
|
||||
dataset = lerobot_dataset_factory(root=root)
|
||||
visualize_dataset_html(
|
||||
repo_id,
|
||||
dataset,
|
||||
episodes=[0],
|
||||
output_dir=tmpdir,
|
||||
output_dir=output_dir,
|
||||
serve=False,
|
||||
)
|
||||
assert (tmpdir / "static" / "episode_0.csv").exists()
|
||||
assert (output_dir / "static" / "episode_0.csv").exists()
|
||||
|
|
Loading…
Reference in New Issue