feat: Introduce structured agent collaboration with `TaskHandoff` and `analysis_v2` agent, updating core agent logic, tools, and audit UI.
This commit is contained in:
parent
8938a8a3c9
commit
70776ee5fd
|
|
@ -28,6 +28,7 @@ from app.models.agent_task import (
|
|||
)
|
||||
from app.models.project import Project
|
||||
from app.models.user import User
|
||||
from app.models.user_config import UserConfig
|
||||
from app.services.agent import AgentRunner, EventManager, run_agent_task
|
||||
from app.services.agent.streaming import StreamHandler, StreamEvent, StreamEventType
|
||||
|
||||
|
|
@ -199,7 +200,7 @@ class TaskSummaryResponse(BaseModel):
|
|||
|
||||
# ============ 后台任务执行 ============
|
||||
|
||||
async def _execute_agent_task(task_id: str, project_root: str):
|
||||
async def _execute_agent_task(task_id: str):
|
||||
"""在后台执行 Agent 任务"""
|
||||
async with async_session_factory() as db:
|
||||
try:
|
||||
|
|
@ -209,14 +210,57 @@ async def _execute_agent_task(task_id: str, project_root: str):
|
|||
logger.error(f"Task {task_id} not found")
|
||||
return
|
||||
|
||||
# 获取项目
|
||||
project = task.project
|
||||
if not project:
|
||||
logger.error(f"Project not found for task {task_id}")
|
||||
return
|
||||
|
||||
# 🔥 获取项目根目录(解压 ZIP 或克隆仓库)
|
||||
project_root = await _get_project_root(project, task_id)
|
||||
|
||||
# 🔥 获取用户配置(从系统配置页面)
|
||||
# 优先级:1. 数据库用户配置 > 2. 环境变量配置
|
||||
user_config = None
|
||||
if task.created_by:
|
||||
from app.api.v1.endpoints.config import (
|
||||
decrypt_config,
|
||||
SENSITIVE_LLM_FIELDS, SENSITIVE_OTHER_FIELDS
|
||||
)
|
||||
import json
|
||||
|
||||
result = await db.execute(
|
||||
select(UserConfig).where(UserConfig.user_id == task.created_by)
|
||||
)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config and config.llm_config:
|
||||
# 🔥 有数据库配置:使用数据库配置(优先)
|
||||
user_llm_config = json.loads(config.llm_config) if config.llm_config else {}
|
||||
user_other_config = json.loads(config.other_config) if config.other_config else {}
|
||||
|
||||
# 解密敏感字段
|
||||
user_llm_config = decrypt_config(user_llm_config, SENSITIVE_LLM_FIELDS)
|
||||
user_other_config = decrypt_config(user_other_config, SENSITIVE_OTHER_FIELDS)
|
||||
|
||||
user_config = {
|
||||
"llmConfig": user_llm_config, # 直接使用数据库配置,不合并默认值
|
||||
"otherConfig": user_other_config,
|
||||
}
|
||||
logger.info(f"✅ Using database user config for task {task_id}, LLM provider: {user_llm_config.get('llmProvider', 'N/A')}")
|
||||
else:
|
||||
# 🔥 无数据库配置:传递 None,让 LLMService 使用环境变量
|
||||
user_config = None
|
||||
logger.info(f"⚠️ No database config found for user {task.created_by}, will use environment variables for task {task_id}")
|
||||
|
||||
# 更新状态为运行中
|
||||
task.status = AgentTaskStatus.RUNNING
|
||||
task.started_at = datetime.now(timezone.utc)
|
||||
await db.commit()
|
||||
logger.info(f"Task {task_id} started")
|
||||
|
||||
# 创建 Runner
|
||||
runner = AgentRunner(db, task, project_root)
|
||||
# 创建 Runner(传入用户配置)
|
||||
runner = AgentRunner(db, task, project_root, user_config=user_config)
|
||||
_running_tasks[task_id] = runner
|
||||
|
||||
# 执行
|
||||
|
|
@ -296,11 +340,8 @@ async def create_agent_task(
|
|||
await db.commit()
|
||||
await db.refresh(task)
|
||||
|
||||
# 确定项目根目录
|
||||
project_root = _get_project_root(project, task.id)
|
||||
|
||||
# 在后台启动任务
|
||||
background_tasks.add_task(_execute_agent_task, task.id, project_root)
|
||||
# 在后台启动任务(项目根目录在任务内部获取)
|
||||
background_tasks.add_task(_execute_agent_task, task.id)
|
||||
|
||||
logger.info(f"Created agent task {task.id} for project {project.name}")
|
||||
|
||||
|
|
@ -897,24 +938,73 @@ async def update_finding_status(
|
|||
|
||||
# ============ Helper Functions ============
|
||||
|
||||
def _get_project_root(project: Project, task_id: str) -> str:
|
||||
async def _get_project_root(project: Project, task_id: str) -> str:
|
||||
"""
|
||||
获取项目根目录
|
||||
|
||||
TODO: 实际实现中需要:
|
||||
- 对于 ZIP 项目:解压到临时目录
|
||||
- 对于 Git 仓库:克隆到临时目录
|
||||
支持两种项目类型:
|
||||
- ZIP 项目:解压 ZIP 文件到临时目录
|
||||
- 仓库项目:克隆仓库到临时目录
|
||||
"""
|
||||
import zipfile
|
||||
import subprocess
|
||||
|
||||
base_path = f"/tmp/deepaudit/{task_id}"
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
|
||||
# 如果项目有存储路径,复制过来
|
||||
if hasattr(project, 'storage_path') and project.storage_path:
|
||||
if os.path.exists(project.storage_path):
|
||||
# 复制项目文件
|
||||
shutil.copytree(project.storage_path, base_path, dirs_exist_ok=True)
|
||||
# 根据项目类型处理
|
||||
if project.source_type == "zip":
|
||||
# 🔥 ZIP 项目:解压 ZIP 文件
|
||||
from app.services.zip_storage import load_project_zip
|
||||
|
||||
zip_path = await load_project_zip(project.id)
|
||||
|
||||
if zip_path and os.path.exists(zip_path):
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(base_path)
|
||||
logger.info(f"✅ Extracted ZIP project {project.id} to {base_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to extract ZIP {zip_path}: {e}")
|
||||
else:
|
||||
logger.warning(f"⚠️ ZIP file not found for project {project.id}")
|
||||
|
||||
elif project.source_type == "repository" and project.repository_url:
|
||||
# 🔥 仓库项目:克隆仓库
|
||||
try:
|
||||
branch = project.default_branch or "main"
|
||||
repo_url = project.repository_url
|
||||
|
||||
# 克隆仓库
|
||||
result = subprocess.run(
|
||||
["git", "clone", "--depth", "1", "--branch", branch, repo_url, base_path],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300,
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
logger.info(f"✅ Cloned repository {repo_url} (branch: {branch}) to {base_path}")
|
||||
else:
|
||||
logger.warning(f"Failed to clone branch {branch}, trying default branch: {result.stderr}")
|
||||
# 如果克隆失败,尝试使用默认分支
|
||||
if branch != "main":
|
||||
result = subprocess.run(
|
||||
["git", "clone", "--depth", "1", repo_url, base_path],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
logger.info(f"✅ Cloned repository {repo_url} (default branch) to {base_path}")
|
||||
else:
|
||||
logger.error(f"Failed to clone repository: {result.stderr}")
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error(f"Git clone timeout for {project.repository_url}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clone repository {project.repository_url}: {e}")
|
||||
|
||||
return base_path
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,13 @@
|
|||
"""
|
||||
混合 Agent 架构
|
||||
包含 Orchestrator、Recon、Analysis 和 Verification Agent
|
||||
|
||||
协作机制:
|
||||
- Agent 之间通过 TaskHandoff 传递结构化上下文
|
||||
- 每个 Agent 完成后生成 handoff 给下一个 Agent
|
||||
"""
|
||||
|
||||
from .base import BaseAgent, AgentConfig, AgentResult
|
||||
from .base import BaseAgent, AgentConfig, AgentResult, TaskHandoff
|
||||
from .orchestrator import OrchestratorAgent
|
||||
from .recon import ReconAgent
|
||||
from .analysis import AnalysisAgent
|
||||
|
|
@ -13,6 +17,7 @@ __all__ = [
|
|||
"BaseAgent",
|
||||
"AgentConfig",
|
||||
"AgentResult",
|
||||
"TaskHandoff",
|
||||
"OrchestratorAgent",
|
||||
"ReconAgent",
|
||||
"AnalysisAgent",
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ LLM 是真正的安全分析大脑!
|
|||
类型: ReAct (真正的!)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
|
@ -17,6 +18,7 @@ from typing import List, Dict, Any, Optional
|
|||
from dataclasses import dataclass
|
||||
|
||||
from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
|
||||
from ..json_parser import AgentJsonParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -33,18 +35,13 @@ ANALYSIS_SYSTEM_PROMPT = """你是 DeepAudit 的漏洞分析 Agent,一个**自
|
|||
|
||||
## 你可以使用的工具
|
||||
|
||||
### 外部扫描工具
|
||||
- **semgrep_scan**: Semgrep 静态分析(推荐首先使用)
|
||||
参数: rules (str), max_results (int)
|
||||
- **bandit_scan**: Python 安全扫描
|
||||
|
||||
### RAG 语义搜索
|
||||
- **rag_query**: 语义代码搜索
|
||||
参数: query (str), top_k (int)
|
||||
- **security_search**: 安全相关代码搜索
|
||||
参数: vulnerability_type (str), top_k (int)
|
||||
- **function_context**: 函数上下文分析
|
||||
参数: function_name (str)
|
||||
### 文件操作
|
||||
- **read_file**: 读取文件内容
|
||||
参数: file_path (str), start_line (int), end_line (int)
|
||||
- **list_files**: 列出目录文件
|
||||
参数: directory (str), pattern (str)
|
||||
- **search_code**: 代码关键字搜索
|
||||
参数: keyword (str), max_results (int)
|
||||
|
||||
### 深度分析
|
||||
- **pattern_match**: 危险模式匹配
|
||||
|
|
@ -53,16 +50,28 @@ ANALYSIS_SYSTEM_PROMPT = """你是 DeepAudit 的漏洞分析 Agent,一个**自
|
|||
参数: code (str), file_path (str), focus (str)
|
||||
- **dataflow_analysis**: 数据流追踪
|
||||
参数: source (str), sink (str)
|
||||
- **vulnerability_validation**: 漏洞验证
|
||||
参数: code (str), vulnerability_type (str)
|
||||
|
||||
### 文件操作
|
||||
- **read_file**: 读取文件内容
|
||||
参数: file_path (str), start_line (int), end_line (int)
|
||||
- **search_code**: 代码关键字搜索
|
||||
参数: keyword (str), max_results (int)
|
||||
- **list_files**: 列出目录文件
|
||||
参数: directory (str), pattern (str)
|
||||
### 外部静态分析工具
|
||||
- **semgrep_scan**: Semgrep 静态分析(推荐首先使用)
|
||||
参数: rules (str), max_results (int)
|
||||
- **bandit_scan**: Python 安全扫描
|
||||
参数: target (str)
|
||||
- **gitleaks_scan**: Git 密钥泄露扫描
|
||||
参数: target (str)
|
||||
- **trufflehog_scan**: 敏感信息扫描
|
||||
参数: target (str)
|
||||
- **npm_audit**: NPM 依赖漏洞扫描
|
||||
参数: target (str)
|
||||
- **safety_scan**: Python 依赖安全扫描
|
||||
参数: target (str)
|
||||
- **osv_scan**: OSV 漏洞数据库扫描
|
||||
参数: target (str)
|
||||
|
||||
### RAG 语义搜索
|
||||
- **security_search**: 安全相关代码搜索
|
||||
参数: vulnerability_type (str), top_k (int)
|
||||
- **function_context**: 函数上下文分析
|
||||
参数: function_name (str)
|
||||
|
||||
## 工作方式
|
||||
每一步,你需要输出:
|
||||
|
|
@ -168,15 +177,7 @@ class AnalysisAgent(BaseAgent):
|
|||
self._conversation_history: List[Dict[str, str]] = []
|
||||
self._steps: List[AnalysisStep] = []
|
||||
|
||||
def _get_tools_description(self) -> str:
|
||||
"""生成工具描述"""
|
||||
tools_info = []
|
||||
for name, tool in self.tools.items():
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
desc = f"- {name}: {getattr(tool, 'description', 'No description')}"
|
||||
tools_info.append(desc)
|
||||
return "\n".join(tools_info)
|
||||
|
||||
|
||||
def _parse_llm_response(self, response: str) -> AnalysisStep:
|
||||
"""解析 LLM 响应"""
|
||||
|
|
@ -191,13 +192,20 @@ class AnalysisAgent(BaseAgent):
|
|||
final_match = re.search(r'Final Answer:\s*(.*?)$', response, re.DOTALL)
|
||||
if final_match:
|
||||
step.is_final = True
|
||||
try:
|
||||
answer_text = final_match.group(1).strip()
|
||||
answer_text = re.sub(r'```json\s*', '', answer_text)
|
||||
answer_text = re.sub(r'```\s*', '', answer_text)
|
||||
step.final_answer = json.loads(answer_text)
|
||||
except json.JSONDecodeError:
|
||||
step.final_answer = {"findings": [], "raw_answer": final_match.group(1).strip()}
|
||||
# 使用增强的 JSON 解析器
|
||||
step.final_answer = AgentJsonParser.parse(
|
||||
answer_text,
|
||||
default={"findings": [], "raw_answer": answer_text}
|
||||
)
|
||||
# 确保 findings 格式正确
|
||||
if "findings" in step.final_answer:
|
||||
step.final_answer["findings"] = [
|
||||
f for f in step.final_answer["findings"]
|
||||
if isinstance(f, dict)
|
||||
]
|
||||
return step
|
||||
|
||||
# 提取 Action
|
||||
|
|
@ -211,51 +219,15 @@ class AnalysisAgent(BaseAgent):
|
|||
input_text = input_match.group(1).strip()
|
||||
input_text = re.sub(r'```json\s*', '', input_text)
|
||||
input_text = re.sub(r'```\s*', '', input_text)
|
||||
try:
|
||||
step.action_input = json.loads(input_text)
|
||||
except json.JSONDecodeError:
|
||||
step.action_input = {"raw_input": input_text}
|
||||
# 使用增强的 JSON 解析器
|
||||
step.action_input = AgentJsonParser.parse(
|
||||
input_text,
|
||||
default={"raw_input": input_text}
|
||||
)
|
||||
|
||||
return step
|
||||
|
||||
async def _execute_tool(self, tool_name: str, tool_input: Dict) -> str:
|
||||
"""执行工具"""
|
||||
tool = self.tools.get(tool_name)
|
||||
|
||||
if not tool:
|
||||
return f"错误: 工具 '{tool_name}' 不存在。可用工具: {list(self.tools.keys())}"
|
||||
|
||||
try:
|
||||
self._tool_calls += 1
|
||||
await self.emit_tool_call(tool_name, tool_input)
|
||||
|
||||
import time
|
||||
start = time.time()
|
||||
|
||||
result = await tool.execute(**tool_input)
|
||||
|
||||
duration_ms = int((time.time() - start) * 1000)
|
||||
await self.emit_tool_result(tool_name, str(result.data)[:200], duration_ms)
|
||||
|
||||
if result.success:
|
||||
output = str(result.data)
|
||||
|
||||
# 如果是代码分析工具,也包含 metadata
|
||||
if result.metadata:
|
||||
if "issues" in result.metadata:
|
||||
output += f"\n\n发现的问题:\n{json.dumps(result.metadata['issues'], ensure_ascii=False, indent=2)}"
|
||||
if "findings" in result.metadata:
|
||||
output += f"\n\n发现:\n{json.dumps(result.metadata['findings'][:10], ensure_ascii=False, indent=2)}"
|
||||
|
||||
if len(output) > 6000:
|
||||
output = output[:6000] + f"\n\n... [输出已截断,共 {len(str(result.data))} 字符]"
|
||||
return output
|
||||
else:
|
||||
return f"工具执行失败: {result.error}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tool execution error: {e}")
|
||||
return f"工具执行错误: {str(e)}"
|
||||
|
||||
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
|
||||
"""
|
||||
|
|
@ -271,6 +243,14 @@ class AnalysisAgent(BaseAgent):
|
|||
task = input_data.get("task", "")
|
||||
task_context = input_data.get("task_context", "")
|
||||
|
||||
# 🔥 处理交接信息
|
||||
handoff = input_data.get("handoff")
|
||||
if handoff:
|
||||
from .base import TaskHandoff
|
||||
if isinstance(handoff, dict):
|
||||
handoff = TaskHandoff.from_dict(handoff)
|
||||
self.receive_handoff(handoff)
|
||||
|
||||
# 从 Recon 结果获取上下文
|
||||
recon_data = previous_results.get("recon", {})
|
||||
if isinstance(recon_data, dict) and "data" in recon_data:
|
||||
|
|
@ -281,7 +261,9 @@ class AnalysisAgent(BaseAgent):
|
|||
high_risk_areas = recon_data.get("high_risk_areas", plan.get("high_risk_areas", []))
|
||||
initial_findings = recon_data.get("initial_findings", [])
|
||||
|
||||
# 构建初始消息
|
||||
# 🔥 构建包含交接上下文的初始消息
|
||||
handoff_context = self.get_handoff_context()
|
||||
|
||||
initial_message = f"""请开始对项目进行安全漏洞分析。
|
||||
|
||||
## 项目信息
|
||||
|
|
@ -289,7 +271,7 @@ class AnalysisAgent(BaseAgent):
|
|||
- 语言: {tech_stack.get('languages', [])}
|
||||
- 框架: {tech_stack.get('frameworks', [])}
|
||||
|
||||
## 上下文信息
|
||||
{handoff_context if handoff_context else f'''## 上下文信息
|
||||
### 高风险区域
|
||||
{json.dumps(high_risk_areas[:20], ensure_ascii=False)}
|
||||
|
||||
|
|
@ -297,7 +279,7 @@ class AnalysisAgent(BaseAgent):
|
|||
{json.dumps(entry_points[:10], ensure_ascii=False, indent=2)}
|
||||
|
||||
### 初步发现 (如果有)
|
||||
{json.dumps(initial_findings[:5], ensure_ascii=False, indent=2) if initial_findings else '无'}
|
||||
{json.dumps(initial_findings[:5], ensure_ascii=False, indent=2) if initial_findings else "无"}'''}
|
||||
|
||||
## 任务
|
||||
{task_context or task or '进行全面的安全漏洞分析,发现代码中的安全问题。'}
|
||||
|
|
@ -306,10 +288,13 @@ class AnalysisAgent(BaseAgent):
|
|||
{config.get('target_vulnerabilities', ['all'])}
|
||||
|
||||
## 可用工具
|
||||
{self._get_tools_description()}
|
||||
{self.get_tools_description()}
|
||||
|
||||
请开始你的安全分析。首先思考分析策略,然后选择合适的工具开始分析。"""
|
||||
|
||||
# 🔥 记录工作开始
|
||||
self.record_work("开始安全漏洞分析")
|
||||
|
||||
# 初始化对话历史
|
||||
self._conversation_history = [
|
||||
{"role": "system", "content": self.config.system_prompt},
|
||||
|
|
@ -328,18 +313,22 @@ class AnalysisAgent(BaseAgent):
|
|||
|
||||
self._iteration = iteration + 1
|
||||
|
||||
# 🔥 发射 LLM 开始思考事件
|
||||
await self.emit_llm_start(iteration + 1)
|
||||
# 🔥 再次检查取消标志(在LLM调用之前)
|
||||
if self.is_cancelled:
|
||||
await self.emit_thinking("🛑 任务已取消,停止执行")
|
||||
break
|
||||
|
||||
# 🔥 调用 LLM 进行思考和决策
|
||||
response = await self.llm_service.chat_completion_raw(
|
||||
messages=self._conversation_history,
|
||||
# 调用 LLM 进行思考和决策(流式输出)
|
||||
try:
|
||||
llm_output, tokens_this_round = await self.stream_llm_call(
|
||||
self._conversation_history,
|
||||
temperature=0.1,
|
||||
max_tokens=2048,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[{self.name}] LLM call cancelled")
|
||||
break
|
||||
|
||||
llm_output = response.get("content", "")
|
||||
tokens_this_round = response.get("usage", {}).get("total_tokens", 0)
|
||||
self._total_tokens += tokens_this_round
|
||||
|
||||
# 解析 LLM 响应
|
||||
|
|
@ -369,6 +358,14 @@ class AnalysisAgent(BaseAgent):
|
|||
finding.get("vulnerability_type", "other"),
|
||||
finding.get("file_path", "")
|
||||
)
|
||||
# 🔥 记录洞察
|
||||
self.add_insight(
|
||||
f"发现 {finding.get('severity', 'medium')} 级别漏洞: {finding.get('title', 'Unknown')}"
|
||||
)
|
||||
|
||||
# 🔥 记录工作完成
|
||||
self.record_work(f"完成安全分析,发现 {len(all_findings)} 个潜在漏洞")
|
||||
|
||||
await self.emit_llm_complete(
|
||||
f"分析完成,发现 {len(all_findings)} 个潜在漏洞",
|
||||
self._total_tokens
|
||||
|
|
@ -380,7 +377,7 @@ class AnalysisAgent(BaseAgent):
|
|||
# 🔥 发射 LLM 动作决策事件
|
||||
await self.emit_llm_action(step.action, step.action_input or {})
|
||||
|
||||
observation = await self._execute_tool(
|
||||
observation = await self.execute_tool(
|
||||
step.action,
|
||||
step.action_input or {}
|
||||
)
|
||||
|
|
@ -427,7 +424,7 @@ class AnalysisAgent(BaseAgent):
|
|||
|
||||
await self.emit_event(
|
||||
"info",
|
||||
f"🎯 Analysis Agent 完成: {len(standardized_findings)} 个发现, {self._iteration} 轮迭代, {self._tool_calls} 次工具调用"
|
||||
f"Analysis Agent 完成: {len(standardized_findings)} 个发现, {self._iteration} 轮迭代, {self._tool_calls} 次工具调用"
|
||||
)
|
||||
|
||||
return AgentResult(
|
||||
|
|
|
|||
|
|
@ -2,14 +2,20 @@
|
|||
Agent 基类
|
||||
定义 Agent 的基本接口和通用功能
|
||||
|
||||
核心原则:LLM 是 Agent 的大脑,所有日志应该反映 LLM 的参与!
|
||||
核心原则:
|
||||
1. LLM 是 Agent 的大脑,全程参与决策
|
||||
2. Agent 之间通过 TaskHandoff 传递结构化上下文
|
||||
3. 事件分为流式事件(前端展示)和持久化事件(数据库记录)
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from datetime import datetime, timezone
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -73,6 +79,9 @@ class AgentResult:
|
|||
# 元数据
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# 🔥 协作信息 - Agent 传递给下一个 Agent 的结构化信息
|
||||
handoff: Optional["TaskHandoff"] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"success": self.success,
|
||||
|
|
@ -83,9 +92,139 @@ class AgentResult:
|
|||
"tokens_used": self.tokens_used,
|
||||
"duration_ms": self.duration_ms,
|
||||
"metadata": self.metadata,
|
||||
"handoff": self.handoff.to_dict() if self.handoff else None,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskHandoff:
|
||||
"""
|
||||
任务交接协议 - Agent 之间传递的结构化信息
|
||||
|
||||
设计原则:
|
||||
1. 包含足够的上下文让下一个 Agent 理解前序工作
|
||||
2. 提供明确的建议和关注点
|
||||
3. 可直接转换为 LLM 可理解的 prompt
|
||||
"""
|
||||
# 基本信息
|
||||
from_agent: str
|
||||
to_agent: str
|
||||
|
||||
# 工作摘要
|
||||
summary: str
|
||||
work_completed: List[str] = field(default_factory=list)
|
||||
|
||||
# 关键发现和洞察
|
||||
key_findings: List[Dict[str, Any]] = field(default_factory=list)
|
||||
insights: List[str] = field(default_factory=list)
|
||||
|
||||
# 建议和关注点
|
||||
suggested_actions: List[Dict[str, Any]] = field(default_factory=list)
|
||||
attention_points: List[str] = field(default_factory=list)
|
||||
priority_areas: List[str] = field(default_factory=list)
|
||||
|
||||
# 上下文数据
|
||||
context_data: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# 置信度
|
||||
confidence: float = 0.8
|
||||
|
||||
# 时间戳
|
||||
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"from_agent": self.from_agent,
|
||||
"to_agent": self.to_agent,
|
||||
"summary": self.summary,
|
||||
"work_completed": self.work_completed,
|
||||
"key_findings": self.key_findings,
|
||||
"insights": self.insights,
|
||||
"suggested_actions": self.suggested_actions,
|
||||
"attention_points": self.attention_points,
|
||||
"priority_areas": self.priority_areas,
|
||||
"context_data": self.context_data,
|
||||
"confidence": self.confidence,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "TaskHandoff":
|
||||
return cls(
|
||||
from_agent=data.get("from_agent", ""),
|
||||
to_agent=data.get("to_agent", ""),
|
||||
summary=data.get("summary", ""),
|
||||
work_completed=data.get("work_completed", []),
|
||||
key_findings=data.get("key_findings", []),
|
||||
insights=data.get("insights", []),
|
||||
suggested_actions=data.get("suggested_actions", []),
|
||||
attention_points=data.get("attention_points", []),
|
||||
priority_areas=data.get("priority_areas", []),
|
||||
context_data=data.get("context_data", {}),
|
||||
confidence=data.get("confidence", 0.8),
|
||||
)
|
||||
|
||||
def to_prompt_context(self) -> str:
|
||||
"""
|
||||
转换为 LLM 可理解的上下文格式
|
||||
这是关键!让 LLM 能够理解前序 Agent 的工作
|
||||
"""
|
||||
lines = [
|
||||
f"## 来自 {self.from_agent} Agent 的任务交接",
|
||||
"",
|
||||
f"### 工作摘要",
|
||||
self.summary,
|
||||
"",
|
||||
]
|
||||
|
||||
if self.work_completed:
|
||||
lines.append("### 已完成的工作")
|
||||
for work in self.work_completed:
|
||||
lines.append(f"- {work}")
|
||||
lines.append("")
|
||||
|
||||
if self.key_findings:
|
||||
lines.append("### 关键发现")
|
||||
for i, finding in enumerate(self.key_findings[:15], 1):
|
||||
severity = finding.get("severity", "medium")
|
||||
title = finding.get("title", "Unknown")
|
||||
file_path = finding.get("file_path", "")
|
||||
lines.append(f"{i}. [{severity.upper()}] {title}")
|
||||
if file_path:
|
||||
lines.append(f" 位置: {file_path}:{finding.get('line_start', '')}")
|
||||
if finding.get("description"):
|
||||
lines.append(f" 描述: {finding['description'][:100]}")
|
||||
lines.append("")
|
||||
|
||||
if self.insights:
|
||||
lines.append("### 洞察和分析")
|
||||
for insight in self.insights:
|
||||
lines.append(f"- {insight}")
|
||||
lines.append("")
|
||||
|
||||
if self.suggested_actions:
|
||||
lines.append("### 建议的下一步行动")
|
||||
for action in self.suggested_actions:
|
||||
action_type = action.get("type", "general")
|
||||
description = action.get("description", "")
|
||||
priority = action.get("priority", "medium")
|
||||
lines.append(f"- [{priority.upper()}] {action_type}: {description}")
|
||||
lines.append("")
|
||||
|
||||
if self.attention_points:
|
||||
lines.append("### ⚠️ 需要特别关注")
|
||||
for point in self.attention_points:
|
||||
lines.append(f"- {point}")
|
||||
lines.append("")
|
||||
|
||||
if self.priority_areas:
|
||||
lines.append("### 优先分析区域")
|
||||
for area in self.priority_areas:
|
||||
lines.append(f"- {area}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class BaseAgent(ABC):
|
||||
"""
|
||||
Agent 基类
|
||||
|
|
@ -94,6 +233,11 @@ class BaseAgent(ABC):
|
|||
1. LLM 是 Agent 的大脑,全程参与决策
|
||||
2. 所有日志应该反映 LLM 的思考过程
|
||||
3. 工具调用是 LLM 的决策结果
|
||||
|
||||
协作原则:
|
||||
1. 通过 TaskHandoff 接收前序 Agent 的上下文
|
||||
2. 执行完成后生成 TaskHandoff 传递给下一个 Agent
|
||||
3. 洞察和发现应该结构化记录
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -123,6 +267,11 @@ class BaseAgent(ABC):
|
|||
self._tool_calls = 0
|
||||
self._cancelled = False
|
||||
|
||||
# 🔥 协作状态
|
||||
self._incoming_handoff: Optional[TaskHandoff] = None
|
||||
self._insights: List[str] = [] # 收集的洞察
|
||||
self._work_completed: List[str] = [] # 完成的工作记录
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.config.name
|
||||
|
|
@ -152,6 +301,103 @@ class BaseAgent(ABC):
|
|||
def is_cancelled(self) -> bool:
|
||||
return self._cancelled
|
||||
|
||||
# ============ 协作方法 ============
|
||||
|
||||
def receive_handoff(self, handoff: TaskHandoff):
|
||||
"""
|
||||
接收来自前序 Agent 的任务交接
|
||||
|
||||
Args:
|
||||
handoff: 任务交接对象
|
||||
"""
|
||||
self._incoming_handoff = handoff
|
||||
logger.info(
|
||||
f"[{self.name}] Received handoff from {handoff.from_agent}: "
|
||||
f"{handoff.summary[:50]}..."
|
||||
)
|
||||
|
||||
def get_handoff_context(self) -> str:
|
||||
"""
|
||||
获取交接上下文(用于构建 LLM prompt)
|
||||
|
||||
Returns:
|
||||
格式化的上下文字符串
|
||||
"""
|
||||
if not self._incoming_handoff:
|
||||
return ""
|
||||
return self._incoming_handoff.to_prompt_context()
|
||||
|
||||
def add_insight(self, insight: str):
|
||||
"""记录洞察"""
|
||||
self._insights.append(insight)
|
||||
|
||||
def record_work(self, work: str):
|
||||
"""记录完成的工作"""
|
||||
self._work_completed.append(work)
|
||||
|
||||
def create_handoff(
|
||||
self,
|
||||
to_agent: str,
|
||||
summary: str,
|
||||
key_findings: List[Dict[str, Any]] = None,
|
||||
suggested_actions: List[Dict[str, Any]] = None,
|
||||
attention_points: List[str] = None,
|
||||
priority_areas: List[str] = None,
|
||||
context_data: Dict[str, Any] = None,
|
||||
) -> TaskHandoff:
|
||||
"""
|
||||
创建任务交接
|
||||
|
||||
Args:
|
||||
to_agent: 目标 Agent
|
||||
summary: 工作摘要
|
||||
key_findings: 关键发现
|
||||
suggested_actions: 建议的行动
|
||||
attention_points: 需要关注的点
|
||||
priority_areas: 优先分析区域
|
||||
context_data: 上下文数据
|
||||
|
||||
Returns:
|
||||
TaskHandoff 对象
|
||||
"""
|
||||
return TaskHandoff(
|
||||
from_agent=self.name,
|
||||
to_agent=to_agent,
|
||||
summary=summary,
|
||||
work_completed=self._work_completed.copy(),
|
||||
key_findings=key_findings or [],
|
||||
insights=self._insights.copy(),
|
||||
suggested_actions=suggested_actions or [],
|
||||
attention_points=attention_points or [],
|
||||
priority_areas=priority_areas or [],
|
||||
context_data=context_data or {},
|
||||
)
|
||||
|
||||
def build_prompt_with_handoff(self, base_prompt: str) -> str:
|
||||
"""
|
||||
构建包含交接上下文的 prompt
|
||||
|
||||
Args:
|
||||
base_prompt: 基础 prompt
|
||||
|
||||
Returns:
|
||||
增强后的 prompt
|
||||
"""
|
||||
handoff_context = self.get_handoff_context()
|
||||
if not handoff_context:
|
||||
return base_prompt
|
||||
|
||||
return f"""{base_prompt}
|
||||
|
||||
---
|
||||
## 前序 Agent 交接信息
|
||||
|
||||
{handoff_context}
|
||||
|
||||
---
|
||||
请基于以上来自前序 Agent 的信息,结合你的专业能力开展工作。
|
||||
"""
|
||||
|
||||
# ============ 核心事件发射方法 ============
|
||||
|
||||
async def emit_event(
|
||||
|
|
@ -173,13 +419,13 @@ class BaseAgent(ABC):
|
|||
|
||||
async def emit_thinking(self, message: str):
|
||||
"""发射 LLM 思考事件"""
|
||||
await self.emit_event("thinking", f"🧠 [{self.name}] {message}")
|
||||
await self.emit_event("thinking", f"[{self.name}] {message}")
|
||||
|
||||
async def emit_llm_start(self, iteration: int):
|
||||
"""发射 LLM 开始思考事件"""
|
||||
await self.emit_event(
|
||||
"llm_start",
|
||||
f"🤔 [{self.name}] LLM 开始第 {iteration} 轮思考...",
|
||||
f"[{self.name}] 第 {iteration} 轮迭代开始",
|
||||
metadata={"iteration": iteration}
|
||||
)
|
||||
|
||||
|
|
@ -189,31 +435,62 @@ class BaseAgent(ABC):
|
|||
display_thought = thought[:500] + "..." if len(thought) > 500 else thought
|
||||
await self.emit_event(
|
||||
"llm_thought",
|
||||
f"💭 [{self.name}] LLM 思考:\n{display_thought}",
|
||||
f"[{self.name}] 思考: {display_thought}",
|
||||
metadata={
|
||||
"thought": thought,
|
||||
"iteration": iteration,
|
||||
}
|
||||
)
|
||||
|
||||
async def emit_thinking_start(self):
|
||||
"""发射开始思考事件(流式输出用)"""
|
||||
await self.emit_event("thinking_start", f"[{self.name}] 开始思考...")
|
||||
|
||||
async def emit_thinking_token(self, token: str, accumulated: str):
|
||||
"""发射思考 token 事件(流式输出用)"""
|
||||
await self.emit_event(
|
||||
"thinking_token",
|
||||
"", # 不需要 message,前端从 metadata 获取
|
||||
metadata={
|
||||
"token": token,
|
||||
"accumulated": accumulated,
|
||||
}
|
||||
)
|
||||
|
||||
async def emit_thinking_end(self, full_response: str):
|
||||
"""发射思考结束事件(流式输出用)"""
|
||||
await self.emit_event(
|
||||
"thinking_end",
|
||||
f"[{self.name}] 思考完成",
|
||||
metadata={"accumulated": full_response}
|
||||
)
|
||||
|
||||
async def emit_llm_decision(self, decision: str, reason: str = ""):
|
||||
"""发射 LLM 决策事件 - 展示 LLM 做了什么决定"""
|
||||
await self.emit_event(
|
||||
"llm_decision",
|
||||
f"💡 [{self.name}] LLM 决策: {decision}" + (f" (理由: {reason})" if reason else ""),
|
||||
f"[{self.name}] 决策: {decision}" + (f" ({reason})" if reason else ""),
|
||||
metadata={
|
||||
"decision": decision,
|
||||
"reason": reason,
|
||||
}
|
||||
)
|
||||
|
||||
async def emit_llm_complete(self, result_summary: str, tokens_used: int):
|
||||
"""发射 LLM 完成事件"""
|
||||
await self.emit_event(
|
||||
"llm_complete",
|
||||
f"[{self.name}] 完成: {result_summary} (消耗 {tokens_used} tokens)",
|
||||
metadata={
|
||||
"tokens_used": tokens_used,
|
||||
}
|
||||
)
|
||||
|
||||
async def emit_llm_action(self, action: str, action_input: Dict):
|
||||
"""发射 LLM 动作事件 - LLM 决定执行什么动作"""
|
||||
import json
|
||||
input_str = json.dumps(action_input, ensure_ascii=False)[:200]
|
||||
"""发射 LLM 动作决策事件"""
|
||||
await self.emit_event(
|
||||
"llm_action",
|
||||
f"⚡ [{self.name}] LLM 动作: {action}\n 参数: {input_str}",
|
||||
f"[{self.name}] 执行动作: {action}",
|
||||
metadata={
|
||||
"action": action,
|
||||
"action_input": action_input,
|
||||
|
|
@ -221,43 +498,33 @@ class BaseAgent(ABC):
|
|||
)
|
||||
|
||||
async def emit_llm_observation(self, observation: str):
|
||||
"""发射 LLM 观察事件 - LLM 看到了什么"""
|
||||
"""发射 LLM 观察事件"""
|
||||
# 截断过长的观察结果
|
||||
display_obs = observation[:300] + "..." if len(observation) > 300 else observation
|
||||
await self.emit_event(
|
||||
"llm_observation",
|
||||
f"👁️ [{self.name}] LLM 观察到:\n{display_obs}",
|
||||
metadata={"observation": observation[:2000]}
|
||||
)
|
||||
|
||||
async def emit_llm_complete(self, result_summary: str, tokens_used: int):
|
||||
"""发射 LLM 完成事件"""
|
||||
await self.emit_event(
|
||||
"llm_complete",
|
||||
f"✅ [{self.name}] LLM 完成: {result_summary} (消耗 {tokens_used} tokens)",
|
||||
f"[{self.name}] 观察结果: {display_obs}",
|
||||
metadata={
|
||||
"tokens_used": tokens_used,
|
||||
"observation": observation[:2000], # 限制存储长度
|
||||
}
|
||||
)
|
||||
|
||||
# ============ 工具调用相关事件 ============
|
||||
|
||||
async def emit_tool_call(self, tool_name: str, tool_input: Dict):
|
||||
"""发射工具调用事件 - LLM 决定调用工具"""
|
||||
import json
|
||||
input_str = json.dumps(tool_input, ensure_ascii=False)[:300]
|
||||
"""发射工具调用事件"""
|
||||
await self.emit_event(
|
||||
"tool_call",
|
||||
f"🔧 [{self.name}] LLM 调用工具: {tool_name}\n 输入: {input_str}",
|
||||
f"[{self.name}] 调用工具: {tool_name}",
|
||||
tool_name=tool_name,
|
||||
tool_input=tool_input,
|
||||
)
|
||||
|
||||
async def emit_tool_result(self, tool_name: str, result: str, duration_ms: int):
|
||||
"""发射工具结果事件"""
|
||||
result_preview = result[:200] + "..." if len(result) > 200 else result
|
||||
await self.emit_event(
|
||||
"tool_result",
|
||||
f"📤 [{self.name}] 工具 {tool_name} 返回 ({duration_ms}ms):\n {result_preview}",
|
||||
f"[{self.name}] 工具 {tool_name} 完成 ({duration_ms}ms)",
|
||||
tool_name=tool_name,
|
||||
tool_duration_ms=duration_ms,
|
||||
)
|
||||
|
|
@ -332,9 +599,6 @@ class BaseAgent(ABC):
|
|||
"""
|
||||
self._iteration += 1
|
||||
|
||||
# 发射 LLM 开始事件
|
||||
await self.emit_llm_start(self._iteration)
|
||||
|
||||
try:
|
||||
response = await self.llm_service.chat_completion(
|
||||
messages=messages,
|
||||
|
|
@ -385,3 +649,124 @@ class BaseAgent(ABC):
|
|||
"tool_calls": self._tool_calls,
|
||||
"tokens_used": self._total_tokens,
|
||||
}
|
||||
|
||||
# ============ 统一的流式 LLM 调用 ============
|
||||
|
||||
async def stream_llm_call(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
temperature: float = 0.1,
|
||||
max_tokens: int = 2048,
|
||||
) -> Tuple[str, int]:
|
||||
"""
|
||||
统一的流式 LLM 调用方法
|
||||
|
||||
所有 Agent 共用此方法,避免重复代码
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
temperature: 温度
|
||||
max_tokens: 最大 token 数
|
||||
|
||||
Returns:
|
||||
(完整响应内容, token数量)
|
||||
"""
|
||||
accumulated = ""
|
||||
total_tokens = 0
|
||||
|
||||
await self.emit_thinking_start()
|
||||
|
||||
try:
|
||||
async for chunk in self.llm_service.chat_completion_stream(
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
):
|
||||
# 检查取消
|
||||
if self.is_cancelled:
|
||||
break
|
||||
|
||||
if chunk["type"] == "token":
|
||||
token = chunk["content"]
|
||||
accumulated = chunk["accumulated"]
|
||||
await self.emit_thinking_token(token, accumulated)
|
||||
|
||||
elif chunk["type"] == "done":
|
||||
accumulated = chunk["content"]
|
||||
if chunk.get("usage"):
|
||||
total_tokens = chunk["usage"].get("total_tokens", 0)
|
||||
break
|
||||
|
||||
elif chunk["type"] == "error":
|
||||
accumulated = chunk.get("accumulated", "")
|
||||
logger.error(f"Stream error: {chunk.get('error')}")
|
||||
break
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[{self.name}] LLM call cancelled")
|
||||
raise
|
||||
finally:
|
||||
await self.emit_thinking_end(accumulated)
|
||||
|
||||
return accumulated, total_tokens
|
||||
|
||||
async def execute_tool(self, tool_name: str, tool_input: Dict) -> str:
|
||||
"""
|
||||
统一的工具执行方法
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
tool_input: 工具参数
|
||||
|
||||
Returns:
|
||||
工具执行结果字符串
|
||||
"""
|
||||
tool = self.tools.get(tool_name)
|
||||
|
||||
if not tool:
|
||||
return f"错误: 工具 '{tool_name}' 不存在。可用工具: {list(self.tools.keys())}"
|
||||
|
||||
try:
|
||||
self._tool_calls += 1
|
||||
await self.emit_tool_call(tool_name, tool_input)
|
||||
|
||||
import time
|
||||
start = time.time()
|
||||
|
||||
result = await tool.execute(**tool_input)
|
||||
|
||||
duration_ms = int((time.time() - start) * 1000)
|
||||
await self.emit_tool_result(tool_name, str(result.data)[:200], duration_ms)
|
||||
|
||||
if result.success:
|
||||
output = str(result.data)
|
||||
|
||||
# 包含 metadata 中的额外信息
|
||||
if result.metadata:
|
||||
if "issues" in result.metadata:
|
||||
import json
|
||||
output += f"\n\n发现的问题:\n{json.dumps(result.metadata['issues'], ensure_ascii=False, indent=2)}"
|
||||
if "findings" in result.metadata:
|
||||
import json
|
||||
output += f"\n\n发现:\n{json.dumps(result.metadata['findings'][:10], ensure_ascii=False, indent=2)}"
|
||||
|
||||
# 截断过长输出
|
||||
if len(output) > 6000:
|
||||
output = output[:6000] + f"\n\n... [输出已截断,共 {len(str(result.data))} 字符]"
|
||||
return output
|
||||
else:
|
||||
return f"工具执行失败: {result.error}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tool execution error: {e}")
|
||||
return f"工具执行错误: {str(e)}"
|
||||
|
||||
def get_tools_description(self) -> str:
|
||||
"""生成工具描述文本(用于 prompt)"""
|
||||
tools_info = []
|
||||
for name, tool in self.tools.items():
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
desc = f"- {name}: {getattr(tool, 'description', 'No description')}"
|
||||
tools_info.append(desc)
|
||||
return "\n".join(tools_info)
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from typing import List, Dict, Any, Optional
|
|||
from dataclasses import dataclass
|
||||
|
||||
from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
|
||||
from ..json_parser import AgentJsonParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -178,18 +179,22 @@ class OrchestratorAgent(BaseAgent):
|
|||
|
||||
self._iteration = iteration + 1
|
||||
|
||||
# 🔥 发射 LLM 开始思考事件
|
||||
await self.emit_llm_start(iteration + 1)
|
||||
# 🔥 再次检查取消标志(在LLM调用之前)
|
||||
if self.is_cancelled:
|
||||
await self.emit_thinking("🛑 任务已取消,停止执行")
|
||||
break
|
||||
|
||||
# 🔥 调用 LLM 进行思考和决策
|
||||
response = await self.llm_service.chat_completion_raw(
|
||||
messages=self._conversation_history,
|
||||
# 调用 LLM 进行思考和决策(流式输出)
|
||||
try:
|
||||
llm_output, tokens_this_round = await self.stream_llm_call(
|
||||
self._conversation_history,
|
||||
temperature=0.1,
|
||||
max_tokens=2048,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[{self.name}] LLM call cancelled")
|
||||
break
|
||||
|
||||
llm_output = response.get("content", "")
|
||||
tokens_this_round = response.get("usage", {}).get("total_tokens", 0)
|
||||
self._total_tokens += tokens_this_round
|
||||
|
||||
# 解析 LLM 的决策
|
||||
|
|
@ -348,10 +353,11 @@ class OrchestratorAgent(BaseAgent):
|
|||
input_text = re.sub(r'```json\s*', '', input_text)
|
||||
input_text = re.sub(r'```\s*', '', input_text)
|
||||
|
||||
try:
|
||||
action_input = json.loads(input_text)
|
||||
except json.JSONDecodeError:
|
||||
action_input = {"raw": input_text}
|
||||
# 使用增强的 JSON 解析器
|
||||
action_input = AgentJsonParser.parse(
|
||||
input_text,
|
||||
default={"raw": input_text}
|
||||
)
|
||||
|
||||
return AgentStep(
|
||||
thought=thought,
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from typing import List, Dict, Any, Optional, Tuple
|
|||
from dataclasses import dataclass
|
||||
|
||||
from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
|
||||
from ..json_parser import AgentJsonParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -182,15 +183,20 @@ class ReActAgent(BaseAgent):
|
|||
final_match = re.search(r'Final Answer:\s*(.*?)$', response, re.DOTALL)
|
||||
if final_match:
|
||||
step.is_final = True
|
||||
try:
|
||||
# 尝试提取 JSON
|
||||
answer_text = final_match.group(1).strip()
|
||||
# 移除 markdown 代码块
|
||||
answer_text = re.sub(r'```json\s*', '', answer_text)
|
||||
answer_text = re.sub(r'```\s*', '', answer_text)
|
||||
step.final_answer = json.loads(answer_text)
|
||||
except json.JSONDecodeError:
|
||||
step.final_answer = {"raw_answer": final_match.group(1).strip()}
|
||||
# 使用增强的 JSON 解析器
|
||||
step.final_answer = AgentJsonParser.parse(
|
||||
answer_text,
|
||||
default={"raw_answer": answer_text}
|
||||
)
|
||||
# 确保 findings 格式正确
|
||||
if "findings" in step.final_answer:
|
||||
step.final_answer["findings"] = [
|
||||
f for f in step.final_answer["findings"]
|
||||
if isinstance(f, dict)
|
||||
]
|
||||
return step
|
||||
|
||||
# 提取 Action
|
||||
|
|
@ -202,14 +208,13 @@ class ReActAgent(BaseAgent):
|
|||
input_match = re.search(r'Action Input:\s*(.*?)(?=Thought:|Action:|Observation:|$)', response, re.DOTALL)
|
||||
if input_match:
|
||||
input_text = input_match.group(1).strip()
|
||||
# 移除 markdown 代码块
|
||||
input_text = re.sub(r'```json\s*', '', input_text)
|
||||
input_text = re.sub(r'```\s*', '', input_text)
|
||||
try:
|
||||
step.action_input = json.loads(input_text)
|
||||
except json.JSONDecodeError:
|
||||
# 尝试简单解析
|
||||
step.action_input = {"raw_input": input_text}
|
||||
# 使用增强的 JSON 解析器
|
||||
step.action_input = AgentJsonParser.parse(
|
||||
input_text,
|
||||
default={"raw_input": input_text}
|
||||
)
|
||||
|
||||
return step
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ LLM 是真正的大脑!
|
|||
类型: ReAct (真正的!)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
|
@ -17,22 +18,24 @@ from typing import List, Dict, Any, Optional
|
|||
from dataclasses import dataclass
|
||||
|
||||
from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
|
||||
from ..json_parser import AgentJsonParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
RECON_SYSTEM_PROMPT = """你是 DeepAudit 的信息收集 Agent,负责在安全审计前**自主**收集项目信息。
|
||||
RECON_SYSTEM_PROMPT = """你是 DeepAudit 的信息收集 Agent,负责在安全审计前收集项目信息。
|
||||
|
||||
## 你的角色
|
||||
你是信息收集的**大脑**,不是机械执行者。你需要:
|
||||
1. 自主思考需要收集什么信息
|
||||
2. 选择合适的工具获取信息
|
||||
3. 根据发现动态调整策略
|
||||
4. 判断何时信息收集足够
|
||||
## 你的职责
|
||||
你专注于**信息收集**,为后续的漏洞分析提供基础数据:
|
||||
1. 分析项目结构和目录布局
|
||||
2. 识别技术栈(语言、框架、数据库)
|
||||
3. 找出入口点(API、路由、用户输入处理)
|
||||
4. 标记高风险区域(认证、数据库操作、文件处理)
|
||||
5. 收集依赖信息
|
||||
|
||||
## 你可以使用的工具
|
||||
|
||||
### 文件系统
|
||||
### 文件系统工具
|
||||
- **list_files**: 列出目录内容
|
||||
参数: directory (str), recursive (bool), pattern (str), max_files (int)
|
||||
|
||||
|
|
@ -42,12 +45,14 @@ RECON_SYSTEM_PROMPT = """你是 DeepAudit 的信息收集 Agent,负责在安
|
|||
- **search_code**: 代码关键字搜索
|
||||
参数: keyword (str), max_results (int)
|
||||
|
||||
### 安全扫描
|
||||
- **semgrep_scan**: Semgrep 静态分析扫描
|
||||
- **npm_audit**: npm 依赖漏洞审计
|
||||
- **safety_scan**: Python 依赖漏洞审计
|
||||
- **gitleaks_scan**: 密钥/敏感信息泄露扫描
|
||||
- **osv_scan**: OSV 通用依赖漏洞扫描
|
||||
### 语义搜索工具
|
||||
- **rag_query**: 语义代码搜索(如果可用)
|
||||
参数: query (str), top_k (int)
|
||||
|
||||
## 注意
|
||||
- 你只负责信息收集,不要进行漏洞分析
|
||||
- 漏洞分析由 Analysis Agent 负责
|
||||
- 专注于收集项目结构、技术栈、入口点等信息
|
||||
|
||||
## 工作方式
|
||||
每一步,你需要输出:
|
||||
|
|
@ -142,15 +147,7 @@ class ReconAgent(BaseAgent):
|
|||
self._conversation_history: List[Dict[str, str]] = []
|
||||
self._steps: List[ReconStep] = []
|
||||
|
||||
def _get_tools_description(self) -> str:
|
||||
"""生成工具描述"""
|
||||
tools_info = []
|
||||
for name, tool in self.tools.items():
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
desc = f"- {name}: {getattr(tool, 'description', 'No description')}"
|
||||
tools_info.append(desc)
|
||||
return "\n".join(tools_info)
|
||||
|
||||
|
||||
def _parse_llm_response(self, response: str) -> ReconStep:
|
||||
"""解析 LLM 响应"""
|
||||
|
|
@ -165,13 +162,14 @@ class ReconAgent(BaseAgent):
|
|||
final_match = re.search(r'Final Answer:\s*(.*?)$', response, re.DOTALL)
|
||||
if final_match:
|
||||
step.is_final = True
|
||||
try:
|
||||
answer_text = final_match.group(1).strip()
|
||||
answer_text = re.sub(r'```json\s*', '', answer_text)
|
||||
answer_text = re.sub(r'```\s*', '', answer_text)
|
||||
step.final_answer = json.loads(answer_text)
|
||||
except json.JSONDecodeError:
|
||||
step.final_answer = {"raw_answer": final_match.group(1).strip()}
|
||||
# 使用增强的 JSON 解析器
|
||||
step.final_answer = AgentJsonParser.parse(
|
||||
answer_text,
|
||||
default={"raw_answer": answer_text}
|
||||
)
|
||||
return step
|
||||
|
||||
# 提取 Action
|
||||
|
|
@ -185,43 +183,15 @@ class ReconAgent(BaseAgent):
|
|||
input_text = input_match.group(1).strip()
|
||||
input_text = re.sub(r'```json\s*', '', input_text)
|
||||
input_text = re.sub(r'```\s*', '', input_text)
|
||||
try:
|
||||
step.action_input = json.loads(input_text)
|
||||
except json.JSONDecodeError:
|
||||
step.action_input = {"raw_input": input_text}
|
||||
# 使用增强的 JSON 解析器
|
||||
step.action_input = AgentJsonParser.parse(
|
||||
input_text,
|
||||
default={"raw_input": input_text}
|
||||
)
|
||||
|
||||
return step
|
||||
|
||||
async def _execute_tool(self, tool_name: str, tool_input: Dict) -> str:
|
||||
"""执行工具"""
|
||||
tool = self.tools.get(tool_name)
|
||||
|
||||
if not tool:
|
||||
return f"错误: 工具 '{tool_name}' 不存在。可用工具: {list(self.tools.keys())}"
|
||||
|
||||
try:
|
||||
self._tool_calls += 1
|
||||
await self.emit_tool_call(tool_name, tool_input)
|
||||
|
||||
import time
|
||||
start = time.time()
|
||||
|
||||
result = await tool.execute(**tool_input)
|
||||
|
||||
duration_ms = int((time.time() - start) * 1000)
|
||||
await self.emit_tool_result(tool_name, str(result.data)[:200], duration_ms)
|
||||
|
||||
if result.success:
|
||||
output = str(result.data)
|
||||
if len(output) > 4000:
|
||||
output = output[:4000] + f"\n\n... [输出已截断,共 {len(str(result.data))} 字符]"
|
||||
return output
|
||||
else:
|
||||
return f"工具执行失败: {result.error}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tool execution error: {e}")
|
||||
return f"工具执行错误: {str(e)}"
|
||||
|
||||
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
|
||||
"""
|
||||
|
|
@ -246,7 +216,7 @@ class ReconAgent(BaseAgent):
|
|||
{task_context or task or '进行全面的信息收集,为安全审计做准备。'}
|
||||
|
||||
## 可用工具
|
||||
{self._get_tools_description()}
|
||||
{self.get_tools_description()}
|
||||
|
||||
请开始你的信息收集工作。首先思考应该收集什么信息,然后选择合适的工具。"""
|
||||
|
||||
|
|
@ -259,7 +229,7 @@ class ReconAgent(BaseAgent):
|
|||
self._steps = []
|
||||
final_result = None
|
||||
|
||||
await self.emit_thinking("🔍 Recon Agent 启动,LLM 开始自主收集信息...")
|
||||
await self.emit_thinking("Recon Agent 启动,LLM 开始自主收集信息...")
|
||||
|
||||
try:
|
||||
for iteration in range(self.config.max_iterations):
|
||||
|
|
@ -268,18 +238,22 @@ class ReconAgent(BaseAgent):
|
|||
|
||||
self._iteration = iteration + 1
|
||||
|
||||
# 🔥 发射 LLM 开始思考事件
|
||||
await self.emit_llm_start(iteration + 1)
|
||||
# 🔥 再次检查取消标志(在LLM调用之前)
|
||||
if self.is_cancelled:
|
||||
await self.emit_thinking("🛑 任务已取消,停止执行")
|
||||
break
|
||||
|
||||
# 🔥 调用 LLM 进行思考和决策
|
||||
response = await self.llm_service.chat_completion_raw(
|
||||
messages=self._conversation_history,
|
||||
# 调用 LLM 进行思考和决策(使用基类统一方法)
|
||||
try:
|
||||
llm_output, tokens_this_round = await self.stream_llm_call(
|
||||
self._conversation_history,
|
||||
temperature=0.1,
|
||||
max_tokens=2048,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[{self.name}] LLM call cancelled")
|
||||
break
|
||||
|
||||
llm_output = response.get("content", "")
|
||||
tokens_this_round = response.get("usage", {}).get("total_tokens", 0)
|
||||
self._total_tokens += tokens_this_round
|
||||
|
||||
# 解析 LLM 响应
|
||||
|
|
@ -311,7 +285,7 @@ class ReconAgent(BaseAgent):
|
|||
# 🔥 发射 LLM 动作决策事件
|
||||
await self.emit_llm_action(step.action, step.action_input or {})
|
||||
|
||||
observation = await self._execute_tool(
|
||||
observation = await self.execute_tool(
|
||||
step.action,
|
||||
step.action_input or {}
|
||||
)
|
||||
|
|
@ -341,9 +315,18 @@ class ReconAgent(BaseAgent):
|
|||
if not final_result:
|
||||
final_result = self._summarize_from_steps()
|
||||
|
||||
# 🔥 记录工作和洞察
|
||||
self.record_work(f"完成项目信息收集,发现 {len(final_result.get('entry_points', []))} 个入口点")
|
||||
self.record_work(f"识别技术栈: {final_result.get('tech_stack', {})}")
|
||||
|
||||
if final_result.get("high_risk_areas"):
|
||||
self.add_insight(f"发现 {len(final_result['high_risk_areas'])} 个高风险区域需要重点分析")
|
||||
if final_result.get("initial_findings"):
|
||||
self.add_insight(f"初步发现 {len(final_result['initial_findings'])} 个潜在问题")
|
||||
|
||||
await self.emit_event(
|
||||
"info",
|
||||
f"🎯 Recon Agent 完成: {self._iteration} 轮迭代, {self._tool_calls} 次工具调用"
|
||||
f"Recon Agent 完成: {self._iteration} 轮迭代, {self._tool_calls} 次工具调用"
|
||||
)
|
||||
|
||||
return AgentResult(
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ LLM 是验证的大脑!
|
|||
类型: ReAct (真正的!)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
|
@ -18,6 +19,7 @@ from dataclasses import dataclass
|
|||
from datetime import datetime, timezone
|
||||
|
||||
from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
|
||||
from ..json_parser import AgentJsonParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -34,15 +36,17 @@ VERIFICATION_SYSTEM_PROMPT = """你是 DeepAudit 的漏洞验证 Agent,一个*
|
|||
|
||||
## 你可以使用的工具
|
||||
|
||||
### 代码分析
|
||||
### 文件操作
|
||||
- **read_file**: 读取更多代码上下文
|
||||
参数: file_path (str), start_line (int), end_line (int)
|
||||
- **function_context**: 分析函数调用关系
|
||||
参数: function_name (str)
|
||||
- **dataflow_analysis**: 追踪数据流
|
||||
参数: source (str), sink (str), file_path (str)
|
||||
- **list_files**: 列出目录文件
|
||||
参数: directory (str), pattern (str)
|
||||
|
||||
### 验证分析
|
||||
- **vulnerability_validation**: LLM 深度验证 ⭐
|
||||
参数: code (str), vulnerability_type (str), context (str)
|
||||
- **dataflow_analysis**: 追踪数据流
|
||||
参数: source (str), sink (str), file_path (str)
|
||||
|
||||
### 沙箱验证
|
||||
- **sandbox_exec**: 在沙箱中执行命令
|
||||
|
|
@ -157,16 +161,6 @@ class VerificationAgent(BaseAgent):
|
|||
self._conversation_history: List[Dict[str, str]] = []
|
||||
self._steps: List[VerificationStep] = []
|
||||
|
||||
def _get_tools_description(self) -> str:
|
||||
"""生成工具描述"""
|
||||
tools_info = []
|
||||
for name, tool in self.tools.items():
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
desc = f"- {name}: {getattr(tool, 'description', 'No description')}"
|
||||
tools_info.append(desc)
|
||||
return "\n".join(tools_info)
|
||||
|
||||
def _parse_llm_response(self, response: str) -> VerificationStep:
|
||||
"""解析 LLM 响应"""
|
||||
step = VerificationStep(thought="")
|
||||
|
|
@ -180,13 +174,20 @@ class VerificationAgent(BaseAgent):
|
|||
final_match = re.search(r'Final Answer:\s*(.*?)$', response, re.DOTALL)
|
||||
if final_match:
|
||||
step.is_final = True
|
||||
try:
|
||||
answer_text = final_match.group(1).strip()
|
||||
answer_text = re.sub(r'```json\s*', '', answer_text)
|
||||
answer_text = re.sub(r'```\s*', '', answer_text)
|
||||
step.final_answer = json.loads(answer_text)
|
||||
except json.JSONDecodeError:
|
||||
step.final_answer = {"findings": [], "raw_answer": final_match.group(1).strip()}
|
||||
# 使用增强的 JSON 解析器
|
||||
step.final_answer = AgentJsonParser.parse(
|
||||
answer_text,
|
||||
default={"findings": [], "raw_answer": answer_text}
|
||||
)
|
||||
# 确保 findings 格式正确
|
||||
if "findings" in step.final_answer:
|
||||
step.final_answer["findings"] = [
|
||||
f for f in step.final_answer["findings"]
|
||||
if isinstance(f, dict)
|
||||
]
|
||||
return step
|
||||
|
||||
# 提取 Action
|
||||
|
|
@ -200,50 +201,14 @@ class VerificationAgent(BaseAgent):
|
|||
input_text = input_match.group(1).strip()
|
||||
input_text = re.sub(r'```json\s*', '', input_text)
|
||||
input_text = re.sub(r'```\s*', '', input_text)
|
||||
try:
|
||||
step.action_input = json.loads(input_text)
|
||||
except json.JSONDecodeError:
|
||||
step.action_input = {"raw_input": input_text}
|
||||
# 使用增强的 JSON 解析器
|
||||
step.action_input = AgentJsonParser.parse(
|
||||
input_text,
|
||||
default={"raw_input": input_text}
|
||||
)
|
||||
|
||||
return step
|
||||
|
||||
async def _execute_tool(self, tool_name: str, tool_input: Dict) -> str:
|
||||
"""执行工具"""
|
||||
tool = self.tools.get(tool_name)
|
||||
|
||||
if not tool:
|
||||
return f"错误: 工具 '{tool_name}' 不存在。可用工具: {list(self.tools.keys())}"
|
||||
|
||||
try:
|
||||
self._tool_calls += 1
|
||||
await self.emit_tool_call(tool_name, tool_input)
|
||||
|
||||
import time
|
||||
start = time.time()
|
||||
|
||||
result = await tool.execute(**tool_input)
|
||||
|
||||
duration_ms = int((time.time() - start) * 1000)
|
||||
await self.emit_tool_result(tool_name, str(result.data)[:200], duration_ms)
|
||||
|
||||
if result.success:
|
||||
output = str(result.data)
|
||||
|
||||
# 包含 metadata
|
||||
if result.metadata:
|
||||
if "validation" in result.metadata:
|
||||
output += f"\n\n验证结果:\n{json.dumps(result.metadata['validation'], ensure_ascii=False, indent=2)}"
|
||||
|
||||
if len(output) > 4000:
|
||||
output = output[:4000] + f"\n\n... [输出已截断]"
|
||||
return output
|
||||
else:
|
||||
return f"工具执行失败: {result.error}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tool execution error: {e}")
|
||||
return f"工具执行错误: {str(e)}"
|
||||
|
||||
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
|
||||
"""
|
||||
执行漏洞验证 - LLM 全程参与!
|
||||
|
|
@ -256,9 +221,21 @@ class VerificationAgent(BaseAgent):
|
|||
task = input_data.get("task", "")
|
||||
task_context = input_data.get("task_context", "")
|
||||
|
||||
# 🔥 处理交接信息
|
||||
handoff = input_data.get("handoff")
|
||||
if handoff:
|
||||
from .base import TaskHandoff
|
||||
if isinstance(handoff, dict):
|
||||
handoff = TaskHandoff.from_dict(handoff)
|
||||
self.receive_handoff(handoff)
|
||||
|
||||
# 收集所有待验证的发现
|
||||
findings_to_verify = []
|
||||
|
||||
# 🔥 优先从交接信息获取发现
|
||||
if self._incoming_handoff and self._incoming_handoff.key_findings:
|
||||
findings_to_verify = self._incoming_handoff.key_findings.copy()
|
||||
else:
|
||||
for phase_name, result in previous_results.items():
|
||||
if isinstance(result, dict):
|
||||
data = result.get("data", {})
|
||||
|
|
@ -289,7 +266,12 @@ class VerificationAgent(BaseAgent):
|
|||
f"开始验证 {len(findings_to_verify)} 个发现"
|
||||
)
|
||||
|
||||
# 构建初始消息
|
||||
# 🔥 记录工作开始
|
||||
self.record_work(f"开始验证 {len(findings_to_verify)} 个漏洞发现")
|
||||
|
||||
# 🔥 构建包含交接上下文的初始消息
|
||||
handoff_context = self.get_handoff_context()
|
||||
|
||||
findings_summary = []
|
||||
for i, f in enumerate(findings_to_verify):
|
||||
findings_summary.append(f"""
|
||||
|
|
@ -306,6 +288,8 @@ class VerificationAgent(BaseAgent):
|
|||
|
||||
initial_message = f"""请验证以下 {len(findings_to_verify)} 个安全发现。
|
||||
|
||||
{handoff_context if handoff_context else ''}
|
||||
|
||||
## 待验证发现
|
||||
{''.join(findings_summary)}
|
||||
|
||||
|
|
@ -313,9 +297,10 @@ class VerificationAgent(BaseAgent):
|
|||
- 验证级别: {config.get('verification_level', 'standard')}
|
||||
|
||||
## 可用工具
|
||||
{self._get_tools_description()}
|
||||
{self.get_tools_description()}
|
||||
|
||||
请开始验证。对于每个发现,思考如何验证它,使用合适的工具获取更多信息,然后判断是否为真实漏洞。"""
|
||||
请开始验证。对于每个发现,思考如何验证它,使用合适的工具获取更多信息,然后判断是否为真实漏洞。
|
||||
{f"特别注意 Analysis Agent 提到的关注点。" if handoff_context else ""}"""
|
||||
|
||||
# 初始化对话历史
|
||||
self._conversation_history = [
|
||||
|
|
@ -335,18 +320,22 @@ class VerificationAgent(BaseAgent):
|
|||
|
||||
self._iteration = iteration + 1
|
||||
|
||||
# 🔥 发射 LLM 开始思考事件
|
||||
await self.emit_llm_start(iteration + 1)
|
||||
# 🔥 再次检查取消标志(在LLM调用之前)
|
||||
if self.is_cancelled:
|
||||
await self.emit_thinking("🛑 任务已取消,停止执行")
|
||||
break
|
||||
|
||||
# 🔥 调用 LLM 进行思考和决策
|
||||
response = await self.llm_service.chat_completion_raw(
|
||||
messages=self._conversation_history,
|
||||
# 调用 LLM 进行思考和决策(流式输出)
|
||||
try:
|
||||
llm_output, tokens_this_round = await self.stream_llm_call(
|
||||
self._conversation_history,
|
||||
temperature=0.1,
|
||||
max_tokens=3000,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[{self.name}] LLM call cancelled")
|
||||
break
|
||||
|
||||
llm_output = response.get("content", "")
|
||||
tokens_this_round = response.get("usage", {}).get("total_tokens", 0)
|
||||
self._total_tokens += tokens_this_round
|
||||
|
||||
# 解析 LLM 响应
|
||||
|
|
@ -367,6 +356,14 @@ class VerificationAgent(BaseAgent):
|
|||
if step.is_final:
|
||||
await self.emit_llm_decision("完成漏洞验证", "LLM 判断验证已充分")
|
||||
final_result = step.final_answer
|
||||
|
||||
# 🔥 记录洞察和工作
|
||||
if final_result and "findings" in final_result:
|
||||
verified_count = len([f for f in final_result["findings"] if f.get("is_verified")])
|
||||
fp_count = len([f for f in final_result["findings"] if f.get("verdict") == "false_positive"])
|
||||
self.add_insight(f"验证了 {len(final_result['findings'])} 个发现,{verified_count} 个确认,{fp_count} 个误报")
|
||||
self.record_work(f"完成漏洞验证: {verified_count} 个确认, {fp_count} 个误报")
|
||||
|
||||
await self.emit_llm_complete(
|
||||
f"验证完成",
|
||||
self._total_tokens
|
||||
|
|
@ -378,7 +375,7 @@ class VerificationAgent(BaseAgent):
|
|||
# 🔥 发射 LLM 动作决策事件
|
||||
await self.emit_llm_action(step.action, step.action_input or {})
|
||||
|
||||
observation = await self._execute_tool(
|
||||
observation = await self.execute_tool(
|
||||
step.action,
|
||||
step.action_input or {}
|
||||
)
|
||||
|
|
@ -438,7 +435,7 @@ class VerificationAgent(BaseAgent):
|
|||
|
||||
await self.emit_event(
|
||||
"info",
|
||||
f"🎯 Verification Agent 完成: {confirmed_count} 确认, {likely_count} 可能, {false_positive_count} 误报"
|
||||
f"Verification Agent 完成: {confirmed_count} 确认, {likely_count} 可能, {false_positive_count} 误报"
|
||||
)
|
||||
|
||||
return AgentResult(
|
||||
|
|
|
|||
|
|
@ -299,8 +299,9 @@ class EventManager:
|
|||
"timestamp": timestamp.isoformat(),
|
||||
}
|
||||
|
||||
# 保存到数据库
|
||||
if self.db_session_factory:
|
||||
# 保存到数据库(跳过高频事件如 thinking_token)
|
||||
skip_db_events = {"thinking_token", "thinking_start", "thinking_end"}
|
||||
if self.db_session_factory and event_type not in skip_db_events:
|
||||
try:
|
||||
await self._save_event_to_db(event_data)
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -69,6 +69,11 @@ class AuditState(TypedDict):
|
|||
llm_next_action: Optional[str] # LLM 建议的下一步: "continue_analysis", "verify", "report", "end"
|
||||
llm_routing_reason: Optional[str] # LLM 的决策理由
|
||||
|
||||
# 🔥 新增:Agent 间协作的任务交接信息
|
||||
recon_handoff: Optional[Dict[str, Any]] # Recon -> Analysis 的交接
|
||||
analysis_handoff: Optional[Dict[str, Any]] # Analysis -> Verification 的交接
|
||||
verification_handoff: Optional[Dict[str, Any]] # Verification -> Report 的交接
|
||||
|
||||
# 消息和事件
|
||||
messages: Annotated[List[Dict], operator.add]
|
||||
events: Annotated[List[Dict], operator.add]
|
||||
|
|
@ -146,6 +151,9 @@ class LLMRouter:
|
|||
# 统计发现
|
||||
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
|
||||
for f in findings:
|
||||
# 跳过非字典类型的 finding
|
||||
if not isinstance(f, dict):
|
||||
continue
|
||||
sev = f.get("severity", "medium")
|
||||
severity_counts[sev] = severity_counts.get(sev, 0) + 1
|
||||
|
||||
|
|
@ -243,6 +251,11 @@ def route_after_recon(state: AuditState) -> Literal["analysis", "end"]:
|
|||
Recon 后的路由决策
|
||||
优先使用 LLM 的决策,否则使用默认逻辑
|
||||
"""
|
||||
# 🔥 检查是否有错误
|
||||
if state.get("error") or state.get("current_phase") == "error":
|
||||
logger.error(f"Recon phase has error, routing to end: {state.get('error')}")
|
||||
return "end"
|
||||
|
||||
# 检查 LLM 是否有决策
|
||||
llm_action = state.get("llm_next_action")
|
||||
if llm_action:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
"""
|
||||
LangGraph 节点实现
|
||||
每个节点封装一个 Agent 的执行逻辑
|
||||
|
||||
协作增强:节点之间通过 TaskHandoff 传递结构化的上下文和洞察
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
|
@ -29,13 +31,21 @@ class BaseNode:
|
|||
except Exception as e:
|
||||
logger.warning(f"Failed to emit event: {e}")
|
||||
|
||||
def _extract_handoff_from_state(self, state: Dict[str, Any], from_phase: str):
|
||||
"""从状态中提取前序 Agent 的 handoff"""
|
||||
handoff_data = state.get(f"{from_phase}_handoff")
|
||||
if handoff_data:
|
||||
from ..agents.base import TaskHandoff
|
||||
return TaskHandoff.from_dict(handoff_data)
|
||||
return None
|
||||
|
||||
|
||||
class ReconNode(BaseNode):
|
||||
"""
|
||||
信息收集节点
|
||||
|
||||
输入: project_root, project_info, config
|
||||
输出: tech_stack, entry_points, high_risk_areas, dependencies
|
||||
输出: tech_stack, entry_points, high_risk_areas, dependencies, recon_handoff
|
||||
"""
|
||||
|
||||
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
|
@ -52,6 +62,35 @@ class ReconNode(BaseNode):
|
|||
if result.success and result.data:
|
||||
data = result.data
|
||||
|
||||
# 🔥 创建交接信息给 Analysis Agent
|
||||
handoff = self.agent.create_handoff(
|
||||
to_agent="Analysis",
|
||||
summary=f"项目信息收集完成。发现 {len(data.get('entry_points', []))} 个入口点,{len(data.get('high_risk_areas', []))} 个高风险区域。",
|
||||
key_findings=data.get("initial_findings", []),
|
||||
suggested_actions=[
|
||||
{
|
||||
"type": "deep_analysis",
|
||||
"description": f"深入分析高风险区域: {', '.join(data.get('high_risk_areas', [])[:5])}",
|
||||
"priority": "high",
|
||||
},
|
||||
{
|
||||
"type": "entry_point_audit",
|
||||
"description": "审计所有入口点的输入验证",
|
||||
"priority": "high",
|
||||
},
|
||||
],
|
||||
attention_points=[
|
||||
f"技术栈: {data.get('tech_stack', {}).get('frameworks', [])}",
|
||||
f"主要语言: {data.get('tech_stack', {}).get('languages', [])}",
|
||||
],
|
||||
priority_areas=data.get("high_risk_areas", [])[:10],
|
||||
context_data={
|
||||
"tech_stack": data.get("tech_stack", {}),
|
||||
"entry_points": data.get("entry_points", []),
|
||||
"dependencies": data.get("dependencies", {}),
|
||||
},
|
||||
)
|
||||
|
||||
await self.emit_event(
|
||||
"phase_complete",
|
||||
f"✅ 信息收集完成: 发现 {len(data.get('entry_points', []))} 个入口点"
|
||||
|
|
@ -63,12 +102,15 @@ class ReconNode(BaseNode):
|
|||
"high_risk_areas": data.get("high_risk_areas", []),
|
||||
"dependencies": data.get("dependencies", {}),
|
||||
"current_phase": "recon_complete",
|
||||
"findings": data.get("initial_findings", []), # 初步发现
|
||||
"findings": data.get("initial_findings", []),
|
||||
# 🔥 保存交接信息
|
||||
"recon_handoff": handoff.to_dict(),
|
||||
"events": [{
|
||||
"type": "recon_complete",
|
||||
"data": {
|
||||
"entry_points_count": len(data.get("entry_points", [])),
|
||||
"high_risk_areas_count": len(data.get("high_risk_areas", [])),
|
||||
"handoff_summary": handoff.summary,
|
||||
}
|
||||
}],
|
||||
}
|
||||
|
|
@ -90,8 +132,8 @@ class AnalysisNode(BaseNode):
|
|||
"""
|
||||
漏洞分析节点
|
||||
|
||||
输入: tech_stack, entry_points, high_risk_areas, previous findings
|
||||
输出: findings (累加), should_continue_analysis
|
||||
输入: tech_stack, entry_points, high_risk_areas, recon_handoff
|
||||
输出: findings (累加), should_continue_analysis, analysis_handoff
|
||||
"""
|
||||
|
||||
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
|
@ -104,6 +146,15 @@ class AnalysisNode(BaseNode):
|
|||
)
|
||||
|
||||
try:
|
||||
# 🔥 提取 Recon 的交接信息
|
||||
recon_handoff = self._extract_handoff_from_state(state, "recon")
|
||||
if recon_handoff:
|
||||
self.agent.receive_handoff(recon_handoff)
|
||||
await self.emit_event(
|
||||
"handoff_received",
|
||||
f"📨 收到 Recon Agent 交接: {recon_handoff.summary[:50]}..."
|
||||
)
|
||||
|
||||
# 构建分析输入
|
||||
analysis_input = {
|
||||
"phase_name": "analysis",
|
||||
|
|
@ -121,6 +172,8 @@ class AnalysisNode(BaseNode):
|
|||
}
|
||||
}
|
||||
},
|
||||
# 🔥 传递交接信息
|
||||
"handoff": recon_handoff,
|
||||
}
|
||||
|
||||
# 调用 Analysis Agent
|
||||
|
|
@ -130,27 +183,70 @@ class AnalysisNode(BaseNode):
|
|||
new_findings = result.data.get("findings", [])
|
||||
|
||||
# 判断是否需要继续分析
|
||||
# 如果这一轮发现了很多问题,可能还有更多
|
||||
should_continue = (
|
||||
len(new_findings) >= 5 and
|
||||
iteration < state.get("max_iterations", 3)
|
||||
)
|
||||
|
||||
# 🔥 创建交接信息给 Verification Agent
|
||||
# 统计严重程度
|
||||
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
|
||||
for f in new_findings:
|
||||
if isinstance(f, dict):
|
||||
sev = f.get("severity", "medium")
|
||||
severity_counts[sev] = severity_counts.get(sev, 0) + 1
|
||||
|
||||
handoff = self.agent.create_handoff(
|
||||
to_agent="Verification",
|
||||
summary=f"漏洞分析完成。发现 {len(new_findings)} 个潜在漏洞 (Critical: {severity_counts['critical']}, High: {severity_counts['high']}, Medium: {severity_counts['medium']}, Low: {severity_counts['low']})",
|
||||
key_findings=new_findings[:20], # 传递前20个发现
|
||||
suggested_actions=[
|
||||
{
|
||||
"type": "verify_critical",
|
||||
"description": "优先验证 Critical 和 High 级别的漏洞",
|
||||
"priority": "critical",
|
||||
},
|
||||
{
|
||||
"type": "poc_generation",
|
||||
"description": "为确认的漏洞生成 PoC",
|
||||
"priority": "high",
|
||||
},
|
||||
],
|
||||
attention_points=[
|
||||
f"共 {severity_counts['critical']} 个 Critical 级别漏洞需要立即验证",
|
||||
f"共 {severity_counts['high']} 个 High 级别漏洞需要优先验证",
|
||||
"注意检查是否有误报,特别是静态分析工具的结果",
|
||||
],
|
||||
priority_areas=[
|
||||
f.get("file_path", "") for f in new_findings
|
||||
if f.get("severity") in ["critical", "high"]
|
||||
][:10],
|
||||
context_data={
|
||||
"severity_distribution": severity_counts,
|
||||
"total_findings": len(new_findings),
|
||||
"iteration": iteration,
|
||||
},
|
||||
)
|
||||
|
||||
await self.emit_event(
|
||||
"phase_complete",
|
||||
f"✅ 分析迭代 {iteration} 完成: 发现 {len(new_findings)} 个潜在漏洞"
|
||||
)
|
||||
|
||||
return {
|
||||
"findings": new_findings, # 会自动累加
|
||||
"findings": new_findings,
|
||||
"iteration": iteration,
|
||||
"should_continue_analysis": should_continue,
|
||||
"current_phase": "analysis_complete",
|
||||
# 🔥 保存交接信息
|
||||
"analysis_handoff": handoff.to_dict(),
|
||||
"events": [{
|
||||
"type": "analysis_iteration",
|
||||
"data": {
|
||||
"iteration": iteration,
|
||||
"findings_count": len(new_findings),
|
||||
"severity_distribution": severity_counts,
|
||||
"handoff_summary": handoff.summary,
|
||||
}
|
||||
}],
|
||||
}
|
||||
|
|
@ -174,8 +270,8 @@ class VerificationNode(BaseNode):
|
|||
"""
|
||||
漏洞验证节点
|
||||
|
||||
输入: findings
|
||||
输出: verified_findings, false_positives
|
||||
输入: findings, analysis_handoff
|
||||
输出: verified_findings, false_positives, verification_handoff
|
||||
"""
|
||||
|
||||
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
|
@ -195,6 +291,15 @@ class VerificationNode(BaseNode):
|
|||
)
|
||||
|
||||
try:
|
||||
# 🔥 提取 Analysis 的交接信息
|
||||
analysis_handoff = self._extract_handoff_from_state(state, "analysis")
|
||||
if analysis_handoff:
|
||||
self.agent.receive_handoff(analysis_handoff)
|
||||
await self.emit_event(
|
||||
"handoff_received",
|
||||
f"📨 收到 Analysis Agent 交接: {analysis_handoff.summary[:50]}..."
|
||||
)
|
||||
|
||||
# 构建验证输入
|
||||
verification_input = {
|
||||
"previous_results": {
|
||||
|
|
@ -205,16 +310,49 @@ class VerificationNode(BaseNode):
|
|||
}
|
||||
},
|
||||
"config": state["config"],
|
||||
# 🔥 传递交接信息
|
||||
"handoff": analysis_handoff,
|
||||
}
|
||||
|
||||
# 调用 Verification Agent
|
||||
result = await self.agent.run(verification_input)
|
||||
|
||||
if result.success and result.data:
|
||||
verified = [f for f in result.data.get("findings", []) if f.get("is_verified")]
|
||||
false_pos = [f["id"] for f in result.data.get("findings", [])
|
||||
all_verified_findings = result.data.get("findings", [])
|
||||
verified = [f for f in all_verified_findings if f.get("is_verified")]
|
||||
false_pos = [f.get("id", f.get("title", "unknown")) for f in all_verified_findings
|
||||
if f.get("verdict") == "false_positive"]
|
||||
|
||||
# 🔥 创建交接信息给 Report 节点
|
||||
handoff = self.agent.create_handoff(
|
||||
to_agent="Report",
|
||||
summary=f"漏洞验证完成。{len(verified)} 个漏洞已确认,{len(false_pos)} 个误报已排除。",
|
||||
key_findings=verified,
|
||||
suggested_actions=[
|
||||
{
|
||||
"type": "generate_report",
|
||||
"description": "生成详细的安全审计报告",
|
||||
"priority": "high",
|
||||
},
|
||||
{
|
||||
"type": "remediation_plan",
|
||||
"description": "为确认的漏洞制定修复计划",
|
||||
"priority": "high",
|
||||
},
|
||||
],
|
||||
attention_points=[
|
||||
f"共 {len(verified)} 个漏洞已确认存在",
|
||||
f"共 {len(false_pos)} 个误报已排除",
|
||||
"建议按严重程度优先修复 Critical 和 High 级别漏洞",
|
||||
],
|
||||
context_data={
|
||||
"verified_count": len(verified),
|
||||
"false_positive_count": len(false_pos),
|
||||
"total_analyzed": len(findings),
|
||||
"verification_rate": len(verified) / len(findings) if findings else 0,
|
||||
},
|
||||
)
|
||||
|
||||
await self.emit_event(
|
||||
"phase_complete",
|
||||
f"✅ 验证完成: {len(verified)} 已确认, {len(false_pos)} 误报"
|
||||
|
|
@ -224,11 +362,14 @@ class VerificationNode(BaseNode):
|
|||
"verified_findings": verified,
|
||||
"false_positives": false_pos,
|
||||
"current_phase": "verification_complete",
|
||||
# 🔥 保存交接信息
|
||||
"verification_handoff": handoff.to_dict(),
|
||||
"events": [{
|
||||
"type": "verification_complete",
|
||||
"data": {
|
||||
"verified_count": len(verified),
|
||||
"false_positive_count": len(false_pos),
|
||||
"handoff_summary": handoff.summary,
|
||||
}
|
||||
}],
|
||||
}
|
||||
|
|
@ -269,6 +410,11 @@ class ReportNode(BaseNode):
|
|||
type_counts = {}
|
||||
|
||||
for finding in findings:
|
||||
# 跳过非字典类型的 finding(防止数据格式异常)
|
||||
if not isinstance(finding, dict):
|
||||
logger.warning(f"Skipping invalid finding (not a dict): {type(finding)}")
|
||||
continue
|
||||
|
||||
sev = finding.get("severity", "medium")
|
||||
severity_counts[sev] = severity_counts.get(sev, 0) + 1
|
||||
|
||||
|
|
@ -300,7 +446,7 @@ class ReportNode(BaseNode):
|
|||
|
||||
await self.emit_event(
|
||||
"phase_complete",
|
||||
f"✅ 报告生成完成: 安全评分 {security_score}/100"
|
||||
f"报告生成完成: 安全评分 {security_score}/100"
|
||||
)
|
||||
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -39,167 +39,8 @@ from .nodes import ReconNode, AnalysisNode, VerificationNode, ReportNode
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMService:
|
||||
"""
|
||||
LLM 服务封装
|
||||
提供代码分析、漏洞检测等 AI 功能
|
||||
"""
|
||||
|
||||
def __init__(self, model: Optional[str] = None, api_key: Optional[str] = None):
|
||||
self.model = model or settings.LLM_MODEL or "gpt-4o-mini"
|
||||
self.api_key = api_key or settings.LLM_API_KEY
|
||||
self.base_url = settings.LLM_BASE_URL
|
||||
|
||||
async def chat_completion_raw(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
temperature: float = 0.1,
|
||||
max_tokens: int = 4096,
|
||||
) -> Dict[str, Any]:
|
||||
"""调用 LLM 生成响应"""
|
||||
try:
|
||||
import litellm
|
||||
|
||||
response = await litellm.acompletion(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url,
|
||||
)
|
||||
|
||||
return {
|
||||
"content": response.choices[0].message.content,
|
||||
"usage": {
|
||||
"prompt_tokens": response.usage.prompt_tokens,
|
||||
"completion_tokens": response.usage.completion_tokens,
|
||||
"total_tokens": response.usage.total_tokens,
|
||||
} if response.usage else {},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM call failed: {e}")
|
||||
raise
|
||||
|
||||
async def analyze_code(self, code: str, language: str) -> Dict[str, Any]:
|
||||
"""
|
||||
分析代码安全问题
|
||||
|
||||
Args:
|
||||
code: 代码内容
|
||||
language: 编程语言
|
||||
|
||||
Returns:
|
||||
分析结果,包含 issues 列表
|
||||
"""
|
||||
prompt = f"""请分析以下 {language} 代码的安全问题。
|
||||
|
||||
代码:
|
||||
```{language}
|
||||
{code[:8000]}
|
||||
```
|
||||
|
||||
请识别所有潜在的安全漏洞,包括但不限于:
|
||||
- SQL 注入
|
||||
- XSS (跨站脚本)
|
||||
- 命令注入
|
||||
- 路径遍历
|
||||
- 不安全的反序列化
|
||||
- 硬编码密钥/密码
|
||||
- 不安全的加密
|
||||
- SSRF
|
||||
- 认证/授权问题
|
||||
|
||||
对于每个发现的问题,请提供:
|
||||
1. 漏洞类型
|
||||
2. 严重程度 (critical/high/medium/low)
|
||||
3. 问题描述
|
||||
4. 具体行号
|
||||
5. 修复建议
|
||||
|
||||
请以 JSON 格式返回结果:
|
||||
{{
|
||||
"issues": [
|
||||
{{
|
||||
"type": "漏洞类型",
|
||||
"severity": "严重程度",
|
||||
"title": "问题标题",
|
||||
"description": "详细描述",
|
||||
"line": 行号,
|
||||
"code_snippet": "相关代码片段",
|
||||
"suggestion": "修复建议"
|
||||
}}
|
||||
],
|
||||
"quality_score": 0-100
|
||||
}}
|
||||
|
||||
如果没有发现安全问题,返回空的 issues 数组和较高的 quality_score。"""
|
||||
|
||||
try:
|
||||
result = await self.chat_completion_raw(
|
||||
messages=[
|
||||
{"role": "system", "content": "你是一位专业的代码安全审计专家,擅长发现代码中的安全漏洞。请只返回 JSON 格式的结果,不要包含其他内容。"},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
temperature=0.1,
|
||||
max_tokens=4096,
|
||||
)
|
||||
|
||||
content = result.get("content", "{}")
|
||||
|
||||
# 尝试提取 JSON
|
||||
import json
|
||||
import re
|
||||
|
||||
# 尝试直接解析
|
||||
try:
|
||||
return json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 尝试从 markdown 代码块提取
|
||||
json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', content)
|
||||
if json_match:
|
||||
try:
|
||||
return json.loads(json_match.group(1))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# 返回空结果
|
||||
return {"issues": [], "quality_score": 80}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Code analysis failed: {e}")
|
||||
return {"issues": [], "quality_score": 0, "error": str(e)}
|
||||
|
||||
async def analyze_code_with_custom_prompt(
|
||||
self,
|
||||
code: str,
|
||||
language: str,
|
||||
prompt: str,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""使用自定义提示词分析代码"""
|
||||
full_prompt = prompt.replace("{code}", code).replace("{language}", language)
|
||||
|
||||
try:
|
||||
result = await self.chat_completion_raw(
|
||||
messages=[
|
||||
{"role": "system", "content": "你是一位专业的代码安全审计专家。"},
|
||||
{"role": "user", "content": full_prompt},
|
||||
],
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
return {
|
||||
"analysis": result.get("content", ""),
|
||||
"usage": result.get("usage", {}),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Custom analysis failed: {e}")
|
||||
return {"analysis": "", "error": str(e)}
|
||||
# 🔥 使用系统统一的 LLMService(支持用户配置)
|
||||
from app.services.llm.service import LLMService
|
||||
|
||||
|
||||
class AgentRunner:
|
||||
|
|
@ -217,18 +58,22 @@ class AgentRunner:
|
|||
db: AsyncSession,
|
||||
task: AgentTask,
|
||||
project_root: str,
|
||||
user_config: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
self.db = db
|
||||
self.task = task
|
||||
self.project_root = project_root
|
||||
|
||||
# 🔥 保存用户配置,供 RAG 初始化使用
|
||||
self.user_config = user_config or {}
|
||||
|
||||
# 事件管理 - 传入 db_session_factory 以持久化事件
|
||||
from app.db.session import async_session_factory
|
||||
self.event_manager = EventManager(db_session_factory=async_session_factory)
|
||||
self.event_emitter = AgentEventEmitter(task.id, self.event_manager)
|
||||
|
||||
# LLM 服务
|
||||
self.llm_service = LLMService()
|
||||
# 🔥 LLM 服务 - 使用用户配置(从系统配置页面获取)
|
||||
self.llm_service = LLMService(user_config=self.user_config)
|
||||
|
||||
# 工具集
|
||||
self.tools: Dict[str, Any] = {}
|
||||
|
|
@ -248,14 +93,26 @@ class AgentRunner:
|
|||
self._cancelled = False
|
||||
self._running_task: Optional[asyncio.Task] = None
|
||||
|
||||
# Agent 引用(用于取消传播)
|
||||
self._agents: List[Any] = []
|
||||
|
||||
# 流式处理器
|
||||
self.stream_handler = StreamHandler(task.id)
|
||||
|
||||
def cancel(self):
|
||||
"""取消任务"""
|
||||
self._cancelled = True
|
||||
|
||||
# 🔥 取消所有 Agent
|
||||
for agent in self._agents:
|
||||
if hasattr(agent, 'cancel'):
|
||||
agent.cancel()
|
||||
logger.debug(f"Cancelled agent: {agent.name if hasattr(agent, 'name') else 'unknown'}")
|
||||
|
||||
# 取消运行中的任务
|
||||
if self._running_task and not self._running_task.done():
|
||||
self._running_task.cancel()
|
||||
|
||||
logger.info(f"Task {self.task.id} cancellation requested")
|
||||
|
||||
@property
|
||||
|
|
@ -283,11 +140,33 @@ class AgentRunner:
|
|||
await self.event_emitter.emit_info("📚 初始化 RAG 代码检索系统...")
|
||||
|
||||
try:
|
||||
# 🔥 从用户配置中获取 LLM 配置(用于 Embedding API Key)
|
||||
# 优先级:用户配置 > 环境变量
|
||||
user_llm_config = self.user_config.get('llmConfig', {})
|
||||
|
||||
# 获取 Embedding 配置(优先使用用户配置的 LLM API Key)
|
||||
embedding_provider = getattr(settings, 'EMBEDDING_PROVIDER', 'openai')
|
||||
embedding_model = getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small')
|
||||
|
||||
# 🔥 API Key 优先级:用户配置 > 环境变量
|
||||
embedding_api_key = (
|
||||
user_llm_config.get('llmApiKey') or
|
||||
getattr(settings, 'LLM_API_KEY', '') or
|
||||
''
|
||||
)
|
||||
|
||||
# 🔥 Base URL 优先级:用户配置 > 环境变量
|
||||
embedding_base_url = (
|
||||
user_llm_config.get('llmBaseUrl') or
|
||||
getattr(settings, 'LLM_BASE_URL', None) or
|
||||
None
|
||||
)
|
||||
|
||||
embedding_service = EmbeddingService(
|
||||
provider=settings.EMBEDDING_PROVIDER,
|
||||
model=settings.EMBEDDING_MODEL,
|
||||
api_key=settings.LLM_API_KEY,
|
||||
base_url=settings.LLM_BASE_URL,
|
||||
provider=embedding_provider,
|
||||
model=embedding_model,
|
||||
api_key=embedding_api_key,
|
||||
base_url=embedding_base_url,
|
||||
)
|
||||
|
||||
self.indexer = CodeIndexer(
|
||||
|
|
@ -308,35 +187,59 @@ class AgentRunner:
|
|||
|
||||
async def _initialize_tools(self):
|
||||
"""初始化工具集"""
|
||||
await self.event_emitter.emit_info("🔧 初始化 Agent 工具集...")
|
||||
await self.event_emitter.emit_info("初始化 Agent 工具集...")
|
||||
|
||||
# 文件工具
|
||||
self.tools["read_file"] = FileReadTool(self.project_root)
|
||||
self.tools["search_code"] = FileSearchTool(self.project_root)
|
||||
self.tools["list_files"] = ListFilesTool(self.project_root)
|
||||
# ============ 基础工具(所有 Agent 共享)============
|
||||
base_tools = {
|
||||
"read_file": FileReadTool(self.project_root),
|
||||
"list_files": ListFilesTool(self.project_root),
|
||||
}
|
||||
|
||||
# RAG 工具
|
||||
# ============ Recon Agent 专属工具 ============
|
||||
# 职责:信息收集、项目结构分析、技术栈识别
|
||||
self.recon_tools = {
|
||||
**base_tools,
|
||||
"search_code": FileSearchTool(self.project_root),
|
||||
}
|
||||
|
||||
# RAG 工具(Recon 用于语义搜索)
|
||||
if self.retriever:
|
||||
self.tools["rag_query"] = RAGQueryTool(self.retriever)
|
||||
self.tools["security_search"] = SecurityCodeSearchTool(self.retriever)
|
||||
self.tools["function_context"] = FunctionContextTool(self.retriever)
|
||||
self.recon_tools["rag_query"] = RAGQueryTool(self.retriever)
|
||||
|
||||
# 分析工具
|
||||
self.tools["pattern_match"] = PatternMatchTool(self.project_root)
|
||||
self.tools["code_analysis"] = CodeAnalysisTool(self.llm_service)
|
||||
self.tools["dataflow_analysis"] = DataFlowAnalysisTool(self.llm_service)
|
||||
self.tools["vulnerability_validation"] = VulnerabilityValidationTool(self.llm_service)
|
||||
# ============ Analysis Agent 专属工具 ============
|
||||
# 职责:漏洞分析、代码审计、模式匹配
|
||||
self.analysis_tools = {
|
||||
**base_tools,
|
||||
"search_code": FileSearchTool(self.project_root),
|
||||
# 模式匹配和代码分析
|
||||
"pattern_match": PatternMatchTool(self.project_root),
|
||||
"code_analysis": CodeAnalysisTool(self.llm_service),
|
||||
"dataflow_analysis": DataFlowAnalysisTool(self.llm_service),
|
||||
# 外部静态分析工具
|
||||
"semgrep_scan": SemgrepTool(self.project_root),
|
||||
"bandit_scan": BanditTool(self.project_root),
|
||||
"gitleaks_scan": GitleaksTool(self.project_root),
|
||||
"trufflehog_scan": TruffleHogTool(self.project_root),
|
||||
"npm_audit": NpmAuditTool(self.project_root),
|
||||
"safety_scan": SafetyTool(self.project_root),
|
||||
"osv_scan": OSVScannerTool(self.project_root),
|
||||
}
|
||||
|
||||
# 外部安全工具
|
||||
self.tools["semgrep_scan"] = SemgrepTool(self.project_root)
|
||||
self.tools["bandit_scan"] = BanditTool(self.project_root)
|
||||
self.tools["gitleaks_scan"] = GitleaksTool(self.project_root)
|
||||
self.tools["trufflehog_scan"] = TruffleHogTool(self.project_root)
|
||||
self.tools["npm_audit"] = NpmAuditTool(self.project_root)
|
||||
self.tools["safety_scan"] = SafetyTool(self.project_root)
|
||||
self.tools["osv_scan"] = OSVScannerTool(self.project_root)
|
||||
# RAG 工具(Analysis 用于安全相关代码搜索)
|
||||
if self.retriever:
|
||||
self.analysis_tools["security_search"] = SecurityCodeSearchTool(self.retriever)
|
||||
self.analysis_tools["function_context"] = FunctionContextTool(self.retriever)
|
||||
|
||||
# 沙箱工具
|
||||
# ============ Verification Agent 专属工具 ============
|
||||
# 职责:漏洞验证、PoC 执行、误报排除
|
||||
self.verification_tools = {
|
||||
**base_tools,
|
||||
# 验证工具
|
||||
"vulnerability_validation": VulnerabilityValidationTool(self.llm_service),
|
||||
"dataflow_analysis": DataFlowAnalysisTool(self.llm_service),
|
||||
}
|
||||
|
||||
# 沙箱工具(仅 Verification Agent 可用)
|
||||
try:
|
||||
self.sandbox_manager = SandboxManager(
|
||||
image=settings.SANDBOX_IMAGE,
|
||||
|
|
@ -344,14 +247,20 @@ class AgentRunner:
|
|||
cpu_limit=settings.SANDBOX_CPU_LIMIT,
|
||||
)
|
||||
|
||||
self.tools["sandbox_exec"] = SandboxTool(self.sandbox_manager)
|
||||
self.tools["sandbox_http"] = SandboxHttpTool(self.sandbox_manager)
|
||||
self.tools["verify_vulnerability"] = VulnerabilityVerifyTool(self.sandbox_manager)
|
||||
self.verification_tools["sandbox_exec"] = SandboxTool(self.sandbox_manager)
|
||||
self.verification_tools["sandbox_http"] = SandboxHttpTool(self.sandbox_manager)
|
||||
self.verification_tools["verify_vulnerability"] = VulnerabilityVerifyTool(self.sandbox_manager)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Sandbox initialization failed: {e}")
|
||||
|
||||
await self.event_emitter.emit_info(f"✅ 已加载 {len(self.tools)} 个工具")
|
||||
# 统计总工具数
|
||||
total_tools = len(set(
|
||||
list(self.recon_tools.keys()) +
|
||||
list(self.analysis_tools.keys()) +
|
||||
list(self.verification_tools.keys())
|
||||
))
|
||||
await self.event_emitter.emit_info(f"已加载 {total_tools} 个工具")
|
||||
|
||||
async def _build_graph(self):
|
||||
"""构建 LangGraph 审计图"""
|
||||
|
|
@ -360,25 +269,28 @@ class AgentRunner:
|
|||
# 导入 Agent
|
||||
from app.services.agent.agents import ReconAgent, AnalysisAgent, VerificationAgent
|
||||
|
||||
# 创建 Agent 实例
|
||||
# 创建 Agent 实例(每个 Agent 使用专属工具集)
|
||||
recon_agent = ReconAgent(
|
||||
llm_service=self.llm_service,
|
||||
tools=self.tools,
|
||||
tools=self.recon_tools, # Recon 专属工具
|
||||
event_emitter=self.event_emitter,
|
||||
)
|
||||
|
||||
analysis_agent = AnalysisAgent(
|
||||
llm_service=self.llm_service,
|
||||
tools=self.tools,
|
||||
tools=self.analysis_tools, # Analysis 专属工具
|
||||
event_emitter=self.event_emitter,
|
||||
)
|
||||
|
||||
verification_agent = VerificationAgent(
|
||||
llm_service=self.llm_service,
|
||||
tools=self.tools,
|
||||
tools=self.verification_tools, # Verification 专属工具
|
||||
event_emitter=self.event_emitter,
|
||||
)
|
||||
|
||||
# 🔥 保存 Agent 引用以便取消时传播信号
|
||||
self._agents = [recon_agent, analysis_agent, verification_agent]
|
||||
|
||||
# 创建节点
|
||||
recon_node = ReconNode(recon_agent, self.event_emitter)
|
||||
analysis_node = AnalysisNode(analysis_agent, self.event_emitter)
|
||||
|
|
@ -481,6 +393,10 @@ class AgentRunner:
|
|||
"iteration": 0,
|
||||
"max_iterations": self.task.max_iterations or 50,
|
||||
"should_continue_analysis": False,
|
||||
# 🔥 Agent 协作交接信息
|
||||
"recon_handoff": None,
|
||||
"analysis_handoff": None,
|
||||
"verification_handoff": None,
|
||||
"messages": [],
|
||||
"events": [],
|
||||
"summary": None,
|
||||
|
|
@ -556,6 +472,33 @@ class AgentRunner:
|
|||
graph_state = self.graph.get_state(run_config)
|
||||
final_state = graph_state.values if graph_state else {}
|
||||
|
||||
# 🔥 检查是否有错误
|
||||
error = final_state.get("error")
|
||||
if error:
|
||||
# 检查是否是 LLM 认证错误
|
||||
error_str = str(error)
|
||||
if "AuthenticationError" in error_str or "API key" in error_str or "invalid_api_key" in error_str:
|
||||
error_message = "LLM API 密钥配置错误。请检查环境变量 LLM_API_KEY 或配置中的 API 密钥是否正确。"
|
||||
logger.error(f"LLM authentication error: {error}")
|
||||
else:
|
||||
error_message = error_str
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# 标记任务为失败
|
||||
await self._update_task_status(AgentTaskStatus.FAILED, error_message)
|
||||
await self.event_emitter.emit_task_error(error_message)
|
||||
|
||||
yield StreamEvent(
|
||||
event_type=StreamEventType.TASK_ERROR,
|
||||
sequence=self.stream_handler._next_sequence(),
|
||||
data={
|
||||
"error": error_message,
|
||||
"message": f"❌ 任务失败: {error_message}",
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
# 6. 保存发现
|
||||
findings = final_state.get("findings", [])
|
||||
await self._save_findings(findings)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,251 @@
|
|||
"""
|
||||
Agent JSON 解析工具
|
||||
从 LLM 响应中安全地解析 JSON,参考 llm/service.py 的实现
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 尝试导入 json-repair 库
|
||||
try:
|
||||
from json_repair import repair_json
|
||||
JSON_REPAIR_AVAILABLE = True
|
||||
except ImportError:
|
||||
JSON_REPAIR_AVAILABLE = False
|
||||
logger.debug("json-repair library not available")
|
||||
|
||||
|
||||
class AgentJsonParser:
|
||||
"""Agent 专用的 JSON 解析器"""
|
||||
|
||||
@staticmethod
|
||||
def clean_text(text: str) -> str:
|
||||
"""清理文本中的控制字符"""
|
||||
if not text:
|
||||
return ""
|
||||
# 移除 BOM 和零宽字符
|
||||
text = text.replace('\ufeff', '').replace('\u200b', '').replace('\u200c', '').replace('\u200d', '')
|
||||
return text
|
||||
|
||||
@staticmethod
|
||||
def fix_json_format(text: str) -> str:
|
||||
"""修复常见的 JSON 格式问题"""
|
||||
text = text.strip()
|
||||
# 移除尾部逗号
|
||||
text = re.sub(r',(\s*[}\]])', r'\1', text)
|
||||
# 修复未转义的换行符(在字符串值中)
|
||||
text = re.sub(r':\s*"([^"]*)\n([^"]*)"', r': "\1\\n\2"', text)
|
||||
return text
|
||||
|
||||
@classmethod
|
||||
def extract_from_markdown(cls, text: str) -> Dict[str, Any]:
|
||||
"""从 markdown 代码块提取 JSON"""
|
||||
match = re.search(r'```(?:json)?\s*(\{[\s\S]*?\})\s*```', text)
|
||||
if match:
|
||||
return json.loads(match.group(1))
|
||||
raise ValueError("No markdown code block found")
|
||||
|
||||
@classmethod
|
||||
def extract_json_object(cls, text: str) -> Dict[str, Any]:
|
||||
"""智能提取 JSON 对象"""
|
||||
start_idx = text.find('{')
|
||||
if start_idx == -1:
|
||||
raise ValueError("No JSON object found")
|
||||
|
||||
# 考虑字符串内的花括号和转义字符
|
||||
brace_count = 0
|
||||
in_string = False
|
||||
escape_next = False
|
||||
end_idx = -1
|
||||
|
||||
for i in range(start_idx, len(text)):
|
||||
char = text[i]
|
||||
|
||||
if escape_next:
|
||||
escape_next = False
|
||||
continue
|
||||
|
||||
if char == '\\':
|
||||
escape_next = True
|
||||
continue
|
||||
|
||||
if char == '"' and not escape_next:
|
||||
in_string = not in_string
|
||||
continue
|
||||
|
||||
if not in_string:
|
||||
if char == '{':
|
||||
brace_count += 1
|
||||
elif char == '}':
|
||||
brace_count -= 1
|
||||
if brace_count == 0:
|
||||
end_idx = i + 1
|
||||
break
|
||||
|
||||
if end_idx == -1:
|
||||
# 如果找不到完整的 JSON,尝试使用最后一个 }
|
||||
last_brace = text.rfind('}')
|
||||
if last_brace > start_idx:
|
||||
end_idx = last_brace + 1
|
||||
else:
|
||||
raise ValueError("Incomplete JSON object")
|
||||
|
||||
json_str = text[start_idx:end_idx]
|
||||
# 修复格式问题
|
||||
json_str = re.sub(r',(\s*[}\]])', r'\1', json_str)
|
||||
|
||||
return json.loads(json_str)
|
||||
|
||||
@classmethod
|
||||
def fix_truncated_json(cls, text: str) -> Dict[str, Any]:
|
||||
"""修复截断的 JSON"""
|
||||
start_idx = text.find('{')
|
||||
if start_idx == -1:
|
||||
raise ValueError("Cannot fix truncated JSON")
|
||||
|
||||
json_str = text[start_idx:]
|
||||
|
||||
# 计算缺失的闭合符号
|
||||
open_braces = json_str.count('{')
|
||||
close_braces = json_str.count('}')
|
||||
open_brackets = json_str.count('[')
|
||||
close_brackets = json_str.count(']')
|
||||
|
||||
# 补全缺失的闭合符号
|
||||
json_str += ']' * max(0, open_brackets - close_brackets)
|
||||
json_str += '}' * max(0, open_braces - close_braces)
|
||||
|
||||
# 修复格式
|
||||
json_str = re.sub(r',(\s*[}\]])', r'\1', json_str)
|
||||
return json.loads(json_str)
|
||||
|
||||
@classmethod
|
||||
def repair_with_library(cls, text: str) -> Dict[str, Any]:
|
||||
"""使用 json-repair 库修复损坏的 JSON"""
|
||||
if not JSON_REPAIR_AVAILABLE:
|
||||
raise ValueError("json-repair library not available")
|
||||
|
||||
start_idx = text.find('{')
|
||||
if start_idx == -1:
|
||||
raise ValueError("No JSON object found for repair")
|
||||
|
||||
end_idx = text.rfind('}')
|
||||
if end_idx > start_idx:
|
||||
json_str = text[start_idx:end_idx + 1]
|
||||
else:
|
||||
json_str = text[start_idx:]
|
||||
|
||||
repaired = repair_json(json_str, return_objects=True)
|
||||
|
||||
if isinstance(repaired, dict):
|
||||
return repaired
|
||||
|
||||
raise ValueError(f"json-repair returned unexpected type: {type(repaired)}")
|
||||
|
||||
@classmethod
|
||||
def parse(cls, text: str, default: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
从 LLM 响应中解析 JSON(增强版)
|
||||
|
||||
Args:
|
||||
text: LLM 响应文本
|
||||
default: 解析失败时返回的默认值,如果为 None 则抛出异常
|
||||
|
||||
Returns:
|
||||
解析后的字典
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
if default is not None:
|
||||
logger.warning("LLM 响应为空,返回默认值")
|
||||
return default
|
||||
raise ValueError("LLM 响应内容为空")
|
||||
|
||||
clean = cls.clean_text(text)
|
||||
|
||||
# 尝试多种方式解析
|
||||
attempts = [
|
||||
("直接解析", lambda: json.loads(text)),
|
||||
("清理后解析", lambda: json.loads(cls.fix_json_format(clean))),
|
||||
("Markdown 提取", lambda: cls.extract_from_markdown(text)),
|
||||
("智能提取", lambda: cls.extract_json_object(clean)),
|
||||
("截断修复", lambda: cls.fix_truncated_json(clean)),
|
||||
("json-repair", lambda: cls.repair_with_library(text)),
|
||||
]
|
||||
|
||||
last_error = None
|
||||
for name, attempt in attempts:
|
||||
try:
|
||||
result = attempt()
|
||||
if result and isinstance(result, dict):
|
||||
if name != "直接解析":
|
||||
logger.debug(f"✅ JSON 解析成功(方法: {name})")
|
||||
return result
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
logger.debug(f"JSON 解析方法 '{name}' 失败: {e}")
|
||||
|
||||
# 所有尝试都失败
|
||||
if default is not None:
|
||||
logger.warning(f"JSON 解析失败,返回默认值。原始内容: {text[:200]}...")
|
||||
return default
|
||||
|
||||
logger.error(f"❌ 无法解析 JSON,原始内容: {text[:500]}...")
|
||||
raise ValueError(f"无法解析 JSON: {last_error}")
|
||||
|
||||
@classmethod
|
||||
def parse_findings(cls, text: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
专门解析 findings 列表
|
||||
|
||||
Args:
|
||||
text: LLM 响应文本
|
||||
|
||||
Returns:
|
||||
findings 列表(每个元素都是字典)
|
||||
"""
|
||||
try:
|
||||
result = cls.parse(text, default={"findings": []})
|
||||
findings = result.get("findings", [])
|
||||
|
||||
# 确保每个 finding 都是字典
|
||||
valid_findings = []
|
||||
for f in findings:
|
||||
if isinstance(f, dict):
|
||||
valid_findings.append(f)
|
||||
elif isinstance(f, str):
|
||||
# 尝试将字符串解析为 JSON
|
||||
try:
|
||||
parsed = json.loads(f)
|
||||
if isinstance(parsed, dict):
|
||||
valid_findings.append(parsed)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"跳过无效的 finding(字符串): {f[:100]}...")
|
||||
else:
|
||||
logger.warning(f"跳过无效的 finding(类型: {type(f)})")
|
||||
|
||||
return valid_findings
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析 findings 失败: {e}")
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def safe_get(cls, data: Union[Dict, str, Any], key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
安全地从数据中获取值
|
||||
|
||||
Args:
|
||||
data: 可能是字典或其他类型
|
||||
key: 要获取的键
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
获取的值或默认值
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
return data.get(key, default)
|
||||
return default
|
||||
|
|
@ -79,9 +79,25 @@ class CodeAnalysisTool(AgentTool):
|
|||
**kwargs
|
||||
) -> ToolResult:
|
||||
"""执行代码分析"""
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
# 构建分析结果
|
||||
analysis = await self.llm_service.analyze_code(code, language)
|
||||
# 限制代码长度,避免超时
|
||||
max_code_length = 50000 # 约 50KB
|
||||
if len(code) > max_code_length:
|
||||
code = code[:max_code_length] + "\n\n... (代码已截断,仅分析前 50000 字符)"
|
||||
|
||||
# 添加超时保护(5分钟)
|
||||
try:
|
||||
analysis = await asyncio.wait_for(
|
||||
self.llm_service.analyze_code(code, language),
|
||||
timeout=300.0 # 5分钟超时
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="代码分析超时(超过5分钟)。代码可能过长或过于复杂,请尝试分析较小的代码片段。",
|
||||
)
|
||||
|
||||
issues = analysis.get("issues", [])
|
||||
|
||||
|
|
|
|||
|
|
@ -109,9 +109,13 @@ Semgrep 是业界领先的静态分析工具,支持 30+ 种编程语言。
|
|||
"""执行 Semgrep 扫描"""
|
||||
# 检查 semgrep 是否可用
|
||||
if not await self._check_semgrep():
|
||||
# 尝试自动安装
|
||||
logger.info("Semgrep 未安装,尝试自动安装...")
|
||||
install_success = await self._try_install_semgrep()
|
||||
if not install_success:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="Semgrep 未安装。请使用 'pip install semgrep' 安装。",
|
||||
error="Semgrep 未安装。请使用 'pip install semgrep' 安装,或联系管理员安装。",
|
||||
)
|
||||
|
||||
# 构建完整路径
|
||||
|
|
@ -217,6 +221,30 @@ Semgrep 是业界领先的静态分析工具,支持 30+ 种编程语言。
|
|||
except:
|
||||
return False
|
||||
|
||||
async def _try_install_semgrep(self) -> bool:
|
||||
"""尝试自动安装 Semgrep"""
|
||||
try:
|
||||
logger.info("正在安装 Semgrep...")
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"pip", "install", "semgrep",
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=120)
|
||||
if proc.returncode == 0:
|
||||
logger.info("Semgrep 安装成功")
|
||||
# 验证安装
|
||||
return await self._check_semgrep()
|
||||
else:
|
||||
logger.warning(f"Semgrep 安装失败: {stderr.decode()[:200]}")
|
||||
return False
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Semgrep 安装超时")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"Semgrep 安装出错: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# ============ Bandit 工具 (Python) ============
|
||||
|
||||
|
|
@ -422,7 +450,11 @@ Gitleaks 是专业的密钥检测工具,支持 150+ 种密钥类型。
|
|||
if not await self._check_gitleaks():
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="Gitleaks 未安装。请从 https://github.com/gitleaks/gitleaks 安装。",
|
||||
error="Gitleaks 未安装。Gitleaks 需要手动安装,请参考: https://github.com/gitleaks/gitleaks/releases\n"
|
||||
"安装方法:\n"
|
||||
"- macOS: brew install gitleaks\n"
|
||||
"- Linux: 下载二进制文件并添加到 PATH\n"
|
||||
"- Windows: 下载二进制文件并添加到 PATH",
|
||||
)
|
||||
|
||||
full_path = os.path.normpath(os.path.join(self.project_root, target_path))
|
||||
|
|
|
|||
|
|
@ -291,8 +291,19 @@ class PatternMatchTool(AgentTool):
|
|||
return f"""快速扫描代码中的危险模式和常见漏洞。
|
||||
使用正则表达式检测已知的不安全代码模式。
|
||||
|
||||
⚠️ 重要:此工具需要代码内容作为输入,不是目录路径!
|
||||
使用步骤:
|
||||
1. 先用 read_file 工具读取文件内容
|
||||
2. 然后将读取的代码内容传递给此工具的 code 参数
|
||||
|
||||
支持的漏洞类型: {vuln_types}
|
||||
|
||||
输入参数:
|
||||
- code (必需): 要扫描的代码内容(字符串)
|
||||
- file_path (可选): 文件路径,用于上下文
|
||||
- pattern_types (可选): 要检测的漏洞类型列表,如 ['sql_injection', 'xss']
|
||||
- language (可选): 编程语言,如 'python', 'php', 'javascript'
|
||||
|
||||
这是一个快速扫描工具,可以在分析开始时使用来快速发现潜在问题。
|
||||
发现的问题需要进一步分析确认。"""
|
||||
|
||||
|
|
|
|||
|
|
@ -189,9 +189,26 @@ class SecurityCodeSearchTool(AgentTool):
|
|||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
# 提供更友好的错误信息
|
||||
if "401" in error_msg or "Unauthorized" in error_msg:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"安全代码搜索失败: {str(e)}",
|
||||
error=f"安全代码搜索失败: API 认证失败(401 Unauthorized)。\n"
|
||||
f"请检查系统配置中的 LLM API Key 是否正确设置。\n"
|
||||
f"错误详情: {error_msg[:200]}",
|
||||
)
|
||||
elif "403" in error_msg or "Forbidden" in error_msg:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"安全代码搜索失败: API 访问被拒绝(403 Forbidden)。\n"
|
||||
f"请检查 API Key 是否有足够的权限。\n"
|
||||
f"错误详情: {error_msg[:200]}",
|
||||
)
|
||||
else:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"安全代码搜索失败: {error_msg[:500]}",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -177,6 +177,85 @@ class LiteLLMAdapter(BaseLLMAdapter):
|
|||
finish_reason=choice.finish_reason,
|
||||
)
|
||||
|
||||
async def stream_complete(self, request: LLMRequest):
|
||||
"""
|
||||
流式调用 LLM,逐 token 返回
|
||||
|
||||
Yields:
|
||||
dict: {"type": "token", "content": str} 或 {"type": "done", "content": str, "usage": dict}
|
||||
"""
|
||||
import litellm
|
||||
|
||||
await self.validate_config()
|
||||
|
||||
litellm.cache = None
|
||||
litellm.drop_params = True
|
||||
|
||||
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
||||
|
||||
kwargs = {
|
||||
"model": self._litellm_model,
|
||||
"messages": messages,
|
||||
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
|
||||
"max_tokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens,
|
||||
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
|
||||
"stream": True, # 启用流式输出
|
||||
}
|
||||
|
||||
if self.config.api_key and self.config.api_key != "ollama":
|
||||
kwargs["api_key"] = self.config.api_key
|
||||
|
||||
if self._api_base:
|
||||
kwargs["api_base"] = self._api_base
|
||||
|
||||
kwargs["timeout"] = self.config.timeout
|
||||
|
||||
accumulated_content = ""
|
||||
|
||||
try:
|
||||
response = await litellm.acompletion(**kwargs)
|
||||
|
||||
async for chunk in response:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0].delta
|
||||
content = getattr(delta, "content", "") or ""
|
||||
finish_reason = chunk.choices[0].finish_reason
|
||||
|
||||
if content:
|
||||
accumulated_content += content
|
||||
yield {
|
||||
"type": "token",
|
||||
"content": content,
|
||||
"accumulated": accumulated_content,
|
||||
}
|
||||
|
||||
if finish_reason:
|
||||
# 流式完成
|
||||
usage = None
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
usage = {
|
||||
"prompt_tokens": chunk.usage.prompt_tokens or 0,
|
||||
"completion_tokens": chunk.usage.completion_tokens or 0,
|
||||
"total_tokens": chunk.usage.total_tokens or 0,
|
||||
}
|
||||
|
||||
yield {
|
||||
"type": "done",
|
||||
"content": accumulated_content,
|
||||
"usage": usage,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
yield {
|
||||
"type": "error",
|
||||
"error": str(e),
|
||||
"accumulated": accumulated_content,
|
||||
}
|
||||
|
||||
async def validate_config(self) -> bool:
|
||||
"""验证配置"""
|
||||
# Ollama 不需要 API Key
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ LLM服务 - 代码分析核心服务
|
|||
import json
|
||||
import re
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Dict, Any, Optional, List
|
||||
from .types import LLMConfig, LLMProvider, LLMMessage, LLMRequest, DEFAULT_MODELS
|
||||
from .factory import LLMFactory
|
||||
from app.core.config import settings
|
||||
|
|
@ -36,15 +36,23 @@ class LLMService:
|
|||
|
||||
@property
|
||||
def config(self) -> LLMConfig:
|
||||
"""获取LLM配置(优先使用用户配置,然后使用系统配置)"""
|
||||
"""
|
||||
获取LLM配置
|
||||
|
||||
🔥 优先级(从高到低):
|
||||
1. 数据库用户配置(系统配置页面保存的配置)
|
||||
2. 环境变量配置(.env 文件中的配置)
|
||||
|
||||
如果用户配置中某个字段为空,则自动回退到环境变量。
|
||||
"""
|
||||
if self._config is None:
|
||||
user_llm_config = self._user_config.get('llmConfig', {})
|
||||
|
||||
# 优先使用用户配置的provider,否则使用系统配置
|
||||
# 🔥 Provider 优先级:用户配置 > 环境变量
|
||||
provider_str = user_llm_config.get('llmProvider') or getattr(settings, 'LLM_PROVIDER', 'openai')
|
||||
provider = self._parse_provider(provider_str)
|
||||
|
||||
# 获取API Key - 优先级:用户配置 > 系统通用配置 > 系统平台专属配置
|
||||
# 🔥 API Key 优先级:用户配置 > 环境变量通用配置 > 环境变量平台专属配置
|
||||
api_key = (
|
||||
user_llm_config.get('llmApiKey') or
|
||||
getattr(settings, 'LLM_API_KEY', '') or
|
||||
|
|
@ -52,33 +60,33 @@ class LLMService:
|
|||
self._get_provider_api_key(provider)
|
||||
)
|
||||
|
||||
# 获取Base URL
|
||||
# 🔥 Base URL 优先级:用户配置 > 环境变量
|
||||
base_url = (
|
||||
user_llm_config.get('llmBaseUrl') or
|
||||
getattr(settings, 'LLM_BASE_URL', None) or
|
||||
self._get_provider_base_url(provider)
|
||||
)
|
||||
|
||||
# 获取模型
|
||||
# 🔥 Model 优先级:用户配置 > 环境变量 > 默认模型
|
||||
model = (
|
||||
user_llm_config.get('llmModel') or
|
||||
getattr(settings, 'LLM_MODEL', '') or
|
||||
DEFAULT_MODELS.get(provider, 'gpt-4o-mini')
|
||||
)
|
||||
|
||||
# 获取超时时间(用户配置是毫秒,系统配置是秒)
|
||||
# 🔥 Timeout 优先级:用户配置(毫秒) > 环境变量(秒)
|
||||
timeout_ms = user_llm_config.get('llmTimeout')
|
||||
if timeout_ms:
|
||||
# 用户配置是毫秒,转换为秒
|
||||
timeout = int(timeout_ms / 1000) if timeout_ms > 1000 else int(timeout_ms)
|
||||
else:
|
||||
# 系统配置是秒
|
||||
# 环境变量是秒
|
||||
timeout = int(getattr(settings, 'LLM_TIMEOUT', 150))
|
||||
|
||||
# 获取温度
|
||||
# 🔥 Temperature 优先级:用户配置 > 环境变量
|
||||
temperature = user_llm_config.get('llmTemperature') if user_llm_config.get('llmTemperature') is not None else float(getattr(settings, 'LLM_TEMPERATURE', 0.1))
|
||||
|
||||
# 获取最大token数
|
||||
# 🔥 Max Tokens 优先级:用户配置 > 环境变量
|
||||
max_tokens = user_llm_config.get('llmMaxTokens') or int(getattr(settings, 'LLM_MAX_TOKENS', 4096))
|
||||
|
||||
self._config = LLMConfig(
|
||||
|
|
@ -394,6 +402,83 @@ Please analyze the following code:
|
|||
# 重新抛出异常,让调用者处理
|
||||
raise
|
||||
|
||||
async def chat_completion_raw(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
temperature: float = 0.1,
|
||||
max_tokens: int = 4096,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
🔥 Agent 使用的原始聊天完成接口(兼容旧接口)
|
||||
|
||||
Args:
|
||||
messages: 消息列表,格式为 [{"role": "user", "content": "..."}]
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
|
||||
Returns:
|
||||
包含 content 和 usage 的字典
|
||||
"""
|
||||
# 转换消息格式
|
||||
llm_messages = [
|
||||
LLMMessage(role=msg["role"], content=msg["content"])
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
request = LLMRequest(
|
||||
messages=llm_messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
adapter = LLMFactory.create_adapter(self.config)
|
||||
response = await adapter.complete(request)
|
||||
|
||||
return {
|
||||
"content": response.content,
|
||||
"usage": {
|
||||
"prompt_tokens": response.usage.prompt_tokens if response.usage else 0,
|
||||
"completion_tokens": response.usage.completion_tokens if response.usage else 0,
|
||||
"total_tokens": response.usage.total_tokens if response.usage else 0,
|
||||
},
|
||||
}
|
||||
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
temperature: float = 0.1,
|
||||
max_tokens: int = 4096,
|
||||
):
|
||||
"""
|
||||
流式聊天完成接口,逐 token 返回
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
|
||||
Yields:
|
||||
dict: {"type": "token", "content": str} 或 {"type": "done", ...}
|
||||
"""
|
||||
from .adapters.litellm_adapter import LiteLLMAdapter
|
||||
|
||||
llm_messages = [
|
||||
LLMMessage(role=msg["role"], content=msg["content"])
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
request = LLMRequest(
|
||||
messages=llm_messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
# 使用 LiteLLM adapter 进行流式调用
|
||||
adapter = LiteLLMAdapter(self.config)
|
||||
|
||||
async for chunk in adapter.stream_complete(request):
|
||||
yield chunk
|
||||
|
||||
def _parse_json(self, text: str) -> Dict[str, Any]:
|
||||
"""从LLM响应中解析JSON(增强版)"""
|
||||
|
||||
|
|
|
|||
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -28,14 +28,12 @@ import {
|
|||
cancelAgentTask,
|
||||
} from "@/shared/api/agentTasks";
|
||||
|
||||
// 事件类型图标映射 - 🔥 重点展示 LLM 相关事件
|
||||
// 事件类型图标映射
|
||||
const eventTypeIcons: Record<string, React.ReactNode> = {
|
||||
// 🧠 LLM 核心事件 - 最重要!
|
||||
// LLM 核心事件
|
||||
llm_start: <Brain className="w-3 h-3 text-purple-400 animate-pulse" />,
|
||||
llm_thought: <Sparkles className="w-3 h-3 text-purple-300" />,
|
||||
llm_decision: <Zap className="w-3 h-3 text-yellow-400" />,
|
||||
llm_action: <Zap className="w-3 h-3 text-orange-400" />,
|
||||
llm_observation: <Search className="w-3 h-3 text-blue-400" />,
|
||||
llm_complete: <CheckCircle2 className="w-3 h-3 text-green-400" />,
|
||||
|
||||
// 阶段相关
|
||||
|
|
@ -43,7 +41,7 @@ const eventTypeIcons: Record<string, React.ReactNode> = {
|
|||
phase_complete: <CheckCircle2 className="w-3 h-3 text-green-400" />,
|
||||
thinking: <Brain className="w-3 h-3 text-purple-400" />,
|
||||
|
||||
// 工具相关 - LLM 决定的工具调用
|
||||
// 工具相关
|
||||
tool_call: <Wrench className="w-3 h-3 text-yellow-400" />,
|
||||
tool_result: <CheckCircle2 className="w-3 h-3 text-green-400" />,
|
||||
tool_error: <XCircle className="w-3 h-3 text-red-400" />,
|
||||
|
|
@ -65,14 +63,12 @@ const eventTypeIcons: Record<string, React.ReactNode> = {
|
|||
task_cancel: <Square className="w-3 h-3 text-yellow-500" />,
|
||||
};
|
||||
|
||||
// 事件类型颜色映射 - 🔥 LLM 事件突出显示
|
||||
// 事件类型颜色映射
|
||||
const eventTypeColors: Record<string, string> = {
|
||||
// 🧠 LLM 核心事件 - 使用紫色系突出
|
||||
// LLM 核心事件
|
||||
llm_start: "text-purple-400 font-semibold",
|
||||
llm_thought: "text-purple-300 bg-purple-950/30 rounded px-1", // 思考内容特别高亮
|
||||
llm_decision: "text-yellow-300 font-semibold", // 决策特别突出
|
||||
llm_action: "text-orange-300 font-medium",
|
||||
llm_observation: "text-blue-300",
|
||||
llm_thought: "text-purple-300 bg-purple-950/30 rounded px-1",
|
||||
llm_decision: "text-yellow-300 font-semibold",
|
||||
llm_complete: "text-green-400 font-semibold",
|
||||
|
||||
// 阶段相关
|
||||
|
|
@ -411,7 +407,7 @@ export default function AgentAuditPage() {
|
|||
{/* 左侧:执行日志 */}
|
||||
<div className="flex-1 p-4 flex flex-col min-w-0">
|
||||
|
||||
{/* 🧠 LLM 思考过程展示区域 - 核心!展示 LLM 的大脑活动 */}
|
||||
{/* LLM 思考过程展示区域 */}
|
||||
{(isThinking || thinking) && showThinking && (
|
||||
<div className="mb-4 bg-purple-950/40 rounded-lg border-2 border-purple-700/60 overflow-hidden shadow-lg shadow-purple-900/20">
|
||||
<div
|
||||
|
|
@ -423,8 +419,8 @@ export default function AgentAuditPage() {
|
|||
<Brain className={`w-5 h-5 ${isThinking ? "animate-pulse" : ""}`} />
|
||||
</div>
|
||||
<div>
|
||||
<span className="uppercase tracking-wider font-semibold">🧠 LLM Thinking</span>
|
||||
<span className="text-purple-400 ml-2 text-xs">Agent 的大脑正在工作</span>
|
||||
<span className="uppercase tracking-wider font-semibold">LLM Thinking</span>
|
||||
<span className="text-purple-400 ml-2 text-xs">Agent 思考过程</span>
|
||||
</div>
|
||||
{isThinking && (
|
||||
<span className="flex items-center gap-1 text-purple-200 bg-purple-800/50 px-2 py-0.5 rounded-full text-xs">
|
||||
|
|
@ -438,7 +434,7 @@ export default function AgentAuditPage() {
|
|||
|
||||
<div className="max-h-52 overflow-y-auto bg-[#1a1025]">
|
||||
<div className="p-4 text-sm text-purple-100 font-mono whitespace-pre-wrap leading-relaxed">
|
||||
{thinking || "🤔 正在思考下一步..."}
|
||||
{thinking || "正在思考下一步..."}
|
||||
{isThinking && <span className="animate-pulse text-purple-400 text-lg">▌</span>}
|
||||
</div>
|
||||
<div ref={thinkingEndRef} />
|
||||
|
|
@ -446,7 +442,7 @@ export default function AgentAuditPage() {
|
|||
</div>
|
||||
)}
|
||||
|
||||
{/* 🔧 LLM 工具调用展示区域 - LLM 决定调用的工具 */}
|
||||
{/* 工具调用展示区域 */}
|
||||
{toolCalls.length > 0 && showToolDetails && (
|
||||
<div className="mb-4 bg-yellow-950/30 rounded-lg border-2 border-yellow-700/50 overflow-hidden shadow-lg shadow-yellow-900/10">
|
||||
<div
|
||||
|
|
@ -458,8 +454,8 @@ export default function AgentAuditPage() {
|
|||
<Wrench className="w-5 h-5" />
|
||||
</div>
|
||||
<div>
|
||||
<span className="uppercase tracking-wider font-semibold">🔧 LLM Tool Calls</span>
|
||||
<span className="text-yellow-500 ml-2 text-xs">LLM 决定调用的工具</span>
|
||||
<span className="uppercase tracking-wider font-semibold">Tool Calls</span>
|
||||
<span className="text-yellow-500 ml-2 text-xs">工具调用记录</span>
|
||||
</div>
|
||||
<Badge variant="outline" className="text-xs px-2 py-0.5 bg-yellow-900/50 border-yellow-600 text-yellow-300">
|
||||
{toolCalls.length} 次调用
|
||||
|
|
@ -488,9 +484,9 @@ export default function AgentAuditPage() {
|
|||
<div className="flex items-center gap-3 text-sm text-cyan-400">
|
||||
<div className="flex items-center gap-2">
|
||||
<Terminal className="w-4 h-4" />
|
||||
<span className="uppercase tracking-wider font-semibold">LLM Execution Log</span>
|
||||
<span className="uppercase tracking-wider font-semibold">Execution Log</span>
|
||||
</div>
|
||||
<span className="text-xs text-gray-500">LLM 思考 & 工具调用记录</span>
|
||||
<span className="text-xs text-gray-500">执行日志</span>
|
||||
{(isStreaming || isStreamConnected) && (
|
||||
<span className="flex items-center gap-1.5 text-green-400 bg-green-900/30 px-2 py-0.5 rounded-full text-xs">
|
||||
<span className="w-2 h-2 bg-green-400 rounded-full animate-pulse" />
|
||||
|
|
@ -708,7 +704,7 @@ function StatusBadge({ status }: { status: string }) {
|
|||
);
|
||||
}
|
||||
|
||||
// 事件行组件 - 增强 LLM 事件展示
|
||||
// 事件行组件
|
||||
function EventLine({ event }: { event: AgentEvent }) {
|
||||
const icon = eventTypeIcons[event.event_type] || <ChevronRight className="w-3 h-3 text-gray-500" />;
|
||||
const colorClass = eventTypeColors[event.event_type] || "text-gray-400";
|
||||
|
|
@ -717,19 +713,19 @@ function EventLine({ event }: { event: AgentEvent }) {
|
|||
? new Date(event.timestamp).toLocaleTimeString("zh-CN", { hour12: false })
|
||||
: "";
|
||||
|
||||
// LLM 思考事件特殊处理 - 展示多行内容
|
||||
// 特殊事件处理
|
||||
const isLLMThought = event.event_type === "llm_thought";
|
||||
const isLLMDecision = event.event_type === "llm_decision";
|
||||
const isLLMAction = event.event_type === "llm_action";
|
||||
const isImportantLLMEvent = isLLMThought || isLLMDecision || isLLMAction;
|
||||
const isToolCall = event.event_type === "tool_call";
|
||||
const isToolResult = event.event_type === "tool_result";
|
||||
|
||||
// LLM 事件背景色
|
||||
// 背景色
|
||||
const bgClass = isLLMThought
|
||||
? "bg-purple-950/40 border-l-2 border-purple-600"
|
||||
: isLLMDecision
|
||||
? "bg-yellow-950/30 border-l-2 border-yellow-600"
|
||||
: isLLMAction
|
||||
? "bg-orange-950/30 border-l-2 border-orange-600"
|
||||
: isToolCall || isToolResult
|
||||
? "bg-gray-900/30"
|
||||
: "";
|
||||
|
||||
return (
|
||||
|
|
@ -738,7 +734,7 @@ function EventLine({ event }: { event: AgentEvent }) {
|
|||
{timestamp}
|
||||
</span>
|
||||
<span className="flex-shrink-0 mt-0.5">{icon}</span>
|
||||
<span className={`flex-1 text-sm break-all ${isImportantLLMEvent ? "whitespace-pre-wrap" : ""}`}>
|
||||
<span className={`flex-1 text-sm break-all ${isLLMThought ? "whitespace-pre-wrap" : ""}`}>
|
||||
{event.message}
|
||||
{event.tool_duration_ms && (
|
||||
<span className="text-gray-600 ml-2">({event.tool_duration_ms}ms)</span>
|
||||
|
|
|
|||
Loading…
Reference in New Issue