[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-04-09 07:02:42 +00:00
parent dcd0f5c519
commit f97bcd30e2
3 changed files with 11 additions and 11 deletions

View File

@ -38,7 +38,10 @@ from huggingface_hub import HfApi
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats, convert_stats_parallel
from lerobot.common.datasets.v21.convert_stats import (
check_aggregate_stats,
convert_stats_parallel,
)
V20 = "v2.0"
V21 = "v2.1"

View File

@ -13,11 +13,10 @@
# limitations under the License.
from concurrent.futures import ThreadPoolExecutor, as_completed
from multiprocessing import cpu_count
import numpy as np
from tqdm import tqdm
from multiprocessing import cpu_count
from concurrent.futures import ProcessPoolExecutor, as_completed
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
@ -56,7 +55,7 @@ def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int, is_parallel: boo
if not is_parallel:
dataset.meta.episodes_stats[ep_idx] = ep_stats
return ep_stats, ep_idx
@ -86,14 +85,12 @@ def convert_stats_parallel(dataset: LeRobotDataset, num_workers: int = 0):
print("Computing episodes stats")
total_episodes = dataset.meta.total_episodes
futures = []
max_workers = min(cpu_count(), num_workers)
if num_workers > 0:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
for ep_idx in range(total_episodes):
futures.append(
executor.submit(convert_episode_stats, dataset, ep_idx, True)
)
futures.append(executor.submit(convert_episode_stats, dataset, ep_idx, True))
for future in tqdm(as_completed(futures), total=total_episodes, desc="Converting episodes stats"):
ep_stats, ep_data = future.result()
dataset.meta.episodes_stats[ep_idx] = ep_data
@ -103,7 +100,7 @@ def convert_stats_parallel(dataset: LeRobotDataset, num_workers: int = 0):
for ep_idx in tqdm(range(total_episodes)):
write_episode_stats(ep_idx, dataset.meta.episodes_stats[ep_idx], dataset.root)
def check_aggregate_stats(
dataset: LeRobotDataset,

View File

@ -257,11 +257,11 @@ def decode_video_frames_decord(
frame_ts: np.ndarray = vr.get_frame_timestamp(range(num_frames))
indices = np.abs(frame_ts[:, :1] - timestamps).argmin(axis=0)
frames = vr.get_batch(indices)
frames_tensor = torch.tensor(frames.asnumpy()).type(torch.float32).permute(0, 3, 1, 2) / 255
return frames_tensor
def encode_video_frames(
imgs_dir: Path | str,
video_path: Path | str,