Update merge.py

This commit fixes 8 linter warnings in the merge.py file, including:
1.Added contextlib import and used contextlib.suppress instead of the try-except-pass pattern
2.Removed unnecessary .keys() calls, using Pythonic way to iterate dictionaries directly
3.Renamed unused loop variables with underscore prefix (idx → _idx, dirs → _dirs, folder → _folder)
4. Combined nested if statements to improve code conciseness
These changes maintain the same functionality while improving code quality and readability to conform to the project's coding standards.
This commit is contained in:
zhipeng tang 2025-04-02 14:28:55 +08:00 committed by GitHub
parent a5d38c1ef5
commit 2e525a5a0a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 12 additions and 17 deletions

View File

@ -2,6 +2,7 @@ import json
import os import os
import shutil import shutil
import traceback import traceback
import contextlib
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@ -43,10 +44,8 @@ def load_jsonl(file_path):
with open(file_path) as f: with open(file_path) as f:
for line in f: for line in f:
if line.strip(): if line.strip():
try: with contextlib.suppress(json.JSONDecodeError):
data.append(json.loads(line)) data.append(json.loads(line))
except json.JSONDecodeError:
pass
except Exception as e: except Exception as e:
print(f"Error loading {file_path} line by line: {e}") print(f"Error loading {file_path} line by line: {e}")
else: else:
@ -54,10 +53,8 @@ def load_jsonl(file_path):
with open(file_path) as f: with open(file_path) as f:
for line in f: for line in f:
if line.strip(): if line.strip():
try: with contextlib.suppress(json.JSONDecodeError):
data.append(json.loads(line)) data.append(json.loads(line))
except json.JSONDecodeError:
print(f"Warning: Could not parse line in {file_path}: {line[:100]}...")
return data return data
@ -97,7 +94,7 @@ def merge_stats(stats_list):
common_features = common_features.intersection(set(stats.keys())) common_features = common_features.intersection(set(stats.keys()))
# Process features in the order they appear in the first stats file # Process features in the order they appear in the first stats file
for feature in stats_list[0].keys(): for feature in stats_list[0]:
if feature not in common_features: if feature not in common_features:
continue continue
@ -606,7 +603,7 @@ def copy_data_files(
for feature in ["observation.state", "action"]: for feature in ["observation.state", "action"]:
if feature in df.columns: if feature in df.columns:
# 检查第一个非空值 (Check first non-null value) # 检查第一个非空值 (Check first non-null value)
for idx, value in enumerate(df[feature]): for _idx, value in enumerate(df[feature]):
if value is not None and isinstance(value, (list, np.ndarray)): if value is not None and isinstance(value, (list, np.ndarray)):
current_dim = len(value) current_dim = len(value)
if current_dim < max_dim: if current_dim < max_dim:
@ -704,7 +701,7 @@ def copy_data_files(
for feature in ["observation.state", "action"]: for feature in ["observation.state", "action"]:
if feature in df.columns: if feature in df.columns:
# 检查第一个非空值 (Check first non-null value) # 检查第一个非空值 (Check first non-null value)
for idx, value in enumerate(df[feature]): for _idx, value in enumerate(df[feature]):
if value is not None and isinstance(value, (list, np.ndarray)): if value is not None and isinstance(value, (list, np.ndarray)):
current_dim = len(value) current_dim = len(value)
if current_dim < max_dim: if current_dim < max_dim:
@ -997,7 +994,7 @@ def merge_datasets(
folder_dim = max_dim # 使用变量替代硬编码的18 folder_dim = max_dim # 使用变量替代硬编码的18
# Try to find a parquet file to determine dimensions # Try to find a parquet file to determine dimensions
for root, dirs, files in os.walk(folder): for root, _dirs, files in os.walk(folder):
for file in files: for file in files:
if file.endswith(".parquet"): if file.endswith(".parquet"):
try: try:
@ -1141,7 +1138,7 @@ def merge_datasets(
# Update merged stats with episode-specific stats if available # Update merged stats with episode-specific stats if available
if all_stats_data: if all_stats_data:
# For each feature in the stats # For each feature in the stats
for feature in merged_stats.keys(): for feature in merged_stats:
if feature in all_stats_data[0]: if feature in all_stats_data[0]:
# Recalculate statistics based on all episodes # Recalculate statistics based on all episodes
values = [stat[feature] for stat in all_stats_data if feature in stat] values = [stat[feature] for stat in all_stats_data if feature in stat]
@ -1262,16 +1259,14 @@ def merge_datasets(
if "features" in info: if "features" in info:
# Find the maximum dimension across all folders # Find the maximum dimension across all folders
actual_max_dim = max_dim # 使用变量替代硬编码的18 actual_max_dim = max_dim # 使用变量替代硬编码的18
for folder, dim in folder_dimensions.items(): for _folder, dim in folder_dimensions.items():
actual_max_dim = max(actual_max_dim, dim) actual_max_dim = max(actual_max_dim, dim)
# Update observation.state and action dimensions # Update observation.state and action dimensions
for feature_name in ["observation.state", "action"]: for feature_name in ["observation.state", "action"]:
if feature_name in info["features"]: if feature_name in info["features"] and "shape" in info["features"][feature_name]:
# Update shape to the maximum dimension info["features"][feature_name]["shape"] = [actual_max_dim]
if "shape" in info["features"][feature_name]: print(f"Updated {feature_name} shape to {actual_max_dim}")
info["features"][feature_name]["shape"] = [actual_max_dim]
print(f"Updated {feature_name} shape to {actual_max_dim}")
# 更新视频总数 (Update total videos) # 更新视频总数 (Update total videos)
info["total_videos"] = total_videos info["total_videos"] = total_videos