This commit is contained in:
Remi Cadene 2024-05-20 15:09:41 +00:00
parent 6d6f1fafc8
commit 4a19733dea
3 changed files with 29 additions and 6 deletions

View File

@ -8,6 +8,7 @@ especially in the context of imitation learning. The most reliable approach is t
on the target environment, whether that be in simulation or the real world. on the target environment, whether that be in simulation or the real world.
""" """
import math
from pathlib import Path from pathlib import Path
import torch import torch
@ -39,11 +40,24 @@ delta_timestamps = {
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4], "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
} }
# Load the last 10 episodes of the dataset as a validation set. # Load the last 10% episodes of the dataset as a validation set.
# The `split` argument utilizes the `datasets` library's syntax for slicing datasets. # - Load full dataset
# For more information on the Slice API, please see: full_dataset = LeRobotDataset("lerobot/pusht", split="train")
# - Calculate train and val subsets
num_train_episodes = math.floor(full_dataset.num_episodes * 90 / 100)
num_val_episodes = full_dataset.num_episodes - num_train_episodes
print(f"Number of episodes in full dataset: {full_dataset.num_episodes}")
print(f"Number of episodes in training dataset (90% subset): {num_train_episodes}")
print(f"Number of episodes in validation dataset (10% reset): {num_val_episodes}")
# - Get first frame index of the validation set
first_val_frame_index = full_dataset.episode_data_index["from"][num_train_episodes].item()
# - Load frames subset belonging to validation set using the `split` argument.
# It utilizes the `datasets` library's syntax for slicing datasets.
# For more information on the Slice API, please see:
# https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits # https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
val_dataset = LeRobotDataset("lerobot/pusht", split="train[24342:]", delta_timestamps=delta_timestamps) val_dataset = LeRobotDataset(
"lerobot/pusht", split=f"train[{first_val_frame_index}:]", delta_timestamps=delta_timestamps
)
# Create dataloader for evaluation. # Create dataloader for evaluation.
val_dataloader = torch.utils.data.DataLoader( val_dataloader = torch.utils.data.DataLoader(

View File

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json import json
import re
from pathlib import Path from pathlib import Path
from typing import Dict from typing import Dict
@ -80,7 +81,15 @@ def hf_transform_to_torch(items_dict):
def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset: def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc.""" """hf_dataset contains all the observations, states, actions, rewards, etc."""
if root is not None: if root is not None:
hf_dataset = load_from_disk(str(Path(root) / repo_id / split)) hf_dataset = load_from_disk(str(Path(root) / repo_id / "train"))
# TODO(rcadene): clean this which enables getting a subset of dataset
if split != "train":
match = match = re.search(r"train\[(\d+):\]", split)
if match:
from_frame_index = int(match.group(1))
hf_dataset = hf_dataset.select(range(from_frame_index, len(hf_dataset)))
else:
raise ValueError(split)
else: else:
hf_dataset = load_dataset(repo_id, revision=version, split=split) hf_dataset = load_dataset(repo_id, revision=version, split=split)
hf_dataset.set_transform(hf_transform_to_torch) hf_dataset.set_transform(hf_transform_to_torch)

View File

@ -111,7 +111,7 @@ def test_examples_2_through_4():
'# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', '# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', 'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
), ),
('split="train[24342:]"', 'split="train[-1:]"'), ('split=f"train[{first_val_frame_index}:]"', 'split="train[30:]"'),
("num_workers=4", "num_workers=0"), ("num_workers=4", "num_workers=0"),
('device = torch.device("cuda")', 'device = torch.device("cpu")'), ('device = torch.device("cuda")', 'device = torch.device("cpu")'),
("batch_size=64", "batch_size=1"), ("batch_size=64", "batch_size=1"),