283 lines
12 KiB
Python
283 lines
12 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.d
|
|
from copy import deepcopy
|
|
from uuid import uuid4
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
|
|
from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
|
|
|
|
# Some constants for OnlineBuffer tests.
|
|
data_key = "data"
|
|
data_shape = (2, 3) # just some arbitrary > 1D shape
|
|
buffer_capacity = 100
|
|
fps = 10
|
|
|
|
|
|
def make_new_buffer(
|
|
write_dir: str | None = None, delta_timestamps: dict[str, list[float]] | None = None
|
|
) -> tuple[OnlineBuffer, str]:
|
|
if write_dir is None:
|
|
write_dir = f"/tmp/online_buffer_{uuid4().hex}"
|
|
buffer = OnlineBuffer(
|
|
write_dir,
|
|
data_spec={data_key: {"shape": data_shape, "dtype": np.dtype("float32")}},
|
|
buffer_capacity=buffer_capacity,
|
|
fps=fps,
|
|
delta_timestamps=delta_timestamps,
|
|
)
|
|
return buffer, write_dir
|
|
|
|
|
|
def make_spoof_data_frames(n_episodes: int, n_frames_per_episode: int) -> dict[str, np.ndarray]:
|
|
new_data = {
|
|
data_key: np.arange(n_frames_per_episode * n_episodes * np.prod(data_shape)).reshape(-1, *data_shape),
|
|
OnlineBuffer.INDEX_KEY: np.arange(n_frames_per_episode * n_episodes),
|
|
OnlineBuffer.EPISODE_INDEX_KEY: np.repeat(np.arange(n_episodes), n_frames_per_episode),
|
|
OnlineBuffer.FRAME_INDEX_KEY: np.tile(np.arange(n_frames_per_episode), n_episodes),
|
|
OnlineBuffer.TIMESTAMP_KEY: np.tile(np.arange(n_frames_per_episode) / fps, n_episodes),
|
|
}
|
|
return new_data
|
|
|
|
|
|
def test_non_mutate():
|
|
"""Checks that the data provided to the add_data method is copied rather than passed by reference.
|
|
|
|
This means that mutating the data in the buffer does not mutate the original data.
|
|
|
|
NOTE: If this test fails, it means some of the other tests may be compromised. For example, we can't trust
|
|
a success case for `test_write_read`.
|
|
"""
|
|
buffer, _ = make_new_buffer()
|
|
new_data = make_spoof_data_frames(2, buffer_capacity // 4)
|
|
new_data_copy = deepcopy(new_data)
|
|
buffer.add_data(new_data)
|
|
buffer._data[data_key][:] += 1
|
|
assert all(np.array_equal(new_data[k], new_data_copy[k]) for k in new_data)
|
|
|
|
|
|
def test_index_error_no_data():
|
|
buffer, _ = make_new_buffer()
|
|
with pytest.raises(IndexError):
|
|
buffer[0]
|
|
|
|
|
|
def test_index_error_with_data():
|
|
buffer, _ = make_new_buffer()
|
|
n_frames = buffer_capacity // 2
|
|
new_data = make_spoof_data_frames(1, n_frames)
|
|
buffer.add_data(new_data)
|
|
with pytest.raises(IndexError):
|
|
buffer[n_frames]
|
|
with pytest.raises(IndexError):
|
|
buffer[-n_frames - 1]
|
|
|
|
|
|
@pytest.mark.parametrize("do_reload", [False, True])
|
|
def test_write_read(do_reload: bool):
|
|
"""Checks that data can be added to the buffer and read back.
|
|
|
|
If do_reload we delete the buffer object and load the buffer back from disk before reading.
|
|
"""
|
|
buffer, write_dir = make_new_buffer()
|
|
n_episodes = 2
|
|
n_frames_per_episode = buffer_capacity // 4
|
|
new_data = make_spoof_data_frames(n_episodes, n_frames_per_episode)
|
|
buffer.add_data(new_data)
|
|
|
|
if do_reload:
|
|
del buffer
|
|
buffer, _ = make_new_buffer(write_dir)
|
|
|
|
assert len(buffer) == n_frames_per_episode * n_episodes
|
|
for i, item in enumerate(buffer):
|
|
assert all(isinstance(item[k], torch.Tensor) for k in item)
|
|
assert np.array_equal(item[data_key].numpy(), new_data[data_key][i])
|
|
|
|
|
|
def test_read_data_key():
|
|
"""Tests that data can be added to a buffer and all data for a. specific key can be read back."""
|
|
buffer, _ = make_new_buffer()
|
|
n_episodes = 2
|
|
n_frames_per_episode = buffer_capacity // 4
|
|
new_data = make_spoof_data_frames(n_episodes, n_frames_per_episode)
|
|
buffer.add_data(new_data)
|
|
|
|
data_from_buffer = buffer.get_data_by_key(data_key)
|
|
assert isinstance(data_from_buffer, torch.Tensor)
|
|
assert np.array_equal(data_from_buffer.numpy(), new_data[data_key])
|
|
|
|
|
|
def test_fifo():
|
|
"""Checks that if data is added beyond the buffer capacity, we discard the oldest data first."""
|
|
buffer, _ = make_new_buffer()
|
|
n_frames_per_episode = buffer_capacity // 4
|
|
n_episodes = 3
|
|
new_data = make_spoof_data_frames(n_episodes, n_frames_per_episode)
|
|
buffer.add_data(new_data)
|
|
n_more_episodes = 2
|
|
# Developer sanity check (in case someone changes the global `buffer_capacity`).
|
|
assert (n_episodes + n_more_episodes) * n_frames_per_episode > buffer_capacity, (
|
|
"Something went wrong with the test code."
|
|
)
|
|
more_new_data = make_spoof_data_frames(n_more_episodes, n_frames_per_episode)
|
|
buffer.add_data(more_new_data)
|
|
assert len(buffer) == buffer_capacity, "The buffer should be full."
|
|
|
|
expected_data = {}
|
|
for k in new_data:
|
|
# Concatenate, left-truncate, then roll, to imitate the cyclical FIFO pattern in OnlineBuffer.
|
|
expected_data[k] = np.roll(
|
|
np.concatenate([new_data[k], more_new_data[k]])[-buffer_capacity:],
|
|
shift=len(new_data[k]) + len(more_new_data[k]) - buffer_capacity,
|
|
axis=0,
|
|
)
|
|
|
|
for i, item in enumerate(buffer):
|
|
assert all(isinstance(item[k], torch.Tensor) for k in item)
|
|
assert np.array_equal(item[data_key].numpy(), expected_data[data_key][i])
|
|
|
|
|
|
def test_delta_timestamps_within_tolerance():
|
|
"""Check that getting an item with delta_timestamps within tolerance succeeds.
|
|
|
|
Note: Copied from `test_datasets.py::test_load_previous_and_future_frames_within_tolerance`.
|
|
"""
|
|
# Sanity check on global fps as we are assuming it is 10 here.
|
|
assert fps == 10, "This test assumes fps==10"
|
|
buffer, _ = make_new_buffer(delta_timestamps={"index": [-0.2, 0, 0.139]})
|
|
new_data = make_spoof_data_frames(n_episodes=1, n_frames_per_episode=5)
|
|
buffer.add_data(new_data)
|
|
buffer.tolerance_s = 0.04
|
|
item = buffer[2]
|
|
data, is_pad = item["index"], item[f"index{OnlineBuffer.IS_PAD_POSTFIX}"]
|
|
torch.testing.assert_close(data, torch.tensor([0, 2, 3]), msg="Data does not match expected values")
|
|
assert not is_pad.any(), "Unexpected padding detected"
|
|
|
|
|
|
def test_delta_timestamps_outside_tolerance_inside_episode_range():
|
|
"""Check that getting an item with delta_timestamps outside of tolerance fails.
|
|
|
|
We expect it to fail if and only if the requested timestamps are within the episode range.
|
|
|
|
Note: Copied from
|
|
`test_datasets.py::test_load_previous_and_future_frames_outside_tolerance_inside_episode_range`
|
|
"""
|
|
# Sanity check on global fps as we are assuming it is 10 here.
|
|
assert fps == 10, "This test assumes fps==10"
|
|
buffer, _ = make_new_buffer(delta_timestamps={"index": [-0.2, 0, 0.141]})
|
|
new_data = make_spoof_data_frames(n_episodes=1, n_frames_per_episode=5)
|
|
buffer.add_data(new_data)
|
|
buffer.tolerance_s = 0.04
|
|
with pytest.raises(AssertionError):
|
|
buffer[2]
|
|
|
|
|
|
def test_delta_timestamps_outside_tolerance_outside_episode_range():
|
|
"""Check that copy-padding of timestamps outside of the episode range works.
|
|
|
|
Note: Copied from
|
|
`test_datasets.py::test_load_previous_and_future_frames_outside_tolerance_outside_episode_range`
|
|
"""
|
|
# Sanity check on global fps as we are assuming it is 10 here.
|
|
assert fps == 10, "This test assumes fps==10"
|
|
buffer, _ = make_new_buffer(delta_timestamps={"index": [-0.3, -0.24, 0, 0.26, 0.3]})
|
|
new_data = make_spoof_data_frames(n_episodes=1, n_frames_per_episode=5)
|
|
buffer.add_data(new_data)
|
|
buffer.tolerance_s = 0.04
|
|
item = buffer[2]
|
|
data, is_pad = item["index"], item["index_is_pad"]
|
|
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
|
|
assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), (
|
|
"Padding does not match expected values"
|
|
)
|
|
|
|
|
|
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
|
|
@pytest.mark.parametrize("offline_dataset_size", [1, 6])
|
|
@pytest.mark.parametrize("online_dataset_size", [0, 4])
|
|
@pytest.mark.parametrize("online_sampling_ratio", [0.0, 1.0])
|
|
def test_compute_sampler_weights_trivial(
|
|
lerobot_dataset_factory,
|
|
tmp_path,
|
|
offline_dataset_size: int,
|
|
online_dataset_size: int,
|
|
online_sampling_ratio: float,
|
|
):
|
|
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=offline_dataset_size)
|
|
online_dataset, _ = make_new_buffer()
|
|
if online_dataset_size > 0:
|
|
online_dataset.add_data(
|
|
make_spoof_data_frames(n_episodes=2, n_frames_per_episode=online_dataset_size // 2)
|
|
)
|
|
|
|
weights = compute_sampler_weights(
|
|
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
|
|
)
|
|
if offline_dataset_size == 0 or online_dataset_size == 0:
|
|
expected_weights = torch.ones(offline_dataset_size + online_dataset_size)
|
|
elif online_sampling_ratio == 0:
|
|
expected_weights = torch.cat([torch.ones(offline_dataset_size), torch.zeros(online_dataset_size)])
|
|
elif online_sampling_ratio == 1:
|
|
expected_weights = torch.cat([torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)])
|
|
expected_weights /= expected_weights.sum()
|
|
torch.testing.assert_close(weights, expected_weights)
|
|
|
|
|
|
def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path):
|
|
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
|
|
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
|
|
online_dataset, _ = make_new_buffer()
|
|
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
|
online_sampling_ratio = 0.8
|
|
weights = compute_sampler_weights(
|
|
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
|
|
)
|
|
torch.testing.assert_close(
|
|
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
|
|
)
|
|
|
|
|
|
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_dataset_factory, tmp_path):
|
|
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
|
|
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
|
|
online_dataset, _ = make_new_buffer()
|
|
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
|
weights = compute_sampler_weights(
|
|
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=0.8, online_drop_n_last_frames=1
|
|
)
|
|
torch.testing.assert_close(
|
|
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0])
|
|
)
|
|
|
|
|
|
def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp_path):
|
|
"""Note: test copied from test_sampler."""
|
|
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=2)
|
|
online_dataset, _ = make_new_buffer()
|
|
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
|
|
|
weights = compute_sampler_weights(
|
|
offline_dataset,
|
|
offline_drop_n_last_frames=1,
|
|
online_dataset=online_dataset,
|
|
online_sampling_ratio=0.5,
|
|
online_drop_n_last_frames=1,
|
|
)
|
|
torch.testing.assert_close(weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0]))
|