Merge branch 'main' into tutorial_act_pusht
This commit is contained in:
commit
5d1a498733
|
@ -8,6 +8,7 @@ especially in the context of imitation learning. The most reliable approach is t
|
||||||
on the target environment, whether that be in simulation or the real world.
|
on the target environment, whether that be in simulation or the real world.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -39,11 +40,29 @@ delta_timestamps = {
|
||||||
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
|
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Load the last 10 episodes of the dataset as a validation set.
|
# Load the last 10% of episodes of the dataset as a validation set.
|
||||||
# The `split` argument utilizes the `datasets` library's syntax for slicing datasets.
|
# - Load full dataset
|
||||||
# For more information on the Slice API, please see:
|
full_dataset = LeRobotDataset("lerobot/pusht", split="train")
|
||||||
|
# - Calculate train and val subsets
|
||||||
|
num_train_episodes = math.floor(full_dataset.num_episodes * 90 / 100)
|
||||||
|
num_val_episodes = full_dataset.num_episodes - num_train_episodes
|
||||||
|
print(f"Number of episodes in full dataset: {full_dataset.num_episodes}")
|
||||||
|
print(f"Number of episodes in training dataset (90% subset): {num_train_episodes}")
|
||||||
|
print(f"Number of episodes in validation dataset (10% subset): {num_val_episodes}")
|
||||||
|
# - Get first frame index of the validation set
|
||||||
|
first_val_frame_index = full_dataset.episode_data_index["from"][num_train_episodes].item()
|
||||||
|
# - Load frames subset belonging to validation set using the `split` argument.
|
||||||
|
# It utilizes the `datasets` library's syntax for slicing datasets.
|
||||||
|
# For more information on the Slice API, please see:
|
||||||
# https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
|
# https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
|
||||||
val_dataset = LeRobotDataset("lerobot/pusht", split="train[24342:]", delta_timestamps=delta_timestamps)
|
train_dataset = LeRobotDataset(
|
||||||
|
"lerobot/pusht", split=f"train[:{first_val_frame_index}]", delta_timestamps=delta_timestamps
|
||||||
|
)
|
||||||
|
val_dataset = LeRobotDataset(
|
||||||
|
"lerobot/pusht", split=f"train[{first_val_frame_index}:]", delta_timestamps=delta_timestamps
|
||||||
|
)
|
||||||
|
print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}")
|
||||||
|
print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}")
|
||||||
|
|
||||||
# Create dataloader for evaluation.
|
# Create dataloader for evaluation.
|
||||||
val_dataloader = torch.utils.data.DataLoader(
|
val_dataloader = torch.utils.data.DataLoader(
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
|
@ -80,7 +81,23 @@ def hf_transform_to_torch(items_dict):
|
||||||
def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
|
def load_hf_dataset(repo_id, version, root, split) -> datasets.Dataset:
|
||||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||||
if root is not None:
|
if root is not None:
|
||||||
hf_dataset = load_from_disk(str(Path(root) / repo_id / split))
|
hf_dataset = load_from_disk(str(Path(root) / repo_id / "train"))
|
||||||
|
# TODO(rcadene): clean this which enables getting a subset of dataset
|
||||||
|
if split != "train":
|
||||||
|
if "%" in split:
|
||||||
|
raise NotImplementedError(f"We dont support splitting based on percentage for now ({split}).")
|
||||||
|
match_from = re.search(r"train\[(\d+):\]", split)
|
||||||
|
match_to = re.search(r"train\[:(\d+)\]", split)
|
||||||
|
if match_from:
|
||||||
|
from_frame_index = int(match_from.group(1))
|
||||||
|
hf_dataset = hf_dataset.select(range(from_frame_index, len(hf_dataset)))
|
||||||
|
elif match_to:
|
||||||
|
to_frame_index = int(match_to.group(1))
|
||||||
|
hf_dataset = hf_dataset.select(range(to_frame_index))
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f'`split` ({split}) should either be "train", "train[INT:]", or "train[:INT]"'
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
hf_dataset = load_dataset(repo_id, revision=version, split=split)
|
hf_dataset = load_dataset(repo_id, revision=version, split=split)
|
||||||
hf_dataset.set_transform(hf_transform_to_torch)
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
|
@ -273,6 +290,12 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torc
|
||||||
"to": [3, 7, 12]
|
"to": [3, 7, 12]
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
if len(hf_dataset) == 0:
|
||||||
|
episode_data_index = {
|
||||||
|
"from": torch.tensor([]),
|
||||||
|
"to": torch.tensor([]),
|
||||||
|
}
|
||||||
|
return episode_data_index
|
||||||
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
|
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
|
||||||
if episode_idx != current_episode:
|
if episode_idx != current_episode:
|
||||||
# We encountered a new episode, so we append its starting location to the "from" list
|
# We encountered a new episode, so we append its starting location to the "from" list
|
||||||
|
@ -303,6 +326,8 @@ def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
|
||||||
|
|
||||||
This brings the `episode_index` to the required format.
|
This brings the `episode_index` to the required format.
|
||||||
"""
|
"""
|
||||||
|
if len(hf_dataset) == 0:
|
||||||
|
return hf_dataset
|
||||||
unique_episode_idxs = torch.stack(hf_dataset["episode_index"]).unique().tolist()
|
unique_episode_idxs = torch.stack(hf_dataset["episode_index"]).unique().tolist()
|
||||||
episode_idx_to_reset_idx_mapping = {
|
episode_idx_to_reset_idx_mapping = {
|
||||||
ep_id: reset_ep_id for reset_ep_id, ep_id in enumerate(unique_episode_idxs)
|
ep_id: reset_ep_id for reset_ep_id, ep_id in enumerate(unique_episode_idxs)
|
||||||
|
|
|
@ -111,7 +111,7 @@ def test_examples_2_through_4():
|
||||||
'# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
|
'# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
|
||||||
'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
|
'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
|
||||||
),
|
),
|
||||||
('split="train[24342:]"', 'split="train[-1:]"'),
|
('split=f"train[{first_val_frame_index}:]"', 'split="train[30:]"'),
|
||||||
("num_workers=4", "num_workers=0"),
|
("num_workers=4", "num_workers=0"),
|
||||||
('device = torch.device("cuda")', 'device = torch.device("cpu")'),
|
('device = torch.device("cuda")', 'device = torch.device("cpu")'),
|
||||||
("batch_size=64", "batch_size=1"),
|
("batch_size=64", "batch_size=1"),
|
||||||
|
|
Loading…
Reference in New Issue