LeRobotDataset v2.1 (#711)

Co-authored-by: Remi <remi.cadene@huggingface.co>
Co-authored-by: Remi Cadene <re.cadene@gmail.com>
This commit is contained in:
Simon Alibert 2025-02-25 15:27:29 +01:00 committed by GitHub
parent aca464ca72
commit 3354d919fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
43 changed files with 2023 additions and 1322 deletions

View File

@ -210,7 +210,7 @@ A `LeRobotDataset` is serialised using several widespread file formats for each
- videos are stored in mp4 format to save space
- metadata are stored in plain json/jsonl files
Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can use the `local_files_only` argument and specify its location with the `root` argument if it's not in the default `~/.cache/huggingface/lerobot` location.
Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can specify its location with the `root` argument if it's not in the default `~/.cache/huggingface/lerobot` location.
### Evaluate a pretrained policy

View File

@ -335,7 +335,7 @@ python lerobot/scripts/control_robot.py \
--control.push_to_hub=true
```
Note: You can resume recording by adding `--control.resume=true`. Also if you didn't push your dataset yet, add `--control.local_files_only=true`.
Note: You can resume recording by adding `--control.resume=true`.
## H. Visualize a dataset
@ -363,8 +363,6 @@ python lerobot/scripts/control_robot.py \
--control.episode=0
```
Note: If you didn't push your dataset yet, add `--control.local_files_only=true`.
## J. Train a policy
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
@ -378,8 +376,6 @@ python lerobot/scripts/train.py \
--wandb.enable=true
```
Note: If you didn't push your dataset yet, add `--control.local_files_only=true`.
Let's explain it:
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/so100_test`.
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.

View File

@ -391,7 +391,7 @@ python lerobot/scripts/control_robot.py \
--control.push_to_hub=true
```
Note: You can resume recording by adding `--control.resume=true`. Also if you didn't push your dataset yet, add `--control.local_files_only=true`.
Note: You can resume recording by adding `--control.resume=true`.
# H. Visualize a dataset
@ -418,8 +418,6 @@ python lerobot/scripts/control_robot.py \
--control.episode=0
```
Note: If you didn't push your dataset yet, add `--control.local_files_only=true`.
## J. Train a policy
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
@ -433,8 +431,6 @@ python lerobot/scripts/train.py \
--wandb.enable=true
```
Note: If you didn't push your dataset yet, add `--control.local_files_only=true`.
Let's explain it:
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/lekiwi_test`.
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.

View File

@ -256,7 +256,7 @@ python lerobot/scripts/control_robot.py \
--control.push_to_hub=true
```
Note: You can resume recording by adding `--control.resume=true`. Also if you didn't push your dataset yet, add `--control.local_files_only=true`.
Note: You can resume recording by adding `--control.resume=true`.
## Visualize a dataset
@ -284,8 +284,6 @@ python lerobot/scripts/control_robot.py \
--control.episode=0
```
Note: If you didn't push your dataset yet, add `--control.local_files_only=true`.
## Train a policy
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
@ -299,8 +297,6 @@ python lerobot/scripts/train.py \
--wandb.enable=true
```
Note: If you didn't push your dataset yet, add `--control.local_files_only=true`.
Let's explain it:
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/moss_test`.
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.

View File

@ -768,7 +768,7 @@ You can use the `record` function from [`lerobot/scripts/control_robot.py`](../l
1. Frames from cameras are saved on disk in threads, and encoded into videos at the end of each episode recording.
2. Video streams from cameras are displayed in window so that you can verify them.
3. Data is stored with [`LeRobotDataset`](../lerobot/common/datasets/lerobot_dataset.py) format which is pushed to your Hugging Face page (unless `--control.push_to_hub=false` is provided).
4. Checkpoints are done during recording, so if any issue occurs, you can resume recording by re-running the same command again with `--control.resume=true`. You might need to add `--control.local_files_only=true` if your dataset was not uploaded to hugging face hub. Also you will need to manually delete the dataset directory to start recording from scratch.
4. Checkpoints are done during recording, so if any issue occurs, you can resume recording by re-running the same command again with `--control.resume=true`. You will need to manually delete the dataset directory if you want to start recording from scratch.
5. Set the flow of data recording using command line arguments:
- `--control.warmup_time_s=10` defines the number of seconds before starting data collection. It allows the robot devices to warmup and synchronize (10 seconds by default).
- `--control.episode_time_s=60` defines the number of seconds for data recording for each episode (60 seconds by default).
@ -883,8 +883,6 @@ python lerobot/scripts/control_robot.py \
--control.episode=0
```
Note: You might need to add `--control.local_files_only=true` if your dataset was not uploaded to hugging face hub.
Your robot should replicate movements similar to those you recorded. For example, check out [this video](https://x.com/RemiCadene/status/1793654950905680090) where we use `replay` on a Aloha robot from [Trossen Robotics](https://www.trossenrobotics.com).
## 4. Train a policy on your data
@ -902,8 +900,6 @@ python lerobot/scripts/train.py \
--wandb.enable=true
```
Note: You might need to add `--dataset.local_files_only=true` if your dataset was not uploaded to hugging face hub.
Let's explain it:
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/koch_test`.
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.

View File

@ -2,9 +2,10 @@ import shutil
from pathlib import Path
import numpy as np
import torch
from huggingface_hub import HfApi
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
PUSHT_TASK = "Push the T-shaped blue block onto the T-shaped green target surface."
@ -89,9 +90,9 @@ def calculate_coverage(zarr_data):
num_frames = len(block_pos)
coverage = np.zeros((num_frames,))
coverage = np.zeros((num_frames,), dtype=np.float32)
# 8 keypoints with 2 coords each
keypoints = np.zeros((num_frames, 16))
keypoints = np.zeros((num_frames, 16), dtype=np.float32)
# Set x, y, theta (in radians)
goal_pos_angle = np.array([256, 256, np.pi / 4])
@ -117,7 +118,7 @@ def calculate_coverage(zarr_data):
intersection_area = goal_geom.intersection(block_geom).area
goal_area = goal_geom.area
coverage[i] = intersection_area / goal_area
keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten())
keypoints[i] = PushTEnv.get_keypoints(block_shapes).flatten()
return coverage, keypoints
@ -134,8 +135,8 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
if mode not in ["video", "image", "keypoints"]:
raise ValueError(mode)
if (LEROBOT_HOME / repo_id).exists():
shutil.rmtree(LEROBOT_HOME / repo_id)
if (HF_LEROBOT_HOME / repo_id).exists():
shutil.rmtree(HF_LEROBOT_HOME / repo_id)
if not raw_dir.exists():
download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw")
@ -148,6 +149,10 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
action = zarr_data["action"][:]
image = zarr_data["img"] # (b, h, w, c)
if image.dtype == np.float32 and image.max() == np.float32(255):
# HACK: images are loaded as float32 but they actually encode uint8 data
image = image.astype(np.uint8)
episode_data_index = {
"from": np.concatenate(([0], zarr_data.meta["episode_ends"][:-1])),
"to": zarr_data.meta["episode_ends"],
@ -175,28 +180,30 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
for frame_idx in range(num_frames):
i = from_idx + frame_idx
idx = i + (frame_idx < num_frames - 1)
frame = {
"action": torch.from_numpy(action[i]),
"action": action[i],
# Shift reward and success by +1 until the last item of the episode
"next.reward": reward[i + (frame_idx < num_frames - 1)],
"next.success": success[i + (frame_idx < num_frames - 1)],
"next.reward": reward[idx : idx + 1],
"next.success": success[idx : idx + 1],
"task": PUSHT_TASK,
}
frame["observation.state"] = torch.from_numpy(agent_pos[i])
frame["observation.state"] = agent_pos[i]
if mode == "keypoints":
frame["observation.environment_state"] = torch.from_numpy(keypoints[i])
frame["observation.environment_state"] = keypoints[i]
else:
frame["observation.image"] = torch.from_numpy(image[i])
frame["observation.image"] = image[i]
dataset.add_frame(frame)
dataset.save_episode(task=PUSHT_TASK)
dataset.consolidate()
dataset.save_episode()
if push_to_hub:
dataset.push_to_hub()
hub_api = HfApi()
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset")
if __name__ == "__main__":
@ -218,5 +225,5 @@ if __name__ == "__main__":
main(raw_dir, repo_id=repo_id, mode=mode)
# Uncomment if you want to load the local dataset and explore it
# dataset = LeRobotDataset(repo_id=repo_id, local_files_only=True)
# dataset = LeRobotDataset(repo_id=repo_id)
# breakpoint()

View File

@ -1,4 +1,9 @@
# keys
import os
from pathlib import Path
from huggingface_hub.constants import HF_HOME
OBS_ENV = "observation.environment_state"
OBS_ROBOT = "observation.state"
OBS_IMAGE = "observation.image"
@ -15,3 +20,13 @@ TRAINING_STEP = "training_step.json"
OPTIMIZER_STATE = "optimizer_state.safetensors"
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
SCHEDULER_STATE = "scheduler_state.json"
# cache dir
default_cache_path = Path(HF_HOME) / "lerobot"
HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser()
if "LEROBOT_HOME" in os.environ:
raise ValueError(
f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n"
"'LEROBOT_HOME' is deprecated, please use 'HF_LEROBOT_HOME' instead."
)

View File

@ -0,0 +1,54 @@
import packaging.version
V2_MESSAGE = """
The dataset you requested ({repo_id}) is in {version} format.
We introduced a new format since v2.0 which is not backward compatible with v1.x.
Please, use our conversion script. Modify the following command with your own task description:
```
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
--repo-id {repo_id} \\
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
```
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.", "Insert the
peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.", "Open the top
cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped
target.", "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the
sweatshirt.", ...
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
"""
V21_MESSAGE = """
The dataset you requested ({repo_id}) is in {version} format.
While current version of LeRobot is backward-compatible with it, the version of your dataset still uses global
stats instead of per-episode stats. Update your dataset stats to the new format using this command:
```
python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py --repo-id={repo_id}
```
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
"""
FUTURE_MESSAGE = """
The dataset you requested ({repo_id}) is only available in {version} format.
As we cannot ensure forward compatibility with it, please update your current version of lerobot.
"""
class CompatibilityError(Exception): ...
class BackwardCompatibilityError(CompatibilityError):
def __init__(self, repo_id: str, version: packaging.version.Version):
message = V2_MESSAGE.format(repo_id=repo_id, version=version)
super().__init__(message)
class ForwardCompatibilityError(CompatibilityError):
def __init__(self, repo_id: str, version: packaging.version.Version):
message = FUTURE_MESSAGE.format(repo_id=repo_id, version=version)
super().__init__(message)

View File

@ -13,202 +13,164 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from math import ceil
import numpy as np
import einops
import torch
import tqdm
from lerobot.common.datasets.utils import load_image_as_numpy
def get_stats_einops_patterns(dataset, num_workers=0):
"""These einops patterns will be used to aggregate batches and compute statistics.
def estimate_num_samples(
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
) -> int:
"""Heuristic to estimate the number of samples based on dataset size.
The power controls the sample growth relative to dataset size.
Lower the power for less number of samples.
Note: We assume the images are in channel first format
For default arguments, we have:
- from 1 to ~500, num_samples=100
- at 1000, num_samples=177
- at 2000, num_samples=299
- at 5000, num_samples=594
- at 10000, num_samples=1000
- at 20000, num_samples=1681
"""
if dataset_len < min_num_samples:
min_num_samples = dataset_len
return max(min_num_samples, min(int(dataset_len**power), max_num_samples))
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_size=2,
shuffle=False,
)
batch = next(iter(dataloader))
stats_patterns = {}
def sample_indices(data_len: int) -> list[int]:
num_samples = estimate_num_samples(data_len)
return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist()
for key in dataset.features:
# sanity check that tensors are not float64
assert batch[key].dtype != torch.float64
# if isinstance(feats_type, (VideoFrame, Image)):
if key in dataset.meta.camera_keys:
# sanity check that images are channel first
_, c, h, w = batch[key].shape
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
def auto_downsample_height_width(img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300):
_, height, width = img.shape
# sanity check that images are float32 in range [0,1]
assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}"
assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}"
assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}"
if max(width, height) < max_size_threshold:
# no downsampling needed
return img
stats_patterns[key] = "b c h w -> c 1 1"
elif batch[key].ndim == 2:
stats_patterns[key] = "b c -> c "
elif batch[key].ndim == 1:
stats_patterns[key] = "b -> 1"
downsample_factor = int(width / target_size) if width > height else int(height / target_size)
return img[:, ::downsample_factor, ::downsample_factor]
def sample_images(image_paths: list[str]) -> np.ndarray:
sampled_indices = sample_indices(len(image_paths))
images = None
for i, idx in enumerate(sampled_indices):
path = image_paths[idx]
# we load as uint8 to reduce memory usage
img = load_image_as_numpy(path, dtype=np.uint8, channel_first=True)
img = auto_downsample_height_width(img)
if images is None:
images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
images[i] = img
return images
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
return {
"min": np.min(array, axis=axis, keepdims=keepdims),
"max": np.max(array, axis=axis, keepdims=keepdims),
"mean": np.mean(array, axis=axis, keepdims=keepdims),
"std": np.std(array, axis=axis, keepdims=keepdims),
"count": np.array([len(array)]),
}
def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
ep_stats = {}
for key, data in episode_data.items():
if features[key]["dtype"] == "string":
continue # HACK: we should receive np.arrays of strings
elif features[key]["dtype"] in ["image", "video"]:
ep_ft_array = sample_images(data) # data is a list of image paths
axes_to_reduce = (0, 2, 3) # keep channel dim
keepdims = True
else:
raise ValueError(f"{key}, {batch[key].shape}")
ep_ft_array = data # data is alreay a np.ndarray
axes_to_reduce = 0 # compute stats over the first axis
keepdims = data.ndim == 1 # keep as np.array
return stats_patterns
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
# finally, we normalize and remove batch dim for images
if features[key]["dtype"] in ["image", "video"]:
ep_stats[key] = {
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
}
return ep_stats
def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None):
"""Compute mean/std and min/max statistics of all data keys in a LeRobotDataset."""
if max_num_samples is None:
max_num_samples = len(dataset)
# for more info on why we need to set the same number of workers, see `load_from_videos`
stats_patterns = get_stats_einops_patterns(dataset, num_workers)
# mean and std will be computed incrementally while max and min will track the running value.
mean, std, max, min = {}, {}, {}, {}
for key in stats_patterns:
mean[key] = torch.tensor(0.0).float()
std[key] = torch.tensor(0.0).float()
max[key] = torch.tensor(-float("inf")).float()
min[key] = torch.tensor(float("inf")).float()
def create_seeded_dataloader(dataset, batch_size, seed):
generator = torch.Generator()
generator.manual_seed(seed)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_size=batch_size,
shuffle=True,
drop_last=False,
generator=generator,
)
return dataloader
# Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get
# surprises when rerunning the sampler.
first_batch = None
running_item_count = 0 # for online mean computation
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max")
):
this_batch_size = len(batch["index"])
running_item_count += this_batch_size
if first_batch is None:
first_batch = deepcopy(batch)
for key, pattern in stats_patterns.items():
batch[key] = batch[key].float()
# Numerically stable update step for mean computation.
batch_mean = einops.reduce(batch[key], pattern, "mean")
# Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents
# the update step, N is the running item count, B is this batch size, x̄ is the running mean,
# and x is the current batch mean. Some rearrangement is then required to avoid risking
# numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields
# x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ
mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count
max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max"))
min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min"))
if i == ceil(max_num_samples / batch_size) - 1:
break
first_batch_ = None
running_item_count = 0 # for online std computation
dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337)
for i, batch in enumerate(
tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std")
):
this_batch_size = len(batch["index"])
running_item_count += this_batch_size
# Sanity check to make sure the batches are still in the same order as before.
if first_batch_ is None:
first_batch_ = deepcopy(batch)
for key in stats_patterns:
assert torch.equal(first_batch_[key], first_batch[key])
for key, pattern in stats_patterns.items():
batch[key] = batch[key].float()
# Numerically stable update step for mean computation (where the mean is over squared
# residuals).See notes in the mean computation loop above.
batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean")
std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count
if i == ceil(max_num_samples / batch_size) - 1:
break
for key in stats_patterns:
std[key] = torch.sqrt(std[key])
stats = {}
for key in stats_patterns:
stats[key] = {
"mean": mean[key],
"std": std[key],
"max": max[key],
"min": min[key],
}
return stats
def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
for i in range(len(stats_list)):
for fkey in stats_list[i]:
for k, v in stats_list[i][fkey].items():
if not isinstance(v, np.ndarray):
raise ValueError(
f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead."
)
if v.ndim == 0:
raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
if k == "count" and v.shape != (1,):
raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.")
if "image" in fkey and k != "count" and v.shape != (3, 1, 1):
raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
"""Aggregate stats of multiple LeRobot datasets into one set of stats without recomputing from scratch.
def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
"""Aggregates stats for a single feature."""
means = np.stack([s["mean"] for s in stats_ft_list])
variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
counts = np.stack([s["count"] for s in stats_ft_list])
total_count = counts.sum(axis=0)
The final stats will have the union of all data keys from each of the datasets.
# Prepare weighted mean by matching number of dimensions
while counts.ndim < means.ndim:
counts = np.expand_dims(counts, axis=-1)
The final stats will have the union of all data keys from each of the datasets. For instance:
- new_max = max(max_dataset_0, max_dataset_1, ...)
# Compute the weighted mean
weighted_means = means * counts
total_mean = weighted_means.sum(axis=0) / total_count
# Compute the variance using the parallel algorithm
delta_means = means - total_mean
weighted_variances = (variances + delta_means**2) * counts
total_variance = weighted_variances.sum(axis=0) / total_count
return {
"min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0),
"max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0),
"mean": total_mean,
"std": np.sqrt(total_variance),
"count": total_count,
}
def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.
The final stats will have the union of all data keys from each of the stats dicts.
For instance:
- new_min = min(min_dataset_0, min_dataset_1, ...)
- new_mean = (mean of all data)
- new_max = max(max_dataset_0, max_dataset_1, ...)
- new_mean = (mean of all data, weighted by counts)
- new_std = (std of all data)
"""
data_keys = set()
for dataset in ls_datasets:
data_keys.update(dataset.meta.stats.keys())
stats = {k: {} for k in data_keys}
for data_key in data_keys:
for stat_key in ["min", "max"]:
# compute `max(dataset_0["max"], dataset_1["max"], ...)`
stats[data_key][stat_key] = einops.reduce(
torch.stack(
[ds.meta.stats[data_key][stat_key] for ds in ls_datasets if data_key in ds.meta.stats],
dim=0,
),
"n ... -> ...",
stat_key,
)
total_samples = sum(d.num_frames for d in ls_datasets if data_key in d.meta.stats)
# Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
# dataset, then divide by total_samples to get the overall "mean".
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
# numerical overflow!
stats[data_key]["mean"] = sum(
d.meta.stats[data_key]["mean"] * (d.num_frames / total_samples)
for d in ls_datasets
if data_key in d.meta.stats
)
# The derivation for standard deviation is a little more involved but is much in the same spirit as
# the computation of the mean.
# Given two sets of data where the statistics are known:
# σ_combined = sqrt[ (n1 * (σ1^2 + d1^2) + n2 * (σ2^2 + d2^2)) / (n1 + n2) ]
# where d1 = μ1 - μ_combined, d2 = μ2 - μ_combined
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
# numerical overflow!
stats[data_key]["std"] = torch.sqrt(
sum(
(
d.meta.stats[data_key]["std"] ** 2
+ (d.meta.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2
)
* (d.num_frames / total_samples)
for d in ls_datasets
if data_key in d.meta.stats
)
)
return stats
_assert_type_and_shape(stats_list)
data_keys = {key for stats in stats_list for key in stats}
aggregated_stats = {key: {} for key in data_keys}
for key in data_keys:
stats_with_key = [stats[key] for stats in stats_list if key in stats]
aggregated_stats[key] = aggregate_feature_stats(stats_with_key)
return aggregated_stats

View File

@ -83,15 +83,15 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
)
if isinstance(cfg.dataset.repo_id, str):
ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id, local_files_only=cfg.dataset.local_files_only)
ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id, revision=cfg.dataset.revision)
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
dataset = LeRobotDataset(
cfg.dataset.repo_id,
episodes=cfg.dataset.episodes,
delta_timestamps=delta_timestamps,
image_transforms=image_transforms,
revision=cfg.dataset.revision,
video_backend=cfg.dataset.video_backend,
local_files_only=cfg.dataset.local_files_only,
)
else:
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")

View File

@ -38,22 +38,40 @@ def safe_stop_image_writer(func):
return wrapper
def image_array_to_image(image_array: np.ndarray) -> PIL.Image.Image:
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
# TODO(aliberts): handle 1 channel and 4 for depth images
if image_array.ndim == 3 and image_array.shape[0] in [1, 3]:
if image_array.ndim != 3:
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
if image_array.shape[0] == 3:
# Transpose from pytorch convention (C, H, W) to (H, W, C)
image_array = image_array.transpose(1, 2, 0)
elif image_array.shape[-1] != 3:
raise NotImplementedError(
f"The image has {image_array.shape[-1]} channels, but 3 is required for now."
)
if image_array.dtype != np.uint8:
# Assume the image is in [0, 1] range for floating-point data
image_array = np.clip(image_array, 0, 1)
if range_check:
max_ = image_array.max().item()
min_ = image_array.min().item()
if max_ > 1.0 or min_ < 0.0:
raise ValueError(
"The image data type is float, which requires values in the range [0.0, 1.0]. "
f"However, the provided range is [{min_}, {max_}]. Please adjust the range or "
"provide a uint8 image with values in the range [0, 255]."
)
image_array = (image_array * 255).astype(np.uint8)
return PIL.Image.fromarray(image_array)
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
try:
if isinstance(image, np.ndarray):
img = image_array_to_image(image)
img = image_array_to_pil_image(image)
elif isinstance(image, PIL.Image.Image):
img = image
else:

View File

@ -14,49 +14,54 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import shutil
from functools import cached_property
from pathlib import Path
from typing import Callable
import datasets
import numpy as np
import packaging.version
import PIL.Image
import torch
import torch.utils
from datasets import load_dataset
from huggingface_hub import create_repo, snapshot_download, upload_folder
from datasets import concatenate_datasets, load_dataset
from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.constants import REPOCARD_NAME
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
from lerobot.common.datasets.utils import (
DEFAULT_FEATURES,
DEFAULT_IMAGE_PATH,
EPISODES_PATH,
INFO_PATH,
STATS_PATH,
TASKS_PATH,
append_jsonlines,
backward_compatible_episodes_stats,
check_delta_timestamps,
check_timestamps_sync,
check_version_compatibility,
create_branch,
create_empty_dataset_info,
create_lerobot_dataset_card,
embed_images,
get_delta_indices,
get_episode_data_index,
get_features_from_robot,
get_hf_features_from_features,
get_hub_safe_version,
get_safe_version,
hf_transform_to_torch,
is_valid_version,
load_episodes,
load_episodes_stats,
load_info,
load_stats,
load_tasks,
serialize_dict,
validate_episode_buffer,
validate_frame,
write_episode,
write_episode_stats,
write_info,
write_json,
write_parquet,
)
from lerobot.common.datasets.video_utils import (
VideoFrame,
@ -66,9 +71,7 @@ from lerobot.common.datasets.video_utils import (
)
from lerobot.common.robot_devices.robots.utils import Robot
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
CODEBASE_VERSION = "v2.0"
LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
CODEBASE_VERSION = "v2.1"
class LeRobotDatasetMetadata:
@ -76,19 +79,36 @@ class LeRobotDatasetMetadata:
self,
repo_id: str,
root: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
force_cache_sync: bool = False,
):
self.repo_id = repo_id
self.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
self.local_files_only = local_files_only
self.revision = revision if revision else CODEBASE_VERSION
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
# Load metadata
(self.root / "meta").mkdir(exist_ok=True, parents=True)
self.pull_from_repo(allow_patterns="meta/")
try:
if force_cache_sync:
raise FileNotFoundError
self.load_metadata()
except (FileNotFoundError, NotADirectoryError):
if is_valid_version(self.revision):
self.revision = get_safe_version(self.repo_id, self.revision)
(self.root / "meta").mkdir(exist_ok=True, parents=True)
self.pull_from_repo(allow_patterns="meta/")
self.load_metadata()
def load_metadata(self):
self.info = load_info(self.root)
self.stats = load_stats(self.root)
self.tasks = load_tasks(self.root)
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
self.tasks, self.task_to_task_index = load_tasks(self.root)
self.episodes = load_episodes(self.root)
if self._version < packaging.version.parse("v2.1"):
self.stats = load_stats(self.root)
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
else:
self.episodes_stats = load_episodes_stats(self.root)
self.stats = aggregate_stats(list(self.episodes_stats.values()))
def pull_from_repo(
self,
@ -98,21 +118,16 @@ class LeRobotDatasetMetadata:
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self._hub_version,
revision=self.revision,
local_dir=self.root,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
local_files_only=self.local_files_only,
)
@cached_property
def _hub_version(self) -> str | None:
return None if self.local_files_only else get_hub_safe_version(self.repo_id, CODEBASE_VERSION)
@property
def _version(self) -> str:
def _version(self) -> packaging.version.Version:
"""Codebase version used to create this dataset."""
return self.info["codebase_version"]
return packaging.version.parse(self.info["codebase_version"])
def get_data_file_path(self, ep_index: int) -> Path:
ep_chunk = self.get_episode_chunk(ep_index)
@ -202,54 +217,65 @@ class LeRobotDatasetMetadata:
"""Max number of episodes per chunk."""
return self.info["chunks_size"]
@property
def task_to_task_index(self) -> dict:
return {task: task_idx for task_idx, task in self.tasks.items()}
def get_task_index(self, task: str) -> int:
def get_task_index(self, task: str) -> int | None:
"""
Given a task in natural language, returns its task_index if the task already exists in the dataset,
otherwise creates a new task_index.
otherwise return None.
"""
task_index = self.task_to_task_index.get(task, None)
return task_index if task_index is not None else self.total_tasks
return self.task_to_task_index.get(task, None)
def save_episode(self, episode_index: int, episode_length: int, task: str, task_index: int) -> None:
def add_task(self, task: str):
"""
Given a task in natural language, add it to the dictionnary of tasks.
"""
if task in self.task_to_task_index:
raise ValueError(f"The task '{task}' already exists and can't be added twice.")
task_index = self.info["total_tasks"]
self.task_to_task_index[task] = task_index
self.tasks[task_index] = task
self.info["total_tasks"] += 1
task_dict = {
"task_index": task_index,
"task": task,
}
append_jsonlines(task_dict, self.root / TASKS_PATH)
def save_episode(
self,
episode_index: int,
episode_length: int,
episode_tasks: list[str],
episode_stats: dict[str, dict],
) -> None:
self.info["total_episodes"] += 1
self.info["total_frames"] += episode_length
if task_index not in self.tasks:
self.info["total_tasks"] += 1
self.tasks[task_index] = task
task_dict = {
"task_index": task_index,
"task": task,
}
append_jsonlines(task_dict, self.root / TASKS_PATH)
chunk = self.get_episode_chunk(episode_index)
if chunk >= self.total_chunks:
self.info["total_chunks"] += 1
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
self.info["total_videos"] += len(self.video_keys)
write_json(self.info, self.root / INFO_PATH)
if len(self.video_keys) > 0:
self.update_video_info()
write_info(self.info, self.root)
episode_dict = {
"episode_index": episode_index,
"tasks": [task],
"tasks": episode_tasks,
"length": episode_length,
}
self.episodes.append(episode_dict)
append_jsonlines(episode_dict, self.root / EPISODES_PATH)
self.episodes[episode_index] = episode_dict
write_episode(episode_dict, self.root)
# TODO(aliberts): refactor stats in save_episodes
# image_sampling = int(self.fps / 2) # sample 2 img/s for the stats
# ep_stats = compute_episode_stats(episode_buffer, self.features, episode_length, image_sampling=image_sampling)
# ep_stats = serialize_dict(ep_stats)
# append_jsonlines(ep_stats, self.root / STATS_PATH)
self.episodes_stats[episode_index] = episode_stats
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats
write_episode_stats(episode_index, episode_stats, self.root)
def write_video_info(self) -> None:
def update_video_info(self) -> None:
"""
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
been encoded the same way. Also, this means it assumes the first episode exists.
@ -259,8 +285,6 @@ class LeRobotDatasetMetadata:
video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key)
self.info["features"][key]["info"] = get_video_info(video_path)
write_json(self.info, self.root / INFO_PATH)
def __repr__(self):
feature_keys = list(self.features)
return (
@ -286,7 +310,7 @@ class LeRobotDatasetMetadata:
"""Creates metadata for a LeRobotDataset."""
obj = cls.__new__(cls)
obj.repo_id = repo_id
obj.root = Path(root) if root is not None else LEROBOT_HOME / repo_id
obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
obj.root.mkdir(parents=True, exist_ok=False)
@ -304,6 +328,7 @@ class LeRobotDatasetMetadata:
)
else:
# TODO(aliberts, rcadene): implement sanity check for features
features = {**features, **DEFAULT_FEATURES}
# check if none of the features contains a "/" in their names,
# as this would break the dict flattening in the stats computation, which uses '/' as separator
@ -313,12 +338,13 @@ class LeRobotDatasetMetadata:
features = {**features, **DEFAULT_FEATURES}
obj.tasks, obj.stats, obj.episodes = {}, {}, []
obj.tasks, obj.task_to_task_index = {}, {}
obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {}
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos)
if len(obj.video_keys) > 0 and not use_videos:
raise ValueError()
write_json(obj.info, obj.root / INFO_PATH)
obj.local_files_only = True
obj.revision = None
return obj
@ -331,8 +357,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
tolerance_s: float = 1e-4,
revision: str | None = None,
force_cache_sync: bool = False,
download_videos: bool = True,
local_files_only: bool = False,
video_backend: str | None = None,
):
"""
@ -342,7 +369,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
- On your local disk in the 'root' folder. This is typically the case when you recorded your
dataset locally and you may or may not have pushed it to the hub yet. Instantiating this class
with 'root' will load your dataset directly from disk. This can happen while you're offline (no
internet connection), in that case, use local_files_only=True.
internet connection).
- On the Hugging Face Hub at the address https://huggingface.co/datasets/{repo_id} and not on
your local disk in the 'root' folder. Instantiating this class with this 'repo_id' will download
@ -424,24 +451,28 @@ class LeRobotDataset(torch.utils.data.Dataset):
timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
decoded from video files. It is also used to check that `delta_timestamps` (when provided) are
multiples of 1/fps. Defaults to 1e-4.
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
commit hash. Defaults to current codebase version tag.
sync_cache_first (bool, optional): Flag to sync and refresh local files first. If True and files
are already present in the local cache, this will be faster. However, files loaded might not
be in sync with the version on the hub, especially if you specified 'revision'. Defaults to
False.
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
video files are already present on local disk, they won't be downloaded again. Defaults to
True.
local_files_only (bool, optional): Flag to use local files only. If True, no requests to the hub
will be made. Defaults to False.
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
"""
super().__init__()
self.repo_id = repo_id
self.root = Path(root) if root else LEROBOT_HOME / repo_id
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
self.episodes = episodes
self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION
self.video_backend = video_backend if video_backend else "pyav"
self.delta_indices = None
self.local_files_only = local_files_only
# Unused attributes
self.image_writer = None
@ -450,64 +481,86 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.root.mkdir(exist_ok=True, parents=True)
# Load metadata
self.meta = LeRobotDatasetMetadata(self.repo_id, self.root, self.local_files_only)
# Check version
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
self.meta = LeRobotDatasetMetadata(
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
)
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
self.stats = aggregate_stats(episodes_stats)
# Load actual data
self.download_episodes(download_videos)
self.hf_dataset = self.load_hf_dataset()
try:
if force_cache_sync:
raise FileNotFoundError
assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths())
self.hf_dataset = self.load_hf_dataset()
except (AssertionError, FileNotFoundError, NotADirectoryError):
self.revision = get_safe_version(self.repo_id, self.revision)
self.download_episodes(download_videos)
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
# Check timestamps
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
# Setup delta_indices
if self.delta_timestamps is not None:
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
# Available stats implies all videos have been encoded and dataset is iterable
self.consolidated = self.meta.stats is not None
def push_to_hub(
self,
branch: str | None = None,
tags: list | None = None,
license: str | None = "apache-2.0",
push_videos: bool = True,
private: bool = False,
allow_patterns: list[str] | str | None = None,
upload_large_folder: bool = False,
**card_kwargs,
) -> None:
if not self.consolidated:
logging.warning(
"You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet. "
"Consolidating first."
)
self.consolidate()
ignore_patterns = ["images/"]
if not push_videos:
ignore_patterns.append("videos/")
create_repo(
hub_api = HfApi()
hub_api.create_repo(
repo_id=self.repo_id,
private=private,
repo_type="dataset",
exist_ok=True,
)
if branch:
hub_api.create_branch(
repo_id=self.repo_id,
branch=branch,
revision=self.revision,
repo_type="dataset",
exist_ok=True,
)
upload_folder(
repo_id=self.repo_id,
folder_path=self.root,
repo_type="dataset",
ignore_patterns=ignore_patterns,
)
card = create_lerobot_dataset_card(
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
)
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset")
create_branch(repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset")
upload_kwargs = {
"repo_id": self.repo_id,
"folder_path": self.root,
"repo_type": "dataset",
"revision": branch,
"allow_patterns": allow_patterns,
"ignore_patterns": ignore_patterns,
}
if upload_large_folder:
hub_api.upload_large_folder(**upload_kwargs)
else:
hub_api.upload_folder(**upload_kwargs)
if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
card = create_lerobot_dataset_card(
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
)
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
def pull_from_repo(
self,
@ -517,11 +570,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
snapshot_download(
self.repo_id,
repo_type="dataset",
revision=self.meta._hub_version,
revision=self.revision,
local_dir=self.root,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
local_files_only=self.local_files_only,
)
def download_episodes(self, download_videos: bool = True) -> None:
@ -535,17 +587,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
files = None
ignore_patterns = None if download_videos else "videos/"
if self.episodes is not None:
files = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
if len(self.meta.video_keys) > 0 and download_videos:
video_files = [
str(self.meta.get_video_file_path(ep_idx, vid_key))
for vid_key in self.meta.video_keys
for ep_idx in self.episodes
]
files += video_files
files = self.get_episodes_file_paths()
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
def get_episodes_file_paths(self) -> list[Path]:
episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes))
fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
if len(self.meta.video_keys) > 0:
video_files = [
str(self.meta.get_video_file_path(ep_idx, vid_key))
for vid_key in self.meta.video_keys
for ep_idx in episodes
]
fpaths += video_files
return fpaths
def load_hf_dataset(self) -> datasets.Dataset:
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
if self.episodes is None:
@ -557,7 +615,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
# TODO(aliberts): hf_dataset.set_format("torch")
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
def create_hf_dataset(self) -> datasets.Dataset:
features = get_hf_features_from_features(self.features)
ft_dict = {col: [] for col in features}
hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train")
# TODO(aliberts): hf_dataset.set_format("torch")
hf_dataset.set_transform(hf_transform_to_torch)
return hf_dataset
@property
@ -624,7 +690,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
if key not in self.meta.video_keys
}
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
@ -654,8 +720,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
query_indices = None
if self.delta_indices is not None:
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
query_indices, padding = self._get_query_indices(idx, current_ep_idx)
query_indices, padding = self._get_query_indices(idx, ep_idx)
query_result = self._query_hf_dataset(query_indices)
item = {**item, **padding}
for key, val in query_result.items():
@ -691,10 +756,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
return {
"size": 0,
**{key: current_ep_idx if key == "episode_index" else [] for key in self.features},
}
ep_buffer = {}
# size and task are special cases that are not in self.features
ep_buffer["size"] = 0
ep_buffer["task"] = []
for key in self.features:
ep_buffer[key] = current_ep_idx if key == "episode_index" else []
return ep_buffer
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
fpath = DEFAULT_IMAGE_PATH.format(
@ -716,25 +784,35 @@ class LeRobotDataset(torch.utils.data.Dataset):
temporary directory nothing is written to disk. To save those frames, the 'save_episode()' method
then needs to be called.
"""
# TODO(aliberts, rcadene): Add sanity check for the input, check it's numpy or torch,
# check the dtype and shape matches, etc.
# Convert torch to numpy if needed
for name in frame:
if isinstance(frame[name], torch.Tensor):
frame[name] = frame[name].numpy()
validate_frame(frame, self.features)
if self.episode_buffer is None:
self.episode_buffer = self.create_episode_buffer()
# Automatically add frame_index and timestamp to episode buffer
frame_index = self.episode_buffer["size"]
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
self.episode_buffer["frame_index"].append(frame_index)
self.episode_buffer["timestamp"].append(timestamp)
# Add frame features to episode_buffer
for key in frame:
if key not in self.features:
raise ValueError(key)
if key == "task":
# Note: we associate the task in natural language to its task index during `save_episode`
self.episode_buffer["task"].append(frame["task"])
continue
if self.features[key]["dtype"] not in ["image", "video"]:
item = frame[key].numpy() if isinstance(frame[key], torch.Tensor) else frame[key]
self.episode_buffer[key].append(item)
elif self.features[key]["dtype"] in ["image", "video"]:
if key not in self.features:
raise ValueError(
f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'."
)
if self.features[key]["dtype"] in ["image", "video"]:
img_path = self._get_image_file_path(
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
)
@ -742,80 +820,95 @@ class LeRobotDataset(torch.utils.data.Dataset):
img_path.parent.mkdir(parents=True, exist_ok=True)
self._save_image(frame[key], img_path)
self.episode_buffer[key].append(str(img_path))
else:
self.episode_buffer[key].append(frame[key])
self.episode_buffer["size"] += 1
def save_episode(self, task: str, encode_videos: bool = True, episode_data: dict | None = None) -> None:
def save_episode(self, episode_data: dict | None = None) -> None:
"""
This will save to disk the current episode in self.episode_buffer. Note that since it affects files on
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
the hub.
This will save to disk the current episode in self.episode_buffer.
Use 'encode_videos' if you want to encode videos during the saving of this episode. Otherwise,
you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend
time for video encoding.
Args:
episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
None.
"""
if not episode_data:
episode_buffer = self.episode_buffer
validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features)
# size and task are special cases that won't be added to hf_dataset
episode_length = episode_buffer.pop("size")
tasks = episode_buffer.pop("task")
episode_tasks = list(set(tasks))
episode_index = episode_buffer["episode_index"]
if episode_index != self.meta.total_episodes:
# TODO(aliberts): Add option to use existing episode_index
raise NotImplementedError(
"You might have manually provided the episode_buffer with an episode_index that doesn't "
"match the total number of episodes in the dataset. This is not supported for now."
)
if episode_length == 0:
raise ValueError(
"You must add one or several frames with `add_frame` before calling `add_episode`."
)
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
task_index = self.meta.get_task_index(task)
# Add new tasks to the tasks dictionnary
for task in episode_tasks:
task_index = self.meta.get_task_index(task)
if task_index is None:
self.meta.add_task(task)
if not set(episode_buffer.keys()) == set(self.features):
raise ValueError()
# Given tasks in natural language, find their corresponding task indices
episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
for key, ft in self.features.items():
if key == "index":
episode_buffer[key] = np.arange(
self.meta.total_frames, self.meta.total_frames + episode_length
)
elif key == "episode_index":
episode_buffer[key] = np.full((episode_length,), episode_index)
elif key == "task_index":
episode_buffer[key] = np.full((episode_length,), task_index)
elif ft["dtype"] in ["image", "video"]:
# index, episode_index, task_index are already processed above, and image and video
# are processed separately by storing image path and frame info as meta data
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
continue
elif len(ft["shape"]) == 1 and ft["shape"][0] == 1:
episode_buffer[key] = np.array(episode_buffer[key], dtype=ft["dtype"])
elif len(ft["shape"]) == 1 and ft["shape"][0] > 1:
episode_buffer[key] = np.stack(episode_buffer[key])
else:
raise ValueError(key)
episode_buffer[key] = np.stack(episode_buffer[key])
self._wait_image_writer()
self._save_episode_table(episode_buffer, episode_index)
ep_stats = compute_episode_stats(episode_buffer, self.features)
self.meta.save_episode(episode_index, episode_length, task, task_index)
if encode_videos and len(self.meta.video_keys) > 0:
if len(self.meta.video_keys) > 0:
video_paths = self.encode_episode_videos(episode_index)
for key in self.meta.video_keys:
episode_buffer[key] = video_paths[key]
# `meta.save_episode` be executed after encoding the videos
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index])
ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
check_timestamps_sync(
episode_buffer["timestamp"],
episode_buffer["episode_index"],
ep_data_index_np,
self.fps,
self.tolerance_s,
)
video_files = list(self.root.rglob("*.mp4"))
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
parquet_files = list(self.root.rglob("*.parquet"))
assert len(parquet_files) == self.num_episodes
# delete images
img_dir = self.root / "images"
if img_dir.is_dir():
shutil.rmtree(self.root / "images")
if not episode_data: # Reset the buffer
self.episode_buffer = self.create_episode_buffer()
self.consolidated = False
def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None:
episode_dict = {key: episode_buffer[key] for key in self.hf_features}
ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train")
ep_dataset = embed_images(ep_dataset)
self.hf_dataset = concatenate_datasets([self.hf_dataset, ep_dataset])
self.hf_dataset.set_transform(hf_transform_to_torch)
ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index)
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
write_parquet(ep_dataset, ep_data_path)
ep_dataset.to_parquet(ep_data_path)
def clear_episode_buffer(self) -> None:
episode_index = self.episode_buffer["episode_index"]
@ -884,38 +977,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
return video_paths
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
self.hf_dataset = self.load_hf_dataset()
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
if len(self.meta.video_keys) > 0:
self.encode_videos()
self.meta.write_video_info()
if not keep_image_files:
img_dir = self.root / "images"
if img_dir.is_dir():
shutil.rmtree(self.root / "images")
video_files = list(self.root.rglob("*.mp4"))
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
parquet_files = list(self.root.rglob("*.parquet"))
assert len(parquet_files) == self.num_episodes
if run_compute_stats:
self.stop_image_writer()
# TODO(aliberts): refactor stats in save_episodes
self.meta.stats = compute_stats(self)
serialized_stats = serialize_dict(self.meta.stats)
write_json(serialized_stats, self.root / STATS_PATH)
self.consolidated = True
else:
logging.warning(
"Skipping computation of the dataset statistics, dataset is not fully consolidated."
)
@classmethod
def create(
cls,
@ -944,7 +1005,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
)
obj.repo_id = obj.meta.repo_id
obj.root = obj.meta.root
obj.local_files_only = obj.meta.local_files_only
obj.revision = None
obj.tolerance_s = tolerance_s
obj.image_writer = None
@ -954,14 +1015,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
obj.episode_buffer = obj.create_episode_buffer()
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It
# is used to know when certain operations are need (for instance, computing dataset statistics). In
# order to be able to push the dataset to the hub, it needs to be consolidated first by calling
# self.consolidate().
obj.consolidated = True
obj.episodes = None
obj.hf_dataset = None
obj.hf_dataset = obj.create_hf_dataset()
obj.image_transforms = None
obj.delta_timestamps = None
obj.delta_indices = None
@ -986,12 +1041,11 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
delta_timestamps: dict[list[float]] | None = None,
tolerances_s: dict | None = None,
download_videos: bool = True,
local_files_only: bool = False,
video_backend: str | None = None,
):
super().__init__()
self.repo_ids = repo_ids
self.root = Path(root) if root else LEROBOT_HOME
self.root = Path(root) if root else HF_LEROBOT_HOME
self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids}
# Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which
# are handled by this class.
@ -1004,7 +1058,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
delta_timestamps=delta_timestamps,
tolerance_s=self.tolerances_s[repo_id],
download_videos=download_videos,
local_files_only=local_files_only,
video_backend=video_backend,
)
for repo_id in repo_ids
@ -1032,7 +1085,10 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
self.image_transforms = image_transforms
self.delta_timestamps = delta_timestamps
self.stats = aggregate_stats(self._datasets)
# TODO(rcadene, aliberts): We should not perform this aggregation for datasets
# with multiple robots of different ranges. Instead we should have one normalization
# per robot.
self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets])
@property
def repo_id_to_index(self):

View File

@ -1,56 +0,0 @@
## Using / Updating `CODEBASE_VERSION` (for maintainers)
Since our dataset pushed to the hub are decoupled with the evolution of this repo, we ensure compatibility of
the datasets with our code, we use a `CODEBASE_VERSION` (defined in
lerobot/common/datasets/lerobot_dataset.py) variable.
For instance, [`lerobot/pusht`](https://huggingface.co/datasets/lerobot/pusht) has many versions to maintain backward compatibility between LeRobot codebase versions:
- [v1.0](https://huggingface.co/datasets/lerobot/pusht/tree/v1.0)
- [v1.1](https://huggingface.co/datasets/lerobot/pusht/tree/v1.1)
- [v1.2](https://huggingface.co/datasets/lerobot/pusht/tree/v1.2)
- [v1.3](https://huggingface.co/datasets/lerobot/pusht/tree/v1.3)
- [v1.4](https://huggingface.co/datasets/lerobot/pusht/tree/v1.4)
- [v1.5](https://huggingface.co/datasets/lerobot/pusht/tree/v1.5)
- [v1.6](https://huggingface.co/datasets/lerobot/pusht/tree/v1.6) <-- last version
- [main](https://huggingface.co/datasets/lerobot/pusht/tree/main) <-- points to the last version
Starting with v1.6, every dataset pushed to the hub or saved locally also have this version number in their
`info.json` metadata.
### Uploading a new dataset
If you are pushing a new dataset, you don't need to worry about any of the instructions below, nor to be
compatible with previous codebase versions. The `push_dataset_to_hub.py` script will automatically tag your
dataset with the current `CODEBASE_VERSION`.
### Updating an existing dataset
If you want to update an existing dataset, you need to change the `CODEBASE_VERSION` from `lerobot_dataset.py`
before running `push_dataset_to_hub.py`. This is especially useful if you introduce a breaking change
intentionally or not (i.e. something not backward compatible such as modifying the reward functions used,
deleting some frames at the end of an episode, etc.). That way, people running a previous version of the
codebase won't be affected by your change and backward compatibility is maintained.
However, you will need to update the version of ALL the other datasets so that they have the new
`CODEBASE_VERSION` as a branch in their hugging face dataset repository. Don't worry, there is an easy way
that doesn't require to run `push_dataset_to_hub.py`. You can just "branch-out" from the `main` branch on HF
dataset repo by running this script which corresponds to a `git checkout -b` (so no copy or upload needed):
```python
from huggingface_hub import HfApi
from lerobot import available_datasets
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
api = HfApi()
for repo_id in available_datasets:
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
branches = [b.name for b in dataset_info.branches]
if CODEBASE_VERSION in branches:
print(f"{repo_id} already @{CODEBASE_VERSION}, skipping.")
continue
else:
# Now create a branch named after the new version by branching out from "main"
# which is expected to be the preceding version
api.create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION, revision="main")
print(f"{repo_id} successfully updated @{CODEBASE_VERSION}")
```

View File

@ -13,10 +13,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import importlib.resources
import json
import logging
import textwrap
from collections.abc import Iterator
from itertools import accumulate
from pathlib import Path
@ -27,14 +27,20 @@ from typing import Any
import datasets
import jsonlines
import numpy as np
import pyarrow.compute as pc
import packaging.version
import torch
from datasets.table import embed_table_storage
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
from PIL import Image as PILImage
from torchvision import transforms
from lerobot.common.datasets.backward_compatibility import (
V21_MESSAGE,
BackwardCompatibilityError,
ForwardCompatibilityError,
)
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.utils.utils import is_valid_numpy_dtype_string
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
@ -42,6 +48,7 @@ DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
INFO_PATH = "meta/info.json"
EPISODES_PATH = "meta/episodes.jsonl"
STATS_PATH = "meta/stats.json"
EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
TASKS_PATH = "meta/tasks.jsonl"
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
@ -112,17 +119,26 @@ def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any:
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
serialized_dict = {key: value.tolist() for key, value in flatten_dict(stats).items()}
serialized_dict = {}
for key, value in flatten_dict(stats).items():
if isinstance(value, (torch.Tensor, np.ndarray)):
serialized_dict[key] = value.tolist()
elif isinstance(value, np.generic):
serialized_dict[key] = value.item()
elif isinstance(value, (int, float)):
serialized_dict[key] = value
else:
raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.")
return unflatten_dict(serialized_dict)
def write_parquet(dataset: datasets.Dataset, fpath: Path) -> None:
def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
# Embed image bytes into the table before saving to parquet
format = dataset.format
dataset = dataset.with_format("arrow")
dataset = dataset.map(embed_table_storage, batched=False)
dataset = dataset.with_format(**format)
dataset.to_parquet(fpath)
return dataset
def load_json(fpath: Path) -> Any:
@ -153,6 +169,10 @@ def append_jsonlines(data: dict, fpath: Path) -> None:
writer.write(data)
def write_info(info: dict, local_dir: Path):
write_json(info, local_dir / INFO_PATH)
def load_info(local_dir: Path) -> dict:
info = load_json(local_dir / INFO_PATH)
for ft in info["features"].values():
@ -160,29 +180,76 @@ def load_info(local_dir: Path) -> dict:
return info
def load_stats(local_dir: Path) -> dict:
if not (local_dir / STATS_PATH).exists():
return None
stats = load_json(local_dir / STATS_PATH)
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
def write_stats(stats: dict, local_dir: Path):
serialized_stats = serialize_dict(stats)
write_json(serialized_stats, local_dir / STATS_PATH)
def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]:
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats)
def load_tasks(local_dir: Path) -> dict:
def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]:
if not (local_dir / STATS_PATH).exists():
return None
stats = load_json(local_dir / STATS_PATH)
return cast_stats_to_numpy(stats)
def write_task(task_index: int, task: dict, local_dir: Path):
task_dict = {
"task_index": task_index,
"task": task,
}
append_jsonlines(task_dict, local_dir / TASKS_PATH)
def load_tasks(local_dir: Path) -> tuple[dict, dict]:
tasks = load_jsonlines(local_dir / TASKS_PATH)
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
task_to_task_index = {task: task_index for task_index, task in tasks.items()}
return tasks, task_to_task_index
def write_episode(episode: dict, local_dir: Path):
append_jsonlines(episode, local_dir / EPISODES_PATH)
def load_episodes(local_dir: Path) -> dict:
return load_jsonlines(local_dir / EPISODES_PATH)
episodes = load_jsonlines(local_dir / EPISODES_PATH)
return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])}
def load_image_as_numpy(fpath: str | Path, dtype="float32", channel_first: bool = True) -> np.ndarray:
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
# We wrap episode_stats in a dictionnary since `episode_stats["episode_index"]`
# is a dictionary of stats and not an integer.
episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH)
def load_episodes_stats(local_dir: Path) -> dict:
episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH)
return {
item["episode_index"]: cast_stats_to_numpy(item["stats"])
for item in sorted(episodes_stats, key=lambda x: x["episode_index"])
}
def backward_compatible_episodes_stats(
stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
) -> dict[str, dict[str, np.ndarray]]:
return {ep_idx: stats for ep_idx in episodes}
def load_image_as_numpy(
fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True
) -> np.ndarray:
img = PILImage.open(fpath).convert("RGB")
img_array = np.array(img, dtype=dtype)
if channel_first: # (H, W, C) -> (C, H, W)
img_array = np.transpose(img_array, (2, 0, 1))
if "float" in dtype:
if np.issubdtype(dtype, np.floating):
img_array /= 255.0
return img_array
@ -201,77 +268,82 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
elif first_item is None:
pass
else:
items_dict[key] = [torch.tensor(x) for x in items_dict[key]]
items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]]
return items_dict
def _get_major_minor(version: str) -> tuple[int]:
split = version.strip("v").split(".")
return int(split[0]), int(split[1])
class BackwardCompatibilityError(Exception):
def __init__(self, repo_id, version):
message = textwrap.dedent(f"""
BackwardCompatibilityError: The dataset you requested ({repo_id}) is in {version} format.
We introduced a new format since v2.0 which is not backward compatible with v1.x.
Please, use our conversion script. Modify the following command with your own task description:
```
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
--repo-id {repo_id} \\
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
```
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.",
"Insert the peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.",
"Open the top cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped target.",
"Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the sweatshirt.", ...
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
""")
super().__init__(message)
def is_valid_version(version: str) -> bool:
try:
packaging.version.parse(version)
return True
except packaging.version.InvalidVersion:
return False
def check_version_compatibility(
repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True
repo_id: str,
version_to_check: str | packaging.version.Version,
current_version: str | packaging.version.Version,
enforce_breaking_major: bool = True,
) -> None:
current_major, _ = _get_major_minor(current_version)
major_to_check, _ = _get_major_minor(version_to_check)
if major_to_check < current_major and enforce_breaking_major:
raise BackwardCompatibilityError(repo_id, version_to_check)
elif float(version_to_check.strip("v")) < float(current_version.strip("v")):
logging.warning(
f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the
codebase. The current codebase version is {current_version}. You should be fine since
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
)
v_check = (
packaging.version.parse(version_to_check)
if not isinstance(version_to_check, packaging.version.Version)
else version_to_check
)
v_current = (
packaging.version.parse(current_version)
if not isinstance(current_version, packaging.version.Version)
else current_version
)
if v_check.major < v_current.major and enforce_breaking_major:
raise BackwardCompatibilityError(repo_id, v_check)
elif v_check.minor < v_current.minor:
logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=v_check))
def get_hub_safe_version(repo_id: str, version: str) -> str:
def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
"""Returns available valid versions (branches and tags) on given repo."""
api = HfApi()
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
branches = [b.name for b in dataset_info.branches]
if version not in branches:
num_version = float(version.strip("v"))
hub_num_versions = [float(v.strip("v")) for v in branches if v.startswith("v")]
if num_version >= 2.0 and all(v < 2.0 for v in hub_num_versions):
raise BackwardCompatibilityError(repo_id, version)
repo_refs = api.list_repo_refs(repo_id, repo_type="dataset")
repo_refs = [b.name for b in repo_refs.branches + repo_refs.tags]
repo_versions = []
for ref in repo_refs:
with contextlib.suppress(packaging.version.InvalidVersion):
repo_versions.append(packaging.version.parse(ref))
logging.warning(
f"""You are trying to load a dataset from {repo_id} created with a previous version of the
codebase. The following versions are available: {branches}.
The requested version ('{version}') is not found. You should be fine since
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
)
if "main" not in branches:
raise ValueError(f"Version 'main' not found on {repo_id}")
return "main"
else:
return version
return repo_versions
def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> str:
"""
Returns the version if available on repo or the latest compatible one.
Otherwise, will throw a `CompatibilityError`.
"""
target_version = (
packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version
)
hub_versions = get_repo_versions(repo_id)
if target_version in hub_versions:
return f"v{target_version}"
compatibles = [
v for v in hub_versions if v.major == target_version.major and v.minor <= target_version.minor
]
if compatibles:
return_version = max(compatibles)
if return_version < target_version:
logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}")
return f"v{return_version}"
lower_major = [v for v in hub_versions if v.major < target_version.major]
if lower_major:
raise BackwardCompatibilityError(repo_id, max(lower_major))
upper_versions = [v for v in hub_versions if v > target_version]
assert len(upper_versions) > 0
raise ForwardCompatibilityError(repo_id, min(upper_versions))
def get_hf_features_from_features(features: dict) -> datasets.Features:
@ -283,11 +355,20 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
hf_features[key] = datasets.Image()
elif ft["shape"] == (1,):
hf_features[key] = datasets.Value(dtype=ft["dtype"])
else:
assert len(ft["shape"]) == 1
elif len(ft["shape"]) == 1:
hf_features[key] = datasets.Sequence(
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
)
elif len(ft["shape"]) == 2:
hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"])
elif len(ft["shape"]) == 3:
hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"])
elif len(ft["shape"]) == 4:
hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"])
elif len(ft["shape"]) == 5:
hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"])
else:
raise ValueError(f"Corresponding feature is not valid: {ft}")
return datasets.Features(hf_features)
@ -358,9 +439,9 @@ def create_empty_dataset_info(
def get_episode_data_index(
episode_dicts: list[dict], episodes: list[int] | None = None
episode_dicts: dict[dict], episodes: list[int] | None = None
) -> dict[str, torch.Tensor]:
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)}
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()}
if episodes is not None:
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
@ -371,75 +452,72 @@ def get_episode_data_index(
}
def calculate_total_episode(
hf_dataset: datasets.Dataset, raise_if_not_contiguous: bool = True
) -> dict[str, torch.Tensor]:
episode_indices = sorted(hf_dataset.unique("episode_index"))
total_episodes = len(episode_indices)
if raise_if_not_contiguous and episode_indices != list(range(total_episodes)):
raise ValueError("episode_index values are not sorted and contiguous.")
return total_episodes
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, torch.Tensor]:
episode_lengths = []
table = hf_dataset.data.table
total_episodes = calculate_total_episode(hf_dataset)
for ep_idx in range(total_episodes):
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
episode_lengths.insert(ep_idx, len(ep_table))
cumulative_lenghts = list(accumulate(episode_lengths))
return {
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
"to": torch.LongTensor(cumulative_lenghts),
}
def check_timestamps_sync(
hf_dataset: datasets.Dataset,
episode_data_index: dict[str, torch.Tensor],
timestamps: np.ndarray,
episode_indices: np.ndarray,
episode_data_index: dict[str, np.ndarray],
fps: int,
tolerance_s: float,
raise_value_error: bool = True,
) -> bool:
"""
This check is to make sure that each timestamps is separated to the next by 1/fps +/- tolerance to
account for possible numerical error.
"""
timestamps = torch.stack(hf_dataset["timestamp"])
diffs = torch.diff(timestamps)
within_tolerance = torch.abs(diffs - 1 / fps) <= tolerance_s
This check is to make sure that each timestamp is separated from the next by (1/fps) +/- tolerance
to account for possible numerical error.
# We mask differences between the timestamp at the end of an episode
# and the one at the start of the next episode since these are expected
# to be outside tolerance.
mask = torch.ones(len(diffs), dtype=torch.bool)
ignored_diffs = episode_data_index["to"][:-1] - 1
Args:
timestamps (np.ndarray): Array of timestamps in seconds.
episode_indices (np.ndarray): Array indicating the episode index for each timestamp.
episode_data_index (dict[str, np.ndarray]): A dictionary that includes 'to',
which identifies indices for the end of each episode.
fps (int): Frames per second. Used to check the expected difference between consecutive timestamps.
tolerance_s (float): Allowed deviation from the expected (1/fps) difference.
raise_value_error (bool): Whether to raise a ValueError if the check fails.
Returns:
bool: True if all checked timestamp differences lie within tolerance, False otherwise.
Raises:
ValueError: If the check fails and `raise_value_error` is True.
"""
if timestamps.shape != episode_indices.shape:
raise ValueError(
"timestamps and episode_indices should have the same shape. "
f"Found {timestamps.shape=} and {episode_indices.shape=}."
)
# Consecutive differences
diffs = np.diff(timestamps)
within_tolerance = np.abs(diffs - (1.0 / fps)) <= tolerance_s
# Mask to ignore differences at the boundaries between episodes
mask = np.ones(len(diffs), dtype=bool)
ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode
mask[ignored_diffs] = False
filtered_within_tolerance = within_tolerance[mask]
if not torch.all(filtered_within_tolerance):
# Check if all remaining diffs are within tolerance
if not np.all(filtered_within_tolerance):
# Track original indices before masking
original_indices = torch.arange(len(diffs))
original_indices = np.arange(len(diffs))
filtered_indices = original_indices[mask]
outside_tolerance_filtered_indices = torch.nonzero(~filtered_within_tolerance) # .squeeze()
outside_tolerance_filtered_indices = np.nonzero(~filtered_within_tolerance)[0]
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
episode_indices = torch.stack(hf_dataset["episode_index"])
outside_tolerances = []
for idx in outside_tolerance_indices:
entry = {
"timestamps": [timestamps[idx], timestamps[idx + 1]],
"diff": diffs[idx],
"episode_index": episode_indices[idx].item(),
"episode_index": episode_indices[idx].item()
if hasattr(episode_indices[idx], "item")
else episode_indices[idx],
}
outside_tolerances.append(entry)
if raise_value_error:
raise ValueError(
f"""One or several timestamps unexpectedly violate the tolerance inside episode range.
This might be due to synchronization issues with timestamps during data collection.
This might be due to synchronization issues during data collection.
\n{pformat(outside_tolerances)}"""
)
return False
@ -604,3 +682,118 @@ class IterableNamespace(SimpleNamespace):
def keys(self):
return vars(self).keys()
def validate_frame(frame: dict, features: dict):
optional_features = {"timestamp"}
expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"}
actual_features = set(frame.keys())
error_message = validate_features_presence(actual_features, expected_features, optional_features)
if "task" in frame:
error_message += validate_feature_string("task", frame["task"])
common_features = actual_features & (expected_features | optional_features)
for name in common_features - {"task"}:
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
if error_message:
raise ValueError(error_message)
def validate_features_presence(
actual_features: set[str], expected_features: set[str], optional_features: set[str]
):
error_message = ""
missing_features = expected_features - actual_features
extra_features = actual_features - (expected_features | optional_features)
if missing_features or extra_features:
error_message += "Feature mismatch in `frame` dictionary:\n"
if missing_features:
error_message += f"Missing features: {missing_features}\n"
if extra_features:
error_message += f"Extra features: {extra_features}\n"
return error_message
def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str):
expected_dtype = feature["dtype"]
expected_shape = feature["shape"]
if is_valid_numpy_dtype_string(expected_dtype):
return validate_feature_numpy_array(name, expected_dtype, expected_shape, value)
elif expected_dtype in ["image", "video"]:
return validate_feature_image_or_video(name, expected_shape, value)
elif expected_dtype == "string":
return validate_feature_string(name, value)
else:
raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.")
def validate_feature_numpy_array(
name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray
):
error_message = ""
if isinstance(value, np.ndarray):
actual_dtype = value.dtype
actual_shape = value.shape
if actual_dtype != np.dtype(expected_dtype):
error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n"
if actual_shape != expected_shape:
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n"
else:
error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n"
return error_message
def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image):
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
error_message = ""
if isinstance(value, np.ndarray):
actual_shape = value.shape
c, h, w = expected_shape
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
elif isinstance(value, PILImage.Image):
pass
else:
error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n"
return error_message
def validate_feature_string(name: str, value: str):
if not isinstance(value, str):
return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n"
return ""
def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict):
if "size" not in episode_buffer:
raise ValueError("size key not found in episode_buffer")
if "task" not in episode_buffer:
raise ValueError("task key not found in episode_buffer")
if episode_buffer["episode_index"] != total_episodes:
# TODO(aliberts): Add option to use existing episode_index
raise NotImplementedError(
"You might have manually provided the episode_buffer with an episode_index that doesn't "
"match the total number of episodes already in the dataset. This is not supported for now."
)
if episode_buffer["size"] == 0:
raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.")
buffer_keys = set(episode_buffer.keys()) - {"task", "size"}
if not buffer_keys == set(features):
raise ValueError(
f"Features from `episode_buffer` don't match the ones in `features`."
f"In episode_buffer not in features: {buffer_keys - set(features)}"
f"In features not in episode_buffer: {set(features) - buffer_keys}"
)

View File

@ -130,7 +130,7 @@ from lerobot.common.datasets.utils import (
create_branch,
create_lerobot_dataset_card,
flatten_dict,
get_hub_safe_version,
get_safe_version,
load_json,
unflatten_dict,
write_json,
@ -443,7 +443,7 @@ def convert_dataset(
test_branch: str | None = None,
**card_kwargs,
):
v1 = get_hub_safe_version(repo_id, V16)
v1 = get_safe_version(repo_id, V16)
v1x_dir = local_dir / V16 / repo_id
v20_dir = local_dir / V20 / repo_id
v1x_dir.mkdir(parents=True, exist_ok=True)

View File

@ -0,0 +1,73 @@
import logging
import traceback
from pathlib import Path
from datasets import get_dataset_config_info
from huggingface_hub import HfApi
from lerobot import available_datasets
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.common.datasets.utils import INFO_PATH, write_info
from lerobot.common.datasets.v21.convert_dataset_v20_to_v21 import V20, SuppressWarnings
LOCAL_DIR = Path("data/")
hub_api = HfApi()
def fix_dataset(repo_id: str) -> str:
if not hub_api.revision_exists(repo_id, V20, repo_type="dataset"):
return f"{repo_id}: skipped (not in {V20})."
dataset_info = get_dataset_config_info(repo_id, "default")
with SuppressWarnings():
lerobot_metadata = LeRobotDatasetMetadata(repo_id, revision=V20, force_cache_sync=True)
meta_features = {key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"}
parquet_features = set(dataset_info.features)
diff_parquet_meta = parquet_features - meta_features
diff_meta_parquet = meta_features - parquet_features
if diff_parquet_meta:
raise ValueError(f"In parquet not in info.json: {parquet_features - meta_features}")
if not diff_meta_parquet:
return f"{repo_id}: skipped (no diff)"
if diff_meta_parquet:
logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}")
assert diff_meta_parquet == {"language_instruction"}
lerobot_metadata.features.pop("language_instruction")
write_info(lerobot_metadata.info, lerobot_metadata.root)
commit_info = hub_api.upload_file(
path_or_fileobj=lerobot_metadata.root / INFO_PATH,
path_in_repo=INFO_PATH,
repo_id=repo_id,
repo_type="dataset",
revision=V20,
commit_message="Remove 'language_instruction'",
create_pr=True,
)
return f"{repo_id}: success - PR: {commit_info.pr_url}"
def batch_fix():
status = {}
LOCAL_DIR.mkdir(parents=True, exist_ok=True)
logfile = LOCAL_DIR / "fix_features_v20.txt"
for num, repo_id in enumerate(available_datasets):
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
print("---------------------------------------------------------")
try:
status = fix_dataset(repo_id)
except Exception:
status = f"{repo_id}: failed\n {traceback.format_exc()}"
logging.info(status)
with open(logfile, "a") as file:
file.write(status + "\n")
if __name__ == "__main__":
batch_fix()

View File

@ -0,0 +1,54 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.1.
"""
import traceback
from pathlib import Path
from huggingface_hub import HfApi
from lerobot import available_datasets
from lerobot.common.datasets.v21.convert_dataset_v20_to_v21 import V21, convert_dataset
LOCAL_DIR = Path("data/")
def batch_convert():
status = {}
LOCAL_DIR.mkdir(parents=True, exist_ok=True)
logfile = LOCAL_DIR / "conversion_log_v21.txt"
hub_api = HfApi()
for num, repo_id in enumerate(available_datasets):
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
print("---------------------------------------------------------")
try:
if hub_api.revision_exists(repo_id, V21, repo_type="dataset"):
status = f"{repo_id}: success (already in {V21})."
else:
convert_dataset(repo_id)
status = f"{repo_id}: success."
except Exception:
status = f"{repo_id}: failed\n {traceback.format_exc()}"
with open(logfile, "a") as file:
file.write(status + "\n")
if __name__ == "__main__":
batch_convert()

View File

@ -0,0 +1,100 @@
"""
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to
2.1. It will:
- Generate per-episodes stats and writes them in `episodes_stats.jsonl`
- Check consistency between these new stats and the old ones.
- Remove the deprecated `stats.json`.
- Update codebase_version in `info.json`.
- Push this new version to the hub on the 'main' branch and tags it with "v2.1".
Usage:
```bash
python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py \
--repo-id=aliberts/koch_tutorial
```
"""
import argparse
import logging
from huggingface_hub import HfApi
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
V20 = "v2.0"
V21 = "v2.1"
class SuppressWarnings:
def __enter__(self):
self.previous_level = logging.getLogger().getEffectiveLevel()
logging.getLogger().setLevel(logging.ERROR)
def __exit__(self, exc_type, exc_val, exc_tb):
logging.getLogger().setLevel(self.previous_level)
def convert_dataset(
repo_id: str,
branch: str | None = None,
num_workers: int = 4,
):
with SuppressWarnings():
dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True)
if (dataset.root / EPISODES_STATS_PATH).is_file():
(dataset.root / EPISODES_STATS_PATH).unlink()
convert_stats(dataset, num_workers=num_workers)
ref_stats = load_stats(dataset.root)
check_aggregate_stats(dataset, ref_stats)
dataset.meta.info["codebase_version"] = CODEBASE_VERSION
write_info(dataset.meta.info, dataset.root)
dataset.push_to_hub(branch=branch, allow_patterns="meta/")
# delete old stats.json file
if (dataset.root / STATS_PATH).is_file:
(dataset.root / STATS_PATH).unlink()
hub_api = HfApi()
if hub_api.file_exists(
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
):
hub_api.delete_file(
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
)
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset "
"(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
)
parser.add_argument(
"--branch",
type=str,
default=None,
help="Repo branch to push your dataset. Defaults to the main branch.",
)
parser.add_argument(
"--num-workers",
type=int,
default=4,
help="Number of workers for parallelizing stats compute. Defaults to 4.",
)
args = parser.parse_args()
convert_dataset(**vars(args))

View File

@ -0,0 +1,85 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
from tqdm import tqdm
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import write_episode_stats
def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray:
ep_len = dataset.meta.episodes[episode_index]["length"]
sampled_indices = sample_indices(ep_len)
query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices})
video_frames = dataset._query_videos(query_timestamps, episode_index)
return video_frames[ft_key].numpy()
def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
ep_start_idx = dataset.episode_data_index["from"][ep_idx]
ep_end_idx = dataset.episode_data_index["to"][ep_idx]
ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx))
ep_stats = {}
for key, ft in dataset.features.items():
if ft["dtype"] == "video":
# We sample only for videos
ep_ft_data = sample_episode_video_frames(dataset, ep_idx, key)
else:
ep_ft_data = np.array(ep_data[key])
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
if ft["dtype"] in ["image", "video"]: # remove batch dim
ep_stats[key] = {
k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items()
}
dataset.meta.episodes_stats[ep_idx] = ep_stats
def convert_stats(dataset: LeRobotDataset, num_workers: int = 0):
assert dataset.episodes is None
print("Computing episodes stats")
total_episodes = dataset.meta.total_episodes
if num_workers > 0:
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = {
executor.submit(convert_episode_stats, dataset, ep_idx): ep_idx
for ep_idx in range(total_episodes)
}
for future in tqdm(as_completed(futures), total=total_episodes):
future.result()
else:
for ep_idx in tqdm(range(total_episodes)):
convert_episode_stats(dataset, ep_idx)
for ep_idx in tqdm(range(total_episodes)):
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
def check_aggregate_stats(
dataset: LeRobotDataset,
reference_stats: dict[str, dict[str, np.ndarray]],
video_rtol_atol: tuple[float] = (1e-2, 1e-2),
default_rtol_atol: tuple[float] = (5e-6, 6e-5),
):
"""Verifies that the aggregated stats from episodes_stats are close to reference stats."""
agg_stats = aggregate_stats(list(dataset.meta.episodes_stats.values()))
for key, ft in dataset.features.items():
# These values might need some fine-tuning
if ft["dtype"] == "video":
# to account for image sub-sampling
rtol, atol = video_rtol_atol
else:
rtol, atol = default_rtol_atol
for stat, val in agg_stats[key].items():
if key in reference_stats and stat in reference_stats[key]:
err_msg = f"feature='{key}' stats='{stat}'"
np.testing.assert_allclose(
val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg
)

View File

@ -69,8 +69,8 @@ def decode_video_frames_torchvision(
# set the first and last requested timestamps
# Note: previous timestamps are usually loaded, since we need to access the previous key frame
first_ts = timestamps[0]
last_ts = timestamps[-1]
first_ts = min(timestamps)
last_ts = max(timestamps)
# access closest key frame of the first requested frame
# Note: closest key frame timestamp is usally smaller than `first_ts` (e.g. key frame can be the first frame of the video)

View File

@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from torch import Tensor, nn
@ -77,17 +78,29 @@ def create_stats_buffers(
}
)
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
if stats:
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = stats[key]["mean"].clone()
buffer["std"].data = stats[key]["std"].clone()
elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = stats[key]["min"].clone()
buffer["max"].data = stats[key]["max"].clone()
if isinstance(stats[key]["mean"], np.ndarray):
if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
elif isinstance(stats[key]["mean"], torch.Tensor):
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
else:
type_ = type(stats[key]["mean"])
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
stats_buffers[key] = buffer
return stats_buffers
@ -141,6 +154,7 @@ class Normalize(nn.Module):
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, ft in self.features.items():
if key not in batch:
# FIXME(aliberts, rcadene): This might lead to silent fail!
continue
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)

View File

@ -60,8 +60,6 @@ class RecordControlConfig(ControlConfig):
num_episodes: int = 50
# Encode frames in the dataset into video
video: bool = True
# By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.
run_compute_stats: bool = True
# Upload dataset to Hugging Face hub.
push_to_hub: bool = True
# Upload on private repository on the Hugging Face hub.
@ -83,9 +81,6 @@ class RecordControlConfig(ControlConfig):
play_sounds: bool = True
# Resume recording on an existing dataset.
resume: bool = False
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
# Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.
local_files_only: bool = False
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
@ -130,9 +125,6 @@ class ReplayControlConfig(ControlConfig):
fps: int | None = None
# Use vocal synthesis to read events.
play_sounds: bool = True
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
# Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.
local_files_only: bool = False
@ControlConfig.register_subclass("remote_robot")

View File

@ -183,6 +183,7 @@ def record_episode(
device,
use_amp,
fps,
single_task,
):
control_loop(
robot=robot,
@ -195,6 +196,7 @@ def record_episode(
use_amp=use_amp,
fps=fps,
teleoperate=policy is None,
single_task=single_task,
)
@ -210,6 +212,7 @@ def control_loop(
device: torch.device | str | None = None,
use_amp: bool | None = None,
fps: int | None = None,
single_task: str | None = None,
):
# TODO(rcadene): Add option to record logs
if not robot.is_connected:
@ -224,6 +227,9 @@ def control_loop(
if teleoperate and policy is not None:
raise ValueError("When `teleoperate` is True, `policy` should be None.")
if dataset is not None and single_task is None:
raise ValueError("You need to provide a task as argument in `single_task`.")
if dataset is not None and fps is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
@ -248,7 +254,7 @@ def control_loop(
action = {"action": action}
if dataset is not None:
frame = {**observation, **action}
frame = {**observation, **action, "task": single_task}
dataset.add_frame(frame)
if display_cameras and not is_headless():

View File

@ -21,6 +21,7 @@ from copy import copy
from datetime import datetime, timezone
from pathlib import Path
import numpy as np
import torch
@ -200,5 +201,18 @@ def get_channel_first_image_shape(image_shape: tuple) -> tuple:
return shape
def has_method(cls: object, method_name: str):
def has_method(cls: object, method_name: str) -> bool:
return hasattr(cls, method_name) and callable(getattr(cls, method_name))
def is_valid_numpy_dtype_string(dtype_str: str) -> bool:
"""
Return True if a given string can be converted to a numpy dtype.
"""
try:
# Attempt to convert the string to a numpy dtype
np.dtype(dtype_str)
return True
except TypeError:
# If a TypeError is raised, the string is not a valid dtype
return False

View File

@ -31,7 +31,7 @@ class DatasetConfig:
repo_id: str
episodes: list[int] | None = None
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
local_files_only: bool = False
revision: str | None = None
use_imagenet_stats: bool = True
video_backend: str = "pyav"

View File

@ -92,7 +92,6 @@ python lerobot/scripts/control_robot.py \
This might require a sudo permission to allow your terminal to monitor keyboard events.
**NOTE**: You can resume/continue data recording by running the same data recording command and adding `--control.resume=true`.
If the dataset you want to extend is not on the hub, you also need to add `--control.local_files_only=true`.
- Train on this dataset with the ACT policy:
```bash
@ -234,7 +233,6 @@ def record(
dataset = LeRobotDataset(
cfg.repo_id,
root=cfg.root,
local_files_only=cfg.local_files_only,
)
if len(robot.cameras) > 0:
dataset.start_image_writer(
@ -281,8 +279,8 @@ def record(
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
record_episode(
dataset=dataset,
robot=robot,
dataset=dataset,
events=events,
episode_time_s=cfg.episode_time_s,
display_cameras=cfg.display_cameras,
@ -290,6 +288,7 @@ def record(
device=cfg.device,
use_amp=cfg.use_amp,
fps=cfg.fps,
single_task=cfg.single_task,
)
# Execute a few seconds without recording to give time to manually reset the environment
@ -309,7 +308,7 @@ def record(
dataset.clear_episode_buffer()
continue
dataset.save_episode(cfg.single_task)
dataset.save_episode()
recorded_episodes += 1
if events["stop_recording"]:
@ -318,11 +317,6 @@ def record(
log_say("Stop recording", cfg.play_sounds, blocking=True)
stop_recording(robot, listener, cfg.display_cameras)
if cfg.run_compute_stats:
logging.info("Computing dataset statistics")
dataset.consolidate(cfg.run_compute_stats)
if cfg.push_to_hub:
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)
@ -338,9 +332,7 @@ def replay(
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
# TODO(rcadene): Add option to record logs
dataset = LeRobotDataset(
cfg.repo_id, root=cfg.root, episodes=[cfg.episode], local_files_only=cfg.local_files_only
)
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode])
actions = dataset.hf_dataset.select_columns("action")
if not robot.is_connected:

View File

@ -207,12 +207,6 @@ def main():
required=True,
help="Episode to visualize.",
)
parser.add_argument(
"--local-files-only",
type=int,
default=0,
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
)
parser.add_argument(
"--root",
type=Path,
@ -275,10 +269,9 @@ def main():
kwargs = vars(args)
repo_id = kwargs.pop("repo_id")
root = kwargs.pop("root")
local_files_only = kwargs.pop("local_files_only")
logging.info("Loading dataset")
dataset = LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
dataset = LeRobotDataset(repo_id, root=root)
visualize_dataset(dataset, **vars(args))

View File

@ -150,7 +150,7 @@ def run_server(
400,
)
dataset_version = (
dataset.meta._version if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
str(dataset.meta._version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
)
match = re.search(r"v(\d+)\.", dataset_version)
if match:
@ -384,12 +384,6 @@ def main():
default=None,
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
)
parser.add_argument(
"--local-files-only",
type=int,
default=0,
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
)
parser.add_argument(
"--root",
type=Path,
@ -445,15 +439,10 @@ def main():
repo_id = kwargs.pop("repo_id")
load_from_hf_hub = kwargs.pop("load_from_hf_hub")
root = kwargs.pop("root")
local_files_only = kwargs.pop("local_files_only")
dataset = None
if repo_id:
dataset = (
LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
if not load_from_hf_hub
else get_dataset_info(repo_id)
)
dataset = LeRobotDataset(repo_id, root=root) if not load_from_hf_hub else get_dataset_info(repo_id)
visualize_dataset_html(dataset, **vars(args))

View File

@ -109,7 +109,7 @@ def visualize_image_transforms(cfg: DatasetConfig, output_dir: Path = OUTPUT_DIR
dataset = LeRobotDataset(
repo_id=cfg.repo_id,
episodes=cfg.episodes,
local_files_only=cfg.local_files_only,
revision=cfg.revision,
video_backend=cfg.video_backend,
)

View File

@ -47,6 +47,7 @@ dependencies = [
"numba>=0.59.0",
"omegaconf>=2.3.0",
"opencv-python>=4.9.0",
"packaging>=24.2",
"pyav>=12.0.5",
"pymunk>=6.6.0",
"pynput>=1.7.7",

View File

@ -1,6 +1,6 @@
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
from lerobot.common.constants import HF_LEROBOT_HOME
LEROBOT_TEST_DIR = LEROBOT_HOME / "_testing"
LEROBOT_TEST_DIR = HF_LEROBOT_HOME / "_testing"
DUMMY_REPO_ID = "dummy/repo"
DUMMY_ROBOT_TYPE = "dummy_robot"
DUMMY_MOTOR_FEATURES = {
@ -27,3 +27,5 @@ DUMMY_VIDEO_INFO = {
"video.is_depth_map": False,
"has_audio": False,
}
DUMMY_CHW = (3, 96, 128)
DUMMY_HWC = (96, 128, 3)

View File

@ -1,5 +1,7 @@
import random
from functools import partial
from pathlib import Path
from typing import Protocol
from unittest.mock import patch
import datasets
@ -27,8 +29,12 @@ from tests.fixtures.constants import (
)
class LeRobotDatasetFactory(Protocol):
def __call__(self, *args, **kwargs) -> LeRobotDataset: ...
def get_task_index(task_dicts: dict, task: str) -> int:
tasks = {d["task_index"]: d["task"] for d in task_dicts}
tasks = {d["task_index"]: d["task"] for d in task_dicts.values()}
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()}
return task_to_task_index[task]
@ -141,6 +147,7 @@ def stats_factory():
"mean": np.full((3, 1, 1), 0.5, dtype=np.float32).tolist(),
"min": np.full((3, 1, 1), 0, dtype=np.float32).tolist(),
"std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(),
"count": [10],
}
else:
stats[key] = {
@ -148,20 +155,38 @@ def stats_factory():
"mean": np.full(shape, 0.5, dtype=dtype).tolist(),
"min": np.full(shape, 0, dtype=dtype).tolist(),
"std": np.full(shape, 0.25, dtype=dtype).tolist(),
"count": [10],
}
return stats
return _create_stats
@pytest.fixture(scope="session")
def episodes_stats_factory(stats_factory):
def _create_episodes_stats(
features: dict[str],
total_episodes: int = 3,
) -> dict:
episodes_stats = {}
for episode_index in range(total_episodes):
episodes_stats[episode_index] = {
"episode_index": episode_index,
"stats": stats_factory(features),
}
return episodes_stats
return _create_episodes_stats
@pytest.fixture(scope="session")
def tasks_factory():
def _create_tasks(total_tasks: int = 3) -> int:
tasks_list = []
for i in range(total_tasks):
task_dict = {"task_index": i, "task": f"Perform action {i}."}
tasks_list.append(task_dict)
return tasks_list
tasks = {}
for task_index in range(total_tasks):
task_dict = {"task_index": task_index, "task": f"Perform action {task_index}."}
tasks[task_index] = task_dict
return tasks
return _create_tasks
@ -190,10 +215,10 @@ def episodes_factory(tasks_factory):
# Generate random lengths that sum up to total_length
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
tasks_list = [task_dict["task"] for task_dict in tasks]
tasks_list = [task_dict["task"] for task_dict in tasks.values()]
num_tasks_available = len(tasks_list)
episodes_list = []
episodes = {}
remaining_tasks = tasks_list.copy()
for ep_idx in range(total_episodes):
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
@ -203,15 +228,13 @@ def episodes_factory(tasks_factory):
for task in episode_tasks:
remaining_tasks.remove(task)
episodes_list.append(
{
"episode_index": ep_idx,
"tasks": episode_tasks,
"length": lengths[ep_idx],
}
)
episodes[ep_idx] = {
"episode_index": ep_idx,
"tasks": episode_tasks,
"length": lengths[ep_idx],
}
return episodes_list
return episodes
return _create_episodes
@ -235,7 +258,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
frame_index_col = np.array([], dtype=np.int64)
episode_index_col = np.array([], dtype=np.int64)
task_index = np.array([], dtype=np.int64)
for ep_dict in episodes:
for ep_dict in episodes.values():
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
episode_index_col = np.concatenate(
@ -278,6 +301,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
def lerobot_dataset_metadata_factory(
info_factory,
stats_factory,
episodes_stats_factory,
tasks_factory,
episodes_factory,
mock_snapshot_download_factory,
@ -287,14 +311,18 @@ def lerobot_dataset_metadata_factory(
repo_id: str = DUMMY_REPO_ID,
info: dict | None = None,
stats: dict | None = None,
episodes_stats: list[dict] | None = None,
tasks: list[dict] | None = None,
episodes: list[dict] | None = None,
local_files_only: bool = False,
) -> LeRobotDatasetMetadata:
if not info:
info = info_factory()
if not stats:
stats = stats_factory(features=info["features"])
if not episodes_stats:
episodes_stats = episodes_stats_factory(
features=info["features"], total_episodes=info["total_episodes"]
)
if not tasks:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
@ -305,21 +333,20 @@ def lerobot_dataset_metadata_factory(
mock_snapshot_download = mock_snapshot_download_factory(
info=info,
stats=stats,
episodes_stats=episodes_stats,
tasks=tasks,
episodes=episodes,
)
with (
patch(
"lerobot.common.datasets.lerobot_dataset.get_hub_safe_version"
) as mock_get_hub_safe_version_patch,
patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,
):
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version: version
mock_get_safe_version_patch.side_effect = lambda repo_id, version: version
mock_snapshot_download_patch.side_effect = mock_snapshot_download
return LeRobotDatasetMetadata(repo_id=repo_id, root=root, local_files_only=local_files_only)
return LeRobotDatasetMetadata(repo_id=repo_id, root=root)
return _create_lerobot_dataset_metadata
@ -328,12 +355,13 @@ def lerobot_dataset_metadata_factory(
def lerobot_dataset_factory(
info_factory,
stats_factory,
episodes_stats_factory,
tasks_factory,
episodes_factory,
hf_dataset_factory,
mock_snapshot_download_factory,
lerobot_dataset_metadata_factory,
):
) -> LeRobotDatasetFactory:
def _create_lerobot_dataset(
root: Path,
repo_id: str = DUMMY_REPO_ID,
@ -343,6 +371,7 @@ def lerobot_dataset_factory(
multi_task: bool = False,
info: dict | None = None,
stats: dict | None = None,
episodes_stats: list[dict] | None = None,
tasks: list[dict] | None = None,
episode_dicts: list[dict] | None = None,
hf_dataset: datasets.Dataset | None = None,
@ -354,6 +383,8 @@ def lerobot_dataset_factory(
)
if not stats:
stats = stats_factory(features=info["features"])
if not episodes_stats:
episodes_stats = episodes_stats_factory(features=info["features"], total_episodes=total_episodes)
if not tasks:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episode_dicts:
@ -369,6 +400,7 @@ def lerobot_dataset_factory(
mock_snapshot_download = mock_snapshot_download_factory(
info=info,
stats=stats,
episodes_stats=episodes_stats,
tasks=tasks,
episodes=episode_dicts,
hf_dataset=hf_dataset,
@ -378,19 +410,26 @@ def lerobot_dataset_factory(
repo_id=repo_id,
info=info,
stats=stats,
episodes_stats=episodes_stats,
tasks=tasks,
episodes=episode_dicts,
local_files_only=kwargs.get("local_files_only", False),
)
with (
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,
):
mock_metadata_patch.return_value = mock_metadata
mock_get_safe_version_patch.side_effect = lambda repo_id, version: version
mock_snapshot_download_patch.side_effect = mock_snapshot_download
return LeRobotDataset(repo_id=repo_id, root=root, **kwargs)
return _create_lerobot_dataset
@pytest.fixture(scope="session")
def empty_lerobot_dataset_factory() -> LeRobotDatasetFactory:
return partial(LeRobotDataset.create, repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS)

View File

@ -7,7 +7,13 @@ import pyarrow.compute as pc
import pyarrow.parquet as pq
import pytest
from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH
from lerobot.common.datasets.utils import (
EPISODES_PATH,
EPISODES_STATS_PATH,
INFO_PATH,
STATS_PATH,
TASKS_PATH,
)
@pytest.fixture(scope="session")
@ -38,6 +44,20 @@ def stats_path(stats_factory):
return _create_stats_json_file
@pytest.fixture(scope="session")
def episodes_stats_path(episodes_stats_factory):
def _create_episodes_stats_jsonl_file(dir: Path, episodes_stats: list[dict] | None = None) -> Path:
if not episodes_stats:
episodes_stats = episodes_stats_factory()
fpath = dir / EPISODES_STATS_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(episodes_stats.values())
return fpath
return _create_episodes_stats_jsonl_file
@pytest.fixture(scope="session")
def tasks_path(tasks_factory):
def _create_tasks_jsonl_file(dir: Path, tasks: list | None = None) -> Path:
@ -46,7 +66,7 @@ def tasks_path(tasks_factory):
fpath = dir / TASKS_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(tasks)
writer.write_all(tasks.values())
return fpath
return _create_tasks_jsonl_file
@ -60,7 +80,7 @@ def episode_path(episodes_factory):
fpath = dir / EPISODES_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(episodes)
writer.write_all(episodes.values())
return fpath
return _create_episodes_jsonl_file

21
tests/fixtures/hub.py vendored
View File

@ -4,7 +4,13 @@ import datasets
import pytest
from huggingface_hub.utils import filter_repo_objects
from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH
from lerobot.common.datasets.utils import (
EPISODES_PATH,
EPISODES_STATS_PATH,
INFO_PATH,
STATS_PATH,
TASKS_PATH,
)
from tests.fixtures.constants import LEROBOT_TEST_DIR
@ -14,6 +20,8 @@ def mock_snapshot_download_factory(
info_path,
stats_factory,
stats_path,
episodes_stats_factory,
episodes_stats_path,
tasks_factory,
tasks_path,
episodes_factory,
@ -29,6 +37,7 @@ def mock_snapshot_download_factory(
def _mock_snapshot_download_func(
info: dict | None = None,
stats: dict | None = None,
episodes_stats: list[dict] | None = None,
tasks: list[dict] | None = None,
episodes: list[dict] | None = None,
hf_dataset: datasets.Dataset | None = None,
@ -37,6 +46,10 @@ def mock_snapshot_download_factory(
info = info_factory()
if not stats:
stats = stats_factory(features=info["features"])
if not episodes_stats:
episodes_stats = episodes_stats_factory(
features=info["features"], total_episodes=info["total_episodes"]
)
if not tasks:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
@ -67,11 +80,11 @@ def mock_snapshot_download_factory(
# List all possible files
all_files = []
meta_files = [INFO_PATH, STATS_PATH, TASKS_PATH, EPISODES_PATH]
meta_files = [INFO_PATH, STATS_PATH, EPISODES_STATS_PATH, TASKS_PATH, EPISODES_PATH]
all_files.extend(meta_files)
data_files = []
for episode_dict in episodes:
for episode_dict in episodes.values():
ep_idx = episode_dict["episode_index"]
ep_chunk = ep_idx // info["chunks_size"]
data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
@ -92,6 +105,8 @@ def mock_snapshot_download_factory(
_ = info_path(local_dir, info)
elif rel_path == STATS_PATH:
_ = stats_path(local_dir, stats)
elif rel_path == EPISODES_STATS_PATH:
_ = episodes_stats_path(local_dir, episodes_stats)
elif rel_path == TASKS_PATH:
_ = tasks_path(local_dir, tasks)
elif rel_path == EPISODES_PATH:

View File

@ -182,7 +182,7 @@ def test_camera(request, camera_type, mock):
@pytest.mark.parametrize("camera_type, mock", TEST_CAMERA_TYPES)
@require_camera
def test_save_images_from_cameras(tmpdir, request, camera_type, mock):
def test_save_images_from_cameras(tmp_path, request, camera_type, mock):
# TODO(rcadene): refactor
if camera_type == "opencv":
from lerobot.common.robot_devices.cameras.opencv import save_images_from_cameras
@ -190,4 +190,4 @@ def test_save_images_from_cameras(tmpdir, request, camera_type, mock):
from lerobot.common.robot_devices.cameras.intelrealsense import save_images_from_cameras
# Small `record_time_s` to speedup unit tests
save_images_from_cameras(tmpdir, record_time_s=0.02, mock=mock)
save_images_from_cameras(tmp_path, record_time_s=0.02, mock=mock)

311
tests/test_compute_stats.py Normal file
View File

@ -0,0 +1,311 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import patch
import numpy as np
import pytest
from lerobot.common.datasets.compute_stats import (
_assert_type_and_shape,
aggregate_feature_stats,
aggregate_stats,
compute_episode_stats,
estimate_num_samples,
get_feature_stats,
sample_images,
sample_indices,
)
def mock_load_image_as_numpy(path, dtype, channel_first):
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
@pytest.fixture
def sample_array():
return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
def test_estimate_num_samples():
assert estimate_num_samples(1) == 1
assert estimate_num_samples(10) == 10
assert estimate_num_samples(100) == 100
assert estimate_num_samples(200) == 100
assert estimate_num_samples(1000) == 177
assert estimate_num_samples(2000) == 299
assert estimate_num_samples(5000) == 594
assert estimate_num_samples(10_000) == 1000
assert estimate_num_samples(20_000) == 1681
assert estimate_num_samples(50_000) == 3343
assert estimate_num_samples(500_000) == 10_000
def test_sample_indices():
indices = sample_indices(10)
assert len(indices) > 0
assert indices[0] == 0
assert indices[-1] == 9
assert len(indices) == estimate_num_samples(10)
@patch("lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy)
def test_sample_images(mock_load):
image_paths = [f"image_{i}.jpg" for i in range(100)]
images = sample_images(image_paths)
assert isinstance(images, np.ndarray)
assert images.shape[1:] == (3, 32, 32)
assert images.dtype == np.uint8
assert len(images) == estimate_num_samples(100)
def test_get_feature_stats_images():
data = np.random.rand(100, 3, 32, 32)
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats
np.testing.assert_equal(stats["count"], np.array([100]))
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
def test_get_feature_stats_axis_0_keepdims(sample_array):
expected = {
"min": np.array([[1, 2, 3]]),
"max": np.array([[7, 8, 9]]),
"mean": np.array([[4.0, 5.0, 6.0]]),
"std": np.array([[2.44948974, 2.44948974, 2.44948974]]),
"count": np.array([3]),
}
result = get_feature_stats(sample_array, axis=(0,), keepdims=True)
for key in expected:
np.testing.assert_allclose(result[key], expected[key])
def test_get_feature_stats_axis_1(sample_array):
expected = {
"min": np.array([1, 4, 7]),
"max": np.array([3, 6, 9]),
"mean": np.array([2.0, 5.0, 8.0]),
"std": np.array([0.81649658, 0.81649658, 0.81649658]),
"count": np.array([3]),
}
result = get_feature_stats(sample_array, axis=(1,), keepdims=False)
for key in expected:
np.testing.assert_allclose(result[key], expected[key])
def test_get_feature_stats_no_axis(sample_array):
expected = {
"min": np.array(1),
"max": np.array(9),
"mean": np.array(5.0),
"std": np.array(2.5819889),
"count": np.array([3]),
}
result = get_feature_stats(sample_array, axis=None, keepdims=False)
for key in expected:
np.testing.assert_allclose(result[key], expected[key])
def test_get_feature_stats_empty_array():
array = np.array([])
with pytest.raises(ValueError):
get_feature_stats(array, axis=(0,), keepdims=True)
def test_get_feature_stats_single_value():
array = np.array([[1337]])
result = get_feature_stats(array, axis=None, keepdims=True)
np.testing.assert_equal(result["min"], np.array(1337))
np.testing.assert_equal(result["max"], np.array(1337))
np.testing.assert_equal(result["mean"], np.array(1337.0))
np.testing.assert_equal(result["std"], np.array(0.0))
np.testing.assert_equal(result["count"], np.array([1]))
def test_compute_episode_stats():
episode_data = {
"observation.image": [f"image_{i}.jpg" for i in range(100)],
"observation.state": np.random.rand(100, 10),
}
features = {
"observation.image": {"dtype": "image"},
"observation.state": {"dtype": "numeric"},
}
with patch(
"lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy
):
stats = compute_episode_stats(episode_data, features)
assert "observation.image" in stats and "observation.state" in stats
assert stats["observation.image"]["count"].item() == 100
assert stats["observation.state"]["count"].item() == 100
assert stats["observation.image"]["mean"].shape == (3, 1, 1)
def test_assert_type_and_shape_valid():
valid_stats = [
{
"feature1": {
"min": np.array([1.0]),
"max": np.array([10.0]),
"mean": np.array([5.0]),
"std": np.array([2.0]),
"count": np.array([1]),
}
}
]
_assert_type_and_shape(valid_stats)
def test_assert_type_and_shape_invalid_type():
invalid_stats = [
{
"feature1": {
"min": [1.0], # Not a numpy array
"max": np.array([10.0]),
"mean": np.array([5.0]),
"std": np.array([2.0]),
"count": np.array([1]),
}
}
]
with pytest.raises(ValueError, match="Stats must be composed of numpy array"):
_assert_type_and_shape(invalid_stats)
def test_assert_type_and_shape_invalid_shape():
invalid_stats = [
{
"feature1": {
"count": np.array([1, 2]), # Wrong shape
}
}
]
with pytest.raises(ValueError, match=r"Shape of 'count' must be \(1\)"):
_assert_type_and_shape(invalid_stats)
def test_aggregate_feature_stats():
stats_ft_list = [
{
"min": np.array([1.0]),
"max": np.array([10.0]),
"mean": np.array([5.0]),
"std": np.array([2.0]),
"count": np.array([1]),
},
{
"min": np.array([2.0]),
"max": np.array([12.0]),
"mean": np.array([6.0]),
"std": np.array([2.5]),
"count": np.array([1]),
},
]
result = aggregate_feature_stats(stats_ft_list)
np.testing.assert_allclose(result["min"], np.array([1.0]))
np.testing.assert_allclose(result["max"], np.array([12.0]))
np.testing.assert_allclose(result["mean"], np.array([5.5]))
np.testing.assert_allclose(result["std"], np.array([2.318405]), atol=1e-6)
np.testing.assert_allclose(result["count"], np.array([2]))
def test_aggregate_stats():
all_stats = [
{
"observation.image": {
"min": [1, 2, 3],
"max": [10, 20, 30],
"mean": [5.5, 10.5, 15.5],
"std": [2.87, 5.87, 8.87],
"count": 10,
},
"observation.state": {"min": 1, "max": 10, "mean": 5.5, "std": 2.87, "count": 10},
"extra_key_0": {"min": 5, "max": 25, "mean": 15, "std": 6, "count": 6},
},
{
"observation.image": {
"min": [2, 1, 0],
"max": [15, 10, 5],
"mean": [8.5, 5.5, 2.5],
"std": [3.42, 2.42, 1.42],
"count": 15,
},
"observation.state": {"min": 2, "max": 15, "mean": 8.5, "std": 3.42, "count": 15},
"extra_key_1": {"min": 0, "max": 20, "mean": 10, "std": 5, "count": 5},
},
]
expected_agg_stats = {
"observation.image": {
"min": [1, 1, 0],
"max": [15, 20, 30],
"mean": [7.3, 7.5, 7.7],
"std": [3.5317, 4.8267, 8.5581],
"count": 25,
},
"observation.state": {
"min": 1,
"max": 15,
"mean": 7.3,
"std": 3.5317,
"count": 25,
},
"extra_key_0": {
"min": 5,
"max": 25,
"mean": 15.0,
"std": 6.0,
"count": 6,
},
"extra_key_1": {
"min": 0,
"max": 20,
"mean": 10.0,
"std": 5.0,
"count": 5,
},
}
# cast to numpy
for ep_stats in all_stats:
for fkey, stats in ep_stats.items():
for k in stats:
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
if fkey == "observation.image" and k != "count":
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
else:
stats[k] = stats[k].reshape(1)
# cast to numpy
for fkey, stats in expected_agg_stats.items():
for k in stats:
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
if fkey == "observation.image" and k != "count":
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
else:
stats[k] = stats[k].reshape(1)
results = aggregate_stats(all_stats)
for fkey in expected_agg_stats:
np.testing.assert_allclose(results[fkey]["min"], expected_agg_stats[fkey]["min"])
np.testing.assert_allclose(results[fkey]["max"], expected_agg_stats[fkey]["max"])
np.testing.assert_allclose(results[fkey]["mean"], expected_agg_stats[fkey]["mean"])
np.testing.assert_allclose(
results[fkey]["std"], expected_agg_stats[fkey]["std"], atol=1e-04, rtol=1e-04
)
np.testing.assert_allclose(results[fkey]["count"], expected_agg_stats[fkey]["count"])

View File

@ -24,7 +24,6 @@ pytest -sx 'tests/test_control_robot.py::test_teleoperate[aloha-True]'
"""
import multiprocessing
from pathlib import Path
from unittest.mock import patch
import pytest
@ -45,7 +44,7 @@ from tests.utils import DEVICE, TEST_ROBOT_TYPES, mock_calibration_dir, require_
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
@require_robot
def test_teleoperate(tmpdir, request, robot_type, mock):
def test_teleoperate(tmp_path, request, robot_type, mock):
robot_kwargs = {"robot_type": robot_type, "mock": mock}
if mock and robot_type != "aloha":
@ -53,8 +52,7 @@ def test_teleoperate(tmpdir, request, robot_type, mock):
# Create an empty calibration directory to trigger manual calibration
# and avoid writing calibration files in user .cache/calibration folder
tmpdir = Path(tmpdir)
calibration_dir = tmpdir / robot_type
calibration_dir = tmp_path / robot_type
mock_calibration_dir(calibration_dir)
robot_kwargs["calibration_dir"] = calibration_dir
else:
@ -70,15 +68,14 @@ def test_teleoperate(tmpdir, request, robot_type, mock):
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
@require_robot
def test_calibrate(tmpdir, request, robot_type, mock):
def test_calibrate(tmp_path, request, robot_type, mock):
robot_kwargs = {"robot_type": robot_type, "mock": mock}
if mock:
request.getfixturevalue("patch_builtins_input")
# Create an empty calibration directory to trigger manual calibration
tmpdir = Path(tmpdir)
calibration_dir = tmpdir / robot_type
calibration_dir = tmp_path / robot_type
robot_kwargs["calibration_dir"] = calibration_dir
robot = make_robot(**robot_kwargs)
@ -89,7 +86,7 @@ def test_calibrate(tmpdir, request, robot_type, mock):
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
@require_robot
def test_record_without_cameras(tmpdir, request, robot_type, mock):
def test_record_without_cameras(tmp_path, request, robot_type, mock):
robot_kwargs = {"robot_type": robot_type, "mock": mock}
# Avoid using cameras
@ -100,7 +97,7 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
# Create an empty calibration directory to trigger manual calibration
# and avoid writing calibration files in user .cache/calibration folder
calibration_dir = Path(tmpdir) / robot_type
calibration_dir = tmp_path / robot_type
mock_calibration_dir(calibration_dir)
robot_kwargs["calibration_dir"] = calibration_dir
else:
@ -108,7 +105,7 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
pass
repo_id = "lerobot/debug"
root = Path(tmpdir) / "data" / repo_id
root = tmp_path / "data" / repo_id
single_task = "Do something."
robot = make_robot(**robot_kwargs)
@ -121,7 +118,6 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
episode_time_s=1,
reset_time_s=0.1,
num_episodes=2,
run_compute_stats=False,
push_to_hub=False,
video=False,
play_sounds=False,
@ -131,8 +127,7 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
@require_robot
def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
tmpdir = Path(tmpdir)
def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
robot_kwargs = {"robot_type": robot_type, "mock": mock}
if mock and robot_type != "aloha":
@ -140,7 +135,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
# Create an empty calibration directory to trigger manual calibration
# and avoid writing calibration files in user .cache/calibration folder
calibration_dir = tmpdir / robot_type
calibration_dir = tmp_path / robot_type
mock_calibration_dir(calibration_dir)
robot_kwargs["calibration_dir"] = calibration_dir
else:
@ -148,7 +143,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
pass
repo_id = "lerobot_test/debug"
root = tmpdir / "data" / repo_id
root = tmp_path / "data" / repo_id
single_task = "Do something."
robot = make_robot(**robot_kwargs)
@ -172,15 +167,13 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
assert dataset.meta.total_episodes == 2
assert len(dataset) == 2
replay_cfg = ReplayControlConfig(
episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False, local_files_only=True
)
replay_cfg = ReplayControlConfig(episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False)
replay(robot, replay_cfg)
policy_cfg = ACTConfig()
policy = make_policy(policy_cfg, ds_meta=dataset.meta, device=DEVICE)
out_dir = tmpdir / "logger"
out_dir = tmp_path / "logger"
pretrained_policy_path = out_dir / "checkpoints/last/pretrained_model"
policy.save_pretrained(pretrained_policy_path)
@ -207,7 +200,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
num_image_writer_processes = 0
eval_repo_id = "lerobot/eval_debug"
eval_root = tmpdir / "data" / eval_repo_id
eval_root = tmp_path / "data" / eval_repo_id
rec_eval_cfg = RecordControlConfig(
repo_id=eval_repo_id,
@ -218,7 +211,6 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
episode_time_s=1,
reset_time_s=0.1,
num_episodes=2,
run_compute_stats=False,
push_to_hub=False,
video=False,
display_cameras=False,
@ -240,7 +232,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
@require_robot
def test_resume_record(tmpdir, request, robot_type, mock):
def test_resume_record(tmp_path, request, robot_type, mock):
robot_kwargs = {"robot_type": robot_type, "mock": mock}
if mock and robot_type != "aloha":
@ -248,7 +240,7 @@ def test_resume_record(tmpdir, request, robot_type, mock):
# Create an empty calibration directory to trigger manual calibration
# and avoid writing calibration files in user .cache/calibration folder
calibration_dir = tmpdir / robot_type
calibration_dir = tmp_path / robot_type
mock_calibration_dir(calibration_dir)
robot_kwargs["calibration_dir"] = calibration_dir
else:
@ -258,7 +250,7 @@ def test_resume_record(tmpdir, request, robot_type, mock):
robot = make_robot(**robot_kwargs)
repo_id = "lerobot/debug"
root = Path(tmpdir) / "data" / repo_id
root = tmp_path / "data" / repo_id
single_task = "Do something."
rec_cfg = RecordControlConfig(
@ -272,8 +264,6 @@ def test_resume_record(tmpdir, request, robot_type, mock):
video=False,
display_cameras=False,
play_sounds=False,
run_compute_stats=False,
local_files_only=True,
num_episodes=1,
)
@ -291,7 +281,7 @@ def test_resume_record(tmpdir, request, robot_type, mock):
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
@require_robot
def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
def test_record_with_event_rerecord_episode(tmp_path, request, robot_type, mock):
robot_kwargs = {"robot_type": robot_type, "mock": mock}
if mock and robot_type != "aloha":
@ -299,7 +289,7 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
# Create an empty calibration directory to trigger manual calibration
# and avoid writing calibration files in user .cache/calibration folder
calibration_dir = tmpdir / robot_type
calibration_dir = tmp_path / robot_type
mock_calibration_dir(calibration_dir)
robot_kwargs["calibration_dir"] = calibration_dir
else:
@ -316,7 +306,7 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
mock_listener.return_value = (None, mock_events)
repo_id = "lerobot/debug"
root = Path(tmpdir) / "data" / repo_id
root = tmp_path / "data" / repo_id
single_task = "Do something."
rec_cfg = RecordControlConfig(
@ -331,7 +321,6 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
video=False,
display_cameras=False,
play_sounds=False,
run_compute_stats=False,
)
dataset = record(robot, rec_cfg)
@ -342,7 +331,7 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
@require_robot
def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
def test_record_with_event_exit_early(tmp_path, request, robot_type, mock):
robot_kwargs = {"robot_type": robot_type, "mock": mock}
if mock:
@ -350,7 +339,7 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
# Create an empty calibration directory to trigger manual calibration
# and avoid writing calibration files in user .cache/calibration folder
calibration_dir = tmpdir / robot_type
calibration_dir = tmp_path / robot_type
mock_calibration_dir(calibration_dir)
robot_kwargs["calibration_dir"] = calibration_dir
else:
@ -367,7 +356,7 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
mock_listener.return_value = (None, mock_events)
repo_id = "lerobot/debug"
root = Path(tmpdir) / "data" / repo_id
root = tmp_path / "data" / repo_id
single_task = "Do something."
rec_cfg = RecordControlConfig(
@ -382,7 +371,6 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
video=False,
display_cameras=False,
play_sounds=False,
run_compute_stats=False,
)
dataset = record(robot, rec_cfg)
@ -395,7 +383,7 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
"robot_type, mock, num_image_writer_processes", [("koch", True, 0), ("koch", True, 1)]
)
@require_robot
def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num_image_writer_processes):
def test_record_with_event_stop_recording(tmp_path, request, robot_type, mock, num_image_writer_processes):
robot_kwargs = {"robot_type": robot_type, "mock": mock}
if mock:
@ -403,7 +391,7 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num
# Create an empty calibration directory to trigger manual calibration
# and avoid writing calibration files in user .cache/calibration folder
calibration_dir = tmpdir / robot_type
calibration_dir = tmp_path / robot_type
mock_calibration_dir(calibration_dir)
robot_kwargs["calibration_dir"] = calibration_dir
else:
@ -420,7 +408,7 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num
mock_listener.return_value = (None, mock_events)
repo_id = "lerobot/debug"
root = Path(tmpdir) / "data" / repo_id
root = tmp_path / "data" / repo_id
single_task = "Do something."
rec_cfg = RecordControlConfig(
@ -436,7 +424,6 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num
video=False,
display_cameras=False,
play_sounds=False,
run_compute_stats=False,
num_image_writer_processes=num_image_writer_processes,
)

View File

@ -15,24 +15,21 @@
# limitations under the License.
import json
import logging
import re
from copy import deepcopy
from itertools import chain
from pathlib import Path
import einops
import numpy as np
import pytest
import torch
from datasets import Dataset
from huggingface_hub import HfApi
from PIL import Image
from safetensors.torch import load_file
import lerobot
from lerobot.common.datasets.compute_stats import (
aggregate_stats,
compute_stats,
get_stats_einops_patterns,
)
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.image_writer import image_array_to_pil_image
from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset,
MultiLeRobotDataset,
@ -40,20 +37,34 @@ from lerobot.common.datasets.lerobot_dataset import (
from lerobot.common.datasets.utils import (
create_branch,
flatten_dict,
hf_transform_to_torch,
unflatten_dict,
)
from lerobot.common.envs.factory import make_env_config
from lerobot.common.policies.factory import make_policy_config
from lerobot.common.robot_devices.robots.utils import make_robot
from lerobot.common.utils.random_utils import seeded_context
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from tests.fixtures.constants import DUMMY_REPO_ID
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
from tests.utils import DEVICE, require_x86_64_kernel
def test_same_attributes_defined(lerobot_dataset_factory, tmp_path):
@pytest.fixture
def image_dataset(tmp_path, empty_lerobot_dataset_factory):
features = {
"image": {
"dtype": "image",
"shape": DUMMY_CHW,
"names": [
"channels",
"height",
"width",
],
}
}
return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
"""
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
objects have the same sets of attributes defined.
@ -66,24 +77,20 @@ def test_same_attributes_defined(lerobot_dataset_factory, tmp_path):
root_init = tmp_path / "init"
dataset_init = lerobot_dataset_factory(root=root_init)
# Access the '_hub_version' cached_property in both instances to force its creation
_ = dataset_init.meta._hub_version
_ = dataset_create.meta._hub_version
init_attr = set(vars(dataset_init).keys())
create_attr = set(vars(dataset_create).keys())
assert init_attr == create_attr
def test_dataset_initialization(lerobot_dataset_factory, tmp_path):
def test_dataset_initialization(tmp_path, lerobot_dataset_factory):
kwargs = {
"repo_id": DUMMY_REPO_ID,
"total_episodes": 10,
"total_frames": 400,
"episodes": [2, 5, 6],
}
dataset = lerobot_dataset_factory(root=tmp_path, **kwargs)
dataset = lerobot_dataset_factory(root=tmp_path / "test", **kwargs)
assert dataset.repo_id == kwargs["repo_id"]
assert dataset.meta.total_episodes == kwargs["total_episodes"]
@ -93,12 +100,232 @@ def test_dataset_initialization(lerobot_dataset_factory, tmp_path):
assert dataset.num_frames == len(dataset)
def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n"
):
dataset.add_frame({"state": torch.randn(1)})
def test_add_frame_missing_feature(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n"
):
dataset.add_frame({"task": "Dummy task"})
def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError, match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n"
):
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"})
def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError, match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n"
):
dataset.add_frame({"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"})
def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError,
match=re.escape("The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"),
):
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
def test_add_frame_wrong_shape_python_float(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError,
match=re.escape(
"The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '<class 'float'>' provided instead.\n"
),
):
dataset.add_frame({"state": 1.0, "task": "Dummy task"})
def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError,
match=re.escape("The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"),
):
dataset.add_frame({"state": torch.tensor(1.0), "task": "Dummy task"})
def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError,
match=re.escape(
"The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '<class 'numpy.float32'>' provided instead.\n"
),
):
dataset.add_frame({"state": np.float32(1.0), "task": "Dummy task"})
def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
dataset.save_episode()
assert len(dataset) == 1
assert dataset[0]["task"] == "Dummy task"
assert dataset[0]["task_index"] == 0
assert dataset[0]["state"].ndim == 0
def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2), "task": "Dummy task"})
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2])
def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4), "task": "Dummy task"})
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2, 4])
def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4, 3), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "Dummy task"})
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2, 4, 3])
def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "Dummy task"})
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5])
def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5, 1), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "Dummy task"})
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1])
def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": np.array([1], dtype=np.float32), "task": "Dummy task"})
dataset.save_episode()
assert dataset[0]["state"].ndim == 0
def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory):
features = {"caption": {"dtype": "string", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"caption": "Dummy caption", "task": "Dummy task"})
dataset.save_episode()
assert dataset[0]["caption"] == "Dummy caption"
def test_add_frame_image_wrong_shape(image_dataset):
dataset = image_dataset
with pytest.raises(
ValueError,
match=re.escape(
"The feature 'image' of shape '(3, 128, 96)' does not have the expected shape '(3, 96, 128)' or '(96, 128, 3)'.\n"
),
):
c, h, w = DUMMY_CHW
dataset.add_frame({"image": torch.randn(c, w, h), "task": "Dummy task"})
def test_add_frame_image_wrong_range(image_dataset):
"""This test will display the following error message from a thread:
```
Error writing image ...test_add_frame_image_wrong_ran0/test/images/image/episode_000000/frame_000000.png:
The image data type is float, which requires values in the range [0.0, 1.0]. However, the provided range is [0.009678772038470007, 254.9776492089887].
Please adjust the range or provide a uint8 image with values in the range [0, 255]
```
Hence the image won't be saved on disk and save_episode will raise `FileNotFoundError`.
"""
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255, "task": "Dummy task"})
with pytest.raises(FileNotFoundError):
dataset.save_episode()
def test_add_frame_image(image_dataset):
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
dataset.save_episode()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
def test_add_frame_image_h_w_c(image_dataset):
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_HWC), "task": "Dummy task"})
dataset.save_episode()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
def test_add_frame_image_uint8(image_dataset):
dataset = image_dataset
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
dataset.add_frame({"image": image, "task": "Dummy task"})
dataset.save_episode()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
def test_add_frame_image_pil(image_dataset):
dataset = image_dataset
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
dataset.add_frame({"image": Image.fromarray(image), "task": "Dummy task"})
dataset.save_episode()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
def test_image_array_to_pil_image_wrong_range_float_0_255():
image = np.random.rand(*DUMMY_HWC) * 255
with pytest.raises(ValueError):
image_array_to_pil_image(image)
# TODO(aliberts):
# - [ ] test various attributes & state from init and create
# - [ ] test init with episodes and check num_frames
# - [ ] test add_frame
# - [ ] test add_episode
# - [ ] test consolidate
# - [ ] test push_to_hub
# - [ ] test smaller methods
@ -210,67 +437,6 @@ def test_multidataset_frames():
assert torch.equal(sub_dataset_item[k], dataset_item[k])
# TODO(aliberts, rcadene): Refactor and move this to a tests/test_compute_stats.py
def test_compute_stats_on_xarm():
"""Check that the statistics are computed correctly according to the stats_patterns property.
We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
because we are working with a small dataset).
"""
# TODO(rcadene, aliberts): remove dataset download
dataset = LeRobotDataset("lerobot/xarm_lift_medium", episodes=[0])
# reduce size of dataset sample on which stats compute is tested to 10 frames
dataset.hf_dataset = dataset.hf_dataset.select(range(10))
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
# dataset into even batches.
computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25), num_workers=0)
# get einops patterns to aggregate batches and compute statistics
stats_patterns = get_stats_einops_patterns(dataset)
# get all frames from the dataset in the same dtype and range as during compute_stats
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=len(dataset),
shuffle=False,
)
full_batch = next(iter(dataloader))
# compute stats based on all frames from the dataset without any batching
expected_stats = {}
for k, pattern in stats_patterns.items():
full_batch[k] = full_batch[k].float()
expected_stats[k] = {}
expected_stats[k]["mean"] = einops.reduce(full_batch[k], pattern, "mean")
expected_stats[k]["std"] = torch.sqrt(
einops.reduce((full_batch[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")
)
expected_stats[k]["min"] = einops.reduce(full_batch[k], pattern, "min")
expected_stats[k]["max"] = einops.reduce(full_batch[k], pattern, "max")
# test computed stats match expected stats
for k in stats_patterns:
assert torch.allclose(computed_stats[k]["mean"], expected_stats[k]["mean"])
assert torch.allclose(computed_stats[k]["std"], expected_stats[k]["std"])
assert torch.allclose(computed_stats[k]["min"], expected_stats[k]["min"])
assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"])
# load stats used during training which are expected to match the ones returned by computed_stats
loaded_stats = dataset.meta.stats # noqa: F841
# TODO(rcadene): we can't test this because expected_stats is computed on a subset
# # test loaded stats match expected stats
# for k in stats_patterns:
# assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"])
# assert torch.allclose(loaded_stats[k]["std"], expected_stats[k]["std"])
# assert torch.allclose(loaded_stats[k]["min"], expected_stats[k]["min"])
# assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"])
# TODO(aliberts): Move to more appropriate location
def test_flatten_unflatten_dict():
d = {
@ -374,35 +540,6 @@ def test_backward_compatibility(repo_id):
# load_and_compare(i - 1)
@pytest.mark.skip("TODO after fix multidataset")
def test_multidataset_aggregate_stats():
"""Makes 3 basic datasets and checks that aggregate stats are computed correctly."""
with seeded_context(0):
data_a = torch.rand(30, dtype=torch.float32)
data_b = torch.rand(20, dtype=torch.float32)
data_c = torch.rand(20, dtype=torch.float32)
hf_dataset_1 = Dataset.from_dict(
{"a": data_a[:10], "b": data_b[:10], "c": data_c[:10], "index": torch.arange(10)}
)
hf_dataset_1.set_transform(hf_transform_to_torch)
hf_dataset_2 = Dataset.from_dict({"a": data_a[10:20], "b": data_b[10:], "index": torch.arange(10)})
hf_dataset_2.set_transform(hf_transform_to_torch)
hf_dataset_3 = Dataset.from_dict({"a": data_a[20:], "c": data_c[10:], "index": torch.arange(10)})
hf_dataset_3.set_transform(hf_transform_to_torch)
dataset_1 = LeRobotDataset.from_preloaded("d1", hf_dataset=hf_dataset_1)
dataset_1.stats = compute_stats(dataset_1, batch_size=len(hf_dataset_1), num_workers=0)
dataset_2 = LeRobotDataset.from_preloaded("d2", hf_dataset=hf_dataset_2)
dataset_2.stats = compute_stats(dataset_2, batch_size=len(hf_dataset_2), num_workers=0)
dataset_3 = LeRobotDataset.from_preloaded("d3", hf_dataset=hf_dataset_3)
dataset_3.stats = compute_stats(dataset_3, batch_size=len(hf_dataset_3), num_workers=0)
stats = aggregate_stats([dataset_1, dataset_2, dataset_3])
for data_key, data in zip(["a", "b", "c"], [data_a, data_b, data_c], strict=True):
for agg_fn in ["mean", "min", "max"]:
assert torch.allclose(stats[data_key][agg_fn], einops.reduce(data, "n -> 1", agg_fn))
assert torch.allclose(stats[data_key]["std"], torch.std(data, correction=0))
@pytest.mark.skip("Requires internet access")
def test_create_branch():
api = HfApi()
@ -431,9 +568,9 @@ def test_create_branch():
def test_dataset_feature_with_forward_slash_raises_error():
# make sure dir does not exist
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
from lerobot.common.constants import HF_LEROBOT_HOME
dataset_dir = LEROBOT_HOME / "lerobot/test/with/slash"
dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash"
# make sure does not exist
if dataset_dir.exists():
dataset_dir.rmdir()

View File

@ -1,55 +1,78 @@
from itertools import accumulate
import datasets
import numpy as np
import pyarrow.compute as pc
import pytest
import torch
from datasets import Dataset
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
check_delta_timestamps,
check_timestamps_sync,
get_delta_indices,
hf_transform_to_torch,
)
from tests.fixtures.constants import DUMMY_MOTOR_FEATURES
@pytest.fixture(scope="module")
def synced_hf_dataset_factory(hf_dataset_factory):
def _create_synced_hf_dataset(fps: int = 30) -> Dataset:
return hf_dataset_factory(fps=fps)
def calculate_total_episode(
hf_dataset: datasets.Dataset, raise_if_not_contiguous: bool = True
) -> dict[str, torch.Tensor]:
episode_indices = sorted(hf_dataset.unique("episode_index"))
total_episodes = len(episode_indices)
if raise_if_not_contiguous and episode_indices != list(range(total_episodes)):
raise ValueError("episode_index values are not sorted and contiguous.")
return total_episodes
return _create_synced_hf_dataset
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.ndarray]:
episode_lengths = []
table = hf_dataset.data.table
total_episodes = calculate_total_episode(hf_dataset)
for ep_idx in range(total_episodes):
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
episode_lengths.insert(ep_idx, len(ep_table))
cumulative_lenghts = list(accumulate(episode_lengths))
return {
"from": np.array([0] + cumulative_lenghts[:-1], dtype=np.int64),
"to": np.array(cumulative_lenghts, dtype=np.int64),
}
@pytest.fixture(scope="module")
def unsynced_hf_dataset_factory(synced_hf_dataset_factory):
def _create_unsynced_hf_dataset(fps: int = 30, tolerance_s: float = 1e-4) -> Dataset:
hf_dataset = synced_hf_dataset_factory(fps=fps)
features = hf_dataset.features
df = hf_dataset.to_pandas()
dtype = df["timestamp"].dtype # This is to avoid pandas type warning
# Modify a single timestamp just outside tolerance
df.at[30, "timestamp"] = dtype.type(df.at[30, "timestamp"] + (tolerance_s * 1.1))
unsynced_hf_dataset = Dataset.from_pandas(df, features=features)
unsynced_hf_dataset.set_transform(hf_transform_to_torch)
return unsynced_hf_dataset
def synced_timestamps_factory(hf_dataset_factory):
def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
hf_dataset = hf_dataset_factory(fps=fps)
timestamps = torch.stack(hf_dataset["timestamp"]).numpy()
episode_indices = torch.stack(hf_dataset["episode_index"]).numpy()
episode_data_index = calculate_episode_data_index(hf_dataset)
return timestamps, episode_indices, episode_data_index
return _create_unsynced_hf_dataset
return _create_synced_timestamps
@pytest.fixture(scope="module")
def slightly_off_hf_dataset_factory(synced_hf_dataset_factory):
def _create_slightly_off_hf_dataset(fps: int = 30, tolerance_s: float = 1e-4) -> Dataset:
hf_dataset = synced_hf_dataset_factory(fps=fps)
features = hf_dataset.features
df = hf_dataset.to_pandas()
dtype = df["timestamp"].dtype # This is to avoid pandas type warning
# Modify a single timestamp just inside tolerance
df.at[30, "timestamp"] = dtype.type(df.at[30, "timestamp"] + (tolerance_s * 0.9))
unsynced_hf_dataset = Dataset.from_pandas(df, features=features)
unsynced_hf_dataset.set_transform(hf_transform_to_torch)
return unsynced_hf_dataset
def unsynced_timestamps_factory(synced_timestamps_factory):
def _create_unsynced_timestamps(
fps: int = 30, tolerance_s: float = 1e-4
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps)
timestamps[30] += tolerance_s * 1.1 # Modify a single timestamp just outside tolerance
return timestamps, episode_indices, episode_data_index
return _create_slightly_off_hf_dataset
return _create_unsynced_timestamps
@pytest.fixture(scope="module")
def slightly_off_timestamps_factory(synced_timestamps_factory):
def _create_slightly_off_timestamps(
fps: int = 30, tolerance_s: float = 1e-4
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps)
timestamps[30] += tolerance_s * 0.9 # Modify a single timestamp just inside tolerance
return timestamps, episode_indices, episode_data_index
return _create_slightly_off_timestamps
@pytest.fixture(scope="module")
@ -100,42 +123,42 @@ def delta_indices_factory():
return _delta_indices
def test_check_timestamps_sync_synced(synced_hf_dataset_factory):
def test_check_timestamps_sync_synced(synced_timestamps_factory):
fps = 30
tolerance_s = 1e-4
synced_hf_dataset = synced_hf_dataset_factory(fps)
episode_data_index = calculate_episode_data_index(synced_hf_dataset)
timestamps, ep_idx, ep_data_index = synced_timestamps_factory(fps)
result = check_timestamps_sync(
hf_dataset=synced_hf_dataset,
episode_data_index=episode_data_index,
timestamps=timestamps,
episode_indices=ep_idx,
episode_data_index=ep_data_index,
fps=fps,
tolerance_s=tolerance_s,
)
assert result is True
def test_check_timestamps_sync_unsynced(unsynced_hf_dataset_factory):
def test_check_timestamps_sync_unsynced(unsynced_timestamps_factory):
fps = 30
tolerance_s = 1e-4
unsynced_hf_dataset = unsynced_hf_dataset_factory(fps, tolerance_s)
episode_data_index = calculate_episode_data_index(unsynced_hf_dataset)
timestamps, ep_idx, ep_data_index = unsynced_timestamps_factory(fps, tolerance_s)
with pytest.raises(ValueError):
check_timestamps_sync(
hf_dataset=unsynced_hf_dataset,
episode_data_index=episode_data_index,
timestamps=timestamps,
episode_indices=ep_idx,
episode_data_index=ep_data_index,
fps=fps,
tolerance_s=tolerance_s,
)
def test_check_timestamps_sync_unsynced_no_exception(unsynced_hf_dataset_factory):
def test_check_timestamps_sync_unsynced_no_exception(unsynced_timestamps_factory):
fps = 30
tolerance_s = 1e-4
unsynced_hf_dataset = unsynced_hf_dataset_factory(fps, tolerance_s)
episode_data_index = calculate_episode_data_index(unsynced_hf_dataset)
timestamps, ep_idx, ep_data_index = unsynced_timestamps_factory(fps, tolerance_s)
result = check_timestamps_sync(
hf_dataset=unsynced_hf_dataset,
episode_data_index=episode_data_index,
timestamps=timestamps,
episode_indices=ep_idx,
episode_data_index=ep_data_index,
fps=fps,
tolerance_s=tolerance_s,
raise_value_error=False,
@ -143,14 +166,14 @@ def test_check_timestamps_sync_unsynced_no_exception(unsynced_hf_dataset_factory
assert result is False
def test_check_timestamps_sync_slightly_off(slightly_off_hf_dataset_factory):
def test_check_timestamps_sync_slightly_off(slightly_off_timestamps_factory):
fps = 30
tolerance_s = 1e-4
slightly_off_hf_dataset = slightly_off_hf_dataset_factory(fps, tolerance_s)
episode_data_index = calculate_episode_data_index(slightly_off_hf_dataset)
timestamps, ep_idx, ep_data_index = slightly_off_timestamps_factory(fps, tolerance_s)
result = check_timestamps_sync(
hf_dataset=slightly_off_hf_dataset,
episode_data_index=episode_data_index,
timestamps=timestamps,
episode_indices=ep_idx,
episode_data_index=ep_data_index,
fps=fps,
tolerance_s=tolerance_s,
)
@ -158,33 +181,13 @@ def test_check_timestamps_sync_slightly_off(slightly_off_hf_dataset_factory):
def test_check_timestamps_sync_single_timestamp():
single_timestamp_hf_dataset = Dataset.from_dict({"timestamp": [0.0], "episode_index": [0]})
single_timestamp_hf_dataset.set_transform(hf_transform_to_torch)
episode_data_index = {"to": torch.tensor([1]), "from": torch.tensor([0])}
fps = 30
tolerance_s = 1e-4
timestamps, ep_idx = np.array([0.0]), np.array([0])
episode_data_index = {"to": np.array([1]), "from": np.array([0])}
result = check_timestamps_sync(
hf_dataset=single_timestamp_hf_dataset,
episode_data_index=episode_data_index,
fps=fps,
tolerance_s=tolerance_s,
)
assert result is True
# TODO(aliberts): Change behavior of hf_transform_to_torch so that it can work with empty dataset
@pytest.mark.skip("TODO: fix")
def test_check_timestamps_sync_empty_dataset():
fps = 30
tolerance_s = 1e-4
empty_hf_dataset = Dataset.from_dict({"timestamp": [], "episode_index": []})
empty_hf_dataset.set_transform(hf_transform_to_torch)
episode_data_index = {
"to": torch.tensor([], dtype=torch.int64),
"from": torch.tensor([], dtype=torch.int64),
}
result = check_timestamps_sync(
hf_dataset=empty_hf_dataset,
timestamps=timestamps,
episode_indices=ep_idx,
episode_data_index=episode_data_index,
fps=fps,
tolerance_s=tolerance_s,

View File

@ -53,7 +53,7 @@ def test_example_1(tmp_path, lerobot_dataset_factory):
('repo_id = "lerobot/pusht"', f'repo_id = "{DUMMY_REPO_ID}"'),
(
"LeRobotDataset(repo_id",
f"LeRobotDataset(repo_id, root='{str(tmp_path)}', local_files_only=True",
f"LeRobotDataset(repo_id, root='{str(tmp_path)}'",
),
],
)

View File

@ -9,10 +9,11 @@ from PIL import Image
from lerobot.common.datasets.image_writer import (
AsyncImageWriter,
image_array_to_image,
image_array_to_pil_image,
safe_stop_image_writer,
write_image,
)
from tests.fixtures.constants import DUMMY_HWC
DUMMY_IMAGE = "test_image.png"
@ -48,49 +49,62 @@ def test_zero_threads():
AsyncImageWriter(num_processes=0, num_threads=0)
def test_image_array_to_image_rgb(img_array_factory):
def test_image_array_to_pil_image_float_array_wrong_range_0_255():
image = np.random.rand(*DUMMY_HWC) * 255
with pytest.raises(ValueError):
image_array_to_pil_image(image)
def test_image_array_to_pil_image_float_array_wrong_range_neg_1_1():
image = np.random.rand(*DUMMY_HWC) * 2 - 1
with pytest.raises(ValueError):
image_array_to_pil_image(image)
def test_image_array_to_pil_image_rgb(img_array_factory):
img_array = img_array_factory(100, 100)
result_image = image_array_to_image(img_array)
result_image = image_array_to_pil_image(img_array)
assert isinstance(result_image, Image.Image)
assert result_image.size == (100, 100)
assert result_image.mode == "RGB"
def test_image_array_to_image_pytorch_format(img_array_factory):
def test_image_array_to_pil_image_pytorch_format(img_array_factory):
img_array = img_array_factory(100, 100).transpose(2, 0, 1)
result_image = image_array_to_image(img_array)
result_image = image_array_to_pil_image(img_array)
assert isinstance(result_image, Image.Image)
assert result_image.size == (100, 100)
assert result_image.mode == "RGB"
@pytest.mark.skip("TODO: implement")
def test_image_array_to_image_single_channel(img_array_factory):
def test_image_array_to_pil_image_single_channel(img_array_factory):
img_array = img_array_factory(channels=1)
result_image = image_array_to_image(img_array)
assert isinstance(result_image, Image.Image)
assert result_image.size == (100, 100)
assert result_image.mode == "L"
with pytest.raises(NotImplementedError):
image_array_to_pil_image(img_array)
def test_image_array_to_image_float_array(img_array_factory):
def test_image_array_to_pil_image_4_channels(img_array_factory):
img_array = img_array_factory(channels=4)
with pytest.raises(NotImplementedError):
image_array_to_pil_image(img_array)
def test_image_array_to_pil_image_float_array(img_array_factory):
img_array = img_array_factory(dtype=np.float32)
result_image = image_array_to_image(img_array)
result_image = image_array_to_pil_image(img_array)
assert isinstance(result_image, Image.Image)
assert result_image.size == (100, 100)
assert result_image.mode == "RGB"
assert np.array(result_image).dtype == np.uint8
def test_image_array_to_image_out_of_bounds_float():
# Float array with values out of [0, 1]
img_array = np.random.uniform(-1, 2, size=(100, 100, 3)).astype(np.float32)
result_image = image_array_to_image(img_array)
def test_image_array_to_pil_image_uint8_array(img_array_factory):
img_array = img_array_factory(dtype=np.float32)
result_image = image_array_to_pil_image(img_array)
assert isinstance(result_image, Image.Image)
assert result_image.size == (100, 100)
assert result_image.mode == "RGB"
assert np.array(result_image).dtype == np.uint8
assert np.array(result_image).min() >= 0 and np.array(result_image).max() <= 255
def test_write_image_numpy(tmp_path, img_array_factory):

View File

@ -1,370 +0,0 @@
"""
This file contains generic tests to ensure that nothing breaks if we modify the push_dataset_to_hub API.
Also, this file contains backward compatibility tests. Because they are slow and require to download the raw datasets,
we skip them for now in our CI.
Example to run backward compatiblity tests locally:
```
python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility
```
"""
from pathlib import Path
import numpy as np
import pytest
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub.utils import save_images_concurrently
from lerobot.common.datasets.video_utils import encode_video_frames
from lerobot.scripts.push_dataset_to_hub import push_dataset_to_hub
from tests.utils import require_package_arg
def _mock_download_raw_pusht(raw_dir, num_frames=4, num_episodes=3):
import zarr
raw_dir.mkdir(parents=True, exist_ok=True)
zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr"
store = zarr.DirectoryStore(zarr_path)
zarr_data = zarr.group(store=store)
zarr_data.create_dataset(
"data/action", shape=(num_frames, 1), chunks=(num_frames, 1), dtype=np.float32, overwrite=True
)
zarr_data.create_dataset(
"data/img",
shape=(num_frames, 96, 96, 3),
chunks=(num_frames, 96, 96, 3),
dtype=np.uint8,
overwrite=True,
)
zarr_data.create_dataset(
"data/n_contacts", shape=(num_frames, 2), chunks=(num_frames, 2), dtype=np.float32, overwrite=True
)
zarr_data.create_dataset(
"data/state", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True
)
zarr_data.create_dataset(
"data/keypoint", shape=(num_frames, 9, 2), chunks=(num_frames, 9, 2), dtype=np.float32, overwrite=True
)
zarr_data.create_dataset(
"meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True
)
zarr_data["data/action"][:] = np.random.randn(num_frames, 1)
zarr_data["data/img"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8)
zarr_data["data/n_contacts"][:] = np.random.randn(num_frames, 2)
zarr_data["data/state"][:] = np.random.randn(num_frames, 5)
zarr_data["data/keypoint"][:] = np.random.randn(num_frames, 9, 2)
zarr_data["meta/episode_ends"][:] = np.array([1, 3, 4])
store.close()
def _mock_download_raw_umi(raw_dir, num_frames=4, num_episodes=3):
import zarr
raw_dir.mkdir(parents=True, exist_ok=True)
zarr_path = raw_dir / "cup_in_the_wild.zarr"
store = zarr.DirectoryStore(zarr_path)
zarr_data = zarr.group(store=store)
zarr_data.create_dataset(
"data/camera0_rgb",
shape=(num_frames, 96, 96, 3),
chunks=(num_frames, 96, 96, 3),
dtype=np.uint8,
overwrite=True,
)
zarr_data.create_dataset(
"data/robot0_demo_end_pose",
shape=(num_frames, 5),
chunks=(num_frames, 5),
dtype=np.float32,
overwrite=True,
)
zarr_data.create_dataset(
"data/robot0_demo_start_pose",
shape=(num_frames, 5),
chunks=(num_frames, 5),
dtype=np.float32,
overwrite=True,
)
zarr_data.create_dataset(
"data/robot0_eef_pos", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True
)
zarr_data.create_dataset(
"data/robot0_eef_rot_axis_angle",
shape=(num_frames, 5),
chunks=(num_frames, 5),
dtype=np.float32,
overwrite=True,
)
zarr_data.create_dataset(
"data/robot0_gripper_width",
shape=(num_frames, 5),
chunks=(num_frames, 5),
dtype=np.float32,
overwrite=True,
)
zarr_data.create_dataset(
"meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True
)
zarr_data["data/camera0_rgb"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8)
zarr_data["data/robot0_demo_end_pose"][:] = np.random.randn(num_frames, 5)
zarr_data["data/robot0_demo_start_pose"][:] = np.random.randn(num_frames, 5)
zarr_data["data/robot0_eef_pos"][:] = np.random.randn(num_frames, 5)
zarr_data["data/robot0_eef_rot_axis_angle"][:] = np.random.randn(num_frames, 5)
zarr_data["data/robot0_gripper_width"][:] = np.random.randn(num_frames, 5)
zarr_data["meta/episode_ends"][:] = np.array([1, 3, 4])
store.close()
def _mock_download_raw_xarm(raw_dir, num_frames=4):
import pickle
dataset_dict = {
"observations": {
"rgb": np.random.randint(0, 255, size=(num_frames, 3, 84, 84), dtype=np.uint8),
"state": np.random.randn(num_frames, 4),
},
"actions": np.random.randn(num_frames, 3),
"rewards": np.random.randn(num_frames),
"masks": np.random.randn(num_frames),
"dones": np.array([False, True, True, True]),
}
raw_dir.mkdir(parents=True, exist_ok=True)
pkl_path = raw_dir / "buffer.pkl"
with open(pkl_path, "wb") as f:
pickle.dump(dataset_dict, f)
def _mock_download_raw_aloha(raw_dir, num_frames=6, num_episodes=3):
import h5py
for ep_idx in range(num_episodes):
raw_dir.mkdir(parents=True, exist_ok=True)
path_h5 = raw_dir / f"episode_{ep_idx}.hdf5"
with h5py.File(str(path_h5), "w") as f:
f.create_dataset("action", data=np.random.randn(num_frames // num_episodes, 14))
f.create_dataset("observations/qpos", data=np.random.randn(num_frames // num_episodes, 14))
f.create_dataset("observations/qvel", data=np.random.randn(num_frames // num_episodes, 14))
f.create_dataset(
"observations/images/top",
data=np.random.randint(
0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8
),
)
def _mock_download_raw_dora(raw_dir, num_frames=6, num_episodes=3, fps=30):
from datetime import datetime, timedelta, timezone
import pandas
def write_parquet(key, timestamps, values):
data = {
"timestamp_utc": timestamps,
key: values,
}
df = pandas.DataFrame(data)
raw_dir.mkdir(parents=True, exist_ok=True)
df.to_parquet(raw_dir / f"{key}.parquet", engine="pyarrow")
episode_indices = [None, None, -1, None, None, -1, None, None, -1]
episode_indices_mapping = [0, 0, 0, 1, 1, 1, 2, 2, 2]
frame_indices = [0, 1, -1, 0, 1, -1, 0, 1, -1]
cam_key = "observation.images.cam_high"
timestamps = []
actions = []
states = []
frames = []
# `+ num_episodes`` for buffer frames associated to episode_index=-1
for i, frame_idx in enumerate(frame_indices):
t_utc = datetime.now(timezone.utc) + timedelta(seconds=i / fps)
action = np.random.randn(21).tolist()
state = np.random.randn(21).tolist()
ep_idx = episode_indices_mapping[i]
frame = [{"path": f"videos/{cam_key}_episode_{ep_idx:06d}.mp4", "timestamp": frame_idx / fps}]
timestamps.append(t_utc)
actions.append(action)
states.append(state)
frames.append(frame)
write_parquet(cam_key, timestamps, frames)
write_parquet("observation.state", timestamps, states)
write_parquet("action", timestamps, actions)
write_parquet("episode_index", timestamps, episode_indices)
# write fake mp4 file for each episode
for ep_idx in range(num_episodes):
imgs_array = np.random.randint(0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8)
tmp_imgs_dir = raw_dir / "tmp_images"
save_images_concurrently(imgs_array, tmp_imgs_dir)
fname = f"{cam_key}_episode_{ep_idx:06d}.mp4"
video_path = raw_dir / "videos" / fname
encode_video_frames(tmp_imgs_dir, video_path, fps, vcodec="libx264")
def _mock_download_raw(raw_dir, repo_id):
if "wrist_gripper" in repo_id:
_mock_download_raw_dora(raw_dir)
elif "aloha" in repo_id:
_mock_download_raw_aloha(raw_dir)
elif "pusht" in repo_id:
_mock_download_raw_pusht(raw_dir)
elif "xarm" in repo_id:
_mock_download_raw_xarm(raw_dir)
elif "umi" in repo_id:
_mock_download_raw_umi(raw_dir)
else:
raise ValueError(repo_id)
@pytest.mark.skip("push_dataset_to_hub is deprecated")
def test_push_dataset_to_hub_invalid_repo_id(tmpdir):
with pytest.raises(ValueError):
push_dataset_to_hub(Path(tmpdir), "raw_format", "invalid_repo_id")
@pytest.mark.skip("push_dataset_to_hub is deprecated")
def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir):
tmpdir = Path(tmpdir)
out_dir = tmpdir / "out"
raw_dir = tmpdir / "raw"
# mkdir to skip download
raw_dir.mkdir(parents=True, exist_ok=True)
with pytest.raises(ValueError):
push_dataset_to_hub(
raw_dir=raw_dir,
raw_format="some_format",
repo_id="user/dataset",
local_dir=out_dir,
force_override=False,
)
@pytest.mark.skip("push_dataset_to_hub is deprecated")
@pytest.mark.parametrize(
"required_packages, raw_format, repo_id, make_test_data",
[
(["gym_pusht"], "pusht_zarr", "lerobot/pusht", False),
(["gym_pusht"], "pusht_zarr", "lerobot/pusht", True),
(None, "xarm_pkl", "lerobot/xarm_lift_medium", False),
(None, "aloha_hdf5", "lerobot/aloha_sim_insertion_scripted", False),
(["imagecodecs"], "umi_zarr", "lerobot/umi_cup_in_the_wild", False),
(None, "dora_parquet", "cadene/wrist_gripper", False),
],
)
@require_package_arg
def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_id, make_test_data):
num_episodes = 3
tmpdir = Path(tmpdir)
raw_dir = tmpdir / f"{repo_id}_raw"
_mock_download_raw(raw_dir, repo_id)
local_dir = tmpdir / repo_id
lerobot_dataset = push_dataset_to_hub(
raw_dir=raw_dir,
raw_format=raw_format,
repo_id=repo_id,
push_to_hub=False,
local_dir=local_dir,
force_override=False,
cache_dir=tmpdir / "cache",
tests_data_dir=tmpdir / "tests/data" if make_test_data else None,
encoding={"vcodec": "libx264"},
)
# minimal generic tests on the local directory containing LeRobotDataset
assert (local_dir / "meta_data" / "info.json").exists()
assert (local_dir / "meta_data" / "stats.safetensors").exists()
assert (local_dir / "meta_data" / "episode_data_index.safetensors").exists()
for i in range(num_episodes):
for cam_key in lerobot_dataset.camera_keys:
assert (local_dir / "videos" / f"{cam_key}_episode_{i:06d}.mp4").exists()
assert (local_dir / "train" / "dataset_info.json").exists()
assert (local_dir / "train" / "state.json").exists()
assert len(list((local_dir / "train").glob("*.arrow"))) > 0
# minimal generic tests on the item
item = lerobot_dataset[0]
assert "index" in item
assert "episode_index" in item
assert "timestamp" in item
for cam_key in lerobot_dataset.camera_keys:
assert cam_key in item
if make_test_data:
# Check that only the first episode is selected.
test_dataset = LeRobotDataset(repo_id=repo_id, root=tmpdir / "tests/data")
num_frames = sum(
i == lerobot_dataset.hf_dataset["episode_index"][0]
for i in lerobot_dataset.hf_dataset["episode_index"]
).item()
assert (
test_dataset.hf_dataset["episode_index"]
== lerobot_dataset.hf_dataset["episode_index"][:num_frames]
)
for k in ["from", "to"]:
assert torch.equal(test_dataset.episode_data_index[k], lerobot_dataset.episode_data_index[k][:1])
@pytest.mark.skip("push_dataset_to_hub is deprecated")
@pytest.mark.parametrize(
"raw_format, repo_id",
[
# TODO(rcadene): add raw dataset test artifacts
("pusht_zarr", "lerobot/pusht"),
("xarm_pkl", "lerobot/xarm_lift_medium"),
("aloha_hdf5", "lerobot/aloha_sim_insertion_scripted"),
("umi_zarr", "lerobot/umi_cup_in_the_wild"),
("dora_parquet", "cadene/wrist_gripper"),
],
)
def test_push_dataset_to_hub_pusht_backward_compatibility(tmpdir, raw_format, repo_id):
_, dataset_id = repo_id.split("/")
tmpdir = Path(tmpdir)
raw_dir = tmpdir / f"{dataset_id}_raw"
local_dir = tmpdir / repo_id
push_dataset_to_hub(
raw_dir=raw_dir,
raw_format=raw_format,
repo_id=repo_id,
push_to_hub=False,
local_dir=local_dir,
force_override=False,
cache_dir=tmpdir / "cache",
episodes=[0],
)
ds_actual = LeRobotDataset(repo_id, root=tmpdir)
ds_reference = LeRobotDataset(repo_id)
assert len(ds_reference.hf_dataset) == len(ds_actual.hf_dataset)
def check_same_items(item1, item2):
assert item1.keys() == item2.keys(), "Keys mismatch"
for key in item1:
if isinstance(item1[key], torch.Tensor) and isinstance(item2[key], torch.Tensor):
assert torch.equal(item1[key], item2[key]), f"Mismatch found in key: {key}"
else:
assert item1[key] == item2[key], f"Mismatch found in key: {key}"
for i in range(len(ds_reference.hf_dataset)):
item_reference = ds_reference.hf_dataset[i]
item_actual = ds_actual.hf_dataset[i]
check_same_items(item_reference, item_actual)

View File

@ -23,8 +23,6 @@ pytest -sx 'tests/test_robots.py::test_robot[aloha-True]'
```
"""
from pathlib import Path
import pytest
import torch
@ -35,7 +33,7 @@ from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES)
@require_robot
def test_robot(tmpdir, request, robot_type, mock):
def test_robot(tmp_path, request, robot_type, mock):
# TODO(rcadene): measure fps in nightly?
# TODO(rcadene): test logs
# TODO(rcadene): add compatibility with other robots
@ -50,8 +48,7 @@ def test_robot(tmpdir, request, robot_type, mock):
request.getfixturevalue("patch_builtins_input")
# Create an empty calibration directory to trigger manual calibration
tmpdir = Path(tmpdir)
calibration_dir = tmpdir / robot_type
calibration_dir = tmp_path / robot_type
mock_calibration_dir(calibration_dir)
robot_kwargs["calibration_dir"] = calibration_dir