From f71b8da7df06bb15ca9ab4b5caa21f624026f1a8 Mon Sep 17 00:00:00 2001 From: lintsinghua Date: Tue, 16 Dec 2025 19:42:44 +0800 Subject: [PATCH] =?UTF-8?q?feat(embedding):=20=E6=94=AF=E6=8C=81=E5=89=8D?= =?UTF-8?q?=E7=AB=AF=E9=85=8D=E7=BD=AE=E5=B5=8C=E5=85=A5=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E7=9A=84API=E5=AF=86=E9=92=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit refactor(agent): 改进任务取消逻辑,确保子Agent被正确取消 - 移除asyncio.shield()以允许取消信号传播 - 增加更频繁的取消状态检查 - 添加日志记录子Agent取消情况 feat(nginx): 添加前端构建产物和nginx配置的挂载 refactor(rag): 优化代码索引器的日志输出和元数据处理 - 添加索引文件数量的调试日志 - 将元数据字段提升到顶级以便检索 fix(parser): 修复AST定义提取中的方法识别问题 - 区分函数和方法定义 - 优化遍历逻辑避免重复匹配 --- backend/app/api/v1/endpoints/agent_tasks.py | 18 +++++-- .../app/api/v1/endpoints/embedding_config.py | 47 +++++++++++-------- .../app/services/agent/agents/orchestrator.py | 22 +++++++-- backend/app/services/rag/indexer.py | 6 +++ backend/app/services/rag/splitter.py | 46 +++++++++++++----- docker-compose.yml | 3 ++ .../src/components/agent/EmbeddingConfig.tsx | 20 ++++---- 7 files changed, 113 insertions(+), 49 deletions(-) diff --git a/backend/app/api/v1/endpoints/agent_tasks.py b/backend/app/api/v1/endpoints/agent_tasks.py index c4fd398..32788cc 100644 --- a/backend/app/api/v1/endpoints/agent_tasks.py +++ b/backend/app/api/v1/endpoints/agent_tasks.py @@ -1583,18 +1583,28 @@ async def cancel_agent_task( if runner: runner.cancel() logger.info(f"[Cancel] Set cancel flag for task {task_id}") - - # 🔥 2. 强制取消 asyncio Task(立即中断 LLM 调用) + + # 🔥 2. 通过 agent_registry 取消所有子 Agent + from app.services.agent.core import agent_registry + from app.services.agent.core.graph_controller import stop_all_agents + try: + # 停止所有 Agent(包括子 Agent) + stop_result = stop_all_agents(exclude_root=False) + logger.info(f"[Cancel] Stopped all agents: {stop_result}") + except Exception as e: + logger.warning(f"[Cancel] Failed to stop agents via registry: {e}") + + # 🔥 3. 强制取消 asyncio Task(立即中断 LLM 调用) asyncio_task = _running_asyncio_tasks.get(task_id) if asyncio_task and not asyncio_task.done(): asyncio_task.cancel() logger.info(f"[Cancel] Cancelled asyncio task for {task_id}") - + # 更新状态 task.status = AgentTaskStatus.CANCELLED task.completed_at = datetime.now(timezone.utc) await db.commit() - + logger.info(f"[Cancel] Task {task_id} cancelled successfully") return {"message": "任务已取消", "task_id": task_id} diff --git a/backend/app/api/v1/endpoints/embedding_config.py b/backend/app/api/v1/endpoints/embedding_config.py index bc91c51..541bf2a 100644 --- a/backend/app/api/v1/endpoints/embedding_config.py +++ b/backend/app/api/v1/endpoints/embedding_config.py @@ -11,6 +11,7 @@ from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel, Field from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm.attributes import flag_modified from app.api import deps from app.models.user import User @@ -46,10 +47,10 @@ class EmbeddingConfigResponse(BaseModel): """配置响应""" provider: str model: str + api_key: Optional[str] = None # 返回 API Key base_url: Optional[str] dimensions: int batch_size: int - # 不返回 API Key class TestEmbeddingRequest(BaseModel): @@ -165,14 +166,14 @@ async def get_embedding_config_from_db(db: AsyncSession, user_id: str) -> Embedd select(UserConfig).where(UserConfig.user_id == user_id) ) user_config = result.scalar_one_or_none() - + if user_config and user_config.other_config: try: other_config = json.loads(user_config.other_config) if isinstance(user_config.other_config, str) else user_config.other_config embedding_data = other_config.get(EMBEDDING_CONFIG_KEY) - + if embedding_data: - return EmbeddingConfig( + config = EmbeddingConfig( provider=embedding_data.get("provider", settings.EMBEDDING_PROVIDER), model=embedding_data.get("model", settings.EMBEDDING_MODEL), api_key=embedding_data.get("api_key"), @@ -180,10 +181,13 @@ async def get_embedding_config_from_db(db: AsyncSession, user_id: str) -> Embedd dimensions=embedding_data.get("dimensions"), batch_size=embedding_data.get("batch_size", 100), ) - except (json.JSONDecodeError, AttributeError): - pass - + print(f"[EmbeddingConfig] 读取用户 {user_id} 的嵌入配置: provider={config.provider}, model={config.model}") + return config + except (json.JSONDecodeError, AttributeError) as e: + print(f"[EmbeddingConfig] 解析用户 {user_id} 配置失败: {e}") + # 返回默认配置 + print(f"[EmbeddingConfig] 用户 {user_id} 无保存配置,返回默认值") return EmbeddingConfig( provider=settings.EMBEDDING_PROVIDER, model=settings.EMBEDDING_MODEL, @@ -199,7 +203,7 @@ async def save_embedding_config_to_db(db: AsyncSession, user_id: str, config: Em select(UserConfig).where(UserConfig.user_id == user_id) ) user_config = result.scalar_one_or_none() - + # 准备嵌入配置数据 embedding_data = { "provider": config.provider, @@ -209,16 +213,18 @@ async def save_embedding_config_to_db(db: AsyncSession, user_id: str, config: Em "dimensions": config.dimensions, "batch_size": config.batch_size, } - + if user_config: # 更新现有配置 try: other_config = json.loads(user_config.other_config) if user_config.other_config else {} except (json.JSONDecodeError, TypeError): other_config = {} - + other_config[EMBEDDING_CONFIG_KEY] = embedding_data user_config.other_config = json.dumps(other_config) + # 🔥 显式标记 other_config 字段已修改,确保 SQLAlchemy 检测到变化 + flag_modified(user_config, "other_config") else: # 创建新配置 user_config = UserConfig( @@ -228,8 +234,9 @@ async def save_embedding_config_to_db(db: AsyncSession, user_id: str, config: Em other_config=json.dumps({EMBEDDING_CONFIG_KEY: embedding_data}), ) db.add(user_config) - + await db.commit() + print(f"[EmbeddingConfig] 已保存用户 {user_id} 的嵌入配置: provider={config.provider}, model={config.model}") # ============ API Endpoints ============ @@ -253,13 +260,14 @@ async def get_current_config( 获取当前嵌入模型配置(从数据库读取) """ config = await get_embedding_config_from_db(db, current_user.id) - + # 获取维度 dimensions = _get_model_dimensions(config.provider, config.model) - + return EmbeddingConfigResponse( provider=config.provider, model=config.model, + api_key=config.api_key, base_url=config.base_url, dimensions=dimensions, batch_size=config.batch_size, @@ -279,19 +287,18 @@ async def update_config( provider_ids = [p.id for p in EMBEDDING_PROVIDERS] if config.provider not in provider_ids: raise HTTPException(status_code=400, detail=f"不支持的提供商: {config.provider}") - - # 验证模型 + + # 获取提供商信息(用于检查 API Key 要求) provider = next((p for p in EMBEDDING_PROVIDERS if p.id == config.provider), None) - if provider and config.model not in provider.models: - raise HTTPException(status_code=400, detail=f"不支持的模型: {config.model}") - + # 注意:不再强制验证模型名称,允许用户输入自定义模型 + # 检查 API Key if provider and provider.requires_api_key and not config.api_key: raise HTTPException(status_code=400, detail=f"{config.provider} 需要 API Key") - + # 保存到数据库 await save_embedding_config_to_db(db, current_user.id, config) - + return {"message": "配置已保存", "provider": config.provider, "model": config.model} diff --git a/backend/app/services/agent/agents/orchestrator.py b/backend/app/services/agent/agents/orchestrator.py index b99973f..6d39491 100644 --- a/backend/app/services/agent/agents/orchestrator.py +++ b/backend/app/services/agent/agents/orchestrator.py @@ -667,7 +667,8 @@ Action Input: {{"参数": "值"}} try: while not run_task.done(): if self.is_cancelled: - # 传播取消到子 Agent + # 🔥 传播取消到子 Agent + logger.info(f"[{self.name}] Cancelling sub-agent {agent_name} due to parent cancel") if hasattr(agent, 'cancel'): agent.cancel() run_task.cancel() @@ -678,17 +679,32 @@ Action Input: {{"参数": "值"}} raise asyncio.CancelledError("任务已取消") try: + # 🔥 移除 asyncio.shield(),让取消信号可以直接传播 + # 使用较短的超时来更频繁地检查取消状态 return await asyncio.wait_for( - asyncio.shield(run_task), - timeout=1.0 # 每秒检查一次取消状态 + run_task, + timeout=0.5 # 🔥 每0.5秒检查一次取消状态 ) except asyncio.TimeoutError: continue + except asyncio.CancelledError: + # 🔥 捕获取消异常,确保子Agent也被取消 + logger.info(f"[{self.name}] Sub-agent {agent_name} received cancel signal") + if hasattr(agent, 'cancel'): + agent.cancel() + raise return await run_task except asyncio.CancelledError: + # 🔥 确保子任务被取消 if not run_task.done(): + if hasattr(agent, 'cancel'): + agent.cancel() run_task.cancel() + try: + await run_task + except asyncio.CancelledError: + pass raise try: diff --git a/backend/app/services/rag/indexer.py b/backend/app/services/rag/indexer.py index 168d489..d82ba68 100644 --- a/backend/app/services/rag/indexer.py +++ b/backend/app/services/rag/indexer.py @@ -992,6 +992,8 @@ class CodeIndexer: indexed_file_hashes = await self.vector_store.get_file_hashes() indexed_files = set(indexed_file_hashes.keys()) + logger.debug(f"📂 已索引文件数: {len(indexed_files)}, file_hashes: {list(indexed_file_hashes.keys())[:5]}...") + # 收集当前文件 current_files = self._collect_files(directory, exclude_patterns, include_patterns) current_file_map: Dict[str, str] = {} # relative_path -> absolute_path @@ -1002,11 +1004,15 @@ class CodeIndexer: current_file_set = set(current_file_map.keys()) + logger.debug(f"📁 当前文件数: {len(current_file_set)}, 示例: {list(current_file_set)[:5]}...") + # 计算差异 files_to_add = current_file_set - indexed_files files_to_delete = indexed_files - current_file_set files_to_check = current_file_set & indexed_files + logger.debug(f"📊 差异分析: 交集={len(files_to_check)}, 新增候选={len(files_to_add)}, 删除候选={len(files_to_delete)}") + # 检查需要更新的文件(hash 变化) files_to_update: Set[str] = set() for relative_path in files_to_check: diff --git a/backend/app/services/rag/splitter.py b/backend/app/services/rag/splitter.py index 4dbc89e..cb8b672 100644 --- a/backend/app/services/rag/splitter.py +++ b/backend/app/services/rag/splitter.py @@ -92,7 +92,7 @@ class CodeChunk: return len(self.content) // 4 def to_dict(self) -> Dict[str, Any]: - return { + result = { "id": self.id, "content": self.content, "file_path": self.file_path, @@ -110,8 +110,13 @@ class CodeChunk: "definitions": self.definitions, "security_indicators": self.security_indicators, "estimated_tokens": self.estimated_tokens, - "metadata": self.metadata, } + # 将 metadata 中的字段提升到顶级,确保 file_hash 等字段可以被正确检索 + if self.metadata: + for key, value in self.metadata.items(): + if key not in result: + result[key] = value + return result def to_embedding_text(self) -> str: """生成用于嵌入的文本""" @@ -244,20 +249,29 @@ class TreeSitterParser: """从 AST 提取定义""" if tree is None: return [] - + definitions = [] definition_types = self.DEFINITION_TYPES.get(language, {}) - + def traverse(node, parent_name=None): node_type = node.type - + # 检查是否是定义节点 + matched = False for def_category, types in definition_types.items(): if node_type in types: name = self._extract_name(node, language) - + + # 根据是否有 parent_name 来区分 function 和 method + actual_category = def_category + if def_category == "function" and parent_name: + actual_category = "method" + elif def_category == "method" and not parent_name: + # 跳过没有 parent 的 method 定义(由 function 类别处理) + continue + definitions.append({ - "type": def_category, + "type": actual_category, "name": name, "parent_name": parent_name, "start_point": node.start_point, @@ -266,17 +280,23 @@ class TreeSitterParser: "end_byte": node.end_byte, "node_type": node_type, }) - + + matched = True + # 对于类,继续遍历子节点找方法 if def_category == "class": for child in node.children: traverse(child, name) return - - # 继续遍历子节点 - for child in node.children: - traverse(child, parent_name) - + + # 匹配到一个类别后就不再匹配其他类别 + break + + # 如果没有匹配到定义,继续遍历子节点 + if not matched: + for child in node.children: + traverse(child, parent_name) + traverse(tree.root_node) return definitions diff --git a/docker-compose.yml b/docker-compose.yml index add6aca..484ac3c 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -80,6 +80,9 @@ services: - all_proxy= - ALL_PROXY= restart: unless-stopped + volumes: + - ./frontend/dist:/usr/share/nginx/html:ro # 挂载构建产物,本地 pnpm build 后自动生效 + - ./frontend/nginx.conf:/etc/nginx/conf.d/default.conf:ro # 挂载 nginx 配置 ports: - "3000:80" # Nginx 监听 80 端口 depends_on: diff --git a/frontend/src/components/agent/EmbeddingConfig.tsx b/frontend/src/components/agent/EmbeddingConfig.tsx index a4f1c98..cbbda76 100644 --- a/frontend/src/components/agent/EmbeddingConfig.tsx +++ b/frontend/src/components/agent/EmbeddingConfig.tsx @@ -46,6 +46,7 @@ interface EmbeddingProvider { interface EmbeddingConfig { provider: string; model: string; + api_key: string | null; base_url: string | null; dimensions: number; batch_size: number; @@ -79,15 +80,15 @@ export default function EmbeddingConfigPanel() { loadData(); }, []); - // 当 provider 改变时更新模型 - useEffect(() => { - if (selectedProvider) { - const provider = providers.find((p) => p.id === selectedProvider); - if (provider) { - setSelectedModel(provider.default_model); - } + // 用户手动切换 provider 时更新为默认模型 + const handleProviderChange = (newProvider: string) => { + setSelectedProvider(newProvider); + // 切换 provider 时重置为该 provider 的默认模型 + const provider = providers.find((p) => p.id === newProvider); + if (provider) { + setSelectedModel(provider.default_model); } - }, [selectedProvider, providers]); + }; const loadData = async () => { try { @@ -104,6 +105,7 @@ export default function EmbeddingConfigPanel() { if (configRes.data) { setSelectedProvider(configRes.data.provider); setSelectedModel(configRes.data.model); + setApiKey(configRes.data.api_key || ""); setBaseUrl(configRes.data.base_url || ""); setBatchSize(configRes.data.batch_size); } @@ -230,7 +232,7 @@ export default function EmbeddingConfigPanel() { {/* 提供商选择 */}
-