hot fix
This commit is contained in:
parent
9b62c25f6c
commit
8ca0bc5e88
|
@ -8,6 +8,7 @@ especially in the context of imitation learning. The most reliable approach is t
|
|||
on the target environment, whether that be in simulation or the real world.
|
||||
"""
|
||||
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
@ -39,11 +40,24 @@ delta_timestamps = {
|
|||
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
|
||||
}
|
||||
|
||||
# Load the last 10 episodes of the dataset as a validation set.
|
||||
# The `split` argument utilizes the `datasets` library's syntax for slicing datasets.
|
||||
# For more information on the Slice API, please see:
|
||||
# Load the last 10% episodes of the dataset as a validation set.
|
||||
# - Load full dataset
|
||||
full_dataset = LeRobotDataset("lerobot/pusht", split="train")
|
||||
# - Calculate train and val subsets
|
||||
num_train_episodes = math.floor(full_dataset.num_episodes * 90 / 100)
|
||||
num_val_episodes = full_dataset.num_episodes - num_train_episodes
|
||||
print(f"Number of episodes in full dataset: {full_dataset.num_episodes}")
|
||||
print(f"Number of episodes in training dataset (90% subset): {num_train_episodes}")
|
||||
print(f"Number of episodes in validation dataset (10% reset): {num_val_episodes}")
|
||||
# - Get first frame index of the validation set
|
||||
first_val_frame_index = full_dataset.episode_data_index["from"][num_train_episodes].item()
|
||||
# - Load frames subset belonging to validation set using the `split` argument.
|
||||
# It utilizes the `datasets` library's syntax for slicing datasets.
|
||||
# For more information on the Slice API, please see:
|
||||
# https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
|
||||
val_dataset = LeRobotDataset("lerobot/pusht", split="train[24342:]", delta_timestamps=delta_timestamps)
|
||||
val_dataset = LeRobotDataset(
|
||||
"lerobot/pusht", split=f"train[{first_val_frame_index}:]", delta_timestamps=delta_timestamps
|
||||
)
|
||||
|
||||
# Create dataloader for evaluation.
|
||||
val_dataloader = torch.utils.data.DataLoader(
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
|
@ -80,7 +81,15 @@ def hf_transform_to_torch(items_dict):
|
|||
def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
if root is not None:
|
||||
hf_dataset = load_from_disk(str(Path(root) / repo_id / split))
|
||||
hf_dataset = load_from_disk(str(Path(root) / repo_id / "train"))
|
||||
# TODO(rcadene): clean this which enables getting a subset of dataset
|
||||
if split != "train":
|
||||
match = match = re.search(r"train\[(\d+):\]", split)
|
||||
if match:
|
||||
from_frame_index = int(match.group(1))
|
||||
hf_dataset = hf_dataset.select(range(from_frame_index, len(hf_dataset)))
|
||||
else:
|
||||
raise ValueError(split)
|
||||
else:
|
||||
hf_dataset = load_dataset(repo_id, revision=version, split=split)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
|
|
|
@ -111,7 +111,7 @@ def test_examples_2_through_4():
|
|||
'# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
|
||||
'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
|
||||
),
|
||||
('split="train[24342:]"', 'split="train[-1:]"'),
|
||||
('split=f"train[{first_val_frame_index}:]"', 'split="train[30:]"'),
|
||||
("num_workers=4", "num_workers=0"),
|
||||
('device = torch.device("cuda")', 'device = torch.device("cpu")'),
|
||||
("batch_size=64", "batch_size=1"),
|
||||
|
|
Loading…
Reference in New Issue