diff --git a/lerobot/scripts/merge.py b/lerobot/scripts/merge.py index 6763ecf8..4e7ab3ca 100644 --- a/lerobot/scripts/merge.py +++ b/lerobot/scripts/merge.py @@ -2,6 +2,7 @@ import json import os import shutil import traceback +import contextlib import numpy as np import pandas as pd @@ -43,10 +44,8 @@ def load_jsonl(file_path): with open(file_path) as f: for line in f: if line.strip(): - try: + with contextlib.suppress(json.JSONDecodeError): data.append(json.loads(line)) - except json.JSONDecodeError: - pass except Exception as e: print(f"Error loading {file_path} line by line: {e}") else: @@ -54,10 +53,8 @@ def load_jsonl(file_path): with open(file_path) as f: for line in f: if line.strip(): - try: + with contextlib.suppress(json.JSONDecodeError): data.append(json.loads(line)) - except json.JSONDecodeError: - print(f"Warning: Could not parse line in {file_path}: {line[:100]}...") return data @@ -97,7 +94,7 @@ def merge_stats(stats_list): common_features = common_features.intersection(set(stats.keys())) # Process features in the order they appear in the first stats file - for feature in stats_list[0].keys(): + for feature in stats_list[0]: if feature not in common_features: continue @@ -606,7 +603,7 @@ def copy_data_files( for feature in ["observation.state", "action"]: if feature in df.columns: # 检查第一个非空值 (Check first non-null value) - for idx, value in enumerate(df[feature]): + for _idx, value in enumerate(df[feature]): if value is not None and isinstance(value, (list, np.ndarray)): current_dim = len(value) if current_dim < max_dim: @@ -704,7 +701,7 @@ def copy_data_files( for feature in ["observation.state", "action"]: if feature in df.columns: # 检查第一个非空值 (Check first non-null value) - for idx, value in enumerate(df[feature]): + for _idx, value in enumerate(df[feature]): if value is not None and isinstance(value, (list, np.ndarray)): current_dim = len(value) if current_dim < max_dim: @@ -997,7 +994,7 @@ def merge_datasets( folder_dim = max_dim # 使用变量替代硬编码的18 # Try to find a parquet file to determine dimensions - for root, dirs, files in os.walk(folder): + for root, _dirs, files in os.walk(folder): for file in files: if file.endswith(".parquet"): try: @@ -1141,7 +1138,7 @@ def merge_datasets( # Update merged stats with episode-specific stats if available if all_stats_data: # For each feature in the stats - for feature in merged_stats.keys(): + for feature in merged_stats: if feature in all_stats_data[0]: # Recalculate statistics based on all episodes values = [stat[feature] for stat in all_stats_data if feature in stat] @@ -1262,16 +1259,14 @@ def merge_datasets( if "features" in info: # Find the maximum dimension across all folders actual_max_dim = max_dim # 使用变量替代硬编码的18 - for folder, dim in folder_dimensions.items(): + for _folder, dim in folder_dimensions.items(): actual_max_dim = max(actual_max_dim, dim) # Update observation.state and action dimensions for feature_name in ["observation.state", "action"]: - if feature_name in info["features"]: - # Update shape to the maximum dimension - if "shape" in info["features"][feature_name]: - info["features"][feature_name]["shape"] = [actual_max_dim] - print(f"Updated {feature_name} shape to {actual_max_dim}") + if feature_name in info["features"] and "shape" in info["features"][feature_name]: + info["features"][feature_name]["shape"] = [actual_max_dim] + print(f"Updated {feature_name} shape to {actual_max_dim}") # 更新视频总数 (Update total videos) info["total_videos"] = total_videos