Merge branch 'main' into tutorial_act_pusht
This commit is contained in:
commit
5d1a498733
|
@ -8,6 +8,7 @@ especially in the context of imitation learning. The most reliable approach is t
|
|||
on the target environment, whether that be in simulation or the real world.
|
||||
"""
|
||||
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
@ -39,11 +40,29 @@ delta_timestamps = {
|
|||
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
|
||||
}
|
||||
|
||||
# Load the last 10 episodes of the dataset as a validation set.
|
||||
# The `split` argument utilizes the `datasets` library's syntax for slicing datasets.
|
||||
# For more information on the Slice API, please see:
|
||||
# Load the last 10% of episodes of the dataset as a validation set.
|
||||
# - Load full dataset
|
||||
full_dataset = LeRobotDataset("lerobot/pusht", split="train")
|
||||
# - Calculate train and val subsets
|
||||
num_train_episodes = math.floor(full_dataset.num_episodes * 90 / 100)
|
||||
num_val_episodes = full_dataset.num_episodes - num_train_episodes
|
||||
print(f"Number of episodes in full dataset: {full_dataset.num_episodes}")
|
||||
print(f"Number of episodes in training dataset (90% subset): {num_train_episodes}")
|
||||
print(f"Number of episodes in validation dataset (10% subset): {num_val_episodes}")
|
||||
# - Get first frame index of the validation set
|
||||
first_val_frame_index = full_dataset.episode_data_index["from"][num_train_episodes].item()
|
||||
# - Load frames subset belonging to validation set using the `split` argument.
|
||||
# It utilizes the `datasets` library's syntax for slicing datasets.
|
||||
# For more information on the Slice API, please see:
|
||||
# https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
|
||||
val_dataset = LeRobotDataset("lerobot/pusht", split="train[24342:]", delta_timestamps=delta_timestamps)
|
||||
train_dataset = LeRobotDataset(
|
||||
"lerobot/pusht", split=f"train[:{first_val_frame_index}]", delta_timestamps=delta_timestamps
|
||||
)
|
||||
val_dataset = LeRobotDataset(
|
||||
"lerobot/pusht", split=f"train[{first_val_frame_index}:]", delta_timestamps=delta_timestamps
|
||||
)
|
||||
print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}")
|
||||
print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}")
|
||||
|
||||
# Create dataloader for evaluation.
|
||||
val_dataloader = torch.utils.data.DataLoader(
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
|
@ -80,7 +81,23 @@ def hf_transform_to_torch(items_dict):
|
|||
def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
if root is not None:
|
||||
hf_dataset = load_from_disk(str(Path(root) / repo_id / split))
|
||||
hf_dataset = load_from_disk(str(Path(root) / repo_id / "train"))
|
||||
# TODO(rcadene): clean this which enables getting a subset of dataset
|
||||
if split != "train":
|
||||
if "%" in split:
|
||||
raise NotImplementedError(f"We dont support splitting based on percentage for now ({split}).")
|
||||
match_from = re.search(r"train\[(\d+):\]", split)
|
||||
match_to = re.search(r"train\[:(\d+)\]", split)
|
||||
if match_from:
|
||||
from_frame_index = int(match_from.group(1))
|
||||
hf_dataset = hf_dataset.select(range(from_frame_index, len(hf_dataset)))
|
||||
elif match_to:
|
||||
to_frame_index = int(match_to.group(1))
|
||||
hf_dataset = hf_dataset.select(range(to_frame_index))
|
||||
else:
|
||||
raise ValueError(
|
||||
f'`split` ({split}) should either be "train", "train[INT:]", or "train[:INT]"'
|
||||
)
|
||||
else:
|
||||
hf_dataset = load_dataset(repo_id, revision=version, split=split)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
|
@ -273,6 +290,12 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torc
|
|||
"to": [3, 7, 12]
|
||||
}
|
||||
"""
|
||||
if len(hf_dataset) == 0:
|
||||
episode_data_index = {
|
||||
"from": torch.tensor([]),
|
||||
"to": torch.tensor([]),
|
||||
}
|
||||
return episode_data_index
|
||||
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
|
||||
if episode_idx != current_episode:
|
||||
# We encountered a new episode, so we append its starting location to the "from" list
|
||||
|
@ -303,6 +326,8 @@ def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
|
|||
|
||||
This brings the `episode_index` to the required format.
|
||||
"""
|
||||
if len(hf_dataset) == 0:
|
||||
return hf_dataset
|
||||
unique_episode_idxs = torch.stack(hf_dataset["episode_index"]).unique().tolist()
|
||||
episode_idx_to_reset_idx_mapping = {
|
||||
ep_id: reset_ep_id for reset_ep_id, ep_id in enumerate(unique_episode_idxs)
|
||||
|
|
|
@ -111,7 +111,7 @@ def test_examples_2_through_4():
|
|||
'# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
|
||||
'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
|
||||
),
|
||||
('split="train[24342:]"', 'split="train[-1:]"'),
|
||||
('split=f"train[{first_val_frame_index}:]"', 'split="train[30:]"'),
|
||||
("num_workers=4", "num_workers=0"),
|
||||
('device = torch.device("cuda")', 'device = torch.device("cpu")'),
|
||||
("batch_size=64", "batch_size=1"),
|
||||
|
|
Loading…
Reference in New Issue