diff --git a/backend/app/services/rag/indexer.py b/backend/app/services/rag/indexer.py index 173cdc3..6d1ea48 100644 --- a/backend/app/services/rag/indexer.py +++ b/backend/app/services/rag/indexer.py @@ -760,8 +760,8 @@ class CodeIndexer: # 从 embedding_service 获取配置 self.embedding_config = { - "provider": getattr(self.embedding_service, 'provider', 'openai'), - "model": getattr(self.embedding_service, 'model', 'text-embedding-3-small'), + "provider": getattr(self.embedding_service, 'provider', None) or 'openai', + "model": getattr(self.embedding_service, 'model', None) or 'text-embedding-3-small', "dimension": getattr(self.embedding_service, 'dimension', 1536), "base_url": getattr(self.embedding_service, 'base_url', None), } diff --git a/backend/app/services/rag/retriever.py b/backend/app/services/rag/retriever.py index 1400551..fcd018d 100644 --- a/backend/app/services/rag/retriever.py +++ b/backend/app/services/rag/retriever.py @@ -104,8 +104,8 @@ class CodeRetriever: """ self.collection_name = collection_name self._provided_embedding_service = embedding_service # 用户提供的 embedding 服务 - self.embedding_service = embedding_service # 实际使用的 embedding 服务 - self._api_key = api_key + self.embedding_service = embedding_service or EmbeddingService() # 实际使用的 embedding 服务 + self._api_key = api_key or (getattr(self.embedding_service, 'api_key', None) if self.embedding_service else None) # 创建向量存储 if vector_store: diff --git a/backend/app/services/rag/splitter.py b/backend/app/services/rag/splitter.py index b29af0d..cfc6002 100644 --- a/backend/app/services/rag/splitter.py +++ b/backend/app/services/rag/splitter.py @@ -178,6 +178,8 @@ class TreeSitterParser: ".bash": "bash", ".zsh": "bash", ".sql": "sql", + ".md": "markdown", + ".markdown": "markdown", } # 各语言的函数/类节点类型 @@ -537,6 +539,19 @@ class CodeSplitter: if not chunks: chunks = self._split_by_lines(content, file_path, language) + # 🔥 最后一道防线:如果文件不为空但没有产生任何块(比如文件内容太短被过滤了) + # 我们强制创建一个文件级别的块,以确保该文件在索引中“挂名”,避免增量索引一直提示它是“新增” + if not chunks and content.strip(): + chunks.append(CodeChunk( + id="", + content=content, + file_path=file_path, + language=language, + chunk_type=ChunkType.FILE, + line_start=1, + line_end=len(content.split('\n')), + )) + # 后处理:提取安全指标 for chunk in chunks: chunk.security_indicators = self._extract_security_indicators(