diff --git a/README.md b/README.md index 5125ace5..59929341 100644 --- a/README.md +++ b/README.md @@ -210,7 +210,7 @@ A `LeRobotDataset` is serialised using several widespread file formats for each - videos are stored in mp4 format to save space - metadata are stored in plain json/jsonl files -Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can use the `local_files_only` argument and specify its location with the `root` argument if it's not in the default `~/.cache/huggingface/lerobot` location. +Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can specify its location with the `root` argument if it's not in the default `~/.cache/huggingface/lerobot` location. ### Evaluate a pretrained policy diff --git a/examples/10_use_so100.md b/examples/10_use_so100.md index f7efcb45..b39a0239 100644 --- a/examples/10_use_so100.md +++ b/examples/10_use_so100.md @@ -335,7 +335,7 @@ python lerobot/scripts/control_robot.py \ --control.push_to_hub=true ``` -Note: You can resume recording by adding `--control.resume=true`. Also if you didn't push your dataset yet, add `--control.local_files_only=true`. +Note: You can resume recording by adding `--control.resume=true`. ## H. Visualize a dataset @@ -363,8 +363,6 @@ python lerobot/scripts/control_robot.py \ --control.episode=0 ``` -Note: If you didn't push your dataset yet, add `--control.local_files_only=true`. - ## J. Train a policy To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: @@ -378,8 +376,6 @@ python lerobot/scripts/train.py \ --wandb.enable=true ``` -Note: If you didn't push your dataset yet, add `--control.local_files_only=true`. - Let's explain it: 1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/so100_test`. 2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset. diff --git a/examples/11_use_lekiwi.md b/examples/11_use_lekiwi.md index f10a9396..a7024cc6 100644 --- a/examples/11_use_lekiwi.md +++ b/examples/11_use_lekiwi.md @@ -391,7 +391,7 @@ python lerobot/scripts/control_robot.py \ --control.push_to_hub=true ``` -Note: You can resume recording by adding `--control.resume=true`. Also if you didn't push your dataset yet, add `--control.local_files_only=true`. +Note: You can resume recording by adding `--control.resume=true`. # H. Visualize a dataset @@ -418,8 +418,6 @@ python lerobot/scripts/control_robot.py \ --control.episode=0 ``` -Note: If you didn't push your dataset yet, add `--control.local_files_only=true`. - ## J. Train a policy To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: @@ -433,8 +431,6 @@ python lerobot/scripts/train.py \ --wandb.enable=true ``` -Note: If you didn't push your dataset yet, add `--control.local_files_only=true`. - Let's explain it: 1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/lekiwi_test`. 2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset. diff --git a/examples/11_use_moss.md b/examples/11_use_moss.md index e35ba9b2..2bbfbb18 100644 --- a/examples/11_use_moss.md +++ b/examples/11_use_moss.md @@ -256,7 +256,7 @@ python lerobot/scripts/control_robot.py \ --control.push_to_hub=true ``` -Note: You can resume recording by adding `--control.resume=true`. Also if you didn't push your dataset yet, add `--control.local_files_only=true`. +Note: You can resume recording by adding `--control.resume=true`. ## Visualize a dataset @@ -284,8 +284,6 @@ python lerobot/scripts/control_robot.py \ --control.episode=0 ``` -Note: If you didn't push your dataset yet, add `--control.local_files_only=true`. - ## Train a policy To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: @@ -299,8 +297,6 @@ python lerobot/scripts/train.py \ --wandb.enable=true ``` -Note: If you didn't push your dataset yet, add `--control.local_files_only=true`. - Let's explain it: 1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/moss_test`. 2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset. diff --git a/examples/7_get_started_with_real_robot.md b/examples/7_get_started_with_real_robot.md index e57d783a..638b54d3 100644 --- a/examples/7_get_started_with_real_robot.md +++ b/examples/7_get_started_with_real_robot.md @@ -768,7 +768,7 @@ You can use the `record` function from [`lerobot/scripts/control_robot.py`](../l 1. Frames from cameras are saved on disk in threads, and encoded into videos at the end of each episode recording. 2. Video streams from cameras are displayed in window so that you can verify them. 3. Data is stored with [`LeRobotDataset`](../lerobot/common/datasets/lerobot_dataset.py) format which is pushed to your Hugging Face page (unless `--control.push_to_hub=false` is provided). -4. Checkpoints are done during recording, so if any issue occurs, you can resume recording by re-running the same command again with `--control.resume=true`. You might need to add `--control.local_files_only=true` if your dataset was not uploaded to hugging face hub. Also you will need to manually delete the dataset directory to start recording from scratch. +4. Checkpoints are done during recording, so if any issue occurs, you can resume recording by re-running the same command again with `--control.resume=true`. You will need to manually delete the dataset directory if you want to start recording from scratch. 5. Set the flow of data recording using command line arguments: - `--control.warmup_time_s=10` defines the number of seconds before starting data collection. It allows the robot devices to warmup and synchronize (10 seconds by default). - `--control.episode_time_s=60` defines the number of seconds for data recording for each episode (60 seconds by default). @@ -883,8 +883,6 @@ python lerobot/scripts/control_robot.py \ --control.episode=0 ``` -Note: You might need to add `--control.local_files_only=true` if your dataset was not uploaded to hugging face hub. - Your robot should replicate movements similar to those you recorded. For example, check out [this video](https://x.com/RemiCadene/status/1793654950905680090) where we use `replay` on a Aloha robot from [Trossen Robotics](https://www.trossenrobotics.com). ## 4. Train a policy on your data @@ -902,8 +900,6 @@ python lerobot/scripts/train.py \ --wandb.enable=true ``` -Note: You might need to add `--dataset.local_files_only=true` if your dataset was not uploaded to hugging face hub. - Let's explain it: 1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/koch_test`. 2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset. diff --git a/examples/port_datasets/pusht_zarr.py b/examples/port_datasets/pusht_zarr.py index 1506f427..eac6f63d 100644 --- a/examples/port_datasets/pusht_zarr.py +++ b/examples/port_datasets/pusht_zarr.py @@ -2,9 +2,10 @@ import shutil from pathlib import Path import numpy as np -import torch +from huggingface_hub import HfApi -from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset +from lerobot.common.constants import HF_LEROBOT_HOME +from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw PUSHT_TASK = "Push the T-shaped blue block onto the T-shaped green target surface." @@ -89,9 +90,9 @@ def calculate_coverage(zarr_data): num_frames = len(block_pos) - coverage = np.zeros((num_frames,)) + coverage = np.zeros((num_frames,), dtype=np.float32) # 8 keypoints with 2 coords each - keypoints = np.zeros((num_frames, 16)) + keypoints = np.zeros((num_frames, 16), dtype=np.float32) # Set x, y, theta (in radians) goal_pos_angle = np.array([256, 256, np.pi / 4]) @@ -117,7 +118,7 @@ def calculate_coverage(zarr_data): intersection_area = goal_geom.intersection(block_geom).area goal_area = goal_geom.area coverage[i] = intersection_area / goal_area - keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten()) + keypoints[i] = PushTEnv.get_keypoints(block_shapes).flatten() return coverage, keypoints @@ -134,8 +135,8 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T if mode not in ["video", "image", "keypoints"]: raise ValueError(mode) - if (LEROBOT_HOME / repo_id).exists(): - shutil.rmtree(LEROBOT_HOME / repo_id) + if (HF_LEROBOT_HOME / repo_id).exists(): + shutil.rmtree(HF_LEROBOT_HOME / repo_id) if not raw_dir.exists(): download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw") @@ -148,6 +149,10 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T action = zarr_data["action"][:] image = zarr_data["img"] # (b, h, w, c) + if image.dtype == np.float32 and image.max() == np.float32(255): + # HACK: images are loaded as float32 but they actually encode uint8 data + image = image.astype(np.uint8) + episode_data_index = { "from": np.concatenate(([0], zarr_data.meta["episode_ends"][:-1])), "to": zarr_data.meta["episode_ends"], @@ -175,28 +180,30 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T for frame_idx in range(num_frames): i = from_idx + frame_idx + idx = i + (frame_idx < num_frames - 1) frame = { - "action": torch.from_numpy(action[i]), + "action": action[i], # Shift reward and success by +1 until the last item of the episode - "next.reward": reward[i + (frame_idx < num_frames - 1)], - "next.success": success[i + (frame_idx < num_frames - 1)], + "next.reward": reward[idx : idx + 1], + "next.success": success[idx : idx + 1], + "task": PUSHT_TASK, } - frame["observation.state"] = torch.from_numpy(agent_pos[i]) + frame["observation.state"] = agent_pos[i] if mode == "keypoints": - frame["observation.environment_state"] = torch.from_numpy(keypoints[i]) + frame["observation.environment_state"] = keypoints[i] else: - frame["observation.image"] = torch.from_numpy(image[i]) + frame["observation.image"] = image[i] dataset.add_frame(frame) - dataset.save_episode(task=PUSHT_TASK) - - dataset.consolidate() + dataset.save_episode() if push_to_hub: dataset.push_to_hub() + hub_api = HfApi() + hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, repo_type="dataset") if __name__ == "__main__": @@ -218,5 +225,5 @@ if __name__ == "__main__": main(raw_dir, repo_id=repo_id, mode=mode) # Uncomment if you want to load the local dataset and explore it - # dataset = LeRobotDataset(repo_id=repo_id, local_files_only=True) + # dataset = LeRobotDataset(repo_id=repo_id) # breakpoint() diff --git a/lerobot/common/constants.py b/lerobot/common/constants.py index 34da4ac0..d0c9845a 100644 --- a/lerobot/common/constants.py +++ b/lerobot/common/constants.py @@ -1,4 +1,9 @@ # keys +import os +from pathlib import Path + +from huggingface_hub.constants import HF_HOME + OBS_ENV = "observation.environment_state" OBS_ROBOT = "observation.state" OBS_IMAGE = "observation.image" @@ -15,3 +20,13 @@ TRAINING_STEP = "training_step.json" OPTIMIZER_STATE = "optimizer_state.safetensors" OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json" SCHEDULER_STATE = "scheduler_state.json" + +# cache dir +default_cache_path = Path(HF_HOME) / "lerobot" +HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser() + +if "LEROBOT_HOME" in os.environ: + raise ValueError( + f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n" + "'LEROBOT_HOME' is deprecated, please use 'HF_LEROBOT_HOME' instead." + ) diff --git a/lerobot/common/datasets/backward_compatibility.py b/lerobot/common/datasets/backward_compatibility.py new file mode 100644 index 00000000..d1b8926a --- /dev/null +++ b/lerobot/common/datasets/backward_compatibility.py @@ -0,0 +1,54 @@ +import packaging.version + +V2_MESSAGE = """ +The dataset you requested ({repo_id}) is in {version} format. + +We introduced a new format since v2.0 which is not backward compatible with v1.x. +Please, use our conversion script. Modify the following command with your own task description: +``` +python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\ + --repo-id {repo_id} \\ + --single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\ +``` + +A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.", "Insert the +peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.", "Open the top +cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped +target.", "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the +sweatshirt.", ... + +If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb) +or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose). +""" + +V21_MESSAGE = """ +The dataset you requested ({repo_id}) is in {version} format. +While current version of LeRobot is backward-compatible with it, the version of your dataset still uses global +stats instead of per-episode stats. Update your dataset stats to the new format using this command: +``` +python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py --repo-id={repo_id} +``` + +If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb) +or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose). +""" + +FUTURE_MESSAGE = """ +The dataset you requested ({repo_id}) is only available in {version} format. +As we cannot ensure forward compatibility with it, please update your current version of lerobot. +""" + + +class CompatibilityError(Exception): ... + + +class BackwardCompatibilityError(CompatibilityError): + def __init__(self, repo_id: str, version: packaging.version.Version): + message = V2_MESSAGE.format(repo_id=repo_id, version=version) + super().__init__(message) + + +class ForwardCompatibilityError(CompatibilityError): + def __init__(self, repo_id: str, version: packaging.version.Version): + message = FUTURE_MESSAGE.format(repo_id=repo_id, version=version) + super().__init__(message) diff --git a/lerobot/common/datasets/compute_stats.py b/lerobot/common/datasets/compute_stats.py index c6211699..a029f892 100644 --- a/lerobot/common/datasets/compute_stats.py +++ b/lerobot/common/datasets/compute_stats.py @@ -13,202 +13,164 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from copy import deepcopy -from math import ceil +import numpy as np -import einops -import torch -import tqdm +from lerobot.common.datasets.utils import load_image_as_numpy -def get_stats_einops_patterns(dataset, num_workers=0): - """These einops patterns will be used to aggregate batches and compute statistics. +def estimate_num_samples( + dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75 +) -> int: + """Heuristic to estimate the number of samples based on dataset size. + The power controls the sample growth relative to dataset size. + Lower the power for less number of samples. - Note: We assume the images are in channel first format + For default arguments, we have: + - from 1 to ~500, num_samples=100 + - at 1000, num_samples=177 + - at 2000, num_samples=299 + - at 5000, num_samples=594 + - at 10000, num_samples=1000 + - at 20000, num_samples=1681 """ + if dataset_len < min_num_samples: + min_num_samples = dataset_len + return max(min_num_samples, min(int(dataset_len**power), max_num_samples)) - dataloader = torch.utils.data.DataLoader( - dataset, - num_workers=num_workers, - batch_size=2, - shuffle=False, - ) - batch = next(iter(dataloader)) - stats_patterns = {} +def sample_indices(data_len: int) -> list[int]: + num_samples = estimate_num_samples(data_len) + return np.round(np.linspace(0, data_len - 1, num_samples)).astype(int).tolist() - for key in dataset.features: - # sanity check that tensors are not float64 - assert batch[key].dtype != torch.float64 - # if isinstance(feats_type, (VideoFrame, Image)): - if key in dataset.meta.camera_keys: - # sanity check that images are channel first - _, c, h, w = batch[key].shape - assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}" +def auto_downsample_height_width(img: np.ndarray, target_size: int = 150, max_size_threshold: int = 300): + _, height, width = img.shape - # sanity check that images are float32 in range [0,1] - assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}" - assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}" - assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}" + if max(width, height) < max_size_threshold: + # no downsampling needed + return img - stats_patterns[key] = "b c h w -> c 1 1" - elif batch[key].ndim == 2: - stats_patterns[key] = "b c -> c " - elif batch[key].ndim == 1: - stats_patterns[key] = "b -> 1" + downsample_factor = int(width / target_size) if width > height else int(height / target_size) + return img[:, ::downsample_factor, ::downsample_factor] + + +def sample_images(image_paths: list[str]) -> np.ndarray: + sampled_indices = sample_indices(len(image_paths)) + + images = None + for i, idx in enumerate(sampled_indices): + path = image_paths[idx] + # we load as uint8 to reduce memory usage + img = load_image_as_numpy(path, dtype=np.uint8, channel_first=True) + img = auto_downsample_height_width(img) + + if images is None: + images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8) + + images[i] = img + + return images + + +def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]: + return { + "min": np.min(array, axis=axis, keepdims=keepdims), + "max": np.max(array, axis=axis, keepdims=keepdims), + "mean": np.mean(array, axis=axis, keepdims=keepdims), + "std": np.std(array, axis=axis, keepdims=keepdims), + "count": np.array([len(array)]), + } + + +def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict: + ep_stats = {} + for key, data in episode_data.items(): + if features[key]["dtype"] == "string": + continue # HACK: we should receive np.arrays of strings + elif features[key]["dtype"] in ["image", "video"]: + ep_ft_array = sample_images(data) # data is a list of image paths + axes_to_reduce = (0, 2, 3) # keep channel dim + keepdims = True else: - raise ValueError(f"{key}, {batch[key].shape}") + ep_ft_array = data # data is alreay a np.ndarray + axes_to_reduce = 0 # compute stats over the first axis + keepdims = data.ndim == 1 # keep as np.array - return stats_patterns + ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims) + + # finally, we normalize and remove batch dim for images + if features[key]["dtype"] in ["image", "video"]: + ep_stats[key] = { + k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items() + } + + return ep_stats -def compute_stats(dataset, batch_size=8, num_workers=8, max_num_samples=None): - """Compute mean/std and min/max statistics of all data keys in a LeRobotDataset.""" - if max_num_samples is None: - max_num_samples = len(dataset) - - # for more info on why we need to set the same number of workers, see `load_from_videos` - stats_patterns = get_stats_einops_patterns(dataset, num_workers) - - # mean and std will be computed incrementally while max and min will track the running value. - mean, std, max, min = {}, {}, {}, {} - for key in stats_patterns: - mean[key] = torch.tensor(0.0).float() - std[key] = torch.tensor(0.0).float() - max[key] = torch.tensor(-float("inf")).float() - min[key] = torch.tensor(float("inf")).float() - - def create_seeded_dataloader(dataset, batch_size, seed): - generator = torch.Generator() - generator.manual_seed(seed) - dataloader = torch.utils.data.DataLoader( - dataset, - num_workers=num_workers, - batch_size=batch_size, - shuffle=True, - drop_last=False, - generator=generator, - ) - return dataloader - - # Note: Due to be refactored soon. The point of storing `first_batch` is to make sure we don't get - # surprises when rerunning the sampler. - first_batch = None - running_item_count = 0 # for online mean computation - dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337) - for i, batch in enumerate( - tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max") - ): - this_batch_size = len(batch["index"]) - running_item_count += this_batch_size - if first_batch is None: - first_batch = deepcopy(batch) - for key, pattern in stats_patterns.items(): - batch[key] = batch[key].float() - # Numerically stable update step for mean computation. - batch_mean = einops.reduce(batch[key], pattern, "mean") - # Hint: to update the mean we need x̄ₙ = (Nₙ₋₁x̄ₙ₋₁ + Bₙxₙ) / Nₙ, where the subscript represents - # the update step, N is the running item count, B is this batch size, x̄ is the running mean, - # and x is the current batch mean. Some rearrangement is then required to avoid risking - # numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields - # x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ - mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count - max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max")) - min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min")) - - if i == ceil(max_num_samples / batch_size) - 1: - break - - first_batch_ = None - running_item_count = 0 # for online std computation - dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337) - for i, batch in enumerate( - tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std") - ): - this_batch_size = len(batch["index"]) - running_item_count += this_batch_size - # Sanity check to make sure the batches are still in the same order as before. - if first_batch_ is None: - first_batch_ = deepcopy(batch) - for key in stats_patterns: - assert torch.equal(first_batch_[key], first_batch[key]) - for key, pattern in stats_patterns.items(): - batch[key] = batch[key].float() - # Numerically stable update step for mean computation (where the mean is over squared - # residuals).See notes in the mean computation loop above. - batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean") - std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count - - if i == ceil(max_num_samples / batch_size) - 1: - break - - for key in stats_patterns: - std[key] = torch.sqrt(std[key]) - - stats = {} - for key in stats_patterns: - stats[key] = { - "mean": mean[key], - "std": std[key], - "max": max[key], - "min": min[key], - } - return stats +def _assert_type_and_shape(stats_list: list[dict[str, dict]]): + for i in range(len(stats_list)): + for fkey in stats_list[i]: + for k, v in stats_list[i][fkey].items(): + if not isinstance(v, np.ndarray): + raise ValueError( + f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead." + ) + if v.ndim == 0: + raise ValueError("Number of dimensions must be at least 1, and is 0 instead.") + if k == "count" and v.shape != (1,): + raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.") + if "image" in fkey and k != "count" and v.shape != (3, 1, 1): + raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.") -def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]: - """Aggregate stats of multiple LeRobot datasets into one set of stats without recomputing from scratch. +def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]: + """Aggregates stats for a single feature.""" + means = np.stack([s["mean"] for s in stats_ft_list]) + variances = np.stack([s["std"] ** 2 for s in stats_ft_list]) + counts = np.stack([s["count"] for s in stats_ft_list]) + total_count = counts.sum(axis=0) - The final stats will have the union of all data keys from each of the datasets. + # Prepare weighted mean by matching number of dimensions + while counts.ndim < means.ndim: + counts = np.expand_dims(counts, axis=-1) - The final stats will have the union of all data keys from each of the datasets. For instance: - - new_max = max(max_dataset_0, max_dataset_1, ...) + # Compute the weighted mean + weighted_means = means * counts + total_mean = weighted_means.sum(axis=0) / total_count + + # Compute the variance using the parallel algorithm + delta_means = means - total_mean + weighted_variances = (variances + delta_means**2) * counts + total_variance = weighted_variances.sum(axis=0) / total_count + + return { + "min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0), + "max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0), + "mean": total_mean, + "std": np.sqrt(total_variance), + "count": total_count, + } + + +def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]: + """Aggregate stats from multiple compute_stats outputs into a single set of stats. + + The final stats will have the union of all data keys from each of the stats dicts. + + For instance: - new_min = min(min_dataset_0, min_dataset_1, ...) - - new_mean = (mean of all data) + - new_max = max(max_dataset_0, max_dataset_1, ...) + - new_mean = (mean of all data, weighted by counts) - new_std = (std of all data) """ - data_keys = set() - for dataset in ls_datasets: - data_keys.update(dataset.meta.stats.keys()) - stats = {k: {} for k in data_keys} - for data_key in data_keys: - for stat_key in ["min", "max"]: - # compute `max(dataset_0["max"], dataset_1["max"], ...)` - stats[data_key][stat_key] = einops.reduce( - torch.stack( - [ds.meta.stats[data_key][stat_key] for ds in ls_datasets if data_key in ds.meta.stats], - dim=0, - ), - "n ... -> ...", - stat_key, - ) - total_samples = sum(d.num_frames for d in ls_datasets if data_key in d.meta.stats) - # Compute the "sum" statistic by multiplying each mean by the number of samples in the respective - # dataset, then divide by total_samples to get the overall "mean". - # NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of - # numerical overflow! - stats[data_key]["mean"] = sum( - d.meta.stats[data_key]["mean"] * (d.num_frames / total_samples) - for d in ls_datasets - if data_key in d.meta.stats - ) - # The derivation for standard deviation is a little more involved but is much in the same spirit as - # the computation of the mean. - # Given two sets of data where the statistics are known: - # σ_combined = sqrt[ (n1 * (σ1^2 + d1^2) + n2 * (σ2^2 + d2^2)) / (n1 + n2) ] - # where d1 = μ1 - μ_combined, d2 = μ2 - μ_combined - # NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of - # numerical overflow! - stats[data_key]["std"] = torch.sqrt( - sum( - ( - d.meta.stats[data_key]["std"] ** 2 - + (d.meta.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2 - ) - * (d.num_frames / total_samples) - for d in ls_datasets - if data_key in d.meta.stats - ) - ) - return stats + + _assert_type_and_shape(stats_list) + + data_keys = {key for stats in stats_list for key in stats} + aggregated_stats = {key: {} for key in data_keys} + + for key in data_keys: + stats_with_key = [stats[key] for stats in stats_list if key in stats] + aggregated_stats[key] = aggregate_feature_stats(stats_with_key) + + return aggregated_stats diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 95ba76b8..fb1fe6d6 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -83,15 +83,15 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas ) if isinstance(cfg.dataset.repo_id, str): - ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id, local_files_only=cfg.dataset.local_files_only) + ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id, revision=cfg.dataset.revision) delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta) dataset = LeRobotDataset( cfg.dataset.repo_id, episodes=cfg.dataset.episodes, delta_timestamps=delta_timestamps, image_transforms=image_transforms, + revision=cfg.dataset.revision, video_backend=cfg.dataset.video_backend, - local_files_only=cfg.dataset.local_files_only, ) else: raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.") diff --git a/lerobot/common/datasets/image_writer.py b/lerobot/common/datasets/image_writer.py index 85dd6830..6fc0ee2f 100644 --- a/lerobot/common/datasets/image_writer.py +++ b/lerobot/common/datasets/image_writer.py @@ -38,22 +38,40 @@ def safe_stop_image_writer(func): return wrapper -def image_array_to_image(image_array: np.ndarray) -> PIL.Image.Image: +def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image: # TODO(aliberts): handle 1 channel and 4 for depth images - if image_array.ndim == 3 and image_array.shape[0] in [1, 3]: + if image_array.ndim != 3: + raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.") + + if image_array.shape[0] == 3: # Transpose from pytorch convention (C, H, W) to (H, W, C) image_array = image_array.transpose(1, 2, 0) + + elif image_array.shape[-1] != 3: + raise NotImplementedError( + f"The image has {image_array.shape[-1]} channels, but 3 is required for now." + ) + if image_array.dtype != np.uint8: - # Assume the image is in [0, 1] range for floating-point data - image_array = np.clip(image_array, 0, 1) + if range_check: + max_ = image_array.max().item() + min_ = image_array.min().item() + if max_ > 1.0 or min_ < 0.0: + raise ValueError( + "The image data type is float, which requires values in the range [0.0, 1.0]. " + f"However, the provided range is [{min_}, {max_}]. Please adjust the range or " + "provide a uint8 image with values in the range [0, 255]." + ) + image_array = (image_array * 255).astype(np.uint8) + return PIL.Image.fromarray(image_array) def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path): try: if isinstance(image, np.ndarray): - img = image_array_to_image(image) + img = image_array_to_pil_image(image) elif isinstance(image, PIL.Image.Image): img = image else: diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 9483bf0a..f1eb11a0 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -14,49 +14,54 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import os import shutil -from functools import cached_property from pathlib import Path from typing import Callable import datasets import numpy as np +import packaging.version import PIL.Image import torch import torch.utils -from datasets import load_dataset -from huggingface_hub import create_repo, snapshot_download, upload_folder +from datasets import concatenate_datasets, load_dataset +from huggingface_hub import HfApi, snapshot_download +from huggingface_hub.constants import REPOCARD_NAME -from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats +from lerobot.common.constants import HF_LEROBOT_HOME +from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image from lerobot.common.datasets.utils import ( DEFAULT_FEATURES, DEFAULT_IMAGE_PATH, - EPISODES_PATH, INFO_PATH, - STATS_PATH, TASKS_PATH, append_jsonlines, + backward_compatible_episodes_stats, check_delta_timestamps, check_timestamps_sync, check_version_compatibility, - create_branch, create_empty_dataset_info, create_lerobot_dataset_card, + embed_images, get_delta_indices, get_episode_data_index, get_features_from_robot, get_hf_features_from_features, - get_hub_safe_version, + get_safe_version, hf_transform_to_torch, + is_valid_version, load_episodes, + load_episodes_stats, load_info, load_stats, load_tasks, - serialize_dict, + validate_episode_buffer, + validate_frame, + write_episode, + write_episode_stats, + write_info, write_json, - write_parquet, ) from lerobot.common.datasets.video_utils import ( VideoFrame, @@ -66,9 +71,7 @@ from lerobot.common.datasets.video_utils import ( ) from lerobot.common.robot_devices.robots.utils import Robot -# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md -CODEBASE_VERSION = "v2.0" -LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser() +CODEBASE_VERSION = "v2.1" class LeRobotDatasetMetadata: @@ -76,19 +79,36 @@ class LeRobotDatasetMetadata: self, repo_id: str, root: str | Path | None = None, - local_files_only: bool = False, + revision: str | None = None, + force_cache_sync: bool = False, ): self.repo_id = repo_id - self.root = Path(root) if root is not None else LEROBOT_HOME / repo_id - self.local_files_only = local_files_only + self.revision = revision if revision else CODEBASE_VERSION + self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id - # Load metadata - (self.root / "meta").mkdir(exist_ok=True, parents=True) - self.pull_from_repo(allow_patterns="meta/") + try: + if force_cache_sync: + raise FileNotFoundError + self.load_metadata() + except (FileNotFoundError, NotADirectoryError): + if is_valid_version(self.revision): + self.revision = get_safe_version(self.repo_id, self.revision) + + (self.root / "meta").mkdir(exist_ok=True, parents=True) + self.pull_from_repo(allow_patterns="meta/") + self.load_metadata() + + def load_metadata(self): self.info = load_info(self.root) - self.stats = load_stats(self.root) - self.tasks = load_tasks(self.root) + check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION) + self.tasks, self.task_to_task_index = load_tasks(self.root) self.episodes = load_episodes(self.root) + if self._version < packaging.version.parse("v2.1"): + self.stats = load_stats(self.root) + self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes) + else: + self.episodes_stats = load_episodes_stats(self.root) + self.stats = aggregate_stats(list(self.episodes_stats.values())) def pull_from_repo( self, @@ -98,21 +118,16 @@ class LeRobotDatasetMetadata: snapshot_download( self.repo_id, repo_type="dataset", - revision=self._hub_version, + revision=self.revision, local_dir=self.root, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, - local_files_only=self.local_files_only, ) - @cached_property - def _hub_version(self) -> str | None: - return None if self.local_files_only else get_hub_safe_version(self.repo_id, CODEBASE_VERSION) - @property - def _version(self) -> str: + def _version(self) -> packaging.version.Version: """Codebase version used to create this dataset.""" - return self.info["codebase_version"] + return packaging.version.parse(self.info["codebase_version"]) def get_data_file_path(self, ep_index: int) -> Path: ep_chunk = self.get_episode_chunk(ep_index) @@ -202,54 +217,65 @@ class LeRobotDatasetMetadata: """Max number of episodes per chunk.""" return self.info["chunks_size"] - @property - def task_to_task_index(self) -> dict: - return {task: task_idx for task_idx, task in self.tasks.items()} - - def get_task_index(self, task: str) -> int: + def get_task_index(self, task: str) -> int | None: """ Given a task in natural language, returns its task_index if the task already exists in the dataset, - otherwise creates a new task_index. + otherwise return None. """ - task_index = self.task_to_task_index.get(task, None) - return task_index if task_index is not None else self.total_tasks + return self.task_to_task_index.get(task, None) - def save_episode(self, episode_index: int, episode_length: int, task: str, task_index: int) -> None: + def add_task(self, task: str): + """ + Given a task in natural language, add it to the dictionnary of tasks. + """ + if task in self.task_to_task_index: + raise ValueError(f"The task '{task}' already exists and can't be added twice.") + + task_index = self.info["total_tasks"] + self.task_to_task_index[task] = task_index + self.tasks[task_index] = task + self.info["total_tasks"] += 1 + + task_dict = { + "task_index": task_index, + "task": task, + } + append_jsonlines(task_dict, self.root / TASKS_PATH) + + def save_episode( + self, + episode_index: int, + episode_length: int, + episode_tasks: list[str], + episode_stats: dict[str, dict], + ) -> None: self.info["total_episodes"] += 1 self.info["total_frames"] += episode_length - if task_index not in self.tasks: - self.info["total_tasks"] += 1 - self.tasks[task_index] = task - task_dict = { - "task_index": task_index, - "task": task, - } - append_jsonlines(task_dict, self.root / TASKS_PATH) - chunk = self.get_episode_chunk(episode_index) if chunk >= self.total_chunks: self.info["total_chunks"] += 1 self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"} self.info["total_videos"] += len(self.video_keys) - write_json(self.info, self.root / INFO_PATH) + if len(self.video_keys) > 0: + self.update_video_info() + + write_info(self.info, self.root) episode_dict = { "episode_index": episode_index, - "tasks": [task], + "tasks": episode_tasks, "length": episode_length, } - self.episodes.append(episode_dict) - append_jsonlines(episode_dict, self.root / EPISODES_PATH) + self.episodes[episode_index] = episode_dict + write_episode(episode_dict, self.root) - # TODO(aliberts): refactor stats in save_episodes - # image_sampling = int(self.fps / 2) # sample 2 img/s for the stats - # ep_stats = compute_episode_stats(episode_buffer, self.features, episode_length, image_sampling=image_sampling) - # ep_stats = serialize_dict(ep_stats) - # append_jsonlines(ep_stats, self.root / STATS_PATH) + self.episodes_stats[episode_index] = episode_stats + self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats else episode_stats + write_episode_stats(episode_index, episode_stats, self.root) - def write_video_info(self) -> None: + def update_video_info(self) -> None: """ Warning: this function writes info from first episode videos, implicitly assuming that all videos have been encoded the same way. Also, this means it assumes the first episode exists. @@ -259,8 +285,6 @@ class LeRobotDatasetMetadata: video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key) self.info["features"][key]["info"] = get_video_info(video_path) - write_json(self.info, self.root / INFO_PATH) - def __repr__(self): feature_keys = list(self.features) return ( @@ -286,7 +310,7 @@ class LeRobotDatasetMetadata: """Creates metadata for a LeRobotDataset.""" obj = cls.__new__(cls) obj.repo_id = repo_id - obj.root = Path(root) if root is not None else LEROBOT_HOME / repo_id + obj.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id obj.root.mkdir(parents=True, exist_ok=False) @@ -304,6 +328,7 @@ class LeRobotDatasetMetadata: ) else: # TODO(aliberts, rcadene): implement sanity check for features + features = {**features, **DEFAULT_FEATURES} # check if none of the features contains a "/" in their names, # as this would break the dict flattening in the stats computation, which uses '/' as separator @@ -313,12 +338,13 @@ class LeRobotDatasetMetadata: features = {**features, **DEFAULT_FEATURES} - obj.tasks, obj.stats, obj.episodes = {}, {}, [] + obj.tasks, obj.task_to_task_index = {}, {} + obj.episodes_stats, obj.stats, obj.episodes = {}, {}, {} obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos) if len(obj.video_keys) > 0 and not use_videos: raise ValueError() write_json(obj.info, obj.root / INFO_PATH) - obj.local_files_only = True + obj.revision = None return obj @@ -331,8 +357,9 @@ class LeRobotDataset(torch.utils.data.Dataset): image_transforms: Callable | None = None, delta_timestamps: dict[list[float]] | None = None, tolerance_s: float = 1e-4, + revision: str | None = None, + force_cache_sync: bool = False, download_videos: bool = True, - local_files_only: bool = False, video_backend: str | None = None, ): """ @@ -342,7 +369,7 @@ class LeRobotDataset(torch.utils.data.Dataset): - On your local disk in the 'root' folder. This is typically the case when you recorded your dataset locally and you may or may not have pushed it to the hub yet. Instantiating this class with 'root' will load your dataset directly from disk. This can happen while you're offline (no - internet connection), in that case, use local_files_only=True. + internet connection). - On the Hugging Face Hub at the address https://huggingface.co/datasets/{repo_id} and not on your local disk in the 'root' folder. Instantiating this class with this 'repo_id' will download @@ -424,24 +451,28 @@ class LeRobotDataset(torch.utils.data.Dataset): timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames decoded from video files. It is also used to check that `delta_timestamps` (when provided) are multiples of 1/fps. Defaults to 1e-4. + revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a + commit hash. Defaults to current codebase version tag. + sync_cache_first (bool, optional): Flag to sync and refresh local files first. If True and files + are already present in the local cache, this will be faster. However, files loaded might not + be in sync with the version on the hub, especially if you specified 'revision'. Defaults to + False. download_videos (bool, optional): Flag to download the videos. Note that when set to True but the video files are already present on local disk, they won't be downloaded again. Defaults to True. - local_files_only (bool, optional): Flag to use local files only. If True, no requests to the hub - will be made. Defaults to False. video_backend (str | None, optional): Video backend to use for decoding videos. There is currently a single option which is the pyav decoder used by Torchvision. Defaults to pyav. """ super().__init__() self.repo_id = repo_id - self.root = Path(root) if root else LEROBOT_HOME / repo_id + self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id self.image_transforms = image_transforms self.delta_timestamps = delta_timestamps self.episodes = episodes self.tolerance_s = tolerance_s + self.revision = revision if revision else CODEBASE_VERSION self.video_backend = video_backend if video_backend else "pyav" self.delta_indices = None - self.local_files_only = local_files_only # Unused attributes self.image_writer = None @@ -450,64 +481,86 @@ class LeRobotDataset(torch.utils.data.Dataset): self.root.mkdir(exist_ok=True, parents=True) # Load metadata - self.meta = LeRobotDatasetMetadata(self.repo_id, self.root, self.local_files_only) - - # Check version - check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION) + self.meta = LeRobotDatasetMetadata( + self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync + ) + if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"): + episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes] + self.stats = aggregate_stats(episodes_stats) # Load actual data - self.download_episodes(download_videos) - self.hf_dataset = self.load_hf_dataset() + try: + if force_cache_sync: + raise FileNotFoundError + assert all((self.root / fpath).is_file() for fpath in self.get_episodes_file_paths()) + self.hf_dataset = self.load_hf_dataset() + except (AssertionError, FileNotFoundError, NotADirectoryError): + self.revision = get_safe_version(self.repo_id, self.revision) + self.download_episodes(download_videos) + self.hf_dataset = self.load_hf_dataset() + self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) # Check timestamps - check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s) + timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy() + episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy() + ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()} + check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s) # Setup delta_indices if self.delta_timestamps is not None: check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s) self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps) - # Available stats implies all videos have been encoded and dataset is iterable - self.consolidated = self.meta.stats is not None - def push_to_hub( self, + branch: str | None = None, tags: list | None = None, license: str | None = "apache-2.0", push_videos: bool = True, private: bool = False, + allow_patterns: list[str] | str | None = None, + upload_large_folder: bool = False, **card_kwargs, ) -> None: - if not self.consolidated: - logging.warning( - "You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet. " - "Consolidating first." - ) - self.consolidate() - ignore_patterns = ["images/"] if not push_videos: ignore_patterns.append("videos/") - create_repo( + hub_api = HfApi() + hub_api.create_repo( repo_id=self.repo_id, private=private, repo_type="dataset", exist_ok=True, ) + if branch: + hub_api.create_branch( + repo_id=self.repo_id, + branch=branch, + revision=self.revision, + repo_type="dataset", + exist_ok=True, + ) - upload_folder( - repo_id=self.repo_id, - folder_path=self.root, - repo_type="dataset", - ignore_patterns=ignore_patterns, - ) - card = create_lerobot_dataset_card( - tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs - ) - card.push_to_hub(repo_id=self.repo_id, repo_type="dataset") - create_branch(repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset") + upload_kwargs = { + "repo_id": self.repo_id, + "folder_path": self.root, + "repo_type": "dataset", + "revision": branch, + "allow_patterns": allow_patterns, + "ignore_patterns": ignore_patterns, + } + if upload_large_folder: + hub_api.upload_large_folder(**upload_kwargs) + else: + hub_api.upload_folder(**upload_kwargs) + + if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch): + card = create_lerobot_dataset_card( + tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs + ) + card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch) def pull_from_repo( self, @@ -517,11 +570,10 @@ class LeRobotDataset(torch.utils.data.Dataset): snapshot_download( self.repo_id, repo_type="dataset", - revision=self.meta._hub_version, + revision=self.revision, local_dir=self.root, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns, - local_files_only=self.local_files_only, ) def download_episodes(self, download_videos: bool = True) -> None: @@ -535,17 +587,23 @@ class LeRobotDataset(torch.utils.data.Dataset): files = None ignore_patterns = None if download_videos else "videos/" if self.episodes is not None: - files = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes] - if len(self.meta.video_keys) > 0 and download_videos: - video_files = [ - str(self.meta.get_video_file_path(ep_idx, vid_key)) - for vid_key in self.meta.video_keys - for ep_idx in self.episodes - ] - files += video_files + files = self.get_episodes_file_paths() self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns) + def get_episodes_file_paths(self) -> list[Path]: + episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes)) + fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes] + if len(self.meta.video_keys) > 0: + video_files = [ + str(self.meta.get_video_file_path(ep_idx, vid_key)) + for vid_key in self.meta.video_keys + for ep_idx in episodes + ] + fpaths += video_files + + return fpaths + def load_hf_dataset(self) -> datasets.Dataset: """hf_dataset contains all the observations, states, actions, rewards, etc.""" if self.episodes is None: @@ -557,7 +615,15 @@ class LeRobotDataset(torch.utils.data.Dataset): # TODO(aliberts): hf_dataset.set_format("torch") hf_dataset.set_transform(hf_transform_to_torch) + return hf_dataset + def create_hf_dataset(self) -> datasets.Dataset: + features = get_hf_features_from_features(self.features) + ft_dict = {col: [] for col in features} + hf_dataset = datasets.Dataset.from_dict(ft_dict, features=features, split="train") + + # TODO(aliberts): hf_dataset.set_format("torch") + hf_dataset.set_transform(hf_transform_to_torch) return hf_dataset @property @@ -624,7 +690,7 @@ class LeRobotDataset(torch.utils.data.Dataset): if key not in self.meta.video_keys } - def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict: + def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]: """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a Segmentation Fault. This probably happens because a memory reference to the video loader is created in @@ -654,8 +720,7 @@ class LeRobotDataset(torch.utils.data.Dataset): query_indices = None if self.delta_indices is not None: - current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx - query_indices, padding = self._get_query_indices(idx, current_ep_idx) + query_indices, padding = self._get_query_indices(idx, ep_idx) query_result = self._query_hf_dataset(query_indices) item = {**item, **padding} for key, val in query_result.items(): @@ -691,10 +756,13 @@ class LeRobotDataset(torch.utils.data.Dataset): def create_episode_buffer(self, episode_index: int | None = None) -> dict: current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index - return { - "size": 0, - **{key: current_ep_idx if key == "episode_index" else [] for key in self.features}, - } + ep_buffer = {} + # size and task are special cases that are not in self.features + ep_buffer["size"] = 0 + ep_buffer["task"] = [] + for key in self.features: + ep_buffer[key] = current_ep_idx if key == "episode_index" else [] + return ep_buffer def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path: fpath = DEFAULT_IMAGE_PATH.format( @@ -716,25 +784,35 @@ class LeRobotDataset(torch.utils.data.Dataset): temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method then needs to be called. """ - # TODO(aliberts, rcadene): Add sanity check for the input, check it's numpy or torch, - # check the dtype and shape matches, etc. + # Convert torch to numpy if needed + for name in frame: + if isinstance(frame[name], torch.Tensor): + frame[name] = frame[name].numpy() + + validate_frame(frame, self.features) if self.episode_buffer is None: self.episode_buffer = self.create_episode_buffer() + # Automatically add frame_index and timestamp to episode buffer frame_index = self.episode_buffer["size"] timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps self.episode_buffer["frame_index"].append(frame_index) self.episode_buffer["timestamp"].append(timestamp) + # Add frame features to episode_buffer for key in frame: - if key not in self.features: - raise ValueError(key) + if key == "task": + # Note: we associate the task in natural language to its task index during `save_episode` + self.episode_buffer["task"].append(frame["task"]) + continue - if self.features[key]["dtype"] not in ["image", "video"]: - item = frame[key].numpy() if isinstance(frame[key], torch.Tensor) else frame[key] - self.episode_buffer[key].append(item) - elif self.features[key]["dtype"] in ["image", "video"]: + if key not in self.features: + raise ValueError( + f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'." + ) + + if self.features[key]["dtype"] in ["image", "video"]: img_path = self._get_image_file_path( episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index ) @@ -742,80 +820,95 @@ class LeRobotDataset(torch.utils.data.Dataset): img_path.parent.mkdir(parents=True, exist_ok=True) self._save_image(frame[key], img_path) self.episode_buffer[key].append(str(img_path)) + else: + self.episode_buffer[key].append(frame[key]) self.episode_buffer["size"] += 1 - def save_episode(self, task: str, encode_videos: bool = True, episode_data: dict | None = None) -> None: + def save_episode(self, episode_data: dict | None = None) -> None: """ - This will save to disk the current episode in self.episode_buffer. Note that since it affects files on - disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to - the hub. + This will save to disk the current episode in self.episode_buffer. - Use 'encode_videos' if you want to encode videos during the saving of this episode. Otherwise, - you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend - time for video encoding. + Args: + episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will + save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to + None. """ if not episode_data: episode_buffer = self.episode_buffer + validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features) + + # size and task are special cases that won't be added to hf_dataset episode_length = episode_buffer.pop("size") + tasks = episode_buffer.pop("task") + episode_tasks = list(set(tasks)) episode_index = episode_buffer["episode_index"] - if episode_index != self.meta.total_episodes: - # TODO(aliberts): Add option to use existing episode_index - raise NotImplementedError( - "You might have manually provided the episode_buffer with an episode_index that doesn't " - "match the total number of episodes in the dataset. This is not supported for now." - ) - if episode_length == 0: - raise ValueError( - "You must add one or several frames with `add_frame` before calling `add_episode`." - ) + episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length) + episode_buffer["episode_index"] = np.full((episode_length,), episode_index) - task_index = self.meta.get_task_index(task) + # Add new tasks to the tasks dictionnary + for task in episode_tasks: + task_index = self.meta.get_task_index(task) + if task_index is None: + self.meta.add_task(task) - if not set(episode_buffer.keys()) == set(self.features): - raise ValueError() + # Given tasks in natural language, find their corresponding task indices + episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks]) for key, ft in self.features.items(): - if key == "index": - episode_buffer[key] = np.arange( - self.meta.total_frames, self.meta.total_frames + episode_length - ) - elif key == "episode_index": - episode_buffer[key] = np.full((episode_length,), episode_index) - elif key == "task_index": - episode_buffer[key] = np.full((episode_length,), task_index) - elif ft["dtype"] in ["image", "video"]: + # index, episode_index, task_index are already processed above, and image and video + # are processed separately by storing image path and frame info as meta data + if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]: continue - elif len(ft["shape"]) == 1 and ft["shape"][0] == 1: - episode_buffer[key] = np.array(episode_buffer[key], dtype=ft["dtype"]) - elif len(ft["shape"]) == 1 and ft["shape"][0] > 1: - episode_buffer[key] = np.stack(episode_buffer[key]) - else: - raise ValueError(key) + episode_buffer[key] = np.stack(episode_buffer[key]) self._wait_image_writer() self._save_episode_table(episode_buffer, episode_index) + ep_stats = compute_episode_stats(episode_buffer, self.features) - self.meta.save_episode(episode_index, episode_length, task, task_index) - - if encode_videos and len(self.meta.video_keys) > 0: + if len(self.meta.video_keys) > 0: video_paths = self.encode_episode_videos(episode_index) for key in self.meta.video_keys: episode_buffer[key] = video_paths[key] + # `meta.save_episode` be executed after encoding the videos + self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats) + + ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index]) + ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()} + check_timestamps_sync( + episode_buffer["timestamp"], + episode_buffer["episode_index"], + ep_data_index_np, + self.fps, + self.tolerance_s, + ) + + video_files = list(self.root.rglob("*.mp4")) + assert len(video_files) == self.num_episodes * len(self.meta.video_keys) + + parquet_files = list(self.root.rglob("*.parquet")) + assert len(parquet_files) == self.num_episodes + + # delete images + img_dir = self.root / "images" + if img_dir.is_dir(): + shutil.rmtree(self.root / "images") + if not episode_data: # Reset the buffer self.episode_buffer = self.create_episode_buffer() - self.consolidated = False - def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None: episode_dict = {key: episode_buffer[key] for key in self.hf_features} ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train") + ep_dataset = embed_images(ep_dataset) + self.hf_dataset = concatenate_datasets([self.hf_dataset, ep_dataset]) + self.hf_dataset.set_transform(hf_transform_to_torch) ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index) ep_data_path.parent.mkdir(parents=True, exist_ok=True) - write_parquet(ep_dataset, ep_data_path) + ep_dataset.to_parquet(ep_data_path) def clear_episode_buffer(self) -> None: episode_index = self.episode_buffer["episode_index"] @@ -884,38 +977,6 @@ class LeRobotDataset(torch.utils.data.Dataset): return video_paths - def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None: - self.hf_dataset = self.load_hf_dataset() - self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) - check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s) - - if len(self.meta.video_keys) > 0: - self.encode_videos() - self.meta.write_video_info() - - if not keep_image_files: - img_dir = self.root / "images" - if img_dir.is_dir(): - shutil.rmtree(self.root / "images") - - video_files = list(self.root.rglob("*.mp4")) - assert len(video_files) == self.num_episodes * len(self.meta.video_keys) - - parquet_files = list(self.root.rglob("*.parquet")) - assert len(parquet_files) == self.num_episodes - - if run_compute_stats: - self.stop_image_writer() - # TODO(aliberts): refactor stats in save_episodes - self.meta.stats = compute_stats(self) - serialized_stats = serialize_dict(self.meta.stats) - write_json(serialized_stats, self.root / STATS_PATH) - self.consolidated = True - else: - logging.warning( - "Skipping computation of the dataset statistics, dataset is not fully consolidated." - ) - @classmethod def create( cls, @@ -944,7 +1005,7 @@ class LeRobotDataset(torch.utils.data.Dataset): ) obj.repo_id = obj.meta.repo_id obj.root = obj.meta.root - obj.local_files_only = obj.meta.local_files_only + obj.revision = None obj.tolerance_s = tolerance_s obj.image_writer = None @@ -954,14 +1015,8 @@ class LeRobotDataset(torch.utils.data.Dataset): # TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer obj.episode_buffer = obj.create_episode_buffer() - # This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It - # is used to know when certain operations are need (for instance, computing dataset statistics). In - # order to be able to push the dataset to the hub, it needs to be consolidated first by calling - # self.consolidate(). - obj.consolidated = True - obj.episodes = None - obj.hf_dataset = None + obj.hf_dataset = obj.create_hf_dataset() obj.image_transforms = None obj.delta_timestamps = None obj.delta_indices = None @@ -986,12 +1041,11 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): delta_timestamps: dict[list[float]] | None = None, tolerances_s: dict | None = None, download_videos: bool = True, - local_files_only: bool = False, video_backend: str | None = None, ): super().__init__() self.repo_ids = repo_ids - self.root = Path(root) if root else LEROBOT_HOME + self.root = Path(root) if root else HF_LEROBOT_HOME self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids} # Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which # are handled by this class. @@ -1004,7 +1058,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): delta_timestamps=delta_timestamps, tolerance_s=self.tolerances_s[repo_id], download_videos=download_videos, - local_files_only=local_files_only, video_backend=video_backend, ) for repo_id in repo_ids @@ -1032,7 +1085,10 @@ class MultiLeRobotDataset(torch.utils.data.Dataset): self.image_transforms = image_transforms self.delta_timestamps = delta_timestamps - self.stats = aggregate_stats(self._datasets) + # TODO(rcadene, aliberts): We should not perform this aggregation for datasets + # with multiple robots of different ranges. Instead we should have one normalization + # per robot. + self.stats = aggregate_stats([dataset.meta.stats for dataset in self._datasets]) @property def repo_id_to_index(self): diff --git a/lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md b/lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md deleted file mode 100644 index 8fcc8bbe..00000000 --- a/lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md +++ /dev/null @@ -1,56 +0,0 @@ -## Using / Updating `CODEBASE_VERSION` (for maintainers) - -Since our dataset pushed to the hub are decoupled with the evolution of this repo, we ensure compatibility of -the datasets with our code, we use a `CODEBASE_VERSION` (defined in -lerobot/common/datasets/lerobot_dataset.py) variable. - -For instance, [`lerobot/pusht`](https://huggingface.co/datasets/lerobot/pusht) has many versions to maintain backward compatibility between LeRobot codebase versions: -- [v1.0](https://huggingface.co/datasets/lerobot/pusht/tree/v1.0) -- [v1.1](https://huggingface.co/datasets/lerobot/pusht/tree/v1.1) -- [v1.2](https://huggingface.co/datasets/lerobot/pusht/tree/v1.2) -- [v1.3](https://huggingface.co/datasets/lerobot/pusht/tree/v1.3) -- [v1.4](https://huggingface.co/datasets/lerobot/pusht/tree/v1.4) -- [v1.5](https://huggingface.co/datasets/lerobot/pusht/tree/v1.5) -- [v1.6](https://huggingface.co/datasets/lerobot/pusht/tree/v1.6) <-- last version -- [main](https://huggingface.co/datasets/lerobot/pusht/tree/main) <-- points to the last version - -Starting with v1.6, every dataset pushed to the hub or saved locally also have this version number in their -`info.json` metadata. - -### Uploading a new dataset -If you are pushing a new dataset, you don't need to worry about any of the instructions below, nor to be -compatible with previous codebase versions. The `push_dataset_to_hub.py` script will automatically tag your -dataset with the current `CODEBASE_VERSION`. - -### Updating an existing dataset -If you want to update an existing dataset, you need to change the `CODEBASE_VERSION` from `lerobot_dataset.py` -before running `push_dataset_to_hub.py`. This is especially useful if you introduce a breaking change -intentionally or not (i.e. something not backward compatible such as modifying the reward functions used, -deleting some frames at the end of an episode, etc.). That way, people running a previous version of the -codebase won't be affected by your change and backward compatibility is maintained. - -However, you will need to update the version of ALL the other datasets so that they have the new -`CODEBASE_VERSION` as a branch in their hugging face dataset repository. Don't worry, there is an easy way -that doesn't require to run `push_dataset_to_hub.py`. You can just "branch-out" from the `main` branch on HF -dataset repo by running this script which corresponds to a `git checkout -b` (so no copy or upload needed): - -```python -from huggingface_hub import HfApi - -from lerobot import available_datasets -from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION - -api = HfApi() - -for repo_id in available_datasets: - dataset_info = api.list_repo_refs(repo_id, repo_type="dataset") - branches = [b.name for b in dataset_info.branches] - if CODEBASE_VERSION in branches: - print(f"{repo_id} already @{CODEBASE_VERSION}, skipping.") - continue - else: - # Now create a branch named after the new version by branching out from "main" - # which is expected to be the preceding version - api.create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION, revision="main") - print(f"{repo_id} successfully updated @{CODEBASE_VERSION}") -``` diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 612bac39..2d90798f 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -13,10 +13,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import importlib.resources import json import logging -import textwrap from collections.abc import Iterator from itertools import accumulate from pathlib import Path @@ -27,14 +27,20 @@ from typing import Any import datasets import jsonlines import numpy as np -import pyarrow.compute as pc +import packaging.version import torch from datasets.table import embed_table_storage from huggingface_hub import DatasetCard, DatasetCardData, HfApi from PIL import Image as PILImage from torchvision import transforms +from lerobot.common.datasets.backward_compatibility import ( + V21_MESSAGE, + BackwardCompatibilityError, + ForwardCompatibilityError, +) from lerobot.common.robot_devices.robots.utils import Robot +from lerobot.common.utils.utils import is_valid_numpy_dtype_string from lerobot.configs.types import DictLike, FeatureType, PolicyFeature DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk @@ -42,6 +48,7 @@ DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk INFO_PATH = "meta/info.json" EPISODES_PATH = "meta/episodes.jsonl" STATS_PATH = "meta/stats.json" +EPISODES_STATS_PATH = "meta/episodes_stats.jsonl" TASKS_PATH = "meta/tasks.jsonl" DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4" @@ -112,17 +119,26 @@ def get_nested_item(obj: DictLike, flattened_key: str, sep: str = "/") -> Any: def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict: - serialized_dict = {key: value.tolist() for key, value in flatten_dict(stats).items()} + serialized_dict = {} + for key, value in flatten_dict(stats).items(): + if isinstance(value, (torch.Tensor, np.ndarray)): + serialized_dict[key] = value.tolist() + elif isinstance(value, np.generic): + serialized_dict[key] = value.item() + elif isinstance(value, (int, float)): + serialized_dict[key] = value + else: + raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.") return unflatten_dict(serialized_dict) -def write_parquet(dataset: datasets.Dataset, fpath: Path) -> None: +def embed_images(dataset: datasets.Dataset) -> datasets.Dataset: # Embed image bytes into the table before saving to parquet format = dataset.format dataset = dataset.with_format("arrow") dataset = dataset.map(embed_table_storage, batched=False) dataset = dataset.with_format(**format) - dataset.to_parquet(fpath) + return dataset def load_json(fpath: Path) -> Any: @@ -153,6 +169,10 @@ def append_jsonlines(data: dict, fpath: Path) -> None: writer.write(data) +def write_info(info: dict, local_dir: Path): + write_json(info, local_dir / INFO_PATH) + + def load_info(local_dir: Path) -> dict: info = load_json(local_dir / INFO_PATH) for ft in info["features"].values(): @@ -160,29 +180,76 @@ def load_info(local_dir: Path) -> dict: return info -def load_stats(local_dir: Path) -> dict: - if not (local_dir / STATS_PATH).exists(): - return None - stats = load_json(local_dir / STATS_PATH) - stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()} +def write_stats(stats: dict, local_dir: Path): + serialized_stats = serialize_dict(stats) + write_json(serialized_stats, local_dir / STATS_PATH) + + +def cast_stats_to_numpy(stats) -> dict[str, dict[str, np.ndarray]]: + stats = {key: np.array(value) for key, value in flatten_dict(stats).items()} return unflatten_dict(stats) -def load_tasks(local_dir: Path) -> dict: +def load_stats(local_dir: Path) -> dict[str, dict[str, np.ndarray]]: + if not (local_dir / STATS_PATH).exists(): + return None + stats = load_json(local_dir / STATS_PATH) + return cast_stats_to_numpy(stats) + + +def write_task(task_index: int, task: dict, local_dir: Path): + task_dict = { + "task_index": task_index, + "task": task, + } + append_jsonlines(task_dict, local_dir / TASKS_PATH) + + +def load_tasks(local_dir: Path) -> tuple[dict, dict]: tasks = load_jsonlines(local_dir / TASKS_PATH) - return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])} + tasks = {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])} + task_to_task_index = {task: task_index for task_index, task in tasks.items()} + return tasks, task_to_task_index + + +def write_episode(episode: dict, local_dir: Path): + append_jsonlines(episode, local_dir / EPISODES_PATH) def load_episodes(local_dir: Path) -> dict: - return load_jsonlines(local_dir / EPISODES_PATH) + episodes = load_jsonlines(local_dir / EPISODES_PATH) + return {item["episode_index"]: item for item in sorted(episodes, key=lambda x: x["episode_index"])} -def load_image_as_numpy(fpath: str | Path, dtype="float32", channel_first: bool = True) -> np.ndarray: +def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path): + # We wrap episode_stats in a dictionnary since `episode_stats["episode_index"]` + # is a dictionary of stats and not an integer. + episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)} + append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH) + + +def load_episodes_stats(local_dir: Path) -> dict: + episodes_stats = load_jsonlines(local_dir / EPISODES_STATS_PATH) + return { + item["episode_index"]: cast_stats_to_numpy(item["stats"]) + for item in sorted(episodes_stats, key=lambda x: x["episode_index"]) + } + + +def backward_compatible_episodes_stats( + stats: dict[str, dict[str, np.ndarray]], episodes: list[int] +) -> dict[str, dict[str, np.ndarray]]: + return {ep_idx: stats for ep_idx in episodes} + + +def load_image_as_numpy( + fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True +) -> np.ndarray: img = PILImage.open(fpath).convert("RGB") img_array = np.array(img, dtype=dtype) if channel_first: # (H, W, C) -> (C, H, W) img_array = np.transpose(img_array, (2, 0, 1)) - if "float" in dtype: + if np.issubdtype(dtype, np.floating): img_array /= 255.0 return img_array @@ -201,77 +268,82 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]): elif first_item is None: pass else: - items_dict[key] = [torch.tensor(x) for x in items_dict[key]] + items_dict[key] = [x if isinstance(x, str) else torch.tensor(x) for x in items_dict[key]] return items_dict -def _get_major_minor(version: str) -> tuple[int]: - split = version.strip("v").split(".") - return int(split[0]), int(split[1]) - - -class BackwardCompatibilityError(Exception): - def __init__(self, repo_id, version): - message = textwrap.dedent(f""" - BackwardCompatibilityError: The dataset you requested ({repo_id}) is in {version} format. - - We introduced a new format since v2.0 which is not backward compatible with v1.x. - Please, use our conversion script. Modify the following command with your own task description: - ``` - python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\ - --repo-id {repo_id} \\ - --single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\ - ``` - - A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.", - "Insert the peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.", - "Open the top cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped target.", - "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the sweatshirt.", ... - - If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb) - or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose). - """) - super().__init__(message) +def is_valid_version(version: str) -> bool: + try: + packaging.version.parse(version) + return True + except packaging.version.InvalidVersion: + return False def check_version_compatibility( - repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True + repo_id: str, + version_to_check: str | packaging.version.Version, + current_version: str | packaging.version.Version, + enforce_breaking_major: bool = True, ) -> None: - current_major, _ = _get_major_minor(current_version) - major_to_check, _ = _get_major_minor(version_to_check) - if major_to_check < current_major and enforce_breaking_major: - raise BackwardCompatibilityError(repo_id, version_to_check) - elif float(version_to_check.strip("v")) < float(current_version.strip("v")): - logging.warning( - f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the - codebase. The current codebase version is {current_version}. You should be fine since - backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on - Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""", - ) + v_check = ( + packaging.version.parse(version_to_check) + if not isinstance(version_to_check, packaging.version.Version) + else version_to_check + ) + v_current = ( + packaging.version.parse(current_version) + if not isinstance(current_version, packaging.version.Version) + else current_version + ) + if v_check.major < v_current.major and enforce_breaking_major: + raise BackwardCompatibilityError(repo_id, v_check) + elif v_check.minor < v_current.minor: + logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=v_check)) -def get_hub_safe_version(repo_id: str, version: str) -> str: +def get_repo_versions(repo_id: str) -> list[packaging.version.Version]: + """Returns available valid versions (branches and tags) on given repo.""" api = HfApi() - dataset_info = api.list_repo_refs(repo_id, repo_type="dataset") - branches = [b.name for b in dataset_info.branches] - if version not in branches: - num_version = float(version.strip("v")) - hub_num_versions = [float(v.strip("v")) for v in branches if v.startswith("v")] - if num_version >= 2.0 and all(v < 2.0 for v in hub_num_versions): - raise BackwardCompatibilityError(repo_id, version) + repo_refs = api.list_repo_refs(repo_id, repo_type="dataset") + repo_refs = [b.name for b in repo_refs.branches + repo_refs.tags] + repo_versions = [] + for ref in repo_refs: + with contextlib.suppress(packaging.version.InvalidVersion): + repo_versions.append(packaging.version.parse(ref)) - logging.warning( - f"""You are trying to load a dataset from {repo_id} created with a previous version of the - codebase. The following versions are available: {branches}. - The requested version ('{version}') is not found. You should be fine since - backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on - Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""", - ) - if "main" not in branches: - raise ValueError(f"Version 'main' not found on {repo_id}") - return "main" - else: - return version + return repo_versions + + +def get_safe_version(repo_id: str, version: str | packaging.version.Version) -> str: + """ + Returns the version if available on repo or the latest compatible one. + Otherwise, will throw a `CompatibilityError`. + """ + target_version = ( + packaging.version.parse(version) if not isinstance(version, packaging.version.Version) else version + ) + hub_versions = get_repo_versions(repo_id) + + if target_version in hub_versions: + return f"v{target_version}" + + compatibles = [ + v for v in hub_versions if v.major == target_version.major and v.minor <= target_version.minor + ] + if compatibles: + return_version = max(compatibles) + if return_version < target_version: + logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}") + return f"v{return_version}" + + lower_major = [v for v in hub_versions if v.major < target_version.major] + if lower_major: + raise BackwardCompatibilityError(repo_id, max(lower_major)) + + upper_versions = [v for v in hub_versions if v > target_version] + assert len(upper_versions) > 0 + raise ForwardCompatibilityError(repo_id, min(upper_versions)) def get_hf_features_from_features(features: dict) -> datasets.Features: @@ -283,11 +355,20 @@ def get_hf_features_from_features(features: dict) -> datasets.Features: hf_features[key] = datasets.Image() elif ft["shape"] == (1,): hf_features[key] = datasets.Value(dtype=ft["dtype"]) - else: - assert len(ft["shape"]) == 1 + elif len(ft["shape"]) == 1: hf_features[key] = datasets.Sequence( length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"]) ) + elif len(ft["shape"]) == 2: + hf_features[key] = datasets.Array2D(shape=ft["shape"], dtype=ft["dtype"]) + elif len(ft["shape"]) == 3: + hf_features[key] = datasets.Array3D(shape=ft["shape"], dtype=ft["dtype"]) + elif len(ft["shape"]) == 4: + hf_features[key] = datasets.Array4D(shape=ft["shape"], dtype=ft["dtype"]) + elif len(ft["shape"]) == 5: + hf_features[key] = datasets.Array5D(shape=ft["shape"], dtype=ft["dtype"]) + else: + raise ValueError(f"Corresponding feature is not valid: {ft}") return datasets.Features(hf_features) @@ -358,9 +439,9 @@ def create_empty_dataset_info( def get_episode_data_index( - episode_dicts: list[dict], episodes: list[int] | None = None + episode_dicts: dict[dict], episodes: list[int] | None = None ) -> dict[str, torch.Tensor]: - episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)} + episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in episode_dicts.items()} if episodes is not None: episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes} @@ -371,75 +452,72 @@ def get_episode_data_index( } -def calculate_total_episode( - hf_dataset: datasets.Dataset, raise_if_not_contiguous: bool = True -) -> dict[str, torch.Tensor]: - episode_indices = sorted(hf_dataset.unique("episode_index")) - total_episodes = len(episode_indices) - if raise_if_not_contiguous and episode_indices != list(range(total_episodes)): - raise ValueError("episode_index values are not sorted and contiguous.") - return total_episodes - - -def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, torch.Tensor]: - episode_lengths = [] - table = hf_dataset.data.table - total_episodes = calculate_total_episode(hf_dataset) - for ep_idx in range(total_episodes): - ep_table = table.filter(pc.equal(table["episode_index"], ep_idx)) - episode_lengths.insert(ep_idx, len(ep_table)) - - cumulative_lenghts = list(accumulate(episode_lengths)) - return { - "from": torch.LongTensor([0] + cumulative_lenghts[:-1]), - "to": torch.LongTensor(cumulative_lenghts), - } - - def check_timestamps_sync( - hf_dataset: datasets.Dataset, - episode_data_index: dict[str, torch.Tensor], + timestamps: np.ndarray, + episode_indices: np.ndarray, + episode_data_index: dict[str, np.ndarray], fps: int, tolerance_s: float, raise_value_error: bool = True, ) -> bool: """ - This check is to make sure that each timestamps is separated to the next by 1/fps +/- tolerance to - account for possible numerical error. - """ - timestamps = torch.stack(hf_dataset["timestamp"]) - diffs = torch.diff(timestamps) - within_tolerance = torch.abs(diffs - 1 / fps) <= tolerance_s + This check is to make sure that each timestamp is separated from the next by (1/fps) +/- tolerance + to account for possible numerical error. - # We mask differences between the timestamp at the end of an episode - # and the one at the start of the next episode since these are expected - # to be outside tolerance. - mask = torch.ones(len(diffs), dtype=torch.bool) - ignored_diffs = episode_data_index["to"][:-1] - 1 + Args: + timestamps (np.ndarray): Array of timestamps in seconds. + episode_indices (np.ndarray): Array indicating the episode index for each timestamp. + episode_data_index (dict[str, np.ndarray]): A dictionary that includes 'to', + which identifies indices for the end of each episode. + fps (int): Frames per second. Used to check the expected difference between consecutive timestamps. + tolerance_s (float): Allowed deviation from the expected (1/fps) difference. + raise_value_error (bool): Whether to raise a ValueError if the check fails. + + Returns: + bool: True if all checked timestamp differences lie within tolerance, False otherwise. + + Raises: + ValueError: If the check fails and `raise_value_error` is True. + """ + if timestamps.shape != episode_indices.shape: + raise ValueError( + "timestamps and episode_indices should have the same shape. " + f"Found {timestamps.shape=} and {episode_indices.shape=}." + ) + + # Consecutive differences + diffs = np.diff(timestamps) + within_tolerance = np.abs(diffs - (1.0 / fps)) <= tolerance_s + + # Mask to ignore differences at the boundaries between episodes + mask = np.ones(len(diffs), dtype=bool) + ignored_diffs = episode_data_index["to"][:-1] - 1 # indices at the end of each episode mask[ignored_diffs] = False filtered_within_tolerance = within_tolerance[mask] - if not torch.all(filtered_within_tolerance): + # Check if all remaining diffs are within tolerance + if not np.all(filtered_within_tolerance): # Track original indices before masking - original_indices = torch.arange(len(diffs)) + original_indices = np.arange(len(diffs)) filtered_indices = original_indices[mask] - outside_tolerance_filtered_indices = torch.nonzero(~filtered_within_tolerance) # .squeeze() + outside_tolerance_filtered_indices = np.nonzero(~filtered_within_tolerance)[0] outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices] - episode_indices = torch.stack(hf_dataset["episode_index"]) outside_tolerances = [] for idx in outside_tolerance_indices: entry = { "timestamps": [timestamps[idx], timestamps[idx + 1]], "diff": diffs[idx], - "episode_index": episode_indices[idx].item(), + "episode_index": episode_indices[idx].item() + if hasattr(episode_indices[idx], "item") + else episode_indices[idx], } outside_tolerances.append(entry) if raise_value_error: raise ValueError( f"""One or several timestamps unexpectedly violate the tolerance inside episode range. - This might be due to synchronization issues with timestamps during data collection. + This might be due to synchronization issues during data collection. \n{pformat(outside_tolerances)}""" ) return False @@ -604,3 +682,118 @@ class IterableNamespace(SimpleNamespace): def keys(self): return vars(self).keys() + + +def validate_frame(frame: dict, features: dict): + optional_features = {"timestamp"} + expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"} + actual_features = set(frame.keys()) + + error_message = validate_features_presence(actual_features, expected_features, optional_features) + + if "task" in frame: + error_message += validate_feature_string("task", frame["task"]) + + common_features = actual_features & (expected_features | optional_features) + for name in common_features - {"task"}: + error_message += validate_feature_dtype_and_shape(name, features[name], frame[name]) + + if error_message: + raise ValueError(error_message) + + +def validate_features_presence( + actual_features: set[str], expected_features: set[str], optional_features: set[str] +): + error_message = "" + missing_features = expected_features - actual_features + extra_features = actual_features - (expected_features | optional_features) + + if missing_features or extra_features: + error_message += "Feature mismatch in `frame` dictionary:\n" + if missing_features: + error_message += f"Missing features: {missing_features}\n" + if extra_features: + error_message += f"Extra features: {extra_features}\n" + + return error_message + + +def validate_feature_dtype_and_shape(name: str, feature: dict, value: np.ndarray | PILImage.Image | str): + expected_dtype = feature["dtype"] + expected_shape = feature["shape"] + if is_valid_numpy_dtype_string(expected_dtype): + return validate_feature_numpy_array(name, expected_dtype, expected_shape, value) + elif expected_dtype in ["image", "video"]: + return validate_feature_image_or_video(name, expected_shape, value) + elif expected_dtype == "string": + return validate_feature_string(name, value) + else: + raise NotImplementedError(f"The feature dtype '{expected_dtype}' is not implemented yet.") + + +def validate_feature_numpy_array( + name: str, expected_dtype: str, expected_shape: list[int], value: np.ndarray +): + error_message = "" + if isinstance(value, np.ndarray): + actual_dtype = value.dtype + actual_shape = value.shape + + if actual_dtype != np.dtype(expected_dtype): + error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n" + + if actual_shape != expected_shape: + error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n" + else: + error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n" + + return error_message + + +def validate_feature_image_or_video(name: str, expected_shape: list[str], value: np.ndarray | PILImage.Image): + # Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads. + error_message = "" + if isinstance(value, np.ndarray): + actual_shape = value.shape + c, h, w = expected_shape + if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)): + error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n" + elif isinstance(value, PILImage.Image): + pass + else: + error_message += f"The feature '{name}' is expected to be of type 'PIL.Image' or 'np.ndarray' channel first or channel last, but type '{type(value)}' provided instead.\n" + + return error_message + + +def validate_feature_string(name: str, value: str): + if not isinstance(value, str): + return f"The feature '{name}' is expected to be of type 'str', but type '{type(value)}' provided instead.\n" + return "" + + +def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: dict): + if "size" not in episode_buffer: + raise ValueError("size key not found in episode_buffer") + + if "task" not in episode_buffer: + raise ValueError("task key not found in episode_buffer") + + if episode_buffer["episode_index"] != total_episodes: + # TODO(aliberts): Add option to use existing episode_index + raise NotImplementedError( + "You might have manually provided the episode_buffer with an episode_index that doesn't " + "match the total number of episodes already in the dataset. This is not supported for now." + ) + + if episode_buffer["size"] == 0: + raise ValueError("You must add one or several frames with `add_frame` before calling `add_episode`.") + + buffer_keys = set(episode_buffer.keys()) - {"task", "size"} + if not buffer_keys == set(features): + raise ValueError( + f"Features from `episode_buffer` don't match the ones in `features`." + f"In episode_buffer not in features: {buffer_keys - set(features)}" + f"In features not in episode_buffer: {set(features) - buffer_keys}" + ) diff --git a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py index 62ca9932..99864e3b 100644 --- a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py +++ b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py @@ -130,7 +130,7 @@ from lerobot.common.datasets.utils import ( create_branch, create_lerobot_dataset_card, flatten_dict, - get_hub_safe_version, + get_safe_version, load_json, unflatten_dict, write_json, @@ -443,7 +443,7 @@ def convert_dataset( test_branch: str | None = None, **card_kwargs, ): - v1 = get_hub_safe_version(repo_id, V16) + v1 = get_safe_version(repo_id, V16) v1x_dir = local_dir / V16 / repo_id v20_dir = local_dir / V20 / repo_id v1x_dir.mkdir(parents=True, exist_ok=True) diff --git a/lerobot/common/datasets/v21/_remove_language_instruction.py b/lerobot/common/datasets/v21/_remove_language_instruction.py new file mode 100644 index 00000000..dd4604cf --- /dev/null +++ b/lerobot/common/datasets/v21/_remove_language_instruction.py @@ -0,0 +1,73 @@ +import logging +import traceback +from pathlib import Path + +from datasets import get_dataset_config_info +from huggingface_hub import HfApi + +from lerobot import available_datasets +from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata +from lerobot.common.datasets.utils import INFO_PATH, write_info +from lerobot.common.datasets.v21.convert_dataset_v20_to_v21 import V20, SuppressWarnings + +LOCAL_DIR = Path("data/") + +hub_api = HfApi() + + +def fix_dataset(repo_id: str) -> str: + if not hub_api.revision_exists(repo_id, V20, repo_type="dataset"): + return f"{repo_id}: skipped (not in {V20})." + + dataset_info = get_dataset_config_info(repo_id, "default") + with SuppressWarnings(): + lerobot_metadata = LeRobotDatasetMetadata(repo_id, revision=V20, force_cache_sync=True) + + meta_features = {key for key, ft in lerobot_metadata.features.items() if ft["dtype"] != "video"} + parquet_features = set(dataset_info.features) + + diff_parquet_meta = parquet_features - meta_features + diff_meta_parquet = meta_features - parquet_features + + if diff_parquet_meta: + raise ValueError(f"In parquet not in info.json: {parquet_features - meta_features}") + + if not diff_meta_parquet: + return f"{repo_id}: skipped (no diff)" + + if diff_meta_parquet: + logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}") + assert diff_meta_parquet == {"language_instruction"} + lerobot_metadata.features.pop("language_instruction") + write_info(lerobot_metadata.info, lerobot_metadata.root) + commit_info = hub_api.upload_file( + path_or_fileobj=lerobot_metadata.root / INFO_PATH, + path_in_repo=INFO_PATH, + repo_id=repo_id, + repo_type="dataset", + revision=V20, + commit_message="Remove 'language_instruction'", + create_pr=True, + ) + return f"{repo_id}: success - PR: {commit_info.pr_url}" + + +def batch_fix(): + status = {} + LOCAL_DIR.mkdir(parents=True, exist_ok=True) + logfile = LOCAL_DIR / "fix_features_v20.txt" + for num, repo_id in enumerate(available_datasets): + print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})") + print("---------------------------------------------------------") + try: + status = fix_dataset(repo_id) + except Exception: + status = f"{repo_id}: failed\n {traceback.format_exc()}" + + logging.info(status) + with open(logfile, "a") as file: + file.write(status + "\n") + + +if __name__ == "__main__": + batch_fix() diff --git a/lerobot/common/datasets/v21/batch_convert_dataset_v20_to_v21.py b/lerobot/common/datasets/v21/batch_convert_dataset_v20_to_v21.py new file mode 100644 index 00000000..cc9272a8 --- /dev/null +++ b/lerobot/common/datasets/v21/batch_convert_dataset_v20_to_v21.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.1. +""" + +import traceback +from pathlib import Path + +from huggingface_hub import HfApi + +from lerobot import available_datasets +from lerobot.common.datasets.v21.convert_dataset_v20_to_v21 import V21, convert_dataset + +LOCAL_DIR = Path("data/") + + +def batch_convert(): + status = {} + LOCAL_DIR.mkdir(parents=True, exist_ok=True) + logfile = LOCAL_DIR / "conversion_log_v21.txt" + hub_api = HfApi() + for num, repo_id in enumerate(available_datasets): + print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})") + print("---------------------------------------------------------") + try: + if hub_api.revision_exists(repo_id, V21, repo_type="dataset"): + status = f"{repo_id}: success (already in {V21})." + else: + convert_dataset(repo_id) + status = f"{repo_id}: success." + except Exception: + status = f"{repo_id}: failed\n {traceback.format_exc()}" + + with open(logfile, "a") as file: + file.write(status + "\n") + + +if __name__ == "__main__": + batch_convert() diff --git a/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py new file mode 100644 index 00000000..20bda75b --- /dev/null +++ b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py @@ -0,0 +1,100 @@ +""" +This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 2.0 to +2.1. It will: + +- Generate per-episodes stats and writes them in `episodes_stats.jsonl` +- Check consistency between these new stats and the old ones. +- Remove the deprecated `stats.json`. +- Update codebase_version in `info.json`. +- Push this new version to the hub on the 'main' branch and tags it with "v2.1". + +Usage: + +```bash +python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py \ + --repo-id=aliberts/koch_tutorial +``` + +""" + +import argparse +import logging + +from huggingface_hub import HfApi + +from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset +from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info +from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats + +V20 = "v2.0" +V21 = "v2.1" + + +class SuppressWarnings: + def __enter__(self): + self.previous_level = logging.getLogger().getEffectiveLevel() + logging.getLogger().setLevel(logging.ERROR) + + def __exit__(self, exc_type, exc_val, exc_tb): + logging.getLogger().setLevel(self.previous_level) + + +def convert_dataset( + repo_id: str, + branch: str | None = None, + num_workers: int = 4, +): + with SuppressWarnings(): + dataset = LeRobotDataset(repo_id, revision=V20, force_cache_sync=True) + + if (dataset.root / EPISODES_STATS_PATH).is_file(): + (dataset.root / EPISODES_STATS_PATH).unlink() + + convert_stats(dataset, num_workers=num_workers) + ref_stats = load_stats(dataset.root) + check_aggregate_stats(dataset, ref_stats) + + dataset.meta.info["codebase_version"] = CODEBASE_VERSION + write_info(dataset.meta.info, dataset.root) + + dataset.push_to_hub(branch=branch, allow_patterns="meta/") + + # delete old stats.json file + if (dataset.root / STATS_PATH).is_file: + (dataset.root / STATS_PATH).unlink() + + hub_api = HfApi() + if hub_api.file_exists( + repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset" + ): + hub_api.delete_file( + path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset" + ) + + hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--repo-id", + type=str, + required=True, + help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset " + "(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).", + ) + parser.add_argument( + "--branch", + type=str, + default=None, + help="Repo branch to push your dataset. Defaults to the main branch.", + ) + parser.add_argument( + "--num-workers", + type=int, + default=4, + help="Number of workers for parallelizing stats compute. Defaults to 4.", + ) + + args = parser.parse_args() + convert_dataset(**vars(args)) diff --git a/lerobot/common/datasets/v21/convert_stats.py b/lerobot/common/datasets/v21/convert_stats.py new file mode 100644 index 00000000..cbf584b7 --- /dev/null +++ b/lerobot/common/datasets/v21/convert_stats.py @@ -0,0 +1,85 @@ +from concurrent.futures import ThreadPoolExecutor, as_completed + +import numpy as np +from tqdm import tqdm + +from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.utils import write_episode_stats + + +def sample_episode_video_frames(dataset: LeRobotDataset, episode_index: int, ft_key: str) -> np.ndarray: + ep_len = dataset.meta.episodes[episode_index]["length"] + sampled_indices = sample_indices(ep_len) + query_timestamps = dataset._get_query_timestamps(0.0, {ft_key: sampled_indices}) + video_frames = dataset._query_videos(query_timestamps, episode_index) + return video_frames[ft_key].numpy() + + +def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int): + ep_start_idx = dataset.episode_data_index["from"][ep_idx] + ep_end_idx = dataset.episode_data_index["to"][ep_idx] + ep_data = dataset.hf_dataset.select(range(ep_start_idx, ep_end_idx)) + + ep_stats = {} + for key, ft in dataset.features.items(): + if ft["dtype"] == "video": + # We sample only for videos + ep_ft_data = sample_episode_video_frames(dataset, ep_idx, key) + else: + ep_ft_data = np.array(ep_data[key]) + + axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0 + keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1 + ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims) + + if ft["dtype"] in ["image", "video"]: # remove batch dim + ep_stats[key] = { + k: v if k == "count" else np.squeeze(v, axis=0) for k, v in ep_stats[key].items() + } + + dataset.meta.episodes_stats[ep_idx] = ep_stats + + +def convert_stats(dataset: LeRobotDataset, num_workers: int = 0): + assert dataset.episodes is None + print("Computing episodes stats") + total_episodes = dataset.meta.total_episodes + if num_workers > 0: + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = { + executor.submit(convert_episode_stats, dataset, ep_idx): ep_idx + for ep_idx in range(total_episodes) + } + for future in tqdm(as_completed(futures), total=total_episodes): + future.result() + else: + for ep_idx in tqdm(range(total_episodes)): + convert_episode_stats(dataset, ep_idx) + + for ep_idx in tqdm(range(total_episodes)): + write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root) + + +def check_aggregate_stats( + dataset: LeRobotDataset, + reference_stats: dict[str, dict[str, np.ndarray]], + video_rtol_atol: tuple[float] = (1e-2, 1e-2), + default_rtol_atol: tuple[float] = (5e-6, 6e-5), +): + """Verifies that the aggregated stats from episodes_stats are close to reference stats.""" + agg_stats = aggregate_stats(list(dataset.meta.episodes_stats.values())) + for key, ft in dataset.features.items(): + # These values might need some fine-tuning + if ft["dtype"] == "video": + # to account for image sub-sampling + rtol, atol = video_rtol_atol + else: + rtol, atol = default_rtol_atol + + for stat, val in agg_stats[key].items(): + if key in reference_stats and stat in reference_stats[key]: + err_msg = f"feature='{key}' stats='{stat}'" + np.testing.assert_allclose( + val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg + ) diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index 8ed3318d..8be53483 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -69,8 +69,8 @@ def decode_video_frames_torchvision( # set the first and last requested timestamps # Note: previous timestamps are usually loaded, since we need to access the previous key frame - first_ts = timestamps[0] - last_ts = timestamps[-1] + first_ts = min(timestamps) + last_ts = max(timestamps) # access closest key frame of the first requested frame # Note: closest key frame timestamp is usally smaller than `first_ts` (e.g. key frame can be the first frame of the video) diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index 95219273..b3255ec1 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np import torch from torch import Tensor, nn @@ -77,17 +78,29 @@ def create_stats_buffers( } ) + # TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch) if stats: - # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated - # tensors anywhere (for example, when we use the same stats for normalization and - # unnormalization). See the logic here - # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. - if norm_mode is NormalizationMode.MEAN_STD: - buffer["mean"].data = stats[key]["mean"].clone() - buffer["std"].data = stats[key]["std"].clone() - elif norm_mode is NormalizationMode.MIN_MAX: - buffer["min"].data = stats[key]["min"].clone() - buffer["max"].data = stats[key]["max"].clone() + if isinstance(stats[key]["mean"], np.ndarray): + if norm_mode is NormalizationMode.MEAN_STD: + buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32) + buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32) + elif norm_mode is NormalizationMode.MIN_MAX: + buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32) + buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32) + elif isinstance(stats[key]["mean"], torch.Tensor): + # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated + # tensors anywhere (for example, when we use the same stats for normalization and + # unnormalization). See the logic here + # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. + if norm_mode is NormalizationMode.MEAN_STD: + buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32) + buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32) + elif norm_mode is NormalizationMode.MIN_MAX: + buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32) + buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32) + else: + type_ = type(stats[key]["mean"]) + raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.") stats_buffers[key] = buffer return stats_buffers @@ -141,6 +154,7 @@ class Normalize(nn.Module): batch = dict(batch) # shallow copy avoids mutating the input batch for key, ft in self.features.items(): if key not in batch: + # FIXME(aliberts, rcadene): This might lead to silent fail! continue norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) diff --git a/lerobot/common/robot_devices/control_configs.py b/lerobot/common/robot_devices/control_configs.py index 6dae8cb6..c3e920b1 100644 --- a/lerobot/common/robot_devices/control_configs.py +++ b/lerobot/common/robot_devices/control_configs.py @@ -60,8 +60,6 @@ class RecordControlConfig(ControlConfig): num_episodes: int = 50 # Encode frames in the dataset into video video: bool = True - # By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode. - run_compute_stats: bool = True # Upload dataset to Hugging Face hub. push_to_hub: bool = True # Upload on private repository on the Hugging Face hub. @@ -83,9 +81,6 @@ class RecordControlConfig(ControlConfig): play_sounds: bool = True # Resume recording on an existing dataset. resume: bool = False - # TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument - # Use local files only. By default, this script will try to fetch the dataset from the hub if it exists. - local_files_only: bool = False def __post_init__(self): # HACK: We parse again the cli args here to get the pretrained path if there was one. @@ -130,9 +125,6 @@ class ReplayControlConfig(ControlConfig): fps: int | None = None # Use vocal synthesis to read events. play_sounds: bool = True - # TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument - # Use local files only. By default, this script will try to fetch the dataset from the hub if it exists. - local_files_only: bool = False @ControlConfig.register_subclass("remote_robot") diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 7264f078..6c97d0cb 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -183,6 +183,7 @@ def record_episode( device, use_amp, fps, + single_task, ): control_loop( robot=robot, @@ -195,6 +196,7 @@ def record_episode( use_amp=use_amp, fps=fps, teleoperate=policy is None, + single_task=single_task, ) @@ -210,6 +212,7 @@ def control_loop( device: torch.device | str | None = None, use_amp: bool | None = None, fps: int | None = None, + single_task: str | None = None, ): # TODO(rcadene): Add option to record logs if not robot.is_connected: @@ -224,6 +227,9 @@ def control_loop( if teleoperate and policy is not None: raise ValueError("When `teleoperate` is True, `policy` should be None.") + if dataset is not None and single_task is None: + raise ValueError("You need to provide a task as argument in `single_task`.") + if dataset is not None and fps is not None and dataset.fps != fps: raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).") @@ -248,7 +254,7 @@ def control_loop( action = {"action": action} if dataset is not None: - frame = {**observation, **action} + frame = {**observation, **action, "task": single_task} dataset.add_frame(frame) if display_cameras and not is_headless(): diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py index 015d1ede..d0c12b30 100644 --- a/lerobot/common/utils/utils.py +++ b/lerobot/common/utils/utils.py @@ -21,6 +21,7 @@ from copy import copy from datetime import datetime, timezone from pathlib import Path +import numpy as np import torch @@ -200,5 +201,18 @@ def get_channel_first_image_shape(image_shape: tuple) -> tuple: return shape -def has_method(cls: object, method_name: str): +def has_method(cls: object, method_name: str) -> bool: return hasattr(cls, method_name) and callable(getattr(cls, method_name)) + + +def is_valid_numpy_dtype_string(dtype_str: str) -> bool: + """ + Return True if a given string can be converted to a numpy dtype. + """ + try: + # Attempt to convert the string to a numpy dtype + np.dtype(dtype_str) + return True + except TypeError: + # If a TypeError is raised, the string is not a valid dtype + return False diff --git a/lerobot/configs/default.py b/lerobot/configs/default.py index 5dd2f898..a5013431 100644 --- a/lerobot/configs/default.py +++ b/lerobot/configs/default.py @@ -31,7 +31,7 @@ class DatasetConfig: repo_id: str episodes: list[int] | None = None image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig) - local_files_only: bool = False + revision: str | None = None use_imagenet_stats: bool = True video_backend: str = "pyav" diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index 9129c9e3..32f3b181 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -92,7 +92,6 @@ python lerobot/scripts/control_robot.py \ This might require a sudo permission to allow your terminal to monitor keyboard events. **NOTE**: You can resume/continue data recording by running the same data recording command and adding `--control.resume=true`. -If the dataset you want to extend is not on the hub, you also need to add `--control.local_files_only=true`. - Train on this dataset with the ACT policy: ```bash @@ -234,7 +233,6 @@ def record( dataset = LeRobotDataset( cfg.repo_id, root=cfg.root, - local_files_only=cfg.local_files_only, ) if len(robot.cameras) > 0: dataset.start_image_writer( @@ -281,8 +279,8 @@ def record( log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds) record_episode( - dataset=dataset, robot=robot, + dataset=dataset, events=events, episode_time_s=cfg.episode_time_s, display_cameras=cfg.display_cameras, @@ -290,6 +288,7 @@ def record( device=cfg.device, use_amp=cfg.use_amp, fps=cfg.fps, + single_task=cfg.single_task, ) # Execute a few seconds without recording to give time to manually reset the environment @@ -309,7 +308,7 @@ def record( dataset.clear_episode_buffer() continue - dataset.save_episode(cfg.single_task) + dataset.save_episode() recorded_episodes += 1 if events["stop_recording"]: @@ -318,11 +317,6 @@ def record( log_say("Stop recording", cfg.play_sounds, blocking=True) stop_recording(robot, listener, cfg.display_cameras) - if cfg.run_compute_stats: - logging.info("Computing dataset statistics") - - dataset.consolidate(cfg.run_compute_stats) - if cfg.push_to_hub: dataset.push_to_hub(tags=cfg.tags, private=cfg.private) @@ -338,9 +332,7 @@ def replay( # TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset # TODO(rcadene): Add option to record logs - dataset = LeRobotDataset( - cfg.repo_id, root=cfg.root, episodes=[cfg.episode], local_files_only=cfg.local_files_only - ) + dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode]) actions = dataset.hf_dataset.select_columns("action") if not robot.is_connected: diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index 626b0bde..11feb1af 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -207,12 +207,6 @@ def main(): required=True, help="Episode to visualize.", ) - parser.add_argument( - "--local-files-only", - type=int, - default=0, - help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.", - ) parser.add_argument( "--root", type=Path, @@ -275,10 +269,9 @@ def main(): kwargs = vars(args) repo_id = kwargs.pop("repo_id") root = kwargs.pop("root") - local_files_only = kwargs.pop("local_files_only") logging.info("Loading dataset") - dataset = LeRobotDataset(repo_id, root=root, local_files_only=local_files_only) + dataset = LeRobotDataset(repo_id, root=root) visualize_dataset(dataset, **vars(args)) diff --git a/lerobot/scripts/visualize_dataset_html.py b/lerobot/scripts/visualize_dataset_html.py index cc3f3930..ed748c9a 100644 --- a/lerobot/scripts/visualize_dataset_html.py +++ b/lerobot/scripts/visualize_dataset_html.py @@ -150,7 +150,7 @@ def run_server( 400, ) dataset_version = ( - dataset.meta._version if isinstance(dataset, LeRobotDataset) else dataset.codebase_version + str(dataset.meta._version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version ) match = re.search(r"v(\d+)\.", dataset_version) if match: @@ -384,12 +384,6 @@ def main(): default=None, help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).", ) - parser.add_argument( - "--local-files-only", - type=int, - default=0, - help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.", - ) parser.add_argument( "--root", type=Path, @@ -445,15 +439,10 @@ def main(): repo_id = kwargs.pop("repo_id") load_from_hf_hub = kwargs.pop("load_from_hf_hub") root = kwargs.pop("root") - local_files_only = kwargs.pop("local_files_only") dataset = None if repo_id: - dataset = ( - LeRobotDataset(repo_id, root=root, local_files_only=local_files_only) - if not load_from_hf_hub - else get_dataset_info(repo_id) - ) + dataset = LeRobotDataset(repo_id, root=root) if not load_from_hf_hub else get_dataset_info(repo_id) visualize_dataset_html(dataset, **vars(args)) diff --git a/lerobot/scripts/visualize_image_transforms.py b/lerobot/scripts/visualize_image_transforms.py index 727fe178..80935d32 100644 --- a/lerobot/scripts/visualize_image_transforms.py +++ b/lerobot/scripts/visualize_image_transforms.py @@ -109,7 +109,7 @@ def visualize_image_transforms(cfg: DatasetConfig, output_dir: Path = OUTPUT_DIR dataset = LeRobotDataset( repo_id=cfg.repo_id, episodes=cfg.episodes, - local_files_only=cfg.local_files_only, + revision=cfg.revision, video_backend=cfg.video_backend, ) diff --git a/pyproject.toml b/pyproject.toml index 21c3fc78..25cadb3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "numba>=0.59.0", "omegaconf>=2.3.0", "opencv-python>=4.9.0", + "packaging>=24.2", "pyav>=12.0.5", "pymunk>=6.6.0", "pynput>=1.7.7", diff --git a/tests/fixtures/constants.py b/tests/fixtures/constants.py index bfe6c339..3201dcf2 100644 --- a/tests/fixtures/constants.py +++ b/tests/fixtures/constants.py @@ -1,6 +1,6 @@ -from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME +from lerobot.common.constants import HF_LEROBOT_HOME -LEROBOT_TEST_DIR = LEROBOT_HOME / "_testing" +LEROBOT_TEST_DIR = HF_LEROBOT_HOME / "_testing" DUMMY_REPO_ID = "dummy/repo" DUMMY_ROBOT_TYPE = "dummy_robot" DUMMY_MOTOR_FEATURES = { @@ -27,3 +27,5 @@ DUMMY_VIDEO_INFO = { "video.is_depth_map": False, "has_audio": False, } +DUMMY_CHW = (3, 96, 128) +DUMMY_HWC = (96, 128, 3) diff --git a/tests/fixtures/dataset_factories.py b/tests/fixtures/dataset_factories.py index c28a1165..2259e0e6 100644 --- a/tests/fixtures/dataset_factories.py +++ b/tests/fixtures/dataset_factories.py @@ -1,5 +1,7 @@ import random +from functools import partial from pathlib import Path +from typing import Protocol from unittest.mock import patch import datasets @@ -27,8 +29,12 @@ from tests.fixtures.constants import ( ) +class LeRobotDatasetFactory(Protocol): + def __call__(self, *args, **kwargs) -> LeRobotDataset: ... + + def get_task_index(task_dicts: dict, task: str) -> int: - tasks = {d["task_index"]: d["task"] for d in task_dicts} + tasks = {d["task_index"]: d["task"] for d in task_dicts.values()} task_to_task_index = {task: task_idx for task_idx, task in tasks.items()} return task_to_task_index[task] @@ -141,6 +147,7 @@ def stats_factory(): "mean": np.full((3, 1, 1), 0.5, dtype=np.float32).tolist(), "min": np.full((3, 1, 1), 0, dtype=np.float32).tolist(), "std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(), + "count": [10], } else: stats[key] = { @@ -148,20 +155,38 @@ def stats_factory(): "mean": np.full(shape, 0.5, dtype=dtype).tolist(), "min": np.full(shape, 0, dtype=dtype).tolist(), "std": np.full(shape, 0.25, dtype=dtype).tolist(), + "count": [10], } return stats return _create_stats +@pytest.fixture(scope="session") +def episodes_stats_factory(stats_factory): + def _create_episodes_stats( + features: dict[str], + total_episodes: int = 3, + ) -> dict: + episodes_stats = {} + for episode_index in range(total_episodes): + episodes_stats[episode_index] = { + "episode_index": episode_index, + "stats": stats_factory(features), + } + return episodes_stats + + return _create_episodes_stats + + @pytest.fixture(scope="session") def tasks_factory(): def _create_tasks(total_tasks: int = 3) -> int: - tasks_list = [] - for i in range(total_tasks): - task_dict = {"task_index": i, "task": f"Perform action {i}."} - tasks_list.append(task_dict) - return tasks_list + tasks = {} + for task_index in range(total_tasks): + task_dict = {"task_index": task_index, "task": f"Perform action {task_index}."} + tasks[task_index] = task_dict + return tasks return _create_tasks @@ -190,10 +215,10 @@ def episodes_factory(tasks_factory): # Generate random lengths that sum up to total_length lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist() - tasks_list = [task_dict["task"] for task_dict in tasks] + tasks_list = [task_dict["task"] for task_dict in tasks.values()] num_tasks_available = len(tasks_list) - episodes_list = [] + episodes = {} remaining_tasks = tasks_list.copy() for ep_idx in range(total_episodes): num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1 @@ -203,15 +228,13 @@ def episodes_factory(tasks_factory): for task in episode_tasks: remaining_tasks.remove(task) - episodes_list.append( - { - "episode_index": ep_idx, - "tasks": episode_tasks, - "length": lengths[ep_idx], - } - ) + episodes[ep_idx] = { + "episode_index": ep_idx, + "tasks": episode_tasks, + "length": lengths[ep_idx], + } - return episodes_list + return episodes return _create_episodes @@ -235,7 +258,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar frame_index_col = np.array([], dtype=np.int64) episode_index_col = np.array([], dtype=np.int64) task_index = np.array([], dtype=np.int64) - for ep_dict in episodes: + for ep_dict in episodes.values(): timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps)) frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int))) episode_index_col = np.concatenate( @@ -278,6 +301,7 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar def lerobot_dataset_metadata_factory( info_factory, stats_factory, + episodes_stats_factory, tasks_factory, episodes_factory, mock_snapshot_download_factory, @@ -287,14 +311,18 @@ def lerobot_dataset_metadata_factory( repo_id: str = DUMMY_REPO_ID, info: dict | None = None, stats: dict | None = None, + episodes_stats: list[dict] | None = None, tasks: list[dict] | None = None, episodes: list[dict] | None = None, - local_files_only: bool = False, ) -> LeRobotDatasetMetadata: if not info: info = info_factory() if not stats: stats = stats_factory(features=info["features"]) + if not episodes_stats: + episodes_stats = episodes_stats_factory( + features=info["features"], total_episodes=info["total_episodes"] + ) if not tasks: tasks = tasks_factory(total_tasks=info["total_tasks"]) if not episodes: @@ -305,21 +333,20 @@ def lerobot_dataset_metadata_factory( mock_snapshot_download = mock_snapshot_download_factory( info=info, stats=stats, + episodes_stats=episodes_stats, tasks=tasks, episodes=episodes, ) with ( - patch( - "lerobot.common.datasets.lerobot_dataset.get_hub_safe_version" - ) as mock_get_hub_safe_version_patch, + patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch, patch( "lerobot.common.datasets.lerobot_dataset.snapshot_download" ) as mock_snapshot_download_patch, ): - mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version: version + mock_get_safe_version_patch.side_effect = lambda repo_id, version: version mock_snapshot_download_patch.side_effect = mock_snapshot_download - return LeRobotDatasetMetadata(repo_id=repo_id, root=root, local_files_only=local_files_only) + return LeRobotDatasetMetadata(repo_id=repo_id, root=root) return _create_lerobot_dataset_metadata @@ -328,12 +355,13 @@ def lerobot_dataset_metadata_factory( def lerobot_dataset_factory( info_factory, stats_factory, + episodes_stats_factory, tasks_factory, episodes_factory, hf_dataset_factory, mock_snapshot_download_factory, lerobot_dataset_metadata_factory, -): +) -> LeRobotDatasetFactory: def _create_lerobot_dataset( root: Path, repo_id: str = DUMMY_REPO_ID, @@ -343,6 +371,7 @@ def lerobot_dataset_factory( multi_task: bool = False, info: dict | None = None, stats: dict | None = None, + episodes_stats: list[dict] | None = None, tasks: list[dict] | None = None, episode_dicts: list[dict] | None = None, hf_dataset: datasets.Dataset | None = None, @@ -354,6 +383,8 @@ def lerobot_dataset_factory( ) if not stats: stats = stats_factory(features=info["features"]) + if not episodes_stats: + episodes_stats = episodes_stats_factory(features=info["features"], total_episodes=total_episodes) if not tasks: tasks = tasks_factory(total_tasks=info["total_tasks"]) if not episode_dicts: @@ -369,6 +400,7 @@ def lerobot_dataset_factory( mock_snapshot_download = mock_snapshot_download_factory( info=info, stats=stats, + episodes_stats=episodes_stats, tasks=tasks, episodes=episode_dicts, hf_dataset=hf_dataset, @@ -378,19 +410,26 @@ def lerobot_dataset_factory( repo_id=repo_id, info=info, stats=stats, + episodes_stats=episodes_stats, tasks=tasks, episodes=episode_dicts, - local_files_only=kwargs.get("local_files_only", False), ) with ( patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch, + patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch, patch( "lerobot.common.datasets.lerobot_dataset.snapshot_download" ) as mock_snapshot_download_patch, ): mock_metadata_patch.return_value = mock_metadata + mock_get_safe_version_patch.side_effect = lambda repo_id, version: version mock_snapshot_download_patch.side_effect = mock_snapshot_download return LeRobotDataset(repo_id=repo_id, root=root, **kwargs) return _create_lerobot_dataset + + +@pytest.fixture(scope="session") +def empty_lerobot_dataset_factory() -> LeRobotDatasetFactory: + return partial(LeRobotDataset.create, repo_id=DUMMY_REPO_ID, fps=DEFAULT_FPS) diff --git a/tests/fixtures/files.py b/tests/fixtures/files.py index 5fe8a314..4ef12e49 100644 --- a/tests/fixtures/files.py +++ b/tests/fixtures/files.py @@ -7,7 +7,13 @@ import pyarrow.compute as pc import pyarrow.parquet as pq import pytest -from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH +from lerobot.common.datasets.utils import ( + EPISODES_PATH, + EPISODES_STATS_PATH, + INFO_PATH, + STATS_PATH, + TASKS_PATH, +) @pytest.fixture(scope="session") @@ -38,6 +44,20 @@ def stats_path(stats_factory): return _create_stats_json_file +@pytest.fixture(scope="session") +def episodes_stats_path(episodes_stats_factory): + def _create_episodes_stats_jsonl_file(dir: Path, episodes_stats: list[dict] | None = None) -> Path: + if not episodes_stats: + episodes_stats = episodes_stats_factory() + fpath = dir / EPISODES_STATS_PATH + fpath.parent.mkdir(parents=True, exist_ok=True) + with jsonlines.open(fpath, "w") as writer: + writer.write_all(episodes_stats.values()) + return fpath + + return _create_episodes_stats_jsonl_file + + @pytest.fixture(scope="session") def tasks_path(tasks_factory): def _create_tasks_jsonl_file(dir: Path, tasks: list | None = None) -> Path: @@ -46,7 +66,7 @@ def tasks_path(tasks_factory): fpath = dir / TASKS_PATH fpath.parent.mkdir(parents=True, exist_ok=True) with jsonlines.open(fpath, "w") as writer: - writer.write_all(tasks) + writer.write_all(tasks.values()) return fpath return _create_tasks_jsonl_file @@ -60,7 +80,7 @@ def episode_path(episodes_factory): fpath = dir / EPISODES_PATH fpath.parent.mkdir(parents=True, exist_ok=True) with jsonlines.open(fpath, "w") as writer: - writer.write_all(episodes) + writer.write_all(episodes.values()) return fpath return _create_episodes_jsonl_file diff --git a/tests/fixtures/hub.py b/tests/fixtures/hub.py index 351768c0..ae309cb4 100644 --- a/tests/fixtures/hub.py +++ b/tests/fixtures/hub.py @@ -4,7 +4,13 @@ import datasets import pytest from huggingface_hub.utils import filter_repo_objects -from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH +from lerobot.common.datasets.utils import ( + EPISODES_PATH, + EPISODES_STATS_PATH, + INFO_PATH, + STATS_PATH, + TASKS_PATH, +) from tests.fixtures.constants import LEROBOT_TEST_DIR @@ -14,6 +20,8 @@ def mock_snapshot_download_factory( info_path, stats_factory, stats_path, + episodes_stats_factory, + episodes_stats_path, tasks_factory, tasks_path, episodes_factory, @@ -29,6 +37,7 @@ def mock_snapshot_download_factory( def _mock_snapshot_download_func( info: dict | None = None, stats: dict | None = None, + episodes_stats: list[dict] | None = None, tasks: list[dict] | None = None, episodes: list[dict] | None = None, hf_dataset: datasets.Dataset | None = None, @@ -37,6 +46,10 @@ def mock_snapshot_download_factory( info = info_factory() if not stats: stats = stats_factory(features=info["features"]) + if not episodes_stats: + episodes_stats = episodes_stats_factory( + features=info["features"], total_episodes=info["total_episodes"] + ) if not tasks: tasks = tasks_factory(total_tasks=info["total_tasks"]) if not episodes: @@ -67,11 +80,11 @@ def mock_snapshot_download_factory( # List all possible files all_files = [] - meta_files = [INFO_PATH, STATS_PATH, TASKS_PATH, EPISODES_PATH] + meta_files = [INFO_PATH, STATS_PATH, EPISODES_STATS_PATH, TASKS_PATH, EPISODES_PATH] all_files.extend(meta_files) data_files = [] - for episode_dict in episodes: + for episode_dict in episodes.values(): ep_idx = episode_dict["episode_index"] ep_chunk = ep_idx // info["chunks_size"] data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx) @@ -92,6 +105,8 @@ def mock_snapshot_download_factory( _ = info_path(local_dir, info) elif rel_path == STATS_PATH: _ = stats_path(local_dir, stats) + elif rel_path == EPISODES_STATS_PATH: + _ = episodes_stats_path(local_dir, episodes_stats) elif rel_path == TASKS_PATH: _ = tasks_path(local_dir, tasks) elif rel_path == EPISODES_PATH: diff --git a/tests/test_cameras.py b/tests/test_cameras.py index 1a1812f7..7c043c25 100644 --- a/tests/test_cameras.py +++ b/tests/test_cameras.py @@ -182,7 +182,7 @@ def test_camera(request, camera_type, mock): @pytest.mark.parametrize("camera_type, mock", TEST_CAMERA_TYPES) @require_camera -def test_save_images_from_cameras(tmpdir, request, camera_type, mock): +def test_save_images_from_cameras(tmp_path, request, camera_type, mock): # TODO(rcadene): refactor if camera_type == "opencv": from lerobot.common.robot_devices.cameras.opencv import save_images_from_cameras @@ -190,4 +190,4 @@ def test_save_images_from_cameras(tmpdir, request, camera_type, mock): from lerobot.common.robot_devices.cameras.intelrealsense import save_images_from_cameras # Small `record_time_s` to speedup unit tests - save_images_from_cameras(tmpdir, record_time_s=0.02, mock=mock) + save_images_from_cameras(tmp_path, record_time_s=0.02, mock=mock) diff --git a/tests/test_compute_stats.py b/tests/test_compute_stats.py new file mode 100644 index 00000000..d9032c8a --- /dev/null +++ b/tests/test_compute_stats.py @@ -0,0 +1,311 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import patch + +import numpy as np +import pytest + +from lerobot.common.datasets.compute_stats import ( + _assert_type_and_shape, + aggregate_feature_stats, + aggregate_stats, + compute_episode_stats, + estimate_num_samples, + get_feature_stats, + sample_images, + sample_indices, +) + + +def mock_load_image_as_numpy(path, dtype, channel_first): + return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype) + + +@pytest.fixture +def sample_array(): + return np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + + +def test_estimate_num_samples(): + assert estimate_num_samples(1) == 1 + assert estimate_num_samples(10) == 10 + assert estimate_num_samples(100) == 100 + assert estimate_num_samples(200) == 100 + assert estimate_num_samples(1000) == 177 + assert estimate_num_samples(2000) == 299 + assert estimate_num_samples(5000) == 594 + assert estimate_num_samples(10_000) == 1000 + assert estimate_num_samples(20_000) == 1681 + assert estimate_num_samples(50_000) == 3343 + assert estimate_num_samples(500_000) == 10_000 + + +def test_sample_indices(): + indices = sample_indices(10) + assert len(indices) > 0 + assert indices[0] == 0 + assert indices[-1] == 9 + assert len(indices) == estimate_num_samples(10) + + +@patch("lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy) +def test_sample_images(mock_load): + image_paths = [f"image_{i}.jpg" for i in range(100)] + images = sample_images(image_paths) + assert isinstance(images, np.ndarray) + assert images.shape[1:] == (3, 32, 32) + assert images.dtype == np.uint8 + assert len(images) == estimate_num_samples(100) + + +def test_get_feature_stats_images(): + data = np.random.rand(100, 3, 32, 32) + stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True) + assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats + np.testing.assert_equal(stats["count"], np.array([100])) + assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape + + +def test_get_feature_stats_axis_0_keepdims(sample_array): + expected = { + "min": np.array([[1, 2, 3]]), + "max": np.array([[7, 8, 9]]), + "mean": np.array([[4.0, 5.0, 6.0]]), + "std": np.array([[2.44948974, 2.44948974, 2.44948974]]), + "count": np.array([3]), + } + result = get_feature_stats(sample_array, axis=(0,), keepdims=True) + for key in expected: + np.testing.assert_allclose(result[key], expected[key]) + + +def test_get_feature_stats_axis_1(sample_array): + expected = { + "min": np.array([1, 4, 7]), + "max": np.array([3, 6, 9]), + "mean": np.array([2.0, 5.0, 8.0]), + "std": np.array([0.81649658, 0.81649658, 0.81649658]), + "count": np.array([3]), + } + result = get_feature_stats(sample_array, axis=(1,), keepdims=False) + for key in expected: + np.testing.assert_allclose(result[key], expected[key]) + + +def test_get_feature_stats_no_axis(sample_array): + expected = { + "min": np.array(1), + "max": np.array(9), + "mean": np.array(5.0), + "std": np.array(2.5819889), + "count": np.array([3]), + } + result = get_feature_stats(sample_array, axis=None, keepdims=False) + for key in expected: + np.testing.assert_allclose(result[key], expected[key]) + + +def test_get_feature_stats_empty_array(): + array = np.array([]) + with pytest.raises(ValueError): + get_feature_stats(array, axis=(0,), keepdims=True) + + +def test_get_feature_stats_single_value(): + array = np.array([[1337]]) + result = get_feature_stats(array, axis=None, keepdims=True) + np.testing.assert_equal(result["min"], np.array(1337)) + np.testing.assert_equal(result["max"], np.array(1337)) + np.testing.assert_equal(result["mean"], np.array(1337.0)) + np.testing.assert_equal(result["std"], np.array(0.0)) + np.testing.assert_equal(result["count"], np.array([1])) + + +def test_compute_episode_stats(): + episode_data = { + "observation.image": [f"image_{i}.jpg" for i in range(100)], + "observation.state": np.random.rand(100, 10), + } + features = { + "observation.image": {"dtype": "image"}, + "observation.state": {"dtype": "numeric"}, + } + + with patch( + "lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy + ): + stats = compute_episode_stats(episode_data, features) + + assert "observation.image" in stats and "observation.state" in stats + assert stats["observation.image"]["count"].item() == 100 + assert stats["observation.state"]["count"].item() == 100 + assert stats["observation.image"]["mean"].shape == (3, 1, 1) + + +def test_assert_type_and_shape_valid(): + valid_stats = [ + { + "feature1": { + "min": np.array([1.0]), + "max": np.array([10.0]), + "mean": np.array([5.0]), + "std": np.array([2.0]), + "count": np.array([1]), + } + } + ] + _assert_type_and_shape(valid_stats) + + +def test_assert_type_and_shape_invalid_type(): + invalid_stats = [ + { + "feature1": { + "min": [1.0], # Not a numpy array + "max": np.array([10.0]), + "mean": np.array([5.0]), + "std": np.array([2.0]), + "count": np.array([1]), + } + } + ] + with pytest.raises(ValueError, match="Stats must be composed of numpy array"): + _assert_type_and_shape(invalid_stats) + + +def test_assert_type_and_shape_invalid_shape(): + invalid_stats = [ + { + "feature1": { + "count": np.array([1, 2]), # Wrong shape + } + } + ] + with pytest.raises(ValueError, match=r"Shape of 'count' must be \(1\)"): + _assert_type_and_shape(invalid_stats) + + +def test_aggregate_feature_stats(): + stats_ft_list = [ + { + "min": np.array([1.0]), + "max": np.array([10.0]), + "mean": np.array([5.0]), + "std": np.array([2.0]), + "count": np.array([1]), + }, + { + "min": np.array([2.0]), + "max": np.array([12.0]), + "mean": np.array([6.0]), + "std": np.array([2.5]), + "count": np.array([1]), + }, + ] + result = aggregate_feature_stats(stats_ft_list) + np.testing.assert_allclose(result["min"], np.array([1.0])) + np.testing.assert_allclose(result["max"], np.array([12.0])) + np.testing.assert_allclose(result["mean"], np.array([5.5])) + np.testing.assert_allclose(result["std"], np.array([2.318405]), atol=1e-6) + np.testing.assert_allclose(result["count"], np.array([2])) + + +def test_aggregate_stats(): + all_stats = [ + { + "observation.image": { + "min": [1, 2, 3], + "max": [10, 20, 30], + "mean": [5.5, 10.5, 15.5], + "std": [2.87, 5.87, 8.87], + "count": 10, + }, + "observation.state": {"min": 1, "max": 10, "mean": 5.5, "std": 2.87, "count": 10}, + "extra_key_0": {"min": 5, "max": 25, "mean": 15, "std": 6, "count": 6}, + }, + { + "observation.image": { + "min": [2, 1, 0], + "max": [15, 10, 5], + "mean": [8.5, 5.5, 2.5], + "std": [3.42, 2.42, 1.42], + "count": 15, + }, + "observation.state": {"min": 2, "max": 15, "mean": 8.5, "std": 3.42, "count": 15}, + "extra_key_1": {"min": 0, "max": 20, "mean": 10, "std": 5, "count": 5}, + }, + ] + + expected_agg_stats = { + "observation.image": { + "min": [1, 1, 0], + "max": [15, 20, 30], + "mean": [7.3, 7.5, 7.7], + "std": [3.5317, 4.8267, 8.5581], + "count": 25, + }, + "observation.state": { + "min": 1, + "max": 15, + "mean": 7.3, + "std": 3.5317, + "count": 25, + }, + "extra_key_0": { + "min": 5, + "max": 25, + "mean": 15.0, + "std": 6.0, + "count": 6, + }, + "extra_key_1": { + "min": 0, + "max": 20, + "mean": 10.0, + "std": 5.0, + "count": 5, + }, + } + + # cast to numpy + for ep_stats in all_stats: + for fkey, stats in ep_stats.items(): + for k in stats: + stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32) + if fkey == "observation.image" and k != "count": + stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels + else: + stats[k] = stats[k].reshape(1) + + # cast to numpy + for fkey, stats in expected_agg_stats.items(): + for k in stats: + stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32) + if fkey == "observation.image" and k != "count": + stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels + else: + stats[k] = stats[k].reshape(1) + + results = aggregate_stats(all_stats) + + for fkey in expected_agg_stats: + np.testing.assert_allclose(results[fkey]["min"], expected_agg_stats[fkey]["min"]) + np.testing.assert_allclose(results[fkey]["max"], expected_agg_stats[fkey]["max"]) + np.testing.assert_allclose(results[fkey]["mean"], expected_agg_stats[fkey]["mean"]) + np.testing.assert_allclose( + results[fkey]["std"], expected_agg_stats[fkey]["std"], atol=1e-04, rtol=1e-04 + ) + np.testing.assert_allclose(results[fkey]["count"], expected_agg_stats[fkey]["count"]) diff --git a/tests/test_control_robot.py b/tests/test_control_robot.py index 36ee096f..12b68641 100644 --- a/tests/test_control_robot.py +++ b/tests/test_control_robot.py @@ -24,7 +24,6 @@ pytest -sx 'tests/test_control_robot.py::test_teleoperate[aloha-True]' """ import multiprocessing -from pathlib import Path from unittest.mock import patch import pytest @@ -45,7 +44,7 @@ from tests.utils import DEVICE, TEST_ROBOT_TYPES, mock_calibration_dir, require_ @pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES) @require_robot -def test_teleoperate(tmpdir, request, robot_type, mock): +def test_teleoperate(tmp_path, request, robot_type, mock): robot_kwargs = {"robot_type": robot_type, "mock": mock} if mock and robot_type != "aloha": @@ -53,8 +52,7 @@ def test_teleoperate(tmpdir, request, robot_type, mock): # Create an empty calibration directory to trigger manual calibration # and avoid writing calibration files in user .cache/calibration folder - tmpdir = Path(tmpdir) - calibration_dir = tmpdir / robot_type + calibration_dir = tmp_path / robot_type mock_calibration_dir(calibration_dir) robot_kwargs["calibration_dir"] = calibration_dir else: @@ -70,15 +68,14 @@ def test_teleoperate(tmpdir, request, robot_type, mock): @pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES) @require_robot -def test_calibrate(tmpdir, request, robot_type, mock): +def test_calibrate(tmp_path, request, robot_type, mock): robot_kwargs = {"robot_type": robot_type, "mock": mock} if mock: request.getfixturevalue("patch_builtins_input") # Create an empty calibration directory to trigger manual calibration - tmpdir = Path(tmpdir) - calibration_dir = tmpdir / robot_type + calibration_dir = tmp_path / robot_type robot_kwargs["calibration_dir"] = calibration_dir robot = make_robot(**robot_kwargs) @@ -89,7 +86,7 @@ def test_calibrate(tmpdir, request, robot_type, mock): @pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES) @require_robot -def test_record_without_cameras(tmpdir, request, robot_type, mock): +def test_record_without_cameras(tmp_path, request, robot_type, mock): robot_kwargs = {"robot_type": robot_type, "mock": mock} # Avoid using cameras @@ -100,7 +97,7 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock): # Create an empty calibration directory to trigger manual calibration # and avoid writing calibration files in user .cache/calibration folder - calibration_dir = Path(tmpdir) / robot_type + calibration_dir = tmp_path / robot_type mock_calibration_dir(calibration_dir) robot_kwargs["calibration_dir"] = calibration_dir else: @@ -108,7 +105,7 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock): pass repo_id = "lerobot/debug" - root = Path(tmpdir) / "data" / repo_id + root = tmp_path / "data" / repo_id single_task = "Do something." robot = make_robot(**robot_kwargs) @@ -121,7 +118,6 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock): episode_time_s=1, reset_time_s=0.1, num_episodes=2, - run_compute_stats=False, push_to_hub=False, video=False, play_sounds=False, @@ -131,8 +127,7 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock): @pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES) @require_robot -def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): - tmpdir = Path(tmpdir) +def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock): robot_kwargs = {"robot_type": robot_type, "mock": mock} if mock and robot_type != "aloha": @@ -140,7 +135,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): # Create an empty calibration directory to trigger manual calibration # and avoid writing calibration files in user .cache/calibration folder - calibration_dir = tmpdir / robot_type + calibration_dir = tmp_path / robot_type mock_calibration_dir(calibration_dir) robot_kwargs["calibration_dir"] = calibration_dir else: @@ -148,7 +143,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): pass repo_id = "lerobot_test/debug" - root = tmpdir / "data" / repo_id + root = tmp_path / "data" / repo_id single_task = "Do something." robot = make_robot(**robot_kwargs) @@ -172,15 +167,13 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): assert dataset.meta.total_episodes == 2 assert len(dataset) == 2 - replay_cfg = ReplayControlConfig( - episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False, local_files_only=True - ) + replay_cfg = ReplayControlConfig(episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False) replay(robot, replay_cfg) policy_cfg = ACTConfig() policy = make_policy(policy_cfg, ds_meta=dataset.meta, device=DEVICE) - out_dir = tmpdir / "logger" + out_dir = tmp_path / "logger" pretrained_policy_path = out_dir / "checkpoints/last/pretrained_model" policy.save_pretrained(pretrained_policy_path) @@ -207,7 +200,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): num_image_writer_processes = 0 eval_repo_id = "lerobot/eval_debug" - eval_root = tmpdir / "data" / eval_repo_id + eval_root = tmp_path / "data" / eval_repo_id rec_eval_cfg = RecordControlConfig( repo_id=eval_repo_id, @@ -218,7 +211,6 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): episode_time_s=1, reset_time_s=0.1, num_episodes=2, - run_compute_stats=False, push_to_hub=False, video=False, display_cameras=False, @@ -240,7 +232,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock): @pytest.mark.parametrize("robot_type, mock", [("koch", True)]) @require_robot -def test_resume_record(tmpdir, request, robot_type, mock): +def test_resume_record(tmp_path, request, robot_type, mock): robot_kwargs = {"robot_type": robot_type, "mock": mock} if mock and robot_type != "aloha": @@ -248,7 +240,7 @@ def test_resume_record(tmpdir, request, robot_type, mock): # Create an empty calibration directory to trigger manual calibration # and avoid writing calibration files in user .cache/calibration folder - calibration_dir = tmpdir / robot_type + calibration_dir = tmp_path / robot_type mock_calibration_dir(calibration_dir) robot_kwargs["calibration_dir"] = calibration_dir else: @@ -258,7 +250,7 @@ def test_resume_record(tmpdir, request, robot_type, mock): robot = make_robot(**robot_kwargs) repo_id = "lerobot/debug" - root = Path(tmpdir) / "data" / repo_id + root = tmp_path / "data" / repo_id single_task = "Do something." rec_cfg = RecordControlConfig( @@ -272,8 +264,6 @@ def test_resume_record(tmpdir, request, robot_type, mock): video=False, display_cameras=False, play_sounds=False, - run_compute_stats=False, - local_files_only=True, num_episodes=1, ) @@ -291,7 +281,7 @@ def test_resume_record(tmpdir, request, robot_type, mock): @pytest.mark.parametrize("robot_type, mock", [("koch", True)]) @require_robot -def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock): +def test_record_with_event_rerecord_episode(tmp_path, request, robot_type, mock): robot_kwargs = {"robot_type": robot_type, "mock": mock} if mock and robot_type != "aloha": @@ -299,7 +289,7 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock): # Create an empty calibration directory to trigger manual calibration # and avoid writing calibration files in user .cache/calibration folder - calibration_dir = tmpdir / robot_type + calibration_dir = tmp_path / robot_type mock_calibration_dir(calibration_dir) robot_kwargs["calibration_dir"] = calibration_dir else: @@ -316,7 +306,7 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock): mock_listener.return_value = (None, mock_events) repo_id = "lerobot/debug" - root = Path(tmpdir) / "data" / repo_id + root = tmp_path / "data" / repo_id single_task = "Do something." rec_cfg = RecordControlConfig( @@ -331,7 +321,6 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock): video=False, display_cameras=False, play_sounds=False, - run_compute_stats=False, ) dataset = record(robot, rec_cfg) @@ -342,7 +331,7 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock): @pytest.mark.parametrize("robot_type, mock", [("koch", True)]) @require_robot -def test_record_with_event_exit_early(tmpdir, request, robot_type, mock): +def test_record_with_event_exit_early(tmp_path, request, robot_type, mock): robot_kwargs = {"robot_type": robot_type, "mock": mock} if mock: @@ -350,7 +339,7 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock): # Create an empty calibration directory to trigger manual calibration # and avoid writing calibration files in user .cache/calibration folder - calibration_dir = tmpdir / robot_type + calibration_dir = tmp_path / robot_type mock_calibration_dir(calibration_dir) robot_kwargs["calibration_dir"] = calibration_dir else: @@ -367,7 +356,7 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock): mock_listener.return_value = (None, mock_events) repo_id = "lerobot/debug" - root = Path(tmpdir) / "data" / repo_id + root = tmp_path / "data" / repo_id single_task = "Do something." rec_cfg = RecordControlConfig( @@ -382,7 +371,6 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock): video=False, display_cameras=False, play_sounds=False, - run_compute_stats=False, ) dataset = record(robot, rec_cfg) @@ -395,7 +383,7 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock): "robot_type, mock, num_image_writer_processes", [("koch", True, 0), ("koch", True, 1)] ) @require_robot -def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num_image_writer_processes): +def test_record_with_event_stop_recording(tmp_path, request, robot_type, mock, num_image_writer_processes): robot_kwargs = {"robot_type": robot_type, "mock": mock} if mock: @@ -403,7 +391,7 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num # Create an empty calibration directory to trigger manual calibration # and avoid writing calibration files in user .cache/calibration folder - calibration_dir = tmpdir / robot_type + calibration_dir = tmp_path / robot_type mock_calibration_dir(calibration_dir) robot_kwargs["calibration_dir"] = calibration_dir else: @@ -420,7 +408,7 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num mock_listener.return_value = (None, mock_events) repo_id = "lerobot/debug" - root = Path(tmpdir) / "data" / repo_id + root = tmp_path / "data" / repo_id single_task = "Do something." rec_cfg = RecordControlConfig( @@ -436,7 +424,6 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num video=False, display_cameras=False, play_sounds=False, - run_compute_stats=False, num_image_writer_processes=num_image_writer_processes, ) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 8664d33e..61b68aa8 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -15,24 +15,21 @@ # limitations under the License. import json import logging +import re from copy import deepcopy from itertools import chain from pathlib import Path -import einops +import numpy as np import pytest import torch -from datasets import Dataset from huggingface_hub import HfApi +from PIL import Image from safetensors.torch import load_file import lerobot -from lerobot.common.datasets.compute_stats import ( - aggregate_stats, - compute_stats, - get_stats_einops_patterns, -) from lerobot.common.datasets.factory import make_dataset +from lerobot.common.datasets.image_writer import image_array_to_pil_image from lerobot.common.datasets.lerobot_dataset import ( LeRobotDataset, MultiLeRobotDataset, @@ -40,20 +37,34 @@ from lerobot.common.datasets.lerobot_dataset import ( from lerobot.common.datasets.utils import ( create_branch, flatten_dict, - hf_transform_to_torch, unflatten_dict, ) from lerobot.common.envs.factory import make_env_config from lerobot.common.policies.factory import make_policy_config from lerobot.common.robot_devices.robots.utils import make_robot -from lerobot.common.utils.random_utils import seeded_context from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig -from tests.fixtures.constants import DUMMY_REPO_ID +from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID from tests.utils import DEVICE, require_x86_64_kernel -def test_same_attributes_defined(lerobot_dataset_factory, tmp_path): +@pytest.fixture +def image_dataset(tmp_path, empty_lerobot_dataset_factory): + features = { + "image": { + "dtype": "image", + "shape": DUMMY_CHW, + "names": [ + "channels", + "height", + "width", + ], + } + } + return empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + + +def test_same_attributes_defined(tmp_path, lerobot_dataset_factory): """ Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated objects have the same sets of attributes defined. @@ -66,24 +77,20 @@ def test_same_attributes_defined(lerobot_dataset_factory, tmp_path): root_init = tmp_path / "init" dataset_init = lerobot_dataset_factory(root=root_init) - # Access the '_hub_version' cached_property in both instances to force its creation - _ = dataset_init.meta._hub_version - _ = dataset_create.meta._hub_version - init_attr = set(vars(dataset_init).keys()) create_attr = set(vars(dataset_create).keys()) assert init_attr == create_attr -def test_dataset_initialization(lerobot_dataset_factory, tmp_path): +def test_dataset_initialization(tmp_path, lerobot_dataset_factory): kwargs = { "repo_id": DUMMY_REPO_ID, "total_episodes": 10, "total_frames": 400, "episodes": [2, 5, 6], } - dataset = lerobot_dataset_factory(root=tmp_path, **kwargs) + dataset = lerobot_dataset_factory(root=tmp_path / "test", **kwargs) assert dataset.repo_id == kwargs["repo_id"] assert dataset.meta.total_episodes == kwargs["total_episodes"] @@ -93,12 +100,232 @@ def test_dataset_initialization(lerobot_dataset_factory, tmp_path): assert dataset.num_frames == len(dataset) +def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory): + features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + with pytest.raises( + ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n" + ): + dataset.add_frame({"state": torch.randn(1)}) + + +def test_add_frame_missing_feature(tmp_path, empty_lerobot_dataset_factory): + features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + with pytest.raises( + ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n" + ): + dataset.add_frame({"task": "Dummy task"}) + + +def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory): + features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + with pytest.raises( + ValueError, match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n" + ): + dataset.add_frame({"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"}) + + +def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory): + features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + with pytest.raises( + ValueError, match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n" + ): + dataset.add_frame({"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"}) + + +def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory): + features = {"state": {"dtype": "float32", "shape": (2,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + with pytest.raises( + ValueError, + match=re.escape("The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"), + ): + dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"}) + + +def test_add_frame_wrong_shape_python_float(tmp_path, empty_lerobot_dataset_factory): + features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + with pytest.raises( + ValueError, + match=re.escape( + "The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '' provided instead.\n" + ), + ): + dataset.add_frame({"state": 1.0, "task": "Dummy task"}) + + +def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_factory): + features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + with pytest.raises( + ValueError, + match=re.escape("The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"), + ): + dataset.add_frame({"state": torch.tensor(1.0), "task": "Dummy task"}) + + +def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_factory): + features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + with pytest.raises( + ValueError, + match=re.escape( + "The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '' provided instead.\n" + ), + ): + dataset.add_frame({"state": np.float32(1.0), "task": "Dummy task"}) + + +def test_add_frame(tmp_path, empty_lerobot_dataset_factory): + features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"}) + dataset.save_episode() + + assert len(dataset) == 1 + assert dataset[0]["task"] == "Dummy task" + assert dataset[0]["task_index"] == 0 + assert dataset[0]["state"].ndim == 0 + + +def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory): + features = {"state": {"dtype": "float32", "shape": (2,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + dataset.add_frame({"state": torch.randn(2), "task": "Dummy task"}) + dataset.save_episode() + + assert dataset[0]["state"].shape == torch.Size([2]) + + +def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory): + features = {"state": {"dtype": "float32", "shape": (2, 4), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + dataset.add_frame({"state": torch.randn(2, 4), "task": "Dummy task"}) + dataset.save_episode() + + assert dataset[0]["state"].shape == torch.Size([2, 4]) + + +def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory): + features = {"state": {"dtype": "float32", "shape": (2, 4, 3), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "Dummy task"}) + dataset.save_episode() + + assert dataset[0]["state"].shape == torch.Size([2, 4, 3]) + + +def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory): + features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "Dummy task"}) + dataset.save_episode() + + assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5]) + + +def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory): + features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5, 1), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "Dummy task"}) + dataset.save_episode() + + assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1]) + + +def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory): + features = {"state": {"dtype": "float32", "shape": (1,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + dataset.add_frame({"state": np.array([1], dtype=np.float32), "task": "Dummy task"}) + dataset.save_episode() + + assert dataset[0]["state"].ndim == 0 + + +def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory): + features = {"caption": {"dtype": "string", "shape": (1,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features) + dataset.add_frame({"caption": "Dummy caption", "task": "Dummy task"}) + dataset.save_episode() + + assert dataset[0]["caption"] == "Dummy caption" + + +def test_add_frame_image_wrong_shape(image_dataset): + dataset = image_dataset + with pytest.raises( + ValueError, + match=re.escape( + "The feature 'image' of shape '(3, 128, 96)' does not have the expected shape '(3, 96, 128)' or '(96, 128, 3)'.\n" + ), + ): + c, h, w = DUMMY_CHW + dataset.add_frame({"image": torch.randn(c, w, h), "task": "Dummy task"}) + + +def test_add_frame_image_wrong_range(image_dataset): + """This test will display the following error message from a thread: + ``` + Error writing image ...test_add_frame_image_wrong_ran0/test/images/image/episode_000000/frame_000000.png: + The image data type is float, which requires values in the range [0.0, 1.0]. However, the provided range is [0.009678772038470007, 254.9776492089887]. + Please adjust the range or provide a uint8 image with values in the range [0, 255] + ``` + Hence the image won't be saved on disk and save_episode will raise `FileNotFoundError`. + """ + dataset = image_dataset + dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255, "task": "Dummy task"}) + with pytest.raises(FileNotFoundError): + dataset.save_episode() + + +def test_add_frame_image(image_dataset): + dataset = image_dataset + dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"}) + dataset.save_episode() + + assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) + + +def test_add_frame_image_h_w_c(image_dataset): + dataset = image_dataset + dataset.add_frame({"image": np.random.rand(*DUMMY_HWC), "task": "Dummy task"}) + dataset.save_episode() + + assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) + + +def test_add_frame_image_uint8(image_dataset): + dataset = image_dataset + image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8) + dataset.add_frame({"image": image, "task": "Dummy task"}) + dataset.save_episode() + + assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) + + +def test_add_frame_image_pil(image_dataset): + dataset = image_dataset + image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8) + dataset.add_frame({"image": Image.fromarray(image), "task": "Dummy task"}) + dataset.save_episode() + + assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) + + +def test_image_array_to_pil_image_wrong_range_float_0_255(): + image = np.random.rand(*DUMMY_HWC) * 255 + with pytest.raises(ValueError): + image_array_to_pil_image(image) + + # TODO(aliberts): # - [ ] test various attributes & state from init and create # - [ ] test init with episodes and check num_frames -# - [ ] test add_frame # - [ ] test add_episode -# - [ ] test consolidate # - [ ] test push_to_hub # - [ ] test smaller methods @@ -210,67 +437,6 @@ def test_multidataset_frames(): assert torch.equal(sub_dataset_item[k], dataset_item[k]) -# TODO(aliberts, rcadene): Refactor and move this to a tests/test_compute_stats.py -def test_compute_stats_on_xarm(): - """Check that the statistics are computed correctly according to the stats_patterns property. - - We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do - because we are working with a small dataset). - """ - # TODO(rcadene, aliberts): remove dataset download - dataset = LeRobotDataset("lerobot/xarm_lift_medium", episodes=[0]) - - # reduce size of dataset sample on which stats compute is tested to 10 frames - dataset.hf_dataset = dataset.hf_dataset.select(range(10)) - - # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched - # computation of the statistics. While doing this, we also make sure it works when we don't divide the - # dataset into even batches. - computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25), num_workers=0) - - # get einops patterns to aggregate batches and compute statistics - stats_patterns = get_stats_einops_patterns(dataset) - - # get all frames from the dataset in the same dtype and range as during compute_stats - dataloader = torch.utils.data.DataLoader( - dataset, - num_workers=0, - batch_size=len(dataset), - shuffle=False, - ) - full_batch = next(iter(dataloader)) - - # compute stats based on all frames from the dataset without any batching - expected_stats = {} - for k, pattern in stats_patterns.items(): - full_batch[k] = full_batch[k].float() - expected_stats[k] = {} - expected_stats[k]["mean"] = einops.reduce(full_batch[k], pattern, "mean") - expected_stats[k]["std"] = torch.sqrt( - einops.reduce((full_batch[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean") - ) - expected_stats[k]["min"] = einops.reduce(full_batch[k], pattern, "min") - expected_stats[k]["max"] = einops.reduce(full_batch[k], pattern, "max") - - # test computed stats match expected stats - for k in stats_patterns: - assert torch.allclose(computed_stats[k]["mean"], expected_stats[k]["mean"]) - assert torch.allclose(computed_stats[k]["std"], expected_stats[k]["std"]) - assert torch.allclose(computed_stats[k]["min"], expected_stats[k]["min"]) - assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"]) - - # load stats used during training which are expected to match the ones returned by computed_stats - loaded_stats = dataset.meta.stats # noqa: F841 - - # TODO(rcadene): we can't test this because expected_stats is computed on a subset - # # test loaded stats match expected stats - # for k in stats_patterns: - # assert torch.allclose(loaded_stats[k]["mean"], expected_stats[k]["mean"]) - # assert torch.allclose(loaded_stats[k]["std"], expected_stats[k]["std"]) - # assert torch.allclose(loaded_stats[k]["min"], expected_stats[k]["min"]) - # assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"]) - - # TODO(aliberts): Move to more appropriate location def test_flatten_unflatten_dict(): d = { @@ -374,35 +540,6 @@ def test_backward_compatibility(repo_id): # load_and_compare(i - 1) -@pytest.mark.skip("TODO after fix multidataset") -def test_multidataset_aggregate_stats(): - """Makes 3 basic datasets and checks that aggregate stats are computed correctly.""" - with seeded_context(0): - data_a = torch.rand(30, dtype=torch.float32) - data_b = torch.rand(20, dtype=torch.float32) - data_c = torch.rand(20, dtype=torch.float32) - - hf_dataset_1 = Dataset.from_dict( - {"a": data_a[:10], "b": data_b[:10], "c": data_c[:10], "index": torch.arange(10)} - ) - hf_dataset_1.set_transform(hf_transform_to_torch) - hf_dataset_2 = Dataset.from_dict({"a": data_a[10:20], "b": data_b[10:], "index": torch.arange(10)}) - hf_dataset_2.set_transform(hf_transform_to_torch) - hf_dataset_3 = Dataset.from_dict({"a": data_a[20:], "c": data_c[10:], "index": torch.arange(10)}) - hf_dataset_3.set_transform(hf_transform_to_torch) - dataset_1 = LeRobotDataset.from_preloaded("d1", hf_dataset=hf_dataset_1) - dataset_1.stats = compute_stats(dataset_1, batch_size=len(hf_dataset_1), num_workers=0) - dataset_2 = LeRobotDataset.from_preloaded("d2", hf_dataset=hf_dataset_2) - dataset_2.stats = compute_stats(dataset_2, batch_size=len(hf_dataset_2), num_workers=0) - dataset_3 = LeRobotDataset.from_preloaded("d3", hf_dataset=hf_dataset_3) - dataset_3.stats = compute_stats(dataset_3, batch_size=len(hf_dataset_3), num_workers=0) - stats = aggregate_stats([dataset_1, dataset_2, dataset_3]) - for data_key, data in zip(["a", "b", "c"], [data_a, data_b, data_c], strict=True): - for agg_fn in ["mean", "min", "max"]: - assert torch.allclose(stats[data_key][agg_fn], einops.reduce(data, "n -> 1", agg_fn)) - assert torch.allclose(stats[data_key]["std"], torch.std(data, correction=0)) - - @pytest.mark.skip("Requires internet access") def test_create_branch(): api = HfApi() @@ -431,9 +568,9 @@ def test_create_branch(): def test_dataset_feature_with_forward_slash_raises_error(): # make sure dir does not exist - from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME + from lerobot.common.constants import HF_LEROBOT_HOME - dataset_dir = LEROBOT_HOME / "lerobot/test/with/slash" + dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash" # make sure does not exist if dataset_dir.exists(): dataset_dir.rmdir() diff --git a/tests/test_delta_timestamps.py b/tests/test_delta_timestamps.py index 3516583d..3e3b83ac 100644 --- a/tests/test_delta_timestamps.py +++ b/tests/test_delta_timestamps.py @@ -1,55 +1,78 @@ +from itertools import accumulate + +import datasets +import numpy as np +import pyarrow.compute as pc import pytest import torch -from datasets import Dataset from lerobot.common.datasets.utils import ( - calculate_episode_data_index, check_delta_timestamps, check_timestamps_sync, get_delta_indices, - hf_transform_to_torch, ) from tests.fixtures.constants import DUMMY_MOTOR_FEATURES -@pytest.fixture(scope="module") -def synced_hf_dataset_factory(hf_dataset_factory): - def _create_synced_hf_dataset(fps: int = 30) -> Dataset: - return hf_dataset_factory(fps=fps) +def calculate_total_episode( + hf_dataset: datasets.Dataset, raise_if_not_contiguous: bool = True +) -> dict[str, torch.Tensor]: + episode_indices = sorted(hf_dataset.unique("episode_index")) + total_episodes = len(episode_indices) + if raise_if_not_contiguous and episode_indices != list(range(total_episodes)): + raise ValueError("episode_index values are not sorted and contiguous.") + return total_episodes - return _create_synced_hf_dataset + +def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.ndarray]: + episode_lengths = [] + table = hf_dataset.data.table + total_episodes = calculate_total_episode(hf_dataset) + for ep_idx in range(total_episodes): + ep_table = table.filter(pc.equal(table["episode_index"], ep_idx)) + episode_lengths.insert(ep_idx, len(ep_table)) + + cumulative_lenghts = list(accumulate(episode_lengths)) + return { + "from": np.array([0] + cumulative_lenghts[:-1], dtype=np.int64), + "to": np.array(cumulative_lenghts, dtype=np.int64), + } @pytest.fixture(scope="module") -def unsynced_hf_dataset_factory(synced_hf_dataset_factory): - def _create_unsynced_hf_dataset(fps: int = 30, tolerance_s: float = 1e-4) -> Dataset: - hf_dataset = synced_hf_dataset_factory(fps=fps) - features = hf_dataset.features - df = hf_dataset.to_pandas() - dtype = df["timestamp"].dtype # This is to avoid pandas type warning - # Modify a single timestamp just outside tolerance - df.at[30, "timestamp"] = dtype.type(df.at[30, "timestamp"] + (tolerance_s * 1.1)) - unsynced_hf_dataset = Dataset.from_pandas(df, features=features) - unsynced_hf_dataset.set_transform(hf_transform_to_torch) - return unsynced_hf_dataset +def synced_timestamps_factory(hf_dataset_factory): + def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + hf_dataset = hf_dataset_factory(fps=fps) + timestamps = torch.stack(hf_dataset["timestamp"]).numpy() + episode_indices = torch.stack(hf_dataset["episode_index"]).numpy() + episode_data_index = calculate_episode_data_index(hf_dataset) + return timestamps, episode_indices, episode_data_index - return _create_unsynced_hf_dataset + return _create_synced_timestamps @pytest.fixture(scope="module") -def slightly_off_hf_dataset_factory(synced_hf_dataset_factory): - def _create_slightly_off_hf_dataset(fps: int = 30, tolerance_s: float = 1e-4) -> Dataset: - hf_dataset = synced_hf_dataset_factory(fps=fps) - features = hf_dataset.features - df = hf_dataset.to_pandas() - dtype = df["timestamp"].dtype # This is to avoid pandas type warning - # Modify a single timestamp just inside tolerance - df.at[30, "timestamp"] = dtype.type(df.at[30, "timestamp"] + (tolerance_s * 0.9)) - unsynced_hf_dataset = Dataset.from_pandas(df, features=features) - unsynced_hf_dataset.set_transform(hf_transform_to_torch) - return unsynced_hf_dataset +def unsynced_timestamps_factory(synced_timestamps_factory): + def _create_unsynced_timestamps( + fps: int = 30, tolerance_s: float = 1e-4 + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps) + timestamps[30] += tolerance_s * 1.1 # Modify a single timestamp just outside tolerance + return timestamps, episode_indices, episode_data_index - return _create_slightly_off_hf_dataset + return _create_unsynced_timestamps + + +@pytest.fixture(scope="module") +def slightly_off_timestamps_factory(synced_timestamps_factory): + def _create_slightly_off_timestamps( + fps: int = 30, tolerance_s: float = 1e-4 + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps) + timestamps[30] += tolerance_s * 0.9 # Modify a single timestamp just inside tolerance + return timestamps, episode_indices, episode_data_index + + return _create_slightly_off_timestamps @pytest.fixture(scope="module") @@ -100,42 +123,42 @@ def delta_indices_factory(): return _delta_indices -def test_check_timestamps_sync_synced(synced_hf_dataset_factory): +def test_check_timestamps_sync_synced(synced_timestamps_factory): fps = 30 tolerance_s = 1e-4 - synced_hf_dataset = synced_hf_dataset_factory(fps) - episode_data_index = calculate_episode_data_index(synced_hf_dataset) + timestamps, ep_idx, ep_data_index = synced_timestamps_factory(fps) result = check_timestamps_sync( - hf_dataset=synced_hf_dataset, - episode_data_index=episode_data_index, + timestamps=timestamps, + episode_indices=ep_idx, + episode_data_index=ep_data_index, fps=fps, tolerance_s=tolerance_s, ) assert result is True -def test_check_timestamps_sync_unsynced(unsynced_hf_dataset_factory): +def test_check_timestamps_sync_unsynced(unsynced_timestamps_factory): fps = 30 tolerance_s = 1e-4 - unsynced_hf_dataset = unsynced_hf_dataset_factory(fps, tolerance_s) - episode_data_index = calculate_episode_data_index(unsynced_hf_dataset) + timestamps, ep_idx, ep_data_index = unsynced_timestamps_factory(fps, tolerance_s) with pytest.raises(ValueError): check_timestamps_sync( - hf_dataset=unsynced_hf_dataset, - episode_data_index=episode_data_index, + timestamps=timestamps, + episode_indices=ep_idx, + episode_data_index=ep_data_index, fps=fps, tolerance_s=tolerance_s, ) -def test_check_timestamps_sync_unsynced_no_exception(unsynced_hf_dataset_factory): +def test_check_timestamps_sync_unsynced_no_exception(unsynced_timestamps_factory): fps = 30 tolerance_s = 1e-4 - unsynced_hf_dataset = unsynced_hf_dataset_factory(fps, tolerance_s) - episode_data_index = calculate_episode_data_index(unsynced_hf_dataset) + timestamps, ep_idx, ep_data_index = unsynced_timestamps_factory(fps, tolerance_s) result = check_timestamps_sync( - hf_dataset=unsynced_hf_dataset, - episode_data_index=episode_data_index, + timestamps=timestamps, + episode_indices=ep_idx, + episode_data_index=ep_data_index, fps=fps, tolerance_s=tolerance_s, raise_value_error=False, @@ -143,14 +166,14 @@ def test_check_timestamps_sync_unsynced_no_exception(unsynced_hf_dataset_factory assert result is False -def test_check_timestamps_sync_slightly_off(slightly_off_hf_dataset_factory): +def test_check_timestamps_sync_slightly_off(slightly_off_timestamps_factory): fps = 30 tolerance_s = 1e-4 - slightly_off_hf_dataset = slightly_off_hf_dataset_factory(fps, tolerance_s) - episode_data_index = calculate_episode_data_index(slightly_off_hf_dataset) + timestamps, ep_idx, ep_data_index = slightly_off_timestamps_factory(fps, tolerance_s) result = check_timestamps_sync( - hf_dataset=slightly_off_hf_dataset, - episode_data_index=episode_data_index, + timestamps=timestamps, + episode_indices=ep_idx, + episode_data_index=ep_data_index, fps=fps, tolerance_s=tolerance_s, ) @@ -158,33 +181,13 @@ def test_check_timestamps_sync_slightly_off(slightly_off_hf_dataset_factory): def test_check_timestamps_sync_single_timestamp(): - single_timestamp_hf_dataset = Dataset.from_dict({"timestamp": [0.0], "episode_index": [0]}) - single_timestamp_hf_dataset.set_transform(hf_transform_to_torch) - episode_data_index = {"to": torch.tensor([1]), "from": torch.tensor([0])} fps = 30 tolerance_s = 1e-4 + timestamps, ep_idx = np.array([0.0]), np.array([0]) + episode_data_index = {"to": np.array([1]), "from": np.array([0])} result = check_timestamps_sync( - hf_dataset=single_timestamp_hf_dataset, - episode_data_index=episode_data_index, - fps=fps, - tolerance_s=tolerance_s, - ) - assert result is True - - -# TODO(aliberts): Change behavior of hf_transform_to_torch so that it can work with empty dataset -@pytest.mark.skip("TODO: fix") -def test_check_timestamps_sync_empty_dataset(): - fps = 30 - tolerance_s = 1e-4 - empty_hf_dataset = Dataset.from_dict({"timestamp": [], "episode_index": []}) - empty_hf_dataset.set_transform(hf_transform_to_torch) - episode_data_index = { - "to": torch.tensor([], dtype=torch.int64), - "from": torch.tensor([], dtype=torch.int64), - } - result = check_timestamps_sync( - hf_dataset=empty_hf_dataset, + timestamps=timestamps, + episode_indices=ep_idx, episode_data_index=episode_data_index, fps=fps, tolerance_s=tolerance_s, diff --git a/tests/test_examples.py b/tests/test_examples.py index f3b7948c..aabec69a 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -53,7 +53,7 @@ def test_example_1(tmp_path, lerobot_dataset_factory): ('repo_id = "lerobot/pusht"', f'repo_id = "{DUMMY_REPO_ID}"'), ( "LeRobotDataset(repo_id", - f"LeRobotDataset(repo_id, root='{str(tmp_path)}', local_files_only=True", + f"LeRobotDataset(repo_id, root='{str(tmp_path)}'", ), ], ) diff --git a/tests/test_image_writer.py b/tests/test_image_writer.py index f51e86b4..c7fc11f2 100644 --- a/tests/test_image_writer.py +++ b/tests/test_image_writer.py @@ -9,10 +9,11 @@ from PIL import Image from lerobot.common.datasets.image_writer import ( AsyncImageWriter, - image_array_to_image, + image_array_to_pil_image, safe_stop_image_writer, write_image, ) +from tests.fixtures.constants import DUMMY_HWC DUMMY_IMAGE = "test_image.png" @@ -48,49 +49,62 @@ def test_zero_threads(): AsyncImageWriter(num_processes=0, num_threads=0) -def test_image_array_to_image_rgb(img_array_factory): +def test_image_array_to_pil_image_float_array_wrong_range_0_255(): + image = np.random.rand(*DUMMY_HWC) * 255 + with pytest.raises(ValueError): + image_array_to_pil_image(image) + + +def test_image_array_to_pil_image_float_array_wrong_range_neg_1_1(): + image = np.random.rand(*DUMMY_HWC) * 2 - 1 + with pytest.raises(ValueError): + image_array_to_pil_image(image) + + +def test_image_array_to_pil_image_rgb(img_array_factory): img_array = img_array_factory(100, 100) - result_image = image_array_to_image(img_array) + result_image = image_array_to_pil_image(img_array) assert isinstance(result_image, Image.Image) assert result_image.size == (100, 100) assert result_image.mode == "RGB" -def test_image_array_to_image_pytorch_format(img_array_factory): +def test_image_array_to_pil_image_pytorch_format(img_array_factory): img_array = img_array_factory(100, 100).transpose(2, 0, 1) - result_image = image_array_to_image(img_array) + result_image = image_array_to_pil_image(img_array) assert isinstance(result_image, Image.Image) assert result_image.size == (100, 100) assert result_image.mode == "RGB" -@pytest.mark.skip("TODO: implement") -def test_image_array_to_image_single_channel(img_array_factory): +def test_image_array_to_pil_image_single_channel(img_array_factory): img_array = img_array_factory(channels=1) - result_image = image_array_to_image(img_array) - assert isinstance(result_image, Image.Image) - assert result_image.size == (100, 100) - assert result_image.mode == "L" + with pytest.raises(NotImplementedError): + image_array_to_pil_image(img_array) -def test_image_array_to_image_float_array(img_array_factory): +def test_image_array_to_pil_image_4_channels(img_array_factory): + img_array = img_array_factory(channels=4) + with pytest.raises(NotImplementedError): + image_array_to_pil_image(img_array) + + +def test_image_array_to_pil_image_float_array(img_array_factory): img_array = img_array_factory(dtype=np.float32) - result_image = image_array_to_image(img_array) + result_image = image_array_to_pil_image(img_array) assert isinstance(result_image, Image.Image) assert result_image.size == (100, 100) assert result_image.mode == "RGB" assert np.array(result_image).dtype == np.uint8 -def test_image_array_to_image_out_of_bounds_float(): - # Float array with values out of [0, 1] - img_array = np.random.uniform(-1, 2, size=(100, 100, 3)).astype(np.float32) - result_image = image_array_to_image(img_array) +def test_image_array_to_pil_image_uint8_array(img_array_factory): + img_array = img_array_factory(dtype=np.float32) + result_image = image_array_to_pil_image(img_array) assert isinstance(result_image, Image.Image) assert result_image.size == (100, 100) assert result_image.mode == "RGB" assert np.array(result_image).dtype == np.uint8 - assert np.array(result_image).min() >= 0 and np.array(result_image).max() <= 255 def test_write_image_numpy(tmp_path, img_array_factory): diff --git a/tests/test_push_dataset_to_hub.py b/tests/test_push_dataset_to_hub.py deleted file mode 100644 index a0c8d908..00000000 --- a/tests/test_push_dataset_to_hub.py +++ /dev/null @@ -1,370 +0,0 @@ -""" -This file contains generic tests to ensure that nothing breaks if we modify the push_dataset_to_hub API. -Also, this file contains backward compatibility tests. Because they are slow and require to download the raw datasets, -we skip them for now in our CI. - -Example to run backward compatiblity tests locally: -``` -python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility -``` -""" - -from pathlib import Path - -import numpy as np -import pytest -import torch - -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -from lerobot.common.datasets.push_dataset_to_hub.utils import save_images_concurrently -from lerobot.common.datasets.video_utils import encode_video_frames -from lerobot.scripts.push_dataset_to_hub import push_dataset_to_hub -from tests.utils import require_package_arg - - -def _mock_download_raw_pusht(raw_dir, num_frames=4, num_episodes=3): - import zarr - - raw_dir.mkdir(parents=True, exist_ok=True) - zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr" - store = zarr.DirectoryStore(zarr_path) - zarr_data = zarr.group(store=store) - - zarr_data.create_dataset( - "data/action", shape=(num_frames, 1), chunks=(num_frames, 1), dtype=np.float32, overwrite=True - ) - zarr_data.create_dataset( - "data/img", - shape=(num_frames, 96, 96, 3), - chunks=(num_frames, 96, 96, 3), - dtype=np.uint8, - overwrite=True, - ) - zarr_data.create_dataset( - "data/n_contacts", shape=(num_frames, 2), chunks=(num_frames, 2), dtype=np.float32, overwrite=True - ) - zarr_data.create_dataset( - "data/state", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True - ) - zarr_data.create_dataset( - "data/keypoint", shape=(num_frames, 9, 2), chunks=(num_frames, 9, 2), dtype=np.float32, overwrite=True - ) - zarr_data.create_dataset( - "meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True - ) - - zarr_data["data/action"][:] = np.random.randn(num_frames, 1) - zarr_data["data/img"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8) - zarr_data["data/n_contacts"][:] = np.random.randn(num_frames, 2) - zarr_data["data/state"][:] = np.random.randn(num_frames, 5) - zarr_data["data/keypoint"][:] = np.random.randn(num_frames, 9, 2) - zarr_data["meta/episode_ends"][:] = np.array([1, 3, 4]) - - store.close() - - -def _mock_download_raw_umi(raw_dir, num_frames=4, num_episodes=3): - import zarr - - raw_dir.mkdir(parents=True, exist_ok=True) - zarr_path = raw_dir / "cup_in_the_wild.zarr" - store = zarr.DirectoryStore(zarr_path) - zarr_data = zarr.group(store=store) - - zarr_data.create_dataset( - "data/camera0_rgb", - shape=(num_frames, 96, 96, 3), - chunks=(num_frames, 96, 96, 3), - dtype=np.uint8, - overwrite=True, - ) - zarr_data.create_dataset( - "data/robot0_demo_end_pose", - shape=(num_frames, 5), - chunks=(num_frames, 5), - dtype=np.float32, - overwrite=True, - ) - zarr_data.create_dataset( - "data/robot0_demo_start_pose", - shape=(num_frames, 5), - chunks=(num_frames, 5), - dtype=np.float32, - overwrite=True, - ) - zarr_data.create_dataset( - "data/robot0_eef_pos", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True - ) - zarr_data.create_dataset( - "data/robot0_eef_rot_axis_angle", - shape=(num_frames, 5), - chunks=(num_frames, 5), - dtype=np.float32, - overwrite=True, - ) - zarr_data.create_dataset( - "data/robot0_gripper_width", - shape=(num_frames, 5), - chunks=(num_frames, 5), - dtype=np.float32, - overwrite=True, - ) - zarr_data.create_dataset( - "meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True - ) - - zarr_data["data/camera0_rgb"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8) - zarr_data["data/robot0_demo_end_pose"][:] = np.random.randn(num_frames, 5) - zarr_data["data/robot0_demo_start_pose"][:] = np.random.randn(num_frames, 5) - zarr_data["data/robot0_eef_pos"][:] = np.random.randn(num_frames, 5) - zarr_data["data/robot0_eef_rot_axis_angle"][:] = np.random.randn(num_frames, 5) - zarr_data["data/robot0_gripper_width"][:] = np.random.randn(num_frames, 5) - zarr_data["meta/episode_ends"][:] = np.array([1, 3, 4]) - - store.close() - - -def _mock_download_raw_xarm(raw_dir, num_frames=4): - import pickle - - dataset_dict = { - "observations": { - "rgb": np.random.randint(0, 255, size=(num_frames, 3, 84, 84), dtype=np.uint8), - "state": np.random.randn(num_frames, 4), - }, - "actions": np.random.randn(num_frames, 3), - "rewards": np.random.randn(num_frames), - "masks": np.random.randn(num_frames), - "dones": np.array([False, True, True, True]), - } - - raw_dir.mkdir(parents=True, exist_ok=True) - pkl_path = raw_dir / "buffer.pkl" - with open(pkl_path, "wb") as f: - pickle.dump(dataset_dict, f) - - -def _mock_download_raw_aloha(raw_dir, num_frames=6, num_episodes=3): - import h5py - - for ep_idx in range(num_episodes): - raw_dir.mkdir(parents=True, exist_ok=True) - path_h5 = raw_dir / f"episode_{ep_idx}.hdf5" - with h5py.File(str(path_h5), "w") as f: - f.create_dataset("action", data=np.random.randn(num_frames // num_episodes, 14)) - f.create_dataset("observations/qpos", data=np.random.randn(num_frames // num_episodes, 14)) - f.create_dataset("observations/qvel", data=np.random.randn(num_frames // num_episodes, 14)) - f.create_dataset( - "observations/images/top", - data=np.random.randint( - 0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8 - ), - ) - - -def _mock_download_raw_dora(raw_dir, num_frames=6, num_episodes=3, fps=30): - from datetime import datetime, timedelta, timezone - - import pandas - - def write_parquet(key, timestamps, values): - data = { - "timestamp_utc": timestamps, - key: values, - } - df = pandas.DataFrame(data) - raw_dir.mkdir(parents=True, exist_ok=True) - df.to_parquet(raw_dir / f"{key}.parquet", engine="pyarrow") - - episode_indices = [None, None, -1, None, None, -1, None, None, -1] - episode_indices_mapping = [0, 0, 0, 1, 1, 1, 2, 2, 2] - frame_indices = [0, 1, -1, 0, 1, -1, 0, 1, -1] - - cam_key = "observation.images.cam_high" - timestamps = [] - actions = [] - states = [] - frames = [] - # `+ num_episodes`` for buffer frames associated to episode_index=-1 - for i, frame_idx in enumerate(frame_indices): - t_utc = datetime.now(timezone.utc) + timedelta(seconds=i / fps) - action = np.random.randn(21).tolist() - state = np.random.randn(21).tolist() - ep_idx = episode_indices_mapping[i] - frame = [{"path": f"videos/{cam_key}_episode_{ep_idx:06d}.mp4", "timestamp": frame_idx / fps}] - timestamps.append(t_utc) - actions.append(action) - states.append(state) - frames.append(frame) - - write_parquet(cam_key, timestamps, frames) - write_parquet("observation.state", timestamps, states) - write_parquet("action", timestamps, actions) - write_parquet("episode_index", timestamps, episode_indices) - - # write fake mp4 file for each episode - for ep_idx in range(num_episodes): - imgs_array = np.random.randint(0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8) - - tmp_imgs_dir = raw_dir / "tmp_images" - save_images_concurrently(imgs_array, tmp_imgs_dir) - - fname = f"{cam_key}_episode_{ep_idx:06d}.mp4" - video_path = raw_dir / "videos" / fname - encode_video_frames(tmp_imgs_dir, video_path, fps, vcodec="libx264") - - -def _mock_download_raw(raw_dir, repo_id): - if "wrist_gripper" in repo_id: - _mock_download_raw_dora(raw_dir) - elif "aloha" in repo_id: - _mock_download_raw_aloha(raw_dir) - elif "pusht" in repo_id: - _mock_download_raw_pusht(raw_dir) - elif "xarm" in repo_id: - _mock_download_raw_xarm(raw_dir) - elif "umi" in repo_id: - _mock_download_raw_umi(raw_dir) - else: - raise ValueError(repo_id) - - -@pytest.mark.skip("push_dataset_to_hub is deprecated") -def test_push_dataset_to_hub_invalid_repo_id(tmpdir): - with pytest.raises(ValueError): - push_dataset_to_hub(Path(tmpdir), "raw_format", "invalid_repo_id") - - -@pytest.mark.skip("push_dataset_to_hub is deprecated") -def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir): - tmpdir = Path(tmpdir) - out_dir = tmpdir / "out" - raw_dir = tmpdir / "raw" - # mkdir to skip download - raw_dir.mkdir(parents=True, exist_ok=True) - with pytest.raises(ValueError): - push_dataset_to_hub( - raw_dir=raw_dir, - raw_format="some_format", - repo_id="user/dataset", - local_dir=out_dir, - force_override=False, - ) - - -@pytest.mark.skip("push_dataset_to_hub is deprecated") -@pytest.mark.parametrize( - "required_packages, raw_format, repo_id, make_test_data", - [ - (["gym_pusht"], "pusht_zarr", "lerobot/pusht", False), - (["gym_pusht"], "pusht_zarr", "lerobot/pusht", True), - (None, "xarm_pkl", "lerobot/xarm_lift_medium", False), - (None, "aloha_hdf5", "lerobot/aloha_sim_insertion_scripted", False), - (["imagecodecs"], "umi_zarr", "lerobot/umi_cup_in_the_wild", False), - (None, "dora_parquet", "cadene/wrist_gripper", False), - ], -) -@require_package_arg -def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_id, make_test_data): - num_episodes = 3 - tmpdir = Path(tmpdir) - - raw_dir = tmpdir / f"{repo_id}_raw" - _mock_download_raw(raw_dir, repo_id) - - local_dir = tmpdir / repo_id - - lerobot_dataset = push_dataset_to_hub( - raw_dir=raw_dir, - raw_format=raw_format, - repo_id=repo_id, - push_to_hub=False, - local_dir=local_dir, - force_override=False, - cache_dir=tmpdir / "cache", - tests_data_dir=tmpdir / "tests/data" if make_test_data else None, - encoding={"vcodec": "libx264"}, - ) - - # minimal generic tests on the local directory containing LeRobotDataset - assert (local_dir / "meta_data" / "info.json").exists() - assert (local_dir / "meta_data" / "stats.safetensors").exists() - assert (local_dir / "meta_data" / "episode_data_index.safetensors").exists() - for i in range(num_episodes): - for cam_key in lerobot_dataset.camera_keys: - assert (local_dir / "videos" / f"{cam_key}_episode_{i:06d}.mp4").exists() - assert (local_dir / "train" / "dataset_info.json").exists() - assert (local_dir / "train" / "state.json").exists() - assert len(list((local_dir / "train").glob("*.arrow"))) > 0 - - # minimal generic tests on the item - item = lerobot_dataset[0] - assert "index" in item - assert "episode_index" in item - assert "timestamp" in item - for cam_key in lerobot_dataset.camera_keys: - assert cam_key in item - - if make_test_data: - # Check that only the first episode is selected. - test_dataset = LeRobotDataset(repo_id=repo_id, root=tmpdir / "tests/data") - num_frames = sum( - i == lerobot_dataset.hf_dataset["episode_index"][0] - for i in lerobot_dataset.hf_dataset["episode_index"] - ).item() - assert ( - test_dataset.hf_dataset["episode_index"] - == lerobot_dataset.hf_dataset["episode_index"][:num_frames] - ) - for k in ["from", "to"]: - assert torch.equal(test_dataset.episode_data_index[k], lerobot_dataset.episode_data_index[k][:1]) - - -@pytest.mark.skip("push_dataset_to_hub is deprecated") -@pytest.mark.parametrize( - "raw_format, repo_id", - [ - # TODO(rcadene): add raw dataset test artifacts - ("pusht_zarr", "lerobot/pusht"), - ("xarm_pkl", "lerobot/xarm_lift_medium"), - ("aloha_hdf5", "lerobot/aloha_sim_insertion_scripted"), - ("umi_zarr", "lerobot/umi_cup_in_the_wild"), - ("dora_parquet", "cadene/wrist_gripper"), - ], -) -def test_push_dataset_to_hub_pusht_backward_compatibility(tmpdir, raw_format, repo_id): - _, dataset_id = repo_id.split("/") - - tmpdir = Path(tmpdir) - raw_dir = tmpdir / f"{dataset_id}_raw" - local_dir = tmpdir / repo_id - - push_dataset_to_hub( - raw_dir=raw_dir, - raw_format=raw_format, - repo_id=repo_id, - push_to_hub=False, - local_dir=local_dir, - force_override=False, - cache_dir=tmpdir / "cache", - episodes=[0], - ) - - ds_actual = LeRobotDataset(repo_id, root=tmpdir) - ds_reference = LeRobotDataset(repo_id) - - assert len(ds_reference.hf_dataset) == len(ds_actual.hf_dataset) - - def check_same_items(item1, item2): - assert item1.keys() == item2.keys(), "Keys mismatch" - - for key in item1: - if isinstance(item1[key], torch.Tensor) and isinstance(item2[key], torch.Tensor): - assert torch.equal(item1[key], item2[key]), f"Mismatch found in key: {key}" - else: - assert item1[key] == item2[key], f"Mismatch found in key: {key}" - - for i in range(len(ds_reference.hf_dataset)): - item_reference = ds_reference.hf_dataset[i] - item_actual = ds_actual.hf_dataset[i] - check_same_items(item_reference, item_actual) diff --git a/tests/test_robots.py b/tests/test_robots.py index e03b5f78..6c300b71 100644 --- a/tests/test_robots.py +++ b/tests/test_robots.py @@ -23,8 +23,6 @@ pytest -sx 'tests/test_robots.py::test_robot[aloha-True]' ``` """ -from pathlib import Path - import pytest import torch @@ -35,7 +33,7 @@ from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot @pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES) @require_robot -def test_robot(tmpdir, request, robot_type, mock): +def test_robot(tmp_path, request, robot_type, mock): # TODO(rcadene): measure fps in nightly? # TODO(rcadene): test logs # TODO(rcadene): add compatibility with other robots @@ -50,8 +48,7 @@ def test_robot(tmpdir, request, robot_type, mock): request.getfixturevalue("patch_builtins_input") # Create an empty calibration directory to trigger manual calibration - tmpdir = Path(tmpdir) - calibration_dir = tmpdir / robot_type + calibration_dir = tmp_path / robot_type mock_calibration_dir(calibration_dir) robot_kwargs["calibration_dir"] = calibration_dir