CodeReview/backend/app/services/rag/splitter.py

1058 lines
38 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
代码分块器 - 基于 Tree-sitter AST 的智能代码分块
使用先进的 Python 库实现专业级代码解析
"""
import re
import asyncio
import hashlib
import logging
import threading
import time
from typing import List, Dict, Any, Optional, Tuple, Set
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
logger = logging.getLogger(__name__)
# 全局缓存 tiktoken 编码器,避免在 CodeChunk 初始化时重复加载
_TIKTOKEN_ENCODING = None
class ChunkType(Enum):
"""代码块类型"""
FILE = "file"
MODULE = "module"
CLASS = "class"
FUNCTION = "function"
METHOD = "method"
INTERFACE = "interface"
STRUCT = "struct"
ENUM = "enum"
IMPORT = "import"
CONSTANT = "constant"
CONFIG = "config"
COMMENT = "comment"
DECORATOR = "decorator"
UNKNOWN = "unknown"
@dataclass
class CodeChunk:
"""代码块"""
id: str
content: str
file_path: str
language: str
chunk_type: ChunkType
# 位置信息
line_start: int = 0
line_end: int = 0
byte_start: int = 0
byte_end: int = 0
# 语义信息
name: Optional[str] = None
parent_name: Optional[str] = None
signature: Optional[str] = None
docstring: Optional[str] = None
# AST 信息
ast_type: Optional[str] = None
# 关联信息
imports: List[str] = field(default_factory=list)
calls: List[str] = field(default_factory=list)
dependencies: List[str] = field(default_factory=list)
definitions: List[str] = field(default_factory=list)
# 安全相关
security_indicators: List[str] = field(default_factory=list)
# 元数据
metadata: Dict[str, Any] = field(default_factory=dict)
# Token 估算
estimated_tokens: int = 0
def __post_init__(self):
if not self.id:
self.id = self._generate_id()
if not self.estimated_tokens:
self.estimated_tokens = self._estimate_tokens()
def _generate_id(self) -> str:
# 使用完整内容的 hash 确保唯一性
content = f"{self.file_path}:{self.line_start}:{self.line_end}:{self.content}"
return hashlib.sha256(content.encode()).hexdigest()[:16]
def _estimate_tokens(self) -> int:
# 使用 tiktoken 如果可用
global _TIKTOKEN_ENCODING
try:
if _TIKTOKEN_ENCODING is None:
import tiktoken
_TIKTOKEN_ENCODING = tiktoken.get_encoding("cl100k_base")
return len(_TIKTOKEN_ENCODING.encode(self.content))
except ImportError:
return len(self.content) // 4
except Exception as e:
# 避免编码过程中的异常导致整个任务失败
logger.debug(f"Token estimation error: {e}")
return len(self.content) // 4
def to_dict(self) -> Dict[str, Any]:
result = {
"id": self.id,
"content": self.content,
"file_path": self.file_path,
"language": self.language,
"chunk_type": self.chunk_type.value,
"line_start": self.line_start,
"line_end": self.line_end,
"name": self.name,
"parent_name": self.parent_name,
"signature": self.signature,
"docstring": self.docstring,
"ast_type": self.ast_type,
"imports": self.imports,
"calls": self.calls,
"definitions": self.definitions,
"security_indicators": self.security_indicators,
"estimated_tokens": self.estimated_tokens,
}
# 将 metadata 中的字段提升到顶级,确保 file_hash 等字段可以被正确检索
if self.metadata:
for key, value in self.metadata.items():
if key not in result:
result[key] = value
return result
def to_embedding_text(self) -> str:
"""生成用于嵌入的文本"""
parts = []
parts.append(f"File: {self.file_path}")
if self.name:
parts.append(f"{self.chunk_type.value.title()}: {self.name}")
if self.parent_name:
parts.append(f"In: {self.parent_name}")
if self.signature:
parts.append(f"Signature: {self.signature}")
if self.docstring:
parts.append(f"Description: {self.docstring[:300]}")
parts.append(f"Code:\n{self.content}")
return "\n".join(parts)
class TreeSitterParser:
"""
基于 Tree-sitter 的代码解析器
提供 AST 级别的代码分析
"""
# 语言映射
LANGUAGE_MAP = {
".py": "python",
".js": "javascript",
".jsx": "javascript",
".ts": "typescript",
".tsx": "tsx",
".java": "java",
".go": "go",
".rs": "rust",
".cpp": "cpp",
".c": "c",
".h": "c",
".hpp": "cpp",
".hxx": "cpp",
".cs": "csharp",
".php": "php",
".rb": "ruby",
".kt": "kotlin",
".ktm": "kotlin",
".kts": "kotlin",
".swift": "swift",
".dart": "dart",
".scala": "scala",
".sc": "scala",
".groovy": "groovy",
".lua": "lua",
".hs": "haskell",
".clj": "clojure",
".ex": "elixir",
".erl": "erlang",
".m": "objective-c",
".mm": "objective-c",
".sh": "bash",
".bash": "bash",
".zsh": "bash",
".sql": "sql",
".md": "markdown",
".markdown": "markdown",
}
# 各语言的函数/类节点类型
DEFINITION_TYPES = {
"python": {
"class": ["class_definition"],
"function": ["function_definition"],
"method": ["function_definition"],
"import": ["import_statement", "import_from_statement"],
},
"javascript": {
"class": ["class_declaration", "class"],
"function": ["function_declaration", "function", "arrow_function", "method_definition"],
"import": ["import_statement"],
},
"typescript": {
"class": ["class_declaration", "class"],
"function": ["function_declaration", "function", "arrow_function", "method_definition"],
"interface": ["interface_declaration"],
"import": ["import_statement"],
},
"java": {
"class": ["class_declaration", "enum_declaration", "record_declaration"],
"method": ["method_declaration", "constructor_declaration"],
"interface": ["interface_declaration", "annotation_type_declaration"],
"import": ["import_declaration"],
},
"csharp": {
"class": ["class_declaration", "record_declaration", "struct_declaration", "enum_declaration"],
"method": ["method_declaration", "constructor_declaration", "destructor_declaration"],
"interface": ["interface_declaration"],
"import": ["using_directive"],
},
"cpp": {
"class": ["class_specifier", "struct_specifier", "enum_specifier"],
"function": ["function_definition"],
},
"go": {
"struct": ["type_declaration"],
"function": ["function_declaration", "method_declaration"],
"interface": ["type_declaration"],
"import": ["import_declaration"],
},
"rust": {
"struct": ["struct_item", "union_item"],
"enum": ["enum_item"],
"function": ["function_item"],
"class": ["impl_item", "trait_item"],
},
"php": {
"class": ["class_declaration"],
"function": ["function_definition", "method_definition"],
"interface": ["interface_declaration"],
},
"ruby": {
"class": ["class", "module"],
"function": ["method"],
},
"swift": {
"class": ["class_declaration", "struct_declaration", "enum_declaration"],
"function": ["function_declaration"],
"interface": ["protocol_declaration"],
},
}
# tree-sitter-languages 支持的语言列表
SUPPORTED_LANGUAGES = {
"python", "javascript", "typescript", "tsx", "java", "go", "rust",
"c", "cpp", "csharp", "php", "ruby", "kotlin", "swift", "bash",
"json", "yaml", "html", "css", "sql", "markdown", "dart", "scala",
"lua", "haskell", "clojure", "elixir", "erlang", "objective-c"
}
def __init__(self):
# 使用线程本地存储来保存解析器,因为 Tree-sitter Parser 对象不是线程安全的
# asyncio.to_thread 会在不同的线程中执行,如果共享同一个 parser 会导致崩溃或挂起
self._local = threading.local()
self._initialized = False
def _get_parsers(self) -> Dict[str, Any]:
"""获取当前线程的解析器字典"""
if not hasattr(self._local, 'parsers'):
self._local.parsers = {}
return self._local.parsers
def _ensure_initialized(self, language: str) -> bool:
"""确保语言解析器已初始化"""
parsers = self._get_parsers()
if language in parsers:
return True
# 检查语言是否受支持
if language not in self.SUPPORTED_LANGUAGES:
# 不是 tree-sitter 支持的语言,静默跳过
return False
try:
from tree_sitter_language_pack import get_parser
parser = get_parser(language)
parsers[language] = parser
return True
except ImportError:
logger.warning("tree-sitter-languages not installed, falling back to regex parsing")
return False
except Exception as e:
logger.warning(f"Failed to load tree-sitter parser for {language}: {e}")
return False
def parse(self, code: str, language: str) -> Optional[Any]:
"""解析代码返回 AST同步方法"""
if not self._ensure_initialized(language):
return None
parsers = self._get_parsers()
parser = parsers.get(language)
if not parser:
return None
try:
tree = parser.parse(code.encode())
return tree
except Exception as e:
logger.warning(f"Failed to parse code: {e}")
return None
async def parse_async(self, code: str, language: str) -> Optional[Any]:
"""
异步解析代码返回 AST
将 CPU 密集型的 Tree-sitter 解析操作放到线程池中执行,
避免阻塞事件循环
"""
return await asyncio.to_thread(self.parse, code, language)
def extract_definitions(self, tree: Any, code: str, language: str) -> List[Dict[str, Any]]:
"""从 AST 提取定义"""
if tree is None:
return []
definitions = []
definition_types = self.DEFINITION_TYPES.get(language, {})
def traverse(node, parent_name=None):
node_type = node.type
# 检查是否是定义节点
matched = False
for def_category, types in definition_types.items():
if node_type in types:
name = self._extract_name(node, language)
# 根据是否有 parent_name 来区分 function 和 method
actual_category = def_category
if def_category == "function" and parent_name:
actual_category = "method"
elif def_category == "method" and not parent_name:
# 跳过没有 parent 的 method 定义(由 function 类别处理)
continue
definitions.append({
"type": actual_category,
"name": name,
"parent_name": parent_name,
"start_point": node.start_point,
"end_point": node.end_point,
"start_byte": node.start_byte,
"end_byte": node.end_byte,
"node_type": node_type,
})
matched = True
# 对于类,继续遍历子节点找方法
if def_category == "class":
for child in node.children:
traverse(child, name)
return
# 匹配到一个类别后就不再匹配其他类别
break
# 如果没有匹配到定义,继续遍历子节点
if not matched:
for child in node.children:
traverse(child, parent_name)
traverse(tree.root_node)
return definitions
def _extract_name(self, node: Any, language: str) -> Optional[str]:
"""从节点提取名称"""
# 查找 identifier 子节点
for child in node.children:
if child.type in ["identifier", "name", "type_identifier", "property_identifier"]:
return child.text.decode() if isinstance(child.text, bytes) else child.text
# 对于某些语言的特殊处理
if language == "python":
for child in node.children:
if child.type == "name":
return child.text.decode() if isinstance(child.text, bytes) else child.text
return None
class CodeSplitter:
"""
高级代码分块器
使用 Tree-sitter 进行 AST 解析,支持多种编程语言
"""
# 危险函数/模式(用于安全指标)
SECURITY_PATTERNS = {
"python": [
(r"\bexec\s*\(", "exec"),
(r"\beval\s*\(", "eval"),
(r"\bcompile\s*\(", "compile"),
(r"\bos\.system\s*\(", "os_system"),
(r"\bsubprocess\.", "subprocess"),
(r"\bcursor\.execute\s*\(", "sql_execute"),
(r"\.execute\s*\(.*%", "sql_format"),
(r"\bpickle\.loads?\s*\(", "pickle"),
(r"\byaml\.load\s*\(", "yaml_load"),
(r"\brequests?\.", "http_request"),
(r"password\s*=", "password_assign"),
(r"secret\s*=", "secret_assign"),
(r"api_key\s*=", "api_key_assign"),
],
"javascript": [
(r"\beval\s*\(", "eval"),
(r"\bFunction\s*\(", "function_constructor"),
(r"innerHTML\s*=", "innerHTML"),
(r"outerHTML\s*=", "outerHTML"),
(r"document\.write\s*\(", "document_write"),
(r"\.exec\s*\(", "exec"),
(r"\.query\s*\(.*\+", "sql_concat"),
(r"password\s*[=:]", "password_assign"),
(r"apiKey\s*[=:]", "api_key_assign"),
],
"java": [
(r"Runtime\.getRuntime\(\)\.exec", "runtime_exec"),
(r"ProcessBuilder", "process_builder"),
(r"\.executeQuery\s*\(.*\+", "sql_concat"),
(r"ObjectInputStream", "deserialization"),
(r"XMLDecoder", "xml_decoder"),
(r"password\s*=", "password_assign"),
],
"go": [
(r"exec\.Command\s*\(", "exec_command"),
(r"\.Query\s*\(.*\+", "sql_concat"),
(r"\.Exec\s*\(.*\+", "sql_concat"),
(r"template\.HTML\s*\(", "unsafe_html"),
(r"password\s*=", "password_assign"),
],
"php": [
(r"\beval\s*\(", "eval"),
(r"\bexec\s*\(", "exec"),
(r"\bsystem\s*\(", "system"),
(r"\bshell_exec\s*\(", "shell_exec"),
(r"\$_GET\[", "get_input"),
(r"\$_POST\[", "post_input"),
(r"\$_REQUEST\[", "request_input"),
],
"csharp": [
(r"Process\.Start\s*\(", "process_start"),
(r"SqlCommand\s*\(.*\+", "sql_concat"),
(r"Deserialize\s*\(", "deserialization"),
(r"AllowHtml\s*=", "unsafe_html"),
(r"password\s*=", "password_assign"),
],
"cpp": [
(r"\bsystem\s*\(", "system_call"),
(r"\bpopen\s*\(", "popen"),
(r"\bstrcpy\s*\(", "unsafe_string_copy"),
(r"\bsprintf\s*\(", "unsafe_string_format"),
(r"\bmalloc\s*\(", "memory_allocation"),
],
"ruby": [
(r"\beval\s*\(", "eval"),
(r"`.*`", "shell_execution"),
(r"system\s*\(", "system_call"),
(r"send\s*\(", "dynamic_method_call"),
],
"rust": [
(r"\bunsafe\b", "unsafe_code"),
(r"Command::new", "process_execution"),
(r"File::create", "file_creation"),
(r"File::open", "file_access"),
],
"swift": [
(r"Process\(\)", "process_start"),
(r"try!", "force_try"),
(r"eval\s*\(", "eval"),
],
"kotlin": [
(r"Runtime\.getRuntime\(\)\.exec", "runtime_exec"),
(r"ProcessBuilder", "process_builder"),
(r"password\s*=", "password_assign"),
],
"sql": [
(r"DROP\s+TABLE", "drop_table"),
(r"DELETE\s+FROM", "delete_records"),
(r"GRANT\s+ALL", "grant_privileges"),
],
}
def __init__(
self,
max_chunk_size: int = 1500,
min_chunk_size: int = 100,
overlap_size: int = 50,
preserve_structure: bool = True,
use_tree_sitter: bool = True,
):
self.max_chunk_size = max_chunk_size
self.min_chunk_size = min_chunk_size
self.overlap_size = overlap_size
self.preserve_structure = preserve_structure
self.use_tree_sitter = use_tree_sitter
self._ts_parser = TreeSitterParser() if use_tree_sitter else None
def detect_language(self, file_path: str) -> str:
"""检测编程语言"""
ext = Path(file_path).suffix.lower()
lang = TreeSitterParser.LANGUAGE_MAP.get(ext, "text")
# 针对 .cs 文件,确保映射到 csharp
if ext == ".cs":
return "csharp"
return lang
def split_file(
self,
content: str,
file_path: str,
language: Optional[str] = None
) -> List[CodeChunk]:
"""
分割单个文件
"""
if not content or not content.strip():
return []
if language is None:
language = self.detect_language(file_path)
# 记录开始处理大型文件
file_size_kb = len(content) / 1024
if file_size_kb > 100:
logger.debug(f"Starting to split large file: {file_path} ({file_size_kb:.1f} KB, lang: {language})")
start_time = time.time()
chunks = []
try:
# 尝试使用 Tree-sitter 解析
if self.use_tree_sitter and self._ts_parser:
tree = self._ts_parser.parse(content, language)
if tree:
chunks = self._split_by_ast(content, file_path, language, tree)
# 如果 AST 解析失败或没有结果,使用增强的正则解析
if not chunks:
chunks = self._split_by_enhanced_regex(content, file_path, language)
# 如果还是没有,使用基于行的分块
if not chunks:
chunks = self._split_by_lines(content, file_path, language)
# 🔥 最后一道防线:如果文件不为空但没有产生任何块(比如文件内容太短被过滤了)
# 我们强制创建一个文件级别的块,以确保该文件在索引中“挂名”,避免增量索引一直提示它是“新增”
if not chunks and content.strip():
chunks.append(CodeChunk(
id="",
content=content,
file_path=file_path,
language=language,
chunk_type=ChunkType.FILE,
line_start=1,
line_end=len(content.split('\n')),
))
# 后处理:提取安全指标
for chunk in chunks:
chunk.security_indicators = self._extract_security_indicators(
chunk.content, language
)
# 后处理:使用语义分析增强
self._enrich_chunks_with_semantics(chunks, content, language)
if file_size_kb > 100:
elapsed = time.time() - start_time
logger.debug(f"Finished splitting {file_path}: {len(chunks)} chunks, took {elapsed:.2f}s")
except Exception as e:
logger.warning(f"分块失败 {file_path}: {e}, 使用简单分块")
chunks = self._split_by_lines(content, file_path, language)
return chunks
async def split_file_async(
self,
content: str,
file_path: str,
language: Optional[str] = None
) -> List[CodeChunk]:
"""
异步分割单个文件
将 CPU 密集型的分块操作(包括 Tree-sitter 解析)放到线程池中执行,
避免阻塞事件循环。
Args:
content: 文件内容
file_path: 文件路径
language: 编程语言(可选)
Returns:
代码块列表
"""
return await asyncio.to_thread(self.split_file, content, file_path, language)
def _split_by_ast(
self,
content: str,
file_path: str,
language: str,
tree: Any
) -> List[CodeChunk]:
"""基于 AST 分块"""
chunks = []
lines = content.split('\n')
# 提取定义
definitions = self._ts_parser.extract_definitions(tree, content, language)
if not definitions:
return []
# 为每个定义创建代码块
for defn in definitions:
start_line = defn["start_point"][0]
end_line = defn["end_point"][0]
# 提取代码内容
chunk_lines = lines[start_line:end_line + 1]
chunk_content = '\n'.join(chunk_lines)
if len(chunk_content.strip()) < self.min_chunk_size // 4:
continue
chunk_type = ChunkType.CLASS if defn["type"] == "class" else \
ChunkType.FUNCTION if defn["type"] in ["function", "method"] else \
ChunkType.INTERFACE if defn["type"] == "interface" else \
ChunkType.STRUCT if defn["type"] == "struct" else \
ChunkType.IMPORT if defn["type"] == "import" else \
ChunkType.MODULE
chunk = CodeChunk(
id="",
content=chunk_content,
file_path=file_path,
language=language,
chunk_type=chunk_type,
line_start=start_line + 1,
line_end=end_line + 1,
byte_start=defn["start_byte"],
byte_end=defn["end_byte"],
name=defn.get("name"),
parent_name=defn.get("parent_name"),
ast_type=defn.get("node_type"),
)
# 如果块太大,进一步分割
if chunk.estimated_tokens > self.max_chunk_size:
sub_chunks = self._split_large_chunk(chunk)
chunks.extend(sub_chunks)
else:
chunks.append(chunk)
return chunks
def _split_by_enhanced_regex(
self,
content: str,
file_path: str,
language: str
) -> List[CodeChunk]:
"""增强的正则表达式分块(支持更多语言)"""
chunks = []
lines = content.split('\n')
# 各语言的定义模式
patterns = {
"python": [
(r"^(\s*)class\s+(\w+)(?:\s*\([^)]*\))?\s*:", ChunkType.CLASS),
(r"^(\s*)(?:async\s+)?def\s+(\w+)\s*\([^)]*\)\s*(?:->[^:]+)?:", ChunkType.FUNCTION),
],
"javascript": [
(r"^(\s*)(?:export\s+)?class\s+(\w+)", ChunkType.CLASS),
(r"^(\s*)(?:export\s+)?(?:async\s+)?function\s*(\w*)\s*\(", ChunkType.FUNCTION),
(r"^(\s*)(?:export\s+)?(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s+)?\([^)]*\)\s*=>", ChunkType.FUNCTION),
],
"typescript": [
(r"^(\s*)(?:export\s+)?(?:abstract\s+)?class\s+(\w+)", ChunkType.CLASS),
(r"^(\s*)(?:export\s+)?interface\s+(\w+)", ChunkType.INTERFACE),
(r"^(\s*)(?:export\s+)?(?:async\s+)?function\s*(\w*)", ChunkType.FUNCTION),
],
"java": [
(r"^(\s*)(?:public|private|protected)?\s*(?:static\s+)?(?:final\s+)?class\s+(\w+)", ChunkType.CLASS),
(r"^(\s*)(?:public|private|protected)?\s*interface\s+(\w+)", ChunkType.INTERFACE),
(r"^(\s*)(?:public|private|protected)?\s*(?:static\s+)?[\w<>\[\],\s]+\s+(\w+)\s*\([^)]*\)\s*(?:throws\s+[\w,\s]+)?\s*\{", ChunkType.METHOD),
],
"go": [
(r"^type\s+(\w+)\s+struct\s*\{", ChunkType.STRUCT),
(r"^type\s+(\w+)\s+interface\s*\{", ChunkType.INTERFACE),
(r"^func\s+(?:\([^)]+\)\s+)?(\w+)\s*\([^)]*\)", ChunkType.FUNCTION),
],
"php": [
(r"^(\s*)(?:abstract\s+)?class\s+(\w+)", ChunkType.CLASS),
(r"^(\s*)interface\s+(\w+)", ChunkType.INTERFACE),
(r"^(\s*)(?:public|private|protected)?\s*(?:static\s+)?function\s+(\w+)", ChunkType.FUNCTION),
],
"csharp": [
(r"^(\s*)(?:public|private|protected|internal)?\s*(?:static\s+)?(?:partial\s+)?(?:class|record|struct|enum)\s+(\w+)", ChunkType.CLASS),
(r"^(\s*)(?:public|private|protected|internal)?\s*interface\s+(\w+)", ChunkType.INTERFACE),
(r"^(\s*)(?:public|private|protected|internal)?\s*(?:async\s+)?(?:static\s+)?[\w<>\[\],\s]+\s+(\w+)\s*\([^)]*\)", ChunkType.METHOD),
],
"cpp": [
(r"^(\s*)(?:class|struct)\s+(\w+)", ChunkType.CLASS),
(r"^(\s*)[\w<>:]+\s+(\w+)\s*\([^)]*\)\s*\{", ChunkType.FUNCTION),
],
"ruby": [
(r"^(\s*)(?:class|module)\s+(\w+)", ChunkType.CLASS),
(r"^(\s*)def\s+(\w+)", ChunkType.FUNCTION),
],
"rust": [
(r"^(\s*)(?:pub\s+)?(?:struct|enum|union)\s+(\w+)", ChunkType.CLASS),
(r"^(\s*)(?:pub\s+)?(?:async\s+)?fn\s+(\w+)", ChunkType.FUNCTION),
(r"^(\s*)(?:pub\s+)?impl", ChunkType.CLASS),
(r"^(\s*)(?:pub\s+)?trait\s+(\w+)", ChunkType.INTERFACE),
],
}
lang_patterns = patterns.get(language, [])
if not lang_patterns:
return []
# 找到所有定义的位置
definitions = []
for i, line in enumerate(lines):
for pattern, chunk_type in lang_patterns:
match = re.match(pattern, line)
if match:
indent = len(match.group(1)) if match.lastindex >= 1 else 0
name = match.group(2) if match.lastindex >= 2 else None
definitions.append({
"line": i,
"indent": indent,
"name": name,
"type": chunk_type,
})
break
if not definitions:
return []
# 计算每个定义的范围
for i, defn in enumerate(definitions):
start_line = defn["line"]
base_indent = defn["indent"]
# 查找结束位置
end_line = len(lines) - 1
for j in range(start_line + 1, len(lines)):
line = lines[j]
if line.strip():
current_indent = len(line) - len(line.lstrip())
# 如果缩进回到基础级别,检查是否是下一个定义
if current_indent <= base_indent:
# 检查是否是下一个定义
is_next_def = any(d["line"] == j for d in definitions)
if is_next_def or (current_indent < base_indent):
end_line = j - 1
break
chunk_content = '\n'.join(lines[start_line:end_line + 1])
if len(chunk_content.strip()) < 10:
continue
chunk = CodeChunk(
id="",
content=chunk_content,
file_path=file_path,
language=language,
chunk_type=defn["type"],
line_start=start_line + 1,
line_end=end_line + 1,
name=defn.get("name"),
)
if chunk.estimated_tokens > self.max_chunk_size:
sub_chunks = self._split_large_chunk(chunk)
chunks.extend(sub_chunks)
else:
chunks.append(chunk)
return chunks
def _split_by_lines(
self,
content: str,
file_path: str,
language: str
) -> List[CodeChunk]:
"""基于行数分块(回退方案)"""
chunks = []
lines = content.split('\n')
# 估算每行 Token 数
total_tokens = len(content) // 4
avg_tokens_per_line = max(1, total_tokens // max(1, len(lines)))
lines_per_chunk = max(10, self.max_chunk_size // avg_tokens_per_line)
overlap_lines = self.overlap_size // avg_tokens_per_line
for i in range(0, len(lines), lines_per_chunk - overlap_lines):
end = min(i + lines_per_chunk, len(lines))
chunk_content = '\n'.join(lines[i:end])
if len(chunk_content.strip()) < 10:
continue
chunk = CodeChunk(
id="",
content=chunk_content,
file_path=file_path,
language=language,
chunk_type=ChunkType.MODULE,
line_start=i + 1,
line_end=end,
)
chunks.append(chunk)
if end >= len(lines):
break
return chunks
def _split_large_chunk(self, chunk: CodeChunk) -> List[CodeChunk]:
"""分割过大的代码块"""
sub_chunks = []
lines = chunk.content.split('\n')
avg_tokens_per_line = max(1, chunk.estimated_tokens // max(1, len(lines)))
lines_per_chunk = max(10, self.max_chunk_size // avg_tokens_per_line)
for i in range(0, len(lines), lines_per_chunk):
end = min(i + lines_per_chunk, len(lines))
sub_content = '\n'.join(lines[i:end])
if len(sub_content.strip()) < 10:
continue
sub_chunk = CodeChunk(
id="",
content=sub_content,
file_path=chunk.file_path,
language=chunk.language,
chunk_type=chunk.chunk_type,
line_start=chunk.line_start + i,
line_end=chunk.line_start + end - 1,
name=chunk.name,
parent_name=chunk.parent_name,
)
sub_chunks.append(sub_chunk)
return sub_chunks if sub_chunks else [chunk]
def _extract_security_indicators(self, content: str, language: str) -> List[str]:
"""提取安全相关指标"""
indicators = []
patterns = self.SECURITY_PATTERNS.get(language, [])
# 添加通用模式
common_patterns = [
(r"password", "password"),
(r"secret", "secret"),
(r"api[_-]?key", "api_key"),
(r"token", "token"),
(r"private[_-]?key", "private_key"),
(r"credential", "credential"),
]
all_patterns = patterns + common_patterns
for pattern, name in all_patterns:
try:
if re.search(pattern, content, re.IGNORECASE):
if name not in indicators:
indicators.append(name)
except re.error:
continue
return indicators[:15]
def _enrich_chunks_with_semantics(
self,
chunks: List[CodeChunk],
full_content: str,
language: str
):
"""使用语义分析增强代码块"""
# 提取导入
imports = self._extract_imports(full_content, language)
for chunk in chunks:
# 添加相关导入
chunk.imports = self._filter_relevant_imports(imports, chunk.content)
# 提取函数调用
chunk.calls = self._extract_function_calls(chunk.content, language)
# 提取定义
chunk.definitions = self._extract_definitions(chunk.content, language)
def _extract_imports(self, content: str, language: str) -> List[str]:
"""提取导入语句"""
imports = []
patterns = {
"python": [
r"^import\s+([\w.]+)",
r"^from\s+([\w.]+)\s+import",
],
"javascript": [
r"^import\s+.*\s+from\s+['\"]([^'\"]+)['\"]",
r"require\s*\(['\"]([^'\"]+)['\"]\)",
],
"typescript": [
r"^import\s+.*\s+from\s+['\"]([^'\"]+)['\"]",
],
"java": [
r"^import\s+([\w.]+);",
],
"csharp": [
r"^using\s+([\w.]+);",
],
"go": [
r"['\"]([^'\"]+)['\"]",
],
"cpp": [
r'^#include\s+["<]([^">]+)[">]',
],
"php": [
r"^use\s+([\w\\]+);",
r"require(?:_once)?\s*\(['\"]([^'\"]+)['\"]\)",
],
"ruby": [
r"require\s+['\"]([^'\"]+)['\"]",
r"require_relative\s+['\"]([^'\"]+)['\"]",
],
}
for pattern in patterns.get(language, []):
matches = re.findall(pattern, content, re.MULTILINE)
imports.extend(matches)
return list(set(imports))
def _filter_relevant_imports(self, all_imports: List[str], chunk_content: str) -> List[str]:
"""过滤与代码块相关的导入"""
if not all_imports:
return []
# 提取导入的末尾部分(通常是类名或模块名)
names_map = {}
for imp in all_imports:
name = imp.split('.')[-1]
if name:
if name not in names_map:
names_map[name] = []
names_map[name].append(imp)
if not names_map:
return []
# 构建一个综合正则,一次性匹配所有导入名称,效率远高于循环 re.search
try:
# 限制名称数量,避免正则过大
target_names = list(names_map.keys())[:200]
pattern = rf"\b({'|'.join(re.escape(name) for name in target_names)})\b"
matches = set(re.findall(pattern, chunk_content))
relevant = []
for m in matches:
relevant.extend(names_map[m])
return relevant[:20]
except Exception:
# 回退到简单搜索
relevant = []
for name, imps in names_map.items():
if name in chunk_content:
relevant.extend(imps)
if len(relevant) >= 20: break
return relevant[:20]
def _extract_function_calls(self, content: str, language: str) -> List[str]:
"""提取函数调用"""
pattern = r'\b(\w+)\s*\('
matches = re.findall(pattern, content)
keywords = {
"python": {"if", "for", "while", "with", "def", "class", "return", "except", "print", "assert", "lambda"},
"javascript": {"if", "for", "while", "switch", "function", "return", "catch", "console", "async", "await"},
"java": {"if", "for", "while", "switch", "return", "catch", "throw", "new"},
"go": {"if", "for", "switch", "return", "func", "go", "defer"},
}
lang_keywords = keywords.get(language, set())
calls = [m for m in matches if m not in lang_keywords]
return list(set(calls))[:30]
def _extract_definitions(self, content: str, language: str) -> List[str]:
"""提取定义的标识符"""
definitions = []
patterns = {
"python": [
r"def\s+(\w+)\s*\(",
r"class\s+(\w+)",
r"(\w+)\s*=\s*",
],
"java": [
r"(?:public|private|protected)?\s*(?:static\s+)?(?:final\s+)?(?:class|interface|enum|record)\s+(\w+)",
r"(?:public|private|protected)?\s*(?:static\s+)?[\w<>\[\],\s]+\s+(\w+)\s*\([^)]*\)",
],
"csharp": [
r"(?:class|record|struct|enum|interface)\s+(\w+)",
# 备用
# r"[^\s<>\[\],]+\s+(\w+)\s*\(",
# 究极版:支持基础类型、泛型、可空类型 (?)、数组 ([,] [][])、元组返回类型 ((int, string)) 以及构造函数
r"(?:(?:(?:[a-zA-Z_][\w.]*(?:<[\w\s,<>]+>)?\??|(?:\([\w\s,<>?]+\)))(?:\[[\s,]*\])*\s+)+)?(\w+)\s*(?:<[\w\s,<>]+>)?\s*\(",
],
"cpp": [
r"(?:class|struct)\s+(\w+)",
r"(?:[\w<>:]+)\s+(\w+)\s*\([^)]*\)\s*\{",
],
"rust": [
r"(?:struct|enum|union|trait)\s+(\w+)",
r"fn\s+(\w+)",
],
}
for pattern in patterns.get(language, []):
matches = re.findall(pattern, content)
definitions.extend(matches)
return list(set(definitions))[:20]