From f97bcd30e2a9f859be30cd7e0f49ec666c62fb0e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 9 Apr 2025 07:02:42 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../datasets/v21/convert_dataset_v20_to_v21.py | 5 ++++- lerobot/common/datasets/v21/convert_stats.py | 13 +++++-------- lerobot/common/datasets/video_utils.py | 4 ++-- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py index be58188d..53cacbef 100644 --- a/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py +++ b/lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py @@ -38,7 +38,10 @@ from huggingface_hub import HfApi from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info -from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats, convert_stats_parallel +from lerobot.common.datasets.v21.convert_stats import ( + check_aggregate_stats, + convert_stats_parallel, +) V20 = "v2.0" V21 = "v2.1" diff --git a/lerobot/common/datasets/v21/convert_stats.py b/lerobot/common/datasets/v21/convert_stats.py index a3184e13..cbce5f89 100644 --- a/lerobot/common/datasets/v21/convert_stats.py +++ b/lerobot/common/datasets/v21/convert_stats.py @@ -13,11 +13,10 @@ # limitations under the License. from concurrent.futures import ThreadPoolExecutor, as_completed +from multiprocessing import cpu_count import numpy as np from tqdm import tqdm -from multiprocessing import cpu_count -from concurrent.futures import ProcessPoolExecutor, as_completed from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices from lerobot.common.datasets.lerobot_dataset import LeRobotDataset @@ -56,7 +55,7 @@ def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int, is_parallel: boo if not is_parallel: dataset.meta.episodes_stats[ep_idx] = ep_stats - + return ep_stats, ep_idx @@ -86,14 +85,12 @@ def convert_stats_parallel(dataset: LeRobotDataset, num_workers: int = 0): print("Computing episodes stats") total_episodes = dataset.meta.total_episodes futures = [] - + max_workers = min(cpu_count(), num_workers) if num_workers > 0: with ThreadPoolExecutor(max_workers=max_workers) as executor: for ep_idx in range(total_episodes): - futures.append( - executor.submit(convert_episode_stats, dataset, ep_idx, True) - ) + futures.append(executor.submit(convert_episode_stats, dataset, ep_idx, True)) for future in tqdm(as_completed(futures), total=total_episodes, desc="Converting episodes stats"): ep_stats, ep_data = future.result() dataset.meta.episodes_stats[ep_idx] = ep_data @@ -103,7 +100,7 @@ def convert_stats_parallel(dataset: LeRobotDataset, num_workers: int = 0): for ep_idx in tqdm(range(total_episodes)): write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root) - + def check_aggregate_stats( dataset: LeRobotDataset, diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index 129a09b0..8d361fb8 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -257,11 +257,11 @@ def decode_video_frames_decord( frame_ts: np.ndarray = vr.get_frame_timestamp(range(num_frames)) indices = np.abs(frame_ts[:, :1] - timestamps).argmin(axis=0) frames = vr.get_batch(indices) - + frames_tensor = torch.tensor(frames.asnumpy()).type(torch.float32).permute(0, 3, 1, 2) / 255 return frames_tensor - + def encode_video_frames( imgs_dir: Path | str, video_path: Path | str,