resolve indexing hangs and optimize code splitting performance
Build and Push CodeReview / build (push) Waiting to run Details

This commit is contained in:
vinland100 2026-01-09 16:12:55 +08:00
parent f86e556d0d
commit cd8fb49a56
1 changed files with 80 additions and 22 deletions

View File

@ -7,6 +7,8 @@ import re
import asyncio
import hashlib
import logging
import threading
import time
from typing import List, Dict, Any, Optional, Tuple, Set
from dataclasses import dataclass, field
from enum import Enum
@ -14,6 +16,9 @@ from pathlib import Path
logger = logging.getLogger(__name__)
# 全局缓存 tiktoken 编码器,避免在 CodeChunk 初始化时重复加载
_TIKTOKEN_ENCODING = None
class ChunkType(Enum):
"""代码块类型"""
@ -85,12 +90,18 @@ class CodeChunk:
def _estimate_tokens(self) -> int:
# 使用 tiktoken 如果可用
global _TIKTOKEN_ENCODING
try:
import tiktoken
enc = tiktoken.get_encoding("cl100k_base")
return len(enc.encode(self.content))
if _TIKTOKEN_ENCODING is None:
import tiktoken
_TIKTOKEN_ENCODING = tiktoken.get_encoding("cl100k_base")
return len(_TIKTOKEN_ENCODING.encode(self.content))
except ImportError:
return len(self.content) // 4
except Exception as e:
# 避免编码过程中的异常导致整个任务失败
logger.debug(f"Token estimation error: {e}")
return len(self.content) // 4
def to_dict(self) -> Dict[str, Any]:
result = {
@ -254,12 +265,21 @@ class TreeSitterParser:
}
def __init__(self):
self._parsers: Dict[str, Any] = {}
# 使用线程本地存储来保存解析器,因为 Tree-sitter Parser 对象不是线程安全的
# asyncio.to_thread 会在不同的线程中执行,如果共享同一个 parser 会导致崩溃或挂起
self._local = threading.local()
self._initialized = False
def _get_parsers(self) -> Dict[str, Any]:
"""获取当前线程的解析器字典"""
if not hasattr(self._local, 'parsers'):
self._local.parsers = {}
return self._local.parsers
def _ensure_initialized(self, language: str) -> bool:
"""确保语言解析器已初始化"""
if language in self._parsers:
parsers = self._get_parsers()
if language in parsers:
return True
# 检查语言是否受支持
@ -271,7 +291,7 @@ class TreeSitterParser:
from tree_sitter_language_pack import get_parser
parser = get_parser(language)
self._parsers[language] = parser
parsers[language] = parser
return True
except ImportError:
@ -286,7 +306,8 @@ class TreeSitterParser:
if not self._ensure_initialized(language):
return None
parser = self._parsers.get(language)
parsers = self._get_parsers()
parser = parsers.get(language)
if not parser:
return None
@ -497,7 +518,11 @@ class CodeSplitter:
def detect_language(self, file_path: str) -> str:
"""检测编程语言"""
ext = Path(file_path).suffix.lower()
return TreeSitterParser.LANGUAGE_MAP.get(ext, "text")
lang = TreeSitterParser.LANGUAGE_MAP.get(ext, "text")
# 针对 .cs 文件,确保映射到 csharp
if ext == ".cs":
return "csharp"
return lang
def split_file(
self,
@ -507,14 +532,6 @@ class CodeSplitter:
) -> List[CodeChunk]:
"""
分割单个文件
Args:
content: 文件内容
file_path: 文件路径
language: 编程语言可选
Returns:
代码块列表
"""
if not content or not content.strip():
return []
@ -522,6 +539,12 @@ class CodeSplitter:
if language is None:
language = self.detect_language(file_path)
# 记录开始处理大型文件
file_size_kb = len(content) / 1024
if file_size_kb > 100:
logger.debug(f"Starting to split large file: {file_path} ({file_size_kb:.1f} KB, lang: {language})")
start_time = time.time()
chunks = []
try:
@ -561,6 +584,10 @@ class CodeSplitter:
# 后处理:使用语义分析增强
self._enrich_chunks_with_semantics(chunks, content, language)
if file_size_kb > 100:
elapsed = time.time() - start_time
logger.debug(f"Finished splitting {file_path}: {len(chunks)} chunks, took {elapsed:.2f}s")
except Exception as e:
logger.warning(f"分块失败 {file_path}: {e}, 使用简单分块")
chunks = self._split_by_lines(content, file_path, language)
@ -939,12 +966,40 @@ class CodeSplitter:
def _filter_relevant_imports(self, all_imports: List[str], chunk_content: str) -> List[str]:
"""过滤与代码块相关的导入"""
relevant = []
if not all_imports:
return []
# 提取导入的末尾部分(通常是类名或模块名)
names_map = {}
for imp in all_imports:
module_name = imp.split('.')[-1]
if re.search(rf'\b{re.escape(module_name)}\b', chunk_content):
relevant.append(imp)
return relevant[:20]
name = imp.split('.')[-1]
if name:
if name not in names_map:
names_map[name] = []
names_map[name].append(imp)
if not names_map:
return []
# 构建一个综合正则,一次性匹配所有导入名称,效率远高于循环 re.search
try:
# 限制名称数量,避免正则过大
target_names = list(names_map.keys())[:200]
pattern = rf"\b({'|'.join(re.escape(name) for name in target_names)})\b"
matches = set(re.findall(pattern, chunk_content))
relevant = []
for m in matches:
relevant.extend(names_map[m])
return relevant[:20]
except Exception:
# 回退到简单搜索
relevant = []
for name, imps in names_map.items():
if name in chunk_content:
relevant.extend(imps)
if len(relevant) >= 20: break
return relevant[:20]
def _extract_function_calls(self, content: str, language: str) -> List[str]:
"""提取函数调用"""
@ -979,7 +1034,10 @@ class CodeSplitter:
],
"csharp": [
r"(?:class|record|struct|enum|interface)\s+(\w+)",
r"[\w<>\[\],\s]+\s+(\w+)\s*\([^)]*\)",
# 备用
# r"[^\s<>\[\],]+\s+(\w+)\s*\(",
# 究极版:支持基础类型、泛型、可空类型 (?)、数组 ([,] [][])、元组返回类型 ((int, string)) 以及构造函数
r"(?:(?:(?:[a-zA-Z_][\w.]*(?:<[\w\s,<>]+>)?\??|(?:\([\w\s,<>?]+\)))(?:\[[\s,]*\])*\s+)+)?(\w+)\s*(?:<[\w\s,<>]+>)?\s*\(",
],
"cpp": [
r"(?:class|struct)\s+(\w+)",