squash commit

This commit is contained in:
Alexander Soare 2024-05-05 18:50:00 +01:00
parent 6eaffbef1d
commit b699a2f484
4 changed files with 41 additions and 7 deletions

View File

@ -1,15 +1,12 @@
import logging
import torch
from omegaconf import OmegaConf
from omegaconf import DictConfig, OmegaConf
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
def make_dataset(
cfg,
split="train",
):
def make_dataset(cfg: DictConfig, split="train") -> LeRobotDataset:
if cfg.env.name not in cfg.dataset_repo_id:
logging.warning(
f"There might be a mismatch between your training dataset ({cfg.dataset_repo_id=}) and your "
@ -28,6 +25,7 @@ def make_dataset(
cfg.dataset_repo_id,
split=split,
delta_timestamps=delta_timestamps,
n_end_keyframes_dropped=eval(cfg.training.get("n_end_keyframes_dropped", "0")),
)
if cfg.get("override_dataset_stats"):

View File

@ -27,7 +27,26 @@ class LeRobotDataset(torch.utils.data.Dataset):
split: str = "train",
transform: callable = None,
delta_timestamps: dict[list[float]] | None = None,
n_end_keyframes_dropped: int = 0,
):
"""
Args:
delta_timestamps: A dictionary mapping lists of relative times (Δt) to data keys. When a frame is
sampled from the underlying dataset, we treat it as a "keyframe" and load multiple frames
according to the list of Δt's. For example {"action": [-0.05, 0, 0.05]} indicates
that we want to load the current keyframe's action, as well as one from 50 ms ago, and one
50 ms into the future. The action key then contains a (3, action_dim) tensor (whereas without
`delta_timestamps` there would just be a (action_dim,) tensor. When the Δt's demand that
frames outside of an episode boundary are retrieved, a copy padding strategy is used. See
`load_previous_and_future_frames` for more details.
n_end_keyframes_dropped: Don't sample the last n items in each episode. This option is handy when
used in combination with `delta_timestamps` when, for example, the Δt's demand multiple future
frames, but we want to avoid introducing too much copy padding into the data distribution.
For example if `delta_timestamps = {"action": [0, 0.05, 0.10, 0.15, 0.20, 0.25, 0.30]}`
and we sample the last frame in the episode, we would end up padding with 6 frames worth of
copies. Instead, we might want no padding (in which case we need n=6), or we might be okay
with up to 2 frames of padding (in which case we need n=4).
"""
super().__init__()
self.repo_id = repo_id
self.version = version
@ -44,6 +63,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.info = load_info(repo_id, version, root)
if self.video:
self.videos_dir = load_videos(repo_id, version, root)
# If `n_end_keyframes_dropped == 0`, `self.index` contains exactly the indices of the hf_dataset. If
# `n_end_keyframes_dropped > 0`, `self.index` contains a subset of the indices of the hf_dataset where
# we drop those indices pertaining to the last n frames of each episode.
self.index = []
for from_ix, to_ix in zip(*self.episode_data_index.values(), strict=True):
self.index.extend(list(range(from_ix, to_ix - n_end_keyframes_dropped)))
@property
def fps(self) -> int:
@ -78,7 +103,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
@property
def num_samples(self) -> int:
return len(self.hf_dataset)
return len(self.index)
@property
def num_episodes(self) -> int:
@ -97,7 +122,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
return self.num_samples
def __getitem__(self, idx):
item = self.hf_dataset[idx]
item = self.hf_dataset[self.index[idx]]
if self.delta_timestamps is not None:
item = load_previous_and_future_frames(

View File

@ -25,11 +25,21 @@ training:
adam_weight_decay: 1.0e-6
online_steps_between_rollouts: 1
# For each training batch we want (consider n_obs_steps=2, horizon=16):
# t | -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14
# action | a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a
# observation | o, o, , , , , , , , , , , , , ,
# Note that at rollout we only use some of the actions (consider n_action_steps=8):
# action used | , a, a, a, a, a, a, a, a, , , , , , ,
delta_timestamps:
observation.image: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]"
# The original implementation doesn't sample keyframes for the last 7 steps. This is because, as described
# above, the last 7 actions from the diffusion model are not used.
n_end_keyframes_dropped: ${policy.horizon} - ${policy.n_action_steps} - ${policy.n_obs_steps} + 1
eval:
n_episodes: 50
batch_size: 50

View File

@ -100,6 +100,7 @@ def test_compute_stats_on_xarm():
# reduce size of dataset sample on which stats compute is tested to 10 frames
dataset.hf_dataset = dataset.hf_dataset.select(range(10))
dataset.index = [i for i in dataset.index if i < 10]
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
# computation of the statistics. While doing this, we also make sure it works when we don't divide the