diff --git a/backend/alembic/versions/007_add_agent_checkpoint_tables.py b/backend/alembic/versions/007_add_agent_checkpoint_tables.py new file mode 100644 index 0000000..b2d39ed --- /dev/null +++ b/backend/alembic/versions/007_add_agent_checkpoint_tables.py @@ -0,0 +1,68 @@ +"""Add agent checkpoint and tree node tables + +Revision ID: 007_add_agent_checkpoint_tables +Revises: 006_add_agent_tables +Create Date: 2024-12-12 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '007_add_agent_checkpoint_tables' +down_revision = '006_add_agent_tables' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Create agent_checkpoints table + op.create_table( + 'agent_checkpoints', + sa.Column('id', sa.String(36), primary_key=True), + sa.Column('task_id', sa.String(36), sa.ForeignKey('agent_tasks.id', ondelete='CASCADE'), nullable=False, index=True), + sa.Column('agent_id', sa.String(50), nullable=False, index=True), + sa.Column('agent_name', sa.String(255), nullable=False), + sa.Column('agent_type', sa.String(50), nullable=False), + sa.Column('parent_agent_id', sa.String(50), nullable=True), + sa.Column('state_data', sa.Text, nullable=False), + sa.Column('iteration', sa.Integer, default=0), + sa.Column('status', sa.String(30), nullable=False), + sa.Column('total_tokens', sa.Integer, default=0), + sa.Column('tool_calls', sa.Integer, default=0), + sa.Column('findings_count', sa.Integer, default=0), + sa.Column('checkpoint_type', sa.String(30), default='auto'), + sa.Column('checkpoint_name', sa.String(255), nullable=True), + sa.Column('checkpoint_metadata', sa.JSON, nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now(), index=True), + ) + + # Create agent_tree_nodes table + op.create_table( + 'agent_tree_nodes', + sa.Column('id', sa.String(36), primary_key=True), + sa.Column('task_id', sa.String(36), sa.ForeignKey('agent_tasks.id', ondelete='CASCADE'), nullable=False, index=True), + sa.Column('agent_id', sa.String(50), nullable=False, unique=True, index=True), + sa.Column('agent_name', sa.String(255), nullable=False), + sa.Column('agent_type', sa.String(50), nullable=False), + sa.Column('parent_agent_id', sa.String(50), nullable=True, index=True), + sa.Column('depth', sa.Integer, default=0), + sa.Column('task_description', sa.Text, nullable=True), + sa.Column('knowledge_modules', sa.JSON, nullable=True), + sa.Column('status', sa.String(30), default='created'), + sa.Column('result_summary', sa.Text, nullable=True), + sa.Column('findings_count', sa.Integer, default=0), + sa.Column('iterations', sa.Integer, default=0), + sa.Column('tokens_used', sa.Integer, default=0), + sa.Column('tool_calls', sa.Integer, default=0), + sa.Column('duration_ms', sa.Integer, nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column('started_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('finished_at', sa.DateTime(timezone=True), nullable=True), + ) + + +def downgrade() -> None: + op.drop_table('agent_tree_nodes') + op.drop_table('agent_checkpoints') diff --git a/backend/alembic/versions/4c280754c680_merge_heads.py b/backend/alembic/versions/4c280754c680_merge_heads.py new file mode 100644 index 0000000..268d25e --- /dev/null +++ b/backend/alembic/versions/4c280754c680_merge_heads.py @@ -0,0 +1,29 @@ +"""merge_heads + +Revision ID: 4c280754c680 +Revises: 004_add_prompts_and_rules, 007_add_agent_checkpoint_tables +Create Date: 2025-12-12 12:07:42.238185 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '4c280754c680' +down_revision = ('004_add_prompts_and_rules', '007_add_agent_checkpoint_tables') +branch_labels = None +depends_on = None + + +def upgrade() -> None: + pass + + +def downgrade() -> None: + pass + + + + + diff --git a/backend/app/api/v1/endpoints/agent_tasks.py b/backend/app/api/v1/endpoints/agent_tasks.py index 79fbb33..f3316ae 100644 --- a/backend/app/api/v1/endpoints/agent_tasks.py +++ b/backend/app/api/v1/endpoints/agent_tasks.py @@ -29,14 +29,17 @@ from app.models.agent_task import ( from app.models.project import Project from app.models.user import User from app.models.user_config import UserConfig -from app.services.agent import AgentRunner, EventManager, run_agent_task +from app.services.agent.event_manager import EventManager from app.services.agent.streaming import StreamHandler, StreamEvent, StreamEventType logger = logging.getLogger(__name__) router = APIRouter() -# 运行中的任务 -_running_tasks: Dict[str, AgentRunner] = {} +# 运行中的任务(兼容旧接口) +_running_tasks: Dict[str, Any] = {} + +# 🔥 运行中的 asyncio Tasks(用于强制取消) +_running_asyncio_tasks: Dict[str, asyncio.Task] = {} # ============ Schemas ============ @@ -71,7 +74,7 @@ class AgentTaskCreate(BaseModel): target_files: Optional[List[str]] = Field(None, description="指定扫描的文件") # Agent 配置 - max_iterations: int = Field(3, ge=1, le=10, description="最大分析迭代次数") + max_iterations: int = Field(50, ge=1, le=200, description="最大迭代次数") timeout_seconds: int = Field(1800, ge=60, le=7200, description="超时时间(秒)") @@ -200,9 +203,29 @@ class TaskSummaryResponse(BaseModel): # ============ 后台任务执行 ============ +# 运行中的动态执行器 +_running_orchestrators: Dict[str, Any] = {} +# 运行中的事件管理器(用于 SSE 流) +_running_event_managers: Dict[str, EventManager] = {} + + async def _execute_agent_task(task_id: str): - """在后台执行 Agent 任务""" + """ + 在后台执行 Agent 任务 - 使用动态 Agent 树架构 + + 架构:OrchestratorAgent 作为大脑,动态调度子 Agent + """ + from app.services.agent.agents import OrchestratorAgent, ReconAgent, AnalysisAgent, VerificationAgent + from app.services.agent.event_manager import EventManager, AgentEventEmitter + from app.services.llm.service import LLMService + from app.services.agent.core import agent_registry + from app.core.config import settings + import time + async with async_session_factory() as db: + orchestrator = None + start_time = time.time() + try: # 获取任务 task = await db.get(AgentTask, task_id, options=[selectinload(AgentTask.project)]) @@ -216,78 +239,203 @@ async def _execute_agent_task(task_id: str): logger.error(f"Project not found for task {task_id}") return - # 🔥 获取项目根目录(解压 ZIP 或克隆仓库) + # 获取项目根目录 project_root = await _get_project_root(project, task_id) - # 🔥 获取用户配置(从系统配置页面) - # 优先级:1. 数据库用户配置 > 2. 环境变量配置 - user_config = None - if task.created_by: - from app.api.v1.endpoints.config import ( - decrypt_config, - SENSITIVE_LLM_FIELDS, SENSITIVE_OTHER_FIELDS - ) - import json - - result = await db.execute( - select(UserConfig).where(UserConfig.user_id == task.created_by) - ) - config = result.scalar_one_or_none() - - if config and config.llm_config: - # 🔥 有数据库配置:使用数据库配置(优先) - user_llm_config = json.loads(config.llm_config) if config.llm_config else {} - user_other_config = json.loads(config.other_config) if config.other_config else {} - - # 解密敏感字段 - user_llm_config = decrypt_config(user_llm_config, SENSITIVE_LLM_FIELDS) - user_other_config = decrypt_config(user_other_config, SENSITIVE_OTHER_FIELDS) - - user_config = { - "llmConfig": user_llm_config, # 直接使用数据库配置,不合并默认值 - "otherConfig": user_other_config, - } - logger.info(f"✅ Using database user config for task {task_id}, LLM provider: {user_llm_config.get('llmProvider', 'N/A')}") - else: - # 🔥 无数据库配置:传递 None,让 LLMService 使用环境变量 - user_config = None - logger.info(f"⚠️ No database config found for user {task.created_by}, will use environment variables for task {task_id}") + # 获取用户配置 + user_config = await _get_user_config(db, task.created_by) # 更新状态为运行中 task.status = AgentTaskStatus.RUNNING task.started_at = datetime.now(timezone.utc) + task.current_phase = AgentTaskPhase.PLANNING await db.commit() - logger.info(f"Task {task_id} started") + logger.info(f"🚀 Task {task_id} started with Dynamic Agent Tree architecture") - # 创建 Runner(传入用户配置) - runner = AgentRunner(db, task, project_root, user_config=user_config) - _running_tasks[task_id] = runner + # 创建事件管理器 + event_manager = EventManager(db_session_factory=async_session_factory) + event_manager.create_queue(task_id) + event_emitter = AgentEventEmitter(task_id, event_manager) - # 执行 - result = await runner.run() + # 创建 LLM 服务 + llm_service = LLMService(user_config=user_config) + + # 初始化工具集 - 传递排除模式和目标文件 + tools = await _initialize_tools( + project_root, + llm_service, + user_config, + exclude_patterns=task.exclude_patterns, + target_files=task.target_files, + ) + + # 创建子 Agent + recon_agent = ReconAgent( + llm_service=llm_service, + tools=tools.get("recon", {}), + event_emitter=event_emitter, + ) + + analysis_agent = AnalysisAgent( + llm_service=llm_service, + tools=tools.get("analysis", {}), + event_emitter=event_emitter, + ) + + verification_agent = VerificationAgent( + llm_service=llm_service, + tools=tools.get("verification", {}), + event_emitter=event_emitter, + ) + + # 创建 Orchestrator Agent + orchestrator = OrchestratorAgent( + llm_service=llm_service, + tools=tools.get("orchestrator", {}), + event_emitter=event_emitter, + sub_agents={ + "recon": recon_agent, + "analysis": analysis_agent, + "verification": verification_agent, + }, + ) + + # 注册到全局 + _running_orchestrators[task_id] = orchestrator + _running_tasks[task_id] = orchestrator # 兼容旧的取消逻辑 + _running_event_managers[task_id] = event_manager # 用于 SSE 流 + + # 🔥 清理旧的 Agent 注册表,避免显示多个树 + from app.services.agent.core import agent_registry + agent_registry.clear() + + # 注册 Orchestrator 到 Agent Registry(使用其内置方法) + orchestrator._register_to_registry(task="Root orchestrator for security audit") + + await event_emitter.emit_info("🧠 动态 Agent 树架构启动") + await event_emitter.emit_info(f"📁 项目路径: {project_root}") + + # 收集项目信息 - 传递排除模式和目标文件 + project_info = await _collect_project_info( + project_root, + project.name, + exclude_patterns=task.exclude_patterns, + target_files=task.target_files, + ) + + # 更新任务文件统计 + task.total_files = project_info.get("file_count", 0) + await db.commit() + + # 构建输入数据 + input_data = { + "project_info": project_info, + "config": { + "target_vulnerabilities": task.target_vulnerabilities or [], + "verification_level": task.verification_level or "sandbox", + "exclude_patterns": task.exclude_patterns or [], + "target_files": task.target_files or [], + "max_iterations": task.max_iterations or 50, + }, + "project_root": project_root, + "task_id": task_id, + } + + # 执行 Orchestrator + await event_emitter.emit_phase_start("orchestration", "🎯 Orchestrator 开始编排审计流程") + task.current_phase = AgentTaskPhase.ANALYSIS + await db.commit() + + # 🔥 将 orchestrator.run() 包装在 asyncio.Task 中,以便可以强制取消 + run_task = asyncio.create_task(orchestrator.run(input_data)) + _running_asyncio_tasks[task_id] = run_task + + try: + result = await run_task + finally: + _running_asyncio_tasks.pop(task_id, None) + + # 处理结果 + duration_ms = int((time.time() - start_time) * 1000) - # 更新任务状态 await db.refresh(task) - if result.get('success', True): # 默认成功,除非明确失败 + + if result.success: + # 保存发现 + findings = result.data.get("findings", []) + await _save_findings(db, task_id, findings) + + # 更新任务统计 task.status = AgentTaskStatus.COMPLETED task.completed_at = datetime.now(timezone.utc) + task.current_phase = AgentTaskPhase.COMPLETED + task.findings_count = len(findings) + task.total_iterations = result.iterations + task.tool_calls_count = result.tool_calls + task.tokens_used = result.tokens_used + + # 统计严重程度 + for f in findings: + if isinstance(f, dict): + sev = f.get("severity", "low") + if sev == "critical": + task.critical_count += 1 + elif sev == "high": + task.high_count += 1 + elif sev == "medium": + task.medium_count += 1 + elif sev == "low": + task.low_count += 1 + + # 计算安全评分 + task.security_score = _calculate_security_score(findings) + task.progress_percentage = 100.0 + + await db.commit() + + await event_emitter.emit_task_complete( + findings_count=len(findings), + duration_ms=duration_ms, + ) + + logger.info(f"✅ Task {task_id} completed: {len(findings)} findings, {duration_ms}ms") else: - task.status = AgentTaskStatus.FAILED - task.error_message = result.get('error', 'Unknown error') - task.completed_at = datetime.now(timezone.utc) - - await db.commit() - logger.info(f"Task {task_id} completed with status: {task.status}") + # 🔥 检查是否是取消导致的失败 + if result.error == "任务已取消": + # 状态可能已经被 cancel API 更新,只需确保一致性 + if task.status != AgentTaskStatus.CANCELLED: + task.status = AgentTaskStatus.CANCELLED + task.completed_at = datetime.now(timezone.utc) + await db.commit() + logger.info(f"🛑 Task {task_id} cancelled") + else: + task.status = AgentTaskStatus.FAILED + task.error_message = result.error or "Unknown error" + task.completed_at = datetime.now(timezone.utc) + await db.commit() + + await event_emitter.emit_error(result.error or "Unknown error") + logger.error(f"❌ Task {task_id} failed: {result.error}") + except asyncio.CancelledError: + logger.info(f"Task {task_id} cancelled") + try: + task = await db.get(AgentTask, task_id) + if task: + task.status = AgentTaskStatus.CANCELLED + task.completed_at = datetime.now(timezone.utc) + await db.commit() + except Exception: + pass + except Exception as e: logger.error(f"Task {task_id} failed: {e}", exc_info=True) - # 更新任务状态 try: task = await db.get(AgentTask, task_id) if task: task.status = AgentTaskStatus.FAILED - task.error_message = str(e)[:1000] # 限制错误消息长度 + task.error_message = str(e)[:1000] task.completed_at = datetime.now(timezone.utc) await db.commit() except Exception as db_error: @@ -295,10 +443,308 @@ async def _execute_agent_task(task_id: str): finally: # 清理 + _running_orchestrators.pop(task_id, None) _running_tasks.pop(task_id, None) + _running_event_managers.pop(task_id, None) + _running_asyncio_tasks.pop(task_id, None) # 🔥 清理 asyncio task + + # 从 Registry 注销 + if orchestrator: + agent_registry.unregister_agent(orchestrator.agent_id) + logger.debug(f"Task {task_id} cleaned up") +async def _get_user_config(db: AsyncSession, user_id: Optional[str]) -> Optional[Dict[str, Any]]: + """获取用户配置""" + if not user_id: + return None + + try: + from app.api.v1.endpoints.config import ( + decrypt_config, + SENSITIVE_LLM_FIELDS, SENSITIVE_OTHER_FIELDS + ) + + result = await db.execute( + select(UserConfig).where(UserConfig.user_id == user_id) + ) + config = result.scalar_one_or_none() + + if config and config.llm_config: + user_llm_config = json.loads(config.llm_config) if config.llm_config else {} + user_other_config = json.loads(config.other_config) if config.other_config else {} + + user_llm_config = decrypt_config(user_llm_config, SENSITIVE_LLM_FIELDS) + user_other_config = decrypt_config(user_other_config, SENSITIVE_OTHER_FIELDS) + + return { + "llmConfig": user_llm_config, + "otherConfig": user_other_config, + } + except Exception as e: + logger.warning(f"Failed to get user config: {e}") + + return None + + +async def _initialize_tools( + project_root: str, + llm_service, + user_config: Optional[Dict[str, Any]], + exclude_patterns: Optional[List[str]] = None, + target_files: Optional[List[str]] = None, +) -> Dict[str, Dict[str, Any]]: + """初始化工具集 + + Args: + project_root: 项目根目录 + llm_service: LLM 服务 + user_config: 用户配置 + exclude_patterns: 排除模式列表 + target_files: 目标文件列表 + """ + from app.services.agent.tools import ( + FileReadTool, FileSearchTool, ListFilesTool, + PatternMatchTool, CodeAnalysisTool, DataFlowAnalysisTool, + SemgrepTool, BanditTool, GitleaksTool, + ThinkTool, ReflectTool, + CreateVulnerabilityReportTool, + VulnerabilityValidationTool, + ) + from app.services.agent.knowledge import ( + SecurityKnowledgeQueryTool, + GetVulnerabilityKnowledgeTool, + ) + + # 基础工具 - 传递排除模式和目标文件 + base_tools = { + "read_file": FileReadTool(project_root, exclude_patterns, target_files), + "list_files": ListFilesTool(project_root, exclude_patterns, target_files), + "search_code": FileSearchTool(project_root, exclude_patterns, target_files), + "think": ThinkTool(), + "reflect": ReflectTool(), + } + + # Recon 工具 + recon_tools = { + **base_tools, + } + + # Analysis 工具 + analysis_tools = { + **base_tools, + "pattern_match": PatternMatchTool(project_root), + # TODO: code_analysis 工具暂时禁用,因为 LLM 调用经常失败 + # "code_analysis": CodeAnalysisTool(llm_service), + "dataflow_analysis": DataFlowAnalysisTool(llm_service), + "semgrep_scan": SemgrepTool(project_root), + "bandit_scan": BanditTool(project_root), + "gitleaks_scan": GitleaksTool(project_root), + "query_security_knowledge": SecurityKnowledgeQueryTool(), + "get_vulnerability_knowledge": GetVulnerabilityKnowledgeTool(), + } + + # Verification 工具 + verification_tools = { + **base_tools, + "vulnerability_validation": VulnerabilityValidationTool(llm_service), + "dataflow_analysis": DataFlowAnalysisTool(llm_service), + "create_vulnerability_report": CreateVulnerabilityReportTool(), + } + + # Orchestrator 工具(主要是思考工具) + orchestrator_tools = { + "think": ThinkTool(), + "reflect": ReflectTool(), + } + + return { + "recon": recon_tools, + "analysis": analysis_tools, + "verification": verification_tools, + "orchestrator": orchestrator_tools, + } + + +async def _collect_project_info( + project_root: str, + project_name: str, + exclude_patterns: Optional[List[str]] = None, + target_files: Optional[List[str]] = None, +) -> Dict[str, Any]: + """收集项目信息 + + Args: + project_root: 项目根目录 + project_name: 项目名称 + exclude_patterns: 排除模式列表 + target_files: 目标文件列表 + """ + import fnmatch + + info = { + "name": project_name, + "root": project_root, + "languages": [], + "file_count": 0, + "structure": {}, + } + + try: + # 默认排除目录 + exclude_dirs = { + "node_modules", "__pycache__", ".git", "venv", ".venv", + "build", "dist", "target", ".idea", ".vscode", + } + + # 从用户配置的排除模式中提取目录 + if exclude_patterns: + for pattern in exclude_patterns: + if pattern.endswith("/**"): + exclude_dirs.add(pattern[:-3]) + elif "/" not in pattern and "*" not in pattern: + exclude_dirs.add(pattern) + + # 目标文件集合 + target_files_set = set(target_files) if target_files else None + + lang_map = { + ".py": "Python", ".js": "JavaScript", ".ts": "TypeScript", + ".java": "Java", ".go": "Go", ".php": "PHP", + ".rb": "Ruby", ".rs": "Rust", ".c": "C", ".cpp": "C++", + } + + for root, dirs, files in os.walk(project_root): + dirs[:] = [d for d in dirs if d not in exclude_dirs] + + for f in files: + relative_path = os.path.relpath(os.path.join(root, f), project_root) + + # 检查是否在目标文件列表中 + if target_files_set and relative_path not in target_files_set: + continue + + # 检查排除模式 + should_skip = False + if exclude_patterns: + for pattern in exclude_patterns: + if fnmatch.fnmatch(relative_path, pattern) or fnmatch.fnmatch(f, pattern): + should_skip = True + break + if should_skip: + continue + + info["file_count"] += 1 + + ext = os.path.splitext(f)[1].lower() + if ext in lang_map and lang_map[ext] not in info["languages"]: + info["languages"].append(lang_map[ext]) + + # 收集顶层目录结构 + try: + top_items = os.listdir(project_root) + info["structure"] = { + "directories": [d for d in top_items if os.path.isdir(os.path.join(project_root, d)) and d not in exclude_dirs], + "files": [f for f in top_items if os.path.isfile(os.path.join(project_root, f))][:20], + } + except Exception: + pass + + except Exception as e: + logger.warning(f"Failed to collect project info: {e}") + + return info + + +async def _save_findings(db: AsyncSession, task_id: str, findings: List[Dict]) -> None: + """保存发现到数据库""" + from app.models.agent_task import VulnerabilityType + + severity_map = { + "critical": VulnerabilitySeverity.CRITICAL, + "high": VulnerabilitySeverity.HIGH, + "medium": VulnerabilitySeverity.MEDIUM, + "low": VulnerabilitySeverity.LOW, + "info": VulnerabilitySeverity.INFO, + } + + type_map = { + "sql_injection": VulnerabilityType.SQL_INJECTION, + "nosql_injection": VulnerabilityType.NOSQL_INJECTION, + "xss": VulnerabilityType.XSS, + "command_injection": VulnerabilityType.COMMAND_INJECTION, + "code_injection": VulnerabilityType.CODE_INJECTION, + "path_traversal": VulnerabilityType.PATH_TRAVERSAL, + "ssrf": VulnerabilityType.SSRF, + "xxe": VulnerabilityType.XXE, + "auth_bypass": VulnerabilityType.AUTH_BYPASS, + "idor": VulnerabilityType.IDOR, + "sensitive_data_exposure": VulnerabilityType.SENSITIVE_DATA_EXPOSURE, + "hardcoded_secret": VulnerabilityType.HARDCODED_SECRET, + } + + for finding in findings: + if not isinstance(finding, dict): + continue + + try: + db_finding = AgentFinding( + id=str(uuid4()), + task_id=task_id, + vulnerability_type=type_map.get( + finding.get("vulnerability_type", "other"), + VulnerabilityType.OTHER + ), + severity=severity_map.get( + finding.get("severity", "medium"), + VulnerabilitySeverity.MEDIUM + ), + title=finding.get("title", "Unknown"), + description=finding.get("description", ""), + file_path=finding.get("file_path"), + line_start=finding.get("line_start"), + line_end=finding.get("line_end"), + code_snippet=finding.get("code_snippet"), + suggestion=finding.get("suggestion") or finding.get("recommendation"), + is_verified=finding.get("is_verified", False), + confidence=finding.get("confidence", 0.5), + status=FindingStatus.VERIFIED if finding.get("is_verified") else FindingStatus.NEW, + ) + db.add(db_finding) + except Exception as e: + logger.warning(f"Failed to save finding: {e}") + + try: + await db.commit() + except Exception as e: + logger.error(f"Failed to commit findings: {e}") + + +def _calculate_security_score(findings: List[Dict]) -> float: + """计算安全评分""" + if not findings: + return 100.0 + + # 基于发现的严重程度计算扣分 + deductions = { + "critical": 25, + "high": 15, + "medium": 8, + "low": 3, + "info": 1, + } + + total_deduction = 0 + for f in findings: + if isinstance(f, dict): + sev = f.get("severity", "low") + total_deduction += deductions.get(sev, 3) + + score = max(0, 100 - total_deduction) + return float(score) + + # ============ API Endpoints ============ @router.post("/", response_model=AgentTaskResponse) @@ -420,6 +866,28 @@ async def get_agent_task( elif task.status in [AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]: progress = 0.0 + # 🔥 从运行中的 Orchestrator 获取实时统计 + total_iterations = task.total_iterations or 0 + tool_calls_count = task.tool_calls_count or 0 + tokens_used = task.tokens_used or 0 + + orchestrator = _running_orchestrators.get(task_id) + if orchestrator and task.status == AgentTaskStatus.RUNNING: + # 从 Orchestrator 获取统计 + stats = orchestrator.get_stats() + total_iterations = stats.get("iterations", 0) + tool_calls_count = stats.get("tool_calls", 0) + tokens_used = stats.get("tokens_used", 0) + + # 累加子 Agent 的统计 + if hasattr(orchestrator, 'sub_agents'): + for agent in orchestrator.sub_agents.values(): + if hasattr(agent, 'get_stats'): + sub_stats = agent.get_stats() + total_iterations += sub_stats.get("iterations", 0) + tool_calls_count += sub_stats.get("tool_calls", 0) + tokens_used += sub_stats.get("tokens_used", 0) + # 手动构建响应数据 response_data = { "id": task.id, @@ -434,9 +902,9 @@ async def get_agent_task( "indexed_files": task.indexed_files or 0, "analyzed_files": task.analyzed_files or 0, "total_chunks": task.total_chunks or 0, - "total_iterations": task.total_iterations or 0, - "tool_calls_count": task.tool_calls_count or 0, - "tokens_used": task.tokens_used or 0, + "total_iterations": total_iterations, + "tool_calls_count": tool_calls_count, + "tokens_used": tokens_used, "findings_count": task.findings_count or 0, "total_findings": task.findings_count or 0, # 兼容字段 "verified_count": task.verified_count or 0, @@ -486,16 +954,24 @@ async def cancel_agent_task( if task.status in [AgentTaskStatus.COMPLETED, AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]: raise HTTPException(status_code=400, detail="任务已结束,无法取消") - # 取消运行中的任务 + # 🔥 1. 设置 Agent 的取消标志 runner = _running_tasks.get(task_id) if runner: runner.cancel() + logger.info(f"[Cancel] Set cancel flag for task {task_id}") + + # 🔥 2. 强制取消 asyncio Task(立即中断 LLM 调用) + asyncio_task = _running_asyncio_tasks.get(task_id) + if asyncio_task and not asyncio_task.done(): + asyncio_task.cancel() + logger.info(f"[Cancel] Cancelled asyncio task for {task_id}") # 更新状态 task.status = AgentTaskStatus.CANCELLED task.completed_at = datetime.now(timezone.utc) await db.commit() + logger.info(f"[Cancel] Task {task_id} cancelled successfully") return {"message": "任务已取消", "task_id": task_id} @@ -631,9 +1107,9 @@ async def stream_agent_with_thinking( async def enhanced_event_generator(): """生成增强版 SSE 事件流""" # 1. 检查任务是否在运行中 (内存) - runner = _running_tasks.get(task_id) + event_manager = _running_event_managers.get(task_id) - if runner: + if event_manager: logger.info(f"Stream {task_id}: Using in-memory event manager") try: # 使用 EventManager 的流式接口 @@ -644,7 +1120,7 @@ async def stream_agent_with_thinking( if not include_tool_calls: skip_types.update(["tool_call_start", "tool_call_input", "tool_call_output", "tool_call_end"]) - async for event in runner.event_manager.stream_events(task_id, after_sequence=after_sequence): + async for event in event_manager.stream_events(task_id, after_sequence=after_sequence): event_type = event.get("event_type") if event_type in skip_types: @@ -1033,3 +1509,320 @@ async def _get_project_root(project: Project, task_id: str) -> str: return base_path + +# ============ Agent Tree API ============ + +class AgentTreeNodeResponse(BaseModel): + """Agent 树节点响应""" + id: str + agent_id: str + agent_name: str + agent_type: str + parent_agent_id: Optional[str] = None + depth: int = 0 + task_description: Optional[str] = None + knowledge_modules: Optional[List[str]] = None + status: str = "created" + result_summary: Optional[str] = None + findings_count: int = 0 + iterations: int = 0 + tokens_used: int = 0 + tool_calls: int = 0 + duration_ms: Optional[int] = None + children: List["AgentTreeNodeResponse"] = [] + + class Config: + from_attributes = True + + +class AgentTreeResponse(BaseModel): + """Agent 树响应""" + task_id: str + root_agent_id: Optional[str] = None + total_agents: int = 0 + running_agents: int = 0 + completed_agents: int = 0 + failed_agents: int = 0 + total_findings: int = 0 + nodes: List[AgentTreeNodeResponse] = [] + + +@router.get("/{task_id}/agent-tree", response_model=AgentTreeResponse) +async def get_agent_tree( + task_id: str, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(deps.get_current_user), +) -> Any: + """ + 获取任务的 Agent 树结构 + + 返回动态 Agent 树的完整结构,包括: + - 所有 Agent 节点 + - 父子关系 + - 执行状态 + - 发现统计 + """ + task = await db.get(AgentTask, task_id) + if not task: + raise HTTPException(status_code=404, detail="任务不存在") + + project = await db.get(Project, task.project_id) + if not project or project.owner_id != current_user.id: + raise HTTPException(status_code=403, detail="无权访问此任务") + + # 尝试从内存中获取 Agent 树(运行中的任务) + runner = _running_tasks.get(task_id) + logger.info(f"[AgentTree API] task_id={task_id}, runner exists={runner is not None}") + + if runner: + from app.services.agent.core import agent_registry + + tree = agent_registry.get_agent_tree() + stats = agent_registry.get_statistics() + logger.info(f"[AgentTree API] tree nodes={len(tree.get('nodes', {}))}, root={tree.get('root_agent_id')}") + logger.info(f"[AgentTree API] 节点详情: {list(tree.get('nodes', {}).keys())}") + + # 构建节点列表 + nodes = [] + for agent_id, node_data in tree.get("nodes", {}).items(): + # 🔥 从 Agent 实例获取实时统计数据 + iterations = 0 + tool_calls = 0 + tokens_used = 0 + findings_count = 0 + + agent_instance = agent_registry.get_agent(agent_id) + if agent_instance and hasattr(agent_instance, 'get_stats'): + agent_stats = agent_instance.get_stats() + iterations = agent_stats.get("iterations", 0) + tool_calls = agent_stats.get("tool_calls", 0) + tokens_used = agent_stats.get("tokens_used", 0) + + # 从结果中获取发现数量 + if node_data.get("result"): + result = node_data.get("result", {}) + findings_count = len(result.get("findings", [])) + + nodes.append(AgentTreeNodeResponse( + id=node_data.get("id", agent_id), + agent_id=agent_id, + agent_name=node_data.get("name", "Unknown"), + agent_type=node_data.get("type", "unknown"), + parent_agent_id=node_data.get("parent_id"), + task_description=node_data.get("task"), + knowledge_modules=node_data.get("knowledge_modules", []), + status=node_data.get("status", "unknown"), + findings_count=findings_count, + iterations=iterations, + tool_calls=tool_calls, + tokens_used=tokens_used, + children=[], + )) + + return AgentTreeResponse( + task_id=task_id, + root_agent_id=tree.get("root_agent_id"), + total_agents=stats.get("total", 0), + running_agents=stats.get("running", 0), + completed_agents=stats.get("completed", 0), + failed_agents=stats.get("failed", 0), + total_findings=sum(n.findings_count for n in nodes), + nodes=nodes, + ) + + # 从数据库获取(已完成的任务) + from app.models.agent_task import AgentTreeNode + + result = await db.execute( + select(AgentTreeNode) + .where(AgentTreeNode.task_id == task_id) + .order_by(AgentTreeNode.depth, AgentTreeNode.created_at) + ) + db_nodes = result.scalars().all() + + if not db_nodes: + return AgentTreeResponse( + task_id=task_id, + nodes=[], + ) + + # 构建响应 + nodes = [] + root_id = None + running = 0 + completed = 0 + failed = 0 + total_findings = 0 + + for node in db_nodes: + if node.parent_agent_id is None: + root_id = node.agent_id + + if node.status == "running": + running += 1 + elif node.status == "completed": + completed += 1 + elif node.status == "failed": + failed += 1 + + total_findings += node.findings_count or 0 + + nodes.append(AgentTreeNodeResponse( + id=node.id, + agent_id=node.agent_id, + agent_name=node.agent_name, + agent_type=node.agent_type, + parent_agent_id=node.parent_agent_id, + depth=node.depth, + task_description=node.task_description, + knowledge_modules=node.knowledge_modules, + status=node.status, + result_summary=node.result_summary, + findings_count=node.findings_count or 0, + iterations=node.iterations or 0, + tokens_used=node.tokens_used or 0, + tool_calls=node.tool_calls or 0, + duration_ms=node.duration_ms, + children=[], + )) + + return AgentTreeResponse( + task_id=task_id, + root_agent_id=root_id, + total_agents=len(nodes), + running_agents=running, + completed_agents=completed, + failed_agents=failed, + total_findings=total_findings, + nodes=nodes, + ) + + +# ============ Checkpoint API ============ + +class CheckpointResponse(BaseModel): + """检查点响应""" + id: str + agent_id: str + agent_name: str + agent_type: str + iteration: int + status: str + total_tokens: int = 0 + tool_calls: int = 0 + findings_count: int = 0 + checkpoint_type: str = "auto" + checkpoint_name: Optional[str] = None + created_at: Optional[str] = None + + class Config: + from_attributes = True + + +@router.get("/{task_id}/checkpoints", response_model=List[CheckpointResponse]) +async def list_checkpoints( + task_id: str, + agent_id: Optional[str] = None, + limit: int = Query(20, ge=1, le=100), + db: AsyncSession = Depends(get_db), + current_user: User = Depends(deps.get_current_user), +) -> Any: + """ + 获取任务的检查点列表 + + 用于: + - 查看执行历史 + - 状态恢复 + - 调试分析 + """ + task = await db.get(AgentTask, task_id) + if not task: + raise HTTPException(status_code=404, detail="任务不存在") + + project = await db.get(Project, task.project_id) + if not project or project.owner_id != current_user.id: + raise HTTPException(status_code=403, detail="无权访问此任务") + + from app.models.agent_task import AgentCheckpoint + + query = select(AgentCheckpoint).where(AgentCheckpoint.task_id == task_id) + + if agent_id: + query = query.where(AgentCheckpoint.agent_id == agent_id) + + query = query.order_by(AgentCheckpoint.created_at.desc()).limit(limit) + + result = await db.execute(query) + checkpoints = result.scalars().all() + + return [ + CheckpointResponse( + id=cp.id, + agent_id=cp.agent_id, + agent_name=cp.agent_name, + agent_type=cp.agent_type, + iteration=cp.iteration, + status=cp.status, + total_tokens=cp.total_tokens or 0, + tool_calls=cp.tool_calls or 0, + findings_count=cp.findings_count or 0, + checkpoint_type=cp.checkpoint_type or "auto", + checkpoint_name=cp.checkpoint_name, + created_at=cp.created_at.isoformat() if cp.created_at else None, + ) + for cp in checkpoints + ] + + +@router.get("/{task_id}/checkpoints/{checkpoint_id}") +async def get_checkpoint_detail( + task_id: str, + checkpoint_id: str, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(deps.get_current_user), +) -> Any: + """ + 获取检查点详情 + + 返回完整的 Agent 状态数据 + """ + task = await db.get(AgentTask, task_id) + if not task: + raise HTTPException(status_code=404, detail="任务不存在") + + project = await db.get(Project, task.project_id) + if not project or project.owner_id != current_user.id: + raise HTTPException(status_code=403, detail="无权访问此任务") + + from app.models.agent_task import AgentCheckpoint + + checkpoint = await db.get(AgentCheckpoint, checkpoint_id) + if not checkpoint or checkpoint.task_id != task_id: + raise HTTPException(status_code=404, detail="检查点不存在") + + # 解析状态数据 + state_data = {} + if checkpoint.state_data: + try: + state_data = json.loads(checkpoint.state_data) + except json.JSONDecodeError: + pass + + return { + "id": checkpoint.id, + "task_id": checkpoint.task_id, + "agent_id": checkpoint.agent_id, + "agent_name": checkpoint.agent_name, + "agent_type": checkpoint.agent_type, + "parent_agent_id": checkpoint.parent_agent_id, + "iteration": checkpoint.iteration, + "status": checkpoint.status, + "total_tokens": checkpoint.total_tokens, + "tool_calls": checkpoint.tool_calls, + "findings_count": checkpoint.findings_count, + "checkpoint_type": checkpoint.checkpoint_type, + "checkpoint_name": checkpoint.checkpoint_name, + "state_data": state_data, + "metadata": checkpoint.checkpoint_metadata, + "created_at": checkpoint.created_at.isoformat() if checkpoint.created_at else None, + } diff --git a/backend/app/api/v1/endpoints/projects.py b/backend/app/api/v1/endpoints/projects.py index e573a18..cbc34c1 100644 --- a/backend/app/api/v1/endpoints/projects.py +++ b/backend/app/api/v1/endpoints/projects.py @@ -676,15 +676,26 @@ async def get_project_branches( repo_type = project.repository_type or "other" + # 详细日志 + print(f"[Branch] 项目: {project.name}, 类型: {repo_type}, URL: {project.repository_url}") + print(f"[Branch] GitHub Token: {'已配置' if github_token else '未配置'}, GitLab Token: {'已配置' if gitlab_token else '未配置'}") + try: if repo_type == "github": + if not github_token: + print("[Branch] 警告: GitHub Token 未配置,可能会遇到 API 限制") branches = await get_github_branches(project.repository_url, github_token) elif repo_type == "gitlab": + if not gitlab_token: + print("[Branch] 警告: GitLab Token 未配置,可能无法访问私有仓库") branches = await get_gitlab_branches(project.repository_url, gitlab_token) else: # 对于其他类型,返回默认分支 + print(f"[Branch] 仓库类型 '{repo_type}' 不支持获取分支,返回默认分支") branches = [project.default_branch or "main"] + print(f"[Branch] 成功获取 {len(branches)} 个分支") + # 将默认分支放在第一位 default_branch = project.default_branch or "main" if default_branch in branches: @@ -694,7 +705,8 @@ async def get_project_branches( return {"branches": branches, "default_branch": default_branch} except Exception as e: - print(f"获取分支列表失败: {e}") + error_msg = str(e) + print(f"[Branch] 获取分支列表失败: {error_msg}") # 返回默认分支作为后备 return { "branches": [project.default_branch or "main"], diff --git a/backend/app/models/agent_task.py b/backend/app/models/agent_task.py index 435e2a1..0bc1a1a 100644 --- a/backend/app/models/agent_task.py +++ b/backend/app/models/agent_task.py @@ -442,3 +442,141 @@ class AgentFinding(Base): "ai_confidence": self.ai_confidence, "created_at": self.created_at.isoformat() if self.created_at else None, } + + +class AgentCheckpoint(Base): + """ + Agent 检查点 + + 用于持久化 Agent 状态,支持: + - 任务恢复 + - 状态回滚 + - 执行历史追踪 + """ + __tablename__ = "agent_checkpoints" + + id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + task_id = Column(String(36), ForeignKey("agent_tasks.id", ondelete="CASCADE"), nullable=False, index=True) + + # Agent 信息 + agent_id = Column(String(50), nullable=False, index=True) + agent_name = Column(String(255), nullable=False) + agent_type = Column(String(50), nullable=False) + parent_agent_id = Column(String(50), nullable=True) + + # 状态数据(JSON 序列化的 AgentState) + state_data = Column(Text, nullable=False) + + # 执行状态 + iteration = Column(Integer, default=0) + status = Column(String(30), nullable=False) + + # 统计信息 + total_tokens = Column(Integer, default=0) + tool_calls = Column(Integer, default=0) + findings_count = Column(Integer, default=0) + + # 检查点类型 + checkpoint_type = Column(String(30), default="auto") # auto, manual, error, final + checkpoint_name = Column(String(255), nullable=True) + + # 元数据 + checkpoint_metadata = Column(JSON, nullable=True) + + # 时间戳 + created_at = Column(DateTime(timezone=True), server_default=func.now(), index=True) + + def __repr__(self): + return f"" + + def to_dict(self) -> dict: + """转换为字典""" + return { + "id": self.id, + "task_id": self.task_id, + "agent_id": self.agent_id, + "agent_name": self.agent_name, + "agent_type": self.agent_type, + "parent_agent_id": self.parent_agent_id, + "iteration": self.iteration, + "status": self.status, + "total_tokens": self.total_tokens, + "tool_calls": self.tool_calls, + "findings_count": self.findings_count, + "checkpoint_type": self.checkpoint_type, + "checkpoint_name": self.checkpoint_name, + "created_at": self.created_at.isoformat() if self.created_at else None, + } + + +class AgentTreeNode(Base): + """ + Agent 树节点 + + 记录动态 Agent 树的结构,用于: + - 可视化 Agent 树 + - 追踪 Agent 间关系 + - 分析执行流程 + """ + __tablename__ = "agent_tree_nodes" + + id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + task_id = Column(String(36), ForeignKey("agent_tasks.id", ondelete="CASCADE"), nullable=False, index=True) + + # Agent 信息 + agent_id = Column(String(50), nullable=False, unique=True, index=True) + agent_name = Column(String(255), nullable=False) + agent_type = Column(String(50), nullable=False) + + # 树结构 + parent_agent_id = Column(String(50), nullable=True, index=True) + depth = Column(Integer, default=0) # 树深度 + + # 任务信息 + task_description = Column(Text, nullable=True) + knowledge_modules = Column(JSON, nullable=True) + + # 执行状态 + status = Column(String(30), default="created") + + # 执行结果 + result_summary = Column(Text, nullable=True) + findings_count = Column(Integer, default=0) + + # 统计 + iterations = Column(Integer, default=0) + tokens_used = Column(Integer, default=0) + tool_calls = Column(Integer, default=0) + duration_ms = Column(Integer, nullable=True) + + # 时间戳 + created_at = Column(DateTime(timezone=True), server_default=func.now()) + started_at = Column(DateTime(timezone=True), nullable=True) + finished_at = Column(DateTime(timezone=True), nullable=True) + + def __repr__(self): + return f"" + + def to_dict(self) -> dict: + """转换为字典""" + return { + "id": self.id, + "task_id": self.task_id, + "agent_id": self.agent_id, + "agent_name": self.agent_name, + "agent_type": self.agent_type, + "parent_agent_id": self.parent_agent_id, + "depth": self.depth, + "task_description": self.task_description, + "knowledge_modules": self.knowledge_modules, + "status": self.status, + "result_summary": self.result_summary, + "findings_count": self.findings_count, + "iterations": self.iterations, + "tokens_used": self.tokens_used, + "tool_calls": self.tool_calls, + "duration_ms": self.duration_ms, + "created_at": self.created_at.isoformat() if self.created_at else None, + "started_at": self.started_at.isoformat() if self.started_at else None, + "finished_at": self.finished_at.isoformat() if self.finished_at else None, + } diff --git a/backend/app/services/agent/__init__.py b/backend/app/services/agent/__init__.py index a82b83b..b3cbbe5 100644 --- a/backend/app/services/agent/__init__.py +++ b/backend/app/services/agent/__init__.py @@ -2,16 +2,17 @@ DeepAudit Agent 服务模块 基于 LangGraph 的 AI Agent 代码安全审计 -架构: - LangGraph 状态图工作流 - +架构升级版本 - 支持: +- 动态Agent树结构 +- 专业知识模块系统 +- Agent间通信机制 +- 完整状态管理 +- Think工具和漏洞报告工具 + +工作流: START → Recon → Analysis ⟲ → Verification → Report → END -节点: - - Recon: 信息收集 (项目结构、技术栈、入口点) - - Analysis: 漏洞分析 (静态分析、RAG、模式匹配) - - Verification: 漏洞验证 (LLM 验证、沙箱测试) - - Report: 报告生成 + 支持动态创建子Agent进行专业化分析 """ # 从 graph 模块导入主要组件 @@ -32,6 +33,29 @@ from .agents import ( OrchestratorAgent, ReconAgent, AnalysisAgent, VerificationAgent, ) +# 🔥 新增:核心模块(状态管理、注册表、消息) +from .core import ( + AgentState, AgentStatus, + AgentRegistry, agent_registry, + AgentMessage, MessageType, MessagePriority, MessageBus, +) + +# 🔥 新增:知识模块系统(基于RAG) +from .knowledge import ( + KnowledgeLoader, knowledge_loader, + get_available_modules, get_module_content, + SecurityKnowledgeRAG, security_knowledge_rag, + SecurityKnowledgeQueryTool, GetVulnerabilityKnowledgeTool, +) + +# 🔥 新增:协作工具 +from .tools import ( + ThinkTool, ReflectTool, + CreateVulnerabilityReportTool, + CreateSubAgentTool, SendMessageTool, ViewAgentGraphTool, + WaitForMessageTool, AgentFinishTool, +) + __all__ = [ # 核心 Runner "AgentRunner", @@ -54,5 +78,35 @@ __all__ = [ "ReconAgent", "AnalysisAgent", "VerificationAgent", + + # 🔥 核心模块 + "AgentState", + "AgentStatus", + "AgentRegistry", + "agent_registry", + "AgentMessage", + "MessageType", + "MessagePriority", + "MessageBus", + + # 🔥 知识模块(基于RAG) + "KnowledgeLoader", + "knowledge_loader", + "get_available_modules", + "get_module_content", + "SecurityKnowledgeRAG", + "security_knowledge_rag", + "SecurityKnowledgeQueryTool", + "GetVulnerabilityKnowledgeTool", + + # 🔥 协作工具 + "ThinkTool", + "ReflectTool", + "CreateVulnerabilityReportTool", + "CreateSubAgentTool", + "SendMessageTool", + "ViewAgentGraphTool", + "WaitForMessageTool", + "AgentFinishTool", ] diff --git a/backend/app/services/agent/agents/analysis.py b/backend/app/services/agent/agents/analysis.py index 2e6a2d3..508c66e 100644 --- a/backend/app/services/agent/agents/analysis.py +++ b/backend/app/services/agent/agents/analysis.py @@ -46,8 +46,6 @@ ANALYSIS_SYSTEM_PROMPT = """你是 DeepAudit 的漏洞分析 Agent,一个**自 ### 深度分析 - **pattern_match**: 危险模式匹配 参数: pattern (str), file_types (list) -- **code_analysis**: LLM 深度代码分析 ⭐ - 参数: code (str), file_path (str), focus (str) - **dataflow_analysis**: 数据流追踪 参数: source (str), sink (str) @@ -114,7 +112,7 @@ Final Answer: [JSON 格式的漏洞报告] ## 分析策略建议 1. **快速扫描**: 先用 semgrep_scan 获得概览 -2. **重点深入**: 对可疑文件使用 read_file + code_analysis +2. **重点深入**: 对可疑文件使用 read_file + pattern_match 3. **模式搜索**: 用 search_code 找危险模式 (eval, exec, query 等) 4. **语义搜索**: 用 RAG 找相似的漏洞模式 5. **数据流**: 用 dataflow_analysis 追踪用户输入 @@ -268,6 +266,9 @@ class AnalysisAgent(BaseAgent): # 🔥 构建包含交接上下文的初始消息 handoff_context = self.get_handoff_context() + # 🔥 获取目标文件列表 + target_files = config.get("target_files", []) + initial_message = f"""请开始对项目进行安全漏洞分析。 ## 项目信息 @@ -275,7 +276,22 @@ class AnalysisAgent(BaseAgent): - 语言: {tech_stack.get('languages', [])} - 框架: {tech_stack.get('frameworks', [])} -{handoff_context if handoff_context else f'''## 上下文信息 +""" + # 🔥 如果指定了目标文件,明确告知 Agent + if target_files: + initial_message += f"""## ⚠️ 审计范围 +用户指定了 {len(target_files)} 个目标文件进行审计: +""" + for tf in target_files[:10]: + initial_message += f"- {tf}\n" + if len(target_files) > 10: + initial_message += f"- ... 还有 {len(target_files) - 10} 个文件\n" + initial_message += """ +请直接分析这些指定的文件,不要分析其他文件。 + +""" + + initial_message += f"""{handoff_context if handoff_context else f'''## 上下文信息 ### 高风险区域 {json.dumps(high_risk_areas[:20], ensure_ascii=False)} @@ -307,6 +323,7 @@ class AnalysisAgent(BaseAgent): self._steps = [] all_findings = [] + error_message = None # 🔥 跟踪错误信息 await self.emit_thinking("🔬 Analysis Agent 启动,LLM 开始自主安全分析...") @@ -323,11 +340,12 @@ class AnalysisAgent(BaseAgent): break # 调用 LLM 进行思考和决策(流式输出) + # 🔥 增加 max_tokens 到 4096,避免长输出被截断 try: llm_output, tokens_this_round = await self.stream_llm_call( self._conversation_history, temperature=0.1, - max_tokens=2048, + max_tokens=4096, ) except asyncio.CancelledError: logger.info(f"[{self.name}] LLM call cancelled") @@ -338,12 +356,21 @@ class AnalysisAgent(BaseAgent): # 🔥 Handle empty LLM response to prevent loops if not llm_output or not llm_output.strip(): logger.warning(f"[{self.name}] Empty LLM response in iteration {self._iteration}") - await self.emit_llm_decision("收到空响应", "LLM 返回内容为空,尝试重试通过提示") + empty_retry_count = getattr(self, '_empty_retry_count', 0) + 1 + self._empty_retry_count = empty_retry_count + if empty_retry_count >= 3: + logger.error(f"[{self.name}] Too many empty responses, stopping") + error_message = "连续收到空响应,停止分析" + await self.emit_event("error", error_message) + break self._conversation_history.append({ "role": "user", "content": "Received empty response. Please output your Thought and Action.", }) continue + + # 重置空响应计数器 + self._empty_retry_count = 0 # 解析 LLM 响应 step = self._parse_llm_response(llm_output) @@ -396,6 +423,11 @@ class AnalysisAgent(BaseAgent): step.action_input or {} ) + # 🔥 工具执行后检查取消状态 + if self.is_cancelled: + logger.info(f"[{self.name}] Cancelled after tool execution") + break + step.observation = observation # 🔥 发射 LLM 观察事件 @@ -414,9 +446,96 @@ class AnalysisAgent(BaseAgent): "content": "请继续分析。选择一个工具执行,或者如果分析完成,输出 Final Answer 汇总所有发现。", }) + # 🔥 如果循环结束但没有发现,强制 LLM 总结 + if not all_findings and not self.is_cancelled and not error_message: + await self.emit_thinking("📝 分析阶段结束,正在生成漏洞总结...") + + # 添加强制总结的提示 + self._conversation_history.append({ + "role": "user", + "content": """分析阶段已结束。请立即输出 Final Answer,总结你发现的所有安全问题。 + +即使没有发现严重漏洞,也请总结你的分析过程和观察到的潜在风险点。 + +请按以下 JSON 格式输出: +```json +{ + "findings": [ + { + "vulnerability_type": "sql_injection|xss|command_injection|path_traversal|ssrf|hardcoded_secret|other", + "severity": "critical|high|medium|low", + "title": "漏洞标题", + "description": "详细描述", + "file_path": "文件路径", + "line_start": 行号, + "code_snippet": "相关代码片段", + "suggestion": "修复建议" + } + ], + "summary": "分析总结" +} +``` + +Final Answer:""", + }) + + try: + summary_output, _ = await self.stream_llm_call( + self._conversation_history, + temperature=0.1, + max_tokens=4096, + ) + + if summary_output and summary_output.strip(): + # 解析总结输出 + import re + summary_text = summary_output.strip() + summary_text = re.sub(r'```json\s*', '', summary_text) + summary_text = re.sub(r'```\s*', '', summary_text) + parsed_result = AgentJsonParser.parse( + summary_text, + default={"findings": [], "summary": ""} + ) + if "findings" in parsed_result: + all_findings = parsed_result["findings"] + except Exception as e: + logger.warning(f"[{self.name}] Failed to generate summary: {e}") + # 处理结果 duration_ms = int((time.time() - start_time) * 1000) + # 🔥 如果被取消,返回取消结果 + if self.is_cancelled: + await self.emit_event( + "info", + f"🛑 Analysis Agent 已取消: {len(all_findings)} 个发现, {self._iteration} 轮迭代" + ) + return AgentResult( + success=False, + error="任务已取消", + data={"findings": all_findings}, + iterations=self._iteration, + tool_calls=self._tool_calls, + tokens_used=self._total_tokens, + duration_ms=duration_ms, + ) + + # 🔥 如果有错误,返回失败结果 + if error_message: + await self.emit_event( + "error", + f"❌ Analysis Agent 失败: {error_message}" + ) + return AgentResult( + success=False, + error=error_message, + data={"findings": all_findings}, + iterations=self._iteration, + tool_calls=self._tool_calls, + tokens_used=self._total_tokens, + duration_ms=duration_ms, + ) + # 标准化发现 standardized_findings = [] for finding in all_findings: diff --git a/backend/app/services/agent/agents/analysis_v2.py b/backend/app/services/agent/agents/analysis_v2.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/services/agent/agents/base.py b/backend/app/services/agent/agents/base.py index 84b0878..2564d3f 100644 --- a/backend/app/services/agent/agents/base.py +++ b/backend/app/services/agent/agents/base.py @@ -6,6 +6,8 @@ Agent 基类 1. LLM 是 Agent 的大脑,全程参与决策 2. Agent 之间通过 TaskHandoff 传递结构化上下文 3. 事件分为流式事件(前端展示)和持久化事件(数据库记录) +4. 支持动态Agent树和专业知识模块 +5. 完整的状态管理和Agent间通信 """ from abc import ABC, abstractmethod @@ -17,6 +19,10 @@ import asyncio import logging import uuid +from ..core.state import AgentState, AgentStatus +from ..core.registry import agent_registry +from ..core.message import message_bus, MessageType, AgentMessage + logger = logging.getLogger(__name__) @@ -238,6 +244,11 @@ class BaseAgent(ABC): 1. 通过 TaskHandoff 接收前序 Agent 的上下文 2. 执行完成后生成 TaskHandoff 传递给下一个 Agent 3. 洞察和发现应该结构化记录 + + 动态Agent树: + 1. 支持动态创建子Agent + 2. Agent间通过消息总线通信 + 3. 完整的状态管理和生命周期 """ def __init__( @@ -246,6 +257,8 @@ class BaseAgent(ABC): llm_service, tools: Dict[str, Any], event_emitter=None, + parent_id: Optional[str] = None, + knowledge_modules: Optional[List[str]] = None, ): """ 初始化 Agent @@ -255,13 +268,30 @@ class BaseAgent(ABC): llm_service: LLM 服务 tools: 可用工具字典 event_emitter: 事件发射器 + parent_id: 父Agent ID(用于动态Agent树) + knowledge_modules: 要加载的知识模块 """ self.config = config self.llm_service = llm_service self.tools = tools self.event_emitter = event_emitter + self.parent_id = parent_id + self.knowledge_modules = knowledge_modules or [] - # 运行状态 + # 🔥 生成唯一ID + self._agent_id = f"agent_{uuid.uuid4().hex[:8]}" + + # 🔥 增强的状态管理 + self._state = AgentState( + agent_id=self._agent_id, + agent_name=config.name, + agent_type=config.agent_type.value, + parent_id=parent_id, + max_iterations=config.max_iterations, + knowledge_modules=self.knowledge_modules, + ) + + # 运行状态(保持向后兼容) self._iteration = 0 self._total_tokens = 0 self._tool_calls = 0 @@ -271,15 +301,171 @@ class BaseAgent(ABC): self._incoming_handoff: Optional[TaskHandoff] = None self._insights: List[str] = [] # 收集的洞察 self._work_completed: List[str] = [] # 完成的工作记录 + + # 🔥 是否已注册到注册表 + self._registered = False + + # 🔥 加载知识模块到系统提示词 + if self.knowledge_modules: + self._load_knowledge_modules() + + def _register_to_registry(self, task: Optional[str] = None) -> None: + """注册到Agent注册表(延迟注册,在run时调用)""" + logger.info(f"[AgentTree] _register_to_registry 被调用: {self.config.name} (id={self._agent_id}, parent={self.parent_id}, _registered={self._registered})") + + if self._registered: + logger.warning(f"[AgentTree] {self.config.name} 已注册,跳过 (id={self._agent_id})") + return + + logger.info(f"[AgentTree] 正在注册 Agent: {self.config.name} (id={self._agent_id}, parent={self.parent_id})") + + agent_registry.register_agent( + agent_id=self._agent_id, + agent_name=self.config.name, + agent_type=self.config.agent_type.value, + task=task or self._state.task or "Initializing", + parent_id=self.parent_id, + agent_instance=self, + state=self._state, + knowledge_modules=self.knowledge_modules, + ) + + # 创建消息队列 + message_bus.create_queue(self._agent_id) + self._registered = True + + tree = agent_registry.get_agent_tree() + logger.info(f"[AgentTree] Agent 注册完成: {self.config.name}, 当前树节点数: {len(tree['nodes'])}") + + def set_parent_id(self, parent_id: str) -> None: + """设置父Agent ID(在调度时调用)""" + self.parent_id = parent_id + self._state.parent_id = parent_id + + def _load_knowledge_modules(self) -> None: + """加载知识模块到系统提示词""" + if not self.knowledge_modules: + return + + try: + from ..knowledge import knowledge_loader + + enhanced_prompt = knowledge_loader.build_system_prompt_with_modules( + self.config.system_prompt or "", + self.knowledge_modules, + ) + self.config.system_prompt = enhanced_prompt + + logger.info(f"[{self.name}] Loaded knowledge modules: {self.knowledge_modules}") + except Exception as e: + logger.warning(f"Failed to load knowledge modules: {e}") @property def name(self) -> str: return self.config.name + @property + def agent_id(self) -> str: + return self._agent_id + + @property + def state(self) -> AgentState: + return self._state + @property def agent_type(self) -> AgentType: return self.config.agent_type + # ============ Agent间消息处理 ============ + + def check_messages(self) -> List[AgentMessage]: + """ + 检查并处理收到的消息 + + Returns: + 未读消息列表 + """ + messages = message_bus.get_messages( + self._agent_id, + unread_only=True, + mark_as_read=True, + ) + + for msg in messages: + # 处理消息 + if msg.from_agent == "user": + # 用户消息直接添加到对话历史 + self._state.add_message("user", msg.content) + else: + # Agent间消息使用XML格式 + self._state.add_message("user", msg.to_xml()) + + # 如果在等待状态,恢复执行 + if self._state.is_waiting_for_input(): + self._state.resume_from_waiting() + agent_registry.update_agent_status(self._agent_id, "running") + + return messages + + def has_pending_messages(self) -> bool: + """检查是否有待处理的消息""" + return message_bus.has_unread_messages(self._agent_id) + + def send_message_to_parent( + self, + content: str, + message_type: MessageType = MessageType.INFORMATION, + ) -> None: + """向父Agent发送消息""" + if self.parent_id: + message_bus.send_message( + from_agent=self._agent_id, + to_agent=self.parent_id, + content=content, + message_type=message_type, + ) + + def send_message_to_agent( + self, + target_id: str, + content: str, + message_type: MessageType = MessageType.INFORMATION, + ) -> None: + """向指定Agent发送消息""" + message_bus.send_message( + from_agent=self._agent_id, + to_agent=target_id, + content=content, + message_type=message_type, + ) + + # ============ 生命周期管理 ============ + + def on_start(self) -> None: + """Agent开始执行时调用""" + self._state.start() + agent_registry.update_agent_status(self._agent_id, "running") + + def on_complete(self, result: Dict[str, Any]) -> None: + """Agent完成时调用""" + self._state.set_completed(result) + agent_registry.update_agent_status(self._agent_id, "completed", result) + + # 向父Agent报告完成 + if self.parent_id: + message_bus.send_completion_report( + from_agent=self._agent_id, + to_agent=self.parent_id, + summary=result.get("summary", "Task completed"), + findings=result.get("findings", []), + success=True, + ) + + def on_error(self, error: str) -> None: + """Agent出错时调用""" + self._state.set_failed(error) + agent_registry.update_agent_status(self._agent_id, "failed", {"error": error}) + @abstractmethod async def run(self, input_data: Dict[str, Any]) -> AgentResult: """ @@ -296,6 +482,7 @@ class BaseAgent(ABC): def cancel(self): """取消执行""" self._cancelled = True + logger.info(f"[{self.name}] Cancel requested") @property def is_cancelled(self) -> bool: @@ -671,6 +858,35 @@ class BaseAgent(ABC): "tokens_used": self._total_tokens, } + # ============ Memory Compression ============ + + def compress_messages_if_needed( + self, + messages: List[Dict[str, str]], + max_tokens: int = 100000, + ) -> List[Dict[str, str]]: + """ + 如果消息历史过长,自动压缩 + + Args: + messages: 消息列表 + max_tokens: 最大token数 + + Returns: + 压缩后的消息列表 + """ + from ...llm.memory_compressor import MemoryCompressor + + compressor = MemoryCompressor(max_total_tokens=max_tokens) + + if compressor.should_compress(messages): + logger.info(f"[{self.name}] Compressing conversation history...") + compressed = compressor.compress_history(messages) + logger.info(f"[{self.name}] Compressed {len(messages)} -> {len(compressed)} messages") + return compressed + + return messages + # ============ 统一的流式 LLM 调用 ============ async def stream_llm_call( @@ -678,6 +894,7 @@ class BaseAgent(ABC): messages: List[Dict[str, str]], temperature: float = 0.1, max_tokens: int = 2048, + auto_compress: bool = True, ) -> Tuple[str, int]: """ 统一的流式 LLM 调用方法 @@ -688,13 +905,23 @@ class BaseAgent(ABC): messages: 消息列表 temperature: 温度 max_tokens: 最大 token 数 + auto_compress: 是否自动压缩过长的消息历史 Returns: (完整响应内容, token数量) """ + # 🔥 自动压缩过长的消息历史 + if auto_compress: + messages = self.compress_messages_if_needed(messages) + accumulated = "" total_tokens = 0 + # 🔥 在开始 LLM 调用前检查取消 + if self.is_cancelled: + logger.info(f"[{self.name}] Cancelled before LLM call") + return "", 0 + await self.emit_thinking_start() try: @@ -705,6 +932,7 @@ class BaseAgent(ABC): ): # 检查取消 if self.is_cancelled: + logger.info(f"[{self.name}] Cancelled during LLM streaming") break if chunk["type"] == "token": @@ -745,6 +973,10 @@ class BaseAgent(ABC): Returns: 工具执行结果字符串 """ + # 🔥 在执行工具前检查取消 + if self.is_cancelled: + return "任务已取消" + tool = self.tools.get(tool_name) if not tool: diff --git a/backend/app/services/agent/agents/orchestrator.py b/backend/app/services/agent/agents/orchestrator.py index f6c1030..4f1aae0 100644 --- a/backend/app/services/agent/agents/orchestrator.py +++ b/backend/app/services/agent/agents/orchestrator.py @@ -79,7 +79,7 @@ Action Input: [JSON 参数] ``` ## 审计策略建议 -- 先用 recon Agent 了解项目全貌 +- 先用 recon Agent 了解项目全貌(只需调度一次) - 根据 recon 结果,让 analysis Agent 重点审计高风险区域 - 发现可疑漏洞后,用 verification Agent 验证 - 随时根据新发现调整策略,不要机械执行 @@ -90,6 +90,15 @@ Action Input: [JSON 参数] 2. **动态调整** - 根据发现调整策略 3. **主动决策** - 不要等待,主动推进 4. **质量优先** - 宁可深入分析几个真实漏洞,不要浅尝辄止 +5. **避免重复** - 每个 Agent 通常只需要调度一次,如果结果不理想,尝试其他 Agent 或直接完成审计 + +## 处理子 Agent 结果 +- 子 Agent 返回的 Observation 包含它们的分析结果 +- 即使结果看起来不完整,也要基于已有信息继续推进 +- 不要反复调度同一个 Agent 期望得到不同结果 +- 如果 recon 完成后,应该调度 analysis 进行深度分析 +- 如果 analysis 完成后有发现,可以调度 verification 验证 +- 如果没有更多工作要做,使用 finish 结束审计 现在,基于项目信息开始你的审计工作!""" @@ -136,11 +145,32 @@ class OrchestratorAgent(BaseAgent): self._conversation_history: List[Dict[str, str]] = [] self._steps: List[AgentStep] = [] self._all_findings: List[Dict] = [] + + # 🔥 存储运行时上下文,用于传递给子 Agent + self._runtime_context: Dict[str, Any] = {} + + # 🔥 跟踪已调度的 Agent 任务,避免重复调度 + self._dispatched_tasks: Dict[str, int] = {} # agent_name -> dispatch_count def register_sub_agent(self, name: str, agent: BaseAgent): """注册子 Agent""" self.sub_agents[name] = agent + def cancel(self): + """ + 取消执行 - 同时取消所有子 Agent + + 重写父类方法,确保取消信号传播到所有子 Agent + """ + self._cancelled = True + logger.info(f"[{self.name}] Cancel requested, propagating to {len(self.sub_agents)} sub-agents") + + # 🔥 传播取消信号到所有子 Agent + for name, agent in self.sub_agents.items(): + if hasattr(agent, 'cancel'): + agent.cancel() + logger.info(f"[{self.name}] Cancelled sub-agent: {name}") + async def run(self, input_data: Dict[str, Any]) -> AgentResult: """ 执行编排任务 - LLM 全程参与! @@ -149,6 +179,8 @@ class OrchestratorAgent(BaseAgent): input_data: { "project_info": 项目信息, "config": 审计配置, + "project_root": 项目根目录, + "task_id": 任务ID, } """ import time @@ -157,6 +189,14 @@ class OrchestratorAgent(BaseAgent): project_info = input_data.get("project_info", {}) config = input_data.get("config", {}) + # 🔥 保存运行时上下文,用于传递给子 Agent + self._runtime_context = { + "project_info": project_info, + "config": config, + "project_root": input_data.get("project_root", project_info.get("root", ".")), + "task_id": input_data.get("task_id"), + } + # 构建初始消息 initial_message = self._build_initial_message(project_info, config) @@ -169,6 +209,7 @@ class OrchestratorAgent(BaseAgent): self._steps = [] self._all_findings = [] final_result = None + error_message = None # 🔥 跟踪错误信息 await self.emit_thinking("🧠 Orchestrator Agent 启动,LLM 开始自主编排决策...") @@ -189,7 +230,7 @@ class OrchestratorAgent(BaseAgent): llm_output, tokens_this_round = await self.stream_llm_call( self._conversation_history, temperature=0.1, - max_tokens=2048, + max_tokens=4096, # 🔥 增加到 4096,避免截断 ) except asyncio.CancelledError: logger.info(f"[{self.name}] LLM call cancelled") @@ -197,11 +238,37 @@ class OrchestratorAgent(BaseAgent): self._total_tokens += tokens_this_round + # 🔥 检测空响应 + if not llm_output or not llm_output.strip(): + logger.warning(f"[{self.name}] Empty LLM response") + empty_retry_count = getattr(self, '_empty_retry_count', 0) + 1 + self._empty_retry_count = empty_retry_count + if empty_retry_count >= 3: + logger.error(f"[{self.name}] Too many empty responses, stopping") + error_message = "连续收到空响应,停止编排" + await self.emit_event("error", error_message) + break + self._conversation_history.append({ + "role": "user", + "content": "Received empty response. Please output Thought + Action + Action Input.", + }) + continue + + # 重置空响应计数器 + self._empty_retry_count = 0 + # 解析 LLM 的决策 step = self._parse_llm_response(llm_output) if not step: # LLM 输出格式不正确,提示重试 + format_retry_count = getattr(self, '_format_retry_count', 0) + 1 + self._format_retry_count = format_retry_count + if format_retry_count >= 3: + logger.error(f"[{self.name}] Too many format errors, stopping") + error_message = "连续格式错误,停止编排" + await self.emit_event("error", error_message) + break await self.emit_llm_decision("格式错误", "需要重新输出") self._conversation_history.append({ "role": "assistant", @@ -213,6 +280,9 @@ class OrchestratorAgent(BaseAgent): }) continue + # 重置格式重试计数器 + self._format_retry_count = 0 + self._steps.append(step) # 🔥 发射 LLM 思考内容事件 - 展示编排决策的思考过程 @@ -249,6 +319,11 @@ class OrchestratorAgent(BaseAgent): observation = await self._dispatch_agent(step.action_input) step.observation = observation + # 🔥 子 Agent 执行完成后检查取消状态 + if self.is_cancelled: + logger.info(f"[{self.name}] Cancelled after sub-agent dispatch") + break + # 🔥 发射观察事件 await self.emit_llm_observation(observation) @@ -272,6 +347,60 @@ class OrchestratorAgent(BaseAgent): # 生成最终结果 duration_ms = int((time.time() - start_time) * 1000) + # 🔥 如果被取消,返回取消结果 + if self.is_cancelled: + await self.emit_event( + "info", + f"🛑 Orchestrator 已取消: {len(self._all_findings)} 个发现, {self._iteration} 轮决策" + ) + return AgentResult( + success=False, + error="任务已取消", + data={ + "findings": self._all_findings, + "steps": [ + { + "thought": s.thought, + "action": s.action, + "action_input": s.action_input, + "observation": s.observation[:500] if s.observation else None, + } + for s in self._steps + ], + }, + iterations=self._iteration, + tool_calls=self._tool_calls, + tokens_used=self._total_tokens, + duration_ms=duration_ms, + ) + + # 🔥 如果有错误,返回失败结果 + if error_message: + await self.emit_event( + "error", + f"❌ Orchestrator 失败: {error_message}" + ) + return AgentResult( + success=False, + error=error_message, + data={ + "findings": self._all_findings, + "steps": [ + { + "thought": s.thought, + "action": s.action, + "action_input": s.action_input, + "observation": s.observation[:500] if s.observation else None, + } + for s in self._steps + ], + }, + iterations=self._iteration, + tool_calls=self._tool_calls, + tokens_used=self._total_tokens, + duration_ms=duration_ms, + ) + await self.emit_event( "info", f"🎯 Orchestrator 完成: {len(self._all_findings)} 个发现, {self._iteration} 轮决策" @@ -377,6 +506,30 @@ class OrchestratorAgent(BaseAgent): available = list(self.sub_agents.keys()) return f"错误: Agent '{agent_name}' 不存在。可用的 Agent: {available}" + # 🔥 检查是否重复调度同一个 Agent + dispatch_count = self._dispatched_tasks.get(agent_name, 0) + if dispatch_count >= 2: + return f"""## ⚠️ 重复调度警告 + +你已经调度 {agent_name} Agent {dispatch_count} 次了。 + +如果之前的调度没有返回有用的结果,请考虑: +1. 尝试调度其他 Agent(如 analysis 或 verification) +2. 使用 finish 操作结束审计并汇总已有发现 +3. 提供更具体的任务描述 + +当前已收集的发现数量: {len(self._all_findings)} +""" + + self._dispatched_tasks[agent_name] = dispatch_count + 1 + + # 🔥 设置父 Agent ID 并注册到注册表(动态 Agent 树) + logger.info(f"[Orchestrator] 准备调度 {agent_name} Agent, agent._registered={agent._registered}") + agent.set_parent_id(self._agent_id) + logger.info(f"[Orchestrator] 设置 parent_id 完成,准备注册 {agent_name}") + agent._register_to_registry(task=task) + logger.info(f"[Orchestrator] {agent_name} 注册完成,agent._registered={agent._registered}") + await self.emit_event( "dispatch", f"📤 调度 {agent_name} Agent: {task[:100]}...", @@ -387,31 +540,92 @@ class OrchestratorAgent(BaseAgent): self._tool_calls += 1 try: - # 构建子 Agent 输入 + # 🔥 构建子 Agent 输入 - 传递完整的运行时上下文 + project_info = self._runtime_context.get("project_info", {}).copy() + # 确保 project_info 包含 root 路径 + if "root" not in project_info: + project_info["root"] = self._runtime_context.get("project_root", ".") + sub_input = { "task": task, "task_context": context, - "project_info": {}, # 从上下文获取 - "config": {}, + "project_info": project_info, + "config": self._runtime_context.get("config", {}), + "project_root": self._runtime_context.get("project_root", "."), + "previous_results": { + "findings": self._all_findings, # 传递已收集的发现 + }, } + # 🔥 执行子 Agent 前检查取消状态 + if self.is_cancelled: + return f"## {agent_name} Agent 执行取消\n\n任务已被用户取消" + # 执行子 Agent result = await agent.run(sub_input) - # 收集发现 + # 🔥 执行后再次检查取消状态 + if self.is_cancelled: + return f"## {agent_name} Agent 执行中断\n\n任务已被用户取消" + + # 🔥 处理子 Agent 结果 - 不同 Agent 返回不同的数据结构 if result.success and result.data: - findings = result.data.get("findings", []) - self._all_findings.extend(findings) + data = result.data + + # 🔥 收集发现 - 只收集格式正确的漏洞对象 + # findings 字段通常来自 Analysis/Verification Agent,是漏洞对象数组 + # initial_findings 来自 Recon Agent,可能是字符串数组(观察)或对象数组 + findings = data.get("findings", []) + if findings: + # 只添加字典格式的发现 + valid_findings = [f for f in findings if isinstance(f, dict)] + self._all_findings.extend(valid_findings) await self.emit_event( "dispatch_complete", - f"✅ {agent_name} Agent 完成: {len(findings)} 个发现", + f"✅ {agent_name} Agent 完成", agent=agent_name, findings_count=len(findings), ) - # 构建观察结果 - observation = f"""## {agent_name} Agent 执行结果 + # 🔥 根据 Agent 类型构建不同的观察结果 + if agent_name == "recon": + # Recon Agent 返回项目信息 + observation = f"""## Recon Agent 执行结果 + +**状态**: 成功 +**迭代次数**: {result.iterations} +**耗时**: {result.duration_ms}ms + +### 项目结构 +{json.dumps(data.get('project_structure', {}), ensure_ascii=False, indent=2)} + +### 技术栈 +- 语言: {data.get('tech_stack', {}).get('languages', [])} +- 框架: {data.get('tech_stack', {}).get('frameworks', [])} +- 数据库: {data.get('tech_stack', {}).get('databases', [])} + +### 入口点 ({len(data.get('entry_points', []))} 个) +""" + for i, ep in enumerate(data.get('entry_points', [])[:10]): + if isinstance(ep, dict): + observation += f"{i+1}. [{ep.get('type', 'unknown')}] {ep.get('file', '')}:{ep.get('line', '')}\n" + + observation += f""" +### 高风险区域 +{data.get('high_risk_areas', [])} + +### 初步发现 ({len(data.get('initial_findings', []))} 个) +""" + for finding in data.get('initial_findings', [])[:5]: + if isinstance(finding, str): + observation += f"- {finding}\n" + elif isinstance(finding, dict): + observation += f"- {finding.get('title', finding)}\n" + + else: + # Analysis/Verification Agent 返回漏洞发现 + observation = f"""## {agent_name} Agent 执行结果 **状态**: 成功 **发现数量**: {len(findings)} @@ -420,22 +634,21 @@ class OrchestratorAgent(BaseAgent): ### 发现摘要 """ - for i, f in enumerate(findings[:10]): # 最多显示 10 个 - if not isinstance(f, dict): - continue - - observation += f""" + for i, f in enumerate(findings[:10]): + if not isinstance(f, dict): + continue + observation += f""" {i+1}. [{f.get('severity', 'unknown')}] {f.get('title', 'Unknown')} - 类型: {f.get('vulnerability_type', 'unknown')} - 文件: {f.get('file_path', 'unknown')} - 描述: {f.get('description', '')[:200]}... """ + + if len(findings) > 10: + observation += f"\n... 还有 {len(findings) - 10} 个发现" - if len(findings) > 10: - observation += f"\n... 还有 {len(findings) - 10} 个发现" - - if result.data.get("summary"): - observation += f"\n\n### Agent 总结\n{result.data['summary']}" + if data.get("summary"): + observation += f"\n\n### Agent 总结\n{data['summary']}" return observation else: diff --git a/backend/app/services/agent/agents/react_agent.py b/backend/app/services/agent/agents/react_agent.py deleted file mode 100644 index 37e8e0e..0000000 --- a/backend/app/services/agent/agents/react_agent.py +++ /dev/null @@ -1,380 +0,0 @@ -""" -真正的 ReAct Agent 实现 -LLM 是大脑,全程参与决策! - -ReAct 循环: -1. Thought: LLM 思考当前状态和下一步 -2. Action: LLM 决定调用哪个工具 -3. Observation: 执行工具,获取结果 -4. 重复直到 LLM 决定完成 -""" - -import json -import logging -import re -from typing import List, Dict, Any, Optional, Tuple -from dataclasses import dataclass - -from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern -from ..json_parser import AgentJsonParser - -logger = logging.getLogger(__name__) - - -REACT_SYSTEM_PROMPT = """你是 DeepAudit 安全审计 Agent,一个专业的代码安全分析专家。 - -## 你的任务 -对目标项目进行全面的安全审计,发现潜在的安全漏洞。 - -## 你的工具 -{tools_description} - -## 工作方式 -你需要通过 **思考-行动-观察** 循环来完成任务: - -1. **Thought**: 分析当前情况,思考下一步应该做什么 -2. **Action**: 选择一个工具并执行 -3. **Observation**: 观察工具返回的结果 -4. 重复上述过程直到你认为审计完成 - -## 输出格式 -每一步必须严格按照以下格式输出: - -``` -Thought: [你的思考过程,分析当前状态,决定下一步] -Action: [工具名称] -Action Input: [工具参数,JSON 格式] -``` - -当你完成分析后,输出: -``` -Thought: [总结分析结果] -Final Answer: [JSON 格式的最终发现] -``` - -## Final Answer 格式 -```json -{{ - "findings": [ - {{ - "vulnerability_type": "sql_injection", - "severity": "high", - "title": "SQL 注入漏洞", - "description": "详细描述", - "file_path": "path/to/file.py", - "line_start": 42, - "code_snippet": "危险代码片段", - "suggestion": "修复建议" - }} - ], - "summary": "审计总结" -}} -``` - -## 审计策略建议 -1. 先用 list_files 了解项目结构 -2. 识别关键文件(路由、控制器、数据库操作) -3. 使用 search_code 搜索危险模式(eval, exec, query, innerHTML 等) -4. 读取可疑文件进行深度分析 -5. 如果有 semgrep,用它进行全面扫描 - -## 重点关注的漏洞类型 -- SQL 注入 (query, execute, raw SQL) -- XSS (innerHTML, document.write, v-html) -- 命令注入 (exec, system, subprocess, child_process) -- 路径遍历 (open, readFile, path concatenation) -- SSRF (requests, fetch, http client) -- 硬编码密钥 (password, secret, api_key, token) -- 不安全的反序列化 (pickle, yaml.load, eval) - -现在开始审计!""" - - -@dataclass -class AgentStep: - """Agent 执行步骤""" - thought: str - action: Optional[str] = None - action_input: Optional[Dict] = None - observation: Optional[str] = None - is_final: bool = False - final_answer: Optional[Dict] = None - - -class ReActAgent(BaseAgent): - """ - 真正的 ReAct Agent - - LLM 全程参与决策,自主选择工具和分析策略 - """ - - def __init__( - self, - llm_service, - tools: Dict[str, Any], - event_emitter=None, - agent_type: AgentType = AgentType.ANALYSIS, - max_iterations: int = 30, - ): - config = AgentConfig( - name="ReActAgent", - agent_type=agent_type, - pattern=AgentPattern.REACT, - max_iterations=max_iterations, - system_prompt=REACT_SYSTEM_PROMPT, - ) - super().__init__(config, llm_service, tools, event_emitter) - - self._conversation_history: List[Dict[str, str]] = [] - self._steps: List[AgentStep] = [] - - def _get_tools_description(self) -> str: - """生成工具描述""" - descriptions = [] - - for name, tool in self.tools.items(): - if name.startswith("_"): - continue - - desc = f"### {name}\n" - desc += f"{tool.description}\n" - - # 添加参数说明 - if hasattr(tool, 'args_schema') and tool.args_schema: - schema = tool.args_schema.schema() - properties = schema.get("properties", {}) - if properties: - desc += "参数:\n" - for param_name, param_info in properties.items(): - param_desc = param_info.get("description", "") - param_type = param_info.get("type", "string") - desc += f" - {param_name} ({param_type}): {param_desc}\n" - - descriptions.append(desc) - - return "\n".join(descriptions) - - def _build_system_prompt(self, project_info: Dict, task_context: str = "") -> str: - """构建系统提示词""" - tools_desc = self._get_tools_description() - prompt = self.config.system_prompt.format(tools_description=tools_desc) - - if project_info: - prompt += f"\n\n## 项目信息\n" - prompt += f"- 名称: {project_info.get('name', 'unknown')}\n" - prompt += f"- 语言: {', '.join(project_info.get('languages', ['unknown']))}\n" - prompt += f"- 文件数: {project_info.get('file_count', 'unknown')}\n" - - if task_context: - prompt += f"\n\n## 任务上下文\n{task_context}" - - return prompt - - def _parse_llm_response(self, response: str) -> AgentStep: - """解析 LLM 响应""" - step = AgentStep(thought="") - - # 提取 Thought - thought_match = re.search(r'Thought:\s*(.*?)(?=Action:|Final Answer:|$)', response, re.DOTALL) - if thought_match: - step.thought = thought_match.group(1).strip() - - # 检查是否是最终答案 - final_match = re.search(r'Final Answer:\s*(.*?)$', response, re.DOTALL) - if final_match: - step.is_final = True - answer_text = final_match.group(1).strip() - answer_text = re.sub(r'```json\s*', '', answer_text) - answer_text = re.sub(r'```\s*', '', answer_text) - # 使用增强的 JSON 解析器 - step.final_answer = AgentJsonParser.parse( - answer_text, - default={"raw_answer": answer_text} - ) - # 确保 findings 格式正确 - if "findings" in step.final_answer: - step.final_answer["findings"] = [ - f for f in step.final_answer["findings"] - if isinstance(f, dict) - ] - return step - - # 提取 Action - action_match = re.search(r'Action:\s*(\w+)', response) - if action_match: - step.action = action_match.group(1).strip() - - # 提取 Action Input - input_match = re.search(r'Action Input:\s*(.*?)(?=Thought:|Action:|Observation:|$)', response, re.DOTALL) - if input_match: - input_text = input_match.group(1).strip() - input_text = re.sub(r'```json\s*', '', input_text) - input_text = re.sub(r'```\s*', '', input_text) - # 使用增强的 JSON 解析器 - step.action_input = AgentJsonParser.parse( - input_text, - default={"raw_input": input_text} - ) - - return step - - async def _execute_tool(self, tool_name: str, tool_input: Dict) -> str: - """执行工具""" - tool = self.tools.get(tool_name) - - if not tool: - return f"错误: 工具 '{tool_name}' 不存在。可用工具: {list(self.tools.keys())}" - - try: - self._tool_calls += 1 - await self.emit_tool_call(tool_name, tool_input) - - import time - start = time.time() - - result = await tool.execute(**tool_input) - - duration_ms = int((time.time() - start) * 1000) - await self.emit_tool_result(tool_name, str(result.data)[:200], duration_ms) - - if result.success: - # 截断过长的输出 - output = str(result.data) - if len(output) > 4000: - output = output[:4000] + "\n\n... [输出已截断,共 {} 字符]".format(len(str(result.data))) - return output - else: - return f"工具执行失败: {result.error}" - - except Exception as e: - logger.error(f"Tool execution error: {e}") - return f"工具执行错误: {str(e)}" - - async def run(self, input_data: Dict[str, Any]) -> AgentResult: - """ - 执行 ReAct Agent - - LLM 全程参与,自主决策! - """ - import time - start_time = time.time() - - project_info = input_data.get("project_info", {}) - task_context = input_data.get("task_context", "") - config = input_data.get("config", {}) - - # 构建系统提示词 - system_prompt = self._build_system_prompt(project_info, task_context) - - # 初始化对话历史 - self._conversation_history = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": "请开始对项目进行安全审计。首先了解项目结构,然后系统性地搜索和分析潜在的安全漏洞。"}, - ] - - self._steps = [] - all_findings = [] - - await self.emit_thinking("🤖 ReAct Agent 启动,LLM 开始自主分析...") - - try: - for iteration in range(self.config.max_iterations): - if self.is_cancelled: - break - - self._iteration = iteration + 1 - - await self.emit_thinking(f"💭 第 {iteration + 1} 轮思考...") - - # 🔥 调用 LLM 进行思考和决策 - response = await self.llm_service.chat_completion_raw( - messages=self._conversation_history, - temperature=0.1, - max_tokens=2048, - ) - - llm_output = response.get("content", "") - self._total_tokens += response.get("usage", {}).get("total_tokens", 0) - - # 发射思考事件 - await self.emit_event("thinking", f"LLM: {llm_output[:500]}...") - - # 解析 LLM 响应 - step = self._parse_llm_response(llm_output) - self._steps.append(step) - - # 添加 LLM 响应到历史 - self._conversation_history.append({ - "role": "assistant", - "content": llm_output, - }) - - # 检查是否完成 - if step.is_final: - await self.emit_thinking("✅ LLM 完成分析,生成最终报告") - - if step.final_answer and "findings" in step.final_answer: - all_findings = step.final_answer["findings"] - break - - # 执行工具 - if step.action: - await self.emit_thinking(f"🔧 LLM 决定调用工具: {step.action}") - - observation = await self._execute_tool( - step.action, - step.action_input or {} - ) - - step.observation = observation - - # 添加观察结果到历史 - self._conversation_history.append({ - "role": "user", - "content": f"Observation: {observation}", - }) - else: - # LLM 没有选择工具,提示它继续 - self._conversation_history.append({ - "role": "user", - "content": "请继续分析,选择一个工具执行,或者如果分析完成,输出 Final Answer。", - }) - - duration_ms = int((time.time() - start_time) * 1000) - - await self.emit_event( - "info", - f"🎯 ReAct Agent 完成: {len(all_findings)} 个发现, {self._iteration} 轮迭代, {self._tool_calls} 次工具调用" - ) - - return AgentResult( - success=True, - data={ - "findings": all_findings, - "steps": [ - { - "thought": s.thought, - "action": s.action, - "action_input": s.action_input, - "observation": s.observation[:500] if s.observation else None, - } - for s in self._steps - ], - }, - iterations=self._iteration, - tool_calls=self._tool_calls, - tokens_used=self._total_tokens, - duration_ms=duration_ms, - ) - - except Exception as e: - logger.error(f"ReAct Agent failed: {e}", exc_info=True) - return AgentResult(success=False, error=str(e)) - - def get_conversation_history(self) -> List[Dict[str, str]]: - """获取对话历史""" - return self._conversation_history - - def get_steps(self) -> List[AgentStep]: - """获取执行步骤""" - return self._steps diff --git a/backend/app/services/agent/agents/recon.py b/backend/app/services/agent/agents/recon.py index 1850c52..e1dcd9c 100644 --- a/backend/app/services/agent/agents/recon.py +++ b/backend/app/services/agent/agents/recon.py @@ -99,6 +99,12 @@ Final Answer: [JSON 格式的收集结果] 4. 运行安全扫描发现初步问题 5. 根据发现继续深入 +## 重要提示 +- 用户可能指定了特定的目标文件进行审计 +- 如果 list_files 显示"审计范围限定为 X 个指定文件",说明只需要分析这些文件 +- 在这种情况下,直接读取和分析指定的文件,不要浪费时间遍历其他目录 +- 如果目录显示为空,可能是因为该目录不包含目标文件 + ## 重要原则 1. **你是大脑** - 每一步都要思考,不要机械执行 2. **动态调整** - 根据发现调整策略 @@ -216,13 +222,38 @@ class ReconAgent(BaseAgent): task = input_data.get("task", "") task_context = input_data.get("task_context", "") + # 🔥 获取目标文件列表 + target_files = config.get("target_files", []) + exclude_patterns = config.get("exclude_patterns", []) + # 构建初始消息 initial_message = f"""请开始收集项目信息。 ## 项目基本信息 - 名称: {project_info.get('name', 'unknown')} - 根目录: {project_info.get('root', '.')} +- 文件数量: {project_info.get('file_count', 'unknown')} +## 审计范围 +""" + # 🔥 如果指定了目标文件,明确告知 Agent + if target_files: + initial_message += f"""⚠️ **重要**: 用户指定了 {len(target_files)} 个目标文件进行审计: +""" + for tf in target_files[:10]: + initial_message += f"- {tf}\n" + if len(target_files) > 10: + initial_message += f"- ... 还有 {len(target_files) - 10} 个文件\n" + initial_message += """ +请直接读取和分析这些指定的文件,不要浪费时间遍历其他目录。 +""" + else: + initial_message += "全项目审计(无特定文件限制)\n" + + if exclude_patterns: + initial_message += f"\n排除模式: {', '.join(exclude_patterns[:5])}\n" + + initial_message += f""" ## 任务上下文 {task_context or task or '进行全面的信息收集,为安全审计做准备。'} @@ -239,6 +270,7 @@ class ReconAgent(BaseAgent): self._steps = [] final_result = None + error_message = None # 🔥 跟踪错误信息 await self.emit_thinking("Recon Agent 启动,LLM 开始自主收集信息...") @@ -259,7 +291,7 @@ class ReconAgent(BaseAgent): llm_output, tokens_this_round = await self.stream_llm_call( self._conversation_history, temperature=0.1, - max_tokens=2048, + max_tokens=4096, # 🔥 增加到 4096,避免截断 ) except asyncio.CancelledError: logger.info(f"[{self.name}] LLM call cancelled") @@ -270,12 +302,21 @@ class ReconAgent(BaseAgent): # 🔥 Handle empty LLM response to prevent loops if not llm_output or not llm_output.strip(): logger.warning(f"[{self.name}] Empty LLM response in iteration {self._iteration}") - await self.emit_llm_decision("收到空响应", "LLM 返回内容为空,尝试重试通过提示") + empty_retry_count = getattr(self, '_empty_retry_count', 0) + 1 + self._empty_retry_count = empty_retry_count + if empty_retry_count >= 3: + logger.error(f"[{self.name}] Too many empty responses, stopping") + error_message = "连续收到空响应,停止信息收集" + await self.emit_event("error", error_message) + break self._conversation_history.append({ "role": "user", "content": "Received empty response. Please output your Thought and Action.", }) continue + + # 重置空响应计数器 + self._empty_retry_count = 0 # 解析 LLM 响应 step = self._parse_llm_response(llm_output) @@ -311,6 +352,11 @@ class ReconAgent(BaseAgent): step.action_input or {} ) + # 🔥 工具执行后检查取消状态 + if self.is_cancelled: + logger.info(f"[{self.name}] Cancelled after tool execution") + break + step.observation = observation # 🔥 发射 LLM 观察事件 @@ -329,9 +375,84 @@ class ReconAgent(BaseAgent): "content": "请继续,选择一个工具执行,或者如果信息收集完成,输出 Final Answer。", }) + # 🔥 如果循环结束但没有 final_result,强制 LLM 总结 + if not final_result and not self.is_cancelled and not error_message: + await self.emit_thinking("📝 信息收集阶段结束,正在生成总结...") + + # 添加强制总结的提示 + self._conversation_history.append({ + "role": "user", + "content": """信息收集阶段已结束。请立即输出 Final Answer,总结你收集到的所有信息。 + +请按以下 JSON 格式输出: +```json +{ + "project_structure": {"directories": [...], "key_files": [...]}, + "tech_stack": {"languages": [...], "frameworks": [...], "databases": [...]}, + "entry_points": [{"type": "...", "file": "...", "description": "..."}], + "high_risk_areas": ["file1.py", "file2.js"], + "initial_findings": [{"title": "...", "description": "...", "file_path": "..."}], + "summary": "项目总结描述" +} +``` + +Final Answer:""", + }) + + try: + summary_output, _ = await self.stream_llm_call( + self._conversation_history, + temperature=0.1, + max_tokens=2048, + ) + + if summary_output and summary_output.strip(): + # 解析总结输出 + summary_text = summary_output.strip() + summary_text = re.sub(r'```json\s*', '', summary_text) + summary_text = re.sub(r'```\s*', '', summary_text) + final_result = AgentJsonParser.parse( + summary_text, + default=self._summarize_from_steps() + ) + except Exception as e: + logger.warning(f"[{self.name}] Failed to generate summary: {e}") + # 处理结果 duration_ms = int((time.time() - start_time) * 1000) + # 🔥 如果被取消,返回取消结果 + if self.is_cancelled: + await self.emit_event( + "info", + f"🛑 Recon Agent 已取消: {self._iteration} 轮迭代" + ) + return AgentResult( + success=False, + error="任务已取消", + data=self._summarize_from_steps(), + iterations=self._iteration, + tool_calls=self._tool_calls, + tokens_used=self._total_tokens, + duration_ms=duration_ms, + ) + + # 🔥 如果有错误,返回失败结果 + if error_message: + await self.emit_event( + "error", + f"❌ Recon Agent 失败: {error_message}" + ) + return AgentResult( + success=False, + error=error_message, + data=self._summarize_from_steps(), + iterations=self._iteration, + tool_calls=self._tool_calls, + tokens_used=self._total_tokens, + duration_ms=duration_ms, + ) + # 如果没有最终结果,从历史中汇总 if not final_result: final_result = self._summarize_from_steps() @@ -364,7 +485,7 @@ class ReconAgent(BaseAgent): return AgentResult(success=False, error=str(e)) def _summarize_from_steps(self) -> Dict[str, Any]: - """从步骤中汇总结果""" + """从步骤中汇总结果 - 增强版,从 LLM 思考过程中提取更多信息""" # 默认结果结构 result = { "project_structure": {}, @@ -377,34 +498,90 @@ class ReconAgent(BaseAgent): "high_risk_areas": [], "dependencies": {}, "initial_findings": [], + "summary": "", # 🔥 新增:汇总 LLM 的思考 } - # 从步骤的观察结果中提取信息 + # 🔥 收集所有 LLM 的思考内容 + thoughts = [] + + # 从步骤的观察结果和思考中提取信息 for step in self._steps: + # 收集思考内容 + if step.thought: + thoughts.append(step.thought) + if step.observation: # 尝试从观察中识别技术栈等信息 obs_lower = step.observation.lower() - if "package.json" in obs_lower: + # 识别语言 + if "package.json" in obs_lower or ".js" in obs_lower or ".ts" in obs_lower: result["tech_stack"]["languages"].append("JavaScript/TypeScript") - if "requirements.txt" in obs_lower or "setup.py" in obs_lower: + if "requirements.txt" in obs_lower or "setup.py" in obs_lower or ".py" in obs_lower: result["tech_stack"]["languages"].append("Python") - if "go.mod" in obs_lower: + if "go.mod" in obs_lower or ".go" in obs_lower: result["tech_stack"]["languages"].append("Go") + if "pom.xml" in obs_lower or ".java" in obs_lower: + result["tech_stack"]["languages"].append("Java") + if ".php" in obs_lower: + result["tech_stack"]["languages"].append("PHP") + if ".rb" in obs_lower or "gemfile" in obs_lower: + result["tech_stack"]["languages"].append("Ruby") # 识别框架 if "react" in obs_lower: result["tech_stack"]["frameworks"].append("React") + if "vue" in obs_lower: + result["tech_stack"]["frameworks"].append("Vue") + if "angular" in obs_lower: + result["tech_stack"]["frameworks"].append("Angular") if "django" in obs_lower: result["tech_stack"]["frameworks"].append("Django") + if "flask" in obs_lower: + result["tech_stack"]["frameworks"].append("Flask") if "fastapi" in obs_lower: result["tech_stack"]["frameworks"].append("FastAPI") if "express" in obs_lower: result["tech_stack"]["frameworks"].append("Express") + if "spring" in obs_lower: + result["tech_stack"]["frameworks"].append("Spring") + if "streamlit" in obs_lower: + result["tech_stack"]["frameworks"].append("Streamlit") + + # 识别数据库 + if "mysql" in obs_lower or "pymysql" in obs_lower: + result["tech_stack"]["databases"].append("MySQL") + if "postgres" in obs_lower or "asyncpg" in obs_lower: + result["tech_stack"]["databases"].append("PostgreSQL") + if "mongodb" in obs_lower or "pymongo" in obs_lower: + result["tech_stack"]["databases"].append("MongoDB") + if "redis" in obs_lower: + result["tech_stack"]["databases"].append("Redis") + if "sqlite" in obs_lower: + result["tech_stack"]["databases"].append("SQLite") + + # 🔥 识别高风险区域(从观察中提取) + risk_keywords = ["api", "auth", "login", "password", "secret", "key", "token", + "admin", "upload", "download", "exec", "eval", "sql", "query"] + for keyword in risk_keywords: + if keyword in obs_lower: + # 尝试从观察中提取文件路径 + import re + file_matches = re.findall(r'[\w/]+\.(?:py|js|ts|java|php|go|rb)', step.observation) + for file_path in file_matches[:3]: # 限制数量 + if file_path not in result["high_risk_areas"]: + result["high_risk_areas"].append(file_path) # 去重 result["tech_stack"]["languages"] = list(set(result["tech_stack"]["languages"])) result["tech_stack"]["frameworks"] = list(set(result["tech_stack"]["frameworks"])) + result["tech_stack"]["databases"] = list(set(result["tech_stack"]["databases"])) + result["high_risk_areas"] = list(set(result["high_risk_areas"]))[:20] # 限制数量 + + # 🔥 汇总 LLM 的思考作为 summary + if thoughts: + # 取最后几个思考作为总结 + result["summary"] = "\n".join(thoughts[-3:]) return result diff --git a/backend/app/services/agent/agents/verification.py b/backend/app/services/agent/agents/verification.py index 919618b..3752d91 100644 --- a/backend/app/services/agent/agents/verification.py +++ b/backend/app/services/agent/agents/verification.py @@ -334,7 +334,7 @@ class VerificationAgent(BaseAgent): llm_output, tokens_this_round = await self.stream_llm_call( self._conversation_history, temperature=0.1, - max_tokens=3000, + max_tokens=4096, # 🔥 增加到 4096,避免截断 ) except asyncio.CancelledError: logger.info(f"[{self.name}] LLM call cancelled") @@ -415,6 +415,22 @@ class VerificationAgent(BaseAgent): # 处理结果 duration_ms = int((time.time() - start_time) * 1000) + # 🔥 如果被取消,返回取消结果 + if self.is_cancelled: + await self.emit_event( + "info", + f"🛑 Verification Agent 已取消: {self._iteration} 轮迭代" + ) + return AgentResult( + success=False, + error="任务已取消", + data={"findings": findings_to_verify}, + iterations=self._iteration, + tool_calls=self._tool_calls, + tokens_used=self._total_tokens, + duration_ms=duration_ms, + ) + # 处理最终结果 verified_findings = [] if final_result and "findings" in final_result: diff --git a/backend/app/services/agent/core/__init__.py b/backend/app/services/agent/core/__init__.py new file mode 100644 index 0000000..a72f0dd --- /dev/null +++ b/backend/app/services/agent/core/__init__.py @@ -0,0 +1,53 @@ +""" +DeepAudit Agent 核心模块 + +包含Agent系统的基础组件: +- state: 增强的Agent状态管理 +- registry: Agent注册表和动态Agent树管理 +- message: Agent间通信机制 +- executor: 动态Agent树执行器 +- persistence: Agent状态持久化 +""" + +from .state import AgentState, AgentStatus +from .registry import AgentRegistry, agent_registry +from .message import AgentMessage, MessageType, MessagePriority, MessageBus, message_bus +from .executor import ( + DynamicAgentExecutor, + SubAgentExecutor, + ExecutionTask, + ExecutionResult, + ExecutionMode, +) +from .persistence import ( + AgentStatePersistence, + CheckpointManager, + agent_persistence, + checkpoint_manager, +) + +__all__ = [ + # State + "AgentState", + "AgentStatus", + # Registry + "AgentRegistry", + "agent_registry", + # Message + "AgentMessage", + "MessageType", + "MessagePriority", + "MessageBus", + "message_bus", + # Executor + "DynamicAgentExecutor", + "SubAgentExecutor", + "ExecutionTask", + "ExecutionResult", + "ExecutionMode", + # Persistence + "AgentStatePersistence", + "CheckpointManager", + "agent_persistence", + "checkpoint_manager", +] diff --git a/backend/app/services/agent/core/executor.py b/backend/app/services/agent/core/executor.py new file mode 100644 index 0000000..d76ec45 --- /dev/null +++ b/backend/app/services/agent/core/executor.py @@ -0,0 +1,491 @@ +""" +动态 Agent 树执行器 + +实现完整的动态 Agent 树执行逻辑: +- 子 Agent 实际执行 +- 并行 Agent 执行 +- 结果汇总 +- 执行状态追踪 +""" + +import asyncio +import logging +import time +from typing import Dict, Any, List, Optional, Callable, Awaitable +from dataclasses import dataclass, field +from enum import Enum +from datetime import datetime, timezone + +from .state import AgentState, AgentStatus +from .registry import agent_registry +from .message import message_bus, MessageType + +logger = logging.getLogger(__name__) + + +class ExecutionMode(str, Enum): + """执行模式""" + SEQUENTIAL = "sequential" # 顺序执行 + PARALLEL = "parallel" # 并行执行 + ADAPTIVE = "adaptive" # 自适应(根据任务类型决定) + + +@dataclass +class ExecutionTask: + """执行任务""" + agent_id: str + agent_type: str + task: str + context: Dict[str, Any] = field(default_factory=dict) + priority: int = 0 # 优先级,数字越大优先级越高 + dependencies: List[str] = field(default_factory=list) # 依赖的其他任务 ID + + # 执行状态 + status: str = "pending" # pending, running, completed, failed + result: Optional[Dict[str, Any]] = None + error: Optional[str] = None + started_at: Optional[datetime] = None + finished_at: Optional[datetime] = None + + +@dataclass +class ExecutionResult: + """执行结果""" + success: bool + total_agents: int = 0 + completed_agents: int = 0 + failed_agents: int = 0 + + # 汇总的发现 + all_findings: List[Dict[str, Any]] = field(default_factory=list) + + # 各 Agent 的结果 + agent_results: Dict[str, Dict[str, Any]] = field(default_factory=dict) + + # 执行统计 + total_duration_ms: int = 0 + total_tokens: int = 0 + total_tool_calls: int = 0 + + # 错误信息 + errors: List[str] = field(default_factory=list) + + +class DynamicAgentExecutor: + """ + 动态 Agent 树执行器 + + 负责: + 1. 管理 Agent 的创建和执行 + 2. 处理并行执行和依赖关系 + 3. 汇总执行结果 + 4. 处理错误和超时 + """ + + def __init__( + self, + llm_service, + tools: Dict[str, Any], + event_emitter=None, + max_parallel: int = 5, + default_timeout: int = 600, + ): + """ + 初始化执行器 + + Args: + llm_service: LLM 服务 + tools: 可用工具 + event_emitter: 事件发射器 + max_parallel: 最大并行 Agent 数 + default_timeout: 默认超时时间(秒) + """ + self.llm_service = llm_service + self.tools = tools + self.event_emitter = event_emitter + self.max_parallel = max_parallel + self.default_timeout = default_timeout + + # 执行状态 + self._tasks: Dict[str, ExecutionTask] = {} + self._running_tasks: Dict[str, asyncio.Task] = {} + self._semaphore = asyncio.Semaphore(max_parallel) + + # 取消标志 + self._cancelled = False + + def cancel(self): + """取消所有执行""" + self._cancelled = True + + # 取消所有运行中的任务 + for task_id, task in self._running_tasks.items(): + if not task.done(): + task.cancel() + logger.info(f"Cancelled task: {task_id}") + + @property + def is_cancelled(self) -> bool: + return self._cancelled + + async def execute_agent( + self, + agent_class, + agent_config: Dict[str, Any], + input_data: Dict[str, Any], + parent_id: Optional[str] = None, + knowledge_modules: Optional[List[str]] = None, + ) -> Dict[str, Any]: + """ + 执行单个 Agent + + Args: + agent_class: Agent 类 + agent_config: Agent 配置 + input_data: 输入数据 + parent_id: 父 Agent ID + knowledge_modules: 知识模块列表 + + Returns: + Agent 执行结果 + """ + if self._cancelled: + return {"success": False, "error": "Execution cancelled"} + + async with self._semaphore: + try: + # 创建 Agent 实例 + agent = agent_class( + llm_service=self.llm_service, + tools=self.tools, + event_emitter=self.event_emitter, + parent_id=parent_id, + knowledge_modules=knowledge_modules, + **agent_config, + ) + + # 执行 Agent + start_time = time.time() + result = await asyncio.wait_for( + agent.run(input_data), + timeout=self.default_timeout, + ) + duration_ms = int((time.time() - start_time) * 1000) + + return { + "success": result.success, + "data": result.data, + "error": result.error, + "agent_id": agent.agent_id, + "iterations": result.iterations, + "tokens_used": result.tokens_used, + "tool_calls": result.tool_calls, + "duration_ms": duration_ms, + "handoff": result.handoff.to_dict() if result.handoff else None, + } + + except asyncio.TimeoutError: + logger.error(f"Agent execution timed out") + return {"success": False, "error": "Execution timed out"} + except asyncio.CancelledError: + logger.info(f"Agent execution cancelled") + return {"success": False, "error": "Execution cancelled"} + except Exception as e: + logger.error(f"Agent execution failed: {e}", exc_info=True) + return {"success": False, "error": str(e)} + + async def execute_parallel( + self, + tasks: List[ExecutionTask], + agent_factory: Callable[[ExecutionTask], Awaitable[Dict[str, Any]]], + ) -> ExecutionResult: + """ + 并行执行多个 Agent 任务 + + Args: + tasks: 任务列表 + agent_factory: Agent 工厂函数,接收任务返回执行结果 + + Returns: + 汇总的执行结果 + """ + if not tasks: + return ExecutionResult(success=True) + + start_time = time.time() + + # 按优先级排序 + sorted_tasks = sorted(tasks, key=lambda t: t.priority, reverse=True) + + # 分离有依赖和无依赖的任务 + independent_tasks = [t for t in sorted_tasks if not t.dependencies] + dependent_tasks = [t for t in sorted_tasks if t.dependencies] + + # 存储任务 + for task in sorted_tasks: + self._tasks[task.agent_id] = task + + result = ExecutionResult( + success=True, + total_agents=len(tasks), + ) + + # 先执行无依赖的任务 + if independent_tasks: + await self._execute_task_batch(independent_tasks, agent_factory, result) + + # 然后执行有依赖的任务 + for task in dependent_tasks: + if self._cancelled: + break + + # 等待依赖完成 + await self._wait_for_dependencies(task) + + # 执行任务 + await self._execute_single_task(task, agent_factory, result) + + # 计算总时长 + result.total_duration_ms = int((time.time() - start_time) * 1000) + + # 判断整体成功状态 + result.success = result.failed_agents == 0 + + return result + + async def _execute_task_batch( + self, + tasks: List[ExecutionTask], + agent_factory: Callable[[ExecutionTask], Awaitable[Dict[str, Any]]], + result: ExecutionResult, + ): + """执行一批任务""" + async_tasks = [] + + for task in tasks: + if self._cancelled: + break + + async_task = asyncio.create_task( + self._execute_single_task(task, agent_factory, result) + ) + self._running_tasks[task.agent_id] = async_task + async_tasks.append(async_task) + + # 等待所有任务完成 + if async_tasks: + await asyncio.gather(*async_tasks, return_exceptions=True) + + async def _execute_single_task( + self, + task: ExecutionTask, + agent_factory: Callable[[ExecutionTask], Awaitable[Dict[str, Any]]], + result: ExecutionResult, + ): + """执行单个任务""" + task.status = "running" + task.started_at = datetime.now(timezone.utc) + + try: + # 调用工厂函数执行 Agent + agent_result = await agent_factory(task) + + task.finished_at = datetime.now(timezone.utc) + task.result = agent_result + + if agent_result.get("success"): + task.status = "completed" + result.completed_agents += 1 + + # 收集发现 + findings = agent_result.get("data", {}).get("findings", []) + result.all_findings.extend(findings) + + # 统计 + result.total_tokens += agent_result.get("tokens_used", 0) + result.total_tool_calls += agent_result.get("tool_calls", 0) + else: + task.status = "failed" + task.error = agent_result.get("error") + result.failed_agents += 1 + result.errors.append(f"{task.agent_id}: {task.error}") + + # 保存结果 + result.agent_results[task.agent_id] = agent_result + + except Exception as e: + task.status = "failed" + task.error = str(e) + task.finished_at = datetime.now(timezone.utc) + result.failed_agents += 1 + result.errors.append(f"{task.agent_id}: {str(e)}") + logger.error(f"Task {task.agent_id} failed: {e}", exc_info=True) + + finally: + # 清理运行中的任务 + self._running_tasks.pop(task.agent_id, None) + + async def _wait_for_dependencies(self, task: ExecutionTask): + """等待任务的依赖完成""" + for dep_id in task.dependencies: + dep_task = self._tasks.get(dep_id) + if not dep_task: + continue + + # 等待依赖任务完成 + while dep_task.status in ["pending", "running"]: + if self._cancelled: + return + await asyncio.sleep(0.1) + + def get_execution_summary(self) -> Dict[str, Any]: + """获取执行摘要""" + return { + "total_tasks": len(self._tasks), + "completed": sum(1 for t in self._tasks.values() if t.status == "completed"), + "failed": sum(1 for t in self._tasks.values() if t.status == "failed"), + "pending": sum(1 for t in self._tasks.values() if t.status == "pending"), + "running": sum(1 for t in self._tasks.values() if t.status == "running"), + "tasks": { + tid: { + "status": t.status, + "agent_type": t.agent_type, + "error": t.error, + } + for tid, t in self._tasks.items() + }, + } + + +class SubAgentExecutor: + """ + 子 Agent 执行器 + + 专门用于从父 Agent 创建和执行子 Agent + """ + + def __init__( + self, + parent_agent, + llm_service, + tools: Dict[str, Any], + event_emitter=None, + ): + self.parent_agent = parent_agent + self.llm_service = llm_service + self.tools = tools + self.event_emitter = event_emitter + + self._child_agents: Dict[str, Any] = {} + self._executor = DynamicAgentExecutor( + llm_service=llm_service, + tools=tools, + event_emitter=event_emitter, + ) + + async def create_and_run_sub_agent( + self, + agent_type: str, + task: str, + context: Dict[str, Any] = None, + knowledge_modules: List[str] = None, + ) -> Dict[str, Any]: + """ + 创建并运行子 Agent + + Args: + agent_type: Agent 类型 (analysis, verification, specialist) + task: 任务描述 + context: 任务上下文 + knowledge_modules: 知识模块 + + Returns: + 子 Agent 执行结果 + """ + from ..agents import AnalysisAgent, VerificationAgent + + # 根据类型选择 Agent 类 + agent_class_map = { + "analysis": AnalysisAgent, + "verification": VerificationAgent, + } + + agent_class = agent_class_map.get(agent_type) + if not agent_class: + return {"success": False, "error": f"Unknown agent type: {agent_type}"} + + # 准备输入数据 + input_data = { + "task": task, + "task_context": context or {}, + "project_info": context.get("project_info", {}) if context else {}, + "config": context.get("config", {}) if context else {}, + } + + # 如果父 Agent 有 handoff,传递给子 Agent + if hasattr(self.parent_agent, "_incoming_handoff") and self.parent_agent._incoming_handoff: + input_data["parent_handoff"] = self.parent_agent._incoming_handoff.to_dict() + + # 执行子 Agent + result = await self._executor.execute_agent( + agent_class=agent_class, + agent_config={}, + input_data=input_data, + parent_id=self.parent_agent.agent_id, + knowledge_modules=knowledge_modules, + ) + + # 记录子 Agent + if result.get("agent_id"): + self._child_agents[result["agent_id"]] = result + + return result + + async def run_parallel_sub_agents( + self, + sub_agent_configs: List[Dict[str, Any]], + ) -> ExecutionResult: + """ + 并行运行多个子 Agent + + Args: + sub_agent_configs: 子 Agent 配置列表 + [{"agent_type": "analysis", "task": "...", "context": {...}, "knowledge_modules": [...]}] + + Returns: + 汇总的执行结果 + """ + tasks = [] + + for i, config in enumerate(sub_agent_configs): + task = ExecutionTask( + agent_id=f"sub_{self.parent_agent.agent_id}_{i}", + agent_type=config.get("agent_type", "analysis"), + task=config.get("task", ""), + context=config.get("context", {}), + priority=config.get("priority", 0), + dependencies=config.get("dependencies", []), + ) + tasks.append(task) + + async def agent_factory(task: ExecutionTask) -> Dict[str, Any]: + return await self.create_and_run_sub_agent( + agent_type=task.agent_type, + task=task.task, + context=task.context, + knowledge_modules=task.context.get("knowledge_modules"), + ) + + return await self._executor.execute_parallel(tasks, agent_factory) + + def get_child_results(self) -> Dict[str, Dict[str, Any]]: + """获取所有子 Agent 的结果""" + return self._child_agents.copy() + + def get_all_findings(self) -> List[Dict[str, Any]]: + """获取所有子 Agent 发现的漏洞""" + findings = [] + for result in self._child_agents.values(): + if result.get("success") and result.get("data"): + findings.extend(result["data"].get("findings", [])) + return findings diff --git a/backend/app/services/agent/core/message.py b/backend/app/services/agent/core/message.py new file mode 100644 index 0000000..8e6d4f5 --- /dev/null +++ b/backend/app/services/agent/core/message.py @@ -0,0 +1,290 @@ +""" +Agent 间通信机制 + +提供: +- 消息类型定义 +- 消息队列管理 +- Agent间消息传递 +""" + +import logging +import uuid +from datetime import datetime, timezone +from enum import Enum +from typing import Any, Dict, List, Optional +from dataclasses import dataclass, field + +logger = logging.getLogger(__name__) + + +class MessageType(str, Enum): + """消息类型""" + QUERY = "query" # 查询消息(请求信息) + INSTRUCTION = "instruction" # 指令消息(要求执行操作) + INFORMATION = "information" # 信息消息(分享发现或状态) + RESULT = "result" # 结果消息(任务完成报告) + ERROR = "error" # 错误消息 + + +class MessagePriority(str, Enum): + """消息优先级""" + LOW = "low" + NORMAL = "normal" + HIGH = "high" + URGENT = "urgent" + + +@dataclass +class AgentMessage: + """ + Agent 消息 + + 用于Agent间通信的消息结构 + """ + id: str = field(default_factory=lambda: f"msg_{uuid.uuid4().hex[:8]}") + from_agent: str = "" + to_agent: str = "" + content: str = "" + message_type: MessageType = MessageType.INFORMATION + priority: MessagePriority = MessagePriority.NORMAL + timestamp: str = field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + + # 状态 + delivered: bool = False + read: bool = False + + # 附加数据 + metadata: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + """转换为字典""" + return { + "id": self.id, + "from": self.from_agent, + "to": self.to_agent, + "content": self.content, + "message_type": self.message_type.value if isinstance(self.message_type, MessageType) else self.message_type, + "priority": self.priority.value if isinstance(self.priority, MessagePriority) else self.priority, + "timestamp": self.timestamp, + "delivered": self.delivered, + "read": self.read, + "metadata": self.metadata, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "AgentMessage": + """从字典创建""" + return cls( + id=data.get("id", f"msg_{uuid.uuid4().hex[:8]}"), + from_agent=data.get("from", ""), + to_agent=data.get("to", ""), + content=data.get("content", ""), + message_type=MessageType(data.get("message_type", "information")), + priority=MessagePriority(data.get("priority", "normal")), + timestamp=data.get("timestamp", datetime.now(timezone.utc).isoformat()), + delivered=data.get("delivered", False), + read=data.get("read", False), + metadata=data.get("metadata", {}), + ) + + def to_xml(self) -> str: + """转换为XML格式(用于LLM理解)""" + return f""" + + {self.from_agent} + + + {self.message_type.value if isinstance(self.message_type, MessageType) else self.message_type} + {self.priority.value if isinstance(self.priority, MessagePriority) else self.priority} + {self.timestamp} + + +{self.content} + +""" + + +class MessageBus: + """ + 消息总线 + + 管理Agent间的消息传递 + """ + + def __init__(self): + self._queues: Dict[str, List[AgentMessage]] = {} + self._message_history: List[AgentMessage] = [] + + def create_queue(self, agent_id: str) -> None: + """为Agent创建消息队列""" + if agent_id not in self._queues: + self._queues[agent_id] = [] + logger.debug(f"Created message queue for agent: {agent_id}") + + def delete_queue(self, agent_id: str) -> None: + """删除Agent的消息队列""" + if agent_id in self._queues: + del self._queues[agent_id] + logger.debug(f"Deleted message queue for agent: {agent_id}") + + def send_message( + self, + from_agent: str, + to_agent: str, + content: str, + message_type: MessageType = MessageType.INFORMATION, + priority: MessagePriority = MessagePriority.NORMAL, + metadata: Optional[Dict[str, Any]] = None, + ) -> AgentMessage: + """ + 发送消息 + + Args: + from_agent: 发送者Agent ID + to_agent: 接收者Agent ID + content: 消息内容 + message_type: 消息类型 + priority: 优先级 + metadata: 附加数据 + + Returns: + 发送的消息 + """ + message = AgentMessage( + from_agent=from_agent, + to_agent=to_agent, + content=content, + message_type=message_type, + priority=priority, + metadata=metadata or {}, + ) + + # 确保目标队列存在 + if to_agent not in self._queues: + self.create_queue(to_agent) + + # 添加到队列 + self._queues[to_agent].append(message) + message.delivered = True + + # 记录历史 + self._message_history.append(message) + + logger.debug(f"Message sent from {from_agent} to {to_agent}: {content[:50]}...") + return message + + def get_messages( + self, + agent_id: str, + unread_only: bool = True, + mark_as_read: bool = True, + ) -> List[AgentMessage]: + """ + 获取Agent的消息 + + Args: + agent_id: Agent ID + unread_only: 是否只获取未读消息 + mark_as_read: 是否标记为已读 + + Returns: + 消息列表 + """ + if agent_id not in self._queues: + return [] + + messages = self._queues[agent_id] + + if unread_only: + messages = [m for m in messages if not m.read] + + if mark_as_read: + for m in messages: + m.read = True + + return messages + + def has_unread_messages(self, agent_id: str) -> bool: + """检查是否有未读消息""" + if agent_id not in self._queues: + return False + return any(not m.read for m in self._queues[agent_id]) + + def get_unread_count(self, agent_id: str) -> int: + """获取未读消息数量""" + if agent_id not in self._queues: + return 0 + return sum(1 for m in self._queues[agent_id] if not m.read) + + def send_user_message( + self, + to_agent: str, + content: str, + priority: MessagePriority = MessagePriority.HIGH, + ) -> AgentMessage: + """发送用户消息到Agent""" + return self.send_message( + from_agent="user", + to_agent=to_agent, + content=content, + message_type=MessageType.INSTRUCTION, + priority=priority, + ) + + def send_completion_report( + self, + from_agent: str, + to_agent: str, + summary: str, + findings: List[Dict[str, Any]], + success: bool = True, + ) -> AgentMessage: + """发送任务完成报告""" + content = f""" + {"SUCCESS" if success else "FAILED"} + {summary} + {len(findings)} +""" + + return self.send_message( + from_agent=from_agent, + to_agent=to_agent, + content=content, + message_type=MessageType.RESULT, + priority=MessagePriority.HIGH, + metadata={ + "summary": summary, + "findings": findings, + "success": success, + }, + ) + + def clear_queue(self, agent_id: str) -> None: + """清空Agent的消息队列""" + if agent_id in self._queues: + self._queues[agent_id] = [] + + def clear_all(self) -> None: + """清空所有消息""" + self._queues.clear() + self._message_history.clear() + + def get_message_history( + self, + agent_id: Optional[str] = None, + limit: int = 100, + ) -> List[AgentMessage]: + """获取消息历史""" + history = self._message_history + + if agent_id: + history = [ + m for m in history + if m.from_agent == agent_id or m.to_agent == agent_id + ] + + return history[-limit:] + + +# 全局消息总线实例 +message_bus = MessageBus() diff --git a/backend/app/services/agent/core/persistence.py b/backend/app/services/agent/core/persistence.py new file mode 100644 index 0000000..5e5bd6a --- /dev/null +++ b/backend/app/services/agent/core/persistence.py @@ -0,0 +1,413 @@ +""" +Agent 状态持久化模块 + +提供 Agent 状态的持久化和恢复功能: +- Agent 状态序列化和反序列化 +- 检查点保存和恢复 +- 消息历史持久化 +- 执行记录持久化 +""" + +import json +import logging +import os +from datetime import datetime, timezone +from typing import Dict, Any, List, Optional +from pathlib import Path + +from .state import AgentState, AgentStatus +from .registry import agent_registry + +logger = logging.getLogger(__name__) + + +class AgentStatePersistence: + """ + Agent 状态持久化管理器 + + 支持: + - 文件系统持久化 + - 数据库持久化(可选) + - 检查点机制 + """ + + def __init__( + self, + persist_dir: str = "./agent_checkpoints", + use_database: bool = False, + db_session_factory=None, + ): + """ + 初始化持久化管理器 + + Args: + persist_dir: 持久化目录 + use_database: 是否使用数据库持久化 + db_session_factory: 数据库会话工厂 + """ + self.persist_dir = Path(persist_dir) + self.use_database = use_database + self.db_session_factory = db_session_factory + + # 确保目录存在 + self.persist_dir.mkdir(parents=True, exist_ok=True) + + # ============ 文件系统持久化 ============ + + def save_state(self, state: AgentState, checkpoint_name: Optional[str] = None) -> str: + """ + 保存 Agent 状态到文件 + + Args: + state: Agent 状态 + checkpoint_name: 检查点名称(可选) + + Returns: + 保存的文件路径 + """ + # 生成文件名 + if checkpoint_name: + filename = f"{state.agent_id}_{checkpoint_name}.json" + else: + timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") + filename = f"{state.agent_id}_{timestamp}.json" + + filepath = self.persist_dir / filename + + # 序列化状态 + state_dict = self._serialize_state(state) + + # 保存到文件 + with open(filepath, "w", encoding="utf-8") as f: + json.dump(state_dict, f, ensure_ascii=False, indent=2) + + logger.info(f"Saved agent state to {filepath}") + return str(filepath) + + def load_state(self, filepath: str) -> Optional[AgentState]: + """ + 从文件加载 Agent 状态 + + Args: + filepath: 文件路径 + + Returns: + Agent 状态,如果加载失败返回 None + """ + try: + with open(filepath, "r", encoding="utf-8") as f: + state_dict = json.load(f) + + state = self._deserialize_state(state_dict) + logger.info(f"Loaded agent state from {filepath}") + return state + + except Exception as e: + logger.error(f"Failed to load agent state from {filepath}: {e}") + return None + + def load_latest_checkpoint(self, agent_id: str) -> Optional[AgentState]: + """ + 加载指定 Agent 的最新检查点 + + Args: + agent_id: Agent ID + + Returns: + Agent 状态 + """ + # 查找所有匹配的检查点文件 + pattern = f"{agent_id}_*.json" + checkpoints = list(self.persist_dir.glob(pattern)) + + if not checkpoints: + logger.warning(f"No checkpoints found for agent {agent_id}") + return None + + # 按修改时间排序,取最新的 + latest = max(checkpoints, key=lambda p: p.stat().st_mtime) + return self.load_state(str(latest)) + + def list_checkpoints(self, agent_id: Optional[str] = None) -> List[Dict[str, Any]]: + """ + 列出检查点 + + Args: + agent_id: Agent ID(可选,不指定则列出所有) + + Returns: + 检查点信息列表 + """ + if agent_id: + pattern = f"{agent_id}_*.json" + else: + pattern = "*.json" + + checkpoints = [] + for filepath in self.persist_dir.glob(pattern): + stat = filepath.stat() + checkpoints.append({ + "filepath": str(filepath), + "filename": filepath.name, + "size_bytes": stat.st_size, + "created_at": datetime.fromtimestamp(stat.st_ctime, tz=timezone.utc).isoformat(), + "modified_at": datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc).isoformat(), + }) + + # 按修改时间排序 + checkpoints.sort(key=lambda x: x["modified_at"], reverse=True) + return checkpoints + + def delete_checkpoint(self, filepath: str) -> bool: + """ + 删除检查点 + + Args: + filepath: 文件路径 + + Returns: + 是否删除成功 + """ + try: + os.remove(filepath) + logger.info(f"Deleted checkpoint: {filepath}") + return True + except Exception as e: + logger.error(f"Failed to delete checkpoint {filepath}: {e}") + return False + + def cleanup_old_checkpoints( + self, + agent_id: str, + keep_count: int = 5, + ) -> int: + """ + 清理旧的检查点,只保留最新的几个 + + Args: + agent_id: Agent ID + keep_count: 保留的检查点数量 + + Returns: + 删除的检查点数量 + """ + checkpoints = self.list_checkpoints(agent_id) + + if len(checkpoints) <= keep_count: + return 0 + + # 删除旧的检查点 + to_delete = checkpoints[keep_count:] + deleted = 0 + + for cp in to_delete: + if self.delete_checkpoint(cp["filepath"]): + deleted += 1 + + return deleted + + # ============ 序列化/反序列化 ============ + + def _serialize_state(self, state: AgentState) -> Dict[str, Any]: + """序列化 Agent 状态""" + return { + "version": "1.0", + "serialized_at": datetime.now(timezone.utc).isoformat(), + "state": state.model_dump(), + } + + def _deserialize_state(self, data: Dict[str, Any]) -> AgentState: + """反序列化 Agent 状态""" + version = data.get("version", "1.0") + state_data = data.get("state", data) + + # 处理版本兼容性 + if version == "1.0": + return AgentState(**state_data) + else: + logger.warning(f"Unknown state version: {version}, attempting to load anyway") + return AgentState(**state_data) + + # ============ 数据库持久化 ============ + + async def save_state_to_db( + self, + state: AgentState, + task_id: str, + ) -> bool: + """ + 保存 Agent 状态到数据库 + + Args: + state: Agent 状态 + task_id: 关联的任务 ID + + Returns: + 是否保存成功 + """ + if not self.use_database or not self.db_session_factory: + logger.warning("Database persistence not configured") + return False + + try: + async with self.db_session_factory() as session: + from app.models.agent_task import AgentCheckpoint + + checkpoint = AgentCheckpoint( + task_id=task_id, + agent_id=state.agent_id, + agent_name=state.agent_name, + agent_type=state.agent_type, + state_data=state.model_dump_json(), + iteration=state.iteration, + status=state.status, + created_at=datetime.now(timezone.utc), + ) + + session.add(checkpoint) + await session.commit() + + logger.info(f"Saved agent state to database: {state.agent_id}") + return True + + except Exception as e: + logger.error(f"Failed to save agent state to database: {e}") + return False + + async def load_state_from_db( + self, + task_id: str, + agent_id: Optional[str] = None, + ) -> Optional[AgentState]: + """ + 从数据库加载 Agent 状态 + + Args: + task_id: 任务 ID + agent_id: Agent ID(可选) + + Returns: + Agent 状态 + """ + if not self.use_database or not self.db_session_factory: + logger.warning("Database persistence not configured") + return None + + try: + async with self.db_session_factory() as session: + from sqlalchemy import select + from app.models.agent_task import AgentCheckpoint + + query = select(AgentCheckpoint).where( + AgentCheckpoint.task_id == task_id + ) + + if agent_id: + query = query.where(AgentCheckpoint.agent_id == agent_id) + + query = query.order_by(AgentCheckpoint.created_at.desc()).limit(1) + + result = await session.execute(query) + checkpoint = result.scalar_one_or_none() + + if checkpoint: + state_data = json.loads(checkpoint.state_data) + return AgentState(**state_data) + + return None + + except Exception as e: + logger.error(f"Failed to load agent state from database: {e}") + return None + + +class CheckpointManager: + """ + 检查点管理器 + + 提供自动检查点功能: + - 定期保存检查点 + - 错误恢复 + - 状态回滚 + """ + + def __init__( + self, + persistence: AgentStatePersistence, + auto_checkpoint_interval: int = 5, # 每 N 次迭代自动保存 + ): + self.persistence = persistence + self.auto_checkpoint_interval = auto_checkpoint_interval + + self._last_checkpoint_iteration: Dict[str, int] = {} + + def should_checkpoint(self, state: AgentState) -> bool: + """ + 判断是否应该创建检查点 + + Args: + state: Agent 状态 + + Returns: + 是否应该创建检查点 + """ + last_iteration = self._last_checkpoint_iteration.get(state.agent_id, 0) + return state.iteration - last_iteration >= self.auto_checkpoint_interval + + def create_checkpoint( + self, + state: AgentState, + checkpoint_name: Optional[str] = None, + ) -> str: + """ + 创建检查点 + + Args: + state: Agent 状态 + checkpoint_name: 检查点名称 + + Returns: + 检查点文件路径 + """ + filepath = self.persistence.save_state(state, checkpoint_name) + self._last_checkpoint_iteration[state.agent_id] = state.iteration + return filepath + + def auto_checkpoint(self, state: AgentState) -> Optional[str]: + """ + 自动检查点(如果需要) + + Args: + state: Agent 状态 + + Returns: + 检查点文件路径,如果没有创建则返回 None + """ + if self.should_checkpoint(state): + return self.create_checkpoint(state) + return None + + def restore_from_checkpoint( + self, + agent_id: str, + checkpoint_filepath: Optional[str] = None, + ) -> Optional[AgentState]: + """ + 从检查点恢复 + + Args: + agent_id: Agent ID + checkpoint_filepath: 检查点文件路径(可选,不指定则使用最新的) + + Returns: + 恢复的 Agent 状态 + """ + if checkpoint_filepath: + return self.persistence.load_state(checkpoint_filepath) + else: + return self.persistence.load_latest_checkpoint(agent_id) + + +# 全局持久化管理器 +agent_persistence = AgentStatePersistence() +checkpoint_manager = CheckpointManager(agent_persistence) diff --git a/backend/app/services/agent/core/registry.py b/backend/app/services/agent/core/registry.py new file mode 100644 index 0000000..0e5ea38 --- /dev/null +++ b/backend/app/services/agent/core/registry.py @@ -0,0 +1,309 @@ +""" +Agent 注册表和动态Agent树管理 + +提供: +- Agent实例注册和管理 +- 动态Agent树结构 +- Agent状态追踪 +- 子Agent创建和销毁 +""" + +import logging +import threading +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from .state import AgentState + +logger = logging.getLogger(__name__) + + +class AgentRegistry: + """ + Agent 注册表 + + 管理所有Agent实例,维护动态Agent树结构 + """ + + def __init__(self): + self._lock = threading.RLock() + + # Agent图结构 + self._agent_graph: Dict[str, Any] = { + "nodes": {}, # agent_id -> node_info + "edges": [], # {from, to, type} + } + + # Agent实例和状态 + self._agent_instances: Dict[str, Any] = {} # agent_id -> agent_instance + self._agent_states: Dict[str, "AgentState"] = {} # agent_id -> state + + # 消息队列 + self._agent_messages: Dict[str, List[Dict[str, Any]]] = {} # agent_id -> messages + + # 根Agent + self._root_agent_id: Optional[str] = None + + # 运行中的Agent线程 + self._running_agents: Dict[str, threading.Thread] = {} + + # ============ Agent 注册 ============ + + def register_agent( + self, + agent_id: str, + agent_name: str, + agent_type: str, + task: str, + parent_id: Optional[str] = None, + agent_instance: Any = None, + state: Optional["AgentState"] = None, + knowledge_modules: Optional[List[str]] = None, + ) -> Dict[str, Any]: + """ + 注册Agent到注册表 + + Args: + agent_id: Agent唯一标识 + agent_name: Agent名称 + agent_type: Agent类型 + task: 任务描述 + parent_id: 父Agent ID + agent_instance: Agent实例 + state: Agent状态 + knowledge_modules: 加载的知识模块 + + Returns: + 注册的节点信息 + """ + logger.info(f"[AgentRegistry] register_agent 被调用: {agent_name} (id={agent_id}, parent={parent_id})") + logger.info(f"[AgentRegistry] 当前节点数: {len(self._agent_graph['nodes'])}, 节点列表: {list(self._agent_graph['nodes'].keys())}") + + with self._lock: + node = { + "id": agent_id, + "name": agent_name, + "type": agent_type, + "task": task, + "status": "running", + "parent_id": parent_id, + "created_at": datetime.now(timezone.utc).isoformat(), + "finished_at": None, + "result": None, + "knowledge_modules": knowledge_modules or [], + "children": [], + } + + self._agent_graph["nodes"][agent_id] = node + + if agent_instance: + self._agent_instances[agent_id] = agent_instance + + if state: + self._agent_states[agent_id] = state + + # 初始化消息队列 + if agent_id not in self._agent_messages: + self._agent_messages[agent_id] = [] + + # 添加边(父子关系) + if parent_id: + self._agent_graph["edges"].append({ + "from": parent_id, + "to": agent_id, + "type": "delegation", + "created_at": datetime.now(timezone.utc).isoformat(), + }) + + # 更新父节点的children列表 + if parent_id in self._agent_graph["nodes"]: + self._agent_graph["nodes"][parent_id]["children"].append(agent_id) + + # 设置根Agent + if parent_id is None and self._root_agent_id is None: + self._root_agent_id = agent_id + + logger.info(f"[AgentRegistry] 注册完成: {agent_name} ({agent_id}), parent: {parent_id}") + logger.info(f"[AgentRegistry] 注册后节点数: {len(self._agent_graph['nodes'])}, 节点列表: {list(self._agent_graph['nodes'].keys())}") + return node + + def unregister_agent(self, agent_id: str) -> None: + """注销Agent""" + with self._lock: + if agent_id in self._agent_graph["nodes"]: + del self._agent_graph["nodes"][agent_id] + + self._agent_instances.pop(agent_id, None) + self._agent_states.pop(agent_id, None) + self._agent_messages.pop(agent_id, None) + self._running_agents.pop(agent_id, None) + + # 移除相关边 + self._agent_graph["edges"] = [ + e for e in self._agent_graph["edges"] + if e["from"] != agent_id and e["to"] != agent_id + ] + + logger.info(f"Unregistered agent: {agent_id}") + + # ============ Agent 状态更新 ============ + + def update_agent_status( + self, + agent_id: str, + status: str, + result: Optional[Dict[str, Any]] = None, + ) -> None: + """更新Agent状态""" + with self._lock: + if agent_id in self._agent_graph["nodes"]: + node = self._agent_graph["nodes"][agent_id] + node["status"] = status + + if status in ["completed", "failed", "stopped"]: + node["finished_at"] = datetime.now(timezone.utc).isoformat() + + if result: + node["result"] = result + + logger.debug(f"Updated agent {agent_id} status to {status}") + + def get_agent_status(self, agent_id: str) -> Optional[str]: + """获取Agent状态""" + with self._lock: + if agent_id in self._agent_graph["nodes"]: + return self._agent_graph["nodes"][agent_id]["status"] + return None + + # ============ Agent 查询 ============ + + def get_agent(self, agent_id: str) -> Optional[Any]: + """获取Agent实例""" + return self._agent_instances.get(agent_id) + + def get_agent_state(self, agent_id: str) -> Optional["AgentState"]: + """获取Agent状态""" + return self._agent_states.get(agent_id) + + def get_agent_node(self, agent_id: str) -> Optional[Dict[str, Any]]: + """获取Agent节点信息""" + return self._agent_graph["nodes"].get(agent_id) + + def get_root_agent_id(self) -> Optional[str]: + """获取根Agent ID""" + return self._root_agent_id + + def get_children(self, agent_id: str) -> List[str]: + """获取子Agent ID列表""" + with self._lock: + node = self._agent_graph["nodes"].get(agent_id) + if node: + return node.get("children", []) + return [] + + def get_parent(self, agent_id: str) -> Optional[str]: + """获取父Agent ID""" + with self._lock: + node = self._agent_graph["nodes"].get(agent_id) + if node: + return node.get("parent_id") + return None + + # ============ Agent 树操作 ============ + + def get_agent_tree(self) -> Dict[str, Any]: + """获取完整的Agent树结构""" + with self._lock: + return { + "nodes": dict(self._agent_graph["nodes"]), + "edges": list(self._agent_graph["edges"]), + "root_agent_id": self._root_agent_id, + } + + def get_agent_tree_view(self, agent_id: Optional[str] = None) -> str: + """获取Agent树的文本视图""" + with self._lock: + lines = ["=== AGENT TREE ==="] + + root_id = agent_id or self._root_agent_id + if not root_id or root_id not in self._agent_graph["nodes"]: + return "No agents in the tree" + + def _build_tree(aid: str, depth: int = 0) -> None: + node = self._agent_graph["nodes"].get(aid) + if not node: + return + + indent = " " * depth + status_emoji = { + "running": "🔄", + "waiting": "⏳", + "completed": "✅", + "failed": "❌", + "stopped": "🛑", + }.get(node["status"], "❓") + + lines.append(f"{indent}{status_emoji} {node['name']} ({aid})") + lines.append(f"{indent} Task: {node['task'][:50]}...") + lines.append(f"{indent} Status: {node['status']}") + + if node.get("knowledge_modules"): + lines.append(f"{indent} Modules: {', '.join(node['knowledge_modules'])}") + + for child_id in node.get("children", []): + _build_tree(child_id, depth + 1) + + _build_tree(root_id) + return "\n".join(lines) + + def get_statistics(self) -> Dict[str, int]: + """获取统计信息""" + with self._lock: + stats = { + "total": len(self._agent_graph["nodes"]), + "running": 0, + "waiting": 0, + "completed": 0, + "failed": 0, + "stopped": 0, + } + + for node in self._agent_graph["nodes"].values(): + status = node.get("status", "unknown") + if status in stats: + stats[status] += 1 + + return stats + + # ============ 清理 ============ + + def clear(self) -> None: + """清空注册表""" + with self._lock: + self._agent_graph = {"nodes": {}, "edges": []} + self._agent_instances.clear() + self._agent_states.clear() + self._agent_messages.clear() + self._running_agents.clear() + self._root_agent_id = None + logger.info("Agent registry cleared") + + def cleanup_finished_agents(self) -> int: + """清理已完成的Agent""" + with self._lock: + finished_ids = [ + aid for aid, node in self._agent_graph["nodes"].items() + if node["status"] in ["completed", "failed", "stopped"] + ] + + for aid in finished_ids: + # 保留节点信息,但清理实例 + self._agent_instances.pop(aid, None) + self._running_agents.pop(aid, None) + + return len(finished_ids) + + +# 全局注册表实例 +agent_registry = AgentRegistry() diff --git a/backend/app/services/agent/core/state.py b/backend/app/services/agent/core/state.py new file mode 100644 index 0000000..2f1ac40 --- /dev/null +++ b/backend/app/services/agent/core/state.py @@ -0,0 +1,297 @@ +""" +Agent 状态管理模块 + +提供完整的Agent状态管理,支持: +- 完整的生命周期管理 +- 状态序列化和持久化 +- 暂停和恢复 +- 动态Agent树结构 +""" + +import uuid +from datetime import datetime, timezone +from enum import Enum +from typing import Any, Dict, List, Optional +from pydantic import BaseModel, Field + + +def _generate_agent_id() -> str: + """生成唯一的Agent ID""" + return f"agent_{uuid.uuid4().hex[:8]}" + + +class AgentStatus(str, Enum): + """Agent 运行状态""" + CREATED = "created" # 已创建,未开始 + RUNNING = "running" # 运行中 + WAITING = "waiting" # 等待中(等待消息或输入) + PAUSED = "paused" # 已暂停 + COMPLETED = "completed" # 已完成 + FAILED = "failed" # 失败 + STOPPED = "stopped" # 被停止 + STOPPING = "stopping" # 正在停止 + + +class AgentState(BaseModel): + """ + Agent 状态模型 + + 包含Agent执行所需的所有状态信息,支持: + - 完整的生命周期管理 + - 状态序列化和持久化 + - 暂停和恢复 + - 动态Agent树结构 + """ + + # ============ 基本信息 ============ + agent_id: str = Field(default_factory=_generate_agent_id) + agent_name: str = "DeepAudit Agent" + agent_type: str = "generic" # recon, analysis, verification, specialist + parent_id: Optional[str] = None # 父Agent ID(用于动态Agent树) + + # ============ 任务信息 ============ + task: str = "" + task_context: Dict[str, Any] = Field(default_factory=dict) + inherited_context: Dict[str, Any] = Field(default_factory=dict) # 从父Agent继承的上下文 + + # ============ 知识模块 ============ + knowledge_modules: List[str] = Field(default_factory=list) # 加载的知识模块名称 + + # ============ 执行状态 ============ + status: AgentStatus = AgentStatus.CREATED + iteration: int = 0 + max_iterations: int = 50 + + # ============ 对话历史 ============ + messages: List[Dict[str, Any]] = Field(default_factory=list) + system_prompt: str = "" + + # ============ 执行记录 ============ + actions_taken: List[Dict[str, Any]] = Field(default_factory=list) + observations: List[Dict[str, Any]] = Field(default_factory=list) + errors: List[str] = Field(default_factory=list) + + # ============ 发现列表 ============ + findings: List[Dict[str, Any]] = Field(default_factory=list) + + # ============ 时间戳 ============ + created_at: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + started_at: Optional[str] = None + last_updated: str = Field(default_factory=lambda: datetime.now(timezone.utc).isoformat()) + finished_at: Optional[str] = None + + # ============ 等待状态 ============ + waiting_for_input: bool = False + waiting_start_time: Optional[datetime] = None + waiting_reason: str = "" + waiting_timeout_seconds: int = 600 # 10分钟超时 + + # ============ 最终结果 ============ + final_result: Optional[Dict[str, Any]] = None + + # ============ 统计信息 ============ + total_tokens: int = 0 + tool_calls: int = 0 + + # ============ 标志位 ============ + stop_requested: bool = False + max_iterations_warning_sent: bool = False + + class Config: + use_enum_values = True + + # ============ 状态管理方法 ============ + + def start(self) -> None: + """开始执行""" + self.status = AgentStatus.RUNNING + self.started_at = datetime.now(timezone.utc).isoformat() + self._update_timestamp() + + def increment_iteration(self) -> None: + """增加迭代次数""" + self.iteration += 1 + self._update_timestamp() + + def set_completed(self, final_result: Optional[Dict[str, Any]] = None) -> None: + """标记为完成""" + self.status = AgentStatus.COMPLETED + self.final_result = final_result + self.finished_at = datetime.now(timezone.utc).isoformat() + self._update_timestamp() + + def set_failed(self, error: str) -> None: + """标记为失败""" + self.status = AgentStatus.FAILED + self.add_error(error) + self.finished_at = datetime.now(timezone.utc).isoformat() + self._update_timestamp() + + def request_stop(self) -> None: + """请求停止""" + self.stop_requested = True + self.status = AgentStatus.STOPPING + self._update_timestamp() + + def set_stopped(self) -> None: + """标记为已停止""" + self.status = AgentStatus.STOPPED + self.finished_at = datetime.now(timezone.utc).isoformat() + self._update_timestamp() + + # ============ 等待状态管理 ============ + + def enter_waiting_state(self, reason: str = "等待消息") -> None: + """进入等待状态""" + self.waiting_for_input = True + self.waiting_start_time = datetime.now(timezone.utc) + self.waiting_reason = reason + self.status = AgentStatus.WAITING + self._update_timestamp() + + def resume_from_waiting(self, new_task: Optional[str] = None) -> None: + """从等待状态恢复""" + self.waiting_for_input = False + self.waiting_start_time = None + self.waiting_reason = "" + self.stop_requested = False + self.status = AgentStatus.RUNNING + if new_task: + self.task = new_task + self._update_timestamp() + + def has_waiting_timeout(self) -> bool: + """检查等待是否超时""" + if not self.waiting_for_input or not self.waiting_start_time: + return False + + if self.stop_requested or self.status in [AgentStatus.COMPLETED, AgentStatus.FAILED]: + return False + + elapsed = (datetime.now(timezone.utc) - self.waiting_start_time).total_seconds() + return elapsed > self.waiting_timeout_seconds + + def is_waiting_for_input(self) -> bool: + """是否在等待输入""" + return self.waiting_for_input + + # ============ 执行控制 ============ + + def should_stop(self) -> bool: + """是否应该停止""" + return ( + self.stop_requested or + self.status in [AgentStatus.COMPLETED, AgentStatus.FAILED, AgentStatus.STOPPED] or + self.has_reached_max_iterations() + ) + + def has_reached_max_iterations(self) -> bool: + """是否达到最大迭代次数""" + return self.iteration >= self.max_iterations + + def is_approaching_max_iterations(self, threshold: float = 0.85) -> bool: + """是否接近最大迭代次数""" + return self.iteration >= int(self.max_iterations * threshold) + + # ============ 消息管理 ============ + + def add_message(self, role: str, content: Any) -> None: + """添加消息""" + self.messages.append({ + "role": role, + "content": content, + "timestamp": datetime.now(timezone.utc).isoformat(), + }) + self._update_timestamp() + + def get_conversation_history(self) -> List[Dict[str, Any]]: + """获取对话历史(不含时间戳,用于LLM调用)""" + return [{"role": m["role"], "content": m["content"]} for m in self.messages] + + # ============ 执行记录 ============ + + def add_action(self, action: Dict[str, Any]) -> None: + """记录执行的动作""" + self.actions_taken.append({ + "iteration": self.iteration, + "timestamp": datetime.now(timezone.utc).isoformat(), + "action": action, + }) + self.tool_calls += 1 + self._update_timestamp() + + def add_observation(self, observation: Dict[str, Any]) -> None: + """记录观察结果""" + self.observations.append({ + "iteration": self.iteration, + "timestamp": datetime.now(timezone.utc).isoformat(), + "observation": observation, + }) + self._update_timestamp() + + def add_error(self, error: str) -> None: + """记录错误""" + self.errors.append(f"Iteration {self.iteration}: {error}") + self._update_timestamp() + + def add_finding(self, finding: Dict[str, Any]) -> None: + """添加发现""" + finding["discovered_at"] = datetime.now(timezone.utc).isoformat() + finding["discovered_by"] = self.agent_id + self.findings.append(finding) + self._update_timestamp() + + # ============ 上下文管理 ============ + + def update_context(self, key: str, value: Any) -> None: + """更新任务上下文""" + self.task_context[key] = value + self._update_timestamp() + + def inherit_context(self, parent_context: Dict[str, Any]) -> None: + """继承父Agent的上下文""" + self.inherited_context = parent_context.copy() + self._update_timestamp() + + # ============ 统计和摘要 ============ + + def add_tokens(self, tokens: int) -> None: + """添加token使用量""" + self.total_tokens += tokens + self._update_timestamp() + + def get_execution_summary(self) -> Dict[str, Any]: + """获取执行摘要""" + return { + "agent_id": self.agent_id, + "agent_name": self.agent_name, + "agent_type": self.agent_type, + "parent_id": self.parent_id, + "task": self.task, + "status": self.status, + "iteration": self.iteration, + "max_iterations": self.max_iterations, + "total_tokens": self.total_tokens, + "tool_calls": self.tool_calls, + "findings_count": len(self.findings), + "errors_count": len(self.errors), + "created_at": self.created_at, + "started_at": self.started_at, + "finished_at": self.finished_at, + "duration_seconds": self._calculate_duration(), + "knowledge_modules": self.knowledge_modules, + } + + def _calculate_duration(self) -> Optional[float]: + """计算执行时长""" + if not self.started_at: + return None + + end_time = self.finished_at or datetime.now(timezone.utc).isoformat() + start = datetime.fromisoformat(self.started_at.replace('Z', '+00:00')) + end = datetime.fromisoformat(end_time.replace('Z', '+00:00')) + return (end - start).total_seconds() + + def _update_timestamp(self) -> None: + """更新最后修改时间""" + self.last_updated = datetime.now(timezone.utc).isoformat() diff --git a/backend/app/services/agent/event_manager.py b/backend/app/services/agent/event_manager.py index d7f5d85..f40b032 100644 --- a/backend/app/services/agent/event_manager.py +++ b/backend/app/services/agent/event_manager.py @@ -354,7 +354,7 @@ class EventManager: """创建或获取事件队列""" if task_id not in self._event_queues: # 🔥 使用较大的队列容量,缓存更多 token 事件 - self._event_queues[task_id] = asyncio.Queue(maxsize=1000) + self._event_queues[task_id] = asyncio.Queue(maxsize=5000) return self._event_queues[task_id] def remove_queue(self, task_id: str): diff --git a/backend/app/services/agent/graph/runner.py b/backend/app/services/agent/graph/runner.py index 563678a..1514e8d 100644 --- a/backend/app/services/agent/graph/runner.py +++ b/backend/app/services/agent/graph/runner.py @@ -193,17 +193,36 @@ class AgentRunner: """初始化工具集""" await self.event_emitter.emit_info("初始化 Agent 工具集...") + # 🔥 导入新工具 + from app.services.agent.tools import ( + ThinkTool, ReflectTool, + CreateVulnerabilityReportTool, + ) + # 🔥 导入知识查询工具 + from app.services.agent.knowledge import ( + SecurityKnowledgeQueryTool, + GetVulnerabilityKnowledgeTool, + ) + + # 🔥 获取排除模式和目标文件 + exclude_patterns = self.task.exclude_patterns or [] + target_files = self.task.target_files or None + # ============ 基础工具(所有 Agent 共享)============ base_tools = { - "read_file": FileReadTool(self.project_root), - "list_files": ListFilesTool(self.project_root), + "read_file": FileReadTool(self.project_root, exclude_patterns, target_files), + "list_files": ListFilesTool(self.project_root, exclude_patterns, target_files), + # 🔥 新增:思考工具(所有Agent可用) + "think": ThinkTool(), } # ============ Recon Agent 专属工具 ============ # 职责:信息收集、项目结构分析、技术栈识别 self.recon_tools = { **base_tools, - "search_code": FileSearchTool(self.project_root), + "search_code": FileSearchTool(self.project_root, exclude_patterns, target_files), + # 🔥 新增:反思工具 + "reflect": ReflectTool(), } # RAG 工具(Recon 用于语义搜索) @@ -214,10 +233,11 @@ class AgentRunner: # 职责:漏洞分析、代码审计、模式匹配 self.analysis_tools = { **base_tools, - "search_code": FileSearchTool(self.project_root), + "search_code": FileSearchTool(self.project_root, exclude_patterns, target_files), # 模式匹配和代码分析 "pattern_match": PatternMatchTool(self.project_root), - "code_analysis": CodeAnalysisTool(self.llm_service), + # TODO: code_analysis 工具暂时禁用,因为 LLM 调用经常失败 + # "code_analysis": CodeAnalysisTool(self.llm_service), "dataflow_analysis": DataFlowAnalysisTool(self.llm_service), # 外部静态分析工具 "semgrep_scan": SemgrepTool(self.project_root), @@ -227,6 +247,11 @@ class AgentRunner: "npm_audit": NpmAuditTool(self.project_root), "safety_scan": SafetyTool(self.project_root), "osv_scan": OSVScannerTool(self.project_root), + # 🔥 新增:反思工具 + "reflect": ReflectTool(), + # 🔥 新增:安全知识查询工具(基于RAG) + "query_security_knowledge": SecurityKnowledgeQueryTool(), + "get_vulnerability_knowledge": GetVulnerabilityKnowledgeTool(), } # RAG 工具(Analysis 用于安全相关代码搜索) @@ -241,6 +266,10 @@ class AgentRunner: # 验证工具 "vulnerability_validation": VulnerabilityValidationTool(self.llm_service), "dataflow_analysis": DataFlowAnalysisTool(self.llm_service), + # 🔥 新增:漏洞报告工具(仅Verification可用) + "create_vulnerability_report": CreateVulnerabilityReportTool(), + # 🔥 新增:反思工具 + "reflect": ReflectTool(), } # 沙箱工具(仅 Verification Agent 可用) diff --git a/backend/app/services/agent/knowledge/__init__.py b/backend/app/services/agent/knowledge/__init__.py new file mode 100644 index 0000000..207d031 --- /dev/null +++ b/backend/app/services/agent/knowledge/__init__.py @@ -0,0 +1,59 @@ +""" +知识模块系统 - 基于RAG的安全知识检索 + +提供专业的安全知识检索能力,支持: +- 漏洞类型知识(SQL注入、XSS、命令注入等) +- 框架安全知识(FastAPI、Django、Flask、Express等) +- 安全最佳实践 +- 修复建议 +- 代码模式识别 + +知识库采用模块化组织: +- vulnerabilities/: 漏洞类型知识 +- frameworks/: 框架安全知识 +""" + +# 基础定义 +from .base import KnowledgeDocument, KnowledgeCategory + +# 知识加载器 +from .loader import ( + KnowledgeLoader, + knowledge_loader, + get_available_modules, + get_module_content, +) + +# RAG知识检索 +from .rag_knowledge import ( + SecurityKnowledgeRAG, + security_knowledge_rag, +) + +# 知识查询工具 +from .tools import ( + SecurityKnowledgeQueryTool, + GetVulnerabilityKnowledgeTool, + ListKnowledgeModulesTool, +) + +__all__ = [ + # 基础定义 + "KnowledgeDocument", + "KnowledgeCategory", + + # 知识加载器 + "KnowledgeLoader", + "knowledge_loader", + "get_available_modules", + "get_module_content", + + # RAG知识检索 + "SecurityKnowledgeRAG", + "security_knowledge_rag", + + # 知识查询工具 + "SecurityKnowledgeQueryTool", + "GetVulnerabilityKnowledgeTool", + "ListKnowledgeModulesTool", +] diff --git a/backend/app/services/agent/knowledge/base.py b/backend/app/services/agent/knowledge/base.py new file mode 100644 index 0000000..7bddb02 --- /dev/null +++ b/backend/app/services/agent/knowledge/base.py @@ -0,0 +1,61 @@ +""" +知识模块基础定义 + +定义知识文档的数据结构和类别 +""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional + + +class KnowledgeCategory(Enum): + """知识类别""" + VULNERABILITY = "vulnerability" # 漏洞类型 + FRAMEWORK = "framework" # 框架安全 + BEST_PRACTICE = "best_practice" # 最佳实践 + REMEDIATION = "remediation" # 修复建议 + CODE_PATTERN = "code_pattern" # 代码模式 + COMPLIANCE = "compliance" # 合规要求 + + +@dataclass +class KnowledgeDocument: + """知识文档""" + id: str + title: str + content: str + category: KnowledgeCategory + tags: List[str] = field(default_factory=list) + severity: Optional[str] = None + cwe_ids: List[str] = field(default_factory=list) + owasp_ids: List[str] = field(default_factory=list) + metadata: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + return { + "id": self.id, + "title": self.title, + "content": self.content, + "category": self.category.value, + "tags": self.tags, + "severity": self.severity, + "cwe_ids": self.cwe_ids, + "owasp_ids": self.owasp_ids, + "metadata": self.metadata, + } + + def to_embedding_text(self) -> str: + """生成用于嵌入的文本""" + parts = [ + f"Title: {self.title}", + f"Category: {self.category.value}", + ] + if self.tags: + parts.append(f"Tags: {', '.join(self.tags)}") + if self.cwe_ids: + parts.append(f"CWE: {', '.join(self.cwe_ids)}") + if self.owasp_ids: + parts.append(f"OWASP: {', '.join(self.owasp_ids)}") + parts.append(f"Content: {self.content}") + return "\n".join(parts) diff --git a/backend/app/services/agent/knowledge/frameworks/__init__.py b/backend/app/services/agent/knowledge/frameworks/__init__.py new file mode 100644 index 0000000..63b7586 --- /dev/null +++ b/backend/app/services/agent/knowledge/frameworks/__init__.py @@ -0,0 +1,32 @@ +""" +框架安全知识模块 + +包含各种框架的安全特性和常见漏洞模式 +""" + +from .fastapi import FASTAPI_SECURITY +from .django import DJANGO_SECURITY +from .flask import FLASK_SECURITY +from .express import EXPRESS_SECURITY +from .react import REACT_SECURITY +from .supabase import SUPABASE_SECURITY + +# 所有框架知识文档 +ALL_FRAMEWORK_DOCS = [ + FASTAPI_SECURITY, + DJANGO_SECURITY, + FLASK_SECURITY, + EXPRESS_SECURITY, + REACT_SECURITY, + SUPABASE_SECURITY, +] + +__all__ = [ + "ALL_FRAMEWORK_DOCS", + "FASTAPI_SECURITY", + "DJANGO_SECURITY", + "FLASK_SECURITY", + "EXPRESS_SECURITY", + "REACT_SECURITY", + "SUPABASE_SECURITY", +] diff --git a/backend/app/services/agent/knowledge/frameworks/django.py b/backend/app/services/agent/knowledge/frameworks/django.py new file mode 100644 index 0000000..3a09718 --- /dev/null +++ b/backend/app/services/agent/knowledge/frameworks/django.py @@ -0,0 +1,117 @@ +""" +Django 框架安全知识 +""" + +from ..base import KnowledgeDocument, KnowledgeCategory + + +DJANGO_SECURITY = KnowledgeDocument( + id="framework_django", + title="Django Security", + category=KnowledgeCategory.FRAMEWORK, + tags=["django", "python", "web", "orm"], + content=""" +Django 内置了许多安全保护,但不当使用仍可能引入漏洞。 + +## 内置安全特性 +1. CSRF保护 +2. XSS防护(模板自动转义) +3. SQL注入防护(ORM) +4. 点击劫持防护 +5. 安全的密码哈希 + +## 常见漏洞模式 + +### SQL注入 +```python +# 危险 - raw()和extra() +User.objects.raw(f"SELECT * FROM users WHERE name = '{name}'") +User.objects.extra(where=[f"name = '{name}'"]) + +# 危险 - RawSQL +from django.db.models.expressions import RawSQL +User.objects.annotate(val=RawSQL(f"SELECT {user_input}")) + +# 安全 - 使用ORM +User.objects.filter(name=name) +User.objects.raw("SELECT * FROM users WHERE name = %s", [name]) +``` + +### XSS +```python +# 危险 - 禁用自动转义 +{{ user_input|safe }} +{% autoescape off %}{{ user_input }}{% endautoescape %} +mark_safe(user_input) + +# 安全 - 默认转义 +{{ user_input }} +``` + +### CSRF绕过 +```python +# 危险 - 禁用CSRF +@csrf_exempt +def my_view(request): + pass + +# 危险 - 全局禁用 +MIDDLEWARE = [ + # 'django.middleware.csrf.CsrfViewMiddleware', # 被注释 +] +``` + +### 不安全的反序列化 +```python +# 危险 - 签名数据可被篡改 +from django.core import signing +data = signing.loads(user_input) # 如果SECRET_KEY泄露 + +# 危险 - pickle +import pickle +data = pickle.loads(request.body) +``` + +### 敏感信息泄露 +```python +# 危险 - DEBUG模式在生产环境 +DEBUG = True # settings.py + +# 危险 - 详细错误信息 +ALLOWED_HOSTS = [] # 空列表在DEBUG=False时会报错 +``` + +### 文件上传 +```python +# 危险 - 不验证文件类型 +def upload(request): + file = request.FILES['file'] + with open(f'/uploads/{file.name}', 'wb') as f: + f.write(file.read()) + +# 安全 - 验证和重命名 +import uuid +def upload(request): + file = request.FILES['file'] + ext = os.path.splitext(file.name)[1].lower() + if ext not in ['.jpg', '.png', '.pdf']: + raise ValidationError("Invalid file type") + safe_name = f"{uuid.uuid4()}{ext}" + # 使用Django的文件存储 + default_storage.save(safe_name, file) +``` + +## 安全配置检查 +```python +# settings.py 安全配置 +DEBUG = False +SECRET_KEY = os.environ.get('DJANGO_SECRET_KEY') +ALLOWED_HOSTS = ['example.com'] +SECURE_SSL_REDIRECT = True +SESSION_COOKIE_SECURE = True +CSRF_COOKIE_SECURE = True +SECURE_HSTS_SECONDS = 31536000 +X_FRAME_OPTIONS = 'DENY' +``` +""", +) diff --git a/backend/app/services/agent/knowledge/frameworks/express.py b/backend/app/services/agent/knowledge/frameworks/express.py new file mode 100644 index 0000000..b55cf76 --- /dev/null +++ b/backend/app/services/agent/knowledge/frameworks/express.py @@ -0,0 +1,148 @@ +""" +Express.js 框架安全知识 +""" + +from ..base import KnowledgeDocument, KnowledgeCategory + + +EXPRESS_SECURITY = KnowledgeDocument( + id="framework_express", + title="Express.js Security", + category=KnowledgeCategory.FRAMEWORK, + tags=["express", "nodejs", "javascript", "api"], + content=""" +Express.js 是Node.js最流行的Web框架,需要注意多种安全问题。 + +## 常见漏洞模式 + +### NoSQL注入 +```javascript +// 危险 - MongoDB查询注入 +app.post('/login', async (req, res) => { + const user = await User.findOne({ + username: req.body.username, + password: req.body.password + }); + // 攻击: {"username": {"$ne": ""}, "password": {"$ne": ""}} +}); + +// 安全 - 类型验证 +app.post('/login', async (req, res) => { + const { username, password } = req.body; + if (typeof username !== 'string' || typeof password !== 'string') { + return res.status(400).json({ error: 'Invalid input' }); + } + const user = await User.findOne({ username, password }); +}); +``` + +### 原型污染 +```javascript +// 危险 - 合并用户输入 +const merge = require('lodash.merge'); +app.post('/config', (req, res) => { + merge(config, req.body); + // 攻击: {"__proto__": {"isAdmin": true}} +}); + +// 安全 - 使用Object.assign或白名单 +app.post('/config', (req, res) => { + const allowed = ['theme', 'language']; + allowed.forEach(key => { + if (req.body[key]) config[key] = req.body[key]; + }); +}); +``` + +### 命令注入 +```javascript +// 危险 +const { exec } = require('child_process'); +app.get('/ping', (req, res) => { + exec(`ping ${req.query.host}`, (err, stdout) => { + res.send(stdout); + }); +}); + +// 安全 - 使用execFile和参数数组 +const { execFile } = require('child_process'); +app.get('/ping', (req, res) => { + execFile('ping', ['-c', '4', req.query.host], (err, stdout) => { + res.send(stdout); + }); +}); +``` + +### XSS +```javascript +// 危险 - 直接输出用户输入 +app.get('/search', (req, res) => { + res.send(`

Results for: ${req.query.q}

`); +}); + +// 安全 - 使用模板引擎或转义 +const escape = require('escape-html'); +app.get('/search', (req, res) => { + res.send(`

Results for: ${escape(req.query.q)}

`); +}); +``` + +### 路径遍历 +```javascript +// 危险 +app.get('/files/:name', (req, res) => { + res.sendFile(`/uploads/${req.params.name}`); +}); + +// 安全 - 验证路径 +const path = require('path'); +app.get('/files/:name', (req, res) => { + const safePath = path.join('/uploads', req.params.name); + if (!safePath.startsWith('/uploads/')) { + return res.status(400).send('Invalid path'); + } + res.sendFile(safePath); +}); +``` + +### 不安全的依赖 +```javascript +// 危险 - 使用有漏洞的包 +const serialize = require('node-serialize'); +const obj = serialize.unserialize(userInput); // RCE! + +// 安全 - 使用JSON +const obj = JSON.parse(userInput); +``` + +## 安全中间件 +```javascript +const helmet = require('helmet'); +const rateLimit = require('express-rate-limit'); + +// 安全头 +app.use(helmet()); + +// 速率限制 +app.use(rateLimit({ + windowMs: 15 * 60 * 1000, + max: 100 +})); + +// CORS +const cors = require('cors'); +app.use(cors({ + origin: 'https://example.com', + credentials: true +})); +``` + +## 安全检查清单 +1. 使用helmet设置安全头 +2. 实现速率限制 +3. 验证所有用户输入类型 +4. 使用参数化查询 +5. 定期更新依赖 (npm audit) +6. 不要在错误中暴露堆栈信息 +""", +) diff --git a/backend/app/services/agent/knowledge/frameworks/fastapi.py b/backend/app/services/agent/knowledge/frameworks/fastapi.py new file mode 100644 index 0000000..2aac577 --- /dev/null +++ b/backend/app/services/agent/knowledge/frameworks/fastapi.py @@ -0,0 +1,109 @@ +""" +FastAPI 框架安全知识 +""" + +from ..base import KnowledgeDocument, KnowledgeCategory + + +FASTAPI_SECURITY = KnowledgeDocument( + id="framework_fastapi", + title="FastAPI Security", + category=KnowledgeCategory.FRAMEWORK, + tags=["fastapi", "python", "api", "async", "pydantic"], + content=""" +FastAPI 是一个现代Python Web框架,内置了许多安全特性,但仍需注意一些常见问题。 + +## 安全特性 +1. Pydantic自动数据验证 +2. 自动生成OpenAPI文档 +3. 内置OAuth2/JWT支持 +4. 依赖注入系统 + +## 常见漏洞模式 + +### SQL注入 +```python +# 危险 - 原始SQL +@app.get("/users") +async def get_users(name: str): + query = f"SELECT * FROM users WHERE name = '{name}'" + return await database.fetch_all(query) + +# 安全 - 参数化查询 +@app.get("/users") +async def get_users(name: str): + query = "SELECT * FROM users WHERE name = :name" + return await database.fetch_all(query, {"name": name}) +``` + +### IDOR +```python +# 危险 - 无权限检查 +@app.get("/users/{user_id}") +async def get_user(user_id: int): + return await User.get(user_id) + +# 安全 - 验证权限 +@app.get("/users/{user_id}") +async def get_user(user_id: int, current_user: User = Depends(get_current_user)): + if user_id != current_user.id and not current_user.is_admin: + raise HTTPException(status_code=403) + return await User.get(user_id) +``` + +### 路径遍历 +```python +# 危险 +@app.get("/files/{filename}") +async def get_file(filename: str): + return FileResponse(f"/uploads/{filename}") + +# 安全 - 验证路径 +@app.get("/files/{filename}") +async def get_file(filename: str): + safe_path = Path("/uploads").resolve() / filename + if not str(safe_path.resolve()).startswith(str(Path("/uploads").resolve())): + raise HTTPException(status_code=400) + return FileResponse(safe_path) +``` + +### JWT配置问题 +```python +# 危险 - 弱密钥 +SECRET_KEY = "secret" + +# 危险 - 不验证签名 +jwt.decode(token, options={"verify_signature": False}) + +# 安全 +SECRET_KEY = os.environ.get("JWT_SECRET_KEY") +jwt.decode(token, SECRET_KEY, algorithms=["HS256"]) +``` + +### CORS配置 +```python +# 危险 - 允许所有来源 +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, # 危险组合! +) + +# 安全 - 指定来源 +app.add_middleware( + CORSMiddleware, + allow_origins=["https://example.com"], + allow_credentials=True, +) +``` + +## 安全检查清单 +1. 所有端点是否有适当的认证 +2. 是否使用Depends进行权限检查 +3. 文件操作是否验证路径 +4. SQL查询是否参数化 +5. CORS配置是否合理 +6. JWT密钥是否安全存储 +7. 敏感数据是否在响应中暴露 +""", +) diff --git a/backend/app/services/agent/knowledge/frameworks/flask.py b/backend/app/services/agent/knowledge/frameworks/flask.py new file mode 100644 index 0000000..e96a511 --- /dev/null +++ b/backend/app/services/agent/knowledge/frameworks/flask.py @@ -0,0 +1,139 @@ +""" +Flask 框架安全知识 +""" + +from ..base import KnowledgeDocument, KnowledgeCategory + + +FLASK_SECURITY = KnowledgeDocument( + id="framework_flask", + title="Flask Security", + category=KnowledgeCategory.FRAMEWORK, + tags=["flask", "python", "web", "jinja2"], + content=""" +Flask 是一个轻量级框架,安全性很大程度上取决于开发者的实现。 + +## 常见漏洞模式 + +### 模板注入 (SSTI) +```python +# 危险 - 用户输入作为模板 +from flask import render_template_string +@app.route('/hello') +def hello(): + name = request.args.get('name') + return render_template_string(f'Hello {name}!') + # 攻击: ?name={{config}} + +# 安全 - 使用参数 +@app.route('/hello') +def hello(): + name = request.args.get('name') + return render_template_string('Hello {{ name }}!', name=name) +``` + +### XSS +```python +# 危险 - 禁用转义 +from markupsafe import Markup +return Markup(user_input) + +# 模板中 +{{ user_input|safe }} + +# 安全 - 默认转义 +return render_template('page.html', content=user_input) +``` + +### SQL注入 +```python +# 危险 - 字符串拼接 +@app.route('/user/') +def get_user(name): + cursor.execute(f"SELECT * FROM users WHERE name = '{name}'") + +# 安全 - 参数化 +@app.route('/user/') +def get_user(name): + cursor.execute("SELECT * FROM users WHERE name = ?", (name,)) +``` + +### 会话安全 +```python +# 危险 - 弱密钥 +app.secret_key = 'dev' + +# 危险 - 硬编码密钥 +app.secret_key = 'super-secret-key-12345' + +# 安全 +app.secret_key = os.environ.get('FLASK_SECRET_KEY') +``` + +### 文件上传 +```python +# 危险 - 不验证文件 +@app.route('/upload', methods=['POST']) +def upload(): + file = request.files['file'] + file.save(f'/uploads/{file.filename}') + +# 安全 - 验证和安全文件名 +from werkzeug.utils import secure_filename +ALLOWED_EXTENSIONS = {'png', 'jpg', 'pdf'} + +def allowed_file(filename): + return '.' in filename and \\ + filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS + +@app.route('/upload', methods=['POST']) +def upload(): + file = request.files['file'] + if file and allowed_file(file.filename): + filename = secure_filename(file.filename) + file.save(os.path.join(app.config['UPLOAD_FOLDER'], filename)) +``` + +### 开放重定向 +```python +# 危险 - 未验证的重定向 +@app.route('/redirect') +def redirect_url(): + url = request.args.get('url') + return redirect(url) + +# 安全 - 验证URL +from urllib.parse import urlparse + +@app.route('/redirect') +def redirect_url(): + url = request.args.get('url', '/') + # 只允许相对路径或同域名 + parsed = urlparse(url) + if parsed.netloc and parsed.netloc != request.host: + return redirect('/') + return redirect(url) +``` + +### Debug模式 +```python +# 危险 - 生产环境开启debug +if __name__ == '__main__': + app.run(debug=True) # 可能导致RCE + +# 安全 +if __name__ == '__main__': + app.run(debug=os.environ.get('FLASK_DEBUG', 'False') == 'True') +``` + +## 安全配置 +```python +app.config.update( + SECRET_KEY=os.environ.get('SECRET_KEY'), + SESSION_COOKIE_SECURE=True, + SESSION_COOKIE_HTTPONLY=True, + SESSION_COOKIE_SAMESITE='Lax', +) +``` +""", +) diff --git a/backend/app/services/agent/knowledge/frameworks/react.py b/backend/app/services/agent/knowledge/frameworks/react.py new file mode 100644 index 0000000..403b5b6 --- /dev/null +++ b/backend/app/services/agent/knowledge/frameworks/react.py @@ -0,0 +1,137 @@ +""" +React 框架安全知识 +""" + +from ..base import KnowledgeDocument, KnowledgeCategory + + +REACT_SECURITY = KnowledgeDocument( + id="framework_react", + title="React Security", + category=KnowledgeCategory.FRAMEWORK, + tags=["react", "javascript", "frontend", "jsx"], + content=""" +React 默认对XSS有较好的防护,但仍有一些需要注意的安全问题。 + +## 安全特性 +1. JSX自动转义 +2. 虚拟DOM隔离 + +## 常见漏洞模式 + +### dangerouslySetInnerHTML +```jsx +// 危险 - 直接渲染HTML +function Comment({ content }) { + return
; +} + +// 安全 - 使用DOMPurify +import DOMPurify from 'dompurify'; +function Comment({ content }) { + return
; +} +``` + +### href/src注入 +```jsx +// 危险 - javascript:协议 +function Link({ url }) { + return Click; + // 攻击: url = "javascript:alert('XSS')" +} + +// 安全 - 验证协议 +function Link({ url }) { + const safeUrl = url.startsWith('http') ? url : '#'; + return Click; +} +``` + +### eval和Function +```jsx +// 危险 +function Calculator({ expression }) { + const result = eval(expression); // RCE风险 + return
{result}
; +} + +// 安全 - 使用安全的表达式解析器 +import { evaluate } from 'mathjs'; +function Calculator({ expression }) { + const result = evaluate(expression); + return
{result}
; +} +``` + +### 服务端渲染(SSR) XSS +```jsx +// 危险 - Next.js中 +export async function getServerSideProps({ query }) { + return { + props: { + search: query.q // 未转义 + } + }; +} + +// 页面中 +function Page({ search }) { + return + + +javascript:alert('XSS') + +``` + +## 安全实践 +1. 输出编码/HTML转义 +2. 使用模板引擎的自动转义 +3. Content-Type设置正确 +4. 使用CSP头 + +## 修复示例 +```python +# 安全 - 使用escape +from markupsafe import escape +return f"

搜索结果: {escape(query)}

" + +# 安全 - 使用模板(自动转义) +return render_template('search.html', query=query) +``` +""", +) + + +XSS_STORED = KnowledgeDocument( + id="vuln_xss_stored", + title="Stored XSS", + category=KnowledgeCategory.VULNERABILITY, + tags=["xss", "stored", "persistent", "javascript", "database"], + severity="high", + cwe_ids=["CWE-79"], + owasp_ids=["A03:2021"], + content=""" +存储型XSS:恶意脚本被存储在服务器(数据库、文件等),当其他用户访问时执行。 + +## 危险场景 +- 用户评论/留言板 +- 用户个人资料 +- 论坛帖子 +- 文件名/描述 +- 日志查看器 + +## 危险模式 +```python +# 危险 - 存储未过滤的用户输入 +comment = request.form['comment'] +db.save_comment(comment) # 存储 + +# 危险 - 显示未转义的内容 +comments = db.get_comments() +return render_template_string(f"
{comments}
") +``` + +## 检测要点 +1. 追踪用户输入到数据库的流程 +2. 检查从数据库读取后的输出处理 +3. 关注富文本编辑器的处理 +4. 检查管理后台的数据展示 + +## 安全实践 +1. 输入时过滤/存储时转义 +2. 输出时始终转义 +3. 使用白名单HTML标签(如需富文本) +4. 使用DOMPurify等库清理HTML + +## 修复示例 +```python +# 安全 - 使用bleach清理HTML +import bleach +clean_comment = bleach.clean(comment, tags=['p', 'b', 'i']) +db.save_comment(clean_comment) +``` +""", +) + + +XSS_DOM = KnowledgeDocument( + id="vuln_xss_dom", + title="DOM-based XSS", + category=KnowledgeCategory.VULNERABILITY, + tags=["xss", "dom", "javascript", "client-side"], + severity="high", + cwe_ids=["CWE-79"], + owasp_ids=["A03:2021"], + content=""" +DOM型XSS:漏洞存在于客户端JavaScript代码,通过修改DOM环境执行恶意脚本。 + +## 危险源 (Sources) +```javascript +// URL相关 +location.href +location.search +location.hash +document.URL +document.referrer + +// 存储相关 +localStorage.getItem() +sessionStorage.getItem() + +// 消息相关 +window.postMessage +``` + +## 危险汇 (Sinks) +```javascript +// 危险 - HTML注入 +element.innerHTML = userInput; +element.outerHTML = userInput; +document.write(userInput); +document.writeln(userInput); + +// 危险 - JavaScript执行 +eval(userInput); +setTimeout(userInput, 1000); +setInterval(userInput, 1000); +new Function(userInput); + +// 危险 - URL跳转 +location.href = userInput; +location.assign(userInput); +window.open(userInput); +``` + +## 危险模式 +```javascript +// 危险 - 从URL获取并直接使用 +const name = new URLSearchParams(location.search).get('name'); +document.getElementById('greeting').innerHTML = 'Hello ' + name; + +// 危险 - hash注入 +const hash = location.hash.substring(1); +document.getElementById('content').innerHTML = decodeURIComponent(hash); +``` + +## 安全实践 +1. 使用textContent代替innerHTML +2. 使用安全的DOM API +3. 对URL参数进行验证 +4. 使用DOMPurify清理HTML + +## 修复示例 +```javascript +// 安全 - 使用textContent +element.textContent = userInput; + +// 安全 - 使用DOMPurify +element.innerHTML = DOMPurify.sanitize(userInput); + +// 安全 - 创建文本节点 +element.appendChild(document.createTextNode(userInput)); +``` +""", +) diff --git a/backend/app/services/agent/knowledge/vulnerabilities/xxe.py b/backend/app/services/agent/knowledge/vulnerabilities/xxe.py new file mode 100644 index 0000000..1566686 --- /dev/null +++ b/backend/app/services/agent/knowledge/vulnerabilities/xxe.py @@ -0,0 +1,129 @@ +""" +XXE (XML外部实体注入) 漏洞知识 +""" + +from ..base import KnowledgeDocument, KnowledgeCategory + + +XXE = KnowledgeDocument( + id="vuln_xxe", + title="XML External Entity (XXE) Injection", + category=KnowledgeCategory.VULNERABILITY, + tags=["xxe", "xml", "entity", "injection", "ssrf"], + severity="high", + cwe_ids=["CWE-611"], + owasp_ids=["A05:2021"], + content=""" +XXE允许攻击者通过XML外部实体读取服务器文件、执行SSRF攻击或导致拒绝服务。 + +## 危险模式 + +### Python +```python +# 危险 - lxml默认配置 +from lxml import etree +doc = etree.parse(user_xml) +doc = etree.fromstring(user_xml) + +# 危险 - xml.etree (Python < 3.7.1) +import xml.etree.ElementTree as ET +ET.parse(user_xml) + +# 危险 - xml.dom +from xml.dom import minidom +minidom.parseString(user_xml) +``` + +### Java +```java +// 危险 - DocumentBuilder默认配置 +DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance(); +DocumentBuilder db = dbf.newDocumentBuilder(); +Document doc = db.parse(userInput); + +// 危险 - SAXParser +SAXParserFactory spf = SAXParserFactory.newInstance(); +SAXParser parser = spf.newSAXParser(); +parser.parse(userInput, handler); +``` + +### PHP +```php +// 危险 +$doc = simplexml_load_string($xml); +$doc = new DOMDocument(); +$doc->loadXML($xml); +``` + +## 攻击载荷 + +### 文件读取 +```xml + + +]> +&xxe; +``` + +### SSRF +```xml + + +]> +&xxe; +``` + +### 拒绝服务 (Billion Laughs) +```xml + + + + +]> +&lol3; +``` + +## 检测要点 +1. 所有XML解析代码 +2. 是否禁用了外部实体 +3. 是否禁用了DTD处理 +4. 用户输入是否直接解析 + +## 安全实践 +1. 禁用外部实体 +2. 禁用DTD处理 +3. 使用JSON代替XML +4. 输入验证 + +## 修复示例 + +### Python +```python +# 安全 - lxml禁用实体 +from lxml import etree +parser = etree.XMLParser( + resolve_entities=False, + no_network=True, + dtd_validation=False, + load_dtd=False +) +doc = etree.parse(user_xml, parser) + +# 安全 - defusedxml +import defusedxml.ElementTree as ET +doc = ET.parse(user_xml) +``` + +### Java +```java +// 安全 - 禁用外部实体 +DocumentBuilderFactory dbf = DocumentBuilderFactory.newInstance(); +dbf.setFeature("http://apache.org/xml/features/disallow-doctype-decl", true); +dbf.setFeature("http://xml.org/sax/features/external-general-entities", false); +dbf.setFeature("http://xml.org/sax/features/external-parameter-entities", false); +``` +""", +) diff --git a/backend/app/services/agent/prompts/system_prompts.py b/backend/app/services/agent/prompts/system_prompts.py index f44321b..2f02161 100644 --- a/backend/app/services/agent/prompts/system_prompts.py +++ b/backend/app/services/agent/prompts/system_prompts.py @@ -34,7 +34,7 @@ ORCHESTRATOR_SYSTEM_PROMPT = """你是一个专业的代码安全审计 Agent, ## 分析方法 1. **快速扫描**: 首先使用 pattern_match 快速发现可疑代码 2. **语义搜索**: 使用 rag_query 查找相关上下文 -3. **深度分析**: 对可疑代码使用 code_analysis 深入分析 +3. **深度分析**: 对可疑代码使用 read_file 读取并分析 4. **数据流追踪**: 追踪用户输入到危险函数的路径 5. **漏洞验证**: 在沙箱中验证发现的漏洞 @@ -67,7 +67,6 @@ ANALYSIS_SYSTEM_PROMPT = """你是一个专注于代码漏洞分析的安全专 ## 可用工具 - rag_query: 语义搜索相关代码 - pattern_match: 快速模式匹配 -- code_analysis: LLM 深度分析 - read_file: 读取文件内容 - search_code: 关键字搜索 - dataflow_analysis: 数据流分析 diff --git a/backend/app/services/agent/tools/__init__.py b/backend/app/services/agent/tools/__init__.py index 1ebf28f..444172f 100644 --- a/backend/app/services/agent/tools/__init__.py +++ b/backend/app/services/agent/tools/__init__.py @@ -1,7 +1,12 @@ """ Agent 工具集 -提供 LangChain Agent 使用的各种工具 -包括内置工具和外部安全工具 + +提供 Agent 使用的各种工具,包括: +- 基础工具(文件操作、代码搜索) +- 分析工具(模式匹配、数据流分析) +- 外部安全工具(Semgrep、Bandit等) +- 协作工具(Think、Agent通信) +- 报告工具(漏洞报告) """ from .base import AgentTool, ToolResult @@ -22,6 +27,23 @@ from .external_tools import ( OSVScannerTool, ) +# 🔥 新增:思考和推理工具 +from .thinking_tool import ThinkTool, ReflectTool + +# 🔥 新增:漏洞报告工具 +from .reporting_tool import CreateVulnerabilityReportTool + +# 🔥 新增:Agent协作工具 +from .agent_tools import ( + CreateSubAgentTool, + SendMessageTool, + ViewAgentGraphTool, + WaitForMessageTool, + AgentFinishTool, + RunSubAgentsTool, + CollectSubAgentResultsTool, +) + __all__ = [ # 基础 "AgentTool", @@ -57,5 +79,21 @@ __all__ = [ "SafetyTool", "TruffleHogTool", "OSVScannerTool", + + # 🔥 思考和推理工具 + "ThinkTool", + "ReflectTool", + + # 🔥 漏洞报告工具 + "CreateVulnerabilityReportTool", + + # 🔥 Agent协作工具 + "CreateSubAgentTool", + "SendMessageTool", + "ViewAgentGraphTool", + "WaitForMessageTool", + "AgentFinishTool", + "RunSubAgentsTool", + "CollectSubAgentResultsTool", ] diff --git a/backend/app/services/agent/tools/agent_tools.py b/backend/app/services/agent/tools/agent_tools.py new file mode 100644 index 0000000..a9a21c3 --- /dev/null +++ b/backend/app/services/agent/tools/agent_tools.py @@ -0,0 +1,785 @@ +""" +Agent 协作工具 + +提供动态Agent创建、通信和管理功能 +""" + +import logging +from typing import Optional, List, Dict, Any +from pydantic import BaseModel, Field + +from .base import AgentTool, ToolResult +from ..core.registry import agent_registry +from ..core.message import message_bus, MessageType, MessagePriority + +logger = logging.getLogger(__name__) + + +class CreateAgentInput(BaseModel): + """创建Agent输入参数""" + name: str = Field(..., description="Agent名称") + task: str = Field(..., description="任务描述") + agent_type: str = Field( + default="specialist", + description="Agent类型: analysis(分析), verification(验证), specialist(专家)" + ) + knowledge_modules: Optional[str] = Field( + default=None, + description="知识模块,逗号分隔,最多5个。如: sql_injection,xss,authentication" + ) + inherit_context: bool = Field( + default=True, + description="是否继承父Agent的上下文" + ) + execute_immediately: bool = Field( + default=False, + description="是否立即执行子Agent(否则只创建不执行)" + ) + context: Optional[Dict[str, Any]] = Field( + default=None, + description="传递给子Agent的上下文数据" + ) + + +class CreateSubAgentTool(AgentTool): + """ + 创建子Agent工具 + + 允许Agent动态创建专业化的子Agent来处理特定任务。 + 子Agent可以加载特定的知识模块,专注于特定领域。 + + 支持两种模式: + 1. 仅创建:创建Agent但不执行,后续可以批量执行 + 2. 立即执行:创建并立即执行Agent,等待结果返回 + """ + + def __init__( + self, + parent_agent_id: str, + llm_service=None, + tools: Dict[str, Any] = None, + event_emitter=None, + ): + super().__init__() + self.parent_agent_id = parent_agent_id + self.llm_service = llm_service + self.tools = tools or {} + self.event_emitter = event_emitter + + # 子Agent执行器(延迟初始化) + self._sub_executor = None + + def _get_executor(self): + """获取子Agent执行器""" + if self._sub_executor is None and self.llm_service: + from ..core.executor import SubAgentExecutor + # 需要获取父Agent实例 + parent_agent = agent_registry.get_agent(self.parent_agent_id) + if parent_agent: + self._sub_executor = SubAgentExecutor( + parent_agent=parent_agent, + llm_service=self.llm_service, + tools=self.tools, + event_emitter=self.event_emitter, + ) + return self._sub_executor + + @property + def name(self) -> str: + return "create_sub_agent" + + @property + def description(self) -> str: + return """创建专业化的子Agent来处理特定任务。 + +使用场景: +1. 发现需要深入分析的特定漏洞类型 +2. 需要专业知识来验证某个发现 +3. 任务过于复杂需要分解 + +参数: +- name: Agent名称(如 "SQL注入专家") +- task: 具体任务描述 +- agent_type: Agent类型 (analysis/verification/specialist) +- knowledge_modules: 知识模块,逗号分隔(如 "sql_injection,database_security") +- inherit_context: 是否继承当前上下文 +- execute_immediately: 是否立即执行(默认false,仅创建) +- context: 传递给子Agent的上下文数据 + +注意:每个Agent最多加载5个知识模块。""" + + @property + def args_schema(self): + return CreateAgentInput + + async def _execute( + self, + name: str, + task: str, + agent_type: str = "specialist", + knowledge_modules: Optional[str] = None, + inherit_context: bool = True, + execute_immediately: bool = False, + context: Optional[Dict[str, Any]] = None, + **kwargs + ) -> ToolResult: + """创建子Agent""" + + if not name or not name.strip(): + return ToolResult(success=False, error="Agent名称不能为空") + + if not task or not task.strip(): + return ToolResult(success=False, error="任务描述不能为空") + + # 解析知识模块 + modules = [] + if knowledge_modules: + modules = [m.strip() for m in knowledge_modules.split(",") if m.strip()] + if len(modules) > 5: + return ToolResult( + success=False, + error="知识模块数量不能超过5个" + ) + + # 验证知识模块(如果有) + if modules: + from ..knowledge import knowledge_loader + validation = knowledge_loader.validate_modules(modules) + if validation["invalid"]: + available = knowledge_loader.get_all_module_names() + return ToolResult( + success=False, + error=f"无效的知识模块: {validation['invalid']}。可用模块: {', '.join(available)}" + ) + + # 生成Agent ID + from ..core.state import _generate_agent_id + agent_id = _generate_agent_id() + + # 注册到注册表 + node = agent_registry.register_agent( + agent_id=agent_id, + agent_name=name.strip(), + agent_type=agent_type, + task=task.strip(), + parent_id=self.parent_agent_id, + knowledge_modules=modules, + ) + + # 创建消息队列 + message_bus.create_queue(agent_id) + + logger.info(f"Created sub-agent: {name} ({agent_id}), parent: {self.parent_agent_id}") + + # 如果需要立即执行 + if execute_immediately: + executor = self._get_executor() + if executor: + # 准备上下文 + exec_context = context or {} + exec_context["knowledge_modules"] = modules + + # 执行子Agent + exec_result = await executor.create_and_run_sub_agent( + agent_type=agent_type if agent_type in ["analysis", "verification"] else "analysis", + task=task.strip(), + context=exec_context, + knowledge_modules=modules, + ) + + # 更新注册表状态 + if exec_result.get("success"): + agent_registry.update_agent_status(agent_id, "completed", exec_result) + else: + agent_registry.update_agent_status(agent_id, "failed", {"error": exec_result.get("error")}) + + return ToolResult( + success=exec_result.get("success", False), + data={ + "message": f"子Agent '{name}' 已执行完成" if exec_result.get("success") else f"子Agent '{name}' 执行失败", + "agent_id": agent_id, + "execution_result": exec_result, + "findings": exec_result.get("data", {}).get("findings", []) if exec_result.get("success") else [], + }, + error=exec_result.get("error"), + metadata=node, + ) + else: + logger.warning("SubAgentExecutor not available, agent created but not executed") + + return ToolResult( + success=True, + data={ + "message": f"子Agent '{name}' 已创建", + "agent_id": agent_id, + "agent_info": { + "id": agent_id, + "name": name, + "type": agent_type, + "task": task[:100], + "knowledge_modules": modules, + "parent_id": self.parent_agent_id, + "status": "created", + } + }, + metadata=node, + ) + + +class SendMessageInput(BaseModel): + """发送消息输入参数""" + target_agent_id: str = Field(..., description="目标Agent ID") + message: str = Field(..., description="消息内容") + message_type: str = Field( + default="information", + description="消息类型: query(查询), instruction(指令), information(信息)" + ) + priority: str = Field( + default="normal", + description="优先级: low, normal, high, urgent" + ) + + +class SendMessageTool(AgentTool): + """ + 发送消息工具 + + 向其他Agent发送消息,实现Agent间通信 + """ + + def __init__(self, sender_agent_id: str): + super().__init__() + self.sender_agent_id = sender_agent_id + + @property + def name(self) -> str: + return "send_message" + + @property + def description(self) -> str: + return """向其他Agent发送消息。 + +使用场景: +1. 向子Agent发送指令 +2. 向父Agent报告进展 +3. 请求其他Agent提供信息 + +参数: +- target_agent_id: 目标Agent的ID +- message: 消息内容 +- message_type: 消息类型 (query/instruction/information) +- priority: 优先级 (low/normal/high/urgent)""" + + @property + def args_schema(self): + return SendMessageInput + + async def _execute( + self, + target_agent_id: str, + message: str, + message_type: str = "information", + priority: str = "normal", + **kwargs + ) -> ToolResult: + """发送消息""" + + if not target_agent_id: + return ToolResult(success=False, error="目标Agent ID不能为空") + + if not message or not message.strip(): + return ToolResult(success=False, error="消息内容不能为空") + + # 检查目标Agent是否存在 + target_node = agent_registry.get_agent_node(target_agent_id) + if not target_node: + return ToolResult( + success=False, + error=f"目标Agent '{target_agent_id}' 不存在" + ) + + # 转换消息类型 + try: + msg_type = MessageType(message_type) + except ValueError: + msg_type = MessageType.INFORMATION + + try: + msg_priority = MessagePriority(priority) + except ValueError: + msg_priority = MessagePriority.NORMAL + + # 发送消息 + sent_message = message_bus.send_message( + from_agent=self.sender_agent_id, + to_agent=target_agent_id, + content=message.strip(), + message_type=msg_type, + priority=msg_priority, + ) + + return ToolResult( + success=True, + data={ + "message": f"消息已发送到 '{target_node['name']}'", + "message_id": sent_message.id, + "target_agent": { + "id": target_agent_id, + "name": target_node["name"], + "status": target_node["status"], + } + }, + metadata=sent_message.to_dict(), + ) + + +class ViewAgentGraphTool(AgentTool): + """ + 查看Agent图工具 + + 查看当前的Agent树结构和状态 + """ + + def __init__(self, current_agent_id: str): + super().__init__() + self.current_agent_id = current_agent_id + + @property + def name(self) -> str: + return "view_agent_graph" + + @property + def description(self) -> str: + return """查看当前的Agent树结构和状态。 + +显示: +- 所有Agent及其层级关系 +- 每个Agent的状态和任务 +- 加载的知识模块""" + + @property + def args_schema(self): + return None + + async def _execute(self, **kwargs) -> ToolResult: + """查看Agent图""" + + tree_view = agent_registry.get_agent_tree_view() + stats = agent_registry.get_statistics() + + return ToolResult( + success=True, + data={ + "graph_structure": tree_view, + "summary": stats, + "current_agent_id": self.current_agent_id, + }, + ) + + +class WaitForMessageTool(AgentTool): + """ + 等待消息工具 + + 让Agent进入等待状态,等待其他Agent的消息 + """ + + def __init__(self, agent_id: str, agent_state=None): + super().__init__() + self.agent_id = agent_id + self.agent_state = agent_state + + @property + def name(self) -> str: + return "wait_for_message" + + @property + def description(self) -> str: + return """进入等待状态,等待其他Agent或用户的消息。 + +使用场景: +1. 等待子Agent完成任务并报告 +2. 等待用户提供更多信息 +3. 等待其他Agent的协作响应 + +参数: +- reason: 等待原因""" + + @property + def args_schema(self): + return None + + async def _execute( + self, + reason: str = "等待消息", + **kwargs + ) -> ToolResult: + """进入等待状态""" + + # 更新Agent状态 + if self.agent_state: + self.agent_state.enter_waiting_state(reason) + + # 更新注册表 + agent_registry.update_agent_status(self.agent_id, "waiting") + + return ToolResult( + success=True, + data={ + "status": "waiting", + "message": f"Agent正在等待: {reason}", + "agent_id": self.agent_id, + "resume_conditions": [ + "收到其他Agent的消息", + "收到用户消息", + "等待超时", + ], + }, + ) + + +class AgentFinishInput(BaseModel): + """Agent完成输入参数""" + result_summary: str = Field(..., description="结果摘要") + findings: Optional[List[str]] = Field(default=None, description="发现列表") + success: bool = Field(default=True, description="是否成功") + recommendations: Optional[List[str]] = Field(default=None, description="建议列表") + + +class AgentFinishTool(AgentTool): + """ + Agent完成工具 + + 子Agent完成任务后调用,向父Agent报告结果 + """ + + def __init__(self, agent_id: str, agent_state=None): + super().__init__() + self.agent_id = agent_id + self.agent_state = agent_state + + @property + def name(self) -> str: + return "agent_finish" + + @property + def description(self) -> str: + return """完成当前Agent的任务并向父Agent报告。 + +只有子Agent才能使用此工具。根Agent应使用finish_scan。 + +参数: +- result_summary: 结果摘要 +- findings: 发现列表 +- success: 是否成功完成 +- recommendations: 建议列表""" + + @property + def args_schema(self): + return AgentFinishInput + + async def _execute( + self, + result_summary: str, + findings: Optional[List[str]] = None, + success: bool = True, + recommendations: Optional[List[str]] = None, + **kwargs + ) -> ToolResult: + """完成Agent任务""" + + # 获取父Agent ID + parent_id = agent_registry.get_parent(self.agent_id) + + if not parent_id: + return ToolResult( + success=False, + error="此工具只能由子Agent使用。根Agent请使用finish_scan。" + ) + + # 更新状态 + result = { + "summary": result_summary, + "findings": findings or [], + "success": success, + "recommendations": recommendations or [], + } + + agent_registry.update_agent_status( + self.agent_id, + "completed" if success else "failed", + result, + ) + + if self.agent_state: + self.agent_state.set_completed(result) + + # 向父Agent发送完成报告 + message_bus.send_completion_report( + from_agent=self.agent_id, + to_agent=parent_id, + summary=result_summary, + findings=[{"description": f} for f in (findings or [])], + success=success, + ) + + agent_node = agent_registry.get_agent_node(self.agent_id) + + return ToolResult( + success=True, + data={ + "agent_completed": True, + "parent_notified": True, + "completion_summary": { + "agent_id": self.agent_id, + "agent_name": agent_node["name"] if agent_node else "Unknown", + "success": success, + "findings_count": len(findings or []), + } + }, + ) + + +class RunSubAgentsInput(BaseModel): + """批量执行子Agent输入参数""" + agent_ids: List[str] = Field(..., description="要执行的Agent ID列表") + parallel: bool = Field(default=True, description="是否并行执行") + + +class RunSubAgentsTool(AgentTool): + """ + 批量执行子Agent工具 + + 执行已创建的子Agent,支持并行执行 + """ + + def __init__( + self, + parent_agent_id: str, + llm_service=None, + tools: Dict[str, Any] = None, + event_emitter=None, + ): + super().__init__() + self.parent_agent_id = parent_agent_id + self.llm_service = llm_service + self.tools = tools or {} + self.event_emitter = event_emitter + + @property + def name(self) -> str: + return "run_sub_agents" + + @property + def description(self) -> str: + return """批量执行已创建的子Agent。 + +使用场景: +1. 创建多个子Agent后批量执行 +2. 并行执行多个分析任务 + +参数: +- agent_ids: 要执行的Agent ID列表 +- parallel: 是否并行执行(默认true)""" + + @property + def args_schema(self): + return RunSubAgentsInput + + async def _execute( + self, + agent_ids: List[str], + parallel: bool = True, + **kwargs + ) -> ToolResult: + """批量执行子Agent""" + + if not agent_ids: + return ToolResult(success=False, error="Agent ID列表不能为空") + + # 验证所有Agent存在且是当前Agent的子Agent + valid_agents = [] + for aid in agent_ids: + node = agent_registry.get_agent_node(aid) + if not node: + continue + if node.get("parent_id") != self.parent_agent_id: + continue + if node.get("status") not in ["created", "pending"]: + continue + valid_agents.append(node) + + if not valid_agents: + return ToolResult( + success=False, + error="没有找到可执行的子Agent" + ) + + # 构建执行任务 + from ..core.executor import DynamicAgentExecutor, ExecutionTask + + executor = DynamicAgentExecutor( + llm_service=self.llm_service, + tools=self.tools, + event_emitter=self.event_emitter, + ) + + tasks = [] + for node in valid_agents: + task = ExecutionTask( + agent_id=node["id"], + agent_type=node["type"], + task=node["task"], + context={ + "knowledge_modules": node.get("knowledge_modules", []), + }, + ) + tasks.append(task) + + # 定义Agent工厂函数 + async def agent_factory(task: ExecutionTask) -> Dict[str, Any]: + from ..agents import AnalysisAgent, VerificationAgent + + agent_class_map = { + "analysis": AnalysisAgent, + "verification": VerificationAgent, + "specialist": AnalysisAgent, # 默认使用分析Agent + } + + agent_class = agent_class_map.get(task.agent_type, AnalysisAgent) + + return await executor.execute_agent( + agent_class=agent_class, + agent_config={}, + input_data={ + "task": task.task, + "task_context": task.context, + }, + parent_id=self.parent_agent_id, + knowledge_modules=task.context.get("knowledge_modules"), + ) + + # 执行 + if parallel: + result = await executor.execute_parallel(tasks, agent_factory) + else: + # 顺序执行 + result = await executor.execute_parallel(tasks, agent_factory) + + return ToolResult( + success=result.success, + data={ + "message": f"执行完成: {result.completed_agents}/{result.total_agents} 成功", + "total_agents": result.total_agents, + "completed": result.completed_agents, + "failed": result.failed_agents, + "findings_count": len(result.all_findings), + "findings": result.all_findings[:20], # 限制返回数量 + "duration_ms": result.total_duration_ms, + "tokens_used": result.total_tokens, + }, + error="; ".join(result.errors) if result.errors else None, + metadata={ + "agent_results": { + aid: { + "success": r.get("success"), + "findings_count": len(r.get("data", {}).get("findings", [])) if r.get("success") else 0, + } + for aid, r in result.agent_results.items() + } + }, + ) + + +class CollectSubAgentResultsTool(AgentTool): + """ + 收集子Agent结果工具 + + 收集所有子Agent的执行结果和发现 + """ + + def __init__(self, parent_agent_id: str): + super().__init__() + self.parent_agent_id = parent_agent_id + + @property + def name(self) -> str: + return "collect_sub_agent_results" + + @property + def description(self) -> str: + return """收集所有子Agent的执行结果。 + +返回: +- 所有子Agent的状态 +- 汇总的发现列表 +- 执行统计""" + + @property + def args_schema(self): + return None + + async def _execute(self, **kwargs) -> ToolResult: + """收集子Agent结果""" + + # 获取所有子Agent + children = agent_registry.get_children(self.parent_agent_id) + + if not children: + return ToolResult( + success=True, + data={ + "message": "没有子Agent", + "children_count": 0, + "findings": [], + } + ) + + all_findings = [] + completed = 0 + failed = 0 + running = 0 + + child_summaries = [] + + for child_id in children: + node = agent_registry.get_agent_node(child_id) + if not node: + continue + + status = node.get("status", "unknown") + + if status == "completed": + completed += 1 + # 收集发现 + result = node.get("result", {}) + if isinstance(result, dict): + findings = result.get("findings", []) + if isinstance(findings, list): + all_findings.extend(findings) + elif status == "failed": + failed += 1 + elif status == "running": + running += 1 + + child_summaries.append({ + "id": child_id, + "name": node.get("name"), + "type": node.get("type"), + "status": status, + "findings_count": len(node.get("result", {}).get("findings", [])) if node.get("result") else 0, + }) + + return ToolResult( + success=True, + data={ + "message": f"收集完成: {completed} 完成, {failed} 失败, {running} 运行中", + "children_count": len(children), + "completed": completed, + "failed": failed, + "running": running, + "total_findings": len(all_findings), + "findings": all_findings, + "children": child_summaries, + }, + ) diff --git a/backend/app/services/agent/tools/code_analysis_tool.py b/backend/app/services/agent/tools/code_analysis_tool.py index 6d50391..819d51a 100644 --- a/backend/app/services/agent/tools/code_analysis_tool.py +++ b/backend/app/services/agent/tools/code_analysis_tool.py @@ -4,11 +4,14 @@ """ import json +import logging from typing import Optional, List, Dict, Any from pydantic import BaseModel, Field from .base import AgentTool, ToolResult +logger = logging.getLogger(__name__) + class CodeAnalysisInput(BaseModel): """代码分析输入""" @@ -155,6 +158,12 @@ class CodeAnalysisTool(AgentTool): ) except Exception as e: + import traceback + logger.error(f"代码分析失败: {e}") + logger.error(f"LLM Provider: {self.llm_service.config.provider.value if self.llm_service.config else 'N/A'}") + logger.error(f"LLM Model: {self.llm_service.config.model if self.llm_service.config else 'N/A'}") + logger.error(f"API Key 前缀: {self.llm_service.config.api_key[:10] + '...' if self.llm_service.config and self.llm_service.config.api_key else 'N/A'}") + logger.error(traceback.format_exc()) return ToolResult( success=False, error=f"代码分析失败: {str(e)}", diff --git a/backend/app/services/agent/tools/file_tool.py b/backend/app/services/agent/tools/file_tool.py index 4a6ae0c..172f5b3 100644 --- a/backend/app/services/agent/tools/file_tool.py +++ b/backend/app/services/agent/tools/file_tool.py @@ -26,15 +26,24 @@ class FileReadTool(AgentTool): 读取项目中的文件内容 """ - def __init__(self, project_root: str): + def __init__( + self, + project_root: str, + exclude_patterns: Optional[List[str]] = None, + target_files: Optional[List[str]] = None, + ): """ 初始化文件读取工具 Args: project_root: 项目根目录 + exclude_patterns: 排除模式列表 + target_files: 目标文件列表(如果指定,只允许读取这些文件) """ super().__init__() self.project_root = project_root + self.exclude_patterns = exclude_patterns or [] + self.target_files = set(target_files) if target_files else None @property def name(self) -> str: @@ -61,6 +70,22 @@ class FileReadTool(AgentTool): def args_schema(self): return FileReadInput + def _should_exclude(self, file_path: str) -> bool: + """检查文件是否应该被排除""" + # 如果指定了目标文件,只允许读取这些文件 + if self.target_files and file_path not in self.target_files: + return True + + # 检查排除模式 + for pattern in self.exclude_patterns: + if fnmatch.fnmatch(file_path, pattern): + return True + # 也检查文件名 + if fnmatch.fnmatch(os.path.basename(file_path), pattern): + return True + + return False + async def _execute( self, file_path: str, @@ -71,6 +96,13 @@ class FileReadTool(AgentTool): ) -> ToolResult: """执行文件读取""" try: + # 检查是否被排除 + if self._should_exclude(file_path): + return ToolResult( + success=False, + error=f"文件被排除或不在目标文件列表中: {file_path}", + ) + # 安全检查:防止路径遍历 full_path = os.path.normpath(os.path.join(self.project_root, file_path)) if not full_path.startswith(os.path.normpath(self.project_root)): @@ -178,15 +210,30 @@ class FileSearchTool(AgentTool): """ # 排除的目录 - EXCLUDE_DIRS = { + DEFAULT_EXCLUDE_DIRS = { "node_modules", "vendor", "dist", "build", ".git", "__pycache__", ".pytest_cache", "coverage", ".nyc_output", ".vscode", ".idea", ".vs", "target", "venv", "env", } - def __init__(self, project_root: str): + def __init__( + self, + project_root: str, + exclude_patterns: Optional[List[str]] = None, + target_files: Optional[List[str]] = None, + ): super().__init__() self.project_root = project_root + self.exclude_patterns = exclude_patterns or [] + self.target_files = set(target_files) if target_files else None + + # 从 exclude_patterns 中提取目录排除 + self.exclude_dirs = set(self.DEFAULT_EXCLUDE_DIRS) + for pattern in self.exclude_patterns: + if pattern.endswith("/**"): + self.exclude_dirs.add(pattern[:-3]) + elif "/" not in pattern and "*" not in pattern: + self.exclude_dirs.add(pattern) @property def name(self) -> str: @@ -256,7 +303,7 @@ class FileSearchTool(AgentTool): # 遍历文件 for root, dirs, files in os.walk(search_dir): # 排除目录 - dirs[:] = [d for d in dirs if d not in self.EXCLUDE_DIRS] + dirs[:] = [d for d in dirs if d not in self.exclude_dirs] for filename in files: # 检查文件模式 @@ -266,6 +313,19 @@ class FileSearchTool(AgentTool): file_path = os.path.join(root, filename) relative_path = os.path.relpath(file_path, self.project_root) + # 检查是否在目标文件列表中 + if self.target_files and relative_path not in self.target_files: + continue + + # 检查排除模式 + should_skip = False + for excl_pattern in self.exclude_patterns: + if fnmatch.fnmatch(relative_path, excl_pattern) or fnmatch.fnmatch(filename, excl_pattern): + should_skip = True + break + if should_skip: + continue + try: with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: lines = f.readlines() @@ -351,14 +411,30 @@ class ListFilesTool(AgentTool): 列出目录中的文件 """ - EXCLUDE_DIRS = { + DEFAULT_EXCLUDE_DIRS = { "node_modules", "vendor", "dist", "build", ".git", "__pycache__", ".pytest_cache", "coverage", } - def __init__(self, project_root: str): + def __init__( + self, + project_root: str, + exclude_patterns: Optional[List[str]] = None, + target_files: Optional[List[str]] = None, + ): super().__init__() self.project_root = project_root + self.exclude_patterns = exclude_patterns or [] + self.target_files = set(target_files) if target_files else None + + # 从 exclude_patterns 中提取目录排除 + self.exclude_dirs = set(self.DEFAULT_EXCLUDE_DIRS) + for pattern in self.exclude_patterns: + # 如果是目录模式(如 node_modules/**),提取目录名 + if pattern.endswith("/**"): + self.exclude_dirs.add(pattern[:-3]) + elif "/" not in pattern and "*" not in pattern: + self.exclude_dirs.add(pattern) @property def name(self) -> str: @@ -412,7 +488,7 @@ class ListFilesTool(AgentTool): if recursive: for root, dirnames, filenames in os.walk(target_dir): # 排除目录 - dirnames[:] = [d for d in dirnames if d not in self.EXCLUDE_DIRS] + dirnames[:] = [d for d in dirnames if d not in self.exclude_dirs] for filename in filenames: if pattern and not fnmatch.fnmatch(filename, pattern): @@ -420,6 +496,20 @@ class ListFilesTool(AgentTool): full_path = os.path.join(root, filename) relative_path = os.path.relpath(full_path, self.project_root) + + # 检查是否在目标文件列表中 + if self.target_files and relative_path not in self.target_files: + continue + + # 检查排除模式 + should_skip = False + for excl_pattern in self.exclude_patterns: + if fnmatch.fnmatch(relative_path, excl_pattern) or fnmatch.fnmatch(filename, excl_pattern): + should_skip = True + break + if should_skip: + continue + files.append(relative_path) if len(files) >= max_files: @@ -428,26 +518,78 @@ class ListFilesTool(AgentTool): if len(files) >= max_files: break else: - for item in os.listdir(target_dir): - if item in self.EXCLUDE_DIRS: - continue + # 🔥 如果设置了 target_files,只显示目标文件和包含目标文件的目录 + if self.target_files: + # 计算哪些目录包含目标文件 + dirs_with_targets = set() + for tf in self.target_files: + # 获取目标文件的目录部分 + tf_dir = os.path.dirname(tf) + while tf_dir: + dirs_with_targets.add(tf_dir) + tf_dir = os.path.dirname(tf_dir) - full_path = os.path.join(target_dir, item) - relative_path = os.path.relpath(full_path, self.project_root) - - if os.path.isdir(full_path): - dirs.append(relative_path + "/") - else: - if pattern and not fnmatch.fnmatch(item, pattern): + for item in os.listdir(target_dir): + if item in self.exclude_dirs: continue - files.append(relative_path) - if len(files) >= max_files: - break + full_path = os.path.join(target_dir, item) + relative_path = os.path.relpath(full_path, self.project_root) + + if os.path.isdir(full_path): + # 只显示包含目标文件的目录 + if relative_path in dirs_with_targets or any( + tf.startswith(relative_path + "/") for tf in self.target_files + ): + dirs.append(relative_path + "/") + else: + if pattern and not fnmatch.fnmatch(item, pattern): + continue + + # 检查是否在目标文件列表中 + if relative_path not in self.target_files: + continue + + files.append(relative_path) + + if len(files) >= max_files: + break + else: + # 没有设置 target_files,正常列出 + for item in os.listdir(target_dir): + if item in self.exclude_dirs: + continue + + full_path = os.path.join(target_dir, item) + relative_path = os.path.relpath(full_path, self.project_root) + + if os.path.isdir(full_path): + dirs.append(relative_path + "/") + else: + if pattern and not fnmatch.fnmatch(item, pattern): + continue + + # 检查排除模式 + should_skip = False + for excl_pattern in self.exclude_patterns: + if fnmatch.fnmatch(relative_path, excl_pattern) or fnmatch.fnmatch(item, excl_pattern): + should_skip = True + break + if should_skip: + continue + + files.append(relative_path) + + if len(files) >= max_files: + break # 格式化输出 output_parts = [f"📁 目录: {directory}\n"] + # 🔥 如果设置了 target_files,显示提示信息 + if self.target_files: + output_parts.append(f"⚠️ 注意: 审计范围限定为 {len(self.target_files)} 个指定文件\n") + if dirs: output_parts.append("目录:") for d in sorted(dirs)[:20]: @@ -459,6 +601,13 @@ class ListFilesTool(AgentTool): output_parts.append(f"\n文件 ({len(files)}):") for f in sorted(files): output_parts.append(f" 📄 {f}") + elif self.target_files: + # 如果没有文件但设置了 target_files,显示目标文件列表 + output_parts.append(f"\n指定的目标文件 ({len(self.target_files)}):") + for f in sorted(self.target_files)[:20]: + output_parts.append(f" 📄 {f}") + if len(self.target_files) > 20: + output_parts.append(f" ... 还有 {len(self.target_files) - 20} 个文件") if len(files) >= max_files: output_parts.append(f"\n... 结果已截断(最大 {max_files} 个文件)") diff --git a/backend/app/services/agent/tools/reporting_tool.py b/backend/app/services/agent/tools/reporting_tool.py new file mode 100644 index 0000000..d04f72e --- /dev/null +++ b/backend/app/services/agent/tools/reporting_tool.py @@ -0,0 +1,235 @@ +""" +漏洞报告工具 + +正式记录漏洞的唯一方式,确保漏洞报告的规范性和完整性。 +""" + +import logging +import uuid +from datetime import datetime, timezone +from typing import Optional, List, Dict, Any +from pydantic import BaseModel, Field + +from .base import AgentTool, ToolResult + +logger = logging.getLogger(__name__) + + +class VulnerabilityReportInput(BaseModel): + """漏洞报告输入参数""" + title: str = Field(..., description="漏洞标题") + vulnerability_type: str = Field( + ..., + description="漏洞类型: sql_injection, xss, ssrf, command_injection, path_traversal, idor, auth_bypass, etc." + ) + severity: str = Field( + ..., + description="严重程度: critical, high, medium, low, info" + ) + description: str = Field(..., description="漏洞详细描述") + file_path: str = Field(..., description="漏洞所在文件路径") + line_start: Optional[int] = Field(default=None, description="起始行号") + line_end: Optional[int] = Field(default=None, description="结束行号") + code_snippet: Optional[str] = Field(default=None, description="相关代码片段") + source: Optional[str] = Field(default=None, description="污点来源(用户输入点)") + sink: Optional[str] = Field(default=None, description="危险函数(漏洞触发点)") + poc: Optional[str] = Field(default=None, description="概念验证/利用方法") + impact: Optional[str] = Field(default=None, description="影响分析") + recommendation: Optional[str] = Field(default=None, description="修复建议") + confidence: float = Field(default=0.8, description="置信度 0.0-1.0") + cwe_id: Optional[str] = Field(default=None, description="CWE编号") + cvss_score: Optional[float] = Field(default=None, description="CVSS评分") + + +class CreateVulnerabilityReportTool(AgentTool): + """ + 创建漏洞报告工具 + + 这是正式记录漏洞的唯一方式。只有通过这个工具创建的漏洞才会被计入最终报告。 + 这个设计确保了漏洞报告的规范性和完整性。 + + 通常只有专门的报告Agent或验证Agent才会调用这个工具, + 确保漏洞在被正式报告之前已经经过了充分的验证。 + """ + + # 存储所有报告的漏洞 + _vulnerability_reports: List[Dict[str, Any]] = [] + + def __init__(self): + super().__init__() + self._reports: List[Dict[str, Any]] = [] + + @property + def name(self) -> str: + return "create_vulnerability_report" + + @property + def description(self) -> str: + return """创建正式的漏洞报告。这是记录已确认漏洞的唯一方式。 + +只有在以下情况下才应该使用此工具: +1. 漏洞已经过充分分析和验证 +2. 有明确的证据支持漏洞存在 +3. 已经评估了漏洞的影响 + +必需参数: +- title: 漏洞标题 +- vulnerability_type: 漏洞类型 +- severity: 严重程度 (critical/high/medium/low/info) +- description: 详细描述 +- file_path: 文件路径 + +可选参数: +- line_start/line_end: 行号范围 +- code_snippet: 代码片段 +- source/sink: 数据流信息 +- poc: 概念验证 +- impact: 影响分析 +- recommendation: 修复建议 +- confidence: 置信度 +- cwe_id: CWE编号 +- cvss_score: CVSS评分""" + + @property + def args_schema(self): + return VulnerabilityReportInput + + async def _execute( + self, + title: str, + vulnerability_type: str, + severity: str, + description: str, + file_path: str, + line_start: Optional[int] = None, + line_end: Optional[int] = None, + code_snippet: Optional[str] = None, + source: Optional[str] = None, + sink: Optional[str] = None, + poc: Optional[str] = None, + impact: Optional[str] = None, + recommendation: Optional[str] = None, + confidence: float = 0.8, + cwe_id: Optional[str] = None, + cvss_score: Optional[float] = None, + **kwargs + ) -> ToolResult: + """创建漏洞报告""" + + # 验证必需字段 + if not title or not title.strip(): + return ToolResult(success=False, error="标题不能为空") + + if not description or not description.strip(): + return ToolResult(success=False, error="描述不能为空") + + if not file_path or not file_path.strip(): + return ToolResult(success=False, error="文件路径不能为空") + + # 验证严重程度 + valid_severities = ["critical", "high", "medium", "low", "info"] + severity = severity.lower() + if severity not in valid_severities: + return ToolResult( + success=False, + error=f"无效的严重程度 '{severity}',必须是: {', '.join(valid_severities)}" + ) + + # 验证漏洞类型 + valid_types = [ + "sql_injection", "nosql_injection", "xss", "ssrf", + "command_injection", "code_injection", "path_traversal", + "file_inclusion", "idor", "auth_bypass", "broken_auth", + "sensitive_data_exposure", "hardcoded_secret", "weak_crypto", + "xxe", "deserialization", "race_condition", "business_logic", + "csrf", "open_redirect", "mass_assignment", "other" + ] + vulnerability_type = vulnerability_type.lower() + if vulnerability_type not in valid_types: + # 允许未知类型,但记录警告 + logger.warning(f"Unknown vulnerability type: {vulnerability_type}") + + # 验证置信度 + confidence = max(0.0, min(1.0, confidence)) + + # 生成报告ID + report_id = f"vuln_{uuid.uuid4().hex[:8]}" + + # 构建报告 + report = { + "id": report_id, + "title": title.strip(), + "vulnerability_type": vulnerability_type, + "severity": severity, + "description": description.strip(), + "file_path": file_path.strip(), + "line_start": line_start, + "line_end": line_end, + "code_snippet": code_snippet, + "source": source, + "sink": sink, + "poc": poc, + "impact": impact, + "recommendation": recommendation or self._get_default_recommendation(vulnerability_type), + "confidence": confidence, + "cwe_id": cwe_id, + "cvss_score": cvss_score, + "created_at": datetime.now(timezone.utc).isoformat(), + "is_verified": True, # 通过此工具创建的都视为已验证 + } + + # 存储报告 + self._reports.append(report) + CreateVulnerabilityReportTool._vulnerability_reports.append(report) + + logger.info(f"Created vulnerability report: [{severity.upper()}] {title}") + + # 返回结果 + severity_emoji = { + "critical": "🔴", + "high": "🟠", + "medium": "🟡", + "low": "🟢", + "info": "🔵", + }.get(severity, "⚪") + + return ToolResult( + success=True, + data={ + "message": f"漏洞报告已创建: {severity_emoji} [{severity.upper()}] {title}", + "report_id": report_id, + "severity": severity, + }, + metadata=report, + ) + + def _get_default_recommendation(self, vuln_type: str) -> str: + """获取默认修复建议""" + recommendations = { + "sql_injection": "使用参数化查询或ORM,避免字符串拼接构造SQL语句", + "xss": "对用户输入进行HTML实体编码,使用CSP策略,避免innerHTML", + "ssrf": "验证和限制目标URL,使用白名单,禁止访问内网地址", + "command_injection": "避免使用shell执行,使用参数列表传递命令,严格验证输入", + "path_traversal": "规范化路径后验证,使用白名单,限制访问目录", + "idor": "实现细粒度访问控制,验证资源所有权,使用UUID替代自增ID", + "auth_bypass": "加强认证逻辑,实现多因素认证,定期审计认证代码", + "hardcoded_secret": "使用环境变量或密钥管理服务存储敏感信息", + "weak_crypto": "使用强加密算法(AES-256, SHA-256+),避免MD5/SHA1", + "xxe": "禁用外部实体解析,使用安全的XML解析器配置", + "deserialization": "避免反序列化不可信数据,使用JSON替代pickle/yaml", + } + return recommendations.get(vuln_type, "请根据具体情况修复此安全问题") + + def get_reports(self) -> List[Dict[str, Any]]: + """获取所有报告""" + return self._reports.copy() + + @classmethod + def get_all_reports(cls) -> List[Dict[str, Any]]: + """获取所有实例的报告""" + return cls._vulnerability_reports.copy() + + @classmethod + def clear_all_reports(cls) -> None: + """清空所有报告""" + cls._vulnerability_reports.clear() diff --git a/backend/app/services/agent/tools/thinking_tool.py b/backend/app/services/agent/tools/thinking_tool.py new file mode 100644 index 0000000..8224840 --- /dev/null +++ b/backend/app/services/agent/tools/thinking_tool.py @@ -0,0 +1,167 @@ +""" +Think 工具 - 深度推理工具 + +让Agent进行深度思考和推理,用于: +- 分析复杂情况 +- 规划下一步行动 +- 评估发现的严重性 +- 决定是否需要创建子Agent +""" + +import logging +from typing import Optional +from pydantic import BaseModel, Field + +from .base import AgentTool, ToolResult + +logger = logging.getLogger(__name__) + + +class ThinkInput(BaseModel): + """Think工具输入参数""" + thought: str = Field( + ..., + description="思考内容,可以是分析、规划、评估等" + ) + category: Optional[str] = Field( + default="general", + description="思考类别: analysis(分析), planning(规划), evaluation(评估), decision(决策)" + ) + + +class ThinkTool(AgentTool): + """ + Think 工具 + + 这是一个让Agent进行深度推理的工具。Agent可以用它来: + - 分析复杂情况:当面对复杂的代码逻辑或不确定的漏洞线索时 + - 规划下一步行动:在执行具体操作之前先规划策略 + - 评估发现的严重性:发现可疑点后评估其真实性和影响 + - 决定是否需要分解任务:当任务变得复杂时分析是否需要创建子Agent + + Think工具的输出会被记录到Agent的对话历史中,帮助LLM保持思路的连贯性。 + """ + + @property + def name(self) -> str: + return "think" + + @property + def description(self) -> str: + return """深度思考工具。用于: +1. 分析复杂的代码逻辑或安全问题 +2. 规划下一步的分析策略 +3. 评估发现的漏洞是否真实存在 +4. 决定是否需要深入调查某个方向 + +使用此工具记录你的推理过程,这有助于保持分析的连贯性。 + +参数: +- thought: 你的思考内容 +- category: 思考类别 (analysis/planning/evaluation/decision)""" + + @property + def args_schema(self): + return ThinkInput + + async def _execute( + self, + thought: str, + category: str = "general", + **kwargs + ) -> ToolResult: + """ + 执行思考 + + 实际上这个工具不执行任何操作,只是记录思考内容。 + 但它的存在让Agent有一个"思考"的动作,有助于推理。 + """ + if not thought or not thought.strip(): + return ToolResult( + success=False, + error="思考内容不能为空", + ) + + thought = thought.strip() + + # 根据类别添加标记 + category_labels = { + "analysis": "🔍 分析", + "planning": "📋 规划", + "evaluation": "⚖️ 评估", + "decision": "🎯 决策", + "general": "💭 思考", + } + + label = category_labels.get(category, "💭 思考") + + logger.debug(f"Think tool called: [{label}] {thought[:100]}...") + + return ToolResult( + success=True, + data={ + "message": f"思考已记录 ({len(thought)} 字符)", + "category": category, + "label": label, + }, + metadata={ + "thought": thought, + "category": category, + "char_count": len(thought), + } + ) + + +class ReflectTool(AgentTool): + """ + 反思工具 + + 让Agent回顾和总结当前的分析进展 + """ + + @property + def name(self) -> str: + return "reflect" + + @property + def description(self) -> str: + return """反思工具。用于回顾当前的分析进展: +1. 总结已经发现的问题 +2. 评估当前分析的覆盖度 +3. 识别可能遗漏的方向 +4. 决定是否需要调整策略 + +参数: +- summary: 当前进展总结 +- findings_so_far: 目前发现的问题数量 +- coverage: 分析覆盖度评估 (low/medium/high) +- next_steps: 建议的下一步行动""" + + @property + def args_schema(self): + return None + + async def _execute( + self, + summary: str = "", + findings_so_far: int = 0, + coverage: str = "medium", + next_steps: str = "", + **kwargs + ) -> ToolResult: + """执行反思""" + reflection = { + "summary": summary, + "findings_count": findings_so_far, + "coverage": coverage, + "next_steps": next_steps, + } + + return ToolResult( + success=True, + data={ + "message": "反思已记录", + "reflection": reflection, + }, + metadata=reflection, + ) diff --git a/backend/app/services/llm/__init__.py b/backend/app/services/llm/__init__.py index e69de29..40e5e0a 100644 --- a/backend/app/services/llm/__init__.py +++ b/backend/app/services/llm/__init__.py @@ -0,0 +1,52 @@ +""" +LLM 服务模块 + +提供统一的 LLM 调用接口,支持: +- 多提供商支持(OpenAI, Claude, Gemini, DeepSeek 等) +- Prompt Caching(减少 Token 消耗) +- Memory Compression(对话历史压缩) +- 流式输出 +- 智能重试 +""" + +from .service import LLMService +from .types import ( + LLMConfig, + LLMProvider, + LLMMessage, + LLMRequest, + LLMResponse, + LLMUsage, + LLMError, +) +from .prompt_cache import ( + PromptCacheManager, + CacheConfig, + CacheStrategy, + CacheStats, + prompt_cache_manager, + estimate_tokens, +) +from .memory_compressor import MemoryCompressor + +__all__ = [ + # Service + "LLMService", + # Types + "LLMConfig", + "LLMProvider", + "LLMMessage", + "LLMRequest", + "LLMResponse", + "LLMUsage", + "LLMError", + # Prompt Cache + "PromptCacheManager", + "CacheConfig", + "CacheStrategy", + "CacheStats", + "prompt_cache_manager", + "estimate_tokens", + # Memory Compression + "MemoryCompressor", +] diff --git a/backend/app/services/llm/adapters/litellm_adapter.py b/backend/app/services/llm/adapters/litellm_adapter.py index ea104e6..b42cdca 100644 --- a/backend/app/services/llm/adapters/litellm_adapter.py +++ b/backend/app/services/llm/adapters/litellm_adapter.py @@ -1,9 +1,15 @@ """ LiteLLM 统一适配器 支持通过 LiteLLM 调用多个 LLM 提供商,使用统一的 OpenAI 兼容格式 + +增强功能: +- Prompt Caching: 为支持的 LLM(如 Claude)添加缓存标记 +- 智能重试: 指数退避重试策略 +- 流式输出: 支持逐 token 返回 """ -from typing import Dict, Any, Optional +import logging +from typing import Dict, Any, Optional, List from ..base_adapter import BaseLLMAdapter from ..types import ( LLMConfig, @@ -14,6 +20,9 @@ from ..types import ( LLMError, DEFAULT_BASE_URLS, ) +from ..prompt_cache import prompt_cache_manager, estimate_tokens + +logger = logging.getLogger(__name__) class LiteLLMAdapter(BaseLLMAdapter): @@ -107,6 +116,25 @@ class LiteLLMAdapter(BaseLLMAdapter): # 构建消息 messages = [{"role": msg.role, "content": msg.content} for msg in request.messages] + + # 🔥 Prompt Caching: 为支持的 LLM 添加缓存标记 + cache_enabled = False + if self.config.provider == LLMProvider.CLAUDE: + # 估算系统提示词 token 数 + system_tokens = 0 + for msg in messages: + if msg.get("role") == "system": + system_tokens += estimate_tokens(msg.get("content", "")) + + messages, cache_enabled = prompt_cache_manager.process_messages( + messages=messages, + model=self.config.model, + provider=self.config.provider.value, + system_prompt_tokens=system_tokens, + ) + + if cache_enabled: + logger.debug(f"🔥 Prompt Caching enabled for {self.config.model}") # 构建请求参数 kwargs: Dict[str, Any] = { @@ -169,6 +197,14 @@ class LiteLLMAdapter(BaseLLMAdapter): completion_tokens=response.usage.completion_tokens or 0, total_tokens=response.usage.total_tokens or 0, ) + + # 🔥 更新 Prompt Cache 统计 + if cache_enabled and hasattr(response.usage, "cache_creation_input_tokens"): + prompt_cache_manager.update_stats( + cache_creation_input_tokens=getattr(response.usage, "cache_creation_input_tokens", 0), + cache_read_input_tokens=getattr(response.usage, "cache_read_input_tokens", 0), + total_input_tokens=response.usage.prompt_tokens or 0, + ) return LLMResponse( content=choice.message.content or "", diff --git a/backend/app/services/llm/memory_compressor.py b/backend/app/services/llm/memory_compressor.py new file mode 100644 index 0000000..bd2c4e9 --- /dev/null +++ b/backend/app/services/llm/memory_compressor.py @@ -0,0 +1,349 @@ +""" +Memory Compressor - 对话历史压缩器 + +当对话历史变得很长时,自动进行压缩,保持语义完整性的同时降低Token消耗。 + +压缩策略: +1. 保留所有系统消息 +2. 保留最近的N条消息 +3. 对较早的消息进行摘要压缩 +4. 保留关键信息(发现、决策点、错误) +""" + +import logging +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + + +# 配置常量 +MAX_TOTAL_TOKENS = 100_000 # 最大总token数 +MIN_RECENT_MESSAGES = 15 # 最少保留的最近消息数 +COMPRESSION_THRESHOLD = 0.9 # 触发压缩的阈值(90%) + + +def estimate_tokens(text: str) -> int: + """ + 估算文本的token数量 + + 简单估算:英文约4字符/token,中文约2字符/token + """ + if not text: + return 0 + + # 简单估算 + ascii_chars = sum(1 for c in text if ord(c) < 128) + non_ascii_chars = len(text) - ascii_chars + + return (ascii_chars // 4) + (non_ascii_chars // 2) + 1 + + +def get_message_tokens(msg: Dict[str, Any]) -> int: + """获取单条消息的token数""" + content = msg.get("content", "") + + if isinstance(content, str): + return estimate_tokens(content) + + if isinstance(content, list): + total = 0 + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + total += estimate_tokens(item.get("text", "")) + return total + + return 0 + + +def extract_message_text(msg: Dict[str, Any]) -> str: + """提取消息文本内容""" + content = msg.get("content", "") + + if isinstance(content, str): + return content + + if isinstance(content, list): + parts = [] + for item in content: + if isinstance(item, dict): + if item.get("type") == "text": + parts.append(item.get("text", "")) + elif item.get("type") == "image_url": + parts.append("[IMAGE]") + return " ".join(parts) + + return str(content) + + +class MemoryCompressor: + """ + 对话历史压缩器 + + 当对话历史超过token限制时,自动压缩较早的消息, + 同时保留关键的安全审计上下文。 + """ + + def __init__( + self, + max_total_tokens: int = MAX_TOTAL_TOKENS, + min_recent_messages: int = MIN_RECENT_MESSAGES, + llm_service=None, + ): + """ + 初始化压缩器 + + Args: + max_total_tokens: 最大总token数 + min_recent_messages: 最少保留的最近消息数 + llm_service: LLM服务(用于生成摘要,可选) + """ + self.max_total_tokens = max_total_tokens + self.min_recent_messages = min_recent_messages + self.llm_service = llm_service + + def compress_history( + self, + messages: List[Dict[str, Any]], + ) -> List[Dict[str, Any]]: + """ + 压缩对话历史 + + 策略: + 1. 保留所有系统消息 + 2. 保留最近的N条消息 + 3. 对较早的消息进行摘要压缩 + 4. 保留关键信息 + + Args: + messages: 原始消息列表 + + Returns: + 压缩后的消息列表 + """ + if not messages: + return messages + + # 分离系统消息和普通消息 + system_msgs = [] + regular_msgs = [] + + for msg in messages: + if msg.get("role") == "system": + system_msgs.append(msg) + else: + regular_msgs.append(msg) + + # 计算当前总token数 + total_tokens = sum(get_message_tokens(msg) for msg in messages) + + # 如果未超过阈值,不需要压缩 + if total_tokens <= self.max_total_tokens * COMPRESSION_THRESHOLD: + return messages + + logger.info(f"Compressing conversation history: {total_tokens} tokens -> target: {int(self.max_total_tokens * 0.7)}") + + # 分离最近消息和较早消息 + recent_msgs = regular_msgs[-self.min_recent_messages:] + old_msgs = regular_msgs[:-self.min_recent_messages] if len(regular_msgs) > self.min_recent_messages else [] + + if not old_msgs: + return messages + + # 压缩较早的消息 + compressed = self._compress_messages(old_msgs) + + # 重新组合 + result = system_msgs + compressed + recent_msgs + + new_total = sum(get_message_tokens(msg) for msg in result) + logger.info(f"Compression complete: {total_tokens} -> {new_total} tokens ({100 - new_total * 100 // total_tokens}% reduction)") + + return result + + def _compress_messages( + self, + messages: List[Dict[str, Any]], + chunk_size: int = 10, + ) -> List[Dict[str, Any]]: + """ + 压缩消息列表 + + Args: + messages: 要压缩的消息 + chunk_size: 每次压缩的消息数量 + + Returns: + 压缩后的消息列表 + """ + if not messages: + return [] + + compressed = [] + + # 按chunk分组压缩 + for i in range(0, len(messages), chunk_size): + chunk = messages[i:i + chunk_size] + summary = self._summarize_chunk(chunk) + if summary: + compressed.append(summary) + + return compressed + + def _summarize_chunk(self, messages: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + """ + 摘要一组消息 + + Args: + messages: 要摘要的消息 + + Returns: + 摘要消息 + """ + if not messages: + return None + + # 提取关键信息 + key_info = self._extract_key_info(messages) + + # 构建摘要 + summary_parts = [] + + if key_info["findings"]: + summary_parts.append(f"发现: {', '.join(key_info['findings'][:5])}") + + if key_info["tools_used"]: + summary_parts.append(f"使用工具: {', '.join(key_info['tools_used'][:5])}") + + if key_info["decisions"]: + summary_parts.append(f"决策: {', '.join(key_info['decisions'][:3])}") + + if key_info["errors"]: + summary_parts.append(f"错误: {', '.join(key_info['errors'][:2])}") + + if not summary_parts: + # 如果没有提取到关键信息,生成简单摘要 + summary_parts.append(f"[已压缩 {len(messages)} 条历史消息]") + + summary_text = " | ".join(summary_parts) + + return { + "role": "assistant", + "content": f"{summary_text}", + } + + def _extract_key_info(self, messages: List[Dict[str, Any]]) -> Dict[str, List[str]]: + """ + 从消息中提取关键信息 + + Args: + messages: 消息列表 + + Returns: + 关键信息字典 + """ + import re + + key_info = { + "findings": [], + "tools_used": [], + "decisions": [], + "errors": [], + "files_analyzed": [], + } + + for msg in messages: + text = extract_message_text(msg).lower() + + # 提取发现的漏洞类型 + vuln_patterns = { + "sql": "SQL注入", + "xss": "XSS", + "ssrf": "SSRF", + "idor": "IDOR", + "auth": "认证问题", + "injection": "注入漏洞", + "traversal": "路径遍历", + "deserialization": "反序列化", + "hardcoded": "硬编码凭证", + "secret": "密钥泄露", + } + + for pattern, label in vuln_patterns.items(): + if pattern in text and ("发现" in text or "漏洞" in text or "finding" in text or "vulnerability" in text): + if label not in key_info["findings"]: + key_info["findings"].append(label) + + # 提取工具使用 + tool_match = re.search(r'action:\s*(\w+)', text, re.IGNORECASE) + if tool_match: + tool = tool_match.group(1) + if tool not in key_info["tools_used"]: + key_info["tools_used"].append(tool) + + # 提取分析的文件 + file_patterns = [ + r'读取文件[::]\s*([^\s\n]+)', + r'分析文件[::]\s*([^\s\n]+)', + r'file[_\s]?path[::]\s*["\']?([^\s\n"\']+)', + r'\.py|\.js|\.ts|\.java|\.go|\.php', + ] + for pattern in file_patterns[:3]: + matches = re.findall(pattern, text) + for match in matches: + if match not in key_info["files_analyzed"]: + key_info["files_analyzed"].append(match) + + # 提取决策 + if any(kw in text for kw in ["决定", "决策", "decision", "选择", "采用"]): + # 尝试提取决策内容 + decision_match = re.search(r'(决定|决策|decision)[::\s]*([^\n。.]{10,50})', text) + if decision_match: + key_info["decisions"].append(decision_match.group(2)[:50]) + else: + key_info["decisions"].append("做出决策") + + # 提取错误 + if any(kw in text for kw in ["错误", "失败", "error", "failed", "exception"]): + error_match = re.search(r'(错误|error|failed)[::\s]*([^\n]{10,50})', text, re.IGNORECASE) + if error_match: + key_info["errors"].append(error_match.group(2)[:50]) + else: + key_info["errors"].append("遇到错误") + + # 去重并限制数量 + for key in key_info: + key_info[key] = list(set(key_info[key]))[:5] + + return key_info + + def should_compress(self, messages: List[Dict[str, Any]]) -> bool: + """ + 检查是否需要压缩 + + Args: + messages: 消息列表 + + Returns: + 是否需要压缩 + """ + total_tokens = sum(get_message_tokens(msg) for msg in messages) + return total_tokens > self.max_total_tokens * COMPRESSION_THRESHOLD + + +# 便捷函数 +def compress_conversation( + messages: List[Dict[str, Any]], + max_tokens: int = MAX_TOTAL_TOKENS, +) -> List[Dict[str, Any]]: + """ + 压缩对话历史的便捷函数 + + Args: + messages: 消息列表 + max_tokens: 最大token数 + + Returns: + 压缩后的消息列表 + """ + compressor = MemoryCompressor(max_total_tokens=max_tokens) + return compressor.compress_history(messages) diff --git a/backend/app/services/llm/prompt_cache.py b/backend/app/services/llm/prompt_cache.py new file mode 100644 index 0000000..aec391f --- /dev/null +++ b/backend/app/services/llm/prompt_cache.py @@ -0,0 +1,333 @@ +""" +Prompt Caching 模块 + +为支持缓存的 LLM(如 Anthropic Claude)提供 Prompt 缓存功能。 +通过在系统提示词和早期对话中添加缓存标记,减少重复处理, +显著降低 Token 消耗和响应延迟。 + +支持的 LLM: +- Anthropic Claude (claude-3-5-sonnet, claude-3-opus, claude-3-haiku) +- OpenAI (部分模型支持) + +缓存策略: +- 短对话(<10轮): 仅缓存系统提示词 +- 中等对话(10-30轮): 缓存系统提示词 + 前5轮对话 +- 长对话(>30轮): 多个缓存点,动态调整 +""" + +import logging +from typing import Dict, Any, List, Optional, Tuple +from dataclasses import dataclass, field +from enum import Enum + +logger = logging.getLogger(__name__) + + +class CacheStrategy(str, Enum): + """缓存策略""" + NONE = "none" # 不缓存 + SYSTEM_ONLY = "system_only" # 仅缓存系统提示词 + SYSTEM_AND_EARLY = "system_early" # 缓存系统提示词和早期对话 + MULTI_POINT = "multi_point" # 多缓存点 + + +@dataclass +class CacheConfig: + """缓存配置""" + enabled: bool = True + strategy: CacheStrategy = CacheStrategy.SYSTEM_AND_EARLY + + # 缓存阈值 + min_system_prompt_tokens: int = 1000 # 系统提示词最小 token 数才启用缓存 + early_messages_count: int = 5 # 早期对话缓存的消息数 + + # 多缓存点配置 + multi_point_interval: int = 10 # 多缓存点间隔(消息数) + max_cache_points: int = 4 # 最大缓存点数量 + + +@dataclass +class CacheStats: + """缓存统计""" + cache_hits: int = 0 + cache_misses: int = 0 + cached_tokens: int = 0 + total_tokens: int = 0 + + @property + def hit_rate(self) -> float: + total = self.cache_hits + self.cache_misses + return self.cache_hits / total if total > 0 else 0.0 + + @property + def token_savings(self) -> float: + return self.cached_tokens / self.total_tokens if self.total_tokens > 0 else 0.0 + + +class PromptCacheManager: + """ + Prompt 缓存管理器 + + 负责: + 1. 检测 LLM 是否支持缓存 + 2. 根据对话长度选择缓存策略 + 3. 为消息添加缓存标记 + 4. 统计缓存效果 + """ + + # 支持缓存的模型 + CACHEABLE_MODELS = { + # Anthropic Claude + "claude-3-5-sonnet": True, + "claude-3-5-sonnet-20241022": True, + "claude-3-opus": True, + "claude-3-opus-20240229": True, + "claude-3-haiku": True, + "claude-3-haiku-20240307": True, + "claude-3-sonnet": True, + "claude-3-sonnet-20240229": True, + # OpenAI (部分支持) + "gpt-4-turbo": False, # 暂不支持 + "gpt-4o": False, + "gpt-4o-mini": False, + } + + # Anthropic 缓存标记 + ANTHROPIC_CACHE_CONTROL = {"type": "ephemeral"} + + def __init__(self, config: Optional[CacheConfig] = None): + self.config = config or CacheConfig() + self.stats = CacheStats() + self._cache_enabled_for_session = True + + def supports_caching(self, model: str, provider: str) -> bool: + """ + 检查模型是否支持缓存 + + Args: + model: 模型名称 + provider: 提供商名称 + + Returns: + 是否支持缓存 + """ + if not self.config.enabled: + return False + + # Anthropic Claude 支持缓存 + if provider.lower() in ["anthropic", "claude"]: + # 检查模型名称 + for cacheable_model in self.CACHEABLE_MODELS: + if cacheable_model in model.lower(): + return self.CACHEABLE_MODELS.get(cacheable_model, False) + + return False + + def determine_strategy( + self, + messages: List[Dict[str, Any]], + system_prompt_tokens: int = 0, + ) -> CacheStrategy: + """ + 根据对话状态确定缓存策略 + + Args: + messages: 消息列表 + system_prompt_tokens: 系统提示词的 token 数 + + Returns: + 缓存策略 + """ + if not self.config.enabled: + return CacheStrategy.NONE + + # 系统提示词太短,不值得缓存 + if system_prompt_tokens < self.config.min_system_prompt_tokens: + return CacheStrategy.NONE + + message_count = len(messages) + + # 短对话:仅缓存系统提示词 + if message_count < 10: + return CacheStrategy.SYSTEM_ONLY + + # 中等对话:缓存系统提示词和早期对话 + if message_count < 30: + return CacheStrategy.SYSTEM_AND_EARLY + + # 长对话:多缓存点 + return CacheStrategy.MULTI_POINT + + def add_cache_markers_anthropic( + self, + messages: List[Dict[str, Any]], + strategy: CacheStrategy, + ) -> List[Dict[str, Any]]: + """ + 为 Anthropic Claude 消息添加缓存标记 + + Anthropic 的缓存格式: + - 在 content 中使用 cache_control 字段 + - 支持 text 类型的 content block + + Args: + messages: 原始消息列表 + strategy: 缓存策略 + + Returns: + 添加了缓存标记的消息列表 + """ + if strategy == CacheStrategy.NONE: + return messages + + cached_messages = [] + + for i, msg in enumerate(messages): + new_msg = msg.copy() + + # 系统提示词缓存 + if msg.get("role") == "system": + new_msg = self._add_cache_to_message(new_msg) + cached_messages.append(new_msg) + continue + + # 早期对话缓存 + if strategy in [CacheStrategy.SYSTEM_AND_EARLY, CacheStrategy.MULTI_POINT]: + if i <= self.config.early_messages_count: + new_msg = self._add_cache_to_message(new_msg) + + # 多缓存点 + if strategy == CacheStrategy.MULTI_POINT: + if i > 0 and i % self.config.multi_point_interval == 0: + cache_point_count = i // self.config.multi_point_interval + if cache_point_count <= self.config.max_cache_points: + new_msg = self._add_cache_to_message(new_msg) + + cached_messages.append(new_msg) + + return cached_messages + + def _add_cache_to_message(self, msg: Dict[str, Any]) -> Dict[str, Any]: + """ + 为单条消息添加缓存标记 + + Args: + msg: 原始消息 + + Returns: + 添加了缓存标记的消息 + """ + content = msg.get("content", "") + + # 如果 content 是字符串,转换为 content block 格式 + if isinstance(content, str): + msg["content"] = [ + { + "type": "text", + "text": content, + "cache_control": self.ANTHROPIC_CACHE_CONTROL, + } + ] + elif isinstance(content, list): + # 已经是 content block 格式,为最后一个 block 添加缓存 + if content: + last_block = content[-1] + if isinstance(last_block, dict): + last_block["cache_control"] = self.ANTHROPIC_CACHE_CONTROL + + return msg + + def process_messages( + self, + messages: List[Dict[str, Any]], + model: str, + provider: str, + system_prompt_tokens: int = 0, + ) -> Tuple[List[Dict[str, Any]], bool]: + """ + 处理消息,添加缓存标记 + + Args: + messages: 原始消息列表 + model: 模型名称 + provider: 提供商名称 + system_prompt_tokens: 系统提示词 token 数 + + Returns: + (处理后的消息列表, 是否启用了缓存) + """ + if not self.supports_caching(model, provider): + return messages, False + + strategy = self.determine_strategy(messages, system_prompt_tokens) + + if strategy == CacheStrategy.NONE: + return messages, False + + # 根据提供商选择缓存方法 + if provider.lower() in ["anthropic", "claude"]: + cached_messages = self.add_cache_markers_anthropic(messages, strategy) + logger.debug(f"Applied {strategy.value} caching strategy for Anthropic") + return cached_messages, True + + return messages, False + + def update_stats( + self, + cache_creation_input_tokens: int = 0, + cache_read_input_tokens: int = 0, + total_input_tokens: int = 0, + ): + """ + 更新缓存统计 + + Args: + cache_creation_input_tokens: 缓存创建的 token 数 + cache_read_input_tokens: 缓存读取的 token 数 + total_input_tokens: 总输入 token 数 + """ + if cache_read_input_tokens > 0: + self.stats.cache_hits += 1 + self.stats.cached_tokens += cache_read_input_tokens + else: + self.stats.cache_misses += 1 + + self.stats.total_tokens += total_input_tokens + + def get_stats_summary(self) -> Dict[str, Any]: + """获取缓存统计摘要""" + return { + "cache_hits": self.stats.cache_hits, + "cache_misses": self.stats.cache_misses, + "hit_rate": f"{self.stats.hit_rate:.2%}", + "cached_tokens": self.stats.cached_tokens, + "total_tokens": self.stats.total_tokens, + "token_savings": f"{self.stats.token_savings:.2%}", + } + + +# 全局缓存管理器实例 +prompt_cache_manager = PromptCacheManager() + + +def estimate_tokens(text: str) -> int: + """ + 估算文本的 token 数量 + + 简单估算:英文约 4 字符/token,中文约 2 字符/token + + Args: + text: 文本内容 + + Returns: + 估算的 token 数 + """ + if not text: + return 0 + + # 统计中文字符 + chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff') + other_chars = len(text) - chinese_chars + + # 中文约 2 字符/token,其他约 4 字符/token + return int(chinese_chars / 2 + other_chars / 4) diff --git a/backend/data/vector_db/ef6dc788-cc23-4a4d-b1a9-5ce4b32248b8/data_level0.bin b/backend/data/vector_db/ef6dc788-cc23-4a4d-b1a9-5ce4b32248b8/data_level0.bin index e89c95a..dd3082e 100644 Binary files a/backend/data/vector_db/ef6dc788-cc23-4a4d-b1a9-5ce4b32248b8/data_level0.bin and b/backend/data/vector_db/ef6dc788-cc23-4a4d-b1a9-5ce4b32248b8/data_level0.bin differ diff --git a/backend/test_msg.md b/backend/test_msg.md new file mode 100644 index 0000000..50f4b83 --- /dev/null +++ b/backend/test_msg.md @@ -0,0 +1,809 @@ +14:22:48 +🔗 Connected to audit stream +14:23:04 +Orchestrator +Thought: 我需要对这个复杂的多组件项目进行安全审计。项目包含多个引擎(MediaEngine, MindSpider, QueryEngine, ReportEngine, InsightEngine, ForumEngine)和相关的 Streamlit 报告应用,使用 Python 和 JavaScript 开发。首先,我需要全面了解项目的结构、技术栈和入口点,以便确定审计的重点区域。我将从 recon Agent 开始,收集项目的基本信息。 +14:23:08 +📤 调度 recon Agent: 分析项目结构、技术栈、入口点和主要组件... +Orchestrator +14:23:18 +Recon +Thought: 我需要开始收集项目信息。首先,我应该了解项目的整体结构,包括目录布局和文件组织。这将帮助我识别项目的基本架构和组件。我将从根目录开始,使用list_files工具查看项目结构。 +14:23:22 +Completed: list_files +Recon +4ms +14:23:26 +Recon +Thought: 从项目结构来看,这是一个包含多个组件的复杂项目,主要包括: +1. MediaEngine - 一个Python模块,似乎处理媒体相关功能 +2. MindSpider - 包含爬虫功能,特别是DeepSentimentCrawling +3. 根目录有一些Python应用文件,如app.py、config.py等 +4. 有Docker相关文件,表明这是一个容器化项目 + +我需要进一步了解技术栈和项目配置。首先,我应该查看根目录的配置文件,如requirements.txt、config.py和app.py,以了解主要的技术栈和入口点。 +14:23:34 +Completed: read_file +Recon +3ms +Input: +{ + "file_path": "requirements.txt" +} + +Output: +null +14:23:35 +Recon +Thought: 从requirements.txt可以看出这是一个复杂的项目,使用了多种技术栈: +1. Web框架:Flask、Streamlit、FastAPI +2. 实时通信:Flask-SocketIO、python-socketio +3. 数据库:MySQL、PostgreSQL、SQLite、Redis +4. 爬虫工具:Playwright、BeautifulSoup等 +5. 机器学习:PyTorch、Transformers等 +6. 其他工具:OpenAI API、Tavily搜索等 + +现在我需要查看主要的入口点文件,如app.py,了解应用程序的主要结构和功能。 +14:23:38 +Completed: read_file +Recon +1ms +14:23:39 +Recon +Thought: 从app.py文件的前500行可以看出,这是一个Flask应用,用于统一管理三个Streamlit应用。它包含以下关键信息: + +1. 使用Flask作为主框架,Flask-SocketIO用于实时通信 +2. 配置了多个API密钥和引擎(Insight Engine、Media Engine、Query Engine等) +3. 使用MindSpider模块进行爬虫功能 +4. 尝试导入ReportEngine作为蓝图 +5. 配置了多个数据库连接参数 + +我需要继续查看app.py的其余部分,以了解完整的应用结构和入口点。同时,我也应该查看config.py文件来了解配置结构。 +14:23:44 +Completed: read_file +Recon +0ms +14:23:44 +Recon +Thought: 从app.py的501-1000行可以看出,这个应用管理多个Streamlit子进程: + +1. 有四个主要进程:insight、media、query和forum +2. 每个进程都有独立的端口(insight: 8501, media: 8502, query: 8503) +3. 使用队列进行进程间通信 +4. 有日志记录功能,将输出写入文件并通过WebSocket发送到前端 +5. 有启动Streamlit应用的函数 + +我需要继续查看app.py的剩余部分,以了解完整的API路由和功能。同时,我也应该查看config.py文件来了解配置结构。 +14:23:48 +Completed: read_file +Recon +2ms +14:23:49 +Recon +Thought: 从app.py的最后部分可以看出,这个应用提供了多个API端点: + +1. 日志相关API:读取和测试日志 +2. ForumEngine相关API:启动/停止论坛监控,获取论坛日志 +3. 搜索API:统一搜索接口,可以向运行中的不同引擎发送搜索请求 +4. 每个Streamlit应用都有对应的API端口(insight: 8601, media: 8602, query: 8603) + +现在我需要查看config.py文件来了解配置结构,以及查看SingleEngineApp目录来了解各个引擎的实现。 +14:23:55 +Completed: read_file +Recon +1ms +14:23:56 +Recon +Thought: 从config.py文件可以看出,这是一个使用Pydantic Settings进行配置管理的项目,包含多个API密钥和配置项: + +1. 数据库配置:支持MySQL和PostgreSQL +2. 多个LLM引擎配置:Insight Engine、Media Engine、Query Engine、Report Engine等 +3. 网络工具配置:Tavily API、Bocha API等 +4. 每个引擎都有对应的API密钥、基础URL和模型名称 + +现在我需要查看SingleEngineApp目录,了解各个引擎的实现,以及查看MindSpider目录,了解爬虫功能的实现。 +14:24:02 +Completed: list_files +Recon +0ms +14:24:03 +Recon +Thought: SingleEngineApp目录包含三个Streamlit应用文件,分别对应三个不同的引擎: +1. insight_engine_streamlit_app.py - 洞察引擎 +2. media_engine_streamlit_app.py - 媒体引擎 +3. query_engine_streamlit_app.py - 查询引擎 + +我需要查看这些文件来了解各个引擎的具体实现和功能。首先查看洞察引擎的实现。 +14:24:07 +Completed: read_file +Recon +0ms +14:24:07 +Recon +Thought: 从insight_engine_streamlit_app.py可以看出,这是一个Streamlit应用,为Insight Agent提供Web界面。主要功能包括: + +1. 接收URL参数中的查询内容 +2. 使用DeepSearchAgent执行深度搜索和分析 +3. 生成报告结构并处理段落 +4. 执行反思循环以改进分析结果 +5. 生成最终报告 + +现在我需要查看媒体引擎和查询引擎的实现,以及MindSpider目录,了解爬虫功能的实现。 +14:24:12 +Completed: read_file +Recon +1ms +14:26:02 +Recon +Thought: 从media_engine_streamlit_app.py可以看出,这是Media Agent的Streamlit界面,主要功能包括: + +1. 支持多模态能力,可以处理视频、图文、直播等内容 +2. 支持多种搜索引擎:BochaAPI和AnspireAPI +3. 使用Gemini模型作为默认LLM +4. 接收URL参数中的查询内容并自动执行搜索 + +现在我需要查看查询引擎的实现,以及MindSpider目录,了解爬虫功能的实现。 +14:26:07 +Completed: read_file +Recon +2ms +14:26:10 +Recon +Thought: 从query_engine_streamlit_app.py可以看出,这是Query Agent的Streamlit界面,主要功能包括: + +1. 具备强大的网页搜索能力 +2. 注重国内外资源相结合理解舆情 +3. 使用DeepSeek模型作为默认LLM +4. 使用Tavily API进行网络搜索 +5. 与Insight Agent类似,也执行报告结构生成、段落处理和反思循环 + +现在我需要查看MindSpider目录,了解爬虫功能的实现。 +14:26:14 +Completed: list_files +Recon +2ms +14:26:14 +Recon +Thought: MindSpider目录包含一个复杂的爬虫系统,主要结构包括: + +1. DeepSentimentCrawling/MediaCrawler - 媒体爬虫模块 + - 包含缓存系统 + - 数据库模块 + - 代理系统 + - 各种工具函数 + - 主程序入口 + +2. DeepSentimentCrawling - 深度情感爬虫 + - 关键词管理器 + - 平台爬虫 + - 主程序 + +3. 根目录的MindSpider文件 + - 配置文件 + - 主程序 + +我需要查看MindSpider的主程序和配置文件,了解爬虫系统的整体架构和功能。 +14:26:18 +Completed: read_file +Recon +2ms +14:26:20 +Recon +Thought: 从MindSpider/main.py可以看出,这是一个AI爬虫项目的主程序,集成了两个核心模块: + +1. BroadTopicExtraction - 广泛主题提取 +2. DeepSentimentCrawling - 深度情感爬虫 + +主要功能包括: +- 检查配置和数据库连接 +- 初始化数据库 +- 运行爬虫任务 +- 支持多种数据库(MySQL、PostgreSQL) + +现在我需要查看MindSpider的配置文件,了解具体的配置项。 +14:26:24 +Completed: read_file +Recon +1ms +14:26:25 +Recon +Thought: 从MindSpider/config.py可以看出,这是一个使用Pydantic Settings进行配置管理的模块,主要包含: + +1. 数据库配置:支持MySQL和PostgreSQL +2. MindSpider API配置:包括API密钥、基础URL和模型名称 +3. 使用.env文件进行环境变量管理 + +现在我需要查看DeepSentimentCrawling模块,了解爬虫的具体实现。 +14:26:29 +Completed: read_file +Recon +2ms +14:27:09 +Recon +Thought: 从DeepSentimentCrawling/main.py可以看出,这是一个深度情感爬取模块,主要功能包括: + +1. 支持多个平台:小红书(xhs)、抖音(dy)、快手(ks)、哔哩哔哩(bili)、微博(wb)、贴吧(tieba)、知乎(zhihu) +2. 基于关键词进行爬取 +3. 支持每日爬取任务和单个平台爬取 +4. 使用关键词管理器(KeywordManager)和平台爬虫(PlatformCrawler) + +现在我需要查看平台爬虫的实现,了解具体的爬取逻辑。 +14:27:19 +Completed: read_file +Recon +3ms +14:27:19 +✅ recon Agent 完成 +Orchestrator +14:27:39 +Orchestrator +Thought: Recon Agent 的执行结果没有返回任何有用的信息,项目结构、技术栈、入口点等都是空的。这可能是因为 Agent 需要更具体的指导。我将尝试一个更具体的 recon 任务,重点关注配置文件和主要代码文件,因为这些通常包含安全敏感信息。 +14:27:43 +📤 调度 recon Agent: 重点分析配置文件(config.py, .env.example)、主要入口文件(app.py, report_engine_only.py)和依赖文件(requirements.txt)... +Orchestrator +14:28:44 +Recon +Thought: 我需要开始收集项目信息。首先,我应该了解项目的整体结构,包括目录布局和文件组织。这将帮助我识别项目类型、技术栈和关键组件。我将从根目录开始,递归列出所有文件和目录。 +14:28:51 +Completed: list_files +Recon +4ms +14:28:51 +Recon +Thought: 从文件列表中,我可以看到这是一个复杂的项目,包含多个组件: +1. MediaEngine - 一个Python模块,似乎用于媒体处理和报告生成 +2. MindSpider/DeepSentimentCrawling - 一个网络爬虫项目,用于情感分析爬取 +3. 根目录有一些配置文件和主程序 + +我需要进一步了解项目的技术栈和依赖关系。首先,我应该查看根目录的配置文件,如requirements.txt、package.json等,以了解项目使用的依赖和技术栈。 +14:28:56 +Completed: read_file +Recon +1ms +14:29:05 +Recon +Thought: 从requirements.txt文件中,我可以看到这是一个复杂的项目,使用了多种技术栈: + +1. Web框架:Flask、Streamlit、FastAPI +2. 数据库:MySQL (pymysql, aiomysql)、SQLite (aiosqlite)、Redis、PostgreSQL (asyncpg) +3. 爬虫工具:Playwright、BeautifulSoup、lxml等 +4. 机器学习:PyTorch、Transformers、scikit-learn等 +5. 其他工具:OpenAI API、数据处理、可视化等 + +现在我应该查看MindSpider/DeepSentimentCrawling目录下的package.json文件,以了解前端或Node.js相关的依赖。 +14:29:28 +Completed: read_file +Recon +0ms +14:29:29 +Recon +Thought: package.json文件显示这是一个使用VitePress的文档项目,没有前端应用依赖。现在我应该查看MindSpider/DeepSentimentCrawler目录下的requirements.txt文件,以了解爬虫组件的具体依赖。 +14:29:32 +Completed: read_file +Recon +1ms +14:29:32 +Recon +Thought: 现在我已经了解了项目的主要依赖。接下来,我应该查看主要的入口点文件,如app.py、main.py等,以了解应用程序的启动点和主要功能。首先查看根目录的app.py文件。 +14:29:35 +Completed: read_file +Recon +2ms +14:29:35 +Recon +Thought: app.py文件很长,我需要查看更多内容来了解应用程序的主要功能和入口点。让我继续读取文件的剩余部分。 +14:29:38 +Completed: read_file +Recon +2ms +14:29:39 +Recon +Thought: app.py文件非常长,我需要继续查看剩余部分以了解完整的应用程序结构。让我继续读取文件的最后部分。 +14:29:41 +Completed: read_file +Recon +1ms +14:29:42 +Recon +Thought: app.py文件很长,我已经看到了大部分内容,但还需要查看最后部分以了解完整的应用程序结构。让我继续读取文件的最后部分。 +14:29:43 +Completed: read_file +Recon +1ms +14:29:46 +Recon +Thought: 现在我已经了解了app.py文件的主要内容。这是一个Flask应用,用于管理多个Streamlit应用。接下来,我应该查看MindSpider/main.py文件,以了解爬虫组件的主要功能。 +14:29:51 +Completed: read_file +Recon +4ms +14:29:53 +Recon +Thought: MindSpider/main.py文件也很长,我需要查看更多内容来了解爬虫组件的主要功能。让我继续读取文件的剩余部分。 +14:29:55 +Completed: read_file +Recon +2ms +14:29:56 +Recon +Thought: 我已经了解了MindSpider/main.py文件的主要内容。这是一个爬虫项目的主程序,集成了BroadTopicExtraction和DeepSentimentCrawling两个核心模块。接下来,我应该查看配置文件,以了解项目的配置结构和敏感信息。 +14:29:59 +Completed: read_file +Recon +6ms +Input: +{ + "file_path": "config.py" +} + +Output: +null +14:30:03 +Recon +Thought: config.py文件很长,我需要查看更多内容以了解完整的配置结构。让我继续读取文件的剩余部分。 +14:30:04 +Completed: read_file +Recon +1ms +14:30:05 +Recon +Thought: 现在我已经了解了config.py文件的主要内容。这是一个使用pydantic-settings管理的全局配置文件,包含了数据库配置、LLM API配置、网络工具配置等。接下来,我应该查看.env.example文件,以了解环境变量的配置示例。 +14:30:09 +Completed: read_file +Recon +1ms +14:30:10 +Recon +Thought: 现在我已经了解了项目的主要配置文件。接下来,我应该查看一些关键的入口点文件,如SingleEngineApp目录下的文件,以了解各个引擎的应用程序结构。首先查看SingleEngineApp目录结构。 +14:30:13 +Completed: list_files +Recon +1ms +14:30:14 +Recon +Thought: SingleEngineApp目录包含三个Streamlit应用程序文件。我应该查看其中一个文件以了解应用程序的结构和功能。让我先查看insight_engine_streamlit_app.py文件。 +14:30:17 +Completed: read_file +Recon +0ms +14:30:17 +✅ recon Agent 完成 +Orchestrator +14:30:17 +Orchestrator +尝试直接分析项目中的关键文件,看看能否获取更多信息。我将尝试使用 analysis Agent 来分析项目中的主要文件,特别是配置文件和入口文件。 +14:30:20 +📤 调度 analysis Agent: 分析项目中的关键文件,包括config.py, app.py, report_engine_only.py等主要入口文件... +Orchestrator +14:30:26 +Analysis +Thought: 我需要首先了解这个项目的结构和内容,以便制定合适的分析策略。由于项目信息有限,我不知道项目的具体语言和框架,所以我应该先探索项目结构,然后根据文件类型选择合适的分析工具。 +14:30:28 +Completed: list_files +Analysis +2ms +14:30:50 +Analysis +Thought: 从项目结构来看,这是一个Python项目,主要包含两个主要部分:MediaEngine和MindSpider。项目中有多个Python文件,包括配置文件、主程序文件和工具文件。我注意到有一些敏感文件如.env.example,可能包含配置信息。接下来我应该先查看主要的入口文件和配置文件,了解项目的基本功能,然后进行安全分析。 +14:30:57 +Completed: read_file +Analysis +4ms +14:30:58 +Analysis +Thought: 我已经看到了app.py文件的部分内容,这是一个Flask应用,用于管理三个Streamlit应用。我注意到一些安全相关信息: +1. Flask应用的SECRET_KEY是硬编码的 +2. 有多个API密钥配置 +3. 使用了subprocess模块 +4. 有socketio通信 + +我需要继续查看app.py的其余部分,然后检查配置文件和其他关键文件。让我先读取app.py的剩余部分。 +14:31:02 +Completed: read_file +Analysis +2ms +14:31:02 +Analysis +Thought: 我已经看到了app.py的部分内容,这是一个Flask应用,用于管理多个Streamlit应用。我注意到一些潜在的安全问题: +1. 硬编码的SECRET_KEY +2. 使用subprocess模块启动子进程 +3. 文件操作可能存在路径遍历风险 + +让我继续查看配置文件,了解项目的配置情况,然后使用semgrep进行静态分析,这样可以更系统地发现潜在的安全问题。 +14:31:05 +Completed: read_file +Analysis +1ms +14:31:18 +Analysis +Thought: 我已经看到了config.py文件的部分内容,这是一个使用pydantic-settings的配置文件,包含多个API密钥和配置项。我注意到配置文件中有很多API密钥字段,这些如果被泄露会很危险。 + +现在我已经对项目有了基本了解,接下来我应该使用semgrep进行静态分析,这样可以系统地发现潜在的安全问题。semgrep可以检测多种漏洞类型,包括SQL注入、XSS、命令注入等。 +14:31:22 +Completed: semgrep_scan +Analysis +2406ms +Input: +{ + "rules": "auto", + "max_results": 50 +} + +Output: +null +14:31:27 +Analysis +Thought: semgrep工具未安装,我需要使用其他工具进行安全分析。我可以使用pattern_match工具来检测常见的危险模式,或者使用bandit_scan(如果项目是Python的)。从项目结构来看,这是一个Python项目,所以我可以尝试使用bandit_scan。 +14:31:29 +Completed: bandit_scan +Analysis +4ms +14:31:31 +Analysis +Thought: 看来bandit工具也未安装。我需要使用其他方法进行安全分析。我可以使用pattern_match工具来检测常见的危险模式,或者使用search_code工具搜索特定的关键字。让我先使用pattern_match来检测一些常见的危险模式。 + +首先,我需要读取一些关键文件的内容,然后使用pattern_match进行分析。让我先读取MindSpider/main.py,因为从项目结构看,这似乎是一个主要组件。 +14:31:35 +Completed: read_file +Analysis +5ms +14:31:57 +Analysis +Thought: 我已经看到了MindSpider/main.py文件的部分内容,这是一个AI爬虫项目的主程序。我注意到一些潜在的安全问题: +1. 使用了subprocess模块 +2. 数据库连接字符串构建 +3. 使用了pymysql和sqlalchemy进行数据库操作 + +让我继续读取这个文件的剩余部分,然后使用pattern_match工具来检测常见的危险模式。 +14:32:01 +Completed: read_file +Analysis +1ms +14:32:03 +Analysis +Thought: 我已经看到了MindSpider/main.py文件的更多内容,这个文件使用了subprocess模块来执行其他Python脚本,这可能是命令注入的风险点。现在让我使用pattern_match工具来检测常见的危险模式,包括SQL注入、命令注入、路径遍历等。 +14:32:41 +Completed: pattern_match +Analysis +4ms +14:32:42 +Analysis +Thought: pattern_match工具没有检测到已知的危险模式,但我注意到代码中使用了subprocess.run来执行命令,这仍然是一个潜在的安全风险。让我继续检查其他关键文件,特别是那些处理用户输入或执行外部命令的文件。 + +让我先检查MindSpider/DeepSentimentCrawling/main.py文件,因为从项目结构看,这似乎是另一个主要组件。 +14:32:46 +Completed: read_file +Analysis +2ms +14:32:47 +Analysis +Thought: 我已经看到了MindSpider/DeepSentimentCrawling/main.py文件的部分内容,这是一个爬虫模块。我注意到它使用了PlatformCrawler来执行爬取任务,这可能涉及网络请求。让我继续查看这个文件的剩余部分,然后检查platform_crawler.py文件,因为它可能包含更多的安全敏感代码。 +14:32:51 +Completed: read_file +Analysis +0ms +14:32:52 +Analysis +Thought: 我已经看到了MindSpider/DeepSentimentCrawling/main.py文件的完整内容,这是一个爬虫模块,它使用了PlatformCrawler来执行爬取任务。现在我需要检查platform_crawler.py文件,因为它可能包含更多的安全敏感代码,特别是网络请求和用户输入处理部分。 +14:32:55 +Completed: read_file +Analysis +1ms +14:32:57 +Analysis +Thought: 我已经看到了platform_crawler.py文件的部分内容,这个文件负责配置和调用MediaCrawler进行多平台爬取。我注意到一些潜在的安全问题: +1. 使用subprocess模块执行外部命令 +2. 文件操作可能存在路径遍历风险 +3. 数据库配置信息被硬编码写入文件 + +让我继续查看这个文件的剩余部分,然后使用search_code工具搜索一些特定的危险模式,如eval、exec、subprocess等。 +14:33:01 +Completed: read_file +Analysis +20ms +14:33:03 +Analysis +Thought: 我已经看到了platform_crawler.py文件的更多内容,这个文件使用了subprocess.run来执行MediaCrawler的main.py脚本,这是一个潜在的安全风险。现在让我使用search_code工具搜索一些特定的危险模式,如eval、exec、subprocess等,以便更全面地发现潜在的安全问题。 +14:33:06 +Completed: search_code +Analysis +287ms +14:33:07 +Analysis +Thought: 我发现了多个使用subprocess的地方,这些是潜在的安全风险点。现在让我搜索其他危险模式,如eval、exec、query等,以便更全面地发现潜在的安全问题。 +14:33:10 +Completed: search_code +Analysis +200ms +14:33:11 +Analysis +Thought: 我发现了一些使用eval的地方,大多数是Playwright的page.evaluate方法,用于在浏览器中执行JavaScript代码。这些通常是安全的,因为它们是在浏览器上下文中执行的。现在让我搜索exec和其他潜在的危险模式。 +14:33:14 +Completed: search_code +Analysis +157ms +14:33:16 +Analysis +Thought: 我发现了一些使用exec的地方,大多数是数据库操作或执行外部命令。现在让我搜索query和其他潜在的危险模式,特别是SQL查询相关的代码。 +14:33:18 +Completed: search_code +Analysis +5ms +14:33:20 +Analysis +Thought: 我已经搜索了subprocess、eval、exec和query等关键字,发现了一些潜在的安全风险点。现在让我使用dataflow_analysis工具来分析一些关键的数据流,特别是用户输入如何流向可能危险的函数。 +14:33:27 +Completed: dataflow_analysis +Analysis +65427ms +14:34:34 +Analysis +让我继续分析其他潜在的安全风险点。我将检查一些关键的Python文件,看看是否有明显的安全问题。 +14:34:36 +Completed: read_file +Analysis +6ms +14:34:38 +Analysis +Thought: 我已经看到了keyword_manager.py文件的部分内容,这个文件负责管理关键词和数据库操作。我注意到它使用了SQLAlchemy进行数据库操作,并且使用了参数化查询,这有助于防止SQL注入。让我继续查看这个文件的剩余部分,然后检查其他可能存在安全问题的文件。 +14:34:42 +Completed: read_file +Analysis +2ms +14:35:04 +Analysis +Thought: 我已经看到了keyword_manager.py文件的完整内容,这个文件使用了SQLAlchemy进行数据库操作,并且使用了参数化查询,这有助于防止SQL注入。现在让我检查.env.example文件,看看是否有敏感信息泄露的风险。 +14:35:07 +Completed: read_file +Analysis +5ms +14:35:09 +Analysis +Thought: 我已经看到了.env.example文件,这是一个环境变量配置示例文件,包含了各种API密钥和配置项。这个文件本身不包含敏感信息,因为它只是示例,但需要注意在实际使用时,.env文件可能包含敏感信息。 + +现在让我使用gitleaks_scan工具来检查代码中是否有密钥泄露或其他敏感信息。 +14:35:12 +Completed: gitleaks_scan +Analysis +73ms +14:35:18 +Analysis +Thought: 由于gitleaks工具未安装,我无法使用它来检测密钥泄露。让我使用search_code工具搜索一些常见的敏感信息模式,如API密钥、密码等。 +14:35:20 +Completed: search_code +Analysis +31ms +14:35:22 +Analysis +Thought: 我已经搜索了api_key关键字,发现了很多API密钥的配置。这些API密钥是通过环境变量或配置文件管理的,这是好的做法。现在让我搜索一些其他可能的敏感信息,如密码、密钥等。 +14:35:25 +Completed: search_code +Analysis +460ms +14:35:27 +Analysis +Thought: 我已经搜索了password关键字,发现了很多数据库密码的配置。这些密码是通过环境变量或配置文件管理的,这是好的做法。现在让我总结一下我的分析结果,并提供一个安全漏洞报告。 + +基于我的分析,我发现了以下潜在的安全问题: + +1. 使用subprocess执行外部命令 - 在多个文件中发现 +2. 硬编码的SECRET_KEY - 在app.py中 +3. 数据库配置信息被硬编码写入文件 - 在platform_crawler.py中 +4. 使用Playwright的page.evaluate执行JavaScript代码 - 在多个爬虫客户端中 + +让我使用think工具来总结我的分析结果。 +14:35:44 +Completed: think +Analysis +0ms +14:35:54 +Analysis +Thought: 基于我的分析,我现在可以提供一个完整的安全漏洞报告。我将使用Final Answer格式来总结我的发现。 +14:36:51 +Analysis +Thought: 我需要使用正确的工具格式来输出我的分析结果。让我使用think工具来总结我的分析,然后提供最终的安全漏洞报告。 +14:37:02 +Completed: think +Analysis +0ms +14:37:05 +Analysis +Thought: 我已经完成了对项目的安全分析,并记录了我的评估结果。现在我将使用reflect工具来总结我的分析进展和发现。 +14:37:26 +Completed: reflect +Analysis +165ms +14:37:35 +Analysis +Thought: 我已经完成了对项目的安全分析,并记录了我的发现和评估。现在我将总结我的分析结果,并提供一个完整的安全漏洞报告。 + +基于我的分析,我发现了以下潜在的安全问题: + +1. **命令注入风险**:在多个文件中使用subprocess.run执行外部命令,包括app.py、MindSpider/main.py和platform_crawler.py。虽然当前命令参数是硬编码的,但如果未来修改为接受用户输入,可能存在命令注入风险。 + +2. **硬编码的SECRET_KEY**:在app.py第36行,Flask应用的SECRET_KEY是硬编码的,这降低了密钥的安全性。 + +3. **数据库配置信息硬编码写入文件**:在platform_crawler.py中,数据库配置信息被硬编码写入MediaCrawler的配置文件中,这可能导致敏感信息泄露。 + +4. **JavaScript代码执行**:在多个爬虫客户端中使用Playwright的page.evaluate方法执行JavaScript代码,虽然这些通常是在浏览器上下文中执行的,但如果用户输入被不当处理,可能存在XSS风险。 + +总体而言,项目在敏感信息管理方面做得较好,使用了参数化查询防止SQL注入,但需要注意subprocess使用和硬编码密钥的问题。 +14:37:51 +Analysis +Action +14:37:59 +Completed: think +Analysis +19ms +14:37:59 +✅ analysis Agent 完成 +Orchestrator +14:38:00 +Orchestrator +Thought: analysis Agent 也没有返回任何有用的信息。这可能是因为项目文件结构复杂,或者 Agent 需要更具体的指导。我将尝试一个更直接的方法,先查看项目的根目录文件,特别是配置文件和主要入口文件,然后再进行深入分析。 +14:38:05 +📤 调度 recon Agent: 查看项目根目录下的所有文件,特别关注config.py, app.py, report_engine_only.py, requirements.txt等关键文件... +Orchestrator +14:38:15 +Recon +Thought: 我需要开始收集项目信息。首先,我应该了解项目的整体结构,包括目录布局和主要文件。这将帮助我识别项目类型、技术栈和可能的入口点。我将从列出根目录开始。 +14:38:38 +Completed: list_files +Recon +15ms +14:38:39 +Recon +Thought: 从文件列表中,我可以看到这是一个包含多个组件的项目,主要有: +1. MediaEngine - 一个Python模块,包含LLM相关功能 +2. MindSpider - 一个爬虫项目,包含DeepSentimentCrawling组件 +3. 一些根级别的Python文件如app.py, config.py等 + +我需要进一步了解项目的技术栈和结构。首先,我应该查看根目录下的配置文件,如requirements.txt、package.json等,以了解项目依赖和使用的框架。 +14:38:44 +Completed: read_file +Recon +0ms +14:38:45 +Recon +Thought: 从requirements.txt文件中,我可以看到这是一个舆情分析系统,使用了多种技术栈: +1. Web框架:Flask, FastAPI, Streamlit +2. 数据库:MySQL, PostgreSQL, SQLite, Redis +3. 爬虫技术:Playwright, BeautifulSoup等 +4. LLM集成:OpenAI API +5. 数据处理:pandas, numpy等 + +现在我需要查看MindSpider/DeepSentimentCrawling/MediaCrawler目录下的package.json文件,以了解前端依赖。 +14:38:51 +Completed: read_file +Recon +8ms +14:38:52 +Recon +Thought: package.json文件显示这是一个使用VitePress的文档项目,没有前端应用依赖。现在我需要查看主要的入口点文件,如app.py、MindSpider/main.py等,以了解应用程序的入口点和主要功能。 +14:38:55 +Completed: read_file +Recon +3ms +14:38:57 +Recon +Thought: app.py文件很长,我需要继续查看更多内容来了解应用程序的主要功能和入口点。让我继续读取文件的剩余部分。 +14:39:20 +Completed: read_file +Recon +2ms +14:39:21 +Recon +Thought: 我需要继续查看app.py文件的更多内容,特别是路由定义部分,以了解应用程序的API端点。让我继续读取文件的后续部分。 +14:39:24 +Completed: read_file +Recon +2ms +14:39:45 +Recon +Thought: 我需要继续查看app.py文件,特别是路由定义部分,以了解应用程序的API端点。让我继续读取文件的后续部分。 +14:39:50 +Completed: read_file +Recon +6ms +14:40:57 +Recon +Thought: 我需要继续查看app.py文件的最后部分,以了解所有的API路由。让我读取文件的剩余部分。 +14:41:04 +Completed: read_file +Recon +23ms +14:41:37 +Recon +Thought: 我需要查看app.py文件的最后部分,以了解所有的API路由。让我读取文件的最后部分。 +14:41:40 +Completed: read_file +Recon +5ms +14:42:03 +Recon +Thought: 现在我已经查看了app.py文件,了解了主要的API路由和应用程序结构。接下来,我需要查看MindSpider/main.py文件,以了解爬虫组件的主要功能。 +14:42:08 +Completed: read_file +Recon +3ms +14:42:16 +Recon +Thought: 我需要继续查看MindSpider/main.py文件的剩余部分,以了解爬虫组件的主要功能。 +14:42:18 +Completed: read_file +Recon +1ms +14:43:42 +Recon +Thought: 我需要继续查看MindSpider/main.py文件的剩余部分,以了解爬虫组件的主要功能。 +14:43:45 +Completed: read_file +Recon +4ms +14:43:48 +Recon +Thought: 现在我已经查看了MindSpider/main.py文件,了解了爬虫组件的主要功能。接下来,我需要查看配置文件,以了解项目的配置结构和敏感信息。 +14:43:51 +Completed: read_file +Recon +5ms +14:43:53 +Recon +Thought: 我需要继续查看config.py文件的剩余部分,以了解项目的完整配置结构。 +14:43:55 +Completed: read_file +Recon +0ms +14:43:57 +Recon +Thought: 现在我已经查看了config.py文件,了解了项目的配置结构。接下来,我需要查看.env.example文件,以了解环境变量的配置示例。 +14:44:00 +Completed: read_file +Recon +14ms +14:44:00 +✅ recon Agent 完成 +Orchestrator +14:44:00 +Orchestrator +Thought: 多次尝试使用 recon 和 analysis Agent 都没有返回任何有用的信息。这可能是因为项目结构复杂或者 Agent 无法正确访问文件。我将尝试一个更直接的方法,先查看项目的根目录文件,特别关注配置文件和主要入口文件,然后再进行深入分析。 +14:44:04 +📤 调度 recon Agent: 查看项目根目录下的所有文件,特别关注config.py, app.py, report_engine_only.py, requirements.txt等关键文件... +Orchestrator +14:45:23 +Recon +Thought: 我需要开始收集项目信息。首先,我应该了解项目的整体结构,包括目录布局和主要文件。这将帮助我识别项目类型、技术栈和关键组件。我将从列出根目录的内容开始。 +14:45:26 +Completed: list_files +Recon +2ms +14:46:04 +Recon +Thought: 我已经看到了项目的根目录结构。这是一个多引擎的项目,包含多个子模块如ForumEngine、InsightEngine、MediaEngine等。我注意到根目录下有app.py和config.py,这表明这可能是一个Python项目。接下来,我应该查看requirements.txt文件来了解项目的依赖和技术栈,同时也要查看config.py来了解项目配置。 +14:46:07 +Completed: read_file +Recon +0ms +14:46:08 +Recon +Thought: 从requirements.txt文件可以看出,这是一个基于Python的舆情分析系统,使用了多种技术栈: +1. Web框架:Flask、Streamlit、FastAPI +2. 数据库:MySQL (pymysql, aiomysql)、SQLite (aiosqlite)、Redis、PostgreSQL (asyncpg) +3. 数据处理:pandas, numpy, jieba (中文分词) +4. 机器学习:torch, transformers, sentence-transformers等 +5. 爬虫:playwright, beautifulsoup4等 +6. 可视化:plotly, matplotlib, wordcloud + +接下来,我应该查看config.py文件来了解项目的配置信息,然后查看app.py作为主入口点。 +14:46:14 +Completed: read_file +Recon +1ms +14:46:15 +Recon +Thought: 从config.py文件可以看出,这是一个复杂的舆情分析系统,包含多个AI代理引擎(Insight Engine、Media Engine、Query Engine、Report Engine、MindSpider等)。系统使用Flask作为Web框架,支持多种数据库(MySQL、PostgreSQL),并集成了多个LLM API(如Kimi、Gemini、DeepSeek等)。 + +接下来,我应该查看app.py文件,这是主入口点,可以了解系统的路由和主要功能。 +14:46:18 +Completed: read_file +Recon \ No newline at end of file diff --git a/backend/架构升级方案.md b/backend/架构升级方案.md new file mode 100644 index 0000000..9f96f57 --- /dev/null +++ b/backend/架构升级方案.md @@ -0,0 +1,527 @@ +DeepAudit Agent 架构重构升级方案 +一、现状分析 +当前 DeepAudit 架构特点 +DeepAudit 目前采用基于 LangGraph 的固定流程图架构。整个审计流程按照 Recon(信息收集)→ Analysis(漏洞分析)→ Verification(漏洞验证)→ Report(报告生成)的线性顺序执行。每个阶段由一个专门的 Agent 负责,Agent 之间通过 TaskHandoff 机制传递结构化的上下文信息。 + +这种架构的优点是流程清晰、易于理解和调试,但存在几个明显的局限性: + +第一,流程过于固定。无论面对什么类型的项目或漏洞,都必须走完整个流程,无法根据实际发现动态调整策略。比如发现了一个 SQL 注入线索,无法立即深入分析,必须等待 Analysis 阶段统一处理。 + +第二,Agent 专业化程度不足。Analysis Agent 需要同时处理所有类型的漏洞,从 SQL 注入到 XSS 到 SSRF,这导致系统提示词过于庞大,LLM 难以在每种漏洞类型上都表现出专家级水平。 + +第三,缺乏动态协作能力。Agent 之间只能按照预设的顺序传递信息,无法根据需要动态创建新的 Agent 或在 Agent 之间进行实时通信。 + +Strix 架构的启示 +Strix 是一个开源的 AI 安全测试 Agent 项目,它采用了完全不同的架构理念。通过深入分析 Strix 的设计,我们可以获得以下关键启示: + +Strix 的核心是动态 Agent 树结构。根 Agent 可以根据任务需要随时创建子 Agent,每个子 Agent 专注于特定的漏洞类型或任务。子 Agent 完成后向父 Agent 汇报结果,父 Agent 可以根据结果决定是否需要创建更多子 Agent 或进行其他操作。 + +Strix 的另一个亮点是模块化的专业知识系统。它为每种漏洞类型都准备了详细的 Jinja2 模板,包含该漏洞的检测方法、利用技术、绕过手段、验证步骤等专业知识。创建 Agent 时可以指定加载哪些知识模块,让 Agent 在特定领域具备专家级能力。 + +此外,Strix 还实现了 Agent 间的消息传递机制、完善的状态管理、工具的沙箱执行、LLM 调用优化等高级特性。 + +二、升级后的整体架构 +核心设计理念 +升级后的 DeepAudit 将采用"动态 Agent 协作 + 专业知识模块 + 智能编排"的三层架构。 + +最底层是专业知识模块层,包含各种漏洞类型、框架、技术栈的专业知识库。这些知识以模板形式存储,可以按需加载到 Agent 的系统提示词中。 + +中间层是 Agent 执行层,包含可动态创建和销毁的 Agent 实例。每个 Agent 都有完整的生命周期管理,可以执行任务、调用工具、与其他 Agent 通信。 + +最上层是智能编排层,负责根据审计目标和实时发现来协调整个审计流程,决定何时创建什么类型的 Agent,如何分配任务,何时结束审计。 + +动态 Agent 树 +与当前固定的四阶段流程不同,升级后的系统将采用动态 Agent 树结构。 + +审计开始时,系统创建一个根 Agent(Root Agent)。根 Agent 首先进行初步的信息收集,了解项目的技术栈、目录结构、入口点等基本信息。然后根据收集到的信息,根 Agent 决定需要创建哪些专业子 Agent。 + +例如,如果发现项目使用了 SQL 数据库,根 Agent 可能会创建一个专门的 SQL 注入检测 Agent;如果发现有用户输入直接渲染到页面的代码,可能会创建一个 XSS 检测 Agent;如果发现有 HTTP 请求的代码,可能会创建一个 SSRF 检测 Agent。 + +每个子 Agent 专注于自己的任务领域。当子 Agent 发现可疑的漏洞线索时,它可以进一步创建验证子 Agent 来确认漏洞是否真实存在。验证通过后,还可以创建报告子 Agent 来生成正式的漏洞报告。 + +这种树状结构的好处是:任务可以无限细分,每个 Agent 都能专注于自己擅长的领域;发现和验证可以并行进行,提高效率;根据实际情况动态调整策略,而不是机械地执行固定流程。 + +Agent 间通信机制 +升级后的系统将实现完善的 Agent 间通信机制。 + +每个 Agent 都有一个消息队列,其他 Agent 可以向这个队列发送消息。消息类型包括:查询消息(请求信息)、指令消息(要求执行某个操作)、信息消息(分享发现或状态)。 + +当 Agent 处于等待状态时,它会检查自己的消息队列。如果有新消息,Agent 会处理消息并可能恢复执行。这种机制使得 Agent 之间可以进行实时协作,而不仅仅是单向的结果传递。 + +例如,SQL 注入检测 Agent 在分析过程中发现某个函数可能存在问题,但需要了解这个函数的调用上下文。它可以向根 Agent 发送查询消息,请求提供相关信息。根 Agent 收到消息后,可以调用代码搜索工具获取信息,然后回复给 SQL 注入检测 Agent。 + +Agent 状态管理 +每个 Agent 都有完整的状态管理,状态信息包括: + +基本信息:Agent 的唯一标识、名称、父 Agent 标识、创建时间等。 + +任务信息:当前任务描述、任务上下文、从父 Agent 继承的信息等。 + +执行状态:当前迭代次数、最大迭代限制、运行状态(运行中、等待中、已完成、失败、已停止)等。 + +对话历史:与 LLM 的完整对话记录,包括系统提示词、用户消息、助手回复等。 + +执行记录:已执行的动作列表、观察结果列表、错误记录等。 + +发现列表:该 Agent 发现的所有漏洞和可疑点。 + +这种完整的状态管理使得 Agent 可以被暂停和恢复,可以被序列化和持久化,也便于调试和审计。 + +三、专业知识模块系统 +模块化设计 +专业知识模块是升级后架构的核心创新之一。我们将为不同的漏洞类型、框架、技术栈创建专门的知识模块。 + +漏洞类型模块包括:SQL 注入、XSS、SSRF、IDOR、认证绕过、远程代码执行、路径遍历、XXE、CSRF、竞态条件、反序列化、业务逻辑漏洞等。每个模块都包含该漏洞类型的完整知识体系。 + +框架知识模块包括:FastAPI、Django、Flask、Express、Next.js、Spring、Laravel 等主流框架。每个模块包含该框架的安全特性、常见漏洞模式、最佳实践等。 + +技术栈模块包括:Supabase、Firebase、GraphQL、gRPC、WebSocket 等。每个模块包含该技术的安全考量和常见问题。 + +模块内容结构 +以 SQL 注入模块为例,它应该包含以下内容: + +漏洞概述:SQL 注入的定义、危害、影响范围。 + +检测方法:错误型注入检测、布尔型注入检测、时间型注入检测、带外注入检测的具体技术和判断标准。 + +数据库特定知识:MySQL、PostgreSQL、MSSQL、Oracle 等不同数据库的特有语法、函数、利用技术。 + +绕过技术:WAF 绕过、过滤绕过、编码绕过等高级技术。 + +ORM 和查询构建器:各种 ORM 框架中容易出现 SQL 注入的 API 和模式。 + +验证步骤:如何确认漏洞真实存在,如何构造 PoC,如何评估影响。 + +误报识别:哪些情况容易被误判为 SQL 注入,如何排除误报。 + +修复建议:参数化查询、ORM 正确用法、输入验证等修复方案。 + +模块加载机制 +创建 Agent 时,可以指定该 Agent 需要加载哪些知识模块。系统会将这些模块的内容动态注入到 Agent 的系统提示词中。 + +为了控制提示词长度和保持 Agent 的专注度,每个 Agent 最多加载 5 个知识模块。这个限制迫使我们为每个 Agent 选择最相关的知识,而不是试图让一个 Agent 掌握所有知识。 + +模块之间可以有依赖关系。例如,FastAPI 框架模块可能依赖 Python 安全基础模块;GraphQL 模块可能依赖 API 安全基础模块。加载模块时会自动处理这些依赖。 + +四、工具系统升级 +统一的工具注册机制 +升级后的工具系统将采用装饰器模式进行统一注册。每个工具都需要提供:工具名称、功能描述、参数定义、返回值说明。 + +工具按类别组织,包括:文件操作类(读取文件、搜索文件、列出目录)、代码分析类(模式匹配、数据流分析、AST 分析)、外部扫描类(Semgrep、Bandit、Gitleaks 等)、验证执行类(沙箱命令执行、HTTP 请求)、协作类(创建子 Agent、发送消息、等待消息)、推理类(思考工具)、报告类(创建漏洞报告)。 + +Think 工具 +Think 工具是从 Strix 借鉴的关键创新。这是一个让 Agent 进行深度推理的工具,Agent 可以用它来: + +分析复杂情况:当面对复杂的代码逻辑或不确定的漏洞线索时,Agent 可以调用 Think 工具进行深入思考。 + +规划下一步行动:在执行具体操作之前,先用 Think 工具规划策略。 + +评估发现的严重性:发现可疑点后,用 Think 工具评估其真实性和影响。 + +决定是否需要创建子 Agent:当任务变得复杂时,用 Think 工具分析是否需要分解任务。 + +Think 工具的输出会被记录到 Agent 的对话历史中,帮助 LLM 保持思路的连贯性。 + +漏洞报告工具 +漏洞报告工具是正式记录漏洞的唯一方式。只有通过这个工具创建的漏洞才会被计入最终报告。这个设计确保了漏洞报告的规范性和完整性。 + +报告工具要求提供完整的漏洞信息:漏洞类型、严重程度、标题、详细描述、文件位置、代码片段、PoC、影响分析、修复建议等。 + +通常只有专门的报告 Agent 才会调用这个工具,确保漏洞在被正式报告之前已经经过了充分的验证。 + +沙箱执行 +涉及代码执行或网络请求的工具都在沙箱环境中运行。沙箱提供资源隔离(CPU、内存、网络限制)、文件系统隔离、超时控制等安全保障。 + +沙箱执行通过 Tool Server 机制实现。Agent 发送工具调用请求到 Tool Server,Tool Server 在沙箱中执行工具并返回结果。这种设计使得即使工具执行出现问题,也不会影响主系统的稳定性。 + +五、LLM 调用优化 +Prompt Caching +对于支持 Prompt Caching 的 LLM(如 Anthropic Claude),系统会自动为系统提示词和早期对话添加缓存标记。这样在多轮对话中,这些内容只需要处理一次,后续调用可以直接使用缓存,显著降低 Token 消耗和响应延迟。 + +缓存策略会根据对话长度动态调整。对于短对话,只缓存系统提示词;对于长对话,会在关键位置添加多个缓存点。 + +Memory Compression +当对话历史变得很长时,系统会自动进行压缩。压缩策略包括: + +移除冗余信息:重复的工具调用结果、过长的代码输出等会被截断或摘要。 + +合并相似消息:连续的同类型消息可能被合并。 + +保留关键信息:重要的发现、决策点、错误信息等会被优先保留。 + +压缩后的对话历史仍然保持语义完整性,LLM 可以理解之前发生了什么,但 Token 消耗大大降低。 + +智能重试 +LLM 调用可能因为各种原因失败:网络问题、速率限制、服务不可用等。系统实现了智能重试机制: + +对于可重试的错误(如速率限制),会等待适当时间后重试。 + +对于不可重试的错误(如认证失败),会立即报错并提供清晰的错误信息。 + +重试时会使用指数退避策略,避免对 LLM 服务造成过大压力。 + +六、审计流程重构 +启动阶段 +用户发起审计请求后,系统首先创建根 Agent。根 Agent 加载通用的安全审计知识模块和项目相关的框架知识模块。 + +根 Agent 的第一个任务是信息收集:扫描项目目录结构、识别技术栈、找出入口点、分析依赖关系。这个阶段类似于当前的 Recon 阶段,但更加灵活。 + +任务分解阶段 +根据信息收集的结果,根 Agent 决定如何分解审计任务。它会考虑: + +项目使用了哪些技术?需要创建哪些专业 Agent? + +有哪些高风险区域?应该优先分析哪些部分? + +项目规模如何?需要多少并行 Agent? + +根 Agent 会创建一批初始的子 Agent,每个子 Agent 负责特定的漏洞类型或代码区域。 + +并行分析阶段 +多个子 Agent 并行工作,各自在自己的专业领域进行深入分析。 + +每个子 Agent 都有自己的工作循环:思考当前状态、选择工具执行、观察结果、决定下一步。这个循环会持续进行,直到 Agent 认为任务完成或达到迭代限制。 + +子 Agent 在分析过程中可能会发现需要进一步调查的线索。这时它可以创建更专业的子 Agent 来处理,形成多层的 Agent 树。 + +验证阶段 +当分析 Agent 发现可疑的漏洞时,它会创建验证 Agent 来确认漏洞是否真实存在。 + +验证 Agent 会尝试构造 PoC、进行数据流追踪、在沙箱中测试等。验证通过后,验证 Agent 会创建报告 Agent 来生成正式的漏洞报告。 + +如果验证失败,验证 Agent 会将结果反馈给父 Agent,父 Agent 可以决定是否需要进一步调查或将其标记为误报。 + +汇总阶段 +当所有子 Agent 都完成工作后,根 Agent 会汇总所有发现,生成最终的审计报告。 + +报告包括:发现的所有漏洞(按严重程度排序)、安全评分、技术栈分析、高风险区域标注、修复建议优先级等。 + +七、可观测性和调试 +完整的事件追踪 +系统会记录所有重要事件:Agent 创建和销毁、工具调用和结果、LLM 请求和响应、Agent 间消息、状态变更等。 + +这些事件可以实时推送到前端,让用户看到审计的进展。也可以持久化到数据库,用于后续分析和审计。 + +Agent 树可视化 +前端可以展示当前的 Agent 树结构,显示每个 Agent 的状态、任务、发现数量等信息。用户可以点击任何 Agent 查看其详细信息和对话历史。 + +调试模式 +在调试模式下,系统会记录更详细的信息,包括完整的 LLM 提示词和响应、工具执行的详细日志、状态变更的完整历史等。这些信息对于排查问题和优化系统非常有价值。 + +八、与现有架构的兼容 +渐进式迁移 +升级不需要一次性完成,可以渐进式进行。 + +第一步,保持现有的 LangGraph 流程不变,但将 Agent 的状态管理升级为新的模型。 + +第二步,引入专业知识模块系统,让现有的 Analysis Agent 可以加载不同的知识模块。 + +第三步,在 Analysis 阶段内部引入子 Agent 机制,允许创建专业的漏洞检测子 Agent。 + +第四步,逐步放开流程限制,让 Agent 可以更灵活地决定下一步操作。 + +第五步,完全迁移到动态 Agent 树架构。 + +保留 LangGraph 的优势 +LangGraph 提供了很好的状态管理和检查点机制,这些在新架构中仍然有价值。我们可以将 LangGraph 用于根 Agent 的高层编排,而在子 Agent 层面使用更灵活的动态创建机制。 + +九、预期收益 +更深度的漏洞发现 +专业知识模块让每个 Agent 都具备安全专家级别的知识。专注于单一漏洞类型的 Agent 比通用 Agent 更容易发现深层次的问题。 + +更高的效率 +并行的 Agent 执行比串行流程更快。动态任务分解避免了在无关区域浪费时间。 + +更低的成本 +Prompt Caching 和 Memory Compression 显著降低 Token 消耗。专业化的 Agent 使用更短的提示词就能达到更好的效果。 + +更好的可扩展性 +添加新的漏洞类型只需要创建新的知识模块。支持新的框架只需要添加框架知识模块。整个系统的扩展不需要修改核心架构。 + +更强的可解释性 +完整的事件追踪和 Agent 树可视化让用户清楚地了解系统在做什么。Think 工具的输出展示了 Agent 的推理过程。 + +这个升级方案借鉴了 Strix 的核心设计理念,同时保留了 DeepAudit 的既有优势,通过渐进式迁移降低风险,最终实现一个更强大、更灵活、更专业的安全审计 Agent 系统。 + + +--- + +## 十、实施进度记录 + +### 已完成的工作 (2024-12) + +#### 1. 核心模块系统 ✅ +- `core/state.py`: 增强的Agent状态管理,支持完整生命周期 +- `core/registry.py`: Agent注册表和动态Agent树管理 +- `core/message.py`: Agent间通信机制(消息总线) + +#### 2. 专业知识模块系统 ✅ (基于RAG) +采用模块化文件组织,统一使用RAG进行知识检索: + +``` +knowledge/ +├── base.py # 基础定义(KnowledgeDocument, KnowledgeCategory) +├── loader.py # 知识加载器 +├── rag_knowledge.py # RAG检索系统 +├── tools.py # 知识查询工具 +├── vulnerabilities/ # 漏洞类型知识 +│ ├── injection.py # SQL注入、NoSQL注入、命令注入、代码注入 +│ ├── xss.py # 反射型XSS、存储型XSS、DOM型XSS +│ ├── auth.py # 认证绕过、IDOR、访问控制失效 +│ ├── crypto.py # 弱加密、硬编码凭证 +│ ├── ssrf.py # SSRF +│ ├── deserialization.py # 不安全的反序列化 +│ ├── path_traversal.py # 路径遍历 +│ ├── xxe.py # XXE +│ └── race_condition.py # 竞态条件 +└── frameworks/ # 框架安全知识 + ├── fastapi.py # FastAPI安全 + ├── django.py # Django安全 + ├── flask.py # Flask安全 + ├── express.py # Express.js安全 + ├── react.py # React安全 + └── supabase.py # Supabase安全 +``` + +#### 3. Agent基类增强 ✅ +- 支持动态Agent树(parent_id, 子Agent创建) +- Agent间消息通信 +- TaskHandoff协作机制 +- 知识模块加载 +- Memory Compression集成 + +#### 4. 工具系统 ✅ +- `thinking_tool.py`: Think和Reflect工具 +- `reporting_tool.py`: 漏洞报告工具 +- `agent_tools.py`: Agent协作工具 + - CreateSubAgentTool: 动态创建子Agent + - SendMessageTool: Agent间消息发送 + - ViewAgentGraphTool: 查看Agent树 + - WaitForMessageTool: 等待消息 + - AgentFinishTool: 子Agent完成报告 + +#### 5. LLM调用优化 ✅ +- `memory_compressor.py`: 对话历史压缩 + - 自动检测是否需要压缩 + - 保留关键信息(发现、工具使用、决策、错误) + - 智能摘要生成 +- Agent基类集成自动压缩 + +#### 6. Orchestrator Agent ✅ +- LLM驱动的编排决策 +- 动态调度子Agent +- ReAct模式执行 + +### 已完成的工作 (2024-12 续) + +#### 7. Prompt Caching ✅ +- `llm/prompt_cache.py`: Prompt 缓存管理器 + - 支持 Anthropic Claude 的 Prompt Caching + - 动态缓存策略(SYSTEM_ONLY, SYSTEM_AND_EARLY, MULTI_POINT) + - 缓存统计和效果追踪 + - Token 估算工具 +- LiteLLM 适配器集成缓存支持 + +#### 8. 动态Agent树执行器 ✅ +- `core/executor.py`: 完整的执行器实现 + - `DynamicAgentExecutor`: 动态Agent树执行器 + - 并行Agent执行(带信号量控制) + - 任务依赖管理 + - 执行结果汇总 + - 超时和取消处理 + - `SubAgentExecutor`: 子Agent执行器 + - 从父Agent创建和执行子Agent + - 并行子Agent执行 + - 结果收集和汇总 + - `ExecutionTask`: 执行任务数据结构 + - `ExecutionResult`: 执行结果数据结构 + +#### 9. Agent状态持久化 ✅ +- `core/persistence.py`: 持久化模块 + - `AgentStatePersistence`: 状态持久化管理器 + - 文件系统持久化 + - 数据库持久化(可选) + - 检查点列表和清理 + - `CheckpointManager`: 检查点管理器 + - 自动检查点(按迭代间隔) + - 检查点恢复 + - 状态回滚 + +#### 10. 增强的Agent协作工具 ✅ +- `CreateSubAgentTool`: 增强版 + - 支持立即执行模式 + - 集成SubAgentExecutor + - 上下文传递 +- `RunSubAgentsTool`: 批量执行子Agent + - 并行/顺序执行 + - 结果汇总 +- `CollectSubAgentResultsTool`: 收集子Agent结果 + +#### 11. 数据库模型扩展 ✅ +- `AgentCheckpoint`: Agent检查点模型 + - 状态数据存储 + - 执行统计 + - 检查点类型(auto/manual/error/final) +- `AgentTreeNode`: Agent树节点模型 + - 树结构记录 + - 执行状态追踪 + - 结果汇总 +- Alembic迁移脚本: `007_add_agent_checkpoint_tables.py` + +#### 12. API 端点 ✅ +- `GET /agent-tasks/{task_id}/agent-tree`: Agent树查询API + - 返回完整的Agent树结构 + - 支持运行时内存查询和数据库查询 + - 包含执行状态和发现统计 +- `GET /agent-tasks/{task_id}/checkpoints`: 检查点列表API + - 支持按Agent ID过滤 + - 分页支持 +- `GET /agent-tasks/{task_id}/checkpoints/{checkpoint_id}`: 检查点详情API + - 返回完整的Agent状态数据 + +### 已完成的工作 (2024-12 续2) + +#### 13. 前端 Agent 审计页面 ✅ (Strix-inspired Terminal UI) +- `frontend/src/shared/api/agentTasks.ts`: 扩展 API + - `AgentTreeNode`, `AgentTreeResponse` 类型定义 + - `AgentCheckpoint`, `CheckpointDetail` 类型定义 + - `getAgentTree()`: 获取 Agent 树结构 + - `getAgentCheckpoints()`: 获取检查点列表 + - `getCheckpointDetail()`: 获取检查点详情 + +- `frontend/src/pages/AgentAudit.tsx`: 统一的 Agent 审计页面 (参考 Strix TUI 设计) + - **布局**: 左侧活动日志 (75%) + 右侧 Agent 树和统计 (25%) + - **启动画面**: ASCII Art + 动画加载效果 + - **活动日志**: + - 实时流式显示 Agent 思考过程 + - 工具调用和结果展示 + - 漏洞发现高亮 + - 自动滚动控制 + - 可折叠的日志条目 + - **Agent 树可视化**: + - 树状结构展示 + - 节点状态图标(运行中/已完成/失败/等待) + - 发现数量徽章 + - 节点选择交互 + - **实时统计面板**: + - 进度百分比 + - 文件分析进度 + - Token 使用量 + - 发现数量 + - 严重程度分布 + - **创建任务对话框**: 选择项目后直接跳转到实时流页面 + - **任务控制**: 停止/取消任务 + +- `frontend/src/app/routes.tsx`: 路由配置 + - `/agent-audit`: 启动画面 + 创建任务 + - `/agent-audit/:taskId`: 任务实时流页面 + +- `frontend/src/components/layout/Sidebar.tsx`: 侧边栏导航 + - 新增 Agent 审计入口图标 + +### 已完成的工作 (2024-12 续3) + +#### 14. 执行架构切换 ✅ +- **移除旧的 LangGraph 固定流程架构** +- **启用新的动态 Agent 树架构** +- `backend/app/api/v1/endpoints/agent_tasks.py`: + - `_execute_agent_task()` 重写为使用 `OrchestratorAgent` + - OrchestratorAgent 作为大脑,动态调度子 Agent + - 子 Agent: ReconAgent, AnalysisAgent, VerificationAgent + - 新增辅助函数: `_get_user_config()`, `_initialize_tools()`, `_collect_project_info()`, `_save_findings()`, `_calculate_security_score()` + +### 待完成的工作 + +#### 1. 前端增强 +- 知识模块选择 UI(创建任务时) +- 检查点恢复功能 +- 导出报告功能 + +#### 2. 测试和优化 +- 单元测试 +- 集成测试 +- 性能优化 +- 并发执行压力测试 + +#### 3. 文档 +- API文档更新 +- 架构图更新 +- 使用指南 + +--- + +## 十一、架构升级总结 + +### 已实现的核心功能 + +1. **Prompt Caching** - 为 Claude 等 LLM 提供缓存支持,减少 Token 消耗 +2. **动态 Agent 树执行** - OrchestratorAgent 作为大脑,动态调度子 Agent +3. **Agent 状态持久化** - 文件系统和数据库双重持久化 +4. **检查点机制** - 自动检查点、状态恢复、执行历史追踪 +5. **增强的协作工具** - 子 Agent 创建、批量执行、结果收集 +6. **完整的 API 支持** - Agent 树查询、检查点管理 +7. **旧架构已移除** - 不再使用 LangGraph 固定流程,完全切换到动态 Agent 树 + +### 文件清单 + +``` +backend/app/services/ +├── llm/ +│ ├── __init__.py # 模块导出 +│ ├── prompt_cache.py # 🆕 Prompt Caching +│ ├── memory_compressor.py # Memory Compression +│ └── adapters/ +│ └── litellm_adapter.py # 集成 Prompt Caching +│ +├── agent/ +│ ├── core/ +│ │ ├── __init__.py # 模块导出 +│ │ ├── state.py # Agent 状态管理 +│ │ ├── registry.py # Agent 注册表 +│ │ ├── message.py # Agent 间通信 +│ │ ├── executor.py # 🆕 动态 Agent 树执行器 +│ │ └── persistence.py # 🆕 状态持久化 +│ │ +│ ├── tools/ +│ │ ├── __init__.py # 模块导出 +│ │ ├── agent_tools.py # 🔄 增强的协作工具 +│ │ ├── thinking_tool.py # Think/Reflect 工具 +│ │ └── reporting_tool.py # 漏洞报告工具 +│ │ +│ ├── knowledge/ # 专业知识模块 +│ │ ├── vulnerabilities/ # 漏洞类型知识 +│ │ └── frameworks/ # 框架安全知识 +│ │ +│ └── agents/ +│ ├── base.py # Agent 基类 +│ ├── orchestrator.py # 编排 Agent +│ ├── analysis.py # 分析 Agent +│ └── verification.py # 验证 Agent + +backend/app/models/ +└── agent_task.py # 🔄 新增 AgentCheckpoint, AgentTreeNode + +backend/app/api/v1/endpoints/ +└── agent_tasks.py # 🔄 新增 Agent 树和检查点 API + +backend/alembic/versions/ +└── 007_add_agent_checkpoint_tables.py # 🆕 数据库迁移 + +frontend/src/shared/api/ +└── agentTasks.ts # 🔄 扩展 Agent 树和检查点 API + +frontend/src/pages/ +└── AgentAudit.tsx # 🆕 统一的 Agent 审计页面 (Strix-inspired) + +frontend/src/app/ +└── routes.tsx # 🔄 新增 Agent 审计路由 + +frontend/src/components/layout/ +└── Sidebar.tsx # 🔄 新增 Agent 审计导航图标 +``` + +### 下一步计划 + +1. 测试前端页面渲染和流式事件 +2. 知识模块选择 UI +3. 检查点恢复功能 diff --git a/frontend/src/app/routes.tsx b/frontend/src/app/routes.tsx index cfd4072..a2c52c0 100644 --- a/frontend/src/app/routes.tsx +++ b/frontend/src/app/routes.tsx @@ -59,6 +59,12 @@ const routes: RouteConfig[] = [ }, { name: "Agent审计", + path: "/agent-audit", + element: , + visible: true, + }, + { + name: "Agent审计任务", path: "/agent-audit/:taskId", element: , visible: false, diff --git a/frontend/src/components/agent/CreateAgentTaskDialog.tsx b/frontend/src/components/agent/CreateAgentTaskDialog.tsx new file mode 100644 index 0000000..d1858a4 --- /dev/null +++ b/frontend/src/components/agent/CreateAgentTaskDialog.tsx @@ -0,0 +1,593 @@ +/** + * Agent 审计任务创建对话框 + * 专门用于 Agent Audit 页面,UI 风格与终端界面保持一致 + */ + +import { useState, useEffect, useMemo } from "react"; +import { useNavigate } from "react-router-dom"; +import { + Dialog, + DialogContent, + DialogHeader, + DialogTitle, +} from "@/components/ui/dialog"; +import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; +import { Badge } from "@/components/ui/badge"; +import { ScrollArea } from "@/components/ui/scroll-area"; +import { + Collapsible, + CollapsibleContent, + CollapsibleTrigger, +} from "@/components/ui/collapsible"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { + Search, + ChevronRight, + GitBranch, + Package, + Globe, + Loader2, + Bot, + Settings2, + Play, + Upload, + FolderOpen, +} from "lucide-react"; +import { toast } from "sonner"; +import { api } from "@/shared/config/database"; +import { createAgentTask } from "@/shared/api/agentTasks"; +import { isRepositoryProject, isZipProject } from "@/shared/utils/projectUtils"; +import { getZipFileInfo, type ZipFileMeta } from "@/shared/utils/zipStorage"; +import { validateZipFile } from "@/features/projects/services/repoZipScan"; +import type { Project } from "@/shared/types"; +import FileSelectionDialog from "@/components/audit/FileSelectionDialog"; + +interface CreateAgentTaskDialogProps { + open: boolean; + onOpenChange: (open: boolean) => void; +} + +const DEFAULT_EXCLUDES = [ + "node_modules/**", + ".git/**", + "dist/**", + "build/**", + "*.log", +]; + +export default function CreateAgentTaskDialog({ + open, + onOpenChange, +}: CreateAgentTaskDialogProps) { + const navigate = useNavigate(); + + // 状态 + const [projects, setProjects] = useState([]); + const [loadingProjects, setLoadingProjects] = useState(true); + const [selectedProjectId, setSelectedProjectId] = useState(""); + const [searchTerm, setSearchTerm] = useState(""); + const [branch, setBranch] = useState("main"); + const [branches, setBranches] = useState([]); + const [loadingBranches, setLoadingBranches] = useState(false); + const [excludePatterns, setExcludePatterns] = useState(DEFAULT_EXCLUDES); + const [showAdvanced, setShowAdvanced] = useState(false); + const [creating, setCreating] = useState(false); + + // ZIP 文件状态 + const [zipFile, setZipFile] = useState(null); + const [storedZipInfo, setStoredZipInfo] = useState(null); + const [useStoredZip, setUseStoredZip] = useState(true); + + // 文件选择状态 + const [selectedFiles, setSelectedFiles] = useState(); + const [showFileSelection, setShowFileSelection] = useState(false); + + const selectedProject = projects.find((p) => p.id === selectedProjectId); + + // 加载项目列表 + useEffect(() => { + if (open) { + setLoadingProjects(true); + api.getProjects() + .then((data) => { + setProjects(data.filter((p: Project) => p.is_active)); + }) + .catch(() => { + toast.error("加载项目列表失败"); + }) + .finally(() => setLoadingProjects(false)); + + // 重置状态 + setSelectedProjectId(""); + setSearchTerm(""); + setBranch("main"); + setExcludePatterns(DEFAULT_EXCLUDES); + setShowAdvanced(false); + setZipFile(null); + setStoredZipInfo(null); + setSelectedFiles(undefined); + } + }, [open]); + + // 加载分支列表 + useEffect(() => { + const loadBranches = async () => { + // 使用 selectedProjectId 从 projects 中获取最新的 project 对象 + const project = projects.find((p) => p.id === selectedProjectId); + if (!project || !isRepositoryProject(project)) { + setBranches([]); + return; + } + + setLoadingBranches(true); + try { + const result = await api.getProjectBranches(project.id); + console.log("[Branch] 加载分支结果:", result); + + if (result.error) { + console.warn("[Branch] 加载分支警告:", result.error); + toast.error(`加载分支失败: ${result.error}`); + } + + setBranches(result.branches); + if (result.default_branch) { + setBranch(result.default_branch); + } + } catch (err) { + const msg = err instanceof Error ? err.message : "未知错误"; + console.error("[Branch] 加载分支失败:", msg); + toast.error(`加载分支失败: ${msg}`); + setBranches([project.default_branch || "main"]); + } finally { + setLoadingBranches(false); + } + }; + + loadBranches(); + }, [selectedProjectId, projects]); + + // 加载 ZIP 文件信息 + useEffect(() => { + const loadZipInfo = async () => { + if (!selectedProject || !isZipProject(selectedProject)) { + setStoredZipInfo(null); + return; + } + + try { + const info = await getZipFileInfo(selectedProject.id); + setStoredZipInfo(info); + setUseStoredZip(info.has_file); + } catch { + setStoredZipInfo(null); + } + }; + + loadZipInfo(); + }, [selectedProject?.id]); + + // 过滤项目 + const filteredProjects = useMemo(() => { + if (!searchTerm) return projects; + const term = searchTerm.toLowerCase(); + return projects.filter( + (p) => + p.name.toLowerCase().includes(term) || + p.description?.toLowerCase().includes(term) + ); + }, [projects, searchTerm]); + + // 是否可以开始 + const canStart = useMemo(() => { + if (!selectedProject) return false; + if (isZipProject(selectedProject)) { + return (useStoredZip && storedZipInfo?.has_file) || !!zipFile; + } + return !!selectedProject.repository_url && !!branch.trim(); + }, [selectedProject, useStoredZip, storedZipInfo, zipFile, branch]); + + // 创建任务 + const handleCreate = async () => { + if (!selectedProject) return; + + setCreating(true); + try { + const agentTask = await createAgentTask({ + project_id: selectedProject.id, + name: `Agent审计-${selectedProject.name}`, + branch_name: isRepositoryProject(selectedProject) ? branch : undefined, + exclude_patterns: excludePatterns, + target_files: selectedFiles, + verification_level: "sandbox", + }); + + onOpenChange(false); + toast.success("Agent 审计任务已创建"); + navigate(`/agent-audit/${agentTask.id}`); + } catch (err) { + const msg = err instanceof Error ? err.message : "创建失败"; + toast.error(msg); + } finally { + setCreating(false); + } + }; + + // 处理文件上传 + const handleFileChange = (e: React.ChangeEvent) => { + const file = e.target.files?.[0]; + if (file) { + const validation = validateZipFile(file); + if (!validation.valid) { + toast.error(validation.error || "文件无效"); + e.target.value = ""; + return; + } + setZipFile(file); + setUseStoredZip(false); + } + }; + + return ( + + + {/* Header */} + + +
+ +
+
+ New Agent Audit +

+ AI-Powered Security Analysis +

+
+
+
+ +
+ {/* 项目选择 */} +
+
+ + Select Project + + + {filteredProjects.length} available + +
+ + {/* 搜索框 */} +
+ + setSearchTerm(e.target.value)} + className="pl-9 h-10 bg-gray-900/50 border-gray-800 text-white font-mono placeholder:text-gray-600 focus:border-primary focus:ring-0" + /> +
+ + {/* 项目列表 */} + + {loadingProjects ? ( +
+ +
+ ) : filteredProjects.length === 0 ? ( +
+ + {searchTerm ? "No matches" : "No projects"} +
+ ) : ( +
+ {filteredProjects.map((project) => ( + setSelectedProjectId(project.id)} + /> + ))} +
+ )} +
+
+ + {/* 配置区域 */} + {selectedProject && ( +
+ {/* 仓库项目:分支选择 */} + {isRepositoryProject(selectedProject) && ( +
+ + Branch + {loadingBranches ? ( +
+ + Loading... +
+ ) : ( + + )} +
+ )} + + {/* ZIP 项目:文件选择 */} + {isZipProject(selectedProject) && ( +
+
+ + ZIP File +
+ + {storedZipInfo?.has_file && ( +
setUseStoredZip(true)} + > +
+
+ + {storedZipInfo.original_filename} + + + Stored + +
+
+ )} + +
+
+ )} + + {/* 高级选项 */} + + + + + Advanced Options + + + {/* 文件选择 */} + {(() => { + const isRepo = isRepositoryProject(selectedProject); + const isZip = isZipProject(selectedProject); + const hasStoredZip = storedZipInfo?.has_file; + // 可以选择文件的条件:仓库项目 或 ZIP项目使用已存储文件 + const canSelectFiles = isRepo || (isZip && useStoredZip && hasStoredZip); + + return ( +
+
+

+ Scan Scope +

+

+ {selectedFiles + ? `${selectedFiles.length} files selected` + : "All files"} +

+
+
+ {selectedFiles && canSelectFiles && ( + + )} + +
+
+ ); + })()} + + {/* 排除模式 */} +
+
+ + Exclude Patterns + + +
+ +
+ {excludePatterns.map((p) => ( + setExcludePatterns((prev) => prev.filter((x) => x !== p))} + > + {p} × + + ))} +
+ + { + if (e.key === "Enter" && e.currentTarget.value) { + const val = e.currentTarget.value.trim(); + if (val && !excludePatterns.includes(val)) { + setExcludePatterns((prev) => [...prev, val]); + } + e.currentTarget.value = ""; + } + }} + /> +
+
+
+
+ )} +
+ + {/* Footer */} +
+ + +
+ + + {/* 文件选择对话框 */} + +
+ ); +} + +// 项目列表项 +function ProjectItem({ + project, + selected, + onSelect, +}: { + project: Project; + selected: boolean; + onSelect: () => void; +}) { + const isRepo = isRepositoryProject(project); + + return ( +
+
+ {isRepo ? ( + + ) : ( + + )} +
+ +
+
+ + {project.name} + + + {isRepo ? "REPO" : "ZIP"} + +
+ {project.description && ( +

+ {project.description} +

+ )} +
+ + {selected && ( +
+ )} +
+ ); +} diff --git a/frontend/src/components/audit/CreateTaskDialog.tsx b/frontend/src/components/audit/CreateTaskDialog.tsx index 3e89caa..9ed1d78 100644 --- a/frontend/src/components/audit/CreateTaskDialog.tsx +++ b/frontend/src/components/audit/CreateTaskDialog.tsx @@ -30,7 +30,6 @@ import { Upload, FolderOpen, Settings2, - Play, Package, Globe, Shield, @@ -111,28 +110,39 @@ export default function CreateTaskDialog({ // 加载分支列表 useEffect(() => { const loadBranches = async () => { - if (!selectedProject || !isRepositoryProject(selectedProject)) { + // 使用 selectedProjectId 从 projects 中获取最新的 project 对象 + const project = projects.find((p) => p.id === selectedProjectId); + if (!project || !isRepositoryProject(project)) { setBranches([]); return; } setLoadingBranches(true); try { - const result = await api.getProjectBranches(selectedProject.id); + const result = await api.getProjectBranches(project.id); + console.log("[Branch] 加载分支结果:", result); + + if (result.error) { + console.warn("[Branch] 加载分支警告:", result.error); + toast.error(`加载分支失败: ${result.error}`); + } + setBranches(result.branches); if (result.default_branch) { setBranch(result.default_branch); } } catch (error) { - console.error("加载分支失败:", error); - setBranches([selectedProject.default_branch || "main"]); + const msg = error instanceof Error ? error.message : "未知错误"; + console.error("[Branch] 加载分支失败:", msg); + toast.error(`加载分支失败: ${msg}`); + setBranches([project.default_branch || "main"]); } finally { setLoadingBranches(false); } }; loadBranches(); - }, [selectedProject?.id]); + }, [selectedProjectId, projects]); const filteredProjects = useMemo(() => { if (!searchTerm) return projects; @@ -437,45 +447,47 @@ export default function CreateTaskDialog({ )} {/* 高级选项 */} - {/* 规则集和提示词选择 */} -
-
- - 审计配置 -
-
-
- - + {/* 规则集和提示词选择 - 仅快速扫描模式显示 */} + {auditMode !== "agent" && ( +
+
+ + 审计配置
-
- - +
+
+ + +
+
+ + +
-
+ )} diff --git a/frontend/src/components/layout/Sidebar.tsx b/frontend/src/components/layout/Sidebar.tsx index abb2d97..88a752c 100644 --- a/frontend/src/components/layout/Sidebar.tsx +++ b/frontend/src/components/layout/Sidebar.tsx @@ -16,7 +16,8 @@ import { Github, UserCircle, Shield, - MessageSquare + MessageSquare, + Bot } from "lucide-react"; import routes from "@/app/routes"; import { version } from "../../../package.json"; @@ -27,6 +28,7 @@ const routeIcons: Record = { "/projects": , "/instant-analysis": , "/audit-tasks": , + "/agent-audit": , "/audit-rules": , "/prompts": , "/admin": , diff --git a/frontend/src/pages/AgentAudit.tsx b/frontend/src/pages/AgentAudit.tsx index 8195a77..8cfb283 100644 --- a/frontend/src/pages/AgentAudit.tsx +++ b/frontend/src/pages/AgentAudit.tsx @@ -1,13 +1,16 @@ /** - * Agent Audit Page - Simplified Professional Version + * Agent Audit Page - Strix-inspired Terminal UI + * 参考 Strix 的 TUI 设计:左侧活动日志 + 右侧 Agent 树和统计 */ import { useState, useEffect, useRef, useCallback, useMemo } from "react"; -import { useParams, useNavigate } from "react-router-dom"; +import { useParams } from "react-router-dom"; import { Terminal, Bot, CheckCircle2, Loader2, XCircle, - Bug, Square, ArrowLeft, Brain, Wrench, - ChevronDown, ChevronUp, Clock, Eye, EyeOff, Target + Bug, Square, Brain, Wrench, Play, + ChevronDown, ChevronUp, Clock, Target, Zap, + Shield, Activity, ChevronRight, + FileCode, AlertTriangle, Search } from "lucide-react"; import { Button } from "@/components/ui/button"; import { Badge } from "@/components/ui/badge"; @@ -19,14 +22,17 @@ import { getAgentTask, getAgentFindings, cancelAgentTask, + getAgentTree, + type AgentTreeResponse, + type AgentTreeNode, } from "@/shared/api/agentTasks"; +import CreateAgentTaskDialog from "@/components/agent/CreateAgentTaskDialog"; // ============ Types ============ - interface LogItem { id: string; time: string; - type: 'thinking' | 'tool' | 'phase' | 'finding' | 'info' | 'error'; + type: 'thinking' | 'tool' | 'phase' | 'finding' | 'info' | 'error' | 'user' | 'dispatch'; title: string; content?: string; isStreaming?: boolean; @@ -35,105 +41,334 @@ interface LogItem { agentName?: string; } +// ============ Utilities ============ + +/** + * 将扁平的 Agent 节点列表转换为树结构 + * 后端返回的是扁平列表,需要根据 parent_agent_id 构建树 + */ +function buildAgentTree(flatNodes: AgentTreeNode[]): AgentTreeNode[] { + if (!flatNodes || flatNodes.length === 0) return []; + + // 创建节点映射(使用 agent_id 作为 key) + const nodeMap = new Map(); + + // 首先克隆所有节点并重置 children + flatNodes.forEach(node => { + nodeMap.set(node.agent_id, { ...node, children: [] }); + }); + + // 构建树结构 + const rootNodes: AgentTreeNode[] = []; + + flatNodes.forEach(node => { + const currentNode = nodeMap.get(node.agent_id)!; + + if (node.parent_agent_id && nodeMap.has(node.parent_agent_id)) { + // 有父节点,添加到父节点的 children + const parentNode = nodeMap.get(node.parent_agent_id)!; + parentNode.children.push(currentNode); + } else { + // 没有父节点或父节点不存在,作为根节点 + rootNodes.push(currentNode); + } + }); + + return rootNodes; +} + // ============ Constants ============ - const SEVERITY_COLORS: Record = { - critical: "text-red-400 bg-red-950/50", - high: "text-orange-400 bg-orange-950/50", - medium: "text-yellow-400 bg-yellow-950/50", - low: "text-blue-400 bg-blue-950/50", - info: "text-gray-400 bg-gray-900/50", + critical: "text-red-400 bg-red-950/50 border-red-500", + high: "text-orange-400 bg-orange-950/50 border-orange-500", + medium: "text-yellow-400 bg-yellow-950/50 border-yellow-500", + low: "text-blue-400 bg-blue-950/50 border-blue-500", + info: "text-gray-400 bg-gray-900/50 border-gray-500", }; -const AGENT_COLORS: Record = { - Orchestrator: "text-purple-400 border-purple-500/30 bg-purple-950/20", - Recon: "text-green-400 border-green-500/30 bg-green-950/20", - Analysis: "text-blue-400 border-blue-500/30 bg-blue-950/20", - Verification: "text-red-400 border-red-500/30 bg-red-950/20", - default: "text-gray-400 border-gray-500/30 bg-gray-950/20", -}; +const ACTION_VERBS = [ + "Analyzing", "Scanning", "Probing", "Investigating", + "Examining", "Auditing", "Testing", "Exploring" +]; -// ============ Components ============ +// ============ Sub Components ============ + +// 启动画面 - Strix 风格 +function SplashScreen({ onComplete }: { onComplete: () => void }) { + const [dots, setDots] = useState(0); + const [verb, setVerb] = useState(ACTION_VERBS[0]); + + useEffect(() => { + const timer = setTimeout(onComplete, 2500); + return () => clearTimeout(timer); + }, [onComplete]); + + useEffect(() => { + const dotTimer = setInterval(() => setDots(d => (d + 1) % 4), 400); + const verbTimer = setInterval(() => { + setVerb(ACTION_VERBS[Math.floor(Math.random() * ACTION_VERBS.length)]); + }, 2000); + return () => { + clearInterval(dotTimer); + clearInterval(verbTimer); + }; + }, []); -function StatusBadge({ status }: { status: string }) { - const config: Record = { - pending: { bg: "bg-gray-700", icon: }, - running: { bg: "bg-blue-700", icon: }, - completed: { bg: "bg-green-700", icon: }, - failed: { bg: "bg-red-700", icon: }, - cancelled: { bg: "bg-yellow-700", icon: }, - }; - const c = config[status] || config.pending; return ( - - {c.icon} - {status.toUpperCase()} - +
+
+
+{`
+ ██████╗ ███████╗███████╗██████╗  █████╗ ██╗   ██╗██████╗ ██╗████████╗
+ ██╔══██╗██╔════╝██╔════╝██╔══██╗██╔══██╗██║   ██║██╔══██╗██║╚══██╔══╝
+ ██║  ██║█████╗  █████╗  ██████╔╝███████║██║   ██║██║  ██║██║   ██║   
+ ██║  ██║██╔══╝  ██╔══╝  ██╔═══╝ ██╔══██║██║   ██║██║  ██║██║   ██║   
+ ██████╔╝███████╗███████╗██║     ██║  ██║╚██████╔╝██████╔╝██║   ██║   
+ ╚═════╝ ╚══════╝╚══════╝╚═╝     ╚═╝  ╚═╝ ╚═════╝ ╚═════╝ ╚═╝   ╚═╝   
+`}
+        
+
+

+ Welcome to DeepAudit! +

+

AI-Powered Security Audit Agent

+
+
+ + + {verb}{'.'.repeat(dots)} + +
+
+
); } + +// Agent 树节点 - 增强版 +function AgentTreeNodeItem({ + node, + depth = 0, + selectedId, + onSelect +}: { + node: AgentTreeNode; + depth?: number; + selectedId: string | null; + onSelect: (id: string) => void; +}) { + const [expanded, setExpanded] = useState(true); + const hasChildren = node.children && node.children.length > 0; + const isSelected = selectedId === node.agent_id; + + // 状态图标和颜色 + const statusConfig: Record = { + running: { + icon:
, + color: "text-green-400", + animate: true + }, + completed: { + icon: , + color: "text-green-400" + }, + failed: { + icon: , + color: "text-red-400" + }, + waiting: { + icon: , + color: "text-yellow-400" + }, + created: { + icon:
, + color: "text-gray-400" + }, + }; + + const config = statusConfig[node.status] || statusConfig.created; + + // Agent 类型图标 + const typeIcons: Record = { + orchestrator: , + recon: , + analysis: , + verification: , + }; + + return ( +
+
onSelect(node.agent_id)} + > + {hasChildren ? ( + + ) : } + + {/* 状态指示器 */} + + {config.icon} + + + {/* Agent 类型图标 */} + {typeIcons[node.agent_type] || } + + {/* Agent 名称 */} + + {node.agent_name} + + + {/* 发现数量 */} + {node.findings_count > 0 && ( + + {node.findings_count} + + )} +
+ + {expanded && hasChildren && ( +
+ {node.children.map(child => ( + + ))} +
+ )} +
+ ); +} + +// 日志条目组件 - 增强版 function LogEntry({ item, isExpanded, onToggle }: { item: LogItem; isExpanded: boolean; onToggle: () => void; }) { - const icons: Record = { - thinking: , - tool: , - phase: , - finding: , - info: , - error: , + const config: Record = { + thinking: { + icon: , + borderColor: "border-l-purple-500", + bgColor: "bg-purple-950/20" + }, + tool: { + icon: , + borderColor: "border-l-amber-500", + bgColor: "bg-amber-950/20" + }, + phase: { + icon: , + borderColor: "border-l-cyan-500", + bgColor: "bg-cyan-950/20" + }, + finding: { + icon: , + borderColor: "border-l-red-500", + bgColor: "bg-red-950/20" + }, + dispatch: { + icon: , + borderColor: "border-l-blue-500", + bgColor: "bg-blue-950/20" + }, + info: { + icon: , + borderColor: "border-l-gray-600", + bgColor: "bg-gray-900/30" + }, + error: { + icon: , + borderColor: "border-l-red-600", + bgColor: "bg-red-950/30" + }, + user: { + icon: , + borderColor: "border-l-blue-500", + bgColor: "bg-blue-950/20" + }, }; - const borderColors: Record = { - thinking: "border-l-purple-500", - tool: "border-l-amber-500", - phase: "border-l-cyan-500", - finding: "border-l-red-500", - info: "border-l-gray-600", - error: "border-l-red-600", - }; - - // Thinking content is always shown, others are collapsible + const c = config[item.type] || config.info; const isThinking = item.type === 'thinking'; const showContent = isThinking || isExpanded; const isCollapsible = !isThinking && item.content; return (
-
- {icons[item.type]} - {item.time} - {!isThinking && {item.title}} - {item.isStreaming && } - {item.tool?.status === 'running' && } +
+ {c.icon} + {item.time} + + {!isThinking && ( + {item.title} + )} + + {item.isStreaming && ( + + )} + + {item.tool?.status === 'running' && ( + + )} + {item.agentName && ( - + {item.agentName} )}
+
{item.tool?.duration !== undefined && ( - {item.tool.duration}ms + {item.tool.duration}ms )} {item.severity && ( - + {item.severity.charAt(0).toUpperCase()} )} {isCollapsible && ( - isExpanded ? : + isExpanded ? + : + )}
- + {showContent && item.content && ( -
+
{item.content}
)} @@ -141,19 +376,251 @@ function LogEntry({ item, isExpanded, onToggle }: { ); } -// ============ Main Component ============ +// 选中 Agent 详情面板 +function AgentDetailPanel({ + agentId, + treeNodes, + onClose +}: { + agentId: string; + treeNodes: AgentTreeNode[]; + onClose: () => void; +}) { + // 递归查找 agent + const findAgent = (nodes: AgentTreeNode[], id: string): AgentTreeNode | null => { + for (const node of nodes) { + if (node.agent_id === id) return node; + const found = findAgent(node.children, id); + if (found) return found; + } + return null; + }; + const agent = findAgent(treeNodes, agentId); + if (!agent) return null; + + const statusConfig: Record = { + running: { color: "text-green-400", text: "Running" }, + completed: { color: "text-green-400", text: "Completed" }, + failed: { color: "text-red-400", text: "Failed" }, + waiting: { color: "text-yellow-400", text: "Waiting" }, + created: { color: "text-gray-400", text: "Created" }, + }; + + const typeIcons: Record = { + orchestrator: { icon: , label: "Orchestrator" }, + recon: { icon: , label: "Reconnaissance" }, + analysis: { icon: , label: "Analysis" }, + verification: { icon: , label: "Verification" }, + }; + + const config = statusConfig[agent.status] || statusConfig.created; + const typeInfo = typeIcons[agent.agent_type] || { icon: , label: "Agent" }; + + return ( +
+
+
+ {typeInfo.icon} + {agent.agent_name} +
+ +
+ +
+
+ Type + {typeInfo.label} +
+
+ Status + {config.text} +
+
+ Iterations + {agent.iterations || 0} +
+
+ Findings + 0 ? 'text-red-400' : 'text-white'}`}> + {agent.findings_count} + +
+
+ Tool Calls + {agent.tool_calls || 0} +
+
+ Tokens + {((agent.tokens_used || 0) / 1000).toFixed(1)}k +
+
+ + {agent.task_description && ( +
+ Task +

{agent.task_description}

+
+ )} + + {agent.children && agent.children.length > 0 && ( +
+ Sub-agents: {agent.children.length} +
+ )} +
+ ); +} + +// 实时统计面板 - 增强版 +function StatsPanel({ task, findings }: { task: AgentTask | null; findings: AgentFinding[] }) { + if (!task) return null; + + const severityCounts = { + critical: findings.filter(f => f.severity === 'critical').length, + high: findings.filter(f => f.severity === 'high').length, + medium: findings.filter(f => f.severity === 'medium').length, + low: findings.filter(f => f.severity === 'low').length, + }; + + const totalFindings = Object.values(severityCounts).reduce((a, b) => a + b, 0); + + return ( +
+
+ + Live Stats +
+ + {/* 进度条 */} +
+
+ Progress + {task.progress_percentage?.toFixed(0) || 0}% +
+
+
+
+
+ + {/* 统计数据 */} +
+
+ Files + {task.analyzed_files}/{task.total_files} +
+
+ Iterations + {task.total_iterations || 0} +
+
+ Tokens + {((task.tokens_used || 0) / 1000).toFixed(1)}k +
+
+ Tool Calls + {task.tool_calls_count || 0} +
+
+ + {/* 发现统计 */} + {totalFindings > 0 && ( +
+
+ Findings + {totalFindings} +
+
+ {severityCounts.critical > 0 && ( + + Critical: {severityCounts.critical} + + )} + {severityCounts.high > 0 && ( + + High: {severityCounts.high} + + )} + {severityCounts.medium > 0 && ( + + Medium: {severityCounts.medium} + + )} + {severityCounts.low > 0 && ( + + Low: {severityCounts.low} + + )} +
+
+ )} + + {/* 安全评分 */} + {task.security_score !== null && task.security_score !== undefined && ( +
+
+ Security Score + = 80 ? 'text-green-400' : + task.security_score >= 60 ? 'text-yellow-400' : + 'text-red-400' + }`}> + {task.security_score.toFixed(0)} + +
+
+ )} +
+ ); +} + + + +// 状态徽章 +function StatusBadge({ status }: { status: string }) { + const config: Record = { + pending: { bg: "bg-gray-700", icon: , text: "PENDING" }, + running: { bg: "bg-green-700", icon: , text: "RUNNING" }, + completed: { bg: "bg-green-600", icon: , text: "COMPLETED" }, + failed: { bg: "bg-red-700", icon: , text: "FAILED" }, + cancelled: { bg: "bg-yellow-700", icon: , text: "CANCELLED" }, + }; + const c = config[status] || config.pending; + return ( + + {c.icon} + {c.text} + + ); +} + + +// ============ Main Component ============ export default function AgentAuditPage() { const { taskId } = useParams<{ taskId: string }>(); - const navigate = useNavigate(); + // 状态 + const [showSplash, setShowSplash] = useState(!taskId); + const [showCreateDialog, setShowCreateDialog] = useState(false); const [task, setTask] = useState(null); - const [_findings, setFindings] = useState([]); // Loaded for future use - const [isLoading, setIsLoading] = useState(true); - + const [findings, setFindings] = useState([]); + const [agentTree, setAgentTree] = useState(null); + const [selectedAgentId, setSelectedAgentId] = useState(null); + const [isLoading, setIsLoading] = useState(!!taskId); const [logs, setLogs] = useState([]); const [expandedIds, setExpandedIds] = useState>(new Set()); const [isAutoScroll, setIsAutoScroll] = useState(true); + const [statusVerb, setStatusVerb] = useState(ACTION_VERBS[0]); + const [statusDots, setStatusDots] = useState(0); + const [showAllLogs, setShowAllLogs] = useState(true); // 是否显示所有日志 const logEndRef = useRef(null); const logIdCounter = useRef(0); @@ -163,7 +630,62 @@ export default function AgentAuditPage() { const isRunning = task?.status === "running"; const isComplete = task?.status === "completed" || task?.status === "failed" || task?.status === "cancelled"; - // Helper to add log + // 构建 Agent 树结构(将扁平列表转换为树) + const treeNodes = useMemo(() => { + if (!agentTree?.nodes) return []; + return buildAgentTree(agentTree.nodes); + }, [agentTree?.nodes]); + + // 根据选中的 Agent 过滤日志 + const filteredLogs = useMemo(() => { + if (showAllLogs || !selectedAgentId) { + return logs; + } + // 根据 agentName 或 agentId 过滤 + // 需要找到选中 agent 的名称(在树结构中递归查找) + const findAgentName = (nodes: AgentTreeNode[], id: string): string | null => { + for (const node of nodes) { + if (node.agent_id === id) return node.agent_name; + const found = findAgentName(node.children, id); + if (found) return found; + } + return null; + }; + const selectedAgentName = findAgentName(treeNodes, selectedAgentId); + if (!selectedAgentName) return logs; + + return logs.filter(log => + log.agentName?.toLowerCase() === selectedAgentName.toLowerCase() || + log.agentName?.toLowerCase().includes(selectedAgentName.toLowerCase().split('_')[0]) + ); + }, [logs, selectedAgentId, showAllLogs, treeNodes]); + + // 选中 Agent 时的处理 + const handleAgentSelect = useCallback((agentId: string) => { + if (selectedAgentId === agentId) { + // 再次点击同一个 agent,切换回显示全部 + setShowAllLogs(true); + setSelectedAgentId(null); + } else { + setSelectedAgentId(agentId); + setShowAllLogs(false); + } + }, [selectedAgentId]); + + // 动态状态动画 + useEffect(() => { + if (!isRunning) return; + const dotTimer = setInterval(() => setStatusDots(d => (d + 1) % 4), 500); + const verbTimer = setInterval(() => { + setStatusVerb(ACTION_VERBS[Math.floor(Math.random() * ACTION_VERBS.length)]); + }, 5000); + return () => { + clearInterval(dotTimer); + clearInterval(verbTimer); + }; + }, [isRunning]); + + // Helper: 添加日志 const addLog = useCallback((item: Omit) => { const newItem: LogItem = { ...item, @@ -174,13 +696,13 @@ export default function AgentAuditPage() { return newItem.id; }, []); - // Load functions + // 加载任务数据 const loadTask = useCallback(async () => { if (!taskId) return; try { const data = await getAgentTask(taskId); setTask(data); - } catch (err: unknown) { + } catch (err) { toast.error("Failed to load task"); } }, [taskId]); @@ -195,40 +717,52 @@ export default function AgentAuditPage() { } }, [taskId]); - // Stream options - SIMPLIFIED + const loadAgentTree = useCallback(async () => { + if (!taskId) return; + try { + const data = await getAgentTree(taskId); + setAgentTree(data); + } catch (err) { + console.error(err); + } + }, [taskId]); + + // 流式事件处理 const streamOptions = useMemo(() => ({ includeThinking: true, includeToolCalls: true, - onEvent: (event: any) => { - if (event.agent_name) { - currentAgentName.current = event.agent_name; + // 捕获 agent_name + if (event.metadata?.agent_name) { + currentAgentName.current = event.metadata.agent_name; + } + + // 处理 dispatch 事件 + if (event.type === 'dispatch' || event.type === 'dispatch_complete') { + addLog({ + type: 'dispatch', + title: event.message || `Agent dispatch: ${event.metadata?.agent || 'unknown'}`, + agentName: currentAgentName.current || undefined, + }); + // 🔥 刷新 Agent 树,显示新创建的子 Agent + loadAgentTree(); } }, - onThinkingStart: () => { - // Ensure previous thinking is finalized if (currentThinkingId.current) { setLogs(prev => prev.map(log => - log.id === currentThinkingId.current - ? { ...log, isStreaming: false } - : log + log.id === currentThinkingId.current ? { ...log, isStreaming: false } : log )); } currentThinkingId.current = null; }, - onThinkingToken: (_token: string, accumulated: string) => { - if (!accumulated || accumulated.trim() === '') return; // Skip empty content - - // User Request: Action and Action Input should not be in Thinking box - // Filter out "Action:" and everything after from the thinking log + if (!accumulated?.trim()) return; + // 清理 Action 部分,只显示 Thought const cleanContent = accumulated.replace(/\nAction\s*:[\s\S]*$/, "").trim(); - if (!cleanContent) return; if (!currentThinkingId.current) { - // Create new thinking entry on first non-empty token const id = addLog({ type: 'thinking', title: 'Thinking...', @@ -238,91 +772,68 @@ export default function AgentAuditPage() { }); currentThinkingId.current = id; } else { - // Update existing entry setLogs(prev => prev.map(log => - log.id === currentThinkingId.current - ? { ...log, content: cleanContent } - : log + log.id === currentThinkingId.current ? { ...log, content: cleanContent } : log )); } }, - onThinkingEnd: (response: string) => { const cleanResponse = (response || "").replace(/\nAction\s*:[\s\S]*$/, "").trim(); - - if (!cleanResponse || cleanResponse === '') { - // No content, remove the entry if it exists + if (!cleanResponse) { if (currentThinkingId.current) { setLogs(prev => prev.filter(log => log.id !== currentThinkingId.current)); } currentThinkingId.current = null; return; } - if (currentThinkingId.current) { setLogs(prev => prev.map(log => log.id === currentThinkingId.current - ? { - ...log, - title: cleanResponse.slice(0, 80) + (cleanResponse.length > 80 ? '...' : ''), - content: cleanResponse, - isStreaming: false - } + ? { + ...log, + title: cleanResponse.slice(0, 100) + (cleanResponse.length > 100 ? '...' : ''), + content: cleanResponse, + isStreaming: false + } : log )); currentThinkingId.current = null; - } else if (cleanResponse.trim()) { - // No existing entry but we have content - create one - addLog({ - type: 'thinking', - title: cleanResponse.slice(0, 80) + (cleanResponse.length > 80 ? '...' : ''), - content: cleanResponse, - agentName: currentAgentName.current || undefined, - }); } }, - onToolStart: (name: string, input: Record) => { - // Force finalize any pending thinking log when a tool starts if (currentThinkingId.current) { setLogs(prev => prev.map(log => - log.id === currentThinkingId.current - ? { ...log, isStreaming: false } - : log + log.id === currentThinkingId.current ? { ...log, isStreaming: false } : log )); currentThinkingId.current = null; } - addLog({ type: 'tool', - title: `Action: ${name}`, + title: `Tool: ${name}`, content: `Input:\n${JSON.stringify(input, null, 2)}`, tool: { name, status: 'running' }, agentName: currentAgentName.current || undefined, }); }, - onToolEnd: (name: string, output: unknown, duration: number) => { - // Update the last tool log with duration and output setLogs(prev => { - // Find last matching tool (reverse search for compatibility) let idx = -1; for (let i = prev.length - 1; i >= 0; i--) { - if (prev[i].type === 'tool' && prev[i].tool?.name === name) { - idx = i; - break; + if (prev[i].type === 'tool' && prev[i].tool?.name === name && prev[i].tool?.status === 'running') { + idx = i; + break; } } if (idx >= 0) { const newLogs = [...prev]; - // Preserve existing input content and append output const previousContent = newLogs[idx].content || ''; const outputStr = typeof output === 'string' ? output : JSON.stringify(output, null, 2); - + // 截断过长输出 + const truncatedOutput = outputStr.length > 1000 ? outputStr.slice(0, 1000) + '\n... (truncated)' : outputStr; newLogs[idx] = { ...newLogs[idx], title: `Completed: ${name}`, - content: `${previousContent}\n\nOutput:\n${outputStr}`, + content: `${previousContent}\n\nOutput:\n${truncatedOutput}`, tool: { name, duration, status: 'completed' }, }; return newLogs; @@ -330,57 +841,72 @@ export default function AgentAuditPage() { return prev; }); }, - onFinding: (finding: Record) => { addLog({ type: 'finding', title: (finding.title as string) || 'Vulnerability found', severity: (finding.severity as string) || 'medium', + agentName: currentAgentName.current || undefined, }); loadFindings(); }, - onComplete: () => { - addLog({ type: 'info', title: 'Audit completed' }); + addLog({ type: 'info', title: '✅ Audit completed' }); loadTask(); loadFindings(); + loadAgentTree(); }, - onError: (err: string) => { addLog({ type: 'error', title: `Error: ${err}` }); }, - }), [addLog, loadTask, loadFindings]); + }), [addLog, loadTask, loadFindings, loadAgentTree]); - const { - connect: connectStream, - disconnect: disconnectStream, - isConnected, - } = useAgentStream(taskId || null, streamOptions); + const { connect: connectStream, disconnect: disconnectStream, isConnected } = useAgentStream(taskId || null, streamOptions); - // Init + // 初始化 useEffect(() => { - const init = async () => { - setIsLoading(true); - await Promise.all([loadTask(), loadFindings()]); - setIsLoading(false); - }; - init(); - }, [loadTask, loadFindings]); + if (!taskId) { + setShowSplash(true); + return; + } + setShowSplash(false); + setIsLoading(true); + + Promise.all([loadTask(), loadFindings(), loadAgentTree()]) + .finally(() => setIsLoading(false)); + }, [taskId, loadTask, loadFindings, loadAgentTree]); - // Connect + // 连接流 useEffect(() => { - if (!taskId || isComplete) return; - connectStream(); + if (taskId && task?.status === 'running') { + connectStream(); + addLog({ type: 'info', title: '🔗 Connected to audit stream' }); + } return () => disconnectStream(); - }, [taskId, isComplete, connectStream, disconnectStream]); + }, [taskId, task?.status, connectStream, disconnectStream, addLog]); - // Auto-scroll + // 定期刷新 Agent 树 + useEffect(() => { + if (!taskId || !isRunning) return; + const interval = setInterval(loadAgentTree, 3000); + return () => clearInterval(interval); + }, [taskId, isRunning, loadAgentTree]); + + // 定期刷新 Task 统计数据(Files, Iterations, Tokens, Tool Calls) + useEffect(() => { + if (!taskId || !isRunning) return; + const interval = setInterval(loadTask, 2000); + return () => clearInterval(interval); + }, [taskId, isRunning, loadTask]); + + // 自动滚动 useEffect(() => { if (isAutoScroll && logEndRef.current) { logEndRef.current.scrollIntoView({ behavior: 'smooth' }); } }, [logs, isAutoScroll]); + // 取消任务 const handleCancel = async () => { if (!taskId) return; try { @@ -388,10 +914,11 @@ export default function AgentAuditPage() { toast.success("Task cancelled"); loadTask(); } catch { - toast.error("Failed to cancel"); + toast.error("Failed to cancel task"); } }; + // 切换日志展开 const toggleExpand = (id: string) => { setExpandedIds(prev => { const next = new Set(prev); @@ -401,112 +928,246 @@ export default function AgentAuditPage() { }); }; - if (isLoading || !task) { + // ============ Render ============ + + // Splash 画面 (无 taskId) + if (showSplash && !taskId) { return ( -
- + <> + setShowCreateDialog(true)} /> + + + ); + } + + // 加载中 + if (isLoading && !task) { + return ( +
+
+ + Loading audit task... +
); } return ( -
+
{/* Header */} -
-
- -
- - Security Audit - {taskId?.slice(0, 8)} -
-
- +
- {isConnected && ( - - - LIVE - - )} - - {isRunning && ( - + + DeepAudit + {task && ( + <> + / + + {task.name || task.id.slice(0, 8)} + + + )}
-
- - {/* Main */} -
- {/* Left: Activity Log */} -
- {/* Toolbar */} -
-
- - Activity Log - {logs.length} -
+
+ {isRunning && ( + )} + +
+
+ + {/* Main Content - Strix Layout: 75% left / 25% right */} +
+ {/* Left Panel - Activity Log (75%) */} +
+ {/* Log Header */} +
+
+ + Activity Log + {isConnected && ( + + + Live + + )} + {/* 显示日志数量 */} + + {filteredLogs.length}{!showAllLogs && logs.length !== filteredLogs.length ? `/${logs.length}` : ''} + +
+
- {/* Logs */} -
- {logs.length === 0 ? ( -
- -

Waiting for agent activity...

+ {/* Log Content */} +
+ {/* 过滤提示 */} + {selectedAgentId && !showAllLogs && ( +
+ + Filtering logs for selected agent + + +
+ )} + {filteredLogs.length === 0 ? ( +
+ {isRunning ? ( + + + {selectedAgentId && !showAllLogs + ? 'Waiting for activity from selected agent...' + : 'Waiting for agent activity...'} + + ) : selectedAgentId && !showAllLogs ? ( + 'No activity from selected agent' + ) : ( + 'No activity yet' + )}
) : ( - logs - .filter(item => { - // Filter out empty/placeholder entries - if (item.title === 'Thinking...' && (!item.content || item.content.trim() === '')) { - return false; - } - return true; - }) - .map(item => ( - toggleExpand(item.id)} - /> - )) + filteredLogs.map(item => ( + toggleExpand(item.id)} + /> + )) )} -
+
- {/* Progress */} - {isRunning && ( -
-
- Progress - {task.progress_percentage?.toFixed(0) || 0}% -
-
-
-
+ {/* Status Bar */} + {task && ( +
+ + {isRunning ? ( + + + {statusVerb}{'.'.repeat(statusDots)} + + ) : isComplete ? ( + Audit {task.status} + ) : ( + 'Ready' + )} + + + {task.progress_percentage?.toFixed(0) || 0}% • {task.analyzed_files}/{task.total_files} files • {task.tool_calls_count || 0} tools +
)}
+ + {/* Right Panel - Agent Tree + Stats (25%) */} +
+ {/* Agent Tree */} +
+
+
+ + Agent Tree + {agentTree && ( + + {agentTree.total_agents} + + )} +
+
+ {selectedAgentId && !showAllLogs && ( + + )} + {agentTree && agentTree.running_agents > 0 && ( + + + {agentTree.running_agents} active + + )} +
+
+
+ {treeNodes.length > 0 ? ( + treeNodes.map(node => ( + + )) + ) : ( +
+ {isRunning ? ( + + + Initializing agents... + + ) : ( + 'No agents yet' + )} +
+ )} +
+
+ + {/* Agent Detail + Stats Panel */} +
+ {/* 选中 Agent 详情 */} + {selectedAgentId && !showAllLogs && ( + { setShowAllLogs(true); setSelectedAgentId(null); }} + /> + )} + +
+
+ + {/* Create Agent Task Dialog */} +
); } diff --git a/frontend/src/shared/api/agentTasks.ts b/frontend/src/shared/api/agentTasks.ts index c6aecd4..20c1eb5 100644 --- a/frontend/src/shared/api/agentTasks.ts +++ b/frontend/src/shared/api/agentTasks.ts @@ -26,6 +26,11 @@ export interface AgentTask { verified_count: number; false_positive_count: number; + // Agent 统计 + total_iterations: number; + tool_calls_count: number; + tokens_used: number; + // 严重程度统计 critical_count: number; high_count: number; @@ -34,7 +39,7 @@ export interface AgentTask { // 评分 quality_score: number; - security_score: number; + security_score: number | null; // 时间 created_at: string; @@ -44,6 +49,13 @@ export interface AgentTask { // 进度 progress_percentage: number; + // 配置 + audit_scope: Record | null; + target_vulnerabilities: string[] | null; + verification_level: string | null; + exclude_patterns: string[] | null; + target_files: string[] | null; + // 错误信息 error_message: string | null; } @@ -307,3 +319,89 @@ export async function* streamAgentEvents( } } +// ============ Agent Tree Types ============ + +export interface AgentTreeNode { + id: string; + agent_id: string; + agent_name: string; + agent_type: string; + parent_agent_id: string | null; + depth: number; + task_description: string | null; + knowledge_modules: string[] | null; + status: "created" | "running" | "completed" | "failed" | "waiting"; + result_summary: string | null; + findings_count: number; + iterations: number; + tokens_used: number; + tool_calls: number; + duration_ms: number | null; + children: AgentTreeNode[]; +} + +export interface AgentTreeResponse { + task_id: string; + root_agent_id: string | null; + total_agents: number; + running_agents: number; + completed_agents: number; + failed_agents: number; + total_findings: number; + nodes: AgentTreeNode[]; +} + +export interface AgentCheckpoint { + id: string; + agent_id: string; + agent_name: string; + agent_type: string; + iteration: number; + status: string; + total_tokens: number; + tool_calls: number; + findings_count: number; + checkpoint_type: "auto" | "manual" | "error" | "final"; + checkpoint_name: string | null; + created_at: string | null; +} + +export interface CheckpointDetail extends AgentCheckpoint { + task_id: string; + parent_agent_id: string | null; + state_data: Record; + metadata: Record | null; +} + +// ============ Agent Tree API Functions ============ + +/** + * 获取任务的 Agent 树结构 + */ +export async function getAgentTree(taskId: string): Promise { + const response = await apiClient.get(`/agent-tasks/${taskId}/agent-tree`); + return response.data; +} + +/** + * 获取任务的检查点列表 + */ +export async function getAgentCheckpoints( + taskId: string, + params?: { agent_id?: string; limit?: number } +): Promise { + const response = await apiClient.get(`/agent-tasks/${taskId}/checkpoints`, { params }); + return response.data; +} + +/** + * 获取检查点详情 + */ +export async function getCheckpointDetail( + taskId: string, + checkpointId: string +): Promise { + const response = await apiClient.get(`/agent-tasks/${taskId}/checkpoints/${checkpointId}`); + return response.data; +} +