CodeReview/backend/app/models/agent_task.py

444 lines
15 KiB
Python

"""
Agent 审计任务模型
支持 AI Agent 自主漏洞挖掘和验证
"""
import uuid
from datetime import datetime
from typing import Optional, List, TYPE_CHECKING
from sqlalchemy import (
Column, String, Integer, Float, Text, Boolean,
DateTime, ForeignKey, Enum as SQLEnum, JSON
)
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
from app.db.base import Base
if TYPE_CHECKING:
from .project import Project
class AgentTaskStatus:
"""Agent 任务状态"""
PENDING = "pending" # 等待执行
INITIALIZING = "initializing" # 初始化中
PLANNING = "planning" # 规划阶段
INDEXING = "indexing" # 索引阶段
ANALYZING = "analyzing" # 分析阶段
VERIFYING = "verifying" # 验证阶段
REPORTING = "reporting" # 报告生成
COMPLETED = "completed" # 已完成
FAILED = "failed" # 失败
CANCELLED = "cancelled" # 已取消
PAUSED = "paused" # 已暂停
class AgentTaskPhase:
"""Agent 执行阶段"""
PLANNING = "planning"
INDEXING = "indexing"
RECONNAISSANCE = "reconnaissance"
ANALYSIS = "analysis"
VERIFICATION = "verification"
REPORTING = "reporting"
class AgentTask(Base):
"""Agent 审计任务"""
__tablename__ = "agent_tasks"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
project_id = Column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False)
# 任务基本信息
name = Column(String(255), nullable=True)
description = Column(Text, nullable=True)
task_type = Column(String(50), default="agent_audit")
# 任务配置
audit_scope = Column(JSON, nullable=True) # 审计范围配置
target_vulnerabilities = Column(JSON, nullable=True) # 目标漏洞类型
verification_level = Column(String(50), default="sandbox") # analysis_only, sandbox, generate_poc
# 分支信息(仓库项目)
branch_name = Column(String(255), nullable=True)
# 排除模式
exclude_patterns = Column(JSON, nullable=True)
# 文件范围
target_files = Column(JSON, nullable=True) # 指定扫描的文件列表
# LLM 配置
llm_config = Column(JSON, nullable=True) # LLM 配置
# Agent 配置
agent_config = Column(JSON, nullable=True) # Agent 特定配置
max_iterations = Column(Integer, default=50) # 最大迭代次数
token_budget = Column(Integer, default=100000) # Token 预算
timeout_seconds = Column(Integer, default=1800) # 超时时间(秒)
# 状态
status = Column(String(20), default=AgentTaskStatus.PENDING)
current_phase = Column(String(50), nullable=True)
current_step = Column(String(255), nullable=True) # 当前执行步骤描述
error_message = Column(Text, nullable=True)
# 进度统计
total_files = Column(Integer, default=0)
indexed_files = Column(Integer, default=0)
analyzed_files = Column(Integer, default=0)
total_chunks = Column(Integer, default=0) # 代码块总数
# Agent 统计
total_iterations = Column(Integer, default=0) # Agent 迭代次数
tool_calls_count = Column(Integer, default=0) # 工具调用次数
tokens_used = Column(Integer, default=0) # 已使用 Token 数
# 发现统计
findings_count = Column(Integer, default=0) # 发现总数
verified_count = Column(Integer, default=0) # 已验证数
false_positive_count = Column(Integer, default=0) # 误报数
# 严重程度统计
critical_count = Column(Integer, default=0)
high_count = Column(Integer, default=0)
medium_count = Column(Integer, default=0)
low_count = Column(Integer, default=0)
# 质量评分
quality_score = Column(Float, default=0.0)
security_score = Column(Float, default=0.0)
# 审计计划
audit_plan = Column(JSON, nullable=True) # Agent 生成的审计计划
# 时间戳
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
started_at = Column(DateTime(timezone=True), nullable=True)
completed_at = Column(DateTime(timezone=True), nullable=True)
# 创建者
created_by = Column(String(36), ForeignKey("users.id"), nullable=False)
# 关联关系
project = relationship("Project", back_populates="agent_tasks")
events = relationship("AgentEvent", back_populates="task", cascade="all, delete-orphan", order_by="AgentEvent.created_at")
findings = relationship("AgentFinding", back_populates="task", cascade="all, delete-orphan")
def __repr__(self):
return f"<AgentTask {self.id} - {self.status}>"
@property
def progress_percentage(self) -> float:
"""计算进度百分比"""
if self.status == AgentTaskStatus.COMPLETED:
return 100.0
if self.status in [AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]:
return 0.0
phase_weights = {
AgentTaskPhase.PLANNING: 5,
AgentTaskPhase.INDEXING: 15,
AgentTaskPhase.RECONNAISSANCE: 10,
AgentTaskPhase.ANALYSIS: 50,
AgentTaskPhase.VERIFICATION: 15,
AgentTaskPhase.REPORTING: 5,
}
completed_weight = 0
current_found = False
for phase, weight in phase_weights.items():
if phase == self.current_phase:
current_found = True
# 估算当前阶段进度
if phase == AgentTaskPhase.INDEXING and self.total_files > 0:
completed_weight += weight * (self.indexed_files / self.total_files)
elif phase == AgentTaskPhase.ANALYSIS and self.total_files > 0:
completed_weight += weight * (self.analyzed_files / self.total_files)
else:
completed_weight += weight * 0.5
break
elif not current_found:
completed_weight += weight
return min(completed_weight, 99.0)
class AgentEventType:
"""Agent 事件类型"""
# 系统事件
TASK_START = "task_start"
TASK_COMPLETE = "task_complete"
TASK_ERROR = "task_error"
TASK_CANCEL = "task_cancel"
# 阶段事件
PHASE_START = "phase_start"
PHASE_COMPLETE = "phase_complete"
# Agent 思考
THINKING = "thinking"
PLANNING = "planning"
DECISION = "decision"
# 工具调用
TOOL_CALL = "tool_call"
TOOL_RESULT = "tool_result"
TOOL_ERROR = "tool_error"
# RAG 相关
RAG_QUERY = "rag_query"
RAG_RESULT = "rag_result"
# 发现相关
FINDING_NEW = "finding_new"
FINDING_UPDATE = "finding_update"
FINDING_VERIFIED = "finding_verified"
FINDING_FALSE_POSITIVE = "finding_false_positive"
# 沙箱相关
SANDBOX_START = "sandbox_start"
SANDBOX_EXEC = "sandbox_exec"
SANDBOX_RESULT = "sandbox_result"
SANDBOX_ERROR = "sandbox_error"
# 进度
PROGRESS = "progress"
# 日志
INFO = "info"
WARNING = "warning"
ERROR = "error"
DEBUG = "debug"
class AgentEvent(Base):
"""Agent 执行事件(用于实时日志和回放)"""
__tablename__ = "agent_events"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
task_id = Column(String(36), ForeignKey("agent_tasks.id", ondelete="CASCADE"), nullable=False, index=True)
# 事件信息
event_type = Column(String(50), nullable=False, index=True)
phase = Column(String(50), nullable=True)
# 事件内容
message = Column(Text, nullable=True)
# 工具调用相关
tool_name = Column(String(100), nullable=True)
tool_input = Column(JSON, nullable=True)
tool_output = Column(JSON, nullable=True)
tool_duration_ms = Column(Integer, nullable=True) # 工具执行时长(毫秒)
# 关联的发现
finding_id = Column(String(36), nullable=True)
# Token 消耗
tokens_used = Column(Integer, default=0)
# 元数据
event_metadata = Column(JSON, nullable=True)
# 序号(用于排序)
sequence = Column(Integer, default=0, index=True)
# 时间戳
created_at = Column(DateTime(timezone=True), server_default=func.now(), index=True)
# 关联关系
task = relationship("AgentTask", back_populates="events")
def __repr__(self):
return f"<AgentEvent {self.event_type} - {self.message[:50] if self.message else ''}>"
def to_sse_dict(self) -> dict:
"""转换为 SSE 事件格式"""
return {
"id": self.id,
"type": self.event_type,
"phase": self.phase,
"message": self.message,
"tool_name": self.tool_name,
"tool_input": self.tool_input,
"tool_output": self.tool_output,
"tool_duration_ms": self.tool_duration_ms,
"finding_id": self.finding_id,
"tokens_used": self.tokens_used,
"metadata": self.event_metadata,
"sequence": self.sequence,
"timestamp": self.created_at.isoformat() if self.created_at else None,
}
class VulnerabilitySeverity:
"""漏洞严重程度"""
CRITICAL = "critical"
HIGH = "high"
MEDIUM = "medium"
LOW = "low"
INFO = "info"
class VulnerabilityType:
"""漏洞类型"""
SQL_INJECTION = "sql_injection"
NOSQL_INJECTION = "nosql_injection"
XSS = "xss"
COMMAND_INJECTION = "command_injection"
CODE_INJECTION = "code_injection"
PATH_TRAVERSAL = "path_traversal"
FILE_INCLUSION = "file_inclusion"
SSRF = "ssrf"
XXE = "xxe"
DESERIALIZATION = "deserialization"
AUTH_BYPASS = "auth_bypass"
IDOR = "idor"
SENSITIVE_DATA_EXPOSURE = "sensitive_data_exposure"
HARDCODED_SECRET = "hardcoded_secret"
WEAK_CRYPTO = "weak_crypto"
RACE_CONDITION = "race_condition"
BUSINESS_LOGIC = "business_logic"
MEMORY_CORRUPTION = "memory_corruption"
OTHER = "other"
class FindingStatus:
"""发现状态"""
NEW = "new" # 新发现
ANALYZING = "analyzing" # 分析中
VERIFIED = "verified" # 已验证
FALSE_POSITIVE = "false_positive" # 误报
NEEDS_REVIEW = "needs_review" # 需要人工审核
FIXED = "fixed" # 已修复
WONT_FIX = "wont_fix" # 不修复
DUPLICATE = "duplicate" # 重复
class AgentFinding(Base):
"""Agent 发现的漏洞"""
__tablename__ = "agent_findings"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
task_id = Column(String(36), ForeignKey("agent_tasks.id", ondelete="CASCADE"), nullable=False, index=True)
# 漏洞基本信息
vulnerability_type = Column(String(100), nullable=False, index=True)
severity = Column(String(20), nullable=False, index=True)
title = Column(String(500), nullable=False)
description = Column(Text, nullable=True)
# 位置信息
file_path = Column(String(500), nullable=True, index=True)
line_start = Column(Integer, nullable=True)
line_end = Column(Integer, nullable=True)
column_start = Column(Integer, nullable=True)
column_end = Column(Integer, nullable=True)
function_name = Column(String(255), nullable=True)
class_name = Column(String(255), nullable=True)
# 代码片段
code_snippet = Column(Text, nullable=True)
code_context = Column(Text, nullable=True) # 更多上下文
# 数据流信息
source = Column(Text, nullable=True) # 污点源
sink = Column(Text, nullable=True) # 危险函数
dataflow_path = Column(JSON, nullable=True) # 数据流路径
# 验证信息
status = Column(String(30), default=FindingStatus.NEW, index=True)
is_verified = Column(Boolean, default=False)
verification_method = Column(Text, nullable=True)
verification_result = Column(JSON, nullable=True)
verified_at = Column(DateTime(timezone=True), nullable=True)
# PoC
has_poc = Column(Boolean, default=False)
poc_code = Column(Text, nullable=True)
poc_description = Column(Text, nullable=True)
poc_steps = Column(JSON, nullable=True) # 复现步骤
# 修复建议
suggestion = Column(Text, nullable=True)
fix_code = Column(Text, nullable=True)
fix_description = Column(Text, nullable=True)
references = Column(JSON, nullable=True) # 参考链接 CWE, OWASP 等
# AI 解释
ai_explanation = Column(Text, nullable=True)
ai_confidence = Column(Float, nullable=True) # AI 置信度 0-1
# XAI (可解释AI)
xai_what = Column(Text, nullable=True)
xai_why = Column(Text, nullable=True)
xai_how = Column(Text, nullable=True)
xai_impact = Column(Text, nullable=True)
# 关联规则
matched_rule_code = Column(String(100), nullable=True)
matched_pattern = Column(Text, nullable=True)
# CVSS 评分(可选)
cvss_score = Column(Float, nullable=True)
cvss_vector = Column(String(100), nullable=True)
# 元数据
finding_metadata = Column(JSON, nullable=True)
tags = Column(JSON, nullable=True)
# 去重标识
fingerprint = Column(String(64), nullable=True, index=True) # 用于去重的指纹
# 时间戳
created_at = Column(DateTime(timezone=True), server_default=func.now())
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
# 关联关系
task = relationship("AgentTask", back_populates="findings")
def __repr__(self):
return f"<AgentFinding {self.vulnerability_type} - {self.severity} - {self.file_path}>"
def generate_fingerprint(self) -> str:
"""生成去重指纹"""
import hashlib
components = [
self.vulnerability_type or "",
self.file_path or "",
str(self.line_start or 0),
self.function_name or "",
(self.code_snippet or "")[:200],
]
content = "|".join(components)
return hashlib.sha256(content.encode()).hexdigest()[:16]
def to_dict(self) -> dict:
"""转换为字典"""
return {
"id": self.id,
"task_id": self.task_id,
"vulnerability_type": self.vulnerability_type,
"severity": self.severity,
"title": self.title,
"description": self.description,
"file_path": self.file_path,
"line_start": self.line_start,
"line_end": self.line_end,
"code_snippet": self.code_snippet,
"status": self.status,
"is_verified": self.is_verified,
"has_poc": self.has_poc,
"poc_code": self.poc_code,
"suggestion": self.suggestion,
"fix_code": self.fix_code,
"ai_explanation": self.ai_explanation,
"ai_confidence": self.ai_confidence,
"created_at": self.created_at.isoformat() if self.created_at else None,
}