This commit is contained in:
Remi Cadene 2024-05-20 15:09:41 +00:00
parent 9b62c25f6c
commit 8ca0bc5e88
3 changed files with 29 additions and 6 deletions

View File

@ -8,6 +8,7 @@ especially in the context of imitation learning. The most reliable approach is t
on the target environment, whether that be in simulation or the real world.
"""
import math
from pathlib import Path
import torch
@ -39,11 +40,24 @@ delta_timestamps = {
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
}
# Load the last 10 episodes of the dataset as a validation set.
# The `split` argument utilizes the `datasets` library's syntax for slicing datasets.
# For more information on the Slice API, please see:
# Load the last 10% episodes of the dataset as a validation set.
# - Load full dataset
full_dataset = LeRobotDataset("lerobot/pusht", split="train")
# - Calculate train and val subsets
num_train_episodes = math.floor(full_dataset.num_episodes * 90 / 100)
num_val_episodes = full_dataset.num_episodes - num_train_episodes
print(f"Number of episodes in full dataset: {full_dataset.num_episodes}")
print(f"Number of episodes in training dataset (90% subset): {num_train_episodes}")
print(f"Number of episodes in validation dataset (10% reset): {num_val_episodes}")
# - Get first frame index of the validation set
first_val_frame_index = full_dataset.episode_data_index["from"][num_train_episodes].item()
# - Load frames subset belonging to validation set using the `split` argument.
# It utilizes the `datasets` library's syntax for slicing datasets.
# For more information on the Slice API, please see:
# https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
val_dataset = LeRobotDataset("lerobot/pusht", split="train[24342:]", delta_timestamps=delta_timestamps)
val_dataset = LeRobotDataset(
"lerobot/pusht", split=f"train[{first_val_frame_index}:]", delta_timestamps=delta_timestamps
)
# Create dataloader for evaluation.
val_dataloader = torch.utils.data.DataLoader(

View File

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
from pathlib import Path
from typing import Dict
@ -80,7 +81,15 @@ def hf_transform_to_torch(items_dict):
def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
if root is not None:
hf_dataset = load_from_disk(str(Path(root) / repo_id / split))
hf_dataset = load_from_disk(str(Path(root) / repo_id / "train"))
# TODO(rcadene): clean this which enables getting a subset of dataset
if split != "train":
match = match = re.search(r"train\[(\d+):\]", split)
if match:
from_frame_index = int(match.group(1))
hf_dataset = hf_dataset.select(range(from_frame_index, len(hf_dataset)))
else:
raise ValueError(split)
else:
hf_dataset = load_dataset(repo_id, revision=version, split=split)
hf_dataset.set_transform(hf_transform_to_torch)

View File

@ -111,7 +111,7 @@ def test_examples_2_through_4():
'# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
),
('split="train[24342:]"', 'split="train[-1:]"'),
('split=f"train[{first_val_frame_index}:]"', 'split="train[30:]"'),
("num_workers=4", "num_workers=0"),
('device = torch.device("cuda")', 'device = torch.device("cpu")'),
("batch_size=64", "batch_size=1"),