squash commit
This commit is contained in:
parent
6eaffbef1d
commit
b699a2f484
|
@ -1,15 +1,12 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
|
||||||
|
|
||||||
def make_dataset(
|
def make_dataset(cfg: DictConfig, split="train") -> LeRobotDataset:
|
||||||
cfg,
|
|
||||||
split="train",
|
|
||||||
):
|
|
||||||
if cfg.env.name not in cfg.dataset_repo_id:
|
if cfg.env.name not in cfg.dataset_repo_id:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"There might be a mismatch between your training dataset ({cfg.dataset_repo_id=}) and your "
|
f"There might be a mismatch between your training dataset ({cfg.dataset_repo_id=}) and your "
|
||||||
|
@ -28,6 +25,7 @@ def make_dataset(
|
||||||
cfg.dataset_repo_id,
|
cfg.dataset_repo_id,
|
||||||
split=split,
|
split=split,
|
||||||
delta_timestamps=delta_timestamps,
|
delta_timestamps=delta_timestamps,
|
||||||
|
n_end_keyframes_dropped=eval(cfg.training.get("n_end_keyframes_dropped", "0")),
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.get("override_dataset_stats"):
|
if cfg.get("override_dataset_stats"):
|
||||||
|
|
|
@ -27,7 +27,26 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
split: str = "train",
|
split: str = "train",
|
||||||
transform: callable = None,
|
transform: callable = None,
|
||||||
delta_timestamps: dict[list[float]] | None = None,
|
delta_timestamps: dict[list[float]] | None = None,
|
||||||
|
n_end_keyframes_dropped: int = 0,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
delta_timestamps: A dictionary mapping lists of relative times (Δt) to data keys. When a frame is
|
||||||
|
sampled from the underlying dataset, we treat it as a "keyframe" and load multiple frames
|
||||||
|
according to the list of Δt's. For example {"action": [-0.05, 0, 0.05]} indicates
|
||||||
|
that we want to load the current keyframe's action, as well as one from 50 ms ago, and one
|
||||||
|
50 ms into the future. The action key then contains a (3, action_dim) tensor (whereas without
|
||||||
|
`delta_timestamps` there would just be a (action_dim,) tensor. When the Δt's demand that
|
||||||
|
frames outside of an episode boundary are retrieved, a copy padding strategy is used. See
|
||||||
|
`load_previous_and_future_frames` for more details.
|
||||||
|
n_end_keyframes_dropped: Don't sample the last n items in each episode. This option is handy when
|
||||||
|
used in combination with `delta_timestamps` when, for example, the Δt's demand multiple future
|
||||||
|
frames, but we want to avoid introducing too much copy padding into the data distribution.
|
||||||
|
For example if `delta_timestamps = {"action": [0, 0.05, 0.10, 0.15, 0.20, 0.25, 0.30]}`
|
||||||
|
and we sample the last frame in the episode, we would end up padding with 6 frames worth of
|
||||||
|
copies. Instead, we might want no padding (in which case we need n=6), or we might be okay
|
||||||
|
with up to 2 frames of padding (in which case we need n=4).
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.repo_id = repo_id
|
self.repo_id = repo_id
|
||||||
self.version = version
|
self.version = version
|
||||||
|
@ -44,6 +63,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
self.info = load_info(repo_id, version, root)
|
self.info = load_info(repo_id, version, root)
|
||||||
if self.video:
|
if self.video:
|
||||||
self.videos_dir = load_videos(repo_id, version, root)
|
self.videos_dir = load_videos(repo_id, version, root)
|
||||||
|
# If `n_end_keyframes_dropped == 0`, `self.index` contains exactly the indices of the hf_dataset. If
|
||||||
|
# `n_end_keyframes_dropped > 0`, `self.index` contains a subset of the indices of the hf_dataset where
|
||||||
|
# we drop those indices pertaining to the last n frames of each episode.
|
||||||
|
self.index = []
|
||||||
|
for from_ix, to_ix in zip(*self.episode_data_index.values(), strict=True):
|
||||||
|
self.index.extend(list(range(from_ix, to_ix - n_end_keyframes_dropped)))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def fps(self) -> int:
|
def fps(self) -> int:
|
||||||
|
@ -78,7 +103,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_samples(self) -> int:
|
def num_samples(self) -> int:
|
||||||
return len(self.hf_dataset)
|
return len(self.index)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_episodes(self) -> int:
|
def num_episodes(self) -> int:
|
||||||
|
@ -97,7 +122,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
return self.num_samples
|
return self.num_samples
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
item = self.hf_dataset[idx]
|
item = self.hf_dataset[self.index[idx]]
|
||||||
|
|
||||||
if self.delta_timestamps is not None:
|
if self.delta_timestamps is not None:
|
||||||
item = load_previous_and_future_frames(
|
item = load_previous_and_future_frames(
|
||||||
|
|
|
@ -25,11 +25,21 @@ training:
|
||||||
adam_weight_decay: 1.0e-6
|
adam_weight_decay: 1.0e-6
|
||||||
online_steps_between_rollouts: 1
|
online_steps_between_rollouts: 1
|
||||||
|
|
||||||
|
# For each training batch we want (consider n_obs_steps=2, horizon=16):
|
||||||
|
# t | -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14
|
||||||
|
# action | a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a
|
||||||
|
# observation | o, o, , , , , , , , , , , , , ,
|
||||||
|
# Note that at rollout we only use some of the actions (consider n_action_steps=8):
|
||||||
|
# action used | , a, a, a, a, a, a, a, a, , , , , , ,
|
||||||
delta_timestamps:
|
delta_timestamps:
|
||||||
observation.image: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
observation.image: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||||
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||||
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]"
|
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]"
|
||||||
|
|
||||||
|
# The original implementation doesn't sample keyframes for the last 7 steps. This is because, as described
|
||||||
|
# above, the last 7 actions from the diffusion model are not used.
|
||||||
|
n_end_keyframes_dropped: ${policy.horizon} - ${policy.n_action_steps} - ${policy.n_obs_steps} + 1
|
||||||
|
|
||||||
eval:
|
eval:
|
||||||
n_episodes: 50
|
n_episodes: 50
|
||||||
batch_size: 50
|
batch_size: 50
|
||||||
|
|
|
@ -100,6 +100,7 @@ def test_compute_stats_on_xarm():
|
||||||
|
|
||||||
# reduce size of dataset sample on which stats compute is tested to 10 frames
|
# reduce size of dataset sample on which stats compute is tested to 10 frames
|
||||||
dataset.hf_dataset = dataset.hf_dataset.select(range(10))
|
dataset.hf_dataset = dataset.hf_dataset.select(range(10))
|
||||||
|
dataset.index = [i for i in dataset.index if i < 10]
|
||||||
|
|
||||||
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
|
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
|
||||||
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
|
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
|
||||||
|
|
Loading…
Reference in New Issue