Merge remote-tracking branch 'origin/main' into 2024_05_30_add_data_augmentation
This commit is contained in:
commit
20a3715469
|
@ -121,7 +121,6 @@ celerybeat.pid
|
||||||
# Environments
|
# Environments
|
||||||
.env
|
.env
|
||||||
.venv
|
.venv
|
||||||
env/
|
|
||||||
venv/
|
venv/
|
||||||
ENV/
|
ENV/
|
||||||
env.bak/
|
env.bak/
|
||||||
|
|
|
@ -45,6 +45,9 @@ import itertools
|
||||||
|
|
||||||
from lerobot.__version__ import __version__ # noqa: F401
|
from lerobot.__version__ import __version__ # noqa: F401
|
||||||
|
|
||||||
|
# TODO(rcadene): Improve policies and envs. As of now, an item in `available_policies`
|
||||||
|
# refers to a yaml file AND a modeling name. Same for `available_envs` which refers to
|
||||||
|
# a yaml file AND a environment name. The difference should be more obvious.
|
||||||
available_tasks_per_env = {
|
available_tasks_per_env = {
|
||||||
"aloha": [
|
"aloha": [
|
||||||
"AlohaInsertion-v0",
|
"AlohaInsertion-v0",
|
||||||
|
@ -52,6 +55,7 @@ available_tasks_per_env = {
|
||||||
],
|
],
|
||||||
"pusht": ["PushT-v0"],
|
"pusht": ["PushT-v0"],
|
||||||
"xarm": ["XarmLift-v0"],
|
"xarm": ["XarmLift-v0"],
|
||||||
|
"dora_aloha_real": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"],
|
||||||
}
|
}
|
||||||
available_envs = list(available_tasks_per_env.keys())
|
available_envs = list(available_tasks_per_env.keys())
|
||||||
|
|
||||||
|
@ -77,6 +81,23 @@ available_datasets_per_env = {
|
||||||
"lerobot/xarm_push_medium_image",
|
"lerobot/xarm_push_medium_image",
|
||||||
"lerobot/xarm_push_medium_replay_image",
|
"lerobot/xarm_push_medium_replay_image",
|
||||||
],
|
],
|
||||||
|
"dora_aloha_real": [
|
||||||
|
"lerobot/aloha_static_battery",
|
||||||
|
"lerobot/aloha_static_candy",
|
||||||
|
"lerobot/aloha_static_coffee",
|
||||||
|
"lerobot/aloha_static_coffee_new",
|
||||||
|
"lerobot/aloha_static_cups_open",
|
||||||
|
"lerobot/aloha_static_fork_pick_up",
|
||||||
|
"lerobot/aloha_static_pingpong_test",
|
||||||
|
"lerobot/aloha_static_pro_pencil",
|
||||||
|
"lerobot/aloha_static_screw_driver",
|
||||||
|
"lerobot/aloha_static_tape",
|
||||||
|
"lerobot/aloha_static_thread_velcro",
|
||||||
|
"lerobot/aloha_static_towel",
|
||||||
|
"lerobot/aloha_static_vinh_cup",
|
||||||
|
"lerobot/aloha_static_vinh_cup_left",
|
||||||
|
"lerobot/aloha_static_ziploc_slide",
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
available_real_world_datasets = [
|
available_real_world_datasets = [
|
||||||
|
@ -108,16 +129,19 @@ available_datasets = list(
|
||||||
itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets)
|
itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# lists all available policies from `lerobot/common/policies` by their class attribute: `name`.
|
||||||
available_policies = [
|
available_policies = [
|
||||||
"act",
|
"act",
|
||||||
"diffusion",
|
"diffusion",
|
||||||
"tdmpc",
|
"tdmpc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# keys and values refer to yaml files
|
||||||
available_policies_per_env = {
|
available_policies_per_env = {
|
||||||
"aloha": ["act"],
|
"aloha": ["act"],
|
||||||
"pusht": ["diffusion"],
|
"pusht": ["diffusion"],
|
||||||
"xarm": ["tdmpc"],
|
"xarm": ["tdmpc"],
|
||||||
|
"dora_aloha_real": ["act_real"],
|
||||||
}
|
}
|
||||||
|
|
||||||
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
|
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
|
||||||
|
|
|
@ -16,17 +16,15 @@
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from math import ceil
|
from math import ceil
|
||||||
|
|
||||||
import datasets
|
|
||||||
import einops
|
import einops
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
from datasets import Image
|
from datasets import Image
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|
||||||
from lerobot.common.datasets.video_utils import VideoFrame
|
from lerobot.common.datasets.video_utils import VideoFrame
|
||||||
|
|
||||||
|
|
||||||
def get_stats_einops_patterns(dataset: LeRobotDataset | datasets.Dataset, num_workers=0):
|
def get_stats_einops_patterns(dataset, num_workers=0):
|
||||||
"""These einops patterns will be used to aggregate batches and compute statistics.
|
"""These einops patterns will be used to aggregate batches and compute statistics.
|
||||||
|
|
||||||
Note: We assume the images are in channel first format
|
Note: We assume the images are in channel first format
|
||||||
|
@ -66,9 +64,8 @@ def get_stats_einops_patterns(dataset: LeRobotDataset | datasets.Dataset, num_wo
|
||||||
return stats_patterns
|
return stats_patterns
|
||||||
|
|
||||||
|
|
||||||
def compute_stats(
|
def compute_stats(dataset, batch_size=32, num_workers=16, max_num_samples=None):
|
||||||
dataset: LeRobotDataset | datasets.Dataset, batch_size=32, num_workers=16, max_num_samples=None
|
"""Compute mean/std and min/max statistics of all data keys in a LeRobotDataset."""
|
||||||
):
|
|
||||||
if max_num_samples is None:
|
if max_num_samples is None:
|
||||||
max_num_samples = len(dataset)
|
max_num_samples = len(dataset)
|
||||||
|
|
||||||
|
@ -159,3 +156,54 @@ def compute_stats(
|
||||||
"min": min[key],
|
"min": min[key],
|
||||||
}
|
}
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
|
||||||
|
"""Aggregate stats of multiple LeRobot datasets into one set of stats without recomputing from scratch.
|
||||||
|
|
||||||
|
The final stats will have the union of all data keys from each of the datasets.
|
||||||
|
|
||||||
|
The final stats will have the union of all data keys from each of the datasets. For instance:
|
||||||
|
- new_max = max(max_dataset_0, max_dataset_1, ...)
|
||||||
|
- new_min = min(min_dataset_0, min_dataset_1, ...)
|
||||||
|
- new_mean = (mean of all data)
|
||||||
|
- new_std = (std of all data)
|
||||||
|
"""
|
||||||
|
data_keys = set()
|
||||||
|
for dataset in ls_datasets:
|
||||||
|
data_keys.update(dataset.stats.keys())
|
||||||
|
stats = {k: {} for k in data_keys}
|
||||||
|
for data_key in data_keys:
|
||||||
|
for stat_key in ["min", "max"]:
|
||||||
|
# compute `max(dataset_0["max"], dataset_1["max"], ...)`
|
||||||
|
stats[data_key][stat_key] = einops.reduce(
|
||||||
|
torch.stack([d.stats[data_key][stat_key] for d in ls_datasets if data_key in d.stats], dim=0),
|
||||||
|
"n ... -> ...",
|
||||||
|
stat_key,
|
||||||
|
)
|
||||||
|
total_samples = sum(d.num_samples for d in ls_datasets if data_key in d.stats)
|
||||||
|
# Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
|
||||||
|
# dataset, then divide by total_samples to get the overall "mean".
|
||||||
|
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
|
||||||
|
# numerical overflow!
|
||||||
|
stats[data_key]["mean"] = sum(
|
||||||
|
d.stats[data_key]["mean"] * (d.num_samples / total_samples)
|
||||||
|
for d in ls_datasets
|
||||||
|
if data_key in d.stats
|
||||||
|
)
|
||||||
|
# The derivation for standard deviation is a little more involved but is much in the same spirit as
|
||||||
|
# the computation of the mean.
|
||||||
|
# Given two sets of data where the statistics are known:
|
||||||
|
# σ_combined = sqrt[ (n1 * (σ1^2 + d1^2) + n2 * (σ2^2 + d2^2)) / (n1 + n2) ]
|
||||||
|
# where d1 = μ1 - μ_combined, d2 = μ2 - μ_combined
|
||||||
|
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
|
||||||
|
# numerical overflow!
|
||||||
|
stats[data_key]["std"] = torch.sqrt(
|
||||||
|
sum(
|
||||||
|
(d.stats[data_key]["std"] ** 2 + (d.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2)
|
||||||
|
* (d.num_samples / total_samples)
|
||||||
|
for d in ls_datasets
|
||||||
|
if data_key in d.stats
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return stats
|
|
@ -16,10 +16,10 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from omegaconf import ListConfig, OmegaConf
|
||||||
from torchvision.transforms import v2
|
from torchvision.transforms import v2
|
||||||
from omegaconf import OmegaConf
|
|
||||||
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
|
||||||
|
|
||||||
|
|
||||||
def resolve_delta_timestamps(cfg):
|
def resolve_delta_timestamps(cfg):
|
||||||
|
@ -36,31 +36,72 @@ def resolve_delta_timestamps(cfg):
|
||||||
cfg.training.delta_timestamps[key] = eval(delta_timestamps[key])
|
cfg.training.delta_timestamps[key] = eval(delta_timestamps[key])
|
||||||
|
|
||||||
|
|
||||||
def make_dataset(
|
def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotDataset:
|
||||||
cfg,
|
"""
|
||||||
split="train",
|
Args:
|
||||||
):
|
cfg: A Hydra config as per the LeRobot config scheme.
|
||||||
if cfg.env.name not in cfg.dataset_repo_id:
|
split: Select the data subset used to create an instance of LeRobotDataset.
|
||||||
|
All datasets hosted on [lerobot](https://huggingface.co/lerobot) contain only one subset: "train".
|
||||||
|
Thus, by default, `split="train"` selects all the available data. `split` aims to work like the
|
||||||
|
slicer in the hugging face datasets:
|
||||||
|
https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
|
||||||
|
As of now, it only supports `split="train[:n]"` to load the first n frames of the dataset or
|
||||||
|
`split="train[n:]"` to load the last n frames. For instance `split="train[:1000]"`.
|
||||||
|
Returns:
|
||||||
|
The LeRobotDataset.
|
||||||
|
"""
|
||||||
|
if not isinstance(cfg.dataset_repo_id, (str, ListConfig)):
|
||||||
|
raise ValueError(
|
||||||
|
"Expected cfg.dataset_repo_id to be either a single string to load one dataset or a list of "
|
||||||
|
"strings to load multiple datasets."
|
||||||
|
)
|
||||||
|
|
||||||
|
# A soft check to warn if the environment matches the dataset. Don't check if we are using a real world env (dora).
|
||||||
|
if cfg.env.name != "dora":
|
||||||
|
if isinstance(cfg.dataset_repo_id, str):
|
||||||
|
dataset_repo_ids = [cfg.dataset_repo_id] # single dataset
|
||||||
|
else:
|
||||||
|
dataset_repo_ids = cfg.dataset_repo_id # multiple datasets
|
||||||
|
|
||||||
|
for dataset_repo_id in dataset_repo_ids:
|
||||||
|
if cfg.env.name not in dataset_repo_id:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"There might be a mismatch between your training dataset ({cfg.dataset_repo_id=}) and your "
|
f"There might be a mismatch between your training dataset ({dataset_repo_id=}) and your "
|
||||||
f"environment ({cfg.env.name=})."
|
f"environment ({cfg.env.name=})."
|
||||||
)
|
)
|
||||||
|
|
||||||
resolve_delta_timestamps(cfg)
|
resolve_delta_timestamps(cfg)
|
||||||
|
|
||||||
if cfg.image_transform.enable:
|
if cfg.image_transform.enable:
|
||||||
transform = v2.Compose([v2.ColorJitter(brightness=cfg.image_transform.colorjitter_factor, contrast=cfg.image_transform.colorjitter_factor),
|
transform = v2.Compose(
|
||||||
v2.RandomAdjustSharpness(cfg.image_transform.sharpness_factor, p=cfg.image_transform.sharpness_p), v2.RandomAdjustSharpness(cfg.image_transform.blur_factor, p=cfg.image_transform.blur_p),
|
[
|
||||||
|
v2.ColorJitter(
|
||||||
|
brightness=cfg.image_transform.colorjitter_factor,
|
||||||
|
contrast=cfg.image_transform.colorjitter_factor,
|
||||||
|
),
|
||||||
|
v2.RandomAdjustSharpness(
|
||||||
|
cfg.image_transform.sharpness_factor, p=cfg.image_transform.sharpness_p
|
||||||
|
),
|
||||||
|
v2.RandomAdjustSharpness(cfg.image_transform.blur_factor, p=cfg.image_transform.blur_p),
|
||||||
v2.ToDtype(torch.float32, scale=True),
|
v2.ToDtype(torch.float32, scale=True),
|
||||||
])
|
]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
transform = None
|
transform = None
|
||||||
|
|
||||||
|
if isinstance(cfg.dataset_repo_id, str):
|
||||||
dataset = LeRobotDataset(
|
dataset = LeRobotDataset(
|
||||||
cfg.dataset_repo_id,
|
cfg.dataset_repo_id,
|
||||||
split=split,
|
split=split,
|
||||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||||
transform=transform
|
transform=transform,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
dataset = MultiLeRobotDataset(
|
||||||
|
cfg.dataset_repo_id,
|
||||||
|
split=split,
|
||||||
|
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||||
|
transform=transform,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.get("override_dataset_stats"):
|
if cfg.get("override_dataset_stats"):
|
||||||
|
|
|
@ -13,12 +13,16 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
|
import torch.utils
|
||||||
|
|
||||||
|
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
calculate_episode_data_index,
|
calculate_episode_data_index,
|
||||||
load_episode_data_index,
|
load_episode_data_index,
|
||||||
|
@ -42,7 +46,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
version: str | None = CODEBASE_VERSION,
|
version: str | None = CODEBASE_VERSION,
|
||||||
root: Path | None = DATA_DIR,
|
root: Path | None = DATA_DIR,
|
||||||
split: str = "train",
|
split: str = "train",
|
||||||
transform: callable = None,
|
transform: Callable | None = None,
|
||||||
delta_timestamps: dict[list[float]] | None = None,
|
delta_timestamps: dict[list[float]] | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -172,7 +176,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_preloaded(
|
def from_preloaded(
|
||||||
cls,
|
cls,
|
||||||
repo_id: str,
|
repo_id: str = "from_preloaded",
|
||||||
version: str | None = CODEBASE_VERSION,
|
version: str | None = CODEBASE_VERSION,
|
||||||
root: Path | None = None,
|
root: Path | None = None,
|
||||||
split: str = "train",
|
split: str = "train",
|
||||||
|
@ -184,7 +188,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
stats=None,
|
stats=None,
|
||||||
info=None,
|
info=None,
|
||||||
videos_dir=None,
|
videos_dir=None,
|
||||||
):
|
) -> "LeRobotDataset":
|
||||||
|
"""Create a LeRobot Dataset from existing data and attributes instead of loading from the filesystem.
|
||||||
|
|
||||||
|
It is especially useful when converting raw data into LeRobotDataset before saving the dataset
|
||||||
|
on the filesystem or uploading to the hub.
|
||||||
|
|
||||||
|
Note: Meta-data attributes like `repo_id`, `version`, `root`, etc are optional and potentially
|
||||||
|
meaningless depending on the downstream usage of the return dataset.
|
||||||
|
"""
|
||||||
# create an empty object of type LeRobotDataset
|
# create an empty object of type LeRobotDataset
|
||||||
obj = cls.__new__(cls)
|
obj = cls.__new__(cls)
|
||||||
obj.repo_id = repo_id
|
obj.repo_id = repo_id
|
||||||
|
@ -196,6 +208,193 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||||
obj.hf_dataset = hf_dataset
|
obj.hf_dataset = hf_dataset
|
||||||
obj.episode_data_index = episode_data_index
|
obj.episode_data_index = episode_data_index
|
||||||
obj.stats = stats
|
obj.stats = stats
|
||||||
obj.info = info
|
obj.info = info if info is not None else {}
|
||||||
obj.videos_dir = videos_dir
|
obj.videos_dir = videos_dir
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||||
|
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
|
||||||
|
|
||||||
|
The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API
|
||||||
|
structure of `LeRobotDataset`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
repo_ids: list[str],
|
||||||
|
version: str | None = CODEBASE_VERSION,
|
||||||
|
root: Path | None = DATA_DIR,
|
||||||
|
split: str = "train",
|
||||||
|
transform: Callable | None = None,
|
||||||
|
delta_timestamps: dict[list[float]] | None = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.repo_ids = repo_ids
|
||||||
|
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
||||||
|
# are handled by this class.
|
||||||
|
self._datasets = [
|
||||||
|
LeRobotDataset(
|
||||||
|
repo_id,
|
||||||
|
version=version,
|
||||||
|
root=root,
|
||||||
|
split=split,
|
||||||
|
delta_timestamps=delta_timestamps,
|
||||||
|
transform=transform,
|
||||||
|
)
|
||||||
|
for repo_id in repo_ids
|
||||||
|
]
|
||||||
|
# Check that some properties are consistent across datasets. Note: We may relax some of these
|
||||||
|
# consistency requirements in future iterations of this class.
|
||||||
|
for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True):
|
||||||
|
if dataset.info != self._datasets[0].info:
|
||||||
|
raise ValueError(
|
||||||
|
f"Detected a mismatch in dataset info between {self.repo_ids[0]} and {repo_id}. This is "
|
||||||
|
"not yet supported."
|
||||||
|
)
|
||||||
|
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
|
||||||
|
# restriction in future iterations of this class. For now, this is necessary at least for being able
|
||||||
|
# to use PyTorch's default DataLoader collate function.
|
||||||
|
self.disabled_data_keys = set()
|
||||||
|
intersection_data_keys = set(self._datasets[0].hf_dataset.features)
|
||||||
|
for dataset in self._datasets:
|
||||||
|
intersection_data_keys.intersection_update(dataset.hf_dataset.features)
|
||||||
|
if len(intersection_data_keys) == 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Multiple datasets were provided but they had no keys common to all of them. The "
|
||||||
|
"multi-dataset functionality currently only keeps common keys."
|
||||||
|
)
|
||||||
|
for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True):
|
||||||
|
extra_keys = set(dataset.hf_dataset.features).difference(intersection_data_keys)
|
||||||
|
logging.warning(
|
||||||
|
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
||||||
|
"other datasets."
|
||||||
|
)
|
||||||
|
self.disabled_data_keys.update(extra_keys)
|
||||||
|
|
||||||
|
self.version = version
|
||||||
|
self.root = root
|
||||||
|
self.split = split
|
||||||
|
self.transform = transform
|
||||||
|
self.delta_timestamps = delta_timestamps
|
||||||
|
self.stats = aggregate_stats(self._datasets)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def repo_id_to_index(self):
|
||||||
|
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class.
|
||||||
|
|
||||||
|
This index is incorporated as a data key in the dictionary returned by `__getitem__`.
|
||||||
|
"""
|
||||||
|
return {repo_id: i for i, repo_id in enumerate(self.repo_ids)}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def repo_index_to_id(self):
|
||||||
|
"""Return the inverse mapping if repo_id_to_index."""
|
||||||
|
return {v: k for k, v in self.repo_id_to_index}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fps(self) -> int:
|
||||||
|
"""Frames per second used during data collection.
|
||||||
|
|
||||||
|
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
||||||
|
"""
|
||||||
|
return self._datasets[0].info["fps"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def video(self) -> bool:
|
||||||
|
"""Returns True if this dataset loads video frames from mp4 files.
|
||||||
|
|
||||||
|
Returns False if it only loads images from png files.
|
||||||
|
|
||||||
|
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
||||||
|
"""
|
||||||
|
return self._datasets[0].info.get("video", False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def features(self) -> datasets.Features:
|
||||||
|
features = {}
|
||||||
|
for dataset in self._datasets:
|
||||||
|
features.update({k: v for k, v in dataset.features.items() if k not in self.disabled_data_keys})
|
||||||
|
return features
|
||||||
|
|
||||||
|
@property
|
||||||
|
def camera_keys(self) -> list[str]:
|
||||||
|
"""Keys to access image and video stream from cameras."""
|
||||||
|
keys = []
|
||||||
|
for key, feats in self.features.items():
|
||||||
|
if isinstance(feats, (datasets.Image, VideoFrame)):
|
||||||
|
keys.append(key)
|
||||||
|
return keys
|
||||||
|
|
||||||
|
@property
|
||||||
|
def video_frame_keys(self) -> list[str]:
|
||||||
|
"""Keys to access video frames that requires to be decoded into images.
|
||||||
|
|
||||||
|
Note: It is empty if the dataset contains images only,
|
||||||
|
or equal to `self.cameras` if the dataset contains videos only,
|
||||||
|
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
|
||||||
|
"""
|
||||||
|
video_frame_keys = []
|
||||||
|
for key, feats in self.features.items():
|
||||||
|
if isinstance(feats, VideoFrame):
|
||||||
|
video_frame_keys.append(key)
|
||||||
|
return video_frame_keys
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_samples(self) -> int:
|
||||||
|
"""Number of samples/frames."""
|
||||||
|
return sum(d.num_samples for d in self._datasets)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_episodes(self) -> int:
|
||||||
|
"""Number of episodes."""
|
||||||
|
return sum(d.num_episodes for d in self._datasets)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tolerance_s(self) -> float:
|
||||||
|
"""Tolerance in seconds used to discard loaded frames when their timestamps
|
||||||
|
are not close enough from the requested frames. It is only used when `delta_timestamps`
|
||||||
|
is provided or when loading video frames from mp4 files.
|
||||||
|
"""
|
||||||
|
# 1e-4 to account for possible numerical error
|
||||||
|
return 1 / self.fps - 1e-4
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_samples
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||||
|
if idx >= len(self):
|
||||||
|
raise IndexError(f"Index {idx} out of bounds.")
|
||||||
|
# Determine which dataset to get an item from based on the index.
|
||||||
|
start_idx = 0
|
||||||
|
dataset_idx = 0
|
||||||
|
for dataset in self._datasets:
|
||||||
|
if idx >= start_idx + dataset.num_samples:
|
||||||
|
start_idx += dataset.num_samples
|
||||||
|
dataset_idx += 1
|
||||||
|
continue
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
|
||||||
|
item = self._datasets[dataset_idx][idx - start_idx]
|
||||||
|
item["dataset_index"] = torch.tensor(dataset_idx)
|
||||||
|
for data_key in self.disabled_data_keys:
|
||||||
|
if data_key in item:
|
||||||
|
del item[data_key]
|
||||||
|
return item
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (
|
||||||
|
f"{self.__class__.__name__}(\n"
|
||||||
|
f" Repository IDs: '{self.repo_ids}',\n"
|
||||||
|
f" Version: '{self.version}',\n"
|
||||||
|
f" Split: '{self.split}',\n"
|
||||||
|
f" Number of Samples: {self.num_samples},\n"
|
||||||
|
f" Number of Episodes: {self.num_episodes},\n"
|
||||||
|
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
|
||||||
|
f" Recorded Frames per Second: {self.fps},\n"
|
||||||
|
f" Camera Keys: {self.camera_keys},\n"
|
||||||
|
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
|
||||||
|
f" Transformations: {self.transform},\n"
|
||||||
|
f")"
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import Iterator, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class EpisodeAwareSampler:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
episode_data_index: dict,
|
||||||
|
episode_indices_to_use: Union[list, None] = None,
|
||||||
|
drop_n_first_frames: int = 0,
|
||||||
|
drop_n_last_frames: int = 0,
|
||||||
|
shuffle: bool = False,
|
||||||
|
):
|
||||||
|
"""Sampler that optionally incorporates episode boundary information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
episode_data_index: Dictionary with keys 'from' and 'to' containing the start and end indices of each episode.
|
||||||
|
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
|
||||||
|
Assumes that episodes are indexed from 0 to N-1.
|
||||||
|
drop_n_first_frames: Number of frames to drop from the start of each episode.
|
||||||
|
drop_n_last_frames: Number of frames to drop from the end of each episode.
|
||||||
|
shuffle: Whether to shuffle the indices.
|
||||||
|
"""
|
||||||
|
indices = []
|
||||||
|
for episode_idx, (start_index, end_index) in enumerate(
|
||||||
|
zip(episode_data_index["from"], episode_data_index["to"], strict=True)
|
||||||
|
):
|
||||||
|
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
|
||||||
|
indices.extend(
|
||||||
|
range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.indices = indices
|
||||||
|
self.shuffle = shuffle
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[int]:
|
||||||
|
if self.shuffle:
|
||||||
|
for i in torch.randperm(len(self.indices)):
|
||||||
|
yield self.indices[i]
|
||||||
|
else:
|
||||||
|
for i in self.indices:
|
||||||
|
yield i
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.indices)
|
|
@ -59,7 +59,7 @@ def unflatten_dict(d, sep="/"):
|
||||||
return outdict
|
return outdict
|
||||||
|
|
||||||
|
|
||||||
def hf_transform_to_torch(items_dict):
|
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||||
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
||||||
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
||||||
a channel last representation (h w c) of uint8 type, to a torch image representation
|
a channel last representation (h w c) of uint8 type, to a torch image representation
|
||||||
|
@ -73,6 +73,8 @@ def hf_transform_to_torch(items_dict):
|
||||||
elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item:
|
elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item:
|
||||||
# video frame will be processed downstream
|
# video frame will be processed downstream
|
||||||
pass
|
pass
|
||||||
|
elif first_item is None:
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
|
items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
|
||||||
return items_dict
|
return items_dict
|
||||||
|
@ -318,8 +320,7 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torc
|
||||||
|
|
||||||
|
|
||||||
def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
|
def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
|
||||||
"""
|
"""Reset the `episode_index` of the provided HuggingFace Dataset.
|
||||||
Reset the `episode_index` of the provided HuggingFace Dataset.
|
|
||||||
|
|
||||||
`episode_data_index` (and related functionality such as `load_previous_and_future_frames`) requires the
|
`episode_data_index` (and related functionality such as `load_previous_and_future_frames`) requires the
|
||||||
`episode_index` to be sorted, continuous (1,1,1 and not 1,2,1) and start at 0.
|
`episode_index` to be sorted, continuous (1,1,1 and not 1,2,1) and start at 0.
|
||||||
|
@ -338,6 +339,7 @@ def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
|
||||||
return example
|
return example
|
||||||
|
|
||||||
hf_dataset = hf_dataset.map(modify_ep_idx_func)
|
hf_dataset = hf_dataset.map(modify_ep_idx_func)
|
||||||
|
|
||||||
return hf_dataset
|
return hf_dataset
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -25,6 +25,13 @@ class ACTConfig:
|
||||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||||
Those are: `input_shapes` and 'output_shapes`.
|
Those are: `input_shapes` and 'output_shapes`.
|
||||||
|
|
||||||
|
Notes on the inputs and outputs:
|
||||||
|
- At least one key starting with "observation.image is required as an input.
|
||||||
|
- If there are multiple keys beginning with "observation.images." they are treated as multiple camera
|
||||||
|
views. Right now we only support all images having the same shape.
|
||||||
|
- May optionally work without an "observation.state" key for the proprioceptive robot state.
|
||||||
|
- "action" is required as an output key.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||||
current step and additional steps going back).
|
current step and additional steps going back).
|
||||||
|
@ -33,15 +40,15 @@ class ACTConfig:
|
||||||
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
|
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
|
||||||
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
|
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
|
||||||
environment, and throws the other 50 out.
|
environment, and throws the other 50 out.
|
||||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||||
The key represents the input data name, and the value is a list indicating the dimensions
|
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||||
of the corresponding data. For example, "observation.images.top" refers to an input from the
|
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||||
"top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
|
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||||
Importantly, shapes doesn't include batch dimension or temporal dimension.
|
include batch dimension or temporal dimension.
|
||||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||||
The key represents the output data name, and the value is a list indicating the dimensions
|
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||||
14-dimensional actions. Importantly, shapes doesn't include batch dimension or temporal dimension.
|
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||||
|
|
|
@ -198,27 +198,31 @@ class ACT(nn.Module):
|
||||||
def __init__(self, config: ACTConfig):
|
def __init__(self, config: ACTConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
|
||||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||||
|
self.use_input_state = "observation.state" in config.input_shapes
|
||||||
if self.config.use_vae:
|
if self.config.use_vae:
|
||||||
self.vae_encoder = ACTEncoder(config)
|
self.vae_encoder = ACTEncoder(config)
|
||||||
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
|
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
|
||||||
# Projection layer for joint-space configuration to hidden dimension.
|
# Projection layer for joint-space configuration to hidden dimension.
|
||||||
|
if self.use_input_state:
|
||||||
self.vae_encoder_robot_state_input_proj = nn.Linear(
|
self.vae_encoder_robot_state_input_proj = nn.Linear(
|
||||||
config.input_shapes["observation.state"][0], config.dim_model
|
config.input_shapes["observation.state"][0], config.dim_model
|
||||||
)
|
)
|
||||||
# Projection layer for action (joint-space target) to hidden dimension.
|
# Projection layer for action (joint-space target) to hidden dimension.
|
||||||
self.vae_encoder_action_input_proj = nn.Linear(
|
self.vae_encoder_action_input_proj = nn.Linear(
|
||||||
config.input_shapes["observation.state"][0], config.dim_model
|
config.output_shapes["action"][0], config.dim_model
|
||||||
)
|
)
|
||||||
self.latent_dim = config.latent_dim
|
|
||||||
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
|
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
|
||||||
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2)
|
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
|
||||||
# Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch
|
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
|
||||||
# dimension.
|
# dimension.
|
||||||
|
num_input_token_encoder = 1 + config.chunk_size
|
||||||
|
if self.use_input_state:
|
||||||
|
num_input_token_encoder += 1
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"vae_encoder_pos_enc",
|
"vae_encoder_pos_enc",
|
||||||
create_sinusoidal_pos_embedding(1 + 1 + config.chunk_size, config.dim_model).unsqueeze(0),
|
create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Backbone for image feature extraction.
|
# Backbone for image feature extraction.
|
||||||
|
@ -238,15 +242,17 @@ class ACT(nn.Module):
|
||||||
|
|
||||||
# Transformer encoder input projections. The tokens will be structured like
|
# Transformer encoder input projections. The tokens will be structured like
|
||||||
# [latent, robot_state, image_feature_map_pixels].
|
# [latent, robot_state, image_feature_map_pixels].
|
||||||
|
if self.use_input_state:
|
||||||
self.encoder_robot_state_input_proj = nn.Linear(
|
self.encoder_robot_state_input_proj = nn.Linear(
|
||||||
config.input_shapes["observation.state"][0], config.dim_model
|
config.input_shapes["observation.state"][0], config.dim_model
|
||||||
)
|
)
|
||||||
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, config.dim_model)
|
self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
|
||||||
self.encoder_img_feat_input_proj = nn.Conv2d(
|
self.encoder_img_feat_input_proj = nn.Conv2d(
|
||||||
backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
||||||
)
|
)
|
||||||
# Transformer encoder positional embeddings.
|
# Transformer encoder positional embeddings.
|
||||||
self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, config.dim_model)
|
num_input_token_decoder = 2 if self.use_input_state else 1
|
||||||
|
self.encoder_robot_and_latent_pos_embed = nn.Embedding(num_input_token_decoder, config.dim_model)
|
||||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
||||||
|
|
||||||
# Transformer decoder.
|
# Transformer decoder.
|
||||||
|
@ -285,7 +291,7 @@ class ACT(nn.Module):
|
||||||
"action" in batch
|
"action" in batch
|
||||||
), "actions must be provided when using the variational objective in training mode."
|
), "actions must be provided when using the variational objective in training mode."
|
||||||
|
|
||||||
batch_size = batch["observation.state"].shape[0]
|
batch_size = batch["observation.images"].shape[0]
|
||||||
|
|
||||||
# Prepare the latent for input to the transformer encoder.
|
# Prepare the latent for input to the transformer encoder.
|
||||||
if self.config.use_vae and "action" in batch:
|
if self.config.use_vae and "action" in batch:
|
||||||
|
@ -293,11 +299,16 @@ class ACT(nn.Module):
|
||||||
cls_embed = einops.repeat(
|
cls_embed = einops.repeat(
|
||||||
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
||||||
) # (B, 1, D)
|
) # (B, 1, D)
|
||||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]).unsqueeze(
|
if self.use_input_state:
|
||||||
1
|
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
|
||||||
) # (B, 1, D)
|
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
|
||||||
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
||||||
vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
|
|
||||||
|
if self.use_input_state:
|
||||||
|
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
|
||||||
|
else:
|
||||||
|
vae_encoder_input = [cls_embed, action_embed]
|
||||||
|
vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
|
||||||
|
|
||||||
# Prepare fixed positional embedding.
|
# Prepare fixed positional embedding.
|
||||||
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
|
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
|
||||||
|
@ -308,16 +319,17 @@ class ACT(nn.Module):
|
||||||
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
|
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
|
||||||
)[0] # select the class token, with shape (B, D)
|
)[0] # select the class token, with shape (B, D)
|
||||||
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
|
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
|
||||||
mu = latent_pdf_params[:, : self.latent_dim]
|
mu = latent_pdf_params[:, : self.config.latent_dim]
|
||||||
# This is 2log(sigma). Done this way to match the original implementation.
|
# This is 2log(sigma). Done this way to match the original implementation.
|
||||||
log_sigma_x2 = latent_pdf_params[:, self.latent_dim :]
|
log_sigma_x2 = latent_pdf_params[:, self.config.latent_dim :]
|
||||||
|
|
||||||
# Sample the latent with the reparameterization trick.
|
# Sample the latent with the reparameterization trick.
|
||||||
latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
|
latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
|
||||||
else:
|
else:
|
||||||
# When not using the VAE encoder, we set the latent to be all zeros.
|
# When not using the VAE encoder, we set the latent to be all zeros.
|
||||||
mu = log_sigma_x2 = None
|
mu = log_sigma_x2 = None
|
||||||
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
|
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
||||||
|
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
|
||||||
batch["observation.state"].device
|
batch["observation.state"].device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -326,8 +338,10 @@ class ACT(nn.Module):
|
||||||
all_cam_features = []
|
all_cam_features = []
|
||||||
all_cam_pos_embeds = []
|
all_cam_pos_embeds = []
|
||||||
images = batch["observation.images"]
|
images = batch["observation.images"]
|
||||||
|
|
||||||
for cam_index in range(images.shape[-4]):
|
for cam_index in range(images.shape[-4]):
|
||||||
cam_features = self.backbone(images[:, cam_index])["feature_map"]
|
cam_features = self.backbone(images[:, cam_index])["feature_map"]
|
||||||
|
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
||||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||||
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
||||||
all_cam_features.append(cam_features)
|
all_cam_features.append(cam_features)
|
||||||
|
@ -337,13 +351,15 @@ class ACT(nn.Module):
|
||||||
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
|
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
|
||||||
|
|
||||||
# Get positional embeddings for robot state and latent.
|
# Get positional embeddings for robot state and latent.
|
||||||
|
if self.use_input_state:
|
||||||
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
|
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
|
||||||
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
|
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
|
||||||
|
|
||||||
# Stack encoder input and positional embeddings moving to (S, B, C).
|
# Stack encoder input and positional embeddings moving to (S, B, C).
|
||||||
|
encoder_in_feats = [latent_embed, robot_state_embed] if self.use_input_state else [latent_embed]
|
||||||
encoder_in = torch.cat(
|
encoder_in = torch.cat(
|
||||||
[
|
[
|
||||||
torch.stack([latent_embed, robot_state_embed], axis=0),
|
torch.stack(encoder_in_feats, axis=0),
|
||||||
einops.rearrange(encoder_in, "b c h w -> (h w) b c"),
|
einops.rearrange(encoder_in, "b c h w -> (h w) b c"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -357,6 +373,7 @@ class ACT(nn.Module):
|
||||||
|
|
||||||
# Forward pass through the transformer modules.
|
# Forward pass through the transformer modules.
|
||||||
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
|
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
|
||||||
|
# TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer
|
||||||
decoder_in = torch.zeros(
|
decoder_in = torch.zeros(
|
||||||
(self.config.chunk_size, batch_size, self.config.dim_model),
|
(self.config.chunk_size, batch_size, self.config.dim_model),
|
||||||
dtype=pos_embed.dtype,
|
dtype=pos_embed.dtype,
|
||||||
|
|
|
@ -26,21 +26,26 @@ class DiffusionConfig:
|
||||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||||
Those are: `input_shapes` and `output_shapes`.
|
Those are: `input_shapes` and `output_shapes`.
|
||||||
|
|
||||||
|
Notes on the inputs and outputs:
|
||||||
|
- "observation.state" is required as an input key.
|
||||||
|
- A key starting with "observation.image is required as an input.
|
||||||
|
- "action" is required as an output key.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||||
current step and additional steps going back).
|
current step and additional steps going back).
|
||||||
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
||||||
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||||||
See `DiffusionPolicy.select_action` for more details.
|
See `DiffusionPolicy.select_action` for more details.
|
||||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||||
The key represents the input data name, and the value is a list indicating the dimensions
|
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||||
of the corresponding data. For example, "observation.image" refers to an input from
|
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||||
a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
|
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||||
Importantly, shapes doesnt include batch dimension or temporal dimension.
|
include batch dimension or temporal dimension.
|
||||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||||
The key represents the output data name, and the value is a list indicating the dimensions
|
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||||
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||||
|
|
|
@ -31,6 +31,15 @@ class TDMPCConfig:
|
||||||
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
|
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
|
||||||
action repeats in Q-learning or ask your favorite chatbot)
|
action repeats in Q-learning or ask your favorite chatbot)
|
||||||
horizon: Horizon for model predictive control.
|
horizon: Horizon for model predictive control.
|
||||||
|
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||||
|
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||||
|
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||||
|
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||||
|
include batch dimension or temporal dimension.
|
||||||
|
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||||
|
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||||
|
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||||
|
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||||
|
|
|
@ -120,13 +120,13 @@ def init_logging():
|
||||||
logging.getLogger().addHandler(console_handler)
|
logging.getLogger().addHandler(console_handler)
|
||||||
|
|
||||||
|
|
||||||
def format_big_number(num):
|
def format_big_number(num, precision=0):
|
||||||
suffixes = ["", "K", "M", "B", "T", "Q"]
|
suffixes = ["", "K", "M", "B", "T", "Q"]
|
||||||
divisor = 1000.0
|
divisor = 1000.0
|
||||||
|
|
||||||
for suffix in suffixes:
|
for suffix in suffixes:
|
||||||
if abs(num) < divisor:
|
if abs(num) < divisor:
|
||||||
return f"{num:.0f}{suffix}"
|
return f"{num:.{precision}f}{suffix}"
|
||||||
num /= divisor
|
num /= divisor
|
||||||
|
|
||||||
return num
|
return num
|
||||||
|
|
|
@ -23,6 +23,10 @@ use_amp: false
|
||||||
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
||||||
# AND for the evaluation environments.
|
# AND for the evaluation environments.
|
||||||
seed: ???
|
seed: ???
|
||||||
|
# You may provide a list of datasets here. `train.py` creates them all and concatenates them. Note: only data
|
||||||
|
# keys common between the datasets are kept. Each dataset gets and additional transform that inserts the
|
||||||
|
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
|
||||||
|
# datsets are provided.
|
||||||
dataset_repo_id: lerobot/pusht
|
dataset_repo_id: lerobot/pusht
|
||||||
|
|
||||||
training:
|
training:
|
||||||
|
|
|
@ -0,0 +1,13 @@
|
||||||
|
# @package _global_
|
||||||
|
|
||||||
|
fps: 30
|
||||||
|
|
||||||
|
env:
|
||||||
|
name: dora
|
||||||
|
task: DoraAloha-v0
|
||||||
|
state_dim: 14
|
||||||
|
action_dim: 14
|
||||||
|
fps: ${fps}
|
||||||
|
episode_length: 400
|
||||||
|
gym:
|
||||||
|
fps: ${fps}
|
|
@ -0,0 +1,115 @@
|
||||||
|
# @package _global_
|
||||||
|
|
||||||
|
# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets.
|
||||||
|
# Compared to `act.yaml`, it contains 4 cameras (i.e. cam_right_wrist, cam_left_wrist, images,
|
||||||
|
# cam_low) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
|
||||||
|
# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation.
|
||||||
|
# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot).
|
||||||
|
# Look at its README for more information on how to evaluate a checkpoint in the real-world.
|
||||||
|
#
|
||||||
|
# Example of usage for training:
|
||||||
|
# ```bash
|
||||||
|
# python lerobot/scripts/train.py \
|
||||||
|
# policy=act_real \
|
||||||
|
# env=dora_aloha_real
|
||||||
|
# ```
|
||||||
|
|
||||||
|
seed: 1000
|
||||||
|
dataset_repo_id: lerobot/aloha_static_vinh_cup
|
||||||
|
|
||||||
|
override_dataset_stats:
|
||||||
|
observation.images.cam_right_wrist:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
observation.images.cam_left_wrist:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
observation.images.cam_high:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
observation.images.cam_low:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
|
||||||
|
training:
|
||||||
|
offline_steps: 80000
|
||||||
|
online_steps: 0
|
||||||
|
eval_freq: -1
|
||||||
|
save_freq: 10000
|
||||||
|
log_freq: 100
|
||||||
|
save_checkpoint: true
|
||||||
|
|
||||||
|
batch_size: 8
|
||||||
|
lr: 1e-5
|
||||||
|
lr_backbone: 1e-5
|
||||||
|
weight_decay: 1e-4
|
||||||
|
grad_clip_norm: 10
|
||||||
|
online_steps_between_rollouts: 1
|
||||||
|
|
||||||
|
delta_timestamps:
|
||||||
|
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||||
|
|
||||||
|
eval:
|
||||||
|
n_episodes: 50
|
||||||
|
batch_size: 50
|
||||||
|
|
||||||
|
# See `configuration_act.py` for more details.
|
||||||
|
policy:
|
||||||
|
name: act
|
||||||
|
|
||||||
|
# Input / output structure.
|
||||||
|
n_obs_steps: 1
|
||||||
|
chunk_size: 100 # chunk_size
|
||||||
|
n_action_steps: 100
|
||||||
|
|
||||||
|
input_shapes:
|
||||||
|
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||||
|
observation.images.cam_right_wrist: [3, 480, 640]
|
||||||
|
observation.images.cam_left_wrist: [3, 480, 640]
|
||||||
|
observation.images.cam_high: [3, 480, 640]
|
||||||
|
observation.images.cam_low: [3, 480, 640]
|
||||||
|
observation.state: ["${env.state_dim}"]
|
||||||
|
output_shapes:
|
||||||
|
action: ["${env.action_dim}"]
|
||||||
|
|
||||||
|
# Normalization / Unnormalization
|
||||||
|
input_normalization_modes:
|
||||||
|
observation.images.cam_right_wrist: mean_std
|
||||||
|
observation.images.cam_left_wrist: mean_std
|
||||||
|
observation.images.cam_high: mean_std
|
||||||
|
observation.images.cam_low: mean_std
|
||||||
|
observation.state: mean_std
|
||||||
|
output_normalization_modes:
|
||||||
|
action: mean_std
|
||||||
|
|
||||||
|
# Architecture.
|
||||||
|
# Vision backbone.
|
||||||
|
vision_backbone: resnet18
|
||||||
|
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||||
|
replace_final_stride_with_dilation: false
|
||||||
|
# Transformer layers.
|
||||||
|
pre_norm: false
|
||||||
|
dim_model: 512
|
||||||
|
n_heads: 8
|
||||||
|
dim_feedforward: 3200
|
||||||
|
feedforward_activation: relu
|
||||||
|
n_encoder_layers: 4
|
||||||
|
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||||
|
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||||
|
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||||
|
n_decoder_layers: 1
|
||||||
|
# VAE.
|
||||||
|
use_vae: true
|
||||||
|
latent_dim: 32
|
||||||
|
n_vae_encoder_layers: 4
|
||||||
|
|
||||||
|
# Inference.
|
||||||
|
temporal_ensemble_momentum: null
|
||||||
|
|
||||||
|
# Training and loss computation.
|
||||||
|
dropout: 0.1
|
||||||
|
kl_weight: 10.0
|
|
@ -0,0 +1,111 @@
|
||||||
|
# @package _global_
|
||||||
|
|
||||||
|
# Use `act_real_no_state.yaml` to train on real-world Aloha/Aloha2 datasets when cameras are moving (e.g. wrist cameras)
|
||||||
|
# Compared to `act_real.yaml`, it is camera only and does not use the state as input which is vector of robot joint positions.
|
||||||
|
# We validated experimentaly that not using state reaches better success rate. Our hypothesis is that `act_real.yaml` might
|
||||||
|
# overfits to the state, because the images are more complex to learn from since they are moving.
|
||||||
|
#
|
||||||
|
# Example of usage for training:
|
||||||
|
# ```bash
|
||||||
|
# python lerobot/scripts/train.py \
|
||||||
|
# policy=act_real_no_state \
|
||||||
|
# env=dora_aloha_real
|
||||||
|
# ```
|
||||||
|
|
||||||
|
seed: 1000
|
||||||
|
dataset_repo_id: lerobot/aloha_static_vinh_cup
|
||||||
|
|
||||||
|
override_dataset_stats:
|
||||||
|
observation.images.cam_right_wrist:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
observation.images.cam_left_wrist:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
observation.images.cam_high:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
observation.images.cam_low:
|
||||||
|
# stats from imagenet, since we use a pretrained vision model
|
||||||
|
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||||
|
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||||
|
|
||||||
|
training:
|
||||||
|
offline_steps: 80000
|
||||||
|
online_steps: 0
|
||||||
|
eval_freq: -1
|
||||||
|
save_freq: 10000
|
||||||
|
log_freq: 100
|
||||||
|
save_checkpoint: true
|
||||||
|
|
||||||
|
batch_size: 8
|
||||||
|
lr: 1e-5
|
||||||
|
lr_backbone: 1e-5
|
||||||
|
weight_decay: 1e-4
|
||||||
|
grad_clip_norm: 10
|
||||||
|
online_steps_between_rollouts: 1
|
||||||
|
|
||||||
|
delta_timestamps:
|
||||||
|
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||||
|
|
||||||
|
eval:
|
||||||
|
n_episodes: 50
|
||||||
|
batch_size: 50
|
||||||
|
|
||||||
|
# See `configuration_act.py` for more details.
|
||||||
|
policy:
|
||||||
|
name: act
|
||||||
|
|
||||||
|
# Input / output structure.
|
||||||
|
n_obs_steps: 1
|
||||||
|
chunk_size: 100 # chunk_size
|
||||||
|
n_action_steps: 100
|
||||||
|
|
||||||
|
input_shapes:
|
||||||
|
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||||
|
observation.images.cam_right_wrist: [3, 480, 640]
|
||||||
|
observation.images.cam_left_wrist: [3, 480, 640]
|
||||||
|
observation.images.cam_high: [3, 480, 640]
|
||||||
|
observation.images.cam_low: [3, 480, 640]
|
||||||
|
output_shapes:
|
||||||
|
action: ["${env.action_dim}"]
|
||||||
|
|
||||||
|
# Normalization / Unnormalization
|
||||||
|
input_normalization_modes:
|
||||||
|
observation.images.cam_right_wrist: mean_std
|
||||||
|
observation.images.cam_left_wrist: mean_std
|
||||||
|
observation.images.cam_high: mean_std
|
||||||
|
observation.images.cam_low: mean_std
|
||||||
|
output_normalization_modes:
|
||||||
|
action: mean_std
|
||||||
|
|
||||||
|
# Architecture.
|
||||||
|
# Vision backbone.
|
||||||
|
vision_backbone: resnet18
|
||||||
|
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||||
|
replace_final_stride_with_dilation: false
|
||||||
|
# Transformer layers.
|
||||||
|
pre_norm: false
|
||||||
|
dim_model: 512
|
||||||
|
n_heads: 8
|
||||||
|
dim_feedforward: 3200
|
||||||
|
feedforward_activation: relu
|
||||||
|
n_encoder_layers: 4
|
||||||
|
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||||
|
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||||
|
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||||
|
n_decoder_layers: 1
|
||||||
|
# VAE.
|
||||||
|
use_vae: true
|
||||||
|
latent_dim: 32
|
||||||
|
n_vae_encoder_layers: 4
|
||||||
|
|
||||||
|
# Inference.
|
||||||
|
temporal_ensemble_momentum: null
|
||||||
|
|
||||||
|
# Training and loss computation.
|
||||||
|
dropout: 0.1
|
||||||
|
kl_weight: 10.0
|
|
@ -44,6 +44,10 @@ training:
|
||||||
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||||
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]"
|
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]"
|
||||||
|
|
||||||
|
# The original implementation doesn't sample frames for the last 7 steps,
|
||||||
|
# which avoids excessive padding and leads to improved training results.
|
||||||
|
drop_n_last_frames: 7 # ${policy.horizon} - ${policy.n_action_steps} - ${policy.n_obs_steps} + 1
|
||||||
|
|
||||||
eval:
|
eval:
|
||||||
n_episodes: 50
|
n_episodes: 50
|
||||||
batch_size: 50
|
batch_size: 50
|
||||||
|
|
|
@ -71,9 +71,9 @@ import torch
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
from lerobot.common.datasets.compute_stats import compute_stats
|
||||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.compute_stats import compute_stats
|
|
||||||
from lerobot.common.datasets.utils import flatten_dict
|
from lerobot.common.datasets.utils import flatten_dict
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from copy import deepcopy
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
|
|
||||||
|
@ -28,6 +27,8 @@ from termcolor import colored
|
||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
|
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset
|
||||||
|
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
||||||
from lerobot.common.datasets.utils import cycle
|
from lerobot.common.datasets.utils import cycle
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.logger import Logger, log_output_dir
|
from lerobot.common.logger import Logger, log_output_dir
|
||||||
|
@ -280,6 +281,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
|
|
||||||
logging.info("make_dataset")
|
logging.info("make_dataset")
|
||||||
offline_dataset = make_dataset(cfg)
|
offline_dataset = make_dataset(cfg)
|
||||||
|
if isinstance(offline_dataset, MultiLeRobotDataset):
|
||||||
|
logging.info(
|
||||||
|
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
||||||
|
f"{pformat(offline_dataset.repo_id_to_index , indent=2)}"
|
||||||
|
)
|
||||||
|
|
||||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||||
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||||
|
@ -330,7 +336,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
max_episodes_rendered=4,
|
max_episodes_rendered=4,
|
||||||
start_seed=cfg.seed,
|
start_seed=cfg.seed,
|
||||||
)
|
)
|
||||||
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
|
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline=True)
|
||||||
if cfg.wandb.enable:
|
if cfg.wandb.enable:
|
||||||
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
||||||
logging.info("Resume training")
|
logging.info("Resume training")
|
||||||
|
@ -351,18 +357,28 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
logging.info("Resume training")
|
logging.info("Resume training")
|
||||||
|
|
||||||
# create dataloader for offline training
|
# create dataloader for offline training
|
||||||
|
if cfg.training.get("drop_n_last_frames"):
|
||||||
|
shuffle = False
|
||||||
|
sampler = EpisodeAwareSampler(
|
||||||
|
offline_dataset.episode_data_index,
|
||||||
|
drop_n_last_frames=cfg.training.drop_n_last_frames,
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
shuffle = True
|
||||||
|
sampler = None
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
offline_dataset,
|
offline_dataset,
|
||||||
num_workers=cfg.training.num_workers,
|
num_workers=cfg.training.num_workers,
|
||||||
batch_size=cfg.training.batch_size,
|
batch_size=cfg.training.batch_size,
|
||||||
shuffle=True,
|
shuffle=shuffle,
|
||||||
|
sampler=sampler,
|
||||||
pin_memory=device.type != "cpu",
|
pin_memory=device.type != "cpu",
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
)
|
)
|
||||||
dl_iter = cycle(dataloader)
|
dl_iter = cycle(dataloader)
|
||||||
|
|
||||||
policy.train()
|
policy.train()
|
||||||
is_offline = True
|
|
||||||
for _ in range(step, cfg.training.offline_steps):
|
for _ in range(step, cfg.training.offline_steps):
|
||||||
if step == 0:
|
if step == 0:
|
||||||
logging.info("Start offline training on a fixed dataset")
|
logging.info("Start offline training on a fixed dataset")
|
||||||
|
@ -382,7 +398,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
)
|
)
|
||||||
|
|
||||||
if step % cfg.training.log_freq == 0:
|
if step % cfg.training.log_freq == 0:
|
||||||
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline)
|
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline=True)
|
||||||
|
|
||||||
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
|
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
|
||||||
# so we pass in step + 1.
|
# so we pass in step + 1.
|
||||||
|
@ -390,41 +406,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||||
|
|
||||||
step += 1
|
step += 1
|
||||||
|
|
||||||
logging.info("End of offline training")
|
|
||||||
|
|
||||||
if cfg.training.online_steps == 0:
|
|
||||||
if cfg.training.eval_freq > 0:
|
|
||||||
eval_env.close()
|
eval_env.close()
|
||||||
return
|
logging.info("End of training")
|
||||||
|
|
||||||
# create an env dedicated to online episodes collection from policy rollout
|
|
||||||
online_training_env = make_env(cfg, n_envs=1)
|
|
||||||
|
|
||||||
# create an empty online dataset similar to offline dataset
|
|
||||||
online_dataset = deepcopy(offline_dataset)
|
|
||||||
online_dataset.hf_dataset = {}
|
|
||||||
online_dataset.episode_data_index = {}
|
|
||||||
|
|
||||||
# create dataloader for online training
|
|
||||||
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
|
|
||||||
weights = [1.0] * len(concat_dataset)
|
|
||||||
sampler = torch.utils.data.WeightedRandomSampler(
|
|
||||||
weights, num_samples=len(concat_dataset), replacement=True
|
|
||||||
)
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
|
||||||
concat_dataset,
|
|
||||||
num_workers=4,
|
|
||||||
batch_size=cfg.training.batch_size,
|
|
||||||
sampler=sampler,
|
|
||||||
pin_memory=device.type != "cpu",
|
|
||||||
drop_last=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info("End of online training")
|
|
||||||
|
|
||||||
if cfg.training.eval_freq > 0:
|
|
||||||
eval_env.close()
|
|
||||||
online_training_env.close()
|
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "absl-py"
|
name = "absl-py"
|
||||||
|
@ -444,63 +444,63 @@ files = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "coverage"
|
name = "coverage"
|
||||||
version = "7.5.1"
|
version = "7.5.3"
|
||||||
description = "Code coverage measurement for Python"
|
description = "Code coverage measurement for Python"
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "coverage-7.5.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0884920835a033b78d1c73b6d3bbcda8161a900f38a488829a83982925f6c2e"},
|
{file = "coverage-7.5.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a6519d917abb15e12380406d721e37613e2a67d166f9fb7e5a8ce0375744cd45"},
|
||||||
{file = "coverage-7.5.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:39afcd3d4339329c5f58de48a52f6e4e50f6578dd6099961cf22228feb25f38f"},
|
{file = "coverage-7.5.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:aea7da970f1feccf48be7335f8b2ca64baf9b589d79e05b9397a06696ce1a1ec"},
|
||||||
{file = "coverage-7.5.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a7b0ceee8147444347da6a66be737c9d78f3353b0681715b668b72e79203e4a"},
|
{file = "coverage-7.5.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:923b7b1c717bd0f0f92d862d1ff51d9b2b55dbbd133e05680204465f454bb286"},
|
||||||
{file = "coverage-7.5.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a9ca3f2fae0088c3c71d743d85404cec8df9be818a005ea065495bedc33da35"},
|
{file = "coverage-7.5.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62bda40da1e68898186f274f832ef3e759ce929da9a9fd9fcf265956de269dbc"},
|
||||||
{file = "coverage-7.5.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd215c0c7d7aab005221608a3c2b46f58c0285a819565887ee0b718c052aa4e"},
|
{file = "coverage-7.5.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8b7339180d00de83e930358223c617cc343dd08e1aa5ec7b06c3a121aec4e1d"},
|
||||||
{file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:4bf0655ab60d754491004a5efd7f9cccefcc1081a74c9ef2da4735d6ee4a6223"},
|
{file = "coverage-7.5.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:25a5caf742c6195e08002d3b6c2dd6947e50efc5fc2c2205f61ecb47592d2d83"},
|
||||||
{file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:61c4bf1ba021817de12b813338c9be9f0ad5b1e781b9b340a6d29fc13e7c1b5e"},
|
{file = "coverage-7.5.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:05ac5f60faa0c704c0f7e6a5cbfd6f02101ed05e0aee4d2822637a9e672c998d"},
|
||||||
{file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:db66fc317a046556a96b453a58eced5024af4582a8dbdc0c23ca4dbc0d5b3146"},
|
{file = "coverage-7.5.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:239a4e75e09c2b12ea478d28815acf83334d32e722e7433471fbf641c606344c"},
|
||||||
{file = "coverage-7.5.1-cp310-cp310-win32.whl", hash = "sha256:b016ea6b959d3b9556cb401c55a37547135a587db0115635a443b2ce8f1c7228"},
|
{file = "coverage-7.5.3-cp310-cp310-win32.whl", hash = "sha256:a5812840d1d00eafae6585aba38021f90a705a25b8216ec7f66aebe5b619fb84"},
|
||||||
{file = "coverage-7.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:df4e745a81c110e7446b1cc8131bf986157770fa405fe90e15e850aaf7619bc8"},
|
{file = "coverage-7.5.3-cp310-cp310-win_amd64.whl", hash = "sha256:33ca90a0eb29225f195e30684ba4a6db05dbef03c2ccd50b9077714c48153cac"},
|
||||||
{file = "coverage-7.5.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:796a79f63eca8814ca3317a1ea443645c9ff0d18b188de470ed7ccd45ae79428"},
|
{file = "coverage-7.5.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f81bc26d609bf0fbc622c7122ba6307993c83c795d2d6f6f6fd8c000a770d974"},
|
||||||
{file = "coverage-7.5.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4fc84a37bfd98db31beae3c2748811a3fa72bf2007ff7902f68746d9757f3746"},
|
{file = "coverage-7.5.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7cec2af81f9e7569280822be68bd57e51b86d42e59ea30d10ebdbb22d2cb7232"},
|
||||||
{file = "coverage-7.5.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6175d1a0559986c6ee3f7fccfc4a90ecd12ba0a383dcc2da30c2b9918d67d8a3"},
|
{file = "coverage-7.5.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:55f689f846661e3f26efa535071775d0483388a1ccfab899df72924805e9e7cd"},
|
||||||
{file = "coverage-7.5.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fc81d5878cd6274ce971e0a3a18a8803c3fe25457165314271cf78e3aae3aa2"},
|
{file = "coverage-7.5.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:50084d3516aa263791198913a17354bd1dc627d3c1639209640b9cac3fef5807"},
|
||||||
{file = "coverage-7.5.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:556cf1a7cbc8028cb60e1ff0be806be2eded2daf8129b8811c63e2b9a6c43bca"},
|
{file = "coverage-7.5.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:341dd8f61c26337c37988345ca5c8ccabeff33093a26953a1ac72e7d0103c4fb"},
|
||||||
{file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:9981706d300c18d8b220995ad22627647be11a4276721c10911e0e9fa44c83e8"},
|
{file = "coverage-7.5.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ab0b028165eea880af12f66086694768f2c3139b2c31ad5e032c8edbafca6ffc"},
|
||||||
{file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d7fed867ee50edf1a0b4a11e8e5d0895150e572af1cd6d315d557758bfa9c057"},
|
{file = "coverage-7.5.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:5bc5a8c87714b0c67cfeb4c7caa82b2d71e8864d1a46aa990b5588fa953673b8"},
|
||||||
{file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ef48e2707fb320c8f139424a596f5b69955a85b178f15af261bab871873bb987"},
|
{file = "coverage-7.5.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:38a3b98dae8a7c9057bd91fbf3415c05e700a5114c5f1b5b0ea5f8f429ba6614"},
|
||||||
{file = "coverage-7.5.1-cp311-cp311-win32.whl", hash = "sha256:9314d5678dcc665330df5b69c1e726a0e49b27df0461c08ca12674bcc19ef136"},
|
{file = "coverage-7.5.3-cp311-cp311-win32.whl", hash = "sha256:fcf7d1d6f5da887ca04302db8e0e0cf56ce9a5e05f202720e49b3e8157ddb9a9"},
|
||||||
{file = "coverage-7.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:5fa567e99765fe98f4e7d7394ce623e794d7cabb170f2ca2ac5a4174437e90dd"},
|
{file = "coverage-7.5.3-cp311-cp311-win_amd64.whl", hash = "sha256:8c836309931839cca658a78a888dab9676b5c988d0dd34ca247f5f3e679f4e7a"},
|
||||||
{file = "coverage-7.5.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b6cf3764c030e5338e7f61f95bd21147963cf6aa16e09d2f74f1fa52013c1206"},
|
{file = "coverage-7.5.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:296a7d9bbc598e8744c00f7a6cecf1da9b30ae9ad51c566291ff1314e6cbbed8"},
|
||||||
{file = "coverage-7.5.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ec92012fefebee89a6b9c79bc39051a6cb3891d562b9270ab10ecfdadbc0c34"},
|
{file = "coverage-7.5.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:34d6d21d8795a97b14d503dcaf74226ae51eb1f2bd41015d3ef332a24d0a17b3"},
|
||||||
{file = "coverage-7.5.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16db7f26000a07efcf6aea00316f6ac57e7d9a96501e990a36f40c965ec7a95d"},
|
{file = "coverage-7.5.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e317953bb4c074c06c798a11dbdd2cf9979dbcaa8ccc0fa4701d80042d4ebf1"},
|
||||||
{file = "coverage-7.5.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:beccf7b8a10b09c4ae543582c1319c6df47d78fd732f854ac68d518ee1fb97fa"},
|
{file = "coverage-7.5.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:705f3d7c2b098c40f5b81790a5fedb274113373d4d1a69e65f8b68b0cc26f6db"},
|
||||||
{file = "coverage-7.5.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8748731ad392d736cc9ccac03c9845b13bb07d020a33423fa5b3a36521ac6e4e"},
|
{file = "coverage-7.5.3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1196e13c45e327d6cd0b6e471530a1882f1017eb83c6229fc613cd1a11b53cd"},
|
||||||
{file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7352b9161b33fd0b643ccd1f21f3a3908daaddf414f1c6cb9d3a2fd618bf2572"},
|
{file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:015eddc5ccd5364dcb902eaecf9515636806fa1e0d5bef5769d06d0f31b54523"},
|
||||||
{file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:7a588d39e0925f6a2bff87154752481273cdb1736270642aeb3635cb9b4cad07"},
|
{file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:fd27d8b49e574e50caa65196d908f80e4dff64d7e592d0c59788b45aad7e8b35"},
|
||||||
{file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:68f962d9b72ce69ea8621f57551b2fa9c70509af757ee3b8105d4f51b92b41a7"},
|
{file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:33fc65740267222fc02975c061eb7167185fef4cc8f2770267ee8bf7d6a42f84"},
|
||||||
{file = "coverage-7.5.1-cp312-cp312-win32.whl", hash = "sha256:f152cbf5b88aaeb836127d920dd0f5e7edff5a66f10c079157306c4343d86c19"},
|
{file = "coverage-7.5.3-cp312-cp312-win32.whl", hash = "sha256:7b2a19e13dfb5c8e145c7a6ea959485ee8e2204699903c88c7d25283584bfc08"},
|
||||||
{file = "coverage-7.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:5a5740d1fb60ddf268a3811bcd353de34eb56dc24e8f52a7f05ee513b2d4f596"},
|
{file = "coverage-7.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:0bbddc54bbacfc09b3edaec644d4ac90c08ee8ed4844b0f86227dcda2d428fcb"},
|
||||||
{file = "coverage-7.5.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e2213def81a50519d7cc56ed643c9e93e0247f5bbe0d1247d15fa520814a7cd7"},
|
{file = "coverage-7.5.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f78300789a708ac1f17e134593f577407d52d0417305435b134805c4fb135adb"},
|
||||||
{file = "coverage-7.5.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5037f8fcc2a95b1f0e80585bd9d1ec31068a9bcb157d9750a172836e98bc7a90"},
|
{file = "coverage-7.5.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b368e1aee1b9b75757942d44d7598dcd22a9dbb126affcbba82d15917f0cc155"},
|
||||||
{file = "coverage-7.5.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3721c2c9e4c4953a41a26c14f4cef64330392a6d2d675c8b1db3b645e31f0e"},
|
{file = "coverage-7.5.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f836c174c3a7f639bded48ec913f348c4761cbf49de4a20a956d3431a7c9cb24"},
|
||||||
{file = "coverage-7.5.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca498687ca46a62ae590253fba634a1fe9836bc56f626852fb2720f334c9e4e5"},
|
{file = "coverage-7.5.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:244f509f126dc71369393ce5fea17c0592c40ee44e607b6d855e9c4ac57aac98"},
|
||||||
{file = "coverage-7.5.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0cdcbc320b14c3e5877ee79e649677cb7d89ef588852e9583e6b24c2e5072661"},
|
{file = "coverage-7.5.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c4c2872b3c91f9baa836147ca33650dc5c172e9273c808c3c3199c75490e709d"},
|
||||||
{file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:57e0204b5b745594e5bc14b9b50006da722827f0b8c776949f1135677e88d0b8"},
|
{file = "coverage-7.5.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:dd4b3355b01273a56b20c219e74e7549e14370b31a4ffe42706a8cda91f19f6d"},
|
||||||
{file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:8fe7502616b67b234482c3ce276ff26f39ffe88adca2acf0261df4b8454668b4"},
|
{file = "coverage-7.5.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:f542287b1489c7a860d43a7d8883e27ca62ab84ca53c965d11dac1d3a1fab7ce"},
|
||||||
{file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:9e78295f4144f9dacfed4f92935fbe1780021247c2fabf73a819b17f0ccfff8d"},
|
{file = "coverage-7.5.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:75e3f4e86804023e991096b29e147e635f5e2568f77883a1e6eed74512659ab0"},
|
||||||
{file = "coverage-7.5.1-cp38-cp38-win32.whl", hash = "sha256:1434e088b41594baa71188a17533083eabf5609e8e72f16ce8c186001e6b8c41"},
|
{file = "coverage-7.5.3-cp38-cp38-win32.whl", hash = "sha256:c59d2ad092dc0551d9f79d9d44d005c945ba95832a6798f98f9216ede3d5f485"},
|
||||||
{file = "coverage-7.5.1-cp38-cp38-win_amd64.whl", hash = "sha256:0646599e9b139988b63704d704af8e8df7fa4cbc4a1f33df69d97f36cb0a38de"},
|
{file = "coverage-7.5.3-cp38-cp38-win_amd64.whl", hash = "sha256:fa21a04112c59ad54f69d80e376f7f9d0f5f9123ab87ecd18fbb9ec3a2beed56"},
|
||||||
{file = "coverage-7.5.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4cc37def103a2725bc672f84bd939a6fe4522310503207aae4d56351644682f1"},
|
{file = "coverage-7.5.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f5102a92855d518b0996eb197772f5ac2a527c0ec617124ad5242a3af5e25f85"},
|
||||||
{file = "coverage-7.5.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fc0b4d8bfeabd25ea75e94632f5b6e047eef8adaed0c2161ada1e922e7f7cece"},
|
{file = "coverage-7.5.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d1da0a2e3b37b745a2b2a678a4c796462cf753aebf94edcc87dcc6b8641eae31"},
|
||||||
{file = "coverage-7.5.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d0a0f5e06881ecedfe6f3dd2f56dcb057b6dbeb3327fd32d4b12854df36bf26"},
|
{file = "coverage-7.5.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8383a6c8cefba1b7cecc0149415046b6fc38836295bc4c84e820872eb5478b3d"},
|
||||||
{file = "coverage-7.5.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9735317685ba6ec7e3754798c8871c2f49aa5e687cc794a0b1d284b2389d1bd5"},
|
{file = "coverage-7.5.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9aad68c3f2566dfae84bf46295a79e79d904e1c21ccfc66de88cd446f8686341"},
|
||||||
{file = "coverage-7.5.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d21918e9ef11edf36764b93101e2ae8cc82aa5efdc7c5a4e9c6c35a48496d601"},
|
{file = "coverage-7.5.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e079c9ec772fedbade9d7ebc36202a1d9ef7291bc9b3a024ca395c4d52853d7"},
|
||||||
{file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c3e757949f268364b96ca894b4c342b41dc6f8f8b66c37878aacef5930db61be"},
|
{file = "coverage-7.5.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:bde997cac85fcac227b27d4fb2c7608a2c5f6558469b0eb704c5726ae49e1c52"},
|
||||||
{file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:79afb6197e2f7f60c4824dd4b2d4c2ec5801ceb6ba9ce5d2c3080e5660d51a4f"},
|
{file = "coverage-7.5.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:990fb20b32990b2ce2c5f974c3e738c9358b2735bc05075d50a6f36721b8f303"},
|
||||||
{file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d1d0d98d95dd18fe29dc66808e1accf59f037d5716f86a501fc0256455219668"},
|
{file = "coverage-7.5.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3d5a67f0da401e105753d474369ab034c7bae51a4c31c77d94030d59e41df5bd"},
|
||||||
{file = "coverage-7.5.1-cp39-cp39-win32.whl", hash = "sha256:1cc0fe9b0b3a8364093c53b0b4c0c2dd4bb23acbec4c9240b5f284095ccf7981"},
|
{file = "coverage-7.5.3-cp39-cp39-win32.whl", hash = "sha256:e08c470c2eb01977d221fd87495b44867a56d4d594f43739a8028f8646a51e0d"},
|
||||||
{file = "coverage-7.5.1-cp39-cp39-win_amd64.whl", hash = "sha256:dde0070c40ea8bb3641e811c1cfbf18e265d024deff6de52c5950677a8fb1e0f"},
|
{file = "coverage-7.5.3-cp39-cp39-win_amd64.whl", hash = "sha256:1d2a830ade66d3563bb61d1e3c77c8def97b30ed91e166c67d0632c018f380f0"},
|
||||||
{file = "coverage-7.5.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:6537e7c10cc47c595828b8a8be04c72144725c383c4702703ff4e42e44577312"},
|
{file = "coverage-7.5.3-pp38.pp39.pp310-none-any.whl", hash = "sha256:3538d8fb1ee9bdd2e2692b3b18c22bb1c19ffbefd06880f5ac496e42d7bb3884"},
|
||||||
{file = "coverage-7.5.1.tar.gz", hash = "sha256:54de9ef3a9da981f7af93eafde4ede199e0846cd819eb27c88e2b712aae9708c"},
|
{file = "coverage-7.5.3.tar.gz", hash = "sha256:04aefca5190d1dc7a53a4c1a5a7f8568811306d7a8ee231c42fb69215571944f"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -785,6 +785,26 @@ files = [
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
six = ">=1.4.0"
|
six = ">=1.4.0"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "dora-rs"
|
||||||
|
version = "0.3.4"
|
||||||
|
description = "`dora` goal is to be a low latency, composable, and distributed data flow."
|
||||||
|
optional = true
|
||||||
|
python-versions = "*"
|
||||||
|
files = [
|
||||||
|
{file = "dora_rs-0.3.4-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:d1b738eea5a4966d731c26c6b6a0a50a491a24f7e9e335475f983cfc6f0da19e"},
|
||||||
|
{file = "dora_rs-0.3.4-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:80b724871618c78a4e5863938fa66724176cc40352771087aebe1e62a8141157"},
|
||||||
|
{file = "dora_rs-0.3.4-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a3919e157b47dc1dbc74c040a73087a4485f0d1bee99b6adcdbc36559400fe2"},
|
||||||
|
{file = "dora_rs-0.3.4-cp37-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f7c95f6e5858fd651d6cd220e4f052e99db2944b9c37fb0b5402d60ac4b41a63"},
|
||||||
|
{file = "dora_rs-0.3.4-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37d915fbbca282446235c98a9ca08389aa3ef3155d4e88c6c136326e9a830042"},
|
||||||
|
{file = "dora_rs-0.3.4-cp37-abi3-win32.whl", hash = "sha256:c9f7f22f65c884ec9bee0245ce98d0c7fad25dec0f982e566f844b5e8e58818f"},
|
||||||
|
{file = "dora_rs-0.3.4-cp37-abi3-win_amd64.whl", hash = "sha256:0a6a37f96a9f6e13b58b02a6ea75af192af5fbe4f456f6a67b1f239c3cee3276"},
|
||||||
|
{file = "dora_rs-0.3.4.tar.gz", hash = "sha256:05c5d0db0d23d7c4669995ae34db11cd636dbf91f5705d832669bd04e7452903"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
pyarrow = "*"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "einops"
|
name = "einops"
|
||||||
version = "0.8.0"
|
version = "0.8.0"
|
||||||
|
@ -1066,6 +1086,27 @@ mujoco = ">=2.3.7,<3.0.0"
|
||||||
dev = ["debugpy (>=1.8.1)", "pre-commit (>=3.7.0)"]
|
dev = ["debugpy (>=1.8.1)", "pre-commit (>=3.7.0)"]
|
||||||
test = ["pytest (>=8.1.0)", "pytest-cov (>=5.0.0)"]
|
test = ["pytest (>=8.1.0)", "pytest-cov (>=5.0.0)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "gym-dora"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = ""
|
||||||
|
optional = true
|
||||||
|
python-versions = "^3.10"
|
||||||
|
files = []
|
||||||
|
develop = false
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
dora-rs = ">=0.3.4"
|
||||||
|
gymnasium = ">=0.29.1"
|
||||||
|
pyarrow = ">=12.0.0"
|
||||||
|
|
||||||
|
[package.source]
|
||||||
|
type = "git"
|
||||||
|
url = "https://github.com/dora-rs/dora-lerobot.git"
|
||||||
|
reference = "HEAD"
|
||||||
|
resolved_reference = "ed0c00a4fdc6ec856c9842551acd7dc7ee776f79"
|
||||||
|
subdirectory = "gym_dora"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "gym-pusht"
|
name = "gym-pusht"
|
||||||
version = "0.1.4"
|
version = "0.1.4"
|
||||||
|
@ -1269,13 +1310,13 @@ files = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "huggingface-hub"
|
name = "huggingface-hub"
|
||||||
version = "0.23.1"
|
version = "0.23.2"
|
||||||
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
|
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8.0"
|
python-versions = ">=3.8.0"
|
||||||
files = [
|
files = [
|
||||||
{file = "huggingface_hub-0.23.1-py3-none-any.whl", hash = "sha256:720a5bffd2b1b449deb793da8b0df7a9390a7e238534d5a08c9fbcdecb1dd3cb"},
|
{file = "huggingface_hub-0.23.2-py3-none-any.whl", hash = "sha256:48727a16e704d409c4bb5913613308499664f22a99743435dc3a13b23c485827"},
|
||||||
{file = "huggingface_hub-0.23.1.tar.gz", hash = "sha256:4f62dbf6ae94f400c6d3419485e52bce510591432a5248a65d0cb72e4d479eb4"},
|
{file = "huggingface_hub-0.23.2.tar.gz", hash = "sha256:f6829b62d5fdecb452a76fdbec620cba4c1573655a8d710c1df71735fd9edbd2"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -2061,18 +2102,15 @@ test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "nodeenv"
|
name = "nodeenv"
|
||||||
version = "1.8.0"
|
version = "1.9.0"
|
||||||
description = "Node.js virtual environment builder"
|
description = "Node.js virtual environment builder"
|
||||||
optional = true
|
optional = true
|
||||||
python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*"
|
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
|
||||||
files = [
|
files = [
|
||||||
{file = "nodeenv-1.8.0-py2.py3-none-any.whl", hash = "sha256:df865724bb3c3adc86b3876fa209771517b0cfe596beff01a92700e0e8be4cec"},
|
{file = "nodeenv-1.9.0-py2.py3-none-any.whl", hash = "sha256:508ecec98f9f3330b636d4448c0f1a56fc68017c68f1e7857ebc52acf0eb879a"},
|
||||||
{file = "nodeenv-1.8.0.tar.gz", hash = "sha256:d51e0c37e64fbf47d017feac3145cdbb58836d7eee8c6f6d3b6880c5456227d2"},
|
{file = "nodeenv-1.9.0.tar.gz", hash = "sha256:07f144e90dae547bf0d4ee8da0ee42664a42a04e02ed68e06324348dafe4bdb1"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
|
||||||
setuptools = "*"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "numba"
|
name = "numba"
|
||||||
version = "0.59.1"
|
version = "0.59.1"
|
||||||
|
@ -2406,6 +2444,7 @@ optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
files = [
|
files = [
|
||||||
{file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"},
|
{file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"},
|
||||||
|
{file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"},
|
||||||
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"},
|
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"},
|
||||||
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"},
|
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"},
|
||||||
{file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"},
|
{file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"},
|
||||||
|
@ -2426,6 +2465,7 @@ files = [
|
||||||
{file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"},
|
{file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"},
|
||||||
{file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"},
|
{file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"},
|
||||||
{file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"},
|
{file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"},
|
||||||
|
{file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"},
|
||||||
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"},
|
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"},
|
||||||
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"},
|
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"},
|
||||||
{file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"},
|
{file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"},
|
||||||
|
@ -3188,13 +3228,13 @@ files = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "requests"
|
name = "requests"
|
||||||
version = "2.32.2"
|
version = "2.32.3"
|
||||||
description = "Python HTTP for Humans."
|
description = "Python HTTP for Humans."
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "requests-2.32.2-py3-none-any.whl", hash = "sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c"},
|
{file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"},
|
||||||
{file = "requests-2.32.2.tar.gz", hash = "sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289"},
|
{file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -3210,16 +3250,16 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rerun-sdk"
|
name = "rerun-sdk"
|
||||||
version = "0.16.0"
|
version = "0.16.1"
|
||||||
description = "The Rerun Logging SDK"
|
description = "The Rerun Logging SDK"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = "<3.13,>=3.8"
|
python-versions = "<3.13,>=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "rerun_sdk-0.16.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:1cc6dc66d089e296f945dc238301889efb61dd6d338b5d00f76981cf7aed0a74"},
|
{file = "rerun_sdk-0.16.1-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:170c6976634008611753e10dfef8cdc395ce8180e634c169e7c61cef2f89a277"},
|
||||||
{file = "rerun_sdk-0.16.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:faf231897655e46eb975695df2b0ace07db362d697e697f9a3dff52f81c0dc5d"},
|
{file = "rerun_sdk-0.16.1-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:c9a76eab7eb5559276737dad655200e9350df0837158dbc5a896970ab4201454"},
|
||||||
{file = "rerun_sdk-0.16.0-cp38-abi3-manylinux_2_31_aarch64.whl", hash = "sha256:860a6394380d3e9b9e48bf34423bd56dda54d5b0158d2ae0e433698659b34198"},
|
{file = "rerun_sdk-0.16.1-cp38-abi3-manylinux_2_31_aarch64.whl", hash = "sha256:4d6436752d57e8b8038489a0e7e37f0c760b088e96db5fb81667d3a376d63fea"},
|
||||||
{file = "rerun_sdk-0.16.0-cp38-abi3-manylinux_2_31_x86_64.whl", hash = "sha256:5b8d1476f73a3ad1a5d3f21b61c633f3ab62aa80fa0b049f5ad10bf1227681ab"},
|
{file = "rerun_sdk-0.16.1-cp38-abi3-manylinux_2_31_x86_64.whl", hash = "sha256:37b7b47948471873e84f224b16f417a94a91c7cbd6c72c68281eeff1ba414b8f"},
|
||||||
{file = "rerun_sdk-0.16.0-cp38-abi3-win_amd64.whl", hash = "sha256:aff0051a263b8c3067243c0126d319845baf4fe640899f04aeef7daf151f35e4"},
|
{file = "rerun_sdk-0.16.1-cp38-abi3-win_amd64.whl", hash = "sha256:be88799c8afdf68eafa99e64e2e4f0a484e187e017a180219abbe6bb988acd4e"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -3696,17 +3736,17 @@ files = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "sympy"
|
name = "sympy"
|
||||||
version = "1.12"
|
version = "1.12.1"
|
||||||
description = "Computer algebra system (CAS) in Python"
|
description = "Computer algebra system (CAS) in Python"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"},
|
{file = "sympy-1.12.1-py3-none-any.whl", hash = "sha256:9b2cbc7f1a640289430e13d2a56f02f867a1da0190f2f99d8968c2f74da0e515"},
|
||||||
{file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"},
|
{file = "sympy-1.12.1.tar.gz", hash = "sha256:2877b03f998cd8c08f07cd0de5b767119cd3ef40d09f41c30d722f6686b0fb88"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
mpmath = ">=0.19"
|
mpmath = ">=1.1.0,<1.4.0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tbb"
|
name = "tbb"
|
||||||
|
@ -4220,13 +4260,13 @@ multidict = ">=4.0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "zarr"
|
name = "zarr"
|
||||||
version = "2.18.1"
|
version = "2.18.2"
|
||||||
description = "An implementation of chunked, compressed, N-dimensional arrays for Python"
|
description = "An implementation of chunked, compressed, N-dimensional arrays for Python"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.9"
|
python-versions = ">=3.9"
|
||||||
files = [
|
files = [
|
||||||
{file = "zarr-2.18.1-py3-none-any.whl", hash = "sha256:a1770d194eec4ec0a41a01295a6f724e1c3471d704d3aca906d3b3a7f8830245"},
|
{file = "zarr-2.18.2-py3-none-any.whl", hash = "sha256:a638754902f97efa99b406083fdc807a0e2ccf12a949117389d2a4ba9b05df38"},
|
||||||
{file = "zarr-2.18.1.tar.gz", hash = "sha256:28c360ed123e606c425a694a83300227a907cb86a995fc9eef620ecafbe5f92d"},
|
{file = "zarr-2.18.2.tar.gz", hash = "sha256:9bb393b8a0a38fb121dbb913b047d75db28de9890f6d644a217a73cf4ae74f47"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
|
@ -4241,13 +4281,13 @@ jupyter = ["ipytree (>=0.2.2)", "ipywidgets (>=8.0.0)", "notebook"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "zipp"
|
name = "zipp"
|
||||||
version = "3.18.2"
|
version = "3.19.0"
|
||||||
description = "Backport of pathlib-compatible object wrapper for zip files"
|
description = "Backport of pathlib-compatible object wrapper for zip files"
|
||||||
optional = false
|
optional = false
|
||||||
python-versions = ">=3.8"
|
python-versions = ">=3.8"
|
||||||
files = [
|
files = [
|
||||||
{file = "zipp-3.18.2-py3-none-any.whl", hash = "sha256:dce197b859eb796242b0622af1b8beb0a722d52aa2f57133ead08edd5bf5374e"},
|
{file = "zipp-3.19.0-py3-none-any.whl", hash = "sha256:96dc6ad62f1441bcaccef23b274ec471518daf4fbbc580341204936a5a3dddec"},
|
||||||
{file = "zipp-3.18.2.tar.gz", hash = "sha256:6278d9ddbcfb1f1089a88fde84481528b07b0e10474e09dcfe53dad4069fa059"},
|
{file = "zipp-3.19.0.tar.gz", hash = "sha256:952df858fb3164426c976d9338d3961e8e8b3758e2e059e0f754b8c4262625ee"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.extras]
|
[package.extras]
|
||||||
|
@ -4257,6 +4297,7 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more
|
||||||
[extras]
|
[extras]
|
||||||
aloha = ["gym-aloha"]
|
aloha = ["gym-aloha"]
|
||||||
dev = ["debugpy", "pre-commit"]
|
dev = ["debugpy", "pre-commit"]
|
||||||
|
dora = ["gym-dora"]
|
||||||
pusht = ["gym-pusht"]
|
pusht = ["gym-pusht"]
|
||||||
test = ["pytest", "pytest-cov"]
|
test = ["pytest", "pytest-cov"]
|
||||||
umi = ["imagecodecs"]
|
umi = ["imagecodecs"]
|
||||||
|
@ -4265,4 +4306,4 @@ xarm = ["gym-xarm"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.10,<3.13"
|
python-versions = ">=3.10,<3.13"
|
||||||
content-hash = "1ad6ef0f88f0056ab639e60e033e586f7460a9c5fc3676a477bbd47923f41cb6"
|
content-hash = "23ddb8dd774a4faf85d08a07dfdf19badb7c370120834b71df4afca254520771"
|
||||||
|
|
|
@ -46,6 +46,7 @@ h5py = ">=3.10.0"
|
||||||
huggingface-hub = {extras = ["hf-transfer"], version = "^0.23.0"}
|
huggingface-hub = {extras = ["hf-transfer"], version = "^0.23.0"}
|
||||||
gymnasium = ">=0.29.1"
|
gymnasium = ">=0.29.1"
|
||||||
cmake = ">=3.29.0.1"
|
cmake = ">=3.29.0.1"
|
||||||
|
gym-dora = { git = "https://github.com/dora-rs/dora-lerobot.git", subdirectory = "gym_dora", optional = true }
|
||||||
gym-pusht = { version = ">=0.1.3", optional = true}
|
gym-pusht = { version = ">=0.1.3", optional = true}
|
||||||
gym-xarm = { version = ">=0.1.1", optional = true}
|
gym-xarm = { version = ">=0.1.1", optional = true}
|
||||||
gym-aloha = { version = ">=0.1.1", optional = true}
|
gym-aloha = { version = ">=0.1.1", optional = true}
|
||||||
|
@ -62,6 +63,7 @@ deepdiff = ">=7.0.1"
|
||||||
|
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
|
dora = ["gym-dora"]
|
||||||
pusht = ["gym-pusht"]
|
pusht = ["gym-pusht"]
|
||||||
xarm = ["gym-xarm"]
|
xarm = ["gym-xarm"]
|
||||||
aloha = ["gym-aloha"]
|
aloha = ["gym-aloha"]
|
||||||
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:2fff6294b94cf42d4dd1249dcc5c3b0269d6d9c697f894e61b867d7ab81a94e4
|
||||||
|
size 5104
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:4aa23e51607604a18b70fa42edbbe1af34f119d985628fc27cc1bbb0efbc8901
|
||||||
|
size 31688
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:6fd368406c93cb562a69ff11cf7adf34a4b223507dcb2b9e9b8f44ee1036988a
|
||||||
|
size 68
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:5663ee79a13bb70a1604b887dd21bf89d18482287442419c6cc6c5bf0e753e99
|
||||||
|
size 34928
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:fb1a45463efd860af2ca22c16c77d55a18bd96fef080ae77978845a2f22ef716
|
||||||
|
size 5104
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:aa5a43e22f01d8e2f8d19f31753608794f1edbd74aaf71660091ab80ea58dc9b
|
||||||
|
size 30808
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:97455b4360748c99905cd103473c1a52da6901d0a73ffbc51b5ea3eb250d1386
|
||||||
|
size 68
|
|
@ -0,0 +1,3 @@
|
||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:54d1f75cf67a7b1d7a7c6865ecb9b1cc86a2f032d1890245f8996789ab6e0df6
|
||||||
|
size 33608
|
|
@ -75,15 +75,16 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
|
||||||
# HACK: We reload a batch with no delta_timestamps as `select_action` won't expect a timestamps dimension
|
# HACK: We reload a batch with no delta_timestamps as `select_action` won't expect a timestamps dimension
|
||||||
dataset.delta_timestamps = None
|
dataset.delta_timestamps = None
|
||||||
batch = next(iter(dataloader))
|
batch = next(iter(dataloader))
|
||||||
obs = {
|
obs = {}
|
||||||
k: batch[k]
|
for k in batch:
|
||||||
for k in batch
|
if k.startswith("observation"):
|
||||||
if k in ["observation.image", "observation.images.top", "observation.state"]
|
obs[k] = batch[k]
|
||||||
}
|
|
||||||
|
if "n_action_steps" in cfg.policy:
|
||||||
|
actions_queue = cfg.policy.n_action_steps
|
||||||
|
else:
|
||||||
|
actions_queue = cfg.policy.n_action_repeats
|
||||||
|
|
||||||
actions_queue = (
|
|
||||||
cfg.policy.n_action_steps if "n_action_steps" in cfg.policy else cfg.policy.n_action_repeats
|
|
||||||
)
|
|
||||||
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
|
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
|
||||||
return output_dict, grad_stats, param_stats, actions
|
return output_dict, grad_stats, param_stats, actions
|
||||||
|
|
||||||
|
@ -114,6 +115,8 @@ if __name__ == "__main__":
|
||||||
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
||||||
),
|
),
|
||||||
("aloha", "act", ["policy.n_action_steps=10"]),
|
("aloha", "act", ["policy.n_action_steps=10"]),
|
||||||
|
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
|
||||||
|
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
|
||||||
]
|
]
|
||||||
for env, policy, extra_overrides in env_policies:
|
for env, policy, extra_overrides in env_policies:
|
||||||
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)
|
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)
|
|
@ -16,6 +16,7 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from itertools import chain
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
|
@ -25,26 +26,34 @@ from datasets import Dataset
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
import lerobot
|
import lerobot
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.compute_stats import (
|
||||||
from lerobot.common.datasets.lerobot_dataset import (
|
aggregate_stats,
|
||||||
LeRobotDataset,
|
|
||||||
)
|
|
||||||
from lerobot.common.datasets.push_dataset_to_hub.compute_stats import (
|
|
||||||
compute_stats,
|
compute_stats,
|
||||||
get_stats_einops_patterns,
|
get_stats_einops_patterns,
|
||||||
)
|
)
|
||||||
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
flatten_dict,
|
flatten_dict,
|
||||||
hf_transform_to_torch,
|
hf_transform_to_torch,
|
||||||
load_previous_and_future_frames,
|
load_previous_and_future_frames,
|
||||||
unflatten_dict,
|
unflatten_dict,
|
||||||
)
|
)
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE
|
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("env_name, repo_id, policy_name", lerobot.env_dataset_policy_triplets)
|
@pytest.mark.parametrize(
|
||||||
|
"env_name, repo_id, policy_name",
|
||||||
|
lerobot.env_dataset_policy_triplets
|
||||||
|
+ [("aloha", ["lerobot/aloha_sim_insertion_human", "lerobot/aloha_sim_transfer_cube_human"], "act")],
|
||||||
|
)
|
||||||
def test_factory(env_name, repo_id, policy_name):
|
def test_factory(env_name, repo_id, policy_name):
|
||||||
|
"""
|
||||||
|
Tests that:
|
||||||
|
- we can create a dataset with the factory.
|
||||||
|
- for a commonly used set of data keys, the data dimensions are correct.
|
||||||
|
"""
|
||||||
cfg = init_hydra_config(
|
cfg = init_hydra_config(
|
||||||
DEFAULT_CONFIG_PATH,
|
DEFAULT_CONFIG_PATH,
|
||||||
overrides=[
|
overrides=[
|
||||||
|
@ -105,6 +114,39 @@ def test_factory(env_name, repo_id, policy_name):
|
||||||
assert key in item, f"{key}"
|
assert key in item, f"{key}"
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds.
|
||||||
|
def test_multilerobotdataset_frames():
|
||||||
|
"""Check that all dataset frames are incorporated."""
|
||||||
|
# Note: use the image variants of the dataset to make the test approx 3x faster.
|
||||||
|
# Note: We really do need three repo_ids here as at some point this caught an issue with the chaining
|
||||||
|
# logic that wouldn't be caught with two repo IDs.
|
||||||
|
repo_ids = [
|
||||||
|
"lerobot/aloha_sim_insertion_human_image",
|
||||||
|
"lerobot/aloha_sim_transfer_cube_human_image",
|
||||||
|
"lerobot/aloha_sim_insertion_scripted_image",
|
||||||
|
]
|
||||||
|
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
|
||||||
|
dataset = MultiLeRobotDataset(repo_ids)
|
||||||
|
assert len(dataset) == sum(len(d) for d in sub_datasets)
|
||||||
|
assert dataset.num_samples == sum(d.num_samples for d in sub_datasets)
|
||||||
|
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
|
||||||
|
|
||||||
|
# Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and
|
||||||
|
# check they match.
|
||||||
|
expected_dataset_indices = []
|
||||||
|
for i, sub_dataset in enumerate(sub_datasets):
|
||||||
|
expected_dataset_indices.extend([i] * len(sub_dataset))
|
||||||
|
|
||||||
|
for expected_dataset_index, sub_dataset_item, dataset_item in zip(
|
||||||
|
expected_dataset_indices, chain(*sub_datasets), dataset, strict=True
|
||||||
|
):
|
||||||
|
dataset_index = dataset_item.pop("dataset_index")
|
||||||
|
assert dataset_index == expected_dataset_index
|
||||||
|
assert sub_dataset_item.keys() == dataset_item.keys()
|
||||||
|
for k in sub_dataset_item:
|
||||||
|
assert torch.equal(sub_dataset_item[k], dataset_item[k])
|
||||||
|
|
||||||
|
|
||||||
def test_compute_stats_on_xarm():
|
def test_compute_stats_on_xarm():
|
||||||
"""Check that the statistics are computed correctly according to the stats_patterns property.
|
"""Check that the statistics are computed correctly according to the stats_patterns property.
|
||||||
|
|
||||||
|
@ -315,3 +357,31 @@ def test_backward_compatibility(repo_id):
|
||||||
# i = dataset.episode_data_index["to"][-1].item()
|
# i = dataset.episode_data_index["to"][-1].item()
|
||||||
# load_and_compare(i - 2)
|
# load_and_compare(i - 2)
|
||||||
# load_and_compare(i - 1)
|
# load_and_compare(i - 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_aggregate_stats():
|
||||||
|
"""Makes 3 basic datasets and checks that aggregate stats are computed correctly."""
|
||||||
|
with seeded_context(0):
|
||||||
|
data_a = torch.rand(30, dtype=torch.float32)
|
||||||
|
data_b = torch.rand(20, dtype=torch.float32)
|
||||||
|
data_c = torch.rand(20, dtype=torch.float32)
|
||||||
|
|
||||||
|
hf_dataset_1 = Dataset.from_dict(
|
||||||
|
{"a": data_a[:10], "b": data_b[:10], "c": data_c[:10], "index": torch.arange(10)}
|
||||||
|
)
|
||||||
|
hf_dataset_1.set_transform(hf_transform_to_torch)
|
||||||
|
hf_dataset_2 = Dataset.from_dict({"a": data_a[10:20], "b": data_b[10:], "index": torch.arange(10)})
|
||||||
|
hf_dataset_2.set_transform(hf_transform_to_torch)
|
||||||
|
hf_dataset_3 = Dataset.from_dict({"a": data_a[20:], "c": data_c[10:], "index": torch.arange(10)})
|
||||||
|
hf_dataset_3.set_transform(hf_transform_to_torch)
|
||||||
|
dataset_1 = LeRobotDataset.from_preloaded("d1", hf_dataset=hf_dataset_1)
|
||||||
|
dataset_1.stats = compute_stats(dataset_1, batch_size=len(hf_dataset_1), num_workers=0)
|
||||||
|
dataset_2 = LeRobotDataset.from_preloaded("d2", hf_dataset=hf_dataset_2)
|
||||||
|
dataset_2.stats = compute_stats(dataset_2, batch_size=len(hf_dataset_2), num_workers=0)
|
||||||
|
dataset_3 = LeRobotDataset.from_preloaded("d3", hf_dataset=hf_dataset_3)
|
||||||
|
dataset_3.stats = compute_stats(dataset_3, batch_size=len(hf_dataset_3), num_workers=0)
|
||||||
|
stats = aggregate_stats([dataset_1, dataset_2, dataset_3])
|
||||||
|
for data_key, data in zip(["a", "b", "c"], [data_a, data_b, data_c], strict=True):
|
||||||
|
for agg_fn in ["mean", "min", "max"]:
|
||||||
|
assert torch.allclose(stats[data_key][agg_fn], einops.reduce(data, "n -> 1", agg_fn))
|
||||||
|
assert torch.allclose(stats[data_key]["std"], torch.std(data, correction=0))
|
||||||
|
|
|
@ -30,7 +30,7 @@ from lerobot.common.policies.factory import get_policy_and_config_classes, make_
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
from lerobot.common.policies.policy_protocol import Policy
|
from lerobot.common.policies.policy_protocol import Policy
|
||||||
from lerobot.common.utils.utils import init_hydra_config
|
from lerobot.common.utils.utils import init_hydra_config
|
||||||
from tests.scripts.save_policy_to_safetensor import get_policy_stats
|
from tests.scripts.save_policy_to_safetensors import get_policy_stats
|
||||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
|
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
|
||||||
|
|
||||||
|
|
||||||
|
@ -72,6 +72,8 @@ def test_get_policy_and_config_classes(policy_name: str):
|
||||||
),
|
),
|
||||||
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
|
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
|
||||||
("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]),
|
("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]),
|
||||||
|
("dora_aloha_real", "act_real", []),
|
||||||
|
("dora_aloha_real", "act_real_no_state", []),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@require_env
|
@require_env
|
||||||
|
@ -84,6 +86,9 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||||
- Updating the policy.
|
- Updating the policy.
|
||||||
- Using the policy to select actions at inference time.
|
- Using the policy to select actions at inference time.
|
||||||
- Test the action can be applied to the policy
|
- Test the action can be applied to the policy
|
||||||
|
|
||||||
|
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
|
||||||
|
and for now we add tests as we see fit.
|
||||||
"""
|
"""
|
||||||
cfg = init_hydra_config(
|
cfg = init_hydra_config(
|
||||||
DEFAULT_CONFIG_PATH,
|
DEFAULT_CONFIG_PATH,
|
||||||
|
@ -135,7 +140,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||||
|
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
num_workers=4,
|
num_workers=0,
|
||||||
batch_size=2,
|
batch_size=2,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
pin_memory=DEVICE != "cpu",
|
pin_memory=DEVICE != "cpu",
|
||||||
|
@ -291,6 +296,8 @@ def test_normalize(insert_temporal_dim):
|
||||||
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
||||||
),
|
),
|
||||||
("aloha", "act", ["policy.n_action_steps=10"]),
|
("aloha", "act", ["policy.n_action_steps=10"]),
|
||||||
|
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
|
||||||
|
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
# As artifacts have been generated on an x86_64 kernel, this test won't
|
# As artifacts have been generated on an x86_64 kernel, this test won't
|
||||||
|
|
|
@ -0,0 +1,90 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from datasets import Dataset
|
||||||
|
|
||||||
|
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
||||||
|
from lerobot.common.datasets.utils import (
|
||||||
|
calculate_episode_data_index,
|
||||||
|
hf_transform_to_torch,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_drop_n_first_frames():
|
||||||
|
dataset = Dataset.from_dict(
|
||||||
|
{
|
||||||
|
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||||
|
"index": [0, 1, 2, 3, 4, 5],
|
||||||
|
"episode_index": [0, 0, 1, 2, 2, 2],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
dataset.set_transform(hf_transform_to_torch)
|
||||||
|
episode_data_index = calculate_episode_data_index(dataset)
|
||||||
|
sampler = EpisodeAwareSampler(episode_data_index, drop_n_first_frames=1)
|
||||||
|
assert sampler.indices == [1, 4, 5]
|
||||||
|
assert len(sampler) == 3
|
||||||
|
assert list(sampler) == [1, 4, 5]
|
||||||
|
|
||||||
|
|
||||||
|
def test_drop_n_last_frames():
|
||||||
|
dataset = Dataset.from_dict(
|
||||||
|
{
|
||||||
|
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||||
|
"index": [0, 1, 2, 3, 4, 5],
|
||||||
|
"episode_index": [0, 0, 1, 2, 2, 2],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
dataset.set_transform(hf_transform_to_torch)
|
||||||
|
episode_data_index = calculate_episode_data_index(dataset)
|
||||||
|
sampler = EpisodeAwareSampler(episode_data_index, drop_n_last_frames=1)
|
||||||
|
assert sampler.indices == [0, 3, 4]
|
||||||
|
assert len(sampler) == 3
|
||||||
|
assert list(sampler) == [0, 3, 4]
|
||||||
|
|
||||||
|
|
||||||
|
def test_episode_indices_to_use():
|
||||||
|
dataset = Dataset.from_dict(
|
||||||
|
{
|
||||||
|
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||||
|
"index": [0, 1, 2, 3, 4, 5],
|
||||||
|
"episode_index": [0, 0, 1, 2, 2, 2],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
dataset.set_transform(hf_transform_to_torch)
|
||||||
|
episode_data_index = calculate_episode_data_index(dataset)
|
||||||
|
sampler = EpisodeAwareSampler(episode_data_index, episode_indices_to_use=[0, 2])
|
||||||
|
assert sampler.indices == [0, 1, 3, 4, 5]
|
||||||
|
assert len(sampler) == 5
|
||||||
|
assert list(sampler) == [0, 1, 3, 4, 5]
|
||||||
|
|
||||||
|
|
||||||
|
def test_shuffle():
|
||||||
|
dataset = Dataset.from_dict(
|
||||||
|
{
|
||||||
|
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||||
|
"index": [0, 1, 2, 3, 4, 5],
|
||||||
|
"episode_index": [0, 0, 1, 2, 2, 2],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
dataset.set_transform(hf_transform_to_torch)
|
||||||
|
episode_data_index = calculate_episode_data_index(dataset)
|
||||||
|
sampler = EpisodeAwareSampler(episode_data_index, shuffle=False)
|
||||||
|
assert sampler.indices == [0, 1, 2, 3, 4, 5]
|
||||||
|
assert len(sampler) == 6
|
||||||
|
assert list(sampler) == [0, 1, 2, 3, 4, 5]
|
||||||
|
sampler = EpisodeAwareSampler(episode_data_index, shuffle=True)
|
||||||
|
assert sampler.indices == [0, 1, 2, 3, 4, 5]
|
||||||
|
assert len(sampler) == 6
|
||||||
|
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
Loading…
Reference in New Issue