lerobot/tests/test_online_buffer.py

283 lines
12 KiB
Python

#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.d
from copy import deepcopy
from uuid import uuid4
import numpy as np
import pytest
import torch
from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
# Some constants for OnlineBuffer tests.
data_key = "data"
data_shape = (2, 3) # just some arbitrary > 1D shape
buffer_capacity = 100
fps = 10
def make_new_buffer(
write_dir: str | None = None, delta_timestamps: dict[str, list[float]] | None = None
) -> tuple[OnlineBuffer, str]:
if write_dir is None:
write_dir = f"/tmp/online_buffer_{uuid4().hex}"
buffer = OnlineBuffer(
write_dir,
data_spec={data_key: {"shape": data_shape, "dtype": np.dtype("float32")}},
buffer_capacity=buffer_capacity,
fps=fps,
delta_timestamps=delta_timestamps,
)
return buffer, write_dir
def make_spoof_data_frames(n_episodes: int, n_frames_per_episode: int) -> dict[str, np.ndarray]:
new_data = {
data_key: np.arange(n_frames_per_episode * n_episodes * np.prod(data_shape)).reshape(-1, *data_shape),
OnlineBuffer.INDEX_KEY: np.arange(n_frames_per_episode * n_episodes),
OnlineBuffer.EPISODE_INDEX_KEY: np.repeat(np.arange(n_episodes), n_frames_per_episode),
OnlineBuffer.FRAME_INDEX_KEY: np.tile(np.arange(n_frames_per_episode), n_episodes),
OnlineBuffer.TIMESTAMP_KEY: np.tile(np.arange(n_frames_per_episode) / fps, n_episodes),
}
return new_data
def test_non_mutate():
"""Checks that the data provided to the add_data method is copied rather than passed by reference.
This means that mutating the data in the buffer does not mutate the original data.
NOTE: If this test fails, it means some of the other tests may be compromised. For example, we can't trust
a success case for `test_write_read`.
"""
buffer, _ = make_new_buffer()
new_data = make_spoof_data_frames(2, buffer_capacity // 4)
new_data_copy = deepcopy(new_data)
buffer.add_data(new_data)
buffer._data[data_key][:] += 1
assert all(np.array_equal(new_data[k], new_data_copy[k]) for k in new_data)
def test_index_error_no_data():
buffer, _ = make_new_buffer()
with pytest.raises(IndexError):
buffer[0]
def test_index_error_with_data():
buffer, _ = make_new_buffer()
n_frames = buffer_capacity // 2
new_data = make_spoof_data_frames(1, n_frames)
buffer.add_data(new_data)
with pytest.raises(IndexError):
buffer[n_frames]
with pytest.raises(IndexError):
buffer[-n_frames - 1]
@pytest.mark.parametrize("do_reload", [False, True])
def test_write_read(do_reload: bool):
"""Checks that data can be added to the buffer and read back.
If do_reload we delete the buffer object and load the buffer back from disk before reading.
"""
buffer, write_dir = make_new_buffer()
n_episodes = 2
n_frames_per_episode = buffer_capacity // 4
new_data = make_spoof_data_frames(n_episodes, n_frames_per_episode)
buffer.add_data(new_data)
if do_reload:
del buffer
buffer, _ = make_new_buffer(write_dir)
assert len(buffer) == n_frames_per_episode * n_episodes
for i, item in enumerate(buffer):
assert all(isinstance(item[k], torch.Tensor) for k in item)
assert np.array_equal(item[data_key].numpy(), new_data[data_key][i])
def test_read_data_key():
"""Tests that data can be added to a buffer and all data for a. specific key can be read back."""
buffer, _ = make_new_buffer()
n_episodes = 2
n_frames_per_episode = buffer_capacity // 4
new_data = make_spoof_data_frames(n_episodes, n_frames_per_episode)
buffer.add_data(new_data)
data_from_buffer = buffer.get_data_by_key(data_key)
assert isinstance(data_from_buffer, torch.Tensor)
assert np.array_equal(data_from_buffer.numpy(), new_data[data_key])
def test_fifo():
"""Checks that if data is added beyond the buffer capacity, we discard the oldest data first."""
buffer, _ = make_new_buffer()
n_frames_per_episode = buffer_capacity // 4
n_episodes = 3
new_data = make_spoof_data_frames(n_episodes, n_frames_per_episode)
buffer.add_data(new_data)
n_more_episodes = 2
# Developer sanity check (in case someone changes the global `buffer_capacity`).
assert (n_episodes + n_more_episodes) * n_frames_per_episode > buffer_capacity, (
"Something went wrong with the test code."
)
more_new_data = make_spoof_data_frames(n_more_episodes, n_frames_per_episode)
buffer.add_data(more_new_data)
assert len(buffer) == buffer_capacity, "The buffer should be full."
expected_data = {}
for k in new_data:
# Concatenate, left-truncate, then roll, to imitate the cyclical FIFO pattern in OnlineBuffer.
expected_data[k] = np.roll(
np.concatenate([new_data[k], more_new_data[k]])[-buffer_capacity:],
shift=len(new_data[k]) + len(more_new_data[k]) - buffer_capacity,
axis=0,
)
for i, item in enumerate(buffer):
assert all(isinstance(item[k], torch.Tensor) for k in item)
assert np.array_equal(item[data_key].numpy(), expected_data[data_key][i])
def test_delta_timestamps_within_tolerance():
"""Check that getting an item with delta_timestamps within tolerance succeeds.
Note: Copied from `test_datasets.py::test_load_previous_and_future_frames_within_tolerance`.
"""
# Sanity check on global fps as we are assuming it is 10 here.
assert fps == 10, "This test assumes fps==10"
buffer, _ = make_new_buffer(delta_timestamps={"index": [-0.2, 0, 0.139]})
new_data = make_spoof_data_frames(n_episodes=1, n_frames_per_episode=5)
buffer.add_data(new_data)
buffer.tolerance_s = 0.04
item = buffer[2]
data, is_pad = item["index"], item[f"index{OnlineBuffer.IS_PAD_POSTFIX}"]
torch.testing.assert_close(data, torch.tensor([0, 2, 3]), msg="Data does not match expected values")
assert not is_pad.any(), "Unexpected padding detected"
def test_delta_timestamps_outside_tolerance_inside_episode_range():
"""Check that getting an item with delta_timestamps outside of tolerance fails.
We expect it to fail if and only if the requested timestamps are within the episode range.
Note: Copied from
`test_datasets.py::test_load_previous_and_future_frames_outside_tolerance_inside_episode_range`
"""
# Sanity check on global fps as we are assuming it is 10 here.
assert fps == 10, "This test assumes fps==10"
buffer, _ = make_new_buffer(delta_timestamps={"index": [-0.2, 0, 0.141]})
new_data = make_spoof_data_frames(n_episodes=1, n_frames_per_episode=5)
buffer.add_data(new_data)
buffer.tolerance_s = 0.04
with pytest.raises(AssertionError):
buffer[2]
def test_delta_timestamps_outside_tolerance_outside_episode_range():
"""Check that copy-padding of timestamps outside of the episode range works.
Note: Copied from
`test_datasets.py::test_load_previous_and_future_frames_outside_tolerance_outside_episode_range`
"""
# Sanity check on global fps as we are assuming it is 10 here.
assert fps == 10, "This test assumes fps==10"
buffer, _ = make_new_buffer(delta_timestamps={"index": [-0.3, -0.24, 0, 0.26, 0.3]})
new_data = make_spoof_data_frames(n_episodes=1, n_frames_per_episode=5)
buffer.add_data(new_data)
buffer.tolerance_s = 0.04
item = buffer[2]
data, is_pad = item["index"], item["index_is_pad"]
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), (
"Padding does not match expected values"
)
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
@pytest.mark.parametrize("offline_dataset_size", [1, 6])
@pytest.mark.parametrize("online_dataset_size", [0, 4])
@pytest.mark.parametrize("online_sampling_ratio", [0.0, 1.0])
def test_compute_sampler_weights_trivial(
lerobot_dataset_factory,
tmp_path,
offline_dataset_size: int,
online_dataset_size: int,
online_sampling_ratio: float,
):
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=offline_dataset_size)
online_dataset, _ = make_new_buffer()
if online_dataset_size > 0:
online_dataset.add_data(
make_spoof_data_frames(n_episodes=2, n_frames_per_episode=online_dataset_size // 2)
)
weights = compute_sampler_weights(
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
)
if offline_dataset_size == 0 or online_dataset_size == 0:
expected_weights = torch.ones(offline_dataset_size + online_dataset_size)
elif online_sampling_ratio == 0:
expected_weights = torch.cat([torch.ones(offline_dataset_size), torch.zeros(online_dataset_size)])
elif online_sampling_ratio == 1:
expected_weights = torch.cat([torch.zeros(offline_dataset_size), torch.ones(online_dataset_size)])
expected_weights /= expected_weights.sum()
torch.testing.assert_close(weights, expected_weights)
def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path):
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
online_sampling_ratio = 0.8
weights = compute_sampler_weights(
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=online_sampling_ratio
)
torch.testing.assert_close(
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
)
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_dataset_factory, tmp_path):
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
weights = compute_sampler_weights(
offline_dataset, online_dataset=online_dataset, online_sampling_ratio=0.8, online_drop_n_last_frames=1
)
torch.testing.assert_close(
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0])
)
def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp_path):
"""Note: test copied from test_sampler."""
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=2)
online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
weights = compute_sampler_weights(
offline_dataset,
offline_drop_n_last_frames=1,
online_dataset=online_dataset,
online_sampling_ratio=0.5,
online_drop_n_last_frames=1,
)
torch.testing.assert_close(weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0]))