Fix unit tests

This commit is contained in:
Remi Cadene 2025-04-22 10:35:20 +02:00
parent 601b5fdbfe
commit 367d9bda7d
5 changed files with 95 additions and 75 deletions

View File

@ -187,6 +187,7 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
for chunk_idx, file_idx in data_chunk_file_ids: for chunk_idx, file_idx in data_chunk_file_ids:
path = meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx) path = meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
df = pd.read_parquet(path) df = pd.read_parquet(path)
# TODO(rcadene): update frame index
update_data_func = get_update_episode_and_task_func(num_episodes, meta.tasks, aggr_meta.tasks) update_data_func = get_update_episode_and_task_func(num_episodes, meta.tasks, aggr_meta.tasks)
df = df.apply(update_data_func, axis=1) df = df.apply(update_data_func, axis=1)

View File

@ -197,16 +197,15 @@ def convert_data(root, new_root):
def get_video_keys(root): def get_video_keys(root):
info = load_info(root) info = load_info(root)
features = info["features"] features = info["features"]
image_keys = [key for key, ft in features.items() if ft["dtype"] == "image"]
if len(image_keys) != 0:
raise NotImplementedError()
video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"] video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
return video_keys return video_keys
def convert_videos(root: Path, new_root: Path): def convert_videos(root: Path, new_root: Path):
video_keys = get_video_keys(root) video_keys = get_video_keys(root)
if len(video_keys) == 0:
return None
video_keys = sorted(video_keys) video_keys = sorted(video_keys)
eps_metadata_per_cam = [] eps_metadata_per_cam = []
@ -284,24 +283,32 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key):
def generate_episode_metadata_dict( def generate_episode_metadata_dict(
episodes_legacy_metadata, episodes_metadata, episodes_videos, episodes_stats episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_videos=None
): ):
for ep_legacy_metadata, ep_metadata, ep_video, ep_stats, ep_idx_stats in zip( num_episodes = len(episodes_metadata)
episodes_legacy_metadata.values(), episodes_legacy_metadata_vals = list(episodes_legacy_metadata.values())
episodes_metadata, episodes_stats_vals = list(episodes_stats.values())
episodes_videos, episodes_stats_keys = list(episodes_stats.keys())
episodes_stats.values(),
episodes_stats.keys(),
strict=False,
):
ep_idx = ep_legacy_metadata["episode_index"]
ep_idx_data = ep_metadata["episode_index"]
ep_idx_video = ep_video["episode_index"]
if len({ep_idx, ep_idx_data, ep_idx_video, ep_idx_stats}) != 1: for i in range(num_episodes):
raise ValueError( ep_legacy_metadata = episodes_legacy_metadata_vals[i]
f"Number of episodes is not the same ({ep_idx=},{ep_idx_data=},{ep_idx_video=},{ep_idx_stats=})." ep_metadata = episodes_metadata[i]
) ep_stats = episodes_stats_vals[i]
ep_ids_set = {
ep_legacy_metadata["episode_index"],
ep_metadata["episode_index"],
episodes_stats_keys[i],
}
if episodes_videos is None:
ep_video = {}
else:
ep_video = episodes_videos[i]
ep_ids_set.add(ep_video["episode_index"])
if len(ep_ids_set) != 1:
raise ValueError(f"Number of episodes is not the same ({ep_ids_set}).")
ep_dict = {**ep_metadata, **ep_video, **ep_legacy_metadata, **flatten_dict({"stats": ep_stats})} ep_dict = {**ep_metadata, **ep_video, **ep_legacy_metadata, **flatten_dict({"stats": ep_stats})}
ep_dict["meta/episodes/chunk_index"] = 0 ep_dict["meta/episodes/chunk_index"] = 0
@ -309,21 +316,20 @@ def generate_episode_metadata_dict(
yield ep_dict yield ep_dict
def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_metadata): def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_metadata=None):
episodes_legacy_metadata = legacy_load_episodes(root) episodes_legacy_metadata = legacy_load_episodes(root)
episodes_stats = legacy_load_episodes_stats(root) episodes_stats = legacy_load_episodes_stats(root)
num_eps = len(episodes_legacy_metadata) num_eps_set = {len(episodes_legacy_metadata), len(episodes_metadata)}
num_eps_metadata = len(episodes_metadata) if episodes_video_metadata is not None:
num_eps_video_metadata = len(episodes_video_metadata) num_eps_set.add(len(episodes_video_metadata))
if len({num_eps, num_eps_metadata, num_eps_video_metadata}) != 1:
raise ValueError( if len(num_eps_set) != 1:
f"Number of episodes is not the same ({num_eps=},{num_eps_metadata=},{num_eps_video_metadata=})." raise ValueError(f"Number of episodes is not the same ({num_eps_set}).")
)
ds_episodes = Dataset.from_generator( ds_episodes = Dataset.from_generator(
lambda: generate_episode_metadata_dict( lambda: generate_episode_metadata_dict(
episodes_legacy_metadata, episodes_metadata, episodes_video_metadata, episodes_stats episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_video_metadata
) )
) )
write_episodes(ds_episodes, new_root) write_episodes(ds_episodes, new_root)

