feat(embedding): 支持前端配置嵌入模型的API密钥

refactor(agent): 改进任务取消逻辑,确保子Agent被正确取消
- 移除asyncio.shield()以允许取消信号传播
- 增加更频繁的取消状态检查
- 添加日志记录子Agent取消情况

feat(nginx): 添加前端构建产物和nginx配置的挂载

refactor(rag): 优化代码索引器的日志输出和元数据处理
- 添加索引文件数量的调试日志
- 将元数据字段提升到顶级以便检索

fix(parser): 修复AST定义提取中的方法识别问题
- 区分函数和方法定义
- 优化遍历逻辑避免重复匹配
This commit is contained in:
lintsinghua 2025-12-16 19:42:44 +08:00
parent 7efb89d2d2
commit f71b8da7df
7 changed files with 113 additions and 49 deletions

View File

@ -1583,18 +1583,28 @@ async def cancel_agent_task(
if runner: if runner:
runner.cancel() runner.cancel()
logger.info(f"[Cancel] Set cancel flag for task {task_id}") logger.info(f"[Cancel] Set cancel flag for task {task_id}")
# 🔥 2. 强制取消 asyncio Task立即中断 LLM 调用) # 🔥 2. 通过 agent_registry 取消所有子 Agent
from app.services.agent.core import agent_registry
from app.services.agent.core.graph_controller import stop_all_agents
try:
# 停止所有 Agent包括子 Agent
stop_result = stop_all_agents(exclude_root=False)
logger.info(f"[Cancel] Stopped all agents: {stop_result}")
except Exception as e:
logger.warning(f"[Cancel] Failed to stop agents via registry: {e}")
# 🔥 3. 强制取消 asyncio Task立即中断 LLM 调用)
asyncio_task = _running_asyncio_tasks.get(task_id) asyncio_task = _running_asyncio_tasks.get(task_id)
if asyncio_task and not asyncio_task.done(): if asyncio_task and not asyncio_task.done():
asyncio_task.cancel() asyncio_task.cancel()
logger.info(f"[Cancel] Cancelled asyncio task for {task_id}") logger.info(f"[Cancel] Cancelled asyncio task for {task_id}")
# 更新状态 # 更新状态
task.status = AgentTaskStatus.CANCELLED task.status = AgentTaskStatus.CANCELLED
task.completed_at = datetime.now(timezone.utc) task.completed_at = datetime.now(timezone.utc)
await db.commit() await db.commit()
logger.info(f"[Cancel] Task {task_id} cancelled successfully") logger.info(f"[Cancel] Task {task_id} cancelled successfully")
return {"message": "任务已取消", "task_id": task_id} return {"message": "任务已取消", "task_id": task_id}

View File

@ -11,6 +11,7 @@ from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.api import deps from app.api import deps
from app.models.user import User from app.models.user import User
@ -46,10 +47,10 @@ class EmbeddingConfigResponse(BaseModel):
"""配置响应""" """配置响应"""
provider: str provider: str
model: str model: str
api_key: Optional[str] = None # 返回 API Key
base_url: Optional[str] base_url: Optional[str]
dimensions: int dimensions: int
batch_size: int batch_size: int
# 不返回 API Key
class TestEmbeddingRequest(BaseModel): class TestEmbeddingRequest(BaseModel):
@ -165,14 +166,14 @@ async def get_embedding_config_from_db(db: AsyncSession, user_id: str) -> Embedd
select(UserConfig).where(UserConfig.user_id == user_id) select(UserConfig).where(UserConfig.user_id == user_id)
) )
user_config = result.scalar_one_or_none() user_config = result.scalar_one_or_none()
if user_config and user_config.other_config: if user_config and user_config.other_config:
try: try:
other_config = json.loads(user_config.other_config) if isinstance(user_config.other_config, str) else user_config.other_config other_config = json.loads(user_config.other_config) if isinstance(user_config.other_config, str) else user_config.other_config
embedding_data = other_config.get(EMBEDDING_CONFIG_KEY) embedding_data = other_config.get(EMBEDDING_CONFIG_KEY)
if embedding_data: if embedding_data:
return EmbeddingConfig( config = EmbeddingConfig(
provider=embedding_data.get("provider", settings.EMBEDDING_PROVIDER), provider=embedding_data.get("provider", settings.EMBEDDING_PROVIDER),
model=embedding_data.get("model", settings.EMBEDDING_MODEL), model=embedding_data.get("model", settings.EMBEDDING_MODEL),
api_key=embedding_data.get("api_key"), api_key=embedding_data.get("api_key"),
@ -180,10 +181,13 @@ async def get_embedding_config_from_db(db: AsyncSession, user_id: str) -> Embedd
dimensions=embedding_data.get("dimensions"), dimensions=embedding_data.get("dimensions"),
batch_size=embedding_data.get("batch_size", 100), batch_size=embedding_data.get("batch_size", 100),
) )
except (json.JSONDecodeError, AttributeError): print(f"[EmbeddingConfig] 读取用户 {user_id} 的嵌入配置: provider={config.provider}, model={config.model}")
pass return config
except (json.JSONDecodeError, AttributeError) as e:
print(f"[EmbeddingConfig] 解析用户 {user_id} 配置失败: {e}")
# 返回默认配置 # 返回默认配置
print(f"[EmbeddingConfig] 用户 {user_id} 无保存配置,返回默认值")
return EmbeddingConfig( return EmbeddingConfig(
provider=settings.EMBEDDING_PROVIDER, provider=settings.EMBEDDING_PROVIDER,
model=settings.EMBEDDING_MODEL, model=settings.EMBEDDING_MODEL,
@ -199,7 +203,7 @@ async def save_embedding_config_to_db(db: AsyncSession, user_id: str, config: Em
select(UserConfig).where(UserConfig.user_id == user_id) select(UserConfig).where(UserConfig.user_id == user_id)
) )
user_config = result.scalar_one_or_none() user_config = result.scalar_one_or_none()
# 准备嵌入配置数据 # 准备嵌入配置数据
embedding_data = { embedding_data = {
"provider": config.provider, "provider": config.provider,
@ -209,16 +213,18 @@ async def save_embedding_config_to_db(db: AsyncSession, user_id: str, config: Em
"dimensions": config.dimensions, "dimensions": config.dimensions,
"batch_size": config.batch_size, "batch_size": config.batch_size,
} }
if user_config: if user_config:
# 更新现有配置 # 更新现有配置
try: try:
other_config = json.loads(user_config.other_config) if user_config.other_config else {} other_config = json.loads(user_config.other_config) if user_config.other_config else {}
except (json.JSONDecodeError, TypeError): except (json.JSONDecodeError, TypeError):
other_config = {} other_config = {}
other_config[EMBEDDING_CONFIG_KEY] = embedding_data other_config[EMBEDDING_CONFIG_KEY] = embedding_data
user_config.other_config = json.dumps(other_config) user_config.other_config = json.dumps(other_config)
# 🔥 显式标记 other_config 字段已修改,确保 SQLAlchemy 检测到变化
flag_modified(user_config, "other_config")
else: else:
# 创建新配置 # 创建新配置
user_config = UserConfig( user_config = UserConfig(
@ -228,8 +234,9 @@ async def save_embedding_config_to_db(db: AsyncSession, user_id: str, config: Em
other_config=json.dumps({EMBEDDING_CONFIG_KEY: embedding_data}), other_config=json.dumps({EMBEDDING_CONFIG_KEY: embedding_data}),
) )
db.add(user_config) db.add(user_config)
await db.commit() await db.commit()
print(f"[EmbeddingConfig] 已保存用户 {user_id} 的嵌入配置: provider={config.provider}, model={config.model}")
# ============ API Endpoints ============ # ============ API Endpoints ============
@ -253,13 +260,14 @@ async def get_current_config(
获取当前嵌入模型配置从数据库读取 获取当前嵌入模型配置从数据库读取
""" """
config = await get_embedding_config_from_db(db, current_user.id) config = await get_embedding_config_from_db(db, current_user.id)
# 获取维度 # 获取维度
dimensions = _get_model_dimensions(config.provider, config.model) dimensions = _get_model_dimensions(config.provider, config.model)
return EmbeddingConfigResponse( return EmbeddingConfigResponse(
provider=config.provider, provider=config.provider,
model=config.model, model=config.model,
api_key=config.api_key,
base_url=config.base_url, base_url=config.base_url,
dimensions=dimensions, dimensions=dimensions,
batch_size=config.batch_size, batch_size=config.batch_size,
@ -279,19 +287,18 @@ async def update_config(
provider_ids = [p.id for p in EMBEDDING_PROVIDERS] provider_ids = [p.id for p in EMBEDDING_PROVIDERS]
if config.provider not in provider_ids: if config.provider not in provider_ids:
raise HTTPException(status_code=400, detail=f"不支持的提供商: {config.provider}") raise HTTPException(status_code=400, detail=f"不支持的提供商: {config.provider}")
# 验证模型 # 获取提供商信息(用于检查 API Key 要求)
provider = next((p for p in EMBEDDING_PROVIDERS if p.id == config.provider), None) provider = next((p for p in EMBEDDING_PROVIDERS if p.id == config.provider), None)
if provider and config.model not in provider.models: # 注意:不再强制验证模型名称,允许用户输入自定义模型
raise HTTPException(status_code=400, detail=f"不支持的模型: {config.model}")
# 检查 API Key # 检查 API Key
if provider and provider.requires_api_key and not config.api_key: if provider and provider.requires_api_key and not config.api_key:
raise HTTPException(status_code=400, detail=f"{config.provider} 需要 API Key") raise HTTPException(status_code=400, detail=f"{config.provider} 需要 API Key")
# 保存到数据库 # 保存到数据库
await save_embedding_config_to_db(db, current_user.id, config) await save_embedding_config_to_db(db, current_user.id, config)
return {"message": "配置已保存", "provider": config.provider, "model": config.model} return {"message": "配置已保存", "provider": config.provider, "model": config.model}

View File

@ -667,7 +667,8 @@ Action Input: {{"参数": "值"}}
try: try:
while not run_task.done(): while not run_task.done():
if self.is_cancelled: if self.is_cancelled:
# 传播取消到子 Agent # 🔥 传播取消到子 Agent
logger.info(f"[{self.name}] Cancelling sub-agent {agent_name} due to parent cancel")
if hasattr(agent, 'cancel'): if hasattr(agent, 'cancel'):
agent.cancel() agent.cancel()
run_task.cancel() run_task.cancel()
@ -678,17 +679,32 @@ Action Input: {{"参数": "值"}}
raise asyncio.CancelledError("任务已取消") raise asyncio.CancelledError("任务已取消")
try: try:
# 🔥 移除 asyncio.shield(),让取消信号可以直接传播
# 使用较短的超时来更频繁地检查取消状态
return await asyncio.wait_for( return await asyncio.wait_for(
asyncio.shield(run_task), run_task,
timeout=1.0 # 每秒检查一次取消状态 timeout=0.5 # 🔥 0.5秒检查一次取消状态
) )
except asyncio.TimeoutError: except asyncio.TimeoutError:
continue continue
except asyncio.CancelledError:
# 🔥 捕获取消异常确保子Agent也被取消
logger.info(f"[{self.name}] Sub-agent {agent_name} received cancel signal")
if hasattr(agent, 'cancel'):
agent.cancel()
raise
return await run_task return await run_task
except asyncio.CancelledError: except asyncio.CancelledError:
# 🔥 确保子任务被取消
if not run_task.done(): if not run_task.done():
if hasattr(agent, 'cancel'):
agent.cancel()
run_task.cancel() run_task.cancel()
try:
await run_task
except asyncio.CancelledError:
pass
raise raise
try: try:

View File

@ -992,6 +992,8 @@ class CodeIndexer:
indexed_file_hashes = await self.vector_store.get_file_hashes() indexed_file_hashes = await self.vector_store.get_file_hashes()
indexed_files = set(indexed_file_hashes.keys()) indexed_files = set(indexed_file_hashes.keys())
logger.debug(f"📂 已索引文件数: {len(indexed_files)}, file_hashes: {list(indexed_file_hashes.keys())[:5]}...")
# 收集当前文件 # 收集当前文件
current_files = self._collect_files(directory, exclude_patterns, include_patterns) current_files = self._collect_files(directory, exclude_patterns, include_patterns)
current_file_map: Dict[str, str] = {} # relative_path -> absolute_path current_file_map: Dict[str, str] = {} # relative_path -> absolute_path
@ -1002,11 +1004,15 @@ class CodeIndexer:
current_file_set = set(current_file_map.keys()) current_file_set = set(current_file_map.keys())
logger.debug(f"📁 当前文件数: {len(current_file_set)}, 示例: {list(current_file_set)[:5]}...")
# 计算差异 # 计算差异
files_to_add = current_file_set - indexed_files files_to_add = current_file_set - indexed_files
files_to_delete = indexed_files - current_file_set files_to_delete = indexed_files - current_file_set
files_to_check = current_file_set & indexed_files files_to_check = current_file_set & indexed_files
logger.debug(f"📊 差异分析: 交集={len(files_to_check)}, 新增候选={len(files_to_add)}, 删除候选={len(files_to_delete)}")
# 检查需要更新的文件hash 变化) # 检查需要更新的文件hash 变化)
files_to_update: Set[str] = set() files_to_update: Set[str] = set()
for relative_path in files_to_check: for relative_path in files_to_check:

View File

@ -92,7 +92,7 @@ class CodeChunk:
return len(self.content) // 4 return len(self.content) // 4
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
return { result = {
"id": self.id, "id": self.id,
"content": self.content, "content": self.content,
"file_path": self.file_path, "file_path": self.file_path,
@ -110,8 +110,13 @@ class CodeChunk:
"definitions": self.definitions, "definitions": self.definitions,
"security_indicators": self.security_indicators, "security_indicators": self.security_indicators,
"estimated_tokens": self.estimated_tokens, "estimated_tokens": self.estimated_tokens,
"metadata": self.metadata,
} }
# 将 metadata 中的字段提升到顶级,确保 file_hash 等字段可以被正确检索
if self.metadata:
for key, value in self.metadata.items():
if key not in result:
result[key] = value
return result
def to_embedding_text(self) -> str: def to_embedding_text(self) -> str:
"""生成用于嵌入的文本""" """生成用于嵌入的文本"""
@ -244,20 +249,29 @@ class TreeSitterParser:
"""从 AST 提取定义""" """从 AST 提取定义"""
if tree is None: if tree is None:
return [] return []
definitions = [] definitions = []
definition_types = self.DEFINITION_TYPES.get(language, {}) definition_types = self.DEFINITION_TYPES.get(language, {})
def traverse(node, parent_name=None): def traverse(node, parent_name=None):
node_type = node.type node_type = node.type
# 检查是否是定义节点 # 检查是否是定义节点
matched = False
for def_category, types in definition_types.items(): for def_category, types in definition_types.items():
if node_type in types: if node_type in types:
name = self._extract_name(node, language) name = self._extract_name(node, language)
# 根据是否有 parent_name 来区分 function 和 method
actual_category = def_category
if def_category == "function" and parent_name:
actual_category = "method"
elif def_category == "method" and not parent_name:
# 跳过没有 parent 的 method 定义(由 function 类别处理)
continue
definitions.append({ definitions.append({
"type": def_category, "type": actual_category,
"name": name, "name": name,
"parent_name": parent_name, "parent_name": parent_name,
"start_point": node.start_point, "start_point": node.start_point,
@ -266,17 +280,23 @@ class TreeSitterParser:
"end_byte": node.end_byte, "end_byte": node.end_byte,
"node_type": node_type, "node_type": node_type,
}) })
matched = True
# 对于类,继续遍历子节点找方法 # 对于类,继续遍历子节点找方法
if def_category == "class": if def_category == "class":
for child in node.children: for child in node.children:
traverse(child, name) traverse(child, name)
return return
# 继续遍历子节点 # 匹配到一个类别后就不再匹配其他类别
for child in node.children: break
traverse(child, parent_name)
# 如果没有匹配到定义,继续遍历子节点
if not matched:
for child in node.children:
traverse(child, parent_name)
traverse(tree.root_node) traverse(tree.root_node)
return definitions return definitions

View File

@ -80,6 +80,9 @@ services:
- all_proxy= - all_proxy=
- ALL_PROXY= - ALL_PROXY=
restart: unless-stopped restart: unless-stopped
volumes:
- ./frontend/dist:/usr/share/nginx/html:ro # 挂载构建产物,本地 pnpm build 后自动生效
- ./frontend/nginx.conf:/etc/nginx/conf.d/default.conf:ro # 挂载 nginx 配置
ports: ports:
- "3000:80" # Nginx 监听 80 端口 - "3000:80" # Nginx 监听 80 端口
depends_on: depends_on:

View File

@ -46,6 +46,7 @@ interface EmbeddingProvider {
interface EmbeddingConfig { interface EmbeddingConfig {
provider: string; provider: string;
model: string; model: string;
api_key: string | null;
base_url: string | null; base_url: string | null;
dimensions: number; dimensions: number;
batch_size: number; batch_size: number;
@ -79,15 +80,15 @@ export default function EmbeddingConfigPanel() {
loadData(); loadData();
}, []); }, []);
// 当 provider 改变时更新模型 // 用户手动切换 provider 时更新为默认模型
useEffect(() => { const handleProviderChange = (newProvider: string) => {
if (selectedProvider) { setSelectedProvider(newProvider);
const provider = providers.find((p) => p.id === selectedProvider); // 切换 provider 时重置为该 provider 的默认模型
if (provider) { const provider = providers.find((p) => p.id === newProvider);
setSelectedModel(provider.default_model); if (provider) {
} setSelectedModel(provider.default_model);
} }
}, [selectedProvider, providers]); };
const loadData = async () => { const loadData = async () => {
try { try {
@ -104,6 +105,7 @@ export default function EmbeddingConfigPanel() {
if (configRes.data) { if (configRes.data) {
setSelectedProvider(configRes.data.provider); setSelectedProvider(configRes.data.provider);
setSelectedModel(configRes.data.model); setSelectedModel(configRes.data.model);
setApiKey(configRes.data.api_key || "");
setBaseUrl(configRes.data.base_url || ""); setBaseUrl(configRes.data.base_url || "");
setBatchSize(configRes.data.batch_size); setBatchSize(configRes.data.batch_size);
} }
@ -230,7 +232,7 @@ export default function EmbeddingConfigPanel() {
{/* 提供商选择 */} {/* 提供商选择 */}
<div className="space-y-2"> <div className="space-y-2">
<Label className="text-xs font-bold text-gray-500 uppercase"></Label> <Label className="text-xs font-bold text-gray-500 uppercase"></Label>
<Select value={selectedProvider} onValueChange={setSelectedProvider}> <Select value={selectedProvider} onValueChange={handleProviderChange}>
<SelectTrigger className="h-12 cyber-input"> <SelectTrigger className="h-12 cyber-input">
<SelectValue placeholder="选择提供商" /> <SelectValue placeholder="选择提供商" />
</SelectTrigger> </SelectTrigger>