feat(embedding): 支持前端配置嵌入模型的API密钥
refactor(agent): 改进任务取消逻辑,确保子Agent被正确取消 - 移除asyncio.shield()以允许取消信号传播 - 增加更频繁的取消状态检查 - 添加日志记录子Agent取消情况 feat(nginx): 添加前端构建产物和nginx配置的挂载 refactor(rag): 优化代码索引器的日志输出和元数据处理 - 添加索引文件数量的调试日志 - 将元数据字段提升到顶级以便检索 fix(parser): 修复AST定义提取中的方法识别问题 - 区分函数和方法定义 - 优化遍历逻辑避免重复匹配
This commit is contained in:
parent
7efb89d2d2
commit
f71b8da7df
|
|
@ -1583,18 +1583,28 @@ async def cancel_agent_task(
|
|||
if runner:
|
||||
runner.cancel()
|
||||
logger.info(f"[Cancel] Set cancel flag for task {task_id}")
|
||||
|
||||
# 🔥 2. 强制取消 asyncio Task(立即中断 LLM 调用)
|
||||
|
||||
# 🔥 2. 通过 agent_registry 取消所有子 Agent
|
||||
from app.services.agent.core import agent_registry
|
||||
from app.services.agent.core.graph_controller import stop_all_agents
|
||||
try:
|
||||
# 停止所有 Agent(包括子 Agent)
|
||||
stop_result = stop_all_agents(exclude_root=False)
|
||||
logger.info(f"[Cancel] Stopped all agents: {stop_result}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[Cancel] Failed to stop agents via registry: {e}")
|
||||
|
||||
# 🔥 3. 强制取消 asyncio Task(立即中断 LLM 调用)
|
||||
asyncio_task = _running_asyncio_tasks.get(task_id)
|
||||
if asyncio_task and not asyncio_task.done():
|
||||
asyncio_task.cancel()
|
||||
logger.info(f"[Cancel] Cancelled asyncio task for {task_id}")
|
||||
|
||||
|
||||
# 更新状态
|
||||
task.status = AgentTaskStatus.CANCELLED
|
||||
task.completed_at = datetime.now(timezone.utc)
|
||||
await db.commit()
|
||||
|
||||
|
||||
logger.info(f"[Cancel] Task {task_id} cancelled successfully")
|
||||
return {"message": "任务已取消", "task_id": task_id}
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from fastapi import APIRouter, Depends, HTTPException
|
|||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from app.api import deps
|
||||
from app.models.user import User
|
||||
|
|
@ -46,10 +47,10 @@ class EmbeddingConfigResponse(BaseModel):
|
|||
"""配置响应"""
|
||||
provider: str
|
||||
model: str
|
||||
api_key: Optional[str] = None # 返回 API Key
|
||||
base_url: Optional[str]
|
||||
dimensions: int
|
||||
batch_size: int
|
||||
# 不返回 API Key
|
||||
|
||||
|
||||
class TestEmbeddingRequest(BaseModel):
|
||||
|
|
@ -165,14 +166,14 @@ async def get_embedding_config_from_db(db: AsyncSession, user_id: str) -> Embedd
|
|||
select(UserConfig).where(UserConfig.user_id == user_id)
|
||||
)
|
||||
user_config = result.scalar_one_or_none()
|
||||
|
||||
|
||||
if user_config and user_config.other_config:
|
||||
try:
|
||||
other_config = json.loads(user_config.other_config) if isinstance(user_config.other_config, str) else user_config.other_config
|
||||
embedding_data = other_config.get(EMBEDDING_CONFIG_KEY)
|
||||
|
||||
|
||||
if embedding_data:
|
||||
return EmbeddingConfig(
|
||||
config = EmbeddingConfig(
|
||||
provider=embedding_data.get("provider", settings.EMBEDDING_PROVIDER),
|
||||
model=embedding_data.get("model", settings.EMBEDDING_MODEL),
|
||||
api_key=embedding_data.get("api_key"),
|
||||
|
|
@ -180,10 +181,13 @@ async def get_embedding_config_from_db(db: AsyncSession, user_id: str) -> Embedd
|
|||
dimensions=embedding_data.get("dimensions"),
|
||||
batch_size=embedding_data.get("batch_size", 100),
|
||||
)
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
pass
|
||||
|
||||
print(f"[EmbeddingConfig] 读取用户 {user_id} 的嵌入配置: provider={config.provider}, model={config.model}")
|
||||
return config
|
||||
except (json.JSONDecodeError, AttributeError) as e:
|
||||
print(f"[EmbeddingConfig] 解析用户 {user_id} 配置失败: {e}")
|
||||
|
||||
# 返回默认配置
|
||||
print(f"[EmbeddingConfig] 用户 {user_id} 无保存配置,返回默认值")
|
||||
return EmbeddingConfig(
|
||||
provider=settings.EMBEDDING_PROVIDER,
|
||||
model=settings.EMBEDDING_MODEL,
|
||||
|
|
@ -199,7 +203,7 @@ async def save_embedding_config_to_db(db: AsyncSession, user_id: str, config: Em
|
|||
select(UserConfig).where(UserConfig.user_id == user_id)
|
||||
)
|
||||
user_config = result.scalar_one_or_none()
|
||||
|
||||
|
||||
# 准备嵌入配置数据
|
||||
embedding_data = {
|
||||
"provider": config.provider,
|
||||
|
|
@ -209,16 +213,18 @@ async def save_embedding_config_to_db(db: AsyncSession, user_id: str, config: Em
|
|||
"dimensions": config.dimensions,
|
||||
"batch_size": config.batch_size,
|
||||
}
|
||||
|
||||
|
||||
if user_config:
|
||||
# 更新现有配置
|
||||
try:
|
||||
other_config = json.loads(user_config.other_config) if user_config.other_config else {}
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
other_config = {}
|
||||
|
||||
|
||||
other_config[EMBEDDING_CONFIG_KEY] = embedding_data
|
||||
user_config.other_config = json.dumps(other_config)
|
||||
# 🔥 显式标记 other_config 字段已修改,确保 SQLAlchemy 检测到变化
|
||||
flag_modified(user_config, "other_config")
|
||||
else:
|
||||
# 创建新配置
|
||||
user_config = UserConfig(
|
||||
|
|
@ -228,8 +234,9 @@ async def save_embedding_config_to_db(db: AsyncSession, user_id: str, config: Em
|
|||
other_config=json.dumps({EMBEDDING_CONFIG_KEY: embedding_data}),
|
||||
)
|
||||
db.add(user_config)
|
||||
|
||||
|
||||
await db.commit()
|
||||
print(f"[EmbeddingConfig] 已保存用户 {user_id} 的嵌入配置: provider={config.provider}, model={config.model}")
|
||||
|
||||
|
||||
# ============ API Endpoints ============
|
||||
|
|
@ -253,13 +260,14 @@ async def get_current_config(
|
|||
获取当前嵌入模型配置(从数据库读取)
|
||||
"""
|
||||
config = await get_embedding_config_from_db(db, current_user.id)
|
||||
|
||||
|
||||
# 获取维度
|
||||
dimensions = _get_model_dimensions(config.provider, config.model)
|
||||
|
||||
|
||||
return EmbeddingConfigResponse(
|
||||
provider=config.provider,
|
||||
model=config.model,
|
||||
api_key=config.api_key,
|
||||
base_url=config.base_url,
|
||||
dimensions=dimensions,
|
||||
batch_size=config.batch_size,
|
||||
|
|
@ -279,19 +287,18 @@ async def update_config(
|
|||
provider_ids = [p.id for p in EMBEDDING_PROVIDERS]
|
||||
if config.provider not in provider_ids:
|
||||
raise HTTPException(status_code=400, detail=f"不支持的提供商: {config.provider}")
|
||||
|
||||
# 验证模型
|
||||
|
||||
# 获取提供商信息(用于检查 API Key 要求)
|
||||
provider = next((p for p in EMBEDDING_PROVIDERS if p.id == config.provider), None)
|
||||
if provider and config.model not in provider.models:
|
||||
raise HTTPException(status_code=400, detail=f"不支持的模型: {config.model}")
|
||||
|
||||
# 注意:不再强制验证模型名称,允许用户输入自定义模型
|
||||
|
||||
# 检查 API Key
|
||||
if provider and provider.requires_api_key and not config.api_key:
|
||||
raise HTTPException(status_code=400, detail=f"{config.provider} 需要 API Key")
|
||||
|
||||
|
||||
# 保存到数据库
|
||||
await save_embedding_config_to_db(db, current_user.id, config)
|
||||
|
||||
|
||||
return {"message": "配置已保存", "provider": config.provider, "model": config.model}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -667,7 +667,8 @@ Action Input: {{"参数": "值"}}
|
|||
try:
|
||||
while not run_task.done():
|
||||
if self.is_cancelled:
|
||||
# 传播取消到子 Agent
|
||||
# 🔥 传播取消到子 Agent
|
||||
logger.info(f"[{self.name}] Cancelling sub-agent {agent_name} due to parent cancel")
|
||||
if hasattr(agent, 'cancel'):
|
||||
agent.cancel()
|
||||
run_task.cancel()
|
||||
|
|
@ -678,17 +679,32 @@ Action Input: {{"参数": "值"}}
|
|||
raise asyncio.CancelledError("任务已取消")
|
||||
|
||||
try:
|
||||
# 🔥 移除 asyncio.shield(),让取消信号可以直接传播
|
||||
# 使用较短的超时来更频繁地检查取消状态
|
||||
return await asyncio.wait_for(
|
||||
asyncio.shield(run_task),
|
||||
timeout=1.0 # 每秒检查一次取消状态
|
||||
run_task,
|
||||
timeout=0.5 # 🔥 每0.5秒检查一次取消状态
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
# 🔥 捕获取消异常,确保子Agent也被取消
|
||||
logger.info(f"[{self.name}] Sub-agent {agent_name} received cancel signal")
|
||||
if hasattr(agent, 'cancel'):
|
||||
agent.cancel()
|
||||
raise
|
||||
|
||||
return await run_task
|
||||
except asyncio.CancelledError:
|
||||
# 🔥 确保子任务被取消
|
||||
if not run_task.done():
|
||||
if hasattr(agent, 'cancel'):
|
||||
agent.cancel()
|
||||
run_task.cancel()
|
||||
try:
|
||||
await run_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
raise
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -992,6 +992,8 @@ class CodeIndexer:
|
|||
indexed_file_hashes = await self.vector_store.get_file_hashes()
|
||||
indexed_files = set(indexed_file_hashes.keys())
|
||||
|
||||
logger.debug(f"📂 已索引文件数: {len(indexed_files)}, file_hashes: {list(indexed_file_hashes.keys())[:5]}...")
|
||||
|
||||
# 收集当前文件
|
||||
current_files = self._collect_files(directory, exclude_patterns, include_patterns)
|
||||
current_file_map: Dict[str, str] = {} # relative_path -> absolute_path
|
||||
|
|
@ -1002,11 +1004,15 @@ class CodeIndexer:
|
|||
|
||||
current_file_set = set(current_file_map.keys())
|
||||
|
||||
logger.debug(f"📁 当前文件数: {len(current_file_set)}, 示例: {list(current_file_set)[:5]}...")
|
||||
|
||||
# 计算差异
|
||||
files_to_add = current_file_set - indexed_files
|
||||
files_to_delete = indexed_files - current_file_set
|
||||
files_to_check = current_file_set & indexed_files
|
||||
|
||||
logger.debug(f"📊 差异分析: 交集={len(files_to_check)}, 新增候选={len(files_to_add)}, 删除候选={len(files_to_delete)}")
|
||||
|
||||
# 检查需要更新的文件(hash 变化)
|
||||
files_to_update: Set[str] = set()
|
||||
for relative_path in files_to_check:
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ class CodeChunk:
|
|||
return len(self.content) // 4
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
result = {
|
||||
"id": self.id,
|
||||
"content": self.content,
|
||||
"file_path": self.file_path,
|
||||
|
|
@ -110,8 +110,13 @@ class CodeChunk:
|
|||
"definitions": self.definitions,
|
||||
"security_indicators": self.security_indicators,
|
||||
"estimated_tokens": self.estimated_tokens,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
# 将 metadata 中的字段提升到顶级,确保 file_hash 等字段可以被正确检索
|
||||
if self.metadata:
|
||||
for key, value in self.metadata.items():
|
||||
if key not in result:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
def to_embedding_text(self) -> str:
|
||||
"""生成用于嵌入的文本"""
|
||||
|
|
@ -244,20 +249,29 @@ class TreeSitterParser:
|
|||
"""从 AST 提取定义"""
|
||||
if tree is None:
|
||||
return []
|
||||
|
||||
|
||||
definitions = []
|
||||
definition_types = self.DEFINITION_TYPES.get(language, {})
|
||||
|
||||
|
||||
def traverse(node, parent_name=None):
|
||||
node_type = node.type
|
||||
|
||||
|
||||
# 检查是否是定义节点
|
||||
matched = False
|
||||
for def_category, types in definition_types.items():
|
||||
if node_type in types:
|
||||
name = self._extract_name(node, language)
|
||||
|
||||
|
||||
# 根据是否有 parent_name 来区分 function 和 method
|
||||
actual_category = def_category
|
||||
if def_category == "function" and parent_name:
|
||||
actual_category = "method"
|
||||
elif def_category == "method" and not parent_name:
|
||||
# 跳过没有 parent 的 method 定义(由 function 类别处理)
|
||||
continue
|
||||
|
||||
definitions.append({
|
||||
"type": def_category,
|
||||
"type": actual_category,
|
||||
"name": name,
|
||||
"parent_name": parent_name,
|
||||
"start_point": node.start_point,
|
||||
|
|
@ -266,17 +280,23 @@ class TreeSitterParser:
|
|||
"end_byte": node.end_byte,
|
||||
"node_type": node_type,
|
||||
})
|
||||
|
||||
|
||||
matched = True
|
||||
|
||||
# 对于类,继续遍历子节点找方法
|
||||
if def_category == "class":
|
||||
for child in node.children:
|
||||
traverse(child, name)
|
||||
return
|
||||
|
||||
# 继续遍历子节点
|
||||
for child in node.children:
|
||||
traverse(child, parent_name)
|
||||
|
||||
|
||||
# 匹配到一个类别后就不再匹配其他类别
|
||||
break
|
||||
|
||||
# 如果没有匹配到定义,继续遍历子节点
|
||||
if not matched:
|
||||
for child in node.children:
|
||||
traverse(child, parent_name)
|
||||
|
||||
traverse(tree.root_node)
|
||||
return definitions
|
||||
|
||||
|
|
|
|||
|
|
@ -80,6 +80,9 @@ services:
|
|||
- all_proxy=
|
||||
- ALL_PROXY=
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- ./frontend/dist:/usr/share/nginx/html:ro # 挂载构建产物,本地 pnpm build 后自动生效
|
||||
- ./frontend/nginx.conf:/etc/nginx/conf.d/default.conf:ro # 挂载 nginx 配置
|
||||
ports:
|
||||
- "3000:80" # Nginx 监听 80 端口
|
||||
depends_on:
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ interface EmbeddingProvider {
|
|||
interface EmbeddingConfig {
|
||||
provider: string;
|
||||
model: string;
|
||||
api_key: string | null;
|
||||
base_url: string | null;
|
||||
dimensions: number;
|
||||
batch_size: number;
|
||||
|
|
@ -79,15 +80,15 @@ export default function EmbeddingConfigPanel() {
|
|||
loadData();
|
||||
}, []);
|
||||
|
||||
// 当 provider 改变时更新模型
|
||||
useEffect(() => {
|
||||
if (selectedProvider) {
|
||||
const provider = providers.find((p) => p.id === selectedProvider);
|
||||
if (provider) {
|
||||
setSelectedModel(provider.default_model);
|
||||
}
|
||||
// 用户手动切换 provider 时更新为默认模型
|
||||
const handleProviderChange = (newProvider: string) => {
|
||||
setSelectedProvider(newProvider);
|
||||
// 切换 provider 时重置为该 provider 的默认模型
|
||||
const provider = providers.find((p) => p.id === newProvider);
|
||||
if (provider) {
|
||||
setSelectedModel(provider.default_model);
|
||||
}
|
||||
}, [selectedProvider, providers]);
|
||||
};
|
||||
|
||||
const loadData = async () => {
|
||||
try {
|
||||
|
|
@ -104,6 +105,7 @@ export default function EmbeddingConfigPanel() {
|
|||
if (configRes.data) {
|
||||
setSelectedProvider(configRes.data.provider);
|
||||
setSelectedModel(configRes.data.model);
|
||||
setApiKey(configRes.data.api_key || "");
|
||||
setBaseUrl(configRes.data.base_url || "");
|
||||
setBatchSize(configRes.data.batch_size);
|
||||
}
|
||||
|
|
@ -230,7 +232,7 @@ export default function EmbeddingConfigPanel() {
|
|||
{/* 提供商选择 */}
|
||||
<div className="space-y-2">
|
||||
<Label className="text-xs font-bold text-gray-500 uppercase">嵌入模型提供商</Label>
|
||||
<Select value={selectedProvider} onValueChange={setSelectedProvider}>
|
||||
<Select value={selectedProvider} onValueChange={handleProviderChange}>
|
||||
<SelectTrigger className="h-12 cyber-input">
|
||||
<SelectValue placeholder="选择提供商" />
|
||||
</SelectTrigger>
|
||||
|
|
|
|||
Loading…
Reference in New Issue