diff --git a/backend/app/api/v1/endpoints/agent_tasks.py b/backend/app/api/v1/endpoints/agent_tasks.py index 9806a23..bfe8b87 100644 --- a/backend/app/api/v1/endpoints/agent_tasks.py +++ b/backend/app/api/v1/endpoints/agent_tasks.py @@ -244,24 +244,40 @@ async def _execute_agent_task(task_id: str): sandbox_manager = SandboxManager() await sandbox_manager.initialize() logger.info(f"🐳 Global Sandbox Manager initialized (Available: {sandbox_manager.is_available})") - + + # 🔥 提前创建事件管理器,以便在克隆仓库和索引时发送实时日志 + from app.services.agent.event_manager import EventManager, AgentEventEmitter + event_manager = EventManager(db_session_factory=async_session_factory) + event_manager.create_queue(task_id) + event_emitter = AgentEventEmitter(task_id, event_manager) + _running_event_managers[task_id] = event_manager + async with async_session_factory() as db: orchestrator = None start_time = time.time() - + try: # 获取任务 task = await db.get(AgentTask, task_id, options=[selectinload(AgentTask.project)]) if not task: logger.error(f"Task {task_id} not found") return - + # 获取项目 project = task.project if not project: logger.error(f"Project not found for task {task_id}") return + # 🔥 发送任务开始事件 - 使用 phase_start 让前端知道进入准备阶段 + await event_emitter.emit_phase_start("preparation", f"🚀 任务开始执行: {project.name}") + + # 更新任务阶段为准备中 + task.status = AgentTaskStatus.RUNNING + task.started_at = datetime.now(timezone.utc) + task.current_phase = AgentTaskPhase.PLANNING # preparation 对应 PLANNING + await db.commit() + # 获取用户配置(需要在获取项目根目录之前,以便传递 token) user_config = await _get_user_config(db, task.created_by) @@ -271,30 +287,23 @@ async def _execute_agent_task(task_id: str): gitlab_token = other_config.get('gitlabToken') or settings.GITLAB_TOKEN # 获取项目根目录(传递任务指定的分支和认证 token) + # 🔥 传递 event_emitter 以发送克隆进度 project_root = await _get_project_root( project, task_id, task.branch_name, github_token=github_token, gitlab_token=gitlab_token, + event_emitter=event_emitter, # 🔥 新增 ) - - # 更新状态为运行中 - task.status = AgentTaskStatus.RUNNING - task.started_at = datetime.now(timezone.utc) - task.current_phase = AgentTaskPhase.PLANNING - await db.commit() + logger.info(f"🚀 Task {task_id} started with Dynamic Agent Tree architecture") - - # 创建事件管理器 - event_manager = EventManager(db_session_factory=async_session_factory) - event_manager.create_queue(task_id) - event_emitter = AgentEventEmitter(task_id, event_manager) - + # 创建 LLM 服务 llm_service = LLMService(user_config=user_config) - + # 初始化工具集 - 传递排除模式和目标文件以及预初始化的 sandbox_manager + # 🔥 传递 event_emitter 以发送索引进度 tools = await _initialize_tools( project_root, llm_service, @@ -303,27 +312,28 @@ async def _execute_agent_task(task_id: str): exclude_patterns=task.exclude_patterns, target_files=task.target_files, project_id=str(project.id), # 🔥 传递 project_id 用于 RAG + event_emitter=event_emitter, # 🔥 新增 ) - + # 创建子 Agent recon_agent = ReconAgent( llm_service=llm_service, tools=tools.get("recon", {}), event_emitter=event_emitter, ) - + analysis_agent = AnalysisAgent( llm_service=llm_service, tools=tools.get("analysis", {}), event_emitter=event_emitter, ) - + verification_agent = VerificationAgent( llm_service=llm_service, tools=tools.get("verification", {}), event_emitter=event_emitter, ) - + # 创建 Orchestrator Agent orchestrator = OrchestratorAgent( llm_service=llm_service, @@ -335,7 +345,7 @@ async def _execute_agent_task(task_id: str): "verification": verification_agent, }, ) - + # 注册到全局 _running_orchestrators[task_id] = orchestrator _running_tasks[task_id] = orchestrator # 兼容旧的取消逻辑 @@ -560,6 +570,7 @@ async def _initialize_tools( exclude_patterns: Optional[List[str]] = None, target_files: Optional[List[str]] = None, project_id: Optional[str] = None, # 🔥 用于 RAG collection_name + event_emitter: Optional[Any] = None, # 🔥 新增:用于发送实时日志 ) -> Dict[str, Dict[str, Any]]: """初始化工具集 @@ -571,6 +582,7 @@ async def _initialize_tools( exclude_patterns: 排除模式列表 target_files: 目标文件列表 project_id: 项目 ID(用于 RAG collection_name) + event_emitter: 事件发送器(用于发送实时日志) """ from app.services.agent.tools import ( FileReadTool, FileSearchTool, ListFilesTool, @@ -588,12 +600,27 @@ async def _initialize_tools( GetVulnerabilityKnowledgeTool, ) # 🔥 RAG 相关导入 - from app.services.rag import CodeIndexer, CodeRetriever, EmbeddingService + from app.services.rag import CodeIndexer, CodeRetriever, EmbeddingService, IndexUpdateMode from app.core.config import settings + # 辅助函数:发送事件 + async def emit(message: str, level: str = "info"): + if event_emitter: + logger.debug(f"[EMIT-TOOLS] Sending {level}: {message[:60]}...") + if level == "info": + await event_emitter.emit_info(message) + elif level == "warning": + await event_emitter.emit_warning(message) + elif level == "error": + await event_emitter.emit_error(message) + else: + logger.warning(f"[EMIT-TOOLS] No event_emitter, skipping: {message[:60]}...") + # ============ 🔥 初始化 RAG 系统 ============ retriever = None try: + await emit(f"🔍 正在初始化 RAG 系统...") + # 从用户配置中获取 embedding 配置 user_llm_config = (user_config or {}).get('llmConfig', {}) user_other_config = (user_config or {}).get('otherConfig', {}) @@ -631,6 +658,7 @@ async def _initialize_tools( ) logger.info(f"RAG 配置: provider={embedding_provider}, model={embedding_model}, base_url={embedding_base_url or '(使用默认)'}") + await emit(f"📊 Embedding 配置: {embedding_provider}/{embedding_model}") # 创建 Embedding 服务 embedding_service = EmbeddingService( @@ -643,6 +671,47 @@ async def _initialize_tools( # 创建 collection_name(基于 project_id) collection_name = f"project_{project_id}" if project_id else "default_project" + # 🔥 v2.0: 创建 CodeIndexer 并进行智能索引 + # 智能索引会自动: + # - 检测 embedding 模型变更,如需要则自动重建 + # - 对比文件 hash,只更新变化的文件(增量更新) + indexer = CodeIndexer( + collection_name=collection_name, + embedding_service=embedding_service, + persist_directory=settings.VECTOR_DB_PATH, + ) + + logger.info(f"📝 开始智能索引项目: {project_root}") + await emit(f"📝 正在构建代码向量索引...") + + index_progress = None + last_progress_update = 0 + async for progress in indexer.smart_index_directory( + directory=project_root, + exclude_patterns=exclude_patterns or [], + update_mode=IndexUpdateMode.SMART, + ): + index_progress = progress + # 每处理 10 个文件或有重要变化时发送进度更新 + if progress.processed_files - last_progress_update >= 10 or progress.processed_files == progress.total_files: + if progress.total_files > 0: + await emit( + f"📝 索引进度: {progress.processed_files}/{progress.total_files} 文件 " + f"({progress.progress_percentage:.0f}%)" + ) + last_progress_update = progress.processed_files + + if index_progress: + summary = ( + f"✅ 索引完成: 模式={index_progress.update_mode}, " + f"新增={index_progress.added_files}, " + f"更新={index_progress.updated_files}, " + f"删除={index_progress.deleted_files}, " + f"代码块={index_progress.indexed_chunks}" + ) + logger.info(summary) + await emit(summary) + # 创建 CodeRetriever(用于搜索) # 🔥 传递 api_key,用于自动适配 collection 的 embedding 配置 retriever = CodeRetriever( @@ -653,9 +722,13 @@ async def _initialize_tools( ) logger.info(f"✅ RAG 系统初始化成功: collection={collection_name}") + await emit(f"✅ RAG 系统初始化成功") except Exception as e: logger.warning(f"⚠️ RAG 系统初始化失败: {e}") + await emit(f"⚠️ RAG 系统初始化失败: {e}", "warning") + import traceback + logger.debug(f"RAG 初始化异常详情:\n{traceback.format_exc()}") retriever = None # 基础工具 - 传递排除模式和目标文件 @@ -1942,6 +2015,7 @@ async def _get_project_root( branch_name: Optional[str] = None, github_token: Optional[str] = None, gitlab_token: Optional[str] = None, + event_emitter: Optional[Any] = None, # 🔥 新增:用于发送实时日志 ) -> str: """ 获取项目根目录 @@ -1956,6 +2030,7 @@ async def _get_project_root( branch_name: 分支名称(仓库项目使用,优先于 project.default_branch) github_token: GitHub 访问令牌(用于私有仓库) gitlab_token: GitLab 访问令牌(用于私有仓库) + event_emitter: 事件发送器(用于发送实时日志) Returns: 项目根目录路径 @@ -1968,6 +2043,16 @@ async def _get_project_root( import shutil from urllib.parse import urlparse, urlunparse + # 辅助函数:发送事件 + async def emit(message: str, level: str = "info"): + if event_emitter: + if level == "info": + await event_emitter.emit_info(message) + elif level == "warning": + await event_emitter.emit_warning(message) + elif level == "error": + await event_emitter.emit_error(message) + base_path = f"/tmp/deepaudit/{task_id}" # 确保目录存在且为空 @@ -1978,6 +2063,7 @@ async def _get_project_root( # 根据项目类型处理 if project.source_type == "zip": # 🔥 ZIP 项目:解压 ZIP 文件 + await emit(f"📦 正在解压项目文件...") from app.services.zip_storage import load_project_zip zip_path = await load_project_zip(project.id) @@ -1987,11 +2073,14 @@ async def _get_project_root( with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extractall(base_path) logger.info(f"✅ Extracted ZIP project {project.id} to {base_path}") + await emit(f"✅ ZIP 文件解压完成") except Exception as e: logger.error(f"Failed to extract ZIP {zip_path}: {e}") + await emit(f"❌ 解压失败: {e}", "error") raise RuntimeError(f"无法解压项目文件: {e}") else: logger.warning(f"⚠️ ZIP file not found for project {project.id}") + await emit(f"❌ ZIP 文件不存在", "error") raise RuntimeError(f"项目 ZIP 文件不存在: {project.id}") elif project.source_type == "repository" and project.repository_url: @@ -1999,6 +2088,8 @@ async def _get_project_root( repo_url = project.repository_url repo_type = project.repository_type or "other" + await emit(f"🔄 正在克隆仓库: {repo_url}") + # 检查 git 是否可用(使用 git --version 更可靠) try: git_check = subprocess.run( @@ -2008,11 +2099,14 @@ async def _get_project_root( timeout=10 ) if git_check.returncode != 0: + await emit(f"❌ Git 未安装", "error") raise RuntimeError("Git 未安装,无法克隆仓库。请在 Docker 容器中安装 git。") logger.debug(f"Git version: {git_check.stdout.strip()}") except FileNotFoundError: + await emit(f"❌ Git 未安装", "error") raise RuntimeError("Git 未安装,无法克隆仓库。请在 Docker 容器中安装 git。") except subprocess.TimeoutExpired: + await emit(f"❌ Git 检测超时", "error") raise RuntimeError("Git 检测超时") # 构建带认证的 URL(用于私有仓库) @@ -2028,6 +2122,7 @@ async def _get_project_root( parsed.fragment )) logger.info(f"🔐 Using GitHub token for authentication") + await emit(f"🔐 使用 GitHub Token 认证") elif repo_type == "gitlab" and gitlab_token: parsed = urlparse(repo_url) auth_url = urlunparse(( @@ -2039,6 +2134,7 @@ async def _get_project_root( parsed.fragment )) logger.info(f"🔐 Using GitLab token for authentication") + await emit(f"🔐 使用 GitLab Token 认证") # 构建分支尝试顺序 branches_to_try = [] @@ -2061,6 +2157,7 @@ async def _get_project_root( os.makedirs(base_path, exist_ok=True) logger.info(f"🔄 Trying to clone repository (branch: {branch})...") + await emit(f"🔄 尝试克隆分支: {branch}") try: result = subprocess.run( ["git", "clone", "--depth", "1", "--branch", branch, auth_url, base_path], @@ -2071,18 +2168,22 @@ async def _get_project_root( if result.returncode == 0: logger.info(f"✅ Cloned repository {repo_url} (branch: {branch}) to {base_path}") + await emit(f"✅ 仓库克隆成功 (分支: {branch})") clone_success = True break else: last_error = result.stderr logger.warning(f"Failed to clone branch {branch}: {last_error[:200]}") + await emit(f"⚠️ 分支 {branch} 克隆失败,尝试其他分支...", "warning") except subprocess.TimeoutExpired: last_error = f"克隆分支 {branch} 超时" logger.warning(last_error) + await emit(f"⚠️ 分支 {branch} 克隆超时,尝试其他分支...", "warning") # 如果所有分支都失败,尝试不指定分支克隆(使用仓库默认分支) if not clone_success: logger.info(f"🔄 Trying to clone without specifying branch...") + await emit(f"🔄 尝试使用仓库默认分支克隆...") if os.path.exists(base_path) and os.listdir(base_path): shutil.rmtree(base_path) os.makedirs(base_path, exist_ok=True) @@ -2097,11 +2198,13 @@ async def _get_project_root( if result.returncode == 0: logger.info(f"✅ Cloned repository {repo_url} (default branch) to {base_path}") + await emit(f"✅ 仓库克隆成功 (默认分支)") clone_success = True else: last_error = result.stderr except subprocess.TimeoutExpired: last_error = "克隆仓库超时" + await emit(f"⚠️ 克隆超时", "warning") if not clone_success: # 分析错误原因 @@ -2118,12 +2221,15 @@ async def _get_project_root( error_msg = f"克隆仓库失败: {last_error[:200]}" logger.error(f"❌ {error_msg}") + await emit(f"❌ {error_msg}", "error") raise RuntimeError(error_msg) # 验证目录不为空 if not os.listdir(base_path): + await emit(f"❌ 项目目录为空", "error") raise RuntimeError(f"项目目录为空,可能是克隆/解压失败: {base_path}") + await emit(f"📁 项目准备完成: {base_path}") return base_path diff --git a/backend/app/services/agent/prompts/system_prompts.py b/backend/app/services/agent/prompts/system_prompts.py index 75c1c07..7e690e6 100644 --- a/backend/app/services/agent/prompts/system_prompts.py +++ b/backend/app/services/agent/prompts/system_prompts.py @@ -169,7 +169,7 @@ TOOL_USAGE_GUIDE = """ #### 第一步:快速侦察(5%时间) ``` Action: list_files -Action Input: {"path": "."} +Action Input: {"directory": "."} ``` 了解项目结构、技术栈、入口点 diff --git a/backend/app/services/agent/tools/file_tool.py b/backend/app/services/agent/tools/file_tool.py index 014b792..f8b1d49 100644 --- a/backend/app/services/agent/tools/file_tool.py +++ b/backend/app/services/agent/tools/file_tool.py @@ -502,6 +502,10 @@ class ListFilesTool(AgentTool): ) -> ToolResult: """执行文件列表""" try: + # 🔥 兼容性处理:支持 path 参数作为 directory 的别名 + if "path" in kwargs and kwargs["path"]: + directory = kwargs["path"] + target_dir = os.path.normpath(os.path.join(self.project_root, directory)) if not target_dir.startswith(os.path.normpath(self.project_root)): return ToolResult( diff --git a/backend/app/services/rag/__init__.py b/backend/app/services/rag/__init__.py index c6a031e..e49bdc4 100644 --- a/backend/app/services/rag/__init__.py +++ b/backend/app/services/rag/__init__.py @@ -1,11 +1,23 @@ """ RAG (Retrieval-Augmented Generation) 系统 用于代码索引和语义检索 + +🔥 v2.0 改进: +- 支持嵌入模型变更检测和自动重建 +- 支持增量索引更新(基于文件 hash) +- 支持索引版本控制和状态查询 """ from .splitter import CodeSplitter, CodeChunk from .embeddings import EmbeddingService -from .indexer import CodeIndexer +from .indexer import ( + CodeIndexer, + IndexingProgress, + IndexingResult, + IndexStatus, + IndexUpdateMode, + INDEX_VERSION, +) from .retriever import CodeRetriever __all__ = [ @@ -14,5 +26,10 @@ __all__ = [ "EmbeddingService", "CodeIndexer", "CodeRetriever", + "IndexingProgress", + "IndexingResult", + "IndexStatus", + "IndexUpdateMode", + "INDEX_VERSION", ] diff --git a/backend/app/services/rag/indexer.py b/backend/app/services/rag/indexer.py index 3cffa77..f66d02d 100644 --- a/backend/app/services/rag/indexer.py +++ b/backend/app/services/rag/indexer.py @@ -1,14 +1,22 @@ """ 代码索引器 将代码分块并索引到向量数据库 + +🔥 v2.0 改进: +- 支持嵌入模型变更检测和自动重建 +- 支持增量索引更新(基于文件 hash) +- 支持索引版本控制和状态查询 """ import os import asyncio import logging -from typing import List, Dict, Any, Optional, AsyncGenerator, Callable +import hashlib +import time +from typing import List, Dict, Any, Optional, AsyncGenerator, Callable, Set, Tuple from pathlib import Path -from dataclasses import dataclass +from dataclasses import dataclass, field +from enum import Enum import json from .splitter import CodeSplitter, CodeChunk @@ -16,6 +24,9 @@ from .embeddings import EmbeddingService logger = logging.getLogger(__name__) +# 索引版本号(当索引格式变化时递增) +INDEX_VERSION = "2.0" + # 支持的文本文件扩展名 TEXT_EXTENSIONS = { @@ -40,6 +51,44 @@ EXCLUDE_FILES = { } +class IndexUpdateMode(Enum): + """索引更新模式""" + FULL = "full" # 全量重建:删除旧索引,完全重新索引 + INCREMENTAL = "incremental" # 增量更新:只更新变化的文件 + SMART = "smart" # 智能模式:根据情况自动选择 + + +@dataclass +class IndexStatus: + """索引状态信息""" + collection_name: str + exists: bool = False + index_version: str = "" + chunk_count: int = 0 + file_count: int = 0 + created_at: float = 0.0 + updated_at: float = 0.0 + embedding_provider: str = "" + embedding_model: str = "" + embedding_dimension: int = 0 + project_hash: str = "" + + def to_dict(self) -> Dict[str, Any]: + return { + "collection_name": self.collection_name, + "exists": self.exists, + "index_version": self.index_version, + "chunk_count": self.chunk_count, + "file_count": self.file_count, + "created_at": self.created_at, + "updated_at": self.updated_at, + "embedding_provider": self.embedding_provider, + "embedding_model": self.embedding_model, + "embedding_dimension": self.embedding_dimension, + "project_hash": self.project_hash, + } + + @dataclass class IndexingProgress: """索引进度""" @@ -49,11 +98,17 @@ class IndexingProgress: indexed_chunks: int = 0 current_file: str = "" errors: List[str] = None - + # 🔥 新增:增量更新统计 + added_files: int = 0 + updated_files: int = 0 + deleted_files: int = 0 + skipped_files: int = 0 + update_mode: str = "full" + def __post_init__(self): if self.errors is None: self.errors = [] - + @property def progress_percentage(self) -> float: if self.total_files == 0: @@ -74,11 +129,11 @@ class IndexingResult: class VectorStore: """向量存储抽象基类""" - + async def initialize(self): """初始化存储""" pass - + async def add_documents( self, ids: List[str], @@ -88,7 +143,25 @@ class VectorStore: ): """添加文档""" raise NotImplementedError - + + async def upsert_documents( + self, + ids: List[str], + embeddings: List[List[float]], + documents: List[str], + metadatas: List[Dict[str, Any]], + ): + """更新或插入文档""" + raise NotImplementedError + + async def delete_by_file_path(self, file_path: str) -> int: + """删除指定文件的所有文档,返回删除数量""" + raise NotImplementedError + + async def delete_by_ids(self, ids: List[str]) -> int: + """删除指定 ID 的文档""" + raise NotImplementedError + async def query( self, query_embedding: List[float], @@ -97,33 +170,58 @@ class VectorStore: ) -> Dict[str, Any]: """查询""" raise NotImplementedError - + async def delete_collection(self): """删除集合""" raise NotImplementedError - + async def get_count(self) -> int: """获取文档数量""" raise NotImplementedError + async def get_all_file_paths(self) -> Set[str]: + """获取所有已索引的文件路径""" + raise NotImplementedError + + async def get_file_hashes(self) -> Dict[str, str]: + """获取所有文件的 hash 映射 {file_path: hash}""" + raise NotImplementedError + + def get_collection_metadata(self) -> Dict[str, Any]: + """获取 collection 元数据""" + raise NotImplementedError + class ChromaVectorStore(VectorStore): - """Chroma 向量存储""" + """ + Chroma 向量存储 + + 🔥 v2.0 改进: + - 支持 embedding 配置变更检测 + - 支持增量更新(upsert、delete) + - 支持文件级别的索引管理 + """ def __init__( self, collection_name: str, persist_directory: Optional[str] = None, - embedding_config: Optional[Dict[str, Any]] = None, # 🔥 新增:embedding 配置 + embedding_config: Optional[Dict[str, Any]] = None, ): self.collection_name = collection_name self.persist_directory = persist_directory - self.embedding_config = embedding_config or {} # 🔥 存储 embedding 配置 + self.embedding_config = embedding_config or {} self._client = None self._collection = None + self._is_new_collection = False - async def initialize(self): - """初始化 Chroma""" + async def initialize(self, force_recreate: bool = False): + """ + 初始化 Chroma + + Args: + force_recreate: 是否强制重建 collection + """ try: import chromadb from chromadb.config import Settings @@ -138,33 +236,56 @@ class ChromaVectorStore(VectorStore): settings=Settings(anonymized_telemetry=False), ) - # 🔥 构建 collection 元数据,包含 embedding 配置 - collection_metadata = {"hnsw:space": "cosine"} + # 检查 collection 是否存在 + existing_collections = [c.name for c in self._client.list_collections()] + collection_exists = self.collection_name in existing_collections + + # 如果需要强制重建,先删除 + if force_recreate and collection_exists: + logger.info(f"🗑️ 强制重建: 删除旧 collection '{self.collection_name}'") + self._client.delete_collection(name=self.collection_name) + collection_exists = False + + # 构建 collection 元数据 + current_time = time.time() + collection_metadata = { + "hnsw:space": "cosine", + "index_version": INDEX_VERSION, + } + if self.embedding_config: - # 在元数据中记录 embedding 配置 collection_metadata["embedding_provider"] = self.embedding_config.get("provider", "openai") collection_metadata["embedding_model"] = self.embedding_config.get("model", "text-embedding-3-small") collection_metadata["embedding_dimension"] = self.embedding_config.get("dimension", 1536) if self.embedding_config.get("base_url"): collection_metadata["embedding_base_url"] = self.embedding_config.get("base_url") - self._collection = self._client.get_or_create_collection( - name=self.collection_name, - metadata=collection_metadata, - ) - - logger.info(f"Chroma collection '{self.collection_name}' initialized") + if collection_exists: + # 获取现有 collection + self._collection = self._client.get_collection(name=self.collection_name) + self._is_new_collection = False + logger.info(f"📂 获取现有 collection '{self.collection_name}'") + else: + # 创建新 collection + collection_metadata["created_at"] = current_time + collection_metadata["updated_at"] = current_time + self._collection = self._client.create_collection( + name=self.collection_name, + metadata=collection_metadata, + ) + self._is_new_collection = True + logger.info(f"✨ 创建新 collection '{self.collection_name}'") except ImportError: raise ImportError("chromadb is required. Install with: pip install chromadb") - def get_embedding_config(self) -> Dict[str, Any]: - """ - 🔥 获取 collection 的 embedding 配置 + @property + def is_new_collection(self) -> bool: + """是否是新创建的 collection""" + return self._is_new_collection - Returns: - 包含 provider, model, dimension, base_url 的字典 - """ + def get_embedding_config(self) -> Dict[str, Any]: + """获取 collection 的 embedding 配置""" if not self._collection: return {} @@ -175,7 +296,28 @@ class ChromaVectorStore(VectorStore): "dimension": metadata.get("embedding_dimension"), "base_url": metadata.get("embedding_base_url"), } - + + def get_collection_metadata(self) -> Dict[str, Any]: + """获取 collection 完整元数据""" + if not self._collection: + return {} + return dict(self._collection.metadata or {}) + + def _clean_metadatas(self, metadatas: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """清理元数据,确保符合 Chroma 要求""" + cleaned_metadatas = [] + for meta in metadatas: + cleaned = {} + for k, v in meta.items(): + if isinstance(v, (str, int, float, bool)): + cleaned[k] = v + elif isinstance(v, list): + cleaned[k] = json.dumps(v) + elif v is not None: + cleaned[k] = str(v) + cleaned_metadatas.append(cleaned) + return cleaned_metadatas + async def add_documents( self, ids: List[str], @@ -186,21 +328,9 @@ class ChromaVectorStore(VectorStore): """添加文档到 Chroma""" if not ids: return - - # Chroma 对元数据有限制,需要清理 - cleaned_metadatas = [] - for meta in metadatas: - cleaned = {} - for k, v in meta.items(): - if isinstance(v, (str, int, float, bool)): - cleaned[k] = v - elif isinstance(v, list): - # 列表转为 JSON 字符串 - cleaned[k] = json.dumps(v) - elif v is not None: - cleaned[k] = str(v) - cleaned_metadatas.append(cleaned) - + + cleaned_metadatas = self._clean_metadatas(metadatas) + # 分批添加(Chroma 批次限制) batch_size = 500 for i in range(0, len(ids), batch_size): @@ -208,7 +338,7 @@ class ChromaVectorStore(VectorStore): batch_embeddings = embeddings[i:i + batch_size] batch_documents = documents[i:i + batch_size] batch_metadatas = cleaned_metadatas[i:i + batch_size] - + await asyncio.to_thread( self._collection.add, ids=batch_ids, @@ -216,7 +346,76 @@ class ChromaVectorStore(VectorStore): documents=batch_documents, metadatas=batch_metadatas, ) - + + async def upsert_documents( + self, + ids: List[str], + embeddings: List[List[float]], + documents: List[str], + metadatas: List[Dict[str, Any]], + ): + """更新或插入文档(用于增量更新)""" + if not ids: + return + + cleaned_metadatas = self._clean_metadatas(metadatas) + + # 分批 upsert + batch_size = 500 + for i in range(0, len(ids), batch_size): + batch_ids = ids[i:i + batch_size] + batch_embeddings = embeddings[i:i + batch_size] + batch_documents = documents[i:i + batch_size] + batch_metadatas = cleaned_metadatas[i:i + batch_size] + + await asyncio.to_thread( + self._collection.upsert, + ids=batch_ids, + embeddings=batch_embeddings, + documents=batch_documents, + metadatas=batch_metadatas, + ) + + async def delete_by_file_path(self, file_path: str) -> int: + """删除指定文件的所有文档""" + if not self._collection: + return 0 + + try: + # 查询该文件的所有文档 + result = await asyncio.to_thread( + self._collection.get, + where={"file_path": file_path}, + ) + + ids_to_delete = result.get("ids", []) + if ids_to_delete: + await asyncio.to_thread( + self._collection.delete, + ids=ids_to_delete, + ) + logger.debug(f"删除文件 '{file_path}' 的 {len(ids_to_delete)} 个文档") + + return len(ids_to_delete) + except Exception as e: + logger.warning(f"删除文件文档失败: {e}") + return 0 + + async def delete_by_ids(self, ids: List[str]) -> int: + """删除指定 ID 的文档""" + if not self._collection or not ids: + return 0 + + try: + await asyncio.to_thread( + self._collection.delete, + ids=ids, + ) + return len(ids) + except Exception as e: + logger.warning(f"删除文档失败: {e}") + return 0 + async def query( self, query_embedding: List[float], @@ -231,14 +430,14 @@ class ChromaVectorStore(VectorStore): where=where, include=["documents", "metadatas", "distances"], ) - + return { "ids": result["ids"][0] if result["ids"] else [], "documents": result["documents"][0] if result["documents"] else [], "metadatas": result["metadatas"][0] if result["metadatas"] else [], "distances": result["distances"][0] if result["distances"] else [], } - + async def delete_collection(self): """删除集合""" if self._client and self._collection: @@ -246,25 +445,111 @@ class ChromaVectorStore(VectorStore): self._client.delete_collection, name=self.collection_name, ) - + self._collection = None + async def get_count(self) -> int: """获取文档数量""" if self._collection: return await asyncio.to_thread(self._collection.count) return 0 + async def get_all_file_paths(self) -> Set[str]: + """获取所有已索引的文件路径""" + if not self._collection: + return set() + + try: + # 获取所有文档的元数据 + result = await asyncio.to_thread( + self._collection.get, + include=["metadatas"], + ) + + file_paths = set() + for meta in result.get("metadatas", []): + if meta and "file_path" in meta: + file_paths.add(meta["file_path"]) + + return file_paths + except Exception as e: + logger.warning(f"获取文件路径失败: {e}") + return set() + + async def get_file_hashes(self) -> Dict[str, str]: + """获取所有文件的 hash 映射 {file_path: file_hash}""" + if not self._collection: + return {} + + try: + result = await asyncio.to_thread( + self._collection.get, + include=["metadatas"], + ) + + file_hashes = {} + for meta in result.get("metadatas", []): + if meta: + file_path = meta.get("file_path") + file_hash = meta.get("file_hash") + if file_path and file_hash: + # 同一文件可能有多个 chunk,hash 应该相同 + file_hashes[file_path] = file_hash + + return file_hashes + except Exception as e: + logger.warning(f"获取文件 hash 失败: {e}") + return {} + + async def update_collection_metadata(self, updates: Dict[str, Any]): + """更新 collection 元数据""" + if not self._collection: + return + + try: + current_metadata = dict(self._collection.metadata or {}) + current_metadata.update(updates) + current_metadata["updated_at"] = time.time() + + # Chroma 不支持直接更新元数据,需要通过修改 collection + # 这里我们使用 modify 方法 + await asyncio.to_thread( + self._collection.modify, + metadata=current_metadata, + ) + except Exception as e: + logger.warning(f"更新 collection 元数据失败: {e}") + class InMemoryVectorStore(VectorStore): """内存向量存储(用于测试或小项目)""" - - def __init__(self, collection_name: str): + + def __init__(self, collection_name: str, embedding_config: Optional[Dict[str, Any]] = None): self.collection_name = collection_name + self.embedding_config = embedding_config or {} self._documents: Dict[str, Dict[str, Any]] = {} - - async def initialize(self): + self._metadata: Dict[str, Any] = { + "created_at": time.time(), + "index_version": INDEX_VERSION, + } + self._is_new_collection = True + + async def initialize(self, force_recreate: bool = False): """初始化""" + if force_recreate: + self._documents.clear() + self._is_new_collection = True logger.info(f"InMemory vector store '{self.collection_name}' initialized") - + + @property + def is_new_collection(self) -> bool: + return self._is_new_collection + + def get_embedding_config(self) -> Dict[str, Any]: + return self.embedding_config + + def get_collection_metadata(self) -> Dict[str, Any]: + return self._metadata + async def add_documents( self, ids: List[str], @@ -279,7 +564,37 @@ class InMemoryVectorStore(VectorStore): "document": doc, "metadata": meta, } - + self._is_new_collection = False + + async def upsert_documents( + self, + ids: List[str], + embeddings: List[List[float]], + documents: List[str], + metadatas: List[Dict[str, Any]], + ): + """更新或插入文档""" + await self.add_documents(ids, embeddings, documents, metadatas) + + async def delete_by_file_path(self, file_path: str) -> int: + """删除指定文件的所有文档""" + ids_to_delete = [ + id_ for id_, data in self._documents.items() + if data["metadata"].get("file_path") == file_path + ] + for id_ in ids_to_delete: + del self._documents[id_] + return len(ids_to_delete) + + async def delete_by_ids(self, ids: List[str]) -> int: + """删除指定 ID 的文档""" + count = 0 + for id_ in ids: + if id_ in self._documents: + del self._documents[id_] + count += 1 + return count + async def query( self, query_embedding: List[float], @@ -288,7 +603,7 @@ class InMemoryVectorStore(VectorStore): ) -> Dict[str, Any]: """查询(使用余弦相似度)""" import math - + def cosine_similarity(a: List[float], b: List[float]) -> float: dot = sum(x * y for x, y in zip(a, b)) norm_a = math.sqrt(sum(x * x for x in a)) @@ -296,7 +611,7 @@ class InMemoryVectorStore(VectorStore): if norm_a == 0 or norm_b == 0: return 0.0 return dot / (norm_a * norm_b) - + results = [] for id_, data in self._documents.items(): # 应用过滤条件 @@ -308,39 +623,66 @@ class InMemoryVectorStore(VectorStore): break if not match: continue - + similarity = cosine_similarity(query_embedding, data["embedding"]) results.append({ "id": id_, "document": data["document"], "metadata": data["metadata"], - "distance": 1 - similarity, # 转换为距离 + "distance": 1 - similarity, }) - - # 按距离排序 + results.sort(key=lambda x: x["distance"]) results = results[:n_results] - + return { "ids": [r["id"] for r in results], "documents": [r["document"] for r in results], "metadatas": [r["metadata"] for r in results], "distances": [r["distance"] for r in results], } - + async def delete_collection(self): """删除集合""" self._documents.clear() - + async def get_count(self) -> int: """获取文档数量""" return len(self._documents) + async def get_all_file_paths(self) -> Set[str]: + """获取所有已索引的文件路径""" + return { + data["metadata"].get("file_path") + for data in self._documents.values() + if data["metadata"].get("file_path") + } + + async def get_file_hashes(self) -> Dict[str, str]: + """获取所有文件的 hash 映射""" + file_hashes = {} + for data in self._documents.values(): + file_path = data["metadata"].get("file_path") + file_hash = data["metadata"].get("file_hash") + if file_path and file_hash: + file_hashes[file_path] = file_hash + return file_hashes + + async def update_collection_metadata(self, updates: Dict[str, Any]): + """更新 collection 元数据""" + self._metadata.update(updates) + self._metadata["updated_at"] = time.time() + class CodeIndexer: """ 代码索引器 将代码文件分块、嵌入并索引到向量数据库 + + 🔥 v2.0 改进: + - 自动检测 embedding 模型变更并重建索引 + - 支持增量索引更新(基于文件 hash) + - 支持索引状态查询 """ def __init__( @@ -364,9 +706,10 @@ class CodeIndexer: self.collection_name = collection_name self.embedding_service = embedding_service or EmbeddingService() self.splitter = splitter or CodeSplitter() + self.persist_directory = persist_directory - # 🔥 从 embedding_service 获取配置,用于存储到 collection 元数据 - embedding_config = { + # 从 embedding_service 获取配置 + self.embedding_config = { "provider": getattr(self.embedding_service, 'provider', 'openai'), "model": getattr(self.embedding_service, 'model', 'text-embedding-3-small'), "dimension": getattr(self.embedding_service, 'dimension', 1536), @@ -381,20 +724,385 @@ class CodeIndexer: self.vector_store = ChromaVectorStore( collection_name=collection_name, persist_directory=persist_directory, - embedding_config=embedding_config, # 🔥 传递 embedding 配置 + embedding_config=self.embedding_config, ) except ImportError: logger.warning("Chroma not available, using in-memory store") - self.vector_store = InMemoryVectorStore(collection_name=collection_name) + self.vector_store = InMemoryVectorStore( + collection_name=collection_name, + embedding_config=self.embedding_config, + ) self._initialized = False - - async def initialize(self): - """初始化索引器""" - if not self._initialized: - await self.vector_store.initialize() - self._initialized = True - + self._needs_rebuild = False + self._rebuild_reason = "" + + async def initialize(self, force_rebuild: bool = False) -> Tuple[bool, str]: + """ + 初始化索引器,检测是否需要重建索引 + + Args: + force_rebuild: 是否强制重建 + + Returns: + (needs_rebuild, reason) - 是否需要重建及原因 + """ + if self._initialized and not force_rebuild: + return self._needs_rebuild, self._rebuild_reason + + # 先初始化 vector_store(不强制重建,只是获取现有 collection) + await self.vector_store.initialize(force_recreate=False) + + # 检查是否需要重建 + self._needs_rebuild, self._rebuild_reason = await self._check_rebuild_needed() + + if force_rebuild: + self._needs_rebuild = True + self._rebuild_reason = "用户强制重建" + + # 如果需要重建,重新初始化 vector_store(强制重建) + if self._needs_rebuild: + logger.info(f"🔄 需要重建索引: {self._rebuild_reason}") + await self.vector_store.initialize(force_recreate=True) + + self._initialized = True + return self._needs_rebuild, self._rebuild_reason + + async def _check_rebuild_needed(self) -> Tuple[bool, str]: + """ + 检查是否需要重建索引 + + Returns: + (needs_rebuild, reason) + """ + # 如果是新 collection,不需要重建(因为本来就是空的) + if hasattr(self.vector_store, 'is_new_collection') and self.vector_store.is_new_collection: + return False, "" + + # 获取现有 collection 的配置 + stored_config = self.vector_store.get_embedding_config() + stored_metadata = self.vector_store.get_collection_metadata() + + # 检查索引版本 + stored_version = stored_metadata.get("index_version", "1.0") + if stored_version != INDEX_VERSION: + return True, f"索引版本变更: {stored_version} -> {INDEX_VERSION}" + + # 检查 embedding 提供商 + stored_provider = stored_config.get("provider") + current_provider = self.embedding_config.get("provider") + if stored_provider and current_provider and stored_provider != current_provider: + return True, f"Embedding 提供商变更: {stored_provider} -> {current_provider}" + + # 检查 embedding 模型 + stored_model = stored_config.get("model") + current_model = self.embedding_config.get("model") + if stored_model and current_model and stored_model != current_model: + return True, f"Embedding 模型变更: {stored_model} -> {current_model}" + + # 检查维度 + stored_dimension = stored_config.get("dimension") + current_dimension = self.embedding_config.get("dimension") + if stored_dimension and current_dimension and stored_dimension != current_dimension: + return True, f"Embedding 维度变更: {stored_dimension} -> {current_dimension}" + + return False, "" + + async def get_index_status(self) -> IndexStatus: + """获取索引状态""" + await self.initialize() + + metadata = self.vector_store.get_collection_metadata() + embedding_config = self.vector_store.get_embedding_config() + chunk_count = await self.vector_store.get_count() + file_paths = await self.vector_store.get_all_file_paths() + + return IndexStatus( + collection_name=self.collection_name, + exists=chunk_count > 0, + index_version=metadata.get("index_version", ""), + chunk_count=chunk_count, + file_count=len(file_paths), + created_at=metadata.get("created_at", 0), + updated_at=metadata.get("updated_at", 0), + embedding_provider=embedding_config.get("provider", ""), + embedding_model=embedding_config.get("model", ""), + embedding_dimension=embedding_config.get("dimension", 0), + project_hash=metadata.get("project_hash", ""), + ) + + async def smart_index_directory( + self, + directory: str, + exclude_patterns: Optional[List[str]] = None, + include_patterns: Optional[List[str]] = None, + update_mode: IndexUpdateMode = IndexUpdateMode.SMART, + progress_callback: Optional[Callable[[IndexingProgress], None]] = None, + ) -> AsyncGenerator[IndexingProgress, None]: + """ + 智能索引目录 + + Args: + directory: 目录路径 + exclude_patterns: 排除模式 + include_patterns: 包含模式 + update_mode: 更新模式 + progress_callback: 进度回调 + + Yields: + 索引进度 + """ + # 初始化并检查是否需要重建 + needs_rebuild, rebuild_reason = await self.initialize() + + progress = IndexingProgress() + exclude_patterns = exclude_patterns or [] + + # 确定实际的更新模式 + if update_mode == IndexUpdateMode.SMART: + if needs_rebuild: + actual_mode = IndexUpdateMode.FULL + logger.info(f"🔄 智能模式: 选择全量重建 (原因: {rebuild_reason})") + else: + actual_mode = IndexUpdateMode.INCREMENTAL + logger.info("📝 智能模式: 选择增量更新") + else: + actual_mode = update_mode + + progress.update_mode = actual_mode.value + + if actual_mode == IndexUpdateMode.FULL: + # 全量重建 + async for p in self._full_index(directory, exclude_patterns, include_patterns, progress, progress_callback): + yield p + else: + # 增量更新 + async for p in self._incremental_index(directory, exclude_patterns, include_patterns, progress, progress_callback): + yield p + + async def _full_index( + self, + directory: str, + exclude_patterns: List[str], + include_patterns: Optional[List[str]], + progress: IndexingProgress, + progress_callback: Optional[Callable[[IndexingProgress], None]], + ) -> AsyncGenerator[IndexingProgress, None]: + """全量索引""" + logger.info("🔄 开始全量索引...") + + # 收集文件 + files = self._collect_files(directory, exclude_patterns, include_patterns) + progress.total_files = len(files) + + logger.info(f"📁 发现 {len(files)} 个文件待索引") + yield progress + + all_chunks: List[CodeChunk] = [] + file_hashes: Dict[str, str] = {} + + # 分块处理文件 + for file_path in files: + progress.current_file = file_path + + try: + relative_path = os.path.relpath(file_path, directory) + + with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: + content = f.read() + + if not content.strip(): + progress.processed_files += 1 + progress.skipped_files += 1 + continue + + # 计算文件 hash + file_hash = hashlib.md5(content.encode()).hexdigest() + file_hashes[relative_path] = file_hash + + # 限制文件大小 + if len(content) > 500000: + content = content[:500000] + + # 分块 + chunks = self.splitter.split_file(content, relative_path) + + # 为每个 chunk 添加 file_hash + for chunk in chunks: + chunk.metadata["file_hash"] = file_hash + + all_chunks.extend(chunks) + + progress.processed_files += 1 + progress.added_files += 1 + progress.total_chunks = len(all_chunks) + + if progress_callback: + progress_callback(progress) + yield progress + + except Exception as e: + logger.warning(f"处理文件失败 {file_path}: {e}") + progress.errors.append(f"{file_path}: {str(e)}") + progress.processed_files += 1 + + logger.info(f"📝 创建了 {len(all_chunks)} 个代码块") + + # 批量嵌入和索引 + if all_chunks: + await self._index_chunks(all_chunks, progress, use_upsert=False) + + # 更新 collection 元数据 + project_hash = hashlib.md5(json.dumps(sorted(file_hashes.items())).encode()).hexdigest() + await self.vector_store.update_collection_metadata({ + "project_hash": project_hash, + "file_count": len(file_hashes), + }) + + progress.indexed_chunks = len(all_chunks) + logger.info(f"✅ 全量索引完成: {progress.added_files} 个文件, {len(all_chunks)} 个代码块") + yield progress + + async def _incremental_index( + self, + directory: str, + exclude_patterns: List[str], + include_patterns: Optional[List[str]], + progress: IndexingProgress, + progress_callback: Optional[Callable[[IndexingProgress], None]], + ) -> AsyncGenerator[IndexingProgress, None]: + """增量索引""" + logger.info("📝 开始增量索引...") + + # 获取已索引文件的 hash + indexed_file_hashes = await self.vector_store.get_file_hashes() + indexed_files = set(indexed_file_hashes.keys()) + + # 收集当前文件 + current_files = self._collect_files(directory, exclude_patterns, include_patterns) + current_file_map: Dict[str, str] = {} # relative_path -> absolute_path + + for file_path in current_files: + relative_path = os.path.relpath(file_path, directory) + current_file_map[relative_path] = file_path + + current_file_set = set(current_file_map.keys()) + + # 计算差异 + files_to_add = current_file_set - indexed_files + files_to_delete = indexed_files - current_file_set + files_to_check = current_file_set & indexed_files + + # 检查需要更新的文件(hash 变化) + files_to_update: Set[str] = set() + for relative_path in files_to_check: + file_path = current_file_map[relative_path] + try: + with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: + content = f.read() + current_hash = hashlib.md5(content.encode()).hexdigest() + if current_hash != indexed_file_hashes.get(relative_path): + files_to_update.add(relative_path) + except Exception: + files_to_update.add(relative_path) + + total_operations = len(files_to_add) + len(files_to_delete) + len(files_to_update) + progress.total_files = total_operations + + logger.info(f"📊 增量更新: 新增 {len(files_to_add)}, 删除 {len(files_to_delete)}, 更新 {len(files_to_update)}") + yield progress + + # 删除已移除的文件 + for relative_path in files_to_delete: + progress.current_file = f"删除: {relative_path}" + deleted_count = await self.vector_store.delete_by_file_path(relative_path) + progress.deleted_files += 1 + progress.processed_files += 1 + logger.debug(f"🗑️ 删除文件 '{relative_path}' 的 {deleted_count} 个代码块") + + if progress_callback: + progress_callback(progress) + yield progress + + # 处理新增和更新的文件 + files_to_process = files_to_add | files_to_update + all_chunks: List[CodeChunk] = [] + file_hashes: Dict[str, str] = dict(indexed_file_hashes) + + for relative_path in files_to_process: + file_path = current_file_map[relative_path] + progress.current_file = relative_path + is_update = relative_path in files_to_update + + try: + with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: + content = f.read() + + if not content.strip(): + progress.processed_files += 1 + progress.skipped_files += 1 + continue + + # 如果是更新,先删除旧的 + if is_update: + await self.vector_store.delete_by_file_path(relative_path) + + # 计算文件 hash + file_hash = hashlib.md5(content.encode()).hexdigest() + file_hashes[relative_path] = file_hash + + # 限制文件大小 + if len(content) > 500000: + content = content[:500000] + + # 分块 + chunks = self.splitter.split_file(content, relative_path) + + # 为每个 chunk 添加 file_hash + for chunk in chunks: + chunk.metadata["file_hash"] = file_hash + + all_chunks.extend(chunks) + + progress.processed_files += 1 + if is_update: + progress.updated_files += 1 + else: + progress.added_files += 1 + progress.total_chunks += len(chunks) + + if progress_callback: + progress_callback(progress) + yield progress + + except Exception as e: + logger.warning(f"处理文件失败 {file_path}: {e}") + progress.errors.append(f"{file_path}: {str(e)}") + progress.processed_files += 1 + + # 批量嵌入和索引新的代码块 + if all_chunks: + await self._index_chunks(all_chunks, progress, use_upsert=True) + + # 更新 collection 元数据 + # 移除已删除文件的 hash + for relative_path in files_to_delete: + file_hashes.pop(relative_path, None) + + project_hash = hashlib.md5(json.dumps(sorted(file_hashes.items())).encode()).hexdigest() + await self.vector_store.update_collection_metadata({ + "project_hash": project_hash, + "file_count": len(file_hashes), + }) + + progress.indexed_chunks = len(all_chunks) + logger.info( + f"✅ 增量索引完成: 新增 {progress.added_files}, " + f"更新 {progress.updated_files}, 删除 {progress.deleted_files}" + ) + yield progress + + # 保留原有的 index_directory 方法作为兼容 async def index_directory( self, directory: str, @@ -403,74 +1111,26 @@ class CodeIndexer: progress_callback: Optional[Callable[[IndexingProgress], None]] = None, ) -> AsyncGenerator[IndexingProgress, None]: """ - 索引目录中的代码文件 - + 索引目录(使用智能模式) + Args: directory: 目录路径 exclude_patterns: 排除模式 include_patterns: 包含模式 progress_callback: 进度回调 - + Yields: 索引进度 """ - await self.initialize() - - progress = IndexingProgress() - exclude_patterns = exclude_patterns or [] - - # 收集文件 - files = self._collect_files(directory, exclude_patterns, include_patterns) - progress.total_files = len(files) - - logger.info(f"Found {len(files)} files to index in {directory}") - yield progress - - all_chunks: List[CodeChunk] = [] - - # 分块处理文件 - for file_path in files: - progress.current_file = file_path - - try: - relative_path = os.path.relpath(file_path, directory) - - with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: - content = f.read() - - if not content.strip(): - progress.processed_files += 1 - continue - - # 限制文件大小 - if len(content) > 500000: # 500KB - content = content[:500000] - - # 分块 - chunks = self.splitter.split_file(content, relative_path) - all_chunks.extend(chunks) - - progress.processed_files += 1 - progress.total_chunks = len(all_chunks) - - if progress_callback: - progress_callback(progress) - yield progress - - except Exception as e: - logger.warning(f"Error processing {file_path}: {e}") - progress.errors.append(f"{file_path}: {str(e)}") - progress.processed_files += 1 - - logger.info(f"Created {len(all_chunks)} chunks from {len(files)} files") - - # 批量嵌入和索引 - if all_chunks: - await self._index_chunks(all_chunks, progress) - - progress.indexed_chunks = len(all_chunks) - yield progress - + async for progress in self.smart_index_directory( + directory=directory, + exclude_patterns=exclude_patterns, + include_patterns=include_patterns, + update_mode=IndexUpdateMode.SMART, + progress_callback=progress_callback, + ): + yield progress + async def index_files( self, files: List[Dict[str, str]], @@ -479,86 +1139,113 @@ class CodeIndexer: ) -> AsyncGenerator[IndexingProgress, None]: """ 索引文件列表 - + Args: files: 文件列表 [{"path": "...", "content": "..."}] base_path: 基础路径 progress_callback: 进度回调 - + Yields: 索引进度 """ await self.initialize() - + progress = IndexingProgress() progress.total_files = len(files) - + all_chunks: List[CodeChunk] = [] - + for file_info in files: file_path = file_info.get("path", "") content = file_info.get("content", "") - + progress.current_file = file_path - + try: if not content.strip(): progress.processed_files += 1 + progress.skipped_files += 1 continue - + + # 计算文件 hash + file_hash = hashlib.md5(content.encode()).hexdigest() + # 限制文件大小 if len(content) > 500000: content = content[:500000] - + # 分块 chunks = self.splitter.split_file(content, file_path) + + # 为每个 chunk 添加 file_hash + for chunk in chunks: + chunk.metadata["file_hash"] = file_hash + all_chunks.extend(chunks) - + progress.processed_files += 1 + progress.added_files += 1 progress.total_chunks = len(all_chunks) - + if progress_callback: progress_callback(progress) yield progress - + except Exception as e: - logger.warning(f"Error processing {file_path}: {e}") + logger.warning(f"处理文件失败 {file_path}: {e}") progress.errors.append(f"{file_path}: {str(e)}") progress.processed_files += 1 - + # 批量嵌入和索引 if all_chunks: - await self._index_chunks(all_chunks, progress) - + await self._index_chunks(all_chunks, progress, use_upsert=True) + progress.indexed_chunks = len(all_chunks) yield progress - - async def _index_chunks(self, chunks: List[CodeChunk], progress: IndexingProgress): + + async def _index_chunks( + self, + chunks: List[CodeChunk], + progress: IndexingProgress, + use_upsert: bool = False, + ): """索引代码块""" + if not chunks: + return + # 准备嵌入文本 texts = [chunk.to_embedding_text() for chunk in chunks] - - logger.info(f"Generating embeddings for {len(texts)} chunks...") - + + logger.info(f"🔢 生成 {len(texts)} 个代码块的嵌入向量...") + # 批量嵌入 embeddings = await self.embedding_service.embed_batch(texts, batch_size=50) - + # 准备元数据 ids = [chunk.id for chunk in chunks] documents = [chunk.content for chunk in chunks] metadatas = [chunk.to_dict() for chunk in chunks] - + # 添加到向量存储 - logger.info(f"Adding {len(chunks)} chunks to vector store...") - await self.vector_store.add_documents( - ids=ids, - embeddings=embeddings, - documents=documents, - metadatas=metadatas, - ) - - logger.info(f"Indexed {len(chunks)} chunks successfully") - + logger.info(f"💾 添加 {len(chunks)} 个代码块到向量存储...") + + if use_upsert: + await self.vector_store.upsert_documents( + ids=ids, + embeddings=embeddings, + documents=documents, + metadatas=metadatas, + ) + else: + await self.vector_store.add_documents( + ids=ids, + embeddings=embeddings, + documents=documents, + metadatas=metadatas, + ) + + logger.info(f"✅ 索引 {len(chunks)} 个代码块成功") + def _collect_files( self, directory: str, @@ -567,36 +1254,36 @@ class CodeIndexer: ) -> List[str]: """收集需要索引的文件""" import fnmatch - + files = [] - + for root, dirs, filenames in os.walk(directory): # 过滤目录 dirs[:] = [d for d in dirs if d not in EXCLUDE_DIRS] - + for filename in filenames: # 检查扩展名 ext = os.path.splitext(filename)[1].lower() if ext not in TEXT_EXTENSIONS: continue - + # 检查排除文件 if filename in EXCLUDE_FILES: continue - + file_path = os.path.join(root, filename) relative_path = os.path.relpath(file_path, directory) - + # 检查排除模式 excluded = False for pattern in exclude_patterns: if fnmatch.fnmatch(relative_path, pattern) or fnmatch.fnmatch(filename, pattern): excluded = True break - + if excluded: continue - + # 检查包含模式 if include_patterns: included = False @@ -606,19 +1293,55 @@ class CodeIndexer: break if not included: continue - + files.append(file_path) - + return files - + async def get_chunk_count(self) -> int: """获取已索引的代码块数量""" await self.initialize() return await self.vector_store.get_count() - + async def clear(self): """清空索引""" await self.initialize() await self.vector_store.delete_collection() self._initialized = False + async def delete_file(self, file_path: str) -> int: + """ + 删除指定文件的索引 + + Args: + file_path: 文件路径 + + Returns: + 删除的代码块数量 + """ + await self.initialize() + return await self.vector_store.delete_by_file_path(file_path) + + async def rebuild(self, directory: str, **kwargs) -> AsyncGenerator[IndexingProgress, None]: + """ + 强制重建索引 + + Args: + directory: 目录路径 + **kwargs: 传递给 smart_index_directory 的其他参数 + + Yields: + 索引进度 + """ + # 强制重新初始化 + self._initialized = False + await self.initialize(force_rebuild=True) + + async for progress in self.smart_index_directory( + directory=directory, + update_mode=IndexUpdateMode.FULL, + **kwargs, + ): + yield progress + + diff --git a/backend/app/services/rag/splitter.py b/backend/app/services/rag/splitter.py index 78ff634..2144f1c 100644 --- a/backend/app/services/rag/splitter.py +++ b/backend/app/services/rag/splitter.py @@ -188,22 +188,34 @@ class TreeSitterParser: }, } + # tree-sitter-languages 支持的语言列表 + SUPPORTED_LANGUAGES = { + "python", "javascript", "typescript", "tsx", "java", "go", "rust", + "c", "cpp", "c_sharp", "php", "ruby", "kotlin", "swift", "bash", + "json", "yaml", "html", "css", "sql", "markdown", + } + def __init__(self): self._parsers: Dict[str, Any] = {} self._initialized = False - + def _ensure_initialized(self, language: str) -> bool: """确保语言解析器已初始化""" if language in self._parsers: return True - + + # 检查语言是否受支持 + if language not in self.SUPPORTED_LANGUAGES: + # 不是 tree-sitter 支持的语言,静默跳过 + return False + try: - from tree_sitter_languages import get_parser, get_language - + from tree_sitter_languages import get_parser + parser = get_parser(language) self._parsers[language] = parser return True - + except ImportError: logger.warning("tree-sitter-languages not installed, falling back to regex parsing") return False diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 9bf5c71..a77b308 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -60,7 +60,8 @@ dependencies = [ "chromadb>=0.4.22", # ============ Code Parsing ============ - "tree-sitter>=0.21.0", + # tree-sitter-languages 1.10.x 与 tree-sitter 0.22+ 不兼容 + "tree-sitter>=0.21.0,<0.22.0", "tree-sitter-languages>=1.10.0", "pygments>=2.17.0", diff --git a/backend/requirements.txt b/backend/requirements.txt index 1b18451..37f81e8 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -47,7 +47,8 @@ langgraph>=0.0.40 chromadb>=0.4.22 # ============ Code Parsing ============ -tree-sitter>=0.21.0 +# tree-sitter-languages 1.10.x 与 tree-sitter 0.22+ 不兼容 +tree-sitter>=0.21.0,<0.22.0 tree-sitter-languages>=1.10.0 pygments>=2.17.0 diff --git a/frontend/src/pages/AgentAudit/index.tsx b/frontend/src/pages/AgentAudit/index.tsx index ac5782c..2ca9e12 100644 --- a/frontend/src/pages/AgentAudit/index.tsx +++ b/frontend/src/pages/AgentAudit/index.tsx @@ -392,19 +392,33 @@ function AgentAuditPageContent() { setCurrentAgentName(event.metadata.agent_name); } - const dispatchEvents = ['dispatch', 'dispatch_complete', 'node_start', 'phase_start']; + const dispatchEvents = ['dispatch', 'dispatch_complete', 'node_start', 'phase_start', 'phase_complete']; if (dispatchEvents.includes(event.type)) { - if (event.type === 'dispatch' || event.type === 'dispatch_complete') { - dispatch({ - type: 'ADD_LOG', - payload: { - type: 'dispatch', - title: event.message || `Agent dispatch: ${event.metadata?.agent || 'unknown'}`, - agentName: getCurrentAgentName() || undefined, - } - }); - } + // 所有 dispatch 类型事件都添加到日志 + dispatch({ + type: 'ADD_LOG', + payload: { + type: 'dispatch', + title: event.message || `Agent dispatch: ${event.metadata?.agent || 'unknown'}`, + agentName: getCurrentAgentName() || undefined, + } + }); debouncedLoadAgentTree(); + return; + } + + // 🔥 处理 info、warning、error 类型事件(克隆进度、索引进度等) + const infoEvents = ['info', 'warning', 'error', 'progress']; + if (infoEvents.includes(event.type)) { + dispatch({ + type: 'ADD_LOG', + payload: { + type: event.type === 'error' ? 'error' : 'info', + title: event.message || event.type, + agentName: getCurrentAgentName() || undefined, + } + }); + return; } }, onThinkingStart: () => {