View File

@ -13,10 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import logging import logging
import re import re
from copy import deepcopy
from itertools import chain from itertools import chain
from pathlib import Path from pathlib import Path
@ -36,8 +34,6 @@ from lerobot.common.datasets.lerobot_dataset import (
) )
from lerobot.common.datasets.utils import ( from lerobot.common.datasets.utils import (
create_branch, create_branch,
flatten_dict,
unflatten_dict,
) )
from lerobot.common.envs.factory import make_env_config from lerobot.common.envs.factory import make_env_config
from lerobot.common.policies.factory import make_policy_config from lerobot.common.policies.factory import make_policy_config
@ -100,6 +96,25 @@ def test_dataset_initialization(tmp_path, lerobot_dataset_factory):
assert dataset.num_frames == len(dataset) assert dataset.num_frames == len(dataset)
# TODO(rcadene, aliberts): do not run LeRobotDataset.create, instead refactor LeRobotDatasetMetadata.create
# and test the small resulting function that validates the features
def test_dataset_feature_with_forward_slash_raises_error():
# make sure dir does not exist
from lerobot.common.constants import HF_LEROBOT_HOME
dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash"
# make sure does not exist
if dataset_dir.exists():
dataset_dir.rmdir()
with pytest.raises(ValueError):
LeRobotDataset.create(
repo_id="lerobot/test/with/slash",
fps=30,
features={"a/b": {"dtype": "float32", "shape": 2, "names": None}},
)
def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory): def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
@ -329,6 +344,13 @@ def test_image_array_to_pil_image_wrong_range_float_0_255():
# - [ ] test push_to_hub # - [ ] test push_to_hub
# - [ ] test smaller methods # - [ ] test smaller methods
# TODO(rcadene):
# - [ ] fix code so that old test_factory + backward pass
# - [ ] write new unit tests to test save_episode + getitem
# - [ ] save_episode : case where new dataset, concatenate same file, write new file (meta/episodes, data, videos)
# - [ ]
# - [ ] remove old tests
@pytest.mark.parametrize( @pytest.mark.parametrize(
"env_name, repo_id, policy_name", "env_name, repo_id, policy_name",
@ -436,30 +458,6 @@ def test_multidataset_frames():
assert torch.equal(sub_dataset_item[k], dataset_item[k]) assert torch.equal(sub_dataset_item[k], dataset_item[k])
# TODO(aliberts): Move to more appropriate location
def test_flatten_unflatten_dict():
d = {
"obs": {
"min": 0,
"max": 1,
"mean": 2,
"std": 3,
},
"action": {
"min": 4,
"max": 5,
"mean": 6,
"std": 7,
},
}
original_d = deepcopy(d)
d = unflatten_dict(flatten_dict(d))
# test equality between nested dicts
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
@pytest.mark.parametrize( @pytest.mark.parametrize(
"repo_id", "repo_id",
[ [
@ -569,20 +567,3 @@ def test_create_branch():
# Clean # Clean
api.delete_repo(repo_id, repo_type=repo_type) api.delete_repo(repo_id, repo_type=repo_type)
def test_dataset_feature_with_forward_slash_raises_error():
# make sure dir does not exist
from lerobot.common.constants import HF_LEROBOT_HOME
dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash"
# make sure does not exist
if dataset_dir.exists():
dataset_dir.rmdir()
with pytest.raises(ValueError):
LeRobotDataset.create(
repo_id="lerobot/test/with/slash",
fps=30,
features={"a/b": {"dtype": "float32", "shape": 2, "names": None}},
)

View File

@ -14,12 +14,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
from copy import deepcopy
import torch import torch
from datasets import Dataset from datasets import Dataset
from huggingface_hub import DatasetCard from huggingface_hub import DatasetCard
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
from lerobot.common.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch from lerobot.common.datasets.utils import (
create_lerobot_dataset_card,
flatten_dict,
hf_transform_to_torch,
unflatten_dict,
)
def test_default_parameters(): def test_default_parameters():
@ -53,3 +61,26 @@ def test_calculate_episode_data_index():
episode_data_index = calculate_episode_data_index(dataset) episode_data_index = calculate_episode_data_index(dataset)
assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3])) assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3]))
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6])) assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))
def test_flatten_unflatten_dict():
d = {
"obs": {
"min": 0,
"max": 1,
"mean": 2,
"std": 3,
},
"action": {
"min": 4,
"max": 5,
"mean": 6,
"std": 7,
},
}
original_d = deepcopy(d)
d = unflatten_dict(flatten_dict(d))
# test equality between nested dicts
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"

View File

@ -141,6 +141,7 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive, Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
and for now we add tests as we see fit. and for now we add tests as we see fit.
""" """
policy_kwargs["device"] = DEVICE
train_cfg = TrainPipelineConfig( train_cfg = TrainPipelineConfig(
# TODO(rcadene, aliberts): remove dataset download # TODO(rcadene, aliberts): remove dataset download