# Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from datasets import Dataset from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index from lerobot.common.datasets.utils import ( hf_transform_to_torch, ) def test_calculate_episode_data_index(): dataset = Dataset.from_dict( { "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], "index": [0, 1, 2, 3, 4, 5], "episode_index": [0, 0, 1, 2, 2, 2], }, ) dataset.set_transform(hf_transform_to_torch) episode_data_index = calculate_episode_data_index(dataset) assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3])) assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))