squash commit

This commit is contained in:
Alexander Soare 2024-05-05 18:50:00 +01:00
parent 6eaffbef1d
commit b699a2f484
4 changed files with 41 additions and 7 deletions

View File

@ -1,15 +1,12 @@
import logging import logging
import torch import torch
from omegaconf import OmegaConf from omegaconf import DictConfig, OmegaConf
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
def make_dataset( def make_dataset(cfg: DictConfig, split="train") -> LeRobotDataset:
cfg,
split="train",
):
if cfg.env.name not in cfg.dataset_repo_id: if cfg.env.name not in cfg.dataset_repo_id:
logging.warning( logging.warning(
f"There might be a mismatch between your training dataset ({cfg.dataset_repo_id=}) and your " f"There might be a mismatch between your training dataset ({cfg.dataset_repo_id=}) and your "
@ -28,6 +25,7 @@ def make_dataset(
cfg.dataset_repo_id, cfg.dataset_repo_id,
split=split, split=split,
delta_timestamps=delta_timestamps, delta_timestamps=delta_timestamps,
n_end_keyframes_dropped=eval(cfg.training.get("n_end_keyframes_dropped", "0")),
) )
if cfg.get("override_dataset_stats"): if cfg.get("override_dataset_stats"):

View File

@ -27,7 +27,26 @@ class LeRobotDataset(torch.utils.data.Dataset):
split: str = "train", split: str = "train",
transform: callable = None, transform: callable = None,
delta_timestamps: dict[list[float]] | None = None, delta_timestamps: dict[list[float]] | None = None,
n_end_keyframes_dropped: int = 0,
): ):
"""
Args:
delta_timestamps: A dictionary mapping lists of relative times (Δt) to data keys. When a frame is
sampled from the underlying dataset, we treat it as a "keyframe" and load multiple frames
according to the list of Δt's. For example {"action": [-0.05, 0, 0.05]} indicates
that we want to load the current keyframe's action, as well as one from 50 ms ago, and one
50 ms into the future. The action key then contains a (3, action_dim) tensor (whereas without
`delta_timestamps` there would just be a (action_dim,) tensor. When the Δt's demand that
frames outside of an episode boundary are retrieved, a copy padding strategy is used. See
`load_previous_and_future_frames` for more details.
n_end_keyframes_dropped: Don't sample the last n items in each episode. This option is handy when
used in combination with `delta_timestamps` when, for example, the Δt's demand multiple future
frames, but we want to avoid introducing too much copy padding into the data distribution.
For example if `delta_timestamps = {"action": [0, 0.05, 0.10, 0.15, 0.20, 0.25, 0.30]}`
and we sample the last frame in the episode, we would end up padding with 6 frames worth of
copies. Instead, we might want no padding (in which case we need n=6), or we might be okay
with up to 2 frames of padding (in which case we need n=4).
"""
super().__init__() super().__init__()
self.repo_id = repo_id self.repo_id = repo_id
self.version = version self.version = version
@ -44,6 +63,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.info = load_info(repo_id, version, root) self.info = load_info(repo_id, version, root)
if self.video: if self.video:
self.videos_dir = load_videos(repo_id, version, root) self.videos_dir = load_videos(repo_id, version, root)
# If `n_end_keyframes_dropped == 0`, `self.index` contains exactly the indices of the hf_dataset. If
# `n_end_keyframes_dropped > 0`, `self.index` contains a subset of the indices of the hf_dataset where
# we drop those indices pertaining to the last n frames of each episode.
self.index = []
for from_ix, to_ix in zip(*self.episode_data_index.values(), strict=True):
self.index.extend(list(range(from_ix, to_ix - n_end_keyframes_dropped)))
@property @property
def fps(self) -> int: def fps(self) -> int:
@ -78,7 +103,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
@property @property
def num_samples(self) -> int: def num_samples(self) -> int:
return len(self.hf_dataset) return len(self.index)
@property @property
def num_episodes(self) -> int: def num_episodes(self) -> int:
@ -97,7 +122,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
return self.num_samples return self.num_samples
def __getitem__(self, idx): def __getitem__(self, idx):
item = self.hf_dataset[idx] item = self.hf_dataset[self.index[idx]]
if self.delta_timestamps is not None: if self.delta_timestamps is not None:
item = load_previous_and_future_frames( item = load_previous_and_future_frames(

View File

@ -25,11 +25,21 @@ training:
adam_weight_decay: 1.0e-6 adam_weight_decay: 1.0e-6
online_steps_between_rollouts: 1 online_steps_between_rollouts: 1
# For each training batch we want (consider n_obs_steps=2, horizon=16):
# t | -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14
# action | a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a
# observation | o, o, , , , , , , , , , , , , ,
# Note that at rollout we only use some of the actions (consider n_action_steps=8):
# action used | , a, a, a, a, a, a, a, a, , , , , , ,
delta_timestamps: delta_timestamps:
observation.image: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]" observation.image: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]" observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]" action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]"
# The original implementation doesn't sample keyframes for the last 7 steps. This is because, as described
# above, the last 7 actions from the diffusion model are not used.
n_end_keyframes_dropped: ${policy.horizon} - ${policy.n_action_steps} - ${policy.n_obs_steps} + 1
eval: eval:
n_episodes: 50 n_episodes: 50
batch_size: 50 batch_size: 50

View File

@ -100,6 +100,7 @@ def test_compute_stats_on_xarm():
# reduce size of dataset sample on which stats compute is tested to 10 frames # reduce size of dataset sample on which stats compute is tested to 10 frames
dataset.hf_dataset = dataset.hf_dataset.select(range(10)) dataset.hf_dataset = dataset.hf_dataset.select(range(10))
dataset.index = [i for i in dataset.index if i < 10]
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
# computation of the statistics. While doing this, we also make sure it works when we don't divide the # computation of the statistics. While doing this, we also make sure it works when we don't divide the