feat(embedding): 支持前端配置嵌入模型的API密钥

refactor(agent): 改进任务取消逻辑,确保子Agent被正确取消
- 移除asyncio.shield()以允许取消信号传播
- 增加更频繁的取消状态检查
- 添加日志记录子Agent取消情况

feat(nginx): 添加前端构建产物和nginx配置的挂载

refactor(rag): 优化代码索引器的日志输出和元数据处理
- 添加索引文件数量的调试日志
- 将元数据字段提升到顶级以便检索

fix(parser): 修复AST定义提取中的方法识别问题
- 区分函数和方法定义
- 优化遍历逻辑避免重复匹配
This commit is contained in:
lintsinghua 2025-12-16 19:42:44 +08:00
parent 7efb89d2d2
commit f71b8da7df
7 changed files with 113 additions and 49 deletions

View File

@ -1584,7 +1584,17 @@ async def cancel_agent_task(
runner.cancel()
logger.info(f"[Cancel] Set cancel flag for task {task_id}")
# 🔥 2. 强制取消 asyncio Task立即中断 LLM 调用)
# 🔥 2. 通过 agent_registry 取消所有子 Agent
from app.services.agent.core import agent_registry
from app.services.agent.core.graph_controller import stop_all_agents
try:
# 停止所有 Agent包括子 Agent
stop_result = stop_all_agents(exclude_root=False)
logger.info(f"[Cancel] Stopped all agents: {stop_result}")
except Exception as e:
logger.warning(f"[Cancel] Failed to stop agents via registry: {e}")
# 🔥 3. 强制取消 asyncio Task立即中断 LLM 调用)
asyncio_task = _running_asyncio_tasks.get(task_id)
if asyncio_task and not asyncio_task.done():
asyncio_task.cancel()

View File

@ -11,6 +11,7 @@ from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.attributes import flag_modified
from app.api import deps
from app.models.user import User
@ -46,10 +47,10 @@ class EmbeddingConfigResponse(BaseModel):
"""配置响应"""
provider: str
model: str
api_key: Optional[str] = None # 返回 API Key
base_url: Optional[str]
dimensions: int
batch_size: int
# 不返回 API Key
class TestEmbeddingRequest(BaseModel):
@ -172,7 +173,7 @@ async def get_embedding_config_from_db(db: AsyncSession, user_id: str) -> Embedd
embedding_data = other_config.get(EMBEDDING_CONFIG_KEY)
if embedding_data:
return EmbeddingConfig(
config = EmbeddingConfig(
provider=embedding_data.get("provider", settings.EMBEDDING_PROVIDER),
model=embedding_data.get("model", settings.EMBEDDING_MODEL),
api_key=embedding_data.get("api_key"),
@ -180,10 +181,13 @@ async def get_embedding_config_from_db(db: AsyncSession, user_id: str) -> Embedd
dimensions=embedding_data.get("dimensions"),
batch_size=embedding_data.get("batch_size", 100),
)
except (json.JSONDecodeError, AttributeError):
pass
print(f"[EmbeddingConfig] 读取用户 {user_id} 的嵌入配置: provider={config.provider}, model={config.model}")
return config
except (json.JSONDecodeError, AttributeError) as e:
print(f"[EmbeddingConfig] 解析用户 {user_id} 配置失败: {e}")
# 返回默认配置
print(f"[EmbeddingConfig] 用户 {user_id} 无保存配置,返回默认值")
return EmbeddingConfig(
provider=settings.EMBEDDING_PROVIDER,
model=settings.EMBEDDING_MODEL,
@ -219,6 +223,8 @@ async def save_embedding_config_to_db(db: AsyncSession, user_id: str, config: Em
other_config[EMBEDDING_CONFIG_KEY] = embedding_data
user_config.other_config = json.dumps(other_config)
# 🔥 显式标记 other_config 字段已修改,确保 SQLAlchemy 检测到变化
flag_modified(user_config, "other_config")
else:
# 创建新配置
user_config = UserConfig(
@ -230,6 +236,7 @@ async def save_embedding_config_to_db(db: AsyncSession, user_id: str, config: Em
db.add(user_config)
await db.commit()
print(f"[EmbeddingConfig] 已保存用户 {user_id} 的嵌入配置: provider={config.provider}, model={config.model}")
# ============ API Endpoints ============
@ -260,6 +267,7 @@ async def get_current_config(
return EmbeddingConfigResponse(
provider=config.provider,
model=config.model,
api_key=config.api_key,
base_url=config.base_url,
dimensions=dimensions,
batch_size=config.batch_size,
@ -280,10 +288,9 @@ async def update_config(
if config.provider not in provider_ids:
raise HTTPException(status_code=400, detail=f"不支持的提供商: {config.provider}")
# 验证模型
# 获取提供商信息(用于检查 API Key 要求)
provider = next((p for p in EMBEDDING_PROVIDERS if p.id == config.provider), None)
if provider and config.model not in provider.models:
raise HTTPException(status_code=400, detail=f"不支持的模型: {config.model}")
# 注意:不再强制验证模型名称,允许用户输入自定义模型
# 检查 API Key
if provider and provider.requires_api_key and not config.api_key:

View File

@ -667,7 +667,8 @@ Action Input: {{"参数": "值"}}
try:
while not run_task.done():
if self.is_cancelled:
# 传播取消到子 Agent
# 🔥 传播取消到子 Agent
logger.info(f"[{self.name}] Cancelling sub-agent {agent_name} due to parent cancel")
if hasattr(agent, 'cancel'):
agent.cancel()
run_task.cancel()
@ -678,17 +679,32 @@ Action Input: {{"参数": "值"}}
raise asyncio.CancelledError("任务已取消")
try:
# 🔥 移除 asyncio.shield(),让取消信号可以直接传播
# 使用较短的超时来更频繁地检查取消状态
return await asyncio.wait_for(
asyncio.shield(run_task),
timeout=1.0 # 每秒检查一次取消状态
run_task,
timeout=0.5 # 🔥 0.5秒检查一次取消状态
)
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
# 🔥 捕获取消异常确保子Agent也被取消
logger.info(f"[{self.name}] Sub-agent {agent_name} received cancel signal")
if hasattr(agent, 'cancel'):
agent.cancel()
raise
return await run_task
except asyncio.CancelledError:
# 🔥 确保子任务被取消
if not run_task.done():
if hasattr(agent, 'cancel'):
agent.cancel()
run_task.cancel()
try:
await run_task
except asyncio.CancelledError:
pass
raise
try:

View File

@ -992,6 +992,8 @@ class CodeIndexer:
indexed_file_hashes = await self.vector_store.get_file_hashes()
indexed_files = set(indexed_file_hashes.keys())
logger.debug(f"📂 已索引文件数: {len(indexed_files)}, file_hashes: {list(indexed_file_hashes.keys())[:5]}...")
# 收集当前文件
current_files = self._collect_files(directory, exclude_patterns, include_patterns)
current_file_map: Dict[str, str] = {} # relative_path -> absolute_path
@ -1002,11 +1004,15 @@ class CodeIndexer:
current_file_set = set(current_file_map.keys())
logger.debug(f"📁 当前文件数: {len(current_file_set)}, 示例: {list(current_file_set)[:5]}...")
# 计算差异
files_to_add = current_file_set - indexed_files
files_to_delete = indexed_files - current_file_set
files_to_check = current_file_set & indexed_files
logger.debug(f"📊 差异分析: 交集={len(files_to_check)}, 新增候选={len(files_to_add)}, 删除候选={len(files_to_delete)}")
# 检查需要更新的文件hash 变化)
files_to_update: Set[str] = set()
for relative_path in files_to_check:

View File

@ -92,7 +92,7 @@ class CodeChunk:
return len(self.content) // 4
def to_dict(self) -> Dict[str, Any]:
return {
result = {
"id": self.id,
"content": self.content,
"file_path": self.file_path,
@ -110,8 +110,13 @@ class CodeChunk:
"definitions": self.definitions,
"security_indicators": self.security_indicators,
"estimated_tokens": self.estimated_tokens,
"metadata": self.metadata,
}
# 将 metadata 中的字段提升到顶级,确保 file_hash 等字段可以被正确检索
if self.metadata:
for key, value in self.metadata.items():
if key not in result:
result[key] = value
return result
def to_embedding_text(self) -> str:
"""生成用于嵌入的文本"""
@ -252,12 +257,21 @@ class TreeSitterParser:
node_type = node.type
# 检查是否是定义节点
matched = False
for def_category, types in definition_types.items():
if node_type in types:
name = self._extract_name(node, language)
# 根据是否有 parent_name 来区分 function 和 method
actual_category = def_category
if def_category == "function" and parent_name:
actual_category = "method"
elif def_category == "method" and not parent_name:
# 跳过没有 parent 的 method 定义(由 function 类别处理)
continue
definitions.append({
"type": def_category,
"type": actual_category,
"name": name,
"parent_name": parent_name,
"start_point": node.start_point,
@ -267,13 +281,19 @@ class TreeSitterParser:
"node_type": node_type,
})
matched = True
# 对于类,继续遍历子节点找方法
if def_category == "class":
for child in node.children:
traverse(child, name)
return
# 继续遍历子节点
# 匹配到一个类别后就不再匹配其他类别
break
# 如果没有匹配到定义,继续遍历子节点
if not matched:
for child in node.children:
traverse(child, parent_name)

View File

@ -80,6 +80,9 @@ services:
- all_proxy=
- ALL_PROXY=
restart: unless-stopped
volumes:
- ./frontend/dist:/usr/share/nginx/html:ro # 挂载构建产物,本地 pnpm build 后自动生效
- ./frontend/nginx.conf:/etc/nginx/conf.d/default.conf:ro # 挂载 nginx 配置
ports:
- "3000:80" # Nginx 监听 80 端口
depends_on:

View File

@ -46,6 +46,7 @@ interface EmbeddingProvider {
interface EmbeddingConfig {
provider: string;
model: string;
api_key: string | null;
base_url: string | null;
dimensions: number;
batch_size: number;
@ -79,15 +80,15 @@ export default function EmbeddingConfigPanel() {
loadData();
}, []);
// 当 provider 改变时更新模型
useEffect(() => {
if (selectedProvider) {
const provider = providers.find((p) => p.id === selectedProvider);
// 用户手动切换 provider 时更新为默认模型
const handleProviderChange = (newProvider: string) => {
setSelectedProvider(newProvider);
// 切换 provider 时重置为该 provider 的默认模型
const provider = providers.find((p) => p.id === newProvider);
if (provider) {
setSelectedModel(provider.default_model);
}
}
}, [selectedProvider, providers]);
};
const loadData = async () => {
try {
@ -104,6 +105,7 @@ export default function EmbeddingConfigPanel() {
if (configRes.data) {
setSelectedProvider(configRes.data.provider);
setSelectedModel(configRes.data.model);
setApiKey(configRes.data.api_key || "");
setBaseUrl(configRes.data.base_url || "");
setBatchSize(configRes.data.batch_size);
}
@ -230,7 +232,7 @@ export default function EmbeddingConfigPanel() {
{/* 提供商选择 */}
<div className="space-y-2">
<Label className="text-xs font-bold text-gray-500 uppercase"></Label>
<Select value={selectedProvider} onValueChange={setSelectedProvider}>
<Select value={selectedProvider} onValueChange={handleProviderChange}>
<SelectTrigger className="h-12 cyber-input">
<SelectValue placeholder="选择提供商" />
</SelectTrigger>