22 lines
719 B
Python
22 lines
719 B
Python
import torch
|
|
from datasets import Dataset
|
|
|
|
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
|
from lerobot.common.datasets.utils import (
|
|
hf_transform_to_torch,
|
|
)
|
|
|
|
|
|
def test_calculate_episode_data_index():
|
|
dataset = Dataset.from_dict(
|
|
{
|
|
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
|
"index": [0, 1, 2, 3, 4, 5],
|
|
"episode_index": [0, 0, 1, 2, 2, 2],
|
|
},
|
|
)
|
|
dataset.set_transform(hf_transform_to_torch)
|
|
episode_data_index = calculate_episode_data_index(dataset)
|
|
assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3]))
|
|
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))
|