feat: Introduce structured agent collaboration with `TaskHandoff` and `analysis_v2` agent, updating core agent logic, tools, and audit UI.

This commit is contained in:
lintsinghua 2025-12-11 23:29:04 +08:00
parent 8938a8a3c9
commit 70776ee5fd
25 changed files with 1657 additions and 599 deletions

View File

@ -28,6 +28,7 @@ from app.models.agent_task import (
)
from app.models.project import Project
from app.models.user import User
from app.models.user_config import UserConfig
from app.services.agent import AgentRunner, EventManager, run_agent_task
from app.services.agent.streaming import StreamHandler, StreamEvent, StreamEventType
@ -199,7 +200,7 @@ class TaskSummaryResponse(BaseModel):
# ============ 后台任务执行 ============
async def _execute_agent_task(task_id: str, project_root: str):
async def _execute_agent_task(task_id: str):
"""在后台执行 Agent 任务"""
async with async_session_factory() as db:
try:
@ -209,14 +210,57 @@ async def _execute_agent_task(task_id: str, project_root: str):
logger.error(f"Task {task_id} not found")
return
# 获取项目
project = task.project
if not project:
logger.error(f"Project not found for task {task_id}")
return
# 🔥 获取项目根目录(解压 ZIP 或克隆仓库)
project_root = await _get_project_root(project, task_id)
# 🔥 获取用户配置(从系统配置页面)
# 优先级1. 数据库用户配置 > 2. 环境变量配置
user_config = None
if task.created_by:
from app.api.v1.endpoints.config import (
decrypt_config,
SENSITIVE_LLM_FIELDS, SENSITIVE_OTHER_FIELDS
)
import json
result = await db.execute(
select(UserConfig).where(UserConfig.user_id == task.created_by)
)
config = result.scalar_one_or_none()
if config and config.llm_config:
# 🔥 有数据库配置:使用数据库配置(优先)
user_llm_config = json.loads(config.llm_config) if config.llm_config else {}
user_other_config = json.loads(config.other_config) if config.other_config else {}
# 解密敏感字段
user_llm_config = decrypt_config(user_llm_config, SENSITIVE_LLM_FIELDS)
user_other_config = decrypt_config(user_other_config, SENSITIVE_OTHER_FIELDS)
user_config = {
"llmConfig": user_llm_config, # 直接使用数据库配置,不合并默认值
"otherConfig": user_other_config,
}
logger.info(f"✅ Using database user config for task {task_id}, LLM provider: {user_llm_config.get('llmProvider', 'N/A')}")
else:
# 🔥 无数据库配置:传递 None让 LLMService 使用环境变量
user_config = None
logger.info(f"⚠️ No database config found for user {task.created_by}, will use environment variables for task {task_id}")
# 更新状态为运行中
task.status = AgentTaskStatus.RUNNING
task.started_at = datetime.now(timezone.utc)
await db.commit()
logger.info(f"Task {task_id} started")
# 创建 Runner
runner = AgentRunner(db, task, project_root)
# 创建 Runner(传入用户配置)
runner = AgentRunner(db, task, project_root, user_config=user_config)
_running_tasks[task_id] = runner
# 执行
@ -296,11 +340,8 @@ async def create_agent_task(
await db.commit()
await db.refresh(task)
# 确定项目根目录
project_root = _get_project_root(project, task.id)
# 在后台启动任务
background_tasks.add_task(_execute_agent_task, task.id, project_root)
# 在后台启动任务(项目根目录在任务内部获取)
background_tasks.add_task(_execute_agent_task, task.id)
logger.info(f"Created agent task {task.id} for project {project.name}")
@ -897,24 +938,73 @@ async def update_finding_status(
# ============ Helper Functions ============
def _get_project_root(project: Project, task_id: str) -> str:
async def _get_project_root(project: Project, task_id: str) -> str:
"""
获取项目根目录
TODO: 实际实现中需要
- 对于 ZIP 项目解压到临时目录
- 对于 Git 仓库克隆到临时目录
支持两种项目类型
- ZIP 项目解压 ZIP 文件到临时目录
- 仓库项目克隆仓库到临时目录
"""
import zipfile
import subprocess
base_path = f"/tmp/deepaudit/{task_id}"
# 确保目录存在
os.makedirs(base_path, exist_ok=True)
# 如果项目有存储路径,复制过来
if hasattr(project, 'storage_path') and project.storage_path:
if os.path.exists(project.storage_path):
# 复制项目文件
shutil.copytree(project.storage_path, base_path, dirs_exist_ok=True)
# 根据项目类型处理
if project.source_type == "zip":
# 🔥 ZIP 项目:解压 ZIP 文件
from app.services.zip_storage import load_project_zip
zip_path = await load_project_zip(project.id)
if zip_path and os.path.exists(zip_path):
try:
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(base_path)
logger.info(f"✅ Extracted ZIP project {project.id} to {base_path}")
except Exception as e:
logger.error(f"Failed to extract ZIP {zip_path}: {e}")
else:
logger.warning(f"⚠️ ZIP file not found for project {project.id}")
elif project.source_type == "repository" and project.repository_url:
# 🔥 仓库项目:克隆仓库
try:
branch = project.default_branch or "main"
repo_url = project.repository_url
# 克隆仓库
result = subprocess.run(
["git", "clone", "--depth", "1", "--branch", branch, repo_url, base_path],
capture_output=True,
text=True,
timeout=300,
)
if result.returncode == 0:
logger.info(f"✅ Cloned repository {repo_url} (branch: {branch}) to {base_path}")
else:
logger.warning(f"Failed to clone branch {branch}, trying default branch: {result.stderr}")
# 如果克隆失败,尝试使用默认分支
if branch != "main":
result = subprocess.run(
["git", "clone", "--depth", "1", repo_url, base_path],
capture_output=True,
text=True,
timeout=300,
)
if result.returncode == 0:
logger.info(f"✅ Cloned repository {repo_url} (default branch) to {base_path}")
else:
logger.error(f"Failed to clone repository: {result.stderr}")
except subprocess.TimeoutExpired:
logger.error(f"Git clone timeout for {project.repository_url}")
except Exception as e:
logger.error(f"Failed to clone repository {project.repository_url}: {e}")
return base_path

View File

@ -1,9 +1,13 @@
"""
混合 Agent 架构
包含 OrchestratorReconAnalysis Verification Agent
协作机制
- Agent 之间通过 TaskHandoff 传递结构化上下文
- 每个 Agent 完成后生成 handoff 给下一个 Agent
"""
from .base import BaseAgent, AgentConfig, AgentResult
from .base import BaseAgent, AgentConfig, AgentResult, TaskHandoff
from .orchestrator import OrchestratorAgent
from .recon import ReconAgent
from .analysis import AnalysisAgent
@ -13,6 +17,7 @@ __all__ = [
"BaseAgent",
"AgentConfig",
"AgentResult",
"TaskHandoff",
"OrchestratorAgent",
"ReconAgent",
"AnalysisAgent",

View File

@ -10,6 +10,7 @@ LLM 是真正的安全分析大脑!
类型: ReAct (真正的!)
"""
import asyncio
import json
import logging
import re
@ -17,6 +18,7 @@ from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
from ..json_parser import AgentJsonParser
logger = logging.getLogger(__name__)
@ -33,18 +35,13 @@ ANALYSIS_SYSTEM_PROMPT = """你是 DeepAudit 的漏洞分析 Agent一个**自
## 你可以使用的工具
### 外部扫描工具
- **semgrep_scan**: Semgrep 静态分析推荐首先使用
参数: rules (str), max_results (int)
- **bandit_scan**: Python 安全扫描
### RAG 语义搜索
- **rag_query**: 语义代码搜索
参数: query (str), top_k (int)
- **security_search**: 安全相关代码搜索
参数: vulnerability_type (str), top_k (int)
- **function_context**: 函数上下文分析
参数: function_name (str)
### 文件操作
- **read_file**: 读取文件内容
参数: file_path (str), start_line (int), end_line (int)
- **list_files**: 列出目录文件
参数: directory (str), pattern (str)
- **search_code**: 代码关键字搜索
参数: keyword (str), max_results (int)
### 深度分析
- **pattern_match**: 危险模式匹配
@ -53,16 +50,28 @@ ANALYSIS_SYSTEM_PROMPT = """你是 DeepAudit 的漏洞分析 Agent一个**自
参数: code (str), file_path (str), focus (str)
- **dataflow_analysis**: 数据流追踪
参数: source (str), sink (str)
- **vulnerability_validation**: 漏洞验证
参数: code (str), vulnerability_type (str)
### 文件操作
- **read_file**: 读取文件内容
参数: file_path (str), start_line (int), end_line (int)
- **search_code**: 代码关键字搜索
参数: keyword (str), max_results (int)
- **list_files**: 列出目录文件
参数: directory (str), pattern (str)
### 外部静态分析工具
- **semgrep_scan**: Semgrep 静态分析推荐首先使用
参数: rules (str), max_results (int)
- **bandit_scan**: Python 安全扫描
参数: target (str)
- **gitleaks_scan**: Git 密钥泄露扫描
参数: target (str)
- **trufflehog_scan**: 敏感信息扫描
参数: target (str)
- **npm_audit**: NPM 依赖漏洞扫描
参数: target (str)
- **safety_scan**: Python 依赖安全扫描
参数: target (str)
- **osv_scan**: OSV 漏洞数据库扫描
参数: target (str)
### RAG 语义搜索
- **security_search**: 安全相关代码搜索
参数: vulnerability_type (str), top_k (int)
- **function_context**: 函数上下文分析
参数: function_name (str)
## 工作方式
每一步你需要输出
@ -168,15 +177,7 @@ class AnalysisAgent(BaseAgent):
self._conversation_history: List[Dict[str, str]] = []
self._steps: List[AnalysisStep] = []
def _get_tools_description(self) -> str:
"""生成工具描述"""
tools_info = []
for name, tool in self.tools.items():
if name.startswith("_"):
continue
desc = f"- {name}: {getattr(tool, 'description', 'No description')}"
tools_info.append(desc)
return "\n".join(tools_info)
def _parse_llm_response(self, response: str) -> AnalysisStep:
"""解析 LLM 响应"""
@ -191,13 +192,20 @@ class AnalysisAgent(BaseAgent):
final_match = re.search(r'Final Answer:\s*(.*?)$', response, re.DOTALL)
if final_match:
step.is_final = True
try:
answer_text = final_match.group(1).strip()
answer_text = re.sub(r'```json\s*', '', answer_text)
answer_text = re.sub(r'```\s*', '', answer_text)
step.final_answer = json.loads(answer_text)
except json.JSONDecodeError:
step.final_answer = {"findings": [], "raw_answer": final_match.group(1).strip()}
answer_text = final_match.group(1).strip()
answer_text = re.sub(r'```json\s*', '', answer_text)
answer_text = re.sub(r'```\s*', '', answer_text)
# 使用增强的 JSON 解析器
step.final_answer = AgentJsonParser.parse(
answer_text,
default={"findings": [], "raw_answer": answer_text}
)
# 确保 findings 格式正确
if "findings" in step.final_answer:
step.final_answer["findings"] = [
f for f in step.final_answer["findings"]
if isinstance(f, dict)
]
return step
# 提取 Action
@ -211,51 +219,15 @@ class AnalysisAgent(BaseAgent):
input_text = input_match.group(1).strip()
input_text = re.sub(r'```json\s*', '', input_text)
input_text = re.sub(r'```\s*', '', input_text)
try:
step.action_input = json.loads(input_text)
except json.JSONDecodeError:
step.action_input = {"raw_input": input_text}
# 使用增强的 JSON 解析器
step.action_input = AgentJsonParser.parse(
input_text,
default={"raw_input": input_text}
)
return step
async def _execute_tool(self, tool_name: str, tool_input: Dict) -> str:
"""执行工具"""
tool = self.tools.get(tool_name)
if not tool:
return f"错误: 工具 '{tool_name}' 不存在。可用工具: {list(self.tools.keys())}"
try:
self._tool_calls += 1
await self.emit_tool_call(tool_name, tool_input)
import time
start = time.time()
result = await tool.execute(**tool_input)
duration_ms = int((time.time() - start) * 1000)
await self.emit_tool_result(tool_name, str(result.data)[:200], duration_ms)
if result.success:
output = str(result.data)
# 如果是代码分析工具,也包含 metadata
if result.metadata:
if "issues" in result.metadata:
output += f"\n\n发现的问题:\n{json.dumps(result.metadata['issues'], ensure_ascii=False, indent=2)}"
if "findings" in result.metadata:
output += f"\n\n发现:\n{json.dumps(result.metadata['findings'][:10], ensure_ascii=False, indent=2)}"
if len(output) > 6000:
output = output[:6000] + f"\n\n... [输出已截断,共 {len(str(result.data))} 字符]"
return output
else:
return f"工具执行失败: {result.error}"
except Exception as e:
logger.error(f"Tool execution error: {e}")
return f"工具执行错误: {str(e)}"
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
"""
@ -271,6 +243,14 @@ class AnalysisAgent(BaseAgent):
task = input_data.get("task", "")
task_context = input_data.get("task_context", "")
# 🔥 处理交接信息
handoff = input_data.get("handoff")
if handoff:
from .base import TaskHandoff
if isinstance(handoff, dict):
handoff = TaskHandoff.from_dict(handoff)
self.receive_handoff(handoff)
# 从 Recon 结果获取上下文
recon_data = previous_results.get("recon", {})
if isinstance(recon_data, dict) and "data" in recon_data:
@ -281,7 +261,9 @@ class AnalysisAgent(BaseAgent):
high_risk_areas = recon_data.get("high_risk_areas", plan.get("high_risk_areas", []))
initial_findings = recon_data.get("initial_findings", [])
# 构建初始消息
# 🔥 构建包含交接上下文的初始消息
handoff_context = self.get_handoff_context()
initial_message = f"""请开始对项目进行安全漏洞分析。
## 项目信息
@ -289,7 +271,7 @@ class AnalysisAgent(BaseAgent):
- 语言: {tech_stack.get('languages', [])}
- 框架: {tech_stack.get('frameworks', [])}
## 上下文信息
{handoff_context if handoff_context else f'''## 上下文信息
### 高风险区域
{json.dumps(high_risk_areas[:20], ensure_ascii=False)}
@ -297,7 +279,7 @@ class AnalysisAgent(BaseAgent):
{json.dumps(entry_points[:10], ensure_ascii=False, indent=2)}
### 初步发现 (如果有)
{json.dumps(initial_findings[:5], ensure_ascii=False, indent=2) if initial_findings else ''}
{json.dumps(initial_findings[:5], ensure_ascii=False, indent=2) if initial_findings else ""}'''}
## 任务
{task_context or task or '进行全面的安全漏洞分析,发现代码中的安全问题。'}
@ -306,9 +288,12 @@ class AnalysisAgent(BaseAgent):
{config.get('target_vulnerabilities', ['all'])}
## 可用工具
{self._get_tools_description()}
{self.get_tools_description()}
请开始你的安全分析首先思考分析策略然后选择合适的工具开始分析"""
# 🔥 记录工作开始
self.record_work("开始安全漏洞分析")
# 初始化对话历史
self._conversation_history = [
@ -328,18 +313,22 @@ class AnalysisAgent(BaseAgent):
self._iteration = iteration + 1
# 🔥 发射 LLM 开始思考事件
await self.emit_llm_start(iteration + 1)
# 🔥 再次检查取消标志在LLM调用之前
if self.is_cancelled:
await self.emit_thinking("🛑 任务已取消,停止执行")
break
# 🔥 调用 LLM 进行思考和决策
response = await self.llm_service.chat_completion_raw(
messages=self._conversation_history,
temperature=0.1,
max_tokens=2048,
)
# 调用 LLM 进行思考和决策(流式输出)
try:
llm_output, tokens_this_round = await self.stream_llm_call(
self._conversation_history,
temperature=0.1,
max_tokens=2048,
)
except asyncio.CancelledError:
logger.info(f"[{self.name}] LLM call cancelled")
break
llm_output = response.get("content", "")
tokens_this_round = response.get("usage", {}).get("total_tokens", 0)
self._total_tokens += tokens_this_round
# 解析 LLM 响应
@ -369,6 +358,14 @@ class AnalysisAgent(BaseAgent):
finding.get("vulnerability_type", "other"),
finding.get("file_path", "")
)
# 🔥 记录洞察
self.add_insight(
f"发现 {finding.get('severity', 'medium')} 级别漏洞: {finding.get('title', 'Unknown')}"
)
# 🔥 记录工作完成
self.record_work(f"完成安全分析,发现 {len(all_findings)} 个潜在漏洞")
await self.emit_llm_complete(
f"分析完成,发现 {len(all_findings)} 个潜在漏洞",
self._total_tokens
@ -380,7 +377,7 @@ class AnalysisAgent(BaseAgent):
# 🔥 发射 LLM 动作决策事件
await self.emit_llm_action(step.action, step.action_input or {})
observation = await self._execute_tool(
observation = await self.execute_tool(
step.action,
step.action_input or {}
)
@ -427,7 +424,7 @@ class AnalysisAgent(BaseAgent):
await self.emit_event(
"info",
f"🎯 Analysis Agent 完成: {len(standardized_findings)} 个发现, {self._iteration} 轮迭代, {self._tool_calls} 次工具调用"
f"Analysis Agent 完成: {len(standardized_findings)} 个发现, {self._iteration} 轮迭代, {self._tool_calls} 次工具调用"
)
return AgentResult(

View File

@ -2,14 +2,20 @@
Agent 基类
定义 Agent 的基本接口和通用功能
核心原则LLM Agent 的大脑所有日志应该反映 LLM 的参与
核心原则
1. LLM Agent 的大脑全程参与决策
2. Agent 之间通过 TaskHandoff 传递结构化上下文
3. 事件分为流式事件前端展示和持久化事件数据库记录
"""
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, AsyncGenerator
from typing import List, Dict, Any, Optional, AsyncGenerator, Tuple
from dataclasses import dataclass, field
from enum import Enum
from datetime import datetime, timezone
import asyncio
import logging
import uuid
logger = logging.getLogger(__name__)
@ -73,6 +79,9 @@ class AgentResult:
# 元数据
metadata: Dict[str, Any] = field(default_factory=dict)
# 🔥 协作信息 - Agent 传递给下一个 Agent 的结构化信息
handoff: Optional["TaskHandoff"] = None
def to_dict(self) -> Dict[str, Any]:
return {
"success": self.success,
@ -83,9 +92,139 @@ class AgentResult:
"tokens_used": self.tokens_used,
"duration_ms": self.duration_ms,
"metadata": self.metadata,
"handoff": self.handoff.to_dict() if self.handoff else None,
}
@dataclass
class TaskHandoff:
"""
任务交接协议 - Agent 之间传递的结构化信息
设计原则
1. 包含足够的上下文让下一个 Agent 理解前序工作
2. 提供明确的建议和关注点
3. 可直接转换为 LLM 可理解的 prompt
"""
# 基本信息
from_agent: str
to_agent: str
# 工作摘要
summary: str
work_completed: List[str] = field(default_factory=list)
# 关键发现和洞察
key_findings: List[Dict[str, Any]] = field(default_factory=list)
insights: List[str] = field(default_factory=list)
# 建议和关注点
suggested_actions: List[Dict[str, Any]] = field(default_factory=list)
attention_points: List[str] = field(default_factory=list)
priority_areas: List[str] = field(default_factory=list)
# 上下文数据
context_data: Dict[str, Any] = field(default_factory=dict)
# 置信度
confidence: float = 0.8
# 时间戳
timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
def to_dict(self) -> Dict[str, Any]:
return {
"from_agent": self.from_agent,
"to_agent": self.to_agent,
"summary": self.summary,
"work_completed": self.work_completed,
"key_findings": self.key_findings,
"insights": self.insights,
"suggested_actions": self.suggested_actions,
"attention_points": self.attention_points,
"priority_areas": self.priority_areas,
"context_data": self.context_data,
"confidence": self.confidence,
"timestamp": self.timestamp.isoformat(),
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "TaskHandoff":
return cls(
from_agent=data.get("from_agent", ""),
to_agent=data.get("to_agent", ""),
summary=data.get("summary", ""),
work_completed=data.get("work_completed", []),
key_findings=data.get("key_findings", []),
insights=data.get("insights", []),
suggested_actions=data.get("suggested_actions", []),
attention_points=data.get("attention_points", []),
priority_areas=data.get("priority_areas", []),
context_data=data.get("context_data", {}),
confidence=data.get("confidence", 0.8),
)
def to_prompt_context(self) -> str:
"""
转换为 LLM 可理解的上下文格式
这是关键 LLM 能够理解前序 Agent 的工作
"""
lines = [
f"## 来自 {self.from_agent} Agent 的任务交接",
"",
f"### 工作摘要",
self.summary,
"",
]
if self.work_completed:
lines.append("### 已完成的工作")
for work in self.work_completed:
lines.append(f"- {work}")
lines.append("")
if self.key_findings:
lines.append("### 关键发现")
for i, finding in enumerate(self.key_findings[:15], 1):
severity = finding.get("severity", "medium")
title = finding.get("title", "Unknown")
file_path = finding.get("file_path", "")
lines.append(f"{i}. [{severity.upper()}] {title}")
if file_path:
lines.append(f" 位置: {file_path}:{finding.get('line_start', '')}")
if finding.get("description"):
lines.append(f" 描述: {finding['description'][:100]}")
lines.append("")
if self.insights:
lines.append("### 洞察和分析")
for insight in self.insights:
lines.append(f"- {insight}")
lines.append("")
if self.suggested_actions:
lines.append("### 建议的下一步行动")
for action in self.suggested_actions:
action_type = action.get("type", "general")
description = action.get("description", "")
priority = action.get("priority", "medium")
lines.append(f"- [{priority.upper()}] {action_type}: {description}")
lines.append("")
if self.attention_points:
lines.append("### ⚠️ 需要特别关注")
for point in self.attention_points:
lines.append(f"- {point}")
lines.append("")
if self.priority_areas:
lines.append("### 优先分析区域")
for area in self.priority_areas:
lines.append(f"- {area}")
return "\n".join(lines)
class BaseAgent(ABC):
"""
Agent 基类
@ -94,6 +233,11 @@ class BaseAgent(ABC):
1. LLM Agent 的大脑全程参与决策
2. 所有日志应该反映 LLM 的思考过程
3. 工具调用是 LLM 的决策结果
协作原则
1. 通过 TaskHandoff 接收前序 Agent 的上下文
2. 执行完成后生成 TaskHandoff 传递给下一个 Agent
3. 洞察和发现应该结构化记录
"""
def __init__(
@ -122,6 +266,11 @@ class BaseAgent(ABC):
self._total_tokens = 0
self._tool_calls = 0
self._cancelled = False
# 🔥 协作状态
self._incoming_handoff: Optional[TaskHandoff] = None
self._insights: List[str] = [] # 收集的洞察
self._work_completed: List[str] = [] # 完成的工作记录
@property
def name(self) -> str:
@ -152,6 +301,103 @@ class BaseAgent(ABC):
def is_cancelled(self) -> bool:
return self._cancelled
# ============ 协作方法 ============
def receive_handoff(self, handoff: TaskHandoff):
"""
接收来自前序 Agent 的任务交接
Args:
handoff: 任务交接对象
"""
self._incoming_handoff = handoff
logger.info(
f"[{self.name}] Received handoff from {handoff.from_agent}: "
f"{handoff.summary[:50]}..."
)
def get_handoff_context(self) -> str:
"""
获取交接上下文用于构建 LLM prompt
Returns:
格式化的上下文字符串
"""
if not self._incoming_handoff:
return ""
return self._incoming_handoff.to_prompt_context()
def add_insight(self, insight: str):
"""记录洞察"""
self._insights.append(insight)
def record_work(self, work: str):
"""记录完成的工作"""
self._work_completed.append(work)
def create_handoff(
self,
to_agent: str,
summary: str,
key_findings: List[Dict[str, Any]] = None,
suggested_actions: List[Dict[str, Any]] = None,
attention_points: List[str] = None,
priority_areas: List[str] = None,
context_data: Dict[str, Any] = None,
) -> TaskHandoff:
"""
创建任务交接
Args:
to_agent: 目标 Agent
summary: 工作摘要
key_findings: 关键发现
suggested_actions: 建议的行动
attention_points: 需要关注的点
priority_areas: 优先分析区域
context_data: 上下文数据
Returns:
TaskHandoff 对象
"""
return TaskHandoff(
from_agent=self.name,
to_agent=to_agent,
summary=summary,
work_completed=self._work_completed.copy(),
key_findings=key_findings or [],
insights=self._insights.copy(),
suggested_actions=suggested_actions or [],
attention_points=attention_points or [],
priority_areas=priority_areas or [],
context_data=context_data or {},
)
def build_prompt_with_handoff(self, base_prompt: str) -> str:
"""
构建包含交接上下文的 prompt
Args:
base_prompt: 基础 prompt
Returns:
增强后的 prompt
"""
handoff_context = self.get_handoff_context()
if not handoff_context:
return base_prompt
return f"""{base_prompt}
---
## 前序 Agent 交接信息
{handoff_context}
---
请基于以上来自前序 Agent 的信息结合你的专业能力开展工作
"""
# ============ 核心事件发射方法 ============
async def emit_event(
@ -173,13 +419,13 @@ class BaseAgent(ABC):
async def emit_thinking(self, message: str):
"""发射 LLM 思考事件"""
await self.emit_event("thinking", f"🧠 [{self.name}] {message}")
await self.emit_event("thinking", f"[{self.name}] {message}")
async def emit_llm_start(self, iteration: int):
"""发射 LLM 开始思考事件"""
await self.emit_event(
"llm_start",
f"🤔 [{self.name}] LLM 开始{iteration}思考...",
f"[{self.name}] {iteration}迭代开始",
metadata={"iteration": iteration}
)
@ -189,31 +435,62 @@ class BaseAgent(ABC):
display_thought = thought[:500] + "..." if len(thought) > 500 else thought
await self.emit_event(
"llm_thought",
f"💭 [{self.name}] LLM 思考:\n{display_thought}",
f"[{self.name}] 思考: {display_thought}",
metadata={
"thought": thought,
"iteration": iteration,
}
)
async def emit_thinking_start(self):
"""发射开始思考事件(流式输出用)"""
await self.emit_event("thinking_start", f"[{self.name}] 开始思考...")
async def emit_thinking_token(self, token: str, accumulated: str):
"""发射思考 token 事件(流式输出用)"""
await self.emit_event(
"thinking_token",
"", # 不需要 message前端从 metadata 获取
metadata={
"token": token,
"accumulated": accumulated,
}
)
async def emit_thinking_end(self, full_response: str):
"""发射思考结束事件(流式输出用)"""
await self.emit_event(
"thinking_end",
f"[{self.name}] 思考完成",
metadata={"accumulated": full_response}
)
async def emit_llm_decision(self, decision: str, reason: str = ""):
"""发射 LLM 决策事件 - 展示 LLM 做了什么决定"""
await self.emit_event(
"llm_decision",
f"💡 [{self.name}] LLM 决策: {decision}" + (f" (理由: {reason})" if reason else ""),
f"[{self.name}] 决策: {decision}" + (f" ({reason})" if reason else ""),
metadata={
"decision": decision,
"reason": reason,
}
)
async def emit_llm_complete(self, result_summary: str, tokens_used: int):
"""发射 LLM 完成事件"""
await self.emit_event(
"llm_complete",
f"[{self.name}] 完成: {result_summary} (消耗 {tokens_used} tokens)",
metadata={
"tokens_used": tokens_used,
}
)
async def emit_llm_action(self, action: str, action_input: Dict):
"""发射 LLM 动作事件 - LLM 决定执行什么动作"""
import json
input_str = json.dumps(action_input, ensure_ascii=False)[:200]
"""发射 LLM 动作决策事件"""
await self.emit_event(
"llm_action",
f"⚡ [{self.name}] LLM 动作: {action}\n 参数: {input_str}",
f"[{self.name}] 执行动作: {action}",
metadata={
"action": action,
"action_input": action_input,
@ -221,43 +498,33 @@ class BaseAgent(ABC):
)
async def emit_llm_observation(self, observation: str):
"""发射 LLM 观察事件 - LLM 看到了什么"""
"""发射 LLM 观察事件"""
# 截断过长的观察结果
display_obs = observation[:300] + "..." if len(observation) > 300 else observation
await self.emit_event(
"llm_observation",
f"👁️ [{self.name}] LLM 观察到:\n{display_obs}",
metadata={"observation": observation[:2000]}
)
async def emit_llm_complete(self, result_summary: str, tokens_used: int):
"""发射 LLM 完成事件"""
await self.emit_event(
"llm_complete",
f"✅ [{self.name}] LLM 完成: {result_summary} (消耗 {tokens_used} tokens)",
f"[{self.name}] 观察结果: {display_obs}",
metadata={
"tokens_used": tokens_used,
"observation": observation[:2000], # 限制存储长度
}
)
# ============ 工具调用相关事件 ============
async def emit_tool_call(self, tool_name: str, tool_input: Dict):
"""发射工具调用事件 - LLM 决定调用工具"""
import json
input_str = json.dumps(tool_input, ensure_ascii=False)[:300]
"""发射工具调用事件"""
await self.emit_event(
"tool_call",
f"🔧 [{self.name}] LLM 调用工具: {tool_name}\n 输入: {input_str}",
f"[{self.name}] 调用工具: {tool_name}",
tool_name=tool_name,
tool_input=tool_input,
)
async def emit_tool_result(self, tool_name: str, result: str, duration_ms: int):
"""发射工具结果事件"""
result_preview = result[:200] + "..." if len(result) > 200 else result
await self.emit_event(
"tool_result",
f"📤 [{self.name}] 工具 {tool_name} 返回 ({duration_ms}ms):\n {result_preview}",
f"[{self.name}] 工具 {tool_name} 完成 ({duration_ms}ms)",
tool_name=tool_name,
tool_duration_ms=duration_ms,
)
@ -332,9 +599,6 @@ class BaseAgent(ABC):
"""
self._iteration += 1
# 发射 LLM 开始事件
await self.emit_llm_start(self._iteration)
try:
response = await self.llm_service.chat_completion(
messages=messages,
@ -385,3 +649,124 @@ class BaseAgent(ABC):
"tool_calls": self._tool_calls,
"tokens_used": self._total_tokens,
}
# ============ 统一的流式 LLM 调用 ============
async def stream_llm_call(
self,
messages: List[Dict[str, str]],
temperature: float = 0.1,
max_tokens: int = 2048,
) -> Tuple[str, int]:
"""
统一的流式 LLM 调用方法
所有 Agent 共用此方法避免重复代码
Args:
messages: 消息列表
temperature: 温度
max_tokens: 最大 token
Returns:
(完整响应内容, token数量)
"""
accumulated = ""
total_tokens = 0
await self.emit_thinking_start()
try:
async for chunk in self.llm_service.chat_completion_stream(
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
):
# 检查取消
if self.is_cancelled:
break
if chunk["type"] == "token":
token = chunk["content"]
accumulated = chunk["accumulated"]
await self.emit_thinking_token(token, accumulated)
elif chunk["type"] == "done":
accumulated = chunk["content"]
if chunk.get("usage"):
total_tokens = chunk["usage"].get("total_tokens", 0)
break
elif chunk["type"] == "error":
accumulated = chunk.get("accumulated", "")
logger.error(f"Stream error: {chunk.get('error')}")
break
except asyncio.CancelledError:
logger.info(f"[{self.name}] LLM call cancelled")
raise
finally:
await self.emit_thinking_end(accumulated)
return accumulated, total_tokens
async def execute_tool(self, tool_name: str, tool_input: Dict) -> str:
"""
统一的工具执行方法
Args:
tool_name: 工具名称
tool_input: 工具参数
Returns:
工具执行结果字符串
"""
tool = self.tools.get(tool_name)
if not tool:
return f"错误: 工具 '{tool_name}' 不存在。可用工具: {list(self.tools.keys())}"
try:
self._tool_calls += 1
await self.emit_tool_call(tool_name, tool_input)
import time
start = time.time()
result = await tool.execute(**tool_input)
duration_ms = int((time.time() - start) * 1000)
await self.emit_tool_result(tool_name, str(result.data)[:200], duration_ms)
if result.success:
output = str(result.data)
# 包含 metadata 中的额外信息
if result.metadata:
if "issues" in result.metadata:
import json
output += f"\n\n发现的问题:\n{json.dumps(result.metadata['issues'], ensure_ascii=False, indent=2)}"
if "findings" in result.metadata:
import json
output += f"\n\n发现:\n{json.dumps(result.metadata['findings'][:10], ensure_ascii=False, indent=2)}"
# 截断过长输出
if len(output) > 6000:
output = output[:6000] + f"\n\n... [输出已截断,共 {len(str(result.data))} 字符]"
return output
else:
return f"工具执行失败: {result.error}"
except Exception as e:
logger.error(f"Tool execution error: {e}")
return f"工具执行错误: {str(e)}"
def get_tools_description(self) -> str:
"""生成工具描述文本(用于 prompt"""
tools_info = []
for name, tool in self.tools.items():
if name.startswith("_"):
continue
desc = f"- {name}: {getattr(tool, 'description', 'No description')}"
tools_info.append(desc)
return "\n".join(tools_info)

View File

@ -18,6 +18,7 @@ from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
from ..json_parser import AgentJsonParser
logger = logging.getLogger(__name__)
@ -178,18 +179,22 @@ class OrchestratorAgent(BaseAgent):
self._iteration = iteration + 1
# 🔥 发射 LLM 开始思考事件
await self.emit_llm_start(iteration + 1)
# 🔥 再次检查取消标志在LLM调用之前
if self.is_cancelled:
await self.emit_thinking("🛑 任务已取消,停止执行")
break
# 🔥 调用 LLM 进行思考和决策
response = await self.llm_service.chat_completion_raw(
messages=self._conversation_history,
temperature=0.1,
max_tokens=2048,
)
# 调用 LLM 进行思考和决策(流式输出)
try:
llm_output, tokens_this_round = await self.stream_llm_call(
self._conversation_history,
temperature=0.1,
max_tokens=2048,
)
except asyncio.CancelledError:
logger.info(f"[{self.name}] LLM call cancelled")
break
llm_output = response.get("content", "")
tokens_this_round = response.get("usage", {}).get("total_tokens", 0)
self._total_tokens += tokens_this_round
# 解析 LLM 的决策
@ -348,10 +353,11 @@ class OrchestratorAgent(BaseAgent):
input_text = re.sub(r'```json\s*', '', input_text)
input_text = re.sub(r'```\s*', '', input_text)
try:
action_input = json.loads(input_text)
except json.JSONDecodeError:
action_input = {"raw": input_text}
# 使用增强的 JSON 解析器
action_input = AgentJsonParser.parse(
input_text,
default={"raw": input_text}
)
return AgentStep(
thought=thought,

View File

@ -16,6 +16,7 @@ from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
from ..json_parser import AgentJsonParser
logger = logging.getLogger(__name__)
@ -182,15 +183,20 @@ class ReActAgent(BaseAgent):
final_match = re.search(r'Final Answer:\s*(.*?)$', response, re.DOTALL)
if final_match:
step.is_final = True
try:
# 尝试提取 JSON
answer_text = final_match.group(1).strip()
# 移除 markdown 代码块
answer_text = re.sub(r'```json\s*', '', answer_text)
answer_text = re.sub(r'```\s*', '', answer_text)
step.final_answer = json.loads(answer_text)
except json.JSONDecodeError:
step.final_answer = {"raw_answer": final_match.group(1).strip()}
answer_text = final_match.group(1).strip()
answer_text = re.sub(r'```json\s*', '', answer_text)
answer_text = re.sub(r'```\s*', '', answer_text)
# 使用增强的 JSON 解析器
step.final_answer = AgentJsonParser.parse(
answer_text,
default={"raw_answer": answer_text}
)
# 确保 findings 格式正确
if "findings" in step.final_answer:
step.final_answer["findings"] = [
f for f in step.final_answer["findings"]
if isinstance(f, dict)
]
return step
# 提取 Action
@ -202,14 +208,13 @@ class ReActAgent(BaseAgent):
input_match = re.search(r'Action Input:\s*(.*?)(?=Thought:|Action:|Observation:|$)', response, re.DOTALL)
if input_match:
input_text = input_match.group(1).strip()
# 移除 markdown 代码块
input_text = re.sub(r'```json\s*', '', input_text)
input_text = re.sub(r'```\s*', '', input_text)
try:
step.action_input = json.loads(input_text)
except json.JSONDecodeError:
# 尝试简单解析
step.action_input = {"raw_input": input_text}
# 使用增强的 JSON 解析器
step.action_input = AgentJsonParser.parse(
input_text,
default={"raw_input": input_text}
)
return step

View File

@ -10,6 +10,7 @@ LLM 是真正的大脑!
类型: ReAct (真正的!)
"""
import asyncio
import json
import logging
import re
@ -17,22 +18,24 @@ from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
from ..json_parser import AgentJsonParser
logger = logging.getLogger(__name__)
RECON_SYSTEM_PROMPT = """你是 DeepAudit 的信息收集 Agent负责在安全审计前**自主**收集项目信息。
RECON_SYSTEM_PROMPT = """你是 DeepAudit 的信息收集 Agent负责在安全审计前收集项目信息。
## 你的角色
你是信息收集的**大脑**不是机械执行者你需要
1. 自主思考需要收集什么信息
2. 选择合适的工具获取信息
3. 根据发现动态调整策略
4. 判断何时信息收集足够
## 你的职责
你专注于**信息收集**为后续的漏洞分析提供基础数据
1. 分析项目结构和目录布局
2. 识别技术栈语言框架数据库
3. 找出入口点API路由用户输入处理
4. 标记高风险区域认证数据库操作文件处理
5. 收集依赖信息
## 你可以使用的工具
### 文件系统
### 文件系统工具
- **list_files**: 列出目录内容
参数: directory (str), recursive (bool), pattern (str), max_files (int)
@ -42,12 +45,14 @@ RECON_SYSTEM_PROMPT = """你是 DeepAudit 的信息收集 Agent负责在安
- **search_code**: 代码关键字搜索
参数: keyword (str), max_results (int)
### 安全扫描
- **semgrep_scan**: Semgrep 静态分析扫描
- **npm_audit**: npm 依赖漏洞审计
- **safety_scan**: Python 依赖漏洞审计
- **gitleaks_scan**: 密钥/敏感信息泄露扫描
- **osv_scan**: OSV 通用依赖漏洞扫描
### 语义搜索工具
- **rag_query**: 语义代码搜索如果可用
参数: query (str), top_k (int)
## 注意
- 你只负责信息收集不要进行漏洞分析
- 漏洞分析由 Analysis Agent 负责
- 专注于收集项目结构技术栈入口点等信息
## 工作方式
每一步你需要输出
@ -142,15 +147,7 @@ class ReconAgent(BaseAgent):
self._conversation_history: List[Dict[str, str]] = []
self._steps: List[ReconStep] = []
def _get_tools_description(self) -> str:
"""生成工具描述"""
tools_info = []
for name, tool in self.tools.items():
if name.startswith("_"):
continue
desc = f"- {name}: {getattr(tool, 'description', 'No description')}"
tools_info.append(desc)
return "\n".join(tools_info)
def _parse_llm_response(self, response: str) -> ReconStep:
"""解析 LLM 响应"""
@ -165,13 +162,14 @@ class ReconAgent(BaseAgent):
final_match = re.search(r'Final Answer:\s*(.*?)$', response, re.DOTALL)
if final_match:
step.is_final = True
try:
answer_text = final_match.group(1).strip()
answer_text = re.sub(r'```json\s*', '', answer_text)
answer_text = re.sub(r'```\s*', '', answer_text)
step.final_answer = json.loads(answer_text)
except json.JSONDecodeError:
step.final_answer = {"raw_answer": final_match.group(1).strip()}
answer_text = final_match.group(1).strip()
answer_text = re.sub(r'```json\s*', '', answer_text)
answer_text = re.sub(r'```\s*', '', answer_text)
# 使用增强的 JSON 解析器
step.final_answer = AgentJsonParser.parse(
answer_text,
default={"raw_answer": answer_text}
)
return step
# 提取 Action
@ -185,43 +183,15 @@ class ReconAgent(BaseAgent):
input_text = input_match.group(1).strip()
input_text = re.sub(r'```json\s*', '', input_text)
input_text = re.sub(r'```\s*', '', input_text)
try:
step.action_input = json.loads(input_text)
except json.JSONDecodeError:
step.action_input = {"raw_input": input_text}
# 使用增强的 JSON 解析器
step.action_input = AgentJsonParser.parse(
input_text,
default={"raw_input": input_text}
)
return step
async def _execute_tool(self, tool_name: str, tool_input: Dict) -> str:
"""执行工具"""
tool = self.tools.get(tool_name)
if not tool:
return f"错误: 工具 '{tool_name}' 不存在。可用工具: {list(self.tools.keys())}"
try:
self._tool_calls += 1
await self.emit_tool_call(tool_name, tool_input)
import time
start = time.time()
result = await tool.execute(**tool_input)
duration_ms = int((time.time() - start) * 1000)
await self.emit_tool_result(tool_name, str(result.data)[:200], duration_ms)
if result.success:
output = str(result.data)
if len(output) > 4000:
output = output[:4000] + f"\n\n... [输出已截断,共 {len(str(result.data))} 字符]"
return output
else:
return f"工具执行失败: {result.error}"
except Exception as e:
logger.error(f"Tool execution error: {e}")
return f"工具执行错误: {str(e)}"
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
"""
@ -246,7 +216,7 @@ class ReconAgent(BaseAgent):
{task_context or task or '进行全面的信息收集,为安全审计做准备。'}
## 可用工具
{self._get_tools_description()}
{self.get_tools_description()}
请开始你的信息收集工作首先思考应该收集什么信息然后选择合适的工具"""
@ -259,7 +229,7 @@ class ReconAgent(BaseAgent):
self._steps = []
final_result = None
await self.emit_thinking("🔍 Recon Agent 启动LLM 开始自主收集信息...")
await self.emit_thinking("Recon Agent 启动LLM 开始自主收集信息...")
try:
for iteration in range(self.config.max_iterations):
@ -268,18 +238,22 @@ class ReconAgent(BaseAgent):
self._iteration = iteration + 1
# 🔥 发射 LLM 开始思考事件
await self.emit_llm_start(iteration + 1)
# 🔥 再次检查取消标志在LLM调用之前
if self.is_cancelled:
await self.emit_thinking("🛑 任务已取消,停止执行")
break
# 🔥 调用 LLM 进行思考和决策
response = await self.llm_service.chat_completion_raw(
messages=self._conversation_history,
temperature=0.1,
max_tokens=2048,
)
# 调用 LLM 进行思考和决策(使用基类统一方法)
try:
llm_output, tokens_this_round = await self.stream_llm_call(
self._conversation_history,
temperature=0.1,
max_tokens=2048,
)
except asyncio.CancelledError:
logger.info(f"[{self.name}] LLM call cancelled")
break
llm_output = response.get("content", "")
tokens_this_round = response.get("usage", {}).get("total_tokens", 0)
self._total_tokens += tokens_this_round
# 解析 LLM 响应
@ -311,7 +285,7 @@ class ReconAgent(BaseAgent):
# 🔥 发射 LLM 动作决策事件
await self.emit_llm_action(step.action, step.action_input or {})
observation = await self._execute_tool(
observation = await self.execute_tool(
step.action,
step.action_input or {}
)
@ -341,9 +315,18 @@ class ReconAgent(BaseAgent):
if not final_result:
final_result = self._summarize_from_steps()
# 🔥 记录工作和洞察
self.record_work(f"完成项目信息收集,发现 {len(final_result.get('entry_points', []))} 个入口点")
self.record_work(f"识别技术栈: {final_result.get('tech_stack', {})}")
if final_result.get("high_risk_areas"):
self.add_insight(f"发现 {len(final_result['high_risk_areas'])} 个高风险区域需要重点分析")
if final_result.get("initial_findings"):
self.add_insight(f"初步发现 {len(final_result['initial_findings'])} 个潜在问题")
await self.emit_event(
"info",
f"🎯 Recon Agent 完成: {self._iteration} 轮迭代, {self._tool_calls} 次工具调用"
f"Recon Agent 完成: {self._iteration} 轮迭代, {self._tool_calls} 次工具调用"
)
return AgentResult(

View File

@ -10,6 +10,7 @@ LLM 是验证的大脑!
类型: ReAct (真正的!)
"""
import asyncio
import json
import logging
import re
@ -18,6 +19,7 @@ from dataclasses import dataclass
from datetime import datetime, timezone
from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
from ..json_parser import AgentJsonParser
logger = logging.getLogger(__name__)
@ -34,15 +36,17 @@ VERIFICATION_SYSTEM_PROMPT = """你是 DeepAudit 的漏洞验证 Agent一个*
## 你可以使用的工具
### 代码分析
### 文件操作
- **read_file**: 读取更多代码上下文
参数: file_path (str), start_line (int), end_line (int)
- **function_context**: 分析函数调用关系
参数: function_name (str)
- **dataflow_analysis**: 追踪数据流
参数: source (str), sink (str), file_path (str)
- **list_files**: 列出目录文件
参数: directory (str), pattern (str)
### 验证分析
- **vulnerability_validation**: LLM 深度验证
参数: code (str), vulnerability_type (str), context (str)
- **dataflow_analysis**: 追踪数据流
参数: source (str), sink (str), file_path (str)
### 沙箱验证
- **sandbox_exec**: 在沙箱中执行命令
@ -157,16 +161,6 @@ class VerificationAgent(BaseAgent):
self._conversation_history: List[Dict[str, str]] = []
self._steps: List[VerificationStep] = []
def _get_tools_description(self) -> str:
"""生成工具描述"""
tools_info = []
for name, tool in self.tools.items():
if name.startswith("_"):
continue
desc = f"- {name}: {getattr(tool, 'description', 'No description')}"
tools_info.append(desc)
return "\n".join(tools_info)
def _parse_llm_response(self, response: str) -> VerificationStep:
"""解析 LLM 响应"""
step = VerificationStep(thought="")
@ -180,13 +174,20 @@ class VerificationAgent(BaseAgent):
final_match = re.search(r'Final Answer:\s*(.*?)$', response, re.DOTALL)
if final_match:
step.is_final = True
try:
answer_text = final_match.group(1).strip()
answer_text = re.sub(r'```json\s*', '', answer_text)
answer_text = re.sub(r'```\s*', '', answer_text)
step.final_answer = json.loads(answer_text)
except json.JSONDecodeError:
step.final_answer = {"findings": [], "raw_answer": final_match.group(1).strip()}
answer_text = final_match.group(1).strip()
answer_text = re.sub(r'```json\s*', '', answer_text)
answer_text = re.sub(r'```\s*', '', answer_text)
# 使用增强的 JSON 解析器
step.final_answer = AgentJsonParser.parse(
answer_text,
default={"findings": [], "raw_answer": answer_text}
)
# 确保 findings 格式正确
if "findings" in step.final_answer:
step.final_answer["findings"] = [
f for f in step.final_answer["findings"]
if isinstance(f, dict)
]
return step
# 提取 Action
@ -200,50 +201,14 @@ class VerificationAgent(BaseAgent):
input_text = input_match.group(1).strip()
input_text = re.sub(r'```json\s*', '', input_text)
input_text = re.sub(r'```\s*', '', input_text)
try:
step.action_input = json.loads(input_text)
except json.JSONDecodeError:
step.action_input = {"raw_input": input_text}
# 使用增强的 JSON 解析器
step.action_input = AgentJsonParser.parse(
input_text,
default={"raw_input": input_text}
)
return step
async def _execute_tool(self, tool_name: str, tool_input: Dict) -> str:
"""执行工具"""
tool = self.tools.get(tool_name)
if not tool:
return f"错误: 工具 '{tool_name}' 不存在。可用工具: {list(self.tools.keys())}"
try:
self._tool_calls += 1
await self.emit_tool_call(tool_name, tool_input)
import time
start = time.time()
result = await tool.execute(**tool_input)
duration_ms = int((time.time() - start) * 1000)
await self.emit_tool_result(tool_name, str(result.data)[:200], duration_ms)
if result.success:
output = str(result.data)
# 包含 metadata
if result.metadata:
if "validation" in result.metadata:
output += f"\n\n验证结果:\n{json.dumps(result.metadata['validation'], ensure_ascii=False, indent=2)}"
if len(output) > 4000:
output = output[:4000] + f"\n\n... [输出已截断]"
return output
else:
return f"工具执行失败: {result.error}"
except Exception as e:
logger.error(f"Tool execution error: {e}")
return f"工具执行错误: {str(e)}"
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
"""
执行漏洞验证 - LLM 全程参与
@ -256,20 +221,32 @@ class VerificationAgent(BaseAgent):
task = input_data.get("task", "")
task_context = input_data.get("task_context", "")
# 🔥 处理交接信息
handoff = input_data.get("handoff")
if handoff:
from .base import TaskHandoff
if isinstance(handoff, dict):
handoff = TaskHandoff.from_dict(handoff)
self.receive_handoff(handoff)
# 收集所有待验证的发现
findings_to_verify = []
for phase_name, result in previous_results.items():
if isinstance(result, dict):
data = result.get("data", {})
else:
data = result.data if hasattr(result, 'data') else {}
if isinstance(data, dict):
phase_findings = data.get("findings", [])
for f in phase_findings:
if f.get("needs_verification", True):
findings_to_verify.append(f)
# 🔥 优先从交接信息获取发现
if self._incoming_handoff and self._incoming_handoff.key_findings:
findings_to_verify = self._incoming_handoff.key_findings.copy()
else:
for phase_name, result in previous_results.items():
if isinstance(result, dict):
data = result.get("data", {})
else:
data = result.data if hasattr(result, 'data') else {}
if isinstance(data, dict):
phase_findings = data.get("findings", [])
for f in phase_findings:
if f.get("needs_verification", True):
findings_to_verify.append(f)
# 去重
findings_to_verify = self._deduplicate(findings_to_verify)
@ -289,7 +266,12 @@ class VerificationAgent(BaseAgent):
f"开始验证 {len(findings_to_verify)} 个发现"
)
# 构建初始消息
# 🔥 记录工作开始
self.record_work(f"开始验证 {len(findings_to_verify)} 个漏洞发现")
# 🔥 构建包含交接上下文的初始消息
handoff_context = self.get_handoff_context()
findings_summary = []
for i, f in enumerate(findings_to_verify):
findings_summary.append(f"""
@ -306,6 +288,8 @@ class VerificationAgent(BaseAgent):
initial_message = f"""请验证以下 {len(findings_to_verify)} 个安全发现。
{handoff_context if handoff_context else ''}
## 待验证发现
{''.join(findings_summary)}
@ -313,9 +297,10 @@ class VerificationAgent(BaseAgent):
- 验证级别: {config.get('verification_level', 'standard')}
## 可用工具
{self._get_tools_description()}
{self.get_tools_description()}
请开始验证对于每个发现思考如何验证它使用合适的工具获取更多信息然后判断是否为真实漏洞"""
请开始验证对于每个发现思考如何验证它使用合适的工具获取更多信息然后判断是否为真实漏洞
{f"特别注意 Analysis Agent 提到的关注点。" if handoff_context else ""}"""
# 初始化对话历史
self._conversation_history = [
@ -335,18 +320,22 @@ class VerificationAgent(BaseAgent):
self._iteration = iteration + 1
# 🔥 发射 LLM 开始思考事件
await self.emit_llm_start(iteration + 1)
# 🔥 再次检查取消标志在LLM调用之前
if self.is_cancelled:
await self.emit_thinking("🛑 任务已取消,停止执行")
break
# 🔥 调用 LLM 进行思考和决策
response = await self.llm_service.chat_completion_raw(
messages=self._conversation_history,
temperature=0.1,
max_tokens=3000,
)
# 调用 LLM 进行思考和决策(流式输出)
try:
llm_output, tokens_this_round = await self.stream_llm_call(
self._conversation_history,
temperature=0.1,
max_tokens=3000,
)
except asyncio.CancelledError:
logger.info(f"[{self.name}] LLM call cancelled")
break
llm_output = response.get("content", "")
tokens_this_round = response.get("usage", {}).get("total_tokens", 0)
self._total_tokens += tokens_this_round
# 解析 LLM 响应
@ -367,6 +356,14 @@ class VerificationAgent(BaseAgent):
if step.is_final:
await self.emit_llm_decision("完成漏洞验证", "LLM 判断验证已充分")
final_result = step.final_answer
# 🔥 记录洞察和工作
if final_result and "findings" in final_result:
verified_count = len([f for f in final_result["findings"] if f.get("is_verified")])
fp_count = len([f for f in final_result["findings"] if f.get("verdict") == "false_positive"])
self.add_insight(f"验证了 {len(final_result['findings'])} 个发现,{verified_count} 个确认,{fp_count} 个误报")
self.record_work(f"完成漏洞验证: {verified_count} 个确认, {fp_count} 个误报")
await self.emit_llm_complete(
f"验证完成",
self._total_tokens
@ -378,7 +375,7 @@ class VerificationAgent(BaseAgent):
# 🔥 发射 LLM 动作决策事件
await self.emit_llm_action(step.action, step.action_input or {})
observation = await self._execute_tool(
observation = await self.execute_tool(
step.action,
step.action_input or {}
)
@ -438,7 +435,7 @@ class VerificationAgent(BaseAgent):
await self.emit_event(
"info",
f"🎯 Verification Agent 完成: {confirmed_count} 确认, {likely_count} 可能, {false_positive_count} 误报"
f"Verification Agent 完成: {confirmed_count} 确认, {likely_count} 可能, {false_positive_count} 误报"
)
return AgentResult(

View File

@ -299,8 +299,9 @@ class EventManager:
"timestamp": timestamp.isoformat(),
}
# 保存到数据库
if self.db_session_factory:
# 保存到数据库(跳过高频事件如 thinking_token
skip_db_events = {"thinking_token", "thinking_start", "thinking_end"}
if self.db_session_factory and event_type not in skip_db_events:
try:
await self._save_event_to_db(event_data)
except Exception as e:

View File

@ -69,6 +69,11 @@ class AuditState(TypedDict):
llm_next_action: Optional[str] # LLM 建议的下一步: "continue_analysis", "verify", "report", "end"
llm_routing_reason: Optional[str] # LLM 的决策理由
# 🔥 新增Agent 间协作的任务交接信息
recon_handoff: Optional[Dict[str, Any]] # Recon -> Analysis 的交接
analysis_handoff: Optional[Dict[str, Any]] # Analysis -> Verification 的交接
verification_handoff: Optional[Dict[str, Any]] # Verification -> Report 的交接
# 消息和事件
messages: Annotated[List[Dict], operator.add]
events: Annotated[List[Dict], operator.add]
@ -146,6 +151,9 @@ class LLMRouter:
# 统计发现
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
for f in findings:
# 跳过非字典类型的 finding
if not isinstance(f, dict):
continue
sev = f.get("severity", "medium")
severity_counts[sev] = severity_counts.get(sev, 0) + 1
@ -243,6 +251,11 @@ def route_after_recon(state: AuditState) -> Literal["analysis", "end"]:
Recon 后的路由决策
优先使用 LLM 的决策否则使用默认逻辑
"""
# 🔥 检查是否有错误
if state.get("error") or state.get("current_phase") == "error":
logger.error(f"Recon phase has error, routing to end: {state.get('error')}")
return "end"
# 检查 LLM 是否有决策
llm_action = state.get("llm_next_action")
if llm_action:

View File

@ -1,6 +1,8 @@
"""
LangGraph 节点实现
每个节点封装一个 Agent 的执行逻辑
协作增强节点之间通过 TaskHandoff 传递结构化的上下文和洞察
"""
from typing import Dict, Any, List, Optional
@ -28,6 +30,14 @@ class BaseNode:
await self.event_emitter.emit_info(message)
except Exception as e:
logger.warning(f"Failed to emit event: {e}")
def _extract_handoff_from_state(self, state: Dict[str, Any], from_phase: str):
"""从状态中提取前序 Agent 的 handoff"""
handoff_data = state.get(f"{from_phase}_handoff")
if handoff_data:
from ..agents.base import TaskHandoff
return TaskHandoff.from_dict(handoff_data)
return None
class ReconNode(BaseNode):
@ -35,7 +45,7 @@ class ReconNode(BaseNode):
信息收集节点
输入: project_root, project_info, config
输出: tech_stack, entry_points, high_risk_areas, dependencies
输出: tech_stack, entry_points, high_risk_areas, dependencies, recon_handoff
"""
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
@ -52,6 +62,35 @@ class ReconNode(BaseNode):
if result.success and result.data:
data = result.data
# 🔥 创建交接信息给 Analysis Agent
handoff = self.agent.create_handoff(
to_agent="Analysis",
summary=f"项目信息收集完成。发现 {len(data.get('entry_points', []))} 个入口点,{len(data.get('high_risk_areas', []))} 个高风险区域。",
key_findings=data.get("initial_findings", []),
suggested_actions=[
{
"type": "deep_analysis",
"description": f"深入分析高风险区域: {', '.join(data.get('high_risk_areas', [])[:5])}",
"priority": "high",
},
{
"type": "entry_point_audit",
"description": "审计所有入口点的输入验证",
"priority": "high",
},
],
attention_points=[
f"技术栈: {data.get('tech_stack', {}).get('frameworks', [])}",
f"主要语言: {data.get('tech_stack', {}).get('languages', [])}",
],
priority_areas=data.get("high_risk_areas", [])[:10],
context_data={
"tech_stack": data.get("tech_stack", {}),
"entry_points": data.get("entry_points", []),
"dependencies": data.get("dependencies", {}),
},
)
await self.emit_event(
"phase_complete",
f"✅ 信息收集完成: 发现 {len(data.get('entry_points', []))} 个入口点"
@ -63,12 +102,15 @@ class ReconNode(BaseNode):
"high_risk_areas": data.get("high_risk_areas", []),
"dependencies": data.get("dependencies", {}),
"current_phase": "recon_complete",
"findings": data.get("initial_findings", []), # 初步发现
"findings": data.get("initial_findings", []),
# 🔥 保存交接信息
"recon_handoff": handoff.to_dict(),
"events": [{
"type": "recon_complete",
"data": {
"entry_points_count": len(data.get("entry_points", [])),
"high_risk_areas_count": len(data.get("high_risk_areas", [])),
"handoff_summary": handoff.summary,
}
}],
}
@ -90,8 +132,8 @@ class AnalysisNode(BaseNode):
"""
漏洞分析节点
输入: tech_stack, entry_points, high_risk_areas, previous findings
输出: findings (累加), should_continue_analysis
输入: tech_stack, entry_points, high_risk_areas, recon_handoff
输出: findings (累加), should_continue_analysis, analysis_handoff
"""
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
@ -104,6 +146,15 @@ class AnalysisNode(BaseNode):
)
try:
# 🔥 提取 Recon 的交接信息
recon_handoff = self._extract_handoff_from_state(state, "recon")
if recon_handoff:
self.agent.receive_handoff(recon_handoff)
await self.emit_event(
"handoff_received",
f"📨 收到 Recon Agent 交接: {recon_handoff.summary[:50]}..."
)
# 构建分析输入
analysis_input = {
"phase_name": "analysis",
@ -121,6 +172,8 @@ class AnalysisNode(BaseNode):
}
}
},
# 🔥 传递交接信息
"handoff": recon_handoff,
}
# 调用 Analysis Agent
@ -130,27 +183,70 @@ class AnalysisNode(BaseNode):
new_findings = result.data.get("findings", [])
# 判断是否需要继续分析
# 如果这一轮发现了很多问题,可能还有更多
should_continue = (
len(new_findings) >= 5 and
iteration < state.get("max_iterations", 3)
)
# 🔥 创建交接信息给 Verification Agent
# 统计严重程度
severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0}
for f in new_findings:
if isinstance(f, dict):
sev = f.get("severity", "medium")
severity_counts[sev] = severity_counts.get(sev, 0) + 1
handoff = self.agent.create_handoff(
to_agent="Verification",
summary=f"漏洞分析完成。发现 {len(new_findings)} 个潜在漏洞 (Critical: {severity_counts['critical']}, High: {severity_counts['high']}, Medium: {severity_counts['medium']}, Low: {severity_counts['low']})",
key_findings=new_findings[:20], # 传递前20个发现
suggested_actions=[
{
"type": "verify_critical",
"description": "优先验证 Critical 和 High 级别的漏洞",
"priority": "critical",
},
{
"type": "poc_generation",
"description": "为确认的漏洞生成 PoC",
"priority": "high",
},
],
attention_points=[
f"{severity_counts['critical']} 个 Critical 级别漏洞需要立即验证",
f"{severity_counts['high']} 个 High 级别漏洞需要优先验证",
"注意检查是否有误报,特别是静态分析工具的结果",
],
priority_areas=[
f.get("file_path", "") for f in new_findings
if f.get("severity") in ["critical", "high"]
][:10],
context_data={
"severity_distribution": severity_counts,
"total_findings": len(new_findings),
"iteration": iteration,
},
)
await self.emit_event(
"phase_complete",
f"✅ 分析迭代 {iteration} 完成: 发现 {len(new_findings)} 个潜在漏洞"
)
return {
"findings": new_findings, # 会自动累加
"findings": new_findings,
"iteration": iteration,
"should_continue_analysis": should_continue,
"current_phase": "analysis_complete",
# 🔥 保存交接信息
"analysis_handoff": handoff.to_dict(),
"events": [{
"type": "analysis_iteration",
"data": {
"iteration": iteration,
"findings_count": len(new_findings),
"severity_distribution": severity_counts,
"handoff_summary": handoff.summary,
}
}],
}
@ -174,8 +270,8 @@ class VerificationNode(BaseNode):
"""
漏洞验证节点
输入: findings
输出: verified_findings, false_positives
输入: findings, analysis_handoff
输出: verified_findings, false_positives, verification_handoff
"""
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
@ -195,6 +291,15 @@ class VerificationNode(BaseNode):
)
try:
# 🔥 提取 Analysis 的交接信息
analysis_handoff = self._extract_handoff_from_state(state, "analysis")
if analysis_handoff:
self.agent.receive_handoff(analysis_handoff)
await self.emit_event(
"handoff_received",
f"📨 收到 Analysis Agent 交接: {analysis_handoff.summary[:50]}..."
)
# 构建验证输入
verification_input = {
"previous_results": {
@ -205,16 +310,49 @@ class VerificationNode(BaseNode):
}
},
"config": state["config"],
# 🔥 传递交接信息
"handoff": analysis_handoff,
}
# 调用 Verification Agent
result = await self.agent.run(verification_input)
if result.success and result.data:
verified = [f for f in result.data.get("findings", []) if f.get("is_verified")]
false_pos = [f["id"] for f in result.data.get("findings", [])
all_verified_findings = result.data.get("findings", [])
verified = [f for f in all_verified_findings if f.get("is_verified")]
false_pos = [f.get("id", f.get("title", "unknown")) for f in all_verified_findings
if f.get("verdict") == "false_positive"]
# 🔥 创建交接信息给 Report 节点
handoff = self.agent.create_handoff(
to_agent="Report",
summary=f"漏洞验证完成。{len(verified)} 个漏洞已确认,{len(false_pos)} 个误报已排除。",
key_findings=verified,
suggested_actions=[
{
"type": "generate_report",
"description": "生成详细的安全审计报告",
"priority": "high",
},
{
"type": "remediation_plan",
"description": "为确认的漏洞制定修复计划",
"priority": "high",
},
],
attention_points=[
f"{len(verified)} 个漏洞已确认存在",
f"{len(false_pos)} 个误报已排除",
"建议按严重程度优先修复 Critical 和 High 级别漏洞",
],
context_data={
"verified_count": len(verified),
"false_positive_count": len(false_pos),
"total_analyzed": len(findings),
"verification_rate": len(verified) / len(findings) if findings else 0,
},
)
await self.emit_event(
"phase_complete",
f"✅ 验证完成: {len(verified)} 已确认, {len(false_pos)} 误报"
@ -224,11 +362,14 @@ class VerificationNode(BaseNode):
"verified_findings": verified,
"false_positives": false_pos,
"current_phase": "verification_complete",
# 🔥 保存交接信息
"verification_handoff": handoff.to_dict(),
"events": [{
"type": "verification_complete",
"data": {
"verified_count": len(verified),
"false_positive_count": len(false_pos),
"handoff_summary": handoff.summary,
}
}],
}
@ -269,6 +410,11 @@ class ReportNode(BaseNode):
type_counts = {}
for finding in findings:
# 跳过非字典类型的 finding防止数据格式异常
if not isinstance(finding, dict):
logger.warning(f"Skipping invalid finding (not a dict): {type(finding)}")
continue
sev = finding.get("severity", "medium")
severity_counts[sev] = severity_counts.get(sev, 0) + 1
@ -300,7 +446,7 @@ class ReportNode(BaseNode):
await self.emit_event(
"phase_complete",
f"报告生成完成: 安全评分 {security_score}/100"
f"报告生成完成: 安全评分 {security_score}/100"
)
return {

View File

@ -39,167 +39,8 @@ from .nodes import ReconNode, AnalysisNode, VerificationNode, ReportNode
logger = logging.getLogger(__name__)
class LLMService:
"""
LLM 服务封装
提供代码分析漏洞检测等 AI 功能
"""
def __init__(self, model: Optional[str] = None, api_key: Optional[str] = None):
self.model = model or settings.LLM_MODEL or "gpt-4o-mini"
self.api_key = api_key or settings.LLM_API_KEY
self.base_url = settings.LLM_BASE_URL
async def chat_completion_raw(
self,
messages: List[Dict[str, str]],
temperature: float = 0.1,
max_tokens: int = 4096,
) -> Dict[str, Any]:
"""调用 LLM 生成响应"""
try:
import litellm
response = await litellm.acompletion(
model=self.model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
api_key=self.api_key,
base_url=self.base_url,
)
return {
"content": response.choices[0].message.content,
"usage": {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens,
} if response.usage else {},
}
except Exception as e:
logger.error(f"LLM call failed: {e}")
raise
async def analyze_code(self, code: str, language: str) -> Dict[str, Any]:
"""
分析代码安全问题
Args:
code: 代码内容
language: 编程语言
Returns:
分析结果包含 issues 列表
"""
prompt = f"""请分析以下 {language} 代码的安全问题。
代码:
```{language}
{code[:8000]}
```
请识别所有潜在的安全漏洞包括但不限于:
- SQL 注入
- XSS (跨站脚本)
- 命令注入
- 路径遍历
- 不安全的反序列化
- 硬编码密钥/密码
- 不安全的加密
- SSRF
- 认证/授权问题
对于每个发现的问题请提供:
1. 漏洞类型
2. 严重程度 (critical/high/medium/low)
3. 问题描述
4. 具体行号
5. 修复建议
请以 JSON 格式返回结果:
{{
"issues": [
{{
"type": "漏洞类型",
"severity": "严重程度",
"title": "问题标题",
"description": "详细描述",
"line": 行号,
"code_snippet": "相关代码片段",
"suggestion": "修复建议"
}}
],
"quality_score": 0-100
}}
如果没有发现安全问题返回空的 issues 数组和较高的 quality_score"""
try:
result = await self.chat_completion_raw(
messages=[
{"role": "system", "content": "你是一位专业的代码安全审计专家,擅长发现代码中的安全漏洞。请只返回 JSON 格式的结果,不要包含其他内容。"},
{"role": "user", "content": prompt},
],
temperature=0.1,
max_tokens=4096,
)
content = result.get("content", "{}")
# 尝试提取 JSON
import json
import re
# 尝试直接解析
try:
return json.loads(content)
except json.JSONDecodeError:
pass
# 尝试从 markdown 代码块提取
json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', content)
if json_match:
try:
return json.loads(json_match.group(1))
except json.JSONDecodeError:
pass
# 返回空结果
return {"issues": [], "quality_score": 80}
except Exception as e:
logger.error(f"Code analysis failed: {e}")
return {"issues": [], "quality_score": 0, "error": str(e)}
async def analyze_code_with_custom_prompt(
self,
code: str,
language: str,
prompt: str,
**kwargs
) -> Dict[str, Any]:
"""使用自定义提示词分析代码"""
full_prompt = prompt.replace("{code}", code).replace("{language}", language)
try:
result = await self.chat_completion_raw(
messages=[
{"role": "system", "content": "你是一位专业的代码安全审计专家。"},
{"role": "user", "content": full_prompt},
],
temperature=0.1,
)
return {
"analysis": result.get("content", ""),
"usage": result.get("usage", {}),
}
except Exception as e:
logger.error(f"Custom analysis failed: {e}")
return {"analysis": "", "error": str(e)}
# 🔥 使用系统统一的 LLMService支持用户配置
from app.services.llm.service import LLMService
class AgentRunner:
@ -217,18 +58,22 @@ class AgentRunner:
db: AsyncSession,
task: AgentTask,
project_root: str,
user_config: Optional[Dict[str, Any]] = None,
):
self.db = db
self.task = task
self.project_root = project_root
# 🔥 保存用户配置,供 RAG 初始化使用
self.user_config = user_config or {}
# 事件管理 - 传入 db_session_factory 以持久化事件
from app.db.session import async_session_factory
self.event_manager = EventManager(db_session_factory=async_session_factory)
self.event_emitter = AgentEventEmitter(task.id, self.event_manager)
# LLM 服务
self.llm_service = LLMService()
# 🔥 LLM 服务 - 使用用户配置(从系统配置页面获取)
self.llm_service = LLMService(user_config=self.user_config)
# 工具集
self.tools: Dict[str, Any] = {}
@ -248,14 +93,26 @@ class AgentRunner:
self._cancelled = False
self._running_task: Optional[asyncio.Task] = None
# Agent 引用(用于取消传播)
self._agents: List[Any] = []
# 流式处理器
self.stream_handler = StreamHandler(task.id)
def cancel(self):
"""取消任务"""
self._cancelled = True
# 🔥 取消所有 Agent
for agent in self._agents:
if hasattr(agent, 'cancel'):
agent.cancel()
logger.debug(f"Cancelled agent: {agent.name if hasattr(agent, 'name') else 'unknown'}")
# 取消运行中的任务
if self._running_task and not self._running_task.done():
self._running_task.cancel()
logger.info(f"Task {self.task.id} cancellation requested")
@property
@ -283,11 +140,33 @@ class AgentRunner:
await self.event_emitter.emit_info("📚 初始化 RAG 代码检索系统...")
try:
# 🔥 从用户配置中获取 LLM 配置(用于 Embedding API Key
# 优先级:用户配置 > 环境变量
user_llm_config = self.user_config.get('llmConfig', {})
# 获取 Embedding 配置(优先使用用户配置的 LLM API Key
embedding_provider = getattr(settings, 'EMBEDDING_PROVIDER', 'openai')
embedding_model = getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small')
# 🔥 API Key 优先级:用户配置 > 环境变量
embedding_api_key = (
user_llm_config.get('llmApiKey') or
getattr(settings, 'LLM_API_KEY', '') or
''
)
# 🔥 Base URL 优先级:用户配置 > 环境变量
embedding_base_url = (
user_llm_config.get('llmBaseUrl') or
getattr(settings, 'LLM_BASE_URL', None) or
None
)
embedding_service = EmbeddingService(
provider=settings.EMBEDDING_PROVIDER,
model=settings.EMBEDDING_MODEL,
api_key=settings.LLM_API_KEY,
base_url=settings.LLM_BASE_URL,
provider=embedding_provider,
model=embedding_model,
api_key=embedding_api_key,
base_url=embedding_base_url,
)
self.indexer = CodeIndexer(
@ -308,35 +187,59 @@ class AgentRunner:
async def _initialize_tools(self):
"""初始化工具集"""
await self.event_emitter.emit_info("🔧 初始化 Agent 工具集...")
await self.event_emitter.emit_info("初始化 Agent 工具集...")
# 文件工具
self.tools["read_file"] = FileReadTool(self.project_root)
self.tools["search_code"] = FileSearchTool(self.project_root)
self.tools["list_files"] = ListFilesTool(self.project_root)
# ============ 基础工具(所有 Agent 共享)============
base_tools = {
"read_file": FileReadTool(self.project_root),
"list_files": ListFilesTool(self.project_root),
}
# RAG 工具
# ============ Recon Agent 专属工具 ============
# 职责:信息收集、项目结构分析、技术栈识别
self.recon_tools = {
**base_tools,
"search_code": FileSearchTool(self.project_root),
}
# RAG 工具Recon 用于语义搜索)
if self.retriever:
self.tools["rag_query"] = RAGQueryTool(self.retriever)
self.tools["security_search"] = SecurityCodeSearchTool(self.retriever)
self.tools["function_context"] = FunctionContextTool(self.retriever)
self.recon_tools["rag_query"] = RAGQueryTool(self.retriever)
# 分析工具
self.tools["pattern_match"] = PatternMatchTool(self.project_root)
self.tools["code_analysis"] = CodeAnalysisTool(self.llm_service)
self.tools["dataflow_analysis"] = DataFlowAnalysisTool(self.llm_service)
self.tools["vulnerability_validation"] = VulnerabilityValidationTool(self.llm_service)
# ============ Analysis Agent 专属工具 ============
# 职责:漏洞分析、代码审计、模式匹配
self.analysis_tools = {
**base_tools,
"search_code": FileSearchTool(self.project_root),
# 模式匹配和代码分析
"pattern_match": PatternMatchTool(self.project_root),
"code_analysis": CodeAnalysisTool(self.llm_service),
"dataflow_analysis": DataFlowAnalysisTool(self.llm_service),
# 外部静态分析工具
"semgrep_scan": SemgrepTool(self.project_root),
"bandit_scan": BanditTool(self.project_root),
"gitleaks_scan": GitleaksTool(self.project_root),
"trufflehog_scan": TruffleHogTool(self.project_root),
"npm_audit": NpmAuditTool(self.project_root),
"safety_scan": SafetyTool(self.project_root),
"osv_scan": OSVScannerTool(self.project_root),
}
# 外部安全工具
self.tools["semgrep_scan"] = SemgrepTool(self.project_root)
self.tools["bandit_scan"] = BanditTool(self.project_root)
self.tools["gitleaks_scan"] = GitleaksTool(self.project_root)
self.tools["trufflehog_scan"] = TruffleHogTool(self.project_root)
self.tools["npm_audit"] = NpmAuditTool(self.project_root)
self.tools["safety_scan"] = SafetyTool(self.project_root)
self.tools["osv_scan"] = OSVScannerTool(self.project_root)
# RAG 工具Analysis 用于安全相关代码搜索)
if self.retriever:
self.analysis_tools["security_search"] = SecurityCodeSearchTool(self.retriever)
self.analysis_tools["function_context"] = FunctionContextTool(self.retriever)
# 沙箱工具
# ============ Verification Agent 专属工具 ============
# 职责漏洞验证、PoC 执行、误报排除
self.verification_tools = {
**base_tools,
# 验证工具
"vulnerability_validation": VulnerabilityValidationTool(self.llm_service),
"dataflow_analysis": DataFlowAnalysisTool(self.llm_service),
}
# 沙箱工具(仅 Verification Agent 可用)
try:
self.sandbox_manager = SandboxManager(
image=settings.SANDBOX_IMAGE,
@ -344,14 +247,20 @@ class AgentRunner:
cpu_limit=settings.SANDBOX_CPU_LIMIT,
)
self.tools["sandbox_exec"] = SandboxTool(self.sandbox_manager)
self.tools["sandbox_http"] = SandboxHttpTool(self.sandbox_manager)
self.tools["verify_vulnerability"] = VulnerabilityVerifyTool(self.sandbox_manager)
self.verification_tools["sandbox_exec"] = SandboxTool(self.sandbox_manager)
self.verification_tools["sandbox_http"] = SandboxHttpTool(self.sandbox_manager)
self.verification_tools["verify_vulnerability"] = VulnerabilityVerifyTool(self.sandbox_manager)
except Exception as e:
logger.warning(f"Sandbox initialization failed: {e}")
await self.event_emitter.emit_info(f"✅ 已加载 {len(self.tools)} 个工具")
# 统计总工具数
total_tools = len(set(
list(self.recon_tools.keys()) +
list(self.analysis_tools.keys()) +
list(self.verification_tools.keys())
))
await self.event_emitter.emit_info(f"已加载 {total_tools} 个工具")
async def _build_graph(self):
"""构建 LangGraph 审计图"""
@ -360,25 +269,28 @@ class AgentRunner:
# 导入 Agent
from app.services.agent.agents import ReconAgent, AnalysisAgent, VerificationAgent
# 创建 Agent 实例
# 创建 Agent 实例(每个 Agent 使用专属工具集)
recon_agent = ReconAgent(
llm_service=self.llm_service,
tools=self.tools,
tools=self.recon_tools, # Recon 专属工具
event_emitter=self.event_emitter,
)
analysis_agent = AnalysisAgent(
llm_service=self.llm_service,
tools=self.tools,
tools=self.analysis_tools, # Analysis 专属工具
event_emitter=self.event_emitter,
)
verification_agent = VerificationAgent(
llm_service=self.llm_service,
tools=self.tools,
tools=self.verification_tools, # Verification 专属工具
event_emitter=self.event_emitter,
)
# 🔥 保存 Agent 引用以便取消时传播信号
self._agents = [recon_agent, analysis_agent, verification_agent]
# 创建节点
recon_node = ReconNode(recon_agent, self.event_emitter)
analysis_node = AnalysisNode(analysis_agent, self.event_emitter)
@ -481,6 +393,10 @@ class AgentRunner:
"iteration": 0,
"max_iterations": self.task.max_iterations or 50,
"should_continue_analysis": False,
# 🔥 Agent 协作交接信息
"recon_handoff": None,
"analysis_handoff": None,
"verification_handoff": None,
"messages": [],
"events": [],
"summary": None,
@ -556,6 +472,33 @@ class AgentRunner:
graph_state = self.graph.get_state(run_config)
final_state = graph_state.values if graph_state else {}
# 🔥 检查是否有错误
error = final_state.get("error")
if error:
# 检查是否是 LLM 认证错误
error_str = str(error)
if "AuthenticationError" in error_str or "API key" in error_str or "invalid_api_key" in error_str:
error_message = "LLM API 密钥配置错误。请检查环境变量 LLM_API_KEY 或配置中的 API 密钥是否正确。"
logger.error(f"LLM authentication error: {error}")
else:
error_message = error_str
duration_ms = int((time.time() - start_time) * 1000)
# 标记任务为失败
await self._update_task_status(AgentTaskStatus.FAILED, error_message)
await self.event_emitter.emit_task_error(error_message)
yield StreamEvent(
event_type=StreamEventType.TASK_ERROR,
sequence=self.stream_handler._next_sequence(),
data={
"error": error_message,
"message": f"❌ 任务失败: {error_message}",
},
)
return
# 6. 保存发现
findings = final_state.get("findings", [])
await self._save_findings(findings)

View File

@ -0,0 +1,251 @@
"""
Agent JSON 解析工具
LLM 响应中安全地解析 JSON参考 llm/service.py 的实现
"""
import json
import re
import logging
from typing import Dict, Any, List, Optional, Union
logger = logging.getLogger(__name__)
# 尝试导入 json-repair 库
try:
from json_repair import repair_json
JSON_REPAIR_AVAILABLE = True
except ImportError:
JSON_REPAIR_AVAILABLE = False
logger.debug("json-repair library not available")
class AgentJsonParser:
"""Agent 专用的 JSON 解析器"""
@staticmethod
def clean_text(text: str) -> str:
"""清理文本中的控制字符"""
if not text:
return ""
# 移除 BOM 和零宽字符
text = text.replace('\ufeff', '').replace('\u200b', '').replace('\u200c', '').replace('\u200d', '')
return text
@staticmethod
def fix_json_format(text: str) -> str:
"""修复常见的 JSON 格式问题"""
text = text.strip()
# 移除尾部逗号
text = re.sub(r',(\s*[}\]])', r'\1', text)
# 修复未转义的换行符(在字符串值中)
text = re.sub(r':\s*"([^"]*)\n([^"]*)"', r': "\1\\n\2"', text)
return text
@classmethod
def extract_from_markdown(cls, text: str) -> Dict[str, Any]:
"""从 markdown 代码块提取 JSON"""
match = re.search(r'```(?:json)?\s*(\{[\s\S]*?\})\s*```', text)
if match:
return json.loads(match.group(1))
raise ValueError("No markdown code block found")
@classmethod
def extract_json_object(cls, text: str) -> Dict[str, Any]:
"""智能提取 JSON 对象"""
start_idx = text.find('{')
if start_idx == -1:
raise ValueError("No JSON object found")
# 考虑字符串内的花括号和转义字符
brace_count = 0
in_string = False
escape_next = False
end_idx = -1
for i in range(start_idx, len(text)):
char = text[i]
if escape_next:
escape_next = False
continue
if char == '\\':
escape_next = True
continue
if char == '"' and not escape_next:
in_string = not in_string
continue
if not in_string:
if char == '{':
brace_count += 1
elif char == '}':
brace_count -= 1
if brace_count == 0:
end_idx = i + 1
break
if end_idx == -1:
# 如果找不到完整的 JSON尝试使用最后一个 }
last_brace = text.rfind('}')
if last_brace > start_idx:
end_idx = last_brace + 1
else:
raise ValueError("Incomplete JSON object")
json_str = text[start_idx:end_idx]
# 修复格式问题
json_str = re.sub(r',(\s*[}\]])', r'\1', json_str)
return json.loads(json_str)
@classmethod
def fix_truncated_json(cls, text: str) -> Dict[str, Any]:
"""修复截断的 JSON"""
start_idx = text.find('{')
if start_idx == -1:
raise ValueError("Cannot fix truncated JSON")
json_str = text[start_idx:]
# 计算缺失的闭合符号
open_braces = json_str.count('{')
close_braces = json_str.count('}')
open_brackets = json_str.count('[')
close_brackets = json_str.count(']')
# 补全缺失的闭合符号
json_str += ']' * max(0, open_brackets - close_brackets)
json_str += '}' * max(0, open_braces - close_braces)
# 修复格式
json_str = re.sub(r',(\s*[}\]])', r'\1', json_str)
return json.loads(json_str)
@classmethod
def repair_with_library(cls, text: str) -> Dict[str, Any]:
"""使用 json-repair 库修复损坏的 JSON"""
if not JSON_REPAIR_AVAILABLE:
raise ValueError("json-repair library not available")
start_idx = text.find('{')
if start_idx == -1:
raise ValueError("No JSON object found for repair")
end_idx = text.rfind('}')
if end_idx > start_idx:
json_str = text[start_idx:end_idx + 1]
else:
json_str = text[start_idx:]
repaired = repair_json(json_str, return_objects=True)
if isinstance(repaired, dict):
return repaired
raise ValueError(f"json-repair returned unexpected type: {type(repaired)}")
@classmethod
def parse(cls, text: str, default: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
LLM 响应中解析 JSON增强版
Args:
text: LLM 响应文本
default: 解析失败时返回的默认值如果为 None 则抛出异常
Returns:
解析后的字典
"""
if not text or not text.strip():
if default is not None:
logger.warning("LLM 响应为空,返回默认值")
return default
raise ValueError("LLM 响应内容为空")
clean = cls.clean_text(text)
# 尝试多种方式解析
attempts = [
("直接解析", lambda: json.loads(text)),
("清理后解析", lambda: json.loads(cls.fix_json_format(clean))),
("Markdown 提取", lambda: cls.extract_from_markdown(text)),
("智能提取", lambda: cls.extract_json_object(clean)),
("截断修复", lambda: cls.fix_truncated_json(clean)),
("json-repair", lambda: cls.repair_with_library(text)),
]
last_error = None
for name, attempt in attempts:
try:
result = attempt()
if result and isinstance(result, dict):
if name != "直接解析":
logger.debug(f"✅ JSON 解析成功(方法: {name}")
return result
except Exception as e:
last_error = e
logger.debug(f"JSON 解析方法 '{name}' 失败: {e}")
# 所有尝试都失败
if default is not None:
logger.warning(f"JSON 解析失败,返回默认值。原始内容: {text[:200]}...")
return default
logger.error(f"❌ 无法解析 JSON原始内容: {text[:500]}...")
raise ValueError(f"无法解析 JSON: {last_error}")
@classmethod
def parse_findings(cls, text: str) -> List[Dict[str, Any]]:
"""
专门解析 findings 列表
Args:
text: LLM 响应文本
Returns:
findings 列表每个元素都是字典
"""
try:
result = cls.parse(text, default={"findings": []})
findings = result.get("findings", [])
# 确保每个 finding 都是字典
valid_findings = []
for f in findings:
if isinstance(f, dict):
valid_findings.append(f)
elif isinstance(f, str):
# 尝试将字符串解析为 JSON
try:
parsed = json.loads(f)
if isinstance(parsed, dict):
valid_findings.append(parsed)
except json.JSONDecodeError:
logger.warning(f"跳过无效的 finding字符串: {f[:100]}...")
else:
logger.warning(f"跳过无效的 finding类型: {type(f)}")
return valid_findings
except Exception as e:
logger.error(f"解析 findings 失败: {e}")
return []
@classmethod
def safe_get(cls, data: Union[Dict, str, Any], key: str, default: Any = None) -> Any:
"""
安全地从数据中获取值
Args:
data: 可能是字典或其他类型
key: 要获取的键
default: 默认值
Returns:
获取的值或默认值
"""
if isinstance(data, dict):
return data.get(key, default)
return default

View File

@ -79,9 +79,25 @@ class CodeAnalysisTool(AgentTool):
**kwargs
) -> ToolResult:
"""执行代码分析"""
import asyncio
try:
# 构建分析结果
analysis = await self.llm_service.analyze_code(code, language)
# 限制代码长度,避免超时
max_code_length = 50000 # 约 50KB
if len(code) > max_code_length:
code = code[:max_code_length] + "\n\n... (代码已截断,仅分析前 50000 字符)"
# 添加超时保护5分钟
try:
analysis = await asyncio.wait_for(
self.llm_service.analyze_code(code, language),
timeout=300.0 # 5分钟超时
)
except asyncio.TimeoutError:
return ToolResult(
success=False,
error="代码分析超时超过5分钟。代码可能过长或过于复杂请尝试分析较小的代码片段。",
)
issues = analysis.get("issues", [])

View File

@ -109,10 +109,14 @@ Semgrep 是业界领先的静态分析工具,支持 30+ 种编程语言。
"""执行 Semgrep 扫描"""
# 检查 semgrep 是否可用
if not await self._check_semgrep():
return ToolResult(
success=False,
error="Semgrep 未安装。请使用 'pip install semgrep' 安装。",
)
# 尝试自动安装
logger.info("Semgrep 未安装,尝试自动安装...")
install_success = await self._try_install_semgrep()
if not install_success:
return ToolResult(
success=False,
error="Semgrep 未安装。请使用 'pip install semgrep' 安装,或联系管理员安装。",
)
# 构建完整路径
full_path = os.path.normpath(os.path.join(self.project_root, target_path))
@ -216,6 +220,30 @@ Semgrep 是业界领先的静态分析工具,支持 30+ 种编程语言。
return proc.returncode == 0
except:
return False
async def _try_install_semgrep(self) -> bool:
"""尝试自动安装 Semgrep"""
try:
logger.info("正在安装 Semgrep...")
proc = await asyncio.create_subprocess_exec(
"pip", "install", "semgrep",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=120)
if proc.returncode == 0:
logger.info("Semgrep 安装成功")
# 验证安装
return await self._check_semgrep()
else:
logger.warning(f"Semgrep 安装失败: {stderr.decode()[:200]}")
return False
except asyncio.TimeoutError:
logger.warning("Semgrep 安装超时")
return False
except Exception as e:
logger.warning(f"Semgrep 安装出错: {e}")
return False
# ============ Bandit 工具 (Python) ============
@ -422,7 +450,11 @@ Gitleaks 是专业的密钥检测工具,支持 150+ 种密钥类型。
if not await self._check_gitleaks():
return ToolResult(
success=False,
error="Gitleaks 未安装。请从 https://github.com/gitleaks/gitleaks 安装。",
error="Gitleaks 未安装。Gitleaks 需要手动安装,请参考: https://github.com/gitleaks/gitleaks/releases\n"
"安装方法:\n"
"- macOS: brew install gitleaks\n"
"- Linux: 下载二进制文件并添加到 PATH\n"
"- Windows: 下载二进制文件并添加到 PATH",
)
full_path = os.path.normpath(os.path.join(self.project_root, target_path))

View File

@ -291,8 +291,19 @@ class PatternMatchTool(AgentTool):
return f"""快速扫描代码中的危险模式和常见漏洞。
使用正则表达式检测已知的不安全代码模式
重要此工具需要代码内容作为输入不是目录路径
使用步骤
1. 先用 read_file 工具读取文件内容
2. 然后将读取的代码内容传递给此工具的 code 参数
支持的漏洞类型: {vuln_types}
输入参数:
- code (必需): 要扫描的代码内容字符串
- file_path (可选): 文件路径用于上下文
- pattern_types (可选): 要检测的漏洞类型列表 ['sql_injection', 'xss']
- language (可选): 编程语言 'python', 'php', 'javascript'
这是一个快速扫描工具可以在分析开始时使用来快速发现潜在问题
发现的问题需要进一步分析确认"""

View File

@ -189,10 +189,27 @@ class SecurityCodeSearchTool(AgentTool):
)
except Exception as e:
return ToolResult(
success=False,
error=f"安全代码搜索失败: {str(e)}",
)
error_msg = str(e)
# 提供更友好的错误信息
if "401" in error_msg or "Unauthorized" in error_msg:
return ToolResult(
success=False,
error=f"安全代码搜索失败: API 认证失败401 Unauthorized\n"
f"请检查系统配置中的 LLM API Key 是否正确设置。\n"
f"错误详情: {error_msg[:200]}",
)
elif "403" in error_msg or "Forbidden" in error_msg:
return ToolResult(
success=False,
error=f"安全代码搜索失败: API 访问被拒绝403 Forbidden\n"
f"请检查 API Key 是否有足够的权限。\n"
f"错误详情: {error_msg[:200]}",
)
else:
return ToolResult(
success=False,
error=f"安全代码搜索失败: {error_msg[:500]}",
)
class FunctionContextInput(BaseModel):

View File

@ -177,6 +177,85 @@ class LiteLLMAdapter(BaseLLMAdapter):
finish_reason=choice.finish_reason,
)
async def stream_complete(self, request: LLMRequest):
"""
流式调用 LLM token 返回
Yields:
dict: {"type": "token", "content": str} {"type": "done", "content": str, "usage": dict}
"""
import litellm
await self.validate_config()
litellm.cache = None
litellm.drop_params = True
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
kwargs = {
"model": self._litellm_model,
"messages": messages,
"temperature": request.temperature if request.temperature is not None else self.config.temperature,
"max_tokens": request.max_tokens if request.max_tokens is not None else self.config.max_tokens,
"top_p": request.top_p if request.top_p is not None else self.config.top_p,
"stream": True, # 启用流式输出
}
if self.config.api_key and self.config.api_key != "ollama":
kwargs["api_key"] = self.config.api_key
if self._api_base:
kwargs["api_base"] = self._api_base
kwargs["timeout"] = self.config.timeout
accumulated_content = ""
try:
response = await litellm.acompletion(**kwargs)
async for chunk in response:
if not chunk.choices:
continue
delta = chunk.choices[0].delta
content = getattr(delta, "content", "") or ""
finish_reason = chunk.choices[0].finish_reason
if content:
accumulated_content += content
yield {
"type": "token",
"content": content,
"accumulated": accumulated_content,
}
if finish_reason:
# 流式完成
usage = None
if hasattr(chunk, "usage") and chunk.usage:
usage = {
"prompt_tokens": chunk.usage.prompt_tokens or 0,
"completion_tokens": chunk.usage.completion_tokens or 0,
"total_tokens": chunk.usage.total_tokens or 0,
}
yield {
"type": "done",
"content": accumulated_content,
"usage": usage,
"finish_reason": finish_reason,
}
break
except Exception as e:
yield {
"type": "error",
"error": str(e),
"accumulated": accumulated_content,
}
async def validate_config(self) -> bool:
"""验证配置"""
# Ollama 不需要 API Key

View File

@ -6,7 +6,7 @@ LLM服务 - 代码分析核心服务
import json
import re
import logging
from typing import Dict, Any, Optional
from typing import Dict, Any, Optional, List
from .types import LLMConfig, LLMProvider, LLMMessage, LLMRequest, DEFAULT_MODELS
from .factory import LLMFactory
from app.core.config import settings
@ -36,15 +36,23 @@ class LLMService:
@property
def config(self) -> LLMConfig:
"""获取LLM配置优先使用用户配置然后使用系统配置"""
"""
获取LLM配置
🔥 优先级从高到低
1. 数据库用户配置系统配置页面保存的配置
2. 环境变量配置.env 文件中的配置
如果用户配置中某个字段为空则自动回退到环境变量
"""
if self._config is None:
user_llm_config = self._user_config.get('llmConfig', {})
# 优先使用用户配置的provider否则使用系统配置
# 🔥 Provider 优先级:用户配置 > 环境变量
provider_str = user_llm_config.get('llmProvider') or getattr(settings, 'LLM_PROVIDER', 'openai')
provider = self._parse_provider(provider_str)
# 获取API Key - 优先级:用户配置 > 系统通用配置 > 系统平台专属配置
# 🔥 API Key 优先级:用户配置 > 环境变量通用配置 > 环境变量平台专属配置
api_key = (
user_llm_config.get('llmApiKey') or
getattr(settings, 'LLM_API_KEY', '') or
@ -52,33 +60,33 @@ class LLMService:
self._get_provider_api_key(provider)
)
# 获取Base URL
# 🔥 Base URL 优先级:用户配置 > 环境变量
base_url = (
user_llm_config.get('llmBaseUrl') or
getattr(settings, 'LLM_BASE_URL', None) or
self._get_provider_base_url(provider)
)
# 获取模型
# 🔥 Model 优先级:用户配置 > 环境变量 > 默认模型
model = (
user_llm_config.get('llmModel') or
getattr(settings, 'LLM_MODEL', '') or
DEFAULT_MODELS.get(provider, 'gpt-4o-mini')
)
# 获取超时时间(用户配置是毫秒,系统配置是秒)
# 🔥 Timeout 优先级:用户配置(毫秒) > 环境变量(秒)
timeout_ms = user_llm_config.get('llmTimeout')
if timeout_ms:
# 用户配置是毫秒,转换为秒
timeout = int(timeout_ms / 1000) if timeout_ms > 1000 else int(timeout_ms)
else:
# 系统配置是秒
# 环境变量是秒
timeout = int(getattr(settings, 'LLM_TIMEOUT', 150))
# 获取温度
# 🔥 Temperature 优先级:用户配置 > 环境变量
temperature = user_llm_config.get('llmTemperature') if user_llm_config.get('llmTemperature') is not None else float(getattr(settings, 'LLM_TEMPERATURE', 0.1))
# 获取最大token数
# 🔥 Max Tokens 优先级:用户配置 > 环境变量
max_tokens = user_llm_config.get('llmMaxTokens') or int(getattr(settings, 'LLM_MAX_TOKENS', 4096))
self._config = LLMConfig(
@ -394,6 +402,83 @@ Please analyze the following code:
# 重新抛出异常,让调用者处理
raise
async def chat_completion_raw(
self,
messages: List[Dict[str, str]],
temperature: float = 0.1,
max_tokens: int = 4096,
) -> Dict[str, Any]:
"""
🔥 Agent 使用的原始聊天完成接口兼容旧接口
Args:
messages: 消息列表格式为 [{"role": "user", "content": "..."}]
temperature: 温度参数
max_tokens: 最大token数
Returns:
包含 content usage 的字典
"""
# 转换消息格式
llm_messages = [
LLMMessage(role=msg["role"], content=msg["content"])
for msg in messages
]
request = LLMRequest(
messages=llm_messages,
temperature=temperature,
max_tokens=max_tokens,
)
adapter = LLMFactory.create_adapter(self.config)
response = await adapter.complete(request)
return {
"content": response.content,
"usage": {
"prompt_tokens": response.usage.prompt_tokens if response.usage else 0,
"completion_tokens": response.usage.completion_tokens if response.usage else 0,
"total_tokens": response.usage.total_tokens if response.usage else 0,
},
}
async def chat_completion_stream(
self,
messages: List[Dict[str, str]],
temperature: float = 0.1,
max_tokens: int = 4096,
):
"""
流式聊天完成接口 token 返回
Args:
messages: 消息列表
temperature: 温度参数
max_tokens: 最大token数
Yields:
dict: {"type": "token", "content": str} {"type": "done", ...}
"""
from .adapters.litellm_adapter import LiteLLMAdapter
llm_messages = [
LLMMessage(role=msg["role"], content=msg["content"])
for msg in messages
]
request = LLMRequest(
messages=llm_messages,
temperature=temperature,
max_tokens=max_tokens,
)
# 使用 LiteLLM adapter 进行流式调用
adapter = LiteLLMAdapter(self.config)
async for chunk in adapter.stream_complete(request):
yield chunk
def _parse_json(self, text: str) -> Dict[str, Any]:
"""从LLM响应中解析JSON增强版"""

View File

@ -28,14 +28,12 @@ import {
cancelAgentTask,
} from "@/shared/api/agentTasks";
// 事件类型图标映射 - 🔥 重点展示 LLM 相关事件
// 事件类型图标映射
const eventTypeIcons: Record<string, React.ReactNode> = {
// 🧠 LLM 核心事件 - 最重要!
// LLM 核心事件
llm_start: <Brain className="w-3 h-3 text-purple-400 animate-pulse" />,
llm_thought: <Sparkles className="w-3 h-3 text-purple-300" />,
llm_decision: <Zap className="w-3 h-3 text-yellow-400" />,
llm_action: <Zap className="w-3 h-3 text-orange-400" />,
llm_observation: <Search className="w-3 h-3 text-blue-400" />,
llm_complete: <CheckCircle2 className="w-3 h-3 text-green-400" />,
// 阶段相关
@ -43,7 +41,7 @@ const eventTypeIcons: Record<string, React.ReactNode> = {
phase_complete: <CheckCircle2 className="w-3 h-3 text-green-400" />,
thinking: <Brain className="w-3 h-3 text-purple-400" />,
// 工具相关 - LLM 决定的工具调用
// 工具相关
tool_call: <Wrench className="w-3 h-3 text-yellow-400" />,
tool_result: <CheckCircle2 className="w-3 h-3 text-green-400" />,
tool_error: <XCircle className="w-3 h-3 text-red-400" />,
@ -65,14 +63,12 @@ const eventTypeIcons: Record<string, React.ReactNode> = {
task_cancel: <Square className="w-3 h-3 text-yellow-500" />,
};
// 事件类型颜色映射 - 🔥 LLM 事件突出显示
// 事件类型颜色映射
const eventTypeColors: Record<string, string> = {
// 🧠 LLM 核心事件 - 使用紫色系突出
// LLM 核心事件
llm_start: "text-purple-400 font-semibold",
llm_thought: "text-purple-300 bg-purple-950/30 rounded px-1", // 思考内容特别高亮
llm_decision: "text-yellow-300 font-semibold", // 决策特别突出
llm_action: "text-orange-300 font-medium",
llm_observation: "text-blue-300",
llm_thought: "text-purple-300 bg-purple-950/30 rounded px-1",
llm_decision: "text-yellow-300 font-semibold",
llm_complete: "text-green-400 font-semibold",
// 阶段相关
@ -411,7 +407,7 @@ export default function AgentAuditPage() {
{/* 左侧:执行日志 */}
<div className="flex-1 p-4 flex flex-col min-w-0">
{/* 🧠 LLM 思考过程展示区域 - 核心!展示 LLM 的大脑活动 */}
{/* LLM 思考过程展示区域 */}
{(isThinking || thinking) && showThinking && (
<div className="mb-4 bg-purple-950/40 rounded-lg border-2 border-purple-700/60 overflow-hidden shadow-lg shadow-purple-900/20">
<div
@ -423,8 +419,8 @@ export default function AgentAuditPage() {
<Brain className={`w-5 h-5 ${isThinking ? "animate-pulse" : ""}`} />
</div>
<div>
<span className="uppercase tracking-wider font-semibold">🧠 LLM Thinking</span>
<span className="text-purple-400 ml-2 text-xs">Agent </span>
<span className="uppercase tracking-wider font-semibold">LLM Thinking</span>
<span className="text-purple-400 ml-2 text-xs">Agent </span>
</div>
{isThinking && (
<span className="flex items-center gap-1 text-purple-200 bg-purple-800/50 px-2 py-0.5 rounded-full text-xs">
@ -438,7 +434,7 @@ export default function AgentAuditPage() {
<div className="max-h-52 overflow-y-auto bg-[#1a1025]">
<div className="p-4 text-sm text-purple-100 font-mono whitespace-pre-wrap leading-relaxed">
{thinking || "🤔 正在思考下一步..."}
{thinking || "正在思考下一步..."}
{isThinking && <span className="animate-pulse text-purple-400 text-lg"></span>}
</div>
<div ref={thinkingEndRef} />
@ -446,7 +442,7 @@ export default function AgentAuditPage() {
</div>
)}
{/* 🔧 LLM 工具调用展示区域 - LLM 决定调用的工具 */}
{/* 工具调用展示区域 */}
{toolCalls.length > 0 && showToolDetails && (
<div className="mb-4 bg-yellow-950/30 rounded-lg border-2 border-yellow-700/50 overflow-hidden shadow-lg shadow-yellow-900/10">
<div
@ -458,8 +454,8 @@ export default function AgentAuditPage() {
<Wrench className="w-5 h-5" />
</div>
<div>
<span className="uppercase tracking-wider font-semibold">🔧 LLM Tool Calls</span>
<span className="text-yellow-500 ml-2 text-xs">LLM </span>
<span className="uppercase tracking-wider font-semibold">Tool Calls</span>
<span className="text-yellow-500 ml-2 text-xs"></span>
</div>
<Badge variant="outline" className="text-xs px-2 py-0.5 bg-yellow-900/50 border-yellow-600 text-yellow-300">
{toolCalls.length}
@ -488,9 +484,9 @@ export default function AgentAuditPage() {
<div className="flex items-center gap-3 text-sm text-cyan-400">
<div className="flex items-center gap-2">
<Terminal className="w-4 h-4" />
<span className="uppercase tracking-wider font-semibold">LLM Execution Log</span>
<span className="uppercase tracking-wider font-semibold">Execution Log</span>
</div>
<span className="text-xs text-gray-500">LLM & </span>
<span className="text-xs text-gray-500"></span>
{(isStreaming || isStreamConnected) && (
<span className="flex items-center gap-1.5 text-green-400 bg-green-900/30 px-2 py-0.5 rounded-full text-xs">
<span className="w-2 h-2 bg-green-400 rounded-full animate-pulse" />
@ -708,7 +704,7 @@ function StatusBadge({ status }: { status: string }) {
);
}
// 事件行组件 - 增强 LLM 事件展示
// 事件行组件
function EventLine({ event }: { event: AgentEvent }) {
const icon = eventTypeIcons[event.event_type] || <ChevronRight className="w-3 h-3 text-gray-500" />;
const colorClass = eventTypeColors[event.event_type] || "text-gray-400";
@ -717,19 +713,19 @@ function EventLine({ event }: { event: AgentEvent }) {
? new Date(event.timestamp).toLocaleTimeString("zh-CN", { hour12: false })
: "";
// LLM 思考事件特殊处理 - 展示多行内容
// 特殊事件处理
const isLLMThought = event.event_type === "llm_thought";
const isLLMDecision = event.event_type === "llm_decision";
const isLLMAction = event.event_type === "llm_action";
const isImportantLLMEvent = isLLMThought || isLLMDecision || isLLMAction;
const isToolCall = event.event_type === "tool_call";
const isToolResult = event.event_type === "tool_result";
// LLM 事件背景色
// 背景色
const bgClass = isLLMThought
? "bg-purple-950/40 border-l-2 border-purple-600"
: isLLMDecision
? "bg-yellow-950/30 border-l-2 border-yellow-600"
: isLLMAction
? "bg-orange-950/30 border-l-2 border-orange-600"
: isToolCall || isToolResult
? "bg-gray-900/30"
: "";
return (
@ -738,7 +734,7 @@ function EventLine({ event }: { event: AgentEvent }) {
{timestamp}
</span>
<span className="flex-shrink-0 mt-0.5">{icon}</span>
<span className={`flex-1 text-sm break-all ${isImportantLLMEvent ? "whitespace-pre-wrap" : ""}`}>
<span className={`flex-1 text-sm break-all ${isLLMThought ? "whitespace-pre-wrap" : ""}`}>
{event.message}
{event.tool_duration_ms && (
<span className="text-gray-600 ml-2">({event.tool_duration_ms}ms)</span>