103 lines
3.8 KiB
Python
103 lines
3.8 KiB
Python
import logging
|
|
import os
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from torchvision.transforms import v2
|
|
|
|
from lerobot.common.datasets.utils import compute_stats
|
|
from lerobot.common.transforms import NormalizeTransform, Prod
|
|
|
|
# DATA_DIR specifies to location where datasets are loaded. By default, DATA_DIR is None and
|
|
# we load from `$HOME/.cache/huggingface/hub/datasets`. For our unit tests, we set `DATA_DIR=tests/data`
|
|
# to load a subset of our datasets for faster continuous integration.
|
|
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
|
|
|
|
|
def make_dataset(
|
|
cfg,
|
|
# set normalize=False to remove all transformations and keep images unnormalized in [0,255]
|
|
normalize=True,
|
|
stats_path=None,
|
|
):
|
|
if cfg.env.name == "xarm":
|
|
from lerobot.common.datasets.xarm import XarmDataset
|
|
|
|
clsfunc = XarmDataset
|
|
|
|
elif cfg.env.name == "pusht":
|
|
from lerobot.common.datasets.pusht import PushtDataset
|
|
|
|
clsfunc = PushtDataset
|
|
|
|
elif cfg.env.name == "aloha":
|
|
from lerobot.common.datasets.aloha import AlohaDataset
|
|
|
|
clsfunc = AlohaDataset
|
|
else:
|
|
raise ValueError(cfg.env.name)
|
|
|
|
transforms = None
|
|
if normalize:
|
|
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max,
|
|
# min_max_from_spec
|
|
# TODO(rcadene): remove this and put it in config. Ideally we want to reproduce SOTA results just with mean_std
|
|
normalization_mode = "mean_std" if cfg.env.name == "aloha" else "min_max"
|
|
|
|
if cfg.policy.name == "diffusion" and cfg.env.name == "pusht":
|
|
stats = {}
|
|
# TODO(rcadene): we overwrite stats to have the same as pretrained model, but we should remove this
|
|
stats["observation.state"] = {}
|
|
stats["observation.state"]["min"] = torch.tensor([13.456424, 32.938293], dtype=torch.float32)
|
|
stats["observation.state"]["max"] = torch.tensor([496.14618, 510.9579], dtype=torch.float32)
|
|
stats["action"] = {}
|
|
stats["action"]["min"] = torch.tensor([12.0, 25.0], dtype=torch.float32)
|
|
stats["action"]["max"] = torch.tensor([511.0, 511.0], dtype=torch.float32)
|
|
elif stats_path is None:
|
|
# instantiate a one frame dataset with light transform
|
|
stats_dataset = clsfunc(
|
|
dataset_id=cfg.dataset_id,
|
|
root=DATA_DIR,
|
|
transform=Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
|
)
|
|
|
|
# load stats if the file exists already or compute stats and save it
|
|
precomputed_stats_path = stats_dataset.data_dir / "stats.pth"
|
|
if precomputed_stats_path.exists():
|
|
stats = torch.load(precomputed_stats_path)
|
|
else:
|
|
logging.info(f"compute_stats and save to {precomputed_stats_path}")
|
|
stats = compute_stats(stats_dataset)
|
|
torch.save(stats, stats_path)
|
|
else:
|
|
stats = torch.load(stats_path)
|
|
|
|
transforms = v2.Compose(
|
|
[
|
|
Prod(in_keys=clsfunc.image_keys, prod=1 / 255.0),
|
|
NormalizeTransform(
|
|
stats,
|
|
in_keys=[
|
|
"observation.state",
|
|
"action",
|
|
],
|
|
mode=normalization_mode,
|
|
),
|
|
]
|
|
)
|
|
|
|
delta_timestamps = cfg.policy.get("delta_timestamps")
|
|
if delta_timestamps is not None:
|
|
for key in delta_timestamps:
|
|
if isinstance(delta_timestamps[key], str):
|
|
delta_timestamps[key] = eval(delta_timestamps[key])
|
|
|
|
dataset = clsfunc(
|
|
dataset_id=cfg.dataset_id,
|
|
root=DATA_DIR,
|
|
delta_timestamps=delta_timestamps,
|
|
transform=transforms,
|
|
)
|
|
|
|
return dataset
|