squash commit
This commit is contained in:
parent
6eaffbef1d
commit
b699a2f484
|
@ -1,15 +1,12 @@
|
|||
import logging
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def make_dataset(
|
||||
cfg,
|
||||
split="train",
|
||||
):
|
||||
def make_dataset(cfg: DictConfig, split="train") -> LeRobotDataset:
|
||||
if cfg.env.name not in cfg.dataset_repo_id:
|
||||
logging.warning(
|
||||
f"There might be a mismatch between your training dataset ({cfg.dataset_repo_id=}) and your "
|
||||
|
@ -28,6 +25,7 @@ def make_dataset(
|
|||
cfg.dataset_repo_id,
|
||||
split=split,
|
||||
delta_timestamps=delta_timestamps,
|
||||
n_end_keyframes_dropped=eval(cfg.training.get("n_end_keyframes_dropped", "0")),
|
||||
)
|
||||
|
||||
if cfg.get("override_dataset_stats"):
|
||||
|
|
|
@ -27,7 +27,26 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
split: str = "train",
|
||||
transform: callable = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
n_end_keyframes_dropped: int = 0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
delta_timestamps: A dictionary mapping lists of relative times (Δt) to data keys. When a frame is
|
||||
sampled from the underlying dataset, we treat it as a "keyframe" and load multiple frames
|
||||
according to the list of Δt's. For example {"action": [-0.05, 0, 0.05]} indicates
|
||||
that we want to load the current keyframe's action, as well as one from 50 ms ago, and one
|
||||
50 ms into the future. The action key then contains a (3, action_dim) tensor (whereas without
|
||||
`delta_timestamps` there would just be a (action_dim,) tensor. When the Δt's demand that
|
||||
frames outside of an episode boundary are retrieved, a copy padding strategy is used. See
|
||||
`load_previous_and_future_frames` for more details.
|
||||
n_end_keyframes_dropped: Don't sample the last n items in each episode. This option is handy when
|
||||
used in combination with `delta_timestamps` when, for example, the Δt's demand multiple future
|
||||
frames, but we want to avoid introducing too much copy padding into the data distribution.
|
||||
For example if `delta_timestamps = {"action": [0, 0.05, 0.10, 0.15, 0.20, 0.25, 0.30]}`
|
||||
and we sample the last frame in the episode, we would end up padding with 6 frames worth of
|
||||
copies. Instead, we might want no padding (in which case we need n=6), or we might be okay
|
||||
with up to 2 frames of padding (in which case we need n=4).
|
||||
"""
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.version = version
|
||||
|
@ -44,6 +63,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
self.info = load_info(repo_id, version, root)
|
||||
if self.video:
|
||||
self.videos_dir = load_videos(repo_id, version, root)
|
||||
# If `n_end_keyframes_dropped == 0`, `self.index` contains exactly the indices of the hf_dataset. If
|
||||
# `n_end_keyframes_dropped > 0`, `self.index` contains a subset of the indices of the hf_dataset where
|
||||
# we drop those indices pertaining to the last n frames of each episode.
|
||||
self.index = []
|
||||
for from_ix, to_ix in zip(*self.episode_data_index.values(), strict=True):
|
||||
self.index.extend(list(range(from_ix, to_ix - n_end_keyframes_dropped)))
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
|
@ -78,7 +103,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
return len(self.hf_dataset)
|
||||
return len(self.index)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
|
@ -97,7 +122,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.hf_dataset[idx]
|
||||
item = self.hf_dataset[self.index[idx]]
|
||||
|
||||
if self.delta_timestamps is not None:
|
||||
item = load_previous_and_future_frames(
|
||||
|
|
|
@ -25,11 +25,21 @@ training:
|
|||
adam_weight_decay: 1.0e-6
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
# For each training batch we want (consider n_obs_steps=2, horizon=16):
|
||||
# t | -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14
|
||||
# action | a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a
|
||||
# observation | o, o, , , , , , , , , , , , , ,
|
||||
# Note that at rollout we only use some of the actions (consider n_action_steps=8):
|
||||
# action used | , a, a, a, a, a, a, a, a, , , , , , ,
|
||||
delta_timestamps:
|
||||
observation.image: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]"
|
||||
|
||||
# The original implementation doesn't sample keyframes for the last 7 steps. This is because, as described
|
||||
# above, the last 7 actions from the diffusion model are not used.
|
||||
n_end_keyframes_dropped: ${policy.horizon} - ${policy.n_action_steps} - ${policy.n_obs_steps} + 1
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
|
|
@ -100,6 +100,7 @@ def test_compute_stats_on_xarm():
|
|||
|
||||
# reduce size of dataset sample on which stats compute is tested to 10 frames
|
||||
dataset.hf_dataset = dataset.hf_dataset.select(range(10))
|
||||
dataset.index = [i for i in dataset.index if i < 10]
|
||||
|
||||
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
|
||||
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
|
||||
|
|
Loading…
Reference in New Issue