""" 代码分块器 - 基于 Tree-sitter AST 的智能代码分块 使用先进的 Python 库实现专业级代码解析 """ import re import asyncio import hashlib import logging from typing import List, Dict, Any, Optional, Tuple, Set from dataclasses import dataclass, field from enum import Enum from pathlib import Path logger = logging.getLogger(__name__) class ChunkType(Enum): """代码块类型""" FILE = "file" MODULE = "module" CLASS = "class" FUNCTION = "function" METHOD = "method" INTERFACE = "interface" STRUCT = "struct" ENUM = "enum" IMPORT = "import" CONSTANT = "constant" CONFIG = "config" COMMENT = "comment" DECORATOR = "decorator" UNKNOWN = "unknown" @dataclass class CodeChunk: """代码块""" id: str content: str file_path: str language: str chunk_type: ChunkType # 位置信息 line_start: int = 0 line_end: int = 0 byte_start: int = 0 byte_end: int = 0 # 语义信息 name: Optional[str] = None parent_name: Optional[str] = None signature: Optional[str] = None docstring: Optional[str] = None # AST 信息 ast_type: Optional[str] = None # 关联信息 imports: List[str] = field(default_factory=list) calls: List[str] = field(default_factory=list) dependencies: List[str] = field(default_factory=list) definitions: List[str] = field(default_factory=list) # 安全相关 security_indicators: List[str] = field(default_factory=list) # 元数据 metadata: Dict[str, Any] = field(default_factory=dict) # Token 估算 estimated_tokens: int = 0 def __post_init__(self): if not self.id: self.id = self._generate_id() if not self.estimated_tokens: self.estimated_tokens = self._estimate_tokens() def _generate_id(self) -> str: # 使用完整内容的 hash 确保唯一性 content = f"{self.file_path}:{self.line_start}:{self.line_end}:{self.content}" return hashlib.sha256(content.encode()).hexdigest()[:16] def _estimate_tokens(self) -> int: # 使用 tiktoken 如果可用 try: import tiktoken enc = tiktoken.get_encoding("cl100k_base") return len(enc.encode(self.content)) except ImportError: return len(self.content) // 4 def to_dict(self) -> Dict[str, Any]: result = { "id": self.id, "content": self.content, "file_path": self.file_path, "language": self.language, "chunk_type": self.chunk_type.value, "line_start": self.line_start, "line_end": self.line_end, "name": self.name, "parent_name": self.parent_name, "signature": self.signature, "docstring": self.docstring, "ast_type": self.ast_type, "imports": self.imports, "calls": self.calls, "definitions": self.definitions, "security_indicators": self.security_indicators, "estimated_tokens": self.estimated_tokens, } # 将 metadata 中的字段提升到顶级,确保 file_hash 等字段可以被正确检索 if self.metadata: for key, value in self.metadata.items(): if key not in result: result[key] = value return result def to_embedding_text(self) -> str: """生成用于嵌入的文本""" parts = [] parts.append(f"File: {self.file_path}") if self.name: parts.append(f"{self.chunk_type.value.title()}: {self.name}") if self.parent_name: parts.append(f"In: {self.parent_name}") if self.signature: parts.append(f"Signature: {self.signature}") if self.docstring: parts.append(f"Description: {self.docstring[:300]}") parts.append(f"Code:\n{self.content}") return "\n".join(parts) class TreeSitterParser: """ 基于 Tree-sitter 的代码解析器 提供 AST 级别的代码分析 """ # 语言映射 LANGUAGE_MAP = { ".py": "python", ".js": "javascript", ".jsx": "javascript", ".ts": "typescript", ".tsx": "tsx", ".java": "java", ".go": "go", ".rs": "rust", ".cpp": "cpp", ".c": "c", ".h": "c", ".hpp": "cpp", ".hxx": "cpp", ".cs": "csharp", ".php": "php", ".rb": "ruby", ".kt": "kotlin", ".ktm": "kotlin", ".kts": "kotlin", ".swift": "swift", ".dart": "dart", ".scala": "scala", ".sc": "scala", ".groovy": "groovy", ".lua": "lua", ".hs": "haskell", ".clj": "clojure", ".ex": "elixir", ".erl": "erlang", ".m": "objective-c", ".mm": "objective-c", ".sh": "bash", ".bash": "bash", ".zsh": "bash", ".sql": "sql", ".md": "markdown", ".markdown": "markdown", } # 各语言的函数/类节点类型 DEFINITION_TYPES = { "python": { "class": ["class_definition"], "function": ["function_definition"], "method": ["function_definition"], "import": ["import_statement", "import_from_statement"], }, "javascript": { "class": ["class_declaration", "class"], "function": ["function_declaration", "function", "arrow_function", "method_definition"], "import": ["import_statement"], }, "typescript": { "class": ["class_declaration", "class"], "function": ["function_declaration", "function", "arrow_function", "method_definition"], "interface": ["interface_declaration"], "import": ["import_statement"], }, "java": { "class": ["class_declaration", "enum_declaration", "record_declaration"], "method": ["method_declaration", "constructor_declaration"], "interface": ["interface_declaration", "annotation_type_declaration"], "import": ["import_declaration"], }, "csharp": { "class": ["class_declaration", "record_declaration", "struct_declaration", "enum_declaration"], "method": ["method_declaration", "constructor_declaration", "destructor_declaration"], "interface": ["interface_declaration"], "import": ["using_directive"], }, "cpp": { "class": ["class_specifier", "struct_specifier", "enum_specifier"], "function": ["function_definition"], }, "go": { "struct": ["type_declaration"], "function": ["function_declaration", "method_declaration"], "interface": ["type_declaration"], "import": ["import_declaration"], }, "rust": { "struct": ["struct_item", "union_item"], "enum": ["enum_item"], "function": ["function_item"], "class": ["impl_item", "trait_item"], }, "php": { "class": ["class_declaration"], "function": ["function_definition", "method_definition"], "interface": ["interface_declaration"], }, "ruby": { "class": ["class", "module"], "function": ["method"], }, "swift": { "class": ["class_declaration", "struct_declaration", "enum_declaration"], "function": ["function_declaration"], "interface": ["protocol_declaration"], }, } # tree-sitter-languages 支持的语言列表 SUPPORTED_LANGUAGES = { "python", "javascript", "typescript", "tsx", "java", "go", "rust", "c", "cpp", "csharp", "php", "ruby", "kotlin", "swift", "bash", "json", "yaml", "html", "css", "sql", "markdown", "dart", "scala", "lua", "haskell", "clojure", "elixir", "erlang", "objective-c" } def __init__(self): self._parsers: Dict[str, Any] = {} self._initialized = False def _ensure_initialized(self, language: str) -> bool: """确保语言解析器已初始化""" if language in self._parsers: return True # 检查语言是否受支持 if language not in self.SUPPORTED_LANGUAGES: # 不是 tree-sitter 支持的语言,静默跳过 return False try: from tree_sitter_language_pack import get_parser parser = get_parser(language) self._parsers[language] = parser return True except ImportError: logger.warning("tree-sitter-languages not installed, falling back to regex parsing") return False except Exception as e: logger.warning(f"Failed to load tree-sitter parser for {language}: {e}") return False def parse(self, code: str, language: str) -> Optional[Any]: """解析代码返回 AST(同步方法)""" if not self._ensure_initialized(language): return None parser = self._parsers.get(language) if not parser: return None try: tree = parser.parse(code.encode()) return tree except Exception as e: logger.warning(f"Failed to parse code: {e}") return None async def parse_async(self, code: str, language: str) -> Optional[Any]: """ 异步解析代码返回 AST 将 CPU 密集型的 Tree-sitter 解析操作放到线程池中执行, 避免阻塞事件循环 """ return await asyncio.to_thread(self.parse, code, language) def extract_definitions(self, tree: Any, code: str, language: str) -> List[Dict[str, Any]]: """从 AST 提取定义""" if tree is None: return [] definitions = [] definition_types = self.DEFINITION_TYPES.get(language, {}) def traverse(node, parent_name=None): node_type = node.type # 检查是否是定义节点 matched = False for def_category, types in definition_types.items(): if node_type in types: name = self._extract_name(node, language) # 根据是否有 parent_name 来区分 function 和 method actual_category = def_category if def_category == "function" and parent_name: actual_category = "method" elif def_category == "method" and not parent_name: # 跳过没有 parent 的 method 定义(由 function 类别处理) continue definitions.append({ "type": actual_category, "name": name, "parent_name": parent_name, "start_point": node.start_point, "end_point": node.end_point, "start_byte": node.start_byte, "end_byte": node.end_byte, "node_type": node_type, }) matched = True # 对于类,继续遍历子节点找方法 if def_category == "class": for child in node.children: traverse(child, name) return # 匹配到一个类别后就不再匹配其他类别 break # 如果没有匹配到定义,继续遍历子节点 if not matched: for child in node.children: traverse(child, parent_name) traverse(tree.root_node) return definitions def _extract_name(self, node: Any, language: str) -> Optional[str]: """从节点提取名称""" # 查找 identifier 子节点 for child in node.children: if child.type in ["identifier", "name", "type_identifier", "property_identifier"]: return child.text.decode() if isinstance(child.text, bytes) else child.text # 对于某些语言的特殊处理 if language == "python": for child in node.children: if child.type == "name": return child.text.decode() if isinstance(child.text, bytes) else child.text return None class CodeSplitter: """ 高级代码分块器 使用 Tree-sitter 进行 AST 解析,支持多种编程语言 """ # 危险函数/模式(用于安全指标) SECURITY_PATTERNS = { "python": [ (r"\bexec\s*\(", "exec"), (r"\beval\s*\(", "eval"), (r"\bcompile\s*\(", "compile"), (r"\bos\.system\s*\(", "os_system"), (r"\bsubprocess\.", "subprocess"), (r"\bcursor\.execute\s*\(", "sql_execute"), (r"\.execute\s*\(.*%", "sql_format"), (r"\bpickle\.loads?\s*\(", "pickle"), (r"\byaml\.load\s*\(", "yaml_load"), (r"\brequests?\.", "http_request"), (r"password\s*=", "password_assign"), (r"secret\s*=", "secret_assign"), (r"api_key\s*=", "api_key_assign"), ], "javascript": [ (r"\beval\s*\(", "eval"), (r"\bFunction\s*\(", "function_constructor"), (r"innerHTML\s*=", "innerHTML"), (r"outerHTML\s*=", "outerHTML"), (r"document\.write\s*\(", "document_write"), (r"\.exec\s*\(", "exec"), (r"\.query\s*\(.*\+", "sql_concat"), (r"password\s*[=:]", "password_assign"), (r"apiKey\s*[=:]", "api_key_assign"), ], "java": [ (r"Runtime\.getRuntime\(\)\.exec", "runtime_exec"), (r"ProcessBuilder", "process_builder"), (r"\.executeQuery\s*\(.*\+", "sql_concat"), (r"ObjectInputStream", "deserialization"), (r"XMLDecoder", "xml_decoder"), (r"password\s*=", "password_assign"), ], "go": [ (r"exec\.Command\s*\(", "exec_command"), (r"\.Query\s*\(.*\+", "sql_concat"), (r"\.Exec\s*\(.*\+", "sql_concat"), (r"template\.HTML\s*\(", "unsafe_html"), (r"password\s*=", "password_assign"), ], "php": [ (r"\beval\s*\(", "eval"), (r"\bexec\s*\(", "exec"), (r"\bsystem\s*\(", "system"), (r"\bshell_exec\s*\(", "shell_exec"), (r"\$_GET\[", "get_input"), (r"\$_POST\[", "post_input"), (r"\$_REQUEST\[", "request_input"), ], "csharp": [ (r"Process\.Start\s*\(", "process_start"), (r"SqlCommand\s*\(.*\+", "sql_concat"), (r"Deserialize\s*\(", "deserialization"), (r"AllowHtml\s*=", "unsafe_html"), (r"password\s*=", "password_assign"), ], "cpp": [ (r"\bsystem\s*\(", "system_call"), (r"\bpopen\s*\(", "popen"), (r"\bstrcpy\s*\(", "unsafe_string_copy"), (r"\bsprintf\s*\(", "unsafe_string_format"), (r"\bmalloc\s*\(", "memory_allocation"), ], "ruby": [ (r"\beval\s*\(", "eval"), (r"`.*`", "shell_execution"), (r"system\s*\(", "system_call"), (r"send\s*\(", "dynamic_method_call"), ], "rust": [ (r"\bunsafe\b", "unsafe_code"), (r"Command::new", "process_execution"), (r"File::create", "file_creation"), (r"File::open", "file_access"), ], "swift": [ (r"Process\(\)", "process_start"), (r"try!", "force_try"), (r"eval\s*\(", "eval"), ], "kotlin": [ (r"Runtime\.getRuntime\(\)\.exec", "runtime_exec"), (r"ProcessBuilder", "process_builder"), (r"password\s*=", "password_assign"), ], "sql": [ (r"DROP\s+TABLE", "drop_table"), (r"DELETE\s+FROM", "delete_records"), (r"GRANT\s+ALL", "grant_privileges"), ], } def __init__( self, max_chunk_size: int = 1500, min_chunk_size: int = 100, overlap_size: int = 50, preserve_structure: bool = True, use_tree_sitter: bool = True, ): self.max_chunk_size = max_chunk_size self.min_chunk_size = min_chunk_size self.overlap_size = overlap_size self.preserve_structure = preserve_structure self.use_tree_sitter = use_tree_sitter self._ts_parser = TreeSitterParser() if use_tree_sitter else None def detect_language(self, file_path: str) -> str: """检测编程语言""" ext = Path(file_path).suffix.lower() return TreeSitterParser.LANGUAGE_MAP.get(ext, "text") def split_file( self, content: str, file_path: str, language: Optional[str] = None ) -> List[CodeChunk]: """ 分割单个文件 Args: content: 文件内容 file_path: 文件路径 language: 编程语言(可选) Returns: 代码块列表 """ if not content or not content.strip(): return [] if language is None: language = self.detect_language(file_path) chunks = [] try: # 尝试使用 Tree-sitter 解析 if self.use_tree_sitter and self._ts_parser: tree = self._ts_parser.parse(content, language) if tree: chunks = self._split_by_ast(content, file_path, language, tree) # 如果 AST 解析失败或没有结果,使用增强的正则解析 if not chunks: chunks = self._split_by_enhanced_regex(content, file_path, language) # 如果还是没有,使用基于行的分块 if not chunks: chunks = self._split_by_lines(content, file_path, language) # 🔥 最后一道防线:如果文件不为空但没有产生任何块(比如文件内容太短被过滤了) # 我们强制创建一个文件级别的块,以确保该文件在索引中“挂名”,避免增量索引一直提示它是“新增” if not chunks and content.strip(): chunks.append(CodeChunk( id="", content=content, file_path=file_path, language=language, chunk_type=ChunkType.FILE, line_start=1, line_end=len(content.split('\n')), )) # 后处理:提取安全指标 for chunk in chunks: chunk.security_indicators = self._extract_security_indicators( chunk.content, language ) # 后处理:使用语义分析增强 self._enrich_chunks_with_semantics(chunks, content, language) except Exception as e: logger.warning(f"分块失败 {file_path}: {e}, 使用简单分块") chunks = self._split_by_lines(content, file_path, language) return chunks async def split_file_async( self, content: str, file_path: str, language: Optional[str] = None ) -> List[CodeChunk]: """ 异步分割单个文件 将 CPU 密集型的分块操作(包括 Tree-sitter 解析)放到线程池中执行, 避免阻塞事件循环。 Args: content: 文件内容 file_path: 文件路径 language: 编程语言(可选) Returns: 代码块列表 """ return await asyncio.to_thread(self.split_file, content, file_path, language) def _split_by_ast( self, content: str, file_path: str, language: str, tree: Any ) -> List[CodeChunk]: """基于 AST 分块""" chunks = [] lines = content.split('\n') # 提取定义 definitions = self._ts_parser.extract_definitions(tree, content, language) if not definitions: return [] # 为每个定义创建代码块 for defn in definitions: start_line = defn["start_point"][0] end_line = defn["end_point"][0] # 提取代码内容 chunk_lines = lines[start_line:end_line + 1] chunk_content = '\n'.join(chunk_lines) if len(chunk_content.strip()) < self.min_chunk_size // 4: continue chunk_type = ChunkType.CLASS if defn["type"] == "class" else \ ChunkType.FUNCTION if defn["type"] in ["function", "method"] else \ ChunkType.INTERFACE if defn["type"] == "interface" else \ ChunkType.STRUCT if defn["type"] == "struct" else \ ChunkType.IMPORT if defn["type"] == "import" else \ ChunkType.MODULE chunk = CodeChunk( id="", content=chunk_content, file_path=file_path, language=language, chunk_type=chunk_type, line_start=start_line + 1, line_end=end_line + 1, byte_start=defn["start_byte"], byte_end=defn["end_byte"], name=defn.get("name"), parent_name=defn.get("parent_name"), ast_type=defn.get("node_type"), ) # 如果块太大,进一步分割 if chunk.estimated_tokens > self.max_chunk_size: sub_chunks = self._split_large_chunk(chunk) chunks.extend(sub_chunks) else: chunks.append(chunk) return chunks def _split_by_enhanced_regex( self, content: str, file_path: str, language: str ) -> List[CodeChunk]: """增强的正则表达式分块(支持更多语言)""" chunks = [] lines = content.split('\n') # 各语言的定义模式 patterns = { "python": [ (r"^(\s*)class\s+(\w+)(?:\s*\([^)]*\))?\s*:", ChunkType.CLASS), (r"^(\s*)(?:async\s+)?def\s+(\w+)\s*\([^)]*\)\s*(?:->[^:]+)?:", ChunkType.FUNCTION), ], "javascript": [ (r"^(\s*)(?:export\s+)?class\s+(\w+)", ChunkType.CLASS), (r"^(\s*)(?:export\s+)?(?:async\s+)?function\s*(\w*)\s*\(", ChunkType.FUNCTION), (r"^(\s*)(?:export\s+)?(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s+)?\([^)]*\)\s*=>", ChunkType.FUNCTION), ], "typescript": [ (r"^(\s*)(?:export\s+)?(?:abstract\s+)?class\s+(\w+)", ChunkType.CLASS), (r"^(\s*)(?:export\s+)?interface\s+(\w+)", ChunkType.INTERFACE), (r"^(\s*)(?:export\s+)?(?:async\s+)?function\s*(\w*)", ChunkType.FUNCTION), ], "java": [ (r"^(\s*)(?:public|private|protected)?\s*(?:static\s+)?(?:final\s+)?class\s+(\w+)", ChunkType.CLASS), (r"^(\s*)(?:public|private|protected)?\s*interface\s+(\w+)", ChunkType.INTERFACE), (r"^(\s*)(?:public|private|protected)?\s*(?:static\s+)?[\w<>\[\],\s]+\s+(\w+)\s*\([^)]*\)\s*(?:throws\s+[\w,\s]+)?\s*\{", ChunkType.METHOD), ], "go": [ (r"^type\s+(\w+)\s+struct\s*\{", ChunkType.STRUCT), (r"^type\s+(\w+)\s+interface\s*\{", ChunkType.INTERFACE), (r"^func\s+(?:\([^)]+\)\s+)?(\w+)\s*\([^)]*\)", ChunkType.FUNCTION), ], "php": [ (r"^(\s*)(?:abstract\s+)?class\s+(\w+)", ChunkType.CLASS), (r"^(\s*)interface\s+(\w+)", ChunkType.INTERFACE), (r"^(\s*)(?:public|private|protected)?\s*(?:static\s+)?function\s+(\w+)", ChunkType.FUNCTION), ], "csharp": [ (r"^(\s*)(?:public|private|protected|internal)?\s*(?:static\s+)?(?:partial\s+)?(?:class|record|struct|enum)\s+(\w+)", ChunkType.CLASS), (r"^(\s*)(?:public|private|protected|internal)?\s*interface\s+(\w+)", ChunkType.INTERFACE), (r"^(\s*)(?:public|private|protected|internal)?\s*(?:async\s+)?(?:static\s+)?[\w<>\[\],\s]+\s+(\w+)\s*\([^)]*\)", ChunkType.METHOD), ], "cpp": [ (r"^(\s*)(?:class|struct)\s+(\w+)", ChunkType.CLASS), (r"^(\s*)[\w<>:]+\s+(\w+)\s*\([^)]*\)\s*\{", ChunkType.FUNCTION), ], "ruby": [ (r"^(\s*)(?:class|module)\s+(\w+)", ChunkType.CLASS), (r"^(\s*)def\s+(\w+)", ChunkType.FUNCTION), ], "rust": [ (r"^(\s*)(?:pub\s+)?(?:struct|enum|union)\s+(\w+)", ChunkType.CLASS), (r"^(\s*)(?:pub\s+)?(?:async\s+)?fn\s+(\w+)", ChunkType.FUNCTION), (r"^(\s*)(?:pub\s+)?impl", ChunkType.CLASS), (r"^(\s*)(?:pub\s+)?trait\s+(\w+)", ChunkType.INTERFACE), ], } lang_patterns = patterns.get(language, []) if not lang_patterns: return [] # 找到所有定义的位置 definitions = [] for i, line in enumerate(lines): for pattern, chunk_type in lang_patterns: match = re.match(pattern, line) if match: indent = len(match.group(1)) if match.lastindex >= 1 else 0 name = match.group(2) if match.lastindex >= 2 else None definitions.append({ "line": i, "indent": indent, "name": name, "type": chunk_type, }) break if not definitions: return [] # 计算每个定义的范围 for i, defn in enumerate(definitions): start_line = defn["line"] base_indent = defn["indent"] # 查找结束位置 end_line = len(lines) - 1 for j in range(start_line + 1, len(lines)): line = lines[j] if line.strip(): current_indent = len(line) - len(line.lstrip()) # 如果缩进回到基础级别,检查是否是下一个定义 if current_indent <= base_indent: # 检查是否是下一个定义 is_next_def = any(d["line"] == j for d in definitions) if is_next_def or (current_indent < base_indent): end_line = j - 1 break chunk_content = '\n'.join(lines[start_line:end_line + 1]) if len(chunk_content.strip()) < 10: continue chunk = CodeChunk( id="", content=chunk_content, file_path=file_path, language=language, chunk_type=defn["type"], line_start=start_line + 1, line_end=end_line + 1, name=defn.get("name"), ) if chunk.estimated_tokens > self.max_chunk_size: sub_chunks = self._split_large_chunk(chunk) chunks.extend(sub_chunks) else: chunks.append(chunk) return chunks def _split_by_lines( self, content: str, file_path: str, language: str ) -> List[CodeChunk]: """基于行数分块(回退方案)""" chunks = [] lines = content.split('\n') # 估算每行 Token 数 total_tokens = len(content) // 4 avg_tokens_per_line = max(1, total_tokens // max(1, len(lines))) lines_per_chunk = max(10, self.max_chunk_size // avg_tokens_per_line) overlap_lines = self.overlap_size // avg_tokens_per_line for i in range(0, len(lines), lines_per_chunk - overlap_lines): end = min(i + lines_per_chunk, len(lines)) chunk_content = '\n'.join(lines[i:end]) if len(chunk_content.strip()) < 10: continue chunk = CodeChunk( id="", content=chunk_content, file_path=file_path, language=language, chunk_type=ChunkType.MODULE, line_start=i + 1, line_end=end, ) chunks.append(chunk) if end >= len(lines): break return chunks def _split_large_chunk(self, chunk: CodeChunk) -> List[CodeChunk]: """分割过大的代码块""" sub_chunks = [] lines = chunk.content.split('\n') avg_tokens_per_line = max(1, chunk.estimated_tokens // max(1, len(lines))) lines_per_chunk = max(10, self.max_chunk_size // avg_tokens_per_line) for i in range(0, len(lines), lines_per_chunk): end = min(i + lines_per_chunk, len(lines)) sub_content = '\n'.join(lines[i:end]) if len(sub_content.strip()) < 10: continue sub_chunk = CodeChunk( id="", content=sub_content, file_path=chunk.file_path, language=chunk.language, chunk_type=chunk.chunk_type, line_start=chunk.line_start + i, line_end=chunk.line_start + end - 1, name=chunk.name, parent_name=chunk.parent_name, ) sub_chunks.append(sub_chunk) return sub_chunks if sub_chunks else [chunk] def _extract_security_indicators(self, content: str, language: str) -> List[str]: """提取安全相关指标""" indicators = [] patterns = self.SECURITY_PATTERNS.get(language, []) # 添加通用模式 common_patterns = [ (r"password", "password"), (r"secret", "secret"), (r"api[_-]?key", "api_key"), (r"token", "token"), (r"private[_-]?key", "private_key"), (r"credential", "credential"), ] all_patterns = patterns + common_patterns for pattern, name in all_patterns: try: if re.search(pattern, content, re.IGNORECASE): if name not in indicators: indicators.append(name) except re.error: continue return indicators[:15] def _enrich_chunks_with_semantics( self, chunks: List[CodeChunk], full_content: str, language: str ): """使用语义分析增强代码块""" # 提取导入 imports = self._extract_imports(full_content, language) for chunk in chunks: # 添加相关导入 chunk.imports = self._filter_relevant_imports(imports, chunk.content) # 提取函数调用 chunk.calls = self._extract_function_calls(chunk.content, language) # 提取定义 chunk.definitions = self._extract_definitions(chunk.content, language) def _extract_imports(self, content: str, language: str) -> List[str]: """提取导入语句""" imports = [] patterns = { "python": [ r"^import\s+([\w.]+)", r"^from\s+([\w.]+)\s+import", ], "javascript": [ r"^import\s+.*\s+from\s+['\"]([^'\"]+)['\"]", r"require\s*\(['\"]([^'\"]+)['\"]\)", ], "typescript": [ r"^import\s+.*\s+from\s+['\"]([^'\"]+)['\"]", ], "java": [ r"^import\s+([\w.]+);", ], "csharp": [ r"^using\s+([\w.]+);", ], "go": [ r"['\"]([^'\"]+)['\"]", ], "cpp": [ r'^#include\s+["<]([^">]+)[">]', ], "php": [ r"^use\s+([\w\\]+);", r"require(?:_once)?\s*\(['\"]([^'\"]+)['\"]\)", ], "ruby": [ r"require\s+['\"]([^'\"]+)['\"]", r"require_relative\s+['\"]([^'\"]+)['\"]", ], } for pattern in patterns.get(language, []): matches = re.findall(pattern, content, re.MULTILINE) imports.extend(matches) return list(set(imports)) def _filter_relevant_imports(self, all_imports: List[str], chunk_content: str) -> List[str]: """过滤与代码块相关的导入""" relevant = [] for imp in all_imports: module_name = imp.split('.')[-1] if re.search(rf'\b{re.escape(module_name)}\b', chunk_content): relevant.append(imp) return relevant[:20] def _extract_function_calls(self, content: str, language: str) -> List[str]: """提取函数调用""" pattern = r'\b(\w+)\s*\(' matches = re.findall(pattern, content) keywords = { "python": {"if", "for", "while", "with", "def", "class", "return", "except", "print", "assert", "lambda"}, "javascript": {"if", "for", "while", "switch", "function", "return", "catch", "console", "async", "await"}, "java": {"if", "for", "while", "switch", "return", "catch", "throw", "new"}, "go": {"if", "for", "switch", "return", "func", "go", "defer"}, } lang_keywords = keywords.get(language, set()) calls = [m for m in matches if m not in lang_keywords] return list(set(calls))[:30] def _extract_definitions(self, content: str, language: str) -> List[str]: """提取定义的标识符""" definitions = [] patterns = { "python": [ r"def\s+(\w+)\s*\(", r"class\s+(\w+)", r"(\w+)\s*=\s*", ], "java": [ r"(?:public|private|protected)?\s*(?:static\s+)?(?:final\s+)?(?:class|interface|enum|record)\s+(\w+)", r"(?:public|private|protected)?\s*(?:static\s+)?[\w<>\[\],\s]+\s+(\w+)\s*\([^)]*\)", ], "csharp": [ r"(?:class|record|struct|enum|interface)\s+(\w+)", r"[\w<>\[\],\s]+\s+(\w+)\s*\([^)]*\)", ], "cpp": [ r"(?:class|struct)\s+(\w+)", r"(?:[\w<>:]+)\s+(\w+)\s*\([^)]*\)\s*\{", ], "rust": [ r"(?:struct|enum|union|trait)\s+(\w+)", r"fn\s+(\w+)", ], } for pattern in patterns.get(language, []): matches = re.findall(pattern, content) definitions.extend(matches) return list(set(definitions))[:20]