312 lines
10 KiB
Python
312 lines
10 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from unittest.mock import patch
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from lerobot.common.datasets.compute_stats import (
|
|
_assert_type_and_shape,
|
|
aggregate_feature_stats,
|
|
aggregate_stats,
|
|
compute_episode_stats,
|
|
estimate_num_samples,
|
|
get_feature_stats,
|
|
sample_images,
|
|
sample_indices,
|
|
)
|
|
|
|
|
|
def mock_load_image_as_numpy(path, dtype, channel_first):
|
|
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_array():
|
|
return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
|
|
|
|
|
def test_estimate_num_samples():
|
|
assert estimate_num_samples(1) == 1
|
|
assert estimate_num_samples(10) == 10
|
|
assert estimate_num_samples(100) == 100
|
|
assert estimate_num_samples(200) == 100
|
|
assert estimate_num_samples(1000) == 177
|
|
assert estimate_num_samples(2000) == 299
|
|
assert estimate_num_samples(5000) == 594
|
|
assert estimate_num_samples(10_000) == 1000
|
|
assert estimate_num_samples(20_000) == 1681
|
|
assert estimate_num_samples(50_000) == 3343
|
|
assert estimate_num_samples(500_000) == 10_000
|
|
|
|
|
|
def test_sample_indices():
|
|
indices = sample_indices(10)
|
|
assert len(indices) > 0
|
|
assert indices[0] == 0
|
|
assert indices[-1] == 9
|
|
assert len(indices) == estimate_num_samples(10)
|
|
|
|
|
|
@patch("lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy)
|
|
def test_sample_images(mock_load):
|
|
image_paths = [f"image_{i}.jpg" for i in range(100)]
|
|
images = sample_images(image_paths)
|
|
assert isinstance(images, np.ndarray)
|
|
assert images.shape[1:] == (3, 32, 32)
|
|
assert images.dtype == np.uint8
|
|
assert len(images) == estimate_num_samples(100)
|
|
|
|
|
|
def test_get_feature_stats_images():
|
|
data = np.random.rand(100, 3, 32, 32)
|
|
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
|
|
assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats
|
|
np.testing.assert_equal(stats["count"], np.array([100]))
|
|
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
|
|
|
|
|
|
def test_get_feature_stats_axis_0_keepdims(sample_array):
|
|
expected = {
|
|
"min": np.array([[1, 2, 3]]),
|
|
"max": np.array([[7, 8, 9]]),
|
|
"mean": np.array([[4.0, 5.0, 6.0]]),
|
|
"std": np.array([[2.44948974, 2.44948974, 2.44948974]]),
|
|
"count": np.array([3]),
|
|
}
|
|
result = get_feature_stats(sample_array, axis=(0,), keepdims=True)
|
|
for key in expected:
|
|
np.testing.assert_allclose(result[key], expected[key])
|
|
|
|
|
|
def test_get_feature_stats_axis_1(sample_array):
|
|
expected = {
|
|
"min": np.array([1, 4, 7]),
|
|
"max": np.array([3, 6, 9]),
|
|
"mean": np.array([2.0, 5.0, 8.0]),
|
|
"std": np.array([0.81649658, 0.81649658, 0.81649658]),
|
|
"count": np.array([3]),
|
|
}
|
|
result = get_feature_stats(sample_array, axis=(1,), keepdims=False)
|
|
for key in expected:
|
|
np.testing.assert_allclose(result[key], expected[key])
|
|
|
|
|
|
def test_get_feature_stats_no_axis(sample_array):
|
|
expected = {
|
|
"min": np.array(1),
|
|
"max": np.array(9),
|
|
"mean": np.array(5.0),
|
|
"std": np.array(2.5819889),
|
|
"count": np.array([3]),
|
|
}
|
|
result = get_feature_stats(sample_array, axis=None, keepdims=False)
|
|
for key in expected:
|
|
np.testing.assert_allclose(result[key], expected[key])
|
|
|
|
|
|
def test_get_feature_stats_empty_array():
|
|
array = np.array([])
|
|
with pytest.raises(ValueError):
|
|
get_feature_stats(array, axis=(0,), keepdims=True)
|
|
|
|
|
|
def test_get_feature_stats_single_value():
|
|
array = np.array([[1337]])
|
|
result = get_feature_stats(array, axis=None, keepdims=True)
|
|
np.testing.assert_equal(result["min"], np.array(1337))
|
|
np.testing.assert_equal(result["max"], np.array(1337))
|
|
np.testing.assert_equal(result["mean"], np.array(1337.0))
|
|
np.testing.assert_equal(result["std"], np.array(0.0))
|
|
np.testing.assert_equal(result["count"], np.array([1]))
|
|
|
|
|
|
def test_compute_episode_stats():
|
|
episode_data = {
|
|
"observation.image": [f"image_{i}.jpg" for i in range(100)],
|
|
"observation.state": np.random.rand(100, 10),
|
|
}
|
|
features = {
|
|
"observation.image": {"dtype": "image"},
|
|
"observation.state": {"dtype": "numeric"},
|
|
}
|
|
|
|
with patch(
|
|
"lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy
|
|
):
|
|
stats = compute_episode_stats(episode_data, features)
|
|
|
|
assert "observation.image" in stats and "observation.state" in stats
|
|
assert stats["observation.image"]["count"].item() == 100
|
|
assert stats["observation.state"]["count"].item() == 100
|
|
assert stats["observation.image"]["mean"].shape == (3, 1, 1)
|
|
|
|
|
|
def test_assert_type_and_shape_valid():
|
|
valid_stats = [
|
|
{
|
|
"feature1": {
|
|
"min": np.array([1.0]),
|
|
"max": np.array([10.0]),
|
|
"mean": np.array([5.0]),
|
|
"std": np.array([2.0]),
|
|
"count": np.array([1]),
|
|
}
|
|
}
|
|
]
|
|
_assert_type_and_shape(valid_stats)
|
|
|
|
|
|
def test_assert_type_and_shape_invalid_type():
|
|
invalid_stats = [
|
|
{
|
|
"feature1": {
|
|
"min": [1.0], # Not a numpy array
|
|
"max": np.array([10.0]),
|
|
"mean": np.array([5.0]),
|
|
"std": np.array([2.0]),
|
|
"count": np.array([1]),
|
|
}
|
|
}
|
|
]
|
|
with pytest.raises(ValueError, match="Stats must be composed of numpy array"):
|
|
_assert_type_and_shape(invalid_stats)
|
|
|
|
|
|
def test_assert_type_and_shape_invalid_shape():
|
|
invalid_stats = [
|
|
{
|
|
"feature1": {
|
|
"count": np.array([1, 2]), # Wrong shape
|
|
}
|
|
}
|
|
]
|
|
with pytest.raises(ValueError, match=r"Shape of 'count' must be \(1\)"):
|
|
_assert_type_and_shape(invalid_stats)
|
|
|
|
|
|
def test_aggregate_feature_stats():
|
|
stats_ft_list = [
|
|
{
|
|
"min": np.array([1.0]),
|
|
"max": np.array([10.0]),
|
|
"mean": np.array([5.0]),
|
|
"std": np.array([2.0]),
|
|
"count": np.array([1]),
|
|
},
|
|
{
|
|
"min": np.array([2.0]),
|
|
"max": np.array([12.0]),
|
|
"mean": np.array([6.0]),
|
|
"std": np.array([2.5]),
|
|
"count": np.array([1]),
|
|
},
|
|
]
|
|
result = aggregate_feature_stats(stats_ft_list)
|
|
np.testing.assert_allclose(result["min"], np.array([1.0]))
|
|
np.testing.assert_allclose(result["max"], np.array([12.0]))
|
|
np.testing.assert_allclose(result["mean"], np.array([5.5]))
|
|
np.testing.assert_allclose(result["std"], np.array([2.318405]), atol=1e-6)
|
|
np.testing.assert_allclose(result["count"], np.array([2]))
|
|
|
|
|
|
def test_aggregate_stats():
|
|
all_stats = [
|
|
{
|
|
"observation.image": {
|
|
"min": [1, 2, 3],
|
|
"max": [10, 20, 30],
|
|
"mean": [5.5, 10.5, 15.5],
|
|
"std": [2.87, 5.87, 8.87],
|
|
"count": 10,
|
|
},
|
|
"observation.state": {"min": 1, "max": 10, "mean": 5.5, "std": 2.87, "count": 10},
|
|
"extra_key_0": {"min": 5, "max": 25, "mean": 15, "std": 6, "count": 6},
|
|
},
|
|
{
|
|
"observation.image": {
|
|
"min": [2, 1, 0],
|
|
"max": [15, 10, 5],
|
|
"mean": [8.5, 5.5, 2.5],
|
|
"std": [3.42, 2.42, 1.42],
|
|
"count": 15,
|
|
},
|
|
"observation.state": {"min": 2, "max": 15, "mean": 8.5, "std": 3.42, "count": 15},
|
|
"extra_key_1": {"min": 0, "max": 20, "mean": 10, "std": 5, "count": 5},
|
|
},
|
|
]
|
|
|
|
expected_agg_stats = {
|
|
"observation.image": {
|
|
"min": [1, 1, 0],
|
|
"max": [15, 20, 30],
|
|
"mean": [7.3, 7.5, 7.7],
|
|
"std": [3.5317, 4.8267, 8.5581],
|
|
"count": 25,
|
|
},
|
|
"observation.state": {
|
|
"min": 1,
|
|
"max": 15,
|
|
"mean": 7.3,
|
|
"std": 3.5317,
|
|
"count": 25,
|
|
},
|
|
"extra_key_0": {
|
|
"min": 5,
|
|
"max": 25,
|
|
"mean": 15.0,
|
|
"std": 6.0,
|
|
"count": 6,
|
|
},
|
|
"extra_key_1": {
|
|
"min": 0,
|
|
"max": 20,
|
|
"mean": 10.0,
|
|
"std": 5.0,
|
|
"count": 5,
|
|
},
|
|
}
|
|
|
|
# cast to numpy
|
|
for ep_stats in all_stats:
|
|
for fkey, stats in ep_stats.items():
|
|
for k in stats:
|
|
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
|
|
if fkey == "observation.image" and k != "count":
|
|
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
|
|
else:
|
|
stats[k] = stats[k].reshape(1)
|
|
|
|
# cast to numpy
|
|
for fkey, stats in expected_agg_stats.items():
|
|
for k in stats:
|
|
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
|
|
if fkey == "observation.image" and k != "count":
|
|
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
|
|
else:
|
|
stats[k] = stats[k].reshape(1)
|
|
|
|
results = aggregate_stats(all_stats)
|
|
|
|
for fkey in expected_agg_stats:
|
|
np.testing.assert_allclose(results[fkey]["min"], expected_agg_stats[fkey]["min"])
|
|
np.testing.assert_allclose(results[fkey]["max"], expected_agg_stats[fkey]["max"])
|
|
np.testing.assert_allclose(results[fkey]["mean"], expected_agg_stats[fkey]["mean"])
|
|
np.testing.assert_allclose(
|
|
results[fkey]["std"], expected_agg_stats[fkey]["std"], atol=1e-04, rtol=1e-04
|
|
)
|
|
np.testing.assert_allclose(results[fkey]["count"], expected_agg_stats[fkey]["count"])
|