Merge branch 'main' into tutorial_act_pusht

This commit is contained in:
Alexander Soare 2024-05-20 17:38:35 +01:00 committed by GitHub
commit 5d1a498733
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 50 additions and 6 deletions

View File

@ -8,6 +8,7 @@ especially in the context of imitation learning. The most reliable approach is t
on the target environment, whether that be in simulation or the real world. on the target environment, whether that be in simulation or the real world.
""" """
import math
from pathlib import Path from pathlib import Path
import torch import torch
@ -39,11 +40,29 @@ delta_timestamps = {
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4], "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
} }
# Load the last 10 episodes of the dataset as a validation set. # Load the last 10% of episodes of the dataset as a validation set.
# The `split` argument utilizes the `datasets` library's syntax for slicing datasets. # - Load full dataset
full_dataset = LeRobotDataset("lerobot/pusht", split="train")
# - Calculate train and val subsets
num_train_episodes = math.floor(full_dataset.num_episodes * 90 / 100)
num_val_episodes = full_dataset.num_episodes - num_train_episodes
print(f"Number of episodes in full dataset: {full_dataset.num_episodes}")
print(f"Number of episodes in training dataset (90% subset): {num_train_episodes}")
print(f"Number of episodes in validation dataset (10% subset): {num_val_episodes}")
# - Get first frame index of the validation set
first_val_frame_index = full_dataset.episode_data_index["from"][num_train_episodes].item()
# - Load frames subset belonging to validation set using the `split` argument.
# It utilizes the `datasets` library's syntax for slicing datasets.
# For more information on the Slice API, please see: # For more information on the Slice API, please see:
# https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits # https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
val_dataset = LeRobotDataset("lerobot/pusht", split="train[24342:]", delta_timestamps=delta_timestamps) train_dataset = LeRobotDataset(
"lerobot/pusht", split=f"train[:{first_val_frame_index}]", delta_timestamps=delta_timestamps
)
val_dataset = LeRobotDataset(
"lerobot/pusht", split=f"train[{first_val_frame_index}:]", delta_timestamps=delta_timestamps
)
print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}")
print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}")
# Create dataloader for evaluation. # Create dataloader for evaluation.
val_dataloader = torch.utils.data.DataLoader( val_dataloader = torch.utils.data.DataLoader(

View File

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json import json
import re
from pathlib import Path from pathlib import Path
from typing import Dict from typing import Dict
@ -80,7 +81,23 @@ def hf_transform_to_torch(items_dict):
def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset: def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc.""" """hf_dataset contains all the observations, states, actions, rewards, etc."""
if root is not None: if root is not None:
hf_dataset = load_from_disk(str(Path(root) / repo_id / split)) hf_dataset = load_from_disk(str(Path(root) / repo_id / "train"))
# TODO(rcadene): clean this which enables getting a subset of dataset
if split != "train":
if "%" in split:
raise NotImplementedError(f"We dont support splitting based on percentage for now ({split}).")
match_from = re.search(r"train\[(\d+):\]", split)
match_to = re.search(r"train\[:(\d+)\]", split)
if match_from:
from_frame_index = int(match_from.group(1))
hf_dataset = hf_dataset.select(range(from_frame_index, len(hf_dataset)))
elif match_to:
to_frame_index = int(match_to.group(1))
hf_dataset = hf_dataset.select(range(to_frame_index))
else:
raise ValueError(
f'`split` ({split}) should either be "train", "train[INT:]", or "train[:INT]"'
)
else: else:
hf_dataset = load_dataset(repo_id, revision=version, split=split) hf_dataset = load_dataset(repo_id, revision=version, split=split)
hf_dataset.set_transform(hf_transform_to_torch) hf_dataset.set_transform(hf_transform_to_torch)
@ -273,6 +290,12 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torc
"to": [3, 7, 12] "to": [3, 7, 12]
} }
""" """
if len(hf_dataset) == 0:
episode_data_index = {
"from": torch.tensor([]),
"to": torch.tensor([]),
}
return episode_data_index
for idx, episode_idx in enumerate(hf_dataset["episode_index"]): for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
if episode_idx != current_episode: if episode_idx != current_episode:
# We encountered a new episode, so we append its starting location to the "from" list # We encountered a new episode, so we append its starting location to the "from" list
@ -303,6 +326,8 @@ def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
This brings the `episode_index` to the required format. This brings the `episode_index` to the required format.
""" """
if len(hf_dataset) == 0:
return hf_dataset
unique_episode_idxs = torch.stack(hf_dataset["episode_index"]).unique().tolist() unique_episode_idxs = torch.stack(hf_dataset["episode_index"]).unique().tolist()
episode_idx_to_reset_idx_mapping = { episode_idx_to_reset_idx_mapping = {
ep_id: reset_ep_id for reset_ep_id, ep_id in enumerate(unique_episode_idxs) ep_id: reset_ep_id for reset_ep_id, ep_id in enumerate(unique_episode_idxs)

View File

@ -111,7 +111,7 @@ def test_examples_2_through_4():
'# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', '# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")', 'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
), ),
('split="train[24342:]"', 'split="train[-1:]"'), ('split=f"train[{first_val_frame_index}:]"', 'split="train[30:]"'),
("num_workers=4", "num_workers=0"), ("num_workers=4", "num_workers=0"),
('device = torch.device("cuda")', 'device = torch.device("cpu")'), ('device = torch.device("cuda")', 'device = torch.device("cpu")'),
("batch_size=64", "batch_size=1"), ("batch_size=64", "batch_size=1"),