Merge branch 'main' into octo-model-poc
This commit is contained in:
commit
c869930168
|
@ -70,6 +70,8 @@ jobs:
|
|||
# files: ./coverage.xml
|
||||
# verbose: true
|
||||
- name: Tests end-to-end
|
||||
env:
|
||||
DEVICE: cuda
|
||||
run: make test-end-to-end
|
||||
|
||||
# - name: Generate Report
|
||||
|
|
|
@ -121,7 +121,6 @@ celerybeat.pid
|
|||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
|
|
41
Makefile
41
Makefile
|
@ -10,6 +10,7 @@ endif
|
|||
|
||||
export PATH := $(dir $(PYTHON_PATH)):$(PATH)
|
||||
|
||||
DEVICE ?= cpu
|
||||
|
||||
build-cpu:
|
||||
docker build -t lerobot:latest -f docker/lerobot-cpu/Dockerfile .
|
||||
|
@ -18,16 +19,16 @@ build-gpu:
|
|||
docker build -t lerobot:latest -f docker/lerobot-gpu/Dockerfile .
|
||||
|
||||
test-end-to-end:
|
||||
${MAKE} test-act-ete-train
|
||||
${MAKE} test-act-ete-eval
|
||||
${MAKE} test-act-ete-train-amp
|
||||
${MAKE} test-act-ete-eval-amp
|
||||
${MAKE} test-diffusion-ete-train
|
||||
${MAKE} test-diffusion-ete-eval
|
||||
${MAKE} test-tdmpc-ete-train
|
||||
${MAKE} test-tdmpc-ete-eval
|
||||
${MAKE} test-default-ete-eval
|
||||
${MAKE} test-act-pusht-tutorial
|
||||
${MAKE} DEVICE=$(DEVICE) test-act-ete-train
|
||||
${MAKE} DEVICE=$(DEVICE) test-act-ete-eval
|
||||
${MAKE} DEVICE=$(DEVICE) test-act-ete-train-amp
|
||||
${MAKE} DEVICE=$(DEVICE) test-act-ete-eval-amp
|
||||
${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-train
|
||||
${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-eval
|
||||
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train
|
||||
${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-eval
|
||||
${MAKE} DEVICE=$(DEVICE) test-default-ete-eval
|
||||
${MAKE} DEVICE=$(DEVICE) test-act-pusht-tutorial
|
||||
|
||||
test-act-ete-train:
|
||||
python lerobot/scripts/train.py \
|
||||
|
@ -39,7 +40,7 @@ test-act-ete-train:
|
|||
training.online_steps=0 \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
device=cpu \
|
||||
device=$(DEVICE) \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=2 \
|
||||
policy.n_action_steps=20 \
|
||||
|
@ -53,7 +54,7 @@ test-act-ete-eval:
|
|||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=8 \
|
||||
device=cpu \
|
||||
device=$(DEVICE) \
|
||||
|
||||
test-act-ete-train-amp:
|
||||
python lerobot/scripts/train.py \
|
||||
|
@ -65,7 +66,7 @@ test-act-ete-train-amp:
|
|||
training.online_steps=0 \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
device=cpu \
|
||||
device=$(DEVICE) \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=2 \
|
||||
policy.n_action_steps=20 \
|
||||
|
@ -80,7 +81,7 @@ test-act-ete-eval-amp:
|
|||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=8 \
|
||||
device=cpu \
|
||||
device=$(DEVICE) \
|
||||
use_amp=true
|
||||
|
||||
test-diffusion-ete-train:
|
||||
|
@ -95,7 +96,7 @@ test-diffusion-ete-train:
|
|||
training.online_steps=0 \
|
||||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
device=cpu \
|
||||
device=$(DEVICE) \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=2 \
|
||||
training.batch_size=2 \
|
||||
|
@ -107,7 +108,7 @@ test-diffusion-ete-eval:
|
|||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=8 \
|
||||
device=cpu \
|
||||
device=$(DEVICE) \
|
||||
|
||||
# TODO(alexander-soare): Restore online_steps to 2 when it is reinstated.
|
||||
test-tdmpc-ete-train:
|
||||
|
@ -122,7 +123,7 @@ test-tdmpc-ete-train:
|
|||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=2 \
|
||||
device=cpu \
|
||||
device=$(DEVICE) \
|
||||
training.save_checkpoint=true \
|
||||
training.save_freq=2 \
|
||||
training.batch_size=2 \
|
||||
|
@ -134,7 +135,7 @@ test-tdmpc-ete-eval:
|
|||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=8 \
|
||||
device=cpu \
|
||||
device=$(DEVICE) \
|
||||
|
||||
test-default-ete-eval:
|
||||
python lerobot/scripts/eval.py \
|
||||
|
@ -142,7 +143,7 @@ test-default-ete-eval:
|
|||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=8 \
|
||||
device=cpu \
|
||||
device=$(DEVICE) \
|
||||
|
||||
test-act-pusht-tutorial:
|
||||
cp examples/advanced/1_train_act_pusht/act_pusht.yaml lerobot/configs/policy/created_by_Makefile.yaml
|
||||
|
@ -154,7 +155,7 @@ test-act-pusht-tutorial:
|
|||
eval.n_episodes=1 \
|
||||
eval.batch_size=1 \
|
||||
env.episode_length=2 \
|
||||
device=cpu \
|
||||
device=$(DEVICE) \
|
||||
training.save_model=true \
|
||||
training.save_freq=2 \
|
||||
training.batch_size=2 \
|
||||
|
|
|
@ -70,7 +70,7 @@ python lerobot/scripts/train.py policy=act env=aloha
|
|||
|
||||
There are two things to note here:
|
||||
- Config overrides are passed as `param_name=param_value`.
|
||||
- Here we have overridden the defaults section. `policy=act` tells Hydra to use `policy/act.yaml`, and `env=aloha` tells Hydra to use `env/pusht.yaml`.
|
||||
- Here we have overridden the defaults section. `policy=act` tells Hydra to use `policy/act.yaml`, and `env=aloha` tells Hydra to use `env/aloha.yaml`.
|
||||
|
||||
_As an aside: we've set up all of our configurations so that they reproduce state-of-the-art results from papers in the literature._
|
||||
|
||||
|
|
|
@ -45,6 +45,9 @@ import itertools
|
|||
|
||||
from lerobot.__version__ import __version__ # noqa: F401
|
||||
|
||||
# TODO(rcadene): Improve policies and envs. As of now, an item in `available_policies`
|
||||
# refers to a yaml file AND a modeling name. Same for `available_envs` which refers to
|
||||
# a yaml file AND a environment name. The difference should be more obvious.
|
||||
available_tasks_per_env = {
|
||||
"aloha": [
|
||||
"AlohaInsertion-v0",
|
||||
|
@ -52,6 +55,7 @@ available_tasks_per_env = {
|
|||
],
|
||||
"pusht": ["PushT-v0"],
|
||||
"xarm": ["XarmLift-v0"],
|
||||
"dora_aloha_real": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"],
|
||||
}
|
||||
available_envs = list(available_tasks_per_env.keys())
|
||||
|
||||
|
@ -77,6 +81,23 @@ available_datasets_per_env = {
|
|||
"lerobot/xarm_push_medium_image",
|
||||
"lerobot/xarm_push_medium_replay_image",
|
||||
],
|
||||
"dora_aloha_real": [
|
||||
"lerobot/aloha_static_battery",
|
||||
"lerobot/aloha_static_candy",
|
||||
"lerobot/aloha_static_coffee",
|
||||
"lerobot/aloha_static_coffee_new",
|
||||
"lerobot/aloha_static_cups_open",
|
||||
"lerobot/aloha_static_fork_pick_up",
|
||||
"lerobot/aloha_static_pingpong_test",
|
||||
"lerobot/aloha_static_pro_pencil",
|
||||
"lerobot/aloha_static_screw_driver",
|
||||
"lerobot/aloha_static_tape",
|
||||
"lerobot/aloha_static_thread_velcro",
|
||||
"lerobot/aloha_static_towel",
|
||||
"lerobot/aloha_static_vinh_cup",
|
||||
"lerobot/aloha_static_vinh_cup_left",
|
||||
"lerobot/aloha_static_ziploc_slide",
|
||||
],
|
||||
}
|
||||
|
||||
available_real_world_datasets = [
|
||||
|
@ -108,12 +129,20 @@ available_datasets = list(
|
|||
itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets)
|
||||
)
|
||||
|
||||
available_policies = ["act", "diffusion", "tdmpc", "octo"]
|
||||
# lists all available policies from `lerobot/common/policies` by their class attribute: `name`.
|
||||
available_policies = [
|
||||
"act",
|
||||
"diffusion",
|
||||
"tdmpc",
|
||||
"octo",
|
||||
]
|
||||
|
||||
# keys and values refer to yaml files
|
||||
available_policies_per_env = {
|
||||
"aloha": ["act"],
|
||||
"pusht": ["diffusion", "octo"],
|
||||
"xarm": ["tdmpc"],
|
||||
"dora_aloha_real": ["act_real"],
|
||||
}
|
||||
|
||||
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
|
||||
|
|
|
@ -16,17 +16,15 @@
|
|||
from copy import deepcopy
|
||||
from math import ceil
|
||||
|
||||
import datasets
|
||||
import einops
|
||||
import torch
|
||||
import tqdm
|
||||
from datasets import Image
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.video_utils import VideoFrame
|
||||
|
||||
|
||||
def get_stats_einops_patterns(dataset: LeRobotDataset | datasets.Dataset, num_workers=0):
|
||||
def get_stats_einops_patterns(dataset, num_workers=0):
|
||||
"""These einops patterns will be used to aggregate batches and compute statistics.
|
||||
|
||||
Note: We assume the images are in channel first format
|
||||
|
@ -66,9 +64,8 @@ def get_stats_einops_patterns(dataset: LeRobotDataset | datasets.Dataset, num_wo
|
|||
return stats_patterns
|
||||
|
||||
|
||||
def compute_stats(
|
||||
dataset: LeRobotDataset | datasets.Dataset, batch_size=32, num_workers=16, max_num_samples=None
|
||||
):
|
||||
def compute_stats(dataset, batch_size=32, num_workers=16, max_num_samples=None):
|
||||
"""Compute mean/std and min/max statistics of all data keys in a LeRobotDataset."""
|
||||
if max_num_samples is None:
|
||||
max_num_samples = len(dataset)
|
||||
|
||||
|
@ -159,3 +156,54 @@ def compute_stats(
|
|||
"min": min[key],
|
||||
}
|
||||
return stats
|
||||
|
||||
|
||||
def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
|
||||
"""Aggregate stats of multiple LeRobot datasets into one set of stats without recomputing from scratch.
|
||||
|
||||
The final stats will have the union of all data keys from each of the datasets.
|
||||
|
||||
The final stats will have the union of all data keys from each of the datasets. For instance:
|
||||
- new_max = max(max_dataset_0, max_dataset_1, ...)
|
||||
- new_min = min(min_dataset_0, min_dataset_1, ...)
|
||||
- new_mean = (mean of all data)
|
||||
- new_std = (std of all data)
|
||||
"""
|
||||
data_keys = set()
|
||||
for dataset in ls_datasets:
|
||||
data_keys.update(dataset.stats.keys())
|
||||
stats = {k: {} for k in data_keys}
|
||||
for data_key in data_keys:
|
||||
for stat_key in ["min", "max"]:
|
||||
# compute `max(dataset_0["max"], dataset_1["max"], ...)`
|
||||
stats[data_key][stat_key] = einops.reduce(
|
||||
torch.stack([d.stats[data_key][stat_key] for d in ls_datasets if data_key in d.stats], dim=0),
|
||||
"n ... -> ...",
|
||||
stat_key,
|
||||
)
|
||||
total_samples = sum(d.num_samples for d in ls_datasets if data_key in d.stats)
|
||||
# Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
|
||||
# dataset, then divide by total_samples to get the overall "mean".
|
||||
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
|
||||
# numerical overflow!
|
||||
stats[data_key]["mean"] = sum(
|
||||
d.stats[data_key]["mean"] * (d.num_samples / total_samples)
|
||||
for d in ls_datasets
|
||||
if data_key in d.stats
|
||||
)
|
||||
# The derivation for standard deviation is a little more involved but is much in the same spirit as
|
||||
# the computation of the mean.
|
||||
# Given two sets of data where the statistics are known:
|
||||
# σ_combined = sqrt[ (n1 * (σ1^2 + d1^2) + n2 * (σ2^2 + d2^2)) / (n1 + n2) ]
|
||||
# where d1 = μ1 - μ_combined, d2 = μ2 - μ_combined
|
||||
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
|
||||
# numerical overflow!
|
||||
stats[data_key]["std"] = torch.sqrt(
|
||||
sum(
|
||||
(d.stats[data_key]["std"] ** 2 + (d.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2)
|
||||
* (d.num_samples / total_samples)
|
||||
for d in ls_datasets
|
||||
if data_key in d.stats
|
||||
)
|
||||
)
|
||||
return stats
|
|
@ -16,9 +16,9 @@
|
|||
import logging
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf import ListConfig, OmegaConf
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
|
||||
|
||||
|
||||
def resolve_delta_timestamps(cfg):
|
||||
|
@ -35,25 +35,54 @@ def resolve_delta_timestamps(cfg):
|
|||
cfg.training.delta_timestamps[key] = eval(delta_timestamps[key])
|
||||
|
||||
|
||||
def make_dataset(
|
||||
cfg,
|
||||
split="train",
|
||||
):
|
||||
if cfg.env.name not in cfg.dataset_repo_id:
|
||||
logging.warning(
|
||||
f"There might be a mismatch between your training dataset ({cfg.dataset_repo_id=}) and your "
|
||||
f"environment ({cfg.env.name=})."
|
||||
def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotDataset:
|
||||
"""
|
||||
Args:
|
||||
cfg: A Hydra config as per the LeRobot config scheme.
|
||||
split: Select the data subset used to create an instance of LeRobotDataset.
|
||||
All datasets hosted on [lerobot](https://huggingface.co/lerobot) contain only one subset: "train".
|
||||
Thus, by default, `split="train"` selects all the available data. `split` aims to work like the
|
||||
slicer in the hugging face datasets:
|
||||
https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
|
||||
As of now, it only supports `split="train[:n]"` to load the first n frames of the dataset or
|
||||
`split="train[n:]"` to load the last n frames. For instance `split="train[:1000]"`.
|
||||
Returns:
|
||||
The LeRobotDataset.
|
||||
"""
|
||||
if not isinstance(cfg.dataset_repo_id, (str, ListConfig)):
|
||||
raise ValueError(
|
||||
"Expected cfg.dataset_repo_id to be either a single string to load one dataset or a list of "
|
||||
"strings to load multiple datasets."
|
||||
)
|
||||
|
||||
# A soft check to warn if the environment matches the dataset. Don't check if we are using a real world env (dora).
|
||||
if cfg.env.name != "dora":
|
||||
if isinstance(cfg.dataset_repo_id, str):
|
||||
dataset_repo_ids = [cfg.dataset_repo_id] # single dataset
|
||||
else:
|
||||
dataset_repo_ids = cfg.dataset_repo_id # multiple datasets
|
||||
|
||||
for dataset_repo_id in dataset_repo_ids:
|
||||
if cfg.env.name not in dataset_repo_id:
|
||||
logging.warning(
|
||||
f"There might be a mismatch between your training dataset ({dataset_repo_id=}) and your "
|
||||
f"environment ({cfg.env.name=})."
|
||||
)
|
||||
|
||||
resolve_delta_timestamps(cfg)
|
||||
|
||||
# TODO(rcadene): add data augmentations
|
||||
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset_repo_id,
|
||||
split=split,
|
||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||
)
|
||||
if isinstance(cfg.dataset_repo_id, str):
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset_repo_id,
|
||||
split=split,
|
||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||
)
|
||||
else:
|
||||
dataset = MultiLeRobotDataset(
|
||||
cfg.dataset_repo_id, split=split, delta_timestamps=cfg.training.get("delta_timestamps")
|
||||
)
|
||||
|
||||
if cfg.get("override_dataset_stats"):
|
||||
for key, stats_dict in cfg.override_dataset_stats.items():
|
||||
|
|
|
@ -13,12 +13,16 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
import torch.utils
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
load_episode_data_index,
|
||||
|
@ -42,7 +46,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
version: str | None = CODEBASE_VERSION,
|
||||
root: Path | None = DATA_DIR,
|
||||
split: str = "train",
|
||||
transform: callable = None,
|
||||
transform: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
@ -171,7 +175,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
@classmethod
|
||||
def from_preloaded(
|
||||
cls,
|
||||
repo_id: str,
|
||||
repo_id: str = "from_preloaded",
|
||||
version: str | None = CODEBASE_VERSION,
|
||||
root: Path | None = None,
|
||||
split: str = "train",
|
||||
|
@ -183,7 +187,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
stats=None,
|
||||
info=None,
|
||||
videos_dir=None,
|
||||
):
|
||||
) -> "LeRobotDataset":
|
||||
"""Create a LeRobot Dataset from existing data and attributes instead of loading from the filesystem.
|
||||
|
||||
It is especially useful when converting raw data into LeRobotDataset before saving the dataset
|
||||
on the filesystem or uploading to the hub.
|
||||
|
||||
Note: Meta-data attributes like `repo_id`, `version`, `root`, etc are optional and potentially
|
||||
meaningless depending on the downstream usage of the return dataset.
|
||||
"""
|
||||
# create an empty object of type LeRobotDataset
|
||||
obj = cls.__new__(cls)
|
||||
obj.repo_id = repo_id
|
||||
|
@ -195,6 +207,193 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||
obj.hf_dataset = hf_dataset
|
||||
obj.episode_data_index = episode_data_index
|
||||
obj.stats = stats
|
||||
obj.info = info
|
||||
obj.info = info if info is not None else {}
|
||||
obj.videos_dir = videos_dir
|
||||
return obj
|
||||
|
||||
|
||||
class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
"""A dataset consisting of multiple underlying `LeRobotDataset`s.
|
||||
|
||||
The underlying `LeRobotDataset`s are effectively concatenated, and this class adopts much of the API
|
||||
structure of `LeRobotDataset`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo_ids: list[str],
|
||||
version: str | None = CODEBASE_VERSION,
|
||||
root: Path | None = DATA_DIR,
|
||||
split: str = "train",
|
||||
transform: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_ids = repo_ids
|
||||
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
|
||||
# are handled by this class.
|
||||
self._datasets = [
|
||||
LeRobotDataset(
|
||||
repo_id,
|
||||
version=version,
|
||||
root=root,
|
||||
split=split,
|
||||
delta_timestamps=delta_timestamps,
|
||||
transform=transform,
|
||||
)
|
||||
for repo_id in repo_ids
|
||||
]
|
||||
# Check that some properties are consistent across datasets. Note: We may relax some of these
|
||||
# consistency requirements in future iterations of this class.
|
||||
for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True):
|
||||
if dataset.info != self._datasets[0].info:
|
||||
raise ValueError(
|
||||
f"Detected a mismatch in dataset info between {self.repo_ids[0]} and {repo_id}. This is "
|
||||
"not yet supported."
|
||||
)
|
||||
# Disable any data keys that are not common across all of the datasets. Note: we may relax this
|
||||
# restriction in future iterations of this class. For now, this is necessary at least for being able
|
||||
# to use PyTorch's default DataLoader collate function.
|
||||
self.disabled_data_keys = set()
|
||||
intersection_data_keys = set(self._datasets[0].hf_dataset.features)
|
||||
for dataset in self._datasets:
|
||||
intersection_data_keys.intersection_update(dataset.hf_dataset.features)
|
||||
if len(intersection_data_keys) == 0:
|
||||
raise RuntimeError(
|
||||
"Multiple datasets were provided but they had no keys common to all of them. The "
|
||||
"multi-dataset functionality currently only keeps common keys."
|
||||
)
|
||||
for repo_id, dataset in zip(self.repo_ids, self._datasets, strict=True):
|
||||
extra_keys = set(dataset.hf_dataset.features).difference(intersection_data_keys)
|
||||
logging.warning(
|
||||
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
||||
"other datasets."
|
||||
)
|
||||
self.disabled_data_keys.update(extra_keys)
|
||||
|
||||
self.version = version
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.transform = transform
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.stats = aggregate_stats(self._datasets)
|
||||
|
||||
@property
|
||||
def repo_id_to_index(self):
|
||||
"""Return a mapping from dataset repo_id to a dataset index automatically created by this class.
|
||||
|
||||
This index is incorporated as a data key in the dictionary returned by `__getitem__`.
|
||||
"""
|
||||
return {repo_id: i for i, repo_id in enumerate(self.repo_ids)}
|
||||
|
||||
@property
|
||||
def repo_index_to_id(self):
|
||||
"""Return the inverse mapping if repo_id_to_index."""
|
||||
return {v: k for k, v in self.repo_id_to_index}
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
"""Frames per second used during data collection.
|
||||
|
||||
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
||||
"""
|
||||
return self._datasets[0].info["fps"]
|
||||
|
||||
@property
|
||||
def video(self) -> bool:
|
||||
"""Returns True if this dataset loads video frames from mp4 files.
|
||||
|
||||
Returns False if it only loads images from png files.
|
||||
|
||||
NOTE: Fow now, this relies on a check in __init__ to make sure all sub-datasets have the same info.
|
||||
"""
|
||||
return self._datasets[0].info.get("video", False)
|
||||
|
||||
@property
|
||||
def features(self) -> datasets.Features:
|
||||
features = {}
|
||||
for dataset in self._datasets:
|
||||
features.update({k: v for k, v in dataset.features.items() if k not in self.disabled_data_keys})
|
||||
return features
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""Keys to access image and video stream from cameras."""
|
||||
keys = []
|
||||
for key, feats in self.features.items():
|
||||
if isinstance(feats, (datasets.Image, VideoFrame)):
|
||||
keys.append(key)
|
||||
return keys
|
||||
|
||||
@property
|
||||
def video_frame_keys(self) -> list[str]:
|
||||
"""Keys to access video frames that requires to be decoded into images.
|
||||
|
||||
Note: It is empty if the dataset contains images only,
|
||||
or equal to `self.cameras` if the dataset contains videos only,
|
||||
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
|
||||
"""
|
||||
video_frame_keys = []
|
||||
for key, feats in self.features.items():
|
||||
if isinstance(feats, VideoFrame):
|
||||
video_frame_keys.append(key)
|
||||
return video_frame_keys
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
"""Number of samples/frames."""
|
||||
return sum(d.num_samples for d in self._datasets)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
"""Number of episodes."""
|
||||
return sum(d.num_episodes for d in self._datasets)
|
||||
|
||||
@property
|
||||
def tolerance_s(self) -> float:
|
||||
"""Tolerance in seconds used to discard loaded frames when their timestamps
|
||||
are not close enough from the requested frames. It is only used when `delta_timestamps`
|
||||
is provided or when loading video frames from mp4 files.
|
||||
"""
|
||||
# 1e-4 to account for possible numerical error
|
||||
return 1 / self.fps - 1e-4
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
||||
if idx >= len(self):
|
||||
raise IndexError(f"Index {idx} out of bounds.")
|
||||
# Determine which dataset to get an item from based on the index.
|
||||
start_idx = 0
|
||||
dataset_idx = 0
|
||||
for dataset in self._datasets:
|
||||
if idx >= start_idx + dataset.num_samples:
|
||||
start_idx += dataset.num_samples
|
||||
dataset_idx += 1
|
||||
continue
|
||||
break
|
||||
else:
|
||||
raise AssertionError("We expect the loop to break out as long as the index is within bounds.")
|
||||
item = self._datasets[dataset_idx][idx - start_idx]
|
||||
item["dataset_index"] = torch.tensor(dataset_idx)
|
||||
for data_key in self.disabled_data_keys:
|
||||
if data_key in item:
|
||||
del item[data_key]
|
||||
return item
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{self.__class__.__name__}(\n"
|
||||
f" Repository IDs: '{self.repo_ids}',\n"
|
||||
f" Version: '{self.version}',\n"
|
||||
f" Split: '{self.split}',\n"
|
||||
f" Number of Samples: {self.num_samples},\n"
|
||||
f" Number of Episodes: {self.num_episodes},\n"
|
||||
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
|
||||
f" Recorded Frames per Second: {self.fps},\n"
|
||||
f" Camera Keys: {self.camera_keys},\n"
|
||||
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
|
||||
f" Transformations: {self.transform},\n"
|
||||
f")"
|
||||
)
|
||||
|
|
|
@ -0,0 +1,230 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Contains utilities to process raw data format from dora-record
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
|
||||
|
||||
def check_format(raw_dir) -> bool:
|
||||
assert raw_dir.exists()
|
||||
|
||||
leader_file = list(raw_dir.glob("*.parquet"))
|
||||
if len(leader_file) == 0:
|
||||
raise ValueError(f"Missing parquet files in '{raw_dir}'")
|
||||
return True
|
||||
|
||||
|
||||
def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
|
||||
# Load data stream that will be used as reference for the timestamps synchronization
|
||||
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
|
||||
if len(reference_files) == 0:
|
||||
raise ValueError(f"Missing reference files for camera, starting with in '{raw_dir}'")
|
||||
# select first camera in alphanumeric order
|
||||
reference_key = sorted(reference_files)[0].stem
|
||||
reference_df = pd.read_parquet(raw_dir / f"{reference_key}.parquet")
|
||||
reference_df = reference_df[["timestamp_utc", reference_key]]
|
||||
|
||||
# Merge all data stream using nearest backward strategy
|
||||
df = reference_df
|
||||
for path in raw_dir.glob("*.parquet"):
|
||||
key = path.stem # action or observation.state or ...
|
||||
if key == reference_key:
|
||||
continue
|
||||
if "failed_episode_index" in key:
|
||||
# TODO(rcadene): add support for removing episodes that are tagged as "failed"
|
||||
continue
|
||||
modality_df = pd.read_parquet(path)
|
||||
modality_df = modality_df[["timestamp_utc", key]]
|
||||
df = pd.merge_asof(
|
||||
df,
|
||||
modality_df,
|
||||
on="timestamp_utc",
|
||||
# "nearest" is the best option over "backward", since the latter can desynchronizes camera timestamps by
|
||||
# matching timestamps that are too far appart, in order to fit the backward constraints. It's not the case for "nearest".
|
||||
# However, note that "nearest" might synchronize the reference camera with other cameras on slightly future timestamps.
|
||||
# are too far appart.
|
||||
direction="nearest",
|
||||
tolerance=pd.Timedelta(f"{1/fps} seconds"),
|
||||
)
|
||||
# Remove rows with episode_index -1 which indicates data that correspond to in-between episodes
|
||||
df = df[df["episode_index"] != -1]
|
||||
|
||||
image_keys = [key for key in df if "observation.images." in key]
|
||||
|
||||
def get_episode_index(row):
|
||||
episode_index_per_cam = {}
|
||||
for key in image_keys:
|
||||
path = row[key][0]["path"]
|
||||
match = re.search(r"_(\d{6}).mp4", path)
|
||||
if not match:
|
||||
raise ValueError(path)
|
||||
episode_index = int(match.group(1))
|
||||
episode_index_per_cam[key] = episode_index
|
||||
if len(set(episode_index_per_cam.values())) != 1:
|
||||
raise ValueError(
|
||||
f"All cameras are expected to belong to the same episode, but getting {episode_index_per_cam}"
|
||||
)
|
||||
return episode_index
|
||||
|
||||
df["episode_index"] = df.apply(get_episode_index, axis=1)
|
||||
|
||||
# dora only use arrays, so single values are encapsulated into a list
|
||||
df["frame_index"] = df.groupby("episode_index").cumcount()
|
||||
df = df.reset_index()
|
||||
df["index"] = df.index
|
||||
|
||||
# set 'next.done' to True for the last frame of each episode
|
||||
df["next.done"] = False
|
||||
df.loc[df.groupby("episode_index").tail(1).index, "next.done"] = True
|
||||
|
||||
df["timestamp"] = df["timestamp_utc"].map(lambda x: x.timestamp())
|
||||
# each episode starts with timestamp 0 to match the ones from the video
|
||||
df["timestamp"] = df.groupby("episode_index")["timestamp"].transform(lambda x: x - x.iloc[0])
|
||||
|
||||
del df["timestamp_utc"]
|
||||
|
||||
# sanity check
|
||||
has_nan = df.isna().any().any()
|
||||
if has_nan:
|
||||
raise ValueError("Dataset contains Nan values.")
|
||||
|
||||
# sanity check episode indices go from 0 to n-1
|
||||
ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")]
|
||||
expected_ep_ids = list(range(df["episode_index"].max() + 1))
|
||||
if ep_ids != expected_ep_ids:
|
||||
raise ValueError(f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}")
|
||||
|
||||
# Create symlink to raw videos directory (that needs to be absolute not relative)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
videos_dir = out_dir / "videos"
|
||||
videos_dir.symlink_to((raw_dir / "videos").absolute())
|
||||
|
||||
# sanity check the video paths are well formated
|
||||
for key in df:
|
||||
if "observation.images." not in key:
|
||||
continue
|
||||
for ep_idx in ep_ids:
|
||||
video_path = videos_dir / f"{key}_episode_{ep_idx:06d}.mp4"
|
||||
if not video_path.exists():
|
||||
raise ValueError(f"Video file not found in {video_path}")
|
||||
|
||||
data_dict = {}
|
||||
for key in df:
|
||||
# is video frame
|
||||
if "observation.images." in key:
|
||||
# we need `[0] because dora only use arrays, so single values are encapsulated into a list.
|
||||
# it is the case for video_frame dictionary = [{"path": ..., "timestamp": ...}]
|
||||
data_dict[key] = [video_frame[0] for video_frame in df[key].values]
|
||||
|
||||
# sanity check the video path is well formated
|
||||
video_path = videos_dir.parent / data_dict[key][0]["path"]
|
||||
if not video_path.exists():
|
||||
raise ValueError(f"Video file not found in {video_path}")
|
||||
# is number
|
||||
elif df[key].iloc[0].ndim == 0 or df[key].iloc[0].shape[0] == 1:
|
||||
data_dict[key] = torch.from_numpy(df[key].values)
|
||||
# is vector
|
||||
elif df[key].iloc[0].shape[0] > 1:
|
||||
data_dict[key] = torch.stack([torch.from_numpy(x.copy()) for x in df[key].values])
|
||||
else:
|
||||
raise ValueError(key)
|
||||
|
||||
# Get the episode index containing for each unique episode index
|
||||
first_ep_index_df = df.groupby("episode_index").agg(start_index=("index", "first")).reset_index()
|
||||
from_ = first_ep_index_df["start_index"].tolist()
|
||||
to_ = from_[1:] + [len(df)]
|
||||
episode_data_index = {
|
||||
"from": from_,
|
||||
"to": to_,
|
||||
}
|
||||
|
||||
return data_dict, episode_data_index
|
||||
|
||||
|
||||
def to_hf_dataset(data_dict, video) -> Dataset:
|
||||
features = {}
|
||||
|
||||
keys = [key for key in data_dict if "observation.images." in key]
|
||||
for key in keys:
|
||||
if video:
|
||||
features[key] = VideoFrame()
|
||||
else:
|
||||
features[key] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "observation.velocity" in data_dict:
|
||||
features["observation.velocity"] = Sequence(
|
||||
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "observation.effort" in data_dict:
|
||||
features["observation.effort"] = Sequence(
|
||||
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["episode_index"] = Value(dtype="int64", id=None)
|
||||
features["frame_index"] = Value(dtype="int64", id=None)
|
||||
features["timestamp"] = Value(dtype="float32", id=None)
|
||||
features["next.done"] = Value(dtype="bool", id=None)
|
||||
features["index"] = Value(dtype="int64", id=None)
|
||||
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def from_raw_to_lerobot_format(raw_dir: Path, out_dir: Path, fps=None, video=True, debug=False):
|
||||
init_logging()
|
||||
|
||||
if debug:
|
||||
logging.warning("debug=True not implemented. Falling back to debug=False.")
|
||||
|
||||
# sanity check
|
||||
check_format(raw_dir)
|
||||
|
||||
if fps is None:
|
||||
fps = 30
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
if not video:
|
||||
raise NotImplementedError()
|
||||
|
||||
data_df, episode_data_index = load_from_raw(raw_dir, out_dir, fps)
|
||||
hf_dataset = to_hf_dataset(data_df, video)
|
||||
|
||||
info = {
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
}
|
||||
return hf_dataset, episode_data_index, info
|
|
@ -0,0 +1,61 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Iterator, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class EpisodeAwareSampler:
|
||||
def __init__(
|
||||
self,
|
||||
episode_data_index: dict,
|
||||
episode_indices_to_use: Union[list, None] = None,
|
||||
drop_n_first_frames: int = 0,
|
||||
drop_n_last_frames: int = 0,
|
||||
shuffle: bool = False,
|
||||
):
|
||||
"""Sampler that optionally incorporates episode boundary information.
|
||||
|
||||
Args:
|
||||
episode_data_index: Dictionary with keys 'from' and 'to' containing the start and end indices of each episode.
|
||||
episode_indices_to_use: List of episode indices to use. If None, all episodes are used.
|
||||
Assumes that episodes are indexed from 0 to N-1.
|
||||
drop_n_first_frames: Number of frames to drop from the start of each episode.
|
||||
drop_n_last_frames: Number of frames to drop from the end of each episode.
|
||||
shuffle: Whether to shuffle the indices.
|
||||
"""
|
||||
indices = []
|
||||
for episode_idx, (start_index, end_index) in enumerate(
|
||||
zip(episode_data_index["from"], episode_data_index["to"], strict=True)
|
||||
):
|
||||
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
|
||||
indices.extend(
|
||||
range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames)
|
||||
)
|
||||
|
||||
self.indices = indices
|
||||
self.shuffle = shuffle
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
if self.shuffle:
|
||||
for i in torch.randperm(len(self.indices)):
|
||||
yield self.indices[i]
|
||||
else:
|
||||
for i in self.indices:
|
||||
yield i
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.indices)
|
|
@ -59,7 +59,7 @@ def unflatten_dict(d, sep="/"):
|
|||
return outdict
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict):
|
||||
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
||||
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
||||
a channel last representation (h w c) of uint8 type, to a torch image representation
|
||||
|
@ -73,6 +73,8 @@ def hf_transform_to_torch(items_dict):
|
|||
elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item:
|
||||
# video frame will be processed downstream
|
||||
pass
|
||||
elif first_item is None:
|
||||
pass
|
||||
else:
|
||||
items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
|
||||
return items_dict
|
||||
|
@ -318,8 +320,7 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torc
|
|||
|
||||
|
||||
def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
|
||||
"""
|
||||
Reset the `episode_index` of the provided HuggingFace Dataset.
|
||||
"""Reset the `episode_index` of the provided HuggingFace Dataset.
|
||||
|
||||
`episode_data_index` (and related functionality such as `load_previous_and_future_frames`) requires the
|
||||
`episode_index` to be sorted, continuous (1,1,1 and not 1,2,1) and start at 0.
|
||||
|
@ -338,6 +339,7 @@ def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
|
|||
return example
|
||||
|
||||
hf_dataset = hf_dataset.map(modify_ep_idx_func)
|
||||
|
||||
return hf_dataset
|
||||
|
||||
|
||||
|
|
|
@ -27,14 +27,6 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
|
|||
if n_envs is not None and n_envs < 1:
|
||||
raise ValueError("`n_envs must be at least 1")
|
||||
|
||||
kwargs = {
|
||||
"obs_type": "pixels_agent_pos",
|
||||
"render_mode": "rgb_array",
|
||||
"max_episode_steps": cfg.env.episode_length,
|
||||
"visualization_width": 384,
|
||||
"visualization_height": 384,
|
||||
}
|
||||
|
||||
package_name = f"gym_{cfg.env.name}"
|
||||
|
||||
try:
|
||||
|
@ -46,12 +38,16 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
|
|||
raise e
|
||||
|
||||
gym_handle = f"{package_name}/{cfg.env.task}"
|
||||
gym_kwgs = dict(cfg.env.get("gym", {}))
|
||||
|
||||
if cfg.env.get("episode_length"):
|
||||
gym_kwgs["max_episode_steps"] = cfg.env.episode_length
|
||||
|
||||
# batched version of the env that returns an observation of shape (b, c)
|
||||
env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv
|
||||
env = env_cls(
|
||||
[
|
||||
lambda: gym.make(gym_handle, disable_env_checker=True, **kwargs)
|
||||
lambda: gym.make(gym_handle, disable_env_checker=True, **gym_kwgs)
|
||||
for _ in range(n_envs if n_envs is not None else cfg.eval.batch_size)
|
||||
]
|
||||
)
|
||||
|
|
|
@ -25,6 +25,13 @@ class ACTConfig:
|
|||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_shapes` and 'output_shapes`.
|
||||
|
||||
Notes on the inputs and outputs:
|
||||
- At least one key starting with "observation.image is required as an input.
|
||||
- If there are multiple keys beginning with "observation.images." they are treated as multiple camera
|
||||
views. Right now we only support all images having the same shape.
|
||||
- May optionally work without an "observation.state" key for the proprioceptive robot state.
|
||||
- "action" is required as an output key.
|
||||
|
||||
Args:
|
||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||
current step and additional steps going back).
|
||||
|
@ -33,15 +40,15 @@ class ACTConfig:
|
|||
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
|
||||
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
|
||||
environment, and throws the other 50 out.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
||||
The key represents the input data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "observation.images.top" refers to an input from the
|
||||
"top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
|
||||
Importantly, shapes doesn't include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
||||
The key represents the output data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
||||
14-dimensional actions. Importantly, shapes doesn't include batch dimension or temporal dimension.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||
include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
|
|
|
@ -198,27 +198,31 @@ class ACT(nn.Module):
|
|||
def __init__(self, config: ACTConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# BERT style VAE encoder with input [cls, *joint_space_configuration, *action_sequence].
|
||||
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
|
||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||
self.use_input_state = "observation.state" in config.input_shapes
|
||||
if self.config.use_vae:
|
||||
self.vae_encoder = ACTEncoder(config)
|
||||
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
|
||||
# Projection layer for joint-space configuration to hidden dimension.
|
||||
self.vae_encoder_robot_state_input_proj = nn.Linear(
|
||||
config.input_shapes["observation.state"][0], config.dim_model
|
||||
)
|
||||
if self.use_input_state:
|
||||
self.vae_encoder_robot_state_input_proj = nn.Linear(
|
||||
config.input_shapes["observation.state"][0], config.dim_model
|
||||
)
|
||||
# Projection layer for action (joint-space target) to hidden dimension.
|
||||
self.vae_encoder_action_input_proj = nn.Linear(
|
||||
config.input_shapes["observation.state"][0], config.dim_model
|
||||
config.output_shapes["action"][0], config.dim_model
|
||||
)
|
||||
self.latent_dim = config.latent_dim
|
||||
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
|
||||
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2)
|
||||
# Fixed sinusoidal positional embedding the whole input to the VAE encoder. Unsqueeze for batch
|
||||
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
|
||||
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
|
||||
# dimension.
|
||||
num_input_token_encoder = 1 + config.chunk_size
|
||||
if self.use_input_state:
|
||||
num_input_token_encoder += 1
|
||||
self.register_buffer(
|
||||
"vae_encoder_pos_enc",
|
||||
create_sinusoidal_pos_embedding(1 + 1 + config.chunk_size, config.dim_model).unsqueeze(0),
|
||||
create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
|
||||
)
|
||||
|
||||
# Backbone for image feature extraction.
|
||||
|
@ -238,15 +242,17 @@ class ACT(nn.Module):
|
|||
|
||||
# Transformer encoder input projections. The tokens will be structured like
|
||||
# [latent, robot_state, image_feature_map_pixels].
|
||||
self.encoder_robot_state_input_proj = nn.Linear(
|
||||
config.input_shapes["observation.state"][0], config.dim_model
|
||||
)
|
||||
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, config.dim_model)
|
||||
if self.use_input_state:
|
||||
self.encoder_robot_state_input_proj = nn.Linear(
|
||||
config.input_shapes["observation.state"][0], config.dim_model
|
||||
)
|
||||
self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
|
||||
self.encoder_img_feat_input_proj = nn.Conv2d(
|
||||
backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
||||
)
|
||||
# Transformer encoder positional embeddings.
|
||||
self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, config.dim_model)
|
||||
num_input_token_decoder = 2 if self.use_input_state else 1
|
||||
self.encoder_robot_and_latent_pos_embed = nn.Embedding(num_input_token_decoder, config.dim_model)
|
||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
||||
|
||||
# Transformer decoder.
|
||||
|
@ -285,7 +291,7 @@ class ACT(nn.Module):
|
|||
"action" in batch
|
||||
), "actions must be provided when using the variational objective in training mode."
|
||||
|
||||
batch_size = batch["observation.state"].shape[0]
|
||||
batch_size = batch["observation.images"].shape[0]
|
||||
|
||||
# Prepare the latent for input to the transformer encoder.
|
||||
if self.config.use_vae and "action" in batch:
|
||||
|
@ -293,11 +299,16 @@ class ACT(nn.Module):
|
|||
cls_embed = einops.repeat(
|
||||
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
||||
) # (B, 1, D)
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]).unsqueeze(
|
||||
1
|
||||
) # (B, 1, D)
|
||||
if self.use_input_state:
|
||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
|
||||
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
|
||||
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
||||
vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
|
||||
|
||||
if self.use_input_state:
|
||||
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
|
||||
else:
|
||||
vae_encoder_input = [cls_embed, action_embed]
|
||||
vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
|
||||
|
||||
# Prepare fixed positional embedding.
|
||||
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
|
||||
|
@ -308,16 +319,17 @@ class ACT(nn.Module):
|
|||
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
|
||||
)[0] # select the class token, with shape (B, D)
|
||||
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
|
||||
mu = latent_pdf_params[:, : self.latent_dim]
|
||||
mu = latent_pdf_params[:, : self.config.latent_dim]
|
||||
# This is 2log(sigma). Done this way to match the original implementation.
|
||||
log_sigma_x2 = latent_pdf_params[:, self.latent_dim :]
|
||||
log_sigma_x2 = latent_pdf_params[:, self.config.latent_dim :]
|
||||
|
||||
# Sample the latent with the reparameterization trick.
|
||||
latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
|
||||
else:
|
||||
# When not using the VAE encoder, we set the latent to be all zeros.
|
||||
mu = log_sigma_x2 = None
|
||||
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
|
||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
||||
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
|
||||
batch["observation.state"].device
|
||||
)
|
||||
|
||||
|
@ -326,8 +338,10 @@ class ACT(nn.Module):
|
|||
all_cam_features = []
|
||||
all_cam_pos_embeds = []
|
||||
images = batch["observation.images"]
|
||||
|
||||
for cam_index in range(images.shape[-4]):
|
||||
cam_features = self.backbone(images[:, cam_index])["feature_map"]
|
||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
||||
all_cam_features.append(cam_features)
|
||||
|
@ -337,13 +351,15 @@ class ACT(nn.Module):
|
|||
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
|
||||
|
||||
# Get positional embeddings for robot state and latent.
|
||||
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
|
||||
if self.use_input_state:
|
||||
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
|
||||
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
|
||||
|
||||
# Stack encoder input and positional embeddings moving to (S, B, C).
|
||||
encoder_in_feats = [latent_embed, robot_state_embed] if self.use_input_state else [latent_embed]
|
||||
encoder_in = torch.cat(
|
||||
[
|
||||
torch.stack([latent_embed, robot_state_embed], axis=0),
|
||||
torch.stack(encoder_in_feats, axis=0),
|
||||
einops.rearrange(encoder_in, "b c h w -> (h w) b c"),
|
||||
]
|
||||
)
|
||||
|
@ -357,6 +373,7 @@ class ACT(nn.Module):
|
|||
|
||||
# Forward pass through the transformer modules.
|
||||
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
|
||||
# TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer
|
||||
decoder_in = torch.zeros(
|
||||
(self.config.chunk_size, batch_size, self.config.dim_model),
|
||||
dtype=pos_embed.dtype,
|
||||
|
|
|
@ -26,21 +26,26 @@ class DiffusionConfig:
|
|||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_shapes` and `output_shapes`.
|
||||
|
||||
Notes on the inputs and outputs:
|
||||
- "observation.state" is required as an input key.
|
||||
- A key starting with "observation.image is required as an input.
|
||||
- "action" is required as an output key.
|
||||
|
||||
Args:
|
||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||
current step and additional steps going back).
|
||||
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
|
||||
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
|
||||
See `DiffusionPolicy.select_action` for more details.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
||||
The key represents the input data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "observation.image" refers to an input from
|
||||
a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
|
||||
Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
||||
The key represents the output data name, and the value is a list indicating the dimensions
|
||||
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
|
||||
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||
include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
|
|
|
@ -239,10 +239,8 @@ class DiffusionModel(nn.Module):
|
|||
global_cond = torch.cat([batch["observation.state"], img_features], dim=-1).flatten(start_dim=1)
|
||||
|
||||
# run sampling
|
||||
sample = self.conditional_sample(batch_size, global_cond=global_cond)
|
||||
actions = self.conditional_sample(batch_size, global_cond=global_cond)
|
||||
|
||||
# `horizon` steps worth of actions (from the first observation).
|
||||
actions = sample[..., : self.config.output_shapes["action"][0]]
|
||||
# Extract `n_action_steps` steps worth of actions (from the current observation).
|
||||
start = n_obs_steps - 1
|
||||
end = start + self.config.n_action_steps
|
||||
|
|
|
@ -147,7 +147,7 @@ class Normalize(nn.Module):
|
|||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
# normalize to [0,1]
|
||||
batch[key] = (batch[key] - min) / (max - min)
|
||||
batch[key] = (batch[key] - min) / (max - min + 1e-8)
|
||||
# normalize to [-1, 1]
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
else:
|
||||
|
|
|
@ -31,6 +31,15 @@ class TDMPCConfig:
|
|||
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
|
||||
action repeats in Q-learning or ask your favorite chatbot)
|
||||
horizon: Horizon for model predictive control.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||
include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
|
|
|
@ -120,13 +120,13 @@ def init_logging():
|
|||
logging.getLogger().addHandler(console_handler)
|
||||
|
||||
|
||||
def format_big_number(num):
|
||||
def format_big_number(num, precision=0):
|
||||
suffixes = ["", "K", "M", "B", "T", "Q"]
|
||||
divisor = 1000.0
|
||||
|
||||
for suffix in suffixes:
|
||||
if abs(num) < divisor:
|
||||
return f"{num:.0f}{suffix}"
|
||||
return f"{num:.{precision}f}{suffix}"
|
||||
num /= divisor
|
||||
|
||||
return num
|
||||
|
|
|
@ -23,6 +23,10 @@ use_amp: false
|
|||
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
||||
# AND for the evaluation environments.
|
||||
seed: ???
|
||||
# You may provide a list of datasets here. `train.py` creates them all and concatenates them. Note: only data
|
||||
# keys common between the datasets are kept. Each dataset gets and additional transform that inserts the
|
||||
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
|
||||
# datsets are provided.
|
||||
dataset_repo_id: lerobot/pusht
|
||||
|
||||
training:
|
||||
|
@ -37,6 +41,8 @@ training:
|
|||
save_freq: ???
|
||||
log_freq: 250
|
||||
save_checkpoint: true
|
||||
num_workers: 4
|
||||
batch_size: ???
|
||||
|
||||
eval:
|
||||
n_episodes: 1
|
||||
|
|
|
@ -5,10 +5,10 @@ fps: 50
|
|||
env:
|
||||
name: aloha
|
||||
task: AlohaInsertion-v0
|
||||
from_pixels: True
|
||||
pixels_only: False
|
||||
image_size: [3, 480, 640]
|
||||
episode_length: 400
|
||||
fps: ${fps}
|
||||
state_dim: 14
|
||||
action_dim: 14
|
||||
fps: ${fps}
|
||||
episode_length: 400
|
||||
gym:
|
||||
obs_type: pixels_agent_pos
|
||||
render_mode: rgb_array
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
# @package _global_
|
||||
|
||||
fps: 30
|
||||
|
||||
env:
|
||||
name: dora
|
||||
task: DoraAloha-v0
|
||||
state_dim: 14
|
||||
action_dim: 14
|
||||
fps: ${fps}
|
||||
episode_length: 400
|
||||
gym:
|
||||
fps: ${fps}
|
|
@ -5,10 +5,13 @@ fps: 10
|
|||
env:
|
||||
name: pusht
|
||||
task: PushT-v0
|
||||
from_pixels: True
|
||||
pixels_only: False
|
||||
image_size: 96
|
||||
episode_length: 300
|
||||
fps: ${fps}
|
||||
state_dim: 2
|
||||
action_dim: 2
|
||||
fps: ${fps}
|
||||
episode_length: 300
|
||||
gym:
|
||||
obs_type: pixels_agent_pos
|
||||
render_mode: rgb_array
|
||||
visualization_width: 384
|
||||
visualization_height: 384
|
||||
|
|
|
@ -5,10 +5,13 @@ fps: 15
|
|||
env:
|
||||
name: xarm
|
||||
task: XarmLift-v0
|
||||
from_pixels: True
|
||||
pixels_only: False
|
||||
image_size: 84
|
||||
episode_length: 25
|
||||
fps: ${fps}
|
||||
state_dim: 4
|
||||
action_dim: 4
|
||||
fps: ${fps}
|
||||
episode_length: 25
|
||||
gym:
|
||||
obs_type: pixels_agent_pos
|
||||
render_mode: rgb_array
|
||||
visualization_width: 384
|
||||
visualization_height: 384
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
# @package _global_
|
||||
|
||||
# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets.
|
||||
# Compared to `act.yaml`, it contains 4 cameras (i.e. cam_right_wrist, cam_left_wrist, images,
|
||||
# cam_low) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
|
||||
# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation.
|
||||
# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot).
|
||||
# Look at its README for more information on how to evaluate a checkpoint in the real-world.
|
||||
#
|
||||
# Example of usage for training:
|
||||
# ```bash
|
||||
# python lerobot/scripts/train.py \
|
||||
# policy=act_real \
|
||||
# env=dora_aloha_real
|
||||
# ```
|
||||
|
||||
seed: 1000
|
||||
dataset_repo_id: lerobot/aloha_static_vinh_cup
|
||||
|
||||
override_dataset_stats:
|
||||
observation.images.cam_right_wrist:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.cam_left_wrist:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.cam_high:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.cam_low:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 80000
|
||||
online_steps: 0
|
||||
eval_freq: -1
|
||||
save_freq: 10000
|
||||
log_freq: 100
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 8
|
||||
lr: 1e-5
|
||||
lr_backbone: 1e-5
|
||||
weight_decay: 1e-4
|
||||
grad_clip_norm: 10
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
name: act
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 1
|
||||
chunk_size: 100 # chunk_size
|
||||
n_action_steps: 100
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.images.cam_right_wrist: [3, 480, 640]
|
||||
observation.images.cam_left_wrist: [3, 480, 640]
|
||||
observation.images.cam_high: [3, 480, 640]
|
||||
observation.images.cam_low: [3, 480, 640]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.images.cam_right_wrist: mean_std
|
||||
observation.images.cam_left_wrist: mean_std
|
||||
observation.images.cam_high: mean_std
|
||||
observation.images.cam_low: mean_std
|
||||
observation.state: mean_std
|
||||
output_normalization_modes:
|
||||
action: mean_std
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: resnet18
|
||||
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||
replace_final_stride_with_dilation: false
|
||||
# Transformer layers.
|
||||
pre_norm: false
|
||||
dim_model: 512
|
||||
n_heads: 8
|
||||
dim_feedforward: 3200
|
||||
feedforward_activation: relu
|
||||
n_encoder_layers: 4
|
||||
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||
n_decoder_layers: 1
|
||||
# VAE.
|
||||
use_vae: true
|
||||
latent_dim: 32
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_momentum: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
kl_weight: 10.0
|
|
@ -0,0 +1,111 @@
|
|||
# @package _global_
|
||||
|
||||
# Use `act_real_no_state.yaml` to train on real-world Aloha/Aloha2 datasets when cameras are moving (e.g. wrist cameras)
|
||||
# Compared to `act_real.yaml`, it is camera only and does not use the state as input which is vector of robot joint positions.
|
||||
# We validated experimentaly that not using state reaches better success rate. Our hypothesis is that `act_real.yaml` might
|
||||
# overfits to the state, because the images are more complex to learn from since they are moving.
|
||||
#
|
||||
# Example of usage for training:
|
||||
# ```bash
|
||||
# python lerobot/scripts/train.py \
|
||||
# policy=act_real_no_state \
|
||||
# env=dora_aloha_real
|
||||
# ```
|
||||
|
||||
seed: 1000
|
||||
dataset_repo_id: lerobot/aloha_static_vinh_cup
|
||||
|
||||
override_dataset_stats:
|
||||
observation.images.cam_right_wrist:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.cam_left_wrist:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.cam_high:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.cam_low:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 80000
|
||||
online_steps: 0
|
||||
eval_freq: -1
|
||||
save_freq: 10000
|
||||
log_freq: 100
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 8
|
||||
lr: 1e-5
|
||||
lr_backbone: 1e-5
|
||||
weight_decay: 1e-4
|
||||
grad_clip_norm: 10
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
name: act
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 1
|
||||
chunk_size: 100 # chunk_size
|
||||
n_action_steps: 100
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.images.cam_right_wrist: [3, 480, 640]
|
||||
observation.images.cam_left_wrist: [3, 480, 640]
|
||||
observation.images.cam_high: [3, 480, 640]
|
||||
observation.images.cam_low: [3, 480, 640]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.images.cam_right_wrist: mean_std
|
||||
observation.images.cam_left_wrist: mean_std
|
||||
observation.images.cam_high: mean_std
|
||||
observation.images.cam_low: mean_std
|
||||
output_normalization_modes:
|
||||
action: mean_std
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: resnet18
|
||||
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||
replace_final_stride_with_dilation: false
|
||||
# Transformer layers.
|
||||
pre_norm: false
|
||||
dim_model: 512
|
||||
n_heads: 8
|
||||
dim_feedforward: 3200
|
||||
feedforward_activation: relu
|
||||
n_encoder_layers: 4
|
||||
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||
n_decoder_layers: 1
|
||||
# VAE.
|
||||
use_vae: true
|
||||
latent_dim: 32
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_momentum: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
kl_weight: 10.0
|
|
@ -44,6 +44,10 @@ training:
|
|||
observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
|
||||
action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]"
|
||||
|
||||
# The original implementation doesn't sample frames for the last 7 steps,
|
||||
# which avoids excessive padding and leads to improved training results.
|
||||
drop_n_last_frames: 7 # ${policy.horizon} - ${policy.n_action_steps} - ${policy.n_obs_steps} + 1
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
|
|
@ -209,7 +209,7 @@ def eval_policy(
|
|||
policy: torch.nn.Module,
|
||||
n_episodes: int,
|
||||
max_episodes_rendered: int = 0,
|
||||
video_dir: Path | None = None,
|
||||
videos_dir: Path | None = None,
|
||||
return_episode_data: bool = False,
|
||||
start_seed: int | None = None,
|
||||
enable_progbar: bool = False,
|
||||
|
@ -347,8 +347,8 @@ def eval_policy(
|
|||
):
|
||||
if n_episodes_rendered >= max_episodes_rendered:
|
||||
break
|
||||
video_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_path = video_dir / f"eval_episode_{n_episodes_rendered}.mp4"
|
||||
videos_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_path = videos_dir / f"eval_episode_{n_episodes_rendered}.mp4"
|
||||
video_paths.append(str(video_path))
|
||||
thread = threading.Thread(
|
||||
target=write_video,
|
||||
|
@ -503,9 +503,10 @@ def _compile_episode_data(
|
|||
}
|
||||
|
||||
|
||||
def eval(
|
||||
def main(
|
||||
pretrained_policy_path: str | None = None,
|
||||
hydra_cfg_path: str | None = None,
|
||||
out_dir: str | None = None,
|
||||
config_overrides: list[str] | None = None,
|
||||
):
|
||||
assert (pretrained_policy_path is None) ^ (hydra_cfg_path is None)
|
||||
|
@ -513,12 +514,8 @@ def eval(
|
|||
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", config_overrides)
|
||||
else:
|
||||
hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides)
|
||||
out_dir = (
|
||||
f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
|
||||
)
|
||||
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
out_dir = f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
|
||||
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(hydra_cfg.device, log=True)
|
||||
|
@ -546,7 +543,7 @@ def eval(
|
|||
policy,
|
||||
hydra_cfg.eval.n_episodes,
|
||||
max_episodes_rendered=10,
|
||||
video_dir=Path(out_dir) / "eval",
|
||||
videos_dir=Path(out_dir) / "videos",
|
||||
start_seed=hydra_cfg.seed,
|
||||
enable_progbar=True,
|
||||
enable_inner_progbar=True,
|
||||
|
@ -586,6 +583,13 @@ if __name__ == "__main__":
|
|||
),
|
||||
)
|
||||
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
|
||||
parser.add_argument(
|
||||
"--out-dir",
|
||||
help=(
|
||||
"Where to save the evaluation outputs. If not provided, outputs are saved in "
|
||||
"outputs/eval/{timestamp}_{env_name}_{policy_name}"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"overrides",
|
||||
nargs="*",
|
||||
|
@ -594,7 +598,7 @@ if __name__ == "__main__":
|
|||
args = parser.parse_args()
|
||||
|
||||
if args.pretrained_policy_name_or_path is None:
|
||||
eval(hydra_cfg_path=args.config, config_overrides=args.overrides)
|
||||
main(hydra_cfg_path=args.config, out_dir=args.out_dir, config_overrides=args.overrides)
|
||||
else:
|
||||
try:
|
||||
pretrained_policy_path = Path(
|
||||
|
@ -618,4 +622,8 @@ if __name__ == "__main__":
|
|||
"repo ID, nor is it an existing local directory."
|
||||
)
|
||||
|
||||
eval(pretrained_policy_path=pretrained_policy_path, config_overrides=args.overrides)
|
||||
main(
|
||||
pretrained_policy_path=pretrained_policy_path,
|
||||
out_dir=args.out_dir,
|
||||
config_overrides=args.overrides,
|
||||
)
|
||||
|
|
|
@ -71,9 +71,9 @@ import torch
|
|||
from huggingface_hub import HfApi
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from lerobot.common.datasets.compute_stats import compute_stats
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
||||
from lerobot.common.datasets.push_dataset_to_hub.compute_stats import compute_stats
|
||||
from lerobot.common.datasets.utils import flatten_dict
|
||||
|
||||
|
||||
|
@ -84,10 +84,14 @@ def get_from_raw_to_lerobot_format_fn(raw_format):
|
|||
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "aloha_hdf5":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "aloha_dora":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_dora_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "xarm_pkl":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
|
||||
else:
|
||||
raise ValueError(raw_format)
|
||||
raise ValueError(
|
||||
f"The selected {raw_format} can't be found. Did you add it to `lerobot/scripts/push_dataset_to_hub.py::get_from_raw_to_lerobot_format_fn`?"
|
||||
)
|
||||
|
||||
return from_raw_to_lerobot_format
|
||||
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
import logging
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
|
||||
|
@ -28,6 +27,8 @@ from termcolor import colored
|
|||
from torch.cuda.amp import GradScaler
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
|
||||
from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset
|
||||
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.logger import Logger, log_output_dir
|
||||
|
@ -164,6 +165,7 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
|
|||
grad_norm = info["grad_norm"]
|
||||
lr = info["lr"]
|
||||
update_s = info["update_s"]
|
||||
dataloading_s = info["dataloading_s"]
|
||||
|
||||
# A sample is an (observation,action) pair, where observation and action
|
||||
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
|
||||
|
@ -184,6 +186,7 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
|
|||
f"lr:{lr:0.1e}",
|
||||
# in seconds
|
||||
f"updt_s:{update_s:.3f}",
|
||||
f"data_s:{dataloading_s:.3f}", # if not ~0, you are bottlenecked by cpu or io
|
||||
]
|
||||
logging.info(" ".join(log_items))
|
||||
|
||||
|
@ -295,9 +298,18 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
|
||||
logging.info("make_dataset")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
if isinstance(offline_dataset, MultiLeRobotDataset):
|
||||
logging.info(
|
||||
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
||||
f"{pformat(offline_dataset.repo_id_to_index , indent=2)}"
|
||||
)
|
||||
|
||||
logging.info("make_env")
|
||||
eval_env = make_env(cfg)
|
||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||
# using the eval.py instead, with gym_dora environment and dora-rs.
|
||||
if cfg.training.eval_freq > 0:
|
||||
logging.info("make_env")
|
||||
eval_env = make_env(cfg)
|
||||
|
||||
logging.info("make_policy")
|
||||
policy = make_policy(
|
||||
|
@ -330,18 +342,21 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
|
||||
# Note: this helper will be used in offline and online training loops.
|
||||
def evaluate_and_checkpoint_if_needed(step):
|
||||
if step % cfg.training.eval_freq == 0:
|
||||
_num_digits = max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
|
||||
step_identifier = f"{step:0{_num_digits}d}"
|
||||
|
||||
if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0:
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
|
||||
eval_info = eval_policy(
|
||||
eval_env,
|
||||
policy,
|
||||
cfg.eval.n_episodes,
|
||||
video_dir=Path(out_dir) / "eval",
|
||||
videos_dir=Path(out_dir) / "eval" / f"videos_step_{step_identifier}",
|
||||
max_episodes_rendered=4,
|
||||
start_seed=cfg.seed,
|
||||
)
|
||||
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
|
||||
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline=True)
|
||||
if cfg.wandb.enable:
|
||||
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
|
||||
logging.info("Resume training")
|
||||
|
@ -355,29 +370,40 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
policy,
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
identifier=str(step).zfill(
|
||||
max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
|
||||
),
|
||||
identifier=step_identifier,
|
||||
)
|
||||
logging.info("Resume training")
|
||||
|
||||
# create dataloader for offline training
|
||||
if cfg.training.get("drop_n_last_frames"):
|
||||
shuffle = False
|
||||
sampler = EpisodeAwareSampler(
|
||||
offline_dataset.episode_data_index,
|
||||
drop_n_last_frames=cfg.training.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
)
|
||||
else:
|
||||
shuffle = True
|
||||
sampler = None
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
offline_dataset,
|
||||
num_workers=4,
|
||||
num_workers=cfg.training.num_workers,
|
||||
batch_size=cfg.training.batch_size,
|
||||
shuffle=True,
|
||||
shuffle=shuffle,
|
||||
sampler=sampler,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=False,
|
||||
)
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
policy.train()
|
||||
is_offline = True
|
||||
for _ in range(step, cfg.training.offline_steps):
|
||||
if step == 0:
|
||||
logging.info("Start offline training on a fixed dataset")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
batch = next(dl_iter)
|
||||
dataloading_s = time.perf_counter() - start_time
|
||||
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(device, non_blocking=True)
|
||||
|
@ -392,8 +418,10 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
use_amp=cfg.use_amp,
|
||||
)
|
||||
|
||||
train_info["dataloading_s"] = dataloading_s
|
||||
|
||||
if step % cfg.training.log_freq == 0:
|
||||
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline)
|
||||
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline=True)
|
||||
|
||||
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
|
||||
# so we pass in step + 1.
|
||||
|
@ -401,26 +429,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
|||
|
||||
step += 1
|
||||
|
||||
# create an empty online dataset similar to offline dataset
|
||||
online_dataset = deepcopy(offline_dataset)
|
||||
online_dataset.hf_dataset = {}
|
||||
online_dataset.episode_data_index = {}
|
||||
|
||||
# create dataloader for online training
|
||||
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
|
||||
weights = [1.0] * len(concat_dataset)
|
||||
sampler = torch.utils.data.WeightedRandomSampler(
|
||||
weights, num_samples=len(concat_dataset), replacement=True
|
||||
)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
concat_dataset,
|
||||
num_workers=4,
|
||||
batch_size=cfg.training.batch_size,
|
||||
sampler=sampler,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
eval_env.close()
|
||||
logging.info("End of training")
|
||||
|
||||
|
|
|
@ -224,7 +224,8 @@ def main():
|
|||
help=(
|
||||
"Mode of viewing between 'local' or 'distant'. "
|
||||
"'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. "
|
||||
"'distant' creates a server on the distant machine where the data is stored. Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
|
||||
"'distant' creates a server on the distant machine where the data is stored. "
|
||||
"Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
|
@ -245,8 +246,8 @@ def main():
|
|||
default=0,
|
||||
help=(
|
||||
"Save a .rrd file in the directory provided by `--output-dir`. "
|
||||
"It also deactivates the spawning of a viewer. ",
|
||||
"Visualize the data by running `rerun path/to/file.rrd` on your local machine.",
|
||||
"It also deactivates the spawning of a viewer. "
|
||||
"Visualize the data by running `rerun path/to/file.rrd` on your local machine."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "absl-py"
|
||||
|
@ -444,63 +444,63 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "coverage"
|
||||
version = "7.5.1"
|
||||
version = "7.5.3"
|
||||
description = "Code coverage measurement for Python"
|
||||
optional = true
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "coverage-7.5.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0884920835a033b78d1c73b6d3bbcda8161a900f38a488829a83982925f6c2e"},
|
||||
{file = "coverage-7.5.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:39afcd3d4339329c5f58de48a52f6e4e50f6578dd6099961cf22228feb25f38f"},
|
||||
{file = "coverage-7.5.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a7b0ceee8147444347da6a66be737c9d78f3353b0681715b668b72e79203e4a"},
|
||||
{file = "coverage-7.5.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a9ca3f2fae0088c3c71d743d85404cec8df9be818a005ea065495bedc33da35"},
|
||||
{file = "coverage-7.5.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd215c0c7d7aab005221608a3c2b46f58c0285a819565887ee0b718c052aa4e"},
|
||||
{file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:4bf0655ab60d754491004a5efd7f9cccefcc1081a74c9ef2da4735d6ee4a6223"},
|
||||
{file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:61c4bf1ba021817de12b813338c9be9f0ad5b1e781b9b340a6d29fc13e7c1b5e"},
|
||||
{file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:db66fc317a046556a96b453a58eced5024af4582a8dbdc0c23ca4dbc0d5b3146"},
|
||||
{file = "coverage-7.5.1-cp310-cp310-win32.whl", hash = "sha256:b016ea6b959d3b9556cb401c55a37547135a587db0115635a443b2ce8f1c7228"},
|
||||
{file = "coverage-7.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:df4e745a81c110e7446b1cc8131bf986157770fa405fe90e15e850aaf7619bc8"},
|
||||
{file = "coverage-7.5.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:796a79f63eca8814ca3317a1ea443645c9ff0d18b188de470ed7ccd45ae79428"},
|
||||
{file = "coverage-7.5.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4fc84a37bfd98db31beae3c2748811a3fa72bf2007ff7902f68746d9757f3746"},
|
||||
{file = "coverage-7.5.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6175d1a0559986c6ee3f7fccfc4a90ecd12ba0a383dcc2da30c2b9918d67d8a3"},
|
||||
{file = "coverage-7.5.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fc81d5878cd6274ce971e0a3a18a8803c3fe25457165314271cf78e3aae3aa2"},
|
||||
{file = "coverage-7.5.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:556cf1a7cbc8028cb60e1ff0be806be2eded2daf8129b8811c63e2b9a6c43bca"},
|
||||
{file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:9981706d300c18d8b220995ad22627647be11a4276721c10911e0e9fa44c83e8"},
|
||||
{file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d7fed867ee50edf1a0b4a11e8e5d0895150e572af1cd6d315d557758bfa9c057"},
|
||||
{file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ef48e2707fb320c8f139424a596f5b69955a85b178f15af261bab871873bb987"},
|
||||
{file = "coverage-7.5.1-cp311-cp311-win32.whl", hash = "sha256:9314d5678dcc665330df5b69c1e726a0e49b27df0461c08ca12674bcc19ef136"},
|
||||
{file = "coverage-7.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:5fa567e99765fe98f4e7d7394ce623e794d7cabb170f2ca2ac5a4174437e90dd"},
|
||||
{file = "coverage-7.5.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b6cf3764c030e5338e7f61f95bd21147963cf6aa16e09d2f74f1fa52013c1206"},
|
||||
{file = "coverage-7.5.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ec92012fefebee89a6b9c79bc39051a6cb3891d562b9270ab10ecfdadbc0c34"},
|
||||
{file = "coverage-7.5.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16db7f26000a07efcf6aea00316f6ac57e7d9a96501e990a36f40c965ec7a95d"},
|
||||
{file = "coverage-7.5.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:beccf7b8a10b09c4ae543582c1319c6df47d78fd732f854ac68d518ee1fb97fa"},
|
||||
{file = "coverage-7.5.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8748731ad392d736cc9ccac03c9845b13bb07d020a33423fa5b3a36521ac6e4e"},
|
||||
{file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7352b9161b33fd0b643ccd1f21f3a3908daaddf414f1c6cb9d3a2fd618bf2572"},
|
||||
{file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:7a588d39e0925f6a2bff87154752481273cdb1736270642aeb3635cb9b4cad07"},
|
||||
{file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:68f962d9b72ce69ea8621f57551b2fa9c70509af757ee3b8105d4f51b92b41a7"},
|
||||
{file = "coverage-7.5.1-cp312-cp312-win32.whl", hash = "sha256:f152cbf5b88aaeb836127d920dd0f5e7edff5a66f10c079157306c4343d86c19"},
|
||||
{file = "coverage-7.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:5a5740d1fb60ddf268a3811bcd353de34eb56dc24e8f52a7f05ee513b2d4f596"},
|
||||
{file = "coverage-7.5.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e2213def81a50519d7cc56ed643c9e93e0247f5bbe0d1247d15fa520814a7cd7"},
|
||||
{file = "coverage-7.5.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5037f8fcc2a95b1f0e80585bd9d1ec31068a9bcb157d9750a172836e98bc7a90"},
|
||||
{file = "coverage-7.5.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3721c2c9e4c4953a41a26c14f4cef64330392a6d2d675c8b1db3b645e31f0e"},
|
||||
{file = "coverage-7.5.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca498687ca46a62ae590253fba634a1fe9836bc56f626852fb2720f334c9e4e5"},
|
||||
{file = "coverage-7.5.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0cdcbc320b14c3e5877ee79e649677cb7d89ef588852e9583e6b24c2e5072661"},
|
||||
{file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:57e0204b5b745594e5bc14b9b50006da722827f0b8c776949f1135677e88d0b8"},
|
||||
{file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:8fe7502616b67b234482c3ce276ff26f39ffe88adca2acf0261df4b8454668b4"},
|
||||
{file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:9e78295f4144f9dacfed4f92935fbe1780021247c2fabf73a819b17f0ccfff8d"},
|
||||
{file = "coverage-7.5.1-cp38-cp38-win32.whl", hash = "sha256:1434e088b41594baa71188a17533083eabf5609e8e72f16ce8c186001e6b8c41"},
|
||||
{file = "coverage-7.5.1-cp38-cp38-win_amd64.whl", hash = "sha256:0646599e9b139988b63704d704af8e8df7fa4cbc4a1f33df69d97f36cb0a38de"},
|
||||
{file = "coverage-7.5.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4cc37def103a2725bc672f84bd939a6fe4522310503207aae4d56351644682f1"},
|
||||
{file = "coverage-7.5.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fc0b4d8bfeabd25ea75e94632f5b6e047eef8adaed0c2161ada1e922e7f7cece"},
|
||||
{file = "coverage-7.5.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d0a0f5e06881ecedfe6f3dd2f56dcb057b6dbeb3327fd32d4b12854df36bf26"},
|
||||
{file = "coverage-7.5.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9735317685ba6ec7e3754798c8871c2f49aa5e687cc794a0b1d284b2389d1bd5"},
|
||||
{file = "coverage-7.5.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d21918e9ef11edf36764b93101e2ae8cc82aa5efdc7c5a4e9c6c35a48496d601"},
|
||||
{file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c3e757949f268364b96ca894b4c342b41dc6f8f8b66c37878aacef5930db61be"},
|
||||
{file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:79afb6197e2f7f60c4824dd4b2d4c2ec5801ceb6ba9ce5d2c3080e5660d51a4f"},
|
||||
{file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d1d0d98d95dd18fe29dc66808e1accf59f037d5716f86a501fc0256455219668"},
|
||||
{file = "coverage-7.5.1-cp39-cp39-win32.whl", hash = "sha256:1cc0fe9b0b3a8364093c53b0b4c0c2dd4bb23acbec4c9240b5f284095ccf7981"},
|
||||
{file = "coverage-7.5.1-cp39-cp39-win_amd64.whl", hash = "sha256:dde0070c40ea8bb3641e811c1cfbf18e265d024deff6de52c5950677a8fb1e0f"},
|
||||
{file = "coverage-7.5.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:6537e7c10cc47c595828b8a8be04c72144725c383c4702703ff4e42e44577312"},
|
||||
{file = "coverage-7.5.1.tar.gz", hash = "sha256:54de9ef3a9da981f7af93eafde4ede199e0846cd819eb27c88e2b712aae9708c"},
|
||||
{file = "coverage-7.5.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a6519d917abb15e12380406d721e37613e2a67d166f9fb7e5a8ce0375744cd45"},
|
||||
{file = "coverage-7.5.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:aea7da970f1feccf48be7335f8b2ca64baf9b589d79e05b9397a06696ce1a1ec"},
|
||||
{file = "coverage-7.5.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:923b7b1c717bd0f0f92d862d1ff51d9b2b55dbbd133e05680204465f454bb286"},
|
||||
{file = "coverage-7.5.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62bda40da1e68898186f274f832ef3e759ce929da9a9fd9fcf265956de269dbc"},
|
||||
{file = "coverage-7.5.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8b7339180d00de83e930358223c617cc343dd08e1aa5ec7b06c3a121aec4e1d"},
|
||||
{file = "coverage-7.5.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:25a5caf742c6195e08002d3b6c2dd6947e50efc5fc2c2205f61ecb47592d2d83"},
|
||||
{file = "coverage-7.5.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:05ac5f60faa0c704c0f7e6a5cbfd6f02101ed05e0aee4d2822637a9e672c998d"},
|
||||
{file = "coverage-7.5.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:239a4e75e09c2b12ea478d28815acf83334d32e722e7433471fbf641c606344c"},
|
||||
{file = "coverage-7.5.3-cp310-cp310-win32.whl", hash = "sha256:a5812840d1d00eafae6585aba38021f90a705a25b8216ec7f66aebe5b619fb84"},
|
||||
{file = "coverage-7.5.3-cp310-cp310-win_amd64.whl", hash = "sha256:33ca90a0eb29225f195e30684ba4a6db05dbef03c2ccd50b9077714c48153cac"},
|
||||
{file = "coverage-7.5.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f81bc26d609bf0fbc622c7122ba6307993c83c795d2d6f6f6fd8c000a770d974"},
|
||||
{file = "coverage-7.5.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7cec2af81f9e7569280822be68bd57e51b86d42e59ea30d10ebdbb22d2cb7232"},
|
||||
{file = "coverage-7.5.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:55f689f846661e3f26efa535071775d0483388a1ccfab899df72924805e9e7cd"},
|
||||
{file = "coverage-7.5.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:50084d3516aa263791198913a17354bd1dc627d3c1639209640b9cac3fef5807"},
|
||||
{file = "coverage-7.5.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:341dd8f61c26337c37988345ca5c8ccabeff33093a26953a1ac72e7d0103c4fb"},
|
||||
{file = "coverage-7.5.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ab0b028165eea880af12f66086694768f2c3139b2c31ad5e032c8edbafca6ffc"},
|
||||
{file = "coverage-7.5.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:5bc5a8c87714b0c67cfeb4c7caa82b2d71e8864d1a46aa990b5588fa953673b8"},
|
||||
{file = "coverage-7.5.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:38a3b98dae8a7c9057bd91fbf3415c05e700a5114c5f1b5b0ea5f8f429ba6614"},
|
||||
{file = "coverage-7.5.3-cp311-cp311-win32.whl", hash = "sha256:fcf7d1d6f5da887ca04302db8e0e0cf56ce9a5e05f202720e49b3e8157ddb9a9"},
|
||||
{file = "coverage-7.5.3-cp311-cp311-win_amd64.whl", hash = "sha256:8c836309931839cca658a78a888dab9676b5c988d0dd34ca247f5f3e679f4e7a"},
|
||||
{file = "coverage-7.5.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:296a7d9bbc598e8744c00f7a6cecf1da9b30ae9ad51c566291ff1314e6cbbed8"},
|
||||
{file = "coverage-7.5.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:34d6d21d8795a97b14d503dcaf74226ae51eb1f2bd41015d3ef332a24d0a17b3"},
|
||||
{file = "coverage-7.5.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e317953bb4c074c06c798a11dbdd2cf9979dbcaa8ccc0fa4701d80042d4ebf1"},
|
||||
{file = "coverage-7.5.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:705f3d7c2b098c40f5b81790a5fedb274113373d4d1a69e65f8b68b0cc26f6db"},
|
||||
{file = "coverage-7.5.3-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1196e13c45e327d6cd0b6e471530a1882f1017eb83c6229fc613cd1a11b53cd"},
|
||||
{file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:015eddc5ccd5364dcb902eaecf9515636806fa1e0d5bef5769d06d0f31b54523"},
|
||||
{file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:fd27d8b49e574e50caa65196d908f80e4dff64d7e592d0c59788b45aad7e8b35"},
|
||||
{file = "coverage-7.5.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:33fc65740267222fc02975c061eb7167185fef4cc8f2770267ee8bf7d6a42f84"},
|
||||
{file = "coverage-7.5.3-cp312-cp312-win32.whl", hash = "sha256:7b2a19e13dfb5c8e145c7a6ea959485ee8e2204699903c88c7d25283584bfc08"},
|
||||
{file = "coverage-7.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:0bbddc54bbacfc09b3edaec644d4ac90c08ee8ed4844b0f86227dcda2d428fcb"},
|
||||
{file = "coverage-7.5.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f78300789a708ac1f17e134593f577407d52d0417305435b134805c4fb135adb"},
|
||||
{file = "coverage-7.5.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b368e1aee1b9b75757942d44d7598dcd22a9dbb126affcbba82d15917f0cc155"},
|
||||
{file = "coverage-7.5.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f836c174c3a7f639bded48ec913f348c4761cbf49de4a20a956d3431a7c9cb24"},
|
||||
{file = "coverage-7.5.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:244f509f126dc71369393ce5fea17c0592c40ee44e607b6d855e9c4ac57aac98"},
|
||||
{file = "coverage-7.5.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c4c2872b3c91f9baa836147ca33650dc5c172e9273c808c3c3199c75490e709d"},
|
||||
{file = "coverage-7.5.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:dd4b3355b01273a56b20c219e74e7549e14370b31a4ffe42706a8cda91f19f6d"},
|
||||
{file = "coverage-7.5.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:f542287b1489c7a860d43a7d8883e27ca62ab84ca53c965d11dac1d3a1fab7ce"},
|
||||
{file = "coverage-7.5.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:75e3f4e86804023e991096b29e147e635f5e2568f77883a1e6eed74512659ab0"},
|
||||
{file = "coverage-7.5.3-cp38-cp38-win32.whl", hash = "sha256:c59d2ad092dc0551d9f79d9d44d005c945ba95832a6798f98f9216ede3d5f485"},
|
||||
{file = "coverage-7.5.3-cp38-cp38-win_amd64.whl", hash = "sha256:fa21a04112c59ad54f69d80e376f7f9d0f5f9123ab87ecd18fbb9ec3a2beed56"},
|
||||
{file = "coverage-7.5.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f5102a92855d518b0996eb197772f5ac2a527c0ec617124ad5242a3af5e25f85"},
|
||||
{file = "coverage-7.5.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d1da0a2e3b37b745a2b2a678a4c796462cf753aebf94edcc87dcc6b8641eae31"},
|
||||
{file = "coverage-7.5.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8383a6c8cefba1b7cecc0149415046b6fc38836295bc4c84e820872eb5478b3d"},
|
||||
{file = "coverage-7.5.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9aad68c3f2566dfae84bf46295a79e79d904e1c21ccfc66de88cd446f8686341"},
|
||||
{file = "coverage-7.5.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e079c9ec772fedbade9d7ebc36202a1d9ef7291bc9b3a024ca395c4d52853d7"},
|
||||
{file = "coverage-7.5.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:bde997cac85fcac227b27d4fb2c7608a2c5f6558469b0eb704c5726ae49e1c52"},
|
||||
{file = "coverage-7.5.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:990fb20b32990b2ce2c5f974c3e738c9358b2735bc05075d50a6f36721b8f303"},
|
||||
{file = "coverage-7.5.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3d5a67f0da401e105753d474369ab034c7bae51a4c31c77d94030d59e41df5bd"},
|
||||
{file = "coverage-7.5.3-cp39-cp39-win32.whl", hash = "sha256:e08c470c2eb01977d221fd87495b44867a56d4d594f43739a8028f8646a51e0d"},
|
||||
{file = "coverage-7.5.3-cp39-cp39-win_amd64.whl", hash = "sha256:1d2a830ade66d3563bb61d1e3c77c8def97b30ed91e166c67d0632c018f380f0"},
|
||||
{file = "coverage-7.5.3-pp38.pp39.pp310-none-any.whl", hash = "sha256:3538d8fb1ee9bdd2e2692b3b18c22bb1c19ffbefd06880f5ac496e42d7bb3884"},
|
||||
{file = "coverage-7.5.3.tar.gz", hash = "sha256:04aefca5190d1dc7a53a4c1a5a7f8568811306d7a8ee231c42fb69215571944f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -785,6 +785,26 @@ files = [
|
|||
[package.dependencies]
|
||||
six = ">=1.4.0"
|
||||
|
||||
[[package]]
|
||||
name = "dora-rs"
|
||||
version = "0.3.4"
|
||||
description = "`dora` goal is to be a low latency, composable, and distributed data flow."
|
||||
optional = true
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "dora_rs-0.3.4-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:d1b738eea5a4966d731c26c6b6a0a50a491a24f7e9e335475f983cfc6f0da19e"},
|
||||
{file = "dora_rs-0.3.4-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:80b724871618c78a4e5863938fa66724176cc40352771087aebe1e62a8141157"},
|
||||
{file = "dora_rs-0.3.4-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3a3919e157b47dc1dbc74c040a73087a4485f0d1bee99b6adcdbc36559400fe2"},
|
||||
{file = "dora_rs-0.3.4-cp37-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f7c95f6e5858fd651d6cd220e4f052e99db2944b9c37fb0b5402d60ac4b41a63"},
|
||||
{file = "dora_rs-0.3.4-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37d915fbbca282446235c98a9ca08389aa3ef3155d4e88c6c136326e9a830042"},
|
||||
{file = "dora_rs-0.3.4-cp37-abi3-win32.whl", hash = "sha256:c9f7f22f65c884ec9bee0245ce98d0c7fad25dec0f982e566f844b5e8e58818f"},
|
||||
{file = "dora_rs-0.3.4-cp37-abi3-win_amd64.whl", hash = "sha256:0a6a37f96a9f6e13b58b02a6ea75af192af5fbe4f456f6a67b1f239c3cee3276"},
|
||||
{file = "dora_rs-0.3.4.tar.gz", hash = "sha256:05c5d0db0d23d7c4669995ae34db11cd636dbf91f5705d832669bd04e7452903"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pyarrow = "*"
|
||||
|
||||
[[package]]
|
||||
name = "einops"
|
||||
version = "0.8.0"
|
||||
|
@ -1066,6 +1086,27 @@ mujoco = ">=2.3.7,<3.0.0"
|
|||
dev = ["debugpy (>=1.8.1)", "pre-commit (>=3.7.0)"]
|
||||
test = ["pytest (>=8.1.0)", "pytest-cov (>=5.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "gym-dora"
|
||||
version = "0.1.0"
|
||||
description = ""
|
||||
optional = true
|
||||
python-versions = "^3.10"
|
||||
files = []
|
||||
develop = false
|
||||
|
||||
[package.dependencies]
|
||||
dora-rs = ">=0.3.4"
|
||||
gymnasium = ">=0.29.1"
|
||||
pyarrow = ">=12.0.0"
|
||||
|
||||
[package.source]
|
||||
type = "git"
|
||||
url = "https://github.com/dora-rs/dora-lerobot.git"
|
||||
reference = "HEAD"
|
||||
resolved_reference = "ed0c00a4fdc6ec856c9842551acd7dc7ee776f79"
|
||||
subdirectory = "gym_dora"
|
||||
|
||||
[[package]]
|
||||
name = "gym-pusht"
|
||||
version = "0.1.4"
|
||||
|
@ -1269,13 +1310,13 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "huggingface-hub"
|
||||
version = "0.23.1"
|
||||
version = "0.23.2"
|
||||
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
|
||||
optional = false
|
||||
python-versions = ">=3.8.0"
|
||||
files = [
|
||||
{file = "huggingface_hub-0.23.1-py3-none-any.whl", hash = "sha256:720a5bffd2b1b449deb793da8b0df7a9390a7e238534d5a08c9fbcdecb1dd3cb"},
|
||||
{file = "huggingface_hub-0.23.1.tar.gz", hash = "sha256:4f62dbf6ae94f400c6d3419485e52bce510591432a5248a65d0cb72e4d479eb4"},
|
||||
{file = "huggingface_hub-0.23.2-py3-none-any.whl", hash = "sha256:48727a16e704d409c4bb5913613308499664f22a99743435dc3a13b23c485827"},
|
||||
{file = "huggingface_hub-0.23.2.tar.gz", hash = "sha256:f6829b62d5fdecb452a76fdbec620cba4c1573655a8d710c1df71735fd9edbd2"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -2061,18 +2102,15 @@ test = ["pytest (>=7.2)", "pytest-cov (>=4.0)"]
|
|||
|
||||
[[package]]
|
||||
name = "nodeenv"
|
||||
version = "1.8.0"
|
||||
version = "1.9.0"
|
||||
description = "Node.js virtual environment builder"
|
||||
optional = true
|
||||
python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*"
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
|
||||
files = [
|
||||
{file = "nodeenv-1.8.0-py2.py3-none-any.whl", hash = "sha256:df865724bb3c3adc86b3876fa209771517b0cfe596beff01a92700e0e8be4cec"},
|
||||
{file = "nodeenv-1.8.0.tar.gz", hash = "sha256:d51e0c37e64fbf47d017feac3145cdbb58836d7eee8c6f6d3b6880c5456227d2"},
|
||||
{file = "nodeenv-1.9.0-py2.py3-none-any.whl", hash = "sha256:508ecec98f9f3330b636d4448c0f1a56fc68017c68f1e7857ebc52acf0eb879a"},
|
||||
{file = "nodeenv-1.9.0.tar.gz", hash = "sha256:07f144e90dae547bf0d4ee8da0ee42664a42a04e02ed68e06324348dafe4bdb1"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
setuptools = "*"
|
||||
|
||||
[[package]]
|
||||
name = "numba"
|
||||
version = "0.59.1"
|
||||
|
@ -2406,6 +2444,7 @@ optional = false
|
|||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "pandas-2.2.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:90c6fca2acf139569e74e8781709dccb6fe25940488755716d1d354d6bc58bce"},
|
||||
{file = "pandas-2.2.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c7adfc142dac335d8c1e0dcbd37eb8617eac386596eb9e1a1b77791cf2498238"},
|
||||
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4abfe0be0d7221be4f12552995e58723c7422c80a659da13ca382697de830c08"},
|
||||
{file = "pandas-2.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8635c16bf3d99040fdf3ca3db669a7250ddf49c55dc4aa8fe0ae0fa8d6dcc1f0"},
|
||||
{file = "pandas-2.2.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:40ae1dffb3967a52203105a077415a86044a2bea011b5f321c6aa64b379a3f51"},
|
||||
|
@ -2426,6 +2465,7 @@ files = [
|
|||
{file = "pandas-2.2.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:43498c0bdb43d55cb162cdc8c06fac328ccb5d2eabe3cadeb3529ae6f0517c32"},
|
||||
{file = "pandas-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:d187d355ecec3629624fccb01d104da7d7f391db0311145817525281e2804d23"},
|
||||
{file = "pandas-2.2.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0ca6377b8fca51815f382bd0b697a0814c8bda55115678cbc94c30aacbb6eff2"},
|
||||
{file = "pandas-2.2.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9057e6aa78a584bc93a13f0a9bf7e753a5e9770a30b4d758b8d5f2a62a9433cd"},
|
||||
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:001910ad31abc7bf06f49dcc903755d2f7f3a9186c0c040b827e522e9cef0863"},
|
||||
{file = "pandas-2.2.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66b479b0bd07204e37583c191535505410daa8df638fd8e75ae1b383851fe921"},
|
||||
{file = "pandas-2.2.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:a77e9d1c386196879aa5eb712e77461aaee433e54c68cf253053a73b7e49c33a"},
|
||||
|
@ -3188,13 +3228,13 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "requests"
|
||||
version = "2.32.2"
|
||||
version = "2.32.3"
|
||||
description = "Python HTTP for Humans."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "requests-2.32.2-py3-none-any.whl", hash = "sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c"},
|
||||
{file = "requests-2.32.2.tar.gz", hash = "sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289"},
|
||||
{file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"},
|
||||
{file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -3210,16 +3250,16 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
|||
|
||||
[[package]]
|
||||
name = "rerun-sdk"
|
||||
version = "0.16.0"
|
||||
version = "0.16.1"
|
||||
description = "The Rerun Logging SDK"
|
||||
optional = false
|
||||
python-versions = "<3.13,>=3.8"
|
||||
files = [
|
||||
{file = "rerun_sdk-0.16.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:1cc6dc66d089e296f945dc238301889efb61dd6d338b5d00f76981cf7aed0a74"},
|
||||
{file = "rerun_sdk-0.16.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:faf231897655e46eb975695df2b0ace07db362d697e697f9a3dff52f81c0dc5d"},
|
||||
{file = "rerun_sdk-0.16.0-cp38-abi3-manylinux_2_31_aarch64.whl", hash = "sha256:860a6394380d3e9b9e48bf34423bd56dda54d5b0158d2ae0e433698659b34198"},
|
||||
{file = "rerun_sdk-0.16.0-cp38-abi3-manylinux_2_31_x86_64.whl", hash = "sha256:5b8d1476f73a3ad1a5d3f21b61c633f3ab62aa80fa0b049f5ad10bf1227681ab"},
|
||||
{file = "rerun_sdk-0.16.0-cp38-abi3-win_amd64.whl", hash = "sha256:aff0051a263b8c3067243c0126d319845baf4fe640899f04aeef7daf151f35e4"},
|
||||
{file = "rerun_sdk-0.16.1-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:170c6976634008611753e10dfef8cdc395ce8180e634c169e7c61cef2f89a277"},
|
||||
{file = "rerun_sdk-0.16.1-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:c9a76eab7eb5559276737dad655200e9350df0837158dbc5a896970ab4201454"},
|
||||
{file = "rerun_sdk-0.16.1-cp38-abi3-manylinux_2_31_aarch64.whl", hash = "sha256:4d6436752d57e8b8038489a0e7e37f0c760b088e96db5fb81667d3a376d63fea"},
|
||||
{file = "rerun_sdk-0.16.1-cp38-abi3-manylinux_2_31_x86_64.whl", hash = "sha256:37b7b47948471873e84f224b16f417a94a91c7cbd6c72c68281eeff1ba414b8f"},
|
||||
{file = "rerun_sdk-0.16.1-cp38-abi3-win_amd64.whl", hash = "sha256:be88799c8afdf68eafa99e64e2e4f0a484e187e017a180219abbe6bb988acd4e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -3696,17 +3736,17 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "sympy"
|
||||
version = "1.12"
|
||||
version = "1.12.1"
|
||||
description = "Computer algebra system (CAS) in Python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "sympy-1.12-py3-none-any.whl", hash = "sha256:c3588cd4295d0c0f603d0f2ae780587e64e2efeedb3521e46b9bb1d08d184fa5"},
|
||||
{file = "sympy-1.12.tar.gz", hash = "sha256:ebf595c8dac3e0fdc4152c51878b498396ec7f30e7a914d6071e674d49420fb8"},
|
||||
{file = "sympy-1.12.1-py3-none-any.whl", hash = "sha256:9b2cbc7f1a640289430e13d2a56f02f867a1da0190f2f99d8968c2f74da0e515"},
|
||||
{file = "sympy-1.12.1.tar.gz", hash = "sha256:2877b03f998cd8c08f07cd0de5b767119cd3ef40d09f41c30d722f6686b0fb88"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
mpmath = ">=0.19"
|
||||
mpmath = ">=1.1.0,<1.4.0"
|
||||
|
||||
[[package]]
|
||||
name = "tbb"
|
||||
|
@ -4220,13 +4260,13 @@ multidict = ">=4.0"
|
|||
|
||||
[[package]]
|
||||
name = "zarr"
|
||||
version = "2.18.1"
|
||||
version = "2.18.2"
|
||||
description = "An implementation of chunked, compressed, N-dimensional arrays for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "zarr-2.18.1-py3-none-any.whl", hash = "sha256:a1770d194eec4ec0a41a01295a6f724e1c3471d704d3aca906d3b3a7f8830245"},
|
||||
{file = "zarr-2.18.1.tar.gz", hash = "sha256:28c360ed123e606c425a694a83300227a907cb86a995fc9eef620ecafbe5f92d"},
|
||||
{file = "zarr-2.18.2-py3-none-any.whl", hash = "sha256:a638754902f97efa99b406083fdc807a0e2ccf12a949117389d2a4ba9b05df38"},
|
||||
{file = "zarr-2.18.2.tar.gz", hash = "sha256:9bb393b8a0a38fb121dbb913b047d75db28de9890f6d644a217a73cf4ae74f47"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
@ -4241,13 +4281,13 @@ jupyter = ["ipytree (>=0.2.2)", "ipywidgets (>=8.0.0)", "notebook"]
|
|||
|
||||
[[package]]
|
||||
name = "zipp"
|
||||
version = "3.18.2"
|
||||
version = "3.19.0"
|
||||
description = "Backport of pathlib-compatible object wrapper for zip files"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "zipp-3.18.2-py3-none-any.whl", hash = "sha256:dce197b859eb796242b0622af1b8beb0a722d52aa2f57133ead08edd5bf5374e"},
|
||||
{file = "zipp-3.18.2.tar.gz", hash = "sha256:6278d9ddbcfb1f1089a88fde84481528b07b0e10474e09dcfe53dad4069fa059"},
|
||||
{file = "zipp-3.19.0-py3-none-any.whl", hash = "sha256:96dc6ad62f1441bcaccef23b274ec471518daf4fbbc580341204936a5a3dddec"},
|
||||
{file = "zipp-3.19.0.tar.gz", hash = "sha256:952df858fb3164426c976d9338d3961e8e8b3758e2e059e0f754b8c4262625ee"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
|
@ -4257,6 +4297,7 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more
|
|||
[extras]
|
||||
aloha = ["gym-aloha"]
|
||||
dev = ["debugpy", "pre-commit"]
|
||||
dora = ["gym-dora"]
|
||||
pusht = ["gym-pusht"]
|
||||
test = ["pytest", "pytest-cov"]
|
||||
umi = ["imagecodecs"]
|
||||
|
@ -4265,4 +4306,4 @@ xarm = ["gym-xarm"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<3.13"
|
||||
content-hash = "1ad6ef0f88f0056ab639e60e033e586f7460a9c5fc3676a477bbd47923f41cb6"
|
||||
content-hash = "23ddb8dd774a4faf85d08a07dfdf19badb7c370120834b71df4afca254520771"
|
||||
|
|
|
@ -46,6 +46,7 @@ h5py = ">=3.10.0"
|
|||
huggingface-hub = {extras = ["hf-transfer"], version = "^0.23.0"}
|
||||
gymnasium = ">=0.29.1"
|
||||
cmake = ">=3.29.0.1"
|
||||
gym-dora = { git = "https://github.com/dora-rs/dora-lerobot.git", subdirectory = "gym_dora", optional = true }
|
||||
gym-pusht = { version = ">=0.1.3", optional = true}
|
||||
gym-xarm = { version = ">=0.1.1", optional = true}
|
||||
gym-aloha = { version = ">=0.1.1", optional = true}
|
||||
|
@ -62,6 +63,7 @@ deepdiff = ">=7.0.1"
|
|||
|
||||
|
||||
[tool.poetry.extras]
|
||||
dora = ["gym-dora"]
|
||||
pusht = ["gym-pusht"]
|
||||
xarm = ["gym-xarm"]
|
||||
aloha = ["gym-aloha"]
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2fff6294b94cf42d4dd1249dcc5c3b0269d6d9c697f894e61b867d7ab81a94e4
|
||||
size 5104
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4aa23e51607604a18b70fa42edbbe1af34f119d985628fc27cc1bbb0efbc8901
|
||||
size 31688
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6fd368406c93cb562a69ff11cf7adf34a4b223507dcb2b9e9b8f44ee1036988a
|
||||
size 68
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5663ee79a13bb70a1604b887dd21bf89d18482287442419c6cc6c5bf0e753e99
|
||||
size 34928
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fb1a45463efd860af2ca22c16c77d55a18bd96fef080ae77978845a2f22ef716
|
||||
size 5104
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:aa5a43e22f01d8e2f8d19f31753608794f1edbd74aaf71660091ab80ea58dc9b
|
||||
size 30808
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:97455b4360748c99905cd103473c1a52da6901d0a73ffbc51b5ea3eb250d1386
|
||||
size 68
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:54d1f75cf67a7b1d7a7c6865ecb9b1cc86a2f032d1890245f8996789ab6e0df6
|
||||
size 33608
|
|
@ -75,15 +75,16 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
|
|||
# HACK: We reload a batch with no delta_timestamps as `select_action` won't expect a timestamps dimension
|
||||
dataset.delta_timestamps = None
|
||||
batch = next(iter(dataloader))
|
||||
obs = {
|
||||
k: batch[k]
|
||||
for k in batch
|
||||
if k in ["observation.image", "observation.images.top", "observation.state"]
|
||||
}
|
||||
obs = {}
|
||||
for k in batch:
|
||||
if k.startswith("observation"):
|
||||
obs[k] = batch[k]
|
||||
|
||||
if "n_action_steps" in cfg.policy:
|
||||
actions_queue = cfg.policy.n_action_steps
|
||||
else:
|
||||
actions_queue = cfg.policy.n_action_repeats
|
||||
|
||||
actions_queue = (
|
||||
cfg.policy.n_action_steps if "n_action_steps" in cfg.policy else cfg.policy.n_action_repeats
|
||||
)
|
||||
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
|
||||
return output_dict, grad_stats, param_stats, actions
|
||||
|
||||
|
@ -114,6 +115,8 @@ if __name__ == "__main__":
|
|||
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
||||
),
|
||||
("aloha", "act", ["policy.n_action_steps=10"]),
|
||||
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
|
||||
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
|
||||
(
|
||||
"pusht",
|
||||
"octo",
|
|
@ -16,6 +16,7 @@
|
|||
import json
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
|
@ -25,26 +26,34 @@ from datasets import Dataset
|
|||
from safetensors.torch import load_file
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
)
|
||||
from lerobot.common.datasets.push_dataset_to_hub.compute_stats import (
|
||||
from lerobot.common.datasets.compute_stats import (
|
||||
aggregate_stats,
|
||||
compute_stats,
|
||||
get_stats_einops_patterns,
|
||||
)
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
|
||||
from lerobot.common.datasets.utils import (
|
||||
flatten_dict,
|
||||
hf_transform_to_torch,
|
||||
load_previous_and_future_frames,
|
||||
unflatten_dict,
|
||||
)
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_name, repo_id, policy_name", lerobot.env_dataset_policy_triplets)
|
||||
@pytest.mark.parametrize(
|
||||
"env_name, repo_id, policy_name",
|
||||
lerobot.env_dataset_policy_triplets
|
||||
+ [("aloha", ["lerobot/aloha_sim_insertion_human", "lerobot/aloha_sim_transfer_cube_human"], "act")],
|
||||
)
|
||||
def test_factory(env_name, repo_id, policy_name):
|
||||
"""
|
||||
Tests that:
|
||||
- we can create a dataset with the factory.
|
||||
- for a commonly used set of data keys, the data dimensions are correct.
|
||||
"""
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH,
|
||||
overrides=[
|
||||
|
@ -105,6 +114,39 @@ def test_factory(env_name, repo_id, policy_name):
|
|||
assert key in item, f"{key}"
|
||||
|
||||
|
||||
# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds.
|
||||
def test_multilerobotdataset_frames():
|
||||
"""Check that all dataset frames are incorporated."""
|
||||
# Note: use the image variants of the dataset to make the test approx 3x faster.
|
||||
# Note: We really do need three repo_ids here as at some point this caught an issue with the chaining
|
||||
# logic that wouldn't be caught with two repo IDs.
|
||||
repo_ids = [
|
||||
"lerobot/aloha_sim_insertion_human_image",
|
||||
"lerobot/aloha_sim_transfer_cube_human_image",
|
||||
"lerobot/aloha_sim_insertion_scripted_image",
|
||||
]
|
||||
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
|
||||
dataset = MultiLeRobotDataset(repo_ids)
|
||||
assert len(dataset) == sum(len(d) for d in sub_datasets)
|
||||
assert dataset.num_samples == sum(d.num_samples for d in sub_datasets)
|
||||
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
|
||||
|
||||
# Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and
|
||||
# check they match.
|
||||
expected_dataset_indices = []
|
||||
for i, sub_dataset in enumerate(sub_datasets):
|
||||
expected_dataset_indices.extend([i] * len(sub_dataset))
|
||||
|
||||
for expected_dataset_index, sub_dataset_item, dataset_item in zip(
|
||||
expected_dataset_indices, chain(*sub_datasets), dataset, strict=True
|
||||
):
|
||||
dataset_index = dataset_item.pop("dataset_index")
|
||||
assert dataset_index == expected_dataset_index
|
||||
assert sub_dataset_item.keys() == dataset_item.keys()
|
||||
for k in sub_dataset_item:
|
||||
assert torch.equal(sub_dataset_item[k], dataset_item[k])
|
||||
|
||||
|
||||
def test_compute_stats_on_xarm():
|
||||
"""Check that the statistics are computed correctly according to the stats_patterns property.
|
||||
|
||||
|
@ -315,3 +357,31 @@ def test_backward_compatibility(repo_id):
|
|||
# i = dataset.episode_data_index["to"][-1].item()
|
||||
# load_and_compare(i - 2)
|
||||
# load_and_compare(i - 1)
|
||||
|
||||
|
||||
def test_aggregate_stats():
|
||||
"""Makes 3 basic datasets and checks that aggregate stats are computed correctly."""
|
||||
with seeded_context(0):
|
||||
data_a = torch.rand(30, dtype=torch.float32)
|
||||
data_b = torch.rand(20, dtype=torch.float32)
|
||||
data_c = torch.rand(20, dtype=torch.float32)
|
||||
|
||||
hf_dataset_1 = Dataset.from_dict(
|
||||
{"a": data_a[:10], "b": data_b[:10], "c": data_c[:10], "index": torch.arange(10)}
|
||||
)
|
||||
hf_dataset_1.set_transform(hf_transform_to_torch)
|
||||
hf_dataset_2 = Dataset.from_dict({"a": data_a[10:20], "b": data_b[10:], "index": torch.arange(10)})
|
||||
hf_dataset_2.set_transform(hf_transform_to_torch)
|
||||
hf_dataset_3 = Dataset.from_dict({"a": data_a[20:], "c": data_c[10:], "index": torch.arange(10)})
|
||||
hf_dataset_3.set_transform(hf_transform_to_torch)
|
||||
dataset_1 = LeRobotDataset.from_preloaded("d1", hf_dataset=hf_dataset_1)
|
||||
dataset_1.stats = compute_stats(dataset_1, batch_size=len(hf_dataset_1), num_workers=0)
|
||||
dataset_2 = LeRobotDataset.from_preloaded("d2", hf_dataset=hf_dataset_2)
|
||||
dataset_2.stats = compute_stats(dataset_2, batch_size=len(hf_dataset_2), num_workers=0)
|
||||
dataset_3 = LeRobotDataset.from_preloaded("d3", hf_dataset=hf_dataset_3)
|
||||
dataset_3.stats = compute_stats(dataset_3, batch_size=len(hf_dataset_3), num_workers=0)
|
||||
stats = aggregate_stats([dataset_1, dataset_2, dataset_3])
|
||||
for data_key, data in zip(["a", "b", "c"], [data_a, data_b, data_c], strict=True):
|
||||
for agg_fn in ["mean", "min", "max"]:
|
||||
assert torch.allclose(stats[data_key][agg_fn], einops.reduce(data, "n -> 1", agg_fn))
|
||||
assert torch.allclose(stats[data_key]["std"], torch.std(data, correction=0))
|
||||
|
|
|
@ -30,7 +30,7 @@ from lerobot.common.policies.factory import get_policy_and_config_classes, make_
|
|||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from tests.scripts.save_policy_to_safetensor import get_policy_stats
|
||||
from tests.scripts.save_policy_to_safetensors import get_policy_stats
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
|
||||
|
||||
|
||||
|
@ -73,6 +73,8 @@ def test_get_policy_and_config_classes(policy_name: str):
|
|||
),
|
||||
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
|
||||
("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]),
|
||||
("dora_aloha_real", "act_real", []),
|
||||
("dora_aloha_real", "act_real_no_state", []),
|
||||
],
|
||||
)
|
||||
@require_env
|
||||
|
@ -85,6 +87,9 @@ def test_policy(env_name, policy_name, extra_overrides):
|
|||
- Updating the policy.
|
||||
- Using the policy to select actions at inference time.
|
||||
- Test the action can be applied to the policy
|
||||
|
||||
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
|
||||
and for now we add tests as we see fit.
|
||||
"""
|
||||
cfg = init_hydra_config(
|
||||
DEFAULT_CONFIG_PATH,
|
||||
|
@ -136,7 +141,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
|||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
num_workers=0,
|
||||
batch_size=2,
|
||||
shuffle=True,
|
||||
pin_memory=DEVICE != "cpu",
|
||||
|
@ -292,6 +297,8 @@ def test_normalize(insert_temporal_dim):
|
|||
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
||||
),
|
||||
("aloha", "act", ["policy.n_action_steps=10"]),
|
||||
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
|
||||
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
|
||||
(
|
||||
"pusht",
|
||||
"octo",
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
|
||||
|
||||
def test_drop_n_first_frames():
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
"index": [0, 1, 2, 3, 4, 5],
|
||||
"episode_index": [0, 0, 1, 2, 2, 2],
|
||||
},
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
sampler = EpisodeAwareSampler(episode_data_index, drop_n_first_frames=1)
|
||||
assert sampler.indices == [1, 4, 5]
|
||||
assert len(sampler) == 3
|
||||
assert list(sampler) == [1, 4, 5]
|
||||
|
||||
|
||||
def test_drop_n_last_frames():
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
"index": [0, 1, 2, 3, 4, 5],
|
||||
"episode_index": [0, 0, 1, 2, 2, 2],
|
||||
},
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
sampler = EpisodeAwareSampler(episode_data_index, drop_n_last_frames=1)
|
||||
assert sampler.indices == [0, 3, 4]
|
||||
assert len(sampler) == 3
|
||||
assert list(sampler) == [0, 3, 4]
|
||||
|
||||
|
||||
def test_episode_indices_to_use():
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
"index": [0, 1, 2, 3, 4, 5],
|
||||
"episode_index": [0, 0, 1, 2, 2, 2],
|
||||
},
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
sampler = EpisodeAwareSampler(episode_data_index, episode_indices_to_use=[0, 2])
|
||||
assert sampler.indices == [0, 1, 3, 4, 5]
|
||||
assert len(sampler) == 5
|
||||
assert list(sampler) == [0, 1, 3, 4, 5]
|
||||
|
||||
|
||||
def test_shuffle():
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
"index": [0, 1, 2, 3, 4, 5],
|
||||
"episode_index": [0, 0, 1, 2, 2, 2],
|
||||
},
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
sampler = EpisodeAwareSampler(episode_data_index, shuffle=False)
|
||||
assert sampler.indices == [0, 1, 2, 3, 4, 5]
|
||||
assert len(sampler) == 6
|
||||
assert list(sampler) == [0, 1, 2, 3, 4, 5]
|
||||
sampler = EpisodeAwareSampler(episode_data_index, shuffle=True)
|
||||
assert sampler.indices == [0, 1, 2, 3, 4, 5]
|
||||
assert len(sampler) == 6
|
||||
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
Loading…
Reference in New Issue