feat(embedding): 支持前端配置嵌入模型的API密钥
refactor(agent): 改进任务取消逻辑,确保子Agent被正确取消 - 移除asyncio.shield()以允许取消信号传播 - 增加更频繁的取消状态检查 - 添加日志记录子Agent取消情况 feat(nginx): 添加前端构建产物和nginx配置的挂载 refactor(rag): 优化代码索引器的日志输出和元数据处理 - 添加索引文件数量的调试日志 - 将元数据字段提升到顶级以便检索 fix(parser): 修复AST定义提取中的方法识别问题 - 区分函数和方法定义 - 优化遍历逻辑避免重复匹配
This commit is contained in:
parent
7efb89d2d2
commit
f71b8da7df
|
|
@ -1584,7 +1584,17 @@ async def cancel_agent_task(
|
||||||
runner.cancel()
|
runner.cancel()
|
||||||
logger.info(f"[Cancel] Set cancel flag for task {task_id}")
|
logger.info(f"[Cancel] Set cancel flag for task {task_id}")
|
||||||
|
|
||||||
# 🔥 2. 强制取消 asyncio Task(立即中断 LLM 调用)
|
# 🔥 2. 通过 agent_registry 取消所有子 Agent
|
||||||
|
from app.services.agent.core import agent_registry
|
||||||
|
from app.services.agent.core.graph_controller import stop_all_agents
|
||||||
|
try:
|
||||||
|
# 停止所有 Agent(包括子 Agent)
|
||||||
|
stop_result = stop_all_agents(exclude_root=False)
|
||||||
|
logger.info(f"[Cancel] Stopped all agents: {stop_result}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[Cancel] Failed to stop agents via registry: {e}")
|
||||||
|
|
||||||
|
# 🔥 3. 强制取消 asyncio Task(立即中断 LLM 调用)
|
||||||
asyncio_task = _running_asyncio_tasks.get(task_id)
|
asyncio_task = _running_asyncio_tasks.get(task_id)
|
||||||
if asyncio_task and not asyncio_task.done():
|
if asyncio_task and not asyncio_task.done():
|
||||||
asyncio_task.cancel()
|
asyncio_task.cancel()
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ from fastapi import APIRouter, Depends, HTTPException
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm.attributes import flag_modified
|
||||||
|
|
||||||
from app.api import deps
|
from app.api import deps
|
||||||
from app.models.user import User
|
from app.models.user import User
|
||||||
|
|
@ -46,10 +47,10 @@ class EmbeddingConfigResponse(BaseModel):
|
||||||
"""配置响应"""
|
"""配置响应"""
|
||||||
provider: str
|
provider: str
|
||||||
model: str
|
model: str
|
||||||
|
api_key: Optional[str] = None # 返回 API Key
|
||||||
base_url: Optional[str]
|
base_url: Optional[str]
|
||||||
dimensions: int
|
dimensions: int
|
||||||
batch_size: int
|
batch_size: int
|
||||||
# 不返回 API Key
|
|
||||||
|
|
||||||
|
|
||||||
class TestEmbeddingRequest(BaseModel):
|
class TestEmbeddingRequest(BaseModel):
|
||||||
|
|
@ -172,7 +173,7 @@ async def get_embedding_config_from_db(db: AsyncSession, user_id: str) -> Embedd
|
||||||
embedding_data = other_config.get(EMBEDDING_CONFIG_KEY)
|
embedding_data = other_config.get(EMBEDDING_CONFIG_KEY)
|
||||||
|
|
||||||
if embedding_data:
|
if embedding_data:
|
||||||
return EmbeddingConfig(
|
config = EmbeddingConfig(
|
||||||
provider=embedding_data.get("provider", settings.EMBEDDING_PROVIDER),
|
provider=embedding_data.get("provider", settings.EMBEDDING_PROVIDER),
|
||||||
model=embedding_data.get("model", settings.EMBEDDING_MODEL),
|
model=embedding_data.get("model", settings.EMBEDDING_MODEL),
|
||||||
api_key=embedding_data.get("api_key"),
|
api_key=embedding_data.get("api_key"),
|
||||||
|
|
@ -180,10 +181,13 @@ async def get_embedding_config_from_db(db: AsyncSession, user_id: str) -> Embedd
|
||||||
dimensions=embedding_data.get("dimensions"),
|
dimensions=embedding_data.get("dimensions"),
|
||||||
batch_size=embedding_data.get("batch_size", 100),
|
batch_size=embedding_data.get("batch_size", 100),
|
||||||
)
|
)
|
||||||
except (json.JSONDecodeError, AttributeError):
|
print(f"[EmbeddingConfig] 读取用户 {user_id} 的嵌入配置: provider={config.provider}, model={config.model}")
|
||||||
pass
|
return config
|
||||||
|
except (json.JSONDecodeError, AttributeError) as e:
|
||||||
|
print(f"[EmbeddingConfig] 解析用户 {user_id} 配置失败: {e}")
|
||||||
|
|
||||||
# 返回默认配置
|
# 返回默认配置
|
||||||
|
print(f"[EmbeddingConfig] 用户 {user_id} 无保存配置,返回默认值")
|
||||||
return EmbeddingConfig(
|
return EmbeddingConfig(
|
||||||
provider=settings.EMBEDDING_PROVIDER,
|
provider=settings.EMBEDDING_PROVIDER,
|
||||||
model=settings.EMBEDDING_MODEL,
|
model=settings.EMBEDDING_MODEL,
|
||||||
|
|
@ -219,6 +223,8 @@ async def save_embedding_config_to_db(db: AsyncSession, user_id: str, config: Em
|
||||||
|
|
||||||
other_config[EMBEDDING_CONFIG_KEY] = embedding_data
|
other_config[EMBEDDING_CONFIG_KEY] = embedding_data
|
||||||
user_config.other_config = json.dumps(other_config)
|
user_config.other_config = json.dumps(other_config)
|
||||||
|
# 🔥 显式标记 other_config 字段已修改,确保 SQLAlchemy 检测到变化
|
||||||
|
flag_modified(user_config, "other_config")
|
||||||
else:
|
else:
|
||||||
# 创建新配置
|
# 创建新配置
|
||||||
user_config = UserConfig(
|
user_config = UserConfig(
|
||||||
|
|
@ -230,6 +236,7 @@ async def save_embedding_config_to_db(db: AsyncSession, user_id: str, config: Em
|
||||||
db.add(user_config)
|
db.add(user_config)
|
||||||
|
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
print(f"[EmbeddingConfig] 已保存用户 {user_id} 的嵌入配置: provider={config.provider}, model={config.model}")
|
||||||
|
|
||||||
|
|
||||||
# ============ API Endpoints ============
|
# ============ API Endpoints ============
|
||||||
|
|
@ -260,6 +267,7 @@ async def get_current_config(
|
||||||
return EmbeddingConfigResponse(
|
return EmbeddingConfigResponse(
|
||||||
provider=config.provider,
|
provider=config.provider,
|
||||||
model=config.model,
|
model=config.model,
|
||||||
|
api_key=config.api_key,
|
||||||
base_url=config.base_url,
|
base_url=config.base_url,
|
||||||
dimensions=dimensions,
|
dimensions=dimensions,
|
||||||
batch_size=config.batch_size,
|
batch_size=config.batch_size,
|
||||||
|
|
@ -280,10 +288,9 @@ async def update_config(
|
||||||
if config.provider not in provider_ids:
|
if config.provider not in provider_ids:
|
||||||
raise HTTPException(status_code=400, detail=f"不支持的提供商: {config.provider}")
|
raise HTTPException(status_code=400, detail=f"不支持的提供商: {config.provider}")
|
||||||
|
|
||||||
# 验证模型
|
# 获取提供商信息(用于检查 API Key 要求)
|
||||||
provider = next((p for p in EMBEDDING_PROVIDERS if p.id == config.provider), None)
|
provider = next((p for p in EMBEDDING_PROVIDERS if p.id == config.provider), None)
|
||||||
if provider and config.model not in provider.models:
|
# 注意:不再强制验证模型名称,允许用户输入自定义模型
|
||||||
raise HTTPException(status_code=400, detail=f"不支持的模型: {config.model}")
|
|
||||||
|
|
||||||
# 检查 API Key
|
# 检查 API Key
|
||||||
if provider and provider.requires_api_key and not config.api_key:
|
if provider and provider.requires_api_key and not config.api_key:
|
||||||
|
|
|
||||||
|
|
@ -667,7 +667,8 @@ Action Input: {{"参数": "值"}}
|
||||||
try:
|
try:
|
||||||
while not run_task.done():
|
while not run_task.done():
|
||||||
if self.is_cancelled:
|
if self.is_cancelled:
|
||||||
# 传播取消到子 Agent
|
# 🔥 传播取消到子 Agent
|
||||||
|
logger.info(f"[{self.name}] Cancelling sub-agent {agent_name} due to parent cancel")
|
||||||
if hasattr(agent, 'cancel'):
|
if hasattr(agent, 'cancel'):
|
||||||
agent.cancel()
|
agent.cancel()
|
||||||
run_task.cancel()
|
run_task.cancel()
|
||||||
|
|
@ -678,17 +679,32 @@ Action Input: {{"参数": "值"}}
|
||||||
raise asyncio.CancelledError("任务已取消")
|
raise asyncio.CancelledError("任务已取消")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# 🔥 移除 asyncio.shield(),让取消信号可以直接传播
|
||||||
|
# 使用较短的超时来更频繁地检查取消状态
|
||||||
return await asyncio.wait_for(
|
return await asyncio.wait_for(
|
||||||
asyncio.shield(run_task),
|
run_task,
|
||||||
timeout=1.0 # 每秒检查一次取消状态
|
timeout=0.5 # 🔥 每0.5秒检查一次取消状态
|
||||||
)
|
)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
continue
|
continue
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# 🔥 捕获取消异常,确保子Agent也被取消
|
||||||
|
logger.info(f"[{self.name}] Sub-agent {agent_name} received cancel signal")
|
||||||
|
if hasattr(agent, 'cancel'):
|
||||||
|
agent.cancel()
|
||||||
|
raise
|
||||||
|
|
||||||
return await run_task
|
return await run_task
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
|
# 🔥 确保子任务被取消
|
||||||
if not run_task.done():
|
if not run_task.done():
|
||||||
|
if hasattr(agent, 'cancel'):
|
||||||
|
agent.cancel()
|
||||||
run_task.cancel()
|
run_task.cancel()
|
||||||
|
try:
|
||||||
|
await run_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
raise
|
raise
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -992,6 +992,8 @@ class CodeIndexer:
|
||||||
indexed_file_hashes = await self.vector_store.get_file_hashes()
|
indexed_file_hashes = await self.vector_store.get_file_hashes()
|
||||||
indexed_files = set(indexed_file_hashes.keys())
|
indexed_files = set(indexed_file_hashes.keys())
|
||||||
|
|
||||||
|
logger.debug(f"📂 已索引文件数: {len(indexed_files)}, file_hashes: {list(indexed_file_hashes.keys())[:5]}...")
|
||||||
|
|
||||||
# 收集当前文件
|
# 收集当前文件
|
||||||
current_files = self._collect_files(directory, exclude_patterns, include_patterns)
|
current_files = self._collect_files(directory, exclude_patterns, include_patterns)
|
||||||
current_file_map: Dict[str, str] = {} # relative_path -> absolute_path
|
current_file_map: Dict[str, str] = {} # relative_path -> absolute_path
|
||||||
|
|
@ -1002,11 +1004,15 @@ class CodeIndexer:
|
||||||
|
|
||||||
current_file_set = set(current_file_map.keys())
|
current_file_set = set(current_file_map.keys())
|
||||||
|
|
||||||
|
logger.debug(f"📁 当前文件数: {len(current_file_set)}, 示例: {list(current_file_set)[:5]}...")
|
||||||
|
|
||||||
# 计算差异
|
# 计算差异
|
||||||
files_to_add = current_file_set - indexed_files
|
files_to_add = current_file_set - indexed_files
|
||||||
files_to_delete = indexed_files - current_file_set
|
files_to_delete = indexed_files - current_file_set
|
||||||
files_to_check = current_file_set & indexed_files
|
files_to_check = current_file_set & indexed_files
|
||||||
|
|
||||||
|
logger.debug(f"📊 差异分析: 交集={len(files_to_check)}, 新增候选={len(files_to_add)}, 删除候选={len(files_to_delete)}")
|
||||||
|
|
||||||
# 检查需要更新的文件(hash 变化)
|
# 检查需要更新的文件(hash 变化)
|
||||||
files_to_update: Set[str] = set()
|
files_to_update: Set[str] = set()
|
||||||
for relative_path in files_to_check:
|
for relative_path in files_to_check:
|
||||||
|
|
|
||||||
|
|
@ -92,7 +92,7 @@ class CodeChunk:
|
||||||
return len(self.content) // 4
|
return len(self.content) // 4
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
return {
|
result = {
|
||||||
"id": self.id,
|
"id": self.id,
|
||||||
"content": self.content,
|
"content": self.content,
|
||||||
"file_path": self.file_path,
|
"file_path": self.file_path,
|
||||||
|
|
@ -110,8 +110,13 @@ class CodeChunk:
|
||||||
"definitions": self.definitions,
|
"definitions": self.definitions,
|
||||||
"security_indicators": self.security_indicators,
|
"security_indicators": self.security_indicators,
|
||||||
"estimated_tokens": self.estimated_tokens,
|
"estimated_tokens": self.estimated_tokens,
|
||||||
"metadata": self.metadata,
|
|
||||||
}
|
}
|
||||||
|
# 将 metadata 中的字段提升到顶级,确保 file_hash 等字段可以被正确检索
|
||||||
|
if self.metadata:
|
||||||
|
for key, value in self.metadata.items():
|
||||||
|
if key not in result:
|
||||||
|
result[key] = value
|
||||||
|
return result
|
||||||
|
|
||||||
def to_embedding_text(self) -> str:
|
def to_embedding_text(self) -> str:
|
||||||
"""生成用于嵌入的文本"""
|
"""生成用于嵌入的文本"""
|
||||||
|
|
@ -252,12 +257,21 @@ class TreeSitterParser:
|
||||||
node_type = node.type
|
node_type = node.type
|
||||||
|
|
||||||
# 检查是否是定义节点
|
# 检查是否是定义节点
|
||||||
|
matched = False
|
||||||
for def_category, types in definition_types.items():
|
for def_category, types in definition_types.items():
|
||||||
if node_type in types:
|
if node_type in types:
|
||||||
name = self._extract_name(node, language)
|
name = self._extract_name(node, language)
|
||||||
|
|
||||||
|
# 根据是否有 parent_name 来区分 function 和 method
|
||||||
|
actual_category = def_category
|
||||||
|
if def_category == "function" and parent_name:
|
||||||
|
actual_category = "method"
|
||||||
|
elif def_category == "method" and not parent_name:
|
||||||
|
# 跳过没有 parent 的 method 定义(由 function 类别处理)
|
||||||
|
continue
|
||||||
|
|
||||||
definitions.append({
|
definitions.append({
|
||||||
"type": def_category,
|
"type": actual_category,
|
||||||
"name": name,
|
"name": name,
|
||||||
"parent_name": parent_name,
|
"parent_name": parent_name,
|
||||||
"start_point": node.start_point,
|
"start_point": node.start_point,
|
||||||
|
|
@ -267,13 +281,19 @@ class TreeSitterParser:
|
||||||
"node_type": node_type,
|
"node_type": node_type,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
matched = True
|
||||||
|
|
||||||
# 对于类,继续遍历子节点找方法
|
# 对于类,继续遍历子节点找方法
|
||||||
if def_category == "class":
|
if def_category == "class":
|
||||||
for child in node.children:
|
for child in node.children:
|
||||||
traverse(child, name)
|
traverse(child, name)
|
||||||
return
|
return
|
||||||
|
|
||||||
# 继续遍历子节点
|
# 匹配到一个类别后就不再匹配其他类别
|
||||||
|
break
|
||||||
|
|
||||||
|
# 如果没有匹配到定义,继续遍历子节点
|
||||||
|
if not matched:
|
||||||
for child in node.children:
|
for child in node.children:
|
||||||
traverse(child, parent_name)
|
traverse(child, parent_name)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -80,6 +80,9 @@ services:
|
||||||
- all_proxy=
|
- all_proxy=
|
||||||
- ALL_PROXY=
|
- ALL_PROXY=
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
volumes:
|
||||||
|
- ./frontend/dist:/usr/share/nginx/html:ro # 挂载构建产物,本地 pnpm build 后自动生效
|
||||||
|
- ./frontend/nginx.conf:/etc/nginx/conf.d/default.conf:ro # 挂载 nginx 配置
|
||||||
ports:
|
ports:
|
||||||
- "3000:80" # Nginx 监听 80 端口
|
- "3000:80" # Nginx 监听 80 端口
|
||||||
depends_on:
|
depends_on:
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,7 @@ interface EmbeddingProvider {
|
||||||
interface EmbeddingConfig {
|
interface EmbeddingConfig {
|
||||||
provider: string;
|
provider: string;
|
||||||
model: string;
|
model: string;
|
||||||
|
api_key: string | null;
|
||||||
base_url: string | null;
|
base_url: string | null;
|
||||||
dimensions: number;
|
dimensions: number;
|
||||||
batch_size: number;
|
batch_size: number;
|
||||||
|
|
@ -79,15 +80,15 @@ export default function EmbeddingConfigPanel() {
|
||||||
loadData();
|
loadData();
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
// 当 provider 改变时更新模型
|
// 用户手动切换 provider 时更新为默认模型
|
||||||
useEffect(() => {
|
const handleProviderChange = (newProvider: string) => {
|
||||||
if (selectedProvider) {
|
setSelectedProvider(newProvider);
|
||||||
const provider = providers.find((p) => p.id === selectedProvider);
|
// 切换 provider 时重置为该 provider 的默认模型
|
||||||
|
const provider = providers.find((p) => p.id === newProvider);
|
||||||
if (provider) {
|
if (provider) {
|
||||||
setSelectedModel(provider.default_model);
|
setSelectedModel(provider.default_model);
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
}, [selectedProvider, providers]);
|
|
||||||
|
|
||||||
const loadData = async () => {
|
const loadData = async () => {
|
||||||
try {
|
try {
|
||||||
|
|
@ -104,6 +105,7 @@ export default function EmbeddingConfigPanel() {
|
||||||
if (configRes.data) {
|
if (configRes.data) {
|
||||||
setSelectedProvider(configRes.data.provider);
|
setSelectedProvider(configRes.data.provider);
|
||||||
setSelectedModel(configRes.data.model);
|
setSelectedModel(configRes.data.model);
|
||||||
|
setApiKey(configRes.data.api_key || "");
|
||||||
setBaseUrl(configRes.data.base_url || "");
|
setBaseUrl(configRes.data.base_url || "");
|
||||||
setBatchSize(configRes.data.batch_size);
|
setBatchSize(configRes.data.batch_size);
|
||||||
}
|
}
|
||||||
|
|
@ -230,7 +232,7 @@ export default function EmbeddingConfigPanel() {
|
||||||
{/* 提供商选择 */}
|
{/* 提供商选择 */}
|
||||||
<div className="space-y-2">
|
<div className="space-y-2">
|
||||||
<Label className="text-xs font-bold text-gray-500 uppercase">嵌入模型提供商</Label>
|
<Label className="text-xs font-bold text-gray-500 uppercase">嵌入模型提供商</Label>
|
||||||
<Select value={selectedProvider} onValueChange={setSelectedProvider}>
|
<Select value={selectedProvider} onValueChange={handleProviderChange}>
|
||||||
<SelectTrigger className="h-12 cyber-input">
|
<SelectTrigger className="h-12 cyber-input">
|
||||||
<SelectValue placeholder="选择提供商" />
|
<SelectValue placeholder="选择提供商" />
|
||||||
</SelectTrigger>
|
</SelectTrigger>
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue