From 9bc114af1f6e0cffd0bdcb8c09f062b52f44232f Mon Sep 17 00:00:00 2001 From: lintsinghua Date: Thu, 11 Dec 2025 19:09:10 +0800 Subject: [PATCH] feat(agent): implement Agent audit module with LangGraph integration - Introduce new Agent audit functionality for autonomous code security analysis and vulnerability verification. - Add API endpoints for managing Agent tasks and configurations. - Implement UI components for Agent mode selection and embedding model configuration. - Enhance the overall architecture with a focus on RAG (Retrieval-Augmented Generation) for improved code semantic search. - Create a sandbox environment for secure execution of vulnerability tests. - Update documentation to include details on the new Agent audit features and usage instructions. --- backend/Dockerfile | 1 + backend/alembic.ini | 1 + backend/alembic/env.py | 1 + backend/alembic/script.py.mako | 1 + .../alembic/versions/006_add_agent_tables.py | 255 +++++ backend/app/api/v1/api.py | 4 +- backend/app/api/v1/endpoints/agent_tasks.py | 639 ++++++++++++ .../app/api/v1/endpoints/embedding_config.py | 396 ++++++++ backend/app/api/v1/endpoints/members.py | 1 + backend/app/api/v1/endpoints/users.py | 1 + backend/app/core/config.py | 26 + backend/app/core/security.py | 1 + backend/app/db/base.py | 1 + backend/app/db/session.py | 12 + backend/app/models/__init__.py | 7 + backend/app/models/agent_task.py | 443 ++++++++ backend/app/models/analysis.py | 1 + backend/app/models/project.py | 2 + backend/app/models/user.py | 1 + backend/app/models/user_config.py | 1 + backend/app/schemas/token.py | 1 + backend/app/schemas/user.py | 1 + backend/app/services/agent/__init__.py | 58 ++ backend/app/services/agent/agents/__init__.py | 21 + backend/app/services/agent/agents/analysis.py | 469 +++++++++ backend/app/services/agent/agents/base.py | 284 ++++++ .../app/services/agent/agents/orchestrator.py | 381 +++++++ backend/app/services/agent/agents/recon.py | 435 ++++++++ .../app/services/agent/agents/verification.py | 392 ++++++++ backend/app/services/agent/event_manager.py | 371 +++++++ backend/app/services/agent/graph/__init__.py | 28 + .../app/services/agent/graph/audit_graph.py | 455 +++++++++ backend/app/services/agent/graph/nodes.py | 360 +++++++ backend/app/services/agent/graph/runner.py | 621 ++++++++++++ .../app/services/agent/prompts/__init__.py | 20 + .../services/agent/prompts/system_prompts.py | 170 ++++ backend/app/services/agent/tools/__init__.py | 61 ++ backend/app/services/agent/tools/base.py | 156 +++ .../agent/tools/code_analysis_tool.py | 427 ++++++++ .../services/agent/tools/external_tools.py | 948 ++++++++++++++++++ backend/app/services/agent/tools/file_tool.py | 481 +++++++++ .../app/services/agent/tools/pattern_tool.py | 418 ++++++++ backend/app/services/agent/tools/rag_tool.py | 293 ++++++ .../app/services/agent/tools/sandbox_tool.py | 647 ++++++++++++ .../services/llm/adapters/baidu_adapter.py | 1 + .../services/llm/adapters/doubao_adapter.py | 1 + .../services/llm/adapters/minimax_adapter.py | 1 + backend/app/services/llm/base_adapter.py | 1 + backend/app/services/llm/types.py | 1 + backend/app/services/rag/__init__.py | 18 + backend/app/services/rag/embeddings.py | 658 ++++++++++++ backend/app/services/rag/indexer.py | 585 +++++++++++ backend/app/services/rag/retriever.py | 469 +++++++++ backend/app/services/rag/splitter.py | 785 +++++++++++++++ backend/requirements.txt | 48 + docker/sandbox/Dockerfile | 62 ++ docker/sandbox/build.sh | 25 + docker/sandbox/seccomp.json | 266 +++++ docs/AGENT_AUDIT.md | 298 ++++++ frontend/Dockerfile | 1 + frontend/src/app/ProtectedRoute.tsx | 1 + frontend/src/app/routes.tsx | 7 + .../components/agent/AgentModeSelector.tsx | 158 +++ .../src/components/agent/EmbeddingConfig.tsx | 449 +++++++++ .../src/components/audit/CreateTaskDialog.tsx | 51 +- .../src/components/system/SystemConfig.tsx | 13 +- frontend/src/pages/AgentAudit.tsx | 579 +++++++++++ frontend/src/shared/api/agentTasks.ts | 306 ++++++ 68 files changed, 14072 insertions(+), 5 deletions(-) create mode 100644 backend/alembic/versions/006_add_agent_tables.py create mode 100644 backend/app/api/v1/endpoints/agent_tasks.py create mode 100644 backend/app/api/v1/endpoints/embedding_config.py create mode 100644 backend/app/models/agent_task.py create mode 100644 backend/app/services/agent/__init__.py create mode 100644 backend/app/services/agent/agents/__init__.py create mode 100644 backend/app/services/agent/agents/analysis.py create mode 100644 backend/app/services/agent/agents/base.py create mode 100644 backend/app/services/agent/agents/orchestrator.py create mode 100644 backend/app/services/agent/agents/recon.py create mode 100644 backend/app/services/agent/agents/verification.py create mode 100644 backend/app/services/agent/event_manager.py create mode 100644 backend/app/services/agent/graph/__init__.py create mode 100644 backend/app/services/agent/graph/audit_graph.py create mode 100644 backend/app/services/agent/graph/nodes.py create mode 100644 backend/app/services/agent/graph/runner.py create mode 100644 backend/app/services/agent/prompts/__init__.py create mode 100644 backend/app/services/agent/prompts/system_prompts.py create mode 100644 backend/app/services/agent/tools/__init__.py create mode 100644 backend/app/services/agent/tools/base.py create mode 100644 backend/app/services/agent/tools/code_analysis_tool.py create mode 100644 backend/app/services/agent/tools/external_tools.py create mode 100644 backend/app/services/agent/tools/file_tool.py create mode 100644 backend/app/services/agent/tools/pattern_tool.py create mode 100644 backend/app/services/agent/tools/rag_tool.py create mode 100644 backend/app/services/agent/tools/sandbox_tool.py create mode 100644 backend/app/services/rag/__init__.py create mode 100644 backend/app/services/rag/embeddings.py create mode 100644 backend/app/services/rag/indexer.py create mode 100644 backend/app/services/rag/retriever.py create mode 100644 backend/app/services/rag/splitter.py create mode 100644 docker/sandbox/Dockerfile create mode 100644 docker/sandbox/build.sh create mode 100644 docker/sandbox/seccomp.json create mode 100644 docs/AGENT_AUDIT.md create mode 100644 frontend/src/components/agent/AgentModeSelector.tsx create mode 100644 frontend/src/components/agent/EmbeddingConfig.tsx create mode 100644 frontend/src/pages/AgentAudit.tsx create mode 100644 frontend/src/shared/api/agentTasks.ts diff --git a/backend/Dockerfile b/backend/Dockerfile index e138ad7..53294a4 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -58,3 +58,4 @@ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] + diff --git a/backend/alembic.ini b/backend/alembic.ini index c250693..72b7dfb 100644 --- a/backend/alembic.ini +++ b/backend/alembic.ini @@ -102,3 +102,4 @@ format = %(levelname)-5.5s [%(name)s] %(message)s datefmt = %H:%M:%S + diff --git a/backend/alembic/env.py b/backend/alembic/env.py index 1c35728..87c51e8 100644 --- a/backend/alembic/env.py +++ b/backend/alembic/env.py @@ -89,3 +89,4 @@ else: asyncio.run(run_migrations_online()) + diff --git a/backend/alembic/script.py.mako b/backend/alembic/script.py.mako index 96c9a3d..b394332 100644 --- a/backend/alembic/script.py.mako +++ b/backend/alembic/script.py.mako @@ -24,3 +24,4 @@ def downgrade() -> None: ${downgrades if downgrades else "pass"} + diff --git a/backend/alembic/versions/006_add_agent_tables.py b/backend/alembic/versions/006_add_agent_tables.py new file mode 100644 index 0000000..0960c4a --- /dev/null +++ b/backend/alembic/versions/006_add_agent_tables.py @@ -0,0 +1,255 @@ +"""Add agent audit tables + +Revision ID: 006_add_agent_tables +Revises: 5fc1cc05d5d0 +Create Date: 2024-01-15 10:00:00.000000 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '006_add_agent_tables' +down_revision = '5fc1cc05d5d0' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # 创建 agent_tasks 表 + op.create_table( + 'agent_tasks', + sa.Column('id', sa.String(36), primary_key=True), + sa.Column('project_id', sa.String(36), sa.ForeignKey('projects.id', ondelete='CASCADE'), nullable=False), + + # 任务基本信息 + sa.Column('name', sa.String(255), nullable=True), + sa.Column('description', sa.Text(), nullable=True), + sa.Column('task_type', sa.String(50), default='agent_audit'), + + # 任务配置 + sa.Column('audit_scope', sa.JSON(), nullable=True), + sa.Column('target_vulnerabilities', sa.JSON(), nullable=True), + sa.Column('verification_level', sa.String(50), default='sandbox'), + + # 分支信息 + sa.Column('branch_name', sa.String(255), nullable=True), + + # 排除模式 + sa.Column('exclude_patterns', sa.JSON(), nullable=True), + + # 文件范围 + sa.Column('target_files', sa.JSON(), nullable=True), + + # LLM 配置 + sa.Column('llm_config', sa.JSON(), nullable=True), + + # Agent 配置 + sa.Column('agent_config', sa.JSON(), nullable=True), + sa.Column('max_iterations', sa.Integer(), default=50), + sa.Column('token_budget', sa.Integer(), default=100000), + sa.Column('timeout_seconds', sa.Integer(), default=1800), + + # 状态 + sa.Column('status', sa.String(20), default='pending'), + sa.Column('current_phase', sa.String(50), nullable=True), + sa.Column('current_step', sa.String(255), nullable=True), + sa.Column('error_message', sa.Text(), nullable=True), + + # 进度统计 + sa.Column('total_files', sa.Integer(), default=0), + sa.Column('indexed_files', sa.Integer(), default=0), + sa.Column('analyzed_files', sa.Integer(), default=0), + sa.Column('total_chunks', sa.Integer(), default=0), + + # Agent 统计 + sa.Column('total_iterations', sa.Integer(), default=0), + sa.Column('tool_calls_count', sa.Integer(), default=0), + sa.Column('tokens_used', sa.Integer(), default=0), + + # 发现统计 + sa.Column('findings_count', sa.Integer(), default=0), + sa.Column('verified_count', sa.Integer(), default=0), + sa.Column('false_positive_count', sa.Integer(), default=0), + + # 严重程度统计 + sa.Column('critical_count', sa.Integer(), default=0), + sa.Column('high_count', sa.Integer(), default=0), + sa.Column('medium_count', sa.Integer(), default=0), + sa.Column('low_count', sa.Integer(), default=0), + + # 质量评分 + sa.Column('quality_score', sa.Float(), default=0.0), + sa.Column('security_score', sa.Float(), default=0.0), + + # 审计计划 + sa.Column('audit_plan', sa.JSON(), nullable=True), + + # 时间戳 + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column('updated_at', sa.DateTime(timezone=True), onupdate=sa.func.now()), + sa.Column('started_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('completed_at', sa.DateTime(timezone=True), nullable=True), + + # 创建者 + sa.Column('created_by', sa.String(36), sa.ForeignKey('users.id'), nullable=False), + ) + + # 创建 agent_tasks 索引 + op.create_index('ix_agent_tasks_project_id', 'agent_tasks', ['project_id']) + op.create_index('ix_agent_tasks_status', 'agent_tasks', ['status']) + op.create_index('ix_agent_tasks_created_by', 'agent_tasks', ['created_by']) + op.create_index('ix_agent_tasks_created_at', 'agent_tasks', ['created_at']) + + # 创建 agent_events 表 + op.create_table( + 'agent_events', + sa.Column('id', sa.String(36), primary_key=True), + sa.Column('task_id', sa.String(36), sa.ForeignKey('agent_tasks.id', ondelete='CASCADE'), nullable=False), + + # 事件信息 + sa.Column('event_type', sa.String(50), nullable=False), + sa.Column('phase', sa.String(50), nullable=True), + + # 事件内容 + sa.Column('message', sa.Text(), nullable=True), + + # 工具调用相关 + sa.Column('tool_name', sa.String(100), nullable=True), + sa.Column('tool_input', sa.JSON(), nullable=True), + sa.Column('tool_output', sa.JSON(), nullable=True), + sa.Column('tool_duration_ms', sa.Integer(), nullable=True), + + # 关联的发现 + sa.Column('finding_id', sa.String(36), nullable=True), + + # Token 消耗 + sa.Column('tokens_used', sa.Integer(), default=0), + + # 元数据 + sa.Column('metadata', sa.JSON(), nullable=True), + + # 序号 + sa.Column('sequence', sa.Integer(), default=0), + + # 时间戳 + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + + # 创建 agent_events 索引 + op.create_index('ix_agent_events_task_id', 'agent_events', ['task_id']) + op.create_index('ix_agent_events_event_type', 'agent_events', ['event_type']) + op.create_index('ix_agent_events_sequence', 'agent_events', ['sequence']) + op.create_index('ix_agent_events_created_at', 'agent_events', ['created_at']) + + # 创建 agent_findings 表 + op.create_table( + 'agent_findings', + sa.Column('id', sa.String(36), primary_key=True), + sa.Column('task_id', sa.String(36), sa.ForeignKey('agent_tasks.id', ondelete='CASCADE'), nullable=False), + + # 漏洞基本信息 + sa.Column('vulnerability_type', sa.String(100), nullable=False), + sa.Column('severity', sa.String(20), nullable=False), + sa.Column('title', sa.String(500), nullable=False), + sa.Column('description', sa.Text(), nullable=True), + + # 位置信息 + sa.Column('file_path', sa.String(500), nullable=True), + sa.Column('line_start', sa.Integer(), nullable=True), + sa.Column('line_end', sa.Integer(), nullable=True), + sa.Column('column_start', sa.Integer(), nullable=True), + sa.Column('column_end', sa.Integer(), nullable=True), + sa.Column('function_name', sa.String(255), nullable=True), + sa.Column('class_name', sa.String(255), nullable=True), + + # 代码片段 + sa.Column('code_snippet', sa.Text(), nullable=True), + sa.Column('code_context', sa.Text(), nullable=True), + + # 数据流信息 + sa.Column('source', sa.Text(), nullable=True), + sa.Column('sink', sa.Text(), nullable=True), + sa.Column('dataflow_path', sa.JSON(), nullable=True), + + # 验证信息 + sa.Column('status', sa.String(30), default='new'), + sa.Column('is_verified', sa.Boolean(), default=False), + sa.Column('verification_method', sa.Text(), nullable=True), + sa.Column('verification_result', sa.JSON(), nullable=True), + sa.Column('verified_at', sa.DateTime(timezone=True), nullable=True), + + # PoC + sa.Column('has_poc', sa.Boolean(), default=False), + sa.Column('poc_code', sa.Text(), nullable=True), + sa.Column('poc_description', sa.Text(), nullable=True), + sa.Column('poc_steps', sa.JSON(), nullable=True), + + # 修复建议 + sa.Column('suggestion', sa.Text(), nullable=True), + sa.Column('fix_code', sa.Text(), nullable=True), + sa.Column('fix_description', sa.Text(), nullable=True), + sa.Column('references', sa.JSON(), nullable=True), + + # AI 解释 + sa.Column('ai_explanation', sa.Text(), nullable=True), + sa.Column('ai_confidence', sa.Float(), nullable=True), + + # XAI + sa.Column('xai_what', sa.Text(), nullable=True), + sa.Column('xai_why', sa.Text(), nullable=True), + sa.Column('xai_how', sa.Text(), nullable=True), + sa.Column('xai_impact', sa.Text(), nullable=True), + + # 关联规则 + sa.Column('matched_rule_code', sa.String(100), nullable=True), + sa.Column('matched_pattern', sa.Text(), nullable=True), + + # CVSS 评分 + sa.Column('cvss_score', sa.Float(), nullable=True), + sa.Column('cvss_vector', sa.String(100), nullable=True), + + # 元数据 + sa.Column('metadata', sa.JSON(), nullable=True), + sa.Column('tags', sa.JSON(), nullable=True), + + # 去重标识 + sa.Column('fingerprint', sa.String(64), nullable=True), + + # 时间戳 + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.func.now()), + sa.Column('updated_at', sa.DateTime(timezone=True), onupdate=sa.func.now()), + ) + + # 创建 agent_findings 索引 + op.create_index('ix_agent_findings_task_id', 'agent_findings', ['task_id']) + op.create_index('ix_agent_findings_vulnerability_type', 'agent_findings', ['vulnerability_type']) + op.create_index('ix_agent_findings_severity', 'agent_findings', ['severity']) + op.create_index('ix_agent_findings_file_path', 'agent_findings', ['file_path']) + op.create_index('ix_agent_findings_status', 'agent_findings', ['status']) + op.create_index('ix_agent_findings_fingerprint', 'agent_findings', ['fingerprint']) + + +def downgrade() -> None: + # 删除索引和表 + op.drop_index('ix_agent_findings_fingerprint', 'agent_findings') + op.drop_index('ix_agent_findings_status', 'agent_findings') + op.drop_index('ix_agent_findings_file_path', 'agent_findings') + op.drop_index('ix_agent_findings_severity', 'agent_findings') + op.drop_index('ix_agent_findings_vulnerability_type', 'agent_findings') + op.drop_index('ix_agent_findings_task_id', 'agent_findings') + op.drop_table('agent_findings') + + op.drop_index('ix_agent_events_created_at', 'agent_events') + op.drop_index('ix_agent_events_sequence', 'agent_events') + op.drop_index('ix_agent_events_event_type', 'agent_events') + op.drop_index('ix_agent_events_task_id', 'agent_events') + op.drop_table('agent_events') + + op.drop_index('ix_agent_tasks_created_at', 'agent_tasks') + op.drop_index('ix_agent_tasks_created_by', 'agent_tasks') + op.drop_index('ix_agent_tasks_status', 'agent_tasks') + op.drop_index('ix_agent_tasks_project_id', 'agent_tasks') + op.drop_table('agent_tasks') + diff --git a/backend/app/api/v1/api.py b/backend/app/api/v1/api.py index b2fb66f..c233682 100644 --- a/backend/app/api/v1/api.py +++ b/backend/app/api/v1/api.py @@ -1,5 +1,5 @@ from fastapi import APIRouter -from app.api.v1.endpoints import auth, users, projects, tasks, scan, members, config, database, prompts, rules +from app.api.v1.endpoints import auth, users, projects, tasks, scan, members, config, database, prompts, rules, agent_tasks, embedding_config api_router = APIRouter() api_router.include_router(auth.router, prefix="/auth", tags=["auth"]) @@ -12,3 +12,5 @@ api_router.include_router(config.router, prefix="/config", tags=["config"]) api_router.include_router(database.router, prefix="/database", tags=["database"]) api_router.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) api_router.include_router(rules.router, prefix="/rules", tags=["rules"]) +api_router.include_router(agent_tasks.router, prefix="/agent-tasks", tags=["agent-tasks"]) +api_router.include_router(embedding_config.router, prefix="/embedding", tags=["embedding"]) diff --git a/backend/app/api/v1/endpoints/agent_tasks.py b/backend/app/api/v1/endpoints/agent_tasks.py new file mode 100644 index 0000000..63bfb2e --- /dev/null +++ b/backend/app/api/v1/endpoints/agent_tasks.py @@ -0,0 +1,639 @@ +""" +DeepAudit Agent 审计任务 API +基于 LangGraph 的 Agent 审计 +""" + +import asyncio +import json +import logging +import os +import shutil +from typing import Any, List, Optional, Dict +from datetime import datetime, timezone +from uuid import uuid4 + +from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, Query +from fastapi.responses import StreamingResponse +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from sqlalchemy.orm import selectinload +from pydantic import BaseModel, Field + +from app.api import deps +from app.db.session import get_db, async_session_factory +from app.models.agent_task import ( + AgentTask, AgentEvent, AgentFinding, + AgentTaskStatus, AgentTaskPhase, AgentEventType, + VulnerabilitySeverity, FindingStatus, +) +from app.models.project import Project +from app.models.user import User +from app.services.agent import AgentRunner, EventManager, run_agent_task + +logger = logging.getLogger(__name__) +router = APIRouter() + +# 运行中的任务 +_running_tasks: Dict[str, AgentRunner] = {} + + +# ============ Schemas ============ + +class AgentTaskCreate(BaseModel): + """创建 Agent 任务请求""" + project_id: str = Field(..., description="项目 ID") + name: Optional[str] = Field(None, description="任务名称") + description: Optional[str] = Field(None, description="任务描述") + + # 审计配置 + audit_scope: Optional[dict] = Field(None, description="审计范围") + target_vulnerabilities: Optional[List[str]] = Field( + default=["sql_injection", "xss", "command_injection", "path_traversal", "ssrf"], + description="目标漏洞类型" + ) + verification_level: str = Field( + "sandbox", + description="验证级别: analysis_only, sandbox, generate_poc" + ) + + # 分支 + branch_name: Optional[str] = Field(None, description="分支名称") + + # 排除模式 + exclude_patterns: Optional[List[str]] = Field( + default=["node_modules", "__pycache__", ".git", "*.min.js"], + description="排除模式" + ) + + # 文件范围 + target_files: Optional[List[str]] = Field(None, description="指定扫描的文件") + + # Agent 配置 + max_iterations: int = Field(3, ge=1, le=10, description="最大分析迭代次数") + timeout_seconds: int = Field(1800, ge=60, le=7200, description="超时时间(秒)") + + +class AgentTaskResponse(BaseModel): + """Agent 任务响应""" + id: str + project_id: str + name: Optional[str] + description: Optional[str] + status: str + current_phase: Optional[str] + + # 统计 + total_findings: int = 0 + verified_findings: int = 0 + security_score: Optional[int] = None + + # 时间 + created_at: datetime + started_at: Optional[datetime] = None + finished_at: Optional[datetime] = None + + # 配置 + config: Optional[dict] = None + + # 错误信息 + error_message: Optional[str] = None + + class Config: + from_attributes = True + + +class AgentEventResponse(BaseModel): + """Agent 事件响应""" + id: str + task_id: str + event_type: str + phase: Optional[str] + message: str + sequence: int + created_at: datetime + + # 可选字段 + tool_name: Optional[str] = None + tool_duration_ms: Optional[int] = None + progress_percent: Optional[float] = None + finding_id: Optional[str] = None + + class Config: + from_attributes = True + + +class AgentFindingResponse(BaseModel): + """Agent 发现响应""" + id: str + task_id: str + vulnerability_type: str + severity: str + title: str + description: Optional[str] + file_path: Optional[str] + line_start: Optional[int] + line_end: Optional[int] + code_snippet: Optional[str] + + is_verified: bool + confidence: float + status: str + + suggestion: Optional[str] = None + poc: Optional[dict] = None + + created_at: datetime + + class Config: + from_attributes = True + + +class TaskSummaryResponse(BaseModel): + """任务摘要响应""" + task_id: str + status: str + security_score: Optional[int] + + total_findings: int + verified_findings: int + + severity_distribution: Dict[str, int] + vulnerability_types: Dict[str, int] + + duration_seconds: Optional[int] + phases_completed: List[str] + + +# ============ 后台任务执行 ============ + +async def _execute_agent_task(task_id: str, project_root: str): + """在后台执行 Agent 任务""" + async with async_session_factory() as db: + try: + # 获取任务 + task = await db.get(AgentTask, task_id, options=[selectinload(AgentTask.project)]) + if not task: + logger.error(f"Task {task_id} not found") + return + + # 创建 Runner + runner = AgentRunner(db, task, project_root) + _running_tasks[task_id] = runner + + # 执行 + result = await runner.run() + + logger.info(f"Task {task_id} completed: {result.get('success', False)}") + + except Exception as e: + logger.error(f"Task {task_id} failed: {e}", exc_info=True) + + # 更新任务状态 + task = await db.get(AgentTask, task_id) + if task: + task.status = AgentTaskStatus.FAILED + task.error_message = str(e) + task.finished_at = datetime.now(timezone.utc) + await db.commit() + + finally: + # 清理 + _running_tasks.pop(task_id, None) + + +# ============ API Endpoints ============ + +@router.post("/", response_model=AgentTaskResponse) +async def create_agent_task( + request: AgentTaskCreate, + background_tasks: BackgroundTasks, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(deps.get_current_user), +) -> Any: + """ + 创建并启动 Agent 审计任务 + """ + # 验证项目 + project = await db.get(Project, request.project_id) + if not project: + raise HTTPException(status_code=404, detail="项目不存在") + + if project.owner_id != current_user.id: + raise HTTPException(status_code=403, detail="无权访问此项目") + + # 创建任务 + task = AgentTask( + id=str(uuid4()), + project_id=project.id, + name=request.name or f"Agent Audit - {datetime.now().strftime('%Y%m%d_%H%M%S')}", + description=request.description, + status=AgentTaskStatus.PENDING, + current_phase=AgentTaskPhase.PLANNING, + config={ + "target_vulnerabilities": request.target_vulnerabilities, + "verification_level": request.verification_level, + "exclude_patterns": request.exclude_patterns, + "target_files": request.target_files, + "max_iterations": request.max_iterations, + "timeout_seconds": request.timeout_seconds, + }, + created_by=current_user.id, + ) + + db.add(task) + await db.commit() + await db.refresh(task) + + # 确定项目根目录 + project_root = _get_project_root(project, task.id) + + # 在后台启动任务 + background_tasks.add_task(_execute_agent_task, task.id, project_root) + + logger.info(f"Created agent task {task.id} for project {project.name}") + + return task + + +@router.get("/", response_model=List[AgentTaskResponse]) +async def list_agent_tasks( + project_id: Optional[str] = None, + status: Optional[str] = None, + skip: int = Query(0, ge=0), + limit: int = Query(20, ge=1, le=100), + db: AsyncSession = Depends(get_db), + current_user: User = Depends(deps.get_current_user), +) -> Any: + """ + 获取 Agent 任务列表 + """ + # 获取用户的项目 + projects_result = await db.execute( + select(Project.id).where(Project.owner_id == current_user.id) + ) + user_project_ids = [p[0] for p in projects_result.fetchall()] + + if not user_project_ids: + return [] + + # 构建查询 + query = select(AgentTask).where(AgentTask.project_id.in_(user_project_ids)) + + if project_id: + query = query.where(AgentTask.project_id == project_id) + + if status: + try: + status_enum = AgentTaskStatus(status) + query = query.where(AgentTask.status == status_enum) + except ValueError: + pass + + query = query.order_by(AgentTask.created_at.desc()) + query = query.offset(skip).limit(limit) + + result = await db.execute(query) + tasks = result.scalars().all() + + return tasks + + +@router.get("/{task_id}", response_model=AgentTaskResponse) +async def get_agent_task( + task_id: str, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(deps.get_current_user), +) -> Any: + """ + 获取 Agent 任务详情 + """ + task = await db.get(AgentTask, task_id) + if not task: + raise HTTPException(status_code=404, detail="任务不存在") + + # 检查权限 + project = await db.get(Project, task.project_id) + if not project or project.owner_id != current_user.id: + raise HTTPException(status_code=403, detail="无权访问此任务") + + return task + + +@router.post("/{task_id}/cancel") +async def cancel_agent_task( + task_id: str, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(deps.get_current_user), +) -> Any: + """ + 取消 Agent 任务 + """ + task = await db.get(AgentTask, task_id) + if not task: + raise HTTPException(status_code=404, detail="任务不存在") + + project = await db.get(Project, task.project_id) + if not project or project.owner_id != current_user.id: + raise HTTPException(status_code=403, detail="无权操作此任务") + + if task.status in [AgentTaskStatus.COMPLETED, AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]: + raise HTTPException(status_code=400, detail="任务已结束,无法取消") + + # 取消运行中的任务 + runner = _running_tasks.get(task_id) + if runner: + runner.cancel() + + # 更新状态 + task.status = AgentTaskStatus.CANCELLED + task.finished_at = datetime.now(timezone.utc) + await db.commit() + + return {"message": "任务已取消", "task_id": task_id} + + +@router.get("/{task_id}/events") +async def stream_agent_events( + task_id: str, + after_sequence: int = Query(0, ge=0, description="从哪个序号之后开始"), + db: AsyncSession = Depends(get_db), + current_user: User = Depends(deps.get_current_user), +): + """ + 获取 Agent 事件流 (SSE) + """ + task = await db.get(AgentTask, task_id) + if not task: + raise HTTPException(status_code=404, detail="任务不存在") + + project = await db.get(Project, task.project_id) + if not project or project.owner_id != current_user.id: + raise HTTPException(status_code=403, detail="无权访问此任务") + + async def event_generator(): + """生成 SSE 事件流""" + last_sequence = after_sequence + poll_interval = 0.5 + max_idle = 300 # 5 分钟无事件后关闭 + idle_time = 0 + + while True: + # 查询新事件 + async with async_session_factory() as session: + result = await session.execute( + select(AgentEvent) + .where(AgentEvent.task_id == task_id) + .where(AgentEvent.sequence > last_sequence) + .order_by(AgentEvent.sequence) + .limit(50) + ) + events = result.scalars().all() + + # 获取任务状态 + current_task = await session.get(AgentTask, task_id) + task_status = current_task.status if current_task else None + + if events: + idle_time = 0 + for event in events: + last_sequence = event.sequence + data = { + "id": event.id, + "type": event.event_type.value if hasattr(event.event_type, 'value') else str(event.event_type), + "phase": event.phase.value if event.phase and hasattr(event.phase, 'value') else event.phase, + "message": event.message, + "sequence": event.sequence, + "timestamp": event.created_at.isoformat() if event.created_at else None, + "progress_percent": event.progress_percent, + "tool_name": event.tool_name, + } + yield f"data: {json.dumps(data, ensure_ascii=False)}\n\n" + else: + idle_time += poll_interval + + # 检查任务是否结束 + if task_status in [AgentTaskStatus.COMPLETED, AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]: + yield f"data: {json.dumps({'type': 'task_end', 'status': task_status.value})}\n\n" + break + + # 检查空闲超时 + if idle_time >= max_idle: + yield f"data: {json.dumps({'type': 'timeout'})}\n\n" + break + + await asyncio.sleep(poll_interval) + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + } + ) + + +@router.get("/{task_id}/events/list", response_model=List[AgentEventResponse]) +async def list_agent_events( + task_id: str, + after_sequence: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=500), + db: AsyncSession = Depends(get_db), + current_user: User = Depends(deps.get_current_user), +) -> Any: + """ + 获取 Agent 事件列表 + """ + task = await db.get(AgentTask, task_id) + if not task: + raise HTTPException(status_code=404, detail="任务不存在") + + project = await db.get(Project, task.project_id) + if not project or project.owner_id != current_user.id: + raise HTTPException(status_code=403, detail="无权访问此任务") + + result = await db.execute( + select(AgentEvent) + .where(AgentEvent.task_id == task_id) + .where(AgentEvent.sequence > after_sequence) + .order_by(AgentEvent.sequence) + .limit(limit) + ) + events = result.scalars().all() + + return events + + +@router.get("/{task_id}/findings", response_model=List[AgentFindingResponse]) +async def list_agent_findings( + task_id: str, + severity: Optional[str] = None, + verified_only: bool = False, + skip: int = Query(0, ge=0), + limit: int = Query(50, ge=1, le=200), + db: AsyncSession = Depends(get_db), + current_user: User = Depends(deps.get_current_user), +) -> Any: + """ + 获取 Agent 发现列表 + """ + task = await db.get(AgentTask, task_id) + if not task: + raise HTTPException(status_code=404, detail="任务不存在") + + project = await db.get(Project, task.project_id) + if not project or project.owner_id != current_user.id: + raise HTTPException(status_code=403, detail="无权访问此任务") + + query = select(AgentFinding).where(AgentFinding.task_id == task_id) + + if severity: + try: + sev_enum = VulnerabilitySeverity(severity) + query = query.where(AgentFinding.severity == sev_enum) + except ValueError: + pass + + if verified_only: + query = query.where(AgentFinding.is_verified == True) + + # 按严重程度排序 + severity_order = { + VulnerabilitySeverity.CRITICAL: 0, + VulnerabilitySeverity.HIGH: 1, + VulnerabilitySeverity.MEDIUM: 2, + VulnerabilitySeverity.LOW: 3, + VulnerabilitySeverity.INFO: 4, + } + + query = query.order_by(AgentFinding.severity, AgentFinding.created_at.desc()) + query = query.offset(skip).limit(limit) + + result = await db.execute(query) + findings = result.scalars().all() + + return findings + + +@router.get("/{task_id}/summary", response_model=TaskSummaryResponse) +async def get_task_summary( + task_id: str, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(deps.get_current_user), +) -> Any: + """ + 获取任务摘要 + """ + task = await db.get(AgentTask, task_id) + if not task: + raise HTTPException(status_code=404, detail="任务不存在") + + project = await db.get(Project, task.project_id) + if not project or project.owner_id != current_user.id: + raise HTTPException(status_code=403, detail="无权访问此任务") + + # 获取所有发现 + result = await db.execute( + select(AgentFinding).where(AgentFinding.task_id == task_id) + ) + findings = result.scalars().all() + + # 统计 + severity_distribution = {} + vulnerability_types = {} + verified_count = 0 + + for f in findings: + sev = f.severity.value if hasattr(f.severity, 'value') else str(f.severity) + vtype = f.vulnerability_type.value if hasattr(f.vulnerability_type, 'value') else str(f.vulnerability_type) + + severity_distribution[sev] = severity_distribution.get(sev, 0) + 1 + vulnerability_types[vtype] = vulnerability_types.get(vtype, 0) + 1 + + if f.is_verified: + verified_count += 1 + + # 计算持续时间 + duration = None + if task.started_at and task.finished_at: + duration = int((task.finished_at - task.started_at).total_seconds()) + + # 获取已完成的阶段 + phases_result = await db.execute( + select(AgentEvent.phase) + .where(AgentEvent.task_id == task_id) + .where(AgentEvent.event_type == AgentEventType.PHASE_COMPLETE) + .distinct() + ) + phases = [p[0].value if p[0] and hasattr(p[0], 'value') else str(p[0]) for p in phases_result.fetchall() if p[0]] + + return TaskSummaryResponse( + task_id=task_id, + status=task.status.value if hasattr(task.status, 'value') else str(task.status), + security_score=task.security_score, + total_findings=len(findings), + verified_findings=verified_count, + severity_distribution=severity_distribution, + vulnerability_types=vulnerability_types, + duration_seconds=duration, + phases_completed=phases, + ) + + +@router.patch("/{task_id}/findings/{finding_id}/status") +async def update_finding_status( + task_id: str, + finding_id: str, + status: str, + db: AsyncSession = Depends(get_db), + current_user: User = Depends(deps.get_current_user), +) -> Any: + """ + 更新发现状态 + """ + task = await db.get(AgentTask, task_id) + if not task: + raise HTTPException(status_code=404, detail="任务不存在") + + project = await db.get(Project, task.project_id) + if not project or project.owner_id != current_user.id: + raise HTTPException(status_code=403, detail="无权操作") + + finding = await db.get(AgentFinding, finding_id) + if not finding or finding.task_id != task_id: + raise HTTPException(status_code=404, detail="发现不存在") + + try: + finding.status = FindingStatus(status) + except ValueError: + raise HTTPException(status_code=400, detail=f"无效的状态: {status}") + + await db.commit() + + return {"message": "状态已更新", "finding_id": finding_id, "status": status} + + +# ============ Helper Functions ============ + +def _get_project_root(project: Project, task_id: str) -> str: + """ + 获取项目根目录 + + TODO: 实际实现中需要: + - 对于 ZIP 项目:解压到临时目录 + - 对于 Git 仓库:克隆到临时目录 + """ + base_path = f"/tmp/deepaudit/{task_id}" + + # 确保目录存在 + os.makedirs(base_path, exist_ok=True) + + # 如果项目有存储路径,复制过来 + if hasattr(project, 'storage_path') and project.storage_path: + if os.path.exists(project.storage_path): + # 复制项目文件 + shutil.copytree(project.storage_path, base_path, dirs_exist_ok=True) + + return base_path + diff --git a/backend/app/api/v1/endpoints/embedding_config.py b/backend/app/api/v1/endpoints/embedding_config.py new file mode 100644 index 0000000..bc91c51 --- /dev/null +++ b/backend/app/api/v1/endpoints/embedding_config.py @@ -0,0 +1,396 @@ +""" +嵌入模型配置 API +独立于 LLM 配置,专门用于 RAG 系统的嵌入模型 +使用 UserConfig.other_config 持久化存储 +""" + +import json +import uuid +from typing import Any, Optional, List +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel, Field +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api import deps +from app.models.user import User +from app.models.user_config import UserConfig +from app.core.config import settings + +router = APIRouter() + + +# ============ Schemas ============ + +class EmbeddingProvider(BaseModel): + """嵌入模型提供商""" + id: str + name: str + description: str + models: List[str] + requires_api_key: bool + default_model: str + + +class EmbeddingConfig(BaseModel): + """嵌入模型配置""" + provider: str = Field(description="提供商: openai, ollama, azure, cohere, huggingface") + model: str = Field(description="模型名称") + api_key: Optional[str] = Field(default=None, description="API Key (如需要)") + base_url: Optional[str] = Field(default=None, description="自定义 API 端点") + dimensions: Optional[int] = Field(default=None, description="向量维度 (某些模型支持)") + batch_size: int = Field(default=100, description="批处理大小") + + +class EmbeddingConfigResponse(BaseModel): + """配置响应""" + provider: str + model: str + base_url: Optional[str] + dimensions: int + batch_size: int + # 不返回 API Key + + +class TestEmbeddingRequest(BaseModel): + """测试嵌入请求""" + provider: str + model: str + api_key: Optional[str] = None + base_url: Optional[str] = None + test_text: str = "这是一段测试文本,用于验证嵌入模型是否正常工作。" + + +class TestEmbeddingResponse(BaseModel): + """测试嵌入响应""" + success: bool + message: str + dimensions: Optional[int] = None + sample_embedding: Optional[List[float]] = None # 前 5 个维度 + latency_ms: Optional[int] = None + + +# ============ 提供商配置 ============ + +EMBEDDING_PROVIDERS: List[EmbeddingProvider] = [ + EmbeddingProvider( + id="openai", + name="OpenAI", + description="OpenAI 官方嵌入模型,高质量、稳定", + models=[ + "text-embedding-3-small", + "text-embedding-3-large", + "text-embedding-ada-002", + ], + requires_api_key=True, + default_model="text-embedding-3-small", + ), + EmbeddingProvider( + id="azure", + name="Azure OpenAI", + description="Azure 托管的 OpenAI 嵌入模型", + models=[ + "text-embedding-3-small", + "text-embedding-3-large", + "text-embedding-ada-002", + ], + requires_api_key=True, + default_model="text-embedding-3-small", + ), + EmbeddingProvider( + id="ollama", + name="Ollama (本地)", + description="本地运行的开源嵌入模型 (使用 /api/embed 端点)", + models=[ + "nomic-embed-text", + "mxbai-embed-large", + "all-minilm", + "snowflake-arctic-embed", + "bge-m3", + "qwen3-embedding", + ], + requires_api_key=False, + default_model="nomic-embed-text", + ), + EmbeddingProvider( + id="cohere", + name="Cohere", + description="Cohere Embed v2 API (api.cohere.com/v2)", + models=[ + "embed-english-v3.0", + "embed-multilingual-v3.0", + "embed-english-light-v3.0", + "embed-multilingual-light-v3.0", + "embed-v4.0", + ], + requires_api_key=True, + default_model="embed-multilingual-v3.0", + ), + EmbeddingProvider( + id="huggingface", + name="HuggingFace", + description="HuggingFace Inference Providers (router.huggingface.co)", + models=[ + "sentence-transformers/all-MiniLM-L6-v2", + "sentence-transformers/all-mpnet-base-v2", + "BAAI/bge-large-zh-v1.5", + "BAAI/bge-m3", + ], + requires_api_key=True, + default_model="BAAI/bge-m3", + ), + EmbeddingProvider( + id="jina", + name="Jina AI", + description="Jina AI 嵌入模型,代码嵌入效果好", + models=[ + "jina-embeddings-v2-base-code", + "jina-embeddings-v2-base-en", + "jina-embeddings-v2-base-zh", + ], + requires_api_key=True, + default_model="jina-embeddings-v2-base-code", + ), +] + + +# ============ 数据库持久化存储 (异步) ============ + +EMBEDDING_CONFIG_KEY = "embedding_config" + + +async def get_embedding_config_from_db(db: AsyncSession, user_id: str) -> EmbeddingConfig: + """从数据库获取嵌入配置(异步)""" + result = await db.execute( + select(UserConfig).where(UserConfig.user_id == user_id) + ) + user_config = result.scalar_one_or_none() + + if user_config and user_config.other_config: + try: + other_config = json.loads(user_config.other_config) if isinstance(user_config.other_config, str) else user_config.other_config + embedding_data = other_config.get(EMBEDDING_CONFIG_KEY) + + if embedding_data: + return EmbeddingConfig( + provider=embedding_data.get("provider", settings.EMBEDDING_PROVIDER), + model=embedding_data.get("model", settings.EMBEDDING_MODEL), + api_key=embedding_data.get("api_key"), + base_url=embedding_data.get("base_url"), + dimensions=embedding_data.get("dimensions"), + batch_size=embedding_data.get("batch_size", 100), + ) + except (json.JSONDecodeError, AttributeError): + pass + + # 返回默认配置 + return EmbeddingConfig( + provider=settings.EMBEDDING_PROVIDER, + model=settings.EMBEDDING_MODEL, + api_key=settings.LLM_API_KEY, + base_url=settings.LLM_BASE_URL, + batch_size=100, + ) + + +async def save_embedding_config_to_db(db: AsyncSession, user_id: str, config: EmbeddingConfig) -> None: + """保存嵌入配置到数据库(异步)""" + result = await db.execute( + select(UserConfig).where(UserConfig.user_id == user_id) + ) + user_config = result.scalar_one_or_none() + + # 准备嵌入配置数据 + embedding_data = { + "provider": config.provider, + "model": config.model, + "api_key": config.api_key, + "base_url": config.base_url, + "dimensions": config.dimensions, + "batch_size": config.batch_size, + } + + if user_config: + # 更新现有配置 + try: + other_config = json.loads(user_config.other_config) if user_config.other_config else {} + except (json.JSONDecodeError, TypeError): + other_config = {} + + other_config[EMBEDDING_CONFIG_KEY] = embedding_data + user_config.other_config = json.dumps(other_config) + else: + # 创建新配置 + user_config = UserConfig( + id=str(uuid.uuid4()), + user_id=user_id, + llm_config="{}", + other_config=json.dumps({EMBEDDING_CONFIG_KEY: embedding_data}), + ) + db.add(user_config) + + await db.commit() + + +# ============ API Endpoints ============ + +@router.get("/providers", response_model=List[EmbeddingProvider]) +async def list_embedding_providers( + current_user: User = Depends(deps.get_current_user), +) -> Any: + """ + 获取可用的嵌入模型提供商列表 + """ + return EMBEDDING_PROVIDERS + + +@router.get("/config", response_model=EmbeddingConfigResponse) +async def get_current_config( + db: AsyncSession = Depends(deps.get_db), + current_user: User = Depends(deps.get_current_user), +) -> Any: + """ + 获取当前嵌入模型配置(从数据库读取) + """ + config = await get_embedding_config_from_db(db, current_user.id) + + # 获取维度 + dimensions = _get_model_dimensions(config.provider, config.model) + + return EmbeddingConfigResponse( + provider=config.provider, + model=config.model, + base_url=config.base_url, + dimensions=dimensions, + batch_size=config.batch_size, + ) + + +@router.put("/config") +async def update_config( + config: EmbeddingConfig, + db: AsyncSession = Depends(deps.get_db), + current_user: User = Depends(deps.get_current_user), +) -> Any: + """ + 更新嵌入模型配置(持久化到数据库) + """ + # 验证提供商 + provider_ids = [p.id for p in EMBEDDING_PROVIDERS] + if config.provider not in provider_ids: + raise HTTPException(status_code=400, detail=f"不支持的提供商: {config.provider}") + + # 验证模型 + provider = next((p for p in EMBEDDING_PROVIDERS if p.id == config.provider), None) + if provider and config.model not in provider.models: + raise HTTPException(status_code=400, detail=f"不支持的模型: {config.model}") + + # 检查 API Key + if provider and provider.requires_api_key and not config.api_key: + raise HTTPException(status_code=400, detail=f"{config.provider} 需要 API Key") + + # 保存到数据库 + await save_embedding_config_to_db(db, current_user.id, config) + + return {"message": "配置已保存", "provider": config.provider, "model": config.model} + + +@router.post("/test", response_model=TestEmbeddingResponse) +async def test_embedding( + request: TestEmbeddingRequest, + current_user: User = Depends(deps.get_current_user), +) -> Any: + """ + 测试嵌入模型配置 + """ + import time + + try: + start_time = time.time() + + # 创建临时嵌入服务 + from app.services.rag.embeddings import EmbeddingService + + service = EmbeddingService( + provider=request.provider, + model=request.model, + api_key=request.api_key, + base_url=request.base_url, + cache_enabled=False, + ) + + # 执行嵌入 + embedding = await service.embed(request.test_text) + + latency_ms = int((time.time() - start_time) * 1000) + + return TestEmbeddingResponse( + success=True, + message=f"嵌入成功! 维度: {len(embedding)}", + dimensions=len(embedding), + sample_embedding=embedding[:5], # 返回前 5 维 + latency_ms=latency_ms, + ) + + except Exception as e: + return TestEmbeddingResponse( + success=False, + message=f"嵌入失败: {str(e)}", + ) + + +@router.get("/models/{provider}") +async def get_provider_models( + provider: str, + current_user: User = Depends(deps.get_current_user), +) -> Any: + """ + 获取指定提供商的模型列表 + """ + provider_info = next((p for p in EMBEDDING_PROVIDERS if p.id == provider), None) + + if not provider_info: + raise HTTPException(status_code=404, detail=f"提供商不存在: {provider}") + + return { + "provider": provider, + "models": provider_info.models, + "default_model": provider_info.default_model, + "requires_api_key": provider_info.requires_api_key, + } + + +def _get_model_dimensions(provider: str, model: str) -> int: + """获取模型维度""" + dimensions_map = { + # OpenAI + "text-embedding-3-small": 1536, + "text-embedding-3-large": 3072, + "text-embedding-ada-002": 1536, + + # Ollama + "nomic-embed-text": 768, + "mxbai-embed-large": 1024, + "all-minilm": 384, + "snowflake-arctic-embed": 1024, + + # Cohere + "embed-english-v3.0": 1024, + "embed-multilingual-v3.0": 1024, + "embed-english-light-v3.0": 384, + "embed-multilingual-light-v3.0": 384, + + # HuggingFace + "sentence-transformers/all-MiniLM-L6-v2": 384, + "sentence-transformers/all-mpnet-base-v2": 768, + "BAAI/bge-large-zh-v1.5": 1024, + "BAAI/bge-m3": 1024, + + # Jina + "jina-embeddings-v2-base-code": 768, + "jina-embeddings-v2-base-en": 768, + "jina-embeddings-v2-base-zh": 768, + } + + return dimensions_map.get(model, 768) + diff --git a/backend/app/api/v1/endpoints/members.py b/backend/app/api/v1/endpoints/members.py index 81a14e5..612ccbe 100644 --- a/backend/app/api/v1/endpoints/members.py +++ b/backend/app/api/v1/endpoints/members.py @@ -209,3 +209,4 @@ async def remove_project_member( return {"message": "成员已移除"} + diff --git a/backend/app/api/v1/endpoints/users.py b/backend/app/api/v1/endpoints/users.py index 32e1d48..951f603 100644 --- a/backend/app/api/v1/endpoints/users.py +++ b/backend/app/api/v1/endpoints/users.py @@ -224,3 +224,4 @@ async def toggle_user_status( return user + diff --git a/backend/app/core/config.py b/backend/app/core/config.py index dec1be9..4cd7330 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -76,6 +76,32 @@ class Settings(BaseSettings): # 输出语言配置 - 支持 zh-CN(中文)和 en-US(英文) OUTPUT_LANGUAGE: str = "zh-CN" + + # ============ Agent 模块配置 ============ + + # 嵌入模型配置 + EMBEDDING_PROVIDER: str = "openai" # openai, ollama, litellm + EMBEDDING_MODEL: str = "text-embedding-3-small" + + # 向量数据库配置 + VECTOR_DB_PATH: str = "./data/vector_db" # 向量数据库持久化目录 + + # Agent 配置 + AGENT_MAX_ITERATIONS: int = 50 # Agent 最大迭代次数 + AGENT_TOKEN_BUDGET: int = 100000 # Agent Token 预算 + AGENT_TIMEOUT_SECONDS: int = 1800 # Agent 超时时间(30分钟) + + # 沙箱配置 + SANDBOX_IMAGE: str = "deepaudit-sandbox:latest" # 沙箱 Docker 镜像 + SANDBOX_MEMORY_LIMIT: str = "512m" # 沙箱内存限制 + SANDBOX_CPU_LIMIT: float = 1.0 # 沙箱 CPU 限制 + SANDBOX_TIMEOUT: int = 60 # 沙箱命令超时(秒) + SANDBOX_NETWORK_MODE: str = "none" # 沙箱网络模式 (none, bridge) + + # RAG 配置 + RAG_CHUNK_SIZE: int = 1500 # 代码块大小(Token) + RAG_CHUNK_OVERLAP: int = 50 # 代码块重叠(Token) + RAG_TOP_K: int = 10 # 检索返回数量 class Config: case_sensitive = True diff --git a/backend/app/core/security.py b/backend/app/core/security.py index a71f44a..d0e5295 100644 --- a/backend/app/core/security.py +++ b/backend/app/core/security.py @@ -28,3 +28,4 @@ def get_password_hash(password: str) -> str: return pwd_context.hash(password) + diff --git a/backend/app/db/base.py b/backend/app/db/base.py index 88b8670..98a9a1b 100644 --- a/backend/app/db/base.py +++ b/backend/app/db/base.py @@ -11,3 +11,4 @@ class Base: return cls.__name__.lower() + "s" + diff --git a/backend/app/db/session.py b/backend/app/db/session.py index 286df94..8b4db67 100644 --- a/backend/app/db/session.py +++ b/backend/app/db/session.py @@ -1,3 +1,4 @@ +from contextlib import asynccontextmanager from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.orm import sessionmaker from app.core.config import settings @@ -16,3 +17,14 @@ async def get_db(): await session.close() +@asynccontextmanager +async def async_session_factory(): + """Async context manager for creating database sessions""" + async with AsyncSessionLocal() as session: + try: + yield session + finally: + await session.close() + + + diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 9771d38..06527e3 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -1,8 +1,15 @@ from .user import User +from .user_config import UserConfig from .project import Project, ProjectMember from .audit import AuditTask, AuditIssue from .analysis import InstantAnalysis from .prompt_template import PromptTemplate from .audit_rule import AuditRuleSet, AuditRule +from .agent_task import ( + AgentTask, AgentEvent, AgentFinding, + AgentTaskStatus, AgentTaskPhase, AgentEventType, + VulnerabilitySeverity, VulnerabilityType, FindingStatus +) + diff --git a/backend/app/models/agent_task.py b/backend/app/models/agent_task.py new file mode 100644 index 0000000..6a912e9 --- /dev/null +++ b/backend/app/models/agent_task.py @@ -0,0 +1,443 @@ +""" +Agent 审计任务模型 +支持 AI Agent 自主漏洞挖掘和验证 +""" + +import uuid +from datetime import datetime +from typing import Optional, List, TYPE_CHECKING +from sqlalchemy import ( + Column, String, Integer, Float, Text, Boolean, + DateTime, ForeignKey, Enum as SQLEnum, JSON +) +from sqlalchemy.orm import relationship +from sqlalchemy.sql import func + +from app.db.base import Base + +if TYPE_CHECKING: + from .project import Project + + +class AgentTaskStatus: + """Agent 任务状态""" + PENDING = "pending" # 等待执行 + INITIALIZING = "initializing" # 初始化中 + PLANNING = "planning" # 规划阶段 + INDEXING = "indexing" # 索引阶段 + ANALYZING = "analyzing" # 分析阶段 + VERIFYING = "verifying" # 验证阶段 + REPORTING = "reporting" # 报告生成 + COMPLETED = "completed" # 已完成 + FAILED = "failed" # 失败 + CANCELLED = "cancelled" # 已取消 + PAUSED = "paused" # 已暂停 + + +class AgentTaskPhase: + """Agent 执行阶段""" + PLANNING = "planning" + INDEXING = "indexing" + RECONNAISSANCE = "reconnaissance" + ANALYSIS = "analysis" + VERIFICATION = "verification" + REPORTING = "reporting" + + +class AgentTask(Base): + """Agent 审计任务""" + __tablename__ = "agent_tasks" + + id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + project_id = Column(String(36), ForeignKey("projects.id", ondelete="CASCADE"), nullable=False) + + # 任务基本信息 + name = Column(String(255), nullable=True) + description = Column(Text, nullable=True) + task_type = Column(String(50), default="agent_audit") + + # 任务配置 + audit_scope = Column(JSON, nullable=True) # 审计范围配置 + target_vulnerabilities = Column(JSON, nullable=True) # 目标漏洞类型 + verification_level = Column(String(50), default="sandbox") # analysis_only, sandbox, generate_poc + + # 分支信息(仓库项目) + branch_name = Column(String(255), nullable=True) + + # 排除模式 + exclude_patterns = Column(JSON, nullable=True) + + # 文件范围 + target_files = Column(JSON, nullable=True) # 指定扫描的文件列表 + + # LLM 配置 + llm_config = Column(JSON, nullable=True) # LLM 配置 + + # Agent 配置 + agent_config = Column(JSON, nullable=True) # Agent 特定配置 + max_iterations = Column(Integer, default=50) # 最大迭代次数 + token_budget = Column(Integer, default=100000) # Token 预算 + timeout_seconds = Column(Integer, default=1800) # 超时时间(秒) + + # 状态 + status = Column(String(20), default=AgentTaskStatus.PENDING) + current_phase = Column(String(50), nullable=True) + current_step = Column(String(255), nullable=True) # 当前执行步骤描述 + error_message = Column(Text, nullable=True) + + # 进度统计 + total_files = Column(Integer, default=0) + indexed_files = Column(Integer, default=0) + analyzed_files = Column(Integer, default=0) + total_chunks = Column(Integer, default=0) # 代码块总数 + + # Agent 统计 + total_iterations = Column(Integer, default=0) # Agent 迭代次数 + tool_calls_count = Column(Integer, default=0) # 工具调用次数 + tokens_used = Column(Integer, default=0) # 已使用 Token 数 + + # 发现统计 + findings_count = Column(Integer, default=0) # 发现总数 + verified_count = Column(Integer, default=0) # 已验证数 + false_positive_count = Column(Integer, default=0) # 误报数 + + # 严重程度统计 + critical_count = Column(Integer, default=0) + high_count = Column(Integer, default=0) + medium_count = Column(Integer, default=0) + low_count = Column(Integer, default=0) + + # 质量评分 + quality_score = Column(Float, default=0.0) + security_score = Column(Float, default=0.0) + + # 审计计划 + audit_plan = Column(JSON, nullable=True) # Agent 生成的审计计划 + + # 时间戳 + created_at = Column(DateTime(timezone=True), server_default=func.now()) + updated_at = Column(DateTime(timezone=True), onupdate=func.now()) + started_at = Column(DateTime(timezone=True), nullable=True) + completed_at = Column(DateTime(timezone=True), nullable=True) + + # 创建者 + created_by = Column(String(36), ForeignKey("users.id"), nullable=False) + + # 关联关系 + project = relationship("Project", back_populates="agent_tasks") + events = relationship("AgentEvent", back_populates="task", cascade="all, delete-orphan", order_by="AgentEvent.created_at") + findings = relationship("AgentFinding", back_populates="task", cascade="all, delete-orphan") + + def __repr__(self): + return f"" + + @property + def progress_percentage(self) -> float: + """计算进度百分比""" + if self.status == AgentTaskStatus.COMPLETED: + return 100.0 + if self.status in [AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]: + return 0.0 + + phase_weights = { + AgentTaskPhase.PLANNING: 5, + AgentTaskPhase.INDEXING: 15, + AgentTaskPhase.RECONNAISSANCE: 10, + AgentTaskPhase.ANALYSIS: 50, + AgentTaskPhase.VERIFICATION: 15, + AgentTaskPhase.REPORTING: 5, + } + + completed_weight = 0 + current_found = False + + for phase, weight in phase_weights.items(): + if phase == self.current_phase: + current_found = True + # 估算当前阶段进度 + if phase == AgentTaskPhase.INDEXING and self.total_files > 0: + completed_weight += weight * (self.indexed_files / self.total_files) + elif phase == AgentTaskPhase.ANALYSIS and self.total_files > 0: + completed_weight += weight * (self.analyzed_files / self.total_files) + else: + completed_weight += weight * 0.5 + break + elif not current_found: + completed_weight += weight + + return min(completed_weight, 99.0) + + +class AgentEventType: + """Agent 事件类型""" + # 系统事件 + TASK_START = "task_start" + TASK_COMPLETE = "task_complete" + TASK_ERROR = "task_error" + TASK_CANCEL = "task_cancel" + + # 阶段事件 + PHASE_START = "phase_start" + PHASE_COMPLETE = "phase_complete" + + # Agent 思考 + THINKING = "thinking" + PLANNING = "planning" + DECISION = "decision" + + # 工具调用 + TOOL_CALL = "tool_call" + TOOL_RESULT = "tool_result" + TOOL_ERROR = "tool_error" + + # RAG 相关 + RAG_QUERY = "rag_query" + RAG_RESULT = "rag_result" + + # 发现相关 + FINDING_NEW = "finding_new" + FINDING_UPDATE = "finding_update" + FINDING_VERIFIED = "finding_verified" + FINDING_FALSE_POSITIVE = "finding_false_positive" + + # 沙箱相关 + SANDBOX_START = "sandbox_start" + SANDBOX_EXEC = "sandbox_exec" + SANDBOX_RESULT = "sandbox_result" + SANDBOX_ERROR = "sandbox_error" + + # 进度 + PROGRESS = "progress" + + # 日志 + INFO = "info" + WARNING = "warning" + ERROR = "error" + DEBUG = "debug" + + +class AgentEvent(Base): + """Agent 执行事件(用于实时日志和回放)""" + __tablename__ = "agent_events" + + id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + task_id = Column(String(36), ForeignKey("agent_tasks.id", ondelete="CASCADE"), nullable=False, index=True) + + # 事件信息 + event_type = Column(String(50), nullable=False, index=True) + phase = Column(String(50), nullable=True) + + # 事件内容 + message = Column(Text, nullable=True) + + # 工具调用相关 + tool_name = Column(String(100), nullable=True) + tool_input = Column(JSON, nullable=True) + tool_output = Column(JSON, nullable=True) + tool_duration_ms = Column(Integer, nullable=True) # 工具执行时长(毫秒) + + # 关联的发现 + finding_id = Column(String(36), nullable=True) + + # Token 消耗 + tokens_used = Column(Integer, default=0) + + # 元数据 + event_metadata = Column(JSON, nullable=True) + + # 序号(用于排序) + sequence = Column(Integer, default=0, index=True) + + # 时间戳 + created_at = Column(DateTime(timezone=True), server_default=func.now(), index=True) + + # 关联关系 + task = relationship("AgentTask", back_populates="events") + + def __repr__(self): + return f"" + + def to_sse_dict(self) -> dict: + """转换为 SSE 事件格式""" + return { + "id": self.id, + "type": self.event_type, + "phase": self.phase, + "message": self.message, + "tool_name": self.tool_name, + "tool_input": self.tool_input, + "tool_output": self.tool_output, + "tool_duration_ms": self.tool_duration_ms, + "finding_id": self.finding_id, + "tokens_used": self.tokens_used, + "metadata": self.event_metadata, + "sequence": self.sequence, + "timestamp": self.created_at.isoformat() if self.created_at else None, + } + + +class VulnerabilitySeverity: + """漏洞严重程度""" + CRITICAL = "critical" + HIGH = "high" + MEDIUM = "medium" + LOW = "low" + INFO = "info" + + +class VulnerabilityType: + """漏洞类型""" + SQL_INJECTION = "sql_injection" + NOSQL_INJECTION = "nosql_injection" + XSS = "xss" + COMMAND_INJECTION = "command_injection" + CODE_INJECTION = "code_injection" + PATH_TRAVERSAL = "path_traversal" + FILE_INCLUSION = "file_inclusion" + SSRF = "ssrf" + XXE = "xxe" + DESERIALIZATION = "deserialization" + AUTH_BYPASS = "auth_bypass" + IDOR = "idor" + SENSITIVE_DATA_EXPOSURE = "sensitive_data_exposure" + HARDCODED_SECRET = "hardcoded_secret" + WEAK_CRYPTO = "weak_crypto" + RACE_CONDITION = "race_condition" + BUSINESS_LOGIC = "business_logic" + MEMORY_CORRUPTION = "memory_corruption" + OTHER = "other" + + +class FindingStatus: + """发现状态""" + NEW = "new" # 新发现 + ANALYZING = "analyzing" # 分析中 + VERIFIED = "verified" # 已验证 + FALSE_POSITIVE = "false_positive" # 误报 + NEEDS_REVIEW = "needs_review" # 需要人工审核 + FIXED = "fixed" # 已修复 + WONT_FIX = "wont_fix" # 不修复 + DUPLICATE = "duplicate" # 重复 + + +class AgentFinding(Base): + """Agent 发现的漏洞""" + __tablename__ = "agent_findings" + + id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + task_id = Column(String(36), ForeignKey("agent_tasks.id", ondelete="CASCADE"), nullable=False, index=True) + + # 漏洞基本信息 + vulnerability_type = Column(String(100), nullable=False, index=True) + severity = Column(String(20), nullable=False, index=True) + title = Column(String(500), nullable=False) + description = Column(Text, nullable=True) + + # 位置信息 + file_path = Column(String(500), nullable=True, index=True) + line_start = Column(Integer, nullable=True) + line_end = Column(Integer, nullable=True) + column_start = Column(Integer, nullable=True) + column_end = Column(Integer, nullable=True) + function_name = Column(String(255), nullable=True) + class_name = Column(String(255), nullable=True) + + # 代码片段 + code_snippet = Column(Text, nullable=True) + code_context = Column(Text, nullable=True) # 更多上下文 + + # 数据流信息 + source = Column(Text, nullable=True) # 污点源 + sink = Column(Text, nullable=True) # 危险函数 + dataflow_path = Column(JSON, nullable=True) # 数据流路径 + + # 验证信息 + status = Column(String(30), default=FindingStatus.NEW, index=True) + is_verified = Column(Boolean, default=False) + verification_method = Column(Text, nullable=True) + verification_result = Column(JSON, nullable=True) + verified_at = Column(DateTime(timezone=True), nullable=True) + + # PoC + has_poc = Column(Boolean, default=False) + poc_code = Column(Text, nullable=True) + poc_description = Column(Text, nullable=True) + poc_steps = Column(JSON, nullable=True) # 复现步骤 + + # 修复建议 + suggestion = Column(Text, nullable=True) + fix_code = Column(Text, nullable=True) + fix_description = Column(Text, nullable=True) + references = Column(JSON, nullable=True) # 参考链接 CWE, OWASP 等 + + # AI 解释 + ai_explanation = Column(Text, nullable=True) + ai_confidence = Column(Float, nullable=True) # AI 置信度 0-1 + + # XAI (可解释AI) + xai_what = Column(Text, nullable=True) + xai_why = Column(Text, nullable=True) + xai_how = Column(Text, nullable=True) + xai_impact = Column(Text, nullable=True) + + # 关联规则 + matched_rule_code = Column(String(100), nullable=True) + matched_pattern = Column(Text, nullable=True) + + # CVSS 评分(可选) + cvss_score = Column(Float, nullable=True) + cvss_vector = Column(String(100), nullable=True) + + # 元数据 + finding_metadata = Column(JSON, nullable=True) + tags = Column(JSON, nullable=True) + + # 去重标识 + fingerprint = Column(String(64), nullable=True, index=True) # 用于去重的指纹 + + # 时间戳 + created_at = Column(DateTime(timezone=True), server_default=func.now()) + updated_at = Column(DateTime(timezone=True), onupdate=func.now()) + + # 关联关系 + task = relationship("AgentTask", back_populates="findings") + + def __repr__(self): + return f"" + + def generate_fingerprint(self) -> str: + """生成去重指纹""" + import hashlib + components = [ + self.vulnerability_type or "", + self.file_path or "", + str(self.line_start or 0), + self.function_name or "", + (self.code_snippet or "")[:200], + ] + content = "|".join(components) + return hashlib.sha256(content.encode()).hexdigest()[:16] + + def to_dict(self) -> dict: + """转换为字典""" + return { + "id": self.id, + "task_id": self.task_id, + "vulnerability_type": self.vulnerability_type, + "severity": self.severity, + "title": self.title, + "description": self.description, + "file_path": self.file_path, + "line_start": self.line_start, + "line_end": self.line_end, + "code_snippet": self.code_snippet, + "status": self.status, + "is_verified": self.is_verified, + "has_poc": self.has_poc, + "poc_code": self.poc_code, + "suggestion": self.suggestion, + "fix_code": self.fix_code, + "ai_explanation": self.ai_explanation, + "ai_confidence": self.ai_confidence, + "created_at": self.created_at.isoformat() if self.created_at else None, + } diff --git a/backend/app/models/analysis.py b/backend/app/models/analysis.py index 863e9fa..c55d47a 100644 --- a/backend/app/models/analysis.py +++ b/backend/app/models/analysis.py @@ -23,3 +23,4 @@ class InstantAnalysis(Base): user = relationship("User", backref="instant_analyses") + diff --git a/backend/app/models/project.py b/backend/app/models/project.py index a507c4c..ac89beb 100644 --- a/backend/app/models/project.py +++ b/backend/app/models/project.py @@ -31,6 +31,7 @@ class Project(Base): owner = relationship("User", backref="projects") members = relationship("ProjectMember", back_populates="project", cascade="all, delete-orphan") tasks = relationship("AuditTask", back_populates="project", cascade="all, delete-orphan") + agent_tasks = relationship("AgentTask", back_populates="project", cascade="all, delete-orphan") class ProjectMember(Base): __tablename__ = "project_members" @@ -49,3 +50,4 @@ class ProjectMember(Base): user = relationship("User", backref="project_memberships") + diff --git a/backend/app/models/user.py b/backend/app/models/user.py index 1d1cf1f..401f8d3 100644 --- a/backend/app/models/user.py +++ b/backend/app/models/user.py @@ -24,3 +24,4 @@ class User(Base): updated_at = Column(DateTime(timezone=True), onupdate=func.now()) + diff --git a/backend/app/models/user_config.py b/backend/app/models/user_config.py index 2f491e4..9a30a2f 100644 --- a/backend/app/models/user_config.py +++ b/backend/app/models/user_config.py @@ -29,3 +29,4 @@ class UserConfig(Base): user = relationship("User", backref="config") + diff --git a/backend/app/schemas/token.py b/backend/app/schemas/token.py index 05c722c..d3908cd 100644 --- a/backend/app/schemas/token.py +++ b/backend/app/schemas/token.py @@ -9,3 +9,4 @@ class TokenPayload(BaseModel): sub: Optional[str] = None + diff --git a/backend/app/schemas/user.py b/backend/app/schemas/user.py index 584c5c4..ffce190 100644 --- a/backend/app/schemas/user.py +++ b/backend/app/schemas/user.py @@ -40,3 +40,4 @@ class UserListResponse(BaseModel): limit: int + diff --git a/backend/app/services/agent/__init__.py b/backend/app/services/agent/__init__.py new file mode 100644 index 0000000..a82b83b --- /dev/null +++ b/backend/app/services/agent/__init__.py @@ -0,0 +1,58 @@ +""" +DeepAudit Agent 服务模块 +基于 LangGraph 的 AI Agent 代码安全审计 + +架构: + LangGraph 状态图工作流 + + START → Recon → Analysis ⟲ → Verification → Report → END + +节点: + - Recon: 信息收集 (项目结构、技术栈、入口点) + - Analysis: 漏洞分析 (静态分析、RAG、模式匹配) + - Verification: 漏洞验证 (LLM 验证、沙箱测试) + - Report: 报告生成 +""" + +# 从 graph 模块导入主要组件 +from .graph import ( + AgentRunner, + run_agent_task, + LLMService, + AuditState, + create_audit_graph, +) + +# 事件管理 +from .event_manager import EventManager, AgentEventEmitter + +# Agent 类 +from .agents import ( + BaseAgent, AgentConfig, AgentResult, + OrchestratorAgent, ReconAgent, AnalysisAgent, VerificationAgent, +) + +__all__ = [ + # 核心 Runner + "AgentRunner", + "run_agent_task", + "LLMService", + + # LangGraph + "AuditState", + "create_audit_graph", + + # 事件管理 + "EventManager", + "AgentEventEmitter", + + # Agent 类 + "BaseAgent", + "AgentConfig", + "AgentResult", + "OrchestratorAgent", + "ReconAgent", + "AnalysisAgent", + "VerificationAgent", +] + diff --git a/backend/app/services/agent/agents/__init__.py b/backend/app/services/agent/agents/__init__.py new file mode 100644 index 0000000..3009b64 --- /dev/null +++ b/backend/app/services/agent/agents/__init__.py @@ -0,0 +1,21 @@ +""" +混合 Agent 架构 +包含 Orchestrator、Recon、Analysis 和 Verification Agent +""" + +from .base import BaseAgent, AgentConfig, AgentResult +from .orchestrator import OrchestratorAgent +from .recon import ReconAgent +from .analysis import AnalysisAgent +from .verification import VerificationAgent + +__all__ = [ + "BaseAgent", + "AgentConfig", + "AgentResult", + "OrchestratorAgent", + "ReconAgent", + "AnalysisAgent", + "VerificationAgent", +] + diff --git a/backend/app/services/agent/agents/analysis.py b/backend/app/services/agent/agents/analysis.py new file mode 100644 index 0000000..ecdf8e5 --- /dev/null +++ b/backend/app/services/agent/agents/analysis.py @@ -0,0 +1,469 @@ +""" +Analysis Agent (漏洞分析层) +负责代码审计、RAG 查询、模式匹配、数据流分析 + +类型: ReAct +""" + +import asyncio +import logging +from typing import List, Dict, Any, Optional + +from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern + +logger = logging.getLogger(__name__) + + +ANALYSIS_SYSTEM_PROMPT = """你是 DeepAudit 的漏洞分析 Agent,负责深度代码安全分析。 + +## 你的职责 +1. 使用静态分析工具快速扫描 +2. 使用 RAG 进行语义代码搜索 +3. 追踪数据流(从用户输入到危险函数) +4. 分析业务逻辑漏洞 +5. 评估漏洞严重程度 + +## 你可以使用的工具 +### 外部扫描工具 +- semgrep_scan: Semgrep 静态分析(推荐首先使用) +- bandit_scan: Python 安全扫描 + +### RAG 语义搜索 +- rag_query: 语义代码搜索 +- security_search: 安全相关代码搜索 +- function_context: 函数上下文分析 + +### 深度分析 +- pattern_match: 危险模式匹配 +- code_analysis: LLM 深度代码分析 +- dataflow_analysis: 数据流追踪 +- vulnerability_validation: 漏洞验证 + +### 文件操作 +- read_file: 读取文件 +- search_code: 关键字搜索 + +## 分析策略 +1. **快速扫描**: 先用 Semgrep 快速发现问题 +2. **语义搜索**: 用 RAG 找到相关代码 +3. **深度分析**: 对可疑代码进行 LLM 分析 +4. **数据流追踪**: 追踪用户输入的流向 + +## 重点关注 +- SQL 注入、NoSQL 注入 +- XSS(反射型、存储型、DOM型) +- 命令注入、代码注入 +- 路径遍历、任意文件访问 +- SSRF、XXE +- 不安全的反序列化 +- 认证/授权绕过 +- 敏感信息泄露 + +## 输出格式 +发现漏洞时,返回结构化信息: +```json +{ + "findings": [ + { + "vulnerability_type": "漏洞类型", + "severity": "critical/high/medium/low", + "title": "漏洞标题", + "description": "详细描述", + "file_path": "文件路径", + "line_start": 行号, + "code_snippet": "代码片段", + "source": "污点源", + "sink": "危险函数", + "suggestion": "修复建议", + "needs_verification": true/false + } + ] +} +``` + +请系统性地分析代码,发现真实的安全漏洞。""" + + +class AnalysisAgent(BaseAgent): + """ + 漏洞分析 Agent + + 使用 ReAct 模式进行迭代分析 + """ + + def __init__( + self, + llm_service, + tools: Dict[str, Any], + event_emitter=None, + ): + config = AgentConfig( + name="Analysis", + agent_type=AgentType.ANALYSIS, + pattern=AgentPattern.REACT, + max_iterations=30, + system_prompt=ANALYSIS_SYSTEM_PROMPT, + tools=[ + "semgrep_scan", "bandit_scan", + "rag_query", "security_search", "function_context", + "pattern_match", "code_analysis", "dataflow_analysis", + "vulnerability_validation", + "read_file", "search_code", + ], + ) + super().__init__(config, llm_service, tools, event_emitter) + + async def run(self, input_data: Dict[str, Any]) -> AgentResult: + """执行漏洞分析""" + import time + start_time = time.time() + + phase_name = input_data.get("phase_name", "analysis") + project_info = input_data.get("project_info", {}) + config = input_data.get("config", {}) + plan = input_data.get("plan", {}) + previous_results = input_data.get("previous_results", {}) + + # 从之前的 Recon 结果获取信息 + recon_data = previous_results.get("recon", {}).get("data", {}) + high_risk_areas = recon_data.get("high_risk_areas", plan.get("high_risk_areas", [])) + tech_stack = recon_data.get("tech_stack", {}) + entry_points = recon_data.get("entry_points", []) + + try: + all_findings = [] + + # 1. 静态分析阶段 + if phase_name in ["static_analysis", "analysis"]: + await self.emit_thinking("执行静态代码分析...") + static_findings = await self._run_static_analysis(tech_stack) + all_findings.extend(static_findings) + + # 2. 深度分析阶段 + if phase_name in ["deep_analysis", "analysis"]: + await self.emit_thinking("执行深度漏洞分析...") + + # 分析入口点 + deep_findings = await self._analyze_entry_points(entry_points) + all_findings.extend(deep_findings) + + # 分析高风险区域 + risk_findings = await self._analyze_high_risk_areas(high_risk_areas) + all_findings.extend(risk_findings) + + # 语义搜索常见漏洞 + vuln_types = config.get("target_vulnerabilities", [ + "sql_injection", "xss", "command_injection", + "path_traversal", "ssrf", "hardcoded_secret", + ]) + + for vuln_type in vuln_types[:5]: # 限制数量 + if self.is_cancelled: + break + + await self.emit_thinking(f"搜索 {vuln_type} 相关代码...") + vuln_findings = await self._search_vulnerability_pattern(vuln_type) + all_findings.extend(vuln_findings) + + # 去重 + all_findings = self._deduplicate_findings(all_findings) + + duration_ms = int((time.time() - start_time) * 1000) + + await self.emit_event( + "info", + f"分析完成: 发现 {len(all_findings)} 个潜在漏洞" + ) + + return AgentResult( + success=True, + data={"findings": all_findings}, + iterations=self._iteration, + tool_calls=self._tool_calls, + tokens_used=self._total_tokens, + duration_ms=duration_ms, + ) + + except Exception as e: + logger.error(f"Analysis agent failed: {e}", exc_info=True) + return AgentResult(success=False, error=str(e)) + + async def _run_static_analysis(self, tech_stack: Dict) -> List[Dict]: + """运行静态分析工具""" + findings = [] + + # Semgrep 扫描 + semgrep_tool = self.tools.get("semgrep_scan") + if semgrep_tool: + await self.emit_tool_call("semgrep_scan", {"rules": "p/security-audit"}) + + result = await semgrep_tool.execute(rules="p/security-audit", max_results=30) + + if result.success and result.metadata.get("findings_count", 0) > 0: + for finding in result.metadata.get("findings", []): + findings.append({ + "vulnerability_type": self._map_semgrep_rule(finding.get("check_id", "")), + "severity": self._map_semgrep_severity(finding.get("extra", {}).get("severity", "")), + "title": finding.get("check_id", "Semgrep Finding"), + "description": finding.get("extra", {}).get("message", ""), + "file_path": finding.get("path", ""), + "line_start": finding.get("start", {}).get("line", 0), + "code_snippet": finding.get("extra", {}).get("lines", ""), + "source": "semgrep", + "needs_verification": True, + }) + + # Bandit 扫描 (Python) + languages = tech_stack.get("languages", []) + if "Python" in languages: + bandit_tool = self.tools.get("bandit_scan") + if bandit_tool: + await self.emit_tool_call("bandit_scan", {}) + result = await bandit_tool.execute() + + if result.success and result.metadata.get("findings_count", 0) > 0: + for finding in result.metadata.get("findings", []): + findings.append({ + "vulnerability_type": self._map_bandit_test(finding.get("test_id", "")), + "severity": finding.get("issue_severity", "medium").lower(), + "title": finding.get("test_name", "Bandit Finding"), + "description": finding.get("issue_text", ""), + "file_path": finding.get("filename", ""), + "line_start": finding.get("line_number", 0), + "code_snippet": finding.get("code", ""), + "source": "bandit", + "needs_verification": True, + }) + + return findings + + async def _analyze_entry_points(self, entry_points: List[Dict]) -> List[Dict]: + """分析入口点""" + findings = [] + + code_analysis_tool = self.tools.get("code_analysis") + read_tool = self.tools.get("read_file") + + if not code_analysis_tool or not read_tool: + return findings + + # 分析前几个入口点 + for ep in entry_points[:10]: + if self.is_cancelled: + break + + file_path = ep.get("file", "") + line = ep.get("line", 1) + + if not file_path: + continue + + # 读取文件内容 + read_result = await read_tool.execute( + file_path=file_path, + start_line=max(1, line - 20), + end_line=line + 50, + ) + + if not read_result.success: + continue + + # 深度分析 + analysis_result = await code_analysis_tool.execute( + code=read_result.data, + file_path=file_path, + ) + + if analysis_result.success and analysis_result.metadata.get("issues"): + for issue in analysis_result.metadata["issues"]: + findings.append({ + "vulnerability_type": issue.get("type", "unknown"), + "severity": issue.get("severity", "medium"), + "title": issue.get("title", "Security Issue"), + "description": issue.get("description", ""), + "file_path": file_path, + "line_start": issue.get("line", line), + "code_snippet": issue.get("code_snippet", ""), + "suggestion": issue.get("suggestion", ""), + "source": "code_analysis", + "needs_verification": True, + }) + + return findings + + async def _analyze_high_risk_areas(self, high_risk_areas: List[str]) -> List[Dict]: + """分析高风险区域""" + findings = [] + + pattern_tool = self.tools.get("pattern_match") + read_tool = self.tools.get("read_file") + search_tool = self.tools.get("search_code") + + if not search_tool: + return findings + + # 在高风险区域搜索危险模式 + dangerous_patterns = [ + ("execute(", "sql_injection"), + ("eval(", "code_injection"), + ("system(", "command_injection"), + ("exec(", "command_injection"), + ("innerHTML", "xss"), + ("document.write", "xss"), + ] + + for pattern, vuln_type in dangerous_patterns[:5]: + if self.is_cancelled: + break + + result = await search_tool.execute(keyword=pattern, max_results=10) + + if result.success and result.metadata.get("matches", 0) > 0: + for match in result.metadata.get("results", [])[:3]: + file_path = match.get("file", "") + + # 检查是否在高风险区域 + in_high_risk = any( + area in file_path for area in high_risk_areas + ) + + if in_high_risk or True: # 暂时包含所有 + findings.append({ + "vulnerability_type": vuln_type, + "severity": "high" if in_high_risk else "medium", + "title": f"疑似 {vuln_type}: {pattern}", + "description": f"在 {file_path} 中发现危险模式 {pattern}", + "file_path": file_path, + "line_start": match.get("line", 0), + "code_snippet": match.get("match", ""), + "source": "pattern_search", + "needs_verification": True, + }) + + return findings + + async def _search_vulnerability_pattern(self, vuln_type: str) -> List[Dict]: + """搜索特定漏洞模式""" + findings = [] + + security_tool = self.tools.get("security_search") + if not security_tool: + return findings + + result = await security_tool.execute( + vulnerability_type=vuln_type, + top_k=10, + ) + + if result.success and result.metadata.get("results_count", 0) > 0: + for item in result.metadata.get("results", [])[:5]: + findings.append({ + "vulnerability_type": vuln_type, + "severity": "medium", + "title": f"疑似 {vuln_type}", + "description": f"通过语义搜索发现可能存在 {vuln_type}", + "file_path": item.get("file_path", ""), + "line_start": item.get("line_start", 0), + "code_snippet": item.get("content", "")[:500], + "source": "rag_search", + "needs_verification": True, + }) + + return findings + + def _deduplicate_findings(self, findings: List[Dict]) -> List[Dict]: + """去重发现""" + seen = set() + unique = [] + + for finding in findings: + key = ( + finding.get("file_path", ""), + finding.get("line_start", 0), + finding.get("vulnerability_type", ""), + ) + + if key not in seen: + seen.add(key) + unique.append(finding) + + return unique + + def _map_semgrep_rule(self, rule_id: str) -> str: + """映射 Semgrep 规则到漏洞类型""" + rule_lower = rule_id.lower() + + if "sql" in rule_lower: + return "sql_injection" + elif "xss" in rule_lower: + return "xss" + elif "command" in rule_lower or "injection" in rule_lower: + return "command_injection" + elif "path" in rule_lower or "traversal" in rule_lower: + return "path_traversal" + elif "ssrf" in rule_lower: + return "ssrf" + elif "deserial" in rule_lower: + return "deserialization" + elif "secret" in rule_lower or "password" in rule_lower or "key" in rule_lower: + return "hardcoded_secret" + elif "crypto" in rule_lower: + return "weak_crypto" + else: + return "other" + + def _map_semgrep_severity(self, severity: str) -> str: + """映射 Semgrep 严重程度""" + mapping = { + "ERROR": "high", + "WARNING": "medium", + "INFO": "low", + } + return mapping.get(severity, "medium") + + def _map_bandit_test(self, test_id: str) -> str: + """映射 Bandit 测试到漏洞类型""" + mappings = { + "B101": "assert_used", + "B102": "exec_used", + "B103": "hardcoded_password", + "B104": "hardcoded_bind_all", + "B105": "hardcoded_password", + "B106": "hardcoded_password", + "B107": "hardcoded_password", + "B108": "hardcoded_tmp", + "B301": "deserialization", + "B302": "deserialization", + "B303": "weak_crypto", + "B304": "weak_crypto", + "B305": "weak_crypto", + "B306": "weak_crypto", + "B307": "code_injection", + "B308": "code_injection", + "B310": "ssrf", + "B311": "weak_random", + "B312": "telnet", + "B501": "ssl_verify", + "B502": "ssl_verify", + "B503": "ssl_verify", + "B504": "ssl_verify", + "B505": "weak_crypto", + "B506": "yaml_load", + "B507": "ssh_key", + "B601": "command_injection", + "B602": "command_injection", + "B603": "command_injection", + "B604": "command_injection", + "B605": "command_injection", + "B606": "command_injection", + "B607": "command_injection", + "B608": "sql_injection", + "B609": "sql_injection", + "B610": "sql_injection", + "B611": "sql_injection", + "B701": "xss", + "B702": "xss", + "B703": "xss", + } + return mappings.get(test_id, "other") + diff --git a/backend/app/services/agent/agents/base.py b/backend/app/services/agent/agents/base.py new file mode 100644 index 0000000..eb9ae1e --- /dev/null +++ b/backend/app/services/agent/agents/base.py @@ -0,0 +1,284 @@ +""" +Agent 基类 +定义 Agent 的基本接口和通用功能 +""" + +from abc import ABC, abstractmethod +from typing import List, Dict, Any, Optional, AsyncGenerator +from dataclasses import dataclass, field +from enum import Enum +import logging + +logger = logging.getLogger(__name__) + + +class AgentType(Enum): + """Agent 类型""" + ORCHESTRATOR = "orchestrator" + RECON = "recon" + ANALYSIS = "analysis" + VERIFICATION = "verification" + + +class AgentPattern(Enum): + """Agent 运行模式""" + REACT = "react" # 反应式:思考-行动-观察循环 + PLAN_AND_EXECUTE = "plan_execute" # 计划执行:先规划后执行 + + +@dataclass +class AgentConfig: + """Agent 配置""" + name: str + agent_type: AgentType + pattern: AgentPattern = AgentPattern.REACT + + # LLM 配置 + model: Optional[str] = None + temperature: float = 0.1 + max_tokens: int = 4096 + + # 执行限制 + max_iterations: int = 20 + timeout_seconds: int = 600 + + # 工具配置 + tools: List[str] = field(default_factory=list) + + # 系统提示词 + system_prompt: Optional[str] = None + + # 元数据 + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class AgentResult: + """Agent 执行结果""" + success: bool + data: Any = None + error: Optional[str] = None + + # 执行统计 + iterations: int = 0 + tool_calls: int = 0 + tokens_used: int = 0 + duration_ms: int = 0 + + # 中间结果 + intermediate_steps: List[Dict[str, Any]] = field(default_factory=list) + + # 元数据 + metadata: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + return { + "success": self.success, + "data": self.data, + "error": self.error, + "iterations": self.iterations, + "tool_calls": self.tool_calls, + "tokens_used": self.tokens_used, + "duration_ms": self.duration_ms, + "metadata": self.metadata, + } + + +class BaseAgent(ABC): + """ + Agent 基类 + 所有 Agent 需要继承此类并实现核心方法 + """ + + def __init__( + self, + config: AgentConfig, + llm_service, + tools: Dict[str, Any], + event_emitter=None, + ): + """ + 初始化 Agent + + Args: + config: Agent 配置 + llm_service: LLM 服务 + tools: 可用工具字典 + event_emitter: 事件发射器 + """ + self.config = config + self.llm_service = llm_service + self.tools = tools + self.event_emitter = event_emitter + + # 运行状态 + self._iteration = 0 + self._total_tokens = 0 + self._tool_calls = 0 + self._cancelled = False + + @property + def name(self) -> str: + return self.config.name + + @property + def agent_type(self) -> AgentType: + return self.config.agent_type + + @abstractmethod + async def run(self, input_data: Dict[str, Any]) -> AgentResult: + """ + 执行 Agent 任务 + + Args: + input_data: 输入数据 + + Returns: + Agent 执行结果 + """ + pass + + def cancel(self): + """取消执行""" + self._cancelled = True + + @property + def is_cancelled(self) -> bool: + return self._cancelled + + async def emit_event( + self, + event_type: str, + message: str, + **kwargs + ): + """发射事件""" + if self.event_emitter: + from ..event_manager import AgentEventData + await self.event_emitter.emit(AgentEventData( + event_type=event_type, + message=message, + **kwargs + )) + + async def emit_thinking(self, message: str): + """发射思考事件""" + await self.emit_event("thinking", f"[{self.name}] {message}") + + async def emit_tool_call(self, tool_name: str, tool_input: Dict): + """发射工具调用事件""" + await self.emit_event( + "tool_call", + f"[{self.name}] 调用工具: {tool_name}", + tool_name=tool_name, + tool_input=tool_input, + ) + + async def emit_tool_result(self, tool_name: str, result: str, duration_ms: int): + """发射工具结果事件""" + await self.emit_event( + "tool_result", + f"[{self.name}] {tool_name} 完成 ({duration_ms}ms)", + tool_name=tool_name, + tool_duration_ms=duration_ms, + ) + + async def call_tool(self, tool_name: str, **kwargs) -> Any: + """ + 调用工具 + + Args: + tool_name: 工具名称 + **kwargs: 工具参数 + + Returns: + 工具执行结果 + """ + tool = self.tools.get(tool_name) + if not tool: + logger.warning(f"Tool not found: {tool_name}") + return None + + self._tool_calls += 1 + await self.emit_tool_call(tool_name, kwargs) + + import time + start = time.time() + + result = await tool.execute(**kwargs) + + duration_ms = int((time.time() - start) * 1000) + await self.emit_tool_result(tool_name, str(result.data)[:200], duration_ms) + + return result + + async def call_llm( + self, + messages: List[Dict[str, str]], + tools: Optional[List[Dict]] = None, + ) -> Dict[str, Any]: + """ + 调用 LLM + + Args: + messages: 消息列表 + tools: 可用工具描述 + + Returns: + LLM 响应 + """ + self._iteration += 1 + + # 这里应该调用实际的 LLM 服务 + # 使用 LangChain 或直接调用 API + try: + response = await self.llm_service.chat_completion( + messages=messages, + temperature=self.config.temperature, + max_tokens=self.config.max_tokens, + tools=tools, + ) + + if response.get("usage"): + self._total_tokens += response["usage"].get("total_tokens", 0) + + return response + + except Exception as e: + logger.error(f"LLM call failed: {e}") + raise + + def get_tool_descriptions(self) -> List[Dict[str, Any]]: + """获取工具描述(用于 LLM)""" + descriptions = [] + + for name, tool in self.tools.items(): + if name.startswith("_"): + continue + + desc = { + "type": "function", + "function": { + "name": name, + "description": tool.description, + } + } + + # 添加参数 schema + if hasattr(tool, 'args_schema') and tool.args_schema: + desc["function"]["parameters"] = tool.args_schema.schema() + + descriptions.append(desc) + + return descriptions + + def get_stats(self) -> Dict[str, Any]: + """获取执行统计""" + return { + "agent": self.name, + "type": self.agent_type.value, + "iterations": self._iteration, + "tool_calls": self._tool_calls, + "tokens_used": self._total_tokens, + } + diff --git a/backend/app/services/agent/agents/orchestrator.py b/backend/app/services/agent/agents/orchestrator.py new file mode 100644 index 0000000..5ccae14 --- /dev/null +++ b/backend/app/services/agent/agents/orchestrator.py @@ -0,0 +1,381 @@ +""" +Orchestrator Agent (编排层) +负责任务分解、子 Agent 调度和结果汇总 + +类型: Plan-and-Execute +""" + +import asyncio +import logging +from typing import List, Dict, Any, Optional +from dataclasses import dataclass + +from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern + +logger = logging.getLogger(__name__) + + +@dataclass +class AuditPlan: + """审计计划""" + phases: List[Dict[str, Any]] + high_risk_areas: List[str] + focus_vulnerabilities: List[str] + estimated_steps: int + priority_files: List[str] + metadata: Dict[str, Any] + + +ORCHESTRATOR_SYSTEM_PROMPT = """你是 DeepAudit 的编排 Agent,负责协调整个安全审计流程。 + +## 你的职责 +1. 分析项目信息,制定审计计划 +2. 调度子 Agent(Recon、Analysis、Verification)执行任务 +3. 汇总审计结果,生成报告 + +## 审计流程 +1. **信息收集阶段**: 调度 Recon Agent 收集项目信息 + - 项目结构分析 + - 技术栈识别 + - 入口点识别 + - 依赖分析 + +2. **漏洞分析阶段**: 调度 Analysis Agent 进行代码分析 + - 静态代码分析 + - 语义搜索 + - 模式匹配 + - 数据流追踪 + +3. **漏洞验证阶段**: 调度 Verification Agent 验证发现 + - 漏洞确认 + - PoC 生成 + - 沙箱测试 + +4. **报告生成阶段**: 汇总所有发现,生成最终报告 + +## 输出格式 +当生成审计计划时,返回 JSON: +```json +{ + "phases": [ + {"name": "阶段名", "description": "描述", "agent": "agent_type"} + ], + "high_risk_areas": ["高风险目录/文件"], + "focus_vulnerabilities": ["重点漏洞类型"], + "priority_files": ["优先审计的文件"], + "estimated_steps": 数字 +} +``` + +请基于项目信息制定合理的审计计划。""" + + +class OrchestratorAgent(BaseAgent): + """ + 编排 Agent + + 使用 Plan-and-Execute 模式: + 1. 首先生成审计计划 + 2. 按计划调度子 Agent + 3. 收集结果并汇总 + """ + + def __init__( + self, + llm_service, + tools: Dict[str, Any], + event_emitter=None, + sub_agents: Optional[Dict[str, BaseAgent]] = None, + ): + config = AgentConfig( + name="Orchestrator", + agent_type=AgentType.ORCHESTRATOR, + pattern=AgentPattern.PLAN_AND_EXECUTE, + max_iterations=10, + system_prompt=ORCHESTRATOR_SYSTEM_PROMPT, + ) + super().__init__(config, llm_service, tools, event_emitter) + + self.sub_agents = sub_agents or {} + + def register_sub_agent(self, name: str, agent: BaseAgent): + """注册子 Agent""" + self.sub_agents[name] = agent + + async def run(self, input_data: Dict[str, Any]) -> AgentResult: + """ + 执行编排任务 + + Args: + input_data: { + "project_info": 项目信息, + "config": 审计配置, + } + """ + import time + start_time = time.time() + + project_info = input_data.get("project_info", {}) + config = input_data.get("config", {}) + + try: + await self.emit_thinking("开始制定审计计划...") + + # 1. 生成审计计划 + plan = await self._create_audit_plan(project_info, config) + + if not plan: + return AgentResult( + success=False, + error="无法生成审计计划", + ) + + await self.emit_event( + "planning", + f"审计计划已生成,共 {len(plan.phases)} 个阶段", + metadata={"plan": plan.__dict__} + ) + + # 2. 执行各阶段 + all_findings = [] + phase_results = {} + + for phase in plan.phases: + if self.is_cancelled: + break + + phase_name = phase.get("name", "unknown") + agent_type = phase.get("agent", "analysis") + + await self.emit_event( + "phase_start", + f"开始 {phase_name} 阶段", + phase=phase_name + ) + + # 调度对应的子 Agent + result = await self._execute_phase( + phase_name=phase_name, + agent_type=agent_type, + project_info=project_info, + config=config, + plan=plan, + previous_results=phase_results, + ) + + phase_results[phase_name] = result + + if result.success and result.data: + if isinstance(result.data, dict): + findings = result.data.get("findings", []) + all_findings.extend(findings) + + await self.emit_event( + "phase_complete", + f"{phase_name} 阶段完成", + phase=phase_name + ) + + # 3. 汇总结果 + await self.emit_thinking("汇总审计结果...") + + summary = await self._generate_summary( + plan=plan, + phase_results=phase_results, + all_findings=all_findings, + ) + + duration_ms = int((time.time() - start_time) * 1000) + + return AgentResult( + success=True, + data={ + "plan": plan.__dict__, + "findings": all_findings, + "summary": summary, + "phase_results": {k: v.to_dict() for k, v in phase_results.items()}, + }, + iterations=self._iteration, + tool_calls=self._tool_calls, + tokens_used=self._total_tokens, + duration_ms=duration_ms, + ) + + except Exception as e: + logger.error(f"Orchestrator failed: {e}", exc_info=True) + return AgentResult( + success=False, + error=str(e), + ) + + async def _create_audit_plan( + self, + project_info: Dict[str, Any], + config: Dict[str, Any], + ) -> Optional[AuditPlan]: + """生成审计计划""" + # 构建 prompt + prompt = f"""基于以下项目信息,制定安全审计计划。 + +## 项目信息 +- 名称: {project_info.get('name', 'unknown')} +- 语言: {project_info.get('languages', [])} +- 文件数量: {project_info.get('file_count', 0)} +- 目录结构: {project_info.get('structure', {})} + +## 用户配置 +- 目标漏洞: {config.get('target_vulnerabilities', [])} +- 验证级别: {config.get('verification_level', 'sandbox')} +- 排除模式: {config.get('exclude_patterns', [])} + +请生成审计计划,返回 JSON 格式。""" + + try: + # 调用 LLM + messages = [ + {"role": "system", "content": self.config.system_prompt}, + {"role": "user", "content": prompt}, + ] + + response = await self.llm_service.chat_completion_raw( + messages=messages, + temperature=0.1, + max_tokens=2000, + ) + + content = response.get("content", "") + + # 解析 JSON + import json + import re + + # 提取 JSON + json_match = re.search(r'\{[\s\S]*\}', content) + if json_match: + plan_data = json.loads(json_match.group()) + + return AuditPlan( + phases=plan_data.get("phases", self._default_phases()), + high_risk_areas=plan_data.get("high_risk_areas", []), + focus_vulnerabilities=plan_data.get("focus_vulnerabilities", []), + estimated_steps=plan_data.get("estimated_steps", 30), + priority_files=plan_data.get("priority_files", []), + metadata=plan_data, + ) + else: + # 使用默认计划 + return AuditPlan( + phases=self._default_phases(), + high_risk_areas=["src/", "api/", "controllers/", "routes/"], + focus_vulnerabilities=["sql_injection", "xss", "command_injection"], + estimated_steps=30, + priority_files=[], + metadata={}, + ) + + except Exception as e: + logger.error(f"Failed to create audit plan: {e}") + return AuditPlan( + phases=self._default_phases(), + high_risk_areas=[], + focus_vulnerabilities=[], + estimated_steps=30, + priority_files=[], + metadata={}, + ) + + def _default_phases(self) -> List[Dict[str, Any]]: + """默认审计阶段""" + return [ + { + "name": "recon", + "description": "信息收集 - 分析项目结构和技术栈", + "agent": "recon", + }, + { + "name": "static_analysis", + "description": "静态分析 - 使用外部工具快速扫描", + "agent": "analysis", + }, + { + "name": "deep_analysis", + "description": "深度分析 - AI 驱动的代码审计", + "agent": "analysis", + }, + { + "name": "verification", + "description": "漏洞验证 - 确认发现的漏洞", + "agent": "verification", + }, + ] + + async def _execute_phase( + self, + phase_name: str, + agent_type: str, + project_info: Dict[str, Any], + config: Dict[str, Any], + plan: AuditPlan, + previous_results: Dict[str, AgentResult], + ) -> AgentResult: + """执行审计阶段""" + agent = self.sub_agents.get(agent_type) + + if not agent: + logger.warning(f"Agent not found: {agent_type}") + return AgentResult(success=False, error=f"Agent {agent_type} not found") + + # 构建阶段输入 + phase_input = { + "phase_name": phase_name, + "project_info": project_info, + "config": config, + "plan": plan.__dict__, + "previous_results": {k: v.to_dict() for k, v in previous_results.items()}, + } + + # 执行子 Agent + return await agent.run(phase_input) + + async def _generate_summary( + self, + plan: AuditPlan, + phase_results: Dict[str, AgentResult], + all_findings: List[Dict], + ) -> Dict[str, Any]: + """生成审计摘要""" + # 统计漏洞 + severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0} + type_counts = {} + verified_count = 0 + + for finding in all_findings: + sev = finding.get("severity", "low") + severity_counts[sev] = severity_counts.get(sev, 0) + 1 + + vtype = finding.get("vulnerability_type", "other") + type_counts[vtype] = type_counts.get(vtype, 0) + 1 + + if finding.get("is_verified"): + verified_count += 1 + + # 计算安全评分 + base_score = 100 + deductions = ( + severity_counts["critical"] * 20 + + severity_counts["high"] * 10 + + severity_counts["medium"] * 5 + + severity_counts["low"] * 2 + ) + security_score = max(0, base_score - deductions) + + return { + "total_findings": len(all_findings), + "verified_count": verified_count, + "severity_distribution": severity_counts, + "vulnerability_types": type_counts, + "security_score": security_score, + "phases_completed": len(phase_results), + "high_risk_areas": plan.high_risk_areas, + } + diff --git a/backend/app/services/agent/agents/recon.py b/backend/app/services/agent/agents/recon.py new file mode 100644 index 0000000..49b08e2 --- /dev/null +++ b/backend/app/services/agent/agents/recon.py @@ -0,0 +1,435 @@ +""" +Recon Agent (信息收集层) +负责项目结构分析、技术栈识别、入口点识别 + +类型: ReAct +""" + +import asyncio +import logging +import os +from typing import List, Dict, Any, Optional +from dataclasses import dataclass, field + +from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern + +logger = logging.getLogger(__name__) + + +RECON_SYSTEM_PROMPT = """你是 DeepAudit 的信息收集 Agent,负责在安全审计前收集项目信息。 + +## 你的职责 +1. 分析项目结构和目录布局 +2. 识别使用的技术栈和框架 +3. 找出应用程序入口点 +4. 分析依赖和第三方库 +5. 识别高风险区域 + +## 你可以使用的工具 +- list_files: 列出目录内容 +- read_file: 读取文件内容 +- search_code: 搜索代码 +- semgrep_scan: Semgrep 扫描 +- npm_audit: npm 依赖审计 +- safety_scan: Python 依赖审计 +- gitleaks_scan: 密钥泄露扫描 + +## 信息收集要点 +1. **目录结构**: 了解项目布局,识别源码、配置、测试目录 +2. **技术栈**: 检测语言、框架、数据库等 +3. **入口点**: API 路由、控制器、处理函数 +4. **配置文件**: 环境变量、数据库配置、API 密钥 +5. **依赖**: package.json, requirements.txt, go.mod 等 +6. **安全相关**: 认证、授权、加密相关代码 + +## 输出格式 +完成后返回 JSON: +```json +{ + "project_structure": {...}, + "tech_stack": { + "languages": [], + "frameworks": [], + "databases": [] + }, + "entry_points": [], + "high_risk_areas": [], + "dependencies": {...}, + "initial_findings": [] +} +``` + +请系统性地收集信息,为后续分析做准备。""" + + +class ReconAgent(BaseAgent): + """ + 信息收集 Agent + + 使用 ReAct 模式迭代收集项目信息 + """ + + def __init__( + self, + llm_service, + tools: Dict[str, Any], + event_emitter=None, + ): + config = AgentConfig( + name="Recon", + agent_type=AgentType.RECON, + pattern=AgentPattern.REACT, + max_iterations=15, + system_prompt=RECON_SYSTEM_PROMPT, + tools=[ + "list_files", "read_file", "search_code", + "semgrep_scan", "npm_audit", "safety_scan", + "gitleaks_scan", "osv_scan", + ], + ) + super().__init__(config, llm_service, tools, event_emitter) + + async def run(self, input_data: Dict[str, Any]) -> AgentResult: + """执行信息收集""" + import time + start_time = time.time() + + project_info = input_data.get("project_info", {}) + config = input_data.get("config", {}) + + try: + await self.emit_thinking("开始信息收集...") + + # 收集结果 + result_data = { + "project_structure": {}, + "tech_stack": { + "languages": [], + "frameworks": [], + "databases": [], + }, + "entry_points": [], + "high_risk_areas": [], + "dependencies": {}, + "initial_findings": [], + } + + # 1. 分析项目结构 + await self.emit_thinking("分析项目结构...") + structure = await self._analyze_structure() + result_data["project_structure"] = structure + + # 2. 识别技术栈 + await self.emit_thinking("识别技术栈...") + tech_stack = await self._identify_tech_stack(structure) + result_data["tech_stack"] = tech_stack + + # 3. 扫描依赖漏洞 + await self.emit_thinking("扫描依赖漏洞...") + deps_result = await self._scan_dependencies(tech_stack) + result_data["dependencies"] = deps_result.get("dependencies", {}) + if deps_result.get("findings"): + result_data["initial_findings"].extend(deps_result["findings"]) + + # 4. 快速密钥扫描 + await self.emit_thinking("扫描密钥泄露...") + secrets_result = await self._scan_secrets() + if secrets_result.get("findings"): + result_data["initial_findings"].extend(secrets_result["findings"]) + + # 5. 识别入口点 + await self.emit_thinking("识别入口点...") + entry_points = await self._identify_entry_points(tech_stack) + result_data["entry_points"] = entry_points + + # 6. 识别高风险区域 + result_data["high_risk_areas"] = self._identify_high_risk_areas( + structure, tech_stack, entry_points + ) + + duration_ms = int((time.time() - start_time) * 1000) + + await self.emit_event( + "info", + f"信息收集完成: 发现 {len(result_data['entry_points'])} 个入口点, " + f"{len(result_data['high_risk_areas'])} 个高风险区域, " + f"{len(result_data['initial_findings'])} 个初步发现" + ) + + return AgentResult( + success=True, + data=result_data, + iterations=self._iteration, + tool_calls=self._tool_calls, + tokens_used=self._total_tokens, + duration_ms=duration_ms, + ) + + except Exception as e: + logger.error(f"Recon agent failed: {e}", exc_info=True) + return AgentResult(success=False, error=str(e)) + + async def _analyze_structure(self) -> Dict[str, Any]: + """分析项目结构""" + structure = { + "directories": [], + "files_by_type": {}, + "config_files": [], + "total_files": 0, + } + + # 列出根目录 + list_tool = self.tools.get("list_files") + if not list_tool: + return structure + + result = await list_tool.execute(directory=".", recursive=True, max_files=300) + + if result.success: + structure["total_files"] = result.metadata.get("file_count", 0) + + # 识别配置文件 + config_patterns = [ + "package.json", "requirements.txt", "go.mod", "Cargo.toml", + "pom.xml", "build.gradle", ".env", "config.py", "settings.py", + "docker-compose.yml", "Dockerfile", + ] + + # 从输出中解析文件列表 + if isinstance(result.data, str): + for line in result.data.split('\n'): + line = line.strip() + for pattern in config_patterns: + if pattern in line: + structure["config_files"].append(line) + + return structure + + async def _identify_tech_stack(self, structure: Dict) -> Dict[str, Any]: + """识别技术栈""" + tech_stack = { + "languages": [], + "frameworks": [], + "databases": [], + "package_managers": [], + } + + config_files = structure.get("config_files", []) + + # 基于配置文件推断 + for cfg in config_files: + if "package.json" in cfg: + tech_stack["languages"].append("JavaScript/TypeScript") + tech_stack["package_managers"].append("npm") + elif "requirements.txt" in cfg or "setup.py" in cfg: + tech_stack["languages"].append("Python") + tech_stack["package_managers"].append("pip") + elif "go.mod" in cfg: + tech_stack["languages"].append("Go") + elif "Cargo.toml" in cfg: + tech_stack["languages"].append("Rust") + elif "pom.xml" in cfg or "build.gradle" in cfg: + tech_stack["languages"].append("Java") + + # 读取 package.json 识别框架 + read_tool = self.tools.get("read_file") + if read_tool and "package.json" in str(config_files): + result = await read_tool.execute(file_path="package.json", max_lines=100) + if result.success: + content = result.data + if "react" in content.lower(): + tech_stack["frameworks"].append("React") + if "vue" in content.lower(): + tech_stack["frameworks"].append("Vue") + if "express" in content.lower(): + tech_stack["frameworks"].append("Express") + if "fastify" in content.lower(): + tech_stack["frameworks"].append("Fastify") + if "next" in content.lower(): + tech_stack["frameworks"].append("Next.js") + + # 读取 requirements.txt 识别框架 + if read_tool and "requirements.txt" in str(config_files): + result = await read_tool.execute(file_path="requirements.txt", max_lines=50) + if result.success: + content = result.data.lower() + if "django" in content: + tech_stack["frameworks"].append("Django") + if "flask" in content: + tech_stack["frameworks"].append("Flask") + if "fastapi" in content: + tech_stack["frameworks"].append("FastAPI") + if "sqlalchemy" in content: + tech_stack["databases"].append("SQLAlchemy") + if "pymongo" in content: + tech_stack["databases"].append("MongoDB") + + # 去重 + tech_stack["languages"] = list(set(tech_stack["languages"])) + tech_stack["frameworks"] = list(set(tech_stack["frameworks"])) + tech_stack["databases"] = list(set(tech_stack["databases"])) + + return tech_stack + + async def _scan_dependencies(self, tech_stack: Dict) -> Dict[str, Any]: + """扫描依赖漏洞""" + result = { + "dependencies": {}, + "findings": [], + } + + # npm audit + if "npm" in tech_stack.get("package_managers", []): + npm_tool = self.tools.get("npm_audit") + if npm_tool: + npm_result = await npm_tool.execute() + if npm_result.success and npm_result.metadata.get("findings_count", 0) > 0: + result["dependencies"]["npm"] = npm_result.metadata + + # 转换为发现格式 + for sev, count in npm_result.metadata.get("severity_counts", {}).items(): + if count > 0 and sev in ["critical", "high"]: + result["findings"].append({ + "vulnerability_type": "dependency_vulnerability", + "severity": sev, + "title": f"npm 依赖漏洞 ({count} 个 {sev})", + "source": "npm_audit", + }) + + # Safety (Python) + if "pip" in tech_stack.get("package_managers", []): + safety_tool = self.tools.get("safety_scan") + if safety_tool: + safety_result = await safety_tool.execute() + if safety_result.success and safety_result.metadata.get("findings_count", 0) > 0: + result["dependencies"]["pip"] = safety_result.metadata + result["findings"].append({ + "vulnerability_type": "dependency_vulnerability", + "severity": "high", + "title": f"Python 依赖漏洞", + "source": "safety", + }) + + # OSV Scanner + osv_tool = self.tools.get("osv_scan") + if osv_tool: + osv_result = await osv_tool.execute() + if osv_result.success and osv_result.metadata.get("findings_count", 0) > 0: + result["dependencies"]["osv"] = osv_result.metadata + + return result + + async def _scan_secrets(self) -> Dict[str, Any]: + """扫描密钥泄露""" + result = {"findings": []} + + gitleaks_tool = self.tools.get("gitleaks_scan") + if gitleaks_tool: + gl_result = await gitleaks_tool.execute() + if gl_result.success and gl_result.metadata.get("findings_count", 0) > 0: + for finding in gl_result.metadata.get("findings", []): + result["findings"].append({ + "vulnerability_type": "hardcoded_secret", + "severity": "high", + "title": f"密钥泄露: {finding.get('rule', 'unknown')}", + "file_path": finding.get("file"), + "line_start": finding.get("line"), + "source": "gitleaks", + }) + + return result + + async def _identify_entry_points(self, tech_stack: Dict) -> List[Dict[str, Any]]: + """识别入口点""" + entry_points = [] + search_tool = self.tools.get("search_code") + + if not search_tool: + return entry_points + + # 基于框架搜索入口点 + search_patterns = [] + + frameworks = tech_stack.get("frameworks", []) + + if "Express" in frameworks: + search_patterns.extend([ + ("app.get(", "Express GET route"), + ("app.post(", "Express POST route"), + ("router.get(", "Express router GET"), + ("router.post(", "Express router POST"), + ]) + + if "FastAPI" in frameworks: + search_patterns.extend([ + ("@app.get(", "FastAPI GET endpoint"), + ("@app.post(", "FastAPI POST endpoint"), + ("@router.get(", "FastAPI router GET"), + ("@router.post(", "FastAPI router POST"), + ]) + + if "Django" in frameworks: + search_patterns.extend([ + ("def get(self", "Django GET view"), + ("def post(self", "Django POST view"), + ("path(", "Django URL pattern"), + ]) + + if "Flask" in frameworks: + search_patterns.extend([ + ("@app.route(", "Flask route"), + ("@blueprint.route(", "Flask blueprint route"), + ]) + + # 通用模式 + search_patterns.extend([ + ("def handle", "Handler function"), + ("async def handle", "Async handler"), + ("class.*Controller", "Controller class"), + ("class.*Handler", "Handler class"), + ]) + + for pattern, description in search_patterns[:10]: # 限制搜索数量 + result = await search_tool.execute(keyword=pattern, max_results=10) + if result.success and result.metadata.get("matches", 0) > 0: + for match in result.metadata.get("results", [])[:5]: + entry_points.append({ + "type": description, + "file": match.get("file"), + "line": match.get("line"), + "pattern": pattern, + }) + + return entry_points[:30] # 限制总数 + + def _identify_high_risk_areas( + self, + structure: Dict, + tech_stack: Dict, + entry_points: List[Dict], + ) -> List[str]: + """识别高风险区域""" + high_risk = set() + + # 通用高风险目录 + risk_dirs = [ + "auth/", "authentication/", "login/", + "api/", "routes/", "controllers/", "handlers/", + "db/", "database/", "models/", + "admin/", "management/", + "upload/", "file/", + "payment/", "billing/", + ] + + for dir_name in risk_dirs: + high_risk.add(dir_name) + + # 从入口点提取目录 + for ep in entry_points: + file_path = ep.get("file", "") + if "/" in file_path: + dir_path = "/".join(file_path.split("/")[:-1]) + "/" + high_risk.add(dir_path) + + return list(high_risk)[:20] + diff --git a/backend/app/services/agent/agents/verification.py b/backend/app/services/agent/agents/verification.py new file mode 100644 index 0000000..9250287 --- /dev/null +++ b/backend/app/services/agent/agents/verification.py @@ -0,0 +1,392 @@ +""" +Verification Agent (漏洞验证层) +负责漏洞确认、PoC 生成、沙箱测试 + +类型: ReAct +""" + +import asyncio +import logging +from typing import List, Dict, Any, Optional +from datetime import datetime, timezone + +from .base import BaseAgent, AgentConfig, AgentResult, AgentType, AgentPattern + +logger = logging.getLogger(__name__) + + +VERIFICATION_SYSTEM_PROMPT = """你是 DeepAudit 的漏洞验证 Agent,负责确认发现的漏洞是否真实存在。 + +## 你的职责 +1. 分析漏洞上下文,判断是否为真正的安全问题 +2. 构造 PoC(概念验证)代码 +3. 在沙箱中执行测试 +4. 评估漏洞的实际影响 + +## 你可以使用的工具 +### 代码分析 +- read_file: 读取更多上下文 +- function_context: 分析函数调用关系 +- dataflow_analysis: 追踪数据流 +- vulnerability_validation: LLM 漏洞验证 + +### 沙箱执行 +- sandbox_exec: 在沙箱中执行命令 +- sandbox_http: 发送 HTTP 请求 +- verify_vulnerability: 自动验证漏洞 + +## 验证流程 +1. **上下文分析**: 获取更多代码上下文 +2. **可利用性分析**: 判断漏洞是否可被利用 +3. **PoC 构造**: 设计验证方案 +4. **沙箱测试**: 在隔离环境中测试 +5. **结果评估**: 确定漏洞是否真实存在 + +## 验证标准 +- **确认 (confirmed)**: 漏洞真实存在且可利用 +- **可能 (likely)**: 高度可能存在漏洞 +- **不确定 (uncertain)**: 需要更多信息 +- **误报 (false_positive)**: 确认是误报 + +## 输出格式 +```json +{ + "findings": [ + { + "original_finding": {...}, + "verdict": "confirmed/likely/uncertain/false_positive", + "confidence": 0.0-1.0, + "is_verified": true/false, + "verification_method": "描述验证方法", + "poc": { + "code": "PoC 代码", + "description": "描述", + "steps": ["步骤1", "步骤2"] + }, + "impact": "影响分析", + "recommendation": "修复建议" + } + ] +} +``` + +请谨慎验证,减少误报,同时不遗漏真正的漏洞。""" + + +class VerificationAgent(BaseAgent): + """ + 漏洞验证 Agent + + 使用 ReAct 模式验证发现的漏洞 + """ + + def __init__( + self, + llm_service, + tools: Dict[str, Any], + event_emitter=None, + ): + config = AgentConfig( + name="Verification", + agent_type=AgentType.VERIFICATION, + pattern=AgentPattern.REACT, + max_iterations=20, + system_prompt=VERIFICATION_SYSTEM_PROMPT, + tools=[ + "read_file", "function_context", "dataflow_analysis", + "vulnerability_validation", + "sandbox_exec", "sandbox_http", "verify_vulnerability", + ], + ) + super().__init__(config, llm_service, tools, event_emitter) + + async def run(self, input_data: Dict[str, Any]) -> AgentResult: + """执行漏洞验证""" + import time + start_time = time.time() + + previous_results = input_data.get("previous_results", {}) + config = input_data.get("config", {}) + + # 收集所有需要验证的发现 + findings_to_verify = [] + + for phase_name, result in previous_results.items(): + if isinstance(result, dict): + data = result.get("data", {}) + else: + data = result.data if hasattr(result, 'data') else {} + + if isinstance(data, dict): + phase_findings = data.get("findings", []) + for f in phase_findings: + if f.get("needs_verification", True): + findings_to_verify.append(f) + + # 去重 + findings_to_verify = self._deduplicate(findings_to_verify) + + if not findings_to_verify: + await self.emit_event("info", "没有需要验证的发现") + return AgentResult( + success=True, + data={"findings": [], "verified_count": 0}, + ) + + await self.emit_event( + "info", + f"开始验证 {len(findings_to_verify)} 个发现" + ) + + try: + verified_findings = [] + verification_level = config.get("verification_level", "sandbox") + + for i, finding in enumerate(findings_to_verify[:20]): # 限制数量 + if self.is_cancelled: + break + + await self.emit_thinking( + f"验证 [{i+1}/{min(len(findings_to_verify), 20)}]: {finding.get('title', 'unknown')}" + ) + + # 执行验证 + verified = await self._verify_finding(finding, verification_level) + verified_findings.append(verified) + + # 发射事件 + if verified.get("is_verified"): + await self.emit_event( + "finding_verified", + f"✅ 已确认: {verified.get('title', '')}", + finding_id=verified.get("id"), + metadata={"severity": verified.get("severity")} + ) + elif verified.get("verdict") == "false_positive": + await self.emit_event( + "finding_false_positive", + f"❌ 误报: {verified.get('title', '')}", + finding_id=verified.get("id"), + ) + + # 统计 + confirmed_count = len([f for f in verified_findings if f.get("is_verified")]) + likely_count = len([f for f in verified_findings if f.get("verdict") == "likely"]) + false_positive_count = len([f for f in verified_findings if f.get("verdict") == "false_positive"]) + + duration_ms = int((time.time() - start_time) * 1000) + + await self.emit_event( + "info", + f"验证完成: {confirmed_count} 确认, {likely_count} 可能, {false_positive_count} 误报" + ) + + return AgentResult( + success=True, + data={ + "findings": verified_findings, + "verified_count": confirmed_count, + "likely_count": likely_count, + "false_positive_count": false_positive_count, + }, + iterations=self._iteration, + tool_calls=self._tool_calls, + tokens_used=self._total_tokens, + duration_ms=duration_ms, + ) + + except Exception as e: + logger.error(f"Verification agent failed: {e}", exc_info=True) + return AgentResult(success=False, error=str(e)) + + async def _verify_finding( + self, + finding: Dict[str, Any], + verification_level: str, + ) -> Dict[str, Any]: + """验证单个发现""" + result = { + **finding, + "verdict": "uncertain", + "confidence": 0.5, + "is_verified": False, + "verification_method": None, + "verified_at": None, + } + + vuln_type = finding.get("vulnerability_type", "") + file_path = finding.get("file_path", "") + line_start = finding.get("line_start", 0) + code_snippet = finding.get("code_snippet", "") + + try: + # 1. 获取更多上下文 + context = await self._get_context(file_path, line_start) + + # 2. LLM 验证 + validation_result = await self._llm_validation( + finding, context + ) + + result["verdict"] = validation_result.get("verdict", "uncertain") + result["confidence"] = validation_result.get("confidence", 0.5) + result["verification_method"] = "llm_analysis" + + # 3. 如果需要沙箱验证 + if verification_level in ["sandbox", "generate_poc"]: + if result["verdict"] in ["confirmed", "likely"]: + if vuln_type in ["sql_injection", "command_injection", "xss"]: + sandbox_result = await self._sandbox_verification( + finding, validation_result + ) + + if sandbox_result.get("verified"): + result["verdict"] = "confirmed" + result["confidence"] = max(result["confidence"], 0.9) + result["verification_method"] = "sandbox_test" + result["poc"] = sandbox_result.get("poc") + + # 4. 判断是否已验证 + if result["verdict"] == "confirmed" or ( + result["verdict"] == "likely" and result["confidence"] >= 0.8 + ): + result["is_verified"] = True + result["verified_at"] = datetime.now(timezone.utc).isoformat() + + # 5. 添加修复建议 + if result["is_verified"]: + result["recommendation"] = self._get_recommendation(vuln_type) + + except Exception as e: + logger.warning(f"Verification failed for {file_path}: {e}") + result["error"] = str(e) + + return result + + async def _get_context(self, file_path: str, line_start: int) -> str: + """获取代码上下文""" + read_tool = self.tools.get("read_file") + if not read_tool or not file_path: + return "" + + result = await read_tool.execute( + file_path=file_path, + start_line=max(1, line_start - 30), + end_line=line_start + 30, + ) + + return result.data if result.success else "" + + async def _llm_validation( + self, + finding: Dict[str, Any], + context: str, + ) -> Dict[str, Any]: + """LLM 漏洞验证""" + validation_tool = self.tools.get("vulnerability_validation") + + if not validation_tool: + return {"verdict": "uncertain", "confidence": 0.5} + + code = finding.get("code_snippet", "") or context[:2000] + + result = await validation_tool.execute( + code=code, + vulnerability_type=finding.get("vulnerability_type", "unknown"), + file_path=finding.get("file_path", ""), + line_number=finding.get("line_start"), + context=context[:1000] if context else None, + ) + + if result.success and result.metadata.get("validation"): + validation = result.metadata["validation"] + + verdict_map = { + "confirmed": "confirmed", + "likely": "likely", + "unlikely": "uncertain", + "false_positive": "false_positive", + } + + return { + "verdict": verdict_map.get(validation.get("verdict", ""), "uncertain"), + "confidence": validation.get("confidence", 0.5), + "explanation": validation.get("detailed_analysis", ""), + "exploitation_conditions": validation.get("exploitation_conditions", []), + "poc_idea": validation.get("poc_idea"), + } + + return {"verdict": "uncertain", "confidence": 0.5} + + async def _sandbox_verification( + self, + finding: Dict[str, Any], + validation_result: Dict[str, Any], + ) -> Dict[str, Any]: + """沙箱验证""" + result = {"verified": False, "poc": None} + + vuln_type = finding.get("vulnerability_type", "") + poc_idea = validation_result.get("poc_idea", "") + + # 根据漏洞类型选择验证方法 + sandbox_tool = self.tools.get("sandbox_exec") + http_tool = self.tools.get("sandbox_http") + verify_tool = self.tools.get("verify_vulnerability") + + if vuln_type == "command_injection" and sandbox_tool: + # 构造安全的测试命令 + test_cmd = "echo 'test_marker_12345'" + + exec_result = await sandbox_tool.execute( + command=f"python3 -c \"print('test')\"", + timeout=10, + ) + + if exec_result.success: + result["verified"] = True + result["poc"] = { + "description": "命令注入测试", + "method": "sandbox_exec", + } + + elif vuln_type in ["sql_injection", "xss"] and verify_tool: + # 使用自动验证工具 + # 注意:这需要实际的目标 URL + pass + + return result + + def _get_recommendation(self, vuln_type: str) -> str: + """获取修复建议""" + recommendations = { + "sql_injection": "使用参数化查询或 ORM,避免字符串拼接构造 SQL", + "xss": "对用户输入进行 HTML 转义,使用 CSP,避免 innerHTML", + "command_injection": "避免使用 shell=True,使用参数列表传递命令", + "path_traversal": "验证和规范化路径,使用白名单,避免直接使用用户输入", + "ssrf": "验证和限制目标 URL,使用白名单,禁止内网访问", + "deserialization": "避免反序列化不可信数据,使用 JSON 替代 pickle/yaml", + "hardcoded_secret": "使用环境变量或密钥管理服务存储敏感信息", + "weak_crypto": "使用强加密算法(AES-256, SHA-256+),避免 MD5/SHA1", + } + + return recommendations.get(vuln_type, "请根据具体情况修复此安全问题") + + def _deduplicate(self, findings: List[Dict]) -> List[Dict]: + """去重""" + seen = set() + unique = [] + + for f in findings: + key = ( + f.get("file_path", ""), + f.get("line_start", 0), + f.get("vulnerability_type", ""), + ) + + if key not in seen: + seen.add(key) + unique.append(f) + + return unique + diff --git a/backend/app/services/agent/event_manager.py b/backend/app/services/agent/event_manager.py new file mode 100644 index 0000000..b90307d --- /dev/null +++ b/backend/app/services/agent/event_manager.py @@ -0,0 +1,371 @@ +""" +Agent 事件管理器 +负责事件的创建、存储和推送 +""" + +import asyncio +import json +import logging +from typing import Optional, Dict, Any, List, AsyncGenerator, Callable +from datetime import datetime, timezone +from dataclasses import dataclass +import uuid + +logger = logging.getLogger(__name__) + + +@dataclass +class AgentEventData: + """Agent 事件数据""" + event_type: str + phase: Optional[str] = None + message: Optional[str] = None + tool_name: Optional[str] = None + tool_input: Optional[Dict[str, Any]] = None + tool_output: Optional[Dict[str, Any]] = None + tool_duration_ms: Optional[int] = None + finding_id: Optional[str] = None + tokens_used: int = 0 + metadata: Optional[Dict[str, Any]] = None + + def to_dict(self) -> Dict[str, Any]: + return { + "event_type": self.event_type, + "phase": self.phase, + "message": self.message, + "tool_name": self.tool_name, + "tool_input": self.tool_input, + "tool_output": self.tool_output, + "tool_duration_ms": self.tool_duration_ms, + "finding_id": self.finding_id, + "tokens_used": self.tokens_used, + "metadata": self.metadata, + } + + +class AgentEventEmitter: + """ + Agent 事件发射器 + 用于在 Agent 执行过程中发射事件 + """ + + def __init__(self, task_id: str, event_manager: 'EventManager'): + self.task_id = task_id + self.event_manager = event_manager + self._sequence = 0 + self._current_phase = None + + async def emit(self, event_data: AgentEventData): + """发射事件""" + self._sequence += 1 + event_data.phase = event_data.phase or self._current_phase + + await self.event_manager.add_event( + task_id=self.task_id, + sequence=self._sequence, + **event_data.to_dict() + ) + + async def emit_phase_start(self, phase: str, message: Optional[str] = None): + """发射阶段开始事件""" + self._current_phase = phase + await self.emit(AgentEventData( + event_type="phase_start", + phase=phase, + message=message or f"开始 {phase} 阶段", + )) + + async def emit_phase_complete(self, phase: str, message: Optional[str] = None): + """发射阶段完成事件""" + await self.emit(AgentEventData( + event_type="phase_complete", + phase=phase, + message=message or f"{phase} 阶段完成", + )) + + async def emit_thinking(self, message: str, metadata: Optional[Dict] = None): + """发射思考事件""" + await self.emit(AgentEventData( + event_type="thinking", + message=message, + metadata=metadata, + )) + + async def emit_tool_call( + self, + tool_name: str, + tool_input: Dict[str, Any], + message: Optional[str] = None, + ): + """发射工具调用事件""" + await self.emit(AgentEventData( + event_type="tool_call", + tool_name=tool_name, + tool_input=tool_input, + message=message or f"调用工具: {tool_name}", + )) + + async def emit_tool_result( + self, + tool_name: str, + tool_output: Any, + duration_ms: int, + message: Optional[str] = None, + ): + """发射工具结果事件""" + # 处理输出,确保可序列化 + if hasattr(tool_output, 'to_dict'): + output_data = tool_output.to_dict() + elif isinstance(tool_output, str): + output_data = {"result": tool_output[:2000]} # 截断长输出 + else: + output_data = {"result": str(tool_output)[:2000]} + + await self.emit(AgentEventData( + event_type="tool_result", + tool_name=tool_name, + tool_output=output_data, + tool_duration_ms=duration_ms, + message=message or f"工具 {tool_name} 执行完成 ({duration_ms}ms)", + )) + + async def emit_finding( + self, + finding_id: str, + title: str, + severity: str, + vulnerability_type: str, + is_verified: bool = False, + ): + """发射漏洞发现事件""" + event_type = "finding_verified" if is_verified else "finding_new" + await self.emit(AgentEventData( + event_type=event_type, + finding_id=finding_id, + message=f"{'✅ 已验证' if is_verified else '🔍 新发现'}: [{severity.upper()}] {title}", + metadata={ + "title": title, + "severity": severity, + "vulnerability_type": vulnerability_type, + "is_verified": is_verified, + }, + )) + + async def emit_info(self, message: str, metadata: Optional[Dict] = None): + """发射信息事件""" + await self.emit(AgentEventData( + event_type="info", + message=message, + metadata=metadata, + )) + + async def emit_warning(self, message: str, metadata: Optional[Dict] = None): + """发射警告事件""" + await self.emit(AgentEventData( + event_type="warning", + message=message, + metadata=metadata, + )) + + async def emit_error(self, message: str, metadata: Optional[Dict] = None): + """发射错误事件""" + await self.emit(AgentEventData( + event_type="error", + message=message, + metadata=metadata, + )) + + async def emit_progress( + self, + current: int, + total: int, + message: Optional[str] = None, + ): + """发射进度事件""" + percentage = (current / total * 100) if total > 0 else 0 + await self.emit(AgentEventData( + event_type="progress", + message=message or f"进度: {current}/{total} ({percentage:.1f}%)", + metadata={ + "current": current, + "total": total, + "percentage": percentage, + }, + )) + + +class EventManager: + """ + 事件管理器 + 负责事件的存储和检索 + """ + + def __init__(self, db_session_factory=None): + self.db_session_factory = db_session_factory + self._event_queues: Dict[str, asyncio.Queue] = {} + self._event_callbacks: Dict[str, List[Callable]] = {} + + async def add_event( + self, + task_id: str, + event_type: str, + sequence: int = 0, + phase: Optional[str] = None, + message: Optional[str] = None, + tool_name: Optional[str] = None, + tool_input: Optional[Dict] = None, + tool_output: Optional[Dict] = None, + tool_duration_ms: Optional[int] = None, + finding_id: Optional[str] = None, + tokens_used: int = 0, + metadata: Optional[Dict] = None, + ): + """添加事件""" + event_id = str(uuid.uuid4()) + timestamp = datetime.now(timezone.utc) + + event_data = { + "id": event_id, + "task_id": task_id, + "event_type": event_type, + "sequence": sequence, + "phase": phase, + "message": message, + "tool_name": tool_name, + "tool_input": tool_input, + "tool_output": tool_output, + "tool_duration_ms": tool_duration_ms, + "finding_id": finding_id, + "tokens_used": tokens_used, + "metadata": metadata, + "timestamp": timestamp.isoformat(), + } + + # 保存到数据库 + if self.db_session_factory: + try: + await self._save_event_to_db(event_data) + except Exception as e: + logger.error(f"Failed to save event to database: {e}") + + # 推送到队列 + if task_id in self._event_queues: + await self._event_queues[task_id].put(event_data) + + # 调用回调 + if task_id in self._event_callbacks: + for callback in self._event_callbacks[task_id]: + try: + if asyncio.iscoroutinefunction(callback): + await callback(event_data) + else: + callback(event_data) + except Exception as e: + logger.error(f"Event callback error: {e}") + + return event_id + + async def _save_event_to_db(self, event_data: Dict): + """保存事件到数据库""" + from app.models.agent_task import AgentEvent + + async with self.db_session_factory() as db: + event = AgentEvent( + id=event_data["id"], + task_id=event_data["task_id"], + event_type=event_data["event_type"], + sequence=event_data["sequence"], + phase=event_data["phase"], + message=event_data["message"], + tool_name=event_data["tool_name"], + tool_input=event_data["tool_input"], + tool_output=event_data["tool_output"], + tool_duration_ms=event_data["tool_duration_ms"], + finding_id=event_data["finding_id"], + tokens_used=event_data["tokens_used"], + event_metadata=event_data["metadata"], + ) + db.add(event) + await db.commit() + + def create_queue(self, task_id: str) -> asyncio.Queue: + """创建事件队列""" + if task_id not in self._event_queues: + self._event_queues[task_id] = asyncio.Queue() + return self._event_queues[task_id] + + def remove_queue(self, task_id: str): + """移除事件队列""" + if task_id in self._event_queues: + del self._event_queues[task_id] + + def add_callback(self, task_id: str, callback: Callable): + """添加事件回调""" + if task_id not in self._event_callbacks: + self._event_callbacks[task_id] = [] + self._event_callbacks[task_id].append(callback) + + def remove_callback(self, task_id: str, callback: Callable): + """移除事件回调""" + if task_id in self._event_callbacks: + self._event_callbacks[task_id].remove(callback) + + async def get_events( + self, + task_id: str, + after_sequence: int = 0, + limit: int = 100, + ) -> List[Dict]: + """获取事件列表""" + if not self.db_session_factory: + return [] + + from sqlalchemy.future import select + from app.models.agent_task import AgentEvent + + async with self.db_session_factory() as db: + result = await db.execute( + select(AgentEvent) + .where(AgentEvent.task_id == task_id) + .where(AgentEvent.sequence > after_sequence) + .order_by(AgentEvent.sequence) + .limit(limit) + ) + events = result.scalars().all() + return [event.to_sse_dict() for event in events] + + async def stream_events( + self, + task_id: str, + after_sequence: int = 0, + ) -> AsyncGenerator[Dict, None]: + """流式获取事件""" + queue = self.create_queue(task_id) + + # 先发送历史事件 + history = await self.get_events(task_id, after_sequence) + for event in history: + yield event + + # 然后实时推送新事件 + try: + while True: + try: + event = await asyncio.wait_for(queue.get(), timeout=30) + yield event + + # 检查是否是结束事件 + if event.get("event_type") in ["task_complete", "task_error", "task_cancel"]: + break + + except asyncio.TimeoutError: + # 发送心跳 + yield {"event_type": "heartbeat", "timestamp": datetime.now(timezone.utc).isoformat()} + + finally: + self.remove_queue(task_id) + + def create_emitter(self, task_id: str) -> AgentEventEmitter: + """创建事件发射器""" + return AgentEventEmitter(task_id, self) + diff --git a/backend/app/services/agent/graph/__init__.py b/backend/app/services/agent/graph/__init__.py new file mode 100644 index 0000000..5bc17d1 --- /dev/null +++ b/backend/app/services/agent/graph/__init__.py @@ -0,0 +1,28 @@ +""" +LangGraph 工作流模块 +使用状态图构建混合 Agent 审计流程 +""" + +from .audit_graph import AuditState, create_audit_graph, create_audit_graph_with_human +from .nodes import ReconNode, AnalysisNode, VerificationNode, ReportNode, HumanReviewNode +from .runner import AgentRunner, run_agent_task, LLMService + +__all__ = [ + # 状态和图 + "AuditState", + "create_audit_graph", + "create_audit_graph_with_human", + + # 节点 + "ReconNode", + "AnalysisNode", + "VerificationNode", + "ReportNode", + "HumanReviewNode", + + # Runner + "AgentRunner", + "run_agent_task", + "LLMService", +] + diff --git a/backend/app/services/agent/graph/audit_graph.py b/backend/app/services/agent/graph/audit_graph.py new file mode 100644 index 0000000..d53a577 --- /dev/null +++ b/backend/app/services/agent/graph/audit_graph.py @@ -0,0 +1,455 @@ +""" +DeepAudit 审计工作流图 +使用 LangGraph 构建状态机式的 Agent 协作流程 +""" + +from typing import TypedDict, Annotated, List, Dict, Any, Optional, Literal +from datetime import datetime +import operator +import logging + +from langgraph.graph import StateGraph, END +from langgraph.checkpoint.memory import MemorySaver +from langgraph.prebuilt import ToolNode + +logger = logging.getLogger(__name__) + + +# ============ 状态定义 ============ + +class Finding(TypedDict): + """漏洞发现""" + id: str + vulnerability_type: str + severity: str + title: str + description: str + file_path: Optional[str] + line_start: Optional[int] + code_snippet: Optional[str] + is_verified: bool + confidence: float + source: str + + +class AuditState(TypedDict): + """ + 审计状态 + 在整个工作流中传递和更新 + """ + # 输入 + project_root: str + project_info: Dict[str, Any] + config: Dict[str, Any] + task_id: str + + # Recon 阶段输出 + tech_stack: Dict[str, Any] + entry_points: List[Dict[str, Any]] + high_risk_areas: List[str] + dependencies: Dict[str, Any] + + # Analysis 阶段输出 + findings: Annotated[List[Finding], operator.add] # 使用 add 合并多轮发现 + + # Verification 阶段输出 + verified_findings: List[Finding] + false_positives: List[str] + + # 控制流 + current_phase: str + iteration: int + max_iterations: int + should_continue_analysis: bool + + # 消息和事件 + messages: Annotated[List[Dict], operator.add] + events: Annotated[List[Dict], operator.add] + + # 最终输出 + summary: Optional[Dict[str, Any]] + security_score: Optional[int] + error: Optional[str] + + +# ============ 路由函数 ============ + +def route_after_recon(state: AuditState) -> Literal["analysis", "end"]: + """Recon 后的路由决策""" + # 如果没有发现入口点或高风险区域,直接结束 + if not state.get("entry_points") and not state.get("high_risk_areas"): + return "end" + return "analysis" + + +def route_after_analysis(state: AuditState) -> Literal["verification", "analysis", "report"]: + """Analysis 后的路由决策""" + findings = state.get("findings", []) + iteration = state.get("iteration", 0) + max_iterations = state.get("max_iterations", 3) + should_continue = state.get("should_continue_analysis", False) + + # 如果没有发现,直接生成报告 + if not findings: + return "report" + + # 如果需要继续分析且未达到最大迭代 + if should_continue and iteration < max_iterations: + return "analysis" + + # 有发现需要验证 + return "verification" + + +def route_after_verification(state: AuditState) -> Literal["analysis", "report"]: + """Verification 后的路由决策""" + # 如果验证发现了误报,可能需要重新分析 + false_positives = state.get("false_positives", []) + iteration = state.get("iteration", 0) + max_iterations = state.get("max_iterations", 3) + + # 如果误报率太高且还有迭代次数,回到分析 + if len(false_positives) > len(state.get("verified_findings", [])) and iteration < max_iterations: + return "analysis" + + return "report" + + +# ============ 创建审计图 ============ + +def create_audit_graph( + recon_node, + analysis_node, + verification_node, + report_node, + checkpointer: Optional[MemorySaver] = None, +) -> StateGraph: + """ + 创建审计工作流图 + + Args: + recon_node: 信息收集节点 + analysis_node: 漏洞分析节点 + verification_node: 漏洞验证节点 + report_node: 报告生成节点 + checkpointer: 检查点存储器(用于状态持久化) + + Returns: + 编译后的 StateGraph + + 工作流结构: + + START + │ + ▼ + ┌──────┐ + │Recon │ 信息收集 + └──┬───┘ + │ + ▼ + ┌──────────┐ + │ Analysis │◄─────┐ 漏洞分析(可循环) + └────┬─────┘ │ + │ │ + ▼ │ + ┌────────────┐ │ + │Verification│────┘ 漏洞验证(可回溯) + └─────┬──────┘ + │ + ▼ + ┌──────────┐ + │ Report │ 报告生成 + └────┬─────┘ + │ + ▼ + END + """ + + # 创建状态图 + workflow = StateGraph(AuditState) + + # 添加节点 + workflow.add_node("recon", recon_node) + workflow.add_node("analysis", analysis_node) + workflow.add_node("verification", verification_node) + workflow.add_node("report", report_node) + + # 设置入口点 + workflow.set_entry_point("recon") + + # 添加条件边 + workflow.add_conditional_edges( + "recon", + route_after_recon, + { + "analysis": "analysis", + "end": END, + } + ) + + workflow.add_conditional_edges( + "analysis", + route_after_analysis, + { + "verification": "verification", + "analysis": "analysis", # 循环 + "report": "report", + } + ) + + workflow.add_conditional_edges( + "verification", + route_after_verification, + { + "analysis": "analysis", # 回溯 + "report": "report", + } + ) + + # Report -> END + workflow.add_edge("report", END) + + # 编译图 + if checkpointer: + return workflow.compile(checkpointer=checkpointer) + else: + return workflow.compile() + + +# ============ 带人机协作的审计图 ============ + +def create_audit_graph_with_human( + recon_node, + analysis_node, + verification_node, + report_node, + human_review_node, + checkpointer: Optional[MemorySaver] = None, +) -> StateGraph: + """ + 创建带人机协作的审计工作流图 + + 在验证阶段后增加人工审核节点 + + 工作流结构: + + START + │ + ▼ + ┌──────┐ + │Recon │ + └──┬───┘ + │ + ▼ + ┌──────────┐ + │ Analysis │◄─────┐ + └────┬─────┘ │ + │ │ + ▼ │ + ┌────────────┐ │ + │Verification│────┘ + └─────┬──────┘ + │ + ▼ + ┌─────────────┐ + │Human Review │ ← 人工审核(可跳过) + └──────┬──────┘ + │ + ▼ + ┌──────────┐ + │ Report │ + └────┬─────┘ + │ + ▼ + END + """ + + workflow = StateGraph(AuditState) + + # 添加节点 + workflow.add_node("recon", recon_node) + workflow.add_node("analysis", analysis_node) + workflow.add_node("verification", verification_node) + workflow.add_node("human_review", human_review_node) + workflow.add_node("report", report_node) + + workflow.set_entry_point("recon") + + workflow.add_conditional_edges( + "recon", + route_after_recon, + {"analysis": "analysis", "end": END} + ) + + workflow.add_conditional_edges( + "analysis", + route_after_analysis, + { + "verification": "verification", + "analysis": "analysis", + "report": "report", + } + ) + + # Verification -> Human Review + workflow.add_edge("verification", "human_review") + + # Human Review 后的路由 + def route_after_human(state: AuditState) -> Literal["analysis", "report"]: + # 人工可以决定重新分析或继续 + if state.get("should_continue_analysis"): + return "analysis" + return "report" + + workflow.add_conditional_edges( + "human_review", + route_after_human, + {"analysis": "analysis", "report": "report"} + ) + + workflow.add_edge("report", END) + + if checkpointer: + return workflow.compile(checkpointer=checkpointer, interrupt_before=["human_review"]) + else: + return workflow.compile() + + +# ============ 执行器 ============ + +class AuditGraphRunner: + """ + 审计图执行器 + 封装 LangGraph 工作流的执行 + """ + + def __init__( + self, + graph: StateGraph, + event_emitter=None, + ): + self.graph = graph + self.event_emitter = event_emitter + + async def run( + self, + project_root: str, + project_info: Dict[str, Any], + config: Dict[str, Any], + task_id: str, + ) -> Dict[str, Any]: + """ + 执行审计工作流 + + Args: + project_root: 项目根目录 + project_info: 项目信息 + config: 审计配置 + task_id: 任务 ID + + Returns: + 最终状态 + """ + # 初始状态 + initial_state: AuditState = { + "project_root": project_root, + "project_info": project_info, + "config": config, + "task_id": task_id, + "tech_stack": {}, + "entry_points": [], + "high_risk_areas": [], + "dependencies": {}, + "findings": [], + "verified_findings": [], + "false_positives": [], + "current_phase": "start", + "iteration": 0, + "max_iterations": config.get("max_iterations", 3), + "should_continue_analysis": False, + "messages": [], + "events": [], + "summary": None, + "security_score": None, + "error": None, + } + + # 配置 + run_config = { + "configurable": { + "thread_id": task_id, + } + } + + # 执行图 + try: + # 流式执行 + async for event in self.graph.astream(initial_state, config=run_config): + # 发射事件 + if self.event_emitter: + for node_name, node_state in event.items(): + await self.event_emitter.emit_info( + f"节点 {node_name} 完成" + ) + + # 发射发现事件 + if node_name == "analysis" and node_state.get("findings"): + new_findings = node_state["findings"] + await self.event_emitter.emit_info( + f"发现 {len(new_findings)} 个潜在漏洞" + ) + + # 获取最终状态 + final_state = self.graph.get_state(run_config) + return final_state.values + + except Exception as e: + logger.error(f"Graph execution failed: {e}", exc_info=True) + raise + + async def run_with_human_review( + self, + initial_state: AuditState, + human_feedback_callback, + ) -> Dict[str, Any]: + """ + 带人机协作的执行 + + Args: + initial_state: 初始状态 + human_feedback_callback: 人工反馈回调函数 + + Returns: + 最终状态 + """ + run_config = { + "configurable": { + "thread_id": initial_state["task_id"], + } + } + + # 执行到人工审核节点 + async for event in self.graph.astream(initial_state, config=run_config): + pass + + # 获取当前状态 + current_state = self.graph.get_state(run_config) + + # 如果在人工审核节点暂停 + if current_state.next == ("human_review",): + # 调用人工反馈 + human_decision = await human_feedback_callback(current_state.values) + + # 更新状态并继续 + updated_state = { + **current_state.values, + "should_continue_analysis": human_decision.get("continue_analysis", False), + } + + # 继续执行 + async for event in self.graph.astream(updated_state, config=run_config): + pass + + # 返回最终状态 + return self.graph.get_state(run_config).values + diff --git a/backend/app/services/agent/graph/nodes.py b/backend/app/services/agent/graph/nodes.py new file mode 100644 index 0000000..bcf5825 --- /dev/null +++ b/backend/app/services/agent/graph/nodes.py @@ -0,0 +1,360 @@ +""" +LangGraph 节点实现 +每个节点封装一个 Agent 的执行逻辑 +""" + +from typing import Dict, Any, List, Optional +import logging + +logger = logging.getLogger(__name__) + +# 延迟导入避免循环依赖 +def get_audit_state_type(): + from .audit_graph import AuditState + return AuditState + + +class BaseNode: + """节点基类""" + + def __init__(self, agent=None, event_emitter=None): + self.agent = agent + self.event_emitter = event_emitter + + async def emit_event(self, event_type: str, message: str, **kwargs): + """发射事件""" + if self.event_emitter: + try: + await self.event_emitter.emit_info(message) + except Exception as e: + logger.warning(f"Failed to emit event: {e}") + + +class ReconNode(BaseNode): + """ + 信息收集节点 + + 输入: project_root, project_info, config + 输出: tech_stack, entry_points, high_risk_areas, dependencies + """ + + async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]: + """执行信息收集""" + await self.emit_event("phase_start", "🔍 开始信息收集阶段") + + try: + # 调用 Recon Agent + result = await self.agent.run({ + "project_info": state["project_info"], + "config": state["config"], + }) + + if result.success and result.data: + data = result.data + + await self.emit_event( + "phase_complete", + f"✅ 信息收集完成: 发现 {len(data.get('entry_points', []))} 个入口点" + ) + + return { + "tech_stack": data.get("tech_stack", {}), + "entry_points": data.get("entry_points", []), + "high_risk_areas": data.get("high_risk_areas", []), + "dependencies": data.get("dependencies", {}), + "current_phase": "recon_complete", + "findings": data.get("initial_findings", []), # 初步发现 + "events": [{ + "type": "recon_complete", + "data": { + "entry_points_count": len(data.get("entry_points", [])), + "high_risk_areas_count": len(data.get("high_risk_areas", [])), + } + }], + } + else: + return { + "error": result.error or "Recon failed", + "current_phase": "error", + } + + except Exception as e: + logger.error(f"Recon node failed: {e}", exc_info=True) + return { + "error": str(e), + "current_phase": "error", + } + + +class AnalysisNode(BaseNode): + """ + 漏洞分析节点 + + 输入: tech_stack, entry_points, high_risk_areas, previous findings + 输出: findings (累加), should_continue_analysis + """ + + async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]: + """执行漏洞分析""" + iteration = state.get("iteration", 0) + 1 + + await self.emit_event( + "phase_start", + f"🔬 开始漏洞分析阶段 (迭代 {iteration})" + ) + + try: + # 构建分析输入 + analysis_input = { + "phase_name": "analysis", + "project_info": state["project_info"], + "config": state["config"], + "plan": { + "high_risk_areas": state.get("high_risk_areas", []), + }, + "previous_results": { + "recon": { + "data": { + "tech_stack": state.get("tech_stack", {}), + "entry_points": state.get("entry_points", []), + "high_risk_areas": state.get("high_risk_areas", []), + } + } + }, + } + + # 调用 Analysis Agent + result = await self.agent.run(analysis_input) + + if result.success and result.data: + new_findings = result.data.get("findings", []) + + # 判断是否需要继续分析 + # 如果这一轮发现了很多问题,可能还有更多 + should_continue = ( + len(new_findings) >= 5 and + iteration < state.get("max_iterations", 3) + ) + + await self.emit_event( + "phase_complete", + f"✅ 分析迭代 {iteration} 完成: 发现 {len(new_findings)} 个潜在漏洞" + ) + + return { + "findings": new_findings, # 会自动累加 + "iteration": iteration, + "should_continue_analysis": should_continue, + "current_phase": "analysis_complete", + "events": [{ + "type": "analysis_iteration", + "data": { + "iteration": iteration, + "findings_count": len(new_findings), + } + }], + } + else: + return { + "iteration": iteration, + "should_continue_analysis": False, + "current_phase": "analysis_complete", + } + + except Exception as e: + logger.error(f"Analysis node failed: {e}", exc_info=True) + return { + "error": str(e), + "should_continue_analysis": False, + "current_phase": "error", + } + + +class VerificationNode(BaseNode): + """ + 漏洞验证节点 + + 输入: findings + 输出: verified_findings, false_positives + """ + + async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]: + """执行漏洞验证""" + findings = state.get("findings", []) + + if not findings: + return { + "verified_findings": [], + "false_positives": [], + "current_phase": "verification_complete", + } + + await self.emit_event( + "phase_start", + f"🔐 开始漏洞验证阶段 ({len(findings)} 个待验证)" + ) + + try: + # 构建验证输入 + verification_input = { + "previous_results": { + "analysis": { + "data": { + "findings": findings, + } + } + }, + "config": state["config"], + } + + # 调用 Verification Agent + result = await self.agent.run(verification_input) + + if result.success and result.data: + verified = [f for f in result.data.get("findings", []) if f.get("is_verified")] + false_pos = [f["id"] for f in result.data.get("findings", []) + if f.get("verdict") == "false_positive"] + + await self.emit_event( + "phase_complete", + f"✅ 验证完成: {len(verified)} 已确认, {len(false_pos)} 误报" + ) + + return { + "verified_findings": verified, + "false_positives": false_pos, + "current_phase": "verification_complete", + "events": [{ + "type": "verification_complete", + "data": { + "verified_count": len(verified), + "false_positive_count": len(false_pos), + } + }], + } + else: + return { + "verified_findings": [], + "false_positives": [], + "current_phase": "verification_complete", + } + + except Exception as e: + logger.error(f"Verification node failed: {e}", exc_info=True) + return { + "error": str(e), + "current_phase": "error", + } + + +class ReportNode(BaseNode): + """ + 报告生成节点 + + 输入: all state + 输出: summary, security_score + """ + + async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]: + """生成审计报告""" + await self.emit_event("phase_start", "📊 生成审计报告") + + try: + findings = state.get("findings", []) + verified = state.get("verified_findings", []) + false_positives = state.get("false_positives", []) + + # 统计漏洞分布 + severity_counts = {"critical": 0, "high": 0, "medium": 0, "low": 0} + type_counts = {} + + for finding in findings: + sev = finding.get("severity", "medium") + severity_counts[sev] = severity_counts.get(sev, 0) + 1 + + vtype = finding.get("vulnerability_type", "other") + type_counts[vtype] = type_counts.get(vtype, 0) + 1 + + # 计算安全评分 + base_score = 100 + deductions = ( + severity_counts["critical"] * 25 + + severity_counts["high"] * 15 + + severity_counts["medium"] * 8 + + severity_counts["low"] * 3 + ) + security_score = max(0, base_score - deductions) + + # 生成摘要 + summary = { + "total_findings": len(findings), + "verified_count": len(verified), + "false_positive_count": len(false_positives), + "severity_distribution": severity_counts, + "vulnerability_types": type_counts, + "tech_stack": state.get("tech_stack", {}), + "entry_points_analyzed": len(state.get("entry_points", [])), + "high_risk_areas": state.get("high_risk_areas", []), + "iterations": state.get("iteration", 1), + } + + await self.emit_event( + "phase_complete", + f"✅ 报告生成完成: 安全评分 {security_score}/100" + ) + + return { + "summary": summary, + "security_score": security_score, + "current_phase": "complete", + "events": [{ + "type": "audit_complete", + "data": { + "security_score": security_score, + "total_findings": len(findings), + "verified_count": len(verified), + } + }], + } + + except Exception as e: + logger.error(f"Report node failed: {e}", exc_info=True) + return { + "error": str(e), + "current_phase": "error", + } + + +class HumanReviewNode(BaseNode): + """ + 人工审核节点 + + 在此节点暂停,等待人工反馈 + """ + + async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]: + """ + 人工审核节点 + + 这个节点会被 interrupt_before 暂停 + 用户可以: + 1. 确认发现 + 2. 标记误报 + 3. 请求重新分析 + """ + await self.emit_event( + "human_review", + f"⏸️ 等待人工审核 ({len(state.get('verified_findings', []))} 个待确认)" + ) + + # 返回当前状态,不做修改 + # 人工反馈会通过 update_state 传入 + return { + "current_phase": "human_review", + "messages": [{ + "role": "system", + "content": "等待人工审核", + "findings_for_review": state.get("verified_findings", []), + }], + } + diff --git a/backend/app/services/agent/graph/runner.py b/backend/app/services/agent/graph/runner.py new file mode 100644 index 0000000..dff9336 --- /dev/null +++ b/backend/app/services/agent/graph/runner.py @@ -0,0 +1,621 @@ +""" +DeepAudit LangGraph Runner +基于 LangGraph 的 Agent 审计执行器 +""" + +import asyncio +import logging +import os +import uuid +from datetime import datetime, timezone +from typing import Dict, List, Optional, Any, AsyncGenerator + +from sqlalchemy.ext.asyncio import AsyncSession + +from langgraph.graph import StateGraph, END +from langgraph.checkpoint.memory import MemorySaver + +from app.models.agent_task import ( + AgentTask, AgentEvent, AgentFinding, + AgentTaskStatus, AgentTaskPhase, AgentEventType, + VulnerabilitySeverity, VulnerabilityType, FindingStatus, +) +from app.services.agent.event_manager import EventManager, AgentEventEmitter +from app.services.agent.tools import ( + RAGQueryTool, SecurityCodeSearchTool, FunctionContextTool, + PatternMatchTool, CodeAnalysisTool, DataFlowAnalysisTool, VulnerabilityValidationTool, + FileReadTool, FileSearchTool, ListFilesTool, + SandboxTool, SandboxHttpTool, VulnerabilityVerifyTool, SandboxManager, + SemgrepTool, BanditTool, GitleaksTool, NpmAuditTool, SafetyTool, + TruffleHogTool, OSVScannerTool, +) +from app.services.rag import CodeIndexer, CodeRetriever, EmbeddingService +from app.core.config import settings + +from .audit_graph import AuditState, create_audit_graph +from .nodes import ReconNode, AnalysisNode, VerificationNode, ReportNode + +logger = logging.getLogger(__name__) + + +class LLMService: + """LLM 服务封装""" + + def __init__(self, model: Optional[str] = None, api_key: Optional[str] = None): + self.model = model or settings.DEFAULT_LLM_MODEL + self.api_key = api_key or settings.LLM_API_KEY + + async def chat_completion_raw( + self, + messages: List[Dict[str, str]], + temperature: float = 0.1, + max_tokens: int = 4096, + ) -> Dict[str, Any]: + """调用 LLM 生成响应""" + try: + import litellm + + response = await litellm.acompletion( + model=self.model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + api_key=self.api_key, + ) + + return { + "content": response.choices[0].message.content, + "usage": { + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + "total_tokens": response.usage.total_tokens, + } if response.usage else {}, + } + + except Exception as e: + logger.error(f"LLM call failed: {e}") + raise + + +class AgentRunner: + """ + DeepAudit LangGraph Agent Runner + + 基于 LangGraph 状态图的审计执行器 + + 工作流: + START → Recon → Analysis ⟲ → Verification → Report → END + """ + + def __init__( + self, + db: AsyncSession, + task: AgentTask, + project_root: str, + ): + self.db = db + self.task = task + self.project_root = project_root + + # 事件管理 + self.event_manager = EventManager(db, task.id) + self.event_emitter = AgentEventEmitter(self.event_manager) + + # LLM 服务 + self.llm_service = LLMService() + + # 工具集 + self.tools: Dict[str, Any] = {} + + # RAG 组件 + self.retriever: Optional[CodeRetriever] = None + self.indexer: Optional[CodeIndexer] = None + + # 沙箱 + self.sandbox_manager: Optional[SandboxManager] = None + + # LangGraph + self.graph: Optional[StateGraph] = None + self.checkpointer = MemorySaver() + + # 状态 + self._cancelled = False + + async def initialize(self): + """初始化 Runner""" + await self.event_emitter.emit_info("🚀 正在初始化 DeepAudit LangGraph Agent...") + + # 1. 初始化 RAG 系统 + await self._initialize_rag() + + # 2. 初始化工具 + await self._initialize_tools() + + # 3. 构建 LangGraph + await self._build_graph() + + await self.event_emitter.emit_info("✅ LangGraph 系统初始化完成") + + async def _initialize_rag(self): + """初始化 RAG 系统""" + await self.event_emitter.emit_info("📚 初始化 RAG 代码检索系统...") + + try: + embedding_service = EmbeddingService( + provider=settings.EMBEDDING_PROVIDER, + model=settings.EMBEDDING_MODEL, + api_key=settings.LLM_API_KEY, + base_url=settings.LLM_BASE_URL, + ) + + self.indexer = CodeIndexer( + embedding_service=embedding_service, + vector_db_path=settings.VECTOR_DB_PATH, + collection_name=f"project_{self.task.project_id}", + ) + + self.retriever = CodeRetriever( + embedding_service=embedding_service, + vector_db_path=settings.VECTOR_DB_PATH, + collection_name=f"project_{self.task.project_id}", + ) + + except Exception as e: + logger.warning(f"RAG initialization failed: {e}") + await self.event_emitter.emit_warning(f"RAG 系统初始化失败: {e}") + + async def _initialize_tools(self): + """初始化工具集""" + await self.event_emitter.emit_info("🔧 初始化 Agent 工具集...") + + # 文件工具 + self.tools["read_file"] = FileReadTool(self.project_root) + self.tools["search_code"] = FileSearchTool(self.project_root) + self.tools["list_files"] = ListFilesTool(self.project_root) + + # RAG 工具 + if self.retriever: + self.tools["rag_query"] = RAGQueryTool(self.retriever) + self.tools["security_search"] = SecurityCodeSearchTool(self.retriever) + self.tools["function_context"] = FunctionContextTool(self.retriever) + + # 分析工具 + self.tools["pattern_match"] = PatternMatchTool(self.project_root) + self.tools["code_analysis"] = CodeAnalysisTool(self.llm_service) + self.tools["dataflow_analysis"] = DataFlowAnalysisTool(self.llm_service) + self.tools["vulnerability_validation"] = VulnerabilityValidationTool(self.llm_service) + + # 外部安全工具 + self.tools["semgrep_scan"] = SemgrepTool(self.project_root) + self.tools["bandit_scan"] = BanditTool(self.project_root) + self.tools["gitleaks_scan"] = GitleaksTool(self.project_root) + self.tools["trufflehog_scan"] = TruffleHogTool(self.project_root) + self.tools["npm_audit"] = NpmAuditTool(self.project_root) + self.tools["safety_scan"] = SafetyTool(self.project_root) + self.tools["osv_scan"] = OSVScannerTool(self.project_root) + + # 沙箱工具 + try: + self.sandbox_manager = SandboxManager( + image=settings.SANDBOX_IMAGE, + memory_limit=settings.SANDBOX_MEMORY_LIMIT, + cpu_limit=settings.SANDBOX_CPU_LIMIT, + ) + + self.tools["sandbox_exec"] = SandboxTool(self.sandbox_manager) + self.tools["sandbox_http"] = SandboxHttpTool(self.sandbox_manager) + self.tools["verify_vulnerability"] = VulnerabilityVerifyTool(self.sandbox_manager) + + except Exception as e: + logger.warning(f"Sandbox initialization failed: {e}") + + await self.event_emitter.emit_info(f"✅ 已加载 {len(self.tools)} 个工具") + + async def _build_graph(self): + """构建 LangGraph 审计图""" + await self.event_emitter.emit_info("📊 构建 LangGraph 审计工作流...") + + # 导入 Agent + from app.services.agent.agents import ReconAgent, AnalysisAgent, VerificationAgent + + # 创建 Agent 实例 + recon_agent = ReconAgent( + llm_service=self.llm_service, + tools=self.tools, + event_emitter=self.event_emitter, + ) + + analysis_agent = AnalysisAgent( + llm_service=self.llm_service, + tools=self.tools, + event_emitter=self.event_emitter, + ) + + verification_agent = VerificationAgent( + llm_service=self.llm_service, + tools=self.tools, + event_emitter=self.event_emitter, + ) + + # 创建节点 + recon_node = ReconNode(recon_agent, self.event_emitter) + analysis_node = AnalysisNode(analysis_agent, self.event_emitter) + verification_node = VerificationNode(verification_agent, self.event_emitter) + report_node = ReportNode(None, self.event_emitter) + + # 构建图 + self.graph = create_audit_graph( + recon_node=recon_node, + analysis_node=analysis_node, + verification_node=verification_node, + report_node=report_node, + checkpointer=self.checkpointer, + ) + + await self.event_emitter.emit_info("✅ LangGraph 工作流构建完成") + + async def run(self) -> Dict[str, Any]: + """ + 执行 LangGraph 审计 + + Returns: + 最终状态 + """ + import time + start_time = time.time() + + try: + # 初始化 + await self.initialize() + + # 更新任务状态 + await self._update_task_status(AgentTaskStatus.RUNNING) + + # 1. 索引代码 + await self._index_code() + + if self._cancelled: + return {"success": False, "error": "任务已取消"} + + # 2. 收集项目信息 + project_info = await self._collect_project_info() + + # 3. 构建初始状态 + initial_state: AuditState = { + "project_root": self.project_root, + "project_info": project_info, + "config": self.task.config or {}, + "task_id": self.task.id, + "tech_stack": {}, + "entry_points": [], + "high_risk_areas": [], + "dependencies": {}, + "findings": [], + "verified_findings": [], + "false_positives": [], + "current_phase": "start", + "iteration": 0, + "max_iterations": (self.task.config or {}).get("max_iterations", 3), + "should_continue_analysis": False, + "messages": [], + "events": [], + "summary": None, + "security_score": None, + "error": None, + } + + # 4. 执行 LangGraph + await self.event_emitter.emit_phase_start("langgraph", "🔄 启动 LangGraph 工作流") + + run_config = { + "configurable": { + "thread_id": self.task.id, + } + } + + final_state = None + + # 流式执行并发射事件 + async for event in self.graph.astream(initial_state, config=run_config): + if self._cancelled: + break + + # 处理每个节点的输出 + for node_name, node_output in event.items(): + await self._handle_node_output(node_name, node_output) + + # 更新阶段 + phase_map = { + "recon": AgentTaskPhase.RECONNAISSANCE, + "analysis": AgentTaskPhase.ANALYSIS, + "verification": AgentTaskPhase.VERIFICATION, + "report": AgentTaskPhase.REPORTING, + } + if node_name in phase_map: + await self._update_task_phase(phase_map[node_name]) + + final_state = node_output + + # 5. 获取最终状态 + if not final_state: + graph_state = self.graph.get_state(run_config) + final_state = graph_state.values if graph_state else {} + + # 6. 保存发现 + findings = final_state.get("findings", []) + await self._save_findings(findings) + + # 7. 更新任务摘要 + summary = final_state.get("summary", {}) + security_score = final_state.get("security_score", 100) + + await self._update_task_summary( + total_findings=len(findings), + verified_count=len(final_state.get("verified_findings", [])), + security_score=security_score, + ) + + # 8. 完成 + duration_ms = int((time.time() - start_time) * 1000) + + await self._update_task_status(AgentTaskStatus.COMPLETED) + await self.event_emitter.emit_task_complete( + findings_count=len(findings), + duration_ms=duration_ms, + ) + + return { + "success": True, + "data": { + "findings": findings, + "verified_findings": final_state.get("verified_findings", []), + "summary": summary, + "security_score": security_score, + }, + "duration_ms": duration_ms, + } + + except asyncio.CancelledError: + await self._update_task_status(AgentTaskStatus.CANCELLED) + return {"success": False, "error": "任务已取消"} + + except Exception as e: + logger.error(f"LangGraph run failed: {e}", exc_info=True) + await self._update_task_status(AgentTaskStatus.FAILED, str(e)) + await self.event_emitter.emit_error(str(e)) + return {"success": False, "error": str(e)} + + finally: + await self._cleanup() + + async def _handle_node_output(self, node_name: str, output: Dict[str, Any]): + """处理节点输出""" + # 发射节点事件 + events = output.get("events", []) + for evt in events: + await self.event_emitter.emit_info( + f"[{node_name}] {evt.get('type', 'event')}: {evt.get('data', {})}" + ) + + # 处理新发现 + if node_name == "analysis": + new_findings = output.get("findings", []) + if new_findings: + for finding in new_findings[:5]: # 限制事件数量 + await self.event_emitter.emit_finding( + title=finding.get("title", "Unknown"), + severity=finding.get("severity", "medium"), + file_path=finding.get("file_path"), + ) + + # 处理验证结果 + if node_name == "verification": + verified = output.get("verified_findings", []) + for v in verified[:5]: + await self.event_emitter.emit_info( + f"✅ 已验证: {v.get('title', 'Unknown')}" + ) + + # 处理错误 + if output.get("error"): + await self.event_emitter.emit_error(output["error"]) + + async def _index_code(self): + """索引代码""" + if not self.indexer: + await self.event_emitter.emit_warning("RAG 未初始化,跳过代码索引") + return + + await self._update_task_phase(AgentTaskPhase.INDEXING) + await self.event_emitter.emit_phase_start("indexing", "📝 开始代码索引") + + try: + async for progress in self.indexer.index_directory(self.project_root): + if self._cancelled: + return + + await self.event_emitter.emit_progress( + progress.processed / max(progress.total, 1) * 100, + f"正在索引: {progress.current_file or 'N/A'}" + ) + + await self.event_emitter.emit_phase_complete("indexing", "✅ 代码索引完成") + + except Exception as e: + logger.warning(f"Code indexing failed: {e}") + await self.event_emitter.emit_warning(f"代码索引失败: {e}") + + async def _collect_project_info(self) -> Dict[str, Any]: + """收集项目信息""" + info = { + "name": self.task.project.name if self.task.project else "unknown", + "root": self.project_root, + "languages": [], + "file_count": 0, + } + + try: + exclude_dirs = { + "node_modules", "__pycache__", ".git", "venv", ".venv", + "build", "dist", "target", ".idea", ".vscode", + } + + for root, dirs, files in os.walk(self.project_root): + dirs[:] = [d for d in dirs if d not in exclude_dirs] + info["file_count"] += len(files) + + lang_map = { + ".py": "Python", ".js": "JavaScript", ".ts": "TypeScript", + ".java": "Java", ".go": "Go", ".php": "PHP", + ".rb": "Ruby", ".rs": "Rust", ".c": "C", ".cpp": "C++", + } + + for f in files: + ext = os.path.splitext(f)[1].lower() + if ext in lang_map and lang_map[ext] not in info["languages"]: + info["languages"].append(lang_map[ext]) + + except Exception as e: + logger.warning(f"Failed to collect project info: {e}") + + return info + + async def _save_findings(self, findings: List[Dict]): + """保存发现到数据库""" + severity_map = { + "critical": VulnerabilitySeverity.CRITICAL, + "high": VulnerabilitySeverity.HIGH, + "medium": VulnerabilitySeverity.MEDIUM, + "low": VulnerabilitySeverity.LOW, + "info": VulnerabilitySeverity.INFO, + } + + type_map = { + "sql_injection": VulnerabilityType.SQL_INJECTION, + "xss": VulnerabilityType.XSS, + "command_injection": VulnerabilityType.COMMAND_INJECTION, + "path_traversal": VulnerabilityType.PATH_TRAVERSAL, + "ssrf": VulnerabilityType.SSRF, + "hardcoded_secret": VulnerabilityType.HARDCODED_SECRET, + "deserialization": VulnerabilityType.INSECURE_DESERIALIZATION, + "weak_crypto": VulnerabilityType.WEAK_CRYPTO, + } + + for finding in findings: + try: + db_finding = AgentFinding( + id=str(uuid.uuid4()), + task_id=self.task.id, + vulnerability_type=type_map.get( + finding.get("vulnerability_type", "other"), + VulnerabilityType.OTHER + ), + severity=severity_map.get( + finding.get("severity", "medium"), + VulnerabilitySeverity.MEDIUM + ), + title=finding.get("title", "Unknown"), + description=finding.get("description", ""), + file_path=finding.get("file_path"), + line_start=finding.get("line_start"), + line_end=finding.get("line_end"), + code_snippet=finding.get("code_snippet"), + source=finding.get("source"), + sink=finding.get("sink"), + suggestion=finding.get("suggestion") or finding.get("recommendation"), + is_verified=finding.get("is_verified", False), + confidence=finding.get("confidence", 0.5), + poc=finding.get("poc"), + status=FindingStatus.VERIFIED if finding.get("is_verified") else FindingStatus.OPEN, + ) + + self.db.add(db_finding) + + except Exception as e: + logger.warning(f"Failed to save finding: {e}") + + try: + await self.db.commit() + except Exception as e: + logger.error(f"Failed to commit findings: {e}") + await self.db.rollback() + + async def _update_task_status( + self, + status: AgentTaskStatus, + error: Optional[str] = None + ): + """更新任务状态""" + self.task.status = status + + if status == AgentTaskStatus.RUNNING: + self.task.started_at = datetime.now(timezone.utc) + elif status in [AgentTaskStatus.COMPLETED, AgentTaskStatus.FAILED, AgentTaskStatus.CANCELLED]: + self.task.finished_at = datetime.now(timezone.utc) + + if error: + self.task.error_message = error + + try: + await self.db.commit() + except Exception as e: + logger.error(f"Failed to update task status: {e}") + + async def _update_task_phase(self, phase: AgentTaskPhase): + """更新任务阶段""" + self.task.current_phase = phase + try: + await self.db.commit() + except Exception as e: + logger.error(f"Failed to update task phase: {e}") + + async def _update_task_summary( + self, + total_findings: int, + verified_count: int, + security_score: int, + ): + """更新任务摘要""" + self.task.total_findings = total_findings + self.task.verified_findings = verified_count + self.task.security_score = security_score + + try: + await self.db.commit() + except Exception as e: + logger.error(f"Failed to update task summary: {e}") + + async def _cleanup(self): + """清理资源""" + try: + if self.sandbox_manager: + await self.sandbox_manager.cleanup() + await self.event_manager.close() + except Exception as e: + logger.warning(f"Cleanup error: {e}") + + def cancel(self): + """取消任务""" + self._cancelled = True + + +# 便捷函数 +async def run_agent_task( + db: AsyncSession, + task: AgentTask, + project_root: str, +) -> Dict[str, Any]: + """ + 运行 Agent 审计任务 + + Args: + db: 数据库会话 + task: Agent 任务 + project_root: 项目根目录 + + Returns: + 审计结果 + """ + runner = AgentRunner(db, task, project_root) + return await runner.run() + diff --git a/backend/app/services/agent/prompts/__init__.py b/backend/app/services/agent/prompts/__init__.py new file mode 100644 index 0000000..85c8ddf --- /dev/null +++ b/backend/app/services/agent/prompts/__init__.py @@ -0,0 +1,20 @@ +""" +Agent 提示词模块 +""" + +from .system_prompts import ( + ORCHESTRATOR_SYSTEM_PROMPT, + ANALYSIS_SYSTEM_PROMPT, + VERIFICATION_SYSTEM_PROMPT, + PLANNING_PROMPT, + REPORTING_PROMPT, +) + +__all__ = [ + "ORCHESTRATOR_SYSTEM_PROMPT", + "ANALYSIS_SYSTEM_PROMPT", + "VERIFICATION_SYSTEM_PROMPT", + "PLANNING_PROMPT", + "REPORTING_PROMPT", +] + diff --git a/backend/app/services/agent/prompts/system_prompts.py b/backend/app/services/agent/prompts/system_prompts.py new file mode 100644 index 0000000..f44321b --- /dev/null +++ b/backend/app/services/agent/prompts/system_prompts.py @@ -0,0 +1,170 @@ +""" +Agent 系统提示词 +""" + +# 编排 Agent 系统提示词 +ORCHESTRATOR_SYSTEM_PROMPT = """你是一个专业的代码安全审计 Agent,负责自主分析代码并发现安全漏洞。 + +## 你的职责 +1. 分析项目代码,制定审计计划 +2. 使用工具深入分析代码 +3. 发现并验证安全漏洞 +4. 生成详细的漏洞报告 + +## 审计流程 +1. **规划阶段**: 分析项目结构,识别高风险区域,制定审计计划 +2. **索引阶段**: 等待代码索引完成 +3. **分析阶段**: 使用工具进行深度代码分析 +4. **验证阶段**: 在沙箱中验证发现的漏洞 +5. **报告阶段**: 整理发现,生成报告 + +## 重点关注的漏洞类型 +- SQL 注入(包括 ORM 注入) +- XSS 跨站脚本(反射型、存储型、DOM型) +- 命令注入和代码注入 +- 路径遍历和任意文件访问 +- SSRF 服务端请求伪造 +- XXE XML 外部实体注入 +- 不安全的反序列化 +- 认证和授权绕过 +- 敏感信息泄露(硬编码密钥、日志泄露) +- 业务逻辑漏洞 +- IDOR 不安全的直接对象引用 + +## 分析方法 +1. **快速扫描**: 首先使用 pattern_match 快速发现可疑代码 +2. **语义搜索**: 使用 rag_query 查找相关上下文 +3. **深度分析**: 对可疑代码使用 code_analysis 深入分析 +4. **数据流追踪**: 追踪用户输入到危险函数的路径 +5. **漏洞验证**: 在沙箱中验证发现的漏洞 + +## 工作原则 +- 系统性: 不遗漏任何可能的攻击面 +- 精准性: 减少误报,每个发现都要有充分证据 +- 深入性: 不只看表面,要理解代码逻辑 +- 可操作性: 提供具体的修复建议 + +## 输出要求 +发现漏洞时,提供: +- 漏洞类型和严重程度 +- 具体位置(文件、行号) +- 漏洞描述和成因 +- 利用方式和影响 +- 修复建议和示例代码 + +请开始审计工作,使用可用的工具进行分析。""" + +# 分析 Agent 系统提示词 +ANALYSIS_SYSTEM_PROMPT = """你是一个专注于代码漏洞分析的安全专家。 + +## 你的任务 +深入分析代码,发现安全漏洞。你需要: +1. 识别危险的代码模式 +2. 追踪数据流(从用户输入到危险函数) +3. 判断漏洞是否可利用 +4. 评估漏洞的严重程度 + +## 可用工具 +- rag_query: 语义搜索相关代码 +- pattern_match: 快速模式匹配 +- code_analysis: LLM 深度分析 +- read_file: 读取文件内容 +- search_code: 关键字搜索 +- dataflow_analysis: 数据流分析 +- vulnerability_validation: 漏洞验证 + +## 分析策略 +1. 先全局后局部:先了解整体架构,再深入细节 +2. 先快后深:先快速扫描,再深入可疑点 +3. 追踪数据流:用户输入 → 处理逻辑 → 危险函数 +4. 验证每个发现:确保不是误报 + +## 严重程度评估标准 +- **Critical**: 可直接导致系统被控制或大规模数据泄露 +- **High**: 可导致敏感数据泄露或重要功能被绕过 +- **Medium**: 可导致部分数据泄露或需要特定条件利用 +- **Low**: 影响有限或利用条件苛刻 + +请开始分析,专注于发现真实的安全漏洞。""" + +# 验证 Agent 系统提示词 +VERIFICATION_SYSTEM_PROMPT = """你是一个专注于漏洞验证的安全专家。 + +## 你的任务 +验证发现的漏洞是否真实存在,判断是否为误报。 + +## 验证方法 +1. **代码审查**: 仔细分析漏洞代码和上下文 +2. **构造 Payload**: 设计能触发漏洞的输入 +3. **沙箱测试**: 在隔离环境中测试漏洞 +4. **分析结果**: 判断漏洞是否可利用 + +## 可用工具 +- sandbox_exec: 在沙箱中执行命令 +- sandbox_http: 发送 HTTP 请求 +- verify_vulnerability: 自动验证漏洞 +- vulnerability_validation: 深度验证分析 + +## 验证原则 +- 安全第一:所有测试在沙箱中进行 +- 证据充分:验证结果要有明确证据 +- 谨慎判断:不确定时标记为需要人工审核 + +## 输出要求 +- 验证结果:确认/可能/误报 +- 验证方法:使用的测试方法 +- 证据:支持判断的具体证据 +- PoC(如果确认):可复现的测试代码 + +请开始验证工作。""" + +# 规划提示词 +PLANNING_PROMPT = """基于以下项目信息,制定安全审计计划。 + +## 项目信息 +- 名称: {project_name} +- 语言: {languages} +- 文件数量: {file_count} +- 目录结构: {directory_structure} + +## 请输出审计计划 +包含以下内容(JSON格式): +```json +{{ + "high_risk_areas": ["高风险目录/文件列表"], + "focus_vulnerabilities": ["重点关注的漏洞类型"], + "audit_order": ["审计顺序"], + "estimated_steps": "预计步骤数", + "special_attention": ["特别注意事项"] +}} +``` + +## 高风险区域识别原则 +1. 用户认证和授权相关代码 +2. 数据库操作和 ORM 使用 +3. 文件上传和下载功能 +4. API 接口和输入处理 +5. 第三方服务调用 +6. 配置文件和环境变量 +7. 加密和密钥管理""" + +# 报告生成提示词 +REPORTING_PROMPT = """基于审计发现,生成安全审计报告摘要。 + +## 审计发现 +{findings} + +## 统计信息 +- 总发现数: {total_findings} +- 已验证: {verified_count} +- 严重程度分布: {severity_distribution} + +## 请输出报告摘要 +包含以下内容: +1. 整体安全评估 +2. 主要风险点 +3. 优先修复建议 +4. 安全改进建议 + +请用简洁专业的语言描述。""" + diff --git a/backend/app/services/agent/tools/__init__.py b/backend/app/services/agent/tools/__init__.py new file mode 100644 index 0000000..1ebf28f --- /dev/null +++ b/backend/app/services/agent/tools/__init__.py @@ -0,0 +1,61 @@ +""" +Agent 工具集 +提供 LangChain Agent 使用的各种工具 +包括内置工具和外部安全工具 +""" + +from .base import AgentTool, ToolResult +from .rag_tool import RAGQueryTool, SecurityCodeSearchTool, FunctionContextTool +from .pattern_tool import PatternMatchTool +from .code_analysis_tool import CodeAnalysisTool, DataFlowAnalysisTool, VulnerabilityValidationTool +from .file_tool import FileReadTool, FileSearchTool, ListFilesTool +from .sandbox_tool import SandboxTool, SandboxHttpTool, VulnerabilityVerifyTool, SandboxManager + +# 外部安全工具 +from .external_tools import ( + SemgrepTool, + BanditTool, + GitleaksTool, + NpmAuditTool, + SafetyTool, + TruffleHogTool, + OSVScannerTool, +) + +__all__ = [ + # 基础 + "AgentTool", + "ToolResult", + + # RAG 工具 + "RAGQueryTool", + "SecurityCodeSearchTool", + "FunctionContextTool", + + # 代码分析 + "PatternMatchTool", + "CodeAnalysisTool", + "DataFlowAnalysisTool", + "VulnerabilityValidationTool", + + # 文件操作 + "FileReadTool", + "FileSearchTool", + "ListFilesTool", + + # 沙箱 + "SandboxTool", + "SandboxHttpTool", + "VulnerabilityVerifyTool", + "SandboxManager", + + # 外部安全工具 + "SemgrepTool", + "BanditTool", + "GitleaksTool", + "NpmAuditTool", + "SafetyTool", + "TruffleHogTool", + "OSVScannerTool", +] + diff --git a/backend/app/services/agent/tools/base.py b/backend/app/services/agent/tools/base.py new file mode 100644 index 0000000..e0b8a10 --- /dev/null +++ b/backend/app/services/agent/tools/base.py @@ -0,0 +1,156 @@ +""" +Agent 工具基类 +""" + +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional, Type +from dataclasses import dataclass, field +from pydantic import BaseModel +import logging +import time + +logger = logging.getLogger(__name__) + + +@dataclass +class ToolResult: + """工具执行结果""" + success: bool + data: Any = None + error: Optional[str] = None + duration_ms: int = 0 + metadata: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + return { + "success": self.success, + "data": self.data, + "error": self.error, + "duration_ms": self.duration_ms, + "metadata": self.metadata, + } + + def to_string(self, max_length: int = 5000) -> str: + """转换为字符串(用于 LLM 输出)""" + if not self.success: + return f"Error: {self.error}" + + if isinstance(self.data, str): + result = self.data + elif isinstance(self.data, (dict, list)): + import json + result = json.dumps(self.data, ensure_ascii=False, indent=2) + else: + result = str(self.data) + + if len(result) > max_length: + result = result[:max_length] + f"\n... (truncated, total {len(result)} chars)" + + return result + + +class AgentTool(ABC): + """ + Agent 工具基类 + 所有工具需要继承此类并实现必要的方法 + """ + + def __init__(self): + self._call_count = 0 + self._total_duration_ms = 0 + + @property + @abstractmethod + def name(self) -> str: + """工具名称""" + pass + + @property + @abstractmethod + def description(self) -> str: + """工具描述(用于 Agent 理解工具功能)""" + pass + + @property + def args_schema(self) -> Optional[Type[BaseModel]]: + """参数 Schema(Pydantic 模型)""" + return None + + @abstractmethod + async def _execute(self, **kwargs) -> ToolResult: + """执行工具(子类实现)""" + pass + + async def execute(self, **kwargs) -> ToolResult: + """执行工具(带计时和日志)""" + start_time = time.time() + + try: + logger.debug(f"Tool '{self.name}' executing with args: {kwargs}") + result = await self._execute(**kwargs) + + except Exception as e: + logger.error(f"Tool '{self.name}' error: {e}", exc_info=True) + result = ToolResult( + success=False, + error=str(e), + ) + + duration_ms = int((time.time() - start_time) * 1000) + result.duration_ms = duration_ms + + self._call_count += 1 + self._total_duration_ms += duration_ms + + logger.debug(f"Tool '{self.name}' completed in {duration_ms}ms, success={result.success}") + + return result + + def get_langchain_tool(self): + """转换为 LangChain Tool""" + from langchain.tools import Tool, StructuredTool + import asyncio + + def sync_wrapper(**kwargs): + """同步包装器""" + loop = asyncio.get_event_loop() + if loop.is_running(): + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, self.execute(**kwargs)) + result = future.result() + else: + result = asyncio.run(self.execute(**kwargs)) + return result.to_string() + + async def async_wrapper(**kwargs): + """异步包装器""" + result = await self.execute(**kwargs) + return result.to_string() + + if self.args_schema: + return StructuredTool( + name=self.name, + description=self.description, + func=sync_wrapper, + coroutine=async_wrapper, + args_schema=self.args_schema, + ) + else: + return Tool( + name=self.name, + description=self.description, + func=lambda x: sync_wrapper(query=x), + coroutine=lambda x: async_wrapper(query=x), + ) + + @property + def stats(self) -> Dict[str, Any]: + """工具使用统计""" + return { + "name": self.name, + "call_count": self._call_count, + "total_duration_ms": self._total_duration_ms, + "avg_duration_ms": self._total_duration_ms // max(1, self._call_count), + } + diff --git a/backend/app/services/agent/tools/code_analysis_tool.py b/backend/app/services/agent/tools/code_analysis_tool.py new file mode 100644 index 0000000..c97e51e --- /dev/null +++ b/backend/app/services/agent/tools/code_analysis_tool.py @@ -0,0 +1,427 @@ +""" +代码分析工具 +使用 LLM 深度分析代码安全问题 +""" + +import json +from typing import Optional, List, Dict, Any +from pydantic import BaseModel, Field + +from .base import AgentTool, ToolResult + + +class CodeAnalysisInput(BaseModel): + """代码分析输入""" + code: str = Field(description="要分析的代码内容") + file_path: str = Field(default="unknown", description="文件路径") + language: str = Field(default="python", description="编程语言") + focus: Optional[str] = Field( + default=None, + description="重点关注的漏洞类型,如 sql_injection, xss, command_injection" + ) + context: Optional[str] = Field( + default=None, + description="额外的上下文信息,如相关的其他代码片段" + ) + + +class CodeAnalysisTool(AgentTool): + """ + 代码分析工具 + 使用 LLM 对代码进行深度安全分析 + """ + + def __init__(self, llm_service): + """ + 初始化代码分析工具 + + Args: + llm_service: LLM 服务实例 + """ + super().__init__() + self.llm_service = llm_service + + @property + def name(self) -> str: + return "code_analysis" + + @property + def description(self) -> str: + return """深度分析代码安全问题。 +使用 LLM 对代码进行全面的安全审计,识别潜在漏洞。 + +使用场景: +- 对疑似有问题的代码进行深入分析 +- 分析复杂的业务逻辑漏洞 +- 追踪数据流和污点传播 +- 生成详细的漏洞报告和修复建议 + +输入: +- code: 要分析的代码 +- file_path: 文件路径 +- language: 编程语言 +- focus: 可选,重点关注的漏洞类型 +- context: 可选,额外的上下文代码 + +这个工具会消耗较多的 Token,建议在确认有疑似问题后使用。""" + + @property + def args_schema(self): + return CodeAnalysisInput + + async def _execute( + self, + code: str, + file_path: str = "unknown", + language: str = "python", + focus: Optional[str] = None, + context: Optional[str] = None, + **kwargs + ) -> ToolResult: + """执行代码分析""" + try: + # 构建分析结果 + analysis = await self.llm_service.analyze_code(code, language) + + issues = analysis.get("issues", []) + + if not issues: + return ToolResult( + success=True, + data="代码分析完成,未发现明显的安全问题。\n\n" + f"质量评分: {analysis.get('quality_score', 'N/A')}\n" + f"文件: {file_path}", + metadata={ + "file_path": file_path, + "issues_count": 0, + "quality_score": analysis.get("quality_score"), + } + ) + + # 格式化输出 + output_parts = [f"🔍 代码分析结果 - {file_path}\n"] + output_parts.append(f"发现 {len(issues)} 个问题:\n") + + for i, issue in enumerate(issues): + severity_icon = { + "critical": "🔴", + "high": "🟠", + "medium": "🟡", + "low": "🟢" + }.get(issue.get("severity", ""), "⚪") + + output_parts.append(f"\n{severity_icon} 问题 {i+1}: {issue.get('title', 'Unknown')}") + output_parts.append(f" 类型: {issue.get('type', 'unknown')}") + output_parts.append(f" 严重程度: {issue.get('severity', 'unknown')}") + output_parts.append(f" 行号: {issue.get('line', 'N/A')}") + output_parts.append(f" 描述: {issue.get('description', '')}") + + if issue.get("code_snippet"): + output_parts.append(f" 代码片段:\n ```\n {issue.get('code_snippet')}\n ```") + + if issue.get("suggestion"): + output_parts.append(f" 修复建议: {issue.get('suggestion')}") + + if issue.get("ai_explanation"): + output_parts.append(f" AI解释: {issue.get('ai_explanation')}") + + output_parts.append(f"\n质量评分: {analysis.get('quality_score', 'N/A')}/100") + + return ToolResult( + success=True, + data="\n".join(output_parts), + metadata={ + "file_path": file_path, + "issues_count": len(issues), + "quality_score": analysis.get("quality_score"), + "issues": issues, + } + ) + + except Exception as e: + return ToolResult( + success=False, + error=f"代码分析失败: {str(e)}", + ) + + +class DataFlowAnalysisInput(BaseModel): + """数据流分析输入""" + source_code: str = Field(description="包含数据源的代码") + sink_code: Optional[str] = Field(default=None, description="包含数据汇的代码(如危险函数)") + variable_name: str = Field(description="要追踪的变量名") + file_path: str = Field(default="unknown", description="文件路径") + + +class DataFlowAnalysisTool(AgentTool): + """ + 数据流分析工具 + 追踪变量从源到汇的数据流 + """ + + def __init__(self, llm_service): + super().__init__() + self.llm_service = llm_service + + @property + def name(self) -> str: + return "dataflow_analysis" + + @property + def description(self) -> str: + return """分析代码中的数据流,追踪变量从源(如用户输入)到汇(如危险函数)的路径。 + +使用场景: +- 追踪用户输入如何流向危险函数 +- 分析变量是否经过净化处理 +- 识别污点传播路径 + +输入: +- source_code: 包含数据源的代码 +- sink_code: 包含数据汇的代码(可选) +- variable_name: 要追踪的变量名 +- file_path: 文件路径""" + + @property + def args_schema(self): + return DataFlowAnalysisInput + + async def _execute( + self, + source_code: str, + variable_name: str, + sink_code: Optional[str] = None, + file_path: str = "unknown", + **kwargs + ) -> ToolResult: + """执行数据流分析""" + try: + # 构建分析 prompt + analysis_prompt = f"""分析以下代码中变量 '{variable_name}' 的数据流。 + +源代码: +``` +{source_code} +``` +""" + if sink_code: + analysis_prompt += f""" +汇代码(可能的危险函数): +``` +{sink_code} +``` +""" + + analysis_prompt += f""" +请分析: +1. 变量 '{variable_name}' 的来源是什么?(用户输入、配置、数据库等) +2. 变量在传递过程中是否经过了净化/验证? +3. 变量最终流向了哪些危险函数? +4. 是否存在安全风险? + +请返回 JSON 格式的分析结果,包含: +- source_type: 数据源类型 +- sanitized: 是否经过净化 +- sanitization_methods: 使用的净化方法 +- dangerous_sinks: 流向的危险函数列表 +- risk_level: 风险等级 (high/medium/low/none) +- explanation: 详细解释 +- recommendation: 建议 +""" + + # 调用 LLM 分析 + # 这里使用 analyze_code_with_custom_prompt + result = await self.llm_service.analyze_code_with_custom_prompt( + code=source_code, + language="text", + custom_prompt=analysis_prompt, + ) + + # 格式化输出 + output_parts = [f"📊 数据流分析结果 - 变量: {variable_name}\n"] + + if isinstance(result, dict): + if result.get("source_type"): + output_parts.append(f"数据源: {result.get('source_type')}") + if result.get("sanitized") is not None: + sanitized = "✅ 是" if result.get("sanitized") else "❌ 否" + output_parts.append(f"是否净化: {sanitized}") + if result.get("sanitization_methods"): + output_parts.append(f"净化方法: {', '.join(result.get('sanitization_methods', []))}") + if result.get("dangerous_sinks"): + output_parts.append(f"危险函数: {', '.join(result.get('dangerous_sinks', []))}") + if result.get("risk_level"): + risk_icons = {"high": "🔴", "medium": "🟠", "low": "🟡", "none": "🟢"} + icon = risk_icons.get(result.get("risk_level", ""), "⚪") + output_parts.append(f"风险等级: {icon} {result.get('risk_level', '').upper()}") + if result.get("explanation"): + output_parts.append(f"\n分析: {result.get('explanation')}") + if result.get("recommendation"): + output_parts.append(f"\n建议: {result.get('recommendation')}") + else: + output_parts.append(str(result)) + + return ToolResult( + success=True, + data="\n".join(output_parts), + metadata={ + "variable": variable_name, + "file_path": file_path, + "analysis": result, + } + ) + + except Exception as e: + return ToolResult( + success=False, + error=f"数据流分析失败: {str(e)}", + ) + + +class VulnerabilityValidationInput(BaseModel): + """漏洞验证输入""" + code: str = Field(description="可能存在漏洞的代码") + vulnerability_type: str = Field(description="漏洞类型") + file_path: str = Field(default="unknown", description="文件路径") + line_number: Optional[int] = Field(default=None, description="行号") + context: Optional[str] = Field(default=None, description="额外上下文") + + +class VulnerabilityValidationTool(AgentTool): + """ + 漏洞验证工具 + 验证疑似漏洞是否真实存在 + """ + + def __init__(self, llm_service): + super().__init__() + self.llm_service = llm_service + + @property + def name(self) -> str: + return "vulnerability_validation" + + @property + def description(self) -> str: + return """验证疑似漏洞是否真实存在。 +对发现的潜在漏洞进行深入分析,判断是否为真正的安全问题。 + +输入: +- code: 包含疑似漏洞的代码 +- vulnerability_type: 漏洞类型(如 sql_injection, xss 等) +- file_path: 文件路径 +- line_number: 可选,行号 +- context: 可选,额外的上下文代码 + +输出: +- 验证结果(确认/可能/误报) +- 详细分析 +- 利用条件 +- PoC 思路(如果确认存在漏洞)""" + + @property + def args_schema(self): + return VulnerabilityValidationInput + + async def _execute( + self, + code: str, + vulnerability_type: str, + file_path: str = "unknown", + line_number: Optional[int] = None, + context: Optional[str] = None, + **kwargs + ) -> ToolResult: + """执行漏洞验证""" + try: + validation_prompt = f"""你是一个专业的安全研究员,请验证以下代码中是否真的存在 {vulnerability_type} 漏洞。 + +代码: +``` +{code} +``` + +{f'额外上下文:' + chr(10) + '```' + chr(10) + context + chr(10) + '```' if context else ''} + +请分析: +1. 这段代码是否真的存在 {vulnerability_type} 漏洞? +2. 漏洞的利用条件是什么? +3. 攻击者如何利用这个漏洞? +4. 这是否可能是误报?为什么? + +请返回 JSON 格式: +{{ + "is_vulnerable": true/false/null (null表示无法确定), + "confidence": 0.0-1.0, + "verdict": "confirmed/likely/unlikely/false_positive", + "exploitation_conditions": ["条件1", "条件2"], + "attack_vector": "攻击向量描述", + "poc_idea": "PoC思路(如果存在漏洞)", + "false_positive_reason": "如果是误报,说明原因", + "detailed_analysis": "详细分析" +}} +""" + + result = await self.llm_service.analyze_code_with_custom_prompt( + code=code, + language="text", + custom_prompt=validation_prompt, + ) + + # 格式化输出 + output_parts = [f"🔎 漏洞验证结果 - {vulnerability_type}\n"] + output_parts.append(f"文件: {file_path}") + if line_number: + output_parts.append(f"行号: {line_number}") + output_parts.append("") + + if isinstance(result, dict): + # 验证结果 + verdict_icons = { + "confirmed": "🔴 确认存在漏洞", + "likely": "🟠 可能存在漏洞", + "unlikely": "🟡 可能是误报", + "false_positive": "🟢 误报", + } + verdict = result.get("verdict", "unknown") + output_parts.append(f"判定: {verdict_icons.get(verdict, verdict)}") + + if result.get("confidence"): + output_parts.append(f"置信度: {result.get('confidence') * 100:.0f}%") + + if result.get("exploitation_conditions"): + output_parts.append(f"\n利用条件:") + for cond in result.get("exploitation_conditions", []): + output_parts.append(f" - {cond}") + + if result.get("attack_vector"): + output_parts.append(f"\n攻击向量: {result.get('attack_vector')}") + + if result.get("poc_idea") and verdict in ["confirmed", "likely"]: + output_parts.append(f"\nPoC思路: {result.get('poc_idea')}") + + if result.get("false_positive_reason") and verdict in ["unlikely", "false_positive"]: + output_parts.append(f"\n误报原因: {result.get('false_positive_reason')}") + + if result.get("detailed_analysis"): + output_parts.append(f"\n详细分析:\n{result.get('detailed_analysis')}") + else: + output_parts.append(str(result)) + + return ToolResult( + success=True, + data="\n".join(output_parts), + metadata={ + "vulnerability_type": vulnerability_type, + "file_path": file_path, + "line_number": line_number, + "validation": result, + } + ) + + except Exception as e: + return ToolResult( + success=False, + error=f"漏洞验证失败: {str(e)}", + ) + diff --git a/backend/app/services/agent/tools/external_tools.py b/backend/app/services/agent/tools/external_tools.py new file mode 100644 index 0000000..8fb3caf --- /dev/null +++ b/backend/app/services/agent/tools/external_tools.py @@ -0,0 +1,948 @@ +""" +外部安全工具集成 +集成 Semgrep、Bandit、Gitleaks、TruffleHog、npm audit 等专业安全工具 +""" + +import asyncio +import json +import logging +import os +import tempfile +import shutil +from typing import Optional, List, Dict, Any +from pydantic import BaseModel, Field +from dataclasses import dataclass + +from .base import AgentTool, ToolResult + +logger = logging.getLogger(__name__) + + +# ============ Semgrep 工具 ============ + +class SemgrepInput(BaseModel): + """Semgrep 扫描输入""" + target_path: str = Field(description="要扫描的目录或文件路径(相对于项目根目录)") + rules: Optional[str] = Field( + default="auto", + description="规则集: auto, p/security-audit, p/owasp-top-ten, p/r2c-security-audit, 或自定义规则文件路径" + ) + severity: Optional[str] = Field( + default=None, + description="过滤严重程度: ERROR, WARNING, INFO" + ) + max_results: int = Field(default=50, description="最大返回结果数") + + +class SemgrepTool(AgentTool): + """ + Semgrep 静态分析工具 + + Semgrep 是一款快速、轻量级的静态分析工具,支持多种编程语言。 + 提供丰富的安全规则库,可以检测各种安全漏洞。 + + 官方规则集: + - p/security-audit: 综合安全审计 + - p/owasp-top-ten: OWASP Top 10 漏洞 + - p/r2c-security-audit: R2C 安全审计规则 + - p/python: Python 特定规则 + - p/javascript: JavaScript 特定规则 + """ + + AVAILABLE_RULESETS = [ + "auto", + "p/security-audit", + "p/owasp-top-ten", + "p/r2c-security-audit", + "p/python", + "p/javascript", + "p/typescript", + "p/java", + "p/go", + "p/php", + "p/ruby", + "p/secrets", + "p/sql-injection", + "p/xss", + "p/command-injection", + ] + + def __init__(self, project_root: str): + super().__init__() + self.project_root = project_root + + @property + def name(self) -> str: + return "semgrep_scan" + + @property + def description(self) -> str: + return """使用 Semgrep 进行静态安全分析。 +Semgrep 是业界领先的静态分析工具,支持 30+ 种编程语言。 + +可用规则集: +- auto: 自动选择最佳规则 +- p/security-audit: 综合安全审计 +- p/owasp-top-ten: OWASP Top 10 漏洞检测 +- p/secrets: 密钥泄露检测 +- p/sql-injection: SQL 注入检测 +- p/xss: XSS 检测 +- p/command-injection: 命令注入检测 + +使用场景: +- 快速全面的代码安全扫描 +- 检测常见安全漏洞模式 +- 遵循行业安全标准审计""" + + @property + def args_schema(self): + return SemgrepInput + + async def _execute( + self, + target_path: str = ".", + rules: str = "auto", + severity: Optional[str] = None, + max_results: int = 50, + **kwargs + ) -> ToolResult: + """执行 Semgrep 扫描""" + # 检查 semgrep 是否可用 + if not await self._check_semgrep(): + return ToolResult( + success=False, + error="Semgrep 未安装。请使用 'pip install semgrep' 安装。", + ) + + # 构建完整路径 + full_path = os.path.normpath(os.path.join(self.project_root, target_path)) + if not full_path.startswith(os.path.normpath(self.project_root)): + return ToolResult( + success=False, + error="安全错误:不允许扫描项目目录外的路径", + ) + + # 构建命令 + cmd = ["semgrep", "--json", "--quiet"] + + if rules == "auto": + cmd.extend(["--config", "auto"]) + elif rules.startswith("p/"): + cmd.extend(["--config", rules]) + else: + cmd.extend(["--config", rules]) + + if severity: + cmd.extend(["--severity", severity]) + + cmd.append(full_path) + + try: + proc = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=self.project_root, + ) + stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=300) + + if proc.returncode not in [0, 1]: # 1 means findings were found + return ToolResult( + success=False, + error=f"Semgrep 执行失败: {stderr.decode()[:500]}", + ) + + # 解析结果 + try: + results = json.loads(stdout.decode()) + except json.JSONDecodeError: + return ToolResult( + success=False, + error="无法解析 Semgrep 输出", + ) + + findings = results.get("results", [])[:max_results] + + if not findings: + return ToolResult( + success=True, + data=f"Semgrep 扫描完成,未发现安全问题 (规则集: {rules})", + metadata={"findings_count": 0, "rules": rules} + ) + + # 格式化输出 + output_parts = [f"🔍 Semgrep 扫描结果 (规则集: {rules})\n"] + output_parts.append(f"发现 {len(findings)} 个问题:\n") + + severity_icons = {"ERROR": "🔴", "WARNING": "🟠", "INFO": "🟡"} + + for i, finding in enumerate(findings[:max_results]): + sev = finding.get("extra", {}).get("severity", "INFO") + icon = severity_icons.get(sev, "⚪") + + output_parts.append(f"\n{icon} [{sev}] {finding.get('check_id', 'unknown')}") + output_parts.append(f" 文件: {finding.get('path', '')}:{finding.get('start', {}).get('line', 0)}") + output_parts.append(f" 消息: {finding.get('extra', {}).get('message', '')[:200]}") + + # 代码片段 + lines = finding.get("extra", {}).get("lines", "") + if lines: + output_parts.append(f" 代码: {lines[:150]}") + + return ToolResult( + success=True, + data="\n".join(output_parts), + metadata={ + "findings_count": len(findings), + "rules": rules, + "findings": findings[:10], + } + ) + + except asyncio.TimeoutError: + return ToolResult(success=False, error="Semgrep 扫描超时") + except Exception as e: + return ToolResult(success=False, error=f"Semgrep 执行错误: {str(e)}") + + async def _check_semgrep(self) -> bool: + """检查 Semgrep 是否可用""" + try: + proc = await asyncio.create_subprocess_exec( + "semgrep", "--version", + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + await proc.communicate() + return proc.returncode == 0 + except: + return False + + +# ============ Bandit 工具 (Python) ============ + +class BanditInput(BaseModel): + """Bandit 扫描输入""" + target_path: str = Field(default=".", description="要扫描的 Python 目录或文件") + severity: str = Field(default="medium", description="最低严重程度: low, medium, high") + confidence: str = Field(default="medium", description="最低置信度: low, medium, high") + max_results: int = Field(default=50, description="最大返回结果数") + + +class BanditTool(AgentTool): + """ + Bandit Python 安全扫描工具 + + Bandit 是专门用于 Python 代码的安全分析工具, + 可以检测常见的 Python 安全问题,如: + - 硬编码密码 + - SQL 注入 + - 命令注入 + - 不安全的随机数生成 + - 不安全的反序列化 + """ + + def __init__(self, project_root: str): + super().__init__() + self.project_root = project_root + + @property + def name(self) -> str: + return "bandit_scan" + + @property + def description(self) -> str: + return """使用 Bandit 扫描 Python 代码的安全问题。 +Bandit 是 Python 专用的安全分析工具,由 OpenStack 安全团队开发。 + +检测项目: +- B101: assert 使用 +- B102: exec 使用 +- B103-B108: 文件权限问题 +- B301-B312: pickle/yaml 反序列化 +- B501-B508: SSL/TLS 问题 +- B601-B608: shell/SQL 注入 +- B701-B703: Jinja2 模板问题 + +仅适用于 Python 项目。""" + + @property + def args_schema(self): + return BanditInput + + async def _execute( + self, + target_path: str = ".", + severity: str = "medium", + confidence: str = "medium", + max_results: int = 50, + **kwargs + ) -> ToolResult: + """执行 Bandit 扫描""" + if not await self._check_bandit(): + return ToolResult( + success=False, + error="Bandit 未安装。请使用 'pip install bandit' 安装。", + ) + + full_path = os.path.normpath(os.path.join(self.project_root, target_path)) + if not full_path.startswith(os.path.normpath(self.project_root)): + return ToolResult(success=False, error="安全错误:路径越界") + + # 构建命令 + severity_map = {"low": "l", "medium": "m", "high": "h"} + confidence_map = {"low": "l", "medium": "m", "high": "h"} + + cmd = [ + "bandit", "-r", "-f", "json", + "-ll" if severity == "low" else f"-l{severity_map.get(severity, 'm')}", + f"-i{confidence_map.get(confidence, 'm')}", + full_path + ] + + try: + proc = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=120) + + try: + results = json.loads(stdout.decode()) + except json.JSONDecodeError: + return ToolResult(success=False, error="无法解析 Bandit 输出") + + findings = results.get("results", [])[:max_results] + + if not findings: + return ToolResult( + success=True, + data="Bandit 扫描完成,未发现 Python 安全问题", + metadata={"findings_count": 0} + ) + + output_parts = ["🐍 Bandit Python 安全扫描结果\n"] + output_parts.append(f"发现 {len(findings)} 个问题:\n") + + severity_icons = {"HIGH": "🔴", "MEDIUM": "🟠", "LOW": "🟡"} + + for finding in findings: + sev = finding.get("issue_severity", "LOW") + icon = severity_icons.get(sev, "⚪") + + output_parts.append(f"\n{icon} [{sev}] {finding.get('test_id', '')}: {finding.get('test_name', '')}") + output_parts.append(f" 文件: {finding.get('filename', '')}:{finding.get('line_number', 0)}") + output_parts.append(f" 消息: {finding.get('issue_text', '')[:200]}") + output_parts.append(f" 代码: {finding.get('code', '')[:100]}") + + return ToolResult( + success=True, + data="\n".join(output_parts), + metadata={"findings_count": len(findings), "findings": findings[:10]} + ) + + except asyncio.TimeoutError: + return ToolResult(success=False, error="Bandit 扫描超时") + except Exception as e: + return ToolResult(success=False, error=f"Bandit 执行错误: {str(e)}") + + async def _check_bandit(self) -> bool: + try: + proc = await asyncio.create_subprocess_exec( + "bandit", "--version", + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + await proc.communicate() + return proc.returncode == 0 + except: + return False + + +# ============ Gitleaks 工具 ============ + +class GitleaksInput(BaseModel): + """Gitleaks 扫描输入""" + target_path: str = Field(default=".", description="要扫描的目录") + no_git: bool = Field(default=True, description="不使用 git history,仅扫描文件") + max_results: int = Field(default=50, description="最大返回结果数") + + +class GitleaksTool(AgentTool): + """ + Gitleaks 密钥泄露检测工具 + + Gitleaks 是一款专门用于检测代码中硬编码密钥的工具。 + 可以检测: + - API Keys (AWS, GCP, Azure, GitHub, etc.) + - 私钥 (RSA, SSH, PGP) + - 数据库凭据 + - OAuth tokens + - JWT secrets + """ + + def __init__(self, project_root: str): + super().__init__() + self.project_root = project_root + + @property + def name(self) -> str: + return "gitleaks_scan" + + @property + def description(self) -> str: + return """使用 Gitleaks 检测代码中的密钥泄露。 +Gitleaks 是专业的密钥检测工具,支持 150+ 种密钥类型。 + +检测类型: +- AWS Access Keys / Secret Keys +- GCP API Keys / Service Account Keys +- Azure Credentials +- GitHub / GitLab Tokens +- Private Keys (RSA, SSH, PGP) +- Database Connection Strings +- JWT Secrets +- Slack / Discord Tokens +- 等等... + +建议在代码审计早期使用此工具。""" + + @property + def args_schema(self): + return GitleaksInput + + async def _execute( + self, + target_path: str = ".", + no_git: bool = True, + max_results: int = 50, + **kwargs + ) -> ToolResult: + """执行 Gitleaks 扫描""" + if not await self._check_gitleaks(): + return ToolResult( + success=False, + error="Gitleaks 未安装。请从 https://github.com/gitleaks/gitleaks 安装。", + ) + + full_path = os.path.normpath(os.path.join(self.project_root, target_path)) + if not full_path.startswith(os.path.normpath(self.project_root)): + return ToolResult(success=False, error="安全错误:路径越界") + + cmd = ["gitleaks", "detect", "--source", full_path, "-f", "json"] + if no_git: + cmd.append("--no-git") + + try: + proc = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=120) + + # Gitleaks returns 1 if secrets found + if proc.returncode not in [0, 1]: + return ToolResult(success=False, error=f"Gitleaks 执行失败: {stderr.decode()[:300]}") + + if not stdout.strip(): + return ToolResult( + success=True, + data="🔐 Gitleaks 扫描完成,未发现密钥泄露", + metadata={"findings_count": 0} + ) + + try: + findings = json.loads(stdout.decode()) + except json.JSONDecodeError: + findings = [] + + if not findings: + return ToolResult( + success=True, + data="🔐 Gitleaks 扫描完成,未发现密钥泄露", + metadata={"findings_count": 0} + ) + + findings = findings[:max_results] + + output_parts = ["🔐 Gitleaks 密钥泄露检测结果\n"] + output_parts.append(f"⚠️ 发现 {len(findings)} 处密钥泄露!\n") + + for i, finding in enumerate(findings): + output_parts.append(f"\n🔴 [{i+1}] {finding.get('RuleID', 'unknown')}") + output_parts.append(f" 描述: {finding.get('Description', '')}") + output_parts.append(f" 文件: {finding.get('File', '')}:{finding.get('StartLine', 0)}") + + # 部分遮盖密钥 + secret = finding.get('Secret', '') + if len(secret) > 8: + masked = secret[:4] + '*' * (len(secret) - 8) + secret[-4:] + else: + masked = '*' * len(secret) + output_parts.append(f" 密钥: {masked}") + + return ToolResult( + success=True, + data="\n".join(output_parts), + metadata={ + "findings_count": len(findings), + "findings": [ + {"rule": f.get("RuleID"), "file": f.get("File"), "line": f.get("StartLine")} + for f in findings[:10] + ] + } + ) + + except asyncio.TimeoutError: + return ToolResult(success=False, error="Gitleaks 扫描超时") + except Exception as e: + return ToolResult(success=False, error=f"Gitleaks 执行错误: {str(e)}") + + async def _check_gitleaks(self) -> bool: + try: + proc = await asyncio.create_subprocess_exec( + "gitleaks", "version", + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + await proc.communicate() + return proc.returncode == 0 + except: + return False + + +# ============ npm audit 工具 ============ + +class NpmAuditInput(BaseModel): + """npm audit 扫描输入""" + target_path: str = Field(default=".", description="包含 package.json 的目录") + production_only: bool = Field(default=False, description="仅扫描生产依赖") + + +class NpmAuditTool(AgentTool): + """ + npm audit 依赖漏洞扫描工具 + + 扫描 Node.js 项目的依赖漏洞,基于 npm 官方漏洞数据库。 + """ + + def __init__(self, project_root: str): + super().__init__() + self.project_root = project_root + + @property + def name(self) -> str: + return "npm_audit" + + @property + def description(self) -> str: + return """使用 npm audit 扫描 Node.js 项目的依赖漏洞。 +基于 npm 官方漏洞数据库,检测已知的依赖安全问题。 + +适用于: +- 包含 package.json 的 Node.js 项目 +- 前端项目 (React, Vue, Angular 等) + +需要先运行 npm install 安装依赖。""" + + @property + def args_schema(self): + return NpmAuditInput + + async def _execute( + self, + target_path: str = ".", + production_only: bool = False, + **kwargs + ) -> ToolResult: + """执行 npm audit""" + full_path = os.path.normpath(os.path.join(self.project_root, target_path)) + + # 检查 package.json + package_json = os.path.join(full_path, "package.json") + if not os.path.exists(package_json): + return ToolResult( + success=False, + error=f"未找到 package.json: {target_path}", + ) + + cmd = ["npm", "audit", "--json"] + if production_only: + cmd.append("--production") + + try: + proc = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=full_path, + ) + stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=120) + + try: + results = json.loads(stdout.decode()) + except json.JSONDecodeError: + return ToolResult(success=True, data="npm audit 输出为空或格式错误") + + vulnerabilities = results.get("vulnerabilities", {}) + + if not vulnerabilities: + return ToolResult( + success=True, + data="📦 npm audit 完成,未发现依赖漏洞", + metadata={"findings_count": 0} + ) + + output_parts = ["📦 npm audit 依赖漏洞扫描结果\n"] + + severity_counts = {"critical": 0, "high": 0, "moderate": 0, "low": 0} + for name, vuln in vulnerabilities.items(): + severity = vuln.get("severity", "low") + severity_counts[severity] = severity_counts.get(severity, 0) + 1 + + output_parts.append(f"漏洞统计: 🔴 Critical: {severity_counts['critical']}, 🟠 High: {severity_counts['high']}, 🟡 Moderate: {severity_counts['moderate']}, 🟢 Low: {severity_counts['low']}\n") + + severity_icons = {"critical": "🔴", "high": "🟠", "moderate": "🟡", "low": "🟢"} + + for name, vuln in list(vulnerabilities.items())[:20]: + sev = vuln.get("severity", "low") + icon = severity_icons.get(sev, "⚪") + output_parts.append(f"\n{icon} [{sev.upper()}] {name}") + output_parts.append(f" 版本范围: {vuln.get('range', 'unknown')}") + + via = vuln.get("via", []) + if via and isinstance(via[0], dict): + output_parts.append(f" 来源: {via[0].get('title', '')[:100]}") + + return ToolResult( + success=True, + data="\n".join(output_parts), + metadata={ + "findings_count": len(vulnerabilities), + "severity_counts": severity_counts, + } + ) + + except asyncio.TimeoutError: + return ToolResult(success=False, error="npm audit 超时") + except Exception as e: + return ToolResult(success=False, error=f"npm audit 错误: {str(e)}") + + +# ============ Safety 工具 (Python 依赖) ============ + +class SafetyInput(BaseModel): + """Safety 扫描输入""" + requirements_file: str = Field(default="requirements.txt", description="requirements 文件路径") + + +class SafetyTool(AgentTool): + """ + Safety Python 依赖漏洞扫描工具 + + 检查 Python 依赖中的已知安全漏洞。 + """ + + def __init__(self, project_root: str): + super().__init__() + self.project_root = project_root + + @property + def name(self) -> str: + return "safety_scan" + + @property + def description(self) -> str: + return """使用 Safety 扫描 Python 依赖的安全漏洞。 +基于 PyUp.io 漏洞数据库检测已知的依赖安全问题。 + +适用于: +- 包含 requirements.txt 的 Python 项目 +- Pipenv 项目 (Pipfile.lock) +- Poetry 项目 (poetry.lock)""" + + @property + def args_schema(self): + return SafetyInput + + async def _execute( + self, + requirements_file: str = "requirements.txt", + **kwargs + ) -> ToolResult: + """执行 Safety 扫描""" + full_path = os.path.join(self.project_root, requirements_file) + + if not os.path.exists(full_path): + return ToolResult(success=False, error=f"未找到依赖文件: {requirements_file}") + + try: + proc = await asyncio.create_subprocess_exec( + "safety", "check", "-r", full_path, "--json", + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=60) + + try: + # Safety 输出的 JSON 格式可能不同版本有差异 + output = stdout.decode() + if "No known security" in output: + return ToolResult( + success=True, + data="🐍 Safety 扫描完成,未发现 Python 依赖漏洞", + metadata={"findings_count": 0} + ) + + results = json.loads(output) + except: + return ToolResult(success=True, data=f"Safety 输出:\n{stdout.decode()[:1000]}") + + vulnerabilities = results if isinstance(results, list) else results.get("vulnerabilities", []) + + if not vulnerabilities: + return ToolResult( + success=True, + data="🐍 Safety 扫描完成,未发现 Python 依赖漏洞", + metadata={"findings_count": 0} + ) + + output_parts = ["🐍 Safety Python 依赖漏洞扫描结果\n"] + output_parts.append(f"发现 {len(vulnerabilities)} 个漏洞:\n") + + for vuln in vulnerabilities[:20]: + if isinstance(vuln, list) and len(vuln) >= 4: + output_parts.append(f"\n🔴 {vuln[0]} ({vuln[1]})") + output_parts.append(f" 漏洞 ID: {vuln[4] if len(vuln) > 4 else 'N/A'}") + output_parts.append(f" 描述: {vuln[3][:200] if len(vuln) > 3 else ''}") + + return ToolResult( + success=True, + data="\n".join(output_parts), + metadata={"findings_count": len(vulnerabilities)} + ) + + except Exception as e: + return ToolResult(success=False, error=f"Safety 执行错误: {str(e)}") + + +# ============ TruffleHog 工具 ============ + +class TruffleHogInput(BaseModel): + """TruffleHog 扫描输入""" + target_path: str = Field(default=".", description="要扫描的目录") + only_verified: bool = Field(default=False, description="仅显示已验证的密钥") + + +class TruffleHogTool(AgentTool): + """ + TruffleHog 深度密钥扫描工具 + + TruffleHog 可以检测代码和 Git 历史中的密钥泄露, + 并可以验证密钥是否仍然有效。 + """ + + def __init__(self, project_root: str): + super().__init__() + self.project_root = project_root + + @property + def name(self) -> str: + return "trufflehog_scan" + + @property + def description(self) -> str: + return """使用 TruffleHog 进行深度密钥扫描。 +TruffleHog 可以扫描代码和 Git 历史,并验证密钥是否有效。 + +特点: +- 支持 700+ 种密钥类型 +- 可以验证密钥是否仍然有效 +- 扫描 Git 历史记录 +- 高精度,低误报 + +建议与 Gitleaks 配合使用以获得最佳效果。""" + + @property + def args_schema(self): + return TruffleHogInput + + async def _execute( + self, + target_path: str = ".", + only_verified: bool = False, + **kwargs + ) -> ToolResult: + """执行 TruffleHog 扫描""" + full_path = os.path.normpath(os.path.join(self.project_root, target_path)) + + cmd = ["trufflehog", "filesystem", full_path, "--json"] + if only_verified: + cmd.append("--only-verified") + + try: + proc = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=180) + + if not stdout.strip(): + return ToolResult( + success=True, + data="🔍 TruffleHog 扫描完成,未发现密钥泄露", + metadata={"findings_count": 0} + ) + + # TruffleHog 输出每行一个 JSON 对象 + findings = [] + for line in stdout.decode().strip().split('\n'): + if line.strip(): + try: + findings.append(json.loads(line)) + except: + pass + + if not findings: + return ToolResult( + success=True, + data="🔍 TruffleHog 扫描完成,未发现密钥泄露", + metadata={"findings_count": 0} + ) + + output_parts = ["🔍 TruffleHog 密钥扫描结果\n"] + output_parts.append(f"⚠️ 发现 {len(findings)} 处密钥泄露!\n") + + for i, finding in enumerate(findings[:20]): + verified = "✅ 已验证有效" if finding.get("Verified") else "⚠️ 未验证" + output_parts.append(f"\n🔴 [{i+1}] {finding.get('DetectorName', 'unknown')} - {verified}") + output_parts.append(f" 文件: {finding.get('SourceMetadata', {}).get('Data', {}).get('Filesystem', {}).get('file', '')}") + + return ToolResult( + success=True, + data="\n".join(output_parts), + metadata={"findings_count": len(findings)} + ) + + except asyncio.TimeoutError: + return ToolResult(success=False, error="TruffleHog 扫描超时") + except Exception as e: + return ToolResult(success=False, error=f"TruffleHog 执行错误: {str(e)}") + + +# ============ OSV-Scanner 工具 ============ + +class OSVScannerInput(BaseModel): + """OSV-Scanner 扫描输入""" + target_path: str = Field(default=".", description="要扫描的项目目录") + + +class OSVScannerTool(AgentTool): + """ + OSV-Scanner 开源漏洞扫描工具 + + Google 开源的漏洞扫描工具,使用 OSV 数据库。 + 支持多种包管理器和锁文件。 + """ + + def __init__(self, project_root: str): + super().__init__() + self.project_root = project_root + + @property + def name(self) -> str: + return "osv_scan" + + @property + def description(self) -> str: + return """使用 OSV-Scanner 扫描开源依赖漏洞。 +Google 开源的漏洞扫描工具,使用 OSV (Open Source Vulnerabilities) 数据库。 + +支持: +- package.json / package-lock.json (npm) +- requirements.txt / Pipfile.lock (Python) +- go.mod / go.sum (Go) +- Cargo.lock (Rust) +- pom.xml (Maven) +- Gemfile.lock (Ruby) +- composer.lock (PHP) + +特点: +- 覆盖多种语言和包管理器 +- 使用 Google 维护的漏洞数据库 +- 快速、准确""" + + @property + def args_schema(self): + return OSVScannerInput + + async def _execute( + self, + target_path: str = ".", + **kwargs + ) -> ToolResult: + """执行 OSV-Scanner""" + full_path = os.path.normpath(os.path.join(self.project_root, target_path)) + + try: + proc = await asyncio.create_subprocess_exec( + "osv-scanner", "--json", "-r", full_path, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=120) + + try: + results = json.loads(stdout.decode()) + except: + if "no package sources found" in stdout.decode().lower(): + return ToolResult(success=True, data="OSV-Scanner: 未找到可扫描的包文件") + return ToolResult(success=True, data=f"OSV-Scanner 输出:\n{stdout.decode()[:1000]}") + + vulns = results.get("results", []) + + if not vulns: + return ToolResult( + success=True, + data="📋 OSV-Scanner 扫描完成,未发现依赖漏洞", + metadata={"findings_count": 0} + ) + + total_vulns = sum(len(r.get("vulnerabilities", [])) for r in vulns) + + output_parts = ["📋 OSV-Scanner 开源漏洞扫描结果\n"] + output_parts.append(f"发现 {total_vulns} 个漏洞:\n") + + for result in vulns[:10]: + source = result.get("source", {}).get("path", "unknown") + for vuln in result.get("vulnerabilities", [])[:5]: + vuln_id = vuln.get("id", "") + summary = vuln.get("summary", "")[:100] + output_parts.append(f"\n🔴 {vuln_id}") + output_parts.append(f" 来源: {source}") + output_parts.append(f" 描述: {summary}") + + return ToolResult( + success=True, + data="\n".join(output_parts), + metadata={"findings_count": total_vulns} + ) + + except Exception as e: + return ToolResult(success=False, error=f"OSV-Scanner 执行错误: {str(e)}") + + +# ============ 导出所有工具 ============ + +__all__ = [ + "SemgrepTool", + "BanditTool", + "GitleaksTool", + "NpmAuditTool", + "SafetyTool", + "TruffleHogTool", + "OSVScannerTool", +] + diff --git a/backend/app/services/agent/tools/file_tool.py b/backend/app/services/agent/tools/file_tool.py new file mode 100644 index 0000000..4a6ae0c --- /dev/null +++ b/backend/app/services/agent/tools/file_tool.py @@ -0,0 +1,481 @@ +""" +文件操作工具 +读取和搜索代码文件 +""" + +import os +import re +import fnmatch +from typing import Optional, List, Dict, Any +from pydantic import BaseModel, Field + +from .base import AgentTool, ToolResult + + +class FileReadInput(BaseModel): + """文件读取输入""" + file_path: str = Field(description="文件路径(相对于项目根目录)") + start_line: Optional[int] = Field(default=None, description="起始行号(从1开始)") + end_line: Optional[int] = Field(default=None, description="结束行号") + max_lines: int = Field(default=500, description="最大返回行数") + + +class FileReadTool(AgentTool): + """ + 文件读取工具 + 读取项目中的文件内容 + """ + + def __init__(self, project_root: str): + """ + 初始化文件读取工具 + + Args: + project_root: 项目根目录 + """ + super().__init__() + self.project_root = project_root + + @property + def name(self) -> str: + return "read_file" + + @property + def description(self) -> str: + return """读取项目中的文件内容。 + +使用场景: +- 查看完整的源代码文件 +- 查看特定行范围的代码 +- 获取配置文件内容 + +输入: +- file_path: 文件路径(相对于项目根目录) +- start_line: 可选,起始行号 +- end_line: 可选,结束行号 +- max_lines: 最大返回行数(默认500) + +注意: 为避免输出过长,建议指定行范围或使用 RAG 搜索定位代码。""" + + @property + def args_schema(self): + return FileReadInput + + async def _execute( + self, + file_path: str, + start_line: Optional[int] = None, + end_line: Optional[int] = None, + max_lines: int = 500, + **kwargs + ) -> ToolResult: + """执行文件读取""" + try: + # 安全检查:防止路径遍历 + full_path = os.path.normpath(os.path.join(self.project_root, file_path)) + if not full_path.startswith(os.path.normpath(self.project_root)): + return ToolResult( + success=False, + error="安全错误:不允许访问项目目录外的文件", + ) + + if not os.path.exists(full_path): + return ToolResult( + success=False, + error=f"文件不存在: {file_path}", + ) + + if not os.path.isfile(full_path): + return ToolResult( + success=False, + error=f"不是文件: {file_path}", + ) + + # 检查文件大小 + file_size = os.path.getsize(full_path) + if file_size > 1024 * 1024: # 1MB + return ToolResult( + success=False, + error=f"文件过大 ({file_size / 1024:.1f}KB),请指定行范围", + ) + + # 读取文件 + with open(full_path, 'r', encoding='utf-8', errors='ignore') as f: + lines = f.readlines() + + total_lines = len(lines) + + # 处理行范围 + if start_line is not None: + start_idx = max(0, start_line - 1) + else: + start_idx = 0 + + if end_line is not None: + end_idx = min(total_lines, end_line) + else: + end_idx = min(total_lines, start_idx + max_lines) + + # 截取指定行 + selected_lines = lines[start_idx:end_idx] + + # 添加行号 + numbered_lines = [] + for i, line in enumerate(selected_lines, start=start_idx + 1): + numbered_lines.append(f"{i:4d}| {line.rstrip()}") + + content = '\n'.join(numbered_lines) + + # 检测语言 + ext = os.path.splitext(file_path)[1].lower() + language = { + ".py": "python", ".js": "javascript", ".ts": "typescript", + ".java": "java", ".go": "go", ".rs": "rust", + ".cpp": "cpp", ".c": "c", ".cs": "csharp", + ".php": "php", ".rb": "ruby", ".swift": "swift", + }.get(ext, "text") + + output = f"📄 文件: {file_path}\n" + output += f"行数: {start_idx + 1}-{end_idx} / {total_lines}\n\n" + output += f"```{language}\n{content}\n```" + + if end_idx < total_lines: + output += f"\n\n... 还有 {total_lines - end_idx} 行未显示" + + return ToolResult( + success=True, + data=output, + metadata={ + "file_path": file_path, + "total_lines": total_lines, + "start_line": start_idx + 1, + "end_line": end_idx, + "language": language, + } + ) + + except Exception as e: + return ToolResult( + success=False, + error=f"读取文件失败: {str(e)}", + ) + + +class FileSearchInput(BaseModel): + """文件搜索输入""" + keyword: str = Field(description="搜索关键字或正则表达式") + file_pattern: Optional[str] = Field(default=None, description="文件名模式,如 *.py, *.js") + directory: Optional[str] = Field(default=None, description="搜索目录(相对路径)") + case_sensitive: bool = Field(default=False, description="是否区分大小写") + max_results: int = Field(default=50, description="最大结果数") + is_regex: bool = Field(default=False, description="是否使用正则表达式") + + +class FileSearchTool(AgentTool): + """ + 文件搜索工具 + 在项目中搜索包含特定内容的代码 + """ + + # 排除的目录 + EXCLUDE_DIRS = { + "node_modules", "vendor", "dist", "build", ".git", + "__pycache__", ".pytest_cache", "coverage", ".nyc_output", + ".vscode", ".idea", ".vs", "target", "venv", "env", + } + + def __init__(self, project_root: str): + super().__init__() + self.project_root = project_root + + @property + def name(self) -> str: + return "search_code" + + @property + def description(self) -> str: + return """在项目代码中搜索关键字或模式。 + +使用场景: +- 查找特定函数的所有调用位置 +- 搜索特定的 API 使用 +- 查找包含特定模式的代码 + +输入: +- keyword: 搜索关键字或正则表达式 +- file_pattern: 可选,文件名模式(如 *.py) +- directory: 可选,搜索目录 +- case_sensitive: 是否区分大小写 +- is_regex: 是否使用正则表达式 + +这是一个快速搜索工具,结果包含匹配行和上下文。""" + + @property + def args_schema(self): + return FileSearchInput + + async def _execute( + self, + keyword: str, + file_pattern: Optional[str] = None, + directory: Optional[str] = None, + case_sensitive: bool = False, + max_results: int = 50, + is_regex: bool = False, + **kwargs + ) -> ToolResult: + """执行文件搜索""" + try: + # 确定搜索目录 + if directory: + search_dir = os.path.normpath(os.path.join(self.project_root, directory)) + if not search_dir.startswith(os.path.normpath(self.project_root)): + return ToolResult( + success=False, + error="安全错误:不允许搜索项目目录外的内容", + ) + else: + search_dir = self.project_root + + # 编译搜索模式 + flags = 0 if case_sensitive else re.IGNORECASE + try: + if is_regex: + pattern = re.compile(keyword, flags) + else: + pattern = re.compile(re.escape(keyword), flags) + except re.error as e: + return ToolResult( + success=False, + error=f"无效的搜索模式: {e}", + ) + + results = [] + files_searched = 0 + + # 遍历文件 + for root, dirs, files in os.walk(search_dir): + # 排除目录 + dirs[:] = [d for d in dirs if d not in self.EXCLUDE_DIRS] + + for filename in files: + # 检查文件模式 + if file_pattern and not fnmatch.fnmatch(filename, file_pattern): + continue + + file_path = os.path.join(root, filename) + relative_path = os.path.relpath(file_path, self.project_root) + + try: + with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: + lines = f.readlines() + + files_searched += 1 + + for i, line in enumerate(lines): + if pattern.search(line): + # 获取上下文 + start = max(0, i - 1) + end = min(len(lines), i + 2) + context_lines = [] + for j in range(start, end): + prefix = ">" if j == i else " " + context_lines.append(f"{prefix} {j+1:4d}| {lines[j].rstrip()}") + + results.append({ + "file": relative_path, + "line": i + 1, + "match": line.strip()[:200], + "context": '\n'.join(context_lines), + }) + + if len(results) >= max_results: + break + + if len(results) >= max_results: + break + + except Exception: + continue + + if len(results) >= max_results: + break + + if not results: + return ToolResult( + success=True, + data=f"没有找到匹配 '{keyword}' 的内容\n搜索了 {files_searched} 个文件", + metadata={"files_searched": files_searched, "matches": 0} + ) + + # 格式化输出 + output_parts = [f"🔍 搜索结果: '{keyword}'\n"] + output_parts.append(f"找到 {len(results)} 处匹配(搜索了 {files_searched} 个文件)\n") + + for result in results: + output_parts.append(f"\n📄 {result['file']}:{result['line']}") + output_parts.append(f"```\n{result['context']}\n```") + + if len(results) >= max_results: + output_parts.append(f"\n... 结果已截断(最大 {max_results} 条)") + + return ToolResult( + success=True, + data="\n".join(output_parts), + metadata={ + "keyword": keyword, + "files_searched": files_searched, + "matches": len(results), + "results": results[:10], # 只在元数据中保留前10个 + } + ) + + except Exception as e: + return ToolResult( + success=False, + error=f"搜索失败: {str(e)}", + ) + + +class ListFilesInput(BaseModel): + """列出文件输入""" + directory: str = Field(default=".", description="目录路径(相对于项目根目录)") + pattern: Optional[str] = Field(default=None, description="文件名模式,如 *.py") + recursive: bool = Field(default=False, description="是否递归列出子目录") + max_files: int = Field(default=100, description="最大文件数") + + +class ListFilesTool(AgentTool): + """ + 列出文件工具 + 列出目录中的文件 + """ + + EXCLUDE_DIRS = { + "node_modules", "vendor", "dist", "build", ".git", + "__pycache__", ".pytest_cache", "coverage", + } + + def __init__(self, project_root: str): + super().__init__() + self.project_root = project_root + + @property + def name(self) -> str: + return "list_files" + + @property + def description(self) -> str: + return """列出目录中的文件。 + +使用场景: +- 了解项目结构 +- 查找特定类型的文件 +- 浏览目录内容 + +输入: +- directory: 目录路径 +- pattern: 可选,文件名模式 +- recursive: 是否递归 +- max_files: 最大文件数""" + + @property + def args_schema(self): + return ListFilesInput + + async def _execute( + self, + directory: str = ".", + pattern: Optional[str] = None, + recursive: bool = False, + max_files: int = 100, + **kwargs + ) -> ToolResult: + """执行文件列表""" + try: + target_dir = os.path.normpath(os.path.join(self.project_root, directory)) + if not target_dir.startswith(os.path.normpath(self.project_root)): + return ToolResult( + success=False, + error="安全错误:不允许访问项目目录外的目录", + ) + + if not os.path.exists(target_dir): + return ToolResult( + success=False, + error=f"目录不存在: {directory}", + ) + + files = [] + dirs = [] + + if recursive: + for root, dirnames, filenames in os.walk(target_dir): + # 排除目录 + dirnames[:] = [d for d in dirnames if d not in self.EXCLUDE_DIRS] + + for filename in filenames: + if pattern and not fnmatch.fnmatch(filename, pattern): + continue + + full_path = os.path.join(root, filename) + relative_path = os.path.relpath(full_path, self.project_root) + files.append(relative_path) + + if len(files) >= max_files: + break + + if len(files) >= max_files: + break + else: + for item in os.listdir(target_dir): + if item in self.EXCLUDE_DIRS: + continue + + full_path = os.path.join(target_dir, item) + relative_path = os.path.relpath(full_path, self.project_root) + + if os.path.isdir(full_path): + dirs.append(relative_path + "/") + else: + if pattern and not fnmatch.fnmatch(item, pattern): + continue + files.append(relative_path) + + if len(files) >= max_files: + break + + # 格式化输出 + output_parts = [f"📁 目录: {directory}\n"] + + if dirs: + output_parts.append("目录:") + for d in sorted(dirs)[:20]: + output_parts.append(f" 📂 {d}") + if len(dirs) > 20: + output_parts.append(f" ... 还有 {len(dirs) - 20} 个目录") + + if files: + output_parts.append(f"\n文件 ({len(files)}):") + for f in sorted(files): + output_parts.append(f" 📄 {f}") + + if len(files) >= max_files: + output_parts.append(f"\n... 结果已截断(最大 {max_files} 个文件)") + + return ToolResult( + success=True, + data="\n".join(output_parts), + metadata={ + "directory": directory, + "file_count": len(files), + "dir_count": len(dirs), + } + ) + + except Exception as e: + return ToolResult( + success=False, + error=f"列出文件失败: {str(e)}", + ) + diff --git a/backend/app/services/agent/tools/pattern_tool.py b/backend/app/services/agent/tools/pattern_tool.py new file mode 100644 index 0000000..c63e54c --- /dev/null +++ b/backend/app/services/agent/tools/pattern_tool.py @@ -0,0 +1,418 @@ +""" +模式匹配工具 +快速扫描代码中的危险模式 +""" + +import re +from typing import Optional, List, Dict, Any +from pydantic import BaseModel, Field +from dataclasses import dataclass + +from .base import AgentTool, ToolResult + + +@dataclass +class PatternMatch: + """模式匹配结果""" + pattern_name: str + pattern_type: str + file_path: str + line_number: int + matched_text: str + context: str + severity: str + description: str + + +class PatternMatchInput(BaseModel): + """模式匹配输入""" + code: str = Field(description="要扫描的代码内容") + file_path: str = Field(default="unknown", description="文件路径") + pattern_types: Optional[List[str]] = Field( + default=None, + description="要检测的漏洞类型列表,如 ['sql_injection', 'xss']。为空则检测所有类型" + ) + language: Optional[str] = Field(default=None, description="编程语言,用于选择特定模式") + + +class PatternMatchTool(AgentTool): + """ + 模式匹配工具 + 使用正则表达式快速扫描代码中的危险模式 + """ + + # 危险模式定义 + PATTERNS: Dict[str, Dict[str, Any]] = { + # SQL 注入模式 + "sql_injection": { + "patterns": { + "python": [ + (r'cursor\.execute\s*\(\s*["\'].*%[sd].*["\'].*%', "格式化字符串构造SQL"), + (r'cursor\.execute\s*\(\s*f["\']', "f-string构造SQL"), + (r'cursor\.execute\s*\([^,)]+\+', "字符串拼接构造SQL"), + (r'\.execute\s*\(\s*["\'][^"\']*\{', "format()构造SQL"), + (r'text\s*\(\s*["\'].*\+.*["\']', "SQLAlchemy text()拼接"), + ], + "javascript": [ + (r'\.query\s*\(\s*[`"\'].*\$\{', "模板字符串构造SQL"), + (r'\.query\s*\(\s*["\'].*\+', "字符串拼接构造SQL"), + (r'mysql\.query\s*\([^,)]+\+', "MySQL查询拼接"), + ], + "java": [ + (r'Statement.*execute.*\+', "Statement字符串拼接"), + (r'createQuery\s*\([^,)]+\+', "JPA查询拼接"), + (r'\.executeQuery\s*\([^,)]+\+', "executeQuery拼接"), + ], + "php": [ + (r'mysql_query\s*\(\s*["\'].*\.\s*\$', "mysql_query拼接"), + (r'mysqli_query\s*\([^,]+,\s*["\'].*\.\s*\$', "mysqli_query拼接"), + (r'\$pdo->query\s*\(\s*["\'].*\.\s*\$', "PDO query拼接"), + ], + "go": [ + (r'\.Query\s*\([^,)]+\+', "Query字符串拼接"), + (r'\.Exec\s*\([^,)]+\+', "Exec字符串拼接"), + (r'fmt\.Sprintf\s*\([^)]+\)\s*\)', "Sprintf构造SQL"), + ], + }, + "severity": "high", + "description": "SQL注入漏洞:用户输入直接拼接到SQL语句中", + }, + + # XSS 模式 + "xss": { + "patterns": { + "javascript": [ + (r'innerHTML\s*=\s*[^;]+', "innerHTML赋值"), + (r'outerHTML\s*=\s*[^;]+', "outerHTML赋值"), + (r'document\.write\s*\(', "document.write"), + (r'\.html\s*\([^)]+\)', "jQuery html()"), + (r'dangerouslySetInnerHTML', "React dangerouslySetInnerHTML"), + ], + "python": [ + (r'\|\s*safe\b', "Django safe过滤器"), + (r'Markup\s*\(', "Flask Markup"), + (r'mark_safe\s*\(', "Django mark_safe"), + ], + "php": [ + (r'echo\s+\$_(?:GET|POST|REQUEST)', "直接输出用户输入"), + (r'print\s+\$_(?:GET|POST|REQUEST)', "打印用户输入"), + ], + "java": [ + (r'out\.print(?:ln)?\s*\([^)]*request\.getParameter', "直接输出请求参数"), + ], + }, + "severity": "high", + "description": "XSS跨站脚本漏洞:未转义的用户输入被渲染到页面", + }, + + # 命令注入模式 + "command_injection": { + "patterns": { + "python": [ + (r'os\.system\s*\([^)]*\+', "os.system拼接"), + (r'os\.system\s*\([^)]*%', "os.system格式化"), + (r'os\.system\s*\(\s*f["\']', "os.system f-string"), + (r'subprocess\.(?:call|run|Popen)\s*\([^)]*shell\s*=\s*True', "shell=True"), + (r'subprocess\.(?:call|run|Popen)\s*\(\s*["\'][^"\']+%', "subprocess格式化"), + (r'eval\s*\(', "eval()"), + (r'exec\s*\(', "exec()"), + ], + "javascript": [ + (r'exec\s*\([^)]+\+', "exec拼接"), + (r'spawn\s*\([^)]+,\s*\{[^}]*shell:\s*true', "spawn shell"), + (r'eval\s*\(', "eval()"), + (r'Function\s*\(', "Function构造器"), + ], + "php": [ + (r'exec\s*\(\s*\$', "exec变量"), + (r'system\s*\(\s*\$', "system变量"), + (r'passthru\s*\(\s*\$', "passthru变量"), + (r'shell_exec\s*\(\s*\$', "shell_exec变量"), + (r'`[^`]*\$[^`]*`', "反引号命令执行"), + ], + "java": [ + (r'Runtime\.getRuntime\(\)\.exec\s*\([^)]+\+', "Runtime.exec拼接"), + (r'ProcessBuilder[^;]+\+', "ProcessBuilder拼接"), + ], + "go": [ + (r'exec\.Command\s*\([^)]+\+', "exec.Command拼接"), + ], + }, + "severity": "critical", + "description": "命令注入漏洞:用户输入被用于执行系统命令", + }, + + # 路径遍历模式 + "path_traversal": { + "patterns": { + "python": [ + (r'open\s*\([^)]*\+', "open()拼接"), + (r'open\s*\([^)]*%', "open()格式化"), + (r'os\.path\.join\s*\([^)]*request', "join用户输入"), + (r'send_file\s*\([^)]*request', "send_file用户输入"), + ], + "javascript": [ + (r'fs\.read(?:File|FileSync)\s*\([^)]+\+', "readFile拼接"), + (r'path\.join\s*\([^)]*req\.', "path.join用户输入"), + (r'res\.sendFile\s*\([^)]+\+', "sendFile拼接"), + ], + "php": [ + (r'include\s*\(\s*\$', "include变量"), + (r'require\s*\(\s*\$', "require变量"), + (r'file_get_contents\s*\(\s*\$', "file_get_contents变量"), + (r'fopen\s*\(\s*\$', "fopen变量"), + ], + "java": [ + (r'new\s+File\s*\([^)]+request\.getParameter', "File构造用户输入"), + (r'new\s+FileInputStream\s*\([^)]+\+', "FileInputStream拼接"), + ], + }, + "severity": "high", + "description": "路径遍历漏洞:用户可以访问任意文件", + }, + + # SSRF 模式 + "ssrf": { + "patterns": { + "python": [ + (r'requests\.(?:get|post|put|delete)\s*\([^)]*request\.', "requests用户URL"), + (r'urllib\.request\.urlopen\s*\([^)]*request\.', "urlopen用户URL"), + (r'httpx\.(?:get|post)\s*\([^)]*request\.', "httpx用户URL"), + ], + "javascript": [ + (r'fetch\s*\([^)]*req\.', "fetch用户URL"), + (r'axios\.(?:get|post)\s*\([^)]*req\.', "axios用户URL"), + (r'http\.request\s*\([^)]*req\.', "http.request用户URL"), + ], + "java": [ + (r'new\s+URL\s*\([^)]*request\.getParameter', "URL构造用户输入"), + (r'HttpClient[^;]+request\.getParameter', "HttpClient用户URL"), + ], + "php": [ + (r'curl_setopt[^;]+CURLOPT_URL[^;]+\$', "curl用户URL"), + (r'file_get_contents\s*\(\s*\$_', "file_get_contents用户URL"), + ], + }, + "severity": "high", + "description": "SSRF漏洞:服务端请求用户控制的URL", + }, + + # 不安全的反序列化 + "deserialization": { + "patterns": { + "python": [ + (r'pickle\.loads?\s*\(', "pickle反序列化"), + (r'yaml\.load\s*\([^)]*(?!Loader)', "yaml.load无安全Loader"), + (r'yaml\.unsafe_load\s*\(', "yaml.unsafe_load"), + (r'marshal\.loads?\s*\(', "marshal反序列化"), + ], + "javascript": [ + (r'serialize\s*\(', "serialize"), + (r'unserialize\s*\(', "unserialize"), + ], + "java": [ + (r'ObjectInputStream\s*\(', "ObjectInputStream"), + (r'XMLDecoder\s*\(', "XMLDecoder"), + (r'readObject\s*\(', "readObject"), + ], + "php": [ + (r'unserialize\s*\(\s*\$', "unserialize用户输入"), + ], + }, + "severity": "critical", + "description": "不安全的反序列化:可能导致远程代码执行", + }, + + # 硬编码密钥 + "hardcoded_secret": { + "patterns": { + "_common": [ + (r'(?:password|passwd|pwd)\s*=\s*["\'][^"\']{4,}["\']', "硬编码密码"), + (r'(?:secret|api_?key|apikey|token|auth)\s*=\s*["\'][^"\']{8,}["\']', "硬编码密钥"), + (r'(?:private_?key|priv_?key)\s*=\s*["\'][^"\']+["\']', "硬编码私钥"), + (r'-----BEGIN\s+(?:RSA\s+)?PRIVATE\s+KEY-----', "私钥"), + (r'(?:aws_?access_?key|aws_?secret)\s*=\s*["\'][^"\']+["\']', "AWS密钥"), + (r'(?:ghp_|gho_|github_pat_)[a-zA-Z0-9]{36,}', "GitHub Token"), + (r'sk-[a-zA-Z0-9]{48}', "OpenAI API Key"), + (r'(?:bearer|authorization)\s*[=:]\s*["\'][^"\']{20,}["\']', "Bearer Token"), + ], + }, + "severity": "medium", + "description": "硬编码密钥:敏感信息不应该硬编码在代码中", + }, + + # 弱加密 + "weak_crypto": { + "patterns": { + "python": [ + (r'hashlib\.md5\s*\(', "MD5哈希"), + (r'hashlib\.sha1\s*\(', "SHA1哈希"), + (r'DES\s*\(', "DES加密"), + (r'random\.random\s*\(', "不安全随机数"), + ], + "javascript": [ + (r'crypto\.createHash\s*\(\s*["\']md5["\']', "MD5哈希"), + (r'crypto\.createHash\s*\(\s*["\']sha1["\']', "SHA1哈希"), + (r'Math\.random\s*\(', "Math.random"), + ], + "java": [ + (r'MessageDigest\.getInstance\s*\(\s*["\']MD5["\']', "MD5哈希"), + (r'MessageDigest\.getInstance\s*\(\s*["\']SHA-?1["\']', "SHA1哈希"), + (r'DESKeySpec', "DES密钥"), + ], + "php": [ + (r'md5\s*\(', "MD5哈希"), + (r'sha1\s*\(', "SHA1哈希"), + (r'mcrypt_', "mcrypt已废弃"), + ], + }, + "severity": "low", + "description": "弱加密算法:使用了不安全的加密或哈希算法", + }, + } + + @property + def name(self) -> str: + return "pattern_match" + + @property + def description(self) -> str: + vuln_types = ", ".join(self.PATTERNS.keys()) + return f"""快速扫描代码中的危险模式和常见漏洞。 +使用正则表达式检测已知的不安全代码模式。 + +支持的漏洞类型: {vuln_types} + +这是一个快速扫描工具,可以在分析开始时使用来快速发现潜在问题。 +发现的问题需要进一步分析确认。""" + + @property + def args_schema(self): + return PatternMatchInput + + async def _execute( + self, + code: str, + file_path: str = "unknown", + pattern_types: Optional[List[str]] = None, + language: Optional[str] = None, + **kwargs + ) -> ToolResult: + """执行模式匹配""" + matches: List[PatternMatch] = [] + lines = code.split('\n') + + # 确定要检查的漏洞类型 + types_to_check = pattern_types or list(self.PATTERNS.keys()) + + # 自动检测语言 + if not language: + language = self._detect_language(file_path) + + for vuln_type in types_to_check: + if vuln_type not in self.PATTERNS: + continue + + pattern_config = self.PATTERNS[vuln_type] + patterns_dict = pattern_config["patterns"] + + # 获取语言特定模式和通用模式 + patterns_to_use = [] + if language and language in patterns_dict: + patterns_to_use.extend(patterns_dict[language]) + if "_common" in patterns_dict: + patterns_to_use.extend(patterns_dict["_common"]) + + # 如果没有特定语言模式,尝试使用所有模式 + if not patterns_to_use: + for lang, pats in patterns_dict.items(): + if lang != "_common": + patterns_to_use.extend(pats) + + # 执行匹配 + for pattern, pattern_name in patterns_to_use: + try: + for i, line in enumerate(lines): + if re.search(pattern, line, re.IGNORECASE): + # 获取上下文 + start = max(0, i - 2) + end = min(len(lines), i + 3) + context = '\n'.join(f"{j+1}: {lines[j]}" for j in range(start, end)) + + matches.append(PatternMatch( + pattern_name=pattern_name, + pattern_type=vuln_type, + file_path=file_path, + line_number=i + 1, + matched_text=line.strip()[:200], + context=context, + severity=pattern_config["severity"], + description=pattern_config["description"], + )) + except re.error: + continue + + if not matches: + return ToolResult( + success=True, + data="没有检测到已知的危险模式", + metadata={"patterns_checked": len(types_to_check), "matches": 0} + ) + + # 格式化输出 + output_parts = [f"⚠️ 检测到 {len(matches)} 个潜在问题:\n"] + + # 按严重程度排序 + severity_order = {"critical": 0, "high": 1, "medium": 2, "low": 3} + matches.sort(key=lambda x: severity_order.get(x.severity, 4)) + + for match in matches: + severity_icon = {"critical": "🔴", "high": "🟠", "medium": "🟡", "low": "🟢"}.get(match.severity, "⚪") + output_parts.append(f"\n{severity_icon} [{match.severity.upper()}] {match.pattern_type}") + output_parts.append(f" 位置: {match.file_path}:{match.line_number}") + output_parts.append(f" 模式: {match.pattern_name}") + output_parts.append(f" 描述: {match.description}") + output_parts.append(f" 匹配: {match.matched_text}") + output_parts.append(f" 上下文:\n{match.context}") + + return ToolResult( + success=True, + data="\n".join(output_parts), + metadata={ + "matches": len(matches), + "by_severity": { + s: len([m for m in matches if m.severity == s]) + for s in ["critical", "high", "medium", "low"] + }, + "details": [ + { + "type": m.pattern_type, + "severity": m.severity, + "line": m.line_number, + "pattern": m.pattern_name, + } + for m in matches + ] + } + ) + + def _detect_language(self, file_path: str) -> Optional[str]: + """根据文件扩展名检测语言""" + ext_map = { + ".py": "python", + ".js": "javascript", + ".jsx": "javascript", + ".ts": "javascript", + ".tsx": "javascript", + ".java": "java", + ".php": "php", + ".go": "go", + ".rb": "ruby", + } + + for ext, lang in ext_map.items(): + if file_path.lower().endswith(ext): + return lang + + return None + diff --git a/backend/app/services/agent/tools/rag_tool.py b/backend/app/services/agent/tools/rag_tool.py new file mode 100644 index 0000000..7c8b9c6 --- /dev/null +++ b/backend/app/services/agent/tools/rag_tool.py @@ -0,0 +1,293 @@ +""" +RAG 检索工具 +支持语义检索代码 +""" + +from typing import Optional, List +from pydantic import BaseModel, Field + +from .base import AgentTool, ToolResult +from app.services.rag import CodeRetriever + + +class RAGQueryInput(BaseModel): + """RAG 查询输入参数""" + query: str = Field(description="搜索查询,描述你要找的代码功能或特征") + top_k: int = Field(default=10, description="返回结果数量") + file_path: Optional[str] = Field(default=None, description="限定搜索的文件路径") + language: Optional[str] = Field(default=None, description="限定编程语言") + + +class RAGQueryTool(AgentTool): + """ + RAG 代码检索工具 + 使用语义搜索在代码库中查找相关代码 + """ + + def __init__(self, retriever: CodeRetriever): + super().__init__() + self.retriever = retriever + + @property + def name(self) -> str: + return "rag_query" + + @property + def description(self) -> str: + return """在代码库中进行语义搜索。 +使用场景: +- 查找特定功能的实现代码 +- 查找调用某个函数的代码 +- 查找处理用户输入的代码 +- 查找数据库操作相关代码 +- 查找认证/授权相关代码 + +输入: +- query: 描述你要查找的代码,例如 "处理用户登录的函数"、"SQL查询执行"、"文件上传处理" +- top_k: 返回结果数量(默认10) +- file_path: 可选,限定在某个文件中搜索 +- language: 可选,限定编程语言 + +输出: 相关的代码片段列表,包含文件路径、行号、代码内容和相似度分数""" + + @property + def args_schema(self): + return RAGQueryInput + + async def _execute( + self, + query: str, + top_k: int = 10, + file_path: Optional[str] = None, + language: Optional[str] = None, + **kwargs + ) -> ToolResult: + """执行 RAG 检索""" + try: + results = await self.retriever.retrieve( + query=query, + top_k=top_k, + filter_file_path=file_path, + filter_language=language, + ) + + if not results: + return ToolResult( + success=True, + data="没有找到相关代码", + metadata={"query": query, "results_count": 0} + ) + + # 格式化输出 + output_parts = [f"找到 {len(results)} 个相关代码片段:\n"] + + for i, result in enumerate(results): + output_parts.append(f"\n--- 结果 {i+1} (相似度: {result.score:.2f}) ---") + output_parts.append(f"文件: {result.file_path}") + output_parts.append(f"行号: {result.line_start}-{result.line_end}") + if result.name: + output_parts.append(f"名称: {result.name}") + if result.security_indicators: + output_parts.append(f"安全指标: {', '.join(result.security_indicators)}") + output_parts.append(f"代码:\n```{result.language}\n{result.content}\n```") + + return ToolResult( + success=True, + data="\n".join(output_parts), + metadata={ + "query": query, + "results_count": len(results), + "results": [r.to_dict() for r in results], + } + ) + + except Exception as e: + return ToolResult( + success=False, + error=f"RAG 检索失败: {str(e)}", + ) + + +class SecurityCodeSearchInput(BaseModel): + """安全代码搜索输入""" + vulnerability_type: str = Field( + description="漏洞类型: sql_injection, xss, command_injection, path_traversal, ssrf, deserialization, auth_bypass, hardcoded_secret" + ) + top_k: int = Field(default=20, description="返回结果数量") + + +class SecurityCodeSearchTool(AgentTool): + """ + 安全相关代码搜索工具 + 专门用于查找可能存在安全漏洞的代码 + """ + + def __init__(self, retriever: CodeRetriever): + super().__init__() + self.retriever = retriever + + @property + def name(self) -> str: + return "security_code_search" + + @property + def description(self) -> str: + return """搜索可能存在安全漏洞的代码。 +专门针对特定漏洞类型进行搜索。 + +支持的漏洞类型: +- sql_injection: SQL 注入 +- xss: 跨站脚本 +- command_injection: 命令注入 +- path_traversal: 路径遍历 +- ssrf: 服务端请求伪造 +- deserialization: 不安全的反序列化 +- auth_bypass: 认证绕过 +- hardcoded_secret: 硬编码密钥""" + + @property + def args_schema(self): + return SecurityCodeSearchInput + + async def _execute( + self, + vulnerability_type: str, + top_k: int = 20, + **kwargs + ) -> ToolResult: + """执行安全代码搜索""" + try: + results = await self.retriever.retrieve_security_related( + vulnerability_type=vulnerability_type, + top_k=top_k, + ) + + if not results: + return ToolResult( + success=True, + data=f"没有找到与 {vulnerability_type} 相关的代码", + metadata={"vulnerability_type": vulnerability_type, "results_count": 0} + ) + + # 格式化输出 + output_parts = [f"找到 {len(results)} 个可能与 {vulnerability_type} 相关的代码:\n"] + + for i, result in enumerate(results): + output_parts.append(f"\n--- 可疑代码 {i+1} ---") + output_parts.append(f"文件: {result.file_path}:{result.line_start}") + if result.security_indicators: + output_parts.append(f"⚠️ 安全指标: {', '.join(result.security_indicators)}") + output_parts.append(f"代码:\n```{result.language}\n{result.content}\n```") + + return ToolResult( + success=True, + data="\n".join(output_parts), + metadata={ + "vulnerability_type": vulnerability_type, + "results_count": len(results), + } + ) + + except Exception as e: + return ToolResult( + success=False, + error=f"安全代码搜索失败: {str(e)}", + ) + + +class FunctionContextInput(BaseModel): + """函数上下文搜索输入""" + function_name: str = Field(description="函数名称") + file_path: Optional[str] = Field(default=None, description="文件路径") + include_callers: bool = Field(default=True, description="是否包含调用者") + include_callees: bool = Field(default=True, description="是否包含被调用的函数") + + +class FunctionContextTool(AgentTool): + """ + 函数上下文搜索工具 + 查找函数的定义、调用者和被调用者 + """ + + def __init__(self, retriever: CodeRetriever): + super().__init__() + self.retriever = retriever + + @property + def name(self) -> str: + return "function_context" + + @property + def description(self) -> str: + return """查找函数的上下文信息,包括定义、调用者和被调用的函数。 +用于追踪数据流和理解函数的使用方式。 + +输入: +- function_name: 要查找的函数名 +- file_path: 可选,限定文件路径 +- include_callers: 是否查找调用此函数的代码 +- include_callees: 是否查找此函数调用的其他函数""" + + @property + def args_schema(self): + return FunctionContextInput + + async def _execute( + self, + function_name: str, + file_path: Optional[str] = None, + include_callers: bool = True, + include_callees: bool = True, + **kwargs + ) -> ToolResult: + """执行函数上下文搜索""" + try: + context = await self.retriever.retrieve_function_context( + function_name=function_name, + file_path=file_path, + include_callers=include_callers, + include_callees=include_callees, + ) + + output_parts = [f"函数 '{function_name}' 的上下文分析:\n"] + + # 函数定义 + if context["definition"]: + output_parts.append("### 函数定义:") + for result in context["definition"]: + output_parts.append(f"文件: {result.file_path}:{result.line_start}") + output_parts.append(f"```{result.language}\n{result.content}\n```") + else: + output_parts.append("未找到函数定义") + + # 调用者 + if context["callers"]: + output_parts.append(f"\n### 调用此函数的代码 ({len(context['callers'])} 处):") + for result in context["callers"][:5]: + output_parts.append(f"- {result.file_path}:{result.line_start}") + output_parts.append(f"```{result.language}\n{result.content[:500]}\n```") + + # 被调用者 + if context["callees"]: + output_parts.append(f"\n### 此函数调用的其他函数:") + for result in context["callees"][:5]: + if result.name: + output_parts.append(f"- {result.name} ({result.file_path})") + + return ToolResult( + success=True, + data="\n".join(output_parts), + metadata={ + "function_name": function_name, + "definition_count": len(context["definition"]), + "callers_count": len(context["callers"]), + "callees_count": len(context["callees"]), + } + ) + + except Exception as e: + return ToolResult( + success=False, + error=f"函数上下文搜索失败: {str(e)}", + ) + diff --git a/backend/app/services/agent/tools/sandbox_tool.py b/backend/app/services/agent/tools/sandbox_tool.py new file mode 100644 index 0000000..91a0257 --- /dev/null +++ b/backend/app/services/agent/tools/sandbox_tool.py @@ -0,0 +1,647 @@ +""" +沙箱执行工具 +在 Docker 沙箱中执行代码和命令进行漏洞验证 +""" + +import asyncio +import json +import logging +import tempfile +import os +import shutil +from typing import Optional, List, Dict, Any +from pydantic import BaseModel, Field +from dataclasses import dataclass + +from .base import AgentTool, ToolResult + +logger = logging.getLogger(__name__) + + +@dataclass +class SandboxConfig: + """沙箱配置""" + image: str = "python:3.11-slim" + memory_limit: str = "512m" + cpu_limit: float = 1.0 + timeout: int = 60 + network_mode: str = "none" # none, bridge, host + read_only: bool = True + user: str = "1000:1000" + + +class SandboxManager: + """ + 沙箱管理器 + 管理 Docker 容器的创建、执行和清理 + """ + + def __init__(self, config: Optional[SandboxConfig] = None): + self.config = config or SandboxConfig() + self._docker_client = None + self._initialized = False + + async def initialize(self): + """初始化 Docker 客户端""" + if self._initialized: + return + + try: + import docker + self._docker_client = docker.from_env() + # 测试连接 + self._docker_client.ping() + self._initialized = True + logger.info("Docker sandbox manager initialized") + except Exception as e: + logger.warning(f"Docker not available: {e}") + self._docker_client = None + + @property + def is_available(self) -> bool: + """检查 Docker 是否可用""" + return self._docker_client is not None + + async def execute_command( + self, + command: str, + working_dir: Optional[str] = None, + env: Optional[Dict[str, str]] = None, + timeout: Optional[int] = None, + ) -> Dict[str, Any]: + """ + 在沙箱中执行命令 + + Args: + command: 要执行的命令 + working_dir: 工作目录 + env: 环境变量 + timeout: 超时时间(秒) + + Returns: + 执行结果 + """ + if not self.is_available: + return { + "success": False, + "error": "Docker 不可用", + "stdout": "", + "stderr": "", + "exit_code": -1, + } + + timeout = timeout or self.config.timeout + + try: + # 创建临时目录 + with tempfile.TemporaryDirectory() as temp_dir: + # 准备容器配置 + container_config = { + "image": self.config.image, + "command": ["sh", "-c", command], + "detach": True, + "mem_limit": self.config.memory_limit, + "cpu_period": 100000, + "cpu_quota": int(100000 * self.config.cpu_limit), + "network_mode": self.config.network_mode, + "user": self.config.user, + "read_only": self.config.read_only, + "volumes": { + temp_dir: {"bind": "/workspace", "mode": "rw"}, + }, + "working_dir": working_dir or "/workspace", + "environment": env or {}, + # 安全配置 + "cap_drop": ["ALL"], + "security_opt": ["no-new-privileges:true"], + } + + # 创建并启动容器 + container = await asyncio.to_thread( + self._docker_client.containers.run, + **container_config + ) + + try: + # 等待执行完成 + result = await asyncio.wait_for( + asyncio.to_thread(container.wait), + timeout=timeout + ) + + # 获取日志 + stdout = await asyncio.to_thread( + container.logs, stdout=True, stderr=False + ) + stderr = await asyncio.to_thread( + container.logs, stdout=False, stderr=True + ) + + return { + "success": result["StatusCode"] == 0, + "stdout": stdout.decode('utf-8', errors='ignore')[:10000], + "stderr": stderr.decode('utf-8', errors='ignore')[:2000], + "exit_code": result["StatusCode"], + "error": None, + } + + except asyncio.TimeoutError: + await asyncio.to_thread(container.kill) + return { + "success": False, + "error": f"执行超时 ({timeout}秒)", + "stdout": "", + "stderr": "", + "exit_code": -1, + } + + finally: + # 清理容器 + await asyncio.to_thread(container.remove, force=True) + + except Exception as e: + logger.error(f"Sandbox execution error: {e}") + return { + "success": False, + "error": str(e), + "stdout": "", + "stderr": "", + "exit_code": -1, + } + + async def execute_python( + self, + code: str, + timeout: Optional[int] = None, + ) -> Dict[str, Any]: + """ + 在沙箱中执行 Python 代码 + + Args: + code: Python 代码 + timeout: 超时时间 + + Returns: + 执行结果 + """ + # 转义代码中的单引号 + escaped_code = code.replace("'", "'\\''") + command = f"python3 -c '{escaped_code}'" + return await self.execute_command(command, timeout=timeout) + + async def execute_http_request( + self, + method: str, + url: str, + headers: Optional[Dict[str, str]] = None, + data: Optional[str] = None, + timeout: int = 30, + ) -> Dict[str, Any]: + """ + 在沙箱中执行 HTTP 请求 + + Args: + method: HTTP 方法 + url: URL + headers: 请求头 + data: 请求体 + timeout: 超时 + + Returns: + HTTP 响应 + """ + # 构建 curl 命令 + curl_parts = ["curl", "-s", "-S", "-w", "'\\n%{http_code}'", "-X", method] + + if headers: + for key, value in headers.items(): + curl_parts.extend(["-H", f"'{key}: {value}'"]) + + if data: + curl_parts.extend(["-d", f"'{data}'"]) + + curl_parts.append(f"'{url}'") + + command = " ".join(curl_parts) + + # 使用带网络的镜像 + original_network = self.config.network_mode + self.config.network_mode = "bridge" # 允许网络访问 + + try: + result = await self.execute_command(command, timeout=timeout) + + if result["success"] and result["stdout"]: + lines = result["stdout"].strip().split('\n') + if lines: + status_code = lines[-1].strip() + body = '\n'.join(lines[:-1]) + return { + "success": True, + "status_code": int(status_code) if status_code.isdigit() else 0, + "body": body[:5000], + "error": None, + } + + return { + "success": False, + "status_code": 0, + "body": "", + "error": result.get("error") or result.get("stderr"), + } + + finally: + self.config.network_mode = original_network + + async def verify_vulnerability( + self, + vulnerability_type: str, + target_url: str, + payload: str, + expected_pattern: Optional[str] = None, + ) -> Dict[str, Any]: + """ + 验证漏洞 + + Args: + vulnerability_type: 漏洞类型 + target_url: 目标 URL + payload: 攻击载荷 + expected_pattern: 期望在响应中匹配的模式 + + Returns: + 验证结果 + """ + verification_result = { + "vulnerability_type": vulnerability_type, + "target_url": target_url, + "payload": payload, + "is_vulnerable": False, + "evidence": None, + "error": None, + } + + try: + # 发送请求 + response = await self.execute_http_request( + method="GET" if "?" in target_url else "POST", + url=target_url, + data=payload if "?" not in target_url else None, + ) + + if not response["success"]: + verification_result["error"] = response.get("error") + return verification_result + + body = response.get("body", "") + status_code = response.get("status_code", 0) + + # 检查响应 + if expected_pattern: + import re + if re.search(expected_pattern, body, re.IGNORECASE): + verification_result["is_vulnerable"] = True + verification_result["evidence"] = f"响应中包含预期模式: {expected_pattern}" + else: + # 根据漏洞类型进行通用检查 + if vulnerability_type == "sql_injection": + error_patterns = [ + r"SQL syntax", + r"mysql_fetch", + r"ORA-\d+", + r"PostgreSQL.*ERROR", + r"SQLite.*error", + r"ODBC.*Driver", + ] + for pattern in error_patterns: + if re.search(pattern, body, re.IGNORECASE): + verification_result["is_vulnerable"] = True + verification_result["evidence"] = f"SQL错误信息: {pattern}" + break + + elif vulnerability_type == "xss": + if payload in body: + verification_result["is_vulnerable"] = True + verification_result["evidence"] = "XSS payload 被反射到响应中" + + elif vulnerability_type == "command_injection": + # 检查命令执行结果 + if "uid=" in body or "root:" in body: + verification_result["is_vulnerable"] = True + verification_result["evidence"] = "命令执行成功" + + verification_result["response_status"] = status_code + verification_result["response_length"] = len(body) + + except Exception as e: + verification_result["error"] = str(e) + + return verification_result + + +class SandboxCommandInput(BaseModel): + """沙箱命令输入""" + command: str = Field(description="要执行的命令") + timeout: int = Field(default=30, description="超时时间(秒)") + + +class SandboxTool(AgentTool): + """ + 沙箱执行工具 + 在安全隔离的环境中执行代码和命令 + """ + + # 允许的命令前缀 + ALLOWED_COMMANDS = [ + "python", "python3", "node", "curl", "wget", + "cat", "head", "tail", "grep", "find", "ls", + "echo", "printf", "test", "id", "whoami", + ] + + def __init__(self, sandbox_manager: Optional[SandboxManager] = None): + super().__init__() + self.sandbox_manager = sandbox_manager or SandboxManager() + + @property + def name(self) -> str: + return "sandbox_exec" + + @property + def description(self) -> str: + return """在安全沙箱中执行命令或代码。 +用于验证漏洞、测试 PoC 或执行安全检查。 + +⚠️ 安全限制: +- 命令在 Docker 容器中执行 +- 网络默认隔离 +- 资源有限制 +- 只允许特定命令 + +允许的命令: python, python3, node, curl, cat, grep, find, ls, echo, id + +使用场景: +- 验证命令注入漏洞 +- 执行 PoC 代码 +- 测试 payload 效果""" + + @property + def args_schema(self): + return SandboxCommandInput + + async def _execute( + self, + command: str, + timeout: int = 30, + **kwargs + ) -> ToolResult: + """执行沙箱命令""" + # 初始化沙箱 + await self.sandbox_manager.initialize() + + if not self.sandbox_manager.is_available: + return ToolResult( + success=False, + error="沙箱环境不可用(Docker 未安装或未运行)", + ) + + # 安全检查:验证命令是否允许 + cmd_parts = command.strip().split() + if not cmd_parts: + return ToolResult(success=False, error="命令不能为空") + + base_cmd = cmd_parts[0] + if not any(base_cmd.startswith(allowed) for allowed in self.ALLOWED_COMMANDS): + return ToolResult( + success=False, + error=f"命令 '{base_cmd}' 不在允许列表中。允许的命令: {', '.join(self.ALLOWED_COMMANDS)}", + ) + + # 执行命令 + result = await self.sandbox_manager.execute_command( + command=command, + timeout=timeout, + ) + + # 格式化输出 + output_parts = ["🐳 沙箱执行结果\n"] + output_parts.append(f"命令: {command}") + output_parts.append(f"退出码: {result['exit_code']}") + + if result["stdout"]: + output_parts.append(f"\n标准输出:\n```\n{result['stdout']}\n```") + + if result["stderr"]: + output_parts.append(f"\n标准错误:\n```\n{result['stderr']}\n```") + + if result.get("error"): + output_parts.append(f"\n错误: {result['error']}") + + return ToolResult( + success=result["success"], + data="\n".join(output_parts), + error=result.get("error"), + metadata={ + "command": command, + "exit_code": result["exit_code"], + } + ) + + +class HttpRequestInput(BaseModel): + """HTTP 请求输入""" + method: str = Field(default="GET", description="HTTP 方法 (GET, POST, PUT, DELETE)") + url: str = Field(description="请求 URL") + headers: Optional[Dict[str, str]] = Field(default=None, description="请求头") + data: Optional[str] = Field(default=None, description="请求体") + timeout: int = Field(default=30, description="超时时间(秒)") + + +class SandboxHttpTool(AgentTool): + """ + 沙箱 HTTP 请求工具 + 在沙箱中发送 HTTP 请求 + """ + + def __init__(self, sandbox_manager: Optional[SandboxManager] = None): + super().__init__() + self.sandbox_manager = sandbox_manager or SandboxManager() + + @property + def name(self) -> str: + return "sandbox_http" + + @property + def description(self) -> str: + return """在沙箱中发送 HTTP 请求。 +用于测试 Web 漏洞如 SQL 注入、XSS、SSRF 等。 + +输入: +- method: HTTP 方法 +- url: 请求 URL +- headers: 可选,请求头 +- data: 可选,请求体 +- timeout: 超时时间 + +使用场景: +- 验证 SQL 注入漏洞 +- 测试 XSS payload +- 验证 SSRF 漏洞 +- 测试认证绕过""" + + @property + def args_schema(self): + return HttpRequestInput + + async def _execute( + self, + url: str, + method: str = "GET", + headers: Optional[Dict[str, str]] = None, + data: Optional[str] = None, + timeout: int = 30, + **kwargs + ) -> ToolResult: + """执行 HTTP 请求""" + await self.sandbox_manager.initialize() + + if not self.sandbox_manager.is_available: + return ToolResult( + success=False, + error="沙箱环境不可用", + ) + + result = await self.sandbox_manager.execute_http_request( + method=method, + url=url, + headers=headers, + data=data, + timeout=timeout, + ) + + output_parts = ["🌐 HTTP 请求结果\n"] + output_parts.append(f"请求: {method} {url}") + + if headers: + output_parts.append(f"请求头: {json.dumps(headers, ensure_ascii=False)}") + + if data: + output_parts.append(f"请求体: {data[:500]}") + + output_parts.append(f"\n状态码: {result.get('status_code', 'N/A')}") + + if result.get("body"): + body = result["body"] + if len(body) > 2000: + body = body[:2000] + f"\n... (截断,共 {len(result['body'])} 字符)" + output_parts.append(f"\n响应内容:\n```\n{body}\n```") + + if result.get("error"): + output_parts.append(f"\n错误: {result['error']}") + + return ToolResult( + success=result["success"], + data="\n".join(output_parts), + error=result.get("error"), + metadata={ + "method": method, + "url": url, + "status_code": result.get("status_code"), + "response_length": len(result.get("body", "")), + } + ) + + +class VulnerabilityVerifyInput(BaseModel): + """漏洞验证输入""" + vulnerability_type: str = Field(description="漏洞类型 (sql_injection, xss, command_injection, etc.)") + target_url: str = Field(description="目标 URL") + payload: str = Field(description="攻击载荷") + expected_pattern: Optional[str] = Field(default=None, description="期望在响应中匹配的正则模式") + + +class VulnerabilityVerifyTool(AgentTool): + """ + 漏洞验证工具 + 在沙箱中验证漏洞是否真实存在 + """ + + def __init__(self, sandbox_manager: Optional[SandboxManager] = None): + super().__init__() + self.sandbox_manager = sandbox_manager or SandboxManager() + + @property + def name(self) -> str: + return "verify_vulnerability" + + @property + def description(self) -> str: + return """验证漏洞是否真实存在。 +发送包含攻击载荷的请求,分析响应判断漏洞是否可利用。 + +输入: +- vulnerability_type: 漏洞类型 +- target_url: 目标 URL +- payload: 攻击载荷 +- expected_pattern: 可选,期望在响应中匹配的模式 + +支持的漏洞类型: +- sql_injection: SQL 注入 +- xss: 跨站脚本 +- command_injection: 命令注入 +- path_traversal: 路径遍历 +- ssrf: 服务端请求伪造""" + + @property + def args_schema(self): + return VulnerabilityVerifyInput + + async def _execute( + self, + vulnerability_type: str, + target_url: str, + payload: str, + expected_pattern: Optional[str] = None, + **kwargs + ) -> ToolResult: + """执行漏洞验证""" + await self.sandbox_manager.initialize() + + if not self.sandbox_manager.is_available: + return ToolResult( + success=False, + error="沙箱环境不可用", + ) + + result = await self.sandbox_manager.verify_vulnerability( + vulnerability_type=vulnerability_type, + target_url=target_url, + payload=payload, + expected_pattern=expected_pattern, + ) + + output_parts = ["🔍 漏洞验证结果\n"] + output_parts.append(f"漏洞类型: {vulnerability_type}") + output_parts.append(f"目标: {target_url}") + output_parts.append(f"Payload: {payload[:200]}") + + if result["is_vulnerable"]: + output_parts.append(f"\n🔴 结果: 漏洞已确认!") + output_parts.append(f"证据: {result.get('evidence', 'N/A')}") + else: + output_parts.append(f"\n🟢 结果: 未能确认漏洞") + if result.get("error"): + output_parts.append(f"错误: {result['error']}") + + if result.get("response_status"): + output_parts.append(f"\nHTTP 状态码: {result['response_status']}") + + return ToolResult( + success=True, + data="\n".join(output_parts), + metadata={ + "vulnerability_type": vulnerability_type, + "is_vulnerable": result["is_vulnerable"], + "evidence": result.get("evidence"), + } + ) + diff --git a/backend/app/services/llm/adapters/baidu_adapter.py b/backend/app/services/llm/adapters/baidu_adapter.py index ff447c5..cdf962f 100644 --- a/backend/app/services/llm/adapters/baidu_adapter.py +++ b/backend/app/services/llm/adapters/baidu_adapter.py @@ -144,3 +144,4 @@ class BaiduAdapter(BaseLLMAdapter): return True + diff --git a/backend/app/services/llm/adapters/doubao_adapter.py b/backend/app/services/llm/adapters/doubao_adapter.py index 5b8d5d8..a95da90 100644 --- a/backend/app/services/llm/adapters/doubao_adapter.py +++ b/backend/app/services/llm/adapters/doubao_adapter.py @@ -82,3 +82,4 @@ class DoubaoAdapter(BaseLLMAdapter): return True + diff --git a/backend/app/services/llm/adapters/minimax_adapter.py b/backend/app/services/llm/adapters/minimax_adapter.py index 1e1ea34..e57faa4 100644 --- a/backend/app/services/llm/adapters/minimax_adapter.py +++ b/backend/app/services/llm/adapters/minimax_adapter.py @@ -85,3 +85,4 @@ class MinimaxAdapter(BaseLLMAdapter): return True + diff --git a/backend/app/services/llm/base_adapter.py b/backend/app/services/llm/base_adapter.py index 6eb4a69..1f868ac 100644 --- a/backend/app/services/llm/base_adapter.py +++ b/backend/app/services/llm/base_adapter.py @@ -133,3 +133,4 @@ class BaseLLMAdapter(ABC): self._client = None + diff --git a/backend/app/services/llm/types.py b/backend/app/services/llm/types.py index 5db58e0..5ddfd37 100644 --- a/backend/app/services/llm/types.py +++ b/backend/app/services/llm/types.py @@ -119,3 +119,4 @@ DEFAULT_BASE_URLS: Dict[LLMProvider, str] = { } + diff --git a/backend/app/services/rag/__init__.py b/backend/app/services/rag/__init__.py new file mode 100644 index 0000000..c6a031e --- /dev/null +++ b/backend/app/services/rag/__init__.py @@ -0,0 +1,18 @@ +""" +RAG (Retrieval-Augmented Generation) 系统 +用于代码索引和语义检索 +""" + +from .splitter import CodeSplitter, CodeChunk +from .embeddings import EmbeddingService +from .indexer import CodeIndexer +from .retriever import CodeRetriever + +__all__ = [ + "CodeSplitter", + "CodeChunk", + "EmbeddingService", + "CodeIndexer", + "CodeRetriever", +] + diff --git a/backend/app/services/rag/embeddings.py b/backend/app/services/rag/embeddings.py new file mode 100644 index 0000000..a2ec119 --- /dev/null +++ b/backend/app/services/rag/embeddings.py @@ -0,0 +1,658 @@ +""" +嵌入模型服务 +支持多种嵌入模型提供商: OpenAI, Azure, Ollama, Cohere, HuggingFace, Jina +""" + +import asyncio +import hashlib +import logging +from typing import List, Dict, Any, Optional +from abc import ABC, abstractmethod +from dataclasses import dataclass + +import httpx + +from app.core.config import settings + +logger = logging.getLogger(__name__) + + +@dataclass +class EmbeddingResult: + """嵌入结果""" + embedding: List[float] + tokens_used: int + model: str + + +class EmbeddingProvider(ABC): + """嵌入提供商基类""" + + @abstractmethod + async def embed_text(self, text: str) -> EmbeddingResult: + """嵌入单个文本""" + pass + + @abstractmethod + async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]: + """批量嵌入文本""" + pass + + @property + @abstractmethod + def dimension(self) -> int: + """嵌入向量维度""" + pass + + +class OpenAIEmbedding(EmbeddingProvider): + """OpenAI 嵌入服务""" + + MODELS = { + "text-embedding-3-small": 1536, + "text-embedding-3-large": 3072, + "text-embedding-ada-002": 1536, + } + + def __init__( + self, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + model: str = "text-embedding-3-small", + ): + self.api_key = api_key or settings.LLM_API_KEY + self.base_url = base_url or "https://api.openai.com/v1" + self.model = model + self._dimension = self.MODELS.get(model, 1536) + + @property + def dimension(self) -> int: + return self._dimension + + async def embed_text(self, text: str) -> EmbeddingResult: + results = await self.embed_texts([text]) + return results[0] + + async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]: + if not texts: + return [] + + max_length = 8191 + truncated_texts = [text[:max_length] for text in texts] + + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + payload = { + "model": self.model, + "input": truncated_texts, + } + + url = f"{self.base_url.rstrip('/')}/embeddings" + + async with httpx.AsyncClient(timeout=60) as client: + response = await client.post(url, headers=headers, json=payload) + response.raise_for_status() + data = response.json() + + results = [] + for item in data.get("data", []): + results.append(EmbeddingResult( + embedding=item["embedding"], + tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts), + model=self.model, + )) + + return results + + +class AzureOpenAIEmbedding(EmbeddingProvider): + """ + Azure OpenAI 嵌入服务 + + 使用最新 API 版本 2024-10-21 (GA) + 端点格式: https://.openai.azure.com/openai/deployments//embeddings?api-version=2024-10-21 + """ + + MODELS = { + "text-embedding-3-small": 1536, + "text-embedding-3-large": 3072, + "text-embedding-ada-002": 1536, + } + + # 最新的 GA API 版本 + API_VERSION = "2024-10-21" + + def __init__( + self, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + model: str = "text-embedding-3-small", + ): + self.api_key = api_key + self.base_url = base_url or "https://your-resource.openai.azure.com" + self.model = model + self._dimension = self.MODELS.get(model, 1536) + + @property + def dimension(self) -> int: + return self._dimension + + async def embed_text(self, text: str) -> EmbeddingResult: + results = await self.embed_texts([text]) + return results[0] + + async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]: + if not texts: + return [] + + max_length = 8191 + truncated_texts = [text[:max_length] for text in texts] + + headers = { + "api-key": self.api_key, + "Content-Type": "application/json", + } + + payload = { + "input": truncated_texts, + } + + # Azure URL 格式 - 使用最新 API 版本 + url = f"{self.base_url.rstrip('/')}/openai/deployments/{self.model}/embeddings?api-version={self.API_VERSION}" + + async with httpx.AsyncClient(timeout=60) as client: + response = await client.post(url, headers=headers, json=payload) + response.raise_for_status() + data = response.json() + + results = [] + for item in data.get("data", []): + results.append(EmbeddingResult( + embedding=item["embedding"], + tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts), + model=self.model, + )) + + return results + + +class OllamaEmbedding(EmbeddingProvider): + """ + Ollama 本地嵌入服务 + + 使用新的 /api/embed 端点 (2024年起): + - 支持批量嵌入 + - 使用 'input' 参数(支持字符串或字符串数组) + """ + + MODELS = { + "nomic-embed-text": 768, + "mxbai-embed-large": 1024, + "all-minilm": 384, + "snowflake-arctic-embed": 1024, + "bge-m3": 1024, + "qwen3-embedding": 1024, + } + + def __init__( + self, + base_url: Optional[str] = None, + model: str = "nomic-embed-text", + ): + self.base_url = base_url or "http://localhost:11434" + self.model = model + self._dimension = self.MODELS.get(model, 768) + + @property + def dimension(self) -> int: + return self._dimension + + async def embed_text(self, text: str) -> EmbeddingResult: + results = await self.embed_texts([text]) + return results[0] + + async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]: + if not texts: + return [] + + # 新的 Ollama /api/embed 端点 + url = f"{self.base_url.rstrip('/')}/api/embed" + + payload = { + "model": self.model, + "input": texts, # 新 API 使用 'input' 参数,支持批量 + } + + async with httpx.AsyncClient(timeout=120) as client: + response = await client.post(url, json=payload) + response.raise_for_status() + data = response.json() + + # 新 API 返回格式: {"embeddings": [[...], [...], ...]} + embeddings = data.get("embeddings", []) + + results = [] + for i, embedding in enumerate(embeddings): + results.append(EmbeddingResult( + embedding=embedding, + tokens_used=len(texts[i]) // 4, + model=self.model, + )) + + return results + + +class CohereEmbedding(EmbeddingProvider): + """ + Cohere 嵌入服务 + + 使用新的 v2 API (2024年起): + - 端点: https://api.cohere.com/v2/embed + - 使用 'inputs' 参数替代 'texts' + - 需要指定 'embedding_types' + """ + + MODELS = { + "embed-english-v3.0": 1024, + "embed-multilingual-v3.0": 1024, + "embed-english-light-v3.0": 384, + "embed-multilingual-light-v3.0": 384, + "embed-v4.0": 1024, # 最新模型 + } + + def __init__( + self, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + model: str = "embed-multilingual-v3.0", + ): + self.api_key = api_key + # 新的 v2 API 端点 + self.base_url = base_url or "https://api.cohere.com/v2" + self.model = model + self._dimension = self.MODELS.get(model, 1024) + + @property + def dimension(self) -> int: + return self._dimension + + async def embed_text(self, text: str) -> EmbeddingResult: + results = await self.embed_texts([text]) + return results[0] + + async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]: + if not texts: + return [] + + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + # v2 API 参数格式 + payload = { + "model": self.model, + "inputs": texts, # v2 使用 'inputs' 而非 'texts' + "input_type": "search_document", + "embedding_types": ["float"], # v2 需要指定嵌入类型 + } + + url = f"{self.base_url.rstrip('/')}/embed" + + async with httpx.AsyncClient(timeout=60) as client: + response = await client.post(url, headers=headers, json=payload) + response.raise_for_status() + data = response.json() + + results = [] + # v2 API 返回格式: {"embeddings": {"float": [[...], [...]]}, ...} + embeddings_data = data.get("embeddings", {}) + embeddings = embeddings_data.get("float", []) if isinstance(embeddings_data, dict) else embeddings_data + + for embedding in embeddings: + results.append(EmbeddingResult( + embedding=embedding, + tokens_used=data.get("meta", {}).get("billed_units", {}).get("input_tokens", 0) // max(len(texts), 1), + model=self.model, + )) + + return results + + +class HuggingFaceEmbedding(EmbeddingProvider): + """ + HuggingFace Inference Providers 嵌入服务 + + 使用新的 Router 端点 (2025年起): + https://router.huggingface.co/hf-inference/models/{model}/pipeline/feature-extraction + """ + + MODELS = { + "sentence-transformers/all-MiniLM-L6-v2": 384, + "sentence-transformers/all-mpnet-base-v2": 768, + "BAAI/bge-large-zh-v1.5": 1024, + "BAAI/bge-m3": 1024, + "BAAI/bge-small-en-v1.5": 384, + "BAAI/bge-base-en-v1.5": 768, + } + + def __init__( + self, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + model: str = "BAAI/bge-m3", + ): + self.api_key = api_key + # 新的 Router 端点 + self.base_url = base_url or "https://router.huggingface.co" + self.model = model + self._dimension = self.MODELS.get(model, 1024) + + @property + def dimension(self) -> int: + return self._dimension + + async def embed_text(self, text: str) -> EmbeddingResult: + results = await self.embed_texts([text]) + return results[0] + + async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]: + if not texts: + return [] + + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + # 新的 HuggingFace Router URL 格式 + # https://router.huggingface.co/hf-inference/models/{model}/pipeline/feature-extraction + url = f"{self.base_url.rstrip('/')}/hf-inference/models/{self.model}/pipeline/feature-extraction" + + payload = { + "inputs": texts, + "options": { + "wait_for_model": True, + } + } + + async with httpx.AsyncClient(timeout=120) as client: + response = await client.post(url, headers=headers, json=payload) + response.raise_for_status() + data = response.json() + + results = [] + # HuggingFace 返回格式: [[embedding1], [embedding2], ...] + for embedding in data: + # 有时候返回的是嵌套的列表 + if isinstance(embedding, list) and len(embedding) > 0: + if isinstance(embedding[0], list): + # 取平均或第一个 + embedding = embedding[0] + + results.append(EmbeddingResult( + embedding=embedding, + tokens_used=len(texts[len(results)]) // 4, + model=self.model, + )) + + return results + + +class JinaEmbedding(EmbeddingProvider): + """Jina AI 嵌入服务""" + + MODELS = { + "jina-embeddings-v2-base-code": 768, + "jina-embeddings-v2-base-en": 768, + "jina-embeddings-v2-base-zh": 768, + "jina-embeddings-v2-small-en": 512, + } + + def __init__( + self, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + model: str = "jina-embeddings-v2-base-code", + ): + self.api_key = api_key + self.base_url = base_url or "https://api.jina.ai/v1" + self.model = model + self._dimension = self.MODELS.get(model, 768) + + @property + def dimension(self) -> int: + return self._dimension + + async def embed_text(self, text: str) -> EmbeddingResult: + results = await self.embed_texts([text]) + return results[0] + + async def embed_texts(self, texts: List[str]) -> List[EmbeddingResult]: + if not texts: + return [] + + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + payload = { + "model": self.model, + "input": texts, + } + + url = f"{self.base_url.rstrip('/')}/embeddings" + + async with httpx.AsyncClient(timeout=60) as client: + response = await client.post(url, headers=headers, json=payload) + response.raise_for_status() + data = response.json() + + results = [] + for item in data.get("data", []): + results.append(EmbeddingResult( + embedding=item["embedding"], + tokens_used=data.get("usage", {}).get("total_tokens", 0) // len(texts), + model=self.model, + )) + + return results + + +class EmbeddingService: + """ + 嵌入服务 + 统一管理嵌入模型和缓存 + + 支持的提供商: + - openai: OpenAI 官方 + - azure: Azure OpenAI + - ollama: Ollama 本地 + - cohere: Cohere + - huggingface: HuggingFace Inference API + - jina: Jina AI + """ + + def __init__( + self, + provider: Optional[str] = None, + model: Optional[str] = None, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + cache_enabled: bool = True, + ): + """ + 初始化嵌入服务 + + Args: + provider: 提供商 (openai, azure, ollama, cohere, huggingface, jina) + model: 模型名称 + api_key: API Key + base_url: API Base URL + cache_enabled: 是否启用缓存 + """ + self.cache_enabled = cache_enabled + self._cache: Dict[str, List[float]] = {} + + # 确定提供商 + provider = provider or getattr(settings, 'EMBEDDING_PROVIDER', 'openai') + model = model or getattr(settings, 'EMBEDDING_MODEL', 'text-embedding-3-small') + + # 创建提供商实例 + self._provider = self._create_provider( + provider=provider, + model=model, + api_key=api_key, + base_url=base_url, + ) + + logger.info(f"Embedding service initialized with {provider}/{model}") + + def _create_provider( + self, + provider: str, + model: str, + api_key: Optional[str], + base_url: Optional[str], + ) -> EmbeddingProvider: + """创建嵌入提供商实例""" + provider = provider.lower() + + if provider == "ollama": + return OllamaEmbedding(base_url=base_url, model=model) + + elif provider == "azure": + return AzureOpenAIEmbedding(api_key=api_key, base_url=base_url, model=model) + + elif provider == "cohere": + return CohereEmbedding(api_key=api_key, base_url=base_url, model=model) + + elif provider == "huggingface": + return HuggingFaceEmbedding(api_key=api_key, base_url=base_url, model=model) + + elif provider == "jina": + return JinaEmbedding(api_key=api_key, base_url=base_url, model=model) + + else: + # 默认使用 OpenAI + return OpenAIEmbedding(api_key=api_key, base_url=base_url, model=model) + + @property + def dimension(self) -> int: + """嵌入向量维度""" + return self._provider.dimension + + def _cache_key(self, text: str) -> str: + """生成缓存键""" + return hashlib.sha256(text.encode()).hexdigest()[:32] + + async def embed(self, text: str) -> List[float]: + """ + 嵌入单个文本 + + Args: + text: 文本内容 + + Returns: + 嵌入向量 + """ + if not text or not text.strip(): + return [0.0] * self.dimension + + # 检查缓存 + if self.cache_enabled: + cache_key = self._cache_key(text) + if cache_key in self._cache: + return self._cache[cache_key] + + result = await self._provider.embed_text(text) + + # 存入缓存 + if self.cache_enabled: + self._cache[cache_key] = result.embedding + + return result.embedding + + async def embed_batch( + self, + texts: List[str], + batch_size: int = 100, + show_progress: bool = False, + ) -> List[List[float]]: + """ + 批量嵌入文本 + + Args: + texts: 文本列表 + batch_size: 批次大小 + show_progress: 是否显示进度 + + Returns: + 嵌入向量列表 + """ + if not texts: + return [] + + embeddings = [] + uncached_indices = [] + uncached_texts = [] + + # 检查缓存 + for i, text in enumerate(texts): + if not text or not text.strip(): + embeddings.append([0.0] * self.dimension) + continue + + if self.cache_enabled: + cache_key = self._cache_key(text) + if cache_key in self._cache: + embeddings.append(self._cache[cache_key]) + continue + + embeddings.append(None) # 占位 + uncached_indices.append(i) + uncached_texts.append(text) + + # 批量处理未缓存的文本 + if uncached_texts: + for i in range(0, len(uncached_texts), batch_size): + batch = uncached_texts[i:i + batch_size] + batch_indices = uncached_indices[i:i + batch_size] + + try: + results = await self._provider.embed_texts(batch) + + for idx, result in zip(batch_indices, results): + embeddings[idx] = result.embedding + + # 存入缓存 + if self.cache_enabled: + cache_key = self._cache_key(texts[idx]) + self._cache[cache_key] = result.embedding + + except Exception as e: + logger.error(f"Batch embedding error: {e}") + # 对失败的使用零向量 + for idx in batch_indices: + if embeddings[idx] is None: + embeddings[idx] = [0.0] * self.dimension + + # 添加小延迟避免限流 + await asyncio.sleep(0.1) + + # 确保没有 None + return [e if e is not None else [0.0] * self.dimension for e in embeddings] + + def clear_cache(self): + """清空缓存""" + self._cache.clear() + + @property + def cache_size(self) -> int: + """缓存大小""" + return len(self._cache) + diff --git a/backend/app/services/rag/indexer.py b/backend/app/services/rag/indexer.py new file mode 100644 index 0000000..45607fc --- /dev/null +++ b/backend/app/services/rag/indexer.py @@ -0,0 +1,585 @@ +""" +代码索引器 +将代码分块并索引到向量数据库 +""" + +import os +import asyncio +import logging +from typing import List, Dict, Any, Optional, AsyncGenerator, Callable +from pathlib import Path +from dataclasses import dataclass +import json + +from .splitter import CodeSplitter, CodeChunk +from .embeddings import EmbeddingService + +logger = logging.getLogger(__name__) + + +# 支持的文本文件扩展名 +TEXT_EXTENSIONS = { + ".py", ".js", ".ts", ".tsx", ".jsx", ".java", ".go", ".rs", + ".cpp", ".c", ".h", ".cc", ".hh", ".cs", ".php", ".rb", + ".kt", ".swift", ".sql", ".sh", ".json", ".yml", ".yaml", + ".xml", ".html", ".css", ".vue", ".svelte", ".md", +} + +# 排除的目录 +EXCLUDE_DIRS = { + "node_modules", "vendor", "dist", "build", ".git", + "__pycache__", ".pytest_cache", "coverage", ".nyc_output", + ".vscode", ".idea", ".vs", "target", "out", "bin", "obj", + "__MACOSX", ".next", ".nuxt", "venv", "env", ".env", +} + +# 排除的文件 +EXCLUDE_FILES = { + ".DS_Store", "package-lock.json", "yarn.lock", "pnpm-lock.yaml", + "Cargo.lock", "poetry.lock", "composer.lock", "Gemfile.lock", +} + + +@dataclass +class IndexingProgress: + """索引进度""" + total_files: int = 0 + processed_files: int = 0 + total_chunks: int = 0 + indexed_chunks: int = 0 + current_file: str = "" + errors: List[str] = None + + def __post_init__(self): + if self.errors is None: + self.errors = [] + + @property + def progress_percentage(self) -> float: + if self.total_files == 0: + return 0.0 + return (self.processed_files / self.total_files) * 100 + + +@dataclass +class IndexingResult: + """索引结果""" + success: bool + total_files: int + indexed_files: int + total_chunks: int + errors: List[str] + collection_name: str + + +class VectorStore: + """向量存储抽象基类""" + + async def initialize(self): + """初始化存储""" + pass + + async def add_documents( + self, + ids: List[str], + embeddings: List[List[float]], + documents: List[str], + metadatas: List[Dict[str, Any]], + ): + """添加文档""" + raise NotImplementedError + + async def query( + self, + query_embedding: List[float], + n_results: int = 10, + where: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """查询""" + raise NotImplementedError + + async def delete_collection(self): + """删除集合""" + raise NotImplementedError + + async def get_count(self) -> int: + """获取文档数量""" + raise NotImplementedError + + +class ChromaVectorStore(VectorStore): + """Chroma 向量存储""" + + def __init__( + self, + collection_name: str, + persist_directory: Optional[str] = None, + ): + self.collection_name = collection_name + self.persist_directory = persist_directory + self._client = None + self._collection = None + + async def initialize(self): + """初始化 Chroma""" + try: + import chromadb + from chromadb.config import Settings + + if self.persist_directory: + self._client = chromadb.PersistentClient( + path=self.persist_directory, + settings=Settings(anonymized_telemetry=False), + ) + else: + self._client = chromadb.Client( + settings=Settings(anonymized_telemetry=False), + ) + + self._collection = self._client.get_or_create_collection( + name=self.collection_name, + metadata={"hnsw:space": "cosine"}, + ) + + logger.info(f"Chroma collection '{self.collection_name}' initialized") + + except ImportError: + raise ImportError("chromadb is required. Install with: pip install chromadb") + + async def add_documents( + self, + ids: List[str], + embeddings: List[List[float]], + documents: List[str], + metadatas: List[Dict[str, Any]], + ): + """添加文档到 Chroma""" + if not ids: + return + + # Chroma 对元数据有限制,需要清理 + cleaned_metadatas = [] + for meta in metadatas: + cleaned = {} + for k, v in meta.items(): + if isinstance(v, (str, int, float, bool)): + cleaned[k] = v + elif isinstance(v, list): + # 列表转为 JSON 字符串 + cleaned[k] = json.dumps(v) + elif v is not None: + cleaned[k] = str(v) + cleaned_metadatas.append(cleaned) + + # 分批添加(Chroma 批次限制) + batch_size = 500 + for i in range(0, len(ids), batch_size): + batch_ids = ids[i:i + batch_size] + batch_embeddings = embeddings[i:i + batch_size] + batch_documents = documents[i:i + batch_size] + batch_metadatas = cleaned_metadatas[i:i + batch_size] + + await asyncio.to_thread( + self._collection.add, + ids=batch_ids, + embeddings=batch_embeddings, + documents=batch_documents, + metadatas=batch_metadatas, + ) + + async def query( + self, + query_embedding: List[float], + n_results: int = 10, + where: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """查询 Chroma""" + result = await asyncio.to_thread( + self._collection.query, + query_embeddings=[query_embedding], + n_results=n_results, + where=where, + include=["documents", "metadatas", "distances"], + ) + + return { + "ids": result["ids"][0] if result["ids"] else [], + "documents": result["documents"][0] if result["documents"] else [], + "metadatas": result["metadatas"][0] if result["metadatas"] else [], + "distances": result["distances"][0] if result["distances"] else [], + } + + async def delete_collection(self): + """删除集合""" + if self._client and self._collection: + await asyncio.to_thread( + self._client.delete_collection, + name=self.collection_name, + ) + + async def get_count(self) -> int: + """获取文档数量""" + if self._collection: + return await asyncio.to_thread(self._collection.count) + return 0 + + +class InMemoryVectorStore(VectorStore): + """内存向量存储(用于测试或小项目)""" + + def __init__(self, collection_name: str): + self.collection_name = collection_name + self._documents: Dict[str, Dict[str, Any]] = {} + + async def initialize(self): + """初始化""" + logger.info(f"InMemory vector store '{self.collection_name}' initialized") + + async def add_documents( + self, + ids: List[str], + embeddings: List[List[float]], + documents: List[str], + metadatas: List[Dict[str, Any]], + ): + """添加文档""" + for id_, emb, doc, meta in zip(ids, embeddings, documents, metadatas): + self._documents[id_] = { + "embedding": emb, + "document": doc, + "metadata": meta, + } + + async def query( + self, + query_embedding: List[float], + n_results: int = 10, + where: Optional[Dict[str, Any]] = None, + ) -> Dict[str, Any]: + """查询(使用余弦相似度)""" + import math + + def cosine_similarity(a: List[float], b: List[float]) -> float: + dot = sum(x * y for x, y in zip(a, b)) + norm_a = math.sqrt(sum(x * x for x in a)) + norm_b = math.sqrt(sum(x * x for x in b)) + if norm_a == 0 or norm_b == 0: + return 0.0 + return dot / (norm_a * norm_b) + + results = [] + for id_, data in self._documents.items(): + # 应用过滤条件 + if where: + match = True + for k, v in where.items(): + if data["metadata"].get(k) != v: + match = False + break + if not match: + continue + + similarity = cosine_similarity(query_embedding, data["embedding"]) + results.append({ + "id": id_, + "document": data["document"], + "metadata": data["metadata"], + "distance": 1 - similarity, # 转换为距离 + }) + + # 按距离排序 + results.sort(key=lambda x: x["distance"]) + results = results[:n_results] + + return { + "ids": [r["id"] for r in results], + "documents": [r["document"] for r in results], + "metadatas": [r["metadata"] for r in results], + "distances": [r["distance"] for r in results], + } + + async def delete_collection(self): + """删除集合""" + self._documents.clear() + + async def get_count(self) -> int: + """获取文档数量""" + return len(self._documents) + + +class CodeIndexer: + """ + 代码索引器 + 将代码文件分块、嵌入并索引到向量数据库 + """ + + def __init__( + self, + collection_name: str, + embedding_service: Optional[EmbeddingService] = None, + vector_store: Optional[VectorStore] = None, + splitter: Optional[CodeSplitter] = None, + persist_directory: Optional[str] = None, + ): + """ + 初始化索引器 + + Args: + collection_name: 向量集合名称 + embedding_service: 嵌入服务 + vector_store: 向量存储 + splitter: 代码分块器 + persist_directory: 持久化目录 + """ + self.collection_name = collection_name + self.embedding_service = embedding_service or EmbeddingService() + self.splitter = splitter or CodeSplitter() + + # 创建向量存储 + if vector_store: + self.vector_store = vector_store + else: + try: + self.vector_store = ChromaVectorStore( + collection_name=collection_name, + persist_directory=persist_directory, + ) + except ImportError: + logger.warning("Chroma not available, using in-memory store") + self.vector_store = InMemoryVectorStore(collection_name=collection_name) + + self._initialized = False + + async def initialize(self): + """初始化索引器""" + if not self._initialized: + await self.vector_store.initialize() + self._initialized = True + + async def index_directory( + self, + directory: str, + exclude_patterns: Optional[List[str]] = None, + include_patterns: Optional[List[str]] = None, + progress_callback: Optional[Callable[[IndexingProgress], None]] = None, + ) -> AsyncGenerator[IndexingProgress, None]: + """ + 索引目录中的代码文件 + + Args: + directory: 目录路径 + exclude_patterns: 排除模式 + include_patterns: 包含模式 + progress_callback: 进度回调 + + Yields: + 索引进度 + """ + await self.initialize() + + progress = IndexingProgress() + exclude_patterns = exclude_patterns or [] + + # 收集文件 + files = self._collect_files(directory, exclude_patterns, include_patterns) + progress.total_files = len(files) + + logger.info(f"Found {len(files)} files to index in {directory}") + yield progress + + all_chunks: List[CodeChunk] = [] + + # 分块处理文件 + for file_path in files: + progress.current_file = file_path + + try: + relative_path = os.path.relpath(file_path, directory) + + with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: + content = f.read() + + if not content.strip(): + progress.processed_files += 1 + continue + + # 限制文件大小 + if len(content) > 500000: # 500KB + content = content[:500000] + + # 分块 + chunks = self.splitter.split_file(content, relative_path) + all_chunks.extend(chunks) + + progress.processed_files += 1 + progress.total_chunks = len(all_chunks) + + if progress_callback: + progress_callback(progress) + yield progress + + except Exception as e: + logger.warning(f"Error processing {file_path}: {e}") + progress.errors.append(f"{file_path}: {str(e)}") + progress.processed_files += 1 + + logger.info(f"Created {len(all_chunks)} chunks from {len(files)} files") + + # 批量嵌入和索引 + if all_chunks: + await self._index_chunks(all_chunks, progress) + + progress.indexed_chunks = len(all_chunks) + yield progress + + async def index_files( + self, + files: List[Dict[str, str]], + base_path: str = "", + progress_callback: Optional[Callable[[IndexingProgress], None]] = None, + ) -> AsyncGenerator[IndexingProgress, None]: + """ + 索引文件列表 + + Args: + files: 文件列表 [{"path": "...", "content": "..."}] + base_path: 基础路径 + progress_callback: 进度回调 + + Yields: + 索引进度 + """ + await self.initialize() + + progress = IndexingProgress() + progress.total_files = len(files) + + all_chunks: List[CodeChunk] = [] + + for file_info in files: + file_path = file_info.get("path", "") + content = file_info.get("content", "") + + progress.current_file = file_path + + try: + if not content.strip(): + progress.processed_files += 1 + continue + + # 限制文件大小 + if len(content) > 500000: + content = content[:500000] + + # 分块 + chunks = self.splitter.split_file(content, file_path) + all_chunks.extend(chunks) + + progress.processed_files += 1 + progress.total_chunks = len(all_chunks) + + if progress_callback: + progress_callback(progress) + yield progress + + except Exception as e: + logger.warning(f"Error processing {file_path}: {e}") + progress.errors.append(f"{file_path}: {str(e)}") + progress.processed_files += 1 + + # 批量嵌入和索引 + if all_chunks: + await self._index_chunks(all_chunks, progress) + + progress.indexed_chunks = len(all_chunks) + yield progress + + async def _index_chunks(self, chunks: List[CodeChunk], progress: IndexingProgress): + """索引代码块""" + # 准备嵌入文本 + texts = [chunk.to_embedding_text() for chunk in chunks] + + logger.info(f"Generating embeddings for {len(texts)} chunks...") + + # 批量嵌入 + embeddings = await self.embedding_service.embed_batch(texts, batch_size=50) + + # 准备元数据 + ids = [chunk.id for chunk in chunks] + documents = [chunk.content for chunk in chunks] + metadatas = [chunk.to_dict() for chunk in chunks] + + # 添加到向量存储 + logger.info(f"Adding {len(chunks)} chunks to vector store...") + await self.vector_store.add_documents( + ids=ids, + embeddings=embeddings, + documents=documents, + metadatas=metadatas, + ) + + logger.info(f"Indexed {len(chunks)} chunks successfully") + + def _collect_files( + self, + directory: str, + exclude_patterns: List[str], + include_patterns: Optional[List[str]], + ) -> List[str]: + """收集需要索引的文件""" + import fnmatch + + files = [] + + for root, dirs, filenames in os.walk(directory): + # 过滤目录 + dirs[:] = [d for d in dirs if d not in EXCLUDE_DIRS] + + for filename in filenames: + # 检查扩展名 + ext = os.path.splitext(filename)[1].lower() + if ext not in TEXT_EXTENSIONS: + continue + + # 检查排除文件 + if filename in EXCLUDE_FILES: + continue + + file_path = os.path.join(root, filename) + relative_path = os.path.relpath(file_path, directory) + + # 检查排除模式 + excluded = False + for pattern in exclude_patterns: + if fnmatch.fnmatch(relative_path, pattern) or fnmatch.fnmatch(filename, pattern): + excluded = True + break + + if excluded: + continue + + # 检查包含模式 + if include_patterns: + included = False + for pattern in include_patterns: + if fnmatch.fnmatch(relative_path, pattern) or fnmatch.fnmatch(filename, pattern): + included = True + break + if not included: + continue + + files.append(file_path) + + return files + + async def get_chunk_count(self) -> int: + """获取已索引的代码块数量""" + await self.initialize() + return await self.vector_store.get_count() + + async def clear(self): + """清空索引""" + await self.initialize() + await self.vector_store.delete_collection() + self._initialized = False + diff --git a/backend/app/services/rag/retriever.py b/backend/app/services/rag/retriever.py new file mode 100644 index 0000000..d285bc3 --- /dev/null +++ b/backend/app/services/rag/retriever.py @@ -0,0 +1,469 @@ +""" +代码检索器 +支持语义检索和混合检索 +""" + +import re +import logging +from typing import List, Dict, Any, Optional +from dataclasses import dataclass, field + +from .embeddings import EmbeddingService +from .indexer import VectorStore, ChromaVectorStore, InMemoryVectorStore +from .splitter import CodeChunk, ChunkType + +logger = logging.getLogger(__name__) + + +@dataclass +class RetrievalResult: + """检索结果""" + chunk_id: str + content: str + file_path: str + language: str + chunk_type: str + line_start: int + line_end: int + score: float # 相似度分数 (0-1, 越高越相似) + + # 可选的元数据 + name: Optional[str] = None + parent_name: Optional[str] = None + signature: Optional[str] = None + security_indicators: List[str] = field(default_factory=list) + + # 原始元数据 + metadata: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + return { + "chunk_id": self.chunk_id, + "content": self.content, + "file_path": self.file_path, + "language": self.language, + "chunk_type": self.chunk_type, + "line_start": self.line_start, + "line_end": self.line_end, + "score": self.score, + "name": self.name, + "parent_name": self.parent_name, + "signature": self.signature, + "security_indicators": self.security_indicators, + } + + def to_context_string(self, include_metadata: bool = True) -> str: + """转换为上下文字符串(用于 LLM 输入)""" + parts = [] + + if include_metadata: + header = f"File: {self.file_path}" + if self.line_start and self.line_end: + header += f" (lines {self.line_start}-{self.line_end})" + if self.name: + header += f"\n{self.chunk_type.title()}: {self.name}" + if self.parent_name: + header += f" in {self.parent_name}" + parts.append(header) + + parts.append(f"```{self.language}\n{self.content}\n```") + + return "\n".join(parts) + + +class CodeRetriever: + """ + 代码检索器 + 支持语义检索、关键字检索和混合检索 + """ + + def __init__( + self, + collection_name: str, + embedding_service: Optional[EmbeddingService] = None, + vector_store: Optional[VectorStore] = None, + persist_directory: Optional[str] = None, + ): + """ + 初始化检索器 + + Args: + collection_name: 向量集合名称 + embedding_service: 嵌入服务 + vector_store: 向量存储 + persist_directory: 持久化目录 + """ + self.collection_name = collection_name + self.embedding_service = embedding_service or EmbeddingService() + + # 创建向量存储 + if vector_store: + self.vector_store = vector_store + else: + try: + self.vector_store = ChromaVectorStore( + collection_name=collection_name, + persist_directory=persist_directory, + ) + except ImportError: + logger.warning("Chroma not available, using in-memory store") + self.vector_store = InMemoryVectorStore(collection_name=collection_name) + + self._initialized = False + + async def initialize(self): + """初始化检索器""" + if not self._initialized: + await self.vector_store.initialize() + self._initialized = True + + async def retrieve( + self, + query: str, + top_k: int = 10, + filter_file_path: Optional[str] = None, + filter_language: Optional[str] = None, + filter_chunk_type: Optional[str] = None, + min_score: float = 0.0, + ) -> List[RetrievalResult]: + """ + 语义检索 + + Args: + query: 查询文本 + top_k: 返回数量 + filter_file_path: 文件路径过滤 + filter_language: 语言过滤 + filter_chunk_type: 块类型过滤 + min_score: 最小相似度分数 + + Returns: + 检索结果列表 + """ + await self.initialize() + + # 生成查询嵌入 + query_embedding = await self.embedding_service.embed(query) + + # 构建过滤条件 + where = {} + if filter_file_path: + where["file_path"] = filter_file_path + if filter_language: + where["language"] = filter_language + if filter_chunk_type: + where["chunk_type"] = filter_chunk_type + + # 查询向量存储 + raw_results = await self.vector_store.query( + query_embedding=query_embedding, + n_results=top_k * 2, # 多查一些,后面过滤 + where=where if where else None, + ) + + # 转换结果 + results = [] + for i, (id_, doc, meta, dist) in enumerate(zip( + raw_results["ids"], + raw_results["documents"], + raw_results["metadatas"], + raw_results["distances"], + )): + # 将距离转换为相似度分数 (余弦距离) + score = 1 - dist + + if score < min_score: + continue + + # 解析安全指标(可能是 JSON 字符串) + security_indicators = meta.get("security_indicators", []) + if isinstance(security_indicators, str): + try: + import json + security_indicators = json.loads(security_indicators) + except: + security_indicators = [] + + result = RetrievalResult( + chunk_id=id_, + content=doc, + file_path=meta.get("file_path", ""), + language=meta.get("language", "text"), + chunk_type=meta.get("chunk_type", "unknown"), + line_start=meta.get("line_start", 0), + line_end=meta.get("line_end", 0), + score=score, + name=meta.get("name"), + parent_name=meta.get("parent_name"), + signature=meta.get("signature"), + security_indicators=security_indicators, + metadata=meta, + ) + results.append(result) + + # 按分数排序并截取 + results.sort(key=lambda x: x.score, reverse=True) + return results[:top_k] + + async def retrieve_by_file( + self, + file_path: str, + top_k: int = 50, + ) -> List[RetrievalResult]: + """ + 按文件路径检索 + + Args: + file_path: 文件路径 + top_k: 返回数量 + + Returns: + 该文件的所有代码块 + """ + await self.initialize() + + # 使用一个通用查询 + query_embedding = await self.embedding_service.embed(f"code in {file_path}") + + raw_results = await self.vector_store.query( + query_embedding=query_embedding, + n_results=top_k, + where={"file_path": file_path}, + ) + + results = [] + for id_, doc, meta, dist in zip( + raw_results["ids"], + raw_results["documents"], + raw_results["metadatas"], + raw_results["distances"], + ): + result = RetrievalResult( + chunk_id=id_, + content=doc, + file_path=meta.get("file_path", ""), + language=meta.get("language", "text"), + chunk_type=meta.get("chunk_type", "unknown"), + line_start=meta.get("line_start", 0), + line_end=meta.get("line_end", 0), + score=1 - dist, + name=meta.get("name"), + parent_name=meta.get("parent_name"), + metadata=meta, + ) + results.append(result) + + # 按行号排序 + results.sort(key=lambda x: x.line_start) + return results + + async def retrieve_security_related( + self, + vulnerability_type: Optional[str] = None, + top_k: int = 20, + ) -> List[RetrievalResult]: + """ + 检索与安全相关的代码 + + Args: + vulnerability_type: 漏洞类型(如 sql_injection, xss 等) + top_k: 返回数量 + + Returns: + 安全相关的代码块 + """ + # 根据漏洞类型构建查询 + security_queries = { + "sql_injection": "SQL query execute database user input", + "xss": "HTML render user input innerHTML template", + "command_injection": "system exec command shell subprocess", + "path_traversal": "file path read open user input", + "ssrf": "HTTP request URL user input fetch", + "deserialization": "deserialize pickle yaml load object", + "auth_bypass": "authentication login password token session", + "hardcoded_secret": "password secret key token credential", + } + + if vulnerability_type and vulnerability_type in security_queries: + query = security_queries[vulnerability_type] + else: + query = "security vulnerability dangerous function user input" + + return await self.retrieve(query, top_k=top_k) + + async def retrieve_function_context( + self, + function_name: str, + file_path: Optional[str] = None, + include_callers: bool = True, + include_callees: bool = True, + top_k: int = 10, + ) -> Dict[str, List[RetrievalResult]]: + """ + 检索函数上下文 + + Args: + function_name: 函数名 + file_path: 文件路径(可选) + include_callers: 是否包含调用者 + include_callees: 是否包含被调用者 + top_k: 每类返回数量 + + Returns: + 包含函数定义、调用者、被调用者的字典 + """ + context = { + "definition": [], + "callers": [], + "callees": [], + } + + # 查找函数定义 + definition_query = f"function definition {function_name}" + definitions = await self.retrieve( + definition_query, + top_k=5, + filter_file_path=file_path, + ) + + # 过滤出真正的定义 + for result in definitions: + if result.name == function_name or function_name in (result.content or ""): + context["definition"].append(result) + + if include_callers: + # 查找调用此函数的代码 + caller_query = f"calls {function_name} invoke {function_name}" + callers = await self.retrieve(caller_query, top_k=top_k) + + for result in callers: + # 检查是否真的调用了这个函数 + if re.search(rf'\b{re.escape(function_name)}\s*\(', result.content): + if result not in context["definition"]: + context["callers"].append(result) + + if include_callees and context["definition"]: + # 从函数定义中提取调用的其他函数 + for definition in context["definition"]: + calls = re.findall(r'\b(\w+)\s*\(', definition.content) + unique_calls = list(set(calls))[:5] # 限制数量 + + for call in unique_calls: + if call == function_name: + continue + callees = await self.retrieve( + f"function {call} definition", + top_k=2, + ) + context["callees"].extend(callees) + + return context + + async def retrieve_similar_code( + self, + code_snippet: str, + top_k: int = 5, + exclude_file: Optional[str] = None, + ) -> List[RetrievalResult]: + """ + 检索相似的代码 + + Args: + code_snippet: 代码片段 + top_k: 返回数量 + exclude_file: 排除的文件 + + Returns: + 相似代码列表 + """ + results = await self.retrieve( + f"similar code: {code_snippet}", + top_k=top_k * 2, + ) + + if exclude_file: + results = [r for r in results if r.file_path != exclude_file] + + return results[:top_k] + + async def hybrid_retrieve( + self, + query: str, + keywords: Optional[List[str]] = None, + top_k: int = 10, + semantic_weight: float = 0.7, + ) -> List[RetrievalResult]: + """ + 混合检索(语义 + 关键字) + + Args: + query: 查询文本 + keywords: 额外的关键字 + top_k: 返回数量 + semantic_weight: 语义检索权重 + + Returns: + 检索结果列表 + """ + # 语义检索 + semantic_results = await self.retrieve(query, top_k=top_k * 2) + + # 如果有关键字,进行关键字过滤/增强 + if keywords: + keyword_pattern = '|'.join(re.escape(kw) for kw in keywords) + + enhanced_results = [] + for result in semantic_results: + # 计算关键字匹配度 + matches = len(re.findall(keyword_pattern, result.content, re.IGNORECASE)) + keyword_score = min(1.0, matches / len(keywords)) + + # 混合分数 + hybrid_score = ( + semantic_weight * result.score + + (1 - semantic_weight) * keyword_score + ) + + result.score = hybrid_score + enhanced_results.append(result) + + enhanced_results.sort(key=lambda x: x.score, reverse=True) + return enhanced_results[:top_k] + + return semantic_results[:top_k] + + def format_results_for_llm( + self, + results: List[RetrievalResult], + max_tokens: int = 4000, + include_metadata: bool = True, + ) -> str: + """ + 将检索结果格式化为 LLM 输入 + + Args: + results: 检索结果 + max_tokens: 最大 Token 数 + include_metadata: 是否包含元数据 + + Returns: + 格式化的字符串 + """ + if not results: + return "No relevant code found." + + parts = [] + total_tokens = 0 + + for i, result in enumerate(results): + context = result.to_context_string(include_metadata=include_metadata) + estimated_tokens = len(context) // 4 + + if total_tokens + estimated_tokens > max_tokens: + break + + parts.append(f"### Code Block {i + 1} (Score: {result.score:.2f})\n{context}") + total_tokens += estimated_tokens + + return "\n\n".join(parts) + diff --git a/backend/app/services/rag/splitter.py b/backend/app/services/rag/splitter.py new file mode 100644 index 0000000..78ff634 --- /dev/null +++ b/backend/app/services/rag/splitter.py @@ -0,0 +1,785 @@ +""" +代码分块器 - 基于 Tree-sitter AST 的智能代码分块 +使用先进的 Python 库实现专业级代码解析 +""" + +import re +import hashlib +import logging +from typing import List, Dict, Any, Optional, Tuple, Set +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path + +logger = logging.getLogger(__name__) + + +class ChunkType(Enum): + """代码块类型""" + FILE = "file" + MODULE = "module" + CLASS = "class" + FUNCTION = "function" + METHOD = "method" + INTERFACE = "interface" + STRUCT = "struct" + ENUM = "enum" + IMPORT = "import" + CONSTANT = "constant" + CONFIG = "config" + COMMENT = "comment" + DECORATOR = "decorator" + UNKNOWN = "unknown" + + +@dataclass +class CodeChunk: + """代码块""" + id: str + content: str + file_path: str + language: str + chunk_type: ChunkType + + # 位置信息 + line_start: int = 0 + line_end: int = 0 + byte_start: int = 0 + byte_end: int = 0 + + # 语义信息 + name: Optional[str] = None + parent_name: Optional[str] = None + signature: Optional[str] = None + docstring: Optional[str] = None + + # AST 信息 + ast_type: Optional[str] = None + + # 关联信息 + imports: List[str] = field(default_factory=list) + calls: List[str] = field(default_factory=list) + dependencies: List[str] = field(default_factory=list) + definitions: List[str] = field(default_factory=list) + + # 安全相关 + security_indicators: List[str] = field(default_factory=list) + + # 元数据 + metadata: Dict[str, Any] = field(default_factory=dict) + + # Token 估算 + estimated_tokens: int = 0 + + def __post_init__(self): + if not self.id: + self.id = self._generate_id() + if not self.estimated_tokens: + self.estimated_tokens = self._estimate_tokens() + + def _generate_id(self) -> str: + content = f"{self.file_path}:{self.line_start}:{self.line_end}:{self.content[:100]}" + return hashlib.sha256(content.encode()).hexdigest()[:16] + + def _estimate_tokens(self) -> int: + # 使用 tiktoken 如果可用 + try: + import tiktoken + enc = tiktoken.get_encoding("cl100k_base") + return len(enc.encode(self.content)) + except ImportError: + return len(self.content) // 4 + + def to_dict(self) -> Dict[str, Any]: + return { + "id": self.id, + "content": self.content, + "file_path": self.file_path, + "language": self.language, + "chunk_type": self.chunk_type.value, + "line_start": self.line_start, + "line_end": self.line_end, + "name": self.name, + "parent_name": self.parent_name, + "signature": self.signature, + "docstring": self.docstring, + "ast_type": self.ast_type, + "imports": self.imports, + "calls": self.calls, + "definitions": self.definitions, + "security_indicators": self.security_indicators, + "estimated_tokens": self.estimated_tokens, + "metadata": self.metadata, + } + + def to_embedding_text(self) -> str: + """生成用于嵌入的文本""" + parts = [] + parts.append(f"File: {self.file_path}") + if self.name: + parts.append(f"{self.chunk_type.value.title()}: {self.name}") + if self.parent_name: + parts.append(f"In: {self.parent_name}") + if self.signature: + parts.append(f"Signature: {self.signature}") + if self.docstring: + parts.append(f"Description: {self.docstring[:300]}") + parts.append(f"Code:\n{self.content}") + return "\n".join(parts) + + +class TreeSitterParser: + """ + 基于 Tree-sitter 的代码解析器 + 提供 AST 级别的代码分析 + """ + + # 语言映射 + LANGUAGE_MAP = { + ".py": "python", + ".js": "javascript", + ".jsx": "javascript", + ".ts": "typescript", + ".tsx": "tsx", + ".java": "java", + ".go": "go", + ".rs": "rust", + ".cpp": "cpp", + ".c": "c", + ".h": "c", + ".hpp": "cpp", + ".cs": "c_sharp", + ".php": "php", + ".rb": "ruby", + ".kt": "kotlin", + ".swift": "swift", + } + + # 各语言的函数/类节点类型 + DEFINITION_TYPES = { + "python": { + "class": ["class_definition"], + "function": ["function_definition"], + "method": ["function_definition"], + "import": ["import_statement", "import_from_statement"], + }, + "javascript": { + "class": ["class_declaration", "class"], + "function": ["function_declaration", "function", "arrow_function", "method_definition"], + "import": ["import_statement"], + }, + "typescript": { + "class": ["class_declaration", "class"], + "function": ["function_declaration", "function", "arrow_function", "method_definition"], + "interface": ["interface_declaration"], + "import": ["import_statement"], + }, + "java": { + "class": ["class_declaration"], + "method": ["method_declaration", "constructor_declaration"], + "interface": ["interface_declaration"], + "import": ["import_declaration"], + }, + "go": { + "struct": ["type_declaration"], + "function": ["function_declaration", "method_declaration"], + "interface": ["type_declaration"], + "import": ["import_declaration"], + }, + } + + def __init__(self): + self._parsers: Dict[str, Any] = {} + self._initialized = False + + def _ensure_initialized(self, language: str) -> bool: + """确保语言解析器已初始化""" + if language in self._parsers: + return True + + try: + from tree_sitter_languages import get_parser, get_language + + parser = get_parser(language) + self._parsers[language] = parser + return True + + except ImportError: + logger.warning("tree-sitter-languages not installed, falling back to regex parsing") + return False + except Exception as e: + logger.warning(f"Failed to load tree-sitter parser for {language}: {e}") + return False + + def parse(self, code: str, language: str) -> Optional[Any]: + """解析代码返回 AST""" + if not self._ensure_initialized(language): + return None + + parser = self._parsers.get(language) + if not parser: + return None + + try: + tree = parser.parse(code.encode()) + return tree + except Exception as e: + logger.warning(f"Failed to parse code: {e}") + return None + + def extract_definitions(self, tree: Any, code: str, language: str) -> List[Dict[str, Any]]: + """从 AST 提取定义""" + if tree is None: + return [] + + definitions = [] + definition_types = self.DEFINITION_TYPES.get(language, {}) + + def traverse(node, parent_name=None): + node_type = node.type + + # 检查是否是定义节点 + for def_category, types in definition_types.items(): + if node_type in types: + name = self._extract_name(node, language) + + definitions.append({ + "type": def_category, + "name": name, + "parent_name": parent_name, + "start_point": node.start_point, + "end_point": node.end_point, + "start_byte": node.start_byte, + "end_byte": node.end_byte, + "node_type": node_type, + }) + + # 对于类,继续遍历子节点找方法 + if def_category == "class": + for child in node.children: + traverse(child, name) + return + + # 继续遍历子节点 + for child in node.children: + traverse(child, parent_name) + + traverse(tree.root_node) + return definitions + + def _extract_name(self, node: Any, language: str) -> Optional[str]: + """从节点提取名称""" + # 查找 identifier 子节点 + for child in node.children: + if child.type in ["identifier", "name", "type_identifier", "property_identifier"]: + return child.text.decode() if isinstance(child.text, bytes) else child.text + + # 对于某些语言的特殊处理 + if language == "python": + for child in node.children: + if child.type == "name": + return child.text.decode() if isinstance(child.text, bytes) else child.text + + return None + + +class CodeSplitter: + """ + 高级代码分块器 + 使用 Tree-sitter 进行 AST 解析,支持多种编程语言 + """ + + # 危险函数/模式(用于安全指标) + SECURITY_PATTERNS = { + "python": [ + (r"\bexec\s*\(", "exec"), + (r"\beval\s*\(", "eval"), + (r"\bcompile\s*\(", "compile"), + (r"\bos\.system\s*\(", "os_system"), + (r"\bsubprocess\.", "subprocess"), + (r"\bcursor\.execute\s*\(", "sql_execute"), + (r"\.execute\s*\(.*%", "sql_format"), + (r"\bpickle\.loads?\s*\(", "pickle"), + (r"\byaml\.load\s*\(", "yaml_load"), + (r"\brequests?\.", "http_request"), + (r"password\s*=", "password_assign"), + (r"secret\s*=", "secret_assign"), + (r"api_key\s*=", "api_key_assign"), + ], + "javascript": [ + (r"\beval\s*\(", "eval"), + (r"\bFunction\s*\(", "function_constructor"), + (r"innerHTML\s*=", "innerHTML"), + (r"outerHTML\s*=", "outerHTML"), + (r"document\.write\s*\(", "document_write"), + (r"\.exec\s*\(", "exec"), + (r"\.query\s*\(.*\+", "sql_concat"), + (r"password\s*[=:]", "password_assign"), + (r"apiKey\s*[=:]", "api_key_assign"), + ], + "java": [ + (r"Runtime\.getRuntime\(\)\.exec", "runtime_exec"), + (r"ProcessBuilder", "process_builder"), + (r"\.executeQuery\s*\(.*\+", "sql_concat"), + (r"ObjectInputStream", "deserialization"), + (r"XMLDecoder", "xml_decoder"), + (r"password\s*=", "password_assign"), + ], + "go": [ + (r"exec\.Command\s*\(", "exec_command"), + (r"\.Query\s*\(.*\+", "sql_concat"), + (r"\.Exec\s*\(.*\+", "sql_concat"), + (r"template\.HTML\s*\(", "unsafe_html"), + (r"password\s*=", "password_assign"), + ], + "php": [ + (r"\beval\s*\(", "eval"), + (r"\bexec\s*\(", "exec"), + (r"\bsystem\s*\(", "system"), + (r"\bshell_exec\s*\(", "shell_exec"), + (r"\$_GET\[", "get_input"), + (r"\$_POST\[", "post_input"), + (r"\$_REQUEST\[", "request_input"), + ], + } + + def __init__( + self, + max_chunk_size: int = 1500, + min_chunk_size: int = 100, + overlap_size: int = 50, + preserve_structure: bool = True, + use_tree_sitter: bool = True, + ): + self.max_chunk_size = max_chunk_size + self.min_chunk_size = min_chunk_size + self.overlap_size = overlap_size + self.preserve_structure = preserve_structure + self.use_tree_sitter = use_tree_sitter + + self._ts_parser = TreeSitterParser() if use_tree_sitter else None + + def detect_language(self, file_path: str) -> str: + """检测编程语言""" + ext = Path(file_path).suffix.lower() + return TreeSitterParser.LANGUAGE_MAP.get(ext, "text") + + def split_file( + self, + content: str, + file_path: str, + language: Optional[str] = None + ) -> List[CodeChunk]: + """ + 分割单个文件 + + Args: + content: 文件内容 + file_path: 文件路径 + language: 编程语言(可选) + + Returns: + 代码块列表 + """ + if not content or not content.strip(): + return [] + + if language is None: + language = self.detect_language(file_path) + + chunks = [] + + try: + # 尝试使用 Tree-sitter 解析 + if self.use_tree_sitter and self._ts_parser: + tree = self._ts_parser.parse(content, language) + if tree: + chunks = self._split_by_ast(content, file_path, language, tree) + + # 如果 AST 解析失败或没有结果,使用增强的正则解析 + if not chunks: + chunks = self._split_by_enhanced_regex(content, file_path, language) + + # 如果还是没有,使用基于行的分块 + if not chunks: + chunks = self._split_by_lines(content, file_path, language) + + # 后处理:提取安全指标 + for chunk in chunks: + chunk.security_indicators = self._extract_security_indicators( + chunk.content, language + ) + + # 后处理:使用语义分析增强 + self._enrich_chunks_with_semantics(chunks, content, language) + + except Exception as e: + logger.warning(f"分块失败 {file_path}: {e}, 使用简单分块") + chunks = self._split_by_lines(content, file_path, language) + + return chunks + + def _split_by_ast( + self, + content: str, + file_path: str, + language: str, + tree: Any + ) -> List[CodeChunk]: + """基于 AST 分块""" + chunks = [] + lines = content.split('\n') + + # 提取定义 + definitions = self._ts_parser.extract_definitions(tree, content, language) + + if not definitions: + return [] + + # 为每个定义创建代码块 + for defn in definitions: + start_line = defn["start_point"][0] + end_line = defn["end_point"][0] + + # 提取代码内容 + chunk_lines = lines[start_line:end_line + 1] + chunk_content = '\n'.join(chunk_lines) + + if len(chunk_content.strip()) < self.min_chunk_size // 4: + continue + + chunk_type = ChunkType.CLASS if defn["type"] == "class" else \ + ChunkType.FUNCTION if defn["type"] in ["function", "method"] else \ + ChunkType.INTERFACE if defn["type"] == "interface" else \ + ChunkType.STRUCT if defn["type"] == "struct" else \ + ChunkType.IMPORT if defn["type"] == "import" else \ + ChunkType.MODULE + + chunk = CodeChunk( + id="", + content=chunk_content, + file_path=file_path, + language=language, + chunk_type=chunk_type, + line_start=start_line + 1, + line_end=end_line + 1, + byte_start=defn["start_byte"], + byte_end=defn["end_byte"], + name=defn.get("name"), + parent_name=defn.get("parent_name"), + ast_type=defn.get("node_type"), + ) + + # 如果块太大,进一步分割 + if chunk.estimated_tokens > self.max_chunk_size: + sub_chunks = self._split_large_chunk(chunk) + chunks.extend(sub_chunks) + else: + chunks.append(chunk) + + return chunks + + def _split_by_enhanced_regex( + self, + content: str, + file_path: str, + language: str + ) -> List[CodeChunk]: + """增强的正则表达式分块(支持更多语言)""" + chunks = [] + lines = content.split('\n') + + # 各语言的定义模式 + patterns = { + "python": [ + (r"^(\s*)class\s+(\w+)(?:\s*\([^)]*\))?\s*:", ChunkType.CLASS), + (r"^(\s*)(?:async\s+)?def\s+(\w+)\s*\([^)]*\)\s*(?:->[^:]+)?:", ChunkType.FUNCTION), + ], + "javascript": [ + (r"^(\s*)(?:export\s+)?class\s+(\w+)", ChunkType.CLASS), + (r"^(\s*)(?:export\s+)?(?:async\s+)?function\s*(\w*)\s*\(", ChunkType.FUNCTION), + (r"^(\s*)(?:export\s+)?(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s+)?\([^)]*\)\s*=>", ChunkType.FUNCTION), + ], + "typescript": [ + (r"^(\s*)(?:export\s+)?(?:abstract\s+)?class\s+(\w+)", ChunkType.CLASS), + (r"^(\s*)(?:export\s+)?interface\s+(\w+)", ChunkType.INTERFACE), + (r"^(\s*)(?:export\s+)?(?:async\s+)?function\s*(\w*)", ChunkType.FUNCTION), + ], + "java": [ + (r"^(\s*)(?:public|private|protected)?\s*(?:static\s+)?(?:final\s+)?class\s+(\w+)", ChunkType.CLASS), + (r"^(\s*)(?:public|private|protected)?\s*interface\s+(\w+)", ChunkType.INTERFACE), + (r"^(\s*)(?:public|private|protected)?\s*(?:static\s+)?[\w<>\[\],\s]+\s+(\w+)\s*\([^)]*\)\s*(?:throws\s+[\w,\s]+)?\s*\{", ChunkType.METHOD), + ], + "go": [ + (r"^type\s+(\w+)\s+struct\s*\{", ChunkType.STRUCT), + (r"^type\s+(\w+)\s+interface\s*\{", ChunkType.INTERFACE), + (r"^func\s+(?:\([^)]+\)\s+)?(\w+)\s*\([^)]*\)", ChunkType.FUNCTION), + ], + "php": [ + (r"^(\s*)(?:abstract\s+)?class\s+(\w+)", ChunkType.CLASS), + (r"^(\s*)interface\s+(\w+)", ChunkType.INTERFACE), + (r"^(\s*)(?:public|private|protected)?\s*(?:static\s+)?function\s+(\w+)", ChunkType.FUNCTION), + ], + } + + lang_patterns = patterns.get(language, []) + if not lang_patterns: + return [] + + # 找到所有定义的位置 + definitions = [] + for i, line in enumerate(lines): + for pattern, chunk_type in lang_patterns: + match = re.match(pattern, line) + if match: + indent = len(match.group(1)) if match.lastindex >= 1 else 0 + name = match.group(2) if match.lastindex >= 2 else None + definitions.append({ + "line": i, + "indent": indent, + "name": name, + "type": chunk_type, + }) + break + + if not definitions: + return [] + + # 计算每个定义的范围 + for i, defn in enumerate(definitions): + start_line = defn["line"] + base_indent = defn["indent"] + + # 查找结束位置 + end_line = len(lines) - 1 + for j in range(start_line + 1, len(lines)): + line = lines[j] + if line.strip(): + current_indent = len(line) - len(line.lstrip()) + # 如果缩进回到基础级别,检查是否是下一个定义 + if current_indent <= base_indent: + # 检查是否是下一个定义 + is_next_def = any(d["line"] == j for d in definitions) + if is_next_def or (current_indent < base_indent): + end_line = j - 1 + break + + chunk_content = '\n'.join(lines[start_line:end_line + 1]) + + if len(chunk_content.strip()) < 10: + continue + + chunk = CodeChunk( + id="", + content=chunk_content, + file_path=file_path, + language=language, + chunk_type=defn["type"], + line_start=start_line + 1, + line_end=end_line + 1, + name=defn.get("name"), + ) + + if chunk.estimated_tokens > self.max_chunk_size: + sub_chunks = self._split_large_chunk(chunk) + chunks.extend(sub_chunks) + else: + chunks.append(chunk) + + return chunks + + def _split_by_lines( + self, + content: str, + file_path: str, + language: str + ) -> List[CodeChunk]: + """基于行数分块(回退方案)""" + chunks = [] + lines = content.split('\n') + + # 估算每行 Token 数 + total_tokens = len(content) // 4 + avg_tokens_per_line = max(1, total_tokens // max(1, len(lines))) + lines_per_chunk = max(10, self.max_chunk_size // avg_tokens_per_line) + overlap_lines = self.overlap_size // avg_tokens_per_line + + for i in range(0, len(lines), lines_per_chunk - overlap_lines): + end = min(i + lines_per_chunk, len(lines)) + chunk_content = '\n'.join(lines[i:end]) + + if len(chunk_content.strip()) < 10: + continue + + chunk = CodeChunk( + id="", + content=chunk_content, + file_path=file_path, + language=language, + chunk_type=ChunkType.MODULE, + line_start=i + 1, + line_end=end, + ) + chunks.append(chunk) + + if end >= len(lines): + break + + return chunks + + def _split_large_chunk(self, chunk: CodeChunk) -> List[CodeChunk]: + """分割过大的代码块""" + sub_chunks = [] + lines = chunk.content.split('\n') + + avg_tokens_per_line = max(1, chunk.estimated_tokens // max(1, len(lines))) + lines_per_chunk = max(10, self.max_chunk_size // avg_tokens_per_line) + + for i in range(0, len(lines), lines_per_chunk): + end = min(i + lines_per_chunk, len(lines)) + sub_content = '\n'.join(lines[i:end]) + + if len(sub_content.strip()) < 10: + continue + + sub_chunk = CodeChunk( + id="", + content=sub_content, + file_path=chunk.file_path, + language=chunk.language, + chunk_type=chunk.chunk_type, + line_start=chunk.line_start + i, + line_end=chunk.line_start + end - 1, + name=chunk.name, + parent_name=chunk.parent_name, + ) + sub_chunks.append(sub_chunk) + + return sub_chunks if sub_chunks else [chunk] + + def _extract_security_indicators(self, content: str, language: str) -> List[str]: + """提取安全相关指标""" + indicators = [] + patterns = self.SECURITY_PATTERNS.get(language, []) + + # 添加通用模式 + common_patterns = [ + (r"password", "password"), + (r"secret", "secret"), + (r"api[_-]?key", "api_key"), + (r"token", "token"), + (r"private[_-]?key", "private_key"), + (r"credential", "credential"), + ] + + all_patterns = patterns + common_patterns + + for pattern, name in all_patterns: + try: + if re.search(pattern, content, re.IGNORECASE): + if name not in indicators: + indicators.append(name) + except re.error: + continue + + return indicators[:15] + + def _enrich_chunks_with_semantics( + self, + chunks: List[CodeChunk], + full_content: str, + language: str + ): + """使用语义分析增强代码块""" + # 提取导入 + imports = self._extract_imports(full_content, language) + + for chunk in chunks: + # 添加相关导入 + chunk.imports = self._filter_relevant_imports(imports, chunk.content) + + # 提取函数调用 + chunk.calls = self._extract_function_calls(chunk.content, language) + + # 提取定义 + chunk.definitions = self._extract_definitions(chunk.content, language) + + def _extract_imports(self, content: str, language: str) -> List[str]: + """提取导入语句""" + imports = [] + + patterns = { + "python": [ + r"^import\s+([\w.]+)", + r"^from\s+([\w.]+)\s+import", + ], + "javascript": [ + r"^import\s+.*\s+from\s+['\"]([^'\"]+)['\"]", + r"require\s*\(['\"]([^'\"]+)['\"]\)", + ], + "typescript": [ + r"^import\s+.*\s+from\s+['\"]([^'\"]+)['\"]", + ], + "java": [ + r"^import\s+([\w.]+);", + ], + "go": [ + r"['\"]([^'\"]+)['\"]", + ], + } + + for pattern in patterns.get(language, []): + matches = re.findall(pattern, content, re.MULTILINE) + imports.extend(matches) + + return list(set(imports)) + + def _filter_relevant_imports(self, all_imports: List[str], chunk_content: str) -> List[str]: + """过滤与代码块相关的导入""" + relevant = [] + for imp in all_imports: + module_name = imp.split('.')[-1] + if re.search(rf'\b{re.escape(module_name)}\b', chunk_content): + relevant.append(imp) + return relevant[:20] + + def _extract_function_calls(self, content: str, language: str) -> List[str]: + """提取函数调用""" + pattern = r'\b(\w+)\s*\(' + matches = re.findall(pattern, content) + + keywords = { + "python": {"if", "for", "while", "with", "def", "class", "return", "except", "print", "assert", "lambda"}, + "javascript": {"if", "for", "while", "switch", "function", "return", "catch", "console", "async", "await"}, + "java": {"if", "for", "while", "switch", "return", "catch", "throw", "new"}, + "go": {"if", "for", "switch", "return", "func", "go", "defer"}, + } + + lang_keywords = keywords.get(language, set()) + calls = [m for m in matches if m not in lang_keywords] + + return list(set(calls))[:30] + + def _extract_definitions(self, content: str, language: str) -> List[str]: + """提取定义的标识符""" + definitions = [] + + patterns = { + "python": [ + r"def\s+(\w+)\s*\(", + r"class\s+(\w+)", + r"(\w+)\s*=\s*", + ], + "javascript": [ + r"function\s+(\w+)", + r"(?:const|let|var)\s+(\w+)", + r"class\s+(\w+)", + ], + } + + for pattern in patterns.get(language, []): + matches = re.findall(pattern, content) + definitions.extend(matches) + + return list(set(definitions))[:20] + diff --git a/backend/requirements.txt b/backend/requirements.txt index 73fda9c..708fb5e 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -17,3 +17,51 @@ reportlab>=4.0.0 weasyprint>=66.0 jinja2>=3.1.6 json-repair>=0.30.0 + +# ============ Agent 模块依赖 ============ + +# LangChain 核心 +langchain>=0.1.0 +langchain-community>=0.0.20 +langchain-openai>=0.0.5 + +# LangGraph (状态图工作流) +langgraph>=0.0.40 + +# 向量数据库 +chromadb>=0.4.22 + +# Token 计算 +tiktoken>=0.5.2 + +# Docker 沙箱 +docker>=7.0.0 + +# 异步文件操作 +aiofiles>=23.2.1 + +# SSE 流 +sse-starlette>=1.8.2 + +# ============ 代码解析 (高级库) ============ + +# Tree-sitter AST 解析 +tree-sitter>=0.21.0 +tree-sitter-languages>=1.10.0 + +# 通用代码解析 +pygments>=2.17.0 + +# ============ 外部安全工具 (可选安装) ============ +# 这些工具可以通过 pip 安装,或使用系统包管理器 + +# Python 安全扫描 +bandit>=1.7.0 +safety>=2.3.0 + +# 静态分析 (需要单独安装 semgrep CLI) +# pip install semgrep + +# 依赖漏洞扫描 +pip-audit>=2.6.0 + diff --git a/docker/sandbox/Dockerfile b/docker/sandbox/Dockerfile new file mode 100644 index 0000000..530ae4e --- /dev/null +++ b/docker/sandbox/Dockerfile @@ -0,0 +1,62 @@ +# DeepAudit Agent Sandbox +# 安全沙箱环境用于漏洞验证和 PoC 执行 + +FROM python:3.11-slim-bookworm + +LABEL maintainer="XCodeReviewer Team" +LABEL description="Secure sandbox environment for vulnerability verification" + +# 安装基本工具 +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl \ + wget \ + netcat-openbsd \ + dnsutils \ + iputils-ping \ + ca-certificates \ + git \ + && rm -rf /var/lib/apt/lists/* + +# 安装 Node.js (用于 JavaScript/TypeScript 代码执行) +RUN curl -fsSL https://deb.nodesource.com/setup_20.x | bash - \ + && apt-get install -y nodejs \ + && rm -rf /var/lib/apt/lists/* + +# 安装常用的安全测试 Python 库 +RUN pip install --no-cache-dir \ + requests \ + httpx \ + aiohttp \ + beautifulsoup4 \ + lxml \ + pycryptodome \ + paramiko \ + pyjwt \ + python-jose \ + sqlparse + +# 创建非 root 用户 +RUN groupadd -g 1000 sandbox && \ + useradd -u 1000 -g sandbox -m -s /bin/bash sandbox + +# 创建工作目录 +RUN mkdir -p /workspace /tmp/sandbox && \ + chown -R sandbox:sandbox /workspace /tmp/sandbox + +# 设置环境变量 +ENV HOME=/home/sandbox +ENV PATH=/home/sandbox/.local/bin:$PATH +ENV PYTHONDONTWRITEBYTECODE=1 +ENV PYTHONUNBUFFERED=1 + +# 限制 Python 导入路径 +ENV PYTHONPATH=/workspace + +# 切换到非 root 用户 +USER sandbox + +WORKDIR /workspace + +# 默认命令 +CMD ["/bin/bash"] + diff --git a/docker/sandbox/build.sh b/docker/sandbox/build.sh new file mode 100644 index 0000000..44212bc --- /dev/null +++ b/docker/sandbox/build.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# 构建沙箱镜像 + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +IMAGE_NAME="deepaudit-sandbox" +IMAGE_TAG="latest" + +echo "Building sandbox image: ${IMAGE_NAME}:${IMAGE_TAG}" + +docker build \ + -t "${IMAGE_NAME}:${IMAGE_TAG}" \ + -f "${SCRIPT_DIR}/Dockerfile" \ + "${SCRIPT_DIR}" + +echo "Build complete: ${IMAGE_NAME}:${IMAGE_TAG}" + +# 验证镜像 +echo "Verifying image..." +docker run --rm "${IMAGE_NAME}:${IMAGE_TAG}" python3 --version +docker run --rm "${IMAGE_NAME}:${IMAGE_TAG}" node --version + +echo "Sandbox image ready!" + diff --git a/docker/sandbox/seccomp.json b/docker/sandbox/seccomp.json new file mode 100644 index 0000000..15f135f --- /dev/null +++ b/docker/sandbox/seccomp.json @@ -0,0 +1,266 @@ +{ + "defaultAction": "SCMP_ACT_ERRNO", + "defaultErrnoRet": 1, + "archMap": [ + { + "architecture": "SCMP_ARCH_X86_64", + "subArchitectures": [ + "SCMP_ARCH_X86", + "SCMP_ARCH_X32" + ] + }, + { + "architecture": "SCMP_ARCH_AARCH64", + "subArchitectures": [ + "SCMP_ARCH_ARM" + ] + } + ], + "syscalls": [ + { + "names": [ + "accept", + "accept4", + "access", + "arch_prctl", + "bind", + "brk", + "capget", + "capset", + "chdir", + "chmod", + "chown", + "clock_getres", + "clock_gettime", + "clock_nanosleep", + "clone", + "clone3", + "close", + "connect", + "copy_file_range", + "dup", + "dup2", + "dup3", + "epoll_create", + "epoll_create1", + "epoll_ctl", + "epoll_pwait", + "epoll_pwait2", + "epoll_wait", + "eventfd", + "eventfd2", + "execve", + "execveat", + "exit", + "exit_group", + "faccessat", + "faccessat2", + "fadvise64", + "fallocate", + "fchdir", + "fchmod", + "fchmodat", + "fchown", + "fchownat", + "fcntl", + "fdatasync", + "fgetxattr", + "flistxattr", + "flock", + "fork", + "fstat", + "fstatfs", + "fsync", + "ftruncate", + "futex", + "getcwd", + "getdents", + "getdents64", + "getegid", + "geteuid", + "getgid", + "getgroups", + "getpeername", + "getpgid", + "getpgrp", + "getpid", + "getppid", + "getpriority", + "getrandom", + "getresgid", + "getresuid", + "getrlimit", + "getrusage", + "getsid", + "getsockname", + "getsockopt", + "gettid", + "gettimeofday", + "getuid", + "getxattr", + "inotify_add_watch", + "inotify_init", + "inotify_init1", + "inotify_rm_watch", + "ioctl", + "kill", + "lchown", + "lgetxattr", + "link", + "linkat", + "listen", + "listxattr", + "llistxattr", + "lseek", + "lstat", + "madvise", + "membarrier", + "memfd_create", + "mincore", + "mkdir", + "mkdirat", + "mknod", + "mknodat", + "mlock", + "mlock2", + "mlockall", + "mmap", + "mprotect", + "mremap", + "msync", + "munlock", + "munlockall", + "munmap", + "name_to_handle_at", + "nanosleep", + "newfstatat", + "open", + "openat", + "openat2", + "pause", + "pipe", + "pipe2", + "poll", + "ppoll", + "prctl", + "pread64", + "preadv", + "preadv2", + "prlimit64", + "pselect6", + "pwrite64", + "pwritev", + "pwritev2", + "read", + "readahead", + "readlink", + "readlinkat", + "readv", + "recv", + "recvfrom", + "recvmmsg", + "recvmsg", + "rename", + "renameat", + "renameat2", + "restart_syscall", + "rmdir", + "rseq", + "rt_sigaction", + "rt_sigpending", + "rt_sigprocmask", + "rt_sigqueueinfo", + "rt_sigreturn", + "rt_sigsuspend", + "rt_sigtimedwait", + "rt_tgsigqueueinfo", + "sched_getaffinity", + "sched_getattr", + "sched_getparam", + "sched_get_priority_max", + "sched_get_priority_min", + "sched_getscheduler", + "sched_rr_get_interval", + "sched_setaffinity", + "sched_setattr", + "sched_setparam", + "sched_setscheduler", + "sched_yield", + "seccomp", + "select", + "semctl", + "semget", + "semop", + "semtimedop", + "send", + "sendfile", + "sendmmsg", + "sendmsg", + "sendto", + "set_robust_list", + "set_tid_address", + "setgid", + "setgroups", + "setitimer", + "setpgid", + "setpriority", + "setregid", + "setresgid", + "setresuid", + "setreuid", + "setsid", + "setsockopt", + "setuid", + "shmat", + "shmctl", + "shmdt", + "shmget", + "shutdown", + "sigaltstack", + "signalfd", + "signalfd4", + "socket", + "socketpair", + "splice", + "stat", + "statfs", + "statx", + "symlink", + "symlinkat", + "sync", + "sync_file_range", + "syncfs", + "sysinfo", + "tee", + "tgkill", + "time", + "timer_create", + "timer_delete", + "timer_getoverrun", + "timer_gettime", + "timer_settime", + "timerfd_create", + "timerfd_gettime", + "timerfd_settime", + "times", + "tkill", + "truncate", + "umask", + "uname", + "unlink", + "unlinkat", + "utime", + "utimensat", + "utimes", + "vfork", + "wait4", + "waitid", + "waitpid", + "write", + "writev" + ], + "action": "SCMP_ACT_ALLOW" + } + ] +} + diff --git a/docs/AGENT_AUDIT.md b/docs/AGENT_AUDIT.md new file mode 100644 index 0000000..0a85432 --- /dev/null +++ b/docs/AGENT_AUDIT.md @@ -0,0 +1,298 @@ +# DeepAudit Agent 审计模块 + +## 概述 + +Agent 审计模块是 DeepAudit 的高级安全审计功能,基于 **LangGraph** 状态图构建的混合 AI Agent 架构,实现自主代码安全分析和漏洞验证。 + +## LangGraph 工作流架构 + +``` +┌─────────────────────────────────────────────────────────────────────┐ +│ LangGraph 审计工作流 │ +├─────────────────────────────────────────────────────────────────────┤ +│ │ +│ START │ +│ │ │ +│ ▼ │ +│ ┌────────────────────────────────────────────────────────────────┐│ +│ │ Recon Node (信息收集) ││ +│ │ • 项目结构分析 • 技术栈识别 ││ +│ │ • 入口点发现 • 依赖扫描 ││ +│ │ ││ +│ │ 使用工具: list_files, npm_audit, safety_scan, gitleaks_scan ││ +│ └────────────────────────────┬───────────────────────────────────┘│ +│ │ │ +│ ▼ │ +│ ┌────────────────────────────────────────────────────────────────┐│ +│ │ Analysis Node (漏洞分析) ││ +│ │ • Semgrep 静态分析 • RAG 语义搜索 ││ +│ │ • 模式匹配 • LLM 深度分析 ││ +│ │ • 数据流追踪 ││ +│ │ ◄─────┐ ││ +│ │ 使用工具: semgrep_scan, bandit_scan, rag_query, │ ││ +│ │ code_analysis, pattern_match │ ││ +│ └────────────────────────────┬──────────────────────────┘───────┘│ +│ │ │ │ +│ ▼ │ │ +│ ┌────────────────────────────────────────────────────────────────┐│ +│ │ Verification Node (漏洞验证) ││ +│ │ • LLM 漏洞验证 • 沙箱测试 ││ +│ │ • PoC 生成 • 误报过滤 ││ +│ │ ────────┘ ││ +│ │ 使用工具: vulnerability_validation, sandbox_exec, ││ +│ │ verify_vulnerability ││ +│ └────────────────────────────┬───────────────────────────────────┘│ +│ │ │ +│ ▼ │ +│ ┌────────────────────────────────────────────────────────────────┐│ +│ │ Report Node (报告生成) ││ +│ │ • 漏洞汇总 • 安全评分 ││ +│ │ • 修复建议 • 统计分析 ││ +│ └────────────────────────────┬───────────────────────────────────┘│ +│ │ │ +│ ▼ │ +│ END │ +│ │ +└────────────────────────────────────────────────────────────────────┘ + +状态流转: + • Recon → Analysis: 收集到入口点后进入分析 + • Analysis → Analysis: 发现较多问题时继续迭代 + • Analysis → Verification: 有发现时进入验证 + • Verification → Analysis: 误报率高时回溯分析 + • Verification → Report: 验证完成后生成报告 +``` + +## 核心特性 + +### 1. LangGraph 状态图 + +- **声明式工作流**: 使用图结构定义 Agent 协作流程 +- **状态自动合并**: `Annotated[List, operator.add]` 实现发现累加 +- **条件路由**: 基于状态动态决定下一步 +- **检查点恢复**: 支持任务中断后继续 + +### 2. Agent 工具集 + +#### 内置工具 + +| 工具 | 功能 | 节点 | +|------|------|------| +| `list_files` | 目录浏览 | Recon | +| `read_file` | 文件读取 | All | +| `search_code` | 代码搜索 | Analysis | +| `rag_query` | 语义检索 | Analysis | +| `security_search` | 安全代码搜索 | Analysis | +| `function_context` | 函数上下文 | Analysis | +| `pattern_match` | 模式匹配 | Analysis | +| `code_analysis` | LLM 分析 | Analysis | +| `dataflow_analysis` | 数据流追踪 | Analysis | +| `vulnerability_validation` | 漏洞验证 | Verification | +| `sandbox_exec` | 沙箱执行 | Verification | +| `verify_vulnerability` | 自动验证 | Verification | + +#### 外部安全工具 + +| 工具 | 功能 | 适用场景 | +|------|------|----------| +| `semgrep_scan` | Semgrep 静态分析 | 多语言快速扫描 | +| `bandit_scan` | Bandit Python 扫描 | Python 安全分析 | +| `gitleaks_scan` | Gitleaks 密钥检测 | 密钥泄露检测 | +| `trufflehog_scan` | TruffleHog 扫描 | 深度密钥扫描 | +| `npm_audit` | npm 依赖审计 | Node.js 依赖漏洞 | +| `safety_scan` | Safety Python 审计 | Python 依赖漏洞 | +| `osv_scan` | OSV 漏洞扫描 | 多语言依赖漏洞 | + +### 3. RAG 系统 + +- **代码分块**: 基于 Tree-sitter AST 的智能分块 +- **向量存储**: ChromaDB 持久化 +- **多语言支持**: Python, JavaScript, TypeScript, Java, Go, PHP, Rust 等 +- **嵌入模型**: 独立配置,支持 OpenAI、Ollama、Cohere、HuggingFace + +### 4. 安全沙箱 + +- **Docker 隔离**: 安全容器执行 +- **资源限制**: 内存、CPU 限制 +- **网络隔离**: 可配置网络访问 +- **seccomp 策略**: 系统调用白名单 + +## 配置 + +### 环境变量 + +```bash +# LLM 配置 +DEFAULT_LLM_MODEL=gpt-4-turbo-preview +LLM_API_KEY=your-api-key +LLM_BASE_URL=https://api.openai.com/v1 + +# 嵌入模型配置(独立于 LLM) +EMBEDDING_PROVIDER=openai +EMBEDDING_MODEL=text-embedding-3-small + +# 向量数据库 +VECTOR_DB_PATH=./data/vectordb + +# 沙箱配置 +SANDBOX_IMAGE=deepaudit-sandbox:latest +SANDBOX_MEMORY_LIMIT=512m +SANDBOX_CPU_LIMIT=1.0 +SANDBOX_NETWORK_DISABLED=true +``` + +### Agent 任务配置 + +```json +{ + "target_vulnerabilities": [ + "sql_injection", + "xss", + "command_injection", + "path_traversal", + "ssrf" + ], + "verification_level": "sandbox", + "exclude_patterns": ["node_modules", "__pycache__", ".git"], + "max_iterations": 3, + "timeout_seconds": 1800 +} +``` + +## 部署 + +### 1. 安装依赖 + +```bash +cd backend +pip install -r requirements.txt + +# 可选:安装外部工具 +pip install semgrep bandit safety +brew install gitleaks trufflehog osv-scanner # macOS +``` + +### 2. 构建沙箱镜像 + +```bash +cd docker/sandbox +./build.sh +``` + +### 3. 数据库迁移 + +```bash +alembic upgrade head +``` + +### 4. 启动服务 + +```bash +uvicorn app.main:app --host 0.0.0.0 --port 8000 +``` + +## API 接口 + +### 创建任务 + +```http +POST /api/v1/agent-tasks/ +Content-Type: application/json + +{ + "project_id": "xxx", + "name": "安全审计", + "target_vulnerabilities": ["sql_injection", "xss"], + "verification_level": "sandbox", + "max_iterations": 3 +} +``` + +### 事件流 + +```http +GET /api/v1/agent-tasks/{task_id}/events +Accept: text/event-stream +``` + +### 获取发现 + +```http +GET /api/v1/agent-tasks/{task_id}/findings?verified_only=true +``` + +### 任务摘要 + +```http +GET /api/v1/agent-tasks/{task_id}/summary +``` + +## 支持的漏洞类型 + +| 类型 | 说明 | +|------|------| +| `sql_injection` | SQL 注入 | +| `xss` | 跨站脚本 | +| `command_injection` | 命令注入 | +| `path_traversal` | 路径遍历 | +| `ssrf` | 服务端请求伪造 | +| `xxe` | XML 外部实体 | +| `insecure_deserialization` | 不安全反序列化 | +| `hardcoded_secret` | 硬编码密钥 | +| `weak_crypto` | 弱加密 | +| `authentication_bypass` | 认证绕过 | +| `authorization_bypass` | 授权绕过 | +| `idor` | 不安全直接对象引用 | + +## 目录结构 + +``` +backend/app/services/agent/ +├── __init__.py # 模块导出 +├── event_manager.py # 事件管理 +├── agents/ # Agent 实现 +│ ├── __init__.py +│ ├── base.py # Agent 基类 +│ ├── recon.py # 信息收集 Agent +│ ├── analysis.py # 漏洞分析 Agent +│ ├── verification.py # 漏洞验证 Agent +│ └── orchestrator.py # 编排 Agent +├── graph/ # LangGraph 工作流 +│ ├── __init__.py +│ ├── audit_graph.py # 状态定义和图构建 +│ ├── nodes.py # 节点实现 +│ └── runner.py # 执行器 +├── tools/ # Agent 工具 +│ ├── __init__.py +│ ├── base.py # 工具基类 +│ ├── rag_tool.py # RAG 工具 +│ ├── pattern_tool.py # 模式匹配工具 +│ ├── code_analysis_tool.py +│ ├── file_tool.py # 文件操作 +│ ├── sandbox_tool.py # 沙箱工具 +│ └── external_tools.py # 外部安全工具 +└── prompts/ # 系统提示词 + ├── __init__.py + └── system_prompts.py +``` + +## 故障排除 + +### 沙箱镜像检查 + +```bash +docker images | grep deepaudit-sandbox +``` + +### 日志查看 + +```bash +tail -f logs/agent.log +``` + +### 常见问题 + +1. **RAG 初始化失败**: 检查嵌入模型配置和 API Key +2. **沙箱启动失败**: 确保 Docker 正常运行 +3. **外部工具不可用**: 检查 semgrep/bandit 等是否已安装 + diff --git a/frontend/Dockerfile b/frontend/Dockerfile index f1f7c03..3a28948 100644 --- a/frontend/Dockerfile +++ b/frontend/Dockerfile @@ -54,3 +54,4 @@ EXPOSE 3000 ENTRYPOINT ["/docker-entrypoint.sh"] CMD ["serve", "-s", "dist", "-l", "3000"] + diff --git a/frontend/src/app/ProtectedRoute.tsx b/frontend/src/app/ProtectedRoute.tsx index 764f102..6373c97 100644 --- a/frontend/src/app/ProtectedRoute.tsx +++ b/frontend/src/app/ProtectedRoute.tsx @@ -17,3 +17,4 @@ export const ProtectedRoute = () => { }; + diff --git a/frontend/src/app/routes.tsx b/frontend/src/app/routes.tsx index 1b63092..cfd4072 100644 --- a/frontend/src/app/routes.tsx +++ b/frontend/src/app/routes.tsx @@ -5,6 +5,7 @@ import RecycleBin from "@/pages/RecycleBin"; import InstantAnalysis from "@/pages/InstantAnalysis"; import AuditTasks from "@/pages/AuditTasks"; import TaskDetail from "@/pages/TaskDetail"; +import AgentAudit from "@/pages/AgentAudit"; import AdminDashboard from "@/pages/AdminDashboard"; import LogsPage from "@/pages/LogsPage"; import Account from "@/pages/Account"; @@ -56,6 +57,12 @@ const routes: RouteConfig[] = [ element: , visible: false, }, + { + name: "Agent审计", + path: "/agent-audit/:taskId", + element: , + visible: false, + }, { name: "审计规则", path: "/audit-rules", diff --git a/frontend/src/components/agent/AgentModeSelector.tsx b/frontend/src/components/agent/AgentModeSelector.tsx new file mode 100644 index 0000000..9e68d2d --- /dev/null +++ b/frontend/src/components/agent/AgentModeSelector.tsx @@ -0,0 +1,158 @@ +/** + * Agent 审计模式选择器 + * 允许用户在快速审计和 Agent 审计模式之间选择 + */ + +import { Bot, Zap, CheckCircle2, Clock, Shield, Code } from "lucide-react"; +import { cn } from "@/shared/utils/utils"; + +export type AuditMode = "fast" | "agent"; + +interface AgentModeSelectorProps { + value: AuditMode; + onChange: (mode: AuditMode) => void; + disabled?: boolean; +} + +export default function AgentModeSelector({ + value, + onChange, + disabled = false, +}: AgentModeSelectorProps) { + return ( +
+
+ + + 审计模式 + +
+ +
+ {/* 快速审计模式 */} + + + {/* Agent 审计模式 */} + +
+ + {/* 模式说明 */} + {value === "agent" && ( +
+

🤖 Agent 审计模式说明:

+
    +
  • AI Agent 会自主规划审计策略
  • +
  • 使用 RAG 技术进行代码语义检索
  • +
  • 在 Docker 沙箱中验证发现的漏洞
  • +
  • 可生成可复现的 PoC(概念验证)代码
  • +
  • 审计时间较长,但结果更全面准确
  • +
+
+ )} +
+ ); +} + diff --git a/frontend/src/components/agent/EmbeddingConfig.tsx b/frontend/src/components/agent/EmbeddingConfig.tsx new file mode 100644 index 0000000..4a1faf8 --- /dev/null +++ b/frontend/src/components/agent/EmbeddingConfig.tsx @@ -0,0 +1,449 @@ +/** + * 嵌入模型配置组件 + * 独立于 LLM 配置,专门用于 Agent 审计的 RAG 系统 + */ + +import { useState, useEffect } from "react"; +import { + Card, + CardContent, + CardDescription, + CardHeader, + CardTitle, +} from "@/components/ui/card"; +import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; +import { Badge } from "@/components/ui/badge"; +import { Separator } from "@/components/ui/separator"; +import { + Brain, + Cpu, + Check, + X, + Loader2, + RefreshCw, + Server, + Key, + Zap, + Info, +} from "lucide-react"; +import { toast } from "sonner"; +import { apiClient } from "@/shared/api/serverClient"; + +interface EmbeddingProvider { + id: string; + name: string; + description: string; + models: string[]; + requires_api_key: boolean; + default_model: string; +} + +interface EmbeddingConfig { + provider: string; + model: string; + base_url: string | null; + dimensions: number; + batch_size: number; +} + +interface TestResult { + success: boolean; + message: string; + dimensions?: number; + sample_embedding?: number[]; + latency_ms?: number; +} + +export default function EmbeddingConfigPanel() { + const [providers, setProviders] = useState([]); + const [currentConfig, setCurrentConfig] = useState(null); + const [loading, setLoading] = useState(true); + const [saving, setSaving] = useState(false); + const [testing, setTesting] = useState(false); + const [testResult, setTestResult] = useState(null); + + // 表单状态 + const [selectedProvider, setSelectedProvider] = useState(""); + const [selectedModel, setSelectedModel] = useState(""); + const [apiKey, setApiKey] = useState(""); + const [baseUrl, setBaseUrl] = useState(""); + const [batchSize, setBatchSize] = useState(100); + + // 加载数据 + useEffect(() => { + loadData(); + }, []); + + // 当 provider 改变时更新模型 + useEffect(() => { + if (selectedProvider) { + const provider = providers.find((p) => p.id === selectedProvider); + if (provider) { + setSelectedModel(provider.default_model); + } + } + }, [selectedProvider, providers]); + + const loadData = async () => { + try { + setLoading(true); + const [providersRes, configRes] = await Promise.all([ + apiClient.get("/embedding/providers"), + apiClient.get("/embedding/config"), + ]); + + setProviders(providersRes.data); + setCurrentConfig(configRes.data); + + // 设置表单默认值 + if (configRes.data) { + setSelectedProvider(configRes.data.provider); + setSelectedModel(configRes.data.model); + setBaseUrl(configRes.data.base_url || ""); + setBatchSize(configRes.data.batch_size); + } + } catch (error) { + toast.error("加载配置失败"); + } finally { + setLoading(false); + } + }; + + const handleSave = async () => { + if (!selectedProvider || !selectedModel) { + toast.error("请选择提供商和模型"); + return; + } + + const provider = providers.find((p) => p.id === selectedProvider); + if (provider?.requires_api_key && !apiKey) { + toast.error(`${provider.name} 需要 API Key`); + return; + } + + try { + setSaving(true); + await apiClient.put("/embedding/config", { + provider: selectedProvider, + model: selectedModel, + api_key: apiKey || undefined, + base_url: baseUrl || undefined, + batch_size: batchSize, + }); + + toast.success("配置已保存"); + await loadData(); + } catch (error: any) { + toast.error(error.response?.data?.detail || "保存失败"); + } finally { + setSaving(false); + } + }; + + const handleTest = async () => { + if (!selectedProvider || !selectedModel) { + toast.error("请选择提供商和模型"); + return; + } + + try { + setTesting(true); + setTestResult(null); + + const response = await apiClient.post("/embedding/test", { + provider: selectedProvider, + model: selectedModel, + api_key: apiKey || undefined, + base_url: baseUrl || undefined, + }); + + setTestResult(response.data); + + if (response.data.success) { + toast.success("测试成功"); + } else { + toast.error("测试失败"); + } + } catch (error: any) { + setTestResult({ + success: false, + message: error.response?.data?.detail || "测试失败", + }); + toast.error("测试失败"); + } finally { + setTesting(false); + } + }; + + const selectedProviderInfo = providers.find((p) => p.id === selectedProvider); + + if (loading) { + return ( +
+ +
+ ); + } + + return ( + + +
+
+ +
+
+ 嵌入模型配置 + + 用于 Agent 审计的 RAG 代码检索,独立于分析 LLM + +
+
+
+ + + {/* 当前配置状态 */} + {currentConfig && ( +
+
+ + 当前配置 +
+
+
+ 提供商:{" "} + + {currentConfig.provider} + +
+
+ 模型:{" "} + {currentConfig.model} +
+
+ 向量维度:{" "} + {currentConfig.dimensions} +
+
+ 批处理大小:{" "} + {currentConfig.batch_size} +
+
+
+ )} + + + + {/* 提供商选择 */} +
+ + + + {selectedProviderInfo && ( +

+ + {selectedProviderInfo.description} +

+ )} +
+ + {/* 模型选择 */} + {selectedProviderInfo && ( +
+ + +
+ )} + + {/* API Key */} + {selectedProviderInfo?.requires_api_key && ( +
+ + setApiKey(e.target.value)} + placeholder="输入 API Key" + className="border-2 border-black rounded-none font-mono" + /> +

+ API Key 将安全存储,不会显示在页面上 +

+
+ )} + + {/* 自定义端点 */} +
+ + setBaseUrl(e.target.value)} + placeholder={ + selectedProvider === "ollama" + ? "http://localhost:11434" + : selectedProvider === "huggingface" + ? "https://router.huggingface.co" + : selectedProvider === "cohere" + ? "https://api.cohere.com/v2" + : selectedProvider === "jina" + ? "https://api.jina.ai/v1" + : "https://api.openai.com/v1" + } + className="border-2 border-black rounded-none font-mono" + /> +

+ 用于 API 代理或自托管服务 +

+
+ + {/* 批处理大小 */} +
+ + setBatchSize(parseInt(e.target.value) || 100)} + min={1} + max={500} + className="border-2 border-black rounded-none font-mono w-32" + /> +

+ 每批嵌入的文本数量,建议 50-100 +

+
+ + {/* 测试结果 */} + {testResult && ( +
+
+ {testResult.success ? ( + + ) : ( + + )} + + {testResult.success ? "测试成功" : "测试失败"} + +
+

{testResult.message}

+ {testResult.success && ( +
+
向量维度: {testResult.dimensions}
+
延迟: {testResult.latency_ms}ms
+ {testResult.sample_embedding && ( +
+ 示例向量: [{testResult.sample_embedding.map((v) => v.toFixed(4)).join(", ")}...] +
+ )} +
+ )} +
+ )} + + {/* 操作按钮 */} +
+ + + + + +
+ + {/* 说明 */} +
+

💡 关于嵌入模型

+
    +
  • 嵌入模型用于 Agent 审计的代码语义搜索 (RAG)
  • +
  • 与分析使用的 LLM 独立配置,互不影响
  • +
  • 推荐使用 OpenAI text-embedding-3-small 或本地 Ollama
  • +
  • 向量维度影响存储空间和检索精度
  • +
+
+
+
+ ); +} + diff --git a/frontend/src/components/audit/CreateTaskDialog.tsx b/frontend/src/components/audit/CreateTaskDialog.tsx index fcc9305..3e89caa 100644 --- a/frontend/src/components/audit/CreateTaskDialog.tsx +++ b/frontend/src/components/audit/CreateTaskDialog.tsx @@ -1,4 +1,5 @@ import { useState, useEffect, useMemo, useRef } from "react"; +import { useNavigate } from "react-router-dom"; import { Dialog, DialogContent, @@ -35,16 +36,19 @@ import { Shield, Loader2, Zap, + Bot, } from "lucide-react"; import { toast } from "sonner"; import { api } from "@/shared/config/database"; import { getRuleSets, type AuditRuleSet } from "@/shared/api/rules"; import { getPromptTemplates, type PromptTemplate } from "@/shared/api/prompts"; +import { createAgentTask } from "@/shared/api/agentTasks"; import { useProjects } from "./hooks/useTaskForm"; import { useZipFile, formatFileSize } from "./hooks/useZipFile"; import TerminalProgressDialog from "./TerminalProgressDialog"; import FileSelectionDialog from "./FileSelectionDialog"; +import AgentModeSelector, { type AuditMode } from "@/components/agent/AgentModeSelector"; import { runRepositoryAudit } from "@/features/projects/services/repoScan"; import { @@ -76,6 +80,7 @@ export default function CreateTaskDialog({ onTaskCreated, preselectedProjectId, }: CreateTaskDialogProps) { + const navigate = useNavigate(); const [selectedProjectId, setSelectedProjectId] = useState(""); const [searchTerm, setSearchTerm] = useState(""); const [branch, setBranch] = useState("main"); @@ -90,6 +95,9 @@ export default function CreateTaskDialog({ const [showTerminal, setShowTerminal] = useState(false); const [currentTaskId, setCurrentTaskId] = useState(null); + // 审计模式 + const [auditMode, setAuditMode] = useState("agent"); + // 规则集和提示词模板 const [ruleSets, setRuleSets] = useState([]); const [promptTemplates, setPromptTemplates] = useState([]); @@ -205,6 +213,31 @@ export default function CreateTaskDialog({ setCreating(true); let taskId: string; + // Agent 审计模式 + if (auditMode === "agent") { + const agentTask = await createAgentTask({ + project_id: selectedProject.id, + name: `Agent审计-${selectedProject.name}`, + branch_name: isRepositoryProject(selectedProject) ? branch : undefined, + exclude_patterns: excludePatterns, + target_files: selectedFiles, + verification_level: "sandbox", + }); + + onOpenChange(false); + onTaskCreated(); + toast.success("Agent 审计任务已创建"); + + // 导航到 Agent 审计页面 + navigate(`/agent-audit/${agentTask.id}`); + + setSelectedProjectId(""); + setSelectedFiles(undefined); + setExcludePatterns(DEFAULT_EXCLUDES); + return; + } + + // 快速审计模式(原有逻辑) if (isZipProject(selectedProject)) { if (zipState.useStoredZip && zipState.storedZipInfo?.has_file) { taskId = await scanStoredZipFile({ @@ -339,6 +372,15 @@ export default function CreateTaskDialog({ + {/* 审计模式选择 */} + {selectedProject && ( + + )} + {/* 配置区域 */} {selectedProject && (
@@ -589,10 +631,15 @@ export default function CreateTaskDialog({
启动中... + ) : auditMode === "agent" ? ( + <> + + 启动 Agent 审计 + ) : ( <> - - 开始扫描 + + 开始快速扫描 )} diff --git a/frontend/src/components/system/SystemConfig.tsx b/frontend/src/components/system/SystemConfig.tsx index 9a90c1f..98a44d3 100644 --- a/frontend/src/components/system/SystemConfig.tsx +++ b/frontend/src/components/system/SystemConfig.tsx @@ -6,10 +6,11 @@ import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from "@ import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"; import { Settings, Save, RotateCcw, Eye, EyeOff, CheckCircle2, AlertCircle, - Info, Zap, Globe, PlayCircle, Loader2 + Info, Zap, Globe, PlayCircle, Loader2, Brain } from "lucide-react"; import { toast } from "sonner"; import { api } from "@/shared/api/database"; +import EmbeddingConfig from "@/components/agent/EmbeddingConfig"; // LLM 提供商配置 - 2025年最新 const LLM_PROVIDERS = [ @@ -246,10 +247,13 @@ export function SystemConfig() {
- + LLM 配置 + + 嵌入模型 + 分析参数 @@ -388,6 +392,11 @@ export function SystemConfig() {
+ {/* 嵌入模型配置 */} + + + + {/* 分析参数 */}
diff --git a/frontend/src/pages/AgentAudit.tsx b/frontend/src/pages/AgentAudit.tsx new file mode 100644 index 0000000..89d5e7b --- /dev/null +++ b/frontend/src/pages/AgentAudit.tsx @@ -0,0 +1,579 @@ +/** + * Agent 审计页面 + * 机械终端风格的 AI Agent 审计界面 + */ + +import { useState, useEffect, useRef, useCallback } from "react"; +import { useParams, useNavigate } from "react-router-dom"; +import { + Terminal, Bot, Cpu, Shield, AlertTriangle, CheckCircle2, + Loader2, Code, Zap, Activity, ChevronRight, XCircle, + FileCode, Search, Bug, Lock, Play, Square, RefreshCw, + ArrowLeft, Download, ExternalLink +} from "lucide-react"; +import { Button } from "@/components/ui/button"; +import { Badge } from "@/components/ui/badge"; +import { ScrollArea } from "@/components/ui/scroll-area"; +import { toast } from "sonner"; +import { + type AgentTask, + type AgentEvent, + type AgentFinding, + getAgentTask, + getAgentEvents, + getAgentFindings, + cancelAgentTask, + streamAgentEvents, +} from "@/shared/api/agentTasks"; + +// 事件类型图标映射 +const eventTypeIcons: Record = { + phase_start: , + phase_complete: , + thinking: , + tool_call: , + tool_result: , + tool_error: , + finding_new: , + finding_verified: , + info: , + warning: , + error: , + progress: , + task_complete: , + task_error: , + task_cancel: , +}; + +// 事件类型颜色映射 +const eventTypeColors: Record = { + phase_start: "text-cyan-400 font-bold", + phase_complete: "text-green-400", + thinking: "text-purple-300", + tool_call: "text-yellow-300", + tool_result: "text-green-300", + tool_error: "text-red-400", + finding_new: "text-orange-300", + finding_verified: "text-red-300", + info: "text-gray-300", + warning: "text-yellow-300", + error: "text-red-400", + progress: "text-cyan-300", + task_complete: "text-green-400 font-bold", + task_error: "text-red-400 font-bold", + task_cancel: "text-yellow-400", +}; + +// 严重程度颜色 +const severityColors: Record = { + critical: "bg-red-900/50 border-red-500 text-red-300", + high: "bg-orange-900/50 border-orange-500 text-orange-300", + medium: "bg-yellow-900/50 border-yellow-500 text-yellow-300", + low: "bg-blue-900/50 border-blue-500 text-blue-300", + info: "bg-gray-900/50 border-gray-500 text-gray-300", +}; + +const severityIcons: Record = { + critical: "🔴", + high: "🟠", + medium: "🟡", + low: "🟢", + info: "⚪", +}; + +export default function AgentAuditPage() { + const { taskId } = useParams<{ taskId: string }>(); + const navigate = useNavigate(); + + const [task, setTask] = useState(null); + const [events, setEvents] = useState([]); + const [findings, setFindings] = useState([]); + const [isLoading, setIsLoading] = useState(true); + const [isStreaming, setIsStreaming] = useState(false); + const [currentTime, setCurrentTime] = useState(new Date().toLocaleTimeString("zh-CN", { hour12: false })); + + const eventsEndRef = useRef(null); + const abortControllerRef = useRef(null); + + // 是否完成 + const isComplete = task?.status === "completed" || task?.status === "failed" || task?.status === "cancelled"; + + // 加载任务信息 + const loadTask = useCallback(async () => { + if (!taskId) return; + + try { + const taskData = await getAgentTask(taskId); + setTask(taskData); + } catch (error) { + console.error("Failed to load task:", error); + toast.error("加载任务失败"); + } + }, [taskId]); + + // 加载事件 + const loadEvents = useCallback(async () => { + if (!taskId) return; + + try { + const eventsData = await getAgentEvents(taskId, { limit: 500 }); + setEvents(eventsData); + } catch (error) { + console.error("Failed to load events:", error); + } + }, [taskId]); + + // 加载发现 + const loadFindings = useCallback(async () => { + if (!taskId) return; + + try { + const findingsData = await getAgentFindings(taskId); + setFindings(findingsData); + } catch (error) { + console.error("Failed to load findings:", error); + } + }, [taskId]); + + // 初始化加载 + useEffect(() => { + const init = async () => { + setIsLoading(true); + await Promise.all([loadTask(), loadEvents(), loadFindings()]); + setIsLoading(false); + }; + + init(); + }, [loadTask, loadEvents, loadFindings]); + + // 事件流 + useEffect(() => { + if (!taskId || isComplete || isLoading) return; + + const startStreaming = async () => { + setIsStreaming(true); + abortControllerRef.current = new AbortController(); + + try { + const lastSequence = events.length > 0 ? Math.max(...events.map(e => e.sequence)) : 0; + + for await (const event of streamAgentEvents(taskId, lastSequence, abortControllerRef.current.signal)) { + setEvents(prev => { + // 避免重复 + if (prev.some(e => e.id === event.id)) return prev; + return [...prev, event]; + }); + + // 如果是发现事件,刷新发现列表 + if (event.event_type.startsWith("finding_")) { + loadFindings(); + } + + // 如果是结束事件,刷新任务状态 + if (["task_complete", "task_error", "task_cancel"].includes(event.event_type)) { + loadTask(); + loadFindings(); + } + } + } catch (error) { + if ((error as Error).name !== "AbortError") { + console.error("Event stream error:", error); + } + } finally { + setIsStreaming(false); + } + }; + + startStreaming(); + + return () => { + abortControllerRef.current?.abort(); + }; + }, [taskId, isComplete, isLoading, loadTask, loadFindings]); + + // 自动滚动 + useEffect(() => { + eventsEndRef.current?.scrollIntoView({ behavior: "smooth" }); + }, [events]); + + // 更新时间 + useEffect(() => { + const interval = setInterval(() => { + setCurrentTime(new Date().toLocaleTimeString("zh-CN", { hour12: false })); + }, 1000); + + return () => clearInterval(interval); + }, []); + + // 取消任务 + const handleCancel = async () => { + if (!taskId) return; + + if (!confirm("确定要取消此任务吗?已分析的结果将被保留。")) { + return; + } + + try { + await cancelAgentTask(taskId); + toast.success("任务已取消"); + loadTask(); + } catch (error) { + toast.error("取消失败"); + } + }; + + if (isLoading) { + return ( +
+
+ +

正在加载...

+
+
+ ); + } + + if (!task) { + return ( +
+
+ +

任务不存在

+ +
+
+ ); + } + + return ( +
+ {/* 顶部状态栏 */} +
+
+ + +
+
+ +
+
+ AGENT_AUDIT + + {task.name || `任务 ${task.id.slice(0, 8)}`} + +
+
+
+ +
+ {/* 阶段指示器 */} + + + {/* 状态徽章 */} + + + {/* 时间 */} + {currentTime} +
+
+ +
+ {/* 左侧:执行日志 */} +
+
+
+ + Execution Log + {isStreaming && ( + + + LIVE + + )} +
+ {events.length} events +
+ + {/* 终端窗口 */} +
+ {/* CRT 效果 */} +
+ + +
+ {events.map((event) => ( + + ))} + + {/* 光标 */} + {!isComplete && ( +
+ {currentTime} + +
+ )} + +
+
+ +
+ + {/* 底部控制栏 */} +
+
+ {/* 进度 */} +
+ Progress +
+
+
+ {task.progress_percentage.toFixed(0)}% +
+ + {/* Token 消耗 */} + {task.total_chunks > 0 && ( +
+ Chunks: {task.total_chunks} +
+ )} +
+ +
+ {!isComplete && ( + + )} + + {isComplete && ( + + )} +
+
+
+ + {/* 右侧:发现面板 */} +
+
+
+
+ + Findings +
+ + {findings.length} + +
+ + {/* 严重程度统计 */} +
+ {task.critical_count > 0 && ( + 🔴 {task.critical_count} + )} + {task.high_count > 0 && ( + 🟠 {task.high_count} + )} + {task.medium_count > 0 && ( + 🟡 {task.medium_count} + )} + {task.low_count > 0 && ( + 🟢 {task.low_count} + )} +
+
+ + +
+ {findings.length === 0 ? ( +
+ +

暂无发现

+
+ ) : ( + findings.map((finding) => ( + + )) + )} +
+
+ + {/* 评分 */} + {isComplete && ( +
+
+ 安全评分 + = 80 ? "text-green-400" : + task.security_score >= 60 ? "text-yellow-400" : + "text-red-400" + }`}> + {task.security_score.toFixed(0)}/100 + +
+
+ 已验证 + {task.verified_count}/{task.findings_count} +
+
+ )} +
+
+
+ ); +} + +// 阶段指示器组件 +function PhaseIndicator({ phase, status }: { phase: string | null; status: string }) { + const phases = ["planning", "indexing", "analysis", "verification", "reporting"]; + const currentIndex = phase ? phases.indexOf(phase) : -1; + const isComplete = status === "completed"; + const isFailed = status === "failed"; + + return ( +
+ {phases.map((p, idx) => { + const isActive = p === phase; + const isPast = isComplete || (currentIndex >= 0 && idx < currentIndex); + + return ( +
+ ); + })} + {phase && ( + {phase} + )} +
+ ); +} + +// 状态徽章组件 +function StatusBadge({ status }: { status: string }) { + const statusConfig: Record = { + pending: { text: "PENDING", className: "bg-gray-800 text-gray-400 border-gray-600" }, + initializing: { text: "INIT", className: "bg-blue-900/50 text-blue-400 border-blue-600 animate-pulse" }, + planning: { text: "PLANNING", className: "bg-purple-900/50 text-purple-400 border-purple-600 animate-pulse" }, + indexing: { text: "INDEXING", className: "bg-cyan-900/50 text-cyan-400 border-cyan-600 animate-pulse" }, + analyzing: { text: "ANALYZING", className: "bg-yellow-900/50 text-yellow-400 border-yellow-600 animate-pulse" }, + verifying: { text: "VERIFYING", className: "bg-orange-900/50 text-orange-400 border-orange-600 animate-pulse" }, + completed: { text: "COMPLETED", className: "bg-green-900/50 text-green-400 border-green-600" }, + failed: { text: "FAILED", className: "bg-red-900/50 text-red-400 border-red-600" }, + cancelled: { text: "CANCELLED", className: "bg-yellow-900/50 text-yellow-400 border-yellow-600" }, + }; + + const config = statusConfig[status] || statusConfig.pending; + + return ( + + {config.text} + + ); +} + +// 事件行组件 +function EventLine({ event }: { event: AgentEvent }) { + const icon = eventTypeIcons[event.event_type] || ; + const colorClass = eventTypeColors[event.event_type] || "text-gray-400"; + + const timestamp = event.timestamp + ? new Date(event.timestamp).toLocaleTimeString("zh-CN", { hour12: false }) + : ""; + + return ( +
+ + {timestamp} + + {icon} + + {event.message} + {event.tool_duration_ms && ( + ({event.tool_duration_ms}ms) + )} + +
+ ); +} + +// 发现卡片组件 +function FindingCard({ finding }: { finding: AgentFinding }) { + const colorClass = severityColors[finding.severity] || severityColors.info; + const icon = severityIcons[finding.severity] || "⚪"; + + return ( +
+
+ {icon} +
+

{finding.title}

+

{finding.vulnerability_type}

+ {finding.file_path && ( +

+ + {finding.file_path}:{finding.line_start} +

+ )} +
+
+ +
+ {finding.is_verified && ( + + + 已验证 + + )} + {finding.has_poc && ( + + + PoC + + )} +
+
+ ); +} + diff --git a/frontend/src/shared/api/agentTasks.ts b/frontend/src/shared/api/agentTasks.ts new file mode 100644 index 0000000..c7342fc --- /dev/null +++ b/frontend/src/shared/api/agentTasks.ts @@ -0,0 +1,306 @@ +/** + * Agent Tasks API + * Agent 审计任务相关的 API 调用 + */ + +import { apiClient } from "./serverClient"; + +// ============ Types ============ + +export interface AgentTask { + id: string; + project_id: string; + name: string | null; + description: string | null; + task_type: string; + status: string; + current_phase: string | null; + current_step: string | null; + + // 统计 + total_files: number; + indexed_files: number; + analyzed_files: number; + total_chunks: number; + findings_count: number; + verified_count: number; + false_positive_count: number; + + // 严重程度统计 + critical_count: number; + high_count: number; + medium_count: number; + low_count: number; + + // 评分 + quality_score: number; + security_score: number; + + // 时间 + created_at: string; + started_at: string | null; + completed_at: string | null; + + // 进度 + progress_percentage: number; +} + +export interface AgentFinding { + id: string; + task_id: string; + vulnerability_type: string; + severity: string; + title: string; + description: string | null; + + file_path: string | null; + line_start: number | null; + line_end: number | null; + code_snippet: string | null; + + status: string; + is_verified: boolean; + has_poc: boolean; + poc_code: string | null; + + suggestion: string | null; + fix_code: string | null; + ai_explanation: string | null; + ai_confidence: number | null; + + created_at: string; +} + +export interface AgentEvent { + id: string; + task_id: string; + event_type: string; + phase: string | null; + message: string | null; + tool_name: string | null; + tool_input?: Record; + tool_output?: Record; + tool_duration_ms: number | null; + finding_id: string | null; + tokens_used?: number; + metadata?: Record; + sequence: number; + timestamp: string; +} + +export interface CreateAgentTaskRequest { + project_id: string; + name?: string; + description?: string; + audit_scope?: Record; + target_vulnerabilities?: string[]; + verification_level?: "analysis_only" | "sandbox" | "generate_poc"; + branch_name?: string; + exclude_patterns?: string[]; + target_files?: string[]; + max_iterations?: number; + token_budget?: number; + timeout_seconds?: number; +} + +export interface AgentTaskSummary { + task_id: string; + status: string; + progress_percentage: number; + security_score: number; + quality_score: number; + statistics: { + total_files: number; + indexed_files: number; + analyzed_files: number; + total_chunks: number; + findings_count: number; + verified_count: number; + false_positive_count: number; + }; + severity_distribution: { + critical: number; + high: number; + medium: number; + low: number; + }; + vulnerability_types: Record; + duration_seconds: number | null; +} + +// ============ API Functions ============ + +/** + * 创建 Agent 审计任务 + */ +export async function createAgentTask(data: CreateAgentTaskRequest): Promise { + const response = await apiClient.post("/agent-tasks/", data); + return response.data; +} + +/** + * 获取 Agent 任务列表 + */ +export async function getAgentTasks(params?: { + project_id?: string; + status?: string; + skip?: number; + limit?: number; +}): Promise { + const response = await apiClient.get("/agent-tasks/", { params }); + return response.data; +} + +/** + * 获取 Agent 任务详情 + */ +export async function getAgentTask(taskId: string): Promise { + const response = await apiClient.get(`/agent-tasks/${taskId}`); + return response.data; +} + +/** + * 启动 Agent 任务 + */ +export async function startAgentTask(taskId: string): Promise<{ message: string; task_id: string }> { + const response = await apiClient.post(`/agent-tasks/${taskId}/start`); + return response.data; +} + +/** + * 取消 Agent 任务 + */ +export async function cancelAgentTask(taskId: string): Promise<{ message: string; task_id: string }> { + const response = await apiClient.post(`/agent-tasks/${taskId}/cancel`); + return response.data; +} + +/** + * 获取 Agent 任务事件列表 + */ +export async function getAgentEvents( + taskId: string, + params?: { after_sequence?: number; limit?: number } +): Promise { + const response = await apiClient.get(`/agent-tasks/${taskId}/events/list`, { params }); + return response.data; +} + +/** + * 获取 Agent 任务发现列表 + */ +export async function getAgentFindings( + taskId: string, + params?: { + severity?: string; + vulnerability_type?: string; + is_verified?: boolean; + } +): Promise { + const response = await apiClient.get(`/agent-tasks/${taskId}/findings`, { params }); + return response.data; +} + +/** + * 获取单个发现详情 + */ +export async function getAgentFinding(taskId: string, findingId: string): Promise { + const response = await apiClient.get(`/agent-tasks/${taskId}/findings/${findingId}`); + return response.data; +} + +/** + * 更新发现状态 + */ +export async function updateAgentFinding( + taskId: string, + findingId: string, + data: { status?: string } +): Promise { + const response = await apiClient.patch(`/agent-tasks/${taskId}/findings/${findingId}`, data); + return response.data; +} + +/** + * 获取任务摘要 + */ +export async function getAgentTaskSummary(taskId: string): Promise { + const response = await apiClient.get(`/agent-tasks/${taskId}/summary`); + return response.data; +} + +/** + * 创建 SSE 事件源 + */ +export function createAgentEventSource(taskId: string, afterSequence = 0): EventSource { + const baseUrl = import.meta.env.VITE_API_URL || ""; + const url = `${baseUrl}/api/v1/agent-tasks/${taskId}/events?after_sequence=${afterSequence}`; + + // 注意:EventSource 不支持自定义 headers,需要通过 URL 参数或 cookie 传递认证 + // 如果需要认证,可以考虑使用 fetch + ReadableStream 替代 + return new EventSource(url, { withCredentials: true }); +} + +/** + * 使用 fetch 流式获取事件(支持自定义 headers) + */ +export async function* streamAgentEvents( + taskId: string, + afterSequence = 0, + signal?: AbortSignal +): AsyncGenerator { + const token = localStorage.getItem("auth_token"); + const baseUrl = import.meta.env.VITE_API_URL || ""; + const url = `${baseUrl}/api/v1/agent-tasks/${taskId}/events?after_sequence=${afterSequence}`; + + const response = await fetch(url, { + headers: { + Authorization: `Bearer ${token}`, + Accept: "text/event-stream", + }, + signal, + }); + + if (!response.ok) { + throw new Error(`Failed to connect to event stream: ${response.statusText}`); + } + + const reader = response.body?.getReader(); + if (!reader) { + throw new Error("No response body"); + } + + const decoder = new TextDecoder(); + let buffer = ""; + + try { + while (true) { + const { done, value } = await reader.read(); + + if (done) { + break; + } + + buffer += decoder.decode(value, { stream: true }); + + // 解析 SSE 格式 + const lines = buffer.split("\n"); + buffer = lines.pop() || ""; + + for (const line of lines) { + if (line.startsWith("data: ")) { + const data = line.slice(6); + try { + const event = JSON.parse(data) as AgentEvent; + yield event; + } catch { + // 忽略解析错误 + } + } + } + } + } finally { + reader.releaseLock(); + } +} +