Fix unit tests
This commit is contained in:
parent
601b5fdbfe
commit
367d9bda7d
|
@ -187,6 +187,7 @@ def aggregate_datasets(repo_ids: list[str], aggr_repo_id: str, roots: list[Path]
|
||||||
for chunk_idx, file_idx in data_chunk_file_ids:
|
for chunk_idx, file_idx in data_chunk_file_ids:
|
||||||
path = meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
path = meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||||
df = pd.read_parquet(path)
|
df = pd.read_parquet(path)
|
||||||
|
# TODO(rcadene): update frame index
|
||||||
update_data_func = get_update_episode_and_task_func(num_episodes, meta.tasks, aggr_meta.tasks)
|
update_data_func = get_update_episode_and_task_func(num_episodes, meta.tasks, aggr_meta.tasks)
|
||||||
df = df.apply(update_data_func, axis=1)
|
df = df.apply(update_data_func, axis=1)
|
||||||
|
|
||||||
|
|
|
@ -197,16 +197,15 @@ def convert_data(root, new_root):
|
||||||
def get_video_keys(root):
|
def get_video_keys(root):
|
||||||
info = load_info(root)
|
info = load_info(root)
|
||||||
features = info["features"]
|
features = info["features"]
|
||||||
image_keys = [key for key, ft in features.items() if ft["dtype"] == "image"]
|
|
||||||
if len(image_keys) != 0:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
|
video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
|
||||||
return video_keys
|
return video_keys
|
||||||
|
|
||||||
|
|
||||||
def convert_videos(root: Path, new_root: Path):
|
def convert_videos(root: Path, new_root: Path):
|
||||||
video_keys = get_video_keys(root)
|
video_keys = get_video_keys(root)
|
||||||
|
if len(video_keys) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
video_keys = sorted(video_keys)
|
video_keys = sorted(video_keys)
|
||||||
|
|
||||||
eps_metadata_per_cam = []
|
eps_metadata_per_cam = []
|
||||||
|
@ -284,24 +283,32 @@ def convert_videos_of_camera(root: Path, new_root: Path, video_key):
|
||||||
|
|
||||||
|
|
||||||
def generate_episode_metadata_dict(
|
def generate_episode_metadata_dict(
|
||||||
episodes_legacy_metadata, episodes_metadata, episodes_videos, episodes_stats
|
episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_videos=None
|
||||||
):
|
):
|
||||||
for ep_legacy_metadata, ep_metadata, ep_video, ep_stats, ep_idx_stats in zip(
|
num_episodes = len(episodes_metadata)
|
||||||
episodes_legacy_metadata.values(),
|
episodes_legacy_metadata_vals = list(episodes_legacy_metadata.values())
|
||||||
episodes_metadata,
|
episodes_stats_vals = list(episodes_stats.values())
|
||||||
episodes_videos,
|
episodes_stats_keys = list(episodes_stats.keys())
|
||||||
episodes_stats.values(),
|
|
||||||
episodes_stats.keys(),
|
|
||||||
strict=False,
|
|
||||||
):
|
|
||||||
ep_idx = ep_legacy_metadata["episode_index"]
|
|
||||||
ep_idx_data = ep_metadata["episode_index"]
|
|
||||||
ep_idx_video = ep_video["episode_index"]
|
|
||||||
|
|
||||||
if len({ep_idx, ep_idx_data, ep_idx_video, ep_idx_stats}) != 1:
|
for i in range(num_episodes):
|
||||||
raise ValueError(
|
ep_legacy_metadata = episodes_legacy_metadata_vals[i]
|
||||||
f"Number of episodes is not the same ({ep_idx=},{ep_idx_data=},{ep_idx_video=},{ep_idx_stats=})."
|
ep_metadata = episodes_metadata[i]
|
||||||
)
|
ep_stats = episodes_stats_vals[i]
|
||||||
|
|
||||||
|
ep_ids_set = {
|
||||||
|
ep_legacy_metadata["episode_index"],
|
||||||
|
ep_metadata["episode_index"],
|
||||||
|
episodes_stats_keys[i],
|
||||||
|
}
|
||||||
|
|
||||||
|
if episodes_videos is None:
|
||||||
|
ep_video = {}
|
||||||
|
else:
|
||||||
|
ep_video = episodes_videos[i]
|
||||||
|
ep_ids_set.add(ep_video["episode_index"])
|
||||||
|
|
||||||
|
if len(ep_ids_set) != 1:
|
||||||
|
raise ValueError(f"Number of episodes is not the same ({ep_ids_set}).")
|
||||||
|
|
||||||
ep_dict = {**ep_metadata, **ep_video, **ep_legacy_metadata, **flatten_dict({"stats": ep_stats})}
|
ep_dict = {**ep_metadata, **ep_video, **ep_legacy_metadata, **flatten_dict({"stats": ep_stats})}
|
||||||
ep_dict["meta/episodes/chunk_index"] = 0
|
ep_dict["meta/episodes/chunk_index"] = 0
|
||||||
|
@ -309,21 +316,20 @@ def generate_episode_metadata_dict(
|
||||||
yield ep_dict
|
yield ep_dict
|
||||||
|
|
||||||
|
|
||||||
def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_metadata):
|
def convert_episodes_metadata(root, new_root, episodes_metadata, episodes_video_metadata=None):
|
||||||
episodes_legacy_metadata = legacy_load_episodes(root)
|
episodes_legacy_metadata = legacy_load_episodes(root)
|
||||||
episodes_stats = legacy_load_episodes_stats(root)
|
episodes_stats = legacy_load_episodes_stats(root)
|
||||||
|
|
||||||
num_eps = len(episodes_legacy_metadata)
|
num_eps_set = {len(episodes_legacy_metadata), len(episodes_metadata)}
|
||||||
num_eps_metadata = len(episodes_metadata)
|
if episodes_video_metadata is not None:
|
||||||
num_eps_video_metadata = len(episodes_video_metadata)
|
num_eps_set.add(len(episodes_video_metadata))
|
||||||
if len({num_eps, num_eps_metadata, num_eps_video_metadata}) != 1:
|
|
||||||
raise ValueError(
|
if len(num_eps_set) != 1:
|
||||||
f"Number of episodes is not the same ({num_eps=},{num_eps_metadata=},{num_eps_video_metadata=})."
|
raise ValueError(f"Number of episodes is not the same ({num_eps_set}).")
|
||||||
)
|
|
||||||
|
|
||||||
ds_episodes = Dataset.from_generator(
|
ds_episodes = Dataset.from_generator(
|
||||||
lambda: generate_episode_metadata_dict(
|
lambda: generate_episode_metadata_dict(
|
||||||
episodes_legacy_metadata, episodes_metadata, episodes_video_metadata, episodes_stats
|
episodes_legacy_metadata, episodes_metadata, episodes_stats, episodes_video_metadata
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
write_episodes(ds_episodes, new_root)
|
write_episodes(ds_episodes, new_root)
|
||||||
|
|
|
@ -13,10 +13,8 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from copy import deepcopy
|
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
@ -36,8 +34,6 @@ from lerobot.common.datasets.lerobot_dataset import (
|
||||||
)
|
)
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
create_branch,
|
create_branch,
|
||||||
flatten_dict,
|
|
||||||
unflatten_dict,
|
|
||||||
)
|
)
|
||||||
from lerobot.common.envs.factory import make_env_config
|
from lerobot.common.envs.factory import make_env_config
|
||||||
from lerobot.common.policies.factory import make_policy_config
|
from lerobot.common.policies.factory import make_policy_config
|
||||||
|
@ -100,6 +96,25 @@ def test_dataset_initialization(tmp_path, lerobot_dataset_factory):
|
||||||
assert dataset.num_frames == len(dataset)
|
assert dataset.num_frames == len(dataset)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(rcadene, aliberts): do not run LeRobotDataset.create, instead refactor LeRobotDatasetMetadata.create
|
||||||
|
# and test the small resulting function that validates the features
|
||||||
|
def test_dataset_feature_with_forward_slash_raises_error():
|
||||||
|
# make sure dir does not exist
|
||||||
|
from lerobot.common.constants import HF_LEROBOT_HOME
|
||||||
|
|
||||||
|
dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash"
|
||||||
|
# make sure does not exist
|
||||||
|
if dataset_dir.exists():
|
||||||
|
dataset_dir.rmdir()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
LeRobotDataset.create(
|
||||||
|
repo_id="lerobot/test/with/slash",
|
||||||
|
fps=30,
|
||||||
|
features={"a/b": {"dtype": "float32", "shape": 2, "names": None}},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
|
def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
|
||||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
|
@ -329,6 +344,13 @@ def test_image_array_to_pil_image_wrong_range_float_0_255():
|
||||||
# - [ ] test push_to_hub
|
# - [ ] test push_to_hub
|
||||||
# - [ ] test smaller methods
|
# - [ ] test smaller methods
|
||||||
|
|
||||||
|
# TODO(rcadene):
|
||||||
|
# - [ ] fix code so that old test_factory + backward pass
|
||||||
|
# - [ ] write new unit tests to test save_episode + getitem
|
||||||
|
# - [ ] save_episode : case where new dataset, concatenate same file, write new file (meta/episodes, data, videos)
|
||||||
|
# - [ ]
|
||||||
|
# - [ ] remove old tests
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"env_name, repo_id, policy_name",
|
"env_name, repo_id, policy_name",
|
||||||
|
@ -436,30 +458,6 @@ def test_multidataset_frames():
|
||||||
assert torch.equal(sub_dataset_item[k], dataset_item[k])
|
assert torch.equal(sub_dataset_item[k], dataset_item[k])
|
||||||
|
|
||||||
|
|
||||||
# TODO(aliberts): Move to more appropriate location
|
|
||||||
def test_flatten_unflatten_dict():
|
|
||||||
d = {
|
|
||||||
"obs": {
|
|
||||||
"min": 0,
|
|
||||||
"max": 1,
|
|
||||||
"mean": 2,
|
|
||||||
"std": 3,
|
|
||||||
},
|
|
||||||
"action": {
|
|
||||||
"min": 4,
|
|
||||||
"max": 5,
|
|
||||||
"mean": 6,
|
|
||||||
"std": 7,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
original_d = deepcopy(d)
|
|
||||||
d = unflatten_dict(flatten_dict(d))
|
|
||||||
|
|
||||||
# test equality between nested dicts
|
|
||||||
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"repo_id",
|
"repo_id",
|
||||||
[
|
[
|
||||||
|
@ -569,20 +567,3 @@ def test_create_branch():
|
||||||
|
|
||||||
# Clean
|
# Clean
|
||||||
api.delete_repo(repo_id, repo_type=repo_type)
|
api.delete_repo(repo_id, repo_type=repo_type)
|
||||||
|
|
||||||
|
|
||||||
def test_dataset_feature_with_forward_slash_raises_error():
|
|
||||||
# make sure dir does not exist
|
|
||||||
from lerobot.common.constants import HF_LEROBOT_HOME
|
|
||||||
|
|
||||||
dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash"
|
|
||||||
# make sure does not exist
|
|
||||||
if dataset_dir.exists():
|
|
||||||
dataset_dir.rmdir()
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
LeRobotDataset.create(
|
|
||||||
repo_id="lerobot/test/with/slash",
|
|
||||||
fps=30,
|
|
||||||
features={"a/b": {"dtype": "float32", "shape": 2, "names": None}},
|
|
||||||
)
|
|
||||||
|
|
|
@ -14,12 +14,20 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from huggingface_hub import DatasetCard
|
from huggingface_hub import DatasetCard
|
||||||
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||||
from lerobot.common.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch
|
from lerobot.common.datasets.utils import (
|
||||||
|
create_lerobot_dataset_card,
|
||||||
|
flatten_dict,
|
||||||
|
hf_transform_to_torch,
|
||||||
|
unflatten_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_default_parameters():
|
def test_default_parameters():
|
||||||
|
@ -53,3 +61,26 @@ def test_calculate_episode_data_index():
|
||||||
episode_data_index = calculate_episode_data_index(dataset)
|
episode_data_index = calculate_episode_data_index(dataset)
|
||||||
assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3]))
|
assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3]))
|
||||||
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))
|
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_flatten_unflatten_dict():
|
||||||
|
d = {
|
||||||
|
"obs": {
|
||||||
|
"min": 0,
|
||||||
|
"max": 1,
|
||||||
|
"mean": 2,
|
||||||
|
"std": 3,
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"min": 4,
|
||||||
|
"max": 5,
|
||||||
|
"mean": 6,
|
||||||
|
"std": 7,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
original_d = deepcopy(d)
|
||||||
|
d = unflatten_dict(flatten_dict(d))
|
||||||
|
|
||||||
|
# test equality between nested dicts
|
||||||
|
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
|
||||||
|
|
|
@ -141,6 +141,7 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
||||||
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
|
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
|
||||||
and for now we add tests as we see fit.
|
and for now we add tests as we see fit.
|
||||||
"""
|
"""
|
||||||
|
policy_kwargs["device"] = DEVICE
|
||||||
|
|
||||||
train_cfg = TrainPipelineConfig(
|
train_cfg = TrainPipelineConfig(
|
||||||
# TODO(rcadene, aliberts): remove dataset download
|
# TODO(rcadene, aliberts): remove dataset download
|
||||||
|
|
Loading…
Reference in New Issue