Fix unit tests
This commit is contained in:
parent
601b5fdbfe
commit
367d9bda7d
|
@ -187,6 +187,7 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
|
|||
for chunk_idx, file_idx in data_chunk_file_ids:
|
||||
path = meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
df = pd.read_parquet(path)
|
||||
# TODO(rcadene): update frame index
|
||||
update_data_func = get_update_episode_and_task_func(num_episodes, meta.tasks, aggr_meta.tasks)
|
||||
df = df.apply(update_data_func, axis=1)
|
||||
|
||||
|
|
|
@ -197,16 +197,15 @@ def convert_data(root, new_root):
|
|||
def get_video_keys(root):
|
||||
info = load_info(root)
|
||||
features = info["features"]
|
||||
image_keys = [key for key, ft in features.items() if ft["dtype"] == "image"]
|
||||
if len(image_keys) != 0:
|
||||
raise NotImplementedError()
|
||||
|
||||
video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
|
||||
return video_keys
|
||||
|
||||
|
||||
def convert_videos(root: Path, new_root: Path):
|
||||
video_keys = get_video_keys(root)
|
||||
if len(video_keys) == 0:
|
||||
return None
|
||||
|
||||
video_keys = sorted(video_keys)
|
||||
|
||||
eps_metadata_per_cam = []
|
||||
|
@ -284,24 +283,32 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key):
|
|||
|
||||
|
||||
def generate_episode_metadata_dict(
|
||||
episodes_legacy_metadata, episodes_metadata, episodes_videos, episodes_stats
|
||||
episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_videos=None
|
||||
):
|
||||
for ep_legacy_metadata, ep_metadata, ep_video, ep_stats, ep_idx_stats in zip(
|
||||
episodes_legacy_metadata.values(),
|
||||
episodes_metadata,
|
||||
episodes_videos,
|
||||
episodes_stats.values(),
|
||||
episodes_stats.keys(),
|
||||
strict=False,
|
||||
):
|
||||
ep_idx = ep_legacy_metadata["episode_index"]
|
||||
ep_idx_data = ep_metadata["episode_index"]
|
||||
ep_idx_video = ep_video["episode_index"]
|
||||
num_episodes = len(episodes_metadata)
|
||||
episodes_legacy_metadata_vals = list(episodes_legacy_metadata.values())
|
||||
episodes_stats_vals = list(episodes_stats.values())
|
||||
episodes_stats_keys = list(episodes_stats.keys())
|
||||
|
||||
if len({ep_idx, ep_idx_data, ep_idx_video, ep_idx_stats}) != 1:
|
||||
raise ValueError(
|
||||
f"Number of episodes is not the same ({ep_idx=},{ep_idx_data=},{ep_idx_video=},{ep_idx_stats=})."
|
||||
)
|
||||
for i in range(num_episodes):
|
||||
ep_legacy_metadata = episodes_legacy_metadata_vals[i]
|
||||
ep_metadata = episodes_metadata[i]
|
||||
ep_stats = episodes_stats_vals[i]
|
||||
|
||||
ep_ids_set = {
|
||||
ep_legacy_metadata["episode_index"],
|
||||
ep_metadata["episode_index"],
|
||||
episodes_stats_keys[i],
|
||||
}
|
||||
|
||||
if episodes_videos is None:
|
||||
ep_video = {}
|
||||
else:
|
||||
ep_video = episodes_videos[i]
|
||||
ep_ids_set.add(ep_video["episode_index"])
|
||||
|
||||
if len(ep_ids_set) != 1:
|
||||
raise ValueError(f"Number of episodes is not the same ({ep_ids_set}).")
|
||||
|
||||
ep_dict = {**ep_metadata, **ep_video, **ep_legacy_metadata, **flatten_dict({"stats": ep_stats})}
|
||||
ep_dict["meta/episodes/chunk_index"] = 0
|
||||
|
@ -309,21 +316,20 @@ def generate_episode_metadata_dict(
|
|||
yield ep_dict
|
||||
|
||||
|
||||
def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_metadata):
|
||||
def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_metadata=None):
|
||||
episodes_legacy_metadata = legacy_load_episodes(root)
|
||||
episodes_stats = legacy_load_episodes_stats(root)
|
||||
|
||||
num_eps = len(episodes_legacy_metadata)
|
||||
num_eps_metadata = len(episodes_metadata)
|
||||
num_eps_video_metadata = len(episodes_video_metadata)
|
||||
if len({num_eps, num_eps_metadata, num_eps_video_metadata}) != 1:
|
||||
raise ValueError(
|
||||
f"Number of episodes is not the same ({num_eps=},{num_eps_metadata=},{num_eps_video_metadata=})."
|
||||
)
|
||||
num_eps_set = {len(episodes_legacy_metadata), len(episodes_metadata)}
|
||||
if episodes_video_metadata is not None:
|
||||
num_eps_set.add(len(episodes_video_metadata))
|
||||
|
||||
if len(num_eps_set) != 1:
|
||||
raise ValueError(f"Number of episodes is not the same ({num_eps_set}).")
|
||||
|
||||
ds_episodes = Dataset.from_generator(
|
||||
lambda: generate_episode_metadata_dict(
|
||||
episodes_legacy_metadata, episodes_metadata, episodes_video_metadata, episodes_stats
|
||||
episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_video_metadata
|
||||
)
|
||||
)
|
||||
write_episodes(ds_episodes, new_root)
|
||||
|
|
|
@ -13,10 +13,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -36,8 +34,6 @@ from lerobot.common.datasets.lerobot_dataset import (
|
|||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
create_branch,
|
||||
flatten_dict,
|
||||
unflatten_dict,
|
||||
)
|
||||
from lerobot.common.envs.factory import make_env_config
|
||||
from lerobot.common.policies.factory import make_policy_config
|
||||
|
@ -100,6 +96,25 @@ def test_dataset_initialization(tmp_path, lerobot_dataset_factory):
|
|||
assert dataset.num_frames == len(dataset)
|
||||
|
||||
|
||||
# TODO(rcadene, aliberts): do not run LeRobotDataset.create, instead refactor LeRobotDatasetMetadata.create
|
||||
# and test the small resulting function that validates the features
|
||||
def test_dataset_feature_with_forward_slash_raises_error():
|
||||
# make sure dir does not exist
|
||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
|
||||
dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash"
|
||||
# make sure does not exist
|
||||
if dataset_dir.exists():
|
||||
dataset_dir.rmdir()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
LeRobotDataset.create(
|
||||
repo_id="lerobot/test/with/slash",
|
||||
fps=30,
|
||||
features={"a/b": {"dtype": "float32", "shape": 2, "names": None}},
|
||||
)
|
||||
|
||||
|
||||
def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
|
@ -329,6 +344,13 @@ def test_image_array_to_pil_image_wrong_range_float_0_255():
|
|||
# - [ ] test push_to_hub
|
||||
# - [ ] test smaller methods
|
||||
|
||||
# TODO(rcadene):
|
||||
# - [ ] fix code so that old test_factory + backward pass
|
||||
# - [ ] write new unit tests to test save_episode + getitem
|
||||
# - [ ] save_episode : case where new dataset, concatenate same file, write new file (meta/episodes, data, videos)
|
||||
# - [ ]
|
||||
# - [ ] remove old tests
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_name, repo_id, policy_name",
|
||||
|
@ -436,30 +458,6 @@ def test_multidataset_frames():
|
|||
assert torch.equal(sub_dataset_item[k], dataset_item[k])
|
||||
|
||||
|
||||
# TODO(aliberts): Move to more appropriate location
|
||||
def test_flatten_unflatten_dict():
|
||||
d = {
|
||||
"obs": {
|
||||
"min": 0,
|
||||
"max": 1,
|
||||
"mean": 2,
|
||||
"std": 3,
|
||||
},
|
||||
"action": {
|
||||
"min": 4,
|
||||
"max": 5,
|
||||
"mean": 6,
|
||||
"std": 7,
|
||||
},
|
||||
}
|
||||
|
||||
original_d = deepcopy(d)
|
||||
d = unflatten_dict(flatten_dict(d))
|
||||
|
||||
# test equality between nested dicts
|
||||
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"repo_id",
|
||||
[
|
||||
|
@ -569,20 +567,3 @@ def test_create_branch():
|
|||
|
||||
# Clean
|
||||
api.delete_repo(repo_id, repo_type=repo_type)
|
||||
|
||||
|
||||
def test_dataset_feature_with_forward_slash_raises_error():
|
||||
# make sure dir does not exist
|
||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||
|
||||
dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash"
|
||||
# make sure does not exist
|
||||
if dataset_dir.exists():
|
||||
dataset_dir.rmdir()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
LeRobotDataset.create(
|
||||
repo_id="lerobot/test/with/slash",
|
||||
fps=30,
|
||||
features={"a/b": {"dtype": "float32", "shape": 2, "names": None}},
|
||||
)
|
||||
|
|
|
@ -14,12 +14,20 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import DatasetCard
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.common.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch
|
||||
from lerobot.common.datasets.utils import (
|
||||
create_lerobot_dataset_card,
|
||||
flatten_dict,
|
||||
hf_transform_to_torch,
|
||||
unflatten_dict,
|
||||
)
|
||||
|
||||
|
||||
def test_default_parameters():
|
||||
|
@ -53,3 +61,26 @@ def test_calculate_episode_data_index():
|
|||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3]))
|
||||
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))
|
||||
|
||||
|
||||
def test_flatten_unflatten_dict():
|
||||
d = {
|
||||
"obs": {
|
||||
"min": 0,
|
||||
"max": 1,
|
||||
"mean": 2,
|
||||
"std": 3,
|
||||
},
|
||||
"action": {
|
||||
"min": 4,
|
||||
"max": 5,
|
||||
"mean": 6,
|
||||
"std": 7,
|
||||
},
|
||||
}
|
||||
|
||||
original_d = deepcopy(d)
|
||||
d = unflatten_dict(flatten_dict(d))
|
||||
|
||||
# test equality between nested dicts
|
||||
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
|
||||
|
|
|
@ -141,6 +141,7 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
|||
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
|
||||
and for now we add tests as we see fit.
|
||||
"""
|
||||
policy_kwargs["device"] = DEVICE
|
||||
|
||||
train_cfg = TrainPipelineConfig(
|
||||
# TODO(rcadene, aliberts): remove dataset download
|
||||
|
|
Loading…
Reference in New Issue