diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 22dd1789..93aec158 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -1,15 +1,12 @@ import logging import torch -from omegaconf import OmegaConf +from omegaconf import DictConfig, OmegaConf from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -def make_dataset( - cfg, - split="train", -): +def make_dataset(cfg: DictConfig, split="train") -> LeRobotDataset: if cfg.env.name not in cfg.dataset_repo_id: logging.warning( f"There might be a mismatch between your training dataset ({cfg.dataset_repo_id=}) and your " @@ -28,6 +25,7 @@ def make_dataset( cfg.dataset_repo_id, split=split, delta_timestamps=delta_timestamps, + n_end_keyframes_dropped=eval(cfg.training.get("n_end_keyframes_dropped", "0")), ) if cfg.get("override_dataset_stats"): diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index c8cfbd8e..ffc1b1b6 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -27,7 +27,26 @@ class LeRobotDataset(torch.utils.data.Dataset): split: str = "train", transform: callable = None, delta_timestamps: dict[list[float]] | None = None, + n_end_keyframes_dropped: int = 0, ): + """ + Args: + delta_timestamps: A dictionary mapping lists of relative times (Δt) to data keys. When a frame is + sampled from the underlying dataset, we treat it as a "keyframe" and load multiple frames + according to the list of Δt's. For example {"action": [-0.05, 0, 0.05]} indicates + that we want to load the current keyframe's action, as well as one from 50 ms ago, and one + 50 ms into the future. The action key then contains a (3, action_dim) tensor (whereas without + `delta_timestamps` there would just be a (action_dim,) tensor. When the Δt's demand that + frames outside of an episode boundary are retrieved, a copy padding strategy is used. See + `load_previous_and_future_frames` for more details. + n_end_keyframes_dropped: Don't sample the last n items in each episode. This option is handy when + used in combination with `delta_timestamps` when, for example, the Δt's demand multiple future + frames, but we want to avoid introducing too much copy padding into the data distribution. + For example if `delta_timestamps = {"action": [0, 0.05, 0.10, 0.15, 0.20, 0.25, 0.30]}` + and we sample the last frame in the episode, we would end up padding with 6 frames worth of + copies. Instead, we might want no padding (in which case we need n=6), or we might be okay + with up to 2 frames of padding (in which case we need n=4). + """ super().__init__() self.repo_id = repo_id self.version = version @@ -44,6 +63,12 @@ class LeRobotDataset(torch.utils.data.Dataset): self.info = load_info(repo_id, version, root) if self.video: self.videos_dir = load_videos(repo_id, version, root) + # If `n_end_keyframes_dropped == 0`, `self.index` contains exactly the indices of the hf_dataset. If + # `n_end_keyframes_dropped > 0`, `self.index` contains a subset of the indices of the hf_dataset where + # we drop those indices pertaining to the last n frames of each episode. + self.index = [] + for from_ix, to_ix in zip(*self.episode_data_index.values(), strict=True): + self.index.extend(list(range(from_ix, to_ix - n_end_keyframes_dropped))) @property def fps(self) -> int: @@ -78,7 +103,7 @@ class LeRobotDataset(torch.utils.data.Dataset): @property def num_samples(self) -> int: - return len(self.hf_dataset) + return len(self.index) @property def num_episodes(self) -> int: @@ -97,7 +122,7 @@ class LeRobotDataset(torch.utils.data.Dataset): return self.num_samples def __getitem__(self, idx): - item = self.hf_dataset[idx] + item = self.hf_dataset[self.index[idx]] if self.delta_timestamps is not None: item = load_previous_and_future_frames( diff --git a/lerobot/configs/policy/diffusion.yaml b/lerobot/configs/policy/diffusion.yaml index 60061c38..e77a40e3 100644 --- a/lerobot/configs/policy/diffusion.yaml +++ b/lerobot/configs/policy/diffusion.yaml @@ -25,11 +25,21 @@ training: adam_weight_decay: 1.0e-6 online_steps_between_rollouts: 1 + # For each training batch we want (consider n_obs_steps=2, horizon=16): + # t | -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 + # action | a, a, a, a, a, a, a, a, a, a, a, a, a, a, a, a + # observation | o, o, , , , , , , , , , , , , , + # Note that at rollout we only use some of the actions (consider n_action_steps=8): + # action used | , a, a, a, a, a, a, a, a, , , , , , , delta_timestamps: observation.image: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]" observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]" action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]" + # The original implementation doesn't sample keyframes for the last 7 steps. This is because, as described + # above, the last 7 actions from the diffusion model are not used. + n_end_keyframes_dropped: ${policy.horizon} - ${policy.n_action_steps} - ${policy.n_obs_steps} + 1 + eval: n_episodes: 50 batch_size: 50 diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 22b271be..23996aba 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -100,6 +100,7 @@ def test_compute_stats_on_xarm(): # reduce size of dataset sample on which stats compute is tested to 10 frames dataset.hf_dataset = dataset.hf_dataset.select(range(10)) + dataset.index = [i for i in dataset.index if i < 10] # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched # computation of the statistics. While doing this, we also make sure it works when we don't divide